当前位置: 首页 > news >正文

深度学习--------------------------------门控循环单元GRU

目录

  • 候选隐状态
  • 隐状态
  • 门控循环单元GRU从零开始实现代码
    • 初始化模型参数
    • 定义隐藏状态的初始化函数
    • 定义门控循环单元模型
    • 训练
    • 该部分总代码
    • 简洁代码实现

做RNN的时候处理不了太长的序列,这是因为把整个序列信息全部放在隐藏状态里面,当时间很长的话,隐藏状态可能就会累计很多东西,所以对于前面很久以前的信息不易从中抽取出来了。

R t R_t Rt就是重置, Z t Z_t Zt就是更新
门是跟隐藏状态同样长度的一个向量,计算方式跟RNN的隐藏状态是一样的。
在这里插入图片描述

在这里插入图片描述




候选隐状态

在这里插入图片描述

在这里插入图片描述
假设 R t R_t Rt里面的元素靠近零的话,那么 R t R_t Rt点乘 H t − 1 H_{t-1} Ht1就会变得像零。(就等于是把上一个时刻的隐藏状态忘掉。)
如果全部设成0就变成了初始状态,等于这个时刻开始前面的信息全部不要。
如果全部设成1,就表示所有前面的信息全部拿过来做当前的更新。




隐状态

在这里插入图片描述

H t H_t Ht等于 Z t Z_t Zt按元素点乘上一次的隐藏状态+(1- Z t Z_t Zt)按元素点乘候选隐藏状态

Z t Z_t Zt是一个控制单元,叫做update gate。它是在0-1之间的数字。
假设 Z t Z_t Zt都等于1。(就是不更新过去的状态,把过去的状态放到现在)

在这里插入图片描述
假设 Z t Z_t Zt都等于0。(不直接拿过去的状态了,基本上看现在的更新状态)

Z t Z_t Zt里面全0,且 R t R_t Rt里面全1的时候就回到我们RNN的情况下。




门控循环单元GRU从零开始实现代码

import torch
from torch import nn
from d2l import torch as d2lbatch_size, num_steps = 32, 3
train_iter, vocab = d2l.load_data_time_machine(batch_size, num_steps)

初始化模型参数

def get_params(vocab_size, num_hiddens, device):num_inputs = num_outputs = vocab_sizedef normal(shape):return torch.randn(size=shape, device=device) * 0.01# 定义一个函数,生成三组权重和偏置张量,用于不同的门控机制def three():return (normal((num_inputs, num_hiddens)),normal((num_hiddens, num_hiddens)),torch.zeros(num_hiddens, device=device))W_xz, W_hz, b_z = three()  # GRU多了这两行,更新门的权重和偏置W_xr, W_hr, b_r = three()  # GRU多了这两行,重置门的权重和偏置W_xh, W_hh, b_h = three()  # 候选隐藏状态的权重和偏置# 隐藏状态到输出的权重W_hq = normal((num_hiddens, num_outputs))# 输出的偏置b_q = torch.zeros(num_outputs, device=device)params = [W_xz, W_hz, b_z, W_xr, W_hr, b_r, W_xh, W_hh, b_h, W_hq, b_q]# 遍历参数列表中所有参数for param in params:param.requires_grad_(True)return params



定义隐藏状态的初始化函数

定义隐状态的初始化函数init_gru_state。与之前定义的init_rnn_state函数一样,此函数返回一个形状为(批量大小,隐藏单元个数)的张量,张量的值全部为零。

def init_gru_state(batch_size, num_hiddens, device):return (torch.zeros((batch_size, num_hiddens), device=device), )



定义门控循环单元模型

def gru(inputs, state, params):W_xz, W_hz, b_z, W_xr, W_hr, b_r, W_xh, W_hh, b_h, W_hq, b_q = paramsH, = stateoutputs = []for X in inputs:Z = torch.sigmoid((X @ W_xz) + (H @ W_hz) + b_z)R = torch.sigmoid((X @ W_xr) + (H @ W_hr) + b_r)H_tilda = torch.tanh((X @ W_xh) + ((R * H) @ W_hh) + b_h)H = Z * H + (1 - Z) * H_tildaY = H @ W_hq + b_qoutputs.append(Y)return torch.cat(outputs, dim=0), (H,)



训练

vocab_size, num_hiddens, device = len(vocab), 256, d2l.try_gpu()
num_epochs, lr = 500, 1
model = d2l.RNNModelScratch(len(vocab), num_hiddens, device, get_params,init_gru_state, gru)
d2l.train_ch8(model, train_iter, vocab, lr, num_epochs, device)



该部分总代码

