当前位置: 首页 > news >正文

4.5.门控循环单元GRU

门控循环单元GRU

​ 对于一个序列,不是每个观察值都是同等重要的,可能会遇到一下几种情况:

  1. 早期观测值对预测所有未来观测值都具有非常重要的意义。

    考虑极端情况,第一个观测值包含一个校验和,目的是在序列的末尾辨别校验和事否正确,我们希望有某些机制在一个记忆元里存储重要的早期信息。如果没有这样的机制,我们将不得不给这个观测值指定一个非常大的梯度。

  2. 一些词元没有相关的观测值

    在对网页内容进行情感分析时,可能一些辅助的HTML代码与网页传达的情绪无关,我们希望有一些机制来跳过隐状态中的此类词元

  3. 序列的各个部分存在逻辑中断

    书的章节之间可能也会有过渡,证券的熊市,牛市之间可能会有过渡。这种情况下, 最好有一种方法来重置我们的内部状态表示

​ 有很多方法来解决这类问题,最早的方法是"长短期记忆"(long-short-term memory,LSTM)。门控循环单元(gated recurrent unit,GRU)是一个稍微简化的变体,通常能提供同等的效果,并且计算速度更快。

1.门控隐状态

​ 门控循环单元与普通的循环神经网络之间的关键区别在于: 前者支持隐状态的门控。 这意味着模型有专门的机制来确定应该何时更新隐状态, 以及应该何时重置隐状态。这些机制是可学习的。

1.1 重置门和更新门

在这里插入图片描述

​ 重置门和更新门的输入如图所示。重置门允许我们控制”可能还想记住“的过去状态的数量;更新门将允许我们控制新状态中有多少个是旧状态的副本。

​ 其中输入是由当前时间步的输入和前一时间步的隐状态给出,两个门的输出由使用sigmoid激活函数的两个全连接层给出。

​ 假设输入是一个小批量 X t ∈ R n × d X_t\in \R^{n\times d} XtRn×d(样本数量 n n n,输入个数 d d d),上一个时间步的隐状态是 H t − 1 ∈ R n × h H_{t-1}\in \R^{n\times h} Ht1Rn×h(隐藏单元个数 h h h)。那么重置门 R t R_t Rt和更新门 Z t Z_t Zt(均为 R n × h \R^{n\times h} Rn×h)的计算如下所示:
R t = σ ( X t W x r + H t − 1 W h r + b r ) Z t = σ ( X t W x z + H t − 1 W h z + b z ) R_t = \sigma(X_tW_{xr}+H_{t-1}W_{hr}+b_r)\\ Z_t = \sigma(X_t W_{xz}+H_{t-1}W_{hz}+b_z) Rt=σ(XtWxr+Ht1Whr+br)Zt=σ(XtWxz+Ht1Whz+bz)
​ 其中 W x r , W x z ∈ R d × h W_{xr},W_{xz}\in \R^{d\times h} Wxr,WxzRd×h W h r , W h z ∈ R h × h W_{hr},W_{hz}\in \R^{h\times h} Whr,WhzRh×h是权重参数, b r , b z ∈ R 1 × h b_r,b_z\in \R^{1\times h} br,bzR1×h是偏置参数。求和过程中会触发广播机制。 我们使用sigmoid函数将输入值转换到区间¥(0,1)$。

1.2 候选隐状态

在这里插入图片描述

​ 将重置门 R t R_t Rt与常规隐状态更新机制集成,得到在时间步 t t t的候选隐状态 H ^ t ∈ R n × h \hat{H}_t\in\R ^{n\times h} H^tRn×h
H ^ t = t a n h ( X t W x h + ( R t ⊙ H t − 1 ) W h h + b h ) \hat{H}_t = tanh(X_tW_{xh}+(R_t\odot H_{t-1})W_{hh}+b_h) H^t=tanh(XtWxh+(RtHt1)Whh+bh)
​ 其中 W x h ∈ R d × h W_{xh}\in\R^{d\times h} WxhRd×h W h h ∈ R h × h W_{hh}\in \R ^{h\times h} WhhRh×h是权重参数, b h ∈ R 1 × h b_h\in \R^{1\times h} bhR1×h是偏置项,符号 ⊙ \odot 是Hadamard积(按元素乘积)运算符,此处使用tanh非线性激活函数确保候选隐状态中的值保持在区间 ( − 1 , 1 ) (-1,1) (1,1)中。。

