当前位置: 首页 > news >正文

人工智能基础部分13-LSTM网络:预测上证指数走势

大家好,我是微学AI,今天给大家介绍一下LSTM网络,主要运用于解决序列问题。

一、LSTM网络简单介绍

LSTM又称为:长短期记忆网络,它是一种特殊的 RNN。LSTM网络主要是为了解决长序列训练过程中的梯度消失和梯度爆炸问题。对于相比普通的RNN,LSTM能够在更长的序列中有更好的表现。

引入LSTM网络的原因:由于 RNN 网络主要问题是长期依赖,即隐藏状态在时间上传递过程中可能会丢失之前的信息。为了解决这个问题,引入了长短时记忆网络 (LSTM) 和门控循环单元 (GRU)。这两种网络结构在隐藏层中增加了门控机制,能够更好地控制信息的传递。

 其中符号及表示意思如下:

 LSTM中有三个门:
(1)遗忘门f:决定上一个时刻的记忆单元状态需要遗忘多少信息,保留多少信息到当前记忆单元状态。
(2)输入门i:控制当前时刻输入信息候选状态有多少信息需要保存到当前记忆单元状态。
(3)输出门o:控制当前时刻的记忆单元状态有多少信息需要输出给外部状态。

形象的例子让我们更好的理解LSTM的原理:

假设你是一个梦想远大的学生,你想通过学习一门课程获得更多的知识。在学习过程中,LSTM模型帮助你,它就像是一个老师,它的遗忘门就像是老师的提醒,它让你挑出不用的知识,以保持你对重要知识的清晰记忆。它的输入门就像是老师的指导,它会重新审视你学习过的知识,按照自己的逻辑把知识结合起来,进化出更多有用的知识。最后,它的输出门就像老师的监督,它会确保你学习到了有用的知识,不要浪费时间去学习无用的知识。

二、LSTM网络运用-预测上证指数走势

