当前位置: 首页 > article >正文

从GitHub热门项目到实战:手把手教你复现一篇ICLR‘24时间序列预测论文(附完整代码)

从GitHub热门项目到实战手把手教你复现一篇ICLR24时间序列预测论文附完整代码在人工智能领域前沿论文与开源代码的结合正成为推动技术进步的重要动力。GitHub上涌现出大量包含顶会论文和配套实现的仓库如AI4TS这样的专业资源库为研究者提供了宝贵的学习材料。然而面对海量的论文和代码许多工程师和研究生常常感到无从下手——如何从论文阅读者转变为代码实践者将理论转化为可运行的解决方案本文将聚焦ICLR 2024最新时间序列预测论文iTransformer带你完成从环境配置到结果复现的全流程实战。1. 论文选择与环境准备选择适合复现的论文是成功的第一步。ICLR 2024的iTransformer《Inverted Transformers Are Effective for Time Series Forecasting》因其创新的架构设计和出色的性能表现成为理想选择。该论文提出了倒置Transformer的概念通过调整传统Transformer的注意力机制和位置编码方式显著提升了长序列预测的准确性。复现环境配置步骤硬件要求GPUNVIDIA RTX 3090或更高至少24GB显存内存32GB以上存储100GB可用空间用于数据集和模型缓存软件依赖# 创建conda环境 conda create -n itransformer python3.9 conda activate itransformer # 安装PyTorch pip install torch1.13.1cu116 torchvision0.14.1cu116 torchaudio0.13.1 --extra-index-url https://download.pytorch.org/whl/cu116 # 安装其他依赖 pip install numpy pandas scikit-learn matplotlib tqdm tensorboard关键版本控制库名称版本重要性说明PyTorch1.13.1必须匹配CUDA 11.6CUDA Toolkit11.6与PyTorch版本强关联cuDNN8.4.0影响GPU计算效率注意环境配置是复现过程中最容易出错的环节。建议先在小规模数据上验证环境是否正确再开展完整实验。2. 数据准备与预处理iTransformer论文使用了多个标准时间序列数据集进行验证包括ETTElectricity Transformer Temperature、Weather和Traffic等。我们将以ETTh1电力变压器温度每小时数据为例展示完整的数据处理流程。数据获取与清洗import pandas as pd from sklearn.preprocessing import StandardScaler # 加载原始数据 data pd.read_csv(ETTh1.csv) # 处理缺失值 data.fillna(methodffill, inplaceTrue) # 前向填充 data.dropna(inplaceTrue) # 删除剩余缺失值 # 标准化处理 scaler StandardScaler() scaled_data scaler.fit_transform(data[[HUFL,HULL,MUFL,MULL,LUFL,LULL,OT]])数据集划分策略训练集2016年7月-2021年6月60个月验证集2021年7月-2021年12月6个月测试集2022年1月-2022年7月7个月滑动窗口生成import torch from torch.utils.data import Dataset class TimeSeriesDataset(Dataset): def __init__(self, data, seq_len, pred_len): self.data data self.seq_len seq_len self.pred_len pred_len def __getitem__(self, index): x self.data[index:indexself.seq_len] y self.data[indexself.seq_len:indexself.seq_lenself.pred_len] return torch.FloatTensor(x), torch.FloatTensor(y) def __len__(self): return len(self.data) - self.seq_len - self.pred_len 1 # 参数设置与论文一致 seq_len 96 # 输入序列长度 pred_len 336 # 预测长度 batch_size 32 # 创建数据加载器 train_dataset TimeSeriesDataset(train_data, seq_len, pred_len) train_loader torch.utils.data.DataLoader(train_dataset, batch_sizebatch_size, shuffleTrue)3. 模型实现详解iTransformer的核心创新在于其倒置设计——与传统Transformer处理序列的方式不同它将时间点视为token将变量维度视为序列。这种结构特别适合多元时间序列预测任务。关键组件实现倒置注意力层class InvertedAttention(nn.Module): def __init__(self, d_model, n_heads, dropout0.1): super().__init__() self.d_model d_model self.n_heads n_heads self.head_dim d_model // n_heads self.qkv nn.Linear(d_model, d_model * 3) self.dropout nn.Dropout(dropout) self.proj nn.Linear(d_model, d_model) def forward(self, x): B, L, D x.shape # Batch, SeqLen, Dim # 倒置处理将维度作为序列 x x.transpose(1, 2) # [B, D, L] qkv self.qkv(x).reshape(B, D, 3, self.n_heads, self.head_dim) q, k, v qkv.unbind(2) # [B, D, n_heads, head_dim] attn (q k.transpose(-2, -1)) / math.sqrt(self.head_dim) attn attn.softmax(dim-1) attn self.dropout(attn) out (attn v).transpose(1, 2).reshape(B, D, L) out self.proj(out) return out.transpose(1, 2) # 恢复原始维度完整模型架构class iTransformer(nn.Module): def __init__(self, enc_in7, dec_in7, c_out7, seq_len96, pred_len336, d_model512, n_heads8, e_layers3, d_ff2048, dropout0.05): super().__init__() self.pred_len pred_len # 编码器 self.enc_embedding DataEmbedding(enc_in, d_model, dropout) self.encoder Encoder( [ EncoderLayer( InvertedAttention(d_model, n_heads, dropoutdropout), d_model, d_ff, dropoutdropout ) for _ in range(e_layers) ] ) # 解码器 self.dec_embedding DataEmbedding(dec_in, d_model, dropout) self.projection nn.Linear(d_model, c_out, biasTrue) def forward(self, x_enc, x_dec): enc_out self.enc_embedding(x_enc) enc_out self.encoder(enc_out) dec_out self.dec_embedding(x_dec) dec_out self.projection(dec_out) return dec_out[:, -self.pred_len:, :]自定义数据嵌入层class DataEmbedding(nn.Module): def __init__(self, c_in, d_model, dropout0.1): super().__init__() self.value_embedding nn.Linear(c_in, d_model) self.position_embedding PositionalEncoding(d_model) self.dropout nn.Dropout(dropout) def forward(self, x): x self.value_embedding(x) self.position_embedding(x) return self.dropout(x)4. 训练技巧与调优策略成功复现顶会论文不仅需要正确实现模型还需要掌握关键的训练技巧。以下是经过验证的有效方法分阶段训练策略预热阶段前10%训练步数使用较低学习率1e-4只更新嵌入层和最后一层参数目标建立稳定的特征表示主体训练阶段学习率5e-4使用余弦退火调度批量大小32根据显存调整梯度裁剪max_norm3.0微调阶段最后5%训练步数学习率降至1e-5冻结部分层如底层编码器只更新高层网络参数关键训练代码from torch.optim.lr_scheduler import CosineAnnealingLR model iTransformer().cuda() criterion nn.MSELoss() optimizer torch.optim.AdamW(model.parameters(), lr5e-4) scheduler CosineAnnealingLR(optimizer, T_max100, eta_min1e-5) for epoch in range(100): model.train() for x, y in train_loader: x, y x.cuda(), y.cuda() # 前向传播 outputs model(x, x[:, -96:, :]) loss criterion(outputs, y) # 反向传播 optimizer.zero_grad() loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), 3.0) optimizer.step() scheduler.step() # 验证集评估 model.eval() with torch.no_grad(): val_loss evaluate(model, val_loader, criterion) print(fEpoch {epoch}: Train Loss {loss.item():.4f}, Val Loss {val_loss:.4f})常见问题解决方案问题现象可能原因解决方案验证损失震荡大学习率过高降低学习率增加warmup阶段训练损失不下降梯度消失/爆炸检查初始化添加LayerNormGPU内存不足批量过大/序列过长减小batch_size或使用梯度累积预测结果全为均值损失函数设计问题尝试MAE损失或分位数损失过拟合严重模型容量过大增加Dropout添加L2正则提示使用TensorBoard或Weights Biases记录训练过程可视化损失曲线和注意力权重分布这对调试模型非常有用。5. 结果复现与性能对比完成模型训练后我们需要在测试集上评估性能并与论文报告的结果进行对比。iTransformer论文中报告的主要指标包括MSE均方误差和MAE平均绝对误差。评估代码实现def evaluate(model, data_loader, metrics[mse, mae]): model.eval() total_loss {m: 0 for m in metrics} count 0 with torch.no_grad(): for x, y in data_loader: x, y x.cuda(), y.cuda() outputs model(x, x[:, -96:, :]) if mse in metrics: total_loss[mse] F.mse_loss(outputs, y).item() * x.size(0) if mae in metrics: total_loss[mae] F.l1_loss(outputs, y).item() * x.size(0) count x.size(0) return {k: v/count for k,v in total_loss.items()} # 测试集评估 test_metrics evaluate(model, test_loader) print(fTest MSE: {test_metrics[mse]:.4f}, MAE: {test_metrics[mae]:.4f})ETTh1数据集上的预期结果预测长度论文报告MSE复现MSE (预期)允许误差范围960.3840.39-0.42±10%1920.4360.44-0.48±10%3360.4970.50-0.55±10%7200.5630.57-0.62±10%如果复现结果与论文差异超过15%建议检查以下方面数据预处理是否完全一致特别是归一化方法模型超参数层数、注意力头数等是否正确训练策略学习率调度、正则化等是否匹配随机种子是否固定影响初始化可视化预测结果import matplotlib.pyplot as plt def plot_predictions(model, dataset, num_samples3): fig, axes plt.subplots(num_samples, 1, figsize(15, 5*num_samples)) for i in range(num_samples): idx torch.randint(0, len(dataset), (1,)).item() x, y dataset[idx] with torch.no_grad(): pred model(x.unsqueeze(0).cuda(), x[-96:].unsqueeze(0).cuda()) pred pred.cpu().squeeze() axes[i].plot(y[:, -1], labelGround Truth) axes[i].plot(pred[:, -1], labelPrediction) axes[i].set_title(fSample {i1}) axes[i].legend() plt.tight_layout() plt.show() plot_predictions(model, test_dataset)6. 进阶优化与迁移实践成功复现基础模型后可以考虑以下方向进行优化和改进性能优化技巧混合精度训练from torch.cuda.amp import autocast, GradScaler scaler GradScaler() for x, y in train_loader: optimizer.zero_grad() with autocast(): outputs model(x.cuda(), x[:, -96:, :].cuda()) loss criterion(outputs, y.cuda()) scaler.scale(loss).backward() scaler.step(optimizer) scaler.update()模型量化部署优化quantized_model torch.quantization.quantize_dynamic( model, {torch.nn.Linear}, dtypetorch.qint8 ) torch.save(quantized_model.state_dict(), quantized_itransformer.pth)自定义注意力机制class EfficientInvertedAttention(nn.Module): def __init__(self, d_model, n_heads, dropout0.1): super().__init__() # 实现更高效的低秩注意力变体 pass迁移到新数据集 当将iTransformer应用于其他时间序列数据如金融、医疗等领域时需要注意数据特性分析检查序列的周期性、趋势性分析变量间的相关性确定合适的输入/输出尺度必要的架构调整修改嵌入层维度调整注意力头数根据变量数量优化位置编码方式对非均匀采样数据领域适配技巧添加领域特定的特征工程使用迁移学习预训练微调设计领域相关的损失函数# 金融时间序列适配示例 class FinancialiTransformer(iTransformer): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) # 添加技术指标嵌入 self.tech_indicator nn.Linear(5, kwargs[d_model]) # 5个常用技术指标 def forward(self, x_enc, x_dec, indicators): enc_out self.enc_embedding(x_enc) self.tech_indicator(indicators) # 其余部分保持不变 ...通过以上步骤我们不仅完成了iTransformer论文的复现还掌握了将前沿时间序列预测模型应用于实际问题的完整方法论。这种从论文到实践的转化能力正是现代AI工程师和研究者的核心竞争力。

