当前位置: 首页 > news >正文

机器学习入门--门控循环单元(GRU)原理与实践

GRU模型

随着深度学习领域的快速发展,循环神经网络(RNN)已成为自然语言处理(NLP)等领域中常用的模型之一。但是,在RNN中,如果时间步数较大,会导致梯度消失或爆炸的问题,这影响了模型的训练效果。为了解决这个问题,研究人员提出了新的模型,其中GRU是其中的一种。

本文将介绍GRU的数学原理、代码实现,并通过pytorch和sklearn的数据集进行试验,最后对该模型进行总结。

数学原理

GRU是一种门控循环单元(Gated Recurrent Unit)模型。与传统的RNN相比,它具有更强的建模能力和更好的性能。

重置门和更新门

在GRU中,每个时间步有两个状态:隐藏状态 h t h_t ht和更新门 r t r_t rt。。更新门控制如何从先前的状态中获得信息,而隐藏状态捕捉序列中的长期依赖关系。

GRU的核心思想是使用“门”来控制信息的流动。这些门是由sigmoid激活函数控制的,它们决定了哪些信息被保留和传递。
在每个时间步 t t t,GRU模型执行以下操作:

1.计算重置门
r t = σ ( W r [ x t , h t − 1 ] ) r_t = \sigma(W_r[x_t, h_{t-1}]) rt=σ(Wr[xt,ht1])
其中, W r W_r Wr是权重矩阵, σ \sigma σ表示sigmoid函数。重置门 r t r_t rt告诉模型是否要忽略先前的隐藏状态 h t − 1 h_{t-1} ht1,并只依赖于当前输入
x t x_t xt

2.计算更新门
z t = σ ( W z [ x t , h t − 1 ] ) z_t = \sigma(W_z[x_t, h_{t-1}]) zt=σ(Wz[xt,ht1])
其中,更新门 z t z_t zt告诉模型新的隐藏状态 h t h_t ht在多大程度上应该使用先前的状态 h t − 1 h_{t-1} ht1

候选隐藏状态和隐藏状态

在计算完重置门和更新门之后,我们可以计算候选隐藏状态 h ~ t \tilde{h}_{t} h~t和隐藏状态 h t h_t ht

1.计算候选隐藏状态
h ~ t = tanh ⁡ ( W [ x t , r t ∗ h t − 1 ] ) \tilde{h}_{t} = \tanh(W[x_t, r_t * h_{t-1}]) h~t=tanh(W[xt,rtht1])
其中, W W W是权重矩阵。候选隐藏状态 h ~ t \tilde{h}_{t} h~t利用当前输入 x t x_t xt和重置门 r t r_t rt来估计下一个可能的隐藏状态。

2.计算隐藏状态
h t = ( 1 − z t ) ∗ h t − 1 + z t ∗ h ~ t h_{t} = (1 - z_t) * h_{t-1} + z_t * \tilde{h}_{t} ht=(1zt)ht1+zth~t
这是GRU的最终隐藏状态公式。它在候选隐藏状态 h ~ t \tilde{h}_{t} h~t和先前的隐藏状态 h t h_t ht之间进行加权,其中权重由更新门 z t z_t zt控制。

代码实现

下面是使用pytorch和sklearn的房价数据集实现GRU的示例代码:

import torch
import torch.nn as nn
import torch.optim as optim
from sklearn.datasets import load_boston
from sklearn.preprocessing import StandardScaler
import matplotlib.pyplot as plt# 加载数据集并进行标准化
data = load_boston()
X = data.data
y = data.target
scaler = StandardScaler()
X = scaler.fit_transform(X)
y = y.reshape(-1, 1)# 转换为张量
X = torch.tensor(X, dtype=torch.float32).unsqueeze(1)
y = torch.tensor(y, dtype=torch.float32)# 定义GRU模型
class GRUNet(nn.Module):def __init__(self, input_size, hidden_size, output_size):super(GRUNet, self).__init__()self.hidden_size = hidden_sizeself.gru = nn.GRU(input_size, hidden_size, batch_first=True)self.fc = nn.Linear(hidden_size, output_size)def forward(self, x):out, _ = self.gru(x)out = self.fc(out[:, -1, :])return outinput_size = X.shape[2]
hidden_size = 32
output_size = 1
model = GRUNet(input_size, hidden_size, output_size)# 定义损失函数和优化器
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)# 训练模型
num_epochs = 10000
loss_list = []
for epoch in range(num_epochs):optimizer.zero_grad()outputs = model(X)loss = criterion(outputs, y)loss.backward()optimizer.step()if (epoch+1) % 100 == 0:loss_list.append(loss.item())print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item()}')# 可视化损失曲线
plt.plot(range(100), loss_list)
plt.xlabel('num_epochs')
plt.ylabel('loss of GRU Training')
plt.show()# 预测新数据
new_data_point = X[0].reshape(1, 1, -1)
prediction = model(new_data_point)
print(f'Predicted value: {prediction.item()}')

