当前位置: 首页 > news >正文

4.5.门控循环单元GRU

门控循环单元GRU

​ 对于一个序列,不是每个观察值都是同等重要的,可能会遇到一下几种情况:

  1. 早期观测值对预测所有未来观测值都具有非常重要的意义。

    考虑极端情况,第一个观测值包含一个校验和,目的是在序列的末尾辨别校验和事否正确,我们希望有某些机制在一个记忆元里存储重要的早期信息。如果没有这样的机制,我们将不得不给这个观测值指定一个非常大的梯度。

  2. 一些词元没有相关的观测值

    在对网页内容进行情感分析时,可能一些辅助的HTML代码与网页传达的情绪无关,我们希望有一些机制来跳过隐状态中的此类词元

  3. 序列的各个部分存在逻辑中断

    书的章节之间可能也会有过渡,证券的熊市,牛市之间可能会有过渡。这种情况下, 最好有一种方法来重置我们的内部状态表示

​ 有很多方法来解决这类问题,最早的方法是"长短期记忆"(long-short-term memory,LSTM)。门控循环单元(gated recurrent unit,GRU)是一个稍微简化的变体,通常能提供同等的效果,并且计算速度更快。

1.门控隐状态

​ 门控循环单元与普通的循环神经网络之间的关键区别在于: 前者支持隐状态的门控。 这意味着模型有专门的机制来确定应该何时更新隐状态, 以及应该何时重置隐状态。这些机制是可学习的。

1.1 重置门和更新门

在这里插入图片描述

​ 重置门和更新门的输入如图所示。重置门允许我们控制”可能还想记住“的过去状态的数量;更新门将允许我们控制新状态中有多少个是旧状态的副本。

​ 其中输入是由当前时间步的输入和前一时间步的隐状态给出,两个门的输出由使用sigmoid激活函数的两个全连接层给出。

​ 假设输入是一个小批量 X t ∈ R n × d X_t\in \R^{n\times d} XtRn×d(样本数量 n n n,输入个数 d d d),上一个时间步的隐状态是 H t − 1 ∈ R n × h H_{t-1}\in \R^{n\times h} Ht1Rn×h(隐藏单元个数 h h h)。那么重置门 R t R_t Rt和更新门 Z t Z_t Zt(均为 R n × h \R^{n\times h} Rn×h)的计算如下所示:
R t = σ ( X t W x r + H t − 1 W h r + b r ) Z t = σ ( X t W x z + H t − 1 W h z + b z ) R_t = \sigma(X_tW_{xr}+H_{t-1}W_{hr}+b_r)\\ Z_t = \sigma(X_t W_{xz}+H_{t-1}W_{hz}+b_z) Rt=σ(XtWxr+Ht1Whr+br)Zt=σ(XtWxz+Ht1Whz+bz)
​ 其中 W x r , W x z ∈ R d × h W_{xr},W_{xz}\in \R^{d\times h} Wxr,WxzRd×h W h r , W h z ∈ R h × h W_{hr},W_{hz}\in \R^{h\times h} Whr,WhzRh×h是权重参数, b r , b z ∈ R 1 × h b_r,b_z\in \R^{1\times h} br,bzR1×h是偏置参数。求和过程中会触发广播机制。 我们使用sigmoid函数将输入值转换到区间¥(0,1)$。

1.2 候选隐状态

在这里插入图片描述

​ 将重置门 R t R_t Rt与常规隐状态更新机制集成,得到在时间步 t t t的候选隐状态 H ^ t ∈ R n × h \hat{H}_t\in\R ^{n\times h} H^tRn×h
H ^ t = t a n h ( X t W x h + ( R t ⊙ H t − 1 ) W h h + b h ) \hat{H}_t = tanh(X_tW_{xh}+(R_t\odot H_{t-1})W_{hh}+b_h) H^t=tanh(XtWxh+(RtHt1)Whh+bh)
​ 其中 W x h ∈ R d × h W_{xh}\in\R^{d\times h} WxhRd×h W h h ∈ R h × h W_{hh}\in \R ^{h\times h} WhhRh×h是权重参数, b h ∈ R 1 × h b_h\in \R^{1\times h} bhR1×h是偏置项,符号 ⊙ \odot 是Hadamard积(按元素乘积)运算符,此处使用tanh非线性激活函数确保候选隐状态中的值保持在区间 ( − 1 , 1 ) (-1,1) (1,1)中。。

