代码 RNN原理及手写复现
29、PyTorch RNN的原理及其手写复现_哔哩哔哩_bilibili
笔记连接: https://pan.baidu.com/s/1_Sm7ptEiJtTTq3vQWgOTNg?pwd=2rei 提取码: 2rei
import torch
import torch.nn as nn
bs,T=2,3 # 批大小,输入序列长度
input_size,hidden_size = 2,3 # 输入特征大小,隐含层特征大小
input = torch.randn(bs,T,input_size) # 随机初始化一个输入特征序列
h_prev = torch.zeros(bs,hidden_size) # 初始隐含状态
# step1 调用pytorch RNN API
rnn = nn.RNN(input_size,hidden_size,batch_first=True)
rnn_output,state_finall = rnn(input,h_prev.unsqueeze(0))print(rnn_output)
print(state_finall)
# step2 手写 rnn_forward函数,实现RNN的计算原理
def rnn_forward(input,weight_ih,weight_hh,bias_ih,bias_hh,h_prev):bs,T,input_size = input.shapeh_dim = weight_ih.shape[0]h_out = torch.zeros(bs,T,h_dim) # 初始化一个输出(状态)矩阵for t in range(T):x = input[:,t,:].unsqueeze(2) # 获取当前时刻的输入特征,bs*input_size*1w_ih_batch = weight_ih.unsqueeze(0).tile(bs,1,1) # bs * h_dim * input_sizew_hh_batch = weight_hh.unsqueeze(0).tile(bs,1,1)# bs * h_dim * h_dimw_times_x = torch.bmm(w_ih_batch,x).squeeze(-1) # bs*h_dimw_times_h = torch.bmm(w_hh_batch,h_prev.unsqueeze(2)).squeeze(-1) # bs*h_himh_prev = torch.tanh(w_times_x + bias_ih + w_times_h + bias_hh)h_out[:,t,:] = h_prevreturn h_out,h_prev.unsqueeze(0)
# 验证结果
custom_rnn_output,custom_state_finall = rnn_forward(input,rnn.weight_ih_l0,rnn.weight_hh_l0,rnn.bias_ih_l0,rnn.bias_hh_l0,h_prev)
print(custom_rnn_output)
print(custom_state_finall)
print(torch.allclose(rnn_output,custom_rnn_output))
print(torch.allclose(state_finall,custom_state_finall))
# step3 手写一个 bidirectional_rnn_forward函数,实现双向RNN的计算原理
def bidirectional_rnn_forward(input,weight_ih,weight_hh,bias_ih,bias_hh,h_prev,weight_ih_reverse,weight_hh_reverse,bias_ih_reverse,bias_hh_reverse,h_prev_reverse):bs,T,input_size = input.shapeh_dim = weight_ih.shape[0]h_out = torch.zeros(bs,T,h_dim*2) # 初始化一个输出(状态)矩阵,注意双向是两倍的特征大小forward_output = rnn_forward(input,weight_ih,weight_hh,bias_ih,bias_hh,h_prev)[0] # forward layerbackward_output = rnn_forward(torch.flip(input,[1]),weight_ih_reverse,weight_hh_reverse,bias_ih_reverse, bias_hh_reverse,h_prev_reverse)[0] # backward layer# 将input按照时间的顺序翻转h_out[:,:,:h_dim] = forward_outputh_out[:,:,h_dim:] = torch.flip(backward_output,[1]) #需要再翻转一下 才能和forward output拼接h_n = torch.zeros(bs,2,h_dim) # 要最后的状态连接h_n[:,0,:] = forward_output[:,-1,:]h_n[:,1,:] = backward_output[:,-1,:]h_n = h_n.transpose(0,1)return h_out,h_n# return h_out,h_out[:,-1,:].reshape((bs,2,h_dim)).transpose(0,1)# 验证一下 bidirectional_rnn_forward的正确性
bi_rnn = nn.RNN(input_size,hidden_size,batch_first=True,bidirectional=True)
h_prev = torch.zeros((2,bs,hidden_size))
bi_rnn_output,bi_state_finall = bi_rnn(input,h_prev)for k,v in bi_rnn.named_parameters():print(k,v)
custom_bi_rnn_output,custom_bi_state_finall = bidirectional_rnn_forward(input,bi_rnn.weight_ih_l0,bi_rnn.weight_hh_l0,bi_rnn.bias_ih_l0,bi_rnn.bias_hh_l0,h_prev[0],bi_rnn.weight_ih_l0_reverse,bi_rnn.weight_hh_l0_reverse,bi_rnn.bias_ih_l0_reverse,bi_rnn.bias_hh_l0_reverse,h_prev[1])
print("Pytorch API output")
print(bi_rnn_output)
print(bi_state_finall)print("\n custom bidirectional_rnn_forward function output:")
print(custom_bi_rnn_output)
print(custom_bi_state_finall)
print(torch.allclose(bi_rnn_output,custom_bi_rnn_output))
print(torch.allclose(bi_state_finall,custom_bi_state_finall))
相关文章:
代码 RNN原理及手写复现
29、PyTorch RNN的原理及其手写复现_哔哩哔哩_bilibili 笔记连接: https://pan.baidu.com/s/1_Sm7ptEiJtTTq3vQWgOTNg?pwd2rei 提取码: 2rei import torch import torch.nn as nn bs,T2,3 # 批大小,输入序列长度 input_size,hidden_size 2,3 # 输入特征大小&a…...
企业官网的在线客服,如何提高效果?
企业官网的在线客服,如何提高效果? 作者:开源呼叫中心系统 FreeIPCC,github地址:https://github.com/lihaiya/freeipcc 提高企业官网在线客服的效果,是提升客户体验、增强客户满意度和忠诚度的关键。一个…...
「实战应用」如何可视化 DHTMLX Scheduler 中的资源工作量?
DHTMLX Scheduler是一个全面的 UI 组件,用于处理面向业务的 Web 应用程序中复杂的调度和任务管理需求。但是,某些场景可能需要自定义解决方案。例如,如果项目的资源(即劳动力)有限,则需要确保以更高的精度分…...
论文阅读《BEVFormer》
BEVFormer: Learning Bird’s-Eye-View Representation from Multi-Camera Images via Spatiotemporal Transformers 目录 摘要1 介绍2 相关工作2.1 基于Transformer的2D感知 摘要 3D视觉感知任务对于自动驾驶系统至关重要,包括基于多相机图像的3D检测和地图分割。…...
sql专题 之 sql的执行顺序
文章目录 sql的执行顺序sql语句的格式实际的执行顺序:虚拟表 vs 数据集虚拟表 结果集总结嵌套查询在sql查询中的执行顺序 前文我们了解了sql常用的语句,这次我们对于这些语句来个小思索 戳这里→ sql专题 之 常用命令 sql的执行顺序 SQL语句的执行顺序是…...
Vue3 -- 基于Vue3+TS+Vite项目【项目搭建及初始化】
兼容性注意: Vite 需要 Node.js 版本 18+ 或 20+。然而,有些模板需要依赖更高的 Node 版本才能正常运行,当你的包管理器发出警告时,请注意升级你的 Node 版本。【摘抄自vite官网】 这里我用的node版本是 v18.20.2 创建项目: 创建项目我们可以使用npm、yarn、pnpm、bun …...
CTF-RE: TEA系列解密脚本
// // Created by A5rZ on 2024/10/26. //#ifndef WORK_TEA_H #define WORK_TEA_H#endif //WORK_TEA_H#include <cstdint> #include <cstdio>// 定义TEA加密算法的轮次,一般建议为32轮 #define TEA_ROUNDS 32 #define DELTA 0x9e3779b9// TEA加密函数 v…...
信号量和线程池
1.信号量 POSIX信号量,用与同步操作,达到无冲突的访问共享资源目的,POSIX信号量可以用于线程间同步 初始化信号量 #include <semaphore.h> int sem_init(sem_t *sem, int pshared, unsigned int value); sem:指向sem_t类…...
【人工智能】10分钟解读-深入浅出大语言模型(LLM)——从ChatGPT到未来AI的演进
文章目录 一、前言二、GPT模型的发展历程2.1 自然语言处理的局限2.2 机器学习的崛起2.3 深度学习的兴起2.3.1 神经网络的训练2.3.2 神经网络面临的挑战 2.4 Transformer的革命性突破2.4.1 Transformer的核心组成2.4.2 Transformer的优势 2.5 GPT模型的诞生与发展2.5.1 GPT的核心…...
「QT」几何数据类 之 QPointF 浮点型点类
✨博客主页何曾参静谧的博客📌文章专栏「QT」QT5程序设计📚全部专栏「VS」Visual Studio「C/C」C/C程序设计「UG/NX」BlockUI集合「Win」Windows程序设计「DSA」数据结构与算法「UG/NX」NX二次开发「QT」QT5程序设计「File」数据文件格式「PK」Parasolid…...
可能是全网第一个MySQL Workbench插件编写技巧
引言 应公司要求,数据库的敏感数据在写入到数据库中要进行加密,但是在测试环境查询数据的时候要手动解密,很不方便,有的时候数据比较多,解密比较麻烦。遂研究了一下如何通过 MySQL Workbench 的插件来实现查询数据一键…...
D62【python 接口自动化学习】- python基础之数据库
day62 SQL 基础 学习日期:20241108 学习目标:MySQL数据库-- 131 SQL基础和DDL 学习笔记: SQL的概述 SQL语言的分类 SQL的语法特征 DDL - 库管理 DDL - 表管理 总结 SQL是结构化查询语言,用于操作数据库,通用于绝大…...
探索美赛:从准备到挑战的详细指南
前言 美国大学生数学建模竞赛(MCM/ICM),简称“美赛”,是全球规模最大的数学建模竞赛之一。它鼓励参赛者通过数学建模来解决现实世界中的复杂问题,广受世界各地大学生的欢迎。本文将详细介绍美赛的全过程,从…...
IP地址查询——IP归属地离线库
自从网络监管部门将现实IP地址列入监管条例,IP地址的离线库变成网络企业发展业务的不可或缺的一部分,那么IP地址离线库是什么,又能够给我们带来什么呢? 什么是IP地址离线库? IP地址离线库是IP地址服务商将通过各种合…...
“倒时差”用英语怎么说?生活英语口语学习柯桥外语培训
“倒时差”用英语怎么说? “倒时差”,这个让无数旅人闻之色变的词汇,在英语中对应的正是“Jet Lag”。"Jet" 指的是喷气式飞机,而 "lag" 指的是落后或延迟。这个短语形象地描述了当人们乘坐喷气式飞机快速穿…...
Linux入门攻坚——37、Linux防火墙-iptables-3
私网地址访问公网地址的问题,请求时,目标地址是公网地址,可以在公网路由器中进行路由,但是响应报文的目的地址是私网地址,此时在公网路由器上就会出现问题。公网地址访问私网地址的问题,需要先访问一个公网…...
微服务架构面试内容整理-安全性-Spring Security
Spring Security 是 Spring 框架中用于实现认证和授权的安全模块,它提供了全面的安全解决方案,可以帮助开发者保护 Web 应用、微服务和 API 免受常见的安全攻击。以下是 Spring Security 的主要特点、工作原理和使用场景: 主要特点 1. 身份认证与授权: 提供多种认证方式,…...
新的服务器Centos7.6 安装基础的环境配置(新服务器可直接粘贴使用配置)
常见的基础服务器配置之Centos命令 正常来说都是安装一个docker基本上很多问题都可以解决了,我基本上都是通过docker去管理一些容器如:mysql、redis、mongoDB等之类的镜像,还有一些中间件如kafka。下面就安装一个 docker 和 nginx 的相关配置…...
深度学习:广播机制
广播机制(Broadcasting)是 PyTorch(以及其他深度学习框架如 NumPy)中的一种强大功能,它允许不同形状的张量进行逐元素操作,而不需要显式地扩展张量的维度。广播机制通过自动扩展较小的张量来匹配较大张量的…...
音视频入门基础:FLV专题(25)——通过FFprobe显示FLV文件每个packet的信息
音视频入门基础:FLV专题系列文章: 音视频入门基础:FLV专题(1)——FLV官方文档下载 音视频入门基础:FLV专题(2)——使用FFmpeg命令生成flv文件 音视频入门基础:FLV专题…...
当00后测试员给CEO系统提了487个缺陷后
在软件测试领域,一个年轻测试员的行动往往能引发行业深思。故事始于一家科技公司新上线的“CEO决策支持系统”——一个旨在为高管提供实时数据分析和战略建议的核心平台。项目团队信心满满地推进上线,却未料到一位00后测试员小陈的介入,彻底改…...
RIFE智能帧插值技术全解析:从原理到实战的视频流畅度提升指南
RIFE智能帧插值技术全解析:从原理到实战的视频流畅度提升指南 【免费下载链接】video2x A machine learning-based video super resolution and frame interpolation framework. Est. Hack the Valley II, 2018. 项目地址: https://gitcode.com/GitHub_Trending/v…...
AI运维管理与安全防护设备功率MOSFET选型方案——高效、可靠与智能驱动系统设计指南
随着智能化运维与主动安全防护需求的爆发式增长,AI边缘计算节点、智能传感器与安全执行单元已成为现代基础设施管理的核心。其电源管理与信号驱动系统作为设备可靠运行与实时响应的基石,直接决定了系统的能效、稳定性及防护等级。功率MOSFET作为该系统中…...
终极指南:一键解决iPhone USB网络共享驱动问题
终极指南:一键解决iPhone USB网络共享驱动问题 【免费下载链接】Apple-Mobile-Drivers-Installer Powershell script to easily install Apple USB and Mobile Device Ethernet (USB Tethering) drivers on Windows! 项目地址: https://gitcode.com/gh_mirrors/ap…...
电子电路中的“心脏”:电源
一、语言特性:Java 26 与模式匹配进化 1.1 Java 26 语言级别支持 IDEA 2026.1 EAP 最引人注目的变化之一,就是新增 Java 26 语言级别支持。这意味着开发者可以提前体验和测试即将在 JDK 26 中正式发布的语言特性。 其中最重要的变化是对 JEP 530 的全面支…...
数据主权时代,企业即时通讯厂商选型推荐
BeeWorks作为企业级私有化 IM,主打安全可控、深度协同、信创适配,在政企、军工、金融等强合规场景口碑突出。BeeWorks 定位为安全专属数字化协作平台,核心是私有化部署 全链路安全 业务深度融合,区别于通用 SaaS IM。1. 核心架构…...
Hunyuan-MT-7B实战教程:Pixel Language Portal与RAG架构结合提升专业翻译
Hunyuan-MT-7B实战教程:Pixel Language Portal与RAG架构结合提升专业翻译 1. 产品概览与核心价值 Pixel Language Portal(像素语言跨维传送门)是一款基于腾讯Hunyuan-MT-7B大模型构建的创新翻译工具。与传统翻译软件不同,它将语…...
从Prompt到成稿|像素剧本圣殿输入剧情大纲→输出标准剧本全流程
从Prompt到成稿|像素剧本圣殿输入剧情大纲→输出标准剧本全流程 1. 工具介绍:像素剧本圣殿 像素剧本圣殿是一款基于Qwen2.5-14B-Instruct大模型深度优化的专业剧本创作工具。它将先进的AI文本生成能力与独特的8-Bit复古视觉风格相结合,为编…...
intv_ai_mk11开源可部署实践:支持Webhook回调,可对接企业微信/钉钉/飞书通知
intv_ai_mk11开源可部署实践:支持Webhook回调,可对接企业微信/钉钉/飞书通知 1. 项目概述 intv_ai_mk11是一款基于Llama架构的AI对话机器人,拥有7B参数规模,能够运行在GPU服务器上。这个开源项目不仅提供了强大的对话能力&#…...
AutoHotkey自动化效率提升指南:从入门到进阶的全场景应用技巧
AutoHotkey自动化效率提升指南:从入门到进阶的全场景应用技巧 【免费下载链接】antimicrox Graphical program used to map keyboard buttons and mouse controls to a gamepad. Useful for playing games with no gamepad support. 项目地址: https://gitcode.co…...
