当前位置: 首页 > news >正文

LSTM预测模型复现笔记和问题记录

LSTM复现笔记和问题记录

  • 1 LSTM复现记录
    • 1.1 复现环境配置
    • 1.2 LSTM_Fly文件夹
      • 1.2.1 LSTM回归网络(1→1).py
        • 1.2.1.1 加载数据
        • 1.2.1.2 数据处理
        • 1.2.1.3 输入模型维度
      • 1.2.2 移动窗口型回归(3→1).py
        • 1.2.2.1 数据处理
        • 1.2.2.2 输入模型维度
      • 1.2.3 时间步长型回归(3→1).py
        • 1.2.3.1 数据处理
        • 1.2.3.2 输入模型维度
    • 1.3 LSTM系列文件夹
      • 1.3.1 LSTM单变量4
        • 1.3.1.1 输入模型维度1
      • 1.3.2 LSTM多变量3
        • 1.3.2.1 输入模型维度
      • 1.3.3 Multi-Step LSTM预测2
        • 1.3.3.1 输入模型维度
    • 1.4 stock_predict
      • 1.4.1 stock_predict_1.py
        • 1.4.1.1 输入模型维度
    • 1.5 洗发水销量(单步预测)
      • 1.5.1 6.LSTM模型实例.py
  • 总结

1 LSTM复现记录

复现github链接:https://github.com/yangwohenmai/LSTM.git

1.1 复现环境配置

采用cuda10.1;1050ti显卡; python版本:3.8.20

absl-py==2.1.0
astunparse==1.6.3
cachetools==4.2.4
certifi==2025.1.31
charset-normalizer==3.4.1
contourpy==1.1.1
cycler==0.12.1
fonttools==4.56.0
gast==0.3.3
google-auth==1.35.0
google-auth-oauthlib==0.4.6
google-pasta==0.2.0
grpcio==1.70.0
h5py==2.10.0
idna==3.10
importlib_metadata==8.5.0
importlib_resources==6.4.5
joblib==1.4.2
Keras-Preprocessing==1.1.2
kiwisolver==1.4.7
Markdown==3.7
MarkupSafe==2.1.5
matplotlib==3.3.4
numpy==1.18.5
oauthlib==3.2.2
opt_einsum==3.4.0
packaging==24.2
pandas==1.1.5
pillow==10.4.0
protobuf==3.20.0
pyasn1==0.6.1
pyasn1_modules==0.4.1
pyparsing==3.1.4
python-dateutil==2.9.0.post0
pytz==2025.1
PyYAML==6.0.2
requests==2.32.3
requests-oauthlib==2.0.0
rsa==4.9
scikit-learn==0.24.2
scipy==1.4.1
six==1.17.0
tensorboard==2.2.2
tensorboard-data-server==0.7.2
tensorboard-plugin-wit==1.8.1
tensorflow-gpu==2.2.0
tensorflow-gpu-estimator==2.2.0
termcolor==2.4.0
Theano==1.0.5
threadpoolctl==3.5.0
tzdata==2025.1
urllib3==2.2.3
Werkzeug==3.0.6
wrapt==1.17.2
zipp==3.20.2

注:这个版本得tensorflow中自带keras,后面keras.xxx相关模块得导入,修改成tensorflow.keras.xxx

vscode配置:

{"version": "1.95.2","configurations": [{"name": "Python Debugger: Current File","type": "debugpy","request": "launch","program": "${file}","console": "integratedTerminal","cwd":"${fileDirname}",}]
}

1.2 LSTM_Fly文件夹

1.2.1 LSTM回归网络(1→1).py

将数据截取成1->1的监督学习格式:即用前一个数据预测后一个数据

1.2.1.1 加载数据

加载数据时未考虑第一列的时间序列,故命名未回归网络

# 加载数据
dataframe = read_csv('airline-passengers.csv', usecols=[1], engine='python')
1.2.1.2 数据处理

look_back参数设置为1,实现前一个值预测后一个值

