跟着问题学15——GRU网络结构详解及代码实战
1 RNN的缺陷——长期依赖的问题 (The Problem of Long-Term Dependencies)
前面一节我们学习了RNN神经网络,它可以用来处理序列型的数据,比如一段文字,视频等等。RNN网络的基本单元如下图所示,可以将前面的状态作为当前状态的输入。

但也有一些情况,我们需要更“长期”的上下文信息。比如预测最后一个单词“我在中国长大……我说一口流利的**。”“短期”的信息显示,下一个单词很可能是一种语言的名字,但如果我们想缩小范围,我们需要更长期语境——“我在中国长大”,但这个相关信息与需要它的点之间的距离完全有可能变得非常大。
不幸的是,随着这种距离的扩大,RNN无法学会连接这些信息。
从理论上讲,RNN绝对有能力处理这种“长期依赖性”。人们可以为他们精心选择参数,以解决这种形式的问题。遗憾的是,在实践中,RNN似乎无法学习它们。
幸运的是,GRU也没有这个问题!
2、GRU
什么是GRU
GRU(Gate Recurrent Unit)是循环神经网络(Recurrent Neural Network, RNN)的一种。和LSTM(Long-Short Term Memory)一样,也是为了解决长期记忆和反向传播中的梯度等问题而提出来的。
GRU和LSTM在很多情况下实际表现上相差无几,那么为什么我们要使用新人GRU(2014年提出)而不是相对经受了更多考验的LSTM(1997提出)呢。
用论文中的话说,相比LSTM,使用GRU能够达到相当的效果,并且相比之下更容易进行训练,能够很大程度上提高训练效率,因此很多时候会更倾向于使用GRU。
2.1总体结构框架
前面我们讲到,神经网络的各种结构都是为了挖掘变换数据特征的,所以下面我们也将结合数据特征的维度来对比介绍一下RNN&&LSTM的网络结构。
多层感知机(线性连接层)结构

从特征角度考虑:
输入特征:是n*1的单维向量(这也是为什么卷积神经网络在linear层前要把所有特征层展平),
隐藏层:然后根据隐藏层神经元的数量m将前层输入的特征用m*1的单维向量进行表示(对特征进行了提取变换,隐藏层的数据特征),单个隐藏层的神经元数量就代表网络参数,可以设置多个隐藏层;
输出特征:最终根据输出层的神经元数量y输出y*1的单维向量。
卷积神经网络结构

从特征角度考虑:
输入特征:是(batch)*channel*width*height的张量,
卷积层(等):然后根据输入通道channel的数量c_in和输出通道channel的数量c_out会有c_out*c_in*k*k个卷积核将前层输入的特征进行卷积(对特征进行了提取变换,k为卷积核尺寸),卷积核的大小和数量c_out*c_in*k*k就代表网络参数,可以设置多个卷积层;每一个channel都代表提取某方面的一种特征,该特征用width*height的二维张量表示,不同特征层之间是相互独立的(可以进行融合)。
输出特征:根据场景的需要设置后面的输出,可以是多分类的单维向量等等。
循环神经网络RNN系列结构

