单向/双向,单层/多层RNN输入输出维度问题
单向/双向,单层/多层RNN输入输出维度问题
- RNN
- 单层单向RNN
- Rnn Cell
- Rnn
- 双层单向RNN
- 单层双向RNN
- 双层双向RNN
RNN
单层单向RNN
Rnn Cell
循环神经网络最原始的Simple RNN实现如下图所示:

下面写出单个时间步对应的Rnn Cell计算公式:

如果用矩阵运算视角来看待的话,单个时间步对应的Rnn Cell计算过程如下图所示:

代码示例如下:
import torchbatch_size = 2 # 批量大小
seq_len = 3 # 序列长度(单个RNN层包含的Time Step数量)
input_size = 4 # 输入数据的维度
hidden_size = 2 # 隐藏层的维度cell = torch.nn.RNNCell(input_size=input_size, hidden_size=hidden_size)
dataset = torch.rand(seq_len, batch_size, input_size)
hidden = torch.zeros(batch_size, hidden_size)# 使用Rnn Cell加循环实现一个Rnn层处理一批输入序列的过程
for idx, input in enumerate(dataset):print('='*10,idx,'='*10)print('input size:', input.shape)hidden = cell(input, hidden)print('outputs size:', hidden.shape)print(hidden)
输出:
========== 0 ==========
input size: torch.Size([2, 4])
outputs size: torch.Size([2, 2])
tensor([[ 0.1077, -0.7972],[-0.1070, -0.9517]], grad_fn=<TanhBackward0>)
========== 1 ==========
input size: torch.Size([2, 4])
outputs size: torch.Size([2, 2])
tensor([[-0.4524, -0.8874],[ 0.0443, -0.9716]], grad_fn=<TanhBackward0>)
========== 2 ==========
input size: torch.Size([2, 4])
outputs size: torch.Size([2, 2])
tensor([[-0.2941, -0.9751],[-0.3554, -0.9399]], grad_fn=<TanhBackward0>)
图解运算过程:
- 输入数据采用三维来表示 (序列长度,批量大小,输入向量维度) , 将序列长度作为第0维是因为这样可以很方便的将同属于一个Time Step处理的批次序列数据第i个词一次性送入Rnn Cell中处理。


- 隐藏层输出的向量为Rnn Cell对当前Time Step输入的 x t x_{t} xt和 h t − 1 h_{t-1} ht−1做完信息融合和空间变换后得到的结果,其维度可以保持与输入 x t x_{t} xt一致或者降低维度。

- 由Rnn Cell内的权重矩阵 W i h W_{ih} Wih负责完成对输入 x t x_{t} xt的线性变换,使其由输入维度input_size变为hidden_size; 然后通过加法运算与同样处理的 h t − 1 h_{t-1} ht−1做信息融合,再通过tanh激活做一次非线性变换,增强模型拟合数据的能力。

大家重点关注Rnn Cell如何在一个Time Step内同时处理一批数据的过程。
关于Rnn Cell内部具体的运算过程还可以简化为一次矩阵乘法运算完成,如下图所示:

Rnn
我们将上文用Rnn Cell加循环实现的例子通过Rnn层来实现一遍:
import torchbatch_size = 2 # 批量大小
seq_len = 3 # 序列长度(单个RNN层包含的Time Step数量)
input_size = 4 # 输入数据的维度
hidden_size = 2 # 隐藏层的维度rnn = torch.nn.RNN(input_size=input_size, hidden_size=hidden_size,num_layers=1)
dataset = torch.rand(seq_len, batch_size, input_size)
hidden = torch.zeros(1,batch_size, hidden_size) # 此处1为RNN层数,下文会讲out, hidden = rnn(dataset, hidden)
print('output size:', out.shape)
print('Hidden size:', hidden.shape)
输出:
output size: torch.Size([3, 2, 2])
Hidden size: torch.Size([1, 2, 2])
Rnn可以一次性处理完输入的批量序列数据,并返回处理后的结果和最后一个Time Step隐藏层的输出结果,返回结果的维度解释如下:
- out:(序列长度,批量大小,隐藏层维度)
- hidden: (隐藏层层数,批量大小,隐藏层维度)

上图中按照Time Step维度展开的Rnn Cell为同一个实例对象,也就是上文中最开始给出的Rnn Cell加循环的图解过程,所以不同Time Step处理时涉及到的权重矩阵是共享的。
双层单向RNN
双层单向RNN处理的流程如下图所示:

我们下面通过代码来验证一下:
import torchbatch_size = 2 # 批量大小
seq_len = 3 # 序列长度(单个RNN层包含的Time Step数量)
input_size = 4 # 输入数据的维度
hidden_size = 2 # 隐藏层的维度rnn = torch.nn.RNN(input_size=input_size, hidden_size=hidden_size,num_layers=2)
dataset = torch.rand(seq_len, batch_size, input_size)
hidden = torch.zeros(2,batch_size, hidden_size) # 此处1为RNN层数,下文会讲out, hidden = rnn(dataset, hidden)
print('output size:', out.shape)
print('Hidden size:', hidden.shape)
输出:
output size: torch.Size([3, 2, 2])
Hidden size: torch.Size([2, 2, 2])
改为双层RNN后,我们初始化输入的隐藏层向量要同时为层1的RNN和层2的RNN都提供,因此hidden的第一维num_layers要变为2; 同理,我们最终处理完毕后,会得到层1和层2的最后一个Time Step的输出,因此返回的hidden维度第一维也是2。
- out:(序列长度,批量大小,隐藏层维度)
- hidden: (隐藏层层数,批量大小,隐藏层维度)
单层双向RNN
单层双向RNN处理的流程如下图所示:

我们下面通过代码来验证一下:
import torchbatch_size = 2 # 批量大小
seq_len = 3 # 序列长度(单个RNN层包含的Time Step数量)
input_size = 4 # 输入数据的维度
hidden_size = 2 # 隐藏层的维度rnn = torch.nn.RNN(input_size=input_size, hidden_size=hidden_size,num_layers=1,bidirectional=True)
dataset = torch.rand(seq_len, batch_size, input_size)
hidden = torch.zeros(2,batch_size, hidden_size) # 此处1为RNN层数,下文会讲out, hidden = rnn(dataset, hidden)
print('output size:', out.shape)
print('Hidden size:', hidden.shape)
输出:
output size: torch.Size([3, 2, 4])
Hidden size: torch.Size([2, 2, 2])
当我们考虑双向时,最终输出得到的out和hidden维度都会发生改变:
- out:(序列长度,批量大小,隐藏层维度*2)
- hidden: (隐藏层层数*2,批量大小,隐藏层维度)
当设置RNN为双向后,在每一个Time Step我们会得到正向和反向计算的两份输出,因此每个Time Step隐藏层的输出为正向和反向输出的concat,这里的concat运算过程如下所示:
# 维度: (序列长度,批量大小,隐藏层维度) --> [3,2,4]
# 假设正向RNN输出如下
forward_outputs = torch.tensor([[[0.1, 0.2, 0.3, 0.4], [0.5, 0.6, 0.7, 0.8]], # t=0[[0.9, 1.0, 1.1, 1.2], [1.3, 1.4, 1.5, 1.6]], # t=1[[1.7, 1.8, 1.9, 2.0], [2.1, 2.2, 2.3, 2.4]] # t=2
])
# 假设反向RNN输出如下
backward_outputs = torch.tensor([[[2.5, 2.6, 2.7, 2.8], [2.9, 3.0, 3.1, 3.2]], # t=0[[3.3, 3.4, 3.5, 3.6], [3.7, 3.8, 3.9, 4.0]], # t=1[[4.1, 4.2, 4.3, 4.4], [4.5, 4.6, 4.7, 4.8]] # t=2
])
# 拼接(concatenation)后的结果如下
concatenated_outputs = torch.tensor([[[0.1, 0.2, 0.3, 0.4, 2.5, 2.6, 2.7, 2.8], [0.5, 0.6, 0.7, 0.8, 2.9, 3.0, 3.1, 3.2]], # t=0[[0.9, 1.0, 1.1, 1.2, 3.3, 3.4, 3.5, 3.6], [1.3, 1.4, 1.5, 1.6, 3.7, 3.8, 3.9, 4.0]], # t=1[[1.7, 1.8, 1.9, 2.0, 4.1, 4.2, 4.3, 4.4], [2.1, 2.2, 2.3, 2.4, 4.5, 4.6, 4.7, 4.8]] # t=2
])
同理,采用双向RNN后,正向和反向的最后一个Time Step计算完后都会得到一个hidden输出,因此最终计算完返回的hidden第一维是2;我们初始化时也要同时给出正向和反向RNN的hidden,因此初始化输入的hidden第一维也是2。
双层双向RNN
双层双向RNN处理流程如下:
- 第一层双向 RNN:
- 正向 RNN 从序列的开始到结束处理数据。
- 反向 RNN 从序列的结束到开始处理数据。
- 两个方向的输出在每个时间步上进行拼接,形成第一层的输出。
- 第二层双向 RNN:
- 第二层的输入是第一层的输出。
- 第二层的正向 RNN 从第一层的输出序列的开始到结束处理数据。
- 第二层的反向 RNN 从第一层的输出序列的结束到开始处理数据。
- 两个方向的输出在每个时间步上进行拼接,形成第二层的输出。
我们下面通过代码来验证一下:
import torchbatch_size = 2 # 批量大小
seq_len = 3 # 序列长度(单个RNN层包含的Time Step数量)
input_size = 4 # 输入数据的维度
hidden_size = 2 # 隐藏层的维度rnn = torch.nn.RNN(input_size=input_size, hidden_size=hidden_size,num_layers=2,bidirectional=True)
dataset = torch.rand(seq_len, batch_size, input_size)
hidden = torch.zeros(2*2,batch_size, hidden_size) # 此处1为RNN层数,下文会讲out, hidden = rnn(dataset, hidden)
print('output size:', out.shape)
print('Hidden size:', hidden.shape)
通过断点debug可以看到rnn对象实例内部各层权重矩阵的维度:

重点关注weight_ih_l1 和 weight_ih_l1_reverse ,这两个权重矩阵分别为第二层RNN的正向和反向层对应的输入数据的权重矩阵,可以看到其维度为[4,2] (X*W.T这里给出W转置后的维度) , 说明此处单个Time Step输入第二层Rnn Cell的输入数据维度为[2,4],由此可推断输入的为第一层双向RNN拼接输出后结果。
输出:
output size: torch.Size([3, 2, 4])
Hidden size: torch.Size([4, 2, 2])
这里output输出的维度和单层双向RNN的结果一致,这是因为第一层RNN的输出结果维度和初始input数据维度相同。
hidden第一维输出变成了4,这是因为每一个双向RNN层都会在最后一个Time Step提供正向和反向两个hidden输出,而因为我们这里采用了两层双向RNN,因此最终会得到四个hidden输出。
相关文章:
单向/双向,单层/多层RNN输入输出维度问题
单向/双向,单层/多层RNN输入输出维度问题 RNN单层单向RNNRnn CellRnn 双层单向RNN单层双向RNN双层双向RNN RNN 单层单向RNN Rnn Cell 循环神经网络最原始的Simple RNN实现如下图所示: 下面写出单个时间步对应的Rnn Cell计算公式: 如果用矩阵运算视角来看待的话&…...
chromium-mojo
https://chromium.googlesource.com/chromium/src//refs/heads/main/mojo/README.md 相关类:https://zhuanlan.zhihu.com/p/426069459 Core:https://source.chromium.org/chromium/chromium/src//main:mojo/core/README.md;bpv1;bpt0 embedder:https://source.chr…...
ZooKeeper 的典型应用场景:从概念到实践
引言 在分布式系统的生态中,ZooKeeper 作为一个协调服务框架,扮演着至关重要的角色。它的设计目的是提供一个简单高效的解决方案来处理分布式系统中常见的协调问题。本文将详细探讨 ZooKeeper 的典型应用场景,包括但不限于配置管理、命名服务…...
缓存组件<keep-alive>
缓存组件<keep-alive> 1.组件作用 组件, 默认会缓存内部的所有组件实例,当组件需要缓存时首先考虑使用此组件。 2.使用场景 场景1:tab切换时,对应的组件保持原状态,使用keep-alive组件 使用:KeepAlive | Vu…...
YouBIP 项目
技术方案 难点 成效 项目背景 库存管理涉及大量数据,如何在前端实现高效的数据展示和交互是一个挑战。库存管理系统需要处理大量的入库、出库、盘点等操作,尤其是在大企业或多仓库场景下,高并发操作可能导致数据库锁争用、响应延迟等问题。…...
react概览webpack基础
react概览 课程介绍 webpack 构建依赖图->bundle 首屏渲染: 减少白屏等待时间 数据、结构、样式都返回。需要服务器的支持 性能优化 ***webpack干的事情 模块化开发 优势: 多人团队协作开发 可复用 单例:全局冲突 闭包 模块导入的顺序 req…...
DeepSeek 助力 Vue 开发:打造丝滑的步骤条
前言:哈喽,大家好,今天给大家分享一篇文章!并提供具体代码帮助大家深入理解,彻底掌握!创作不易,如果能帮助到大家或者给大家一些灵感和启发,欢迎收藏关注哦 💕 目录 Deep…...
STM32的HAL库开发---高级定时器---互补输出带死区实验
一、互补输出简介 互补输出:OCx输出高电平,则互补通道OCxN输出低电平。OCx输出低电平,则互补通道OCxN输出高电平。 带死区控制的互补输出:OCx输出高电平时,则互补通道OCxN过一会再输出输出低电平。这个时间里输出的电…...
Vue07
一、Vuex 概述 目标:明确Vuex是什么,应用场景以及优势 1.是什么 Vuex 是一个 Vue 的 状态管理工具,状态就是数据。 大白话:Vuex 是一个插件,可以管理 Vue 通用的数据 (多组件共享的数据)。例如:购物车数…...
【CXX-Qt】2 CXX-Qt #[cxx_qt::bridge] 宏指南
#[cxx_qt::bridge] 宏是用于在 Rust 中创建一个模块,该模块能够桥接 Rust 和 Qt(通过 C)之间的交互。它允许你将 Rust 类型暴露给 Qt 作为 QObject、Q_SIGNAL、Q_PROPERTY 等,同时也能够将 Qt 的特性和类型绑定到 Rust 中…...
鸿蒙接入支付宝SDK后模拟器无法运行,报错error: install parse native so failed.
鸿蒙项目接入支付宝后,运行提示error: install parse native so failed. 该问题可能由于设备支持的 Abi 类型与 C 工程中的不匹配导致. 官网error: install parse native so failed.错误解决办法 根据官网提示在模块build-profile.json5中添加“x86_64”依然报错 问…...
局域网使用Ollama(Linux)
解决局域网无法连接Ollama服务的问题 在搭建和使用Ollama服务的过程中,可能会遇到局域网内无法连接的情况。经过排查发现,若开启了代理软件,尤其是Hiddify,会导致此问题。这一发现耗费了我数小时的排查时间,希望能给大…...
Deepseek系列从v3到R易背面经版
deepseek v3 base要点 MTP : Multi-Token Prediction 训练时: 1. 把前一个block中input tokens经过embedding layer和transformer block的输出,进入output head之前的内容记为h,与下一个block的input tokens经过embedding layer输出的内容都…...
Redis深入学习
目录 Redis是什么? Redis使用场景 Redis线程模型 Redis执行命令是单线程的为什么还这么快? Redis持久化 Redis 事务 Key 过期策略 Redis 和 mysql 如何保证数据一致? 缓存穿透 缓存击穿 缓存雪崩 Redis是什么? redis是一…...
《从入门到精通:蓝桥杯编程大赛知识点全攻略》(十一)-回文日期、移动距离、日期问题
前言 在这篇博客中,我们将通过模拟的方法来解决三道经典的算法题:回文日期、移动距离和日期问题。这些题目不仅考察了我们的基础编程能力,还挑战了我们对日期处理和数学推理的理解。通过模拟算法,我们能够深入探索每个问题的核心…...
在Uniapp中使用阿里云OSS插件实现文件上传
在开发小程序时,文件上传是一个常见的需求。阿里云OSS(Object Storage Service)是一个强大的云存储服务,可以帮助我们高效地存储和管理文件。本文将介绍如何在Uniapp小程序中使用阿里云OSS插件实现文件上传功能。 1. 准备工作 首…...
9 数据流图
9 数据流图 9.1数据平衡原则 子图缺少处理后的数据操作结果返回前端应用以及后端数据库返回操作结果到数据管理中间件。 9.2解题技巧 实件名 存储名 加工名 数据流...
IDEA查看项目依赖包及其版本
一.IDEA将现有项目转换为Maven项目 在IntelliJ IDEA中,将现有项目转换为Maven项目是一个常见的需求,可以通过几种不同的方法来实现。Maven是一个强大的构建工具,它可以帮助自动化项目的构建过程,管理依赖关系,以及其他许多方面。 添加Maven支持 如果你的项目还没有pom.xm…...
【数据结构】_栈与队列经典算法OJ:栈与队列的互相实现
目录 1. 用队列实现栈 1.1 题目链接及描述 1.2 解题思路 1.3 程序 2. 用栈实现队列 2.1 题目链接及描述 2.2 解题思路 2.3 程序 1. 用队列实现栈 1.1 题目链接及描述 1. 题目链接 : 225. 用队列实现栈 - 力扣(LeetCode) 2. 题目描…...
SAP-ABAP:ROLLBACK WORK使用详解
在SAP ABAP 中,ROLLBACK WORK 语句用于回滚当前事务(LUW,Logical Unit of Work),撤销自上次提交或回滚以来的所有数据库更改。它通常与 COMMIT WORK 配合使用,确保数据一致性。 关键点: 回滚作…...
.Net框架,除了EF还有很多很多......
文章目录 1. 引言2. Dapper2.1 概述与设计原理2.2 核心功能与代码示例基本查询多映射查询存储过程调用 2.3 性能优化原理2.4 适用场景 3. NHibernate3.1 概述与架构设计3.2 映射配置示例Fluent映射XML映射 3.3 查询示例HQL查询Criteria APILINQ提供程序 3.4 高级特性3.5 适用场…...
前端倒计时误差!
提示:记录工作中遇到的需求及解决办法 文章目录 前言一、误差从何而来?二、五大解决方案1. 动态校准法(基础版)2. Web Worker 计时3. 服务器时间同步4. Performance API 高精度计时5. 页面可见性API优化三、生产环境最佳实践四、终极解决方案架构前言 前几天听说公司某个项…...
OkHttp 中实现断点续传 demo
在 OkHttp 中实现断点续传主要通过以下步骤完成,核心是利用 HTTP 协议的 Range 请求头指定下载范围: 实现原理 Range 请求头:向服务器请求文件的特定字节范围(如 Range: bytes1024-) 本地文件记录:保存已…...
srs linux
下载编译运行 git clone https:///ossrs/srs.git ./configure --h265on make 编译完成后即可启动SRS # 启动 ./objs/srs -c conf/srs.conf # 查看日志 tail -n 30 -f ./objs/srs.log 开放端口 默认RTMP接收推流端口是1935,SRS管理页面端口是8080,可…...
Qt Http Server模块功能及架构
Qt Http Server 是 Qt 6.0 中引入的一个新模块,它提供了一个轻量级的 HTTP 服务器实现,主要用于构建基于 HTTP 的应用程序和服务。 功能介绍: 主要功能 HTTP服务器功能: 支持 HTTP/1.1 协议 简单的请求/响应处理模型 支持 GET…...
Mysql中select查询语句的执行过程
目录 1、介绍 1.1、组件介绍 1.2、Sql执行顺序 2、执行流程 2.1. 连接与认证 2.2. 查询缓存 2.3. 语法解析(Parser) 2.4、执行sql 1. 预处理(Preprocessor) 2. 查询优化器(Optimizer) 3. 执行器…...
深入浅出深度学习基础:从感知机到全连接神经网络的核心原理与应用
文章目录 前言一、感知机 (Perceptron)1.1 基础介绍1.1.1 感知机是什么?1.1.2 感知机的工作原理 1.2 感知机的简单应用:基本逻辑门1.2.1 逻辑与 (Logic AND)1.2.2 逻辑或 (Logic OR)1.2.3 逻辑与非 (Logic NAND) 1.3 感知机的实现1.3.1 简单实现 (基于阈…...
C++课设:简易日历程序(支持传统节假日 + 二十四节气 + 个人纪念日管理)
名人说:路漫漫其修远兮,吾将上下而求索。—— 屈原《离骚》 创作者:Code_流苏(CSDN)(一个喜欢古诗词和编程的Coder😊) 专栏介绍:《编程项目实战》 目录 一、为什么要开发一个日历程序?1. 深入理解时间算法2. 练习面向对象设计3. 学习数据结构应用二、核心算法深度解析…...
Golang——9、反射和文件操作
反射和文件操作 1、反射1.1、reflect.TypeOf()获取任意值的类型对象1.2、reflect.ValueOf()1.3、结构体反射 2、文件操作2.1、os.Open()打开文件2.2、方式一:使用Read()读取文件2.3、方式二:bufio读取文件2.4、方式三:os.ReadFile读取2.5、写…...
日常一水C
多态 言简意赅:就是一个对象面对同一事件时做出的不同反应 而之前的继承中说过,当子类和父类的函数名相同时,会隐藏父类的同名函数转而调用子类的同名函数,如果要调用父类的同名函数,那么就需要对父类进行引用&#…...
