时间序列预测(九)——门控循环单元网络(GRU)
目录
一、GRU结构
二、GRU核心思想
1、更新门(Update Gate):决定了当前时刻隐藏状态中旧状态和新候选状态的混合比例。
2、重置门(Reset Gate):用于控制前一时刻隐藏状态对当前候选隐藏状态的影响程度。
3、候选隐藏状态(Candidate Hidden State):生成当前隐藏状态的候选值
三、GRU 分步演练
1、输入与初始化:
2、计算重置门:
3、计算候选隐藏状态:
4、计算更新门:
5、计算当前隐藏状态:
四、代码实现
1、任务:
2、做法:
3、主要修改点:
4、具体代码:
5、结果
GRU是一种循环神经网络(RNN)的变体,由Cho等人在2014年提出。相比于传统的RNN,GRU引入了门控机制,可以通过该机制来确定应该何时更新隐状态,以及应该何时重置隐状态,使得网络能够更好地捕捉长期依赖性,同时减少了梯度消失的问题。
一、GRU结构
GRU的结构和基础的RNN相比,并没有特别大的不同,都是一种重复神经网络模块的链式结构,由输入层、隐藏层和输出层组成,其中隐藏层是其核心部分,包含了门控机制相关的计算单元。
二、GRU核心思想
与LSTM不同,GRU没有细胞状态,而是直接使用隐藏状态。GRU由两个门控制:更新门(Update Gate)和重置门(Reset Gate)。
1、更新门(Update Gate):决定了当前时刻隐藏状态中旧状态和新候选状态的混合比例。
2、重置门(Reset Gate):用于控制前一时刻隐藏状态对当前候选隐藏状态的影响程度。
补充:
3、候选隐藏状态(Candidate Hidden State):生成当前隐藏状态的候选值
三、GRU 分步演练
1、输入与初始化:
- 假设我们有一个输入序列 X=[x1,x2,...,xT],其中 xt 是第 t 个时间步的输入。
- 初始化隐藏状态 h0,通常为零向量或随机初始化。
2、计算重置门:
- 重置门 rt 决定了前一时间步的隐藏状态 ht−1 对当前候选隐藏状态 h~t 的影响程度。
其中 σ 是sigmoid函数,Wr 和 Ur 是可训练的权重矩阵。
3、计算候选隐藏状态:
- 使用重置门 rt 来控制前一时间步的隐藏状态 ht−1 的影响。
其中 ⊙ 表示元素乘法,tanh 是双曲正切函数,W 和 U 是可训练的权重矩阵。
4、计算更新门:
- 更新门 zt 决定了当前隐藏状态 ht 应该保留多少前一时间步的隐藏状态 ht−1 和多少当前候选隐藏状态 h~t。
其中 Wz 和 Uz 是可训练的权重矩阵。
5、计算当前隐藏状态:
- 使用更新门 zt 来组合前一时间步的隐藏状态 ht−1 和当前候选隐藏状态 h~t。
四、代码实现
1、任务:
根据一个包含道路曲率(Curvature)、车速(Velocity)、侧向加速度(Ay)和方向盘转角(Steering_Angle)真实的数据集,去预测未来的方向盘转角。
2、做法:
提取前5个历史曲率、速度、方向盘转角作为输入特征,同时添加后5个未来曲率(由于车辆的预瞄距离)。目标输出为未来5个方向盘转角。采用GRU网络训练。
3、主要修改点:
- 模型定义:将
LSTM
替换为GRU
,并更新模型类名为GRUModel
。 - 前向传播:
forward
方法中相应地使用 GRU 的输出。
4、具体代码:
# GRU 模型
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import mean_absolute_error as mae, r2_score
import matplotlib.pyplot as plt# 1. 数据预处理
# 读取数据
data = pd.read_excel('input_data_20241010160240.xlsx') # 替换为你的数据文件路径 # 提取特征和标签
curvature = data['Curvature'].values
velocity = data['Velocity'].values
steering = data['Steering_Angle'].values# 定义历史和未来的窗口大小
history_size = 5
future_size = 5features = []
labels = []
for i in range(history_size, len(data) - future_size):# 提取前5个历史的曲率、速度和方向盘转角history_curvature = curvature[i - history_size:i]history_velocity = velocity[i - history_size:i]history_steering = steering[i - history_size:i]# 提取后5个未来的曲率(用于预测)future_curvature = curvature[i:i + future_size]# 输入特征:历史 + 未来曲率feature = np.hstack((history_curvature, history_velocity, history_steering, future_curvature))features.append(feature)# 输出标签:未来5个方向盘转角label = steering[i:i + future_size]labels.append(label)# 转换为 NumPy 数组
features = np.array(features)
labels = np.array(labels)# 归一化
scaler_x = StandardScaler()
scaler_y = StandardScaler()features = scaler_x.fit_transform(features)
labels = scaler_y.fit_transform(labels)# 划分训练集和测试集
x_train, x_test, y_train, y_test = train_test_split(features, labels, test_size=0.05)# 将特征转换为三维张量,形状为 [样本数, 时间序列长度, 特征数]
input_feature_size = history_size * 3 + future_size # 历史曲率、速度、方向盘转角 + 未来曲率
x_train_tensor = torch.tensor(x_train, dtype=torch.float32).view(-1, 1, input_feature_size) # [batch_size, seq_len=1, input_size]
y_train_tensor = torch.tensor(y_train, dtype=torch.float32).view(-1, future_size) # 输出未来的5个方向盘转角
x_test_tensor = torch.tensor(x_test, dtype=torch.float32).view(-1, 1, input_feature_size)
y_test_tensor = torch.tensor(y_test, dtype=torch.float32).view(-1, future_size)# 2. 创建GRU模型
class GRUModel(nn.Module):def __init__(self, input_size, hidden_size, num_layers, output_size):super(GRUModel, self).__init__()self.hidden_size = hidden_sizeself.num_layers = num_layersself.gru = nn.GRU(input_size, hidden_size, num_layers, batch_first=True) # 使用GRUself.fc = nn.Linear(hidden_size, output_size) # 输出层def forward(self, x):# 前向传播out, _ = self.gru(x) # GRU输出out = self.fc(out[:, -1, :]) # 只取最后一个时间步的输出return out# 实例化模型
input_size = input_feature_size # 输入特征数
hidden_size = 64 # 隐藏层大小
num_layers = 2 # GRU层数
output_size = future_size # 输出5个未来方向盘转角
model = GRUModel(input_size, hidden_size, num_layers, output_size)# 3. 设置损失函数和优化器
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)# 4. 训练模型
num_epochs = 1000
for epoch in range(num_epochs):model.train()# 前向传播outputs = model(x_train_tensor)loss = criterion(outputs, y_train_tensor)# 后向传播和优化optimizer.zero_grad()loss.backward()optimizer.step()if (epoch + 1) % 100 == 0:print(f'Epoch [{epoch + 1}/{num_epochs}], Loss: {loss.item():.4f}')# 5. 预测
model.eval()
with torch.no_grad():y_pred_tensor = model(x_test_tensor)y_pred = scaler_y.inverse_transform(y_pred_tensor.numpy()) # 将预测值逆归一化
y_test = scaler_y.inverse_transform(y_test_tensor.numpy()) # 逆归一化真实值# 评估指标
r2 = r2_score(y_test, y_pred, multioutput='uniform_average') # 多维输出下的R^2
mae_score = mae(y_test, y_pred)
print(f"R^2 score: {r2:.4f}")
print(f"MAE: {mae_score:.4f}")# 支持中文
plt.rcParams['font.sans-serif'] = ['SimSun'] # 用来正常显示中文标签
plt.rcParams['axes.unicode_minus'] = False # 用来正常显示负号# 绘制未来5个方向盘转角的预测和真实值对比
plt.figure(figsize=(10, 6))
for i in range(future_size):plt.plot(range(len(y_test)), y_test[:, i], label=f'真实值 {i+1} 步', color='blue')plt.plot(range(len(y_pred)), y_pred[:, i], label=f'预测值 {i+1} 步', color='red')
plt.xlabel('样本索引')
plt.ylabel('Steering Angle')
plt.title('未来5个方向盘转角的实际值与预测值对比图')
plt.legend()
plt.grid(True)
plt.show()# 计算预测和真实方向盘转角的平均值
y_pred_mean = np.mean(y_pred, axis=1) # 每个样本的5个预测值取平均
y_test_mean = np.mean(y_test, axis=1) # 每个样本的5个真实值取平均# 绘制平均值的实际值与预测值对比图
plt.figure(figsize=(10, 6))
plt.plot(range(len(y_test_mean)), y_test_mean, label='真实值(平均)', color='blue')
plt.plot(range(len(y_pred_mean)), y_pred_mean, label='预测值(平均)', color='red')
plt.xlabel('样本索引')
plt.ylabel('Steering Angle (平均)')
plt.title('未来5个方向盘转角的平均值对比图')
plt.legend()
plt.grid(True)
plt.show()# 绘制第1个时间步的实际值与预测值对比图
plt.figure(figsize=(10, 6))
plt.plot(range(len(y_test)), y_test[:, 0], label='真实值 (第1步)', color='blue')
plt.plot(range(len(y_pred)), y_pred[:, 0], label='预测值 (第1步)', color='red')
plt.xlabel('样本索引')
plt.ylabel('Steering Angle')
plt.title('未来第1步方向盘转角的实际值与预测值对比图')
plt.legend()
plt.grid(True)
plt.show()# 计算每个时间步的平均绝对误差
time_steps = y_test.shape[1]
mae_per_step = [mae(y_test[:, i], y_pred[:, i]) for i in range(time_steps)]# 绘制每个时间步的平均绝对误差
plt.figure(figsize=(10, 6))
plt.bar(range(1, time_steps + 1), mae_per_step, color='orange')
plt.xlabel('时间步')
plt.ylabel('MAE')
plt.title('不同时间步的平均绝对误差')
plt.grid(True)
plt.show()
5、结果
五、总结
GRU是LSTM的简化版本,减少了门的数量,使得训练和推理速度更快。它在许多序列建模任务中表现良好,适用于时间序列预测、自然语言处理等领域。
相关文章:

时间序列预测(九)——门控循环单元网络(GRU)
目录 一、GRU结构 二、GRU核心思想 1、更新门(Update Gate):决定了当前时刻隐藏状态中旧状态和新候选状态的混合比例。 2、重置门(Reset Gate):用于控制前一时刻隐藏状态对当前候选隐藏状态的影响程度。…...

李东生牵手通力股份IPO注册卡关,三年近10亿“清仓式分红”引关注
《港湾商业观察》施子夫 9月27日,通力科技股份有限公司(以下简称,通力股份)再度提交了注册申请,实际上早在去年11月6日公司已经提交过注册,看起来公司注册环节面临卡关。公开信息显示,通力股份…...
Android13、14特殊权限-应用安装权限适配
Android13、14特殊权限-应用安装权限适配 文章目录 Android13、14特殊权限-应用安装权限适配一、前言二、权限适配三、其他1、特殊权限-应用安装权限适配小结2、dumpsys package查看获取到了应用安装权限3、Android权限系统:应用操作管理类AppOpsManager(…...

DMVPN协议
DMVPN(Dynamic Multipoint VPN)动态多点VPN 对于分公司和分总公司内网实现通信环境下,分公司是很多的。我们不可能每个分公司和总公司都挨个建立ipsec隧道 ,而且如果是分公司和分公司建立隧道,就会很麻烦。此时我们需…...
leetcode动态规划(十八)-零钱兑换II
题目 322.零钱兑换II 给你一个整数数组 coins ,表示不同面额的硬币;以及一个整数 amount ,表示总金额。 计算并返回可以凑成总金额所需的 最少的硬币个数 。如果没有任何一种硬币组合能组成总金额,返回 -1 。 你可以认为每种硬…...
2024 CSP-J 题解
2024 CSP-J题解 扑克牌 题目给出了一整套牌的定义,但是纯粹在扯淡,完全没有必要去判断给出的牌的花色和点数,我们用一个循环来依次读入每一张牌,如果这个牌在之前出现过,我们就让答案减一。这里建议用map、unorde…...

GPU 服务器厂家:中国加速计算服务器市场的前瞻洞察
科技的飞速发展,让 GPU 服务器在加速计算服务器领域的地位愈发凸显。中国加速计算服务器市场正展现出蓬勃的生机,而 GPU 服务器厂家则是这场科技盛宴中的关键角色。 从市场预测的趋势来看,2023 年起,中国加速计算服务器市场便已展…...
Hadoop集群修改yarn队列
1.修改默认的default队列参数 注意: yarn.scheduler.capacity.root.队列名.capacity总和不能超过100 <property><name>yarn.scheduler.capacity.root.queues</name><value>default,hive,spark,flink</value><description>The…...

【GPIO】2.ADC配置错误,还是能得到电压数据
配置ADC功能时,GPIO引脚弄错了,P1写成P2,但还是配置成功,能得到电压数据。 首先一步步排查: 既然引脚弄错了,那引脚改为正确的引脚,能得到数据通过第一步判断,GPIO配置似乎是不起作…...
css-元素居中方式
<section class"wrapper"><div class"content">Content goes here</div> </section>1. 使用 Flexbox Flexbox 是一种现代的布局方法,可以轻松实现居中。 .wrapper {display: flex; /* 使用 Flexbox …...
redis内存打满了怎么办?
1、设置maxmemory的大小 我们需要给 Redis设置maxmemory的大小,如果不设置的话,它会受限于系统的物理内存和系统对内存的管理机制。 2、设置内存的淘汰策略 内存的淘汰策略分为 8 种,从淘汰范围来说分为从所有的key中淘汰和从设置过期时间…...
决策算法的技术分析
系列文章目录 提示:这里可以添加系列文章的所有文章的目录,目录需要自己手动添加 TODO:写完再整理 文章目录 系列文章目录前言(1)第一层级:分层状态机、分层决策树的想法(三个臭皮匠胜过一个诸葛亮)基于场景的固定规则化的分层决策核心思想(2)第二层级:数据管理的方…...

【Python爬虫】获取汽车之家车型配置附代码(2024.10)
参考大哥,感谢大哥:https://blog.csdn.net/weixin_43498642/article/details/136896338 【任务目标】 工作需要想更方便地下载汽车之家某车系配置清单;(垃圾汽车之家不给下载导出表格,配置页叉掉了车系要出来还要重新…...

JVM 加载 class 文件的原理机制
JVM 加载 class 文件的原理机制 JVM(Java虚拟机)是一个可以执行Java字节码的虚拟机。它负责执行Java应用程序和应用程序的扩展,如Java库和框架。 文章目录 JVM 加载 class 文件的原理机制1. JVM1.1 类加载器1.2 魔数1.3 元空间 2. 类加载2.1 …...

NumPy学习第九课:字符串相关函数
前言 各位有没有注意到,NumPy从第八课开始其实基本上都是讲的是NumPy的函数,而且其实就是各种函数的调用,因为NumPy是一个很强大的函数库,这对我们以后再处理项目中遇到的问题时会有很大的帮助。我们将常用的函数进行一个列举&am…...
卷积神经网络(CNNs)在处理光谱特征的序列属性时表现不佳
卷积神经网络(CNNs)在处理光谱签名的序列属性时表现不佳,主要是由于其固有网络架构的局限性。具体原因如下: 局部感受野(Local Receptive Field): CNN 的核心操作是卷积,它利用局部感…...
【IC】MCU的Tick和晶振频率
Tick 是指 MCU 内部时钟的一个周期,通常表示为一个固定的时间间隔。每个 tick 代表一个时间单位,通常以毫秒(ms)或微秒(μs)为单位。Tick 通常由 MCU 的定时器或计时器生成,作为系统时钟的一部分…...

从0到1学习node.js(npm)
文章目录 一、NPM的生产环境与开发环境二、全局安装三、npm安装指定版本的包四、删除包 五、用npm发布一个包六、修改和删除npm包1、修改2、删除 一、NPM的生产环境与开发环境 类型命令补充生产依赖npm i -S uniq-S 等效于 --save -S是默认选项npm i -save uniq包的信息保存在…...
【STM32 Blue Pill编程实例】-OLED显示DS18B20传感器数据
OLED显示DS18B20传感器数据 文章目录 OLED显示DS18B20传感器数据1、DS18B20介绍2、硬件准备及接线3、模块配置3.1 定时器配置3.2 DS18B20传感器配置3.3 OLED的I2C接口配置4、代码实现在本文中,我们将介绍如何将 DS18B20 温度传感器与 STM32 Blue Pill 开发板连接,并使用 HAL …...
STM32 从0开始系统学习3 启动流程
目录 写在前面 速通:做了什么: 分析I:分析2011年的startup文件所作 分析II:分析2017年的startup文件所作 Helps 2011 2017 Reference 写在前面 请各位看官看本篇笔记的时候首先了解一下计算机体系架构,了解基本…...

网络编程(Modbus进阶)
思维导图 Modbus RTU(先学一点理论) 概念 Modbus RTU 是工业自动化领域 最广泛应用的串行通信协议,由 Modicon 公司(现施耐德电气)于 1979 年推出。它以 高效率、强健性、易实现的特点成为工业控制系统的通信标准。 包…...
[2025CVPR]DeepVideo-R1:基于难度感知回归GRPO的视频强化微调框架详解
突破视频大语言模型推理瓶颈,在多个视频基准上实现SOTA性能 一、核心问题与创新亮点 1.1 GRPO在视频任务中的两大挑战 安全措施依赖问题 GRPO使用min和clip函数限制策略更新幅度,导致: 梯度抑制:当新旧策略差异过大时梯度消失收敛困难:策略无法充分优化# 传统GRPO的梯…...

C++实现分布式网络通信框架RPC(3)--rpc调用端
目录 一、前言 二、UserServiceRpc_Stub 三、 CallMethod方法的重写 头文件 实现 四、rpc调用端的调用 实现 五、 google::protobuf::RpcController *controller 头文件 实现 六、总结 一、前言 在前边的文章中,我们已经大致实现了rpc服务端的各项功能代…...

《Qt C++ 与 OpenCV:解锁视频播放程序设计的奥秘》
引言:探索视频播放程序设计之旅 在当今数字化时代,多媒体应用已渗透到我们生活的方方面面,从日常的视频娱乐到专业的视频监控、视频会议系统,视频播放程序作为多媒体应用的核心组成部分,扮演着至关重要的角色。无论是在个人电脑、移动设备还是智能电视等平台上,用户都期望…...
【Java学习笔记】Arrays类
Arrays 类 1. 导入包:import java.util.Arrays 2. 常用方法一览表 方法描述Arrays.toString()返回数组的字符串形式Arrays.sort()排序(自然排序和定制排序)Arrays.binarySearch()通过二分搜索法进行查找(前提:数组是…...

【网络安全产品大调研系列】2. 体验漏洞扫描
前言 2023 年漏洞扫描服务市场规模预计为 3.06(十亿美元)。漏洞扫描服务市场行业预计将从 2024 年的 3.48(十亿美元)增长到 2032 年的 9.54(十亿美元)。预测期内漏洞扫描服务市场 CAGR(增长率&…...
多模态商品数据接口:融合图像、语音与文字的下一代商品详情体验
一、多模态商品数据接口的技术架构 (一)多模态数据融合引擎 跨模态语义对齐 通过Transformer架构实现图像、语音、文字的语义关联。例如,当用户上传一张“蓝色连衣裙”的图片时,接口可自动提取图像中的颜色(RGB值&…...
基础测试工具使用经验
背景 vtune,perf, nsight system等基础测试工具,都是用过的,但是没有记录,都逐渐忘了。所以写这篇博客总结记录一下,只要以后发现新的用法,就记得来编辑补充一下 perf 比较基础的用法: 先改这…...
【论文笔记】若干矿井粉尘检测算法概述
总的来说,传统机器学习、传统机器学习与深度学习的结合、LSTM等算法所需要的数据集来源于矿井传感器测量的粉尘浓度,通过建立回归模型来预测未来矿井的粉尘浓度。传统机器学习算法性能易受数据中极端值的影响。YOLO等计算机视觉算法所需要的数据集来源于…...
汇编常见指令
汇编常见指令 一、数据传送指令 指令功能示例说明MOV数据传送MOV EAX, 10将立即数 10 送入 EAXMOV [EBX], EAX将 EAX 值存入 EBX 指向的内存LEA加载有效地址LEA EAX, [EBX4]将 EBX4 的地址存入 EAX(不访问内存)XCHG交换数据XCHG EAX, EBX交换 EAX 和 EB…...