当前位置: 首页 > news >正文

LSTM预测模型复现笔记和问题记录

LSTM复现笔记和问题记录

  • 1 LSTM复现记录
    • 1.1 复现环境配置
    • 1.2 LSTM_Fly文件夹
      • 1.2.1 LSTM回归网络(1→1).py
        • 1.2.1.1 加载数据
        • 1.2.1.2 数据处理
        • 1.2.1.3 输入模型维度
      • 1.2.2 移动窗口型回归(3→1).py
        • 1.2.2.1 数据处理
        • 1.2.2.2 输入模型维度
      • 1.2.3 时间步长型回归(3→1).py
        • 1.2.3.1 数据处理
        • 1.2.3.2 输入模型维度
    • 1.3 LSTM系列文件夹
      • 1.3.1 LSTM单变量4
        • 1.3.1.1 输入模型维度1
      • 1.3.2 LSTM多变量3
        • 1.3.2.1 输入模型维度
      • 1.3.3 Multi-Step LSTM预测2
        • 1.3.3.1 输入模型维度
    • 1.4 stock_predict
      • 1.4.1 stock_predict_1.py
        • 1.4.1.1 输入模型维度
    • 1.5 洗发水销量(单步预测)
      • 1.5.1 6.LSTM模型实例.py
  • 总结

1 LSTM复现记录

复现github链接:https://github.com/yangwohenmai/LSTM.git

1.1 复现环境配置

采用cuda10.1;1050ti显卡; python版本:3.8.20

absl-py==2.1.0
astunparse==1.6.3
cachetools==4.2.4
certifi==2025.1.31
charset-normalizer==3.4.1
contourpy==1.1.1
cycler==0.12.1
fonttools==4.56.0
gast==0.3.3
google-auth==1.35.0
google-auth-oauthlib==0.4.6
google-pasta==0.2.0
grpcio==1.70.0
h5py==2.10.0
idna==3.10
importlib_metadata==8.5.0
importlib_resources==6.4.5
joblib==1.4.2
Keras-Preprocessing==1.1.2
kiwisolver==1.4.7
Markdown==3.7
MarkupSafe==2.1.5
matplotlib==3.3.4
numpy==1.18.5
oauthlib==3.2.2
opt_einsum==3.4.0
packaging==24.2
pandas==1.1.5
pillow==10.4.0
protobuf==3.20.0
pyasn1==0.6.1
pyasn1_modules==0.4.1
pyparsing==3.1.4
python-dateutil==2.9.0.post0
pytz==2025.1
PyYAML==6.0.2
requests==2.32.3
requests-oauthlib==2.0.0
rsa==4.9
scikit-learn==0.24.2
scipy==1.4.1
six==1.17.0
tensorboard==2.2.2
tensorboard-data-server==0.7.2
tensorboard-plugin-wit==1.8.1
tensorflow-gpu==2.2.0
tensorflow-gpu-estimator==2.2.0
termcolor==2.4.0
Theano==1.0.5
threadpoolctl==3.5.0
tzdata==2025.1
urllib3==2.2.3
Werkzeug==3.0.6
wrapt==1.17.2
zipp==3.20.2

注:这个版本得tensorflow中自带keras,后面keras.xxx相关模块得导入,修改成tensorflow.keras.xxx

vscode配置:

{"version": "1.95.2","configurations": [{"name": "Python Debugger: Current File","type": "debugpy","request": "launch","program": "${file}","console": "integratedTerminal","cwd":"${fileDirname}",}]
}

1.2 LSTM_Fly文件夹

1.2.1 LSTM回归网络(1→1).py

将数据截取成1->1的监督学习格式:即用前一个数据预测后一个数据

1.2.1.1 加载数据

加载数据时未考虑第一列的时间序列,故命名未回归网络

# 加载数据
dataframe = read_csv('airline-passengers.csv', usecols=[1], engine='python')
1.2.1.2 数据处理

look_back参数设置为1,实现前一个值预测后一个值