从特征角度考虑:
输入特征:是(batch)*T_seq*feature_size的张量(T_seq代表序列长度,注意不是batch_size).
我们来详细对比一下卷积神经网络的输入特征,
(batch)*T_seq*feature_size
(batch)*channel*width*height,
逐个进行分析,RNN系列的基础输入特征表示是feature_size*1的单维向量,比如一个单词的词向量,比如一个股票价格的影响因素向量,而CNN系列的基础输入特征是width*height的二维张量;
再来看一下序列T_seq和通道channel,RNN系列的序列T_seq是指一个连续的输入,比如一句话,一周的股票信息,而且这个序列是有时间先后顺序且互相关联的,而CNN系列的通道channel则是指不同角度的特征,比如彩色图像的RGB三色通道,过程中每个通道代表提取了每个方面的特征,不同通道之间是没有强相关性的,不过也可以进行融合。
最后就是batch,两者都有,在RNN系列,batch就是有多个句子,在CNN系列,就是有多张图片(每个图片可以有多个通道)
隐藏层:明确了输入特征之后,我们再来看看隐藏层代表着什么。隐藏层有T_seq个隐状态H_t(和输入序列长度相同),每个隐状态H_t类似于一个channel,对应着T_seq中的t时刻的输入特征;而每个隐状态H_t是用hidden_size*1的单维向量表示的,所以一个隐含层是T_seq*hidden_size的张量;对应时刻t的输入特征由feature_size*1变为hidden_size*1的向量。如图中所示,同一个隐含层不同时刻的参数W_ih和W_hh是共享的;隐藏层可以有num_layers个(图中只有1个)
以t时刻具体阐述一下:
X_t是t时刻的输入,是一个feature_size*1的向量
W_ih是输入层到隐藏层的权重矩阵
H_t是t时刻的隐藏层的值,是一个hidden_size*1的向量
W_hh是上一时刻的隐藏层的值传入到下一时刻的隐藏层时的权重矩阵
Ot是t时刻RNN网络的输出
从上右图中可以看出这个RNN网络在t时刻接受了输入Xt之后,隐藏层的值是St,输出的值是Ot。但是从结构图中我们可以发现St并不单单只是由Xt决定,还与t-1时刻的隐藏层的值St-1有关。
2.2 GRU的输入输出结构
GRU的输入输出结构与普通的RNN是一样的。有一个当前的输入xt,和上一个节点传递下来的隐状态(hidden state)ht-1 ,这个隐状态包含了之前节点的相关信息。结合xt和 ht-1,GRU会得到当前隐藏节点的输出yt 和传递给下一个节点的隐状态 ht。

图 GRU的输入输出结构
那么,GRU到底有什么特别之处呢?下面来对它的内部结构进行分析!
2.3 GRU的内部结构
不同于LSTM有3个门控,GRU仅有2个门控,
第一个是“重置门”(reset gate),其根据当前时刻的输入xt和上一时刻的隐状态ht-1变换后经sigmoid函数输出介于0和1之间的数字,用于将上一时刻隐状态ht-1重置为ht-1’,即ht-1’=ht-1*r。

再将ht-1’与输入xt进行拼接,再通过一个tanh激活函数来将数据放缩到-1~1的范围内。即得到如下图2-3所示的h’。

第一个是“更新门”(update gate),其根据当前时刻的输入xt和上一时刻的隐状态ht-1变换后经sigmoid函数输出介于0和1之间的数字,

最终的隐状态ht的更新表达式即为:

再次强调一下,门控信号(这里的z)的范围为0~1。门控信号越接近1,代表”记忆“下来的数据越多;而越接近0则代表”遗忘“的越多。
2.4 小结
GRU很聪明的一点就在于,使用了同一个门控z就同时可以进行遗忘和选择记忆(LSTM则要使用多个门控)。与LSTM相比,GRU内部少了一个”门控“,参数比LSTM少,但是却也能够达到与LSTM相当的功能。考虑到硬件的计算能力和时间成本,因而很多时候我们也就会选择更加”实用“的GRU。

