当前位置: 首页 > news >正文

机器学习入门--门控循环单元(GRU)原理与实践

GRU模型

随着深度学习领域的快速发展,循环神经网络(RNN)已成为自然语言处理(NLP)等领域中常用的模型之一。但是,在RNN中,如果时间步数较大,会导致梯度消失或爆炸的问题,这影响了模型的训练效果。为了解决这个问题,研究人员提出了新的模型,其中GRU是其中的一种。

本文将介绍GRU的数学原理、代码实现,并通过pytorch和sklearn的数据集进行试验,最后对该模型进行总结。

数学原理

GRU是一种门控循环单元(Gated Recurrent Unit)模型。与传统的RNN相比,它具有更强的建模能力和更好的性能。

重置门和更新门

在GRU中,每个时间步有两个状态:隐藏状态 h t h_t ht和更新门 r t r_t rt。。更新门控制如何从先前的状态中获得信息,而隐藏状态捕捉序列中的长期依赖关系。

GRU的核心思想是使用“门”来控制信息的流动。这些门是由sigmoid激活函数控制的,它们决定了哪些信息被保留和传递。
在每个时间步 t t t,GRU模型执行以下操作:

1.计算重置门
r t = σ ( W r [ x t , h t − 1 ] ) r_t = \sigma(W_r[x_t, h_{t-1}]) rt=σ(Wr[xt,ht1])
其中, W r W_r Wr是权重矩阵, σ \sigma σ表示sigmoid函数。重置门 r t r_t rt告诉模型是否要忽略先前的隐藏状态 h t − 1 h_{t-1} ht1,并只依赖于当前输入
x t x_t xt

2.计算更新门
z t = σ ( W z [ x t , h t − 1 ] ) z_t = \sigma(W_z[x_t, h_{t-1}]) zt=σ(Wz[xt,ht1])
其中,更新门 z t z_t zt告诉模型新的隐藏状态 h t h_t ht在多大程度上应该使用先前的状态 h t − 1 h_{t-1} ht1

候选隐藏状态和隐藏状态

在计算完重置门和更新门之后,我们可以计算候选隐藏状态 h ~ t \tilde{h}_{t} h~t和隐藏状态 h t h_t ht

1.计算候选隐藏状态
h ~ t = tanh ⁡ ( W [ x t , r t ∗ h t − 1 ] ) \tilde{h}_{t} = \tanh(W[x_t, r_t * h_{t-1}]) h~t=tanh(W[xt,rtht1])
其中, W W W是权重矩阵。候选隐藏状态 h ~ t \tilde{h}_{t} h~t利用当前输入 x t x_t xt和重置门 r t r_t rt来估计下一个可能的隐藏状态。

2.计算隐藏状态
h t = ( 1 − z t ) ∗ h t − 1 + z t ∗ h ~ t h_{t} = (1 - z_t) * h_{t-1} + z_t * \tilde{h}_{t} ht=(1zt)ht1+zth~t
这是GRU的最终隐藏状态公式。它在候选隐藏状态 h ~ t \tilde{h}_{t} h~t和先前的隐藏状态 h t h_t ht之间进行加权,其中权重由更新门 z t z_t zt控制。

代码实现

下面是使用pytorch和sklearn的房价数据集实现GRU的示例代码:

import torch
import torch.nn as nn
import torch.optim as optim
from sklearn.datasets import load_boston
from sklearn.preprocessing import StandardScaler
import matplotlib.pyplot as plt# 加载数据集并进行标准化
data = load_boston()
X = data.data
y = data.target
scaler = StandardScaler()
X = scaler.fit_transform(X)
y = y.reshape(-1, 1)# 转换为张量
X = torch.tensor(X, dtype=torch.float32).unsqueeze(1)
y = torch.tensor(y, dtype=torch.float32)# 定义GRU模型
class GRUNet(nn.Module):def __init__(self, input_size, hidden_size, output_size):super(GRUNet, self).__init__()self.hidden_size = hidden_sizeself.gru = nn.GRU(input_size, hidden_size, batch_first=True)self.fc = nn.Linear(hidden_size, output_size)def forward(self, x):out, _ = self.gru(x)out = self.fc(out[:, -1, :])return outinput_size = X.shape[2]
hidden_size = 32
output_size = 1
model = GRUNet(input_size, hidden_size, output_size)# 定义损失函数和优化器
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)# 训练模型
num_epochs = 10000
loss_list = []
for epoch in range(num_epochs):optimizer.zero_grad()outputs = model(X)loss = criterion(outputs, y)loss.backward()optimizer.step()if (epoch+1) % 100 == 0:loss_list.append(loss.item())print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item()}')# 可视化损失曲线
plt.plot(range(100), loss_list)
plt.xlabel('num_epochs')
plt.ylabel('loss of GRU Training')
plt.show()# 预测新数据
new_data_point = X[0].reshape(1, 1, -1)
prediction = model(new_data_point)
print(f'Predicted value: {prediction.item()}')

