[PyTorch][chapter 44][RNN]
简介
循环神经网络(Recurrent Neural Network, RNN)是一类以序列(sequence)数据为输入,在序列的演进方向进行递归(recursion)且所有节点(循环单元)按链式连接的递归神经网络(recursive neural network) [1] 。
对循环神经网络的研究始于二十世纪80-90年代,并在二十一世纪初发展为深度学习(deep learning)算法之一 [2] ,其中双向循环神经网络(Bidirectional RNN, Bi-RNN)和长短期记忆网络(Long Short-Term Memory networks,LSTM)是常见的循环神经网络 [3] 。
目录:
- 模型
- Forward
- Backward
- nn.RNN
- nn.RNNCell
一 模型

: t 时刻样本输入
: t 时刻样本隐藏状态
t时刻输出
: t时刻样本预测类别(只有分类算法才有)
: t 时刻损失函数
二 RNN 前向传播算法 Forward
2.1 t 时刻隐藏值 更新
其中激活函数通常用tanh
2.2 t 时刻输出
其中激活函数 为softmax
三 RNN 反向传播算法 BPTT(back-propagation through time)
3.1 输出层参数v,c梯度
3.2 隐藏层参数更新
定义
证明:
对于最后一个时刻T
3.3 计算权重系数U,W,b
四 nn.RNN
这里面介绍PyTorch 使用RNN 类
4.1 更新规则:

| 参数 | 说明 |
| L | 时间序列长度T or 句子长度为 L |
| N | batch_size |
| d | 输入特征维度 |
# -*- coding: utf-8 -*-
"""
Created on Wed Jul 19 15:30:01 2023@author: chengxf2
"""import torch
import torch.nn as nnrnn = nn.RNN(input_size=100, hidden_size=5)
param = rnn._parametersprint("\n 权重系数",param.keys())print(rnn.weight_ih_l0.shape)
输出:

RNN参数说明:
| 参数 | 说明 |
| input_size =d | 输入维度 |
| hidden_size=h: | 隐藏层维度 |
| num_layers | RNN默认是 1 层。该参数大于 1 时,会形成 Stacked RNN,又称多层RNN或深度RNN; |
| nonlinearity | 非线性激活函数。可以选择 tanh 或 relu |
| bias | 即偏置。默认启用 |
| batch_first | 选择让 batch_size=N 作为输入的形状中的第一个参数。默认是 False,L × N × d 形状; 当 batch_first=True 时, N × L × d |
| dropout | 即是否启用 dropout。如要启用,则应设置 dropout 的概率,此时除最后一层外,RNN的每一层后面都会加上一个dropout层。默认是 0,即不启用 |
| bidirectional | 即是否启用双向RNN,默认关闭 |
4.2 单层例子
import torch.nn as nn
import torchrnn = nn.RNN(input_size= 100, hidden_size=20, num_layers=1)X = torch.randn(10,3,100)h_0 = torch.zeros(1,3,20)out,h = rnn(X,h_0)print("\n out.shape",out.shape)print("\n h.shape",h.shape)
out: 包含每个时刻的 隐藏值
h : 最后一个时刻的隐藏值
4.3 多层RNN

把当前的隐藏层输出,作为下一层的输入
第一个隐藏层输出:
第二个隐藏层输出
# -*- coding: utf-8 -*-
"""
Created on Mon Jul 24 11:43:30 2023@author: chengxf2
"""import torch.nn as nn
import torch
rnn = nn.RNN(input_size=100, hidden_size=20, num_layers=2)
print(rnn)x = torch.randn(10,3,100) #默认是[L,N,d]结构out,h =rnn(x)print(out.shape, h.shape)

5 nn.RNNCell
nn.RNN封装了整个RNN实现的过程, PyTorch 还提供了 nn.RNNCell 可以
自己实现RNN

