当前位置: 首页 > news >正文

代码 RNN原理及手写复现

29、PyTorch RNN的原理及其手写复现_哔哩哔哩_bilibili

笔记连接: https://pan.baidu.com/s/1_Sm7ptEiJtTTq3vQWgOTNg?pwd=2rei 提取码: 2rei

import torch
import torch.nn as nn
bs,T=2,3  # 批大小,输入序列长度
input_size,hidden_size = 2,3 # 输入特征大小,隐含层特征大小
input = torch.randn(bs,T,input_size)  # 随机初始化一个输入特征序列
h_prev = torch.zeros(bs,hidden_size) # 初始隐含状态
# step1 调用pytorch RNN API
rnn = nn.RNN(input_size,hidden_size,batch_first=True)
rnn_output,state_finall = rnn(input,h_prev.unsqueeze(0))print(rnn_output)
print(state_finall)
# step2 手写 rnn_forward函数,实现RNN的计算原理
def rnn_forward(input,weight_ih,weight_hh,bias_ih,bias_hh,h_prev):bs,T,input_size = input.shapeh_dim = weight_ih.shape[0]h_out = torch.zeros(bs,T,h_dim) # 初始化一个输出(状态)矩阵for t in range(T):x = input[:,t,:].unsqueeze(2)  # 获取当前时刻的输入特征,bs*input_size*1w_ih_batch = weight_ih.unsqueeze(0).tile(bs,1,1) # bs * h_dim * input_sizew_hh_batch = weight_hh.unsqueeze(0).tile(bs,1,1)# bs * h_dim * h_dimw_times_x = torch.bmm(w_ih_batch,x).squeeze(-1) # bs*h_dimw_times_h = torch.bmm(w_hh_batch,h_prev.unsqueeze(2)).squeeze(-1) # bs*h_himh_prev = torch.tanh(w_times_x + bias_ih + w_times_h + bias_hh)h_out[:,t,:] = h_prevreturn h_out,h_prev.unsqueeze(0)
# 验证结果
custom_rnn_output,custom_state_finall = rnn_forward(input,rnn.weight_ih_l0,rnn.weight_hh_l0,rnn.bias_ih_l0,rnn.bias_hh_l0,h_prev)
print(custom_rnn_output)
print(custom_state_finall)
print(torch.allclose(rnn_output,custom_rnn_output))
print(torch.allclose(state_finall,custom_state_finall))
# step3 手写一个 bidirectional_rnn_forward函数,实现双向RNN的计算原理
def bidirectional_rnn_forward(input,weight_ih,weight_hh,bias_ih,bias_hh,h_prev,weight_ih_reverse,weight_hh_reverse,bias_ih_reverse,bias_hh_reverse,h_prev_reverse):bs,T,input_size = input.shapeh_dim = weight_ih.shape[0]h_out = torch.zeros(bs,T,h_dim*2) # 初始化一个输出(状态)矩阵,注意双向是两倍的特征大小forward_output = rnn_forward(input,weight_ih,weight_hh,bias_ih,bias_hh,h_prev)[0]  # forward layerbackward_output = rnn_forward(torch.flip(input,[1]),weight_ih_reverse,weight_hh_reverse,bias_ih_reverse, bias_hh_reverse,h_prev_reverse)[0] # backward layer# 将input按照时间的顺序翻转h_out[:,:,:h_dim] = forward_outputh_out[:,:,h_dim:] = torch.flip(backward_output,[1]) #需要再翻转一下 才能和forward output拼接h_n = torch.zeros(bs,2,h_dim)  # 要最后的状态连接h_n[:,0,:] = forward_output[:,-1,:]h_n[:,1,:] = backward_output[:,-1,:]h_n = h_n.transpose(0,1)return h_out,h_n# return h_out,h_out[:,-1,:].reshape((bs,2,h_dim)).transpose(0,1)# 验证一下 bidirectional_rnn_forward的正确性
bi_rnn = nn.RNN(input_size,hidden_size,batch_first=True,bidirectional=True)
h_prev = torch.zeros((2,bs,hidden_size))
bi_rnn_output,bi_state_finall = bi_rnn(input,h_prev)for k,v in bi_rnn.named_parameters():print(k,v)
custom_bi_rnn_output,custom_bi_state_finall = bidirectional_rnn_forward(input,bi_rnn.weight_ih_l0,bi_rnn.weight_hh_l0,bi_rnn.bias_ih_l0,bi_rnn.bias_hh_l0,h_prev[0],bi_rnn.weight_ih_l0_reverse,bi_rnn.weight_hh_l0_reverse,bi_rnn.bias_ih_l0_reverse,bi_rnn.bias_hh_l0_reverse,h_prev[1])
print("Pytorch API output")
print(bi_rnn_output)
print(bi_state_finall)print("\n custom bidirectional_rnn_forward function output:")
print(custom_bi_rnn_output)
print(custom_bi_state_finall)
print(torch.allclose(bi_rnn_output,custom_bi_rnn_output))
print(torch.allclose(bi_state_finall,custom_bi_state_finall))

