使用TimesFM 对车辆销售进行预测
代码功能概述
- 导入相关包与设置环境变量:
- 首先导入了如
os
、numpy
、pandas
等常用的 Python 库,同时设置了一些与特定库(如XLA_PYTHON_CLIENT_PREALLOCATE
和JAX_PM AP_USE_TENSORSTORE
)相关的环境变量,用于优化计算等操作。
- 首先导入了如
- 加载预训练的 TimesFM 模型:
- 通过指定相关超参数(如后端为
gpu
、每核心批处理大小等)以及预训练模型在 Hugging Face 上的仓库id
,实例化了TimesFm
模型对象,用于后续的评估和微调等操作。
- 通过指定相关超参数(如后端为
- 准备数据集相关信息并加载数据:
- 定义了一个数据集字典
DATA_DICT
,包含多个数据集(如ettm1
等)的详细信息,包括数据文件路径、时间频率、划分边界等。 - 根据选定的数据集(示例中初始化为
ettm1
),读取对应的数据文件为DataFrame
,然后配置TimeSeriesdata
类的实例来进行数据加载、划分训练集、验证集和测试集,同时对数据进行了一些规范化等预处理操作,并生成对应的批次数据(train_batches
、val_batches
、test_batches
)。
- 定义了一个数据集字典
- 评估预训练模型在测试集上的 MAE(平均绝对误差):
- 通过迭代测试集批次数据,利用预训练模型进行预测,计算预测值和实际值之间的平均绝对误差,以此来评估模型在当前数据集上的性能表现。
- 微调模型:
- 导入了一系列用于构建和训练模型的
praxis
、paxml
相关的模块和类,进行了诸如定义学习器(包括优化器、学习率调度等配置)、构建任务、初始化模型状态等操作。 - 将预训练模型的参数设置为微调模型的初始权重,然后通过定义训练步和评估步函数,在多个
epoch
内循环进行训练和定期评估(利用早停机制,根据验证集损失决定是否提前停止训练),在每个训练步中对模型参数进行更新,每个评估步计算验证集上的损失,保存最优模型状态的检查点。
- 导入了一系列用于构建和训练模型的
- 加载并评估微调后的模型:
- 从保存的检查点中恢复最优的模型状态,将其参数更新到原
TimesFM
模型中,然后再次在测试集上计算平均绝对误差,以对比微调前后模型性能的变化情况。
- 从保存的检查点中恢复最优的模型状态,将其参数更新到原
针对车辆销售数据的改写步骤
-
数据准备与加载部分(适配车辆销售数据):
- 修改数据集字典
DATA_DICT
:- 创建一个新的字典项来对应你的车辆销售数据集,例如取名为
vehicle_sales
。 - 填写对应的数据文件路径(假设你的车辆销售数据存储在
../datasets/vehicle_sales.csv
,则data_path
设置为此路径)。 - 根据你的数据时间粒度来设置
freq
,比如如果是按天记录的,就可以设置为"D"
(代表 Daily),如果是按月记录的,可设置为"M"
(代表 Monthly)等。 - 按照你的训练、验证、测试集划分的时间范围来设置
boundaries
列表中的值,例如如果前 3 年数据作为训练集,第 4 年作为验证集,第 5 年作为测试集,你需要根据数据点数量等信息确定对应的时间点边界值填入该列表。
- 创建一个新的字典项来对应你的车辆销售数据集,例如取名为
- 调整数据读取和数据加载配置部分:
- 在
data_df = pd.read_csv(open(data_path, "r"))
这行代码中,确认数据文件格式正确能被read_csv
方法读取,如果数据有特定的分隔符、编码等情况,按需调整参数(比如添加sep
参数指定分隔符、encoding
参数指定编码格式等)。 - 根据你的车辆销售数据列名,修改
ts_cols
、num_cov_cols
、cat_cov_cols
的定义。例如,销售量和销售价格等数值型的时间序列列可添加到ts_cols
,车型、经销商这些分类列可以根据需求分配到num_cov_cols
(如果要进行数值编码等处理)或者cat_cov_cols
(作为分类特征)中。同时修改TimeSeriesdata
实例化时传入的参数,确保数据能正确划分和预处理,例如datetime_col
设置为数据中代表日期的列名。
- 在
- 修改数据集字典
-
模型微调部分(可能无需大改,但检查配置合理性):
- 确认微调时定义的学习器配置(如优化器、学习率调度等参数)是否适合车辆销售数据预测任务。你可能需要根据实际情况调整学习率、训练总步数等参数,例如车辆销售数据如果比较复杂,可能需要适当调小学习率、增加训练总步数等,以保证模型能更好地收敛和学习到数据中的模式。
- 检查
build_learner
函数中设置的bprop_variable_exclusion
参数是否合理,对于车辆销售数据微调场景下想要固定或者放开训练的模型层,根据模型结构和需求进行调整,确保只训练希望更新参数的那些部分。
-
模型评估部分(保持逻辑基本一致):
- 在计算微调前后模型在测试集上的平均绝对误差(MAE)部分,确保数据维度等处理符合车辆销售数据的特点。例如,在预测结果和实际结果对比计算
MAE
时,确认预测的销售量、销售价格等和实际值的对应关系和维度对齐正确,特别是如果有多个时间序列维度或者特征维度时,保证forecasts
和actuals
的形状匹配能正确计算误差。
- 在计算微调前后模型在测试集上的平均绝对误差(MAE)部分,确保数据维度等处理符合车辆销售数据的特点。例如,在预测结果和实际结果对比计算
以下是一个python代码(假设你的车辆销售数据 vehicle_sales.csv
有 date
(日期)、car_model
(车型)、dealer
(经销商)、sales_volume
(销售量)、sales_price
(销售价格)这几列,并且想将车型和经销商作为分类特征,销售量和销售价格作为时间序列特征,数据按年划分训练、验证、测试集,这里简化假设前 3 年训练、第 4 年验证、第 5 年测试,并且时间频率是按年 "Y"
):
# 以下代码主要用于基于TimesFM模型对车辆销售数据进行预训练模型评估、模型微调以及微调后模型的评估操作
# 导入相关包用于微调操作,同时设置一些环境变量来优化计算等相关配置
## Importing relevant packages for finetuning
import os
os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = 'false'
os.environ['JAX_PMAP_USE_TENSORSTORE'] = 'false'
import timesfm
import gc
import numpy as np
import pandas as pd
from timesfm import patched_decoder
from timesfm import data_loader
from tqdm import tqdm
import dataclasses
import IPython
import IPython.display
import matplotlib as mpl
import matplotlib.pyplot as plt
mpl.rcParams['figure.figsize'] = (8, 6)
mpl.rcParams['axes.grid'] = False# 加载预训练的TimesFM模型,通过指定相关超参数(如后端使用的设备、每核心批处理大小、预测长度等)以及从Hugging Face获取预训练模型的仓库id
# 实例化TimesFm模型对象,后续将利用该模型进行数据评估和微调等操作
## Loading TimesFM pretrained checkpoint
tfm = timesfm.TimesFm(hparams=timesfm.TimesFmHparams(backend="gpu",per_core_batch_size=32,horizon_len=128,),checkpoint=timesfm.TimesFmCheckpoint(huggingface_repo_id="google/timesfm-1.0-200m"),)# 配置车辆销售数据集相关信息,包括数据集划分边界、数据文件路径、时间频率等,用于后续的数据加载和处理
# 此处简化假设按年划分数据集,前3年作为训练集,第4年作为验证集,第5年作为测试集,时间频率设置为按年("Y")
# 根据实际数据情况和需求,这些设置都可以进行相应调整
## Evaluating pretrained checkpoint on vehicle sales dataset
DATA_DICT = {"vehicle_sales": {"boundaries": [3, 4, 5], # 简化按年划分,前3年训练,第4年验证,第5年测试"data_path": "../datasets/vehicle_sales.csv","freq": "Y", # 按年的时间频率}
}
dataset = "vehicle_sales"
data_path = DATA_DICT[dataset]["data_path"]
freq = DATA_DICT[dataset]["freq"]
int_freq = timesfm.freq_map(freq)
boundaries = DATA_DICT[dataset]["boundaries"]# 读取车辆销售数据文件为DataFrame格式,后续将基于此数据进行进一步处理,例如划分数据集、提取特征列等操作
# 需要确保数据文件路径正确以及数据格式能被read_csv方法正常读取,如有特殊格式可按需调整参数(如分隔符、编码等)
data_df = pd.read_csv(open(data_path, "r"))# 定义时间序列特征列,这里选取销售量和销售价格作为时间序列特征,将用于模型的输入和预测等相关操作
# 根据实际业务需求和数据特点,可调整此列表包含的列名
ts_cols = ["sales_volume", "sales_price"]
# 暂未定义数值型协变量列,可根据后续是否需要添加额外数值型特征进行设置
num_cov_cols = None
# 定义分类特征列,这里选取车型和经销商作为分类特征,模型可以根据这些特征学习不同分类下的销售模式等信息
cat_cov_cols = ["car_model", "dealer"] context_len = 512
pred_len = 96num_ts = len(ts_cols)
batch_size = 16# 实例化TimeSeriesdata类,用于加载、划分和预处理车辆销售数据,配置训练集、验证集、测试集的范围,以及设置数据归一化等参数
# 该类内部会根据设置对数据进行相应处理,生成对应的批次数据,便于后续模型训练和评估使用
dtl = data_loader.TimeSeriesdata(data_path=data_path,datetime_col="date",num_cov_cols=num_cov_cols,cat_cov_cols=cat_cov_cols,ts_cols=np.array(ts_cols),train_range=[0, boundaries[0]],val_range=[boundaries[0], boundaries[1]],test_range=[boundaries[1], boundaries[2]],hist_len=context_len,pred_len=pred_len,batch_size=num_ts,freq=freq,normalize=True,epoch_len=None,holiday=False,permute=True,)
# 获取训练集批次数据,每个批次的数据将按照设置的batch_size进行划分,便于在训练循环中迭代使用
train_batches = dtl.tf_dataset(mode="train", shift=1).batch(batch_size)
# 获取验证集批次数据,同样按照设置进行划分,用于在模型训练过程中的定期验证,以监控模型性能和防止过拟合等
val_batches = dtl.tf_dataset(mode="val", shift=pred_len)
# 获取测试集批次数据,用于最终评估模型在未见过的数据上的性能表现
test_batches = dtl.tf_dataset(mode="test", shift=pred_len)
# 简单遍历训练集批次数据的迭代器,此处主要是为了触发数据加载等相关操作,确保数据可以正常获取,暂未对数据做具体处理
for tbatch in tqdm(train_batches.as_numpy_iterator()):pass
# 打印训练集批次数据中第一个元素(通常是输入数据部分)的形状,用于检查数据维度是否符合预期
print(tbatch[0].shape)# 以下代码块用于计算预训练模型在测试集上的平均绝对误差(MAE),通过迭代测试集批次数据
# 利用预训练模型进行预测,然后对比预测值和实际值计算平均绝对误差,以此评估模型初始性能
### MAE on the test split for the pretrained TimesFM model
mae_losses = []
for batch in tqdm(test_batches.as_numpy_iterator()):past = batch[0]actuals = batch[3]forecasts, _ = tfm.forecast(list(past), [0] * past.shape[0], normalize=True)forecasts = forecasts[:, 0 : actuals.shape[1]]mae_losses.append(np.abs(forecasts - actuals).mean())print(f"MAE: {np.mean(mae_losses)}")# 导入一系列用于构建和训练模型的praxis、paxml相关的模块和类,这些模块提供了配置模型、定义学习器、优化训练过程等功能
# 后续将利用这些工具来对模型进行微调操作,使其能更好地适应车辆销售数据特点和预测任务
## Finetuning the model on the vehicle sales dataset
import jax
from jax import numpy as jnp
from praxis import pax_fiddle
from praxis import py_utils
from praxis import pytypes
from praxis import base_model
from praxis import optimizers
from praxis import schedules
from praxis import base_hyperparams
from praxis import base_layer
from paxml import tasks_lib
from paxml import trainer_lib
from paxml import checkpoints
from paxml import learners
from paxml import partitioning
from paxml import checkpoint_types
# PAX shortcuts,定义一些便捷使用的类型和函数别名,方便后续代码中调用相关功能时书写简洁
NestedMap = py_utils.NestedMap
WeightInit = base_layer.WeightInit
WeightHParams = base_layer.WeightHParams
InstantiableParams = py_utils.InstantiableParams
JTensor = pytypes.JTensor
NpTensor = pytypes.NpTensor
WeightedScalars = pytypes.WeightedScalars
instantiate = base_hyperparams.instantiate
LayerTpl = pax_fiddle.Config[base_layer.BaseLayer]
AuxLossStruct = base_layer.AuxLossStructAUX_LOSS = base_layer.AUX_LOSS
template_field = base_layer.template_field# 定义标准的伪随机数生成器(PRNG)键名称,用于在模型训练等过程中需要随机数的地方保持一致性和可复现性
# PARAMS和RANDOM是常见的用于区分不同用途随机数的标识
PARAMS = base_layer.PARAMS
RANDOM = base_layer.RANDOM# 生成一个初始的随机数生成器的键,设置种子为1234,以便后续在需要随机初始化等操作时能复现结果
key = jax.random.PRNGKey(seed=1234)
# 配置微调模型的结构,使用PatchedDecoderFinetuneModel作为基础结构,并将之前加载的预训练模型的核心层配置传入
# 以此构建微调模型的初始结构,后续将在此基础上进行参数更新等微调操作
model = pax_fiddle.Config(patched_decoder.PatchedDecoderFinetuneModel,name='patched_decoder_finetune',core_layer_tpl=tfm.model_p,
)# 定义构建学习器的函数,配置学习器相关参数,如损失函数名称、优化器(这里使用Adam优化器)及其参数(学习率、学习率调度策略、梯度裁剪阈值、指数移动平均衰减等)
# 同时设置在微调过程中要固定的模型层(通过bprop_variable_exclusion参数指定,这里示例中固定了变压器层,可根据实际情况调整)
### We will hold the transformer layers fixed while finetuning, while training all other components.
@pax_fiddle.auto_config
def build_learner() -> learners.Learner():return pax_fiddle.Config(learners.Learner,name='learner',loss_name='avg_qloss',optimizer=optimizers.Adam(epsilon=1e-7,clip_threshold=1e2,learning_rate=1e-3, # 示例中适当调整学习率,可根据实际情况进一步优化lr_schedule=pax_fiddle.Config(schedules.Cosine,initial_value=1e-4,final_value=1e-5,total_steps=40000,),ema_decay=0.9999,),# 线性探测,固定变压器层(可根据实际情况调整要固定的层)bprop_variable_exclusion=['.*/stacked_transformer_layer/.*'],)# 构建训练任务配置,将之前定义的模型和学习器配置组合起来,同时设置模型相关的分布式训练的一些参数(如mesh形状和轴名称等)
# 用于后续在多设备等分布式环境下进行模型训练时的配置管理
task_p = tasks_lib.SingleTask(name='vehicle_sales_learn',model=model,train=tasks_lib.SingleTask.Train(learner=build_learner(),),
)
task_p.model.ici_mesh_shape = [1, 1, 1]
task_p.model.mesh_axis_names = ['replica', 'data', 'mdl']# 获取可用的设备(如GPU或CPU等)信息,并将其整理为特定的形状,用于构建分布式训练的Mesh对象,以支持模型在多设备上并行训练
DEVICES = np.array(jax.devices()).reshape([1, 1, 1])
MESH = jax.sharding.Mesh(DEVICES, ['replica', 'data', 'mdl'])# 获取本地设备数量,用于后续在数据划分、并行操作等方面根据设备数量进行相应处理,并打印设备相关信息方便查看
num_devices = jax.local_device_count()
print(f'num_devices: {num_devices}')
print(f'device kind: {jax.local_devices()[0].device_kind}')
jax_task = task_p
key, init_key = jax.random.split(key)# 以下两个函数用于处理训练批次数据和评估批次数据,主要功能是对数据的形状进行调整和整理
# 确保数据格式符合模型输入和后续处理的要求,例如将数据按照设备数量和批次大小等进行合理重塑
# 方便在分布式训练和评估过程中正确使用数据
# To correctly prepare a batch of data for model initialization (now that shape
# inference is merged), we take one devices*batch_size tensor tuple of data,
# slice out just one batch, then run the prepare_input_batch function over it.
def process_train_batch(batch):past_ts = batch[0].reshape(batch_size * num_ts, -1)actual_ts = batch[3].reshape(batch_size * num_ts, -1)return NestedMap(input_ts=past_ts, actual_ts=actual_ts)def process_eval_batch(batch):past_ts = batch[0]actual_ts = batch[3]return NestedMap(input_ts=past_ts, actual_ts=actual_ts)# 初始化模型状态,传入训练任务配置、初始化随机数键以及处理后的训练批次数据等信息
# 根据指定的检查点类型(这里是GDA类型)进行模型状态的初始化操作,得到初始的模型状态信息
jax_model_states, _ = trainer_lib.initialize_model_state(jax_task,init_key,process_train_batch(tbatch),checkpoint_type=checkpoint_types.CheckpointType.GDA,
)# 将预训练模型的参数设置为微调模型的初始权重,具体是将预训练模型的参数赋值给微调模型状态中对应核心层的参数部分
# 这样微调模型就可以在预训练的基础上进行进一步优化,加快收敛并利用预训练学到的通用特征表示
### Setting the initial model weights to the pretrained TimesFM parameters.
jax_model_states.mdl_vars['params']['core_layer'] = tfm._train_state.mdl_vars['params']
jax_vars = jax_model_states.mdl_vars
gc.collect()# 以下是模型微调的训练循环部分,定义了训练步和评估步的函数,在多个训练轮次(epoch)内循环进行训练和定期评估
# 通过早停机制(根据验证集损失决定是否提前停止训练)来避免过拟合,在每个训练步中更新模型参数,每个评估步计算验证集上的损失并保存最优模型状态的检查点
### Training loop# 将之前配置好的训练任务(task_p)赋值给jax_task变量,后续在训练和评估步骤中会用到这个任务配置信息
jax_task = task_p# 定义训练步函数train_step,该函数内部调用了trainer_lib.train_step_single_learner函数,用于执行单个学习器的一次训练步骤
# 它接收当前的模型状态(states)、伪随机数生成器的键(prng_key)以及输入数据(inputs)作为参数,返回训练后的模型状态等相关信息
def train_step(states, prng_key, inputs):return trainer_lib.train_step_single_learner(jax_task, states, prng_key, inputs)# 定义评估步函数eval_step,首先将模型状态转换为评估状态(通过to_eval_state方法),然后调用trainer_lib.eval_step_single_learner函数执行评估步骤
# 同样接收模型状态、伪随机数生成器的键以及输入数据作为参数,返回评估相关的结果信息(例如损失值等)
def eval_step(states, prng_key, inputs):states = states.to_eval_state()return trainer_lib.eval_step_single_learner(jax_task, states, prng_key, inputs)# 对初始的随机数生成器的键(key)进行分割,生成三个新的随机数生成器的键,分别用于后续的训练、评估以及其他可能的操作
# 这样可以保证在不同的步骤中使用不同的随机数流,便于控制随机性和复现实验结果
key, train_key, eval_key = jax.random.split(key, 3)# 根据本地设备数量(jax.local_device_count()),将训练用的随机数生成器的键(train_key)分割成多个子键,每个子键对应一个设备
# 用于在分布式训练中为每个设备提供独立的随机数种子,确保随机性在不同设备上的正确应用
train_prng_seed = jax.random.split(train_key, num=jax.local_device_count())# 同理,将评估用的随机数生成器的键(eval_key)也按照本地设备数量进行分割,为每个设备在评估过程中提供独立的随机数种子
eval_prng_seed = jax.random.split(eval_key, num=jax.local_device_count())# 使用jax.pmap对训练步函数(train_step)进行并行化处理,指定按照'batch'轴进行并行,使得训练可以在多个设备上并行执行,提高训练效率
p_train_step = jax.pmap(train_step, axis_name='batch')# 同样地,对评估步函数(eval_step)进行并行化处理,使其能在多个设备上并行执行评估操作,也是按照'batch'轴进行并行
p_eval_step = jax.pmap(eval_step, axis_name='batch')# 对初始的模型状态(jax_model_states)进行复制操作,以适配分布式训练环境,使得每个设备都有一份相同的初始模型状态副本
# 后续每个设备可以基于这份副本进行独立的参数更新等操作,最终再进行汇总等处理
replicated_jax_states = trainer_lib.replicate_model_state(jax_model_states)# 获取复制后的模型状态中的变量部分(mdl_vars),方便后续在训练和评估过程中对模型参数等变量进行操作和访问
replicated_jax_vars = replicated_jax_states.mdl_vars# 初始化最优验证集损失值为一个较大的数(1e7),在训练过程中,一旦发现更小的验证集损失值,就会更新这个最优值,并保存对应的模型状态
best_eval_loss = 1e7# 记录当前已经执行的训练步数,初始化为0,随着训练循环的进行,每执行一次训练步就会加1,用于判断是否达到定期评估的步数等条件
step_count = 0# 早停机制相关的耐心值(patience),初始化为0,代表目前还没有出现验证集损失不再下降的情况
# 当验证集损失连续多次(由PATIENCE变量定义)没有下降时,就会触发早停机制,提前结束训练
patience = 0# 设定总的训练轮次(epoch)数量,这里设置为100,表示模型将对整个训练数据集完整遍历100次,可根据实际情况调整该值
NUM_EPOCHS = 100# 设定早停机制中的耐心值,即验证集损失连续多少次没有下降就触发早停,这里设置为5,意味着如果连续5次评估验证集损失都没有变小,就停止训练
PATIENCE = 5# 定义每经过多少个训练步就进行一次模型在验证集上的评估操作,这里设置为1000步评估一次,用于定期监控模型在验证集上的性能表现
TRAIN_STEPS_PER_EVAL = 1000# 指定保存模型检查点的目录路径,训练过程中,当发现当前模型在验证集上的性能更好(验证集损失更小)时,会将模型状态保存到这个目录下
CHECKPOINT_DIR = '/home/senrajat_google_com/vehicle_sales_finetune'# 定义一个函数reshape_batch_for_pmap,用于根据设备数量对批次数据进行形状重塑,使其能正确地在分布式训练环境下分配到各个设备上
# 具体操作是将输入张量的第一个维度(通常是批次大小维度)按照设备数量进行划分,重新调整张量的形状
def reshape_batch_for_pmap(batch, num_devices):def _reshape(input_tensor):bsize = input_tensor.shape[0]residual_shape = list(input_tensor.shape[1:])nbsize = bsize // num_devicesreturn jnp.reshape(input_tensor, [num_devices, nbsize] + residual_shape)return jax.tree.map(_reshape, batch)# 外层循环,按照设定的训练轮次(NUM_EPOCHS)进行循环训练,每个epoch代表对整个训练数据集的一次完整遍历
for epoch in range(NUM_EPOCHS):# 打印当前所处的训练轮次信息,方便在训练过程中查看训练进度,flush=True用于立即输出信息,不缓冲print(f"__________________Epoch: {epoch}__________________", flush=True)# 获取训练集批次数据的迭代器,用于在当前epoch内逐个批次地遍历训练数据train_its = train_batches.as_numpy_iterator()# 判断如果早停的耐心值(patience)达到设定的阈值(PATIENCE),则触发早停机制,结束训练if patience >= PATIENCE:print("Early stopping.", flush=True)break# 内层循环,遍历当前epoch的每个训练批次数据for batch in tqdm(train_its):train_losses = []# 再次检查早停条件,若满足则提前停止当前批次的训练if patience >= PATIENCE:print("Early stopping.", flush=True)break# 调用函数处理训练批次数据,主要是对数据形状进行调整,使其符合模型训练输入要求tbatch = process_train_batch(batch)# 根据设备数量对处理后的批次数据进行重塑,以便在分布式训练环境下能正确分配到各个设备上进行并行计算tbatch = reshape_batch_for_pmap(tbatch, num_devices)# 执行分布式训练的一个训练步,传入当前模型状态、训练随机数种子以及处理好的批次数据# 返回更新后的模型状态以及包含训练损失等信息的输出结果(step_fun_out)replicated_jax_states, step_fun_out = p_train_step(replicated_jax_states, train_prng_seed, tbatch)# 将当前训练步的损失值添加到训练损失列表(train_losses)中,后续可以用于计算平均训练损失等操作train_losses.append(step_fun_out.loss[0])# 判断当前训练步数是否达到了设定的定期评估步数(TRAIN_STEPS_PER_EVAL),如果达到则进行模型在验证集上的评估操作if step_count % TRAIN_STEPS_PER_EVAL == 0:# 打印当前训练步数下的平均训练损失值,方便查看训练过程中模型在训练集上的损失变化情况,flush=True用于立即输出信息print(f"Train loss at step {step_count}: {np.mean(train_losses)}",flush=True,)# 清空训练损失列表,为下一个评估周期准备,避免累计之前的损失值影响下一次平均损失的计算train_losses = []# 打印提示信息,表示开始进行模型在验证集上的评估操作print("Starting eval.", flush=True)# 获取验证集批次数据的迭代器,用于在验证过程中逐个批次地遍历验证数据val_its = val_batches.as_numpy_iterator()# 初始化用于存储每个验证批次损失值的列表,用于后续计算平均验证集损失eval_losses = []# 遍历验证集的每个批次数据for ev_batch in tqdm(val_its):# 调用函数处理验证批次数据,对数据形状进行调整,使其符合模型评估输入要求ebatch = process_eval_batch(ev_batch)# 根据设备数量对处理后的验证批次数据进行重塑,适配分布式评估环境ebatch = reshape_batch_for_pmap(ebatch, num_devices)# 执行分布式评估的一个评估步,传入当前模型状态、评估随机数种子以及处理好的验证批次数据# 返回包含评估损失等信息的输出结果(这里只关心损失值,所以用下划线忽略其他返回信息)_, step_fun_out = p_eval_step(replicated_jax_states, eval_prng_seed, ebatch)# 将当前验证批次的损失值添加到验证损失列表(eval_losses)中eval_losses.append(step_fun_out.loss[0])# 计算平均验证集损失值,通过对验证损失列表中的所有损失值求平均得到mean_loss = np.mean(eval_losses)# 打印当前训练步数下的平均验证集损失值,方便查看模型在验证集上的性能表现,flush=True用于立即输出信息print(f"Eval loss at step {step_count}: {mean_loss}", flush=True)# 判断当前平均验证集损失值是否小于之前记录的最优验证集损失值(best_eval_loss),或者是否为NaN(表示出现异常情况)# 如果满足条件,说明当前模型在验证集上的性能更好,需要保存当前的模型状态作为最优模型状态if mean_loss < best_eval_loss or np.isnan(mean_loss):# 更新最优验证集损失值为当前的平均验证集损失值best_eval_loss = mean_loss# 打印提示信息,表示正在保存模型检查点print("Saving checkpoint.")# 对复制后的模型状态进行处理,将其转换为适合保存的格式(可能涉及去除一些分布式相关的冗余信息等操作)jax_state_for_saving = py_utils.maybe_unreplicate_for_fully_replicated(replicated_jax_states)# 调用函数保存模型检查点,将处理后的模型状态保存到指定的目录(CHECKPOINT_DIR)下,并且设置覆盖已存在的同名检查点checkpoints.save_checkpoint(jax_state_for_saving, CHECKPOINT_DIR, overwrite=True)# 将早停机制的耐心值重置为0,因为当前模型性能有提升,重新开始计算耐心值patience = 0# 删除已经保存的模型状态变量,释放内存空间,避免内存占用过多del jax_state_for_saving# 手动触发垃圾回收,及时回收不再使用的内存,优化内存使用情况gc.collect()# 如果当前平均验证集损失值没有小于最优验证集损失值,说明模型在验证集上的性能没有提升,则增加早停机制的耐心值else:patience += 1# 打印当前的耐心值,方便查看早停机制的触发进度情况print(f"patience: {patience}")# 每执行完一个训练步,训练步数加1,用于跟踪训练的进度以及判断是否达到定期评估等条件step_count += 1# 以下代码用于加载根据验证集损失选出的最优微调后的模型检查点,并在测试集上对其进行评估,计算平均绝对误差(MAE)来衡量模型性能
## Loading and evaluating the best (according to validation loss) finetuned checkpoint# 调用函数从指定的目录(CHECKPOINT_DIR)中恢复之前保存的最优模型状态,将其赋值给train_state变量,用于后续的模型参数更新和评估操作
train_state = checkpoints.restore_checkpoint(jax_model_states, CHECKPOINT_DIR)# 打印恢复的模型状态对应的训练步数信息,可用于查看加载的是哪个阶段保存的模型状态
print(train_state.step)# 将微调后模型的参数更新到原TimesFM模型中,具体是将恢复的模型状态中的核心层参数赋值给原TimesFM模型的对应参数部分
# 使得原模型可以使用微调后的参数进行预测等操作,用于在测试集上评估微调后的模型性能
tfm._train_state.mdl_vars['params'] = train_state.mdl_vars['params']['core_layer']# 对TimesFM模型执行即时编译(jit)相关的解码操作,可能是为了优化模型在后续预测过程中的性能,加快预测速度
tfm.jit_decode()# 初始化用于存储测试集上每个批次预测结果与实际结果的平均绝对误差(MAE)的列表,后续将通过循环计算并填充该列表
mae_losses = []# 遍历测试集的每个批次数据,用于计算在整个测试集上的平均绝对误差(MAE)
for batch in tqdm(test_batches.as_numpy_iterator()):# 获取当前批次的输入数据(通常是历史时间序列数据等),作为模型预测的输入past = batch[0]# 获取当前批次的实际值(真实的目标数据,例如实际销售量、销售价格等),用于与模型预测结果进行对比计算误差actuals = batch[3]# 使用更新参数后的TimesFM模型进行预测,传入当前批次的输入数据以及一些相关的辅助参数(这里辅助参数都设置为0,具体含义可能取决于模型的定义)# 返回预测结果(forecasts)以及其他可能的相关信息(这里用下划线忽略)_, forecasts = tfm.forecast(list(past), [0] * past.shape[0])# 对预测结果进行维度处理,选取与实际值维度对应的部分,确保两者可以正确地进行误差计算(这里假设实际值和预测值在维度上需要进行一定的对齐操作)forecasts = forecasts[:, 0 : actuals.shape[1], 5]# 计算当前批次预测结果与实际结果的平均绝对误差(MAE),通过计算预测值与实际值差值的绝对值的平均值得到# 将每个批次的MAE值添加到mae_losses列表中,后续可以通过求平均得到整个测试集上的平均MAE值mae_losses.append(np.abs(forecasts - actuals).mean())print(f"MAE: {np.mean(mae_losses)}")
请注意:
- 上述代码中的路径等相关设置(如
CHECKPOINT_DIR
、数据文件路径等)需要根据你的实际运行环境进行调整,确保可以正确读写文件以及保存和加载模型检查点。 - 代码中关于模型的一些超参数(如学习率、训练轮数、批处理大小等)都是示例值,你可能需要根据车辆销售数据的特点、模型训练情况等进行多次试验和调整,以获得更好的预测性能。
- 假设数据文件
vehicle_sales.csv
的格式是比较规范的,能被pandas
的read_csv
方法正常读取,如果实际数据有特殊格式(例如包含标题行、特定的日期格式、缺失值表示等情况),可能需要进一步对数据读取部分进行修改完善。
相关文章:
使用TimesFM 对车辆销售进行预测
代码功能概述 导入相关包与设置环境变量: 首先导入了如 os、numpy、pandas 等常用的 Python 库,同时设置了一些与特定库(如 XLA_PYTHON_CLIENT_PREALLOCATE 和 JAX_PM AP_USE_TENSORSTORE)相关的环境变量,用于优化计算…...
OpenEuler 22.03 不依赖zookeeper安装 kafka 3.3.2集群
零:规划 本次计划安装三台OpenEuler 22.03 版本操作系统的服务器,用于搭建 kafka和flink 集群。因为从kafka 2.8 版本以后开始不依赖 zookeeper ,同时考虑到需要找一个发布时间早于 flink 1.17 的kafka 版本且应尽量稳定,综合考虑…...

ubuntu 将python3.8 升级为python3.10并进行版本切换
ubuntu 将python3.8 升级为python3.10并进行版本切换 前言将python3.8 升级为3.10安装pippython版本切换 前言 有一个功能包编译环境需要为python3.10 ,但是当前环境为python3.8 ,所以需要进行版本升级,编译完还需要把环境切换回来。 将pyt…...

3. Kafka入门—安装与基本命令
Kafka基础操作 一. 章节简介二. kafka简介三. Kafka安装1. 准备工作2. Zookeeper安装2.1 配置文件2.2 启动相关命令3. Kafka安装3.1 配置文件3.2 启动相关命令-------------------------------------------------------------------------------------------------------------…...

如何使用 python创建图片格式转换器
在本篇博客中,我们将通过一个简单的实例来展示如何使用 wxPython 创建一个图形用户界面(GUI)应用程序,用于将图片从一种格式转换为另一种格式。我们将通过以下几个步骤实现这一目标: C:\pythoncode\new\imageconvertty…...

命令行之巅:Linux Shell编程的至高艺术(上)
文章一览 前言一、shell概述1.1 shell的特点和类型1.1.1 **shell的特点:**1.1.2 常用shell类型 1.2 shell脚本的建立和执行1.2.1 建立shell脚本1.2.2 执行shell脚本的方式1.2.3 shell程序实例 二、shell变量与算数运算2.1 简单shell变量2.1.1 简单变量定义和赋值2.1…...
【gulp】gulp 的基本使用
gulp 是一个基于node的自动化打包构建工具,前端开发者可以使用它来处理常见任务: 创建项目 进入项目 npm init -ynpm i gulp -g (使用命令 gulp)npm i gulp -D # 开发依赖(前端工具都是开发依赖 本地安装 代…...
Linux 下处理 ^M 字符的最佳实践
Linux 下处理 ^M 字符的最佳实践 一、快速解决方案 按照优先级排序的三种解决方案: 1. 使用 dos2unix(推荐) # 安装 sudo apt-get install dos2unix # Ubuntu/Debian sudo yum install dos2unix # CentOS# 使用 dos2unix 文件名2. 使用 sed sed...

【优选算法】—复写零(双指针算法)
云边有个稻草人-CSDN博客 每天至少一道算法题,接着干,以额现在的实力想完成那个目标确实难。算法题确实烧脑,挺煎熬的,但脑子烧多了是不是就该好些了?。。。 记得那句话,必须有为成功付出代价的决心&#x…...
2024国赛A问题三和四
问题三 最小螺距单目标优化模型的建立 问题二考虑了在螺距固定的条件下计算舞龙队盘入的终止时间,问题三在第二问的基础提出了改变螺距的要求,即求解在螺距最小为多少时,龙头前把手能够沿着相应的螺线盘入到调头空间的边界。故可将其转换为…...

asp.net 高校学生勤工俭学系统设计与实现
博主介绍:专注于Java(springboot ssm 等开发框架) vue .net php python(flask Django) 小程序 等诸多技术领域和毕业项目实战、企业信息化系统建设,从业十五余年开发设计教学工作 ☆☆☆ 精彩专栏推荐订阅☆☆☆☆☆不然下次找…...

《计算机组成及汇编语言原理》阅读笔记:p116-p120
《计算机组成及汇编语言原理》学习第 7 天,p116-p120 总结,总计 5 页。 一、技术总结 1.CPU优化 (1)increase overall performance number 例如:16位电脑提升到32位电脑。 (2)multiprocessing One way to make computers more useful i…...

C# OpenCvSharp DNN 卡证检测矫正
目录 说明 效果 模型 项目 代码 下载 参考 说明 源码地址:https://modelscope.cn/models/iic/cv_resnet_carddetection_scrfd34gkps 在实人认证、文档电子化等场景中需要自动化提取卡证的信息,以便进一步做录入处理。这类场景通常存在两类问题&…...
Spring Boot 中 Map 的最佳实践
在Spring Boot中使用Map时,请遵循以下最佳实践: 1.避免在Controller中 直接使用Map。应该使用RequestBody 接收-个DTO对象或者 RequestParam接收参数,然后在Service中处 理Map。 2.避免在Service中 直接使用原始的Map。应该使用Autowired 注入-个专门…...
J-LangChain - 智能链构建
介绍 j-langchain是一个Java版的LangChain开发框架,旨在简化和加速各类大模型应用在Java平台的落地开发。它提供了一组实用的工具和类,使得开发人员能够更轻松地构建类似于LangChain的Java应用程序。 依赖 Maven <dependency><groupId>i…...

开源低代码平台-Microi吾码 打印引擎使用
引言 在开发中,会遇到很多记录的表单数据需要下载打印下来使用到线下各种应用场景中。在传统的方法中可能是需要先导出数据,然后将数据填入word表格中在打印下来。 但Microi吾码提供了一项新功能,便是打印引擎。打印引擎即可在线设计…...

【MySQL】索引 面试题
文章目录 适合创建索引的情况创建索引的注意事项MySQL中不适合创建索引的情况索引失效的常见情况 索引定义与作用 索引是帮助MySQL高效获取数据的有序数据结构,通过维护特定查找算法的数据结构(如B树),以某种方式引用数据…...

【高阶数据结构】AVL树
AVL树 1.AVL的概念2.AVL树的实现1.AVL树的结构2.AVL树的插入1.更新平衡因子2.旋转1.右单旋2.左单旋3.左右双旋4.右左双旋 3.AVL树的查找4.AVL树的平衡检测5.AVL树的性能分析6.AVL树的删除 3.总代码1.AVLTree.h2.Test.cpp 1.AVL的概念 AVL树是最先发明的自平衡⼆叉查找树&#…...
【Spring】基于XML的Spring容器配置——<bean>标签与属性解析
Spring框架是一个非常流行的应用程序框架,它通过控制反转(IoC)和依赖注入(DI)来简化企业级应用的开发。Spring容器是其核心部分,负责管理对象的创建、配置和生命周期。在Spring中,XML配置是一种…...
docker mysql5.7安装
一.更改 /etc/docker/daemon.json sudo mkdir -p /etc/dockersudo tee /etc/docker/daemon.json <<-EOF {"registry-mirrors": ["https://do.nark.eu.org","https://dc.j8.work","https://docker.m.daocloud.io","https:/…...

docker详细操作--未完待续
docker介绍 docker官网: Docker:加速容器应用程序开发 harbor官网:Harbor - Harbor 中文 使用docker加速器: Docker镜像极速下载服务 - 毫秒镜像 是什么 Docker 是一种开源的容器化平台,用于将应用程序及其依赖项(如库、运行时环…...

树莓派超全系列教程文档--(61)树莓派摄像头高级使用方法
树莓派摄像头高级使用方法 配置通过调谐文件来调整相机行为 使用多个摄像头安装 libcam 和 rpicam-apps依赖关系开发包 文章来源: http://raspberry.dns8844.cn/documentation 原文网址 配置 大多数用例自动工作,无需更改相机配置。但是,一…...

【大模型RAG】Docker 一键部署 Milvus 完整攻略
本文概要 Milvus 2.5 Stand-alone 版可通过 Docker 在几分钟内完成安装;只需暴露 19530(gRPC)与 9091(HTTP/WebUI)两个端口,即可让本地电脑通过 PyMilvus 或浏览器访问远程 Linux 服务器上的 Milvus。下面…...

HTML 列表、表格、表单
1 列表标签 作用:布局内容排列整齐的区域 列表分类:无序列表、有序列表、定义列表。 例如: 1.1 无序列表 标签:ul 嵌套 li,ul是无序列表,li是列表条目。 注意事项: ul 标签里面只能包裹 li…...
JVM垃圾回收机制全解析
Java虚拟机(JVM)中的垃圾收集器(Garbage Collector,简称GC)是用于自动管理内存的机制。它负责识别和清除不再被程序使用的对象,从而释放内存空间,避免内存泄漏和内存溢出等问题。垃圾收集器在Ja…...
【HTML-16】深入理解HTML中的块元素与行内元素
HTML元素根据其显示特性可以分为两大类:块元素(Block-level Elements)和行内元素(Inline Elements)。理解这两者的区别对于构建良好的网页布局至关重要。本文将全面解析这两种元素的特性、区别以及实际应用场景。 1. 块元素(Block-level Elements) 1.1 基本特性 …...
力扣-35.搜索插入位置
题目描述 给定一个排序数组和一个目标值,在数组中找到目标值,并返回其索引。如果目标值不存在于数组中,返回它将会被按顺序插入的位置。 请必须使用时间复杂度为 O(log n) 的算法。 class Solution {public int searchInsert(int[] nums, …...

【电力电子】基于STM32F103C8T6单片机双极性SPWM逆变(硬件篇)
本项目是基于 STM32F103C8T6 微控制器的 SPWM(正弦脉宽调制)电源模块,能够生成可调频率和幅值的正弦波交流电源输出。该项目适用于逆变器、UPS电源、变频器等应用场景。 供电电源 输入电压采集 上图为本设计的电源电路,图中 D1 为二极管, 其目的是防止正负极电源反接, …...

深度学习水论文:mamba+图像增强
🧀当前视觉领域对高效长序列建模需求激增,对Mamba图像增强这方向的研究自然也逐渐火热。原因在于其高效长程建模,以及动态计算优势,在图像质量提升和细节恢复方面有难以替代的作用。 🧀因此短时间内,就有不…...
从面试角度回答Android中ContentProvider启动原理
Android中ContentProvider原理的面试角度解析,分为已启动和未启动两种场景: 一、ContentProvider已启动的情况 1. 核心流程 触发条件:当其他组件(如Activity、Service)通过ContentR…...