R t ⊙ H t − 1 R_t\odot H_{t-1} RtHt1的元素相乘可以减少以往状态的影响,每当重置门 R t R_t Rt中的项接近1时,我们恢复一个普通的循环神经网络,如果 R t R_t Rt全为0,则之前的信息全部遗忘。重置门是可以学习的,通过学习,可以根据目前的输入决定哪些东西需要遗忘。

1.3 隐状态

在这里插入图片描述

​ 1.2中得出的是候选隐状态,真正的隐状态需要结合更新门的效果。这一步确定新的隐状态 H t ∈ R n × h H_t\in \R^{n\times h} HtRn×h在多大程度上来自旧的状态 H t − 1 H_{t-1} Ht1和新的候选状态 H t ^ \hat{H_t} Ht^。更新门 Z t Z_t Zt仅需要在 H t − 1 H_{t-1} Ht1 H ^ t \hat{H}_t H^t之间进行按元素的凸组合就可以实现,于是得出了最终的更新公式:
H t = Z t ⊙ H t − 1 + ( 1 − Z t ) ⊙ H ^ t H_t =Z_t \odot H_{t-1}+(1-Z_t)\odot \hat{H}_t Ht=ZtHt1+(1Zt)H^t
​ 容易看出,更新门 Z t Z_t Zt越趋近1,模型就倾向只保留旧状态,此时来自输入 X t X_t Xt的信息基本上被忽略,从而有效地跳过了依赖链条中的时间步 t t t。相反,当 Z t Z_t Zt接近0时,新的隐状态 H t H_t Ht就会接近候选隐状态 H t ^ \hat {H_t} Ht^

2.代码实现

2.1 从零开始

import torch
from torch import nn
from d2l import torch as d2lbatch_size, num_steps = 32, 35
train_iter, vocab = d2l.load_data_time_machine(batch_size, num_steps)def get_params(vocab_size, num_hiddens, device):num_inputs = num_outputs = vocab_sizedef normal(shape):return torch.randn(size=shape, device=device) * 0.01def three():return (normal((num_inputs, num_hiddens)),normal((num_hiddens, num_hiddens)),torch.zeros(num_hiddens, device=device))W_xz, W_hz, b_z = three()  # 更新门参数W_xr, W_hr, b_r = three()  # 重置门参数W_xh, W_hh, b_h = three()  # 候选隐状态参数# 输出层参数W_hq = normal((num_hiddens, num_outputs))b_q = torch.zeros(num_outputs, device=device)# 附加梯度params = [W_xz, W_hz, b_z, W_xr, W_hr, b_r, W_xh, W_hh, b_h, W_hq, b_q]for param in params:param.requires_grad_(True)return paramsdef init_gru_state(batch_size, num_hiddens, device):return (torch.zeros((batch_size, num_hiddens), device=device), )def gru(inputs, state, params):W_xz, W_hz, b_z, W_xr, W_hr, b_r, W_xh, W_hh, b_h, W_hq, b_q = paramsH, = stateoutputs = []for X in inputs:Z = torch.sigmoid((X @ W_xz) + (H @ W_hz) + b_z)R = torch.sigmoid((X @ W_xr) + (H @ W_hr) + b_r)H_tilda = torch.tanh((X @ W_xh) + ((R * H) @ W_hh) + b_h)H = Z * H + (1 - Z) * H_tildaY = H @ W_hq + b_qoutputs.append(Y)return torch.cat(outputs, dim=0), (H,)

2.2 训练与预测

vocab_size, num_hiddens, device = len(vocab), 256, d2l.try_gpu()
num_epochs, lr = 500, 1
model = d2l.RNNModelScratch(len(vocab), num_hiddens, device, get_params,init_gru_state, gru)
d2l.train_ch8(model, train_iter, vocab, lr, num_epochs, device)

2.3 简洁实现