上述代码首先加载并标准化房价数据集,然后定义了一个包含GRU层和全连接层的GRUNet模型,并使用均方误差作为损失函数和Adam优化器进行训练。训练完成后,使用matplotlib库绘制损失曲线(如下图所示),并使用训练好的模型对新的数据点进行预测。
GRU 损失曲线

总结

GRU是一种门控循环单元模型,它通过更新门和重置门,有效地解决了梯度消失或爆炸的问题。在本文中,我们介绍了GRU的数学原理、代码实现和代码解释,并通过pytorch和sklearn的房价数据集进行了试验。

相关文章:

机器学习入门--门控循环单元(GRU)原理与实践

GRU模型 随着深度学习领域的快速发展,循环神经网络(RNN)已成为自然语言处理(NLP)等领域中常用的模型之一。但是,在RNN中,如果时间步数较大,会导致梯度消失或爆炸的问题,…...

GitHub Actions

GitHub Actions GitHub Actions 是 GitHub 提供的一种持续集成(CI)和持续部署(CD)解决方案。它可以让你在 GitHub 仓库中直接自动化、定制化和执行软件开发工作流程。 比如,当有新的推送到仓库或者新的 Pull Request…...

harmony 鸿蒙系统学习 安装ohpm报错 ohpm install failed