import torch
from torch import nn
from d2l import torch as d2l# 初始化模型参数
def get_params(vocab_size, num_hiddens, device):num_inputs = num_outputs = vocab_sizedef normal(shape):return torch.randn(size=shape, device=device) * 0.01# 定义一个函数,生成三组权重和偏置张量,用于不同的门控机制def three():return (normal((num_inputs, num_hiddens)),normal((num_hiddens, num_hiddens)),torch.zeros(num_hiddens, device=device))# 初始化GRU中的权重和偏置# 权重和偏置用于控制更新门W_xz, W_hz, b_z = three()  # GRU多了这两行# 权重和偏置用于控制重置门W_xr, W_hr, b_r = three()  # GRU多了这两行W_xh, W_hh, b_h = three()W_hq = normal((num_hiddens, num_outputs))b_q = torch.zeros(num_outputs, device=device)params = [W_xz, W_hz, b_z, W_xr, W_hr, b_r, W_xh, W_hh, b_h, W_hq, b_q]for param in params:param.requires_grad_(True)return params# 定义隐藏状态的初始化函数
def init_gru_state(batch_size, num_hiddens, device):return (torch.zeros((batch_size, num_hiddens), device=device),)# 定义门控循环单元模型
def gru(inputs, state, params):# 参数 params 解包为多个变量,分别表示模型中的权重和偏置W_xz, W_hz, b_z, W_xr, W_hr, b_r, W_xh, W_hh, b_h, W_hq, b_q = paramsH, = stateoutputs = []# 遍历输入序列中的每个时间步for X in inputs:# 更新门控机制 ZZ = torch.sigmoid((X @ W_xz) + (H @ W_hz) + b_z)# 重置门控机制 RR = torch.sigmoid((X @ W_xr) + (H @ W_hr) + b_r)H_tilda = torch.tanh((X @ W_xh) + ((R * H) @ W_hh) + b_h)H = Z * H + (1 - Z) * H_tildaY = H @ W_hq + b_qoutputs.append(Y)# 将所有输出拼接在一起,并返回拼接后的结果和最终的隐藏状态return torch.cat(outputs, dim=0), (H,)batch_size, num_steps = 32, 35
train_iter, vocab = d2l.load_data_time_machine(batch_size, num_steps)
vocab_size, num_hiddens, device = len(vocab), 256, d2l.try_gpu()
num_epochs, lr = 500, 1
model = d2l.RNNModelScratch(len(vocab), num_hiddens, device, get_params, init_gru_state, gru)
d2l.train_ch8(model, train_iter, vocab, lr, num_epochs, device)
d2l.plt.show()

在这里插入图片描述

在这里插入图片描述




简洁代码实现

from torch import nn
from d2l import torch as d2lbatch_size, num_steps = 32, 35
train_iter, vocab = d2l.load_data_time_machine(batch_size, num_steps)
vocab_size, num_hiddens, device = len(vocab), 256, d2l.try_gpu()
num_epochs, lr = 500, 1
num_inputs = vocab_size
gru_layer = nn.GRU(num_inputs, num_hiddens)
model = d2l.RNNModel(gru_layer, len(vocab))
model = model.to(device)
d2l.train_ch8(model, train_iter, vocab, lr, num_epochs, device)
d2l.plt.show()

在这里插入图片描述

在这里插入图片描述

相关文章:

深度学习--------------------------------门控循环单元GRU

目录 门候选隐状态隐状态门控循环单元GRU从零开始实现代码初始化模型参数定义隐藏状态的初始化函数定义门控循环单元模型训练该部分总代码简洁代码实现 做RNN的时候处理不了太长的序列,这是因为把整个序列信息全部放在隐藏状态里面,当时间很长的话&#…...

【实战】| X小程序任意用户登录

复现步骤 在登陆时,弹出这个页面时 抓包,观察数据包的内容 会发现有mobile值(密文)和iv值(随机数),拿到密文,肯定时想到解密,想要解密就必须知道密文,…...

计算机毕业设计之:云中e百货微信小程序设计与实现(源码+文档+定制)

博主介绍: ✌我是阿龙,一名专注于Java技术领域的程序员,全网拥有10W粉丝。作为CSDN特邀作者、博客专家、新星计划导师,我在计算机毕业设计开发方面积累了丰富的经验。同时,我也是掘金、华为云、阿里云、InfoQ等平台…...

CEX上币趋势分析:Infra赛道与Ton生态的未来

在当前的加密市场中,CEX(中心化交易所)上币的选择愈发重要,尤其是对项目方而言。根据 FMG 的整理,结合「杀破狼」的交易所上币信息,显然 Infra 赛道成为了交易所的热门选择,而 Ton 生态也展现出…...

数组基础(c++)

第1题 精挑细选 时限:1s 空间:256m 小王是公司的仓库管理员,一天,他接到了这样一个任务:从仓库中找出一根钢管。这听起来不算什么,但是这根钢管的要求可真是让他犯难了,要求如下&#x…...

