当前位置: 首页 > article >正文

从认识AI开始-----解密门控循环单元(GRU):对LSTM的再优化

前言

在此之前,我已经详细介绍了RNN和LSTM,RNN虽然在处理序列数据中发挥了重要的作用,但它在实际使用中存在长期依赖问题,处理不了长序列,因为RNN对信息的保存只依赖一个隐藏状态,当序列过长,隐藏转态保存的东西过多时,它对于前面的信息的抽取就会变得困难。为了解决这个问题,LSTM被提出,它通过设计复杂的门控机制以及记忆单元,实现了对信息重要性的提取:因为在现实中,对于一个序列来说,并不是序列中所有的信息都是同等重要的,这就意味着模型可以只记住相关的观测信息即可,但LSTM因为过多的门控机制与记忆单元,导致参数过多,训练速度慢。而GRU则是对LSTM的进一步优化,它的结构简单,训练更高效,并且性能同样出色


一、GRU诞生背景:RNN与LSTM的局限性

1. RNN的问题

RNN 依赖于隐藏单元循环结构来记忆序列信息,但在面对较长序列时会遇到:

  • 梯度消失/爆炸问题
  • 长期依赖问题
  • 训练效率低下

2. LSTM的改进

LSTM 通过设计输入门、遗忘门、输出门,以及单独的记忆单元,有效控制信息流,解决了上述问题。但 LSTM 的结构较为复杂,参数量大,训练慢。

二、GRU:结构更简单性能同样优秀的门控循环单元

GRU在2014年被提出来,其思想来源于LSTM的设计,但是对LSTM的进一步简化:

  • 没有单独的记忆单元,只有一个隐藏转态
  • 将LSTM的输入门和忘记门合并为一个更新门
  • 另有一个重置门控制新信息与历史信息的融合程度

其具体结构如下:

如图可以看出,GRU由重置门、更新门、隐藏状态组成,对于每个时间步,GRU都会进行以下操作:

1. 重置门

重置门(R_t)的作用是:决定遗忘多少过去的信息

R_t=\sigma(X_t @ W_{xr} + H_{t-1} @ W_{hr} + b_r)

2. 更新门

更新门(Z_t)的作用是:决定保留多少过去的信息

Z_t=\sigma(X @ W_{xz} + H @ W_{hz} + b_z)

3. 候选隐藏转态

候选隐藏转态(\tilde {H})能控制对前面信息的遗忘程度,因为 R_t 经过 Sigmoid 后的值在 [0,1] 之间,当 R_t 趋近于 0 时,则表示要遗忘之前的信息,趋近于 1 时,要记住前面的信息

\tilde{H}=tanh(X_t@ W_{xh} + (R_t * H_{t-1}) @ W_{hh} + b_h)

4. 真正的隐藏转态

当 Z_t 为 1 时,忽略当前的候选隐藏转态,直接用前面的隐藏转态 H_{t-1} 作为当前的隐藏转态,当 Z_t 为 0 时,GRU就相当于退化成RNN了。

H_t=Z_t*H_{t-1}+(1-Z_t)*\tilde {H}


三、GRU与RNN/LSTM的比较

特性RNNLSTMGRU
是否解决长期依赖
参数量较少
门控机制输入、输出、遗忘重置、更新
记忆单元无(仅隐藏转态)
训练速度快、但性能差
表现一般类似甚至优于LSTM

GRU相比LSTM来说,结构简洁,参数少,训练更快,在多数任务上性能媲美甚至优于LSTM。更少的参数对过拟合更友好。但由于简化了部分结构,缺少了记忆单元的独立控制,无法像LSTM一样分开控制信息流


 四、手写GRU

通过上面的介绍,我们现在已经知道了GRU的实现原理,现在,我们试着手写一个GRU核心层:

首先,与RNN、LSTM一样,我们先初始化所需要的参数:

import torch
import torch.nn as nn
import torch.nn.functional as Fdef params(input_size, output_size, hidden_size):W_xz, W_hz, b_z = torch.randn(input_size, hidden_size) * 0.1, torch.randn(hidden_size, hidden_size) * 0.1, torch.zeros(hidden_size)W_xr, W_hr, b_r = torch.randn(input_size, hidden_size) * 0.1, torch.randn(hidden_size, hidden_size) * 0.1, torch.zeros(hidden_size)W_xh, W_hh, b_h = torch.randn(input_size, hidden_size) * 0.1, torch.randn(hidden_size, hidden_size) * 0.1, torch.zeros(hidden_size)W_hq = torch.randn(hidden_size, output_size) * 0.1b_q = torch.zeros(output_size)params = [W_xz, W_hz, b_z, W_xr, W_hr, b_r, W_xh, W_hh, b_h, W_hq, b_q]for param in params:param.requires_grad = Truereturn params

