深度学习--------------------------------门控循环单元GRU
目录
- 门
- 候选隐状态
- 隐状态
- 门控循环单元GRU从零开始实现代码
- 初始化模型参数
- 定义隐藏状态的初始化函数
- 定义门控循环单元模型
- 训练
- 该部分总代码
- 简洁代码实现
做RNN的时候处理不了太长的序列,这是因为把整个序列信息全部放在隐藏状态里面,当时间很长的话,隐藏状态可能就会累计很多东西,所以对于前面很久以前的信息不易从中抽取出来了。
门
R t R_t Rt就是重置, Z t Z_t Zt就是更新
门是跟隐藏状态同样长度的一个向量,计算方式跟RNN的隐藏状态是一样的。


候选隐状态


假设 R t R_t Rt里面的元素靠近零的话,那么 R t R_t Rt点乘 H t − 1 H_{t-1} Ht−1就会变得像零。(就等于是把上一个时刻的隐藏状态忘掉。)
如果全部设成0就变成了初始状态,等于这个时刻开始前面的信息全部不要。
如果全部设成1,就表示所有前面的信息全部拿过来做当前的更新。
隐状态

H t H_t Ht等于 Z t Z_t Zt按元素点乘上一次的隐藏状态+(1- Z t Z_t Zt)按元素点乘候选隐藏状态
Z t Z_t Zt是一个控制单元,叫做update gate。它是在0-1之间的数字。
假设 Z t Z_t Zt都等于1。(就是不更新过去的状态,把过去的状态放到现在)

假设 Z t Z_t Zt都等于0。(不直接拿过去的状态了,基本上看现在的更新状态)
Z t Z_t Zt里面全0,且 R t R_t Rt里面全1的时候就回到我们RNN的情况下。
门控循环单元GRU从零开始实现代码
import torch
from torch import nn
from d2l import torch as d2lbatch_size, num_steps = 32, 3
train_iter, vocab = d2l.load_data_time_machine(batch_size, num_steps)
初始化模型参数
def get_params(vocab_size, num_hiddens, device):num_inputs = num_outputs = vocab_sizedef normal(shape):return torch.randn(size=shape, device=device) * 0.01# 定义一个函数,生成三组权重和偏置张量,用于不同的门控机制def three():return (normal((num_inputs, num_hiddens)),normal((num_hiddens, num_hiddens)),torch.zeros(num_hiddens, device=device))W_xz, W_hz, b_z = three() # GRU多了这两行,更新门的权重和偏置W_xr, W_hr, b_r = three() # GRU多了这两行,重置门的权重和偏置W_xh, W_hh, b_h = three() # 候选隐藏状态的权重和偏置# 隐藏状态到输出的权重W_hq = normal((num_hiddens, num_outputs))# 输出的偏置b_q = torch.zeros(num_outputs, device=device)params = [W_xz, W_hz, b_z, W_xr, W_hr, b_r, W_xh, W_hh, b_h, W_hq, b_q]# 遍历参数列表中所有参数for param in params:param.requires_grad_(True)return params
定义隐藏状态的初始化函数
定义隐状态的初始化函数init_gru_state。与之前定义的init_rnn_state函数一样,此函数返回一个形状为(批量大小,隐藏单元个数)的张量,张量的值全部为零。
def init_gru_state(batch_size, num_hiddens, device):return (torch.zeros((batch_size, num_hiddens), device=device), )
定义门控循环单元模型
def gru(inputs, state, params):W_xz, W_hz, b_z, W_xr, W_hr, b_r, W_xh, W_hh, b_h, W_hq, b_q = paramsH, = stateoutputs = []for X in inputs:Z = torch.sigmoid((X @ W_xz) + (H @ W_hz) + b_z)R = torch.sigmoid((X @ W_xr) + (H @ W_hr) + b_r)H_tilda = torch.tanh((X @ W_xh) + ((R * H) @ W_hh) + b_h)H = Z * H + (1 - Z) * H_tildaY = H @ W_hq + b_qoutputs.append(Y)return torch.cat(outputs, dim=0), (H,)
训练
vocab_size, num_hiddens, device = len(vocab), 256, d2l.try_gpu()
num_epochs, lr = 500, 1
model = d2l.RNNModelScratch(len(vocab), num_hiddens, device, get_params,init_gru_state, gru)
d2l.train_ch8(model, train_iter, vocab, lr, num_epochs, device)
该部分总代码
import torch
from torch import nn
from d2l import torch as d2l# 初始化模型参数
def get_params(vocab_size, num_hiddens, device):num_inputs = num_outputs = vocab_sizedef normal(shape):return torch.randn(size=shape, device=device) * 0.01# 定义一个函数,生成三组权重和偏置张量,用于不同的门控机制def three():return (normal((num_inputs, num_hiddens)),normal((num_hiddens, num_hiddens)),torch.zeros(num_hiddens, device=device))# 初始化GRU中的权重和偏置# 权重和偏置用于控制更新门W_xz, W_hz, b_z = three() # GRU多了这两行# 权重和偏置用于控制重置门W_xr, W_hr, b_r = three() # GRU多了这两行W_xh, W_hh, b_h = three()W_hq = normal((num_hiddens, num_outputs))b_q = torch.zeros(num_outputs, device=device)params = [W_xz, W_hz, b_z, W_xr, W_hr, b_r, W_xh, W_hh, b_h, W_hq, b_q]for param in params:param.requires_grad_(True)return params# 定义隐藏状态的初始化函数
def init_gru_state(batch_size, num_hiddens, device):return (torch.zeros((batch_size, num_hiddens), device=device),)# 定义门控循环单元模型
def gru(inputs, state, params):# 参数 params 解包为多个变量,分别表示模型中的权重和偏置W_xz, W_hz, b_z, W_xr, W_hr, b_r, W_xh, W_hh, b_h, W_hq, b_q = paramsH, = stateoutputs = []# 遍历输入序列中的每个时间步for X in inputs:# 更新门控机制 ZZ = torch.sigmoid((X @ W_xz) + (H @ W_hz) + b_z)# 重置门控机制 RR = torch.sigmoid((X @ W_xr) + (H @ W_hr) + b_r)H_tilda = torch.tanh((X @ W_xh) + ((R * H) @ W_hh) + b_h)H = Z * H + (1 - Z) * H_tildaY = H @ W_hq + b_qoutputs.append(Y)# 将所有输出拼接在一起,并返回拼接后的结果和最终的隐藏状态return torch.cat(outputs, dim=0), (H,)batch_size, num_steps = 32, 35
train_iter, vocab = d2l.load_data_time_machine(batch_size, num_steps)
vocab_size, num_hiddens, device = len(vocab), 256, d2l.try_gpu()
num_epochs, lr = 500, 1
model = d2l.RNNModelScratch(len(vocab), num_hiddens, device, get_params, init_gru_state, gru)
d2l.train_ch8(model, train_iter, vocab, lr, num_epochs, device)
d2l.plt.show()


