当前位置: 首页 > news >正文

TensorFlow项目练手(三)——基于GRU股票走势预测任务

项目介绍

项目基于GRU算法通过20天的股票序列来预测第21天的数据,有些项目也可以用LSTM算法,两者主要差别如下:

  • LSTM算法:目前使用最多的时间序列算法,是一种特殊的RNN(循环神经网络),能够学习长期的依赖关系。主要是为了解决长序列训练过程中的梯度消失和梯度爆炸问题。简单来说,就是相比普通的RNN,LSTM能够在更长的序列中有更好的表现。
  • GRU算法:是一种特殊的RNN。和LSTM一样,也是为了解决长期记忆和反向传播中的梯度等问题而提出来的。相比LSTM,使用GRU能够达到相当的效果,并且相比之下更容易进行训练,能够很大程度上提高训练效率,因此很多时候会更倾向于使用GRU。

一、准备数据

1、获取数据

  1. 通过命令行安装yfinance
  2. 通过api获取股票数据
  3. 保存到csv中方便使用
import pandas_datareader.data as web
import datetime
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
plt.rcParams['font.sans-serif']='SimHei' #图表显示中文import yfinance as yf
yf.pdr_override() #需要调用这个函数# 1、获取股票数据
#上海的股票代码+.SS;深圳的股票代码+.SZ :
stock = web.get_data_yahoo("601318.SS", start="2022-01-01", end="2023-07-17")
# 保存到csv中
pd.DataFrame(data=stock).to_csv('./stock.csv')# 2、获取csv中的数据
features = pd.read_csv('stock.csv')
features = features.drop('Adj Close',axis=1)
features.head()

在这里插入图片描述

2、数据可视化

通过绘图的方式查看当前的数据情况

# 3、绘图看看收盘价数据情况
close=features["Close"]
# 计算20天和100天移动平均线:
short_rolling_close = close.rolling(window=20).mean()
long_rolling_close = close.rolling(window=100).mean()
# 绘制
fig, ax = plt.subplots(figsize=(16,9))   #画面大小,可以修改
ax.plot(close.index, close, label='中国平安')   #以收盘价为索引值绘图
ax.plot(short_rolling_close.index, short_rolling_close, label='20天均线')
ax.plot(long_rolling_close.index, long_rolling_close, label='100天均线')
#x轴、y轴及图例:
ax.set_xlabel('日期')
ax.set_ylabel('收盘价 (人民币)')
ax.legend()      #图例
plt.show()      #绘图

在这里插入图片描述

3、数据预处理

取出当前的收盘价,删除无用的日期元素

# 4、取出label值
labels = features['Close']
time = features['Date']
features = features.drop('Date',axis=1)
features.head()

在这里插入图片描述

进行数据的归一化

# 5、数据预处理
from sklearn import preprocessing
input_features = preprocessing.StandardScaler().fit_transform(features)
input_features

在这里插入图片描述

4、构建数据序列

由于RNN的算法要求我们要有一定的序列,来预测出下一个值,所以我们按照20天的数据作为一个序列

# 6、定义序列,[下标1-20天预测第21天的收盘价]
from collections import dequex = []
y = []seq_len = 20
deq = deque(maxlen=seq_len)
for i in input_features:deq.append(list(i))if len(deq) == seq_len:x.append(list(deq))x = x[:-1] # 取少一个序列,因为最后个序列没有答案
y = features['Close'].values[seq_len: ] #从第二十一天开始(下标为20)
time = time.values[seq_len: ] #从第二十一天开始(下标为20)x, y, time = np.array(x), np.array(y), np.array(time)
print(x.shape)
print(y.shape)
print(time.shape)

在这里插入图片描述

二、构建模型

1、搭建GRU模型

import tensorflow as tf
from tensorflow.keras import initializers
from tensorflow.keras import regularizers
from tensorflow.keras import layersfrom keras.models import load_model
from keras.models import Sequential
from keras.layers import Dropout
from keras.layers.core import Dense
from keras.optimizers import Adam# 7、搭建模型
model = tf.keras.Sequential()
model.add(layers.GRU(8,input_shape=(20,5), activation='relu', return_sequences=True,kernel_regularizer=tf.keras.regularizers.l2(0.01)))
model.add(layers.GRU(16, activation='relu', return_sequences=True,kernel_regularizer=tf.keras.regularizers.l2(0.01)))
model.add(layers.GRU(32, activation='relu', return_sequences=False,kernel_regularizer=tf.keras.regularizers.l2(0.01)))
model.add(layers.Dense(16,kernel_initializer='random_normal',kernel_regularizer=tf.keras.regularizers.l2(0.01)))
model.add(layers.Dense(1))
model.summary()

