当前位置: 首页 > news >正文

# 深入理解RNN(一):循环神经网络的核心计算机制

深入理解RNN:循环神经网络的核心计算机制

RNN示意图
在这里插入图片描述

引言

在自然语言处理、时间序列预测、语音识别等涉及序列数据的领域,循环神经网络(RNN)一直扮演着核心角色。尽管近年来Transformer等架构逐渐成为主流,RNN的基本原理和思想依然对于理解深度学习处理序列数据的方式至关重要。本文将深入剖析RNN的核心计算机制,通过公式、代码和直观解释,帮助读者真正掌握这一经典算法的内在逻辑。

RNN的基本思想

传统前馈神经网络的主要局限在于:它无法处理序列数据中的时序依赖关系。每个输入被视为独立的个体,网络无法"记住"之前看到的信息。循环神经网络正是为了解决这一问题而设计的。

RNN的核心思想:通过在网络中引入循环连接,使当前时刻的输出不仅依赖于当前的输入,还依赖于之前时刻的"记忆"。这种设计使得RNN能够保持"状态",从而有效处理序列数据。

RNN的核心计算公式

RNN的计算过程可以用两个关键公式表示:

h t = t a n h ( W h ⋅ h t − 1 + W x ⋅ x t + b h ) h_t = tanh(W_h · h_{t-1} + W_x · x_t + b_h) ht=tanh(Whht1+Wxxt+bh)

先不要被吓到了,脑海中先想到CNN的y=kx+b , CNN里的线性变化。
然后想到y = tan ( kx+b ),引入非线性。
然后就是要引入上一步的信息,所以有了 W_h · h_{t-1} ,所以计算的这里,只是多了上一步的状态信息而已。就想着RNN相比CNN,其实就是多了一个缓冲区,会把上一步的隐藏层的值,加入到这里去计算,

y t = W y ⋅ h t + b y y_t = W_y · h_t + b_y yt=Wyht+by

这是最后输出当前值的步骤,y_t不参与后续的计算,是上一步的隐藏层的信息参与了计算。所以奥秘都在隐藏层里,最后这步的作用你可以理解为和CNN最后的FC层是一个意思

其中:

  • h t h_t ht 是当前时刻t的隐藏状态(即"记忆")
  • h t − 1 h_{t-1} ht1 是前一时刻的隐藏状态
  • x t x_t xt 是当前时刻的输入
  • y t y_t yt 是当前时刻的输出
  • W h W_h Wh, W x W_x Wx, W y W_y Wy 是权重矩阵
  • b h b_h bh, b y b_y by 是偏置项
  • t a n h tanh tanh 是激活函数(也可以使用其他函数如ReLU)

这两个公式理解了还是很简单,基本涵盖了RNN的全部精髓。让我们简要看看每个组成部分的意义。

公式详解:记忆与学习的数学表达

隐藏状态更新(第一个公式)

h t = t a n h ( W h ⋅ h t − 1 + W x ⋅ x t + b h ) h_t = tanh(W_h · h_{t-1} + W_x · x_t + b_h) ht=tanh(Whht1+Wxxt+bh)

这个公式描述了RNN如何更新其"记忆"。我们可以将其拆解为几个关键部分:

  1. 历史信息: W h ⋅ h t − 1 W_h · h_{t-1} Whht1
    • 这部分将前一时刻的隐藏状态 h t − 1 h_{t-1} ht1与权重矩阵 W h W_h Wh相乘
    • W h W_h Wh决定了保留多少历史信息,以及如何将这些信息与当前输入融合
    • 这正是RNN区别于传统神经网络的关键所在

这里可能不好理解,我们仍然可以把 h t − 1 h_{t-1} ht1看成一个变量x,哎,然后这个上一步的隐藏层的信息,我们是不是也要考虑下它如何影响下一步啊,因为每个数/词对下一个数/词的影响肯定是不同的,所以我们也给上一步的信息搞个 k x + b kx+b kx+b,也就是 W h ⋅ h t − 1 + b h ^ W_h · h_{t-1}+b_{\hat{h}} Whht1+bh^,然后放到公式里

