当前位置: 首页 > news >正文

【Python时序预测系列】基于LSTM实现多输入多输出单步预测(案例+源码)

这是我的第312篇原创文章。

一、引言

单站点多变量输入多变量输出单步预测问题----基于LSTM实现。

多输入就是输入多个特征变量

多输出就是同时预测出多个标签的结果

单步就是利用过去N天预测未来1天的结果

二、实现过程

2.1 读取数据集

df=pd.read_csv("data.csv", parse_dates=["Date"], index_col=[0])
print(df.shape)
print(df.head())
fea_num = len(df.columns)

df:

图片

2.2 划分数据集

# 拆分数据集为训练集和测试集
test_split=round(len(df)*0.20)
df_for_training=df[:-test_split]
df_for_testing=df[-test_split:]# 绘制训练集和测试集的折线图
plt.figure(figsize=(10, 6))
plt.plot(train_data, label='Training Data')
plt.plot(test_data, label='Testing Data')
plt.xlabel('Year')
plt.ylabel('Passenger Count')
plt.title('International Airline Passengers - Training and Testing Data')
plt.legend()
plt.show()

共5203条数据,8:2划分:训练集4162,测试集1041。

训练集和测试集:

图片

2.3 归一化

# 将数据归一化到 0~1 范围(整体一起做归一化)
scaler = MinMaxScaler(feature_range=(0,1))
df_for_training_scaled = scaler.fit_transform(df_for_training)
df_for_testing_scaled=scaler.transform(df_for_testing)

2.4 构造LSTM数据集(时序-->监督学习)

def createXY(data, win_size, target_feature_idxs):passwin_size = 12 # 时间窗口
target_feature_idxs = [0, 1, 2, 3, 4] # 指定待预测特征列索引
trainX, trainY = createXY(df_for_training_scaled, win_size, target_feature_idxs)
testX, testY = createXY(df_for_testing_scaled, win_size, target_feature_idxs)
print("训练集形状:", trainX.shape, trainY.shape)
print("测试集形状:", testX.shape, testY.shape)# 将数据集转换为 LSTM 模型所需的形状(样本数,时间步长,特征数)
trainX = np.reshape(trainX, (trainX.shape[0], win_size, fea_num))
testX = np.reshape(testX, (testX.shape[0], win_size, fea_num))print("trainX Shape-- ",trainX.shape)
print("trainY Shape-- ",trainY.shape)
print("testX Shape-- ",testX.shape)
print("testY Shape-- ",testY.shape)

滑动窗口设置为12:

取出df_for_training_scaled第【1-12】行第【1-5】列的12条数据作为trainX[0],取出df_for_training_scaled第【13】行第【1-5】列的1条数据作为trainY[0];依此类推。最终构造出的训练集数量(4150)比划分时候的训练集数量(4162)少一个滑动窗口(12)。

图片

trainX是一个(4150,12,5)的三维数组,三个维度分布表示(样本数量,步长,特征数),每一个样本比如trainX[0]是一个(12,5)二维数组表示(步长,特征数),这也是LSTM模型每一步的输入。

图片

trainY是一个(4150,5)的二维数组,二个维度分布表示(样本数量,标签数),每一个样本比如trainY[0]是一个(5,)一维数组表示(标签数,),这也是LSTM模型每一步的输出。

图片

2.5 建立模拟合模型

# 输入维度
input_shape = Input(shape=(trainX.shape[1], trainX.shape[2]))
# LSTM层
lstm_layer = LSTM(128, activation='relu')(input_shape)
# 全连接层
dense_1 = Dense(64, activation='relu')(lstm_layer)
dense_2 = Dense(32, activation='relu')(dense_1)
# 输出层
output_1 = Dense(1, name='Open')(dense_2)
output_2 = Dense(1, name='High')(dense_2)
output_3 = Dense(1, name='Low')(dense_2)
output_4 = Dense(1, name='Close')(dense_2)
output_5 = Dense(1, name='AdjClose')(dense_2)
model = Model(inputs = input_shape, outputs = [output_1, output_2, output_3, output_4, output_5])
model.compile(loss='mse', optimizer='adam')
model.summary()

这是一个多输入多输出的 LSTM 模型,接受包含12个时间步长和5个特征的输入序列,在经过一层128个神经元的 LSTM 层和5个全连接层后,输出5个单独的预测结果,分别是 Open、High、 Low、Close和 AdjClose。

