当前位置: 首页 > news >正文

LSTM实战之预测股票

📈 用PyTorch搭建LSTM模型,轻松预测股票价格!🚀

Hey小伙伴们,今天给大家带来一个超级实用的项目教程——如何用PyTorch和LSTM模型来预测股票价格!🌟

🔍 项目背景

我们都知道股市是个风云变幻的地方,而预测股价则是很多投资者梦寐以求的能力。今天,我们就来尝试一下用机器学习的方法来预测股价,让数据说话!

📑 准备工作

首先,我们要准备好开发环境,确保安装了以下Python库:

  • numpy: 数组处理
  • pandas: 数据处理
  • matplotlib: 数据可视化
  • scikit-learn: 数据预处理
  • torch: 构建LSTM模型

💻 实战演练

1️⃣ 导入库 & 加载数据

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.preprocessing import MinMaxScaler
import torch
import torch.nn as nn
import torch.optim as optim# 加载数据
df = pd.read_csv('stock_data.csv')
# 只保留收盘价
data = df.filter(['Close'])
# 将数据转换为numpy数组
dataset = data.values
# 归一化数据
scaler = MinMaxScaler(feature_range=(0, 1))
scaled_data = scaler.fit_transform(dataset)

2️⃣ 创建数据集

# 训练集和测试集划分
training_data_len = int(np.ceil(len(dataset) * .8))# 创建训练数据集
def create_dataset(data, time_step=1):X_train, y_train = [], []for i in range(len(data)-time_step-1):X_train.append(data[i:(i+time_step), 0])y_train.append(data[i + time_step, 0])return np.array(X_train), np.array(y_train)time_step = 60
X_train, y_train = create_dataset(scaled_data[:training_data_len], time_step)
X_test, y_test = create_dataset(scaled_data[training_data_len-time_step:], time_step)# 调整数据形状以适应LSTM
X_train = np.reshape(X_train, (X_train.shape[0], X_train.shape[1], 1))
X_test = np.reshape(X_test, (X_test.shape[0], X_test.shape[1], 1))# 转换为PyTorch张量
X_train = torch.from_numpy(X_train).float()
y_train = torch.from_numpy(y_train).float()
X_test = torch.from_numpy(X_test).float()
y_test = torch.from_numpy(y_test).float()

3️⃣ 构建LSTM模型

class LSTMModel(nn.Module):def __init__(self, input_dim, hidden_dim, layer_dim, output_dim):super(LSTMModel, self).__init__()self.hidden_dim = hidden_dimself.layer_dim = layer_dimself.lstm = nn.LSTM(input_dim, hidden_dim, layer_dim, batch_first=True)self.fc = nn.Linear(hidden_dim, output_dim)def forward(self, x):h0 = torch.zeros(self.layer_dim, x.size(0), self.hidden_dim).requires_grad_()c0 = torch.zeros(self.layer_dim, x.size(0), self.hidden_dim).requires_grad_()out, (hn, cn) = self.lstm(x, (h0.detach(), c0.detach()))out = self.fc(out[:, -1, :]) return outinput_dim = 1
hidden_dim = 50
layer_dim = 1
output_dim = 1model = LSTMModel(input_dim, hidden_dim, layer_dim, output_dim)# 损失函数和优化器
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.01)

4️⃣ 训练模型

num_epochs = 100for epoch in range(num_epochs):outputs = model(X_train)optimizer.zero_grad()# 获取损失loss = criterion(outputs, y_train)# 反向传播和优化loss.backward()optimizer.step()if (epoch+1) % 10 == 0:print(f"Epoch: {epoch+1}, Loss: {loss.item()}")

5️⃣ 预测和评估

# 获取模型预测值
train_predictions = model(X_train).detach().numpy()
test_predictions = model(X_test).detach().numpy()# 反归一化预测值
train_predictions = scaler.inverse_transform(train_predictions)
test_predictions = scaler.inverse_transform(test_predictions)# 计算均方根误差(RMSE)
rmse = np.sqrt(np.mean(((test_predictions - y_test.numpy()) ** 2)))
print("Root Mean Squared Error:", rmse)# 可视化结果
train = data[:training_data_len+1]
valid = data[training_data_len+1:]
valid['Predictions'] = test_predictionsplt.figure(figsize=(16,8))
plt.title('Model')
plt.xlabel('Date', fontsize=18)
plt.ylabel('Close Price', fontsize=18)
plt.plot(train['Close'])
plt.plot(valid[['Close', 'Predictions']])
plt.legend(['Train', 'Val', 'Predictions'], loc='upper right')
plt.show()