# 将数据截取成1->1的监督学习格式
def create_dataset(dataset, look_back=1):dataX, dataY = [], []for i in range(len(dataset)-look_back-1):a = dataset[i:(i+look_back), 0]dataX.append(a)dataY.append(dataset[i + look_back, 0])return numpy.array(dataX), numpy.array(dataY)# 预测数据步长为1,一个预测一个,1->1
look_back = 1
trainX, trainY = create_dataset(train, look_back)
testX, testY = create_dataset(test, look_back)
1.2.1.3 输入模型维度
# 重构输入数据格式 [samples, time steps, features] = [93,1,1]
trainX = numpy.reshape(trainX, (trainX.shape[0], 1, trainX.shape[1]))
testX = numpy.reshape(testX, (testX.shape[0], 1, testX.shape[1]))# 构建 LSTM 网络
model = Sequential()
model.add(LSTM(4, input_shape=(1, look_back)))
model.add(Dense(1))
model.compile(loss='mean_squared_error', optimizer='adam')model.fit(trainX, trainY, epochs=100, batch_size=1, verbose=2)

1.2.2 移动窗口型回归(3→1).py

用重构的长度为3的数据预测一个数据,相当于进行了特征重构,前三个时间步作为特征进行变换,预测后一个值。

1.2.2.1 数据处理
# 将数据截取成3个一组的监督学习格式
def create_dataset(dataset, look_back=1):dataX, dataY = [], []# 这里没有充分利用数据,若修改为len(dataset)-look_back,然后加上dataset[i + look_back, 0]才刚好遍历到最后一个数据# 如果没有加上Y,则是len(dataset)-look_back+1for i in range(len(dataset)-look_back-1):  a = dataset[i:(i+look_back), 0]dataX.append(a)dataY.append(dataset[i + look_back, 0])return numpy.array(dataX), numpy.array(dataY)
1.2.2.2 输入模型维度
# 预测数据步长为3,三个预测一个,3->1
look_back = 3
trainX, trainY = create_dataset(train, look_back)
testX, testY = create_dataset(test, look_back)
# 重构输入数据格式 [samples, time steps, features] = [93,1,3]
trainX = numpy.reshape(trainX, (trainX.shape[0], 1, trainX.shape[1]))
testX = numpy.reshape(testX, (testX.shape[0], 1, testX.shape[1]))
# 构建 LSTM 网络
model = Sequential()
model.add(LSTM(4, input_shape=(1, look_back)))
model.add(Dense(1))
model.compile(loss='mean_squared_error', optimizer='adam')
model.fit(trainX,trainY, epochs=100, batch_size=1, verbose=2)
# 对训练数据的Y进行预测
trainPredict = model.predict(trainX)

1.2.3 时间步长型回归(3→1).py

用时间步长为3,特征维度为1的数据预测后一个数据,前三个时间步里的一个元素作为特征用于预测后一个值。

1.2.3.1 数据处理
# 将数据截取成3个一组的监督学习格式
def create_dataset(dataset, look_back=1):dataX, dataY = [], []for i in range(len(dataset)-look_back-1):a = dataset[i:(i+look_back), 0]dataX.append(a)dataY.append(dataset[i + look_back, 0])return numpy.array(dataX), numpy.array(dataY)
1.2.3.2 输入模型维度
# 重构输入数据格式 [samples, time steps, features] = [93,3,1]
trainX = numpy.reshape(trainX, (trainX.shape[0], trainX.shape[1], 1))
testX = numpy.reshape(testX, (testX.shape[0], testX.shape[1], 1))
# 构建 LSTM 网络
model = Sequential()
# model.add(LSTM(4, input_shape=(1, look_back)))
model.add(LSTM(4, input_shape=(look_back, 1)))
model.add(Dense(1))
model.compile(loss='mean_squared_error', optimizer='adam')
model.fit(trainX, trainY, epochs=100, batch_size=1, verbose=2)
# 对训练数据的Y进行预测
trainPredict = model.predict(trainX)

1.3 LSTM系列文件夹

1.3.1 LSTM单变量4

使用前面1个步长里的一个特征元素的数据预测后一个值。

1.3.1.1 输入模型维度1

