当前位置: 首页 > news >正文

GRU(门控循环单元)的原理与代码实现

1.GRU的原理

1.1重置门和更新门

1.2候选隐藏状态

 

1.3隐状态 

2. GRU的代码实现

#导包
import torch
from torch import nn
import dltools#加载数据
batch_size, num_steps = 32, 35
train_iter, vocab = dltools.load_data_time_machine(batch_size, num_steps)#封装函数:实现初始化模型参数
def get_params(vocab_size, num_hiddens, device):num_inputs = num_outputs = vocab_sizedef normal(shape):return torch.randn(size=shape, device=device) * 0.01def three():return (normal((num_inputs, num_hiddens)),normal((num_hiddens, num_hiddens)),torch.zeros(num_hiddens, device=device))# 更新门参数W_xz, W_hz, b_z = three()# 重置门W_xr, W_hr, b_r = three()# 候选隐藏状态参数W_xh, W_hh, b_h = three()# 输出层参数W_hq = normal((num_hiddens, num_outputs))b_q = torch.zeros(num_outputs, device=device)params = [W_xz, W_hz, b_z, W_xr, W_hr, b_r, W_xh, W_hh, b_h, W_hq, b_q]for param in params:param.requires_grad_(True)return params#定义函数:初始化隐藏状态
def init_gru_state(batch_size, num_hiddens, device):return (torch.zeros((batch_size, num_hiddens), device=device))#定义函数:构建GRU网络结构
def gru(inputs, state, params):[W_xz, W_hz, b_z, W_xr, W_hr, b_r, W_xh, W_hh, b_h, W_hq, b_q] = paramsH, = stateoutputs = []for X in inputs:Z = torch.sigmoid((X @ W_xz) + (H @ W_hz) + b_z)R = torch.sigmoid((X @ W_xr) + (H @ W_hr) + b_r)H_tilda = torch.tanh((X @ W_xh) + ((R * H) @ W_hh) + b_h)H = Z * H + (1 - Z) * H_tildaY = H @ W_hq + b_qoutputs.append(Y)return torch.cat(outputs, dim=0), (H, )#训练和预测
vocab_size, num_hiddens, device = len(vocab), 256, dltools.try_gpu()
num_epochs, lr = 500, 5
model = dltools.RNNModelScratch(len(vocab), num_hiddens, device, get_params, init_gru_state, gru)
dltools.train_ch8(model, train_iter, vocab, lr, num_epochs, device)

 

3.pytorch 简洁实现版_GRU调包实现 

num_inputs = vocab_size
#创建网络层
gru_layer = nn.GRU(num_inputs, num_hiddens)
#建模
model = dltools.RNNModel(gru_layer, len(vocab))
#将模型转到device上
model = model.to(device)
#模型训练
dltools.train_ch8(model, train_iter, vocab, lr, num_epochs, device)

 

4.知识点个人理解

 

 

相关文章:

GRU(门控循环单元)的原理与代码实现

1.GRU的原理 1.1重置门和更新门 1.2候选隐藏状态 1.3隐状态 2. GRU的代码实现 #导包 import torch from torch import nn import dltools#加载数据 batch_size, num_steps 32, 35 train_iter, vocab dltools.load_data_time_machine(batch_size, num_steps)#封装函数&…...

【医疗大数据】医疗保健领域的大数据管理:采用挑战和影响

选自期刊**《International Journal of Information Management》**(IF:21.0) 医疗保健领域的大数据管理:采用挑战和影响 1、研究背景 本研究的目标是调查阻止医疗机构实施成功大数据系统的组织障碍,识别和评估这些障碍,并为管理…...

gevent + flask 接口会卡住

在使用 gevent 和 Flask 处理 CPU 密集型任务时,确实可能会遇到性能瓶颈。这是因为 gevent 主要优化的是 I/O 密集型任务,而不是 CPU 密集型任务。以下是一些可能的原因和解决方案: 原因 Gevent 的协程模型: gevent 使用 greenle…...

SpringCloud Alibaba五大组件之——Sentinel

SpringCloud Alibaba五大组件之——Sentinel(文末附有完整项目GitHub链接) 前言一、什么是Sentinel二、Sentinel控制台1.下载jar包2.自己打包3.启动控制台4.浏览器访问 三、项目中引入Sentinel1.在api-service模块的pom文件引入依赖:2.applic…...

brpc之io事件分发器

结构 #mermaid-svg-v4SjrdGXadMO4udP {font-family:"trebuchet ms",verdana,arial,sans-serif;font-size:16px;fill:#333;}#mermaid-svg-v4SjrdGXadMO4udP .error-icon{fill:#552222;}#mermaid-svg-v4SjrdGXadMO4udP .error-text{fill:#552222;stroke:#552222;}#merm…...

MySQL | 知识 | 从底层看清 InnoDB 数据结构