在这里插入图片描述

2、优化器和损失函数

# 优化器和损失函数
model.compile(optimizer=tf.keras.optimizers.Adam(0.001),loss=tf.keras.losses.MeanAbsoluteError(), # 标签和预测之间绝对差异的平均metrics = tf.keras.losses.MeanSquaredLogarithmicError()) # 计算标签和预测

3、开始训练

25%的比例作为验证集,75%的比例作为训练集

# 开始训练
model.fit(x,y,validation_split=0.25,epochs=200,batch_size=128)

在这里插入图片描述

4、模型预测

# 预测
y_pred = model.predict(x)
fig = plt.figure(figsize=(10,5))
axes = fig.add_subplot(111)
axes.plot(time,y,'b-',label='actual')
# 预测值,红色散点
axes.plot(time,y_pred,'r--',label='predict')
axes.set_xticks(time[::50])
axes.set_xticklabels(time[::50],rotation=45)plt.legend()
plt.show()

在这里插入图片描述

5、回归指标评估

from sklearn.metrics import mean_squared_error,mean_absolute_error,r2_score
from math import sqrt#回归评价指标
# calculate MSE 均方误差
mse=mean_squared_error(y,y_pred)
# calculate RMSE 均方根误差
rmse = sqrt(mean_squared_error(y, y_pred))
#calculate MAE 平均绝对误差
mae=mean_absolute_error(y,y_pred)
print('均方误差: %.6f' % mse)
print('均方根误差: %.6f' % rmse)
print('平均绝对误差: %.6f' % mae)

在这里插入图片描述

源代码

  • 源码查看

相关文章:

TensorFlow项目练手(三)——基于GRU股票走势预测任务

项目介绍 项目基于GRU算法通过20天的股票序列来预测第21天的数据,有些项目也可以用LSTM算法,两者主要差别如下: LSTM算法:目前使用最多的时间序列算法,是一种特殊的RNN(循环神经网络)&#xf…...

微信小程序页面传值为对象[Object Object]详解

微信小程序页面传值为对象[Object Object]详解 1、先将传递的对象转化为JSON字符串拼接到url上2、在接受对象页面进行转译3、打印结果 1、先将传递的对象转化为JSON字符串拼接到url上 // info为对象 let stationInfo JSON.stringify(info) uni.navigateTo({url: /pages/statio…...

Redis篇

文章目录 Redis-使用场景1、缓存穿透2、缓存击穿3、缓存雪崩4、双写一致5、Redis持久化6、数据过期策略7、数据淘汰策略 Redis-分布式锁1、redis分布式锁,是如何实现的?2、redisson实现的分布式锁执行流程3、redisson实现的分布式锁-可重入4、redisson实…...

Entity Framework(EF)查询

一、In 查询 var list = dbContext.Users.Where(u => new int[] {1, 2, 3, 5,...

使用Pytest生成HTML测试报告

背景 最近开发有关业务场景的功能时,涉及的API接口比较多,需要自己模拟多个业务场景的自动化测试(暂时不涉及性能测试),并且在每次测试完后能够生成一份测试报告。 考虑到日常使用Python自带的UnitTest,所…...

DSA之图(4):图的应用

文章目录 0 图的应用1 生成树1.1 无向图的生成树1.2 最小生成树1.2.1 构造最小生成树1.2.2 Prim算法构造最小生成树1.2.3 Kruskal算法构造最小生成树1.2.4 两种算法的比较 1.3 最短路径1.3.1 两点间最短路径1.3.2 某源点到其他各点最短路径1.3.3 Dijkstra1.3.4 Floyd 1.4 拓扑排…...

[SQL挖掘机] - 窗口函数 - row_number

介绍: row_number() 是一种常用的窗口函数,它为结果集中的每一行分配一个唯一的数字。这个数字的分配基于指定的排序顺序,并且不会跳过相同的排名。 用法: row_number() 函数的语法如下: row_number() over ([partition by 列名1, 列名2,…...

【论文阅读】通过解缠绕表示学习提升领域泛化能力用于主题感知的作文评分

摘要 本文工作聚焦于从领域泛化的视角提升AES模型的泛化能力,在该情况下,目标主题的数据在训练时不能被获得。本文提出了一个主题感知的神经AES模型(PANN)来抽取用于作文评分的综合的表示,包括主题无关(pr…...

二分查找P1873 [COCI2011-2012#5] EKO / 砍树

P1873 [COCI2011-2012#5] EKO / 砍树 - 洛谷 | 计算机科学教育新生态 (luogu.com.cn) 这个题就是给新手练手的&#xff0c;在那个位置上在进行&#xff0c;寻找合适的砍树高度&#xff0c;下面在介绍一个二分查找的模板 int binarySearch(vector<int>& nums, int t…...

【BOOST程序库】正则表达式相关操作

基本概念这里不解释了&#xff0c;代码中详细解释了BOOST程序库中对于正则表达式常用方法的详细用法。 #include <iostream> #include <string>//正则表达式头文件 #include <boost/xpressive/xpressive.hpp>int main() {//声明正则&#xff1a;boost::pres…...

阿里云国际版在使用过程中应该注意什么呢?

为确保系统稳定性&#xff0c;用户不得进行以下操作。否则&#xff0c;阿里云可能无法解决由以下违规操作引起的问题&#xff1a; 1) Windows系统中的PV Drivers 程序不可删除 PV Drivers程序为服务器虚拟化驱动程序&#xff0c;请不要针对该程序进行任何操作&#xff0c;如果删…...

Flutter Provider 共享状态管理

在使用Provider的时候&#xff0c;我们主要关心三个概念&#xff1a; ChangeNotifier&#xff1a;真正数据&#xff08;状态&#xff09;存放的地方ChangeNotifierProvider&#xff1a;Widget树中提供数据&#xff08;状态&#xff09;的地方&#xff0c;会在其中创建对应的Ch…...

std vector 用法

使用vector&#xff0c;需添加头文件#include&#xff0c;要使用sort或find&#xff0c;则需要添加头文件#include。函数封装在命名空间std中&#xff0c;使用&#xff1a;using namespace std; 1、vector的初始化 std::vector<int> nVec;    // 空对象 std::vecto…...

vue vite ts electron ipc addon-napi c arm64

初始化 因网络问题建议使用 cnpm 代替 npm npm init vue # 全选 yes npm i # 进入项目目录后使用 npm i electron electron-builder -D npm i commander -D # 额外组件electron 新建 plugins、src/electron 文件夹 添加 src/electron/background.ts 属于主进程 ipcMain.o…...

机器人科普--AGILOX 叉车

机器人科普--AGILOX 叉车 1 概述2 导航3 驱动轮组4 叉举参考 1 概述 AGILOX 叉车&#xff0c;不需要画地图路径&#xff0c;很厉害。 2 导航 中间路径自由导航&#xff0c;末端规划出轨迹路线&#xff0c;并使用优良的控制器做轨迹追踪。 AGILOX &#xff5c; 10 Min setu…...

Django的生命周期流程图(补充)、路由层urls.py文件、无名分组和有名分组、反向解析(无名反向解析、有名反向解析)、路由分发、伪静态

一、orm的增删改查方法&#xff08;补充&#xff09; 1. 查询resmodels.表名(类名).objects.all()[0]resmodels.表名(类名).objects.filter(usernameusername, passwordpassword).all()res models.表名(类名).objects.first() # 判断&#xff0c;判断数据是否有# res如果查询…...

selenium交互代码

一&#xff1a;selenium交互 用selenium打开网页后&#xff0c;也可以做一系列真人的操作&#xff0c;也就是利用selenium和浏览器进行交互&#xff0c;可利用以下几个函数进行操作&#xff1a; input.send_keys() 传递输入内容给某输入框button.click() 点击某按钮browser.e…...

下载远程服务器文件

业务需求:下载某云盘的视频文件存储到本地 测试代码 RequestMapping("testVideo")public String test() {try {SimpleDateFormat DATE_FORMAT new SimpleDateFormat("yyyy/MM/dd/");//组装本地保存地址StringBuilder filePath new StringBuilder(StoreP…...

[SQL挖掘机] - 索引

介绍: 当你在数据库中进行查询时&#xff0c;索引是一种用于提高查询性能的重要工具。索引是对表中的一列或多列进行排序的数据结构&#xff0c;它可以快速定位到满足特定条件的记录&#xff0c;从而减少了查询所需的时间和资源。 在数据库中使用索引的主要好处包括&#xff…...

C++STL库中的list

文章目录 list的介绍及使用 list的常用接口 list的模拟实现 list与vector的对比 一、list的介绍及使用 1. list是可以在常数范围内在任意位置进行插入和删除的序列式容器&#xff0c;并且该容器可以前后双向迭代。 2. list的底层是双向带头循环链表结构&#xff0c;双向带头循…...

【LeetCode 75】第十七题(1493)删掉一个元素以后全为1的最长子数组

目录 题目&#xff1a; 示例&#xff1a; 分析&#xff1a; 代码运行结果&#xff1a; 题目&#xff1a; 示例&#xff1a; 分析&#xff1a; 给一个数组&#xff0c;求删除一个元素以后能得到的连续的最长的全是1的子数组。 我们可以先单独统计出连续为1的子数组分别长度…...

配置IPv6 over IPv4 GRE隧道示例

组网需求 如图1&#xff0c;两个IPv6网络分别通过SwitchA和SwitchC与IPv4公网中的SwitchB连接&#xff0c;客户希望两个IPv6网络中的PC1和PC2实现互通。 其中PC1和PC2上分别指定SwitchA和SwitchC为自己的缺省网关。 图1 配置IPv6 over IPv4 GRE隧道组网图 配置思路 要实现I…...

Google Earth Engine谷歌地球引擎提取多波段长期反射率数据后绘制折线图并导出为Excel

本文介绍在谷歌地球引擎GEE中&#xff0c;提取多年遥感影像多个不同波段的反射率数据&#xff0c;在GEE内绘制各波段的长时间序列走势曲线图&#xff0c;并将各波段的反射率数据与其对应的成像日期一起导出为.csv文件的方法。 本文是谷歌地球引擎&#xff08;Google Earth Engi…...

第三大的数

414、第三大的数 class Solution {public int thirdMax(int[] nums) {Arrays.sort(nums);int tempnums[0];int ansnums[0];int count 0;// if(nums.length<3){// return nums[nums.length-1];// }// else {for(int inums.length-1;i>0;i--){if (nums[i]>nums[i…...

正则表达式中的方括号[]有什么用?

在正则表达式中&#xff0c;方括号 [] 是用于定义字符集合的元字符。它在正则表达式中有以下作用&#xff1a; 匹配字符集合中的任意一个字符&#xff1a;方括号中列出的字符&#xff0c;表示在这个位置可以匹配这些字符中的任意一个。例如&#xff0c;[abc] 将匹配任意一个字符…...

SQL编写规范

文章目录 1.命名规范&#xff1a;2.库表设计&#xff1a;3.查询数据&#xff1a;4.修改数据&#xff1a;5.索引创建&#xff1a; 1.命名规范&#xff1a; 1.库名、表名、字段名&#xff0c;必须使用小写字母或数字&#xff0c;不得超过30个字符。 2.库名、表名、字段名&#…...

Azure pipeline自动化打包发布

pipeline自动化&#xff0c;提交代码后&#xff0c;就自动打包&#xff0c;打包成功后自动发布 第一步 pipeline提交代码后&#xff0c;自动打包。 1 在Repos,分支里选择要触发的分支&#xff0c;这里选择cn_china,对该分支设置分支策略 2 在生产验证中增加新的策略 3 在分支安…...

【算法提高:动态规划】1.4 状态机模型 TODO

文章目录 例题列表1049. 大盗阿福&#xff08;其实就是打家劫舍&#xff09;1057. 股票买卖 IV&#xff08;k笔交易&#xff09;1058. 股票买卖 V&#xff08;冷冻期&#xff09;1052. 设计密码⭐⭐⭐&#x1f6b9;&#x1f6b9;&#x1f6b9;&#xff08;TODO&#xff09;1053…...

ip link add 命令

ip link add veth0 type veth peer name veth1 这条命令主要用于在 Linux 操作系统中创建一个新的 veth(虚拟以太网) 对&#xff0c;这是一种虚拟网络设备&#xff0c;用于在 Linux 命名空间&#xff08;namespaces&#xff09;之间创建网络连接。此命令将创建两个设备&#xf…...

unity事件处理

方法调用 //发送事件 【发送事件码&#xff0c;发送消息内容】 EventCenterUtil.Broadcast(EventCenterUtil.EventType.Joystick, ui);//监听无参事件 EventCenterUtil.AddListener(EventCenterUtil.EventType.Joystick, show); public void show(){}//发送事件 有参事件 Eve…...