这里代码中可以看出是使用的前一个时间步的一个特征元素的值来预测后一个值。

# fit LSTM来训练数据
def fit_lstm(train, batch_size, nb_epoch, neurons):X, y = train[:, 0:-1], train[:, -1]X = X.reshape(X.shape[0], 1, X.shape[1])  # 注意这个reshape,一般会需要reshape从二维变成三维,因为这里time_step为1,所以中间的值为1model = Sequential()# 添加LSTM层model.add(LSTM(neurons, batch_input_shape=(batch_size, X.shape[1], X.shape[2]), stateful=True))model.add(Dense(1))  # 输出层1个node# 编译,损失函数mse+优化算法adammodel.compile(loss='mean_squared_error', optimizer='adam')for i in range(nb_epoch):# 按照batch_size,一次读取batch_size个数据model.fit(X, y, epochs=1, batch_size=batch_size, verbose=0, shuffle=False)model.reset_states()print("当前计算次数:"+str(i))return model

1.3.2 LSTM多变量3

使用前面1个时间步长里的多个特征元素的数据预测后一个值。

1.3.2.1 输入模型维度

这里代码中可以看出是使用的前一个时间步的多个特征元素的值来预测后一个值。

#拆分输入输出 split into input and outputs
train_X, train_y = train[:, :-1], train[:, -1]
test_X, test_y = test[:, :-1], test[:, -1]
#reshape输入为LSTM的输入格式 reshape input to be 3D [samples, timesteps, features]
train_X = train_X.reshape((train_X.shape[0], 1, train_X.shape[1]))
test_X = test_X.reshape((test_X.shape[0], 1, test_X.shape[1]))
print ('train_x.shape, train_y.shape, test_x.shape, test_y.shape')
print(train_X.shape, train_y.shape, test_X.shape, test_y.shape)##模型定义 design network
model = Sequential()
model.add(LSTM(50, input_shape=(train_X.shape[1], train_X.shape[2])))  # 注意这段,input_shape的前一个元素表示前面步长,reshape时变为1,后一个元素表示每个时间步长里的元素特征数量
model.add(Dense(1))
model.compile(loss='mae', optimizer='adam')
#模型训练 fit network
history = model.fit(train_X, train_y, epochs=5, batch_size=72, validation_data=(test_X, test_y), verbose=2,shuffle=False)

1.3.3 Multi-Step LSTM预测2

使用前面1个时间步长里的单个特征元素的数据预测后面多个时间步长的单个特征元素值。

1.3.3.1 输入模型维度

具体数据变换维度见代码。

1.4 stock_predict

1.4.1 stock_predict_1.py

用时间步长为20,特征维度为1的数据预测后20数据,前20个时间步里的一个元素作为特征用于预测后20个时间步里的单个元素值。

1.4.1.1 输入模型维度

具体数据变换维度见代码。

#———————————————————形成训练集—————————————————————
time_step = 20      #时间步
rnn_unit = 10       #hidden layer units
lstm_layers = 2     #每一批次训练多少个样例
batch_size = 60     #输入层维度  #每一批次训练多少个样例
input_size = 1      #输入层维度
output_size = 1     #输出层维度
lr = 0.0006         #学习率
train_x, train_y = [], []#训练集
for i in range(len(normalize_data) - time_step - 1):x = normalize_data[i:i + time_step]y = normalize_data[i + 1:i + time_step + 1]train_x.append(x.tolist())train_y.append(y.tolist())

1.5 洗发水销量(单步预测)

1.5.1 6.LSTM模型实例.py

用前面lag步长的数据,特征维度为1,预测后面特征维度为1的数据
… …

总结

总之,使用前面列的数据预测最后一个列,大多是时间步长为1,每个时间步长都包括变量X和预测值y;另外一种时间步长为n,每个时间步长都包括变量X特征维度可能为1维或多维,另外一个位置时间步长为n对应一个y。

相关文章:

LSTM预测模型复现笔记和问题记录