num_inputs = vocab_size
gru_layer = nn.GRU(num_inputs, num_hiddens)
model = d2l.RNNModel(gru_layer, len(vocab))
model = model.to(device)
d2l.train_ch8(model, train_iter, vocab, lr, num_epochs, device)

相关文章:

4.5.门控循环单元GRU

门控循环单元GRU ​ 对于一个序列,不是每个观察值都是同等重要的,可能会遇到一下几种情况: 早期观测值对预测所有未来观测值都具有非常重要的意义。 考虑极端情况,第一个观测值包含一个校验和,目的是在序列的末尾辨别…...

10种 Python数据结构,从入门到精通

今天我们将深入探讨 Python 中常用的数据结构,帮助你从基础到精通。每种数据结构都有其独特的特点和适用场景,通过实际代码示例和生活中的比喻,让你更容易理解这些概念。 学习数据结构的三个阶段 1、掌握基本用法:使用这些数据结…...

【AI】人工智能时代,程序员如何保持核心竞争力?

目录 程序员在AI时代的应对策略1. 引言2. AI在编程领域的影响2.1 AI辅助编程工具的现状2.2 AI对编程工作的影响2.3 程序员的机遇与挑战 3. 深耕细作:专注领域的深度学习3.1 专注领域的重要性3.2 深度学习的策略3.2.1 选择合适的领域3.2.2 持续学习和研究3.2.3 实践与…...

WPF学习(3)- WrapPanel控件(瀑布流布局)+DockPanel控件(停靠布局)

