当前位置: 首页 > news >正文

从0开始基于transformer进行股价预测(pytorch版本)

目录

  • 数据阶段
    • 两个问题
    • 开始利用我们的代码进行切分
  • backbone网络
  • 训练
  • 效果 感觉还行,没有调参数。
  • 源码比较长,如果需要我后续会发(因为太长了!!)

数据阶段

!!!注意!!! , 本文不会讲原理,因为之前两篇文章已经讲过了,只会解释一些结构性问题,和思路问题。

所谓工欲善其事,必先利其器做量化分析的股价预测,完美必须要先把数据处理好。
那么本人的数据下载是在聚宽平台股票代码为601398的数据2014-3 到 2024-3年的默认数据。如何下载可以按照我的方式

在这里插入图片描述
进入研究环境后随便创建一个ipynb文件进行数据下载 ,运行以下代码就行

# 1.获取数据
data = get_price('601398.XSHG', start_date='2014-01-01', end_date='2024-01-01', frequency='daily', fields=None, skip_paused=False, fq='pre', panel=True)
# 2.保存数据
data.to_csv('data_沪深300/601398.XSHG(工商银行14-24).csv')

两个问题

1.为什么我们只需要用encoder部分去预测就行而不需要decoder部分?
答: 编码器用于将输入序列编码成一个上下文表示(contextual representation),然后解码器根据该上下文表示生成目标序列在时间序列预测任务中,我们不需要生成一个序列,而是预测单个或少量几个未来数据点。因此,编码器的上下文表示已经包含了足够的信息来进行预测,无需使用解码器。还有我觉得使用解码器的意思是,你用上一天的数据去预测下一天的数据,我感觉这样就没意思了,这和我们个人看有什么区别。而且对最后的结果也会造成不精准的效果。为什么这么说呢,你看解码器的mask编码部分应该可以理解了。
2.我们的维度为什么不是[batch, len, feature]? 因为这是pytorch要求,自己能实现的话,自己改吧。

开始利用我们的代码进行切分

我的思路用的是用五天的数据去预测下一天,数据集和测试及8/2分
但是我们要记住一点,就是我们必须要理解我们这么做的思路,就比如我们的特征有6列分别是,open,close,high,low,volume,money,我们可以通过训练得到我们想预测的某一特征。OK,我们这就开始。

说起数据分割里面的代码不难,最难的是
for i in range(len(X_CONVERT) - seq_length):
X_data.append(X_CONVERT[i:i+seq_length, :])
y_data.append(X_CONVERT[i+seq_length, 1])
你要知道我在干什么,就是用8成的数据集去预测得到我们所需要的train数据集和我们对应train数据集的label,举个例子就是,我们要炒菜,我们拿上原料后我们要知道炒的什么菜,那么菜单必须要知道。是吧,不然你炒完菜后说是红烧肉,但是没有菜单图片对比你怎么知道这是红烧肉?这也就是这一步的意义。

def split_data(batch_size,seq_length, pred_length, train_ratio):data_all = pd.read_csv(data_path)data_ha = []length = len(data_all)# 将数据转换为numpy数组,并添加到列表中for element in elements:data_element = data_all[element].values.astype(np.float32)data_element = data_element.reshape(length, 1)data_ha.append(data_element)X_hat = np.concatenate(data_ha, axis=1)X_CONVERT = torch.from_numpy(X_hat).float()X_CONVERT = X_CONVERT.flip(dims=[0])# 进行归一化min_val = np.min(X_hat, axis=0)max_val = np.max(X_hat, axis=0)X_normalized = (X_hat - min_val) / (max_val - min_val)X_CONVERT = torch.from_numpy(X_normalized).float()X_CONVERT = X_CONVERT.flip(dims=[0])#数据翻转# 划分训练集和验证集X_data = []y_data = []for i in range(len(X_CONVERT) - seq_length):#划分的时候是用8成的训练集去训练然后label是某##一列X_data.append(X_CONVERT[i:i+seq_length, :])y_data.append(X_CONVERT[i+seq_length, 1])X_data = torch.stack(X_data)y_data = torch.stack(y_data).squeeze(-1)print(X_data.shape, y_data.shape)dataset = TensorDataset(X_data, y_data)train_size = int(len(dataset) * train_ratio)val_size = len(dataset) - train_sizetrain_dataset, val_dataset = random_split(dataset, [train_size, val_size])train_loader = DataLoader(train_dataset, batch_size, shuffle=False)val_loader = DataLoader(val_dataset, batch_size, shuffle=False)return train_loader, val_loader,min_val, max_val