LSTM复现笔记和问题记录 1 LSTM复现记录1.1 复现环境配置1.2 LSTM_Fly文件夹1.2.1 LSTM回归网络(1→1).py1.2.1.1 加载数据1.2.1.2 数据处理1.2.1.3 输入模型维度 1.2.2 移动窗口型回归(3→1).py1.2.2.1 数据处理1.2.2.2 输入模型维度 1.2.3 时间步长型回归(3→1).py1.2.3.1 数…...

开篇词 | Go 项目开发极速入门课介绍

欢迎加入我的训练营:云原生 AI 实战营,一个助力 Go 开发者在 AI 时代建立技术竞争力的实战营。实战营中包含大量 Go、云原生、AI Infra 相关的优质实战课程和项目。欢迎关注我的公众号:令飞编程,持续分享 Go、云原生、AI Infra 技…...

《论软件测试中缺陷管理及其应用》审题技巧 - 系统架构设计师

论软件测试中缺陷管理及其应用写作框架 一、考点概述 本论题“论软件测试中缺陷管理及其应用”主要考查的是软件测试领域中的缺陷管理相关知识与实践应用。论题涵盖了以下几个核心内容: 首先,需要理解软件缺陷的基本概念,即软件中存在的破坏正常运行能力的问题、错误或隐…...

虚拟机快照与linux的目录结构

虚拟机快照是对虚拟机某一时刻状态的完整捕获,包括内存、磁盘、配置及虚拟硬件状态等,保存为独立文件。 其作用主要有数据备份恢复、方便系统测试实验、用于灾难恢复以及数据对比分析。具有快速创建和恢复、占用空间小、可多个快照并存的特点。在管理维…...

FPGA时许约束与分析 1

1、时钟的基本概念 1.1 时钟定义: 同步设计:电路的状态变化总是由某个周期信号的变化进行控制的,在这个信号的 posedge 或者是 negedge 都可以作为电路状态的触发条件。 时钟:在同步设计中,这个信号 叫做时钟。 理…...

【STM32F103ZET6——库函数】6.PWM

目录 配置PWM输出引脚 使能引脚时钟 配置PWM 使能PWM 配置定时器 使能定时器时钟 使能定时器 例程 例程说明 main.h main.c PWM.h PWM.c led.h led.c DSQ.h DSQ.c 配置PWM输出引脚 PWM的输出引脚必须配置为复用功能。 注意:需要使用哪个引脚&…...

基于SpringBoot + Vue的商城购物系统实战

一:简介 使用springboot框架编写后端服务,并使用若依框架搭建管理端界面。在原有基础功能基础上有加入了人工客服、收货地址、智能助手(接入通义千问,暂时关闭)、抽奖功能、支付宝沙箱支付、优惠卷等功能。 目前已部…...

Perl 调用 DeepSeek API 脚本

向 chat.deepseek.com 提问:请将这个 python 脚本翻译为 perl 语言脚本 参阅:Python 调用 DeepSeek API 完整指南 将 Python 脚本翻译为 Perl 语言脚本时,需要注意两种语言之间的语法差异。以下是将给定的 Python 脚本翻译为 Perl 的版本&a…...

2025国家护网HVV高频面试题总结来了01(题目+回答)

网络安全领域各种资源,学习文档,以及工具分享、前沿信息分享、POC、EXP分享。不定期分享各种好玩的项目及好用的工具,欢迎关注。 目录 0x1 高频面试题第一套 0x2 高频面试题第二套 0x3 高频面试题第三套 0x4 高频面试题第四套 0x5 高频面…...

【前端基础】3、HTML的常用元素(h、p、img、a、iframe、div、span)、不常用元素(strong、i、code、br)

HTML结构 一个HTML包含以下部分&#xff1a; 文档类型声明html元素 head元素body元素 例&#xff08;CSDN&#xff09;&#xff1a; 一、文档类型声明 HTML最一方的文档称为&#xff1a;文档类型声明&#xff0c;用于声明文档类型。即&#xff1a;<!DOCTYPE html>…...

【前端场景题】如何应对页面请求接口的大规模并发问题