h t = t a n h ( W h ⋅ h t − 1 + b h ^ + W x ⋅ x t + b h ) h_t = tanh(W_h · h_{t-1}+b_{\hat{h}}+ W_x · x_t + b_h) ht=tanh(Whht1+bh^+Wxxt+bh)

你一手常数项合并,咔,公式就出来了

h t = t a n h ( W h ⋅ h t − 1 + W x ⋅ x t + b h ) h_t = tanh(W_h · h_{t-1} + W_x · x_t + b_h) ht=tanh(Whht1+Wxxt+bh)

  1. 当前输入: W x ⋅ x t W_x · x_t Wxxt

    • 当前时刻的输入 x t x_t xt与权重矩阵 W x W_x Wx相乘
    • W x W_x Wx决定了网络如何解释当前输入的重要性
  2. 非线性变换: t a n h ( . . . ) tanh(...) tanh(...)

    • 将线性组合通过 t a n h tanh tanh激活函数进行非线性变换
    • t a n h tanh tanh将值压缩到[-1,1]范围,帮助稳定网络动态
    • 这种非线性是神经网络表达复杂函数的关键

输出层计算(第二个公式)

y t = W y ⋅ h t + b y y_t = W_y · h_t + b_y yt=Wyht+by

这个公式描述了RNN如何基于当前隐藏状态生成输出:

  1. 隐藏状态 h t h_t ht包含了直到当前时刻的所有相关信息的"摘要"
  2. 权重矩阵 W y W_y Wy将这个隐藏状态映射到所需的输出维度
  3. 输出 y t y_t yt可以是多种形式,取决于任务类型(如分类概率、预测值等)

没错,另一个 k x + b kx+b kx+b,不是吗?

RNN的维度分析

不用看,实践会告诉你答案,你会在你以后的代码实践中对维度有更深刻的理解

为了更好地理解RNN的计算过程,我们需要明确各个参数的维度:

假设:

  • 输入维度: x t ∈ R d i n x_t \in \mathbb{R}^{d_{in}} xtRdin
  • 隐藏状态维度: h t ∈ R d h h_t \in \mathbb{R}^{d_h} htRdh
  • 输出维度: y t ∈ R d o u t y_t \in \mathbb{R}^{d_{out}} ytRdout

则各权重矩阵的维度为:

  • W x ∈ R d h × d i n W_x \in \mathbb{R}^{d_h \times d_{in}} WxRdh×din
  • W h ∈ R d h × d h W_h \in \mathbb{R}^{d_h \times d_h} WhRdh×dh
  • W y ∈ R d o u t × d h W_y \in \mathbb{R}^{d_{out} \times d_h} WyRdout×dh
  • b h ∈ R d h b_h \in \mathbb{R}^{d_h} bhRdh
  • b y ∈ R d o u t b_y \in \mathbb{R}^{d_{out}} byRdout

这种维度设计确保了矩阵乘法的兼容性,同时也反映了数据在网络中的流动方式。

RNN的直观解释

抛开前面的数学公式,我们可以用更直觉的方式理解RNN的工作原理:

  1. 记忆机制:想象RNN有一个"记事本"(隐藏状态),它会在每个时间步更新这个记事本
  2. 选择性记忆:不是所有信息都同等重要,权重矩阵决定记住什么、忘记什么
  3. 信息混合:RNN将之前的记忆与新的观察结合起来,产生更新的理解
  4. 输出决策:基于当前的"记忆状态",RNN做出当前时刻的判断或预测

Python实现:手写一个简单RNN

让我们通过Python代码实现一个简单的RNN,以更好地理解其计算过程:

import numpy as npclass SimpleRNN:def __init__(self, input_size, hidden_size, output_size):"""初始化RNN参数"""# 初始化权重矩阵(使用随机值)self.Wx = np.random.randn(hidden_size, input_size) * 0.01  # 输入到隐藏self.Wh = np.random.randn(hidden_size, hidden_size) * 0.01  # 隐藏到隐藏self.Wy = np.random.randn(output_size, hidden_size) * 0.01  # 隐藏到输出# 初始化偏置项self.bh = np.zeros((hidden_size, 1))  # 隐藏层偏置self.by = np.zeros((output_size, 1))  # 输出层偏置# 保存尺寸信息self.input_size = input_sizeself.hidden_size = hidden_sizeself.output_size = output_sizedef forward(self, x_sequence, h0=None):"""前向传播过程"""# x_sequence形状: (seq_length, input_size, 1)seq_length = len(x_sequence)# 如果没有提供初始隐藏状态,则初始化为零if h0 is None:h0 = np.zeros((self.hidden_size, 1))# 保存所有时间步的隐藏状态和输出(用于反向传播)h = np.zeros((seq_length+1, self.hidden_size, 1))y = np.zeros((seq_length, self.output_size, 1))h[0] = h0  # 设置初始隐藏状态# 按时间步前向传播for t in range(seq_length):# 更新隐藏状态: h_t = tanh(W_h·h_{t-1} + W_x·x_t + b_h)h[t+1] = np.tanh(np.dot(self.Wh, h[t]) + np.dot(self.Wx, x_sequence[t]) + self.bh)# 计算输出: y_t = W_y·h_t + b_yy[t] = np.dot(self.Wy, h[t+1]) + self.byreturn y, h[1:]  # 返回所有输出和隐藏状态def predict(self, x_sequence):"""使用模型进行预测"""y, _ = self.forward(x_sequence)return y# 示例:如何使用这个RNN
if __name__ == "__main__":# 创建一个输入维度为3,隐藏层大小为5,输出维度为2的RNNrnn = SimpleRNN(input_size=3, hidden_size=5, output_size=2)# 创建一个序列数据:3个时间步,每步是一个3维向量seq_data = [np.array([[0.1], [0.2], [0.3]]),  # x_1np.array([[0.2], [0.3], [0.4]]),  # x_2np.array([[0.3], [0.4], [0.5]])   # x_3]# 前向传播outputs, hidden_states = rnn.forward(seq_data)print("输出序列形状:", len(outputs), "x", outputs[0].shape)print("第一个时间步的输出:\n", outputs[0])print("最后一个时间步的隐藏状态:\n", hidden_states[-1])

RNN的缺点与改进版本

尽管RNN的设计非常优雅,但它存在一些严重的局限性:

  1. 梯度消失/爆炸问题:在长序列上,梯度要么趋近于零(无法学习),要么爆炸(不稳定)
  2. 长期依赖问题:基本RNN难以捕捉长距离的依赖关系
  3. 信息覆盖:新信息可能完全覆盖旧信息,导致"遗忘"重要的历史信息

为了解决这些问题,研究者提出了多种RNN的变体:

  1. LSTM (长短期记忆网络):引入了"门"机制,可以选择性地记住或忘记信息
  2. GRU (门控循环单元):LSTM的简化版本,性能相近但计算更高效
  3. 双向RNN:同时考虑过去和未来的信息,适用于有完整序列的场景

这些改进版本的核心计算公式更为复杂,后面有机会我们都摸一下,但基本思想与原始RNN相同:通过更新隐藏状态来保持对序列的"记忆"。

RNN在实际项目中的应用

RNN及其变体广泛应用于各种序列处理任务,至今RNN都在时序任务上仍有一席之地,但是那是另一个故事了。

总结:RNN的核心要点

  1. RNN的本质是一种带有循环连接的神经网络,使其能够处理序列数据
  2. 核心计算公式体现了RNN如何结合历史信息和当前输入
  3. 隐藏状态是RNN的"记忆",它随着序列处理不断更新
  4. 权重共享是RNN的关键特性,使其能够处理任意长度的序列
  5. 梯度问题是基本RNN的主要缺陷,导致了LSTM等改进版本的出现