第十三届蓝桥杯真题Python c组A.排列字母(持续更新)

博客主页:音符犹如代码系列专栏:蓝桥杯关注博主,后期持续更新系列文章如果有错误感谢请大家批评指出,及时修改感谢大家点赞👍收藏⭐评论✍ 【问题描述】 小蓝要把一个字符串中的字母按其在字母表中的顺序排列。 例如&a…...

IDEA自动清理类中未使用的import包

目录 1.建议清理包的理由 2.清理未使用包的方式 2.1 手动快捷键清理 2.2 设置自动清理 1.建议清理包的理由 有时候项目类文件中会有很多包被引入了,但是并没有被使用,这会增加项目的编译时间并且代码可读性也会变差。在开发过程中,建议设…...

加工零件C++

题目: 样例解释: 样例#1: 编号为 1 的工人想生产第 1 阶段的零件,需要编号为 2 的工人提供原材料。 编号为 2 的工人想生产第 1 阶段的零件,需要编号为 1 和 3 的工人提供原材料。 编号为 3 的工人想生产第 1 阶段的零件&#x…...

Etcd 是一个分布式的键值存储系统,用于共享配置和服务发现

Etcd 是一个分布式的键值存储系统,用于共享配置和服务发现。它最初由 CoreOS 开发,并已成为许多分布式系统中的关键组件之一,特别是在 Kubernetes 中扮演着核心角色。Etcd 的设计目标是简单、可靠、安全,并且易于使用。 Etcd 的特…...

如何帮助我们改造升级原有架构——基于TDengine 平台

一、简介 TDengine 核心是一款高性能、集群开源、云原生的时序数据库(Time Series Database,TSDB),专为物联网IoT平台、工业互联网、电力、IT 运维等场景设计并优化,具有极强的弹性伸缩能力。同时它还带有内建的缓存、…...

MySQl查询分析工具 EXPLAIN ANALYZE

文章目录 EXPLAIN ANALYZE是什么Iterator 输出内容解读EXPLAIN ANALYZE和EXPLAIN FORMATTREE的区别单个 Iterator 内容解读 案例分析案例1 文件排序案例2 简单的JOIN查询 参考资料:https://hackmysql.com/book-2/ EXPLAIN ANALYZE是什么 EXPLAIN ANALYZE是MySQL8.…...

RestClientException异常

什么情况下会抛出RestClientException异常 RestClientException 异常通常在使用 Spring 的 RestTemplate 进行 RESTful API 调用时抛出。以下是一些常见的情况: 网络问题:当无法连接到目标服务器时,例如网络中断或服务器不可达。 HTTP 状态…...

poi如何实现自定义导出Excel-纵向横向合并单元格,自定义填充数据列

前情提要 首先需要明确自己需要导出的excel构成是如何的,比如我需要导出一个自定义表头的excel表格,第一行A到X是标题需要横向合并单元格,第二行和第三行是表头,A到J需要第二行和第三行纵向合并单元格,K到N的第二行需…...

6--苍穹外卖-SpringBoot项目中菜品管理 详解(二)

目录 菜品分页查询 需求分析和设计 代码开发 设计DTO类 设计VO类 Controller层 Service层接口 Service层实现类 Mapper层 功能测试 删除菜品 需求设计和分析 代码开发 Controller层 Service层接口 Service层实现类 Mapper层 功能测试 修改菜品 需求分析和设…...

游戏怎么录制?王者荣耀游戏录制指南:iOS与电脑端全面教程

在王者荣耀的战场上,每一个五杀、每一次极限逃生都可能成为你游戏生涯中的高光时刻。但这些瞬间往往转瞬即逝,如何将它们永久保存,成为你游戏历程中不可磨灭的印记呢?本文将为你揭晓答案。无论你是手持iPhone的iOS用户&#xff0c…...

Vue.js组件开发指南

Vue.js组件开发指南 Vue.js 是一个渐进式的 JavaScript 框架,用于构建用户界面。它的核心是基于组件的开发模式。通过将页面分解为多个独立的、可复用的组件,开发者能够更轻松地构建复杂的应用。本文将深入探讨 Vue.js 组件开发的基础知识,并…...

【流计算】流计算概论

前言 作者在之前写过一个大数据的专栏,包含GFS、BigTable、MapReduce、HDFS、Hadoop、LSM树、HBase、Spark,专栏地址: https://blog.csdn.net/joker_zjn/category_12631789.html?fromshareblogcolumn&sharetypeblogcolumn&sharerI…...

20230819盘锦锦州葫芦岛自驾

2023年08月19日,上午带娃和老人驾车前往朝阳,逛凤凰山,中午吃了免费的素面味道不错。下午开车去鸟化石公园单独买儿童票43元。晚上驾车到盘锦,住红海滩民宿95元。 2023年08月20日,逛盘锦红海滩一天,有稻田画…...

