当前位置: 首页 > news >正文

人工智能基础部分13-LSTM网络:预测上证指数走势

大家好,我是微学AI,今天给大家介绍一下LSTM网络,主要运用于解决序列问题。

一、LSTM网络简单介绍

LSTM又称为:长短期记忆网络,它是一种特殊的 RNN。LSTM网络主要是为了解决长序列训练过程中的梯度消失和梯度爆炸问题。对于相比普通的RNN,LSTM能够在更长的序列中有更好的表现。

引入LSTM网络的原因:由于 RNN 网络主要问题是长期依赖,即隐藏状态在时间上传递过程中可能会丢失之前的信息。为了解决这个问题,引入了长短时记忆网络 (LSTM) 和门控循环单元 (GRU)。这两种网络结构在隐藏层中增加了门控机制,能够更好地控制信息的传递。

 其中符号及表示意思如下:

 LSTM中有三个门:
(1)遗忘门f:决定上一个时刻的记忆单元状态需要遗忘多少信息,保留多少信息到当前记忆单元状态。
(2)输入门i:控制当前时刻输入信息候选状态有多少信息需要保存到当前记忆单元状态。
(3)输出门o:控制当前时刻的记忆单元状态有多少信息需要输出给外部状态。

形象的例子让我们更好的理解LSTM的原理:

假设你是一个梦想远大的学生,你想通过学习一门课程获得更多的知识。在学习过程中,LSTM模型帮助你,它就像是一个老师,它的遗忘门就像是老师的提醒,它让你挑出不用的知识,以保持你对重要知识的清晰记忆。它的输入门就像是老师的指导,它会重新审视你学习过的知识,按照自己的逻辑把知识结合起来,进化出更多有用的知识。最后,它的输出门就像老师的监督,它会确保你学习到了有用的知识,不要浪费时间去学习无用的知识。

二、LSTM网络运用-预测上证指数走势