如何应对页面请求接口的大规模并发问题&#xff0c;尤其是前端方面的解决方案&#xff0c;并且需要给出详细的代码解释。首先&#xff0c;我需要仔细阅读我搜索到的资料&#xff0c;找出相关的信息&#xff0c;然后综合这些信息来形成答案。 首先看&#xff0c;它提到前端优化策…...

Sublime Text4安装、汉化

-------------2025-02-22可用---------------------- 官方网址下载&#xff1a;https://www.sublimetext.com 打开https://hexed.it 点击打开文件找到软件安装目录下的 ctrlf 查找 8079 0500 0f94 c2右边启用替换替换为:c641 0501 b200 90点击替换按钮 替换完成后 另存为本地…...

Python PDF文件拆分-详解

目录 使用工具 将PDF按页数拆分 将PDF的每一页拆分为单独的文件 将PDF按指定页数拆分 根据页码范围拆分PDF 根据指定内容拆分PDF 将PDF的一页拆分为多页 在日常生活中&#xff0c;我们常常会遇到大型的PDF文件&#xff0c;这些文件可能难以发送、管理和查阅。将PDF拆分成…...

MacDroid for Mac v2.3 安卓手机文件传输助手 支持M、Intel芯片 4.7K

MacDroid 是Mac毒搜集到的一款安卓手机文件传输助手&#xff0c;在Mac和Android设备之间传输文件。您只需要将安卓手机使用 USB 连接到 Mac 电脑上即可将安卓设备挂载为本地磁盘&#xff0c;就像编辑mac磁盘上的文件一样编辑安卓设备上的文件&#xff0c;MacDroid支持所有 Andr…...

人大金仓国产数据库与PostgreSQL

一、简介 在前面项目中&#xff0c;我们使用若依前后端分离整合人大金仓&#xff0c;在后续开发过程中&#xff0c;我们经常因为各种”不适配“问题&#xff0c;但可以感觉得到大部分问题&#xff0c;将人大金仓视为postgreSQL就能去解决大部分问题。据了解&#xff0c;Kingba…...

阿里云 Qwen2.5-Max:超大规模 MoE 模型架构和性能评估

大家好,我是大 F,深耕AI算法十余年,互联网大厂技术岗。分享AI算法干货、技术心得。 欢迎关注《大模型理论和实战》、《DeepSeek技术解析和实战》,一起探索技术的无限可能! 一、引言 Qwen2.5-Max 是阿里云通义千问团队研发的超大规模 Mixture-of-Expert(MoE)模型,旨在通…...

C++ 标准库容器的常用成员函数

目录 C 标准库容器简介 通用成员函数 1. 大小相关 size() empty() max_size() 2. 元素访问 operator[] at(size_t n) front() back() 3. 修改容器 push_back(const T& value) pop_back() clear() insert() erase() 4. 迭代器相关 begin() end() rbegi…...

MySQL双主搭建-5.7.35

文章目录 上传并安装MySQL 5.7.35双主复制的配置实例一&#xff1a;172.25.0.19&#xff1a;实例二&#xff1a;172.25.0.20&#xff1a; 配置复制用户在实例 1 &#xff08;172.25.0.19&#xff09;上执行&#xff1a;在实例 2 &#xff08;172.25.0.20&#xff09;上执行&…...

Uniapp开发微信小程序插件的一些心得

一、uniapp 开发微信小程序框架搭建 1. 通过 vue-cli 创建 uni-ap // nodejs使用18以上的版本 nvm use 18.14.1 // 安装vue-cli npm install -g vue/cli4 // 选择默认模版 vue create -p dcloudio/uni-preset-vue plugindemo // 运行 uniapp2wxpack-cli npx uniapp2wxpack --…...

Vscode通过Roo Cline接入Deepseek

文章目录 背景第一步、安装插件第二步、申请API key第三步、Vscode中配置第四步、Deepseek对话 背景 在前期介绍【IDEA通过Contince接入Deepseek】步骤和流程&#xff0c;那如何在vscode编译器中使用deepseek&#xff0c;记录下来&#xff0c;方便备查。 第一步、安装插件 在…...

Python爬虫实战:研究MechanicalSoup库相关技术

