当前位置: 首页 > news >正文

TensorFlow项目练手(三)——基于GRU股票走势预测任务

项目介绍

项目基于GRU算法通过20天的股票序列来预测第21天的数据,有些项目也可以用LSTM算法,两者主要差别如下:

  • LSTM算法:目前使用最多的时间序列算法,是一种特殊的RNN(循环神经网络),能够学习长期的依赖关系。主要是为了解决长序列训练过程中的梯度消失和梯度爆炸问题。简单来说,就是相比普通的RNN,LSTM能够在更长的序列中有更好的表现。
  • GRU算法:是一种特殊的RNN。和LSTM一样,也是为了解决长期记忆和反向传播中的梯度等问题而提出来的。相比LSTM,使用GRU能够达到相当的效果,并且相比之下更容易进行训练,能够很大程度上提高训练效率,因此很多时候会更倾向于使用GRU。

一、准备数据

1、获取数据

  1. 通过命令行安装yfinance
  2. 通过api获取股票数据
  3. 保存到csv中方便使用
import pandas_datareader.data as web
import datetime
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
plt.rcParams['font.sans-serif']='SimHei' #图表显示中文import yfinance as yf
yf.pdr_override() #需要调用这个函数# 1、获取股票数据
#上海的股票代码+.SS;深圳的股票代码+.SZ :
stock = web.get_data_yahoo("601318.SS", start="2022-01-01", end="2023-07-17")
# 保存到csv中
pd.DataFrame(data=stock).to_csv('./stock.csv')# 2、获取csv中的数据
features = pd.read_csv('stock.csv')
features = features.drop('Adj Close',axis=1)
features.head()

在这里插入图片描述

2、数据可视化

通过绘图的方式查看当前的数据情况

# 3、绘图看看收盘价数据情况
close=features["Close"]
# 计算20天和100天移动平均线:
short_rolling_close = close.rolling(window=20).mean()
long_rolling_close = close.rolling(window=100).mean()
# 绘制
fig, ax = plt.subplots(figsize=(16,9))   #画面大小,可以修改
ax.plot(close.index, close, label='中国平安')   #以收盘价为索引值绘图
ax.plot(short_rolling_close.index, short_rolling_close, label='20天均线')
ax.plot(long_rolling_close.index, long_rolling_close, label='100天均线')
#x轴、y轴及图例:
ax.set_xlabel('日期')
ax.set_ylabel('收盘价 (人民币)')
ax.legend()      #图例
plt.show()      #绘图

在这里插入图片描述

3、数据预处理

取出当前的收盘价,删除无用的日期元素

# 4、取出label值
labels = features['Close']
time = features['Date']
features = features.drop('Date',axis=1)
features.head()

在这里插入图片描述

进行数据的归一化

# 5、数据预处理
from sklearn import preprocessing
input_features = preprocessing.StandardScaler().fit_transform(features)
input_features

在这里插入图片描述

4、构建数据序列

由于RNN的算法要求我们要有一定的序列,来预测出下一个值,所以我们按照20天的数据作为一个序列

# 6、定义序列,[下标1-20天预测第21天的收盘价]
from collections import dequex = []
y = []seq_len = 20
deq = deque(maxlen=seq_len)
for i in input_features:deq.append(list(i))if len(deq) == seq_len:x.append(list(deq))x = x[:-1] # 取少一个序列,因为最后个序列没有答案
y = features['Close'].values[seq_len: ] #从第二十一天开始(下标为20)
time = time.values[seq_len: ] #从第二十一天开始(下标为20)x, y, time = np.array(x), np.array(y), np.array(time)
print(x.shape)
print(y.shape)
print(time.shape)

在这里插入图片描述

二、构建模型

1、搭建GRU模型

import tensorflow as tf
from tensorflow.keras import initializers
from tensorflow.keras import regularizers
from tensorflow.keras import layersfrom keras.models import load_model
from keras.models import Sequential
from keras.layers import Dropout
from keras.layers.core import Dense
from keras.optimizers import Adam# 7、搭建模型
model = tf.keras.Sequential()
model.add(layers.GRU(8,input_shape=(20,5), activation='relu', return_sequences=True,kernel_regularizer=tf.keras.regularizers.l2(0.01)))
model.add(layers.GRU(16, activation='relu', return_sequences=True,kernel_regularizer=tf.keras.regularizers.l2(0.01)))
model.add(layers.GRU(32, activation='relu', return_sequences=False,kernel_regularizer=tf.keras.regularizers.l2(0.01)))
model.add(layers.Dense(16,kernel_initializer='random_normal',kernel_regularizer=tf.keras.regularizers.l2(0.01)))
model.add(layers.Dense(1))
model.summary()