相关文章:

从GitHub热门项目到实战:手把手教你复现一篇ICLR‘24时间序列预测论文(附完整代码)

从GitHub热门项目到实战:手把手教你复现一篇ICLR24时间序列预测论文(附完整代码) 在人工智能领域,前沿论文与开源代码的结合正成为推动技术进步的重要动力。GitHub上涌现出大量包含顶会论文和配套实现的仓库,如AI4TS这…...

香熏哪个更值得推荐

在快节奏的现代生活中,香薰已成为许多人放松心情、提升生活品质的重要方式。然而,市面上的香薰产品琳琅满目,如何选择一款既安全又高效的香薰呢?本文将从多个角度分析,为什么树边香氛更值得推荐。1. 天然植萃&#xff…...

基于R语言的自动数据收集:网络抓取和文本挖掘实用指南【1.8】

3.6 JSON文档示例在本节,我们要熟悉数据交换标准JSON的优点。这个首字母缩写(发音是“Jason”)代表JavaScript对象标记(JavaScript Object Notation)。JSON的设计和XML如出一辙,两者通常都是用来存储和交换…...

基于R语言的自动数据收集:网络抓取和文本挖掘实用指南【1.7】

3.5 XML和R的实践现在让我们转到实际例子。XML文件在R会话中如何查看、如何导入、如何访问,以及如何把来自XML文档的信息转化为更便于进一步图形化或统计化分析的数据结构,例如常规的数据框(data frame)呢?正如我们前面…...