# 将数据截取成1->1的监督学习格式
def create_dataset(dataset, look_back=1):dataX, dataY = [], []for i in range(len(dataset)-look_back-1):a = dataset[i:(i+look_back), 0]dataX.append(a)dataY.append(dataset[i + look_back, 0])return numpy.array(dataX), numpy.array(dataY)# 预测数据步长为1,一个预测一个,1->1
look_back = 1
trainX, trainY = create_dataset(train, look_back)
testX, testY = create_dataset(test, look_back)
1.2.1.3 输入模型维度
# 重构输入数据格式 [samples, time steps, features] = [93,1,1]
trainX = numpy.reshape(trainX, (trainX.shape[0], 1, trainX.shape[1]))
testX = numpy.reshape(testX, (testX.shape[0], 1, testX.shape[1]))# 构建 LSTM 网络
model = Sequential()
model.add(LSTM(4, input_shape=(1, look_back)))
model.add(Dense(1))
model.compile(loss='mean_squared_error', optimizer='adam')model.fit(trainX, trainY, epochs=100, batch_size=1, verbose=2)

1.2.2 移动窗口型回归(3→1).py

用重构的长度为3的数据预测一个数据,相当于进行了特征重构,前三个时间步作为特征进行变换,预测后一个值。

1.2.2.1 数据处理
# 将数据截取成3个一组的监督学习格式
def create_dataset(dataset, look_back=1):dataX, dataY = [], []# 这里没有充分利用数据,若修改为len(dataset)-look_back,然后加上dataset[i + look_back, 0]才刚好遍历到最后一个数据# 如果没有加上Y,则是len(dataset)-look_back+1for i in range(len(dataset)-look_back-1):  a = dataset[i:(i+look_back), 0]dataX.append(a)dataY.append(dataset[i + look_back, 0])return numpy.array(dataX), numpy.array(dataY)
1.2.2.2 输入模型维度
# 预测数据步长为3,三个预测一个,3->1
look_back = 3
trainX, trainY = create_dataset(train, look_back)
testX, testY = create_dataset(test, look_back)
# 重构输入数据格式 [samples, time steps, features] = [93,1,3]
trainX = numpy.reshape(trainX, (trainX.shape[0], 1, trainX.shape[1]))
testX = numpy.reshape(testX, (testX.shape[0], 1, testX.shape[1]))
# 构建 LSTM 网络
model = Sequential()
model.add(LSTM(4, input_shape=(1, look_back)))
model.add(Dense(1))
model.compile(loss='mean_squared_error', optimizer='adam')
model.fit(trainX,trainY, epochs=100, batch_size=1, verbose=2)
# 对训练数据的Y进行预测
trainPredict = model.predict(trainX)

1.2.3 时间步长型回归(3→1).py

用时间步长为3,特征维度为1的数据预测后一个数据,前三个时间步里的一个元素作为特征用于预测后一个值。

1.2.3.1 数据处理
# 将数据截取成3个一组的监督学习格式
def create_dataset(dataset, look_back=1):dataX, dataY = [], []for i in range(len(dataset)-look_back-1):a = dataset[i:(i+look_back), 0]dataX.append(a)dataY.append(dataset[i + look_back, 0])return numpy.array(dataX), numpy.array(dataY)
1.2.3.2 输入模型维度
# 重构输入数据格式 [samples, time steps, features] = [93,3,1]
trainX = numpy.reshape(trainX, (trainX.shape[0], trainX.shape[1], 1))
testX = numpy.reshape(testX, (testX.shape[0], testX.shape[1], 1))
# 构建 LSTM 网络
model = Sequential()
# model.add(LSTM(4, input_shape=(1, look_back)))
model.add(LSTM(4, input_shape=(look_back, 1)))
model.add(Dense(1))
model.compile(loss='mean_squared_error', optimizer='adam')
model.fit(trainX, trainY, epochs=100, batch_size=1, verbose=2)
# 对训练数据的Y进行预测
trainPredict = model.predict(trainX)

1.3 LSTM系列文件夹

1.3.1 LSTM单变量4

使用前面1个步长里的一个特征元素的数据预测后一个值。

1.3.1.1 输入模型维度1

这里代码中可以看出是使用的前一个时间步的一个特征元素的值来预测后一个值。