相关文章:

代码 RNN原理及手写复现

29、PyTorch RNN的原理及其手写复现_哔哩哔哩_bilibili 笔记连接: https://pan.baidu.com/s/1_Sm7ptEiJtTTq3vQWgOTNg?pwd2rei 提取码: 2rei import torch import torch.nn as nn bs,T2,3 # 批大小,输入序列长度 input_size,hidden_size 2,3 # 输入特征大小&a…...

企业官网的在线客服,如何提高效果?

企业官网的在线客服,如何提高效果? 作者:开源呼叫中心系统 FreeIPCC,github地址:https://github.com/lihaiya/freeipcc 提高企业官网在线客服的效果,是提升客户体验、增强客户满意度和忠诚度的关键。一个…...

「实战应用」如何可视化 DHTMLX Scheduler 中的资源工作量?

DHTMLX Scheduler是一个全面的 UI 组件,用于处理面向业务的 Web 应用程序中复杂的调度和任务管理需求。但是,某些场景可能需要自定义解决方案。例如,如果项目的资源(即劳动力)有限,则需要确保以更高的精度分…...

论文阅读《BEVFormer》

BEVFormer: Learning Bird’s-Eye-View Representation from Multi-Camera Images via Spatiotemporal Transformers 目录 摘要1 介绍2 相关工作2.1 基于Transformer的2D感知 摘要 3D视觉感知任务对于自动驾驶系统至关重要,包括基于多相机图像的3D检测和地图分割。…...

sql专题 之 sql的执行顺序

文章目录 sql的执行顺序sql语句的格式实际的执行顺序:虚拟表 vs 数据集虚拟表 结果集总结嵌套查询在sql查询中的执行顺序 前文我们了解了sql常用的语句,这次我们对于这些语句来个小思索 戳这里→ sql专题 之 常用命令 sql的执行顺序 SQL语句的执行顺序是…...

Vue3 -- 基于Vue3+TS+Vite项目【项目搭建及初始化】

兼容性注意: Vite 需要 Node.js 版本 18+ 或 20+。然而,有些模板需要依赖更高的 Node 版本才能正常运行,当你的包管理器发出警告时,请注意升级你的 Node 版本。【摘抄自vite官网】 这里我用的node版本是 v18.20.2 创建项目: 创建项目我们可以使用npm、yarn、pnpm、bun …...

CTF-RE: TEA系列解密脚本

// // Created by A5rZ on 2024/10/26. //#ifndef WORK_TEA_H #define WORK_TEA_H#endif //WORK_TEA_H#include <cstdint> #include <cstdio>// 定义TEA加密算法的轮次&#xff0c;一般建议为32轮 #define TEA_ROUNDS 32 #define DELTA 0x9e3779b9// TEA加密函数 v…...

信号量和线程池

1.信号量 POSIX信号量&#xff0c;用与同步操作&#xff0c;达到无冲突的访问共享资源目的&#xff0c;POSIX信号量可以用于线程间同步 初始化信号量 #include <semaphore.h> int sem_init(sem_t *sem, int pshared, unsigned int value); sem&#xff1a;指向sem_t类…...

【人工智能】10分钟解读-深入浅出大语言模型(LLM)——从ChatGPT到未来AI的演进

文章目录 一、前言二、GPT模型的发展历程2.1 自然语言处理的局限2.2 机器学习的崛起2.3 深度学习的兴起2.3.1 神经网络的训练2.3.2 神经网络面临的挑战 2.4 Transformer的革命性突破2.4.1 Transformer的核心组成2.4.2 Transformer的优势 2.5 GPT模型的诞生与发展2.5.1 GPT的核心…...

「QT」几何数据类 之 QPointF 浮点型点类