R t ⊙ H t − 1 R_t\odot H_{t-1} RtHt1的元素相乘可以减少以往状态的影响,每当重置门 R t R_t Rt中的项接近1时,我们恢复一个普通的循环神经网络,如果 R t R_t Rt全为0,则之前的信息全部遗忘。重置门是可以学习的,通过学习,可以根据目前的输入决定哪些东西需要遗忘。

1.3 隐状态

在这里插入图片描述

​ 1.2中得出的是候选隐状态,真正的隐状态需要结合更新门的效果。这一步确定新的隐状态 H t ∈ R n × h H_t\in \R^{n\times h} HtRn×h在多大程度上来自旧的状态 H t − 1 H_{t-1} Ht1和新的候选状态 H t ^ \hat{H_t} Ht^。更新门 Z t Z_t Zt仅需要在 H t − 1 H_{t-1} Ht1 H ^ t \hat{H}_t H^t之间进行按元素的凸组合就可以实现,于是得出了最终的更新公式:
H t = Z t ⊙ H t − 1 + ( 1 − Z t ) ⊙ H ^ t H_t =Z_t \odot H_{t-1}+(1-Z_t)\odot \hat{H}_t Ht=ZtHt1+(1Zt)H^t
​ 容易看出,更新门 Z t Z_t Zt越趋近1,模型就倾向只保留旧状态,此时来自输入 X t X_t Xt的信息基本上被忽略,从而有效地跳过了依赖链条中的时间步 t t t。相反,当 Z t Z_t Zt接近0时,新的隐状态 H t H_t Ht就会接近候选隐状态 H t ^ \hat {H_t} Ht^

2.代码实现

2.1 从零开始

import torch
from torch import nn
from d2l import torch as d2lbatch_size, num_steps = 32, 35
train_iter, vocab = d2l.load_data_time_machine(batch_size, num_steps)def get_params(vocab_size, num_hiddens, device):num_inputs = num_outputs = vocab_sizedef normal(shape):return torch.randn(size=shape, device=device) * 0.01def three():return (normal((num_inputs, num_hiddens)),normal((num_hiddens, num_hiddens)),torch.zeros(num_hiddens, device=device))W_xz, W_hz, b_z = three()  # 更新门参数W_xr, W_hr, b_r = three()  # 重置门参数W_xh, W_hh, b_h = three()  # 候选隐状态参数# 输出层参数W_hq = normal((num_hiddens, num_outputs))b_q = torch.zeros(num_outputs, device=device)# 附加梯度params = [W_xz, W_hz, b_z, W_xr, W_hr, b_r, W_xh, W_hh, b_h, W_hq, b_q]for param in params:param.requires_grad_(True)return paramsdef init_gru_state(batch_size, num_hiddens, device):return (torch.zeros((batch_size, num_hiddens), device=device), )def gru(inputs, state, params):W_xz, W_hz, b_z, W_xr, W_hr, b_r, W_xh, W_hh, b_h, W_hq, b_q = paramsH, = stateoutputs = []for X in inputs:Z = torch.sigmoid((X @ W_xz) + (H @ W_hz) + b_z)R = torch.sigmoid((X @ W_xr) + (H @ W_hr) + b_r)H_tilda = torch.tanh((X @ W_xh) + ((R * H) @ W_hh) + b_h)H = Z * H + (1 - Z) * H_tildaY = H @ W_hq + b_qoutputs.append(Y)return torch.cat(outputs, dim=0), (H,)

2.2 训练与预测

vocab_size, num_hiddens, device = len(vocab), 256, d2l.try_gpu()
num_epochs, lr = 500, 1
model = d2l.RNNModelScratch(len(vocab), num_hiddens, device, get_params,init_gru_state, gru)
d2l.train_ch8(model, train_iter, vocab, lr, num_epochs, device)