5.1 单层RNN
# -*- coding: utf-8 -*-
"""
Created on Mon Jul 24 11:43:30 2023@author: chengxf2
"""
import torch
from torch import nndef main():model = nn.RNNCell(input_size=10, hidden_size=20)h1= torch.zeros(3,20)trainData = torch.randn(8,3,10)for xt in trainData:h1= model(xt,h1)print(h1.shape)if __name__ == "__main__":main()

6.2 多层RNN
# -*- coding: utf-8 -*-
"""
Created on Mon Jul 24 11:43:30 2023@author: chengxf2
"""
import torch
from torch import nndef main():layer1 = nn.RNNCell(input_size=40, hidden_size=30)layer2 = nn.RNNCell(input_size=30, hidden_size=20)h1= torch.zeros(3,30)h2= torch.zeros(3,20)trainData = torch.randn(8,3,40)for xt in trainData:h1= layer1(xt,h1)h2 = layer2(h1,h2)print(h1.shape)print(h2.shape)if __name__ == "__main__":main()
参考:
Pytorch 循环神经网络 nn.RNN() nn.RNNCell() nn.Parameter()不同方法实现_老光头_ME2CS的博客-CSDN博客
相关文章:
[PyTorch][chapter 44][RNN]
简介 循环神经网络(Recurrent Neural Network, RNN)是一类以序列(sequence)数据为输入,在序列的演进方向进行递归(recursion)且所有节点(循环单元)按链式连接的递归神经网…...
20230726----重返学习-vue3项目实战-知乎日报第3天-TS-简历
day-121-one-hundred-and-twenty-one-20230726-vue3项目实战-知乎日报第3天-TS-简历 vue3项目实战-知乎日报第3天 封装按钮组件 jsx函数式组件 只能做静态页面,内部没有方法让它自动更新。 封装第三方按钮-非计算属性版 封装第三方按钮-不使用计算属性 src/c…...
TypeScript 在前端开发中的应用实践
TypeScript 在前端开发中的应用实践 TypeScript 已经成为前端开发领域越来越多开发者的首选工具。它是一种静态类型的超集,由 Microsoft 推出,为开发者提供了强大的静态类型检查、面向对象编程和模块化开发的特性,解决了 JavaScript 的动态类…...
商业密码应用安全性评估量化评估规则2023版更新点
《商用密码应用安全性评估量化评估规则》(2023版)已于2023年7月发布,将在8月1日正式执行。相比较2021版,新版本有多处内容更新,具体包括5处微调和5处较大更新。 微调部分(5处) 序号2021版本202…...
【软件测试】单元测试工具---Junit详解
1.junit 1.1 junit是什么 JUnit是一个Java语言的单元测试框架。 虽然我们已经学习了selenium测试框架,但是有的时候测试用例很多,我们需要一个测试工具来管理这些测试用例,Junit就是一个很好的管理工具,简单来说Junit是一个针对…...
【算法基础:搜索与图论】3.4 求最短路算法(Dijkstrabellman-fordspfaFloyd)
文章目录 求最短路算法总览Dijkstra朴素 Dijkstra 算法(⭐原理讲解!⭐重要!)(用于稠密图)例题:849. Dijkstra求最短路 I代码1——使用邻接表代码2——使用邻接矩阵 补充:稠密图和稀疏…...
【Matlab】基于卷积神经网络的数据分类预测(Excel可直接替换数据)
【Matlab】基于卷积神经网络的数据分类预测(Excel可直接替换数据) 1.模型原理2.数学公式3.文件结构4.Excel数据5.分块代码6.完整代码7.运行结果1.模型原理 基于卷积神经网络(Convolutional Neural Network,CNN)的数据分类预测是一种常见的深度学习方法,广泛应用于图像识…...
【C++ 重要知识点总结】自定义类型-枚举和联合
复杂类型 除了类之外还有Union、Enum连个特殊的类型。 Union 概念 union即为联合,它是一种特殊的类。通过关键字union进行定义,一个union可以有多个数据成员。 union Token{char cval;int ival;double dval; };用法 互斥赋值。在任意时刻,…...
Centos MySql安装,手动安装保姆级教程
1.删除原有的mariadb,不然mysql装不进去 查询MAriaDB命令 rpm -qa|grep mariadb 删除 rpm -e --nodeps mariadb-libs-5.5.60-1.el7_5.x86_64 (yum -y remove mysql 如需要清除服务器上以前安装过的MySQL可执行此命令,执行前一…...
电脑C盘空间大小调整 --- 扩容(扩大/缩小)--磁盘分区大小调整/移动
概述: 此方法适合C盘右边没有可分配空间(空闲空间)的情况,D盘有数据不方便删除D盘分区的情况下,可以使用傲梅分区助手软件进行跨分区调整分区大小,不会损坏数据。反之可直接使用系统的磁盘管理工具进行调整…...
centos7设置网桥网卡
安装bridge-utils yum install bridge-utils修改ens33 网卡 TYPEEthernet BOOTPROTOnone DEFROUTEyes IPV4_FAILURE_FATALno IPV6INITyes IPV6_AUTOCONFyes IPV6_DEFROUTEyes IPV6_FAILURE_FATALno NAMEens33 UUID04b97484-25c8-45c7-8c8c-e335e8080e10 DEVICEens33 ONBOOTye…...
TCP模型和工作沟通方式
我们如何与客户沟通?理科生和技术人员可能在沟通技巧方面有所欠缺。 那么我们如何理解和掌握沟通的原则和技巧呢?我发现TCP网络交互模型很好的描述了沟通的原则和要点。下面我们就从TCP来讲沟通的过程。 TCP的客户端就像客户(甲方ÿ…...
Langchain 的 ConversationSummaryBufferMemory
Langchain 的 ConversationSummaryBufferMemory ConversationSummaryBufferMemory 在内存中保留最近交互的缓冲区,但不仅仅是完全刷新旧的交互,而是将它们编译成摘要并使用两者。但与之前的实现不同的是,它使用令牌长度而不是交互次数来确定何…...
【Rust 基础篇】Rust 通道实现单个消费者多个生产者模式
导言 在 Rust 中,我们可以使用通道(Channel)来实现单个消费者多个生产者模式,简称为 MPMC。MPMC 是一种常见的并发模式,适用于多个线程同时向一个通道发送数据,而另一个线程从通道中消费数据的场景。本篇博…...
HTTP协议各版本介绍
HTTP协议是一种用于传输Web页面和其他资源的协议。 下面详细介绍一下HTTP的各个版本: 1.HTTP/0.9 这是最早的HTTP版本,于1991年发布。它非常简单,只能传输HTML格式的文本,并且不支持其他类型的资源、请求头和状态码。 2.HTTP/1…...
玩转ChatGPT:Custom instructions (vol. 1)
一、写在前面 据说GPT-4又被削了,前几天让TA改代码,来来回回好几次才成功。 可以看到之前3小时25条的限制,现在改成了3小时50条,可不可以理解为:以前一个指令能完成的任务,现在得两条指令? 可…...
黄东旭:The Future of Database,掀开 TiDB Serverless 的引擎盖
在 PingCAP 用户峰会 2023 上, PingCAP 联合创始人兼 CTO 黄东旭 分享了“The Future of Database”为主题的演讲, 介绍了 TiDB Serverless 作为未来一代数据库的核心设计理念。黄东旭 通过分享个人经历和示例,强调了数据库的服务化而非服务化…...
Linux环境搭建(XShell+云服务器)
好久不见啊,放假也有一周左右了,简单休息了下(就是玩了几天~~),最近也是在学习Linux,现在正在初步的学习阶段,本篇将会简单的介绍一下Linux操作系统和介绍Linux环境的安装与配置,来帮…...
-bash: /bin/rm: Argument list too long
有套数据库环境,.aud文件太多导致/u01分区使用率过高,rm清理时发现报错如下 [rootdb1 audit]# rm -rf ASM1_ora_*202*.aud -bash: /bin/rm: Argument list too long [rootdb1 audit]# rm -rf ASM1_ora_*20200*.aud -bash: /bin/rm: Argument list too…...
5个步骤完成Linux 搭建Jdk1.8环境
1:首先,在Linux系统中创建一个目录,用于存放JDK文件。可以选择在/opt目录下创建一个新的文件夹,例如/opt/jdk。 sudo mkdir /opt/jdk 2:将下载的jdk-8u381-linux-x64.tar.gz文件复制到新创建的目录中。 sudo cp jdk…...
web vue 项目 Docker化部署
Web 项目 Docker 化部署详细教程 目录 Web 项目 Docker 化部署概述Dockerfile 详解 构建阶段生产阶段 构建和运行 Docker 镜像 1. Web 项目 Docker 化部署概述 Docker 化部署的主要步骤分为以下几个阶段: 构建阶段(Build Stage):…...
idea大量爆红问题解决
问题描述 在学习和工作中,idea是程序员不可缺少的一个工具,但是突然在有些时候就会出现大量爆红的问题,发现无法跳转,无论是关机重启或者是替换root都无法解决 就是如上所展示的问题,但是程序依然可以启动。 问题解决…...
C++:std::is_convertible
C++标志库中提供is_convertible,可以测试一种类型是否可以转换为另一只类型: template <class From, class To> struct is_convertible; 使用举例: #include <iostream> #include <string>using namespace std;struct A { }; struct B : A { };int main…...
.Net框架,除了EF还有很多很多......
文章目录 1. 引言2. Dapper2.1 概述与设计原理2.2 核心功能与代码示例基本查询多映射查询存储过程调用 2.3 性能优化原理2.4 适用场景 3. NHibernate3.1 概述与架构设计3.2 映射配置示例Fluent映射XML映射 3.3 查询示例HQL查询Criteria APILINQ提供程序 3.4 高级特性3.5 适用场…...
关于iview组件中使用 table , 绑定序号分页后序号从1开始的解决方案
问题描述:iview使用table 中type: "index",分页之后 ,索引还是从1开始,试过绑定后台返回数据的id, 这种方法可行,就是后台返回数据的每个页面id都不完全是按照从1开始的升序,因此百度了下,找到了…...
什么是库存周转?如何用进销存系统提高库存周转率?
你可能听说过这样一句话: “利润不是赚出来的,是管出来的。” 尤其是在制造业、批发零售、电商这类“货堆成山”的行业,很多企业看着销售不错,账上却没钱、利润也不见了,一翻库存才发现: 一堆卖不动的旧货…...
1.3 VSCode安装与环境配置
进入网址Visual Studio Code - Code Editing. Redefined下载.deb文件,然后打开终端,进入下载文件夹,键入命令 sudo dpkg -i code_1.100.3-1748872405_amd64.deb 在终端键入命令code即启动vscode 需要安装插件列表 1.Chinese简化 2.ros …...
Python 包管理器 uv 介绍
Python 包管理器 uv 全面介绍 uv 是由 Astral(热门工具 Ruff 的开发者)推出的下一代高性能 Python 包管理器和构建工具,用 Rust 编写。它旨在解决传统工具(如 pip、virtualenv、pip-tools)的性能瓶颈,同时…...
uniapp 开发ios, xcode 提交app store connect 和 testflight内测
uniapp 中配置 配置manifest 文档:manifest.json 应用配置 | uni-app官网 hbuilderx中本地打包 下载IOS最新SDK 开发环境 | uni小程序SDK hbulderx 版本号:4.66 对应的sdk版本 4.66 两者必须一致 本地打包的资源导入到SDK 导入资源 | uni小程序SDK …...
保姆级【快数学会Android端“动画“】+ 实现补间动画和逐帧动画!!!
目录 补间动画 1.创建资源文件夹 2.设置文件夹类型 3.创建.xml文件 4.样式设计 5.动画设置 6.动画的实现 内容拓展 7.在原基础上继续添加.xml文件 8.xml代码编写 (1)rotate_anim (2)scale_anim (3)translate_anim 9.MainActivity.java代码汇总 10.效果展示 逐帧…...