# fit LSTM来训练数据
def fit_lstm(train, batch_size, nb_epoch, neurons):X, y = train[:, 0:-1], train[:, -1]X = X.reshape(X.shape[0], 1, X.shape[1])  # 注意这个reshape,一般会需要reshape从二维变成三维,因为这里time_step为1,所以中间的值为1model = Sequential()# 添加LSTM层model.add(LSTM(neurons, batch_input_shape=(batch_size, X.shape[1], X.shape[2]), stateful=True))model.add(Dense(1))  # 输出层1个node# 编译,损失函数mse+优化算法adammodel.compile(loss='mean_squared_error', optimizer='adam')for i in range(nb_epoch):# 按照batch_size,一次读取batch_size个数据model.fit(X, y, epochs=1, batch_size=batch_size, verbose=0, shuffle=False)model.reset_states()print("当前计算次数:"+str(i))return model

1.3.2 LSTM多变量3

使用前面1个时间步长里的多个特征元素的数据预测后一个值。

1.3.2.1 输入模型维度

这里代码中可以看出是使用的前一个时间步的多个特征元素的值来预测后一个值。

#拆分输入输出 split into input and outputs
train_X, train_y = train[:, :-1], train[:, -1]
test_X, test_y = test[:, :-1], test[:, -1]
#reshape输入为LSTM的输入格式 reshape input to be 3D [samples, timesteps, features]
train_X = train_X.reshape((train_X.shape[0], 1, train_X.shape[1]))
test_X = test_X.reshape((test_X.shape[0], 1, test_X.shape[1]))
print ('train_x.shape, train_y.shape, test_x.shape, test_y.shape')
print(train_X.shape, train_y.shape, test_X.shape, test_y.shape)##模型定义 design network
model = Sequential()
model.add(LSTM(50, input_shape=(train_X.shape[1], train_X.shape[2])))  # 注意这段,input_shape的前一个元素表示前面步长,reshape时变为1,后一个元素表示每个时间步长里的元素特征数量
model.add(Dense(1))
model.compile(loss='mae', optimizer='adam')
#模型训练 fit network
history = model.fit(train_X, train_y, epochs=5, batch_size=72, validation_data=(test_X, test_y), verbose=2,shuffle=False)

1.3.3 Multi-Step LSTM预测2

使用前面1个时间步长里的单个特征元素的数据预测后面多个时间步长的单个特征元素值。

1.3.3.1 输入模型维度

具体数据变换维度见代码。

1.4 stock_predict

1.4.1 stock_predict_1.py

用时间步长为20,特征维度为1的数据预测后20数据,前20个时间步里的一个元素作为特征用于预测后20个时间步里的单个元素值。

1.4.1.1 输入模型维度

具体数据变换维度见代码。

#———————————————————形成训练集—————————————————————
time_step = 20      #时间步
rnn_unit = 10       #hidden layer units
lstm_layers = 2     #每一批次训练多少个样例
batch_size = 60     #输入层维度  #每一批次训练多少个样例
input_size = 1      #输入层维度
output_size = 1     #输出层维度
lr = 0.0006         #学习率
train_x, train_y = [], []#训练集
for i in range(len(normalize_data) - time_step - 1):x = normalize_data[i:i + time_step]y = normalize_data[i + 1:i + time_step + 1]train_x.append(x.tolist())train_y.append(y.tolist())

1.5 洗发水销量(单步预测)

1.5.1 6.LSTM模型实例.py

用前面lag步长的数据,特征维度为1,预测后面特征维度为1的数据
… …

总结

总之,使用前面列的数据预测最后一个列,大多是时间步长为1,每个时间步长都包括变量X和预测值y;另外一种时间步长为n,每个时间步长都包括变量X特征维度可能为1维或多维,另外一个位置时间步长为n对应一个y。

相关文章:

LSTM预测模型复现笔记和问题记录

LSTM复现笔记和问题记录 1 LSTM复现记录1.1 复现环境配置1.2 LSTM_Fly文件夹1.2.1 LSTM回归网络(1→1).py1.2.1.1 加载数据1.2.1.2 数据处理1.2.1.3 输入模型维度 1.2.2 移动窗口型回归(3→1).py1.2.2.1 数据处理1.2.2.2 输入模型维度 1.2.3 时间步长型回归(3→1).py1.2.3.1 数…...