2.3 简洁实现

num_inputs = vocab_size
gru_layer = nn.GRU(num_inputs, num_hiddens)
model = d2l.RNNModel(gru_layer, len(vocab))
model = model.to(device)
d2l.train_ch8(model, train_iter, vocab, lr, num_epochs, device)

相关文章:

4.5.门控循环单元GRU

门控循环单元GRU ​ 对于一个序列,不是每个观察值都是同等重要的,可能会遇到一下几种情况: 早期观测值对预测所有未来观测值都具有非常重要的意义。 考虑极端情况,第一个观测值包含一个校验和,目的是在序列的末尾辨别…...

10种 Python数据结构,从入门到精通

今天我们将深入探讨 Python 中常用的数据结构,帮助你从基础到精通。每种数据结构都有其独特的特点和适用场景,通过实际代码示例和生活中的比喻,让你更容易理解这些概念。 学习数据结构的三个阶段 1、掌握基本用法:使用这些数据结…...

【AI】人工智能时代,程序员如何保持核心竞争力?

目录 程序员在AI时代的应对策略1. 引言2. AI在编程领域的影响2.1 AI辅助编程工具的现状2.2 AI对编程工作的影响2.3 程序员的机遇与挑战 3. 深耕细作:专注领域的深度学习3.1 专注领域的重要性3.2 深度学习的策略3.2.1 选择合适的领域3.2.2 持续学习和研究3.2.3 实践与…...

WPF学习(3)- WrapPanel控件(瀑布流布局)+DockPanel控件(停靠布局)