📊 结果展示

最后,我们来看看预测结果。可以看到,我们的模型虽然不是完美无缺,但在一定程度上还是能够捕捉到股价的变化趋势。这为投资者提供了非常有价值的信息哦!👀
在这里插入图片描述

🏆 结语

今天的分享就到这里啦!希望这篇教程能帮到你,也欢迎小伙伴们在评论区分享你的经验或者遇到的问题,我们一起探讨学习!🌟


如果你在运行过程中遇到任何问题,或者想要了解更多细节,随时可以问我哦!💡
如果你喜欢这篇教程,请给我点个赞哦!💖
也可以收藏,关注我了解更多人工智能知识哦!😉


📌 附录:常见问题解答

  • Q: 如何获取股票数据?

  • A: 你可以从雅虎财经、tushare等数据源获取股票数据。

  • Q: 为什么我的模型预测效果不好?

  • A: 可能是因为数据不足、模型结构不够复杂或者超参数设置不当。尝试增加数据量、调整模型架构或优化超参数。

  • Q: 我可以在哪里找到更多关于LSTM的知识?

  • A: 有很多在线资源和书籍可以学习LSTM,比如官方文档、博客文章和教程视频。

希望这篇文章对你有所帮助!如果有任何疑问,记得留言哦!👋

#PyTorch #LSTM #股票预测 #时间序列分析 #机器学习 #数据科学 #Python编程

相关文章:

LSTM实战之预测股票

📈 用PyTorch搭建LSTM模型,轻松预测股票价格!🚀 Hey小伙伴们,今天给大家带来一个超级实用的项目教程——如何用PyTorch和LSTM模型来预测股票价格!🌟 🔍 项目背景 我们都知道股市是…...

30-50K|抖音大模型|社招3轮面经

情况介绍:我主要做nlp,也涉及到多模态和强化学习。现在大环境比较差,能投的公司不是很多,比如腾讯,主要还是高级别的,所以腾讯我就没投 抖音一面 1、聊项目。 2、AUC的两种公式是?你能证明这…...

ChatGPT首次被植入人类大脑:帮助残障人士开启对话

马斯克在脑机接口中最强大的竞争对手Synchron有了新的技术进展,他们首次将ChatGPT整合到其脑机系统中,以使瘫痪患者更容易控制他们的数字设备。Synchron凭借其独特的脑机接口(BCI)技术脱颖而出,该技术巧妙地运用了成熟…...

数据结构-常见排序的七大排序

1.排序的概念及其运用 1.1排序的概念 排序:所谓排序,就是使一串记录,按照其中的某个或某些关键字的大小,递增或递减的排列起来的操作。 稳定性:假定在待排序的记录序列中,存在多个具有相同的关键字的记录…...

程序员学CFA——财务报告与分析(四)

财务报告与分析(四) 资产负债表资产负债表的构成和格式资产负债表的要素资产负债所有者权益 资产负债表的格式分层的资产负债表基于流动性的资产负债表 资产的计量属性资产负债表科目金融资产持有至到期投资交易性金融资产可供出售金融资产 商誉少数股东…...

【消息队列】kafka如何保证消息不丢失?

👏大家好!我是和风coding,希望我的文章能给你带来帮助! 🔥如果感觉博主的文章还不错的话,请👍三连支持👍一下博主哦 📝点击 我的主页 还可以看到和风的其他内容噢&#x…...

不同随机数生成的含义

torch.manual_seed(all_args.seed) torch.cuda.manual_seed(all_args.seed) torch.cuda.manual_seed_all(all_args.seed) np.random.seed(all_args.seed) random.seed(all_args.seed) 这几种随机种子设置的含义如下: torch.manual_seed(all_args.seed): 设置PyTor…...

Jar工具完全指南:从入门到精通

Jar工具完全指南:从入门到精通的详尽教程 前言 欢迎来到Jar工具的完全指南!无论你是Java编程的初学者,还是经验丰富的开发者,掌握Jar工具都是必不可少的。Jar(Java Archive)是Java生态系统中的一个核心组…...

