当前位置: 首页 > news >正文

机器学习入门--门控循环单元(GRU)原理与实践

GRU模型

随着深度学习领域的快速发展,循环神经网络(RNN)已成为自然语言处理(NLP)等领域中常用的模型之一。但是,在RNN中,如果时间步数较大,会导致梯度消失或爆炸的问题,这影响了模型的训练效果。为了解决这个问题,研究人员提出了新的模型,其中GRU是其中的一种。

本文将介绍GRU的数学原理、代码实现,并通过pytorch和sklearn的数据集进行试验,最后对该模型进行总结。

数学原理

GRU是一种门控循环单元(Gated Recurrent Unit)模型。与传统的RNN相比,它具有更强的建模能力和更好的性能。

重置门和更新门

在GRU中,每个时间步有两个状态:隐藏状态 h t h_t ht和更新门 r t r_t rt。。更新门控制如何从先前的状态中获得信息,而隐藏状态捕捉序列中的长期依赖关系。

GRU的核心思想是使用“门”来控制信息的流动。这些门是由sigmoid激活函数控制的,它们决定了哪些信息被保留和传递。
在每个时间步 t t t,GRU模型执行以下操作:

1.计算重置门
r t = σ ( W r [ x t , h t − 1 ] ) r_t = \sigma(W_r[x_t, h_{t-1}]) rt=σ(Wr[xt,ht1])
其中, W r W_r Wr是权重矩阵, σ \sigma σ表示sigmoid函数。重置门 r t r_t rt告诉模型是否要忽略先前的隐藏状态 h t − 1 h_{t-1} ht1,并只依赖于当前输入
x t x_t xt

2.计算更新门
z t = σ ( W z [ x t , h t − 1 ] ) z_t = \sigma(W_z[x_t, h_{t-1}]) zt=σ(Wz[xt,ht1])
其中,更新门 z t z_t zt告诉模型新的隐藏状态 h t h_t ht在多大程度上应该使用先前的状态 h t − 1 h_{t-1} ht1

候选隐藏状态和隐藏状态

在计算完重置门和更新门之后,我们可以计算候选隐藏状态 h ~ t \tilde{h}_{t} h~t和隐藏状态 h t h_t ht

1.计算候选隐藏状态
h ~ t = tanh ⁡ ( W [ x t , r t ∗ h t − 1 ] ) \tilde{h}_{t} = \tanh(W[x_t, r_t * h_{t-1}]) h~t=tanh(W[xt,rtht1])
其中, W W W是权重矩阵。候选隐藏状态 h ~ t \tilde{h}_{t} h~t利用当前输入 x t x_t xt和重置门 r t r_t rt来估计下一个可能的隐藏状态。

2.计算隐藏状态
h t = ( 1 − z t ) ∗ h t − 1 + z t ∗ h ~ t h_{t} = (1 - z_t) * h_{t-1} + z_t * \tilde{h}_{t} ht=(1zt)ht1+zth~t
这是GRU的最终隐藏状态公式。它在候选隐藏状态 h ~ t \tilde{h}_{t} h~t和先前的隐藏状态 h t h_t ht之间进行加权,其中权重由更新门 z t z_t zt控制。

代码实现

下面是使用pytorch和sklearn的房价数据集实现GRU的示例代码:

import torch
import torch.nn as nn
import torch.optim as optim
from sklearn.datasets import load_boston
from sklearn.preprocessing import StandardScaler
import matplotlib.pyplot as plt# 加载数据集并进行标准化
data = load_boston()
X = data.data
y = data.target
scaler = StandardScaler()
X = scaler.fit_transform(X)
y = y.reshape(-1, 1)# 转换为张量
X = torch.tensor(X, dtype=torch.float32).unsqueeze(1)
y = torch.tensor(y, dtype=torch.float32)# 定义GRU模型
class GRUNet(nn.Module):def __init__(self, input_size, hidden_size, output_size):super(GRUNet, self).__init__()self.hidden_size = hidden_sizeself.gru = nn.GRU(input_size, hidden_size, batch_first=True)self.fc = nn.Linear(hidden_size, output_size)def forward(self, x):out, _ = self.gru(x)out = self.fc(out[:, -1, :])return outinput_size = X.shape[2]
hidden_size = 32
output_size = 1
model = GRUNet(input_size, hidden_size, output_size)# 定义损失函数和优化器
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)# 训练模型
num_epochs = 10000
loss_list = []
for epoch in range(num_epochs):optimizer.zero_grad()outputs = model(X)loss = criterion(outputs, y)loss.backward()optimizer.step()if (epoch+1) % 100 == 0:loss_list.append(loss.item())print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item()}')# 可视化损失曲线
plt.plot(range(100), loss_list)
plt.xlabel('num_epochs')
plt.ylabel('loss of GRU Training')
plt.show()# 预测新数据
new_data_point = X[0].reshape(1, 1, -1)
prediction = model(new_data_point)
print(f'Predicted value: {prediction.item()}')