上述代码首先加载并标准化房价数据集,然后定义了一个包含GRU层和全连接层的GRUNet模型,并使用均方误差作为损失函数和Adam优化器进行训练。训练完成后,使用matplotlib库绘制损失曲线(如下图所示),并使用训练好的模型对新的数据点进行预测。
GRU 损失曲线

总结

GRU是一种门控循环单元模型,它通过更新门和重置门,有效地解决了梯度消失或爆炸的问题。在本文中,我们介绍了GRU的数学原理、代码实现和代码解释,并通过pytorch和sklearn的房价数据集进行了试验。

相关文章:

机器学习入门--门控循环单元(GRU)原理与实践

GRU模型 随着深度学习领域的快速发展,循环神经网络(RNN)已成为自然语言处理(NLP)等领域中常用的模型之一。但是,在RNN中,如果时间步数较大,会导致梯度消失或爆炸的问题,…...

GitHub Actions

GitHub Actions GitHub Actions 是 GitHub 提供的一种持续集成(CI)和持续部署(CD)解决方案。它可以让你在 GitHub 仓库中直接自动化、定制化和执行软件开发工作流程。 比如,当有新的推送到仓库或者新的 Pull Request…...

harmony 鸿蒙系统学习 安装ohpm报错 ohpm install failed