backbone网络

如其名,我们都知道这是这是transformer当然是用的transformer的结构。但是我们用,但是只用一部分,具体用什么部分开头说了,只用encoder

**但是具体操作起来的时候encoder里面的embadding部分我们需要修改,因为我们不是机器翻译,所以我们不需要把他变成词向量,我们时间序列数据,输入通常是连续的数值特征,使用线性层更直接地将这些数值特征映射到高维空间。并且我们的embadding嵌入层,适用于离散的输入,输出是固定维度的嵌入向量。而线性层,适用于连续的输入,可以灵活处理不同维度的输入特征,将其映射到高维表示。**具体看下面代码

class Encoder(nn.Module):def __init__(self):super(Encoder, self).__init__()self.src_emb = nn.Linear(feature, d_model)#这里替换了self.pos_emb = PositionalEncoding(d_model)self.layers = nn.ModuleList([EncoderLayer() for _ in range(n_layers)])def forward(self, enc_inputs):enc_outputs = self.src_emb(enc_inputs)  # [batch_size, src_len, d_model]enc_outputs = self.pos_emb(enc_outputs)  # [batch_size, src_len, d_model]enc_self_attn_mask = get_attn_pad_mask(enc_inputs, enc_inputs)  # [batch_size, src_len, src_len]enc_self_attns = []for layer in self.layers:enc_outputs, enc_self_attn = layer(enc_outputs, enc_self_attn_mask)enc_self_attns.append(enc_self_attn)return enc_outputs, enc_self_attns

当然其他的部分和我上一篇的一样,但是就是decode不要了,当然也可以换成其他结果,或者加个注意力机制

讲下各个参数


d_model = 512   # linnerer的输入维度 也就是字embedding的维度
d_ff = 2048     # 前向传播隐藏层维度
d_k = d_v = 64  # K(=Q), V的维度
n_layers = 6    # 有多少个encoder和decoder
n_heads = 8     # Multi-Head Attention设置为8
feature=6       # 输入特征维度

当然主体还是要看一下的最重要的是通过encoder后的维度转换比较繁琐,要和我们之前split的数据集得到的y_train一致这样才能计算损失


class Transformer(nn.Module):def __init__(self):super(Transformer, self).__init__()self.Encoder = Encoder()self.projection = nn.Linear(d_model, 1, bias=False)def forward(self, enc_inputs):  # enc_inputs: [batch_size, src_len, feature]enc_outputs, enc_self_attns = self.Encoder(enc_inputs)  # enc_outputs: [batch_size, src_len, d_model]dec_logits = self.projection(enc_outputs)  # dec_logits: [batch_size, src_len, 1]dec_logits = dec_logits.mean(dim=1)  # 将每个时间步的预测结果取平均,得到 [batch_size, 1]return dec_logits.squeeze(-1), enc_self_attns  # 输出 [batch_size]

训练

先解释参数

batch_size=64#批处理大小
seq_length=7#时间序列长度 也就是通过seq_length天预测后面pred_length天
pred_length=1#预测长度
train_ratio=0.8#训练集比例
epochs = 50 # 训练轮数
lr= 0.001 # 学习率
png_save_path="diytransformers/12.24transformer/picture"#所有的图片保存的地方
loss_history = []# 存储每个 epoch 的损失

训练代码很长,挺简单的


# 训练模型
for epoch in range(epochs):epoch_loss = 0y_pre = []y_true = []# 训练阶段for X, y in train_loader:X = X.float()  # 确保输入数据类型为float32y = y.float()  # 确保目标数据类型为float32outputs, enc_self_attns = model(X)# 计算损失,确保形状一致loss = criterion(outputs, y)epoch_loss += loss.item()optimizer.zero_grad()loss.backward()torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)optimizer.step()#转换我们的label和训练后得到的训练集的预测值 y_pre.append(outputs.detach())y_true.append(y.detach())avg_loss = epoch_loss / len(train_loader)loss_history.append(avg_loss)#获得最好的lossif avg_loss < best_loss:best_loss = avg_lossbest_epoch = epochbest_model_wts = copy.deepcopy(model.state_dict())torch.save(best_model_wts, path_train)y_pre_concat = torch.cat(y_pre, dim=0)y_true_concat = torch.cat(y_true, dim=0)# 计算并打印评估指标metrics = evaluate(y_pre_concat, y_true_concat, min_val, max_val)print(f'Epoch {epoch + 1}, Loss: {avg_loss:.6f}')# 可视化结果ht(y_true_concat.detach().cpu().numpy(), y_pre_concat.detach().cpu().numpy(), min_val, max_val,png_save_path)