开篇词 | Go 项目开发极速入门课介绍

欢迎加入我的训练营:云原生 AI 实战营,一个助力 Go 开发者在 AI 时代建立技术竞争力的实战营。实战营中包含大量 Go、云原生、AI Infra 相关的优质实战课程和项目。欢迎关注我的公众号:令飞编程,持续分享 Go、云原生、AI Infra 技…...

《论软件测试中缺陷管理及其应用》审题技巧 - 系统架构设计师

论软件测试中缺陷管理及其应用写作框架 一、考点概述 本论题“论软件测试中缺陷管理及其应用”主要考查的是软件测试领域中的缺陷管理相关知识与实践应用。论题涵盖了以下几个核心内容: 首先,需要理解软件缺陷的基本概念,即软件中存在的破坏正常运行能力的问题、错误或隐…...

虚拟机快照与linux的目录结构

虚拟机快照是对虚拟机某一时刻状态的完整捕获,包括内存、磁盘、配置及虚拟硬件状态等,保存为独立文件。 其作用主要有数据备份恢复、方便系统测试实验、用于灾难恢复以及数据对比分析。具有快速创建和恢复、占用空间小、可多个快照并存的特点。在管理维…...

FPGA时许约束与分析 1

1、时钟的基本概念 1.1 时钟定义: 同步设计:电路的状态变化总是由某个周期信号的变化进行控制的,在这个信号的 posedge 或者是 negedge 都可以作为电路状态的触发条件。 时钟:在同步设计中,这个信号 叫做时钟。 理…...

【STM32F103ZET6——库函数】6.PWM

目录 配置PWM输出引脚 使能引脚时钟 配置PWM 使能PWM 配置定时器 使能定时器时钟 使能定时器 例程 例程说明 main.h main.c PWM.h PWM.c led.h led.c DSQ.h DSQ.c 配置PWM输出引脚 PWM的输出引脚必须配置为复用功能。 注意:需要使用哪个引脚&…...

基于SpringBoot + Vue的商城购物系统实战

一:简介 使用springboot框架编写后端服务,并使用若依框架搭建管理端界面。在原有基础功能基础上有加入了人工客服、收货地址、智能助手(接入通义千问,暂时关闭)、抽奖功能、支付宝沙箱支付、优惠卷等功能。 目前已部…...

Perl 调用 DeepSeek API 脚本

向 chat.deepseek.com 提问:请将这个 python 脚本翻译为 perl 语言脚本 参阅:Python 调用 DeepSeek API 完整指南 将 Python 脚本翻译为 Perl 语言脚本时,需要注意两种语言之间的语法差异。以下是将给定的 Python 脚本翻译为 Perl 的版本&a…...

2025国家护网HVV高频面试题总结来了01(题目+回答)

网络安全领域各种资源,学习文档,以及工具分享、前沿信息分享、POC、EXP分享。不定期分享各种好玩的项目及好用的工具,欢迎关注。 目录 0x1 高频面试题第一套 0x2 高频面试题第二套 0x3 高频面试题第三套 0x4 高频面试题第四套 0x5 高频面…...

【前端基础】3、HTML的常用元素(h、p、img、a、iframe、div、span)、不常用元素(strong、i、code、br)

HTML结构 一个HTML包含以下部分&#xff1a; 文档类型声明html元素 head元素body元素 例&#xff08;CSDN&#xff09;&#xff1a; 一、文档类型声明 HTML最一方的文档称为&#xff1a;文档类型声明&#xff0c;用于声明文档类型。即&#xff1a;<!DOCTYPE html>…...

【前端场景题】如何应对页面请求接口的大规模并发问题

如何应对页面请求接口的大规模并发问题&#xff0c;尤其是前端方面的解决方案&#xff0c;并且需要给出详细的代码解释。首先&#xff0c;我需要仔细阅读我搜索到的资料&#xff0c;找出相关的信息&#xff0c;然后综合这些信息来形成答案。 首先看&#xff0c;它提到前端优化策…...