简洁代码实现
from torch import nn
from d2l import torch as d2lbatch_size, num_steps = 32, 35
train_iter, vocab = d2l.load_data_time_machine(batch_size, num_steps)
vocab_size, num_hiddens, device = len(vocab), 256, d2l.try_gpu()
num_epochs, lr = 500, 1
num_inputs = vocab_size
gru_layer = nn.GRU(num_inputs, num_hiddens)
model = d2l.RNNModel(gru_layer, len(vocab))
model = model.to(device)
d2l.train_ch8(model, train_iter, vocab, lr, num_epochs, device)
d2l.plt.show()


相关文章:
深度学习--------------------------------门控循环单元GRU
目录 门候选隐状态隐状态门控循环单元GRU从零开始实现代码初始化模型参数定义隐藏状态的初始化函数定义门控循环单元模型训练该部分总代码简洁代码实现 做RNN的时候处理不了太长的序列,这是因为把整个序列信息全部放在隐藏状态里面,当时间很长的话&#…...
【实战】| X小程序任意用户登录
复现步骤 在登陆时,弹出这个页面时 抓包,观察数据包的内容 会发现有mobile值(密文)和iv值(随机数),拿到密文,肯定时想到解密,想要解密就必须知道密文,…...
计算机毕业设计之:云中e百货微信小程序设计与实现(源码+文档+定制)
博主介绍: ✌我是阿龙,一名专注于Java技术领域的程序员,全网拥有10W粉丝。作为CSDN特邀作者、博客专家、新星计划导师,我在计算机毕业设计开发方面积累了丰富的经验。同时,我也是掘金、华为云、阿里云、InfoQ等平台…...
CEX上币趋势分析:Infra赛道与Ton生态的未来
在当前的加密市场中,CEX(中心化交易所)上币的选择愈发重要,尤其是对项目方而言。根据 FMG 的整理,结合「杀破狼」的交易所上币信息,显然 Infra 赛道成为了交易所的热门选择,而 Ton 生态也展现出…...
数组基础(c++)
第1题 精挑细选 时限:1s 空间:256m 小王是公司的仓库管理员,一天,他接到了这样一个任务:从仓库中找出一根钢管。这听起来不算什么,但是这根钢管的要求可真是让他犯难了,要求如下&#x…...
第十三届蓝桥杯真题Python c组A.排列字母(持续更新)
博客主页:音符犹如代码系列专栏:蓝桥杯关注博主,后期持续更新系列文章如果有错误感谢请大家批评指出,及时修改感谢大家点赞👍收藏⭐评论✍ 【问题描述】 小蓝要把一个字符串中的字母按其在字母表中的顺序排列。 例如&a…...
IDEA自动清理类中未使用的import包
目录 1.建议清理包的理由 2.清理未使用包的方式 2.1 手动快捷键清理 2.2 设置自动清理 1.建议清理包的理由 有时候项目类文件中会有很多包被引入了,但是并没有被使用,这会增加项目的编译时间并且代码可读性也会变差。在开发过程中,建议设…...
加工零件C++
题目: 样例解释: 样例#1: 编号为 1 的工人想生产第 1 阶段的零件,需要编号为 2 的工人提供原材料。 编号为 2 的工人想生产第 1 阶段的零件,需要编号为 1 和 3 的工人提供原材料。 编号为 3 的工人想生产第 1 阶段的零件&#x…...
Etcd 是一个分布式的键值存储系统,用于共享配置和服务发现
Etcd 是一个分布式的键值存储系统,用于共享配置和服务发现。它最初由 CoreOS 开发,并已成为许多分布式系统中的关键组件之一,特别是在 Kubernetes 中扮演着核心角色。Etcd 的设计目标是简单、可靠、安全,并且易于使用。 Etcd 的特…...
如何帮助我们改造升级原有架构——基于TDengine 平台
一、简介 TDengine 核心是一款高性能、集群开源、云原生的时序数据库(Time Series Database,TSDB),专为物联网IoT平台、工业互联网、电力、IT 运维等场景设计并优化,具有极强的弹性伸缩能力。同时它还带有内建的缓存、…...
MySQl查询分析工具 EXPLAIN ANALYZE
文章目录 EXPLAIN ANALYZE是什么Iterator 输出内容解读EXPLAIN ANALYZE和EXPLAIN FORMATTREE的区别单个 Iterator 内容解读 案例分析案例1 文件排序案例2 简单的JOIN查询 参考资料:https://hackmysql.com/book-2/ EXPLAIN ANALYZE是什么 EXPLAIN ANALYZE是MySQL8.…...
RestClientException异常
什么情况下会抛出RestClientException异常 RestClientException 异常通常在使用 Spring 的 RestTemplate 进行 RESTful API 调用时抛出。以下是一些常见的情况: 网络问题:当无法连接到目标服务器时,例如网络中断或服务器不可达。 HTTP 状态…...
poi如何实现自定义导出Excel-纵向横向合并单元格,自定义填充数据列
前情提要 首先需要明确自己需要导出的excel构成是如何的,比如我需要导出一个自定义表头的excel表格,第一行A到X是标题需要横向合并单元格,第二行和第三行是表头,A到J需要第二行和第三行纵向合并单元格,K到N的第二行需…...
6--苍穹外卖-SpringBoot项目中菜品管理 详解(二)
目录 菜品分页查询 需求分析和设计 代码开发 设计DTO类 设计VO类 Controller层 Service层接口 Service层实现类 Mapper层 功能测试 删除菜品 需求设计和分析 代码开发 Controller层 Service层接口 Service层实现类 Mapper层 功能测试 修改菜品 需求分析和设…...
游戏怎么录制?王者荣耀游戏录制指南:iOS与电脑端全面教程
在王者荣耀的战场上,每一个五杀、每一次极限逃生都可能成为你游戏生涯中的高光时刻。但这些瞬间往往转瞬即逝,如何将它们永久保存,成为你游戏历程中不可磨灭的印记呢?本文将为你揭晓答案。无论你是手持iPhone的iOS用户,…...
Vue.js组件开发指南
Vue.js组件开发指南 Vue.js 是一个渐进式的 JavaScript 框架,用于构建用户界面。它的核心是基于组件的开发模式。通过将页面分解为多个独立的、可复用的组件,开发者能够更轻松地构建复杂的应用。本文将深入探讨 Vue.js 组件开发的基础知识,并…...
【流计算】流计算概论
前言 作者在之前写过一个大数据的专栏,包含GFS、BigTable、MapReduce、HDFS、Hadoop、LSM树、HBase、Spark,专栏地址: https://blog.csdn.net/joker_zjn/category_12631789.html?fromshareblogcolumn&sharetypeblogcolumn&sharerI…...
20230819盘锦锦州葫芦岛自驾
2023年08月19日,上午带娃和老人驾车前往朝阳,逛凤凰山,中午吃了免费的素面味道不错。下午开车去鸟化石公园单独买儿童票43元。晚上驾车到盘锦,住红海滩民宿95元。 2023年08月20日,逛盘锦红海滩一天,有稻田画…...
Unity 与虚幻引擎对比:两大游戏开发引擎的优劣分析
在游戏开发领域,Unity 和虚幻引擎(Unreal Engine)是两款最为知名且广泛使用的引擎。它们各有特点,适合不同类型的开发者和项目。在这篇博客中,我们将深入探讨这两大引擎的核心功能、适用场景、优缺点,以及如…...
UDS_4_传输存储的数据功能单元
目录 一. DTC 二. 0x14服务 三. 0x19服务 3.1 0x19服务 3.2 0x01子功能 3.3 0x02子功能 3.4 0x04子功能 3.5 0x06子功能 3.6 0x0A子功能 一. DTC 》DTC-Diagnostic Trouble Code J1939-73 DTCFormat DTC SPN FMI CM OC 8-1位 8-1位 8-6位 5-1位 8位 7-1位 字节1 字节…...
基于FPGA的PID算法学习———实现PID比例控制算法
基于FPGA的PID算法学习 前言一、PID算法分析二、PID仿真分析1. PID代码2.PI代码3.P代码4.顶层5.测试文件6.仿真波形 总结 前言 学习内容:参考网站: PID算法控制 PID即:Proportional(比例)、Integral(积分&…...
简易版抽奖活动的设计技术方案
1.前言 本技术方案旨在设计一套完整且可靠的抽奖活动逻辑,确保抽奖活动能够公平、公正、公开地进行,同时满足高并发访问、数据安全存储与高效处理等需求,为用户提供流畅的抽奖体验,助力业务顺利开展。本方案将涵盖抽奖活动的整体架构设计、核心流程逻辑、关键功能实现以及…...
微软PowerBI考试 PL300-选择 Power BI 模型框架【附练习数据】
微软PowerBI考试 PL300-选择 Power BI 模型框架 20 多年来,Microsoft 持续对企业商业智能 (BI) 进行大量投资。 Azure Analysis Services (AAS) 和 SQL Server Analysis Services (SSAS) 基于无数企业使用的成熟的 BI 数据建模技术。 同样的技术也是 Power BI 数据…...
java调用dll出现unsatisfiedLinkError以及JNA和JNI的区别
UnsatisfiedLinkError 在对接硬件设备中,我们会遇到使用 java 调用 dll文件 的情况,此时大概率出现UnsatisfiedLinkError链接错误,原因可能有如下几种 类名错误包名错误方法名参数错误使用 JNI 协议调用,结果 dll 未实现 JNI 协…...
Cilium动手实验室: 精通之旅---20.Isovalent Enterprise for Cilium: Zero Trust Visibility
Cilium动手实验室: 精通之旅---20.Isovalent Enterprise for Cilium: Zero Trust Visibility 1. 实验室环境1.1 实验室环境1.2 小测试 2. The Endor System2.1 部署应用2.2 检查现有策略 3. Cilium 策略实体3.1 创建 allow-all 网络策略3.2 在 Hubble CLI 中验证网络策略源3.3 …...
抖音增长新引擎:品融电商,一站式全案代运营领跑者
抖音增长新引擎:品融电商,一站式全案代运营领跑者 在抖音这个日活超7亿的流量汪洋中,品牌如何破浪前行?自建团队成本高、效果难控;碎片化运营又难成合力——这正是许多企业面临的增长困局。品融电商以「抖音全案代运营…...
spring:实例工厂方法获取bean
spring处理使用静态工厂方法获取bean实例,也可以通过实例工厂方法获取bean实例。 实例工厂方法步骤如下: 定义实例工厂类(Java代码),定义实例工厂(xml),定义调用实例工厂ÿ…...
Robots.txt 文件
什么是robots.txt? robots.txt 是一个位于网站根目录下的文本文件(如:https://example.com/robots.txt),它用于指导网络爬虫(如搜索引擎的蜘蛛程序)如何抓取该网站的内容。这个文件遵循 Robots…...
css的定位(position)详解:相对定位 绝对定位 固定定位
在 CSS 中,元素的定位通过 position 属性控制,共有 5 种定位模式:static(静态定位)、relative(相对定位)、absolute(绝对定位)、fixed(固定定位)和…...
ABAP设计模式之---“简单设计原则(Simple Design)”
“Simple Design”(简单设计)是软件开发中的一个重要理念,倡导以最简单的方式实现软件功能,以确保代码清晰易懂、易维护,并在项目需求变化时能够快速适应。 其核心目标是避免复杂和过度设计,遵循“让事情保…...