# 使用LSTM预测沪市指数
import numpy as np
import pandas as pd
from keras.models import Sequential
from keras.layers import Dense
from keras.layers import LSTM
from keras.layers import Dropout
from pandas import DataFrame
from pandas import concat
from itertools import chain
from sklearn.metrics import mean_squared_error
import matplotlib.pyplot as plt
plt.rcParams['font.family'] = ['sans-serif']
plt.rcParams['font.sans-serif'] = ['SimHei']# 转化为可以用于监督学习的数据
def get_train_set(data_set, timesteps_in, timesteps_out=1):train_data_set = np.array(data_set)reframed_train_data_set = np.array(series_to_supervised(train_data_set, timesteps_in, timesteps_out).values)train_x, train_y = reframed_train_data_set[:, :-timesteps_out], reframed_train_data_set[:, -timesteps_out:]# 将数据集重构为符合LSTM要求的数据格式,即 [样本数,时间步,特征]train_x = train_x.reshape((train_x.shape[0], timesteps_in, 1))return train_x, train_y"""
将时间序列数据转换为适用于监督学习的数据
给定输入、输出序列的长度
data: 观察序列
n_in: 观测数据input(X)的步长,范围[1, len(data)], 默认为1
n_out: 观测数据output(y)的步长, 范围为[0, len(data)-1], 默认为1
dropnan: 是否删除NaN行
返回值:适用于监督学习的 DataFrame
"""
def series_to_supervised(data, n_in=1, n_out=1, dropnan=True):print(data.shape)n_vars = 1 if type(data) is list else data.shape[1]df = DataFrame(data)cols, names = list(), list()# input sequence (t-n, ... t-1)for i in range(n_in, 0, -1):cols.append(df.shift(i))names += [('var%d(t-%d)' % (j + 1, i)) for j in range(n_vars)]# 预测序列 (t, t+1, ... t+n)for i in range(0, n_out):cols.append(df.shift(-i))if i == 0:names += [('var%d(t)' % (j + 1)) for j in range(n_vars)]else:names += [('var%d(t+%d)' % (j + 1, i)) for j in range(n_vars)]# 拼接到一起agg = concat(cols, axis=1)agg.columns = names# 去掉NaN行if dropnan:agg.dropna(inplace=True)return agg# 使用LSTM进行预测
def lstm_model(source_data_set, train_x, label_y, input_epochs, input_batch_size, timesteps_out):model = Sequential()# 第一层, 隐藏层神经元节点个数为128, 返回整个序列model.add(LSTM(128, return_sequences=True, activation='tanh', input_shape=(train_x.shape[1], train_x.shape[2])))# 第二层,隐藏层神经元节点个数为128, 只返回序列最后一个输出model.add(LSTM(128, return_sequences=False))model.add(Dropout(0.5))# 第三层 因为是回归问题所以使用linearmodel.add(Dense(timesteps_out, activation='linear'))model.compile(loss='mean_squared_error', optimizer='adam')# LSTM训练 input_epochs次数res = model.fit(train_x, label_y, epochs=input_epochs, batch_size=input_batch_size, verbose=2, shuffle=False)# 模型预测train_predict = model.predict(train_x)#test_data_list = list(chain(*test_data))train_predict_list = list(chain(*train_predict))plt.plot(res.history['loss'], label='train')plt.show()#print(model.summary())plot_img(source_data_set, train_predict)# 呈现原始数据,训练结果,验证结果,预测结果
def plot_img(source_data_set, train_predict):plt.figure(figsize=(24, 8))# 原始数据蓝色plt.plot(source_data_set[:, -1], c='b',label = '标签')# 训练数据绿色plt.plot([x for x in train_predict], c='g')plt.legend()plt.show()# 设置观测数据input(X)的步长(时间步),epochs,batch_size
timesteps_in = 3
timesteps_out = 3
epochs = 1000
batch_size = 100
data = pd.read_csv('./shanghai_index_1990_12_19_to_2019_12_11.csv')
data_set = data[['Price']].values.astype('float64')
# 转化为可以用于监督学习的数据
train_x, label_y = get_train_set(data_set, timesteps_in=timesteps_in, timesteps_out=timesteps_out)print(train_x, label_y )
print(train_x.shape)
print(train_x.shape[1], train_x.shape[2])# 使用LSTM进行训练、预测
lstm_model(data_set, train_x, label_y, epochs, batch_size, timesteps_out=timesteps_out)

运行结果:

相关文章:

人工智能基础部分13-LSTM网络:预测上证指数走势

大家好,我是微学AI,今天给大家介绍一下LSTM网络,主要运用于解决序列问题。 一、LSTM网络简单介绍 LSTM又称为:长短期记忆网络,它是一种特殊的 RNN。LSTM网络主要是为了解决长序列训练过程中的梯度消失和梯度爆炸问题…...

内网穿透/组网/设备上云平台EasyNTS上云网关的安装操作指南

EasyNTS上云网关的主要作用是解决异地视频共享/组网/上云的需求,网页对域名进行添加映射时,添加成功后会生成一个外网访问地址,在浏览器中输入外网访问地址,即可查看内网应用。无需开放端口,EasyNTS上云网关平台会向Ea…...

易点天下基于 StarRocks 全面构建实时离线一体的湖仓方案

作者:易点天下数据平台团队易点天下是一家技术驱动发展的企业国际化智能营销服务公司,致力于为客户提供全球营销推广服务,通过效果营销、品牌塑造、垂直行业解决方案等一体化服务,帮助企业在全球范围内高效地获取用户、提升品牌知…...

Tomcat的类加载机制

不遵循双亲委托 在JVM中并不是一次性地把所有的文件都加载到,而是按需加载,加载机制采用 双亲委托原则,如下图所示: BootStrapClassLoader 引导类加载器ExtClassLoader 扩展类加载器AppClassLoader 应用类加载器CustomClassLoad…...

【shell 编程大全】数组,逻辑判断以及循环

数组,逻辑判断以及循环1. 概述 大家好,我又来了。今天呢我们继续学习shell相关的知识。还是老样子我们先回顾下上一次【脚本交互 以及表达式】学习到的知识 登录shell 关联配置文件什么是子shellumask 修改默认权限read 基础表达式 简单计算表达式expr 计…...

Android13 Bluetooth更新

目录 Android 13 版本说明 LE Audio 代码更新 Android 12代码路径 Android 13代码路径 Android 13 版本说明 里面对蓝牙更新的描述较少,一出提到蓝牙的一...

手工测试混了5年,年底接到了被裁员的消息....

大家都比较看好软件测试行业,只是因为表面上看起来:钱多事少加班少。其实这个都是针对个人运气好的童人才会有此待遇。在不同的阶段做好不同阶段的事情,才有可能离这个目标更近,作为一枚软件测试人员,也许下面才是我们…...

Umi框架

什么是 umi umi 是由 dva 的开发者 云谦 编写的一个新的 React 开发框架。umi 既是一个框架也是一个工具,可以将它简单的理解为一个专注性能的类 next.js 前端框架,并通过约定、自动生成和解析代码等方式来辅助开发,减少开发者的代码量。 u…...

教你学git

前言 git是一种用于多人合作写项目。详细说明如下 文章目录前言什么是版本控制?什么是 Git?它就属于人工版本控制器版本控制工具常见版本控制工具怎么工作的?git 文件生命周期状态区域安装配置-- global检查配置创建仓库工作流与基本操作查看…...

【工作笔记】syslog,kern.log大量写入invalid cookie错误信息问题

任务描述 错误出现出现过四五次,应该是诊断单元tf卡读写出问题导致下面这条告警一直高频写入到/var/log/下的syslog、kern.log、messages中 Nov 23 06:25:12 embest kernel: omap_hsmmc 48060000.mmc: [omap_hsmmc_pre_dma_transfer] invalid cookie: data->hos…...

【C++】多线程

多任务处理有两种形式,即:多进程和多线程。 基于进程的多任务处理是程序的并发执行。基于线程的多任务处理是同一程序的片段的并发执行 文章目录1. 多线程介绍2. Windows多线程1. 多线程介绍 每一个进程(可执行程序)都有一个主线…...

0202插入删除-算法第四版红黑树-红黑树-数据结构和算法(Java)

文章目录4 插入4.1 序4.2 向单个2-结点插入新键4.3 向树底部的2-结点插入新键4.4 向一棵双键树(3-结点)中插入新键4.5 颜色调整4.6 根结点总是黑色4.7 向树底部的3-结点插入新键4.8 将红链接在树中向上传递4.9 实现5 删除5.1 删除最小键5.2 删除6 有序性…...

vue 生成二维码插件 vue-qr使用方法

一、安装 npm install vue-qr --save二、引入 import VueQr from vue-qrcomponents:{VueQr,},三、使用 <vue-qr:text"dyQrcode":size"170":logoSrc"logo":margin"6":logoScale"0.2"></vue-qr>四、属性说明 …...

网络工程课(二)

ensp配置vlan 一、配置计算机ip地址和子网掩码 二、配置交换机LSW1 system-view [Huawei]sysname SW1 [SW1]vlan batch 10 20 [SW1]interface Ethernet0/0/1 [SW1-Ethernet0/0/1]port link-type access 将接口设为access接口 [SW1-Ethernet0/0/1]port default vlan 10 [SW1-E…...

Pytorch并行计算(三): 梯度累加

梯度累加 梯度累加&#xff08;Gradient Accmulation&#xff09;是一种增大训练时batch size的技巧。当batch size在一张卡放不下时&#xff0c;可以将很大的batch size分解为一个个小的mini batch&#xff0c;分别计算每一个mini batch的梯度&#xff0c;然后将其累加起来优…...

蓝桥杯入门即劝退(十八)最小覆盖子串(滑动窗口解法)

欢迎关注点赞评论&#xff0c;共同学习&#xff0c;共同进步&#xff01; ------持续更新蓝桥杯入门系列算法实例-------- 如果你也喜欢Java和算法&#xff0c;欢迎订阅专栏共同学习交流&#xff01; 你的点赞、关注、评论、是我创作的动力&#xff01; -------希望我的文章…...

Android一~

进程和线程的区别https://zhuanlan.zhihu.com/p/60375108https://zhuanlan.zhihu.com/p/138689342线程池的用法和原理tcp三次握手和四次挥手、tcp基础http请求报文格式二叉树中序遍历&#xff08;算法&#xff09;activity启动模式OKhttp源码讲解Java修饰符Java线程同步的方法s…...

一月券商金工精选

✦研报目录✦ ✦简述✦ 按发布时间排序 国盛证券 “薪火”量化分析系列研究&#xff08;二&#xff09;-票据逾期数据中的选股信息 发布日期&#xff1a;2023-01-04 关键词&#xff1a;股票、票据、票据预期 主要内容&#xff1a;本文深入探讨了“票据持续逾期名单”这一…...

UML中常见的9种图

UML是Unified Model Language的缩写&#xff0c;中文是统一建模语言&#xff0c;是由一整套图表组成的标准化建模语言。UML用于帮助系统开发人员阐明&#xff0c;展示&#xff0c;构建和记录软件系统的产出。通过使用UML使得在软件开发之前&#xff0c; 对整个软件设计有更好的…...

使用SpringBoot实现无限级评论回复功能

评论功能已经成为APP和网站开发中的必备功能。本文采用springbootmybatis-plus框架,通过代码主要介绍评论功能的数据库设计和接口数据返回。我们返回的格式可以分三种方案,第一种方案是先返回评论,再根据评论id返回回复信息,第二种方案是将评论回复直接封装成一个类似于树的数据…...

Kafka 介绍和使用

文章目录前言1、Kafka 系统架构1.1、Producer 生产者1.2、Consumer 消费者1.3、Consumer Group 消费者群组1.4、Topic 主题1.5、Partition 分区1.6、Log 日志存储1.7、Broker 服务器1.8、Offset 偏移量1.9、Replication 副本1.10、Zookeeper2、Kafka 环境搭建2.1、下载 Kafka2.…...

[学习笔记]Rocket.Chat业务数据备份

Rocket.Chat 的业务数据主要存储于mongodb数据库的rocketchat库中&#xff0c;聊天中通过发送文件功能产生的文件储存于/app/uploads中&#xff08;文件方式设置为"FileSystem"&#xff09;&#xff0c;因此在对Rocket.Chat做数据移动或备份主要分为两步&#xff0c;…...

【ZOJ 1090】The Circumference of the Circle 题解(海伦公式+正弦定理推论)

计算圆的周长似乎是一项简单的任务——只要你知道它的直径。但如果你没有呢&#xff1f; 我们给出了平面中三个非共线点的笛卡尔坐标。 您的工作是计算与所有三个点相交的唯一圆的周长。 输入规范 输入文件将包含一个或多个测试用例。每个测试用例由一条包含六个实数x1、y1、x…...

【go】slice原理

slice包含3个部分&#xff1a; 1.内存的起始位置 2.切片的大小(已经存放的元素数量) 3.容量(可以存放的元素数量) 使用make初始化切片会开辟底层内存&#xff0c;并初始化元素值为默认值&#xff0c;如数字为0&#xff0c;字符串为空 使用New初始化切片不会开辟底层数组&…...

【数据库】MySQL概念知识语法-基础篇(DQL),真的很详细,一篇文章你就会了

目录通用语法及分类DQL&#xff08;数据查询语言&#xff09;基础查询条件查询聚合查询&#xff08;聚合函数&#xff09;分组查询排序查询分页查询内连接查询外连接查询自连接查询联合查询子查询列子查询行子查询表子查询总结通用语法及分类 ● DDL: 数据定义语言&#xff0c…...

博客界的至高神:属于自己的WordPress网站,你值得拥有!

【如果暂时没时间安装&#xff0c;可以直接跳转到最后先看展示效果】 很多朋友都想有一个对外展示的窗口&#xff0c;在那里放一些个人的作品或者其他想对外分享的东西。大部分人选择了在微博、公众号等平台&#xff0c;毕竟这些平台流量大&#xff0c;我们可以很轻易地把自己…...

操作系统(day13)-- 虚拟内存;页面分配策略

虚拟内存管理 虚拟内存的基本概念 传统存储管理方式的特征、缺点 一次性&#xff1a; 作业必须一次性全部装入内存后才能开始运行。驻留性&#xff1a;作业一旦被装入内存&#xff0c;就会一直驻留在内存中&#xff0c;直至作业运行结束。事实上&#xff0c;在一个时间段内&…...

SQL零基础入门学习(四)

SQL零基础入门学习&#xff08;三&#xff09; SQL INSERT INTO 语句 INSERT INTO 语句用于向表中插入新记录。 SQL INSERT INTO 语法 INSERT INTO 语句可以有两种编写形式。 第一种形式无需指定要插入数据的列名&#xff0c;只需提供被插入的值即可&#xff1a; INSERT …...

19岁就患老年痴呆!这些前兆别忽视!

在大部分人的印象中&#xff0c;阿尔兹海默症好像是专属于老年人的疾病&#xff0c;而且它的另一个名字就是老年痴呆症。然而&#xff0c;前不久&#xff0c;一位19岁的男生患上了阿尔兹海默症&#xff0c;是迄今为止最年轻的患者。这个男生从17岁开始&#xff0c;就出现了注意…...

【C++】thread|mutex|atomic|condition_variable

本篇博客&#xff0c;让我们来认识一下C中的线程操作 所用编译器&#xff1a;vs2019 阅读本文前&#xff0c;建议先了解线程的概念 &#x1f449; 线程概念 1.基本介绍 在不同的操作系统&#xff0c;windows、linux、mac上&#xff0c;都会对多线程操作提供自己的系统调用接口…...