最后是看我们的一些指标效果如何 比如这里我计算的mae,rmse,pcc等

# 加载最佳模型权重
model.load_state_dict(torch.load(train_over_path))# 测试模型并计算评估指标
test_metrics = test_model(model, val_loader, min_val, max_val)print(f'Test Metrics: {test_metrics}')

效果 感觉还行,没有调参数。

在这里插入图片描述

源码比较长,如果需要我后续会发(因为太长了!!)

相关文章:

从0开始基于transformer进行股价预测(pytorch版本)

目录 数据阶段两个问题开始利用我们的代码进行切分 backbone网络训练效果 感觉还行&#xff0c;没有调参数。源码比较长&#xff0c;如果需要我后续会发&#xff08;因为太长了&#xff01;&#xff01;&#xff09; 数据阶段 &#xff01;&#xff01;&#xff01;注意&#…...

【多GPU训练方法】

一、数据并行 这是最常用的方法。整个模型复制到每个GPU上。训练数据被均匀分割&#xff0c;每个GPU处理一部分数据。所有GPU上的梯度被收集并求平均。通常使用NCCL&#xff08;NVIDIA Collective Communications Library&#xff09;等通信库实现。参数更新 使用同步后的梯度…...

2024年PMP考试备考经验分享

PMP是项目管理领域最重要的认证之一,本身是IT行业比较流行的证书&#xff0c;近几年在临床试验领域也渐渐流行起来&#xff0c;是我周围临床项PM几乎人手一个的证书。 考试时间&#xff1a;PMP认证考试形式为180道选择题&#xff0c;考试时间为3小时50分。 考试计划&#xff…...

MT3046 愤怒的象棚

思路&#xff1a; a[]存愤怒值&#xff1b;b[i]存以i结尾的&#xff0c;窗口里的最大值&#xff1b;c[i]存以i结尾的&#xff0c;窗口里面包含✳的最大值。 &#xff08;✳为新大象的位置&#xff09; 例&#xff1a;1 2 3 4 ✳ 5 6 7 8 9 则ans的计算公式b3b4c4c5c6b7b8b9…...

深入了解代理IP常见协议:区别与选择

代理服务器在网络使用中扮演着重要的角色&#xff0c;是您设备和互联网之间的中间层。它不仅可以增强网络访问的安全性和隐私保护&#xff0c;还可以提供许多灵活的应用。使用代理时&#xff0c;不同的协议类型对数据交换具有不同的规则和特征。常见的代理协议包括HTTP代理、HT…...

【Linux 线程】线程的基本概念、LWP的理解

文章目录 一、ps -L 指令&#x1f34e;二、线程控制 一、ps -L 指令&#x1f34e; &#x1f427; 使用 ps -L 命令查看轻量级进程信息&#xff1b;&#x1f427; pthread_self() 用于获取用户态线程的 tid&#xff0c;而并非轻量级进程ID&#xff1b;&#x1f427; getpid() 用…...

Dify中的工具

Dify中的工具分为内置工具&#xff08;硬编码&#xff09;和第三方工具&#xff08;OpenAPI Swagger/ChatGPT Plugin&#xff09;。工具可被Workflow&#xff08;工作流&#xff09;和Agent使用&#xff0c;当然Workflow也可被发布为工具&#xff0c;这样Workflow&#xff08;工…...

在Visutal Studio 2022中完成D3D12初始化

在Visutal Studio 2022中完成DirectX设备初始化 1 DirectX121.1 DirectX 简介1.2 DirectX SDK安装2 D3D12初始化2.1 创建Windwos桌面项目2.2 修改符合模式2.3 下载d3dx12.h文件2.4 创建一个异常类D3DException,定义抛出异常实例的宏ThrowIfFailed3 D3D12的初始化步骤3.1 初始化…...

MobaXterm工具

MobaXterm 是一个增强型的 Windows 终端。其为 Windows 桌面提供所有重要的远程网络终端工具&#xff08;如 SSH、X11、RDP、VNC、FTP、SFTP、Telnet、Serial、Mosh、WSL 等&#xff09;&#xff0c;和 Unix 命令&#xff08;如 bash、ls、cat、sed、grep、awk、rsync 等&#…...