基于R语言的自动数据收集:网络抓取和文本挖掘实用指南【1.6】

3.2.4 注释及字符数据XML的语法提供了一种对内容进行注释的方式在<&#xff01;--和-->之间的所有内容都不被当作XML代码的一部分&#xff0c;从而会被解析器所忽略。注释可以用在标签之间或元素内容之内&#xff0c;但不能在元素名或属性名的内部使用。在数据值中有较多…...

JDK 1.8 vs JDK 17:jvisualvm 安装配置全攻略(附Visual GC插件避坑指南)

JDK 1.8 vs JDK 17&#xff1a;jvisualvm 安装配置全攻略&#xff08;附Visual GC插件避坑指南&#xff09; 在Java开发的世界里&#xff0c;JVM性能调优一直是开发者进阶的必修课。而jvisualvm作为Oracle官方提供的免费性能分析工具&#xff0c;可以说是我们窥探JVM内部运行状…...

机器学习实践指南【1.0】

第1章 机器学习引言本章将介绍机器学习及其涵盖的多个话题。你将了解以下内容&#xff1a;什么是机器学习分类方法概述聚类方法概述模型的选择和正则化概述非线性方法概述监督学习概述无监督学习概述增强学习概述结构化预测概述神经网络概述深度学习概述1.1 什么是机器学习人类…...