一、MechanicalSoup 库概述 1.1 库简介 MechanicalSoup 是一个 Python 库,专为自动化交互网站而设计。它结合了 requests 的 HTTP 请求能力和 BeautifulSoup 的 HTML 解析能力,提供了直观的 API,让我们可以像人类用户一样浏览网页、填写表单和提交请求。 1.2 主要功能特点…...

Golang 面试经典题:map 的 key 可以是什么类型?哪些不可以?

Golang 面试经典题&#xff1a;map 的 key 可以是什么类型&#xff1f;哪些不可以&#xff1f; 在 Golang 的面试中&#xff0c;map 类型的使用是一个常见的考点&#xff0c;其中对 key 类型的合法性 是一道常被提及的基础却很容易被忽视的问题。本文将带你深入理解 Golang 中…...

江苏艾立泰跨国资源接力:废料变黄金的绿色供应链革命

在华东塑料包装行业面临限塑令深度调整的背景下&#xff0c;江苏艾立泰以一场跨国资源接力的创新实践&#xff0c;重新定义了绿色供应链的边界。 跨国回收网络&#xff1a;废料变黄金的全球棋局 艾立泰在欧洲、东南亚建立再生塑料回收点&#xff0c;将海外废弃包装箱通过标准…...

跨链模式:多链互操作架构与性能扩展方案

跨链模式&#xff1a;多链互操作架构与性能扩展方案 ——构建下一代区块链互联网的技术基石 一、跨链架构的核心范式演进 1. 分层协议栈&#xff1a;模块化解耦设计 现代跨链系统采用分层协议栈实现灵活扩展&#xff08;H2Cross架构&#xff09;&#xff1a; 适配层&#xf…...

Robots.txt 文件

什么是robots.txt&#xff1f; robots.txt 是一个位于网站根目录下的文本文件&#xff08;如&#xff1a;https://example.com/robots.txt&#xff09;&#xff0c;它用于指导网络爬虫&#xff08;如搜索引擎的蜘蛛程序&#xff09;如何抓取该网站的内容。这个文件遵循 Robots…...

docker 部署发现spring.profiles.active 问题

报错&#xff1a; org.springframework.boot.context.config.InvalidConfigDataPropertyException: Property spring.profiles.active imported from location class path resource [application-test.yml] is invalid in a profile specific resource [origin: class path re…...

【数据分析】R版IntelliGenes用于生物标志物发现的可解释机器学习

禁止商业或二改转载&#xff0c;仅供自学使用&#xff0c;侵权必究&#xff0c;如需截取部分内容请后台联系作者! 文章目录 介绍流程步骤1. 输入数据2. 特征选择3. 模型训练4. I-Genes 评分计算5. 输出结果 IntelliGenesR 安装包1. 特征选择2. 模型训练和评估3. I-Genes 评分计…...

蓝桥杯 冶炼金属

原题目链接 &#x1f527; 冶炼金属转换率推测题解 &#x1f4dc; 原题描述 小蓝有一个神奇的炉子用于将普通金属 O O O 冶炼成为一种特殊金属 X X X。这个炉子有一个属性叫转换率 V V V&#xff0c;是一个正整数&#xff0c;表示每 V V V 个普通金属 O O O 可以冶炼出 …...

算法岗面试经验分享-大模型篇

文章目录 A 基础语言模型A.1 TransformerA.2 Bert B 大语言模型结构B.1 GPTB.2 LLamaB.3 ChatGLMB.4 Qwen C 大语言模型微调C.1 Fine-tuningC.2 Adapter-tuningC.3 Prefix-tuningC.4 P-tuningC.5 LoRA A 基础语言模型 A.1 Transformer &#xff08;1&#xff09;资源 论文&a…...

Java毕业设计:WML信息查询与后端信息发布系统开发

JAVAWML信息查询与后端信息发布系统实现 一、系统概述 本系统基于Java和WML(无线标记语言)技术开发&#xff0c;实现了移动设备上的信息查询与后端信息发布功能。系统采用B/S架构&#xff0c;服务器端使用Java Servlet处理请求&#xff0c;数据库采用MySQL存储信息&#xff0…...