上述代码首先加载并标准化房价数据集,然后定义了一个包含GRU层和全连接层的GRUNet模型,并使用均方误差作为损失函数和Adam优化器进行训练。训练完成后,使用matplotlib库绘制损失曲线(如下图所示),并使用训练好的模型对新的数据点进行预测。
GRU 损失曲线

总结

GRU是一种门控循环单元模型,它通过更新门和重置门,有效地解决了梯度消失或爆炸的问题。在本文中,我们介绍了GRU的数学原理、代码实现和代码解释,并通过pytorch和sklearn的房价数据集进行了试验。

相关文章:

机器学习入门--门控循环单元(GRU)原理与实践

GRU模型 随着深度学习领域的快速发展,循环神经网络(RNN)已成为自然语言处理(NLP)等领域中常用的模型之一。但是,在RNN中,如果时间步数较大,会导致梯度消失或爆炸的问题,…...

GitHub Actions

GitHub Actions GitHub Actions 是 GitHub 提供的一种持续集成(CI)和持续部署(CD)解决方案。它可以让你在 GitHub 仓库中直接自动化、定制化和执行软件开发工作流程。 比如,当有新的推送到仓库或者新的 Pull Request…...

harmony 鸿蒙系统学习 安装ohpm报错 ohpm install failed