然后,定义初始隐藏转态: 

import torchdef init_state(batch_size, hidden_size):return (torch.zeros((batch_size, hidden_size)), )

最后,是GRU的核心操作:

import torch
import torch.nn as nn
def gru(X, state, params):[W_xz, W_hz, b_z, W_xr, W_hr, b_r, W_xh, W_hh, b_h, W_hq, b_q] = params(H, C) = stateoutputs = []for x in X:Z = torch.sigmoid(torch.mm(x, W_xz) + torch.mm(H, W_hz) + b_z)R = torch.sigmoid(torch.mm(x, W_xr) + torch.mm(H, W_hr) + b_r)H_tilde = torch.tanh(torch.mm(x, W_xh) + torch.mm((R * H), W_hh) + b_h)H = Z * H + (1 - Z) * H_tildeY = torch.mm(H, W_hq) + b_qoutputs.append(Y)return torch.cat(outputs, dim=1), (H,)

四、使用Pytroch实现简单的LSTM

在Pytroch中,已经内置了gru函数,我们只需要调用就可以实现上述操作:

import torch
import torch.nn as nnclass mygru(nn.Module):def __init__(self, input_size, hidden_size, output_size, num_layers=1):super(mygru, self).__init__()self.hidden_size = hidden_sizeself.num_layers = num_layersself.gru = nn.GRU(input_size, hidden_size, num_layers=num_layers, batch_first=True)self.fc = nn.Linear(hidden_size, output_size)def forward(self, x, h0):out, hn = self.gru(x, h0)out = self.fc(out)return out, hn# 示例
# 参数定义
input_size = 10
hidden_size = 20
output_size = 10
seq_len = 5
batch_size = 1
num_layers = 1model = mygru(input_size, hidden_size, output_size, num_layers)
inputs = torch.randn(batch_size, seq_len, input_size)
h0 = torch.zeros(num_layers, batch_size, hidden_size)output, hn = model(inputs, h0)
print(output.shape)

总结

以上就是本文的全部内容,算上本篇,我们已经系统性的讲述了RNN、RNN的进化版 LSTM、LSTM的优化版 GRU,相信小伙伴们已经对序列模型有了相当深刻的认识。GRU是一种比LSTM更轻量的门控循环单元,保留了长距离依赖建模能力,同时减少了参数量和计算复杂度。对于大多数NLP和时间序列任务来说,GRU提供了一个在性能与效率之间平衡良好的选择。


如果小伙伴们觉得本文对各位有帮助,欢迎:👍点赞 | ⭐ 收藏 |  🔔 关注。我将持续在专栏《人工智能》中更新人工智能知识,帮助各位小伙伴们打好扎实的理论与操作基础,欢迎🔔订阅本专栏,向AI工程师进阶!

相关文章:

从认识AI开始-----解密门控循环单元(GRU):对LSTM的再优化

前言 在此之前,我已经详细介绍了RNN和LSTM,RNN虽然在处理序列数据中发挥了重要的作用,但它在实际使用中存在长期依赖问题,处理不了长序列,因为RNN对信息的保存只依赖一个隐藏状态,当序列过长,隐…...

Docker系列(五):ROS容器化三叉戟 --- 从X11、Wayland到DockerFile实战全解析

引言 随着机器人操作系统(ROS)在机器人领域的广泛应用,容器化技术成为提高开发效率和简化部署的关键。在多种容器化方案中,基于X11、Wayland和标准Dockerfile的ROS容器化方式各有特点,它们在容器内安装ROS1和ROS2的实…...

【位运算】常见位运算总结