文章目录 一、InnoDB 简介InnoDB 行格式COMPACT 行格式CHAR(M) 列的存储格式VARCHAR(M) 最多能存储的数据记录中的数据太多产生的溢出行溢出的临界点 二、表空间文件的结构三、InnoDB 数据页结构页页的概览Infimum 和 Supremum使用Page Directory页的真实面貌 四、B 树是如何进…...

es的封装

提示:文章写完后,目录可以自动生成,如何生成可参考右边的帮助文档 文章目录 前言一、类和接口介绍0.封装思想1.es的操作分类 二、创建索引1.成员变量2.构造函数2.添加字段3.发送请求4.创建索引总体代码 三.插入数据四.删除数据五.查询数据 前…...

写一个自动化记录鼠标/键盘的动作,然后可以重复执行的python程序

import sys import threading import time from PyQt5.QtWidgets import * from auto_fun import * import pyautogui import pynput from PyQt5.QtCore import pyqtSignal from MouseModule import * from pynput import keyboardlocal_list [] # 保存操作坐标、动作、文本 …...

Spring Boot-热部署问题

Spring Boot 热部署问题分析与解决方案 热部署(Hot Deployment)是指在应用程序运行过程中,无需停止应用就可以动态加载新代码、配置或资源,从而提升开发效率。在 Spring Boot 开发中,热部署是一项非常实用的功能&…...

深度学习——管理模型的参数

改编自李沐老师《动手深度学习》5.2. 参数管理 — 动手学深度学习 2.0.0 documentation (d2l.ai) 在深度学习中,一旦我们选择了模型架构并设置了超参数,我们就会进入训练阶段。训练的目标是找到能够最小化损失函数的模型参数。这些参数在训练后用于预测&…...

芯片验证板卡设计原理图:372-基于XC7VX690T的万兆光纤、双FMC扩展的综合计算平台 RISCV 芯片验证平台

基于XC7VX690T的万兆光纤、双FMC扩展的综合计算平台 RISCV 芯片验证平台 一、板卡概述 基于V7的高性能PCIe信号处理板,北京太速科技板卡选用Xilinx 公司Virtex7系列FPGA XC7VX690T-2FFG1761C为处理芯片,板卡提供两个标准FMC插槽,适用于…...

【软设】 系统开发基础

【软设】 系统开发基础 一.软件工程概述 (了解一下大概的流程就行) 1. 可行性分析与项目开发计划 目的:评估项目的经济性、技术性和运营性,判断项目是否值得投资和开发。确定开发时间、预算、所需资源等。 可行性分析&#xff…...

Linux移植之系统烧写

直接参考【正点原子】I.MX6U嵌入式Linux驱动开发指南V1.81 本文仅作为个人笔记使用,方便进一步记录自己的实践总结。 前面我们已经移植好了 uboot 和 linux kernle,制作好了根文件系统。但是我们移植都是通过网络来测试的,在实际的产品开发中…...

【数据结构与算法】LeetCode:双指针法

文章目录 LeetCode:双指针法正序同向而行(快慢指针)移除元素移动零(Hot 100)删除有序数组中的重复项颜色分类(Hot 100)压缩字符串移除链表元素删除排序链表中的重复元素删除排序链表中的重复元素…...

Istio下载及安装

Istio 是一个开源的服务网格,用于连接、管理和保护微服务。以下是下载并安装 Istio 的步骤。 官网文档:https://istio.io/latest/zh/docs/setup/getting-started/ 下载 Istio 前往Istio 发布页面下载适用于您的操作系统的安装文件,或者自动…...

Redis基础数据结构之 Sorted Set 有序集合 源码解读

目录标题 Sorted Set 是什么?Sorted Set 数据结构跳表(skiplist)跳表节点的结构定义跳表的定义跳表节点查询层数设置 Sorted Set 基本操作 Sorted Set 是什么? 有序集合(Sorted Set)是 Redis 中一种重要的数据类型,…...

蓝队技能-应急响应篇Web内存马查杀JVM分析Class提取诊断反编译日志定性

知识点: 1、应急响应-Web内存马-定性&排查 2、应急响应-Web内存马-分析&日志 注:传统WEB类型的内存马只要网站重启后就清除了。 演示案例-蓝队技能-JAVA Web内存马-JVM分析&日志URL&内存查杀 0、环境搭建 参考地址:http…...

递归快速获取机构树型图

一般组织架构都会有层级关系,根部门的parentId一般设置为null或者0等特殊字符,而次级部门及以下的parentId则指向他们父节点的id。 以此为基础,业务上经常会有查询整个组织架构层级关系的需求,返回对象中的children属性用来存储子…...

[Web安全 网络安全]-XSS跨站脚本攻击

