从0开始基于transformer进行股价预测(pytorch版本)
目录
- 数据阶段
- 两个问题
- 开始利用我们的代码进行切分
- backbone网络
- 训练
- 效果 感觉还行,没有调参数。
- 源码比较长,如果需要我后续会发(因为太长了!!)
数据阶段
!!!注意!!! , 本文不会讲原理,因为之前两篇文章已经讲过了,只会解释一些结构性问题,和思路问题。
所谓工欲善其事,必先利其器做量化分析的股价预测,完美必须要先把数据处理好。
那么本人的数据下载是在聚宽平台股票代码为601398的数据2014-3 到 2024-3年的默认数据。如何下载可以按照我的方式
进入研究环境后随便创建一个ipynb文件进行数据下载 ,运行以下代码就行
# 1.获取数据
data = get_price('601398.XSHG', start_date='2014-01-01', end_date='2024-01-01', frequency='daily', fields=None, skip_paused=False, fq='pre', panel=True)
# 2.保存数据
data.to_csv('data_沪深300/601398.XSHG(工商银行14-24).csv')
两个问题
1.为什么我们只需要用encoder部分去预测就行而不需要decoder部分?
答: 编码器用于将输入序列编码成一个上下文表示(contextual representation),然后解码器根据该上下文表示生成目标序列在时间序列预测任务中,我们不需要生成一个序列,而是预测单个或少量几个未来数据点。因此,编码器的上下文表示已经包含了足够的信息来进行预测,无需使用解码器。还有我觉得使用解码器的意思是,你用上一天的数据去预测下一天的数据,我感觉这样就没意思了,这和我们个人看有什么区别。而且对最后的结果也会造成不精准的效果。为什么这么说呢,你看解码器的mask编码部分应该可以理解了。
2.我们的维度为什么不是[batch, len, feature]? 因为这是pytorch要求,自己能实现的话,自己改吧。
开始利用我们的代码进行切分
我的思路用的是用五天的数据去预测下一天,数据集和测试及8/2分
但是我们要记住一点,就是我们必须要理解我们这么做的思路,就比如我们的特征有6列分别是,open,close,high,low,volume,money,我们可以通过训练得到我们想预测的某一特征。OK,我们这就开始。
说起数据分割里面的代码不难,最难的是
for i in range(len(X_CONVERT) - seq_length):
X_data.append(X_CONVERT[i:i+seq_length, :])
y_data.append(X_CONVERT[i+seq_length, 1])
你要知道我在干什么,就是用8成的数据集去预测得到我们所需要的train数据集和我们对应train数据集的label,举个例子就是,我们要炒菜,我们拿上原料后我们要知道炒的什么菜,那么菜单必须要知道。是吧,不然你炒完菜后说是红烧肉,但是没有菜单图片对比你怎么知道这是红烧肉?这也就是这一步的意义。
def split_data(batch_size,seq_length, pred_length, train_ratio):data_all = pd.read_csv(data_path)data_ha = []length = len(data_all)# 将数据转换为numpy数组,并添加到列表中for element in elements:data_element = data_all[element].values.astype(np.float32)data_element = data_element.reshape(length, 1)data_ha.append(data_element)X_hat = np.concatenate(data_ha, axis=1)X_CONVERT = torch.from_numpy(X_hat).float()X_CONVERT = X_CONVERT.flip(dims=[0])# 进行归一化min_val = np.min(X_hat, axis=0)max_val = np.max(X_hat, axis=0)X_normalized = (X_hat - min_val) / (max_val - min_val)X_CONVERT = torch.from_numpy(X_normalized).float()X_CONVERT = X_CONVERT.flip(dims=[0])#数据翻转# 划分训练集和验证集X_data = []y_data = []for i in range(len(X_CONVERT) - seq_length):#划分的时候是用8成的训练集去训练然后label是某##一列X_data.append(X_CONVERT[i:i+seq_length, :])y_data.append(X_CONVERT[i+seq_length, 1])X_data = torch.stack(X_data)y_data = torch.stack(y_data).squeeze(-1)print(X_data.shape, y_data.shape)dataset = TensorDataset(X_data, y_data)train_size = int(len(dataset) * train_ratio)val_size = len(dataset) - train_sizetrain_dataset, val_dataset = random_split(dataset, [train_size, val_size])train_loader = DataLoader(train_dataset, batch_size, shuffle=False)val_loader = DataLoader(val_dataset, batch_size, shuffle=False)return train_loader, val_loader,min_val, max_val
backbone网络
如其名,我们都知道这是这是transformer当然是用的transformer的结构。但是我们用,但是只用一部分,具体用什么部分开头说了,只用encoder
**但是具体操作起来的时候encoder里面的embadding部分我们需要修改,因为我们不是机器翻译,所以我们不需要把他变成词向量,我们时间序列数据,输入通常是连续的数值特征,使用线性层更直接地将这些数值特征映射到高维空间。并且我们的embadding嵌入层,适用于离散的输入,输出是固定维度的嵌入向量。而线性层,适用于连续的输入,可以灵活处理不同维度的输入特征,将其映射到高维表示。**具体看下面代码
class Encoder(nn.Module):def __init__(self):super(Encoder, self).__init__()self.src_emb = nn.Linear(feature, d_model)#这里替换了self.pos_emb = PositionalEncoding(d_model)self.layers = nn.ModuleList([EncoderLayer() for _ in range(n_layers)])def forward(self, enc_inputs):enc_outputs = self.src_emb(enc_inputs) # [batch_size, src_len, d_model]enc_outputs = self.pos_emb(enc_outputs) # [batch_size, src_len, d_model]enc_self_attn_mask = get_attn_pad_mask(enc_inputs, enc_inputs) # [batch_size, src_len, src_len]enc_self_attns = []for layer in self.layers:enc_outputs, enc_self_attn = layer(enc_outputs, enc_self_attn_mask)enc_self_attns.append(enc_self_attn)return enc_outputs, enc_self_attns
当然其他的部分和我上一篇的一样,但是就是decode不要了,当然也可以换成其他结果,或者加个注意力机制
讲下各个参数
d_model = 512 # linnerer的输入维度 也就是字embedding的维度
d_ff = 2048 # 前向传播隐藏层维度
d_k = d_v = 64 # K(=Q), V的维度
n_layers = 6 # 有多少个encoder和decoder
n_heads = 8 # Multi-Head Attention设置为8
feature=6 # 输入特征维度
当然主体还是要看一下的最重要的是通过encoder后的维度转换比较繁琐,要和我们之前split的数据集得到的y_train一致这样才能计算损失
class Transformer(nn.Module):def __init__(self):super(Transformer, self).__init__()self.Encoder = Encoder()self.projection = nn.Linear(d_model, 1, bias=False)def forward(self, enc_inputs): # enc_inputs: [batch_size, src_len, feature]enc_outputs, enc_self_attns = self.Encoder(enc_inputs) # enc_outputs: [batch_size, src_len, d_model]dec_logits = self.projection(enc_outputs) # dec_logits: [batch_size, src_len, 1]dec_logits = dec_logits.mean(dim=1) # 将每个时间步的预测结果取平均,得到 [batch_size, 1]return dec_logits.squeeze(-1), enc_self_attns # 输出 [batch_size]
训练
先解释参数
batch_size=64#批处理大小
seq_length=7#时间序列长度 也就是通过seq_length天预测后面pred_length天
pred_length=1#预测长度
train_ratio=0.8#训练集比例
epochs = 50 # 训练轮数
lr= 0.001 # 学习率
png_save_path="diytransformers/12.24transformer/picture"#所有的图片保存的地方
loss_history = []# 存储每个 epoch 的损失
训练代码很长,挺简单的
# 训练模型
for epoch in range(epochs):epoch_loss = 0y_pre = []y_true = []# 训练阶段for X, y in train_loader:X = X.float() # 确保输入数据类型为float32y = y.float() # 确保目标数据类型为float32outputs, enc_self_attns = model(X)# 计算损失,确保形状一致loss = criterion(outputs, y)epoch_loss += loss.item()optimizer.zero_grad()loss.backward()torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)optimizer.step()#转换我们的label和训练后得到的训练集的预测值 y_pre.append(outputs.detach())y_true.append(y.detach())avg_loss = epoch_loss / len(train_loader)loss_history.append(avg_loss)#获得最好的lossif avg_loss < best_loss:best_loss = avg_lossbest_epoch = epochbest_model_wts = copy.deepcopy(model.state_dict())torch.save(best_model_wts, path_train)y_pre_concat = torch.cat(y_pre, dim=0)y_true_concat = torch.cat(y_true, dim=0)# 计算并打印评估指标metrics = evaluate(y_pre_concat, y_true_concat, min_val, max_val)print(f'Epoch {epoch + 1}, Loss: {avg_loss:.6f}')# 可视化结果ht(y_true_concat.detach().cpu().numpy(), y_pre_concat.detach().cpu().numpy(), min_val, max_val,png_save_path)
最后是看我们的一些指标效果如何 比如这里我计算的mae,rmse,pcc等
# 加载最佳模型权重
model.load_state_dict(torch.load(train_over_path))# 测试模型并计算评估指标
test_metrics = test_model(model, val_loader, min_val, max_val)print(f'Test Metrics: {test_metrics}')
效果 感觉还行,没有调参数。
源码比较长,如果需要我后续会发(因为太长了!!)
相关文章:

从0开始基于transformer进行股价预测(pytorch版本)
目录 数据阶段两个问题开始利用我们的代码进行切分 backbone网络训练效果 感觉还行,没有调参数。源码比较长,如果需要我后续会发(因为太长了!!) 数据阶段 !!!注意&#…...

【多GPU训练方法】
一、数据并行 这是最常用的方法。整个模型复制到每个GPU上。训练数据被均匀分割,每个GPU处理一部分数据。所有GPU上的梯度被收集并求平均。通常使用NCCL(NVIDIA Collective Communications Library)等通信库实现。参数更新 使用同步后的梯度…...

2024年PMP考试备考经验分享
PMP是项目管理领域最重要的认证之一,本身是IT行业比较流行的证书,近几年在临床试验领域也渐渐流行起来,是我周围临床项PM几乎人手一个的证书。 考试时间:PMP认证考试形式为180道选择题,考试时间为3小时50分。 考试计划ÿ…...

MT3046 愤怒的象棚
思路: a[]存愤怒值;b[i]存以i结尾的,窗口里的最大值;c[i]存以i结尾的,窗口里面包含✳的最大值。 (✳为新大象的位置) 例:1 2 3 4 ✳ 5 6 7 8 9 则ans的计算公式b3b4c4c5c6b7b8b9…...