图片

进行训练,这里[trainY[:,i] for i in range(trainY.shape[1])]把原来的trainY做了转置,是一个(5,4150)的二维数组,分别表示(标签数,样本数)。相当于建立了5个通道,每个通道是(4150,)的一维数组。

history = model.fit(trainX, [trainY[:,i] for i in range(trainY.shape[1])], epochs=20, batch_size=32)

2.6 进行预测

进行预测,上面我们分析过模型每一步的输入是一个(12,5)二维数组表示(步长,特征数),模型每一步的输出是是一个(5,)一维数组表示(标签数,)

prediction_test = model.predict(testX)

如果直接model.predict(testX),testX的形状是(1029,12,5),是一个批量预测,输出prediction_test是一个(5,1029,1)的三维数组,prediction_test[0]就是第一个标签的预测结果,prediction_test[1]就是第二个标签的预测结果...多输出就是同时预测出多个标签的结果

图片

2.7 预测效果展示

分析一下第一个变量open的效果,i=0:

prediction_train = model.predict(trainX)
prediction_train0=model.predict(trainX)[i]
prediction_train_copies_array = ...
pred_train=...
original_train_copies_array = trainY
original_train=...
print("train Pred Values-- ", pred_train)
print("\ntrain Original Values-- ", original_train)
plt.plot(df_for_training.index[win_size:,], original_train, color = 'red', label = '真实值')
plt.plot(df_for_training.index[win_size:,], pred_train, color = 'blue', label = '预测值')
plt.title('Stock Price Prediction')
plt.xlabel('Time')
plt.xticks(rotation=45)
plt.ylabel('Stock Price')
plt.legend()
plt.show()

训练集真实值与预测值:

图片

prediction_test = model.predict(testX)
prediction_test0=model.predict(testX)[i]
prediction_test_copies_array = ...
pred_test=...
original_test_copies_array = testY
original_test=...
print("\ntest Original Values-- ", original_test)
plt.plot(df_for_testing.index[win_size:,], original_test, color = 'red', label = '真实值')
plt.plot(df_for_testing.index[win_size:,], pred_test, color = 'blue', label = '预测值')
plt.title('Stock Price Prediction')
plt.xlabel('Time')
plt.xticks(rotation=45)
plt.ylabel('Stock Price')
plt.legend()
plt.show()

测试集真实值与预测值:

图片

2.8 评估指标

图片

作者简介:

读研期间发表6篇SCI数据挖掘相关论文,现在某研究院从事数据算法相关科研工作,结合自身科研实践经历不定期分享关于Python、机器学习、深度学习、人工智能系列基础知识与应用案例。致力于只做原创,以最简单的方式理解和学习,关注我一起交流成长。需要数据集和源码的小伙伴可以关注底部公众号添加作者微信。

相关文章:

【Python时序预测系列】基于LSTM实现多输入多输出单步预测(案例+源码)