前端使用docx-preview展示docx + 后端doc转docx

文章目录 后端 doc 转 docxdcox - preview安装导入使用注意 最近菜鸟刚搞完签字,结果需求就加了,如果合同有附件(.doc.docx),签名就是签到附件里面,没有附件才是签到那个html里面! 这里附件签名…...

Vue3 组件通信

目录 create-vue创建项目 一. 父子通信 1. 父传子 2. 子传父 二. 模版引用(通过ref获取实例对象) 1.基本使用 2.defineExpose 三. 跨层通信 - provide和inject 1. 作用和场景 2. 跨层传递普通数据 3. 跨层传递响应式数据 4. 跨层传递方法 create-vue创建项目 npm ini…...

如何在Ubuntu 14.04上安装、配置和部署Rocket.Chat

前些天发现了一个巨牛的人工智能学习网站,通俗易懂,风趣幽默,忍不住分享一下给大家。点击跳转到网站。 简介 Rocket.Chat 是一个使用 Meteor 构建的开源消息应用程序。它支持视频会议、文件共享、语音消息,具有完整的 API 等功能…...

ISO 26262中的失效率计算:IEC TR 62380-Section 15-Switches and keyboards

目录 概要 1 开关和键盘的分类 2 开关和键盘失效率的计算 2.1 Switches and keyboards 2.1.1 Base失效率 2.1.2 接触数量 2.1.3 温度循环De-rating系数 概要 IEC TR 62380《电子组件、PCBs和设备的可靠性预计通用模型》是涵盖电路、半导体分立器件、光电组件、电阻器、电…...

Linux安全与高级应用(五)深入探讨Linux Shell脚本应用:从基础到高级

文章目录 深入探讨Linux Shell脚本应用:从基础到高级引言一、Shell脚本基础知识1. Shell的作用与分类2. 编写第一个Shell脚本 二、Shell变量的使用1. 变量的类型与定义2. 引号的使用3. 位置变量与预定义变量 三、重定向与管道操作1. 重定向操作2. 管道操作 四、计划…...

Java中等题-解码方法(力扣)

一条包含字母 A-Z 的消息通过以下映射进行了 编码 : "1" -> A "2" -> B ... "25" -> Y "26" -> Z 然而,在 解码 已编码的消息时,你意识到有许多不同的方式来解码,因为有些…...

【Git】git 从入门到实战系列(二)—— Git 介绍以及安装方法

文章目录 一、前言二、git 是什么三、版本控制系统是什么四、本地 vs 集中式 vs 分布式本地版本控制系统集中式版本控制系统分布式版本控制系统 五、安装 git 一、前言 本系列上一篇文章【Git】git 从入门到实战系列(一)—— Git 的诞生,Lin…...

【QT 5 QT 6 构建工具qmake-cmake-和-软件编译器MSVCxxxvs MinGWxxx说明】

【QT 5报错:/xxx/: error: ‘class Ui::frmMain’ has no member named ‘xxx’-和-软件编译器MSVCxxxvs MinGWxxx说明】 1、前言2 、qt 中 Qmake CMake 和 QBS1-qmake2-Cmake3-QBS4-官网一些说法5-各自特点 3、软件编译套件1-Desktop Qt 6.7.2 llvm-mingw 64-bit2-…...

SD卡参数错误:深度解析与数之寻软件恢复实战

一、SD卡参数错误:数据与设备的隐形杀手 在数字化时代,SD卡作为便携存储设备,广泛应用于相机、手机、无人机及各类电子设备中,承载着人们珍贵的照片、视频、文档等重要数据。然而,SD卡在使用过程中,有时会…...

深入理解和应用RabbitMQ的Work Queues模型

文章目录 1. 场景模拟2. 消息发送3. 消息接收4. 测试5. 能者多劳6. 总结 当你在处理消息时,可能会遇到这样的问题:消息的生产速度远远大于消费速度,导致消息堆积。这时候,Work Queues(工作队列)模型就能派上…...

嵌入式面试八股文(三)·野指针产生原因和解决方法、指针函数和函数指针的区别

目录 1. 野指针产生原因和解决方法 1.1 产生的原因 1.1.1 指针未能初始化 1.1.2 指针指向的内存被释放 1.1.3 指针指向的对象被重复释放 1.2 解决方法 1.2.1 初始化指针 1.2.2 指针空置 1.2.3 避免悬挂指针 2. 指针函数和函数指针的区别 2.1 定义不同 2…...