深入了解代理IP常见协议:区别与选择
代理服务器在网络使用中扮演着重要的角色,是您设备和互联网之间的中间层。它不仅可以增强网络访问的安全性和隐私保护,还可以提供许多灵活的应用。使用代理时,不同的协议类型对数据交换具有不同的规则和特征。常见的代理协议包括HTTP代理、HT…...

【Linux 线程】线程的基本概念、LWP的理解
文章目录 一、ps -L 指令🍎二、线程控制 一、ps -L 指令🍎 🐧 使用 ps -L 命令查看轻量级进程信息;🐧 pthread_self() 用于获取用户态线程的 tid,而并非轻量级进程ID;🐧 getpid() 用…...

Dify中的工具
Dify中的工具分为内置工具(硬编码)和第三方工具(OpenAPI Swagger/ChatGPT Plugin)。工具可被Workflow(工作流)和Agent使用,当然Workflow也可被发布为工具,这样Workflow(工…...

在Visutal Studio 2022中完成D3D12初始化
在Visutal Studio 2022中完成DirectX设备初始化 1 DirectX121.1 DirectX 简介1.2 DirectX SDK安装2 D3D12初始化2.1 创建Windwos桌面项目2.2 修改符合模式2.3 下载d3dx12.h文件2.4 创建一个异常类D3DException,定义抛出异常实例的宏ThrowIfFailed3 D3D12的初始化步骤3.1 初始化…...

MobaXterm工具
MobaXterm 是一个增强型的 Windows 终端。其为 Windows 桌面提供所有重要的远程网络终端工具(如 SSH、X11、RDP、VNC、FTP、SFTP、Telnet、Serial、Mosh、WSL 等),和 Unix 命令(如 bash、ls、cat、sed、grep、awk、rsync 等&#…...

二分图练习
对于二分图我们可以用染色法 #include<bits/stdc.h> using namespace std;#define int long long const int N 2e65; int e[N],ne[N],h[N],idx 0; int colo[N]; int num 0;void add(int x,int y){e[idx] y;ne[idx] h[x];h[x] idx; } void dfs(int nod,int c){colo…...

创新设计策略:提升大屏幕可视化设计效果的关键方法
随着科技的不断发展和数据量的快速增长,数据可视化大屏在各个行业中的应用越来越广泛,可以帮助人们更好地理解和分析数据,可视化大屏设计也因此成了众多企业的需求。但很多设计师对可视化大屏设计并不了解,也不知道如何制作可视化…...

论文 | Chain-of-Thought Prompting Elicits Reasoningin Large Language Models 思维链
这篇论文研究了如何通过生成一系列中间推理步骤(即思维链)来显著提高大型语言模型进行复杂推理的能力。论文展示了一种简单的方法,称为思维链提示,通过在提示中提供几个思维链示例来自然地激发这种推理能力。 主要发现࿱…...

[机器学习]-人工智能对程序员的深远影响——案例分析
机器学习和人工智能对未来程序员的深远影响 目录 机器学习和人工智能对未来程序员的深远影响1. **自动化编码任务**1.1 代码生成1.2 自动调试1.3 测试自动化 2. **提升开发效率**2.1 智能建议2.2 项目管理 3. **改变编程范式**3.1 数据驱动开发 4. **职业发展的新机遇**4.1 AI工…...

AI学习环境 没有更好的替代 - (Google)Drive + Colab
在开始正题前,请容许我做一番回顾,并夹带一点点私货(谷歌扛旗的开源精神还没有死,并且会是未来的举足轻重的力量) 卧龙凤雏,一时瑜亮。一切的缘起应该是世纪初的门户网站乱战。 彼时,谷歌是从…...

【观成科技】Websocket协议代理隧道加密流量分析与检测
Websocket协议代理隧道加密流量简介 攻防场景下,Websocket协议常被用于代理隧道的搭建,攻击者企图通过Websocket协议来绕过网络限制,搭建一个低延迟、双向实时数据传输的隧道。当前,主流的支持Websocket通信代理的工具有…...
DangerWind-RPC-framework---三、服务端下机
当一台机器下线时,面临很多问题:如何将其从注册中心下线?如何清理释放资源?客户端拉取服务列表时也使用了本地缓存,如何及时更新本地缓存? 服务端机器的优雅下线需要使用ShutdownHook,这相当于添…...

基于Make的c工程No compilation commands found报错
由于安装gcc时只安装了build-essential,没有将其添加到环境变量中,因此打开Make工程时,CLion会产生如下错误: 要解决这个问题,一个方法是将GCC添加到环境变量中,但是这个方法需要修改至少两个配置文件&…...

c++:面向对象的继承特性
什么是继承 (1)继承是C源生支持的一种语法特性,是C面向对象的一种表现 (2)继承特性可以让派生类“瞬间”拥有基类的所有(当然还得考虑权限)属性和方法 (3)继承特性本质上是为了代码复用 (4)类在C编译器的内部可以理解为结构体,派…...

skywalking-2-客户端-php的安装与使用
skywalking的客户端支持php,真的很棒。 官方安装文档:https://skywalking.apache.org/docs/skywalking-php/next/en/setup/service-agent/php-agent/readme/ 前置准备 本次使用的php版本是8.2.13: php -v PHP 8.2.13 (cli) (built: Nov 21 2023 09:5…...

图文讲解IDEA如何导入JDBC驱动包
前言 学习JDBC编程,势必要学会如何导入驱动包,这里笔者用图文的方式来介绍 视频版本在这里 50秒教你怎么导入驱动包然后进行JDBC编程的学习_哔哩哔哩_bilibili 忘记录音频了,大伙凑合着看 下载驱动包 https://mvnrepository.com/artifact/mysql/mysql-connector-java 去中…...

Chapter03-Authentication vulnerabilities
文章目录 1. 身份验证简介1.1 What is authentication1.2 difference between authentication and authorization1.3 身份验证机制失效的原因1.4 身份验证机制失效的影响 2. 基于登录功能的漏洞2.1 密码爆破2.2 用户名枚举2.3 有缺陷的暴力破解防护2.3.1 如果用户登录尝试失败次…...

linux之kylin系统nginx的安装
一、nginx的作用 1.可做高性能的web服务器 直接处理静态资源(HTML/CSS/图片等),响应速度远超传统服务器类似apache支持高并发连接 2.反向代理服务器 隐藏后端服务器IP地址,提高安全性 3.负载均衡服务器 支持多种策略分发流量…...

Vue3 + Element Plus + TypeScript中el-transfer穿梭框组件使用详解及示例
使用详解 Element Plus 的 el-transfer 组件是一个强大的穿梭框组件,常用于在两个集合之间进行数据转移,如权限分配、数据选择等场景。下面我将详细介绍其用法并提供一个完整示例。 核心特性与用法 基本属性 v-model:绑定右侧列表的值&…...
深入浅出:JavaScript 中的 `window.crypto.getRandomValues()` 方法
深入浅出:JavaScript 中的 window.crypto.getRandomValues() 方法 在现代 Web 开发中,随机数的生成看似简单,却隐藏着许多玄机。无论是生成密码、加密密钥,还是创建安全令牌,随机数的质量直接关系到系统的安全性。Jav…...

理解 MCP 工作流:使用 Ollama 和 LangChain 构建本地 MCP 客户端
🌟 什么是 MCP? 模型控制协议 (MCP) 是一种创新的协议,旨在无缝连接 AI 模型与应用程序。 MCP 是一个开源协议,它标准化了我们的 LLM 应用程序连接所需工具和数据源并与之协作的方式。 可以把它想象成你的 AI 模型 和想要使用它…...
TRS收益互换:跨境资本流动的金融创新工具与系统化解决方案
一、TRS收益互换的本质与业务逻辑 (一)概念解析 TRS(Total Return Swap)收益互换是一种金融衍生工具,指交易双方约定在未来一定期限内,基于特定资产或指数的表现进行现金流交换的协议。其核心特征包括&am…...
【git】把本地更改提交远程新分支feature_g
创建并切换新分支 git checkout -b feature_g 添加并提交更改 git add . git commit -m “实现图片上传功能” 推送到远程 git push -u origin feature_g...
【RockeMQ】第2节|RocketMQ快速实战以及核⼼概念详解(二)
升级Dledger高可用集群 一、主从架构的不足与Dledger的定位 主从架构缺陷 数据备份依赖Slave节点,但无自动故障转移能力,Master宕机后需人工切换,期间消息可能无法读取。Slave仅存储数据,无法主动升级为Master响应请求ÿ…...

Android15默认授权浮窗权限
我们经常有那种需求,客户需要定制的apk集成在ROM中,并且默认授予其【显示在其他应用的上层】权限,也就是我们常说的浮窗权限,那么我们就可以通过以下方法在wms、ams等系统服务的systemReady()方法中调用即可实现预置应用默认授权浮…...

用docker来安装部署freeswitch记录
今天刚才测试一个callcenter的项目,所以尝试安装freeswitch 1、使用轩辕镜像 - 中国开发者首选的专业 Docker 镜像加速服务平台 编辑下面/etc/docker/daemon.json文件为 {"registry-mirrors": ["https://docker.xuanyuan.me"] }同时可以进入轩…...