# 使用LSTM预测沪市指数
import numpy as np
import pandas as pd
from keras.models import Sequential
from keras.layers import Dense
from keras.layers import LSTM
from keras.layers import Dropout
from pandas import DataFrame
from pandas import concat
from itertools import chain
from sklearn.metrics import mean_squared_error
import matplotlib.pyplot as plt
plt.rcParams['font.family'] = ['sans-serif']
plt.rcParams['font.sans-serif'] = ['SimHei']# 转化为可以用于监督学习的数据
def get_train_set(data_set, timesteps_in, timesteps_out=1):train_data_set = np.array(data_set)reframed_train_data_set = np.array(series_to_supervised(train_data_set, timesteps_in, timesteps_out).values)train_x, train_y = reframed_train_data_set[:, :-timesteps_out], reframed_train_data_set[:, -timesteps_out:]# 将数据集重构为符合LSTM要求的数据格式,即 [样本数,时间步,特征]train_x = train_x.reshape((train_x.shape[0], timesteps_in, 1))return train_x, train_y"""
将时间序列数据转换为适用于监督学习的数据
给定输入、输出序列的长度
data: 观察序列
n_in: 观测数据input(X)的步长,范围[1, len(data)], 默认为1
n_out: 观测数据output(y)的步长, 范围为[0, len(data)-1], 默认为1
dropnan: 是否删除NaN行
返回值:适用于监督学习的 DataFrame
"""
def series_to_supervised(data, n_in=1, n_out=1, dropnan=True):print(data.shape)n_vars = 1 if type(data) is list else data.shape[1]df = DataFrame(data)cols, names = list(), list()# input sequence (t-n, ... t-1)for i in range(n_in, 0, -1):cols.append(df.shift(i))names += [('var%d(t-%d)' % (j + 1, i)) for j in range(n_vars)]# 预测序列 (t, t+1, ... t+n)for i in range(0, n_out):cols.append(df.shift(-i))if i == 0:names += [('var%d(t)' % (j + 1)) for j in range(n_vars)]else:names += [('var%d(t+%d)' % (j + 1, i)) for j in range(n_vars)]# 拼接到一起agg = concat(cols, axis=1)agg.columns = names# 去掉NaN行if dropnan:agg.dropna(inplace=True)return agg# 使用LSTM进行预测
def lstm_model(source_data_set, train_x, label_y, input_epochs, input_batch_size, timesteps_out):model = Sequential()# 第一层, 隐藏层神经元节点个数为128, 返回整个序列model.add(LSTM(128, return_sequences=True, activation='tanh', input_shape=(train_x.shape[1], train_x.shape[2])))# 第二层,隐藏层神经元节点个数为128, 只返回序列最后一个输出model.add(LSTM(128, return_sequences=False))model.add(Dropout(0.5))# 第三层 因为是回归问题所以使用linearmodel.add(Dense(timesteps_out, activation='linear'))model.compile(loss='mean_squared_error', optimizer='adam')# LSTM训练 input_epochs次数res = model.fit(train_x, label_y, epochs=input_epochs, batch_size=input_batch_size, verbose=2, shuffle=False)# 模型预测train_predict = model.predict(train_x)#test_data_list = list(chain(*test_data))train_predict_list = list(chain(*train_predict))plt.plot(res.history['loss'], label='train')plt.show()#print(model.summary())plot_img(source_data_set, train_predict)# 呈现原始数据,训练结果,验证结果,预测结果
def plot_img(source_data_set, train_predict):plt.figure(figsize=(24, 8))# 原始数据蓝色plt.plot(source_data_set[:, -1], c='b',label = '标签')# 训练数据绿色plt.plot([x for x in train_predict], c='g')plt.legend()plt.show()# 设置观测数据input(X)的步长(时间步),epochs,batch_size
timesteps_in = 3
timesteps_out = 3
epochs = 1000
batch_size = 100
data = pd.read_csv('./shanghai_index_1990_12_19_to_2019_12_11.csv')
data_set = data[['Price']].values.astype('float64')
# 转化为可以用于监督学习的数据
train_x, label_y = get_train_set(data_set, timesteps_in=timesteps_in, timesteps_out=timesteps_out)print(train_x, label_y )
print(train_x.shape)
print(train_x.shape[1], train_x.shape[2])# 使用LSTM进行训练、预测
lstm_model(data_set, train_x, label_y, epochs, batch_size, timesteps_out=timesteps_out)

运行结果:

相关文章:

人工智能基础部分13-LSTM网络:预测上证指数走势

大家好,我是微学AI,今天给大家介绍一下LSTM网络,主要运用于解决序列问题。 一、LSTM网络简单介绍 LSTM又称为:长短期记忆网络,它是一种特殊的 RNN。LSTM网络主要是为了解决长序列训练过程中的梯度消失和梯度爆炸问题…...

内网穿透/组网/设备上云平台EasyNTS上云网关的安装操作指南

EasyNTS上云网关的主要作用是解决异地视频共享/组网/上云的需求,网页对域名进行添加映射时,添加成功后会生成一个外网访问地址,在浏览器中输入外网访问地址,即可查看内网应用。无需开放端口,EasyNTS上云网关平台会向Ea…...

易点天下基于 StarRocks 全面构建实时离线一体的湖仓方案

作者:易点天下数据平台团队易点天下是一家技术驱动发展的企业国际化智能营销服务公司,致力于为客户提供全球营销推广服务,通过效果营销、品牌塑造、垂直行业解决方案等一体化服务,帮助企业在全球范围内高效地获取用户、提升品牌知…...

Tomcat的类加载机制

不遵循双亲委托 在JVM中并不是一次性地把所有的文件都加载到,而是按需加载,加载机制采用 双亲委托原则,如下图所示: BootStrapClassLoader 引导类加载器ExtClassLoader 扩展类加载器AppClassLoader 应用类加载器CustomClassLoad…...

【shell 编程大全】数组,逻辑判断以及循环

数组,逻辑判断以及循环1. 概述 大家好,我又来了。今天呢我们继续学习shell相关的知识。还是老样子我们先回顾下上一次【脚本交互 以及表达式】学习到的知识 登录shell 关联配置文件什么是子shellumask 修改默认权限read 基础表达式 简单计算表达式expr 计…...

Android13 Bluetooth更新

目录 Android 13 版本说明 LE Audio 代码更新 Android 12代码路径 Android 13代码路径 Android 13 版本说明 里面对蓝牙更新的描述较少,一出提到蓝牙的一...

手工测试混了5年,年底接到了被裁员的消息....

大家都比较看好软件测试行业,只是因为表面上看起来:钱多事少加班少。其实这个都是针对个人运气好的童人才会有此待遇。在不同的阶段做好不同阶段的事情,才有可能离这个目标更近,作为一枚软件测试人员,也许下面才是我们…...

Umi框架

什么是 umi umi 是由 dva 的开发者 云谦 编写的一个新的 React 开发框架。umi 既是一个框架也是一个工具,可以将它简单的理解为一个专注性能的类 next.js 前端框架,并通过约定、自动生成和解析代码等方式来辅助开发,减少开发者的代码量。 u…...

教你学git

前言 git是一种用于多人合作写项目。详细说明如下 文章目录前言什么是版本控制?什么是 Git?它就属于人工版本控制器版本控制工具常见版本控制工具怎么工作的?git 文件生命周期状态区域安装配置-- global检查配置创建仓库工作流与基本操作查看…...

【工作笔记】syslog,kern.log大量写入invalid cookie错误信息问题

任务描述 错误出现出现过四五次,应该是诊断单元tf卡读写出问题导致下面这条告警一直高频写入到/var/log/下的syslog、kern.log、messages中 Nov 23 06:25:12 embest kernel: omap_hsmmc 48060000.mmc: [omap_hsmmc_pre_dma_transfer] invalid cookie: data->hos…...

【C++】多线程

多任务处理有两种形式,即:多进程和多线程。 基于进程的多任务处理是程序的并发执行。基于线程的多任务处理是同一程序的片段的并发执行 文章目录1. 多线程介绍2. Windows多线程1. 多线程介绍 每一个进程(可执行程序)都有一个主线…...

0202插入删除-算法第四版红黑树-红黑树-数据结构和算法(Java)

文章目录4 插入4.1 序4.2 向单个2-结点插入新键4.3 向树底部的2-结点插入新键4.4 向一棵双键树(3-结点)中插入新键4.5 颜色调整4.6 根结点总是黑色4.7 向树底部的3-结点插入新键4.8 将红链接在树中向上传递4.9 实现5 删除5.1 删除最小键5.2 删除6 有序性…...

vue 生成二维码插件 vue-qr使用方法

一、安装 npm install vue-qr --save二、引入 import VueQr from vue-qrcomponents:{VueQr,},三、使用 <vue-qr:text"dyQrcode":size"170":logoSrc"logo":margin"6":logoScale"0.2"></vue-qr>四、属性说明 …...

网络工程课(二)

ensp配置vlan 一、配置计算机ip地址和子网掩码 二、配置交换机LSW1 system-view [Huawei]sysname SW1 [SW1]vlan batch 10 20 [SW1]interface Ethernet0/0/1 [SW1-Ethernet0/0/1]port link-type access 将接口设为access接口 [SW1-Ethernet0/0/1]port default vlan 10 [SW1-E…...

Pytorch并行计算(三): 梯度累加

梯度累加 梯度累加&#xff08;Gradient Accmulation&#xff09;是一种增大训练时batch size的技巧。当batch size在一张卡放不下时&#xff0c;可以将很大的batch size分解为一个个小的mini batch&#xff0c;分别计算每一个mini batch的梯度&#xff0c;然后将其累加起来优…...

蓝桥杯入门即劝退(十八)最小覆盖子串(滑动窗口解法)

欢迎关注点赞评论&#xff0c;共同学习&#xff0c;共同进步&#xff01; ------持续更新蓝桥杯入门系列算法实例-------- 如果你也喜欢Java和算法&#xff0c;欢迎订阅专栏共同学习交流&#xff01; 你的点赞、关注、评论、是我创作的动力&#xff01; -------希望我的文章…...

Android一~

进程和线程的区别https://zhuanlan.zhihu.com/p/60375108https://zhuanlan.zhihu.com/p/138689342线程池的用法和原理tcp三次握手和四次挥手、tcp基础http请求报文格式二叉树中序遍历&#xff08;算法&#xff09;activity启动模式OKhttp源码讲解Java修饰符Java线程同步的方法s…...

一月券商金工精选

✦研报目录✦ ✦简述✦ 按发布时间排序 国盛证券 “薪火”量化分析系列研究&#xff08;二&#xff09;-票据逾期数据中的选股信息 发布日期&#xff1a;2023-01-04 关键词&#xff1a;股票、票据、票据预期 主要内容&#xff1a;本文深入探讨了“票据持续逾期名单”这一…...

UML中常见的9种图

UML是Unified Model Language的缩写&#xff0c;中文是统一建模语言&#xff0c;是由一整套图表组成的标准化建模语言。UML用于帮助系统开发人员阐明&#xff0c;展示&#xff0c;构建和记录软件系统的产出。通过使用UML使得在软件开发之前&#xff0c; 对整个软件设计有更好的…...

使用SpringBoot实现无限级评论回复功能

评论功能已经成为APP和网站开发中的必备功能。本文采用springbootmybatis-plus框架,通过代码主要介绍评论功能的数据库设计和接口数据返回。我们返回的格式可以分三种方案,第一种方案是先返回评论,再根据评论id返回回复信息,第二种方案是将评论回复直接封装成一个类似于树的数据…...

AD快捷键避坑指南:为什么你的自定义快捷键总是不生效?

AD快捷键避坑指南&#xff1a;为什么你的自定义快捷键总是不生效&#xff1f; 在AD&#xff08;Altium Designer&#xff09;这个功能强大的电子设计自动化软件中&#xff0c;快捷键是提升工作效率的利器。但很多用户都遇到过这样的困扰&#xff1a;明明按照教程设置了自定义快…...

别再踩坑了!Jetson Nano/Xavier NX上PyTorch和torchvision版本匹配保姆级指南(含JetPack 5/6)

Jetson设备PyTorch环境配置终极避坑手册&#xff1a;从版本匹配到性能调优 刚拿到Jetson Nano或Xavier NX的开发者们&#xff0c;十个里有九个会在PyTorch环境配置上栽跟头。不是torchvision报错就是CUDA不可用&#xff0c;最崩溃的是好不容易装好了却发现性能还不如树莓派。本…...

Mojo加速Python科学计算:如何在72小时内将AI推理速度提升8.6倍(附完整可运行代码)

第一章&#xff1a;Mojo与Python混合编程概述Mojo 是一种为 AI 系统量身打造的现代系统编程语言&#xff0c;兼具 Python 的易用性与 C/C 的执行效率。它原生兼容 Python 生态&#xff0c;允许开发者在同一个项目中无缝调用 Python 模块、复用现有 NumPy/Torch 代码&#xff0c…...

万物识别在智能体(Skills Agent)中的集成应用

万物识别在智能体(Skills Agent)中的集成应用 想象一下&#xff0c;你正在开发一个智能客服机器人&#xff0c;用户发来一张照片&#xff0c;里面是自家厨房水槽下漏水的一堆零件。用户问&#xff1a;“这是什么东西坏了&#xff1f;我该买什么配件&#xff1f;” 传统的文本对…...

忍者像素绘卷参数详解:如何通过提示词触发‘火之意志’专属风格权重

忍者像素绘卷参数详解&#xff1a;如何通过提示词触发火之意志专属风格权重 1. 认识忍者像素绘卷 忍者像素绘卷是一款基于Z-Image-Turbo深度优化的图像生成工具&#xff0c;它将传统忍者文化与16-Bit复古游戏美学完美结合。这款工具特别适合创作具有热血动漫风格的像素艺术作…...

Java微服务集成TranslateGemma:企业级翻译中台构建

Java微服务集成TranslateGemma&#xff1a;企业级翻译中台构建 1. 为什么需要企业级翻译中台 最近在给一家跨境电商平台做技术咨询时&#xff0c;客户提到一个很实际的问题&#xff1a;他们的客服系统、商品管理系统、营销内容平台各自维护着不同的翻译逻辑。客服用的是第三方…...

Kandinsky-5.0-I2V-Lite-5s效果展示:建筑图纸→镜头平移漫游视频生成案例

Kandinsky-5.0-I2V-Lite-5s效果展示&#xff1a;建筑图纸→镜头平移漫游视频生成案例 1. 惊艳效果预览 Kandinsky-5.0-I2V-Lite-5s带来的建筑漫游视频生成效果令人印象深刻。想象一下&#xff0c;你有一张静态的建筑设计图纸&#xff0c;通过这个模型&#xff0c;只需简单描述…...

王二明古方草解毒茶商城模式解析

王二明古方草解毒茶商城模式解析&#xff1a;架构、争议与合规思考在社交电商与大健康产业的交叉赛道中&#xff0c;“王二明古方草解毒茶”凭借其独特的草本茶饮定位与多级分销模式&#xff0c;曾一度引发市场关注。该模式以产品为核心&#xff0c;通过数字化商城系统构建了一…...

从Stable Diffusion到多模态大模型:图文交错数据如何让AI学会‘边想边画’?

图文交错数据&#xff1a;多模态大模型实现"边想边画"的关键突破 当Stable Diffusion以惊艳的画质震惊世界时&#xff0c;人们很快发现它存在一个根本局限——这个能画出精美图像的模型&#xff0c;却无法理解自己笔下的内容。与此同时&#xff0c;擅长理解图像的多模…...

精通ComfyUI-BrushNet:专业图像修复全流程指南

精通ComfyUI-BrushNet&#xff1a;专业图像修复全流程指南 【免费下载链接】ComfyUI-BrushNet ComfyUI BrushNet nodes 项目地址: https://gitcode.com/gh_mirrors/co/ComfyUI-BrushNet ComfyUI-BrushNet是一款功能强大的图像修复工具&#xff0c;通过节点式工作流实现专…...