二分图练习

对于二分图我们可以用染色法 #include<bits/stdc.h> using namespace std;#define int long long const int N 2e65; int e[N],ne[N],h[N],idx 0; int colo[N]; int num 0;void add(int x,int y){e[idx] y;ne[idx] h[x];h[x] idx; } void dfs(int nod,int c){colo…...

创新设计策略:提升大屏幕可视化设计效果的关键方法

随着科技的不断发展和数据量的快速增长&#xff0c;数据可视化大屏在各个行业中的应用越来越广泛&#xff0c;可以帮助人们更好地理解和分析数据&#xff0c;可视化大屏设计也因此成了众多企业的需求。但很多设计师对可视化大屏设计并不了解&#xff0c;也不知道如何制作可视化…...

论文 | Chain-of-Thought Prompting Elicits Reasoningin Large Language Models 思维链

这篇论文研究了如何通过生成一系列中间推理步骤&#xff08;即思维链&#xff09;来显著提高大型语言模型进行复杂推理的能力。论文展示了一种简单的方法&#xff0c;称为思维链提示&#xff0c;通过在提示中提供几个思维链示例来自然地激发这种推理能力。 主要发现&#xff1…...

[机器学习]-人工智能对程序员的深远影响——案例分析

机器学习和人工智能对未来程序员的深远影响 目录 机器学习和人工智能对未来程序员的深远影响1. **自动化编码任务**1.1 代码生成1.2 自动调试1.3 测试自动化 2. **提升开发效率**2.1 智能建议2.2 项目管理 3. **改变编程范式**3.1 数据驱动开发 4. **职业发展的新机遇**4.1 AI工…...

AI学习环境 没有更好的替代 - (Google)Drive + Colab

在开始正题前&#xff0c;请容许我做一番回顾&#xff0c;并夹带一点点私货&#xff08;谷歌扛旗的开源精神还没有死&#xff0c;并且会是未来的举足轻重的力量&#xff09; 卧龙凤雏&#xff0c;一时瑜亮。一切的缘起应该是世纪初的门户网站乱战。 彼时&#xff0c;谷歌是从…...

【观成科技】Websocket协议代理隧道加密流量分析与检测

Websocket协议代理隧道加密流量简介 攻防场景下&#xff0c;Websocket协议常被用于代理隧道的搭建&#xff0c;攻击者企图通过Websocket协议来绕过网络限制&#xff0c;搭建一个低延迟、双向实时数据传输的隧道。当前&#xff0c;主流的支持Websocket通信代理的工具有&#xf…...

DangerWind-RPC-framework---三、服务端下机

当一台机器下线时&#xff0c;面临很多问题&#xff1a;如何将其从注册中心下线&#xff1f;如何清理释放资源&#xff1f;客户端拉取服务列表时也使用了本地缓存&#xff0c;如何及时更新本地缓存&#xff1f; 服务端机器的优雅下线需要使用ShutdownHook&#xff0c;这相当于添…...

基于Make的c工程No compilation commands found报错

由于安装gcc时只安装了build-essential&#xff0c;没有将其添加到环境变量中&#xff0c;因此打开Make工程时&#xff0c;CLion会产生如下错误&#xff1a; 要解决这个问题&#xff0c;一个方法是将GCC添加到环境变量中&#xff0c;但是这个方法需要修改至少两个配置文件&…...

c++:面向对象的继承特性

什么是继承 (1)继承是C源生支持的一种语法特性&#xff0c;是C面向对象的一种表现 (2)继承特性可以让派生类“瞬间”拥有基类的所有&#xff08;当然还得考虑权限&#xff09;属性和方法 (3)继承特性本质上是为了代码复用 (4)类在C编译器的内部可以理解为结构体&#xff0c;派…...

skywalking-2-客户端-php的安装与使用

skywalking的客户端支持php&#xff0c;真的很棒。 官方安装文档&#xff1a;https://skywalking.apache.org/docs/skywalking-php/next/en/setup/service-agent/php-agent/readme/ 前置准备 本次使用的php版本是8.2.13: php -v PHP 8.2.13 (cli) (built: Nov 21 2023 09:5…...

图文讲解IDEA如何导入JDBC驱动包

前言 学习JDBC编程,势必要学会如何导入驱动包,这里笔者用图文的方式来介绍 视频版本在这里 50秒教你怎么导入驱动包然后进行JDBC编程的学习_哔哩哔哩_bilibili 忘记录音频了,大伙凑合着看 下载驱动包 https://mvnrepository.com/artifact/mysql/mysql-connector-java 去中…...