文章目录: 一:前言 1.定义 2.漏洞出现的原因 3.鉴别可能存在XSS漏洞的地方 4.攻击原理 5.危害 6.防御 7.环境 7.1 靶场 7.2 自动扫描工具 7.3 手工测试工具 8.payload是什么 二:常用的标签语法 三:XSS的分类 反射…...

数据库数据恢复—SQL Server附加数据库出现“错误823”怎么恢复数据?

SQL Server数据库故障: SQL Server附加数据库出现错误823,附加数据库失败。数据库没有备份,无法通过备份恢复数据库。 SQL Server数据库出现823错误的可能原因有:数据库物理页面损坏、数据库物理页面校验值损坏导致无法识别该页面…...

OpenShamrock:零基础搭建QQ智能交互系统完全指南

OpenShamrock:零基础搭建QQ智能交互系统完全指南 【免费下载链接】OpenShamrock A Bot Framework based on Xposed with OneBot11 项目地址: https://gitcode.com/gh_mirrors/op/OpenShamrock 核心价值解析:为什么选择OpenShamrock构建QQ机器人&a…...

C++vector迭代器失效全解析

深入讲解 C vector 的迭代器失效在 C 中,std::vector 是一个动态数组,它支持随机访问和高效的元素操作。迭代器是 C 中用于遍历容器元素的重要工具,类似于指针。但使用 vector 时,某些操作可能导致迭代器失效(iterator…...

实战指南 — 基于TCGA数据的差异表达分析全流程与可视化呈现

1. TCGA数据获取与准备 第一次接触TCGA数据库时,我被它庞大的数据量震撼到了。作为癌症基因组图谱计划,TCGA收录了33种癌症类型、超过2万例患者的基因组数据。对于肝癌(LIHC)研究来说,这里简直就是一座金矿。 进入TCGA官网后,你会…...

ThinkPHP6(TP6)控制器404问题排查与Nginx伪静态配置指南

1. 为什么你的TP6控制器总是404? 最近帮朋友排查一个ThinkPHP6项目,明明控制器写得没问题,路由也配置了,但一访问就蹦出个404页面。这种问题在新手部署TP6时特别常见,尤其是用Nginx服务器的环境。我自己第一次用TP6时也…...

Boomer:轻量高效的Linux屏幕放大镜工具

Boomer:轻量高效的Linux屏幕放大镜工具 【免费下载链接】boomer Zoomer application for Linux 项目地址: https://gitcode.com/gh_mirrors/boo/boomer 当你需要精准查看屏幕细节时是否常感到操作繁琐?无论是设计工作中的像素级调整、编程时的代码…...

mysql技巧(十六):覆盖索引 vs 回表 —— 让查询效率提升 10 倍的核心技巧

📝 本章学习目标本章聚焦数据库性能优化,帮助读者彻底掌握覆盖索引与回表的核心原理。通过本章学习,你将全面理解覆盖索引 vs 回表这一核心主题,并能在实际工作中应用这些技巧,让查询效率提升 10 倍以上。 一、引言&am…...

别再只记*#*#284#*#*了!揭秘小米手机日志抓取的‘售后模式’:CIT工具(*#*#6484#*#*)的隐藏用法与解读

解锁小米手机CIT工具的隐藏潜能:从硬件诊断到日志深度解析 在智能手机高度普及的今天,用户对设备问题的自主排查需求日益增长。小米手机内置的CIT工具(Customer Interface Test)作为售后服务的核心诊断利器,其实蕴藏着…...

推荐8款提升论文效率的AI工具(含爱毕业aibiye)和简易使用教程

在学术研究领域,AI技术的应用显著提升了论文写作的效率与质量。以下推荐8款功能强大的智能工具,涵盖文献解析、内容生成、文本优化等关键环节,助力研究者高效完成从资料收集到论文润色的全流程工作。这些创新解决方案能够有效简化研究过程&am…...

Vmware系列虚拟机系列【仅供参考】:解决 VMware 嵌套虚拟化提示 关闭“侧通道缓解“

解决 VMware 嵌套虚拟化提示 关闭“侧通道缓解“ 解决 VMware 嵌套虚拟化提示 关闭"侧通道缓解" 解决方法 方法1: 方法2: 完全禁用 Hyper-V 方法3 参考链接: 解决 VMware 嵌套虚拟化提示 关闭"侧通道缓解" 最近给电脑做了新版的 Windows 11 LTSC操作系…...

AI图像增强:3步实现低清图片修复的开源跨平台工具

AI图像增强:3步实现低清图片修复的开源跨平台工具 【免费下载链接】Real-ESRGAN-GUI Lovely Real-ESRGAN / Real-CUGAN GUI Wrapper 项目地址: https://gitcode.com/gh_mirrors/re/Real-ESRGAN-GUI Real-ESRGAN-GUI是一款基于Flutter开发的开源AI图像增强工具…...