一. 安装配置 DevEco Studio 安装包时报错 execute ohpm install failed. Install task failed: ArkTS 3.2.12.5. Install ArkTS dependencies failed. 解决办法 找原因,首先,我的电脑中之前安装过node,也许是因为这个。(其实…...

MySQL Replication

0 序言 MySQL Replication 是 MySQL 中的一个功能,允许从一个 MySQL 数据库服务器(称为主服务器或 master)复制数据和数据库结构到另一个服务器(称为从服务器或 slave)。这种复制是异步的,意味着从服务器不…...

redis分布式锁redisson

文章目录 1. 分布式锁1.1 基本原理和实现方式对比synchronized锁在集群模式下的问题多jvm使用同一个锁监视器分布式锁概念分布式锁须满足的条件分布式锁的实现 1.2 基于Redis的分布式锁获取锁&释放锁操作示例 基于Redis实现分布式锁初级版本ILock接口SimpleRedisLock使用示…...

制作一个简单的html网页

1. 特效按钮 2 可以独立使用的一个页面 3 底部小时钟 <!DOCTYPE html> <html> <head><title>Simple Webpage</title><style>/* 禁止鼠标右键 */body {-webkit-touch-callout: none; /* iOS Safari */-webkit-user-select: none; …...

js filter,every,includes 过滤数组

背景&#xff1a; 页面&#xff1a;在项目中遇到的&#xff0c;前端页面显示为&#xff0c;顶部是下拉搜索条件,下面是一个表格&#xff1b; 数据&#xff1a;接口请求一次性拿到所有&#xff1a;搜索条件里的下拉选项和表格中的数据&#xff1b; 现状&#xff1a;需要前端在搜…...

jenkins自动化部署

Jenkins安装 安装前提条件 yum install java-1.8.0-openjdk* git maven -y ​ 1.下载jenkins wget https://mirrors.tuna.tsinghua.edu.cn/jenkins/redhat/jenkins-2.346-1.1.noarch.rpm --no-check-certificate ​ jenkins的安装路径&#xff1a; /var/lib/jenkins/ ​ …...

【JavaScript】分支语句

目录 一、if语句 二、三元运算符 三、switch语句 JS中分支语句可以分为三种&#xff0c;分别是if语句、三元运算符、switch语句。 一、if语句 let num 10 if (num > 20) {console.log("大于20"); } else if (num < 20) {console.log("小于20");…...

【开源】SpringBoot框架开发农家乐订餐系统

目录 一、摘要1.1 项目介绍1.2 项目录屏 二、功能模块2.1 用户2.2 管理员 三、系统展示四、核心代码4.1 查询菜品类型4.2 查询菜品4.3 加购菜品4.4 新增菜品收藏4.5 新增菜品留言 五、免责说明 一、摘要 1.1 项目介绍 基于JAVAVueSpringBootMySQL的农家乐订餐系统&#xff0c…...

OSQP文档学习

OSQP官方文档 1 QSQP简介 OSQP求解形式为的凸二次规划&#xff1a; x ∈ R n x∈R^n x∈Rn&#xff1a;优化变量 P ∈ S n P∈S^n_ P∈Sn​&#xff1a;半正定矩阵 特征 &#xff08;1&#xff09;高效&#xff1a;使用了一种自定义的基于ADMM的一阶方法&#xff0c;只需…...

ONLYOFFICE 8.0:引领数字化办公新纪元

目录 前言 软件安装 软件启动 软件新版本特性 个人评价 总结 前言 在当今快节奏的数字化世界中&#xff0c;高效的办公软件已成为企业竞争力的关键因素。ONLYOFFICE&#xff0c;作为全球领先的办公解决方案提供商&#xff0c;始终致力于通过技术创新来优化用户体验。如今…...

「Linux」基础命令

目录结构 Linux只有1个顶级目录&#xff0c;称为“根目录”路径之间的层级关系&#xff0c;使用/来表示&#xff0c;例如&#xff1a;/usr/local/hello.txt 开头的/表示根目录后面的/表示层级关系 命令入门 命令的通用格式&#xff1a;command [ -options ] [ parameter] c…...

三防平板丨平板终端丨加固平板丨户外勘测应用

随着科技的不断发展&#xff0c;现代勘测业也在不断升级。相较于传统的勘测设备&#xff0c;三防平板在户外勘测中有着广泛的应用。那么&#xff0c;三防平板在户外勘测中究竟有哪些优势呢&#xff1f; 首先&#xff0c;三防平板具备极强的防水、防尘、防摔能力。在野外勘测中&…...

npm ERR! code CERT_HAS_EXPIRED:解决证书过期问题

转载&#xff1a;npm ERR! code CERT_HAS_EXPIRED&#xff1a;解决证书过期问题_npm err! code cert_has_expired npm err! errno cert-CSDN博客 npm config set registry http://registry.cnpmjs.org npm config set registry http://registry.npm.taobao.org...

npm报错之package-lock.json found. 问题和淘宝镜像源过期问题

1、package-lock.json found. 问题的解决 在执行yarn add react-transition-group -S 安装react-transition-group时出现package-lock.json found. Your project contains lock files generated by tools other than Yarn. It is advised not to mix package managers in orde…...

大模型提示学习、Prompting微调知识

为什么需要提示学习&#xff1f; 提示学习是一种在自然语言处理任务中引入人类编写的提示或示例来辅助模型生成更准确和有意义的输出的技术。以下是一些使用提示学习的原因&#xff1a; 解决模糊性&#xff1a;在某些任务中&#xff0c;输入可能存在歧义或模糊性&#xff0c;通…...

vue 导出,下载错误提示、blob与json数据转换

一、成功/失败 - 页面展示 失败 成功 二、成功/失败 - 接口请求/响应展示成功 2. 失败 三、解决 // 导出列表exportReceivedExcel() {if (this.tableCheckedValue) {this.form.ids this.tableCheckedValue.map(v > {return v.id || null})}this.loadingReceivedExcel …...

代码随想录算法训练营|二叉树总结

二叉树的定义&#xff1a; struct TreeNode {int val;TreeNode* left;TreeNode* right;TreeNode():val(0),left(nullptr),right(nullptr){}TreeNode(int val):val(val),left(nullptr),right(nullptr){}TreeNode(int val,TreeNode* left,TreeNode* right):val(val),left(left),…...

rtt的io设备框架面向对象学习-uart设备

目录 1.uart设备基类2.uart设备基类的子类3.初始化/构造流程3.1设备驱动层3.2 设备驱动框架层3.3 设备io管理层 4.总结5.使用 1.uart设备基类 此层处于设备驱动框架层。也是抽象类。 在/ components / drivers / include / drivers 下的serial.h定义了如下uart设备基类 struc…...

解决Arm Compiler 5与6混合编译的链接警告问题

1. 问题现象解析当使用Arm Compiler 5工具链链接包含Arm Compiler 6构建对象文件的项目时&#xff0c;开发者常会遇到如下警告信息&#xff1a;Warning: L6418W: Tagging symbol __tagsym$$used.0 defined in .obj() is not recognized在包含MDK-Middleware组件的项目中&#x…...

通过curl命令快速测试Taotoken上不同大模型的响应效果

&#x1f680; 告别海外账号与网络限制&#xff01;稳定直连全球优质大模型&#xff0c;限时半价接入中。 &#x1f449; 点击领取海量免费额度 通过curl命令快速测试Taotoken上不同大模型的响应效果 对于开发者而言&#xff0c;在集成大模型能力时&#xff0c;快速验证接口连…...

收藏!小白程序员必看:搞定RAG知识库,解锁大模型核心技能!

文章强调知识库是RAG系统的核心&#xff0c;其质量直接影响智能问答效果。构建知识库并非简单处理数据&#xff0c;而是涉及多数据源整合、复杂格式处理、数据更新与版本管理、文档召回优化及系统架构设计等关键环节。作者指出&#xff0c;随着数据量增长&#xff0c;完善的知识…...

GEO生成引擎优化:当品牌竞争从搜索结果页迁移到大模型对话窗口

当生成式AI成为信息的首要分发渠道&#xff0c;你的品牌还只盯着SEO吗&#xff1f;一、用户获取信息的路径&#xff0c;已经变了过去十几年&#xff0c;我们习惯了"搜索关键词 → 浏览结果页 → 点击进入网站"这条线性路径。SEO&#xff08;搜索引擎优化&#xff09;…...

AI时代中小企业还要不要上ERP?2026年最新思考

最近DeepSeek爆火&#xff0c;AI Agent层出不穷&#xff0c;不少老板问我&#xff1a;都2026年了&#xff0c;AI这么厉害&#xff0c;中小企业还有必要上ERP吗&#xff1f;我的答案是&#xff1a;不仅要上&#xff0c;而且要上得更聪明。一、AI再强&#xff0c;也替代不了ERP的…...

量子退火与经典优化算法性能对比研究

1. 量子退火与经典优化算法的性能对比研究在计算科学领域&#xff0c;量子计算一直被视为可能带来革命性突破的技术。其中&#xff0c;量子退火&#xff08;Quantum Annealing&#xff09;作为一种专门用于解决组合优化问题的方法&#xff0c;近年来备受关注。然而&#xff0c;…...

3步彻底解决Windows更新后开始菜单重置难题:ExplorerPatcher深度解析与实战

3步彻底解决Windows更新后开始菜单重置难题&#xff1a;ExplorerPatcher深度解析与实战 【免费下载链接】ExplorerPatcher This project aims to enhance the working environment on Windows 项目地址: https://gitcode.com/GitHub_Trending/ex/ExplorerPatcher 每次Wi…...

三步解锁全网盘极速下载:免登录直链解析完整教程

三步解锁全网盘极速下载&#xff1a;免登录直链解析完整教程 【免费下载链接】Online-disk-direct-link-download-assistant 一个基于 JavaScript 的网盘文件下载地址获取工具。基于【网盘直链下载助手】修改 &#xff0c;支持 百度网盘 / 阿里云盘 / 中国移动云盘 / 天翼云盘 …...

观察Taotoken在多模型聚合调用下的稳定性与路由表现

&#x1f680; 告别海外账号与网络限制&#xff01;稳定直连全球优质大模型&#xff0c;限时半价接入中。 &#x1f449; 点击领取海量免费额度 观察Taotoken在多模型聚合调用下的稳定性与路由表现 1. 引言 在构建依赖大模型能力的应用时&#xff0c;服务的连续性与稳定性是开…...

GAN与密码学的真实接口:从概念纠偏到工程落地

1. 项目概述&#xff1a;这不是密码学&#xff0c;也不是GAN训练指南&#xff0c;而是一场概念误读的深度解剖 “Understanding GAN Cryptography”——这个标题一出现&#xff0c;我就在笔记本上划了三道横线。不是因为难&#xff0c;而是因为它根本不存在。过去三年里&#x…...