当前位置: 首页 > news >正文

【深度学习基础模型】递归神经网络 (Recurrent Neural Networks, RNN) 详细理解并附实现代码。

【深度学习基础模型】递归神经网络 (Recurrent Neural Networks, RNN)

【深度学习基础模型】递归神经网络 (Recurrent Neural Networks, RNN)


文章目录

  • 【深度学习基础模型】递归神经网络 (Recurrent Neural Networks, RNN)
  • 1.算法原理介绍:递归神经网络 (Recurrent Neural Networks, RNN)
    • 1.1 递归神经网络 (RNN) 概述
    • 1.2 RNN 的关键特性
    • 1.3 RNN 的工作原理
    • 1.4 RNN 的问题:梯度消失与梯度爆炸
    • 1.5 RNN 的应用
  • 2.Python 实现 RNN 的应用实例
    • 2.1代码实现:递归神经网络的实现及文本分类应用
    • 2.2代码解释
  • 3.总结


参考地址:https://www.asimovinstitute.org/neural-network-zoo/
论文地址:https://www.sciencedirect.com/science/article/abs/pii/036402139090002E

欢迎宝子们点赞、关注、收藏!欢迎宝子们批评指正!

1.算法原理介绍:递归神经网络 (Recurrent Neural Networks, RNN)

1.1 递归神经网络 (RNN) 概述

递归神经网络 (RNN) 是前馈神经网络 (FFNN) 的扩展,其主要特点是引入了时间维度上的依赖关系,使得网络具有记忆功能

RNN 通过在不同时间步上共享权重,并通过递归连接在时间序列上传播信息,从而能够处理时间序列数据或顺序依赖性任务

相比于传统的 FFNN,RNN 可以理解输入序列的顺序,并根据序列中前面的信息来调整当前的输出。

1.2 RNN 的关键特性

  • 时间依赖性:RNN 中的每个神经元不仅接收当前时刻的输入,还接收上一个时刻的隐藏状态。这个特性使得 RNN 可以在序列数据中传播信息。
  • 权重共享:所有时刻的隐藏层共享相同的权重矩阵,使得RNN能够处理可变长度的输入序列。
  • 递归连接:通过递归连接,RNN 可以将前一时刻的隐藏状态作为当前时刻的输入,从而在时间步之间传播信息。

1.3 RNN 的工作原理

