TensorFlow项目练手(三)——基于GRU股票走势预测任务
项目介绍
项目基于GRU算法通过20天的股票序列来预测第21天的数据,有些项目也可以用LSTM算法,两者主要差别如下:
- LSTM算法:目前使用最多的时间序列算法,是一种特殊的RNN(循环神经网络),能够学习长期的依赖关系。主要是为了解决长序列训练过程中的梯度消失和梯度爆炸问题。简单来说,就是相比普通的RNN,LSTM能够在更长的序列中有更好的表现。
- GRU算法:是一种特殊的RNN。和LSTM一样,也是为了解决长期记忆和反向传播中的梯度等问题而提出来的。相比LSTM,使用GRU能够达到相当的效果,并且相比之下更容易进行训练,能够很大程度上提高训练效率,因此很多时候会更倾向于使用GRU。
一、准备数据
1、获取数据
- 通过命令行安装yfinance
- 通过api获取股票数据
- 保存到csv中方便使用
import pandas_datareader.data as web
import datetime
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
plt.rcParams['font.sans-serif']='SimHei' #图表显示中文import yfinance as yf
yf.pdr_override() #需要调用这个函数# 1、获取股票数据
#上海的股票代码+.SS;深圳的股票代码+.SZ :
stock = web.get_data_yahoo("601318.SS", start="2022-01-01", end="2023-07-17")
# 保存到csv中
pd.DataFrame(data=stock).to_csv('./stock.csv')# 2、获取csv中的数据
features = pd.read_csv('stock.csv')
features = features.drop('Adj Close',axis=1)
features.head()
2、数据可视化
通过绘图的方式查看当前的数据情况
# 3、绘图看看收盘价数据情况
close=features["Close"]
# 计算20天和100天移动平均线:
short_rolling_close = close.rolling(window=20).mean()
long_rolling_close = close.rolling(window=100).mean()
# 绘制
fig, ax = plt.subplots(figsize=(16,9)) #画面大小,可以修改
ax.plot(close.index, close, label='中国平安') #以收盘价为索引值绘图
ax.plot(short_rolling_close.index, short_rolling_close, label='20天均线')
ax.plot(long_rolling_close.index, long_rolling_close, label='100天均线')
#x轴、y轴及图例:
ax.set_xlabel('日期')
ax.set_ylabel('收盘价 (人民币)')
ax.legend() #图例
plt.show() #绘图
3、数据预处理
取出当前的收盘价,删除无用的日期元素
# 4、取出label值
labels = features['Close']
time = features['Date']
features = features.drop('Date',axis=1)
features.head()
进行数据的归一化
# 5、数据预处理
from sklearn import preprocessing
input_features = preprocessing.StandardScaler().fit_transform(features)
input_features
4、构建数据序列
由于RNN的算法要求我们要有一定的序列,来预测出下一个值,所以我们按照20天的数据作为一个序列
# 6、定义序列,[下标1-20天预测第21天的收盘价]
from collections import dequex = []
y = []seq_len = 20
deq = deque(maxlen=seq_len)
for i in input_features:deq.append(list(i))if len(deq) == seq_len:x.append(list(deq))x = x[:-1] # 取少一个序列,因为最后个序列没有答案
y = features['Close'].values[seq_len: ] #从第二十一天开始(下标为20)
time = time.values[seq_len: ] #从第二十一天开始(下标为20)x, y, time = np.array(x), np.array(y), np.array(time)
print(x.shape)
print(y.shape)
print(time.shape)
二、构建模型
1、搭建GRU模型
import tensorflow as tf
from tensorflow.keras import initializers
from tensorflow.keras import regularizers
from tensorflow.keras import layersfrom keras.models import load_model
from keras.models import Sequential
from keras.layers import Dropout
from keras.layers.core import Dense
from keras.optimizers import Adam# 7、搭建模型
model = tf.keras.Sequential()
model.add(layers.GRU(8,input_shape=(20,5), activation='relu', return_sequences=True,kernel_regularizer=tf.keras.regularizers.l2(0.01)))
model.add(layers.GRU(16, activation='relu', return_sequences=True,kernel_regularizer=tf.keras.regularizers.l2(0.01)))
model.add(layers.GRU(32, activation='relu', return_sequences=False,kernel_regularizer=tf.keras.regularizers.l2(0.01)))
model.add(layers.Dense(16,kernel_initializer='random_normal',kernel_regularizer=tf.keras.regularizers.l2(0.01)))
model.add(layers.Dense(1))
model.summary()
2、优化器和损失函数
# 优化器和损失函数
model.compile(optimizer=tf.keras.optimizers.Adam(0.001),loss=tf.keras.losses.MeanAbsoluteError(), # 标签和预测之间绝对差异的平均metrics = tf.keras.losses.MeanSquaredLogarithmicError()) # 计算标签和预测
3、开始训练
25%的比例作为验证集,75%的比例作为训练集
# 开始训练
model.fit(x,y,validation_split=0.25,epochs=200,batch_size=128)
4、模型预测
# 预测
y_pred = model.predict(x)
fig = plt.figure(figsize=(10,5))
axes = fig.add_subplot(111)
axes.plot(time,y,'b-',label='actual')
# 预测值,红色散点
axes.plot(time,y_pred,'r--',label='predict')
axes.set_xticks(time[::50])
axes.set_xticklabels(time[::50],rotation=45)plt.legend()
plt.show()
5、回归指标评估
from sklearn.metrics import mean_squared_error,mean_absolute_error,r2_score
from math import sqrt#回归评价指标
# calculate MSE 均方误差
mse=mean_squared_error(y,y_pred)
# calculate RMSE 均方根误差
rmse = sqrt(mean_squared_error(y, y_pred))
#calculate MAE 平均绝对误差
mae=mean_absolute_error(y,y_pred)
print('均方误差: %.6f' % mse)
print('均方根误差: %.6f' % rmse)
print('平均绝对误差: %.6f' % mae)
源代码
- 源码查看
相关文章:

TensorFlow项目练手(三)——基于GRU股票走势预测任务
项目介绍 项目基于GRU算法通过20天的股票序列来预测第21天的数据,有些项目也可以用LSTM算法,两者主要差别如下: LSTM算法:目前使用最多的时间序列算法,是一种特殊的RNN(循环神经网络)…...

微信小程序页面传值为对象[Object Object]详解
微信小程序页面传值为对象[Object Object]详解 1、先将传递的对象转化为JSON字符串拼接到url上2、在接受对象页面进行转译3、打印结果 1、先将传递的对象转化为JSON字符串拼接到url上 // info为对象 let stationInfo JSON.stringify(info) uni.navigateTo({url: /pages/statio…...

Redis篇
文章目录 Redis-使用场景1、缓存穿透2、缓存击穿3、缓存雪崩4、双写一致5、Redis持久化6、数据过期策略7、数据淘汰策略 Redis-分布式锁1、redis分布式锁,是如何实现的?2、redisson实现的分布式锁执行流程3、redisson实现的分布式锁-可重入4、redisson实…...

Entity Framework(EF)查询
一、In 查询 var list = dbContext.Users.Where(u => new int[] {1, 2, 3, 5,...

使用Pytest生成HTML测试报告
背景 最近开发有关业务场景的功能时,涉及的API接口比较多,需要自己模拟多个业务场景的自动化测试(暂时不涉及性能测试),并且在每次测试完后能够生成一份测试报告。 考虑到日常使用Python自带的UnitTest,所…...

DSA之图(4):图的应用
文章目录 0 图的应用1 生成树1.1 无向图的生成树1.2 最小生成树1.2.1 构造最小生成树1.2.2 Prim算法构造最小生成树1.2.3 Kruskal算法构造最小生成树1.2.4 两种算法的比较 1.3 最短路径1.3.1 两点间最短路径1.3.2 某源点到其他各点最短路径1.3.3 Dijkstra1.3.4 Floyd 1.4 拓扑排…...

[SQL挖掘机] - 窗口函数 - row_number
介绍: row_number() 是一种常用的窗口函数,它为结果集中的每一行分配一个唯一的数字。这个数字的分配基于指定的排序顺序,并且不会跳过相同的排名。 用法: row_number() 函数的语法如下: row_number() over ([partition by 列名1, 列名2,…...

【论文阅读】通过解缠绕表示学习提升领域泛化能力用于主题感知的作文评分
摘要 本文工作聚焦于从领域泛化的视角提升AES模型的泛化能力,在该情况下,目标主题的数据在训练时不能被获得。本文提出了一个主题感知的神经AES模型(PANN)来抽取用于作文评分的综合的表示,包括主题无关(pr…...

二分查找P1873 [COCI2011-2012#5] EKO / 砍树
P1873 [COCI2011-2012#5] EKO / 砍树 - 洛谷 | 计算机科学教育新生态 (luogu.com.cn) 这个题就是给新手练手的,在那个位置上在进行,寻找合适的砍树高度,下面在介绍一个二分查找的模板 int binarySearch(vector<int>& nums, int t…...

【BOOST程序库】正则表达式相关操作
基本概念这里不解释了,代码中详细解释了BOOST程序库中对于正则表达式常用方法的详细用法。 #include <iostream> #include <string>//正则表达式头文件 #include <boost/xpressive/xpressive.hpp>int main() {//声明正则:boost::pres…...

阿里云国际版在使用过程中应该注意什么呢?
为确保系统稳定性,用户不得进行以下操作。否则,阿里云可能无法解决由以下违规操作引起的问题: 1) Windows系统中的PV Drivers 程序不可删除 PV Drivers程序为服务器虚拟化驱动程序,请不要针对该程序进行任何操作,如果删…...

Flutter Provider 共享状态管理
在使用Provider的时候,我们主要关心三个概念: ChangeNotifier:真正数据(状态)存放的地方ChangeNotifierProvider:Widget树中提供数据(状态)的地方,会在其中创建对应的Ch…...

std vector 用法
使用vector,需添加头文件#include,要使用sort或find,则需要添加头文件#include。函数封装在命名空间std中,使用:using namespace std; 1、vector的初始化 std::vector<int> nVec; // 空对象 std::vecto…...

vue vite ts electron ipc addon-napi c arm64
初始化 因网络问题建议使用 cnpm 代替 npm npm init vue # 全选 yes npm i # 进入项目目录后使用 npm i electron electron-builder -D npm i commander -D # 额外组件electron 新建 plugins、src/electron 文件夹 添加 src/electron/background.ts 属于主进程 ipcMain.o…...

机器人科普--AGILOX 叉车
机器人科普--AGILOX 叉车 1 概述2 导航3 驱动轮组4 叉举参考 1 概述 AGILOX 叉车,不需要画地图路径,很厉害。 2 导航 中间路径自由导航,末端规划出轨迹路线,并使用优良的控制器做轨迹追踪。 AGILOX | 10 Min setu…...

Django的生命周期流程图(补充)、路由层urls.py文件、无名分组和有名分组、反向解析(无名反向解析、有名反向解析)、路由分发、伪静态
一、orm的增删改查方法(补充) 1. 查询resmodels.表名(类名).objects.all()[0]resmodels.表名(类名).objects.filter(usernameusername, passwordpassword).all()res models.表名(类名).objects.first() # 判断,判断数据是否有# res如果查询…...

selenium交互代码
一:selenium交互 用selenium打开网页后,也可以做一系列真人的操作,也就是利用selenium和浏览器进行交互,可利用以下几个函数进行操作: input.send_keys() 传递输入内容给某输入框button.click() 点击某按钮browser.e…...

下载远程服务器文件
业务需求:下载某云盘的视频文件存储到本地 测试代码 RequestMapping("testVideo")public String test() {try {SimpleDateFormat DATE_FORMAT new SimpleDateFormat("yyyy/MM/dd/");//组装本地保存地址StringBuilder filePath new StringBuilder(StoreP…...

[SQL挖掘机] - 索引
介绍: 当你在数据库中进行查询时,索引是一种用于提高查询性能的重要工具。索引是对表中的一列或多列进行排序的数据结构,它可以快速定位到满足特定条件的记录,从而减少了查询所需的时间和资源。 在数据库中使用索引的主要好处包括ÿ…...

C++STL库中的list
文章目录 list的介绍及使用 list的常用接口 list的模拟实现 list与vector的对比 一、list的介绍及使用 1. list是可以在常数范围内在任意位置进行插入和删除的序列式容器,并且该容器可以前后双向迭代。 2. list的底层是双向带头循环链表结构,双向带头循…...

【LeetCode 75】第十七题(1493)删掉一个元素以后全为1的最长子数组
目录 题目: 示例: 分析: 代码运行结果: 题目: 示例: 分析: 给一个数组,求删除一个元素以后能得到的连续的最长的全是1的子数组。 我们可以先单独统计出连续为1的子数组分别长度…...

配置IPv6 over IPv4 GRE隧道示例
组网需求 如图1,两个IPv6网络分别通过SwitchA和SwitchC与IPv4公网中的SwitchB连接,客户希望两个IPv6网络中的PC1和PC2实现互通。 其中PC1和PC2上分别指定SwitchA和SwitchC为自己的缺省网关。 图1 配置IPv6 over IPv4 GRE隧道组网图 配置思路 要实现I…...

Google Earth Engine谷歌地球引擎提取多波段长期反射率数据后绘制折线图并导出为Excel
本文介绍在谷歌地球引擎GEE中,提取多年遥感影像多个不同波段的反射率数据,在GEE内绘制各波段的长时间序列走势曲线图,并将各波段的反射率数据与其对应的成像日期一起导出为.csv文件的方法。 本文是谷歌地球引擎(Google Earth Engi…...

第三大的数
414、第三大的数 class Solution {public int thirdMax(int[] nums) {Arrays.sort(nums);int tempnums[0];int ansnums[0];int count 0;// if(nums.length<3){// return nums[nums.length-1];// }// else {for(int inums.length-1;i>0;i--){if (nums[i]>nums[i…...

正则表达式中的方括号[]有什么用?
在正则表达式中,方括号 [] 是用于定义字符集合的元字符。它在正则表达式中有以下作用: 匹配字符集合中的任意一个字符:方括号中列出的字符,表示在这个位置可以匹配这些字符中的任意一个。例如,[abc] 将匹配任意一个字符…...

SQL编写规范
文章目录 1.命名规范:2.库表设计:3.查询数据:4.修改数据:5.索引创建: 1.命名规范: 1.库名、表名、字段名,必须使用小写字母或数字,不得超过30个字符。 2.库名、表名、字段名&#…...

Azure pipeline自动化打包发布
pipeline自动化,提交代码后,就自动打包,打包成功后自动发布 第一步 pipeline提交代码后,自动打包。 1 在Repos,分支里选择要触发的分支,这里选择cn_china,对该分支设置分支策略 2 在生产验证中增加新的策略 3 在分支安…...

【算法提高:动态规划】1.4 状态机模型 TODO
文章目录 例题列表1049. 大盗阿福(其实就是打家劫舍)1057. 股票买卖 IV(k笔交易)1058. 股票买卖 V(冷冻期)1052. 设计密码⭐⭐⭐🚹🚹🚹(TODO)1053…...

ip link add 命令
ip link add veth0 type veth peer name veth1 这条命令主要用于在 Linux 操作系统中创建一个新的 veth(虚拟以太网) 对,这是一种虚拟网络设备,用于在 Linux 命名空间(namespaces)之间创建网络连接。此命令将创建两个设备…...

unity事件处理
方法调用 //发送事件 【发送事件码,发送消息内容】 EventCenterUtil.Broadcast(EventCenterUtil.EventType.Joystick, ui);//监听无参事件 EventCenterUtil.AddListener(EventCenterUtil.EventType.Joystick, show); public void show(){}//发送事件 有参事件 Eve…...