当前位置: 首页 > news >正文

人工智能基础部分13-LSTM网络:预测上证指数走势

大家好,我是微学AI,今天给大家介绍一下LSTM网络,主要运用于解决序列问题。

一、LSTM网络简单介绍

LSTM又称为:长短期记忆网络,它是一种特殊的 RNN。LSTM网络主要是为了解决长序列训练过程中的梯度消失和梯度爆炸问题。对于相比普通的RNN,LSTM能够在更长的序列中有更好的表现。

引入LSTM网络的原因:由于 RNN 网络主要问题是长期依赖,即隐藏状态在时间上传递过程中可能会丢失之前的信息。为了解决这个问题,引入了长短时记忆网络 (LSTM) 和门控循环单元 (GRU)。这两种网络结构在隐藏层中增加了门控机制,能够更好地控制信息的传递。

 其中符号及表示意思如下:

 LSTM中有三个门:
(1)遗忘门f:决定上一个时刻的记忆单元状态需要遗忘多少信息,保留多少信息到当前记忆单元状态。
(2)输入门i:控制当前时刻输入信息候选状态有多少信息需要保存到当前记忆单元状态。
(3)输出门o:控制当前时刻的记忆单元状态有多少信息需要输出给外部状态。

形象的例子让我们更好的理解LSTM的原理:

假设你是一个梦想远大的学生,你想通过学习一门课程获得更多的知识。在学习过程中,LSTM模型帮助你,它就像是一个老师,它的遗忘门就像是老师的提醒,它让你挑出不用的知识,以保持你对重要知识的清晰记忆。它的输入门就像是老师的指导,它会重新审视你学习过的知识,按照自己的逻辑把知识结合起来,进化出更多有用的知识。最后,它的输出门就像老师的监督,它会确保你学习到了有用的知识,不要浪费时间去学习无用的知识。

二、LSTM网络运用-预测上证指数走势