在一个标准的RNN 中,给定输入序列 X = ( x 1 , x 2 , … , x T ) X=(x_1,x_2,…,x_T) X=(x1,x2,,xT),隐藏状态 h t h_t ht的更新公式如下:
h t = σ ( W i h x t + W h h h t − 1 + b h ) h_t=σ(W_{ih}x_t+W_{hh}h_{t-1}+b_h) ht=σ(Wihxt+Whhht1+bh
其中:

  • h t h_t ht是当前时刻的隐藏状态。
  • W i h W_{ih} Wih是输入到隐藏状态的权重矩阵。
  • W h h W_{hh} Whh是隐藏状态之间的递归权重矩阵。
  • x t x_t xt是当前时间步的输入。
  • b h b_h bh是偏置项。
  • σ σ σ是激活函数(如 tanh 或 ReLU)。

最终,输出层的输出 y t y_t yt计算如下:
y t = σ ( W h o h t + b o ) y_t=σ(W_{ho}h_t+b_o) yt=σ(Whoht+bo

1.4 RNN 的问题:梯度消失与梯度爆炸

由于 RNN 在每个时间步上进行梯度传播,如果序列较长,梯度在反向传播时会呈指数增长或减少,导致梯度爆炸或梯度消失问题。这限制了 RNN 处理长期依赖关系的能力。在实际应用中,长短期记忆网络 (LSTM) 和门控循环单元 (GRU) 被提出用来解决这个问题

1.5 RNN 的应用

RNN 广泛用于处理与序列相关的任务,常见应用包括:

  • 自然语言处理 (NLP):如文本生成、机器翻译、文本分类、情感分析等。
  • 语音识别:将语音信号作为输入序列,通过 RNN 处理时间依赖性。
  • 时间序列预测:例如股票价格预测、天气预报等。

2.Python 实现 RNN 的应用实例

我们将使用 Python 和深度学习框架 PyTorch 实现一个简单的 RNN,用于文本分类任务。

2.1代码实现:递归神经网络的实现及文本分类应用

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
import numpy as np# 创建简单的文本数据
# 假设我们有两个类别的文本数据,每个句子都是单词的索引表示
# 类别 0: "I love machine learning", "deep learning is great"
# 类别 1: "I hate spam emails", "phishing attacks are bad"
X = [[1, 2, 3, 4],     # "I love machine learning"[5, 6, 7, 8],     # "deep learning is great"[1, 9, 10, 11],   # "I hate spam emails"[12, 13, 14, 15]  # "phishing attacks are bad"
]
y = [0, 0, 1, 1]  # 标签,0表示积极类别,1表示消极类别# 转换为Tensor格式
X = torch.tensor(X, dtype=torch.long)
y = torch.tensor(y, dtype=torch.long)# 定义数据集和数据加载器
dataset = TensorDataset(X, y)
dataloader = DataLoader(dataset, batch_size=2, shuffle=True)# 定义 RNN 模型
class RNN(nn.Module):def __init__(self, input_size, hidden_size, output_size, num_layers):super(RNN, self).__init__()self.hidden_size = hidden_sizeself.num_layers = num_layersself.embedding = nn.Embedding(input_size, hidden_size)  # 嵌入层,将输入单词的索引转换为向量self.rnn = nn.RNN(hidden_size, hidden_size, num_layers, batch_first=True)  # RNN层self.fc = nn.Linear(hidden_size, output_size)  # 全连接层,用于分类输出def forward(self, x):# 初始化隐藏状态,形状为 (num_layers, batch_size, hidden_size)h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(x.device)# 通过嵌入层转换输入out = self.embedding(x)# 通过 RNN 层out, _ = self.rnn(out, h0)# 取最后一个时间步的隐藏状态作为输出out = out[:, -1, :]# 通过全连接层得到最终的分类输出out = self.fc(out)return out# 模型参数
input_size = 16  # 假设我们有16个不同的单词
hidden_size = 8  # 隐藏层大小
output_size = 2  # 二分类问题
num_layers = 1  # RNN层数# 创建 RNN 模型
model = RNN(input_size, hidden_size, output_size, num_layers)# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()  # 分类任务使用交叉熵损失函数
optimizer = optim.Adam(model.parameters(), lr=0.01)# 训练模型
num_epochs = 20
for epoch in range(num_epochs):for data, labels in dataloader:# 前向传播outputs = model(data)loss = criterion(outputs, labels)# 反向传播和优化optimizer.zero_grad()loss.backward()optimizer.step()if (epoch+1) % 5 == 0:print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}')# 测试模型
with torch.no_grad():test_sentence = torch.tensor([[1, 2, 3, 4]])  # 测试句子 "I love machine learning"prediction = model(test_sentence)predicted_class = torch.argmax(prediction, dim=1)print(f'Predicted class: {predicted_class.item()}')

2.2代码解释

1.定义 RNN 模型:

  • self.embedding = nn.Embedding(input_size, hidden_size):这是嵌入层,用于将输入的单词索引(如“1, 2, 3, 4”)转换为高维向量表示。
  • self.rnn = nn.RNN(hidden_size, hidden_size, num_layers, batch_first=True):定义 RNN 层,输入大小为 hidden_size,输出也为 hidden_size,序列是按批次为第一维度(batch_first=True)。
  • self.fc = nn.Linear(hidden_size, output_size):全连接层,将最后一个时间步的 RNN 输出映射为类别预测(输出大小为2,表示二分类任务)。

2.RNN 的前向传播:

  • h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(x.device):初始化隐藏状态为0。
  • out, _ = self.rnn(out, h0):RNN 前向传播,得到每个时间步的隐藏状态。
  • out = out[:, -1, :]:取最后一个时间步的输出,作为最终输出。
  • out = self.fc(out):通过全连接层进行分类。

3.数据集生成与加载:

  • 我们使用了简单的二分类文本数据(表示为单词索引序列),并转换为 PyTorch 的 TensorDatasetDataLoader

4.训练与测试:

  • 使用 Adam 优化器和交叉熵损失函数训练模型,并在每 5 个 epoch 打印一次损失。
  • 在测试阶段,我们用一个测试句子进行分类预测,并打印出预测的类别。

3.总结

递归神经网络 (RNN) 是处理序列数据的重要工具,适用于自然语言处理、语音识别、时间序列预测等任务。然而,RNN 存在梯度消失与梯度爆炸问题,尤其是在处理长序列时

在实际应用中,RNN 已被改进为 LSTM 和 GRU 等架构,解决了这些问题。通过 PyTorch 实现的 RNN 示例展示了其在文本分类中的应用。

相关文章:

【深度学习基础模型】递归神经网络 (Recurrent Neural Networks, RNN) 详细理解并附实现代码。

【深度学习基础模型】递归神经网络 (Recurrent Neural Networks, RNN) 【深度学习基础模型】递归神经网络 (Recurrent Neural Networks, RNN) 文章目录 【深度学习基础模型】递归神经网络 (Recurrent Neural Networks, RNN)1.算法原理介绍:递归神经网络 (Recurrent…...

python全栈学习记录(十九) hashlib、shutil和tarfile、configparser

hashlib、shutil和tarfile、configparser 文章目录 hashlib、shutil和tarfile、configparser一、hashlib二、shutil和tarfile1.shutil2.tarfile 三、configparser 一、hashlib hash是一种算法,该算法接受传入的内容,经过运算得到一串hash值。如果把hash…...

RL进阶(一):变分推断、生成模型、SAC

参考资料: 视频课程《CS285: Deep Reinforcement Learning, Decision Making, and Control》第18讲、第19讲,Sergey Levine,UCerkeley课件PDF下载:https://rail.eecs.berkeley.edu/deeprlcourse/主要内容:变分推断、生成模型、以及Soft Actor-Critic。变分推断在model-bas…...

WPF 绑定 DataGrid 里面 Button点击事件 TextBlock 双击事件

TextBlock双击事件 <DataGridTemplateColumn Width"*" Header"内标"><DataGridTemplateColumn.CellTemplate><DataTemplate><Grid><TextBlockBackground"Transparent"Tag"{Binding InternalId}"Text"…...

828华为云征文|华为云Flexus云服务器X实例Windows系统部署一键短视频生成AI工具moneyprinter

在追求创新与效率并重的今天&#xff0c;我们公司迎难而上&#xff0c;决定自主搭建一款短视频生成AI工具——MoneyPrinter&#xff0c;旨在为市场带来前所未有的创意风暴。面对服务器选择的难题&#xff0c;我们经过深思熟虑与多方比较&#xff0c;最终将信任票投给了华为云Fl…...

非标精密五金加工的技术要求

非标精密五金加工在现代制造业中占据着重要地位&#xff0c;其对于产品的精度、质量和性能有着较高的要求。以下是时利和整理的其具体的技术要求&#xff1a; 一、高精度的加工设备 非标精密五金加工需要先进的加工设备来保证加工精度。例如&#xff0c;高精度的数控机床是必不…...

新手小白怎么通过云服务器跑pytorch?

新手小白怎么通过云服务器跑pytorch&#xff1f;安装PyTorch的步骤可以根据不同的操作系统和需求有所差异&#xff0c;通过云服务器运行PyTorch的过程主要包括选择GPU云服务器平台、配置服务器环境、部署和运行PyTorch模型、优化性能等步骤。具体步骤如下&#xff1a; 第一步&a…...

Spring 全家桶使用教程

✅作者简介&#xff1a;2022年博客新星 第八。热爱国学的Java后端开发者&#xff0c;修心和技术同步精进。 &#x1f34e;个人主页&#xff1a;Java Fans的博客 &#x1f34a;个人信条&#xff1a;不迁怒&#xff0c;不贰过。小知识&#xff0c;大智慧。 &#x1f49e;当前专栏…...

Spark SQL性能优化高频面试题及答案

目录 高频面试题及答案1. 如何通过分区&#xff08;Partitioning&#xff09;优化Spark SQL查询性能&#xff1f;2. 什么是数据倾斜&#xff08;Data Skew&#xff09;&#xff1f;如何优化&#xff1f;3. 如何使用广播&#xff08;Broadcast&#xff09;优化Join操作&#xff…...

云原生链路观测平台 openobserve + fluent-bit,日志收集

grpc-opentracing https://github.com/grpc-ecosystem/grpc-opentracing openobserve fluent-bit 为啥会选择这个组合 一个 rust 写的一个是c写的&#xff0c;性能和内存方面不用担心&#xff0c;比java 那套好太多了 openobserve 文档 &#xff1a;https://openobserve.ai/…...

Android 车载应用开发指南 - CarService 详解(下)

车载应用正在改变人们的出行体验。从导航到娱乐、从安全到信息服务&#xff0c;车载应用的开发已成为汽车智能化发展的重要组成部分。而对于开发者来说&#xff0c;如何将自己的应用程序无缝集成到车载系统中&#xff0c;利用汽车的硬件和服务能力&#xff0c;是一个极具挑战性…...

【Linux网络 —— 网络基础概念】

Linux网络 —— 网络基础概念 计算机网络背景网络发展 初始协议协议分层协议分层的好处 OSI七层模型TCP/IP五层(或四层)模型 再识协议为什么要有TCP/IP协议&#xff1f;什么是TCP/IP协议&#xff1f;TCP/IP协议与操作系统的关系所以究竟什么是协议&#xff1f; 网络传输基本流程…...

el-form动态标题和输入值,并且最后一个输入框不校验

需求&#xff1a;给了固定的label&#xff0c;叫xx单位&#xff0c;要输入单位的信息&#xff0c;但是属性名称都一样的&#xff0c;UI画图也是表单的形式&#xff0c;所以改为动态添加的形式&#xff0c;实现方式也很简单&#xff0c;循环就完事了&#xff0c;连着表单校验也动…...

一,初始 MyBatis-Plus

一&#xff0c;初始 MyBatis-Plus 文章目录 一&#xff0c;初始 MyBatis-Plus1. MyBatis-Plus 的概述2. 入门配置第一个 MyBatis-Plus 案例3. 补充说明&#xff1a;3.1 通用 Mapper 接口介绍3.1.1 Mapper 接口的 “增删改查”3.1.1.1 查询所有记录3.1.1.2 插入一条数据3.1.1.3 …...

安卓13删除下拉栏中的关机按钮版本2 android13删除下拉栏关机按钮

总纲 android13 rom 开发总纲说明 文章目录 1.前言2.问题分析3.代码分析4.代码修改5.编译6.彩蛋1.前言 顶部导航栏下拉可以看到,底部这里有个设置按钮,点击可以进入设备的设置页面,这里我们将更改为删除,不同用户通过这个地方进入设置。我们之前写过一个文章也是一样的删除…...

快递物流单号识别API接口代码

官网&#xff1a;快递鸟 API参数 一、接口描述/说明 &#xff08;1&#xff09;该接口仅对运单号做出识别&#xff0c;识别可能属于的一家或多家快递公司。 &#xff08;2&#xff09;接口并不返回物流轨迹&#xff0c;用户可结合即时查询接口和订阅查询接口完成轨迹查询、订…...

AI时代的程序员:如何保持和提升核心竞争力

1.引言 随着AIGC&#xff08;如 ChatGPT、Midjourney、Claude 等&#xff09;大语言模型的快速崛起&#xff0c;AI辅助编程工具逐渐成为程序员工作的重要组成部分。这一转变不仅改变了工作方式&#xff0c;更深刻影响了程序员的职业角色和技术路径。有人担心&#xff0c;AI将取…...

Oracle 数据库常用命令与操作指南

Oracle 数据库是企业级系统中常用的数据库管理系统&#xff0c;掌握基础的命令可以让你在日常管理中更加高效。本指南将介绍几条常用的 Oracle 数据库命令&#xff0c;涵盖用户权限管理、修改用户密码、删除用户、以及其他日常操作。 目录 授权用户操作权限使用最高权限登录 O…...

spring boot项目对接人大金仓

先确认一下依赖 第一 是否引入了mybatis-plus多数据源&#xff0c;如果引入了请将版本保持在3.5.0以上 <dependency><groupId>com.baomidou</groupId><artifactId>dynamic-datasource-spring-boot-starter</artifactId><version>${dynam…...

《操作系统 - 清华大学》1 -2:操作系统概述 —— 什么是操作系统

文章目录 1. 操作系统定义2. 操作系统的位置3. 操作系统软件的分类4. 操作系统软件的组成5. 操作系统内核特征 现在来继续讲什么是操作系统&#xff0c;操作系统什么样的&#xff1f;它是一个程序&#xff0c;它和其他程序是什么样的关系&#xff1f;然后它有些什么样的组成&am…...

React 第五十五节 Router 中 useAsyncError的使用详解

前言 useAsyncError 是 React Router v6.4 引入的一个钩子&#xff0c;用于处理异步操作&#xff08;如数据加载&#xff09;中的错误。下面我将详细解释其用途并提供代码示例。 一、useAsyncError 用途 处理异步错误&#xff1a;捕获在 loader 或 action 中发生的异步错误替…...

调用支付宝接口响应40004 SYSTEM_ERROR问题排查

在对接支付宝API的时候&#xff0c;遇到了一些问题&#xff0c;记录一下排查过程。 Body:{"datadigital_fincloud_generalsaas_face_certify_initialize_response":{"msg":"Business Failed","code":"40004","sub_msg…...

云计算——弹性云计算器(ECS)

弹性云服务器&#xff1a;ECS 概述 云计算重构了ICT系统&#xff0c;云计算平台厂商推出使得厂家能够主要关注应用管理而非平台管理的云平台&#xff0c;包含如下主要概念。 ECS&#xff08;Elastic Cloud Server&#xff09;&#xff1a;即弹性云服务器&#xff0c;是云计算…...

Golang 面试经典题:map 的 key 可以是什么类型?哪些不可以?

Golang 面试经典题&#xff1a;map 的 key 可以是什么类型&#xff1f;哪些不可以&#xff1f; 在 Golang 的面试中&#xff0c;map 类型的使用是一个常见的考点&#xff0c;其中对 key 类型的合法性 是一道常被提及的基础却很容易被忽视的问题。本文将带你深入理解 Golang 中…...

《Qt C++ 与 OpenCV:解锁视频播放程序设计的奥秘》

引言:探索视频播放程序设计之旅 在当今数字化时代,多媒体应用已渗透到我们生活的方方面面,从日常的视频娱乐到专业的视频监控、视频会议系统,视频播放程序作为多媒体应用的核心组成部分,扮演着至关重要的角色。无论是在个人电脑、移动设备还是智能电视等平台上,用户都期望…...

2025 后端自学UNIAPP【项目实战:旅游项目】6、我的收藏页面

代码框架视图 1、先添加一个获取收藏景点的列表请求 【在文件my_api.js文件中添加】 // 引入公共的请求封装 import http from ./my_http.js// 登录接口&#xff08;适配服务端返回 Token&#xff09; export const login async (code, avatar) > {const res await http…...

【服务器压力测试】本地PC电脑作为服务器运行时出现卡顿和资源紧张(Windows/Linux)

要让本地PC电脑作为服务器运行时出现卡顿和资源紧张的情况&#xff0c;可以通过以下几种方式模拟或触发&#xff1a; 1. 增加CPU负载 运行大量计算密集型任务&#xff0c;例如&#xff1a; 使用多线程循环执行复杂计算&#xff08;如数学运算、加密解密等&#xff09;。运行图…...

OpenLayers 分屏对比(地图联动)

注&#xff1a;当前使用的是 ol 5.3.0 版本&#xff0c;天地图使用的key请到天地图官网申请&#xff0c;并替换为自己的key 地图分屏对比在WebGIS开发中是很常见的功能&#xff0c;和卷帘图层不一样的是&#xff0c;分屏对比是在各个地图中添加相同或者不同的图层进行对比查看。…...

精益数据分析(97/126):邮件营销与用户参与度的关键指标优化指南

精益数据分析&#xff08;97/126&#xff09;&#xff1a;邮件营销与用户参与度的关键指标优化指南 在数字化营销时代&#xff0c;邮件列表效度、用户参与度和网站性能等指标往往决定着创业公司的增长成败。今天&#xff0c;我们将深入解析邮件打开率、网站可用性、页面参与时…...

深入浅出深度学习基础:从感知机到全连接神经网络的核心原理与应用

文章目录 前言一、感知机 (Perceptron)1.1 基础介绍1.1.1 感知机是什么&#xff1f;1.1.2 感知机的工作原理 1.2 感知机的简单应用&#xff1a;基本逻辑门1.2.1 逻辑与 (Logic AND)1.2.2 逻辑或 (Logic OR)1.2.3 逻辑与非 (Logic NAND) 1.3 感知机的实现1.3.1 简单实现 (基于阈…...