✨博客主页何曾参静谧的博客&#x1f4cc;文章专栏「QT」QT5程序设计&#x1f4da;全部专栏「VS」Visual Studio「C/C」C/C程序设计「UG/NX」BlockUI集合「Win」Windows程序设计「DSA」数据结构与算法「UG/NX」NX二次开发「QT」QT5程序设计「File」数据文件格式「PK」Parasolid…...

可能是全网第一个MySQL Workbench插件编写技巧

引言 应公司要求&#xff0c;数据库的敏感数据在写入到数据库中要进行加密&#xff0c;但是在测试环境查询数据的时候要手动解密&#xff0c;很不方便&#xff0c;有的时候数据比较多&#xff0c;解密比较麻烦。遂研究了一下如何通过 MySQL Workbench 的插件来实现查询数据一键…...

D62【python 接口自动化学习】- python基础之数据库

day62 SQL 基础 学习日期&#xff1a;20241108 学习目标&#xff1a;MySQL数据库-- 131 SQL基础和DDL 学习笔记&#xff1a; SQL的概述 SQL语言的分类 SQL的语法特征 DDL - 库管理 DDL - 表管理 总结 SQL是结构化查询语言&#xff0c;用于操作数据库&#xff0c;通用于绝大…...

探索美赛:从准备到挑战的详细指南

前言 美国大学生数学建模竞赛&#xff08;MCM/ICM&#xff09;&#xff0c;简称“美赛”&#xff0c;是全球规模最大的数学建模竞赛之一。它鼓励参赛者通过数学建模来解决现实世界中的复杂问题&#xff0c;广受世界各地大学生的欢迎。本文将详细介绍美赛的全过程&#xff0c;从…...

IP地址查询——IP归属地离线库

自从网络监管部门将现实IP地址列入监管条例&#xff0c;IP地址的离线库变成网络企业发展业务的不可或缺的一部分&#xff0c;那么IP地址离线库是什么&#xff0c;又能够给我们带来什么呢&#xff1f; 什么是IP地址离线库&#xff1f; IP地址离线库是IP地址服务商将通过各种合…...

“倒时差”用英语怎么说?生活英语口语学习柯桥外语培训

“倒时差”用英语怎么说&#xff1f; “倒时差”&#xff0c;这个让无数旅人闻之色变的词汇&#xff0c;在英语中对应的正是“Jet Lag”。"Jet" 指的是喷气式飞机&#xff0c;而 "lag" 指的是落后或延迟。这个短语形象地描述了当人们乘坐喷气式飞机快速穿…...

Linux入门攻坚——37、Linux防火墙-iptables-3

私网地址访问公网地址的问题&#xff0c;请求时&#xff0c;目标地址是公网地址&#xff0c;可以在公网路由器中进行路由&#xff0c;但是响应报文的目的地址是私网地址&#xff0c;此时在公网路由器上就会出现问题。公网地址访问私网地址的问题&#xff0c;需要先访问一个公网…...

微服务架构面试内容整理-安全性-Spring Security

Spring Security 是 Spring 框架中用于实现认证和授权的安全模块,它提供了全面的安全解决方案,可以帮助开发者保护 Web 应用、微服务和 API 免受常见的安全攻击。以下是 Spring Security 的主要特点、工作原理和使用场景: 主要特点 1. 身份认证与授权: 提供多种认证方式,…...

新的服务器Centos7.6 安装基础的环境配置(新服务器可直接粘贴使用配置)

常见的基础服务器配置之Centos命令 正常来说都是安装一个docker基本上很多问题都可以解决了&#xff0c;我基本上都是通过docker去管理一些容器如&#xff1a;mysql、redis、mongoDB等之类的镜像&#xff0c;还有一些中间件如kafka。下面就安装一个 docker 和 nginx 的相关配置…...

深度学习:广播机制

广播机制&#xff08;Broadcasting&#xff09;是 PyTorch&#xff08;以及其他深度学习框架如 NumPy&#xff09;中的一种强大功能&#xff0c;它允许不同形状的张量进行逐元素操作&#xff0c;而不需要显式地扩展张量的维度。广播机制通过自动扩展较小的张量来匹配较大张量的…...

音视频入门基础:FLV专题(25)——通过FFprobe显示FLV文件每个packet的信息

音视频入门基础&#xff1a;FLV专题系列文章&#xff1a; 音视频入门基础&#xff1a;FLV专题&#xff08;1&#xff09;——FLV官方文档下载 音视频入门基础&#xff1a;FLV专题&#xff08;2&#xff09;——使用FFmpeg命令生成flv文件 音视频入门基础&#xff1a;FLV专题…...