业务系统对接大模型的基础方案:架构设计与关键步骤

业务系统对接大模型&#xff1a;架构设计与关键步骤 在当今数字化转型的浪潮中&#xff0c;大语言模型&#xff08;LLM&#xff09;已成为企业提升业务效率和创新能力的关键技术之一。将大模型集成到业务系统中&#xff0c;不仅可以优化用户体验&#xff0c;还能为业务决策提供…...

Java 语言特性(面试系列1)

一、面向对象编程 1. 封装&#xff08;Encapsulation&#xff09; 定义&#xff1a;将数据&#xff08;属性&#xff09;和操作数据的方法绑定在一起&#xff0c;通过访问控制符&#xff08;private、protected、public&#xff09;隐藏内部实现细节。示例&#xff1a; public …...

Python:操作 Excel 折叠

💖亲爱的技术爱好者们,热烈欢迎来到 Kant2048 的博客!我是 Thomas Kant,很开心能在CSDN上与你们相遇~💖 本博客的精华专栏: 【自动化测试】 【测试经验】 【人工智能】 【Python】 Python 操作 Excel 系列 读取单元格数据按行写入设置行高和列宽自动调整行高和列宽水平…...

uniapp微信小程序视频实时流+pc端预览方案

方案类型技术实现是否免费优点缺点适用场景延迟范围开发复杂度​WebSocket图片帧​定时拍照Base64传输✅ 完全免费无需服务器 纯前端实现高延迟高流量 帧率极低个人demo测试 超低频监控500ms-2s⭐⭐​RTMP推流​TRTC/即构SDK推流❌ 付费方案 &#xff08;部分有免费额度&#x…...

Python如何给视频添加音频和字幕

在Python中&#xff0c;给视频添加音频和字幕可以使用电影文件处理库MoviePy和字幕处理库Subtitles。下面将详细介绍如何使用这些库来实现视频的音频和字幕添加&#xff0c;包括必要的代码示例和详细解释。 环境准备 在开始之前&#xff0c;需要安装以下Python库&#xff1a;…...

MySQL 8.0 OCP 英文题库解析(十三)

Oracle 为庆祝 MySQL 30 周年&#xff0c;截止到 2025.07.31 之前。所有人均可以免费考取原价245美元的MySQL OCP 认证。 从今天开始&#xff0c;将英文题库免费公布出来&#xff0c;并进行解析&#xff0c;帮助大家在一个月之内轻松通过OCP认证。 本期公布试题111~120 试题1…...

关键领域软件测试的突围之路:如何破解安全与效率的平衡难题

在数字化浪潮席卷全球的今天&#xff0c;软件系统已成为国家关键领域的核心战斗力。不同于普通商业软件&#xff0c;这些承载着国家安全使命的软件系统面临着前所未有的质量挑战——如何在确保绝对安全的前提下&#xff0c;实现高效测试与快速迭代&#xff1f;这一命题正考验着…...

Linux C语言网络编程详细入门教程:如何一步步实现TCP服务端与客户端通信

文章目录 Linux C语言网络编程详细入门教程&#xff1a;如何一步步实现TCP服务端与客户端通信前言一、网络通信基础概念二、服务端与客户端的完整流程图解三、每一步的详细讲解和代码示例1. 创建Socket&#xff08;服务端和客户端都要&#xff09;2. 绑定本地地址和端口&#x…...

论文笔记——相干体技术在裂缝预测中的应用研究

目录 相关地震知识补充地震数据的认识地震几何属性 相干体算法定义基本原理第一代相干体技术&#xff1a;基于互相关的相干体技术&#xff08;Correlation&#xff09;第二代相干体技术&#xff1a;基于相似的相干体技术&#xff08;Semblance&#xff09;基于多道相似的相干体…...

[免费]微信小程序问卷调查系统(SpringBoot后端+Vue管理端)【论文+源码+SQL脚本】

大家好&#xff0c;我是java1234_小锋老师&#xff0c;看到一个不错的微信小程序问卷调查系统(SpringBoot后端Vue管理端)【论文源码SQL脚本】&#xff0c;分享下哈。 项目视频演示 【免费】微信小程序问卷调查系统(SpringBoot后端Vue管理端) Java毕业设计_哔哩哔哩_bilibili 项…...