当前位置: 首页 > news >正文

机器学习入门--门控循环单元(GRU)原理与实践

GRU模型

随着深度学习领域的快速发展,循环神经网络(RNN)已成为自然语言处理(NLP)等领域中常用的模型之一。但是,在RNN中,如果时间步数较大,会导致梯度消失或爆炸的问题,这影响了模型的训练效果。为了解决这个问题,研究人员提出了新的模型,其中GRU是其中的一种。

本文将介绍GRU的数学原理、代码实现,并通过pytorch和sklearn的数据集进行试验,最后对该模型进行总结。

数学原理

GRU是一种门控循环单元(Gated Recurrent Unit)模型。与传统的RNN相比,它具有更强的建模能力和更好的性能。

重置门和更新门

在GRU中,每个时间步有两个状态:隐藏状态 h t h_t ht和更新门 r t r_t rt。。更新门控制如何从先前的状态中获得信息,而隐藏状态捕捉序列中的长期依赖关系。

GRU的核心思想是使用“门”来控制信息的流动。这些门是由sigmoid激活函数控制的,它们决定了哪些信息被保留和传递。
在每个时间步 t t t,GRU模型执行以下操作:

1.计算重置门
r t = σ ( W r [ x t , h t − 1 ] ) r_t = \sigma(W_r[x_t, h_{t-1}]) rt=σ(Wr[xt,ht1])
其中, W r W_r Wr是权重矩阵, σ \sigma σ表示sigmoid函数。重置门 r t r_t rt告诉模型是否要忽略先前的隐藏状态 h t − 1 h_{t-1} ht1,并只依赖于当前输入
x t x_t xt

2.计算更新门
z t = σ ( W z [ x t , h t − 1 ] ) z_t = \sigma(W_z[x_t, h_{t-1}]) zt=σ(Wz[xt,ht1])
其中,更新门 z t z_t zt告诉模型新的隐藏状态 h t h_t ht在多大程度上应该使用先前的状态 h t − 1 h_{t-1} ht1

候选隐藏状态和隐藏状态

在计算完重置门和更新门之后,我们可以计算候选隐藏状态 h ~ t \tilde{h}_{t} h~t和隐藏状态 h t h_t ht

1.计算候选隐藏状态
h ~ t = tanh ⁡ ( W [ x t , r t ∗ h t − 1 ] ) \tilde{h}_{t} = \tanh(W[x_t, r_t * h_{t-1}]) h~t=tanh(W[xt,rtht1])
其中, W W W是权重矩阵。候选隐藏状态 h ~ t \tilde{h}_{t} h~t利用当前输入 x t x_t xt和重置门 r t r_t rt来估计下一个可能的隐藏状态。

2.计算隐藏状态
h t = ( 1 − z t ) ∗ h t − 1 + z t ∗ h ~ t h_{t} = (1 - z_t) * h_{t-1} + z_t * \tilde{h}_{t} ht=(1zt)ht1+zth~t
这是GRU的最终隐藏状态公式。它在候选隐藏状态 h ~ t \tilde{h}_{t} h~t和先前的隐藏状态 h t h_t ht之间进行加权,其中权重由更新门 z t z_t zt控制。

代码实现

下面是使用pytorch和sklearn的房价数据集实现GRU的示例代码:

import torch
import torch.nn as nn
import torch.optim as optim
from sklearn.datasets import load_boston
from sklearn.preprocessing import StandardScaler
import matplotlib.pyplot as plt# 加载数据集并进行标准化
data = load_boston()
X = data.data
y = data.target
scaler = StandardScaler()
X = scaler.fit_transform(X)
y = y.reshape(-1, 1)# 转换为张量
X = torch.tensor(X, dtype=torch.float32).unsqueeze(1)
y = torch.tensor(y, dtype=torch.float32)# 定义GRU模型
class GRUNet(nn.Module):def __init__(self, input_size, hidden_size, output_size):super(GRUNet, self).__init__()self.hidden_size = hidden_sizeself.gru = nn.GRU(input_size, hidden_size, batch_first=True)self.fc = nn.Linear(hidden_size, output_size)def forward(self, x):out, _ = self.gru(x)out = self.fc(out[:, -1, :])return outinput_size = X.shape[2]
hidden_size = 32
output_size = 1
model = GRUNet(input_size, hidden_size, output_size)# 定义损失函数和优化器
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)# 训练模型
num_epochs = 10000
loss_list = []
for epoch in range(num_epochs):optimizer.zero_grad()outputs = model(X)loss = criterion(outputs, y)loss.backward()optimizer.step()if (epoch+1) % 100 == 0:loss_list.append(loss.item())print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item()}')# 可视化损失曲线
plt.plot(range(100), loss_list)
plt.xlabel('num_epochs')
plt.ylabel('loss of GRU Training')
plt.show()# 预测新数据
new_data_point = X[0].reshape(1, 1, -1)
prediction = model(new_data_point)
print(f'Predicted value: {prediction.item()}')