极验滑动验证码自动化实战:背景提取、缺口定位与Playwright滑动模拟

滑动验证码自动化实战&#xff1a;背景提取、缺口定位与Playwright滑动模拟 一、前言 在爬虫自动化、Web端自动化测试、业务流程自动化等场景中&#xff0c;人机验证是保障系统安全的重要防线&#xff0c;也是自动化流程中最常见的“拦路虎”。极验&#xff08;Geetest&#…...

OpenAI Agents SDK 中文实战指南:从入门到多代理协作

1. 为什么你需要OpenAI Agents SDK 第一次接触这个SDK时&#xff0c;我正为一个客户设计智能客服系统。传统方案需要写大量if-else逻辑判断用户意图&#xff0c;而Agents SDK的多代理协作机制让我眼前一亮——就像组建了一支各有所长的AI团队&#xff0c;数学问题自动转交数学专…...

OpenClaw安全加固:Phi-3-vision服务接口的权限控制实践

OpenClaw安全加固&#xff1a;Phi-3-vision服务接口的权限控制实践 1. 为什么需要安全加固&#xff1f; 上周我在本地部署了Phi-3-vision多模态模型&#xff0c;通过OpenClaw实现了一个智能图片分析工作流。但当我用手机测试时&#xff0c;意外发现任何人都能通过公网IP访问我…...

测试小白福音:在快马上通过实战代码轻松攻克软件测试面试题

作为一名刚入门的软件测试新手&#xff0c;面对各种面试题时常常感到一头雾水。最近我发现了一个特别实用的学习方法 - 通过动手实践来理解测试理论。今天就来分享一下我的经验。 从基础概念入手 刚开始学习时&#xff0c;我连黑盒测试和白盒测试的区别都搞不清楚。后来发现&…...