在这里插入图片描述

2、优化器和损失函数

# 优化器和损失函数
model.compile(optimizer=tf.keras.optimizers.Adam(0.001),loss=tf.keras.losses.MeanAbsoluteError(), # 标签和预测之间绝对差异的平均metrics = tf.keras.losses.MeanSquaredLogarithmicError()) # 计算标签和预测

3、开始训练

25%的比例作为验证集,75%的比例作为训练集

# 开始训练
model.fit(x,y,validation_split=0.25,epochs=200,batch_size=128)

在这里插入图片描述

4、模型预测

# 预测
y_pred = model.predict(x)
fig = plt.figure(figsize=(10,5))
axes = fig.add_subplot(111)
axes.plot(time,y,'b-',label='actual')
# 预测值,红色散点
axes.plot(time,y_pred,'r--',label='predict')
axes.set_xticks(time[::50])
axes.set_xticklabels(time[::50],rotation=45)plt.legend()
plt.show()

在这里插入图片描述

5、回归指标评估

from sklearn.metrics import mean_squared_error,mean_absolute_error,r2_score
from math import sqrt#回归评价指标
# calculate MSE 均方误差
mse=mean_squared_error(y,y_pred)
# calculate RMSE 均方根误差
rmse = sqrt(mean_squared_error(y, y_pred))
#calculate MAE 平均绝对误差
mae=mean_absolute_error(y,y_pred)
print('均方误差: %.6f' % mse)
print('均方根误差: %.6f' % rmse)
print('平均绝对误差: %.6f' % mae)

在这里插入图片描述

源代码

  • 源码查看

相关文章:

TensorFlow项目练手(三)——基于GRU股票走势预测任务

项目介绍 项目基于GRU算法通过20天的股票序列来预测第21天的数据,有些项目也可以用LSTM算法,两者主要差别如下: LSTM算法:目前使用最多的时间序列算法,是一种特殊的RNN(循环神经网络)&#xf…...

微信小程序页面传值为对象[Object Object]详解

微信小程序页面传值为对象[Object Object]详解 1、先将传递的对象转化为JSON字符串拼接到url上2、在接受对象页面进行转译3、打印结果 1、先将传递的对象转化为JSON字符串拼接到url上 // info为对象 let stationInfo JSON.stringify(info) uni.navigateTo({url: /pages/statio…...

Redis篇

文章目录 Redis-使用场景1、缓存穿透2、缓存击穿3、缓存雪崩4、双写一致5、Redis持久化6、数据过期策略7、数据淘汰策略 Redis-分布式锁1、redis分布式锁,是如何实现的?2、redisson实现的分布式锁执行流程3、redisson实现的分布式锁-可重入4、redisson实…...

Entity Framework(EF)查询

一、In 查询 var list = dbContext.Users.Where(u => new int[] {1, 2, 3, 5,...

使用Pytest生成HTML测试报告

背景 最近开发有关业务场景的功能时,涉及的API接口比较多,需要自己模拟多个业务场景的自动化测试(暂时不涉及性能测试),并且在每次测试完后能够生成一份测试报告。 考虑到日常使用Python自带的UnitTest,所…...

DSA之图(4):图的应用

文章目录 0 图的应用1 生成树1.1 无向图的生成树1.2 最小生成树1.2.1 构造最小生成树1.2.2 Prim算法构造最小生成树1.2.3 Kruskal算法构造最小生成树1.2.4 两种算法的比较 1.3 最短路径1.3.1 两点间最短路径1.3.2 某源点到其他各点最短路径1.3.3 Dijkstra1.3.4 Floyd 1.4 拓扑排…...

[SQL挖掘机] - 窗口函数 - row_number

介绍: row_number() 是一种常用的窗口函数,它为结果集中的每一行分配一个唯一的数字。这个数字的分配基于指定的排序顺序,并且不会跳过相同的排名。 用法: row_number() 函数的语法如下: row_number() over ([partition by 列名1, 列名2,…...

【论文阅读】通过解缠绕表示学习提升领域泛化能力用于主题感知的作文评分

摘要 本文工作聚焦于从领域泛化的视角提升AES模型的泛化能力,在该情况下,目标主题的数据在训练时不能被获得。本文提出了一个主题感知的神经AES模型(PANN)来抽取用于作文评分的综合的表示,包括主题无关(pr…...

二分查找P1873 [COCI2011-2012#5] EKO / 砍树

P1873 [COCI2011-2012#5] EKO / 砍树 - 洛谷 | 计算机科学教育新生态 (luogu.com.cn) 这个题就是给新手练手的&#xff0c;在那个位置上在进行&#xff0c;寻找合适的砍树高度&#xff0c;下面在介绍一个二分查找的模板 int binarySearch(vector<int>& nums, int t…...

【BOOST程序库】正则表达式相关操作

基本概念这里不解释了&#xff0c;代码中详细解释了BOOST程序库中对于正则表达式常用方法的详细用法。 #include <iostream> #include <string>//正则表达式头文件 #include <boost/xpressive/xpressive.hpp>int main() {//声明正则&#xff1a;boost::pres…...

阿里云国际版在使用过程中应该注意什么呢?

为确保系统稳定性&#xff0c;用户不得进行以下操作。否则&#xff0c;阿里云可能无法解决由以下违规操作引起的问题&#xff1a; 1) Windows系统中的PV Drivers 程序不可删除 PV Drivers程序为服务器虚拟化驱动程序&#xff0c;请不要针对该程序进行任何操作&#xff0c;如果删…...

Flutter Provider 共享状态管理

在使用Provider的时候&#xff0c;我们主要关心三个概念&#xff1a; ChangeNotifier&#xff1a;真正数据&#xff08;状态&#xff09;存放的地方ChangeNotifierProvider&#xff1a;Widget树中提供数据&#xff08;状态&#xff09;的地方&#xff0c;会在其中创建对应的Ch…...

std vector 用法

使用vector&#xff0c;需添加头文件#include&#xff0c;要使用sort或find&#xff0c;则需要添加头文件#include。函数封装在命名空间std中&#xff0c;使用&#xff1a;using namespace std; 1、vector的初始化 std::vector<int> nVec;    // 空对象 std::vecto…...

vue vite ts electron ipc addon-napi c arm64

初始化 因网络问题建议使用 cnpm 代替 npm npm init vue # 全选 yes npm i # 进入项目目录后使用 npm i electron electron-builder -D npm i commander -D # 额外组件electron 新建 plugins、src/electron 文件夹 添加 src/electron/background.ts 属于主进程 ipcMain.o…...

机器人科普--AGILOX 叉车

机器人科普--AGILOX 叉车 1 概述2 导航3 驱动轮组4 叉举参考 1 概述 AGILOX 叉车&#xff0c;不需要画地图路径&#xff0c;很厉害。 2 导航 中间路径自由导航&#xff0c;末端规划出轨迹路线&#xff0c;并使用优良的控制器做轨迹追踪。 AGILOX &#xff5c; 10 Min setu…...

Django的生命周期流程图(补充)、路由层urls.py文件、无名分组和有名分组、反向解析(无名反向解析、有名反向解析)、路由分发、伪静态

一、orm的增删改查方法&#xff08;补充&#xff09; 1. 查询resmodels.表名(类名).objects.all()[0]resmodels.表名(类名).objects.filter(usernameusername, passwordpassword).all()res models.表名(类名).objects.first() # 判断&#xff0c;判断数据是否有# res如果查询…...

selenium交互代码

一&#xff1a;selenium交互 用selenium打开网页后&#xff0c;也可以做一系列真人的操作&#xff0c;也就是利用selenium和浏览器进行交互&#xff0c;可利用以下几个函数进行操作&#xff1a; input.send_keys() 传递输入内容给某输入框button.click() 点击某按钮browser.e…...

下载远程服务器文件

业务需求:下载某云盘的视频文件存储到本地 测试代码 RequestMapping("testVideo")public String test() {try {SimpleDateFormat DATE_FORMAT new SimpleDateFormat("yyyy/MM/dd/");//组装本地保存地址StringBuilder filePath new StringBuilder(StoreP…...

[SQL挖掘机] - 索引

介绍: 当你在数据库中进行查询时&#xff0c;索引是一种用于提高查询性能的重要工具。索引是对表中的一列或多列进行排序的数据结构&#xff0c;它可以快速定位到满足特定条件的记录&#xff0c;从而减少了查询所需的时间和资源。 在数据库中使用索引的主要好处包括&#xff…...

C++STL库中的list

文章目录 list的介绍及使用 list的常用接口 list的模拟实现 list与vector的对比 一、list的介绍及使用 1. list是可以在常数范围内在任意位置进行插入和删除的序列式容器&#xff0c;并且该容器可以前后双向迭代。 2. list的底层是双向带头循环链表结构&#xff0c;双向带头循…...

测试微信模版消息推送

进入“开发接口管理”--“公众平台测试账号”&#xff0c;无需申请公众账号、可在测试账号中体验并测试微信公众平台所有高级接口。 获取access_token: 自定义模版消息&#xff1a; 关注测试号&#xff1a;扫二维码关注测试号。 发送模版消息&#xff1a; import requests da…...

AI Agent与Agentic AI:原理、应用、挑战与未来展望

文章目录 一、引言二、AI Agent与Agentic AI的兴起2.1 技术契机与生态成熟2.2 Agent的定义与特征2.3 Agent的发展历程 三、AI Agent的核心技术栈解密3.1 感知模块代码示例&#xff1a;使用Python和OpenCV进行图像识别 3.2 认知与决策模块代码示例&#xff1a;使用OpenAI GPT-3进…...

【第二十一章 SDIO接口(SDIO)】

第二十一章 SDIO接口 目录 第二十一章 SDIO接口(SDIO) 1 SDIO 主要功能 2 SDIO 总线拓扑 3 SDIO 功能描述 3.1 SDIO 适配器 3.2 SDIOAHB 接口 4 卡功能描述 4.1 卡识别模式 4.2 卡复位 4.3 操作电压范围确认 4.4 卡识别过程 4.5 写数据块 4.6 读数据块 4.7 数据流…...

ETLCloud可能遇到的问题有哪些?常见坑位解析

数据集成平台ETLCloud&#xff0c;主要用于支持数据的抽取&#xff08;Extract&#xff09;、转换&#xff08;Transform&#xff09;和加载&#xff08;Load&#xff09;过程。提供了一个简洁直观的界面&#xff0c;以便用户可以在不同的数据源之间轻松地进行数据迁移和转换。…...

JVM暂停(Stop-The-World,STW)的原因分类及对应排查方案

JVM暂停(Stop-The-World,STW)的完整原因分类及对应排查方案,结合JVM运行机制和常见故障场景整理而成: 一、GC相关暂停​​ 1. ​​安全点(Safepoint)阻塞​​ ​​现象​​:JVM暂停但无GC日志,日志显示No GCs detected。​​原因​​:JVM等待所有线程进入安全点(如…...

ip子接口配置及删除

配置永久生效的子接口&#xff0c;2个IP 都可以登录你这一台服务器。重启不失效。 永久的 [应用] vi /etc/sysconfig/network-scripts/ifcfg-eth0修改文件内内容 TYPE"Ethernet" BOOTPROTO"none" NAME"eth0" DEVICE"eth0" ONBOOT&q…...

HashMap中的put方法执行流程(流程图)

1 put操作整体流程 HashMap 的 put 操作是其最核心的功能之一。在 JDK 1.8 及以后版本中&#xff0c;其主要逻辑封装在 putVal 这个内部方法中。整个过程大致如下&#xff1a; 初始判断与哈希计算&#xff1a; 首先&#xff0c;putVal 方法会检查当前的 table&#xff08;也就…...

CRMEB 中 PHP 短信扩展开发:涵盖一号通、阿里云、腾讯云、创蓝

目前已有一号通短信、阿里云短信、腾讯云短信扩展 扩展入口文件 文件目录 crmeb\services\sms\Sms.php 默认驱动类型为&#xff1a;一号通 namespace crmeb\services\sms;use crmeb\basic\BaseManager; use crmeb\services\AccessTokenServeService; use crmeb\services\sms\…...

接口自动化测试:HttpRunner基础

相关文档 HttpRunner V3.x中文文档 HttpRunner 用户指南 使用HttpRunner 3.x实现接口自动化测试 HttpRunner介绍 HttpRunner 是一个开源的 API 测试工具&#xff0c;支持 HTTP(S)/HTTP2/WebSocket/RPC 等网络协议&#xff0c;涵盖接口测试、性能测试、数字体验监测等测试类型…...

【p2p、分布式,区块链笔记 MESH】Bluetooth蓝牙通信 BLE Mesh协议的拓扑结构 定向转发机制

目录 节点的功能承载层&#xff08;GATT/Adv&#xff09;局限性&#xff1a; 拓扑关系定向转发机制定向转发意义 CG 节点的功能 节点的功能由节点支持的特性和功能决定。所有节点都能够发送和接收网格消息。节点还可以选择支持一个或多个附加功能&#xff0c;如 Configuration …...