一. 安装配置 DevEco Studio 安装包时报错 execute ohpm install failed. Install task failed: ArkTS 3.2.12.5. Install ArkTS dependencies failed. 解决办法 找原因,首先,我的电脑中之前安装过node,也许是因为这个。(其实…...

MySQL Replication

0 序言 MySQL Replication 是 MySQL 中的一个功能,允许从一个 MySQL 数据库服务器(称为主服务器或 master)复制数据和数据库结构到另一个服务器(称为从服务器或 slave)。这种复制是异步的,意味着从服务器不…...

redis分布式锁redisson

文章目录 1. 分布式锁1.1 基本原理和实现方式对比synchronized锁在集群模式下的问题多jvm使用同一个锁监视器分布式锁概念分布式锁须满足的条件分布式锁的实现 1.2 基于Redis的分布式锁获取锁&释放锁操作示例 基于Redis实现分布式锁初级版本ILock接口SimpleRedisLock使用示…...

制作一个简单的html网页

1. 特效按钮 2 可以独立使用的一个页面 3 底部小时钟 <!DOCTYPE html> <html> <head><title>Simple Webpage</title><style>/* 禁止鼠标右键 */body {-webkit-touch-callout: none; /* iOS Safari */-webkit-user-select: none; …...

js filter,every,includes 过滤数组

背景&#xff1a; 页面&#xff1a;在项目中遇到的&#xff0c;前端页面显示为&#xff0c;顶部是下拉搜索条件,下面是一个表格&#xff1b; 数据&#xff1a;接口请求一次性拿到所有&#xff1a;搜索条件里的下拉选项和表格中的数据&#xff1b; 现状&#xff1a;需要前端在搜…...

jenkins自动化部署

Jenkins安装 安装前提条件 yum install java-1.8.0-openjdk* git maven -y ​ 1.下载jenkins wget https://mirrors.tuna.tsinghua.edu.cn/jenkins/redhat/jenkins-2.346-1.1.noarch.rpm --no-check-certificate ​ jenkins的安装路径&#xff1a; /var/lib/jenkins/ ​ …...

【JavaScript】分支语句

目录 一、if语句 二、三元运算符 三、switch语句 JS中分支语句可以分为三种&#xff0c;分别是if语句、三元运算符、switch语句。 一、if语句 let num 10 if (num > 20) {console.log("大于20"); } else if (num < 20) {console.log("小于20");…...

【开源】SpringBoot框架开发农家乐订餐系统

目录 一、摘要1.1 项目介绍1.2 项目录屏 二、功能模块2.1 用户2.2 管理员 三、系统展示四、核心代码4.1 查询菜品类型4.2 查询菜品4.3 加购菜品4.4 新增菜品收藏4.5 新增菜品留言 五、免责说明 一、摘要 1.1 项目介绍 基于JAVAVueSpringBootMySQL的农家乐订餐系统&#xff0c…...

OSQP文档学习

OSQP官方文档 1 QSQP简介 OSQP求解形式为的凸二次规划&#xff1a; x ∈ R n x∈R^n x∈Rn&#xff1a;优化变量 P ∈ S n P∈S^n_ P∈Sn​&#xff1a;半正定矩阵 特征 &#xff08;1&#xff09;高效&#xff1a;使用了一种自定义的基于ADMM的一阶方法&#xff0c;只需…...

ONLYOFFICE 8.0:引领数字化办公新纪元

目录 前言 软件安装 软件启动 软件新版本特性 个人评价 总结 前言 在当今快节奏的数字化世界中&#xff0c;高效的办公软件已成为企业竞争力的关键因素。ONLYOFFICE&#xff0c;作为全球领先的办公解决方案提供商&#xff0c;始终致力于通过技术创新来优化用户体验。如今…...

「Linux」基础命令

目录结构 Linux只有1个顶级目录&#xff0c;称为“根目录”路径之间的层级关系&#xff0c;使用/来表示&#xff0c;例如&#xff1a;/usr/local/hello.txt 开头的/表示根目录后面的/表示层级关系 命令入门 命令的通用格式&#xff1a;command [ -options ] [ parameter] c…...

三防平板丨平板终端丨加固平板丨户外勘测应用

随着科技的不断发展&#xff0c;现代勘测业也在不断升级。相较于传统的勘测设备&#xff0c;三防平板在户外勘测中有着广泛的应用。那么&#xff0c;三防平板在户外勘测中究竟有哪些优势呢&#xff1f; 首先&#xff0c;三防平板具备极强的防水、防尘、防摔能力。在野外勘测中&…...

npm ERR! code CERT_HAS_EXPIRED:解决证书过期问题

转载&#xff1a;npm ERR! code CERT_HAS_EXPIRED&#xff1a;解决证书过期问题_npm err! code cert_has_expired npm err! errno cert-CSDN博客 npm config set registry http://registry.cnpmjs.org npm config set registry http://registry.npm.taobao.org...

npm报错之package-lock.json found. 问题和淘宝镜像源过期问题

1、package-lock.json found. 问题的解决 在执行yarn add react-transition-group -S 安装react-transition-group时出现package-lock.json found. Your project contains lock files generated by tools other than Yarn. It is advised not to mix package managers in orde…...

大模型提示学习、Prompting微调知识

为什么需要提示学习&#xff1f; 提示学习是一种在自然语言处理任务中引入人类编写的提示或示例来辅助模型生成更准确和有意义的输出的技术。以下是一些使用提示学习的原因&#xff1a; 解决模糊性&#xff1a;在某些任务中&#xff0c;输入可能存在歧义或模糊性&#xff0c;通…...

vue 导出,下载错误提示、blob与json数据转换

一、成功/失败 - 页面展示 失败 成功 二、成功/失败 - 接口请求/响应展示成功 2. 失败 三、解决 // 导出列表exportReceivedExcel() {if (this.tableCheckedValue) {this.form.ids this.tableCheckedValue.map(v > {return v.id || null})}this.loadingReceivedExcel …...

代码随想录算法训练营|二叉树总结

二叉树的定义&#xff1a; struct TreeNode {int val;TreeNode* left;TreeNode* right;TreeNode():val(0),left(nullptr),right(nullptr){}TreeNode(int val):val(val),left(nullptr),right(nullptr){}TreeNode(int val,TreeNode* left,TreeNode* right):val(val),left(left),…...

rtt的io设备框架面向对象学习-uart设备

目录 1.uart设备基类2.uart设备基类的子类3.初始化/构造流程3.1设备驱动层3.2 设备驱动框架层3.3 设备io管理层 4.总结5.使用 1.uart设备基类 此层处于设备驱动框架层。也是抽象类。 在/ components / drivers / include / drivers 下的serial.h定义了如下uart设备基类 struc…...

PyCharm - Script parameters (脚本参数)

PyCharm - Script parameters [脚本参数] References Run -> Edit Configurations… -> Run/Debug Configurations -> Configuration -> Script parameters 命令行&#xff1a; python display_yolo_log.py ./person_training_log/person_train_log_DIMM40_stdout…...

Security6.2 中的SpEL 表达式应用(权限注解使用)

最近学习若依框架&#xff0c;里面的权限注解涉及到了SpEL表达式 PreAuthorize("ss.hasPermi(system:user:list)")&#xff0c;若依项目中用的是自己写的方法进行权限处理&#xff0c; 也可以只用security 来实现权限逻辑代码&#xff0c;下面写如何用security 实现。…...

软考笔记--信息系统开发方法(下)

信息系统是一个极其复杂的人机交互系统&#xff0c;它不仅包含计算机技术&#xff0c;通信技术和网络规划以及其他的工程技术&#xff0c;而且&#xff0c;它还是一个复杂的管理系统&#xff0c;需要管理理论和方法的支持&#xff0c;因此&#xff0c;与其他工程项目相比&#…...

从 AGP 4.1.2 到 7.5.1——XmlParser、GPathResult、QName 过时

新年首发&#xff0c; 去年的问题&#xff0c;今年解决~ 问题 & 排查 1: Task failed with an exception. ----------- * What went wrong: Execution failed for task :app:processCommonReleaseManifest. > org.xml.sax.SAXParseException; lineNumber: 1; columnNu…...

spring boot 使用AOP实现是否已登录检测

前后端分离的开发中&#xff0c;用户http请求应用服务的接口时, 如果要求检测该用户是否已登录。可以实现的方法有多种&#xff0c; 本示例是通过aop 的方式实现&#xff0c;简单有效。 约定&#xff1a;前端http的post 请求 export async function request(url,data) {const …...

为什么从没有负值的数据中绘制的小提琴图(Violin Plot)会出现负值部分?

&#x1f349; CSDN 叶庭云&#xff1a;https://yetingyun.blog.csdn.net/ 小提琴图&#xff08;Violin Plot&#xff09; 是一种用于展示和比较数据分布的可视化工具。它结合了箱形图&#xff08;Box Plot&#xff09;和密度图&#xff08;Kernel Density Plot&#xff09;的特…...

有哪几种行为会导致服务器被入侵

导致服务器被入侵的行为有很多种&#xff0c;以下是一些常见的行为&#xff1a; 系统漏洞&#xff1a;服务器操作系统或软件存在漏洞&#xff0c;攻击者可以通过利用这些漏洞获取系统权限&#xff0c;从而入侵服务器。 弱口令&#xff1a;服务器的账号密码过于简单或者未及时更…...

Redis RabbitMQ

Redis&#xff1a;轻量级&#xff0c;NoSQL数据库 redis是一个key-value存储系统。和Memcached类似&#xff0c;它支持存储的value类型相对更多&#xff0c;包括string(字符串)、list(链表)、set(集合)、zset(sorted set --有序集合)和hash&#xff08;哈希类型&#xff09;。这…...

http 和 https 的区别?

目录 1.http 和 https 的基本概念 2.http 和 https 的区别 3.https 协议的工作原理 4.https 协议的优点 5.https 协议的缺点 1.http 和 https 的基本概念 http: 超文本传输协议&#xff0c;是互联网上应用最为广泛的一种网络协议&#xff0c;是一个客户端和服务器端请求和…...

C++中线程的创建

线程创建 引言为什么要使用线程线程的创建使用函数指针示例运行结果使用类对象示例运行结果使用lambda表达式示例运行结果使用带参数的函数作为线程处理函数示例运行结果使用类成员函数示例运行结果引言 在学习C++的过程中,线程的使用作为一个非常重要的部分,也是在复杂项目…...