国内网站 SEO 推广需要多长时间见效

国内网站 SEO 推广需要多长时间见效 在当今互联网时代&#xff0c;搜索引擎优化&#xff08;SEO&#xff09;已经成为提升国内网站流量和品牌知名度的关键手段。很多人都会问&#xff0c;国内网站 SEO 推广需要多长时间才能见效&#xff1f;答案并不简单&#xff0c;因为这涉及…...

2026届必备的十大降重复率工具实测分析

Ai论文网站排名&#xff08;开题报告、文献综述、降aigc率、降重综合对比&#xff09; TOP1. 千笔AI TOP2. aipasspaper TOP3. 清北论文 TOP4. 豆包 TOP5. kimi TOP6. deepseek 维普 AIGC 检测系统&#xff0c;是特意为学术机构还有研究者用心设计的&#xff0c;它的主要…...

2026届学术党必备的十大降重复率工具推荐

Ai论文网站排名&#xff08;开题报告、文献综述、降aigc率、降重综合对比&#xff09; TOP1. 千笔AI TOP2. aipasspaper TOP3. 清北论文 TOP4. 豆包 TOP5. kimi TOP6. deepseek 当下&#xff0c;各种各样的降AI工具纷纷出现&#xff0c;其关键功能是借助文本改写、句式重…...

Mac开发者必备:OpenClaw对接Qwen3-32B实现Xcode日志自动分析

Mac开发者必备&#xff1a;OpenClaw对接Qwen3-32B实现Xcode日志自动分析 1. 为什么需要自动化Xcode日志分析 作为一个长期与Xcode打交道的iOS开发者&#xff0c;我每天至少有2小时耗在编译错误和运行时日志的排查上。那些冗长的符号化崩溃日志、晦涩的Swift类型推断错误、以及…...

无感方波控制方案-脉冲启动与凸极性电机保护功能全面标题:‘无感方波方案-无抖动无反转启动...

无感方波方案&#xff0c;无感启动无抖动&#xff0c;无反转&#xff0c;启动方式为脉冲注入检测位置&#xff0c;换相方式为AD比较器&#xff0c;电机要有一定凸极性 &#xff0c;电机要有一定凸极性&#xff0c;电机要有一定凸极性&#xff01; 软件做有各种保护功能&#x…...

LabVIEW调用VisionPro框架代码:VisionPro labview 2020版

LabVIEW调用VisionPro框架代码 VisionPro labview 2020 最近在折腾LabVIEW和VisionPro的联动开发&#xff0c;发现这俩工业视觉领域的老搭档配合起来确实能玩出不少花样。今天咱们就聊聊怎么在LabVIEW 2020里直接调用VisionPro框架的代码&#xff0c;手头有工控机的朋友可以直接…...

如何为 3D 轮播文本添加可点击的 URL 链接

...

2026 AI行业封神之年:国产模型反超海外,AI短剧/视频/编程三大赛道掘金指南

2026年,AI行业正式迈入工业化落地的关键拐点,不再是技术圈的自嗨,而是全面渗透进写作、设计、影视、开发的各行各业。想抓住这波时代红利,又不想在数十个平台间反复横跳?https://n.kulaai.cn 给出了最优解——这个一站式AI模型聚合平台,直接把ChatGPT、Claude、Gemini、D…...

Windows下OpenClaw安装指南:对接Phi-3-vision-128k-instruct图文模型

Windows下OpenClaw安装指南&#xff1a;对接Phi-3-vision-128k-instruct图文模型 1. 为什么选择OpenClawPhi-3-vision组合 去年我在处理大量图文混排的学术资料时&#xff0c;发现传统自动化工具难以理解图片中的表格和公式。直到尝试将OpenClaw与多模态模型结合&#xff0c;…...

如何在phpMyAdmin中根据结果集生成图表_折线图与柱状图的可视化展示