WrapPanel控件(瀑布流布局) WrapPanel控件表示将其子控件从左到右的顺序排列,如果第一行显示不了,则自动换至第二行,继续显示剩余的子控件。我们来看看它的结构定义: public class WrapPanel : Panel {pub…...

【python】Python中实现定时任务常见的几种方式原理分析与应用实战

✨✨ 欢迎大家来到景天科技苑✨✨ 🎈🎈 养成好习惯,先赞后看哦~🎈🎈 🏆 作者简介:景天科技苑 🏆《头衔》:大厂架构师,华为云开发者社区专家博主,…...

老公请喝茶,2024年老婆必送老公的养生茶,暖暖的很贴心

在这个快节奏的时代,每个人都在为生活奔波,而家的温馨与关怀,成了我们最坚实的后盾。随着2024年的已经过半,作为妻子,你是否也在寻找一份特别的礼物,来表达对老公深深的爱意与关怀?在这个充满爱…...

3d打印相关资料

模型库 拓竹makerworld爱给...

MySQL1 DDL语言

安装与配置 官网: MySQL :: Download MySQL Installer 阿里云: MySQL8 https://www.alipan.com/s/auhN4pTqpRp 点击链接保存,或者复制本段内容,打开「阿里云盘」APP ,无需下载极速在线查看,视频原画倍速…...

el-tree懒加载状态下实现搜索筛选(纯前端)

1.效果图 &#xff08;1&#xff09;初始状态 &#xff08;2&#xff09;筛选后 2.代码 <template><div><el-inputplaceholder"输入关键字进行过滤"v-model"filterText"input"searchValue"></el-input><el-tree…...

NLP——Transfromer 架构详解

Transformer总体架构图 输入部分&#xff1a;源文本嵌入层及其位置编码器、目标文本嵌入层及其位置编码器 编码器部分 由N个编码器层堆叠而成 每个编码器层由两个子层连接结构组成 第一个子层连接结构包括一个多头自注意力子层和规范化层以及一个残差连接 第二个子层连接结构包…...

大模型算法面试题(二十)

本系列收纳各种大模型面试题及答案。 1、描述Encoder和Decoder中Attention机制的不同之处 Encoder和Decoder中的Attention机制在自然语言处理&#xff08;NLP&#xff09;和序列到序列&#xff08;Seq2Seq&#xff09;模型中扮演着重要角色&#xff0c;它们虽然都利用了Attent…...

2024最新最全面的Selenium 3.0 + Python自动化测试框架

文档说明 Selenium是一个用于Web应用程序自动化测试的工具。Selenium测试直接运行在浏览器中&#xff0c;就像真正的用户在操作一样。 Selenium测试的主要功能包括&#xff1a; 测试与浏览器的兼容性&#xff1a;测试应用程序是否能很好的工作在不同的浏览器和操作系统之上。…...

海运中的甩柜是怎么回事❓怎么才能避免❓

什么是甩柜&#xff1f; 甩柜又叫甩箱&#xff0c;是指集装箱船在起运离港时&#xff0c;船公司没有将此前计划装船的集装箱装运上船&#xff0c;导致部分货物滞留港口。多出现在海运旺季。 为什么会甩柜&#xff1f; 甩箱是集装箱物流中常见的事件&#xff0c;主要因为承运…...

Win11+docker+gpu+vscode+pytorch配置anomalib(2)

在上一篇文章中,我在Win11上通过Docker配置了pytorch,并顺利调用了GPU。在这篇文章中,我将继续完成anomalib的配置。 anomalib是一个非常完善的异常检测框架,我希望通过它来学习经典异常检测算法,并且测试这些算法在我自己的数据集上的效果。 步骤如下: 1. 从docker Hub上…...

AI在招聘市场趋势分析中的应用

一、引言 在数字化、智能化的时代背景下&#xff0c;人工智能&#xff08;AI&#xff09;技术正逐步渗透到各行各业&#xff0c;其中招聘市场也不例外。AI技术的运用不仅极大地提高了招聘的效率和精准度&#xff0c;还在招聘市场趋势分析方面展现出巨大的潜力。本文旨在探讨AI在…...

AMEYA360:太阳诱电应对 165℃的叠层金属类功率电感器实现商品化!

太阳诱电株式会社实现了可以满足车载被动部件认定的可靠性试验规格“AEC-Q200”的叠层金属类功率电感器 MCOIL™“LACNF2012KKTR24MAB”(2.0x1.25x1.0mm&#xff0c;高度为最大值)等 4 个产品的商品化。通过本公司独有的金属类材料和叠层工艺的提高&#xff0c;在叠层金属类功率…...

Nginx进阶-常见配置(三)

nginx 变量 Nginx的配置文件使用的语法的就是一门微型的编程语言。既然是编程语言&#xff0c;一般也就少不了“变量”这种东西。 Nginx配置文件使用的语法主要包括以下几个方面&#xff1a; &#xff08;1&#xff09;配置块 (Block Directives): Nginx配置文件由多个嵌套的…...

开源协作式书签管理器推荐

不知道有没有人和我一样&#xff0c;不怎么爱用app&#xff0c;反而喜欢保留用古老的浏览器浏览新闻和知识的习惯。那么归档网页和书签一定是你非常头疼的事情。 推荐一款开源软件&#xff1a;Linkwarden ,这是一款独立的开源协作式书签管理器。 Linkwarden 允许用户收集、组…...

【线性代数】【二】2.2极大线性无关组与向量空间的基

文章目录 前言一、极大线性无关组二、向量空间的基三、向量维数与向量空间维数总结 前言 上一篇中我们介绍了向量空间的概念&#xff0c;并且学习了对任意给出的一组向量&#xff0c;如果构造一个向量空间。本文将更加细致的去分析张成一个向量空间&#xff0c;具有哪些性质。…...

STM32常见的下载方式有三种

经过对比&#xff0c;推荐使用 SWD下载&#xff0c;只需要一个仿真器&#xff08;如jLINK、ST LINK、 CMSIS DAP 等&#xff09;&#xff0c;比较方便。 不推荐使用串口下载&#xff08;速度慢、无法仿真和调试&#xff09;和 JTAG 下载&#xff08;占用 IO 多&#xff09;。...

RK3568-npu模型转换推理

1. rknn-toolkit2-1.4.0进行模型转换和模型推理 1.1 虚拟机转换和模拟器推理(要求ubuntu18+python3.6) sudo apt-get install python3 python3-dev python3-pip sudo apt-get install libxslt1-dev zlib1g-dev libglib2.0 libsm6 libgl1-mesa-glx libprotobuf-dev gcc cd ~…...

《C语言程序设计 第4版》笔记和代码 第十二章 数据体和数据结构基础

12.1从基本数据类型到抽象数据类型 1 所有的程序设计语言都不能将所有复杂数据对象作为其基本数据类型&#xff0c;因此需要允许用户自定义数据类型&#xff0c;在C语言中&#xff0c;就存在构造数据类型&#xff08;复合数据类型&#xff09;。 2 结构体是构造数据类型的一种…...

学习记录——day26 进程间的通信 无名管道 无名管道 信号通信 特殊的信号处理

目录 一、进程间通信引入 二、无名管道 1、无名管道相关概念 2、无名管道的API接口函数 pipe(int pipefd[2]); 3、管道通信的特点 4、管道的读写特点 三、有名管道 1、有名管道&#xff1a;有名字的管道文件&#xff0c;其他进程可以调用 2、可以用于亲缘进程间的通信&…...

WHAT - xmlhttprequest vs fetch vs wretch

目录 前言1. XMLHttpRequest (XHR)2. fetch3. wretch总结 fetch1. 简洁性和易用性2. 错误处理3. 默认行为和功能扩展4. 请求和响应的处理5. 跨域请求和 CORS6. 现代 Web 开发需求 fetch vs xhr 代码示例使用 XMLHttpRequest使用 fetch代码对比 前言 根据标题我们可以知道今天主…...

吴恩达老师机器学习作业-ex7(聚类)

导入库&#xff0c;读取数据&#xff0c;查看数据类型等进行分析&#xff0c;可视化数据 import matplotlib.pyplot as plt import numpy as np import scipy.io as sio#读取数据 path "./ex7data2.mat" data sio.loadmat(path) # print(type(data)) # print(data…...

lombok 驼峰命名缺陷,导致后台获取参数为null的解决办法

1.问题&#xff1a; 下面是我定义一个请求类的属性&#xff0c;采用Lombok注解&#xff0c;自动构建get和set方法。 Schema(description "父组织编码", requiredMode Schema.RequiredMode.REQUIRED) private String pOrgCode; 遇到这种命名&#xff0c;你会发现在…...

【dockerpython】亲测有效!适合新手!docker创建conda镜像+容器使用(挂载、端口映射、gpu使用)+云镜像仓库教程

文章目录 docker基本概念简介配置镜像加速源创建conda镜像1. 写 Dockerfile文件2. 创建镜像3. 创建容器并测试 容器的使用1. wsl挂载2. 端口映射3. 补充-gpu 云镜像仓库使用1. 登录2. 将本地镜像上传至云镜像仓库3. 从云镜像仓库下载镜像到本地 docker基本概念简介 简单来讲&a…...

矩阵,求矩阵秩、逆矩阵

求矩阵秩的方法&#xff1a; 高斯消元法&#xff1a;通过行变换将矩阵化为行阶梯形矩阵&#xff0c;然后数非零行的数量。LU分解&#xff1a;通过分解矩阵成上下三角矩阵&#xff0c;计算非零对角元素的数量。SVD分解&#xff1a;通过奇异值分解&#xff0c;计算非零奇异值的数…...

指针和const

const int* ptr&#xff0c;int* const ptr&#xff0c;const int* const ptr 这三种指针定义有什么区别&#xff1f;用法有什么不同&#xff1f; 指向的地址是否可变指向的地址上存储的内容是否可变const属性const int* ptr可改变不可改*ptr具有const属性int* const pts不可改…...

基于C#调用文心一言大模型制作桌面软件(可改装接口)

目录 开发前的准备账号注册应用创建应用接入 开始开发创建项目设计界面使用 AK&#xff0c;SK 生成鉴权签名窗体代码 百度智能云千帆大模型平台什么是百度智能云千帆大模型平台模型更新记录 开发前的准备 账号注册 访问百度智能云平台&#xff0c;通过百度账号登录或手机号验证…...