Sublime Text4安装、汉化

-------------2025-02-22可用---------------------- 官方网址下载&#xff1a;https://www.sublimetext.com 打开https://hexed.it 点击打开文件找到软件安装目录下的 ctrlf 查找 8079 0500 0f94 c2右边启用替换替换为:c641 0501 b200 90点击替换按钮 替换完成后 另存为本地…...

Python PDF文件拆分-详解

目录 使用工具 将PDF按页数拆分 将PDF的每一页拆分为单独的文件 将PDF按指定页数拆分 根据页码范围拆分PDF 根据指定内容拆分PDF 将PDF的一页拆分为多页 在日常生活中&#xff0c;我们常常会遇到大型的PDF文件&#xff0c;这些文件可能难以发送、管理和查阅。将PDF拆分成…...

MacDroid for Mac v2.3 安卓手机文件传输助手 支持M、Intel芯片 4.7K

MacDroid 是Mac毒搜集到的一款安卓手机文件传输助手&#xff0c;在Mac和Android设备之间传输文件。您只需要将安卓手机使用 USB 连接到 Mac 电脑上即可将安卓设备挂载为本地磁盘&#xff0c;就像编辑mac磁盘上的文件一样编辑安卓设备上的文件&#xff0c;MacDroid支持所有 Andr…...

人大金仓国产数据库与PostgreSQL

一、简介 在前面项目中&#xff0c;我们使用若依前后端分离整合人大金仓&#xff0c;在后续开发过程中&#xff0c;我们经常因为各种”不适配“问题&#xff0c;但可以感觉得到大部分问题&#xff0c;将人大金仓视为postgreSQL就能去解决大部分问题。据了解&#xff0c;Kingba…...

阿里云 Qwen2.5-Max:超大规模 MoE 模型架构和性能评估

大家好,我是大 F,深耕AI算法十余年,互联网大厂技术岗。分享AI算法干货、技术心得。 欢迎关注《大模型理论和实战》、《DeepSeek技术解析和实战》,一起探索技术的无限可能! 一、引言 Qwen2.5-Max 是阿里云通义千问团队研发的超大规模 Mixture-of-Expert(MoE)模型,旨在通…...

C++ 标准库容器的常用成员函数

目录 C 标准库容器简介 通用成员函数 1. 大小相关 size() empty() max_size() 2. 元素访问 operator[] at(size_t n) front() back() 3. 修改容器 push_back(const T& value) pop_back() clear() insert() erase() 4. 迭代器相关 begin() end() rbegi…...

MySQL双主搭建-5.7.35

文章目录 上传并安装MySQL 5.7.35双主复制的配置实例一&#xff1a;172.25.0.19&#xff1a;实例二&#xff1a;172.25.0.20&#xff1a; 配置复制用户在实例 1 &#xff08;172.25.0.19&#xff09;上执行&#xff1a;在实例 2 &#xff08;172.25.0.20&#xff09;上执行&…...

Uniapp开发微信小程序插件的一些心得

一、uniapp 开发微信小程序框架搭建 1. 通过 vue-cli 创建 uni-ap // nodejs使用18以上的版本 nvm use 18.14.1 // 安装vue-cli npm install -g vue/cli4 // 选择默认模版 vue create -p dcloudio/uni-preset-vue plugindemo // 运行 uniapp2wxpack-cli npx uniapp2wxpack --…...

Vscode通过Roo Cline接入Deepseek

文章目录 背景第一步、安装插件第二步、申请API key第三步、Vscode中配置第四步、Deepseek对话 背景 在前期介绍【IDEA通过Contince接入Deepseek】步骤和流程&#xff0c;那如何在vscode编译器中使用deepseek&#xff0c;记录下来&#xff0c;方便备查。 第一步、安装插件 在…...

超厉害!AI写教材,低查重且内容连贯,快速产出专业教材!

整理教材知识点实在是一项“精细工作”&#xff0c;最大的挑战在于如何保持平衡与衔接&#xff01;我们常常担忧会遗漏核心概念&#xff0c;或是难以掌握合适的难度梯度——小学教材常常写得过于复杂&#xff0c;导致学生难以理解&#xff1b;而高中教材则可能显得过于简单&…...