尽管Transformer等新型架构在许多任务上已经超越了RNN,理解RNN的核心计算机制仍然是掌握序列模型的重要基础。RNN简洁的设计和直观的计算过程,体现了序列学习的基本原理,这些原理在更复杂的模型中依然适用。

哎,我上来就是一手 k x + b kx+b kx+b

参考资源

  1. Goodfellow, I., Bengio, Y., & Courville, A. (2016). Deep Learning. MIT Press.
  2. Karpathy, A. The Unreasonable Effectiveness of Recurrent Neural Networks. http://karpathy.github.io/2015/05/21/rnn-effectiveness/
  3. Olah, C. Understanding LSTM Networks. http://colah.github.io/posts/2015-08-Understanding-LSTMs/

关于作者:是个逗比

相关文章:

# 深入理解RNN(一):循环神经网络的核心计算机制

深入理解RNN:循环神经网络的核心计算机制 RNN示意图 引言 在自然语言处理、时间序列预测、语音识别等涉及序列数据的领域,循环神经网络(RNN)一直扮演着核心角色。尽管近年来Transformer等架构逐渐成为主流,RNN的基本原理和思想依然对于理…...

分布式锁—6.Redisson的同步器组件

大纲 1.Redisson的分布式锁简单总结 2.Redisson的Semaphore简介 3.Redisson的Semaphore源码剖析 4.Redisson的CountDownLatch简介 5.Redisson的CountDownLatch源码剖析 1.Redisson的分布式锁简单总结 (1)可重入锁RedissonLock (2)公平锁RedissonFairLock (3)联锁MultiL…...

同步 Fork 仓库的命令

同步 Fork 仓库的命令 要将您 fork 的仓库的 main 分支与原始仓库(fork 源)同步,您可以使用以下命令: 首先,确保您已经添加了原始仓库作为远程仓库(如果尚未添加): git remote add…...

基于PySide6的CATIA零件自动化着色工具开发实践

引言 在汽车及航空制造领域,CATIA作为核心的CAD设计软件,其二次开发能力对提升设计效率具有重要意义。本文介绍一种基于Python的CATIA零件着色工具开发方案,通过PySide6实现GUI交互,结合COM接口操作实现零件着色自动化。该方案成…...

OpenManus 的提示词

OpenManus 的提示词 引言英文提示词的详细内容工具集的详细说明中文翻译的详细内容GitHub 仓库信息背景分析总结 引言 OpenManus 是一个全能 AI 助手,旨在通过多种工具高效地完成用户提出的各种任务,包括编程、信息检索、文件处理和网页浏览等。其系统提…...

Ubuntu-docker安装mysql

只记录执行步骤。 1 手动下载myql镜像(拉去华为云镜像) docker pull swr.cn-east-3.myhuaweicloud.com/library/mysql:latest配置并启动mysql 在opt下创建文件夹 命令:cd /opt/ 命令:mkdir mysql_docker 命令:cd m…...

Electron桌面应用开发:自定义菜单

完成初始应用的创建Electron桌面应用开发:创建应用,随后我们就可以自定义软件的菜单了。菜单可以帮助用户快速找到和执行命令,而不需要记住复杂的快捷键,通过将相关功能组织在一起,用户可以更容易地发现和使用应用程序…...

理解 JavaScript 中的浅拷贝与深拷贝

在 JavaScript 开发中,我们经常需要复制对象或数组。然而,复制的方式不同,可能会导致不同的结果。本文将详细介绍 浅拷贝 和 深拷贝 的概念、区别以及实现方式,帮助你更好地理解和使用它们。 1. 什么是浅拷贝? 定义 …...

【Java开发指南 | 第三十五篇】Maven + Tomcat Web应用程序搭建

读者可订阅专栏:Java开发指南 |【CSDN秋说】 文章目录 前言Maven Tomcat Web应用程序搭建1、使用Maven构建新项目2、单击项目,连续按两次shift键,输入"添加",选择"添加框架支持"3、选择Java Web程序4、点击&…...