网络编程(Modbus进阶)

思维导图 Modbus RTU&#xff08;先学一点理论&#xff09; 概念 Modbus RTU 是工业自动化领域 最广泛应用的串行通信协议&#xff0c;由 Modicon 公司&#xff08;现施耐德电气&#xff09;于 1979 年推出。它以 高效率、强健性、易实现的特点成为工业控制系统的通信标准。 包…...

相机Camera日志实例分析之二:相机Camx【专业模式开启直方图拍照】单帧流程日志详解

【关注我&#xff0c;后续持续新增专题博文&#xff0c;谢谢&#xff01;&#xff01;&#xff01;】 上一篇我们讲了&#xff1a; 这一篇我们开始讲&#xff1a; 目录 一、场景操作步骤 二、日志基础关键字分级如下 三、场景日志如下&#xff1a; 一、场景操作步骤 操作步…...

Mybatis逆向工程,动态创建实体类、条件扩展类、Mapper接口、Mapper.xml映射文件

今天呢&#xff0c;博主的学习进度也是步入了Java Mybatis 框架&#xff0c;目前正在逐步杨帆旗航。 那么接下来就给大家出一期有关 Mybatis 逆向工程的教学&#xff0c;希望能对大家有所帮助&#xff0c;也特别欢迎大家指点不足之处&#xff0c;小生很乐意接受正确的建议&…...

【网络安全产品大调研系列】2. 体验漏洞扫描

前言 2023 年漏洞扫描服务市场规模预计为 3.06&#xff08;十亿美元&#xff09;。漏洞扫描服务市场行业预计将从 2024 年的 3.48&#xff08;十亿美元&#xff09;增长到 2032 年的 9.54&#xff08;十亿美元&#xff09;。预测期内漏洞扫描服务市场 CAGR&#xff08;增长率&…...

ESP32 I2S音频总线学习笔记(四): INMP441采集音频并实时播放

简介 前面两期文章我们介绍了I2S的读取和写入&#xff0c;一个是通过INMP441麦克风模块采集音频&#xff0c;一个是通过PCM5102A模块播放音频&#xff0c;那如果我们将两者结合起来&#xff0c;将麦克风采集到的音频通过PCM5102A播放&#xff0c;是不是就可以做一个扩音器了呢…...

如何将联系人从 iPhone 转移到 Android

从 iPhone 换到 Android 手机时&#xff0c;你可能需要保留重要的数据&#xff0c;例如通讯录。好在&#xff0c;将通讯录从 iPhone 转移到 Android 手机非常简单&#xff0c;你可以从本文中学习 6 种可靠的方法&#xff0c;确保随时保持连接&#xff0c;不错过任何信息。 第 1…...

ServerTrust 并非唯一

NSURLAuthenticationMethodServerTrust 只是 authenticationMethod 的冰山一角 要理解 NSURLAuthenticationMethodServerTrust, 首先要明白它只是 authenticationMethod 的选项之一, 并非唯一 1 先厘清概念 点说明authenticationMethodURLAuthenticationChallenge.protectionS…...

【论文阅读28】-CNN-BiLSTM-Attention-(2024)

本文把滑坡位移序列拆开、筛优质因子&#xff0c;再用 CNN-BiLSTM-Attention 来动态预测每个子序列&#xff0c;最后重构出总位移&#xff0c;预测效果超越传统模型。 文章目录 1 引言2 方法2.1 位移时间序列加性模型2.2 变分模态分解 (VMD) 具体步骤2.3.1 样本熵&#xff08;S…...

让回归模型不再被异常值“带跑偏“,MSE和Cauchy损失函数在噪声数据环境下的实战对比

在机器学习的回归分析中&#xff0c;损失函数的选择对模型性能具有决定性影响。均方误差&#xff08;MSE&#xff09;作为经典的损失函数&#xff0c;在处理干净数据时表现优异&#xff0c;但在面对包含异常值的噪声数据时&#xff0c;其对大误差的二次惩罚机制往往导致模型参数…...

【Linux】Linux 系统默认的目录及作用说明

博主介绍&#xff1a;✌全网粉丝23W&#xff0c;CSDN博客专家、Java领域优质创作者&#xff0c;掘金/华为云/阿里云/InfoQ等平台优质作者、专注于Java技术领域✌ 技术范围&#xff1a;SpringBoot、SpringCloud、Vue、SSM、HTML、Nodejs、Python、MySQL、PostgreSQL、大数据、物…...