ExternalDNS自动化DNS管理实践:实现Kubernetes服务自动注册

ExternalDNS自动化DNS管理实践&#xff1a;实现Kubernetes服务自动注册 一、ExternalDNS概述 ExternalDNS是一个Kubernetes控制器&#xff0c;能够自动同步Kubernetes资源&#xff08;如Service和Ingress&#xff09;到外部DNS服务商。它消除了手动管理DNS记录的繁琐工作&…...

My-TODOs:免费开源跨平台桌面待办清单应用终极指南

My-TODOs&#xff1a;免费开源跨平台桌面待办清单应用终极指南 【免费下载链接】My-TODOs A cross-platform desktop To-Do list. 跨平台桌面待办小工具 项目地址: https://gitcode.com/gh_mirrors/my/My-TODOs 你是否经常忘记重要任务&#xff1f;是否在多个待办应用间…...

FlashAttention 反向传播:删掉 O(N²) 的中间结果,怎么还能算对梯度?

FlashAttention 反向传播&#xff1a;删掉 O(N) 的中间结果&#xff0c;怎么还能算对梯度&#xff1f; 之前有人跟我争&#xff1a;FlashAttention 反向传播不存注意力矩阵&#xff0c;那梯度从哪来&#xff1f;你前向传播的时候 Softmax 的分母、分子都扔了&#xff0c;反向传…...

告别手动对照!用OrCAD Design Sync功能,5分钟自动化同步你的原理图与Allegro PCB变更

告别手动对照&#xff01;用OrCAD Design Sync功能&#xff0c;5分钟自动化同步你的原理图与Allegro PCB变更 在高速迭代的电子设计领域&#xff0c;每一次原理图修改都可能引发PCB布局的连锁反应。传统手动同步方式不仅耗时费力&#xff0c;还容易遗漏关键变更。OrCAD Design…...

破解Windows安装程序本地化难题:Inno Setup简体中文翻译的技术实现与架构设计

破解Windows安装程序本地化难题&#xff1a;Inno Setup简体中文翻译的技术实现与架构设计 【免费下载链接】Inno-Setup-Chinese-Simplified-Translation :earth_asia: Inno Setup Chinese Simplified Translation 项目地址: https://gitcode.com/gh_mirrors/in/Inno-Setup-Ch…...

KMS智能激活工具:三步永久激活Windows和Office系统完整指南

KMS智能激活工具&#xff1a;三步永久激活Windows和Office系统完整指南 【免费下载链接】KMS_VL_ALL_AIO Smart Activation Script 项目地址: https://gitcode.com/gh_mirrors/km/KMS_VL_ALL_AIO 还在为Windows系统频繁弹出激活提示而烦恼吗&#xff1f;Office文档突然变…...

5步实现《鸣潮》游戏体验全面升级:WuWa-Mod模组高效部署指南

5步实现《鸣潮》游戏体验全面升级&#xff1a;WuWa-Mod模组高效部署指南 【免费下载链接】wuwa-mod Wuthering Waves pak mods 项目地址: https://gitcode.com/GitHub_Trending/wu/wuwa-mod 还在为《鸣潮》游戏中的技能冷却、体力限制和繁琐操作而烦恼吗&#xff1f;WuW…...

终极AMD Ryzen性能调优指南:5分钟掌握SMUDebugTool免费调试神器

终极AMD Ryzen性能调优指南&#xff1a;5分钟掌握SMUDebugTool免费调试神器 【免费下载链接】SMUDebugTool A dedicated tool to help write/read various parameters of Ryzen-based systems, such as manual overclock, SMU, PCI, CPUID, MSR and Power Table. 项目地址: h…...

3步掌握抖音批量下载:终极免费无水印下载器完整指南

3步掌握抖音批量下载&#xff1a;终极免费无水印下载器完整指南 【免费下载链接】douyin-downloader A practical Douyin downloader for both single-item and profile batch downloads, with progress display, retries, SQLite deduplication, and browser fallback support…...