上述代码首先加载并标准化房价数据集,然后定义了一个包含GRU层和全连接层的GRUNet模型,并使用均方误差作为损失函数和Adam优化器进行训练。训练完成后,使用matplotlib库绘制损失曲线(如下图所示),并使用训练好的模型对新的数据点进行预测。
GRU 损失曲线

总结

GRU是一种门控循环单元模型,它通过更新门和重置门,有效地解决了梯度消失或爆炸的问题。在本文中,我们介绍了GRU的数学原理、代码实现和代码解释,并通过pytorch和sklearn的房价数据集进行了试验。

相关文章:

机器学习入门--门控循环单元(GRU)原理与实践

GRU模型 随着深度学习领域的快速发展,循环神经网络(RNN)已成为自然语言处理(NLP)等领域中常用的模型之一。但是,在RNN中,如果时间步数较大,会导致梯度消失或爆炸的问题,…...

GitHub Actions

GitHub Actions GitHub Actions 是 GitHub 提供的一种持续集成(CI)和持续部署(CD)解决方案。它可以让你在 GitHub 仓库中直接自动化、定制化和执行软件开发工作流程。 比如,当有新的推送到仓库或者新的 Pull Request…...

harmony 鸿蒙系统学习 安装ohpm报错 ohpm install failed