从0到1入门Linux

一、常用命令 ls 列出目录内容 cd切换目录mkdir创建新目录rm删除文件或目录cp复制文件或目录mv移动或重命名文件和目录cat查看文件内容grep在文件中查找指定字符串ps查看当前进程状态top查看内存kill终止进程df -h查看磁盘空间存储情况iotop -o直接查看比较高的磁盘读写程序up…...

golang 从零单排 (一) 安装环境

1.下载安装 打开网址The Go Programming Language 直接点击下载go1.24.1.windows-amd64.msi 下载完成 直接双击下一步 下一步 安装完成 环境变量自动设置不必配置 2.验证 win r 输入cmd 打开命令行 输入go version...

如何下载和使用Git:初学者指南

🌟 如何下载和使用Git:初学者指南 在当今的软件开发中,Git已经成为不可或缺的版本控制系统。无论你是独立开发者还是团队成员,掌握Git的基本操作都能帮助你更高效地管理代码。今天,我将详细介绍如何下载和使用Git&…...

SQL_语法

1 数据库 1.1 新增 create database [if not exists] 数据库名; 1.2 删除 drop database [if exists] 数据库名; 1.3 查询 (1) 查看所有数据库 show databases; (2) 查看当前数据库下的所有表 show tables; 2 数据表 2.1 新增 (1) 创建表 create table [if not exists…...

基于Python实现的智能旅游推荐系统(Django)

基于Python实现的智能旅游推荐系统(Django) 开发语言:Python 数据库:MySQL所用到的知识:Django框架工具:pycharm、Navicat 系统功能实现 总体设计 系统实现 系统首页模块 统首页页面主要包括首页,旅游资讯,景点信息…...

安孚科技携手政府产业基金、高能时代发力固态电池,开辟南孚电池发展新赛道

安孚科技出手,发力固态电池。 3月7日晚间,安孚科技(603031.SH)发布公告称,公司控股子公司南孚电池拟与南平市绿色产业投资基金有限公司(下称“南平绿色产业基金”)、高能时代(广东横…...

p5.js:模拟 n个彩色小球在一个3D大球体内部弹跳

向 豆包 提问:编写一个 p5.js 脚本,模拟 42 个彩色小球在一个3D大球体内部弹跳。每个小球都应留下一条逐渐消失的轨迹。大球体应缓慢旋转,并显示透明的轮廓线。请确保实现适当的碰撞检测,使小球保持在球体内部。 cd p5-demo copy…...

Kali WebDAV 客户端工具——Cadaver 与 Davtest

1. 工具简介 在 WebDAV 服务器管理和安全测试过程中,Cadaver 和 Davtest 是两款常用的命令行工具。 Cadaver 是一个 Unix/Linux 命令行 WebDAV 客户端,主要用于远程文件管理,支持文件上传、下载、移动、复制、删除等操作。Davtest 则是一款…...

MySQL复习笔记

MySQL复习笔记 1.MySQL 1.1什么是数据库 数据库(DB, DataBase) 概念:数据仓库,软件,安装在操作系统(window、linux、mac…)之上 作用:存储数据,管理数据 1.2 数据库分类 关系型数据库&#…...

六十天前端强化训练之第十四天之深入理解JavaScript异步编程

欢迎来到编程星辰海的博客讲解 目录 一、异步编程的本质与必要性 1.1 单线程的JavaScript运行时 1.2 阻塞与非阻塞的微观区别 1.3 异步操作的性能代价 二、事件循环机制深度解析 2.1 浏览器环境的事件循环架构 核心组件详解: 2.2 执行顺序实战分析 2.3 Nod…...

集合论--形式化语言里的汇编码

如果一阶逻辑是数学这门形式化语言里的机器码,那么集合论就是数学这门形式化语言里的汇编码。 基本思想:从集合出发构建所有其它。 构建自然数构建整数构建有理数构建实数构建有序对、笛卡尔积、关系、函数、序列等构建确定有限自动机(DFA) 全景图 常…...

2025最新群智能优化算法:山羊优化算法(Goat Optimization Algorithm, GOA)求解23个经典函数测试集,MATLAB

一、山羊优化算法 山羊优化算法(Goat Optimization Algorithm, GOA)是2025年提出的一种新型生物启发式元启发式算法,灵感来源于山羊在恶劣和资源有限环境中的适应性行为。该算法旨在通过模拟山羊的觅食策略、移动模式和躲避寄生虫的能力&…...

MySQL数据实时同步至Elasticsearch的高效方案:Java实现+源码解析,一文搞定!

引言:为什么需要实时同步? MySQL擅长事务处理,而Elasticsearch(ES)则专注于搜索与分析。将MySQL数据实时同步到ES,可以充分发挥两者的优势,例如: 构建高性能搜索服务 实时数据分析…...

Spring-事务

Spring 事务 事务的基本概念 🔹 什么是事务? 事务是一组数据库操作,它们作为一个整体,要么全部成功,要么全部回滚。 常见的事务场景: 银行转账(扣款和存款必须同时成功) 订单系统…...

Git系列之git tag和ReleaseMilestone

以下是关于 Git Tag、Release 和 Milestone 的深度融合内容,并补充了关于 Git Tag 的所有命令、详细解释和指令实例,条理清晰,结合实际使用场景和案例。 1. Git Tag 1.1 定义 • Tag 是 Git 中用于标记特定提交(commit&#xf…...

考研机试常见基本题型

1、求100以内的素数 sqrt()函数在cmath头文件中。 #include <iostream> #include <cmath> using namespace std;int main() {int count 0; // 用于统计素数的个数// 遍历 100 到 200 之间的每一个数for (int num 100; num < 200; num) {bool isPrime true…...

Android AudioFlinger(四)—— 揭开PlaybackThread面纱

前言&#xff1a; 继上一篇Android AudioFlinger&#xff08;三&#xff09;—— AndroidAudio Flinger 之设备管理我们知道PlaybackThread继承自Re’fBase&#xff0c; 在被第一次引用的时候就会调用onFirstRef&#xff0c;实现如下&#xff1a; void AudioFlinger::Playbac…...

C语言基础系列【20】内存管理

博主介绍&#xff1a;程序喵大人 35- 资深C/C/Rust/Android/iOS客户端开发10年大厂工作经验嵌入式/人工智能/自动驾驶/音视频/游戏开发入门级选手《C20高级编程》《C23高级编程》等多本书籍著译者更多原创精品文章&#xff0c;首发gzh&#xff0c;见文末&#x1f447;&#x1f…...

JavaScript基础-递增和递减运算符

在JavaScript编程中&#xff0c;递增&#xff08;&#xff09;和递减&#xff08;--&#xff09;运算符是用于对数值进行加一或减一操作的基础工具。它们简洁且强大&#xff0c;但如果不正确地使用&#xff0c;可能会导致混淆或错误。本文将详细介绍这两种运算符的不同形式及其…...

计算机毕业设计SpringBoot+Vue.js社区医疗综合服务平台(源码+文档+PPT+讲解)

温馨提示&#xff1a;文末有 CSDN 平台官方提供的学长联系方式的名片&#xff01; 温馨提示&#xff1a;文末有 CSDN 平台官方提供的学长联系方式的名片&#xff01; 温馨提示&#xff1a;文末有 CSDN 平台官方提供的学长联系方式的名片&#xff01; 作者简介&#xff1a;Java领…...

3.6c语言

#define _CRT_SECURE_NO_WARNINGS #include <math.h> #include <stdio.h> int main() {int sum 0,i,j;for (j 1; j < 1000; j){sum 0;for (i 1; i < j; i){if (j % i 0){sum i;} }if (sum j){printf("%d是完数\n", j);}}return 0; }#de…...