Unity 与虚幻引擎对比:两大游戏开发引擎的优劣分析

在游戏开发领域,Unity 和虚幻引擎(Unreal Engine)是两款最为知名且广泛使用的引擎。它们各有特点,适合不同类型的开发者和项目。在这篇博客中,我们将深入探讨这两大引擎的核心功能、适用场景、优缺点,以及如…...

UDS_4_传输存储的数据功能单元

目录 一. DTC 二. 0x14服务 三. 0x19服务 3.1 0x19服务 3.2 0x01子功能 3.3 0x02子功能 3.4 0x04子功能 3.5 0x06子功能 3.6 0x0A子功能 一. DTC 》DTC-Diagnostic Trouble Code J1939-73 DTCFormat DTC SPN FMI CM OC 8-1位 8-1位 8-6位 5-1位 8位 7-1位 字节1 字节…...

K8S认证|CKS题库+答案| 11. AppArmor

目录 11. AppArmor 免费获取并激活 CKA_v1.31_模拟系统 题目 开始操作: 1)、切换集群 2)、切换节点 3)、切换到 apparmor 的目录 4)、执行 apparmor 策略模块 5)、修改 pod 文件 6)、…...

UDP(Echoserver)

网络命令 Ping 命令 检测网络是否连通 使用方法: ping -c 次数 网址ping -c 3 www.baidu.comnetstat 命令 netstat 是一个用来查看网络状态的重要工具. 语法:netstat [选项] 功能:查看网络状态 常用选项: n 拒绝显示别名&#…...

OpenLayers 分屏对比(地图联动)

注:当前使用的是 ol 5.3.0 版本,天地图使用的key请到天地图官网申请,并替换为自己的key 地图分屏对比在WebGIS开发中是很常见的功能,和卷帘图层不一样的是,分屏对比是在各个地图中添加相同或者不同的图层进行对比查看。…...

使用 Streamlit 构建支持主流大模型与 Ollama 的轻量级统一平台

🎯 使用 Streamlit 构建支持主流大模型与 Ollama 的轻量级统一平台 📌 项目背景 随着大语言模型(LLM)的广泛应用,开发者常面临多个挑战: 各大模型(OpenAI、Claude、Gemini、Ollama)接口风格不统一;缺乏一个统一平台进行模型调用与测试;本地模型 Ollama 的集成与前…...

html css js网页制作成品——HTML+CSS榴莲商城网页设计(4页)附源码

目录 一、👨‍🎓网站题目 二、✍️网站描述 三、📚网站介绍 四、🌐网站效果 五、🪓 代码实现 🧱HTML 六、🥇 如何让学习不再盲目 七、🎁更多干货 一、👨‍&#x1f…...

七、数据库的完整性

七、数据库的完整性 主要内容 7.1 数据库的完整性概述 7.2 实体完整性 7.3 参照完整性 7.4 用户定义的完整性 7.5 触发器 7.6 SQL Server中数据库完整性的实现 7.7 小结 7.1 数据库的完整性概述 数据库完整性的含义 正确性 指数据的合法性 有效性 指数据是否属于所定…...

C语言中提供的第三方库之哈希表实现

一. 简介 前面一篇文章简单学习了C语言中第三方库(uthash库)提供对哈希表的操作,文章如下: C语言中提供的第三方库uthash常用接口-CSDN博客 本文简单学习一下第三方库 uthash库对哈希表的操作。 二. uthash库哈希表操作示例 u…...

如何应对敏捷转型中的团队阻力

应对敏捷转型中的团队阻力需要明确沟通敏捷转型目的、提升团队参与感、提供充分的培训与支持、逐步推进敏捷实践、建立清晰的奖励和反馈机制。其中,明确沟通敏捷转型目的尤为关键,团队成员只有清晰理解转型背后的原因和利益,才能降低对变化的…...

在 Visual Studio Code 中使用驭码 CodeRider 提升开发效率:以冒泡排序为例

目录 前言1 插件安装与配置1.1 安装驭码 CodeRider1.2 初始配置建议 2 示例代码:冒泡排序3 驭码 CodeRider 功能详解3.1 功能概览3.2 代码解释功能3.3 自动注释生成3.4 逻辑修改功能3.5 单元测试自动生成3.6 代码优化建议 4 驭码的实际应用建议5 常见问题与解决建议…...

CppCon 2015 学习:Reactive Stream Processing in Industrial IoT using DDS and Rx

“Reactive Stream Processing in Industrial IoT using DDS and Rx” 是指在工业物联网(IIoT)场景中,结合 DDS(Data Distribution Service) 和 Rx(Reactive Extensions) 技术,实现 …...