一. 安装配置 DevEco Studio 安装包时报错 execute ohpm install failed. Install task failed: ArkTS 3.2.12.5. Install ArkTS dependencies failed. 解决办法 找原因,首先,我的电脑中之前安装过node,也许是因为这个。(其实…...

MySQL Replication

0 序言 MySQL Replication 是 MySQL 中的一个功能,允许从一个 MySQL 数据库服务器(称为主服务器或 master)复制数据和数据库结构到另一个服务器(称为从服务器或 slave)。这种复制是异步的,意味着从服务器不…...

redis分布式锁redisson

文章目录 1. 分布式锁1.1 基本原理和实现方式对比synchronized锁在集群模式下的问题多jvm使用同一个锁监视器分布式锁概念分布式锁须满足的条件分布式锁的实现 1.2 基于Redis的分布式锁获取锁&释放锁操作示例 基于Redis实现分布式锁初级版本ILock接口SimpleRedisLock使用示…...

制作一个简单的html网页

1. 特效按钮 2 可以独立使用的一个页面 3 底部小时钟 <!DOCTYPE html> <html> <head><title>Simple Webpage</title><style>/* 禁止鼠标右键 */body {-webkit-touch-callout: none; /* iOS Safari */-webkit-user-select: none; …...

js filter,every,includes 过滤数组

背景&#xff1a; 页面&#xff1a;在项目中遇到的&#xff0c;前端页面显示为&#xff0c;顶部是下拉搜索条件,下面是一个表格&#xff1b; 数据&#xff1a;接口请求一次性拿到所有&#xff1a;搜索条件里的下拉选项和表格中的数据&#xff1b; 现状&#xff1a;需要前端在搜…...

jenkins自动化部署

Jenkins安装 安装前提条件 yum install java-1.8.0-openjdk* git maven -y ​ 1.下载jenkins wget https://mirrors.tuna.tsinghua.edu.cn/jenkins/redhat/jenkins-2.346-1.1.noarch.rpm --no-check-certificate ​ jenkins的安装路径&#xff1a; /var/lib/jenkins/ ​ …...

【JavaScript】分支语句

目录 一、if语句 二、三元运算符 三、switch语句 JS中分支语句可以分为三种&#xff0c;分别是if语句、三元运算符、switch语句。 一、if语句 let num 10 if (num > 20) {console.log("大于20"); } else if (num < 20) {console.log("小于20");…...

【开源】SpringBoot框架开发农家乐订餐系统

目录 一、摘要1.1 项目介绍1.2 项目录屏 二、功能模块2.1 用户2.2 管理员 三、系统展示四、核心代码4.1 查询菜品类型4.2 查询菜品4.3 加购菜品4.4 新增菜品收藏4.5 新增菜品留言 五、免责说明 一、摘要 1.1 项目介绍 基于JAVAVueSpringBootMySQL的农家乐订餐系统&#xff0c…...

OSQP文档学习

OSQP官方文档 1 QSQP简介 OSQP求解形式为的凸二次规划&#xff1a; x ∈ R n x∈R^n x∈Rn&#xff1a;优化变量 P ∈ S n P∈S^n_ P∈Sn​&#xff1a;半正定矩阵 特征 &#xff08;1&#xff09;高效&#xff1a;使用了一种自定义的基于ADMM的一阶方法&#xff0c;只需…...

ONLYOFFICE 8.0:引领数字化办公新纪元

目录 前言 软件安装 软件启动 软件新版本特性 个人评价 总结 前言 在当今快节奏的数字化世界中&#xff0c;高效的办公软件已成为企业竞争力的关键因素。ONLYOFFICE&#xff0c;作为全球领先的办公解决方案提供商&#xff0c;始终致力于通过技术创新来优化用户体验。如今…...

「Linux」基础命令

目录结构 Linux只有1个顶级目录&#xff0c;称为“根目录”路径之间的层级关系&#xff0c;使用/来表示&#xff0c;例如&#xff1a;/usr/local/hello.txt 开头的/表示根目录后面的/表示层级关系 命令入门 命令的通用格式&#xff1a;command [ -options ] [ parameter] c…...

三防平板丨平板终端丨加固平板丨户外勘测应用

随着科技的不断发展&#xff0c;现代勘测业也在不断升级。相较于传统的勘测设备&#xff0c;三防平板在户外勘测中有着广泛的应用。那么&#xff0c;三防平板在户外勘测中究竟有哪些优势呢&#xff1f; 首先&#xff0c;三防平板具备极强的防水、防尘、防摔能力。在野外勘测中&…...

npm ERR! code CERT_HAS_EXPIRED:解决证书过期问题

转载&#xff1a;npm ERR! code CERT_HAS_EXPIRED&#xff1a;解决证书过期问题_npm err! code cert_has_expired npm err! errno cert-CSDN博客 npm config set registry http://registry.cnpmjs.org npm config set registry http://registry.npm.taobao.org...

npm报错之package-lock.json found. 问题和淘宝镜像源过期问题

1、package-lock.json found. 问题的解决 在执行yarn add react-transition-group -S 安装react-transition-group时出现package-lock.json found. Your project contains lock files generated by tools other than Yarn. It is advised not to mix package managers in orde…...

大模型提示学习、Prompting微调知识

为什么需要提示学习&#xff1f; 提示学习是一种在自然语言处理任务中引入人类编写的提示或示例来辅助模型生成更准确和有意义的输出的技术。以下是一些使用提示学习的原因&#xff1a; 解决模糊性&#xff1a;在某些任务中&#xff0c;输入可能存在歧义或模糊性&#xff0c;通…...

vue 导出,下载错误提示、blob与json数据转换

一、成功/失败 - 页面展示 失败 成功 二、成功/失败 - 接口请求/响应展示成功 2. 失败 三、解决 // 导出列表exportReceivedExcel() {if (this.tableCheckedValue) {this.form.ids this.tableCheckedValue.map(v > {return v.id || null})}this.loadingReceivedExcel …...

代码随想录算法训练营|二叉树总结

二叉树的定义&#xff1a; struct TreeNode {int val;TreeNode* left;TreeNode* right;TreeNode():val(0),left(nullptr),right(nullptr){}TreeNode(int val):val(val),left(nullptr),right(nullptr){}TreeNode(int val,TreeNode* left,TreeNode* right):val(val),left(left),…...

rtt的io设备框架面向对象学习-uart设备

目录 1.uart设备基类2.uart设备基类的子类3.初始化/构造流程3.1设备驱动层3.2 设备驱动框架层3.3 设备io管理层 4.总结5.使用 1.uart设备基类 此层处于设备驱动框架层。也是抽象类。 在/ components / drivers / include / drivers 下的serial.h定义了如下uart设备基类 struc…...

【入坑系列】TiDB 强制索引在不同库下不生效问题

文章目录 背景SQL 优化情况线上SQL运行情况分析怀疑1:执行计划绑定问题?尝试:SHOW WARNINGS 查看警告探索 TiDB 的 USE_INDEX 写法Hint 不生效问题排查解决参考背景 项目中使用 TiDB 数据库,并对 SQL 进行优化了,添加了强制索引。 UAT 环境已经生效,但 PROD 环境强制索…...

Golang dig框架与GraphQL的完美结合

将 Go 的 Dig 依赖注入框架与 GraphQL 结合使用&#xff0c;可以显著提升应用程序的可维护性、可测试性以及灵活性。 Dig 是一个强大的依赖注入容器&#xff0c;能够帮助开发者更好地管理复杂的依赖关系&#xff0c;而 GraphQL 则是一种用于 API 的查询语言&#xff0c;能够提…...

【JavaSE】绘图与事件入门学习笔记

-Java绘图坐标体系 坐标体系-介绍 坐标原点位于左上角&#xff0c;以像素为单位。 在Java坐标系中,第一个是x坐标,表示当前位置为水平方向&#xff0c;距离坐标原点x个像素;第二个是y坐标&#xff0c;表示当前位置为垂直方向&#xff0c;距离坐标原点y个像素。 坐标体系-像素 …...

云原生玩法三问:构建自定义开发环境

云原生玩法三问&#xff1a;构建自定义开发环境 引言 临时运维一个古董项目&#xff0c;无文档&#xff0c;无环境&#xff0c;无交接人&#xff0c;俗称三无。 运行设备的环境老&#xff0c;本地环境版本高&#xff0c;ssh不过去。正好最近对 腾讯出品的云原生 cnb 感兴趣&…...

IP如何挑?2025年海外专线IP如何购买?

你花了时间和预算买了IP&#xff0c;结果IP质量不佳&#xff0c;项目效率低下不说&#xff0c;还可能带来莫名的网络问题&#xff0c;是不是太闹心了&#xff1f;尤其是在面对海外专线IP时&#xff0c;到底怎么才能买到适合自己的呢&#xff1f;所以&#xff0c;挑IP绝对是个技…...

springboot整合VUE之在线教育管理系统简介

可以学习到的技能 学会常用技术栈的使用 独立开发项目 学会前端的开发流程 学会后端的开发流程 学会数据库的设计 学会前后端接口调用方式 学会多模块之间的关联 学会数据的处理 适用人群 在校学生&#xff0c;小白用户&#xff0c;想学习知识的 有点基础&#xff0c;想要通过项…...

力扣热题100 k个一组反转链表题解

题目: 代码: func reverseKGroup(head *ListNode, k int) *ListNode {cur : headfor i : 0; i < k; i {if cur nil {return head}cur cur.Next}newHead : reverse(head, cur)head.Next reverseKGroup(cur, k)return newHead }func reverse(start, end *ListNode) *ListN…...

【MATLAB代码】基于最大相关熵准则(MCC)的三维鲁棒卡尔曼滤波算法(MCC-KF),附源代码|订阅专栏后可直接查看

文章所述的代码实现了基于最大相关熵准则(MCC)的三维鲁棒卡尔曼滤波算法(MCC-KF),针对传感器观测数据中存在的脉冲型异常噪声问题,通过非线性加权机制提升滤波器的抗干扰能力。代码通过对比传统KF与MCC-KF在含异常值场景下的表现,验证了后者在状态估计鲁棒性方面的显著优…...

Caliper 负载(Workload)详细解析

Caliper 负载(Workload)详细解析 负载(Workload)是 Caliper 性能测试的核心部分,它定义了测试期间要执行的具体合约调用行为和交易模式。下面我将全面深入地讲解负载的各个方面。 一、负载模块基本结构 一个典型的负载模块(如 workload.js)包含以下基本结构: use strict;/…...

(一)单例模式

一、前言 单例模式属于六大创建型模式,即在软件设计过程中,主要关注创建对象的结果,并不关心创建对象的过程及细节。创建型设计模式将类对象的实例化过程进行抽象化接口设计,从而隐藏了类对象的实例是如何被创建的,封装了软件系统使用的具体对象类型。 六大创建型模式包括…...