OpenCV 中 CV_8UC1,CV_32FC3,CV_32S等参数的含义

在OpenCV中,创建图像时需要指定图像的类型,这些类型通常通过常量来表示,例如 CV_8UC1、CV_32FC3、CV_32S 等。这些常量定义了图像的数据类型和通道数,具体含义如下: CV_8UC1: CV_8U 表示每个像素由一个8位无…...

《Qt C++ 与 OpenCV:解锁视频播放程序设计的奥秘》

引言:探索视频播放程序设计之旅 在当今数字化时代,多媒体应用已渗透到我们生活的方方面面,从日常的视频娱乐到专业的视频监控、视频会议系统,视频播放程序作为多媒体应用的核心组成部分,扮演着至关重要的角色。无论是在个人电脑、移动设备还是智能电视等平台上,用户都期望…...

ssc377d修改flash分区大小

1、flash的分区默认分配16M、 / # df -h Filesystem Size Used Available Use% Mounted on /dev/root 1.9M 1.9M 0 100% / /dev/mtdblock4 3.0M...

第一篇:Agent2Agent (A2A) 协议——协作式人工智能的黎明

AI 领域的快速发展正在催生一个新时代,智能代理(agents)不再是孤立的个体,而是能够像一个数字团队一样协作。然而,当前 AI 生态系统的碎片化阻碍了这一愿景的实现,导致了“AI 巴别塔问题”——不同代理之间…...

新能源汽车智慧充电桩管理方案:新能源充电桩散热问题及消防安全监管方案

随着新能源汽车的快速普及,充电桩作为核心配套设施,其安全性与可靠性备受关注。然而,在高温、高负荷运行环境下,充电桩的散热问题与消防安全隐患日益凸显,成为制约行业发展的关键瓶颈。 如何通过智慧化管理手段优化散…...

从零实现STL哈希容器:unordered_map/unordered_set封装详解

本篇文章是对C学习的STL哈希容器自主实现部分的学习分享 希望也能为你带来些帮助~ 那咱们废话不多说&#xff0c;直接开始吧&#xff01; 一、源码结构分析 1. SGISTL30实现剖析 // hash_set核心结构 template <class Value, class HashFcn, ...> class hash_set {ty…...

06 Deep learning神经网络编程基础 激活函数 --吴恩达

深度学习激活函数详解 一、核心作用 引入非线性:使神经网络可学习复杂模式控制输出范围:如Sigmoid将输出限制在(0,1)梯度传递:影响反向传播的稳定性二、常见类型及数学表达 Sigmoid σ ( x ) = 1 1 +...

七、数据库的完整性

七、数据库的完整性 主要内容 7.1 数据库的完整性概述 7.2 实体完整性 7.3 参照完整性 7.4 用户定义的完整性 7.5 触发器 7.6 SQL Server中数据库完整性的实现 7.7 小结 7.1 数据库的完整性概述 数据库完整性的含义 正确性 指数据的合法性 有效性 指数据是否属于所定…...

动态 Web 开发技术入门篇

一、HTTP 协议核心 1.1 HTTP 基础 协议全称 &#xff1a;HyperText Transfer Protocol&#xff08;超文本传输协议&#xff09; 默认端口 &#xff1a;HTTP 使用 80 端口&#xff0c;HTTPS 使用 443 端口。 请求方法 &#xff1a; GET &#xff1a;用于获取资源&#xff0c;…...

FFmpeg:Windows系统小白安装及其使用

一、安装 1.访问官网 Download FFmpeg 2.点击版本目录 3.选择版本点击安装 注意这里选择的是【release buids】&#xff0c;注意左上角标题 例如我安装在目录 F:\FFmpeg 4.解压 5.添加环境变量 把你解压后的bin目录&#xff08;即exe所在文件夹&#xff09;加入系统变量…...

HTML前端开发:JavaScript 获取元素方法详解

作为前端开发者&#xff0c;高效获取 DOM 元素是必备技能。以下是 JS 中核心的获取元素方法&#xff0c;分为两大系列&#xff1a; 一、getElementBy... 系列 传统方法&#xff0c;直接通过 DOM 接口访问&#xff0c;返回动态集合&#xff08;元素变化会实时更新&#xff09;。…...