【预测】-双注意LSTM自动编码器记录
预测-双注意LSTM自动编码器
- 1 预测-双注意LSTM自动编码器
- 1.1 复现环境配置
- 1.2 数据流记录
- 1.2.1 **构建Dataset**
- (1) **`X` 的取数**
- (2) **`y` 的取数**
- (3) **`target` 的取数**
- 1.2.2 **举例说明**
- (1)**`X` 的取数**
- (2)**`y` 的取数**
- (3)**`target` 的取数**
- 1.2.3 **`y` 取数的问题**
- **修正后的代码**
- 1.2.4 **总结**
- 1.2.5 数据流总结:
- 1.2.6 数据流图示:
- 1.2.7 参考:
- 2 数据维度变化流程
- 2.1 流程图
- 2.2 总结
1 预测-双注意LSTM自动编码器
复现github链接:https://github.com/JulesBelveze/time-series-autoencoder.git
论文:A Dual-Stage Attention-Based Recurrent Neural Network for Time Series Prediction:https://arxiv.org/abs/1704.02971
1.1 复现环境配置
python版本:python3.8.20
cuda版本:cuda111
包版本环境参考:
Package Version Editable project location
-------------------- ------------ ---------------------------
build 1.2.2.post1
CacheControl 0.14.2
certifi 2025.1.31
charset-normalizer 3.4.1
cleo 2.1.0
colorama 0.4.6
contourpy 1.1.1
crashtest 0.4.1
cycler 0.10.0
distlib 0.3.9
dulwich 0.21.7
fastjsonschema 2.21.1
filelock 3.16.1
fonttools 4.56.0
future 0.18.2
idna 3.10
importlib_metadata 8.5.0
importlib_resources 6.4.5
installer 0.7.0
jaraco.classes 3.4.0
joblib 0.15.1
keyring 24.3.1
kiwisolver 1.2.0
matplotlib 3.2.1
more-itertools 10.5.0
msgpack 1.1.0
numpy 1.21.0
packaging 24.2
pandas 1.1.5
pexpect 4.9.0
pillow 10.4.0
pip 24.3.1
pkginfo 1.12.1.2
platformdirs 4.3.6
poetry 1.8.5
poetry-core 1.9.1
poetry-plugin-export 1.8.0
protobuf 5.29.3
ptyprocess 0.7.0
pyparsing 2.4.7
pyproject_hooks 1.2.0
python-dateutil 2.8.1
pytz 2025.1
pywin32-ctypes 0.2.3
rapidfuzz 3.9.7
requests 2.32.3
requests-toolbelt 1.0.0
scikit-learn 0.23.1
scipy 1.4.1
setuptools 75.3.0
shellingham 1.5.4
six 1.15.0
sklearn 0.0
tensorboardX 2.6.2.2
threadpoolctl 2.1.0
tomli 2.2.1
tomlkit 0.13.2
torch 1.9.1+cu111
torchaudio 0.9.1
torchvision 0.10.1+cu111
tqdm 4.46.1
trove-classifiers 2025.2.18.16
tsa 0.1.0 D:\temp\Pytorch双注意LSTM自动编码器
typing_extensions 4.12.2
urllib3 2.2.3
virtualenv 20.29.2
wheel 0.45.1
zipp 3.20.2
注:
vscode配置:
{// Use IntelliSense to learn about possible attributes.// Hover to view descriptions of existing attributes.// For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387"version": "0.2.0","configurations": [{"name": "Python Debugger: Current File with Arguments","type": "debugpy","request": "launch","program": "${file}","cwd":"${fileDirname}","console": "integratedTerminal",// "args": [// "--ckpt", "output/checkpoint-5000.ckpt" // 添加 --ckpt 参数及其值// ]}]
}
1.2 数据流记录
源代码所用数据字段:
| 列名 | 含义 | 单位 |
| ------------- | ------------------ | -------- |
| Date_Time | 日期和时间 | - |
| CO(GT) | 一氧化碳浓度 | mg/m³ |
| PT08.S1(CO) | 一氧化碳传感器响应值 | 无量纲 |
| NMHC(GT) | 非甲烷烃浓度 | µg/m³ |
| C6H6(GT) | 苯浓度 | µg/m³ |
| PT08.S2(NMHC) | 非甲烷烃传感器响应值 | 无量纲 |
| NOx(GT) | 氮氧化物浓度 | µg/m³ |
| PT08.S3(NOx) | 氮氧化物传感器响应值 | 无量纲 |
| NO2(GT) | 二氧化氮浓度 | µg/m³ |
| PT08.S4(NO2) | 二氧化氮传感器响应值 | 无量纲 |
| PT08.S5(O3) | 臭氧传感器响应值 | 无量纲 |
| T | 温度 | °C |
| RH | 相对湿度 | % |
| AH | 绝对湿度 | g/m³ |
将时间序列数据转换为适合时间序列预测的格式,具体来说,它通过滑动窗口的方式从输入数据 X 和标签 y 中提取特征和标签,并生成一个 TensorDataset。提取特征和标签,预测的是后面的预测窗口长度的标签。下面我将详细解释 X、y 和 target 的取数逻辑,并指出 y 取数可能存在的问题。
1.2.1 构建Dataset
(1) X 的取数
X是输入特征数据,形状为(nb_obs, nb_features),其中nb_obs是样本数量,nb_features是特征数量。- 通过滑动窗口的方式,从
X中提取长度为seq_length的序列:features.append(torch.FloatTensor(X[i:i + self.seq_length, :]).unsqueeze(0))- 例如,如果
seq_length = 10,则每次提取X[i:i+10, :],即从第i个时间步开始的 10 个时间步的特征数据。 unsqueeze(0)是为了增加一个批次维度。
- 例如,如果
(2) y 的取数
y是目标值(标签),通常是与X对应的输出值。- 代码中从
y中提取的是滞后一期的历史值(y[i-1:i+self.seq_length-1]):y_hist.append(torch.FloatTensor(y[i - 1:i + self.seq_length - 1]).unsqueeze(0))- 例如,如果
seq_length = 10,则提取的是y[i-1:i+9],即从第i-1个时间步开始的 10 个时间步的标签值。 - 这里
y[i-1]的使用可能有问题,因为y[i-1]是前一个时间步的值,而不是当前时间步的值。如果y是当前时间步的标签,那么这里应该直接使用y[i:i+self.seq_length]。
- 例如,如果
(3) target 的取数
target是预测的目标值,即未来prediction_window个时间步的标签值:target.append(torch.FloatTensor(y[i + self.seq_length:i + self.seq_length + self.prediction_window]))- 例如,如果
seq_length = 10且prediction_window = 5,则提取的是y[i+10:i+15],即从第i+10个时间步开始的 5 个时间步的标签值。
- 例如,如果
1.2.2 举例说明
假设有以下数据:
X和y的长度为 20。seq_length = 3,prediction_window = 2。
(1)X 的取数
- 当
i = 1时,提取X[1:4, :]。 - 当
i = 2时,提取X[2:5, :]。 - 以此类推。
(2)y 的取数
- 当
i = 1时,提取y[0:3](即y[i-1:i+seq_length-1])。 - 当
i = 2时,提取y[1:4]。 - 以此类推。
(3)target 的取数
- 当
i = 1时,提取y[4:6](即y[i+seq_length:i+seq_length+prediction_window])。 - 当
i = 2时,提取y[5:7]。 - 以此类推。
1.2.3 y 取数的问题
在代码中,y 的取数逻辑是:
y_hist.append(torch.FloatTensor(y[i - 1:i + self.seq_length - 1]).unsqueeze(0))
这里使用了 y[i-1],即前一个时间步的值。如果 y 是当前时间步的标签,那么这里应该直接使用 y[i:i+self.seq_length],而不是 y[i-1:i+self.seq_length-1]。修正后的代码应该是:
y_hist.append(torch.FloatTensor(y[i:i + self.seq_length]).unsqueeze(0))
修正后的代码
def frame_series(self, X, y=None):'''Function used to prepare the data for time series prediction:param X: set of features:param y: targeted value to predict:return: TensorDataset'''nb_obs, nb_features = X.shapefeatures, target, y_hist = [], [], []for i in range(1, nb_obs - self.seq_length - self.prediction_window):features.append(torch.FloatTensor(X[i:i + self.seq_length, :]).unsqueeze(0))# 修正后的 y 取数逻辑y_hist.append(torch.FloatTensor(y[i:i + self.seq_length]).unsqueeze(0))features_var, y_hist_var = torch.cat(features), torch.cat(y_hist)if y is not None:for i in range(1, nb_obs - self.seq_length - self.prediction_window):target.append(torch.FloatTensor(y[i + self.seq_length:i + self.seq_length + self.prediction_window]))target_var = torch.cat(target)return TensorDataset(features_var, y_hist_var, target_var)return TensorDataset(features_var)
1.2.4 总结
X的取数是滑动窗口提取特征序列。y的取数逻辑存在问题,不应使用y[i-1],而应直接使用y[i:i+self.seq_length]。target的取数是提取未来prediction_window个时间步的标签值。
这段代码的数据流可以分为以下几个步骤:
-
数据预处理:
- 调用
self.preprocess_data()方法,生成训练集和测试集的特征和标签:X_train,X_test,y_train,y_test。 - 从
X_train中获取特征的数量nb_features。
- 调用
-
数据集封装:
- 调用
self.frame_series(X_train, y_train)方法,将训练集的特征和标签封装成一个train_dataset对象。 - 调用
self.frame_series(X_test, y_test)方法,将测试集的特征和标签封装成一个test_dataset对象。
- 调用
-
DataLoader 创建:
- 使用
DataLoader类创建train_iter,用于加载训练数据集。参数包括batch_size(批次大小)、shuffle=False(不打乱数据)、drop_last=True(丢弃最后一个不完整的批次)。 - 使用
DataLoader类创建test_iter,用于加载测试数据集。参数与train_iter相同。
- 使用
-
返回结果:
- 返回
train_iter(训练数据加载器)、test_iter(测试数据加载器)和nb_features(特征数量)。
- 返回
1.2.5 数据流总结:
- 输入:原始数据通过
self.preprocess_data()进行预处理,生成特征和标签。 - 处理:特征和标签被封装成
Dataset对象,然后通过DataLoader进行批次加载。 - 输出:返回训练和测试的
DataLoader对象,以及特征数量。
1.2.6 数据流图示:
原始数据 → preprocess_data() → (X_train, X_test, y_train, y_test) → frame_series() → (train_dataset, test_dataset) → DataLoader() → (train_iter, test_iter)
1.2.7 参考:
DataLoader是 PyTorch 中用于批量加载数据的工具,支持多线程加载、数据打乱等功能。Dataset是 PyTorch 中用于封装数据集的基类,通常需要实现__len__和__getitem__方法。
为了更好地理解数据维度的变化情况,我们可以通过一个具体的例子来逐步分析代码中的数据维度变化。假设我们有一个时间序列数据集,包含以下列:
date: 时间戳feature1: 数值特征feature2: 数值特征category: 类别特征target: 目标值
2 数据维度变化流程
-
原始数据 (
data):- 假设数据集有 1000 行,5 列(
date,feature1,feature2,category,target)。 - 维度:
(1000, 5)
- 假设数据集有 1000 行,5 列(
-
预处理 (
preprocess_data):X = data.drop('target', axis=1):去掉目标列,剩下 4 列。- 维度:
(1000, 4)
- 维度:
y = data['target']:目标列。- 维度:
(1000,)
- 维度:
X_train, X_test, y_train, y_test = train_test_split(X, y, train_size=0.8, shuffle=False):X_train维度:(800, 4)X_test维度:(200, 4)y_train维度:(800,)y_test维度:(200,)
X_train = preprocessor.fit_transform(X_train):经过ColumnTransformer处理,假设category列被编码为 3 个新列。- 维度:
(800, 5)(feature1,feature2,category_encoded_1,category_encoded_2,category_encoded_3)
- 维度:
X_test = preprocessor.transform(X_test):- 维度:
(200, 5)
- 维度:
-
时间序列帧化 (
frame_series):- 假设
seq_length = 10,prediction_window = 1。 nb_obs, nb_features = X_train.shape:nb_obs = 800,nb_features = 5
features和y_hist的生成:- 对于
i从 1 到800 - 10 - 1 = 789,每次取 10 个时间步的数据。 features维度:(789, 10, 5)y_hist维度:(789, 10)
- 对于
target的生成:- 对于
i从 1 到789,每次取 1 个时间步的目标值。 target维度:(789, 1)
- 对于
TensorDataset的生成:features_var维度:(789, 10, 5)y_hist_var维度:(789, 10)target_var维度:(789, 1)
- 假设
-
DataLoader (
get_loaders):train_iter = DataLoader(train_dataset, batch_size=32, shuffle=False, drop_last=True):- 每个 batch 的维度:
(32, 10, 5)(特征),(32, 10)(历史目标),(32, 1)(目标)
- 每个 batch 的维度:
test_iter = DataLoader(test_dataset, batch_size=32, shuffle=False, drop_last=True):- 每个 batch 的维度:
(32, 10, 5)(特征),(32, 10)(历史目标),(32, 1)(目标)
- 每个 batch 的维度:
2.1 流程图
原始数据 (1000, 5)|v
预处理 (X_train: 800, 5, y_train: 800)|v
时间序列帧化 (features: 789, 10, 5, y_hist: 789, 10, target: 789, 1)|v
DataLoader (batch_size=32, features: 32, 10, 5, y_hist: 32, 10, target: 32, 1)
2.2 总结
通过上述步骤,可以看到数据从原始形式逐步转换为适合时间序列模型训练的格式。每个步骤中的数据维度变化如下:
- 原始数据:
(1000, 5) - 预处理后:
(800, 5)(训练集特征),(800,)(训练集目标) - 时间序列帧化后:
(789, 10, 5)(特征),(789, 10)(历史目标),(789, 1)(目标) - DataLoader 中:
(32, 10, 5)(特征),(32, 10)(历史目标),(32, 1)(目标)
相关文章:
【预测】-双注意LSTM自动编码器记录
预测-双注意LSTM自动编码器 1 预测-双注意LSTM自动编码器1.1 复现环境配置1.2 数据流记录1.2.1 **构建Dataset**(1) **X 的取数**(2) **y 的取数**(3) **target 的取数** 1.2.2 **举例说明**(1)**X 的取数**(2)**y 的取数**(3)**target 的取数** 1.2.3 **y 取数的问题****修正后…...
S32K3 MCU时钟部分
S32K3 MCU时钟部分 1.系统时钟发生器SCG 系统时钟发生器SCG模块提供MCU的系统时钟,SCG包含一个系统锁相环SPLL,一个慢速的内部参考时钟SIRC,一个快速内部参考时钟FIRC和系统振荡时钟SOSC. 时钟生成的电路提供了多个时钟分频器和选择器允许为不同的模块提供以特定于该模块的频率…...
java开发常用注解
在Java开发中,注解(Annotation)广泛用于简化代码、配置元数据、框架集成等场景。以下是不同场景下常用的注解分类整理: 一、核心Java注解(内置) Override 表示方法重写父类或接口的方法,编译器会…...
Doris vs ClickHouse 企业级实时分析引擎怎么选?
Apache Doris 与 ClickHouse 同作为OLAP领域的佼佼者,在企业级实时分析引擎该如何选择呢。本文将详细介绍 Doris 的优势,并通过直观对比展示两者的关键差异,同时分享一个企业成功用 Doris 替换 ClickHouse 的实践案例,帮助您做出明…...
解锁Egg.js:从Node.js小白到Web开发高手的进阶之路
一、Egg.js 是什么 在当今的 Web 开发领域,Node.js 凭借其事件驱动、非阻塞 I/O 的模型,在构建高性能、可扩展的网络应用方面展现出独特的优势 ,受到了广大开发者的青睐。它让 JavaScript 不仅局限于前端,还能在服务器端大展身手&…...
学习前端前需要了解的一些概念(详细版)
网站的定义与概述 网站(Website)是一个由网络服务器托管的、通过网络访问的、由相关网页和资源组成的集合。它为用户提供信息、服务或娱乐平台,是现代互联网的重要组成部分。网站的基本功能是展示信息和提供服务,用户可以通过浏览…...
分布式数据库中的四种透明性:逻辑透明、位置透明、分片透明和复制透明
四种透明性 1. 逻辑透明(Logical Transparency)2. 位置透明(Location Transparency)3. 分片透明(Fragmentation Transparency)4. 复制透明(Replication Transparency)注意点…...
SSM架构 +java后台 实现rtsp流转hls流,在前端html上实现视频播放
序言:书接上文,我们继续 SSM架构 NginxFFmpeg实现rtsp流转hls流,在前端html上实现视频播放 步骤一:把rtsp流转化为hls流,用Java代码进行转换 package com.tools;import java.io.BufferedReader; import java.io.IOExc…...
时序数据库 TDengine 化工新签约:存储降本一半,查询提速十倍
化工行业在数字化转型过程中面临数据接入复杂、实时性要求高、系统集成难度大等诸多挑战。福州力川数码科技有限公司科技依托深厚的行业积累,精准聚焦行业痛点,并携手 TDengine 提供高效解决方案。通过应用 TDengine,力川科技助力化工企业实现…...
信号完整性基础:高速信号的扩频时钟SSC测试
扩频时钟 SSC 是 Spread Spectrum Clock 的英文缩写,目前很多数字电路芯片都支持 SSC 功能,如:PCIE、USB3.0、SATA 等等。那么扩频时钟是用来做什么的呢? SSC背景: 扩频时钟是出于解决电磁干扰(EMI&#…...
深入理解与配置 Nginx TCP 日志输出
一、背景介绍 在现代网络架构中,Nginx 作为一款高性能的 Web 服务器和反向代理服务器,广泛应用于各种场景。除了对 HTTP/HTTPS 协议的出色支持,Nginx 从 1.9.0 版本开始引入了对 TCP 和 UDP 协议的代理功能,这使得它在处理数据库…...
Java为什么是跨平台的
一、Java虚拟机(JVM)的抽象层作用 JVAM是Java跨平台的核心技术。Java代码编译后生成字节码(.class文件),这些字节码并非直接由操作系统执行,而是由JVM解释或编译为特定平台的机器码。 屏蔽底层差异:JVM为不同操作系统提供统一的运行时环境,开…...
Sora与AGI的结合:从多模态模型到智能体推理的演进
全文目录: 开篇语前言前言:AGI的挑战与Sora的突破Sora的多模态学习架构:支撑智能体推理的基础1. **多模态学习的核心:信息融合与交叉理解**2. **智能体推理:从感知到决策** Sora如何推动AGI的发展:自主学习…...
一个针对煤炭市场的人工智能项目的开发示例
以下是一个针对煤炭市场的人工智能项目的开发示例,此项目将涵盖数据收集、数据预处理、模型构建、模型训练和预测等步骤。这里我们以预测煤炭价格为例,使用 Python 语言结合常见的机器学习库(如pandas、scikit - learn)来完成。 …...
QILSTE H6-S115FOKYG高亮橙光和黄绿光LED灯珠
型号:H6-S115FOKYG --- 在众多电子元件中,H6-S115FOKYG型号的LED以其独特的性能脱颖而出。这款产品采用了高亮橙光和黄绿光两种颜色,尺寸仅为1.6x1.5x0.55mm,却蕴含着强大的光电性能。其透明平面胶体设计,不仅美观&a…...
EasyDSS视频推拉流/直播点播平台:Mysql数据库接口报错502处理方法
视频推拉流/视频直播点播EasyDSS互联网直播平台支持一站式的上传、转码、直播、回放、嵌入、分享功能,具有多屏播放、自由组合、接口丰富等特点。平台可以为用户提供专业、稳定的直播推流、转码、分发和播放服务,全面满足超低延迟、超高画质、超大并发访…...
测试直播postman+Jenkins所学
接口自动化 什么是接口?本质上就是一个url,用于提供数据。后台程序提供一种数据地址,接口的数据一般是从数据库中查出来的。 postman自动化实操: 一般来说公司会给接口文档,如果没有,通过拦截,…...
上线DeepSeek大模型,黄山“大位”智算中心正式点亮
2月28日,智启黄山,算领未来——黄山“大位”智算中心点亮仪式在黄山市大位人工智能计算中心举行,标志着黄山“大位”智算中心正式投入运营。同日,DeepSeek-R1大模型在黄山“大位”正式上线,通过“顶尖大模型普惠算力底…...
计算机毕业设计SpringBoot+Vue.js医院药品管理系统(源码+文档+PPT+讲解)
温馨提示:文末有 CSDN 平台官方提供的学长联系方式的名片! 温馨提示:文末有 CSDN 平台官方提供的学长联系方式的名片! 温馨提示:文末有 CSDN 平台官方提供的学长联系方式的名片! 作者简介:Java领…...
Linux安装nvm和node
执行curl -o- https://raw.githubusercontent.com/nvm-sh/nvm/v0.39.4/install.sh | bash命令下载安装nvm 执行 source ~/.bashrc命令重新加载shell配置文件以使NVM生效 执行nvm ls-remote 查看可用node版本 如果确定版本,可以直接执行npm install 版本号࿰…...
接口测试中缓存处理策略
在接口测试中,缓存处理策略是一个关键环节,直接影响测试结果的准确性和可靠性。合理的缓存处理策略能够确保测试环境的一致性,避免因缓存数据导致的测试偏差。以下是接口测试中常见的缓存处理策略及其详细说明: 一、缓存处理的核…...
【力扣数据库知识手册笔记】索引
索引 索引的优缺点 优点1. 通过创建唯一性索引,可以保证数据库表中每一行数据的唯一性。2. 可以加快数据的检索速度(创建索引的主要原因)。3. 可以加速表和表之间的连接,实现数据的参考完整性。4. 可以在查询过程中,…...
UE5 学习系列(三)创建和移动物体
这篇博客是该系列的第三篇,是在之前两篇博客的基础上展开,主要介绍如何在操作界面中创建和拖动物体,这篇博客跟随的视频链接如下: B 站视频:s03-创建和移动物体 如果你不打算开之前的博客并且对UE5 比较熟的话按照以…...
【CSS position 属性】static、relative、fixed、absolute 、sticky详细介绍,多层嵌套定位示例
文章目录 ★ position 的五种类型及基本用法 ★ 一、position 属性概述 二、position 的五种类型详解(初学者版) 1. static(默认值) 2. relative(相对定位) 3. absolute(绝对定位) 4. fixed(固定定位) 5. sticky(粘性定位) 三、定位元素的层级关系(z-i…...
转转集团旗下首家二手多品类循环仓店“超级转转”开业
6月9日,国内领先的循环经济企业转转集团旗下首家二手多品类循环仓店“超级转转”正式开业。 转转集团创始人兼CEO黄炜、转转循环时尚发起人朱珠、转转集团COO兼红布林CEO胡伟琨、王府井集团副总裁祝捷等出席了开业剪彩仪式。 据「TMT星球」了解,“超级…...
Mac下Android Studio扫描根目录卡死问题记录
环境信息 操作系统: macOS 15.5 (Apple M2芯片)Android Studio版本: Meerkat Feature Drop | 2024.3.2 Patch 1 (Build #AI-243.26053.27.2432.13536105, 2025年5月22日构建) 问题现象 在项目开发过程中,提示一个依赖外部头文件的cpp源文件需要同步,点…...
HDFS分布式存储 zookeeper
hadoop介绍 狭义上hadoop是指apache的一款开源软件 用java语言实现开源框架,允许使用简单的变成模型跨计算机对大型集群进行分布式处理(1.海量的数据存储 2.海量数据的计算)Hadoop核心组件 hdfs(分布式文件存储系统)&a…...
VM虚拟机网络配置(ubuntu24桥接模式):配置静态IP
编辑-虚拟网络编辑器-更改设置 选择桥接模式,然后找到相应的网卡(可以查看自己本机的网络连接) windows连接的网络点击查看属性 编辑虚拟机设置更改网络配置,选择刚才配置的桥接模式 静态ip设置: 我用的ubuntu24桌…...
使用Spring AI和MCP协议构建图片搜索服务
目录 使用Spring AI和MCP协议构建图片搜索服务 引言 技术栈概览 项目架构设计 架构图 服务端开发 1. 创建Spring Boot项目 2. 实现图片搜索工具 3. 配置传输模式 Stdio模式(本地调用) SSE模式(远程调用) 4. 注册工具提…...
打手机检测算法AI智能分析网关V4守护公共/工业/医疗等多场景安全应用
一、方案背景 在现代生产与生活场景中,如工厂高危作业区、医院手术室、公共场景等,人员违规打手机的行为潜藏着巨大风险。传统依靠人工巡查的监管方式,存在效率低、覆盖面不足、判断主观性强等问题,难以满足对人员打手机行为精…...