位运算 常见位运算总结位1的个数比特位计数汉明距离只出现一次的数字只出现一次的数字 III 常见位运算总结 位1的个数 191. 位1的个数 给定一个正整数 n,编写一个函数,获取一个正整数的二进制形式并返回其二进制表达式中 设置位 的个数(也被…...

Delphi 导入excel

Delphi导入Excel的常见方法可分为两种主流方案:基于OLE自动化操作Excel原生接口和利用第三方组件库。以下为具体实现流程及注意事项: ‌一、OLE自动化方案(推荐基础场景)‌ 该方法通过COM接口调用本地安装的Excel程序&#xff0c…...

5G RedCap是什么-与标准5G的区别及支持路由器推荐

技术背景与重要性 从智能穿戴到工业传感器,物联网设备种类繁多,但并非所有设备都需要标准5G的全部功能。为满足这些中端应用的需求,3GPP在Release 17中引入了5G RedCap(Reduced Capability),也称为5G NR-L…...

纯html,js创建一个类似excel的表格

后台是php,表中数据可编辑,可删除,可提交到数据库 <!DOCTYPE html> <html> <head><meta charset="utf-8"><style>body {font-family: Arial, sans-serif;margin: 20px;background-color: #fff;}.toolbar {margin-bottom: 10px;disp…...

如何使用windows下的vscode连接到本地虚拟机的linux

1.打开windows下的vscode 下载下图所示插件 下载完以后打开首选项选择设置搜索ssh 搜索ssh往下滑对下图打上勾 点击下图或者按ctrl shift P 搜索ssh 选择第一个&#xff0c;双击后 进入这个界面 好的window基本配置差不多 2.打开虚拟机 在终端中输入 sudo apt-get install…...

Vue开发系列——零基础HTML引入 Vue.js 实现页面之间传参

目录 一、实现页面之间传参 二、使用 URL 查询参数实现传参(不需要额外引入vue-router) 一、实现页面之间传参 实现从a.html 向b.html传参param1value1, param2value2 二、使用 URL 查询参数实现传参(不需要额外引入vue-router) a.html页面 a.html代码&#xff1a; <!…...

Ubuntu22.04 重装后,串口无响应

欢迎关注公号&#xff1a;每日早参&#xff0c;获取每日最新资讯&#xff01; 1&#xff1a;确认串口设备文件是否存在 在Ubuntu中&#xff0c;串口通常会映射为以下两种 /dev/ttyS*&#xff08;对于传统的串口&#xff09; /fragistics/dev/ttyUSB*&#xff08;对于USB转串口…...

设计模式-发布订阅

文章目录 发布订阅概念发布订阅 vs 监听者例子代码 发布订阅概念 发布/订阅者模式最大的特点就是实现了松耦合&#xff0c;也就是说你可以让发布者发布消息、订阅者接受消息&#xff0c;而不是寻找一种方式把两个分离 的系统连接在一起。当然这种松耦合也是发布/订阅者模式最大…...

C#学习26天:内存优化的几种方法

1.减少对象创建 使用场景&#xff1a; 在循环或密集计算中频繁创建对象时。涉及大量短生命周期对象的场景&#xff0c;比如日志记录或字符串拼接。游戏开发中&#xff0c;需要频繁更新对象状态时。 说明&#xff1a; 重用对象可以降低内存分配和垃圾回收的开销。使用对象池…...

功能测试向量是个什么概念

在半导体测试领域&#xff0c;功能测试向量&#xff08;Functional Test Vector&#xff09; 是一个非常重要的概念。以下是对其的详细解释&#xff1a; 1. 什么是功能测试向量&#xff1f; 功能测试向量是一组输入信号和预期输出信号的集合&#xff0c;用于验证芯片的功能是否…...

C++之string的模拟实现

string 手写C字符串类类的基本结构与成员变量一、构造函数与析构函数二、赋值运算符重载三、迭代器支持四、内存管理与扩容机制五、字符串操作函数六、运算符重载总结 手写C字符串类 从零实现一个简易版std::string 类的基本结构与成员变量 namespace zzh { class string { …...

Python打卡第38天

浙大疏锦行 作业&#xff1a; 了解下cifar数据集&#xff0c;尝试获取其中一张图片 import torch import torch.nn as nn import torch.optim as optim from torch.utils.data import DataLoader , Dataset # DataLoader 是 PyTorch 中用于加载数据的工具 from torchvision im…...

【网络安全】轻量敏感路径扫描工具

订阅专栏,获取文末项目源码。 文章目录 工具简介工具特点项目结构使用方法1.环境准备2.配置目标URL3.运行扫描4.结果查看5.自定义扩展项目源码工具简介 该工具是一款基于Python的异步敏感路径扫描工具,用于检测目标网站是否存在敏感文件或路径泄露(如配置文件、密钥、版本控…...

K8S查看pod资源占用和物理机器IP对应关系

方法1&#xff1a;使用管道组合多个grep kubectl describe node | grep -E "Resource|InternalIP" -A 3方法2&#xff1a;显示节点名称和IP地址的对应关系 kubectl describe node | grep -E "Name:|InternalIP:"方法3&#xff1a;更清晰的格式化输出 ku…...

Java Spring 之拦截器HandlerInterceptor详解与实战

目录 一、拦截器的作用1.1 请求处理前的拦截1.2 请求处理后的拦截1.3 请求完成后的拦截 二、创建拦截器2.1 实现 HandlerInterceptor 接口2.2 注册拦截器 三、拦截器的使用场景3.1 权限校验3.2 日志记录3.3 性能监控 四、总结 在 Spring 框架中&#xff0c;拦截器&#xff08; …...

开源第三方库发展现状

摘要&#xff1a;当前&#xff0c;开源第三方库生态正呈现爆发式增长趋势。GitHub 目前已托管超过 4.2 亿个代码仓库&#xff0c;远超早期统计的 1 亿规模&#xff0c;显示出开发者社区的活跃度持续攀升。同时&#xff0c;37 个主流包管理器所维护的开源组件数量可能已达到数千…...

JavaSE核心知识点04工具04-02(IDEA)

&#x1f91f;致敬读者 &#x1f7e9;感谢阅读&#x1f7e6;笑口常开&#x1f7ea;生日快乐⬛早点睡觉 &#x1f4d8;博主相关 &#x1f7e7;博主信息&#x1f7e8;博客首页&#x1f7eb;专栏推荐&#x1f7e5;活动信息 文章目录 JavaSE核心知识点04工具04-02&#xff08;ID…...

NodeMediaEdge通道管理

NodeMediaEdge任务管理 简介 NodeMediaEdge是一款部署在监控摄像机网络前端中&#xff0c;拉取Onvif或者rtsp/rtmp/http视频流并使用rtmp/kmp推送到公网流媒体服务器的工具。 在未使用NodeMediaServer的情况下&#xff0c;或者对部分视频流需要单独推送的需求&#xff0c;也可…...

25、web场景-【源码分析】-静态资源原理

25、web场景-【源码分析】-静态资源原理 静态资源原理主要涉及Spring Boot如何管理和提供静态文件&#xff0c;如CSS、JavaScript、图片等。以下是详细的分析&#xff1a; #### 默认静态资源目录 Spring Boot默认将以下目录作为静态资源的存放位置&#xff1a; - classpath:/…...

qt结构化绑定的重大缺陷:只能创建局部变量

根据你的描述,问题出现在使用 std::make_tuple 和结构化绑定(structured binding)初始化多个成员变量时。这种初始化方式在C++中是合法的,但可能会导致一些问题,尤其是在类的成员变量初始化中。 问题分析 成员变量初始化顺序: 在C++中,类的成员变量的初始化顺序是按照它…...

历年中南大学计算机保研上机真题

2025中南大学计算机保研上机真题 2024中南大学计算机保研上机真题 2023中南大学计算机保研上机真题 在线测评链接&#xff1a;https://pgcode.cn/school 进制转换 题目描述 请写出一段程序&#xff0c;将十进制数字转为八进制。 输入格式 第一行输入 T T T ( 1 ≤ T ≤…...

端口映射不通的原因有哪些?路由器设置后公网访问本地内网失败分析

本地网络地址通过端口映射出去到公网使用&#xff0c;是较为常用的一种传统方案。然而&#xff0c;很多环境下和很多普通人员在实际使用中&#xff0c;却往往会遇到端口映射不通的问题。端口映射不通的主要原因包括公网IP缺失&#xff08;更换nat123类似映射工具方案&#xff0…...

Vue3 封装el-table组件

封装一个el-table组件&#xff1a;子组件仅负责事件触发&#xff0c;业务逻辑&#xff08;如API调用、状态更新&#xff09;由父组件实现 <template><el-table:data"tableData"borderstripestyle"width: 100%; height: calc(100% - 32px);"class…...

Python爬虫实战:研究Requests-HTML库相关技术

1. 引言 1.1 研究背景与意义 随着互联网数据量的爆炸式增长,网络爬虫已成为数据获取的重要工具,广泛应用于市场调研、舆情分析、学术研究等领域。传统爬虫技术在面对现代 JavaScript 动态渲染网页时面临挑战,而 Requests-HTML 库通过集成浏览器渲染引擎,为解决这一问题提…...

Azure Devops pipeline 技巧和最佳实践

1. 如何显示release pipeline ? 解决方法: 登录devops, 找到organization - pipeline - setting下的Disable creation of classic release pipelines,禁用该选项。 然后在project - pipeline - setting,禁用Disable creation of classic release pipelines 现在可以看到r…...

云原生应用架构设计原则与落地实践:从理念到方法论

&#x1f4dd;个人主页&#x1f339;&#xff1a;慌ZHANG-CSDN博客 &#x1f339;&#x1f339;期待您的关注 &#x1f339;&#x1f339; 一、云原生&#xff1a;现代架构的起点与范式变革 1.1 什么是云原生&#xff1f; 云原生&#xff08;Cloud Native&#xff09;是一种面…...

一起学数据结构和算法(三)| 字符串(线性结构)

字符串&#xff08;String&#xff09; 字符串是由字符组成的有限序列&#xff0c;在计算机中通常以字符数组形式存储&#xff0c;支持拼接、查找、替换等操作。 简介 字符串是计算机科学中最常用的数据类型之一&#xff0c;由一系列字符组成的有限序列。在大多数编程语言中&…...

udp 传输实时性测量

UDP&#xff08;用户数据报协议&#xff09;是一种无连接的传输协议&#xff0c;适用于实时性要求较高的应用&#xff0c;如视频流、音频传输和游戏等。测量UDP传输的实时性可以通过多种工具和方法实现&#xff0c;以下是一些常见的方法和工具&#xff1a; 1. 使用 iperf 测试…...