# 使用LSTM预测沪市指数
import numpy as np
import pandas as pd
from keras.models import Sequential
from keras.layers import Dense
from keras.layers import LSTM
from keras.layers import Dropout
from pandas import DataFrame
from pandas import concat
from itertools import chain
from sklearn.metrics import mean_squared_error
import matplotlib.pyplot as plt
plt.rcParams['font.family'] = ['sans-serif']
plt.rcParams['font.sans-serif'] = ['SimHei']# 转化为可以用于监督学习的数据
def get_train_set(data_set, timesteps_in, timesteps_out=1):train_data_set = np.array(data_set)reframed_train_data_set = np.array(series_to_supervised(train_data_set, timesteps_in, timesteps_out).values)train_x, train_y = reframed_train_data_set[:, :-timesteps_out], reframed_train_data_set[:, -timesteps_out:]# 将数据集重构为符合LSTM要求的数据格式,即 [样本数,时间步,特征]train_x = train_x.reshape((train_x.shape[0], timesteps_in, 1))return train_x, train_y"""
将时间序列数据转换为适用于监督学习的数据
给定输入、输出序列的长度
data: 观察序列
n_in: 观测数据input(X)的步长,范围[1, len(data)], 默认为1
n_out: 观测数据output(y)的步长, 范围为[0, len(data)-1], 默认为1
dropnan: 是否删除NaN行
返回值:适用于监督学习的 DataFrame
"""
def series_to_supervised(data, n_in=1, n_out=1, dropnan=True):print(data.shape)n_vars = 1 if type(data) is list else data.shape[1]df = DataFrame(data)cols, names = list(), list()# input sequence (t-n, ... t-1)for i in range(n_in, 0, -1):cols.append(df.shift(i))names += [('var%d(t-%d)' % (j + 1, i)) for j in range(n_vars)]# 预测序列 (t, t+1, ... t+n)for i in range(0, n_out):cols.append(df.shift(-i))if i == 0:names += [('var%d(t)' % (j + 1)) for j in range(n_vars)]else:names += [('var%d(t+%d)' % (j + 1, i)) for j in range(n_vars)]# 拼接到一起agg = concat(cols, axis=1)agg.columns = names# 去掉NaN行if dropnan:agg.dropna(inplace=True)return agg# 使用LSTM进行预测
def lstm_model(source_data_set, train_x, label_y, input_epochs, input_batch_size, timesteps_out):model = Sequential()# 第一层, 隐藏层神经元节点个数为128, 返回整个序列model.add(LSTM(128, return_sequences=True, activation='tanh', input_shape=(train_x.shape[1], train_x.shape[2])))# 第二层,隐藏层神经元节点个数为128, 只返回序列最后一个输出model.add(LSTM(128, return_sequences=False))model.add(Dropout(0.5))# 第三层 因为是回归问题所以使用linearmodel.add(Dense(timesteps_out, activation='linear'))model.compile(loss='mean_squared_error', optimizer='adam')# LSTM训练 input_epochs次数res = model.fit(train_x, label_y, epochs=input_epochs, batch_size=input_batch_size, verbose=2, shuffle=False)# 模型预测train_predict = model.predict(train_x)#test_data_list = list(chain(*test_data))train_predict_list = list(chain(*train_predict))plt.plot(res.history['loss'], label='train')plt.show()#print(model.summary())plot_img(source_data_set, train_predict)# 呈现原始数据,训练结果,验证结果,预测结果
def plot_img(source_data_set, train_predict):plt.figure(figsize=(24, 8))# 原始数据蓝色plt.plot(source_data_set[:, -1], c='b',label = '标签')# 训练数据绿色plt.plot([x for x in train_predict], c='g')plt.legend()plt.show()# 设置观测数据input(X)的步长(时间步),epochs,batch_size
timesteps_in = 3
timesteps_out = 3
epochs = 1000
batch_size = 100
data = pd.read_csv('./shanghai_index_1990_12_19_to_2019_12_11.csv')
data_set = data[['Price']].values.astype('float64')
# 转化为可以用于监督学习的数据
train_x, label_y = get_train_set(data_set, timesteps_in=timesteps_in, timesteps_out=timesteps_out)print(train_x, label_y )
print(train_x.shape)
print(train_x.shape[1], train_x.shape[2])# 使用LSTM进行训练、预测
lstm_model(data_set, train_x, label_y, epochs, batch_size, timesteps_out=timesteps_out)

运行结果:

相关文章:

人工智能基础部分13-LSTM网络:预测上证指数走势

大家好,我是微学AI,今天给大家介绍一下LSTM网络,主要运用于解决序列问题。 一、LSTM网络简单介绍 LSTM又称为:长短期记忆网络,它是一种特殊的 RNN。LSTM网络主要是为了解决长序列训练过程中的梯度消失和梯度爆炸问题…...

内网穿透/组网/设备上云平台EasyNTS上云网关的安装操作指南

EasyNTS上云网关的主要作用是解决异地视频共享/组网/上云的需求,网页对域名进行添加映射时,添加成功后会生成一个外网访问地址,在浏览器中输入外网访问地址,即可查看内网应用。无需开放端口,EasyNTS上云网关平台会向Ea…...

易点天下基于 StarRocks 全面构建实时离线一体的湖仓方案

作者:易点天下数据平台团队易点天下是一家技术驱动发展的企业国际化智能营销服务公司,致力于为客户提供全球营销推广服务,通过效果营销、品牌塑造、垂直行业解决方案等一体化服务,帮助企业在全球范围内高效地获取用户、提升品牌知…...

Tomcat的类加载机制

不遵循双亲委托 在JVM中并不是一次性地把所有的文件都加载到,而是按需加载,加载机制采用 双亲委托原则,如下图所示: BootStrapClassLoader 引导类加载器ExtClassLoader 扩展类加载器AppClassLoader 应用类加载器CustomClassLoad…...

【shell 编程大全】数组,逻辑判断以及循环

数组,逻辑判断以及循环1. 概述 大家好,我又来了。今天呢我们继续学习shell相关的知识。还是老样子我们先回顾下上一次【脚本交互 以及表达式】学习到的知识 登录shell 关联配置文件什么是子shellumask 修改默认权限read 基础表达式 简单计算表达式expr 计…...

Android13 Bluetooth更新

目录 Android 13 版本说明 LE Audio 代码更新 Android 12代码路径 Android 13代码路径 Android 13 版本说明 里面对蓝牙更新的描述较少,一出提到蓝牙的一...

手工测试混了5年,年底接到了被裁员的消息....

大家都比较看好软件测试行业,只是因为表面上看起来:钱多事少加班少。其实这个都是针对个人运气好的童人才会有此待遇。在不同的阶段做好不同阶段的事情,才有可能离这个目标更近,作为一枚软件测试人员,也许下面才是我们…...

Umi框架

什么是 umi umi 是由 dva 的开发者 云谦 编写的一个新的 React 开发框架。umi 既是一个框架也是一个工具,可以将它简单的理解为一个专注性能的类 next.js 前端框架,并通过约定、自动生成和解析代码等方式来辅助开发,减少开发者的代码量。 u…...

教你学git

前言 git是一种用于多人合作写项目。详细说明如下 文章目录前言什么是版本控制?什么是 Git?它就属于人工版本控制器版本控制工具常见版本控制工具怎么工作的?git 文件生命周期状态区域安装配置-- global检查配置创建仓库工作流与基本操作查看…...

【工作笔记】syslog,kern.log大量写入invalid cookie错误信息问题

任务描述 错误出现出现过四五次,应该是诊断单元tf卡读写出问题导致下面这条告警一直高频写入到/var/log/下的syslog、kern.log、messages中 Nov 23 06:25:12 embest kernel: omap_hsmmc 48060000.mmc: [omap_hsmmc_pre_dma_transfer] invalid cookie: data->hos…...

【C++】多线程

多任务处理有两种形式,即:多进程和多线程。 基于进程的多任务处理是程序的并发执行。基于线程的多任务处理是同一程序的片段的并发执行 文章目录1. 多线程介绍2. Windows多线程1. 多线程介绍 每一个进程(可执行程序)都有一个主线…...

0202插入删除-算法第四版红黑树-红黑树-数据结构和算法(Java)

文章目录4 插入4.1 序4.2 向单个2-结点插入新键4.3 向树底部的2-结点插入新键4.4 向一棵双键树(3-结点)中插入新键4.5 颜色调整4.6 根结点总是黑色4.7 向树底部的3-结点插入新键4.8 将红链接在树中向上传递4.9 实现5 删除5.1 删除最小键5.2 删除6 有序性…...

vue 生成二维码插件 vue-qr使用方法

一、安装 npm install vue-qr --save二、引入 import VueQr from vue-qrcomponents:{VueQr,},三、使用 <vue-qr:text"dyQrcode":size"170":logoSrc"logo":margin"6":logoScale"0.2"></vue-qr>四、属性说明 …...

网络工程课(二)

ensp配置vlan 一、配置计算机ip地址和子网掩码 二、配置交换机LSW1 system-view [Huawei]sysname SW1 [SW1]vlan batch 10 20 [SW1]interface Ethernet0/0/1 [SW1-Ethernet0/0/1]port link-type access 将接口设为access接口 [SW1-Ethernet0/0/1]port default vlan 10 [SW1-E…...

Pytorch并行计算(三): 梯度累加

梯度累加 梯度累加&#xff08;Gradient Accmulation&#xff09;是一种增大训练时batch size的技巧。当batch size在一张卡放不下时&#xff0c;可以将很大的batch size分解为一个个小的mini batch&#xff0c;分别计算每一个mini batch的梯度&#xff0c;然后将其累加起来优…...

蓝桥杯入门即劝退(十八)最小覆盖子串(滑动窗口解法)

欢迎关注点赞评论&#xff0c;共同学习&#xff0c;共同进步&#xff01; ------持续更新蓝桥杯入门系列算法实例-------- 如果你也喜欢Java和算法&#xff0c;欢迎订阅专栏共同学习交流&#xff01; 你的点赞、关注、评论、是我创作的动力&#xff01; -------希望我的文章…...

Android一~

进程和线程的区别https://zhuanlan.zhihu.com/p/60375108https://zhuanlan.zhihu.com/p/138689342线程池的用法和原理tcp三次握手和四次挥手、tcp基础http请求报文格式二叉树中序遍历&#xff08;算法&#xff09;activity启动模式OKhttp源码讲解Java修饰符Java线程同步的方法s…...

一月券商金工精选

✦研报目录✦ ✦简述✦ 按发布时间排序 国盛证券 “薪火”量化分析系列研究&#xff08;二&#xff09;-票据逾期数据中的选股信息 发布日期&#xff1a;2023-01-04 关键词&#xff1a;股票、票据、票据预期 主要内容&#xff1a;本文深入探讨了“票据持续逾期名单”这一…...

UML中常见的9种图

UML是Unified Model Language的缩写&#xff0c;中文是统一建模语言&#xff0c;是由一整套图表组成的标准化建模语言。UML用于帮助系统开发人员阐明&#xff0c;展示&#xff0c;构建和记录软件系统的产出。通过使用UML使得在软件开发之前&#xff0c; 对整个软件设计有更好的…...

使用SpringBoot实现无限级评论回复功能

评论功能已经成为APP和网站开发中的必备功能。本文采用springbootmybatis-plus框架,通过代码主要介绍评论功能的数据库设计和接口数据返回。我们返回的格式可以分三种方案,第一种方案是先返回评论,再根据评论id返回回复信息,第二种方案是将评论回复直接封装成一个类似于树的数据…...

eNSP-Cloud(实现本地电脑与eNSP内设备之间通信)

说明&#xff1a; 想象一下&#xff0c;你正在用eNSP搭建一个虚拟的网络世界&#xff0c;里面有虚拟的路由器、交换机、电脑&#xff08;PC&#xff09;等等。这些设备都在你的电脑里面“运行”&#xff0c;它们之间可以互相通信&#xff0c;就像一个封闭的小王国。 但是&#…...

Java 语言特性(面试系列2)

一、SQL 基础 1. 复杂查询 &#xff08;1&#xff09;连接查询&#xff08;JOIN&#xff09; 内连接&#xff08;INNER JOIN&#xff09;&#xff1a;返回两表匹配的记录。 SELECT e.name, d.dept_name FROM employees e INNER JOIN departments d ON e.dept_id d.dept_id; 左…...

基于FPGA的PID算法学习———实现PID比例控制算法

基于FPGA的PID算法学习 前言一、PID算法分析二、PID仿真分析1. PID代码2.PI代码3.P代码4.顶层5.测试文件6.仿真波形 总结 前言 学习内容&#xff1a;参考网站&#xff1a; PID算法控制 PID即&#xff1a;Proportional&#xff08;比例&#xff09;、Integral&#xff08;积分&…...

DeepSeek 赋能智慧能源:微电网优化调度的智能革新路径

目录 一、智慧能源微电网优化调度概述1.1 智慧能源微电网概念1.2 优化调度的重要性1.3 目前面临的挑战 二、DeepSeek 技术探秘2.1 DeepSeek 技术原理2.2 DeepSeek 独特优势2.3 DeepSeek 在 AI 领域地位 三、DeepSeek 在微电网优化调度中的应用剖析3.1 数据处理与分析3.2 预测与…...

ssc377d修改flash分区大小

1、flash的分区默认分配16M、 / # df -h Filesystem Size Used Available Use% Mounted on /dev/root 1.9M 1.9M 0 100% / /dev/mtdblock4 3.0M...

Python爬虫实战:研究feedparser库相关技术

1. 引言 1.1 研究背景与意义 在当今信息爆炸的时代,互联网上存在着海量的信息资源。RSS(Really Simple Syndication)作为一种标准化的信息聚合技术,被广泛用于网站内容的发布和订阅。通过 RSS,用户可以方便地获取网站更新的内容,而无需频繁访问各个网站。 然而,互联网…...

让回归模型不再被异常值“带跑偏“,MSE和Cauchy损失函数在噪声数据环境下的实战对比

在机器学习的回归分析中&#xff0c;损失函数的选择对模型性能具有决定性影响。均方误差&#xff08;MSE&#xff09;作为经典的损失函数&#xff0c;在处理干净数据时表现优异&#xff0c;但在面对包含异常值的噪声数据时&#xff0c;其对大误差的二次惩罚机制往往导致模型参数…...

使用LangGraph和LangSmith构建多智能体人工智能系统

现在&#xff0c;通过组合几个较小的子智能体来创建一个强大的人工智能智能体正成为一种趋势。但这也带来了一些挑战&#xff0c;比如减少幻觉、管理对话流程、在测试期间留意智能体的工作方式、允许人工介入以及评估其性能。你需要进行大量的反复试验。 在这篇博客〔原作者&a…...

在Mathematica中实现Newton-Raphson迭代的收敛时间算法(一般三次多项式)

考察一般的三次多项式&#xff0c;以r为参数&#xff1a; p[z_, r_] : z^3 (r - 1) z - r; roots[r_] : z /. Solve[p[z, r] 0, z]&#xff1b; 此多项式的根为&#xff1a; 尽管看起来这个多项式是特殊的&#xff0c;其实一般的三次多项式都是可以通过线性变换化为这个形式…...

nnUNet V2修改网络——暴力替换网络为UNet++

更换前,要用nnUNet V2跑通所用数据集,证明nnUNet V2、数据集、运行环境等没有问题 阅读nnU-Net V2 的 U-Net结构,初步了解要修改的网络,知己知彼,修改起来才能游刃有余。 U-Net存在两个局限,一是网络的最佳深度因应用场景而异,这取决于任务的难度和可用于训练的标注数…...