phpMyAdmin 不支持折线图或柱状图&#xff0c;新版已移除 Charts 标签页&#xff0c;旧版仅依赖弃用的 jpgraph 库支持极简饼图&#xff1b;可行方案是导出 CSV 后用 Excel 或 Chart.js 等外部工具绘图。phpMyAdmin 本身不支持折线图或柱状图phpmyadmin 是一个数据库管理工具&a…...

AI设计抗体,成功率低怎么办?从David Baker新论文看RFdiffusion的三大局限与未来优化方向

AI抗体设计的三大技术瓶颈与突破路径&#xff1a;从RFdiffusion的实践启示 抗体药物市场正以惊人的速度扩张&#xff0c;预计2025年将达到4450亿美元规模。在这个充满机遇的领域&#xff0c;AI技术正在改写传统抗体开发的游戏规则。David Baker团队最新发表在bioRxiv的研究成果…...

如何高效使用付费墙绕过工具:Chrome扩展的完整实践指南

如何高效使用付费墙绕过工具&#xff1a;Chrome扩展的完整实践指南 【免费下载链接】bypass-paywalls-chrome-clean 项目地址: https://gitcode.com/GitHub_Trending/by/bypass-paywalls-chrome-clean 在信息获取日益重要的今天&#xff0c;付费墙成为许多用户访问优质…...

【需求改变与测试如何】

需求一旦修改&#xff0c;测试该如何进行呢&#xff1f; 最近面临的项目&#xff0c;经过很多次需求更改或者是前期没有需求&#xff0c;实际操作起来&#xff0c;让人很是头疼&#xff0c;恰到也看到大家也有着相同的讨论。 来源于微信公众号&#xff1a;测试论道学习&#x…...

萌新梦开始的地方

大家好&#xff0c;我是一名双非本科的大一新生&#xff0c;目前就读于计算机科学与技术这个专业&#xff0c;平时的兴趣爱好就是听听歌&#xff0c;健健身&#xff0c;这是我写的第一篇博客&#xff0c;我想以此来作为我学习编程的开始&#xff0c;同同时也以此来见证我在编程…...

实战演练:基于Next.js与快马AI接口,构建可交互的qoderwork官网演示版

今天想和大家分享一个实战项目&#xff1a;用Next.js模拟搭建qoderwork官网&#xff0c;并集成快马AI的代码生成能力。这个项目特别适合想学习全栈开发的朋友&#xff0c;既能练手Next.js&#xff0c;又能体验AI接口的集成。 项目整体设计思路 这个模拟官网主要包含两大核心功…...

obsidian claudian 插件配置使用minimax模型

首先&#xff0c;打开.claude/settings.json文件 sudo gedit .claude/settings.json参考官网配置 “ANTHROPIC_BASE_URL”: “https://api.minimaxi.com/anthropic”, “ANTHROPIC_AUTH_TOKEN”: “MINIMAX_API_KEY”, 等参数然后在claudian插件中在配置一遍&#xff0c;即可正…...

C++的std--ranges视图转换函数异常安全与资源清理在惰性求值中的处理

C的std::ranges视图转换函数异常安全与资源清理在惰性求值中的处理 现代C引入的std::ranges库为序列操作提供了声明式编程支持&#xff0c;其中视图转换函数&#xff08;如transform、filter等&#xff09;通过惰性求值优化性能。惰性求值机制与异常安全、资源清理的交互可能引…...

FinalBurn Neo终极指南:如何打造完美的复古游戏体验

FinalBurn Neo终极指南&#xff1a;如何打造完美的复古游戏体验 【免费下载链接】FBNeo FinalBurn Neo - We are Team FBNeo. 项目地址: https://gitcode.com/gh_mirrors/fb/FBNeo FinalBurn Neo&#xff08;简称FBNeo&#xff09;是一款开源街机游戏模拟器&#xff0c;…...

CTFshow-PWN实战:利用NOP Sled绕过栈保护获取Shell

1. 理解NOP Sled技术原理 NOP Sled&#xff08;空操作雪橇&#xff09;是二进制漏洞利用中的经典技术&#xff0c;特别适合应对地址随机化&#xff08;ASLR&#xff09;或栈地址不确定的情况。它的核心思想就像滑雪场里的缓冲坡道——通过布置大量无操作指令&#xff08;NOP&am…...