WrapPanel控件(瀑布流布局) WrapPanel控件表示将其子控件从左到右的顺序排列,如果第一行显示不了,则自动换至第二行,继续显示剩余的子控件。我们来看看它的结构定义: public class WrapPanel : Panel {pub…...

【python】Python中实现定时任务常见的几种方式原理分析与应用实战

✨✨ 欢迎大家来到景天科技苑✨✨ 🎈🎈 养成好习惯,先赞后看哦~🎈🎈 🏆 作者简介:景天科技苑 🏆《头衔》:大厂架构师,华为云开发者社区专家博主,…...

老公请喝茶,2024年老婆必送老公的养生茶,暖暖的很贴心

在这个快节奏的时代,每个人都在为生活奔波,而家的温馨与关怀,成了我们最坚实的后盾。随着2024年的已经过半,作为妻子,你是否也在寻找一份特别的礼物,来表达对老公深深的爱意与关怀?在这个充满爱…...

3d打印相关资料

模型库 拓竹makerworld爱给...

MySQL1 DDL语言

安装与配置 官网: MySQL :: Download MySQL Installer 阿里云: MySQL8 https://www.alipan.com/s/auhN4pTqpRp 点击链接保存,或者复制本段内容,打开「阿里云盘」APP ,无需下载极速在线查看,视频原画倍速…...

el-tree懒加载状态下实现搜索筛选(纯前端)

1.效果图 &#xff08;1&#xff09;初始状态 &#xff08;2&#xff09;筛选后 2.代码 <template><div><el-inputplaceholder"输入关键字进行过滤"v-model"filterText"input"searchValue"></el-input><el-tree…...

NLP——Transfromer 架构详解

Transformer总体架构图 输入部分&#xff1a;源文本嵌入层及其位置编码器、目标文本嵌入层及其位置编码器 编码器部分 由N个编码器层堆叠而成 每个编码器层由两个子层连接结构组成 第一个子层连接结构包括一个多头自注意力子层和规范化层以及一个残差连接 第二个子层连接结构包…...

大模型算法面试题(二十)

本系列收纳各种大模型面试题及答案。 1、描述Encoder和Decoder中Attention机制的不同之处 Encoder和Decoder中的Attention机制在自然语言处理&#xff08;NLP&#xff09;和序列到序列&#xff08;Seq2Seq&#xff09;模型中扮演着重要角色&#xff0c;它们虽然都利用了Attent…...

2024最新最全面的Selenium 3.0 + Python自动化测试框架

文档说明 Selenium是一个用于Web应用程序自动化测试的工具。Selenium测试直接运行在浏览器中&#xff0c;就像真正的用户在操作一样。 Selenium测试的主要功能包括&#xff1a; 测试与浏览器的兼容性&#xff1a;测试应用程序是否能很好的工作在不同的浏览器和操作系统之上。…...

海运中的甩柜是怎么回事❓怎么才能避免❓

什么是甩柜&#xff1f; 甩柜又叫甩箱&#xff0c;是指集装箱船在起运离港时&#xff0c;船公司没有将此前计划装船的集装箱装运上船&#xff0c;导致部分货物滞留港口。多出现在海运旺季。 为什么会甩柜&#xff1f; 甩箱是集装箱物流中常见的事件&#xff0c;主要因为承运…...

Win11+docker+gpu+vscode+pytorch配置anomalib(2)

在上一篇文章中,我在Win11上通过Docker配置了pytorch,并顺利调用了GPU。在这篇文章中,我将继续完成anomalib的配置。 anomalib是一个非常完善的异常检测框架,我希望通过它来学习经典异常检测算法,并且测试这些算法在我自己的数据集上的效果。 步骤如下: 1. 从docker Hub上…...

AI在招聘市场趋势分析中的应用

一、引言 在数字化、智能化的时代背景下&#xff0c;人工智能&#xff08;AI&#xff09;技术正逐步渗透到各行各业&#xff0c;其中招聘市场也不例外。AI技术的运用不仅极大地提高了招聘的效率和精准度&#xff0c;还在招聘市场趋势分析方面展现出巨大的潜力。本文旨在探讨AI在…...

AMEYA360:太阳诱电应对 165℃的叠层金属类功率电感器实现商品化!

太阳诱电株式会社实现了可以满足车载被动部件认定的可靠性试验规格“AEC-Q200”的叠层金属类功率电感器 MCOIL™“LACNF2012KKTR24MAB”(2.0x1.25x1.0mm&#xff0c;高度为最大值)等 4 个产品的商品化。通过本公司独有的金属类材料和叠层工艺的提高&#xff0c;在叠层金属类功率…...

Nginx进阶-常见配置(三)

nginx 变量 Nginx的配置文件使用的语法的就是一门微型的编程语言。既然是编程语言&#xff0c;一般也就少不了“变量”这种东西。 Nginx配置文件使用的语法主要包括以下几个方面&#xff1a; &#xff08;1&#xff09;配置块 (Block Directives): Nginx配置文件由多个嵌套的…...

开源协作式书签管理器推荐

不知道有没有人和我一样&#xff0c;不怎么爱用app&#xff0c;反而喜欢保留用古老的浏览器浏览新闻和知识的习惯。那么归档网页和书签一定是你非常头疼的事情。 推荐一款开源软件&#xff1a;Linkwarden ,这是一款独立的开源协作式书签管理器。 Linkwarden 允许用户收集、组…...

【线性代数】【二】2.2极大线性无关组与向量空间的基

文章目录 前言一、极大线性无关组二、向量空间的基三、向量维数与向量空间维数总结 前言 上一篇中我们介绍了向量空间的概念&#xff0c;并且学习了对任意给出的一组向量&#xff0c;如果构造一个向量空间。本文将更加细致的去分析张成一个向量空间&#xff0c;具有哪些性质。…...

STM32常见的下载方式有三种

经过对比&#xff0c;推荐使用 SWD下载&#xff0c;只需要一个仿真器&#xff08;如jLINK、ST LINK、 CMSIS DAP 等&#xff09;&#xff0c;比较方便。 不推荐使用串口下载&#xff08;速度慢、无法仿真和调试&#xff09;和 JTAG 下载&#xff08;占用 IO 多&#xff09;。...

避坑指南:思科模拟器做链路聚合时,你可能会遇到的5个报错及解决方法

思科模拟器链路聚合实战&#xff1a;5个典型报错分析与精准排错指南 在Packet Tracer中配置链路聚合时&#xff0c;最令人头疼的往往不是基础配置步骤&#xff0c;而是那些突如其来的报错信息。上周有位学员在CCNA备考群里发了一张截图&#xff1a;%EC-5-CANNOT_BUNDLE2: Fa0/2…...

SHA-3:从海绵结构到抗量子密码学的基石

1. SHA-3的诞生背景与核心价值 2004年&#xff0c;密码学界发现SHA-1存在理论漏洞&#xff0c;这直接推动了NIST启动新一代哈希算法竞赛。经过5年激烈角逐&#xff0c;Keccak团队提出的海绵结构方案最终胜出。与传统哈希算法不同&#xff0c;SHA-3不是对SHA-2的简单升级&#x…...

WindowsCleaner:智能化解救C盘空间危机的开源解决方案

WindowsCleaner&#xff1a;智能化解救C盘空间危机的开源解决方案 【免费下载链接】WindowsCleaner Windows Cleaner——专治C盘爆红及各种不服&#xff01; 项目地址: https://gitcode.com/gh_mirrors/wi/WindowsCleaner 一、痛点剖析&#xff1a;C盘空间管理的深层困境…...

低成本搭建AI助手:OpenClaw+nanobot镜像每月节省80%Token费用

低成本搭建AI助手&#xff1a;OpenClawnanobot镜像每月节省80%Token费用 1. 为什么选择OpenClawnanobot组合 作为一个长期关注AI自动化工具的技术爱好者&#xff0c;我一直在寻找一个既经济实惠又能满足个人需求的AI助手方案。市面上大多数解决方案要么价格昂贵&#xff0c;要…...

C语言嵌入式开发核心技术难点解析

C语言嵌入式开发中的三大核心技术难点解析 1. 指针&#xff1a;内存操作的艺术 指针是C语言中最具挑战性的概念&#xff0c;也是嵌入式系统开发中不可或缺的核心技术。指针本质上是一个存储内存地址的特殊变量&#xff0c;其设计哲学直接映射了计算机底层的内存管理机制。 1…...

HIL测试入门避坑指南:从CANoe配置到故障注入的完整踩坑实录

HIL测试实战避坑手册&#xff1a;从零搭建车窗ECU测试台架的12个关键陷阱 第一次接触HIL测试时&#xff0c;我盯着实验室里那些闪烁的指示灯和缠绕的线缆&#xff0c;仿佛面对着一个未知的宇宙。作为车载测试领域最具挑战性的环节之一&#xff0c;HIL测试既是验证ECU可靠性的终…...

航拍小目标检测入门必看:YOLOv8 VisDrone实战第一阶段,基线mAP从32%提升至58%

本文是YOLOv8 VisDrone航拍目标检测全系列实战的第一阶段,基于我3年智慧城市、无人机安防项目的一线落地经验,针对VisDrone航拍场景最核心的「小目标密集、尺度变化大、类别分布不均、遮挡严重」四大痛点,完整拆解从0到1搭建基线模型的全流程。 本文全程配套VisDrone数据集…...

MSE、MAE、Binary/Categorical Cross-Entropy、HingeLoss五种损失函数的典型应用场景

目录第一类&#xff1a;回归任务&#xff08;预测具体数值&#xff09;&#x1f453;1. MSE (均方误差) —— 重罚离群点&#x1f453;2. MAE (平均绝对误差) —— 鲁棒性强第二类&#xff1a;分类任务&#xff08;判断属于哪一类&#xff09;&#x1f453;3. Binary Cross-Ent…...

腾讯验证码攻防新篇:六宫格、滑块与文字识别的毫秒级破解实战

1. 腾讯验证码体系深度解析 腾讯验证码作为当前互联网安全防护的重要组成部分&#xff0c;已经发展出包括六宫格、图标点选、滑块验证和文字识别在内的多种形式。这些验证码在设计时充分考虑了人机交互的特点&#xff0c;通过视觉识别和行为分析双重机制来区分真实用户和自动化…...

Netgear路由器Telnet功能启用工具:技术解析与实践指南

Netgear路由器Telnet功能启用工具&#xff1a;技术解析与实践指南 【免费下载链接】netgear_telnet Netgear Enable Telnet (New Crypto) 项目地址: https://gitcode.com/gh_mirrors/ne/netgear_telnet 一、功能价值&#xff1a;技术突破点与应用场景 1.1 核心功能概述…...