一. 安装配置 DevEco Studio 安装包时报错 execute ohpm install failed. Install task failed: ArkTS 3.2.12.5. Install ArkTS dependencies failed. 解决办法 找原因,首先,我的电脑中之前安装过node,也许是因为这个。(其实…...

MySQL Replication

0 序言 MySQL Replication 是 MySQL 中的一个功能,允许从一个 MySQL 数据库服务器(称为主服务器或 master)复制数据和数据库结构到另一个服务器(称为从服务器或 slave)。这种复制是异步的,意味着从服务器不…...

redis分布式锁redisson

文章目录 1. 分布式锁1.1 基本原理和实现方式对比synchronized锁在集群模式下的问题多jvm使用同一个锁监视器分布式锁概念分布式锁须满足的条件分布式锁的实现 1.2 基于Redis的分布式锁获取锁&释放锁操作示例 基于Redis实现分布式锁初级版本ILock接口SimpleRedisLock使用示…...

制作一个简单的html网页

1. 特效按钮 2 可以独立使用的一个页面 3 底部小时钟 <!DOCTYPE html> <html> <head><title>Simple Webpage</title><style>/* 禁止鼠标右键 */body {-webkit-touch-callout: none; /* iOS Safari */-webkit-user-select: none; …...

js filter,every,includes 过滤数组

背景&#xff1a; 页面&#xff1a;在项目中遇到的&#xff0c;前端页面显示为&#xff0c;顶部是下拉搜索条件,下面是一个表格&#xff1b; 数据&#xff1a;接口请求一次性拿到所有&#xff1a;搜索条件里的下拉选项和表格中的数据&#xff1b; 现状&#xff1a;需要前端在搜…...

jenkins自动化部署

Jenkins安装 安装前提条件 yum install java-1.8.0-openjdk* git maven -y ​ 1.下载jenkins wget https://mirrors.tuna.tsinghua.edu.cn/jenkins/redhat/jenkins-2.346-1.1.noarch.rpm --no-check-certificate ​ jenkins的安装路径&#xff1a; /var/lib/jenkins/ ​ …...

【JavaScript】分支语句

目录 一、if语句 二、三元运算符 三、switch语句 JS中分支语句可以分为三种&#xff0c;分别是if语句、三元运算符、switch语句。 一、if语句 let num 10 if (num > 20) {console.log("大于20"); } else if (num < 20) {console.log("小于20");…...

【开源】SpringBoot框架开发农家乐订餐系统

目录 一、摘要1.1 项目介绍1.2 项目录屏 二、功能模块2.1 用户2.2 管理员 三、系统展示四、核心代码4.1 查询菜品类型4.2 查询菜品4.3 加购菜品4.4 新增菜品收藏4.5 新增菜品留言 五、免责说明 一、摘要 1.1 项目介绍 基于JAVAVueSpringBootMySQL的农家乐订餐系统&#xff0c…...

OSQP文档学习

OSQP官方文档 1 QSQP简介 OSQP求解形式为的凸二次规划&#xff1a; x ∈ R n x∈R^n x∈Rn&#xff1a;优化变量 P ∈ S n P∈S^n_ P∈Sn​&#xff1a;半正定矩阵 特征 &#xff08;1&#xff09;高效&#xff1a;使用了一种自定义的基于ADMM的一阶方法&#xff0c;只需…...

ONLYOFFICE 8.0:引领数字化办公新纪元

目录 前言 软件安装 软件启动 软件新版本特性 个人评价 总结 前言 在当今快节奏的数字化世界中&#xff0c;高效的办公软件已成为企业竞争力的关键因素。ONLYOFFICE&#xff0c;作为全球领先的办公解决方案提供商&#xff0c;始终致力于通过技术创新来优化用户体验。如今…...

「Linux」基础命令

目录结构 Linux只有1个顶级目录&#xff0c;称为“根目录”路径之间的层级关系&#xff0c;使用/来表示&#xff0c;例如&#xff1a;/usr/local/hello.txt 开头的/表示根目录后面的/表示层级关系 命令入门 命令的通用格式&#xff1a;command [ -options ] [ parameter] c…...

三防平板丨平板终端丨加固平板丨户外勘测应用

随着科技的不断发展&#xff0c;现代勘测业也在不断升级。相较于传统的勘测设备&#xff0c;三防平板在户外勘测中有着广泛的应用。那么&#xff0c;三防平板在户外勘测中究竟有哪些优势呢&#xff1f; 首先&#xff0c;三防平板具备极强的防水、防尘、防摔能力。在野外勘测中&…...

npm ERR! code CERT_HAS_EXPIRED:解决证书过期问题

转载&#xff1a;npm ERR! code CERT_HAS_EXPIRED&#xff1a;解决证书过期问题_npm err! code cert_has_expired npm err! errno cert-CSDN博客 npm config set registry http://registry.cnpmjs.org npm config set registry http://registry.npm.taobao.org...

npm报错之package-lock.json found. 问题和淘宝镜像源过期问题

1、package-lock.json found. 问题的解决 在执行yarn add react-transition-group -S 安装react-transition-group时出现package-lock.json found. Your project contains lock files generated by tools other than Yarn. It is advised not to mix package managers in orde…...

大模型提示学习、Prompting微调知识

为什么需要提示学习&#xff1f; 提示学习是一种在自然语言处理任务中引入人类编写的提示或示例来辅助模型生成更准确和有意义的输出的技术。以下是一些使用提示学习的原因&#xff1a; 解决模糊性&#xff1a;在某些任务中&#xff0c;输入可能存在歧义或模糊性&#xff0c;通…...

vue 导出,下载错误提示、blob与json数据转换

一、成功/失败 - 页面展示 失败 成功 二、成功/失败 - 接口请求/响应展示成功 2. 失败 三、解决 // 导出列表exportReceivedExcel() {if (this.tableCheckedValue) {this.form.ids this.tableCheckedValue.map(v > {return v.id || null})}this.loadingReceivedExcel …...

代码随想录算法训练营|二叉树总结

二叉树的定义&#xff1a; struct TreeNode {int val;TreeNode* left;TreeNode* right;TreeNode():val(0),left(nullptr),right(nullptr){}TreeNode(int val):val(val),left(nullptr),right(nullptr){}TreeNode(int val,TreeNode* left,TreeNode* right):val(val),left(left),…...

rtt的io设备框架面向对象学习-uart设备

目录 1.uart设备基类2.uart设备基类的子类3.初始化/构造流程3.1设备驱动层3.2 设备驱动框架层3.3 设备io管理层 4.总结5.使用 1.uart设备基类 此层处于设备驱动框架层。也是抽象类。 在/ components / drivers / include / drivers 下的serial.h定义了如下uart设备基类 struc…...

PHP和Node.js哪个更爽?

先说结论&#xff0c;rust完胜。 php&#xff1a;laravel&#xff0c;swoole&#xff0c;webman&#xff0c;最开始在苏宁的时候写了几年php&#xff0c;当时觉得php真的是世界上最好的语言&#xff0c;因为当初活在舒适圈里&#xff0c;不愿意跳出来&#xff0c;就好比当初活在…...

centos 7 部署awstats 网站访问检测

一、基础环境准备&#xff08;两种安装方式都要做&#xff09; bash # 安装必要依赖 yum install -y httpd perl mod_perl perl-Time-HiRes perl-DateTime systemctl enable httpd # 设置 Apache 开机自启 systemctl start httpd # 启动 Apache二、安装 AWStats&#xff0…...

Java 加密常用的各种算法及其选择

在数字化时代&#xff0c;数据安全至关重要&#xff0c;Java 作为广泛应用的编程语言&#xff0c;提供了丰富的加密算法来保障数据的保密性、完整性和真实性。了解这些常用加密算法及其适用场景&#xff0c;有助于开发者在不同的业务需求中做出正确的选择。​ 一、对称加密算法…...

深入解析C++中的extern关键字:跨文件共享变量与函数的终极指南

&#x1f680; C extern 关键字深度解析&#xff1a;跨文件编程的终极指南 &#x1f4c5; 更新时间&#xff1a;2025年6月5日 &#x1f3f7;️ 标签&#xff1a;C | extern关键字 | 多文件编程 | 链接与声明 | 现代C 文章目录 前言&#x1f525;一、extern 是什么&#xff1f;&…...

安卓基础(aar)

重新设置java21的环境&#xff0c;临时设置 $env:JAVA_HOME "D:\Android Studio\jbr" 查看当前环境变量 JAVA_HOME 的值 echo $env:JAVA_HOME 构建ARR文件 ./gradlew :private-lib:assembleRelease 目录是这样的&#xff1a; MyApp/ ├── app/ …...

MFE(微前端) Module Federation:Webpack.config.js文件中每个属性的含义解释

以Module Federation 插件详为例&#xff0c;Webpack.config.js它可能的配置和含义如下&#xff1a; 前言 Module Federation 的Webpack.config.js核心配置包括&#xff1a; name filename&#xff08;定义应用标识&#xff09; remotes&#xff08;引用远程模块&#xff0…...

基于鸿蒙(HarmonyOS5)的打车小程序

1. 开发环境准备 安装DevEco Studio (鸿蒙官方IDE)配置HarmonyOS SDK申请开发者账号和必要的API密钥 2. 项目结构设计 ├── entry │ ├── src │ │ ├── main │ │ │ ├── ets │ │ │ │ ├── pages │ │ │ │ │ ├── H…...

FFmpeg avformat_open_input函数分析

函数内部的总体流程如下&#xff1a; avformat_open_input 精简后的代码如下&#xff1a; int avformat_open_input(AVFormatContext **ps, const char *filename,ff_const59 AVInputFormat *fmt, AVDictionary **options) {AVFormatContext *s *ps;int i, ret 0;AVDictio…...

【UE5 C++】通过文件对话框获取选择文件的路径

目录 效果 步骤 源码 效果 步骤 1. 在“xxx.Build.cs”中添加需要使用的模块 &#xff0c;这里主要使用“DesktopPlatform”模块 2. 添加后闭UE编辑器&#xff0c;右键点击 .uproject 文件&#xff0c;选择 "Generate Visual Studio project files"&#xff0c;重…...

【WebSocket】SpringBoot项目中使用WebSocket

1. 导入坐标 如果springboot父工程没有加入websocket的起步依赖&#xff0c;添加它的坐标的时候需要带上版本号。 <dependency><groupId>org.springframework.boot</groupId><artifactId>spring-boot-starter-websocket</artifactId> </dep…...