3代码
import torch
import torch.nn as nndef my_gru(input,initial_states,w_ih,w_hh,b_ih,b_hh):h_prev=initial_statesbatch_size,T_seq,feature_size=input.shapehidden_size=w_ih.shape[0]//3batch_w_ih=w_ih.unsqueeze(0).tile(batch_size,1,1)batch_w_hh=w_hh.unsqueeze(0).tile(batch_size,1,1)output=torch.zeros(batch_size,T_seq,hidden_size)for t in range(T_seq):x=input[:,t,:]w_times_x=torch.bmm(batch_w_ih,x.unsqueeze(-1))w_times_x=w_times_x.squeeze(-1)# print(batch_w_hh.shape,h_prev.shape)# 计算两个tensor的矩阵乘法,torch.bmm(a,b),tensor a 的size为(b,h,w),tensor b的size为(b,w,m)# 也就是说两个tensor的第一维是相等的,然后第一个数组的第三维和第二个数组的第二维度要求一样,# 对于剩下的则不做要求,输出维度 (b,h,m)# batch_w_hh=batch_size*(3*hidden_size)*hidden_size# h_prev=batch_size*hidden_size*1# w_times_x=batch_size*hidden_size*1##squeeze,在给定维度(维度值必须为1)上压缩维度,负数代表从后开始数w_times_h_prev=torch.bmm(batch_w_hh,h_prev.unsqueeze(-1))w_times_h_prev=w_times_h_prev.squeeze(-1)r_t=torch.sigmoid(w_times_x[:,:hidden_size]+w_times_h_prev[:,:hidden_size]+b_ih[:hidden_size]+b_hh[:hidden_size])z_t=torch.sigmoid(w_times_x[:,hidden_size:2*hidden_size]+w_times_h_prev[:,hidden_size:2*hidden_size]+b_ih[hidden_size:2*hidden_size]+b_hh[hidden_size:2*hidden_size])n_t=torch.tanh(w_times_x[:,2*hidden_size:3*hidden_size]+w_times_h_prev[:,2*hidden_size:3*hidden_size]+b_ih[2*hidden_size:3*hidden_size]+b_hh[2*hidden_size:3*hidden_size])h_prev=(1-z_t)*n_t+z_t*h_prevoutput[:,t,:]=h_prevreturn output,h_previf __name__=="__main__":fc=nn.Linear(12,6)batch_size=2T_seq=5feature_size=4hidden_size=3# output_feature_size=3input=torch.randn(batch_size,T_seq,feature_size)h_prev=torch.randn(batch_size,hidden_size)gru_layer=nn.GRU(feature_size,hidden_size,batch_first=True)output,h_final=gru_layer(input,h_prev.unsqueeze(0))# for k,v in gru_layer.named_parameters():# print(k,v.shape)# print(output,h_final)my_output, my_h_final=my_gru(input,h_prev,gru_layer.weight_ih_l0,gru_layer.weight_hh_l0,gru_layer.bias_ih_l0,gru_layer.bias_hh_l0)# print(my_output, my_h_final)# print(torch.allclose(output,my_output))
参考资料
https://zhuanlan.zhihu.com/p/32481747
https://speech.ee.ntu.edu.tw/~tlkagk/courses/MLDS_2018/Lecture/Seq%20(v2).pdf
https://www.bilibili.com/video/BV1jm4y1Q7uh/?spm_id_from=333.788&vd_source=cf7630d31a6ad93edecfb6c5d361c659
相关文章:
跟着问题学15——GRU网络结构详解及代码实战
1 RNN的缺陷——长期依赖的问题 (The Problem of Long-Term Dependencies) 前面一节我们学习了RNN神经网络,它可以用来处理序列型的数据,比如一段文字,视频等等。RNN网络的基本单元如下图所示,可以将前面的…...
【uniapp】swiper切换时,v-for重新渲染页面导致文字在视觉上的拉扯问题
问题描述 先用v-for渲染了几个列表,但这几个列表是占同一个位置的,只是通过切换swiper来显示哪个列表显示,也就是为了优化页面切换时候,没有根据swiper的current再更新v-for的数据,但现在就有个问题,怎么隐…...
【Android】Compose初识
文章目录 1.Compose是什么2.Compose优势3.可组合函数4.布局5.配置布局6.Material Design7.列表与动画8.声明式UI9.组合10.重组 1.Compose是什么 Jetpack Compose是谷歌开发的一个现代的、声明式的UI工具包,用于构建原生的Android应用程序界面。它简化了创建复杂用户…...
前端工程化面试题(二)
前端模块化标准 CJS、ESM 和 UMD 的区别 CJS(CommonJS)、ESM(ESModule)和UMD(Universal Module Definition)是前端模块化标准的三种主要形式,它们各自有不同的特点和使用场景: CJS&…...
以攻击者的视角进行软件安全防护
1. 前言 孙子曰:知彼知己者,百战不殆;不知彼而知己,一胜一负,不知彼,不知己,每战必殆。 摘自《 孙子兵法谋攻篇 》在2500 年前的那个波澜壮阔的春秋战国时代,孙子兵法的这段话&…...
008.精读《Apache Paimon Docs - Table w/o PK》
文章目录 1. 引言2. 基本概念2.1 定义2.2 使用场景 3. 流式处理3.1 自动小文件合并3.2 流式查询 4. 数据更新4.1 查询4.2 更新4.3 分桶附加表 5 总结 1. 引言 通过本文,上篇我们了解了Apache Paimon 主键表,本期我们将继续学习附加表(Append…...
C#实时监控指定文件夹中的动态,并将文件夹中生成的新图片显示在界面上(相机采图,并且从本地拿图)
结果展示 此类原理适用于文件夹中自动生成图片,并提取最新生成的图片将其显示, 如果你是相机采图将其保存到本地,可以用这中方法可视化,并将检测的结果和图片匹配 理论上任何文件都是可以监视并显示的,我这里只是做了…...
使用SQLark分析达梦慢SQL执行计划的一次实践
最近刚参加完达梦的 DCP 培训与考试,正好业务系统有个 sql 查询较慢,就想着练练手。 在深入了解达梦的过程中,发现达梦新出了一款叫 SQLark 百灵连接的工具。 我首先去官网大致浏览了下。虽然 SQLark 在功能深度上不如 DM Manager 和 PL/SQ…...
【人工智能】用Python构建高效的自动化数据标注工具:从理论到实现
《Python OpenCV从菜鸟到高手》带你进入图像处理与计算机视觉的大门! 数据标注是构建高质量机器学习模型的关键环节,但其耗时耗力常成为制约因素。本篇文章将介绍如何用Python构建一个自动化数据标注工具,结合机器学习和NLP技术,帮助加速数据标注过程。我们将从需求分析入…...
Java --- 注解(Annotation)
一.什么是注解? 在Java中,注解(Annotation)是一种元数据(metadata),它为程序中的类、方法、字段等提供额外的描述信息。注解本身不直接改变程序的行为,但可以被编译器、开发工具、框…...
nodejs作为provider接入nacos
需求:公司产品一直是nodejs的后台,采用的eggjs框架,也不是最新版本,现有有需求需求将这些应用集成到微服务的注册中心,领导要求用java。 思路:用spring cloud gateway将需要暴露的接口url转发,…...
SpringBoot3+Micormeter监控应用指标
监控内容简介 SpringBoot3项目监控服务 ,可以使用Micormeter度量指标库,帮助我们监控应用程序的度量指标,并将其发送到Prometheus中并用Grafana展示。监控指标有系统负载、内存使用情况、应用程序的响应时间、吞吐量、错误率等。 micromete…...
Mybatis-plus 简单使用,mybatis-plus 分页模糊查询报500 的错
一、mybtis-plus配置下载 MyBatis-Plus 是一个 Mybatis 增强版工具,在 MyBatis 上扩充了其他功能没有改变其基本功能,为了简化开发提交效率而存在。 具体的介绍请参见官方文档。 官网文档地址:mybatis-plus 添加mybatis-plus依赖 <depe…...
2022 年 12 月青少年软编等考 C 语言三级真题解析
目录 T1. 鸡兔同笼思路分析T2. 猴子吃桃思路分析T3. 括号匹配问题T4. 上台阶思路分析T5. 田忌赛马T1. 鸡兔同笼 一个笼子里面关了鸡和兔子(鸡有 2 2 2 只脚,兔子有 4 4 4 只脚,没有例外)。已经知道了笼子里面脚的总数 a a a,问笼子里面至少有多少只动物,至多有多少只…...
webpack 题目
文章目录 webpack 中 chunkHash 和 contentHash 的区别loader和plugin的区别?webpack 处理 image 是用哪个 loader,限制 image 大小的是...;webpack 如何优化打包速度 webpack 中 chunkHash 和 contentHash 的区别 主要从四方面来讲一下区别&…...
【MySQL】视图详解
视图详解 一、视图的概念二、视图的常用操作2.1创建视图2.2查询视图2.3修改视图2.4 删除视图2.5向视图中插入数据 三、视图的检查选项3.1 cascaded(级联 )3.2 local(本地) 四、视图的作用 一、视图的概念 视图(View)是一种虚拟存…...
第一节:ORIN NX介绍与基于sdkmanager的镜像烧录(包含ubuntu文件系统/CUDA/OpenCV/cudnn/TensorRT)
ORIN NX技术参数 Orin NX版本对比 如上图所示,ORIN NX官方发布的版本有两个版本一个版本是70TOPS算力,DDR为8GB的版本低配版本,一个是100TOPS算法,DDR为16GB的高配版本。 Orin NX的外设框图 两个版本除了GPU和DDR的差距外,外设基本上没有区别,丰富的外设接口,后续开发…...
2024-12-04OpenCV视频处理基础
OpenCV视频处理基础 OpenCV的视频教学:https://www.bilibili.com/video/BV14P411D7MH 1-OpenCV视频捕获 在 OpenCV 中,cv2.VideoCapture() 是一个用于捕获视频流的类。它可以用来从摄像头捕获实时视频,或者从视频文件中读取帧。以下是如何使用…...
D89【python 接口自动化学习】- pytest基础用法
day89 pytest的setup,setdown详解 学习日期:20241205 学习目标:pytest基础用法 -- pytest的setup,setdown详解 学习笔记: setup、teardown详解 模块级 setup_module/teardown_module 开始于模块始末,生…...
七、docker registry
七、docker registry 7.1 了解Docker Registry 7.1.1 介绍 registry 用于保存docker 镜像,包括镜像的层次结构和元数据。启动容器时,docker daemon会试图从本地获取相关的镜像;本地镜像不存在时,其将从registry中下载该镜像并保…...
论文解读:交大港大上海AI Lab开源论文 | 宇树机器人多姿态起立控制强化学习框架(二)
HoST框架核心实现方法详解 - 论文深度解读(第二部分) 《Learning Humanoid Standing-up Control across Diverse Postures》 系列文章: 论文深度解读 + 算法与代码分析(二) 作者机构: 上海AI Lab, 上海交通大学, 香港大学, 浙江大学, 香港中文大学 论文主题: 人形机器人…...
Day131 | 灵神 | 回溯算法 | 子集型 子集
Day131 | 灵神 | 回溯算法 | 子集型 子集 78.子集 78. 子集 - 力扣(LeetCode) 思路: 笔者写过很多次这道题了,不想写题解了,大家看灵神讲解吧 回溯算法套路①子集型回溯【基础算法精讲 14】_哔哩哔哩_bilibili 完…...
基于Docker Compose部署Java微服务项目
一. 创建根项目 根项目(父项目)主要用于依赖管理 一些需要注意的点: 打包方式需要为 pom<modules>里需要注册子模块不要引入maven的打包插件,否则打包时会出问题 <?xml version"1.0" encoding"UTF-8…...
【Web 进阶篇】优雅的接口设计:统一响应、全局异常处理与参数校验
系列回顾: 在上一篇中,我们成功地为应用集成了数据库,并使用 Spring Data JPA 实现了基本的 CRUD API。我们的应用现在能“记忆”数据了!但是,如果你仔细审视那些 API,会发现它们还很“粗糙”:有…...
数据库分批入库
今天在工作中,遇到一个问题,就是分批查询的时候,由于批次过大导致出现了一些问题,一下是问题描述和解决方案: 示例: // 假设已有数据列表 dataList 和 PreparedStatement pstmt int batchSize 1000; // …...
图表类系列各种样式PPT模版分享
图标图表系列PPT模版,柱状图PPT模版,线状图PPT模版,折线图PPT模版,饼状图PPT模版,雷达图PPT模版,树状图PPT模版 图表类系列各种样式PPT模版分享:图表系列PPT模板https://pan.quark.cn/s/20d40aa…...
Maven 概述、安装、配置、仓库、私服详解
目录 1、Maven 概述 1.1 Maven 的定义 1.2 Maven 解决的问题 1.3 Maven 的核心特性与优势 2、Maven 安装 2.1 下载 Maven 2.2 安装配置 Maven 2.3 测试安装 2.4 修改 Maven 本地仓库的默认路径 3、Maven 配置 3.1 配置本地仓库 3.2 配置 JDK 3.3 IDEA 配置本地 Ma…...
Java后端检查空条件查询
通过抛出运行异常:throw new RuntimeException("请输入查询条件!");BranchWarehouseServiceImpl.java // 查询试剂交易(入库/出库)记录Overridepublic List<BranchWarehouseTransactions> queryForReagent(Branch…...
rm视觉学习1-自瞄部分
首先先感谢中南大学的开源,提供了很全面的思路,减少了很多基础性的开发研究 我看的阅读的是中南大学FYT战队开源视觉代码 链接:https://github.com/CSU-FYT-Vision/FYT2024_vision.git 1.框架: 代码框架结构:readme有…...
2025-05-08-deepseek本地化部署
title: 2025-05-08-deepseek 本地化部署 tags: 深度学习 程序开发 2025-05-08-deepseek 本地化部署 参考博客 本地部署 DeepSeek:小白也能轻松搞定! 如何给本地部署的 DeepSeek 投喂数据,让他更懂你 [实验目的]:理解系统架构与原…...