这是我的第312篇原创文章。 一、引言 单站点多变量输入多变量输出单步预测问题----基于LSTM实现。 多输入就是输入多个特征变量 多输出就是同时预测出多个标签的结果 单步就是利用过去N天预测未来1天的结果 二、实现过程 2.1 读取数据集 dfpd.read_csv("data.csv&qu…...

git客户端工具之Github,适用于windows和mac

对于我本人,我已经习惯了使用Github Desktop,不同的公司使用的代码管理平台不一样,就好奇Github Desktop是不是也适用于其他平台,结果是可以的。 一、克隆代码 File --> Clone repository… 选择第三种URL方式,输入url &…...

ai除安卓手机版APP软件一键操作自动渲染去擦消稀缺资源下载

安卓手机版:点击下载 苹果手机版:点击下载 电脑版(支持Mac和Windows):点击下载 一款全新的AI除安卓手机版APP,一键操作,轻松实现自动渲染和去擦消效果,稀缺资源下载 1、一键操作&…...

Unity获取剪切板内容粘贴板图片文件文字

最近做了一个发送消息的unity项目,需要访问剪切板里面的图片文字文件等,翻遍了网上的东西,看了不是需要导入System.Windows.Forms(关键导入了unity还不好用,只能用在纯c#项目中),所以我看了下py…...

利用谷歌云serverless代码托管服务Cloud Functions构建Gemini Pro API

谷歌在2024年4月发布了全新一代的多模态模型Gemini 1.5 Pro,Gemini 1.5 Pro不仅能够生成创意文本和代码,还能理解、总结上传的图片、视频和音频内容,并且支持高达100万tokens的上下文。在多个基准测试中表现优异,性能超越了ChatGP…...

极狐GitLab 17.0 重磅发布,100+ DevSecOps功能更新来啦~【一】

GitLab 是一个全球知名的一体化 DevOps 平台,很多人都通过私有化部署 GitLab 来进行源代码托管。极狐GitLab :https://gitlab.cn/install?channelcontent&utm_sourcecsdn 是 GitLab 在中国的发行版,专门为中国程序员服务。可以一键式部署…...

python实现符文加、解密

在历史悠久的加密技术中,恺撒密码以其简单却有效的原理闻名。通过固定的字母位移,明文可以被转换成密文,而解密则是逆向操作。这种技术不仅适用于英文字母,还可以扩展到其他语言的字符体系,如日语的平假名或汉语的拼音…...

【解释】i.MX6ULL_IO_电气属性说明

【解释】i.MX6ULL_IO_电气属性说明 文章目录 1 Hyst1.1 迟滞(Hysteresis)是什么?1.2 GPIO的Hyst. Enable Field 参数1.3 应用场景 2 Pull / Keep Select Field2.1 PUE_0_Keeper — Keeper2.2 PUE_1_Pull — Pull2.3 选择Keeper还是Pull 3 Dr…...

02-《石莲》

石 莲 石莲(学名:Sinocrassula indica A.Berger),别名因地卡,为二年生草本植物,全株无毛,具须根。花茎高15-60厘米,直立,常被微乳头状突起。茎生叶互生,宽倒披…...

MySQL之聚簇索引和非聚簇索引

1、什么是聚簇索引和非聚簇索引? 聚簇索引,通常也叫聚集索引。 非聚簇索引,指的是二级索引。 下面看一下它们的含义: 1.1、聚集索引选取规则 如果存在主键,主键索引就是聚集索引。如果不存在主键,将使…...

Web后端开发之前后端交互

http协议 http ● 超文本传输协议 (HyperText Transfer Protocol)服务器传输超文本到本地浏览器的传送协议 是互联网上应用最为流行的一种网络协议,用于定义客户端浏览器和服务器之间交换数据的过程。 HTTP是一个基于TCP/IP通信协议来传递数据. HTT…...

520. 检测大写字母 Easy

我们定义,在以下情况时,单词的大写用法是正确的: 全部字母都是大写,比如 "USA" 。 单词中所有字母都不是大写,比如 "leetcode" 。 如果单词不只含有一个字母,只有首字母大写&#xff0…...

vue的跳转传参

1、接收参数使用route,route包含路由信息,接收参数有两种方式,params和query path跳转只能使用query传参,name跳转都可以 params:获取来自动态路由的参数 query:获取来自search部分的参数 写法 path跳,query传 传参数 import { useRout…...

docker配置镜像源

1)打开 docker配置文件 sudo nano /etc/docker/daemon.json 2)添加 国内镜像源 {"registry-mirrors": ["https://akchsmlh.mirror.aliyuncs.com","https://registry.docker-cn.com","https://docker.mirrors.ustc…...

MySQL高级-SQL优化-insert优化-批量插入-手动提交事务-主键顺序插入

文章目录 1、批量插入1.1、大批量插入数据1.2、启动Linux中的mysql服务1.3、客户端连接到mysql数据库,加上参数 --local-infile1.4、查询当前会话中 local_infile 系统变量的值。1.5、开启从本地文件加载数据到服务器的功能1.6、创建表 tb_user 结构1.7、上传文件到…...

认识100种电路之振荡电路

在电子电路领域,振荡是一项至关重要的功能。那么,为什么电路中需要振荡?其背后的原理是什么?让我们一同深入探究。 【为什么需要振荡电路】 简单来说,振荡电路的存在是为了产生周期性的信号。在众多电子设备中&#…...

SSH 无密登录配置流程

一、免密登录原理 非对称加密: 由于对称加密的存在弊端,就产生了非对称加密,非对称加密中有两个密钥:公钥和私钥。公钥由私钥产生,但却无法推算出私钥;公钥加密后的密文,只能通过对应的私钥来解…...

Python自动化运维 系统基础信息模块

1.系统信息的收集 系统信息的收集,对于服务质量的把控,服务的监控等来说是非常重要的组成部分,甚至是核心的基础支撑部分。我们可以通过大量的核心指标数据,结合对应的检测体系,快速的发现异常现象的苗头,进…...

如何安装和配置Monit

前些天发现了一个巨牛的人工智能学习网站,通俗易懂,风趣幽默,忍不住分享一下给大家。点击跳转到网站。 关于 Monit Monit 是一个有用的程序,可以自动监控和管理服务器程序,以确保它们不仅保持在线,而且文…...

【redis】redis分片集群基础知识

1、基本概念 1.1定义 分片:数据按照某种规则(比如哈希)被分割成多个片段(或分片),每个片段被称为一个槽(slot)。槽是Redis分片集群中数据的基本单元。节点:Redis分片集…...

练习(含atoi的模拟实现,自定义类型等练习)

一、结构体大小的计算及位段 (结构体大小计算及位段 详解请看:自定义类型:结构体进阶-CSDN博客) 1.在32位系统环境,编译选项为4字节对齐,那么sizeof(A)和sizeof(B)是多少? #pragma pack(4)st…...

Java求职者面试指南:Spring、Spring Boot、MyBatis框架与计算机基础问题解析

Java求职者面试指南:Spring、Spring Boot、MyBatis框架与计算机基础问题解析 一、第一轮提问(基础概念问题) 1. 请解释Spring框架的核心容器是什么?它在Spring中起到什么作用? Spring框架的核心容器是IoC容器&#…...

Java数值运算常见陷阱与规避方法

整数除法中的舍入问题 问题现象 当开发者预期进行浮点除法却误用整数除法时,会出现小数部分被截断的情况。典型错误模式如下: void process(int value) {double half = value / 2; // 整数除法导致截断// 使用half变量 }此时...

uniapp 字符包含的相关方法

在uniapp中,如果你想检查一个字符串是否包含另一个子字符串,你可以使用JavaScript中的includes()方法或者indexOf()方法。这两种方法都可以达到目的,但它们在处理方式和返回值上有所不同。 使用includes()方法 includes()方法用于判断一个字…...

c++第七天 继承与派生2

这一篇文章主要内容是 派生类构造函数与析构函数 在派生类中重写基类成员 以及多继承 第一部分:派生类构造函数与析构函数 当创建一个派生类对象时,基类成员是如何初始化的? 1.当派生类对象创建的时候,基类成员的初始化顺序 …...

uniapp 小程序 学习(一)

利用Hbuilder 创建项目 运行到内置浏览器看效果 下载微信小程序 安装到Hbuilder 下载地址 :开发者工具默认安装 设置服务端口号 在Hbuilder中设置微信小程序 配置 找到运行设置,将微信开发者工具放入到Hbuilder中, 打开后出现 如下 bug 解…...

【Linux手册】探秘系统世界:从用户交互到硬件底层的全链路工作之旅

目录 前言 操作系统与驱动程序 是什么,为什么 怎么做 system call 用户操作接口 总结 前言 日常生活中,我们在使用电子设备时,我们所输入执行的每一条指令最终大多都会作用到硬件上,比如下载一款软件最终会下载到硬盘上&am…...

【无标题】湖北理元理律师事务所:债务优化中的生活保障与法律平衡之道

文/法律实务观察组 在债务重组领域,专业机构的核心价值不仅在于减轻债务数字,更在于帮助债务人在履行义务的同时维持基本生活尊严。湖北理元理律师事务所的服务实践表明,合法债务优化需同步实现三重平衡: 法律刚性(债…...

Monorepo架构: Nx Cloud 扩展能力与缓存加速

借助 Nx Cloud 实现项目协同与加速构建 1 ) 缓存工作原理分析 在了解了本地缓存和远程缓存之后,我们来探究缓存是如何工作的。以计算文件的哈希串为例,若后续运行任务时文件哈希串未变,系统会直接使用对应的输出和制品文件。 2 …...

Python爬虫实战:研究Restkit库相关技术

1. 引言 1.1 研究背景与意义 在当今信息爆炸的时代,互联网上存在着海量的有价值数据。如何高效地采集这些数据并将其应用于实际业务中,成为了许多企业和开发者关注的焦点。网络爬虫技术作为一种自动化的数据采集工具,可以帮助我们从网页中提取所需的信息。而 RESTful API …...