人工智能算法工程师(中级)课程11-PyTorch神经网络之循环神经网络RNN与代码详解
大家好,我是微学AI,今天给大家介绍一下人工智能算法工程师(中级)课程11-PyTorch神经网络之循环神经网络RNN与代码详解。循环神经网络(Recurrent Neural Network,RNN)是一种处理序列数据的神经网络。本文将详细介绍RNN网络的原理、运行过程、类别、参数计算,以及应用场景,并附上基于PyTorch框架的完整可运行代码。
文章目录
- 一、RNN网络的原理
- RNN的基本结构
- 二、RNN网络的运行过程
- 三、RNN的类别
- 1. 普通RNN
- 2.LSTM模型:记忆魔法的图书馆
- 3.GRU模型:高效的记忆工作室
- 四、RNN网络的参数计算
- 五、RNN网络的应用场景
- 六、PyTorch框架搭建RNN网络的
- 七、总结
一、RNN网络的原理
RNN网络的核心思想是利用历史信息来影响当前输出。与传统的前馈神经网络不同,RNN在网络结构中引入了循环结构,使得网络能够记忆前面的信息。
RNN的基本结构
- 输入层:输入序列数据,如 x 1 , x 2 , … , x t x_{1}, x_{2}, \ldots, x_{t} x1,x2,…,xt。
- 隐藏层:包含一系列的循环单元,每个循环单元负责处理当前输入和上一时刻的隐藏状态,输出当前时刻的隐藏状态。
- 输出层:根据当前时刻的隐藏状态输出结果。
二、RNN网络的运行过程
RNN网络的运行过程可以表示为以下公式:
h t = f ( W h h h t − 1 + W x h x t + b h ) h_{t} = f(W_{hh}h_{t-1} + W_{xh}x_{t} + b_{h}) ht=f(Whhht−1+Wxhxt+bh)
y t = g ( W h y h t + b y ) y_{t} = g(W_{hy}h_{t} + b_{y}) yt=g(Whyht+by)
其中, h t h_{t} ht表示第 t t t时刻的隐藏状态, x t x_{t} xt表示第 t t t时刻的输入, y t y_{t} yt表示第 t t t时刻的输出, W W W表示权重矩阵, b b b表示偏置向量, f f f和 g g g分别表示隐藏层和输出层的激活函数。
三、RNN的类别
1. 普通RNN
最简单的RNN结构,存在梯度消失和梯度爆炸问题,难以学习长距离依赖。
2.LSTM模型:记忆魔法的图书馆
想象一下,LSTM模型就像一个拥有魔法能力的图书馆。这个图书馆的特殊之处在于,它能够记住很久以前读过的书籍内容,并且能够决定哪些信息是重要的,需要长期保留,哪些信息是可以丢弃的。
魔法书架(细胞状态)
图书馆的中心有一个魔法书架,称为“细胞状态”。这个书架上的书可以长期保存,是图书馆记忆的核心。书架上的书籍可以随着时间流动,新的书籍可以加入,旧的书籍也可以被替换。
魔法门(门结构)
图书馆有三个魔法门:遗忘门、输入门和输出门。
遗忘门
这个门决定哪些旧书籍(信息)不再重要,需要从书架上移除。如果某个信息对未来的预测不再重要,遗忘门就会让它“消失”。
输入门
这个门负责决定哪些新书(新信息)应该被添加到书架上。它检查新来的书籍,并决定哪些是有价值的,可以增强图书馆的记忆。
输出门
这个门决定哪些书籍的内容需要被阅读(输出),以影响图书馆的下一步行动。它查看书架上的书籍,并决定哪些信息需要传递到下一个时间步。
3.GRU模型:高效的记忆工作室
现在,让我们将GRU模型想象成一个高效的记忆工作室。这个工作室的任务与图书馆相似,但是它更简洁,更高效。
工作室的一体化空间(更新门和重置门)
GRU模型将LSTM的遗忘门和输入门合并成了一个叫做“更新门”的机制。同时,它还有一个“重置门”。
更新门
这个门同时负责决定哪些信息需要被遗忘,以及哪些新信息需要被存储。它就像一个高效的助手,一边清理旧资料,一边挑选新资料。
重置门
这个门决定如何将新的输入信息与旧的记忆相结合。有时候,我们需要完全忘记旧的信息,以便更好地吸收新的信息。
四、RNN网络的参数计算
以普通RNN为例,假设输入序列长度为 T T T,隐藏层维度为 H H H,输出维度为 O O O。则网络的参数计算如下:
- 输入权重矩阵: W x h ∈ R H × X W_{xh} \in \mathbb{R}^{H \times X} Wxh∈RH×X,其中 X X X为输入维度。
- 隐藏权重矩阵: W h h ∈ R H × H W_{hh} \in \mathbb{R}^{H \times H} Whh∈RH×H。
- 输出权重矩阵: W h y ∈ R O × H W_{hy} \in \mathbb{R}^{O \times H} Why∈RO×H。
- 隐藏层偏置向量: b h ∈ R H b_{h} \in \mathbb{R}^{H} bh∈RH。
- 输出层偏置向量: b y ∈ R O b_{y} \in \mathbb{R}^{O} by∈RO。
五、RNN网络的应用场景
- 自然语言处理:如文本分类、情感分析、机器翻译等。
- 语音识别:将语音信号转换为文字。
- 时间序列预测:如股票价格预测、气温预测等。
六、PyTorch框架搭建RNN网络的
下面是基于PyTorch框架的RNN网络实现代码:
import torch
import torch.nn as nn
# 定义RNN模型
class RNN(nn.Module):def __init__(self, input_size, hidden_size, output_size):super(RNN, self).__init__()self.hidden_size = hidden_size# 输入权重矩阵self.W_xh = nn.Parameter(torch.randn(input_size, hidden_size))# 隐藏权重矩阵self.W_hh = nn.Parameter(torch.randn(hidden_size, hidden_size))# 输出权重矩阵self.W_hy = nn.Parameter(torch.randn(hidden_size, output_size))# 隐藏层偏置向量self.b_h = nn.Parameter(torch.randn(hidden_size))# 输出层偏置向量self.b_y = nn.Parameter(torch.randn(output_size))def forward(self, x):h = torch.zeros(1, self.hidden_size)for i in range(x.size(0)):h = torch.tanh(torch.mm(x[i], self.W_xh) + torch.mm(h, self.W_hh) + self.b_h)y = torch.mm(h, self.W_hy) + self.b_yreturn y
# 实例化模型
input_size = 10
hidden_size = 20
output_size = 1
model = RNN(input_size, hidden_size, output_size)
# 损失函数和优化器
criterion = nn.MSELoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
# 输入数据
x = torch.randn(5, input_size)
y_true = torch.randn(output_size)
# 训练模型
for epoch in range(100):model.zero_grad()y_pred = model(x)loss = criterion(y_pred, y_true)loss.backward()optimizer.step()if (epoch + 1) % 10 == 0:print(f'Epoch [{epoch+1}/100], Loss: {loss.item()}')
# 测试模型
with torch.no_grad():y_pred = model(x)print(f'Predicted: {y_pred}, True: {y_true}')
以上我利用pytorch搭建了一个RNN模型,用于序列数据的预测。详细解释一下代码:
- 定义RNN模型类,继承自
nn.Module
。 - 在初始化方法中,定义了输入权重矩阵 W x h W_xh Wxh、隐藏权重矩阵 W h h W_hh Whh、输出权重矩阵 W h y W_hy Why、隐藏层偏置向量 b h b_h bh和输出层偏置向量 b y b_y by。
- f o r w a r d forward forward方法实现了RNN的前向传播过程。对于每个时间步,计算隐藏状态 h h h,并在最后一个时间步计算输出 y y y。
- 实例化RNN模型,设置输入维度、隐藏层维度和输出维度。
- 定义损失函数 M S E L o s s MSELoss MSELoss和优化器 S G D SGD SGD。
- 生成随机的输入数据 x x x和真实标签 y t r u e y_true ytrue。
- 训练模型,通过前向传播、计算损失、反向传播和更新权重。
- 每隔10个epoch打印损失,观察模型训练过程。
- 在测试阶段,关闭梯度计算,预测输入数据的输出,并与真实标签进行比较。
七、总结
本文详细介绍了循环神经网络(RNN)的原理、运行过程、类别、参数计算和应用场景,并通过PyTorch框架给出了一个完整的RNN模型实现。通过本文,读者可以了解到RNN在处理序列数据方面的优势,以及如何在实际应用中使用RNN。
需要注意的是,实际应用中通常会使用PyTorch提供的内置RNN模块,如nn.RNN
、nn.LSTM
和nn.GRU
,这些模块提供了更高效、更灵活的实现。以下是一个使用PyTorch内置LSTM模块的示例:
相关文章:

人工智能算法工程师(中级)课程11-PyTorch神经网络之循环神经网络RNN与代码详解
大家好,我是微学AI,今天给大家介绍一下人工智能算法工程师(中级)课程11-PyTorch神经网络之循环神经网络RNN与代码详解。循环神经网络(Recurrent Neural Network,RNN)是一种处理序列数据的神经网络。本文将详细介绍RNN网…...
服务端生成RSA密钥实例
RSA非对称加密算法的一种,这里分享一下服务端生成公钥和私钥的实例,并打印出来。 一:实例代码 package mainimport ("bufio""crypto/rand""crypto/rsa""crypto/x509""encoding/pem"&quo…...

Maven Nexus3 私服搭建、配置、项目发布指南
maven nexus私服搭建 访问nexus3官方镜像库,选择需要的版本下载:Docker Nexus docker pull sonatype/nexus3:3.49.0 创建数据目录并赋权 sudo mkdir /nexus-data && sudo chown -R 200 /nexus-data 运行(数据目录选择硬盘大的卷进行挂载) …...
东方博宜1627 - 暑期的旅游计划(2)
问题描述 期末考试结束了,小华语文、数学、英语三门功课分别考了 x、y、z 分,小华的家长说,如果小华三门功课中有一门考到 90 分或者 90 分以上,那么就去北京旅游,如果都没考到,那么就去南京玩。 请从键盘…...

FastAPI 学习之路(三十五)项目结构优化
之前我们创建的文件都是在一个目录中,但是在我们的实际开发中,肯定不能这样设计,那么我们去创建一个目录,叫models,大致如下。 主要目录是: __init__.py 是一个空文件,说明models是一个package…...

linux源码安装mysql8.0的小白教程
1.下载8.x版本的mysql MySQL :: Download MySQL Community Server (Archived Versions) 2.安装linux 我安装的是Rocky Linux8.6 3.设置ip地址,方便远程连接 使用nmcli或者nmtui设置或修改ip地址 4.使用远程连接工具MobaXterm操作: (1)将mysql8版本的压缩包上传到mybaxterm…...

如何评估独立站的外链质量?
要评估独立站的外链质量时,首先要看的不是别的,而是内容,跟你网站相关的文章内容才是最重要的,其他的一切其实都不重要。什么网站的DA,评级,网站的主要内容跟你的文章内容是否相关其实都不重要,…...
AI在编程领域的作用
AI(人工智能)在软件开发和许多其他领域都发挥着重要作用,但这并不意味着它在取代开发者。相反,AI更多地是在帮助开发者提高工作效率,解决复杂问题,并创造新的可能性。 探讨AI工具对开发者日常工作的影响 …...

医疗器械网络安全 | 漏洞扫描、渗透测试没有发现问题,是否说明我的设备是安全的?
尽管漏洞扫描、模糊测试和渗透测试在评估系统安全性方面是非常重要和有效的工具,但即使这些测试没有发现任何问题,也不能完全保证您的医疗器械是绝对安全的。这是因为安全性的评估是一个多维度、复杂且持续的过程,涉及多个方面和因素。以下是…...
【GameFramework扩展应用】6-4、GameFramework框架增加AB包加解密功能
推荐阅读 CSDN主页GitHub开源地址Unity3D插件分享简书地址QQ群:398291828大家好,我是佛系工程师☆恬静的小魔龙☆,不定时更新Unity开发技巧,觉得有用记得一键三连哦。 一、前言 【GameFramework框架】系列教程目录: https://blog.csdn.net/q764424567/article/details/1…...

通用图形处理器设计GPGPU基础与架构(二)
一、前言 本系列旨在介绍通用图形处理器设计GPGPU的基础与架构,因此在介绍GPGPU具体架构之前,需要了解GPGPU的编程模型,了解软件层面是怎么做到并行的,硬件层面又要怎么配合软件,乃至定出合适的架构来实现软硬件协同。…...
在一个使用了 Sass 的 React Webpack 项目中安装和使用 Tailwind CSS
要在一个使用了 Sass 的 React Webpack 项目中安装和使用 Tailwind CSS,可以按照以下步骤操作: 1. 安装 Tailwind CSS 及其依赖 首先,确保你的项目根目录下有 package.json 文件,然后运行以下命令来安装 Tailwind CSS 及其所需的…...

HDMI简介
本篇主要介绍HDMI常见接口以及TMDS传输技术。 文章目录 一、HDMI简介二、TMDS传输技术1.编码(encoder)2.并转串(serializer)——OSERDESE2原语3.单端转差分——OBUFDS源语 三、常见的几种信号传输方式 一、HDMI简介 HDMI(High-Definition Multimedia I…...

原作者带队,LSTM卷土重来之Vision-LSTM出世
与 DeiT 等使用 ViT 和 Vision-Mamba (Vim) 方法的模型相比,ViL 的性能更胜一筹。 AI 领域的研究者应该还记得,在 Transformer 诞生后的三年,谷歌将这一自然语言处理届的重要研究扩展到了视觉领域,也就是 Vision Transformer。后来…...

Fiddler 抓包工具抓https
Fiddler 抓包工具抓https...

详细谈谈负载均衡的startupProbe探针、livenessProbe探针、readnessProbe探针如何使用以及使用差异化
文章目录 startupProbe探针startupProbe说明示例配置参数解释 使用场景说明实例——要求: 容器在8秒内完成启动,否则杀死对应容器工作流程说明timeoutSeconds: 和 periodSeconds: 参数顺序说明 livenessProbe探针livenessProbe说明示例配置参数解释 使用…...
守望数据边界:sklearn中的离群点检测技术
守望数据边界:sklearn中的离群点检测技术 在数据分析和机器学习项目中,离群点检测是一项关键任务。离群点,又称异常值或离群点,是指那些与其他数据显著不同的观测值。这些点可能由测量误差、数据录入错误或真实的变异性造成。正确…...
python工作中遇到的坑
1. 字典拷贝 有些场景下,需要对字典拷贝一个副本。这个副本用于保存原始数据,然后原来的字典去参与其他运算,或者作为参数传递给一些函数。 例如, >>> dict_a {"name": "John", "address&q…...

中职网络安全wire0077数据包分析
从靶机服务器的FTP上下载wire0077.pcap,分析该文件,找出黑客入侵使用的协议,提交协议名称 SMTP 分析该文件,找出黑客入侵获取的zip压缩包,提交压缩包文件名 DESKTOP-M1JC4XX_2020_09_24_22_43_12.zip 分析该文件&…...

引领未来:在【PyCharm】中利用【机器学习】与【支持向量机】实现高效【图像识别】
目录 一、数据准备 1. 获取数据集 2. 数据可视化 3. 数据清洗 二、特征提取 1. 数据标准化 2. 图像增强 三、模型训练 1. 划分训练集和测试集 2. 训练 SVM 模型 3. 参数调优 四、模型评估 1. 评估模型性能 2. 可视化结果 五、预测新图像 1. 加载和预处理新图像…...

网络六边形受到攻击
大家读完觉得有帮助记得关注和点赞!!! 抽象 现代智能交通系统 (ITS) 的一个关键要求是能够以安全、可靠和匿名的方式从互联车辆和移动设备收集地理参考数据。Nexagon 协议建立在 IETF 定位器/ID 分离协议 (…...

手游刚开服就被攻击怎么办?如何防御DDoS?
开服初期是手游最脆弱的阶段,极易成为DDoS攻击的目标。一旦遭遇攻击,可能导致服务器瘫痪、玩家流失,甚至造成巨大经济损失。本文为开发者提供一套简洁有效的应急与防御方案,帮助快速应对并构建长期防护体系。 一、遭遇攻击的紧急应…...
今日科技热点速览
🔥 今日科技热点速览 🎮 任天堂Switch 2 正式发售 任天堂新一代游戏主机 Switch 2 今日正式上线发售,主打更强图形性能与沉浸式体验,支持多模态交互,受到全球玩家热捧 。 🤖 人工智能持续突破 DeepSeek-R1&…...
Java入门学习详细版(一)
大家好,Java 学习是一个系统学习的过程,核心原则就是“理论 实践 坚持”,并且需循序渐进,不可过于着急,本篇文章推出的这份详细入门学习资料将带大家从零基础开始,逐步掌握 Java 的核心概念和编程技能。 …...
【学习笔记】深入理解Java虚拟机学习笔记——第4章 虚拟机性能监控,故障处理工具
第2章 虚拟机性能监控,故障处理工具 4.1 概述 略 4.2 基础故障处理工具 4.2.1 jps:虚拟机进程状况工具 命令:jps [options] [hostid] 功能:本地虚拟机进程显示进程ID(与ps相同),可同时显示主类&#x…...

HashMap中的put方法执行流程(流程图)
1 put操作整体流程 HashMap 的 put 操作是其最核心的功能之一。在 JDK 1.8 及以后版本中,其主要逻辑封装在 putVal 这个内部方法中。整个过程大致如下: 初始判断与哈希计算: 首先,putVal 方法会检查当前的 table(也就…...

GruntJS-前端自动化任务运行器从入门到实战
Grunt 完全指南:从入门到实战 一、Grunt 是什么? Grunt是一个基于 Node.js 的前端自动化任务运行器,主要用于自动化执行项目开发中重复性高的任务,例如文件压缩、代码编译、语法检查、单元测试、文件合并等。通过配置简洁的任务…...

Windows安装Miniconda
一、下载 https://www.anaconda.com/download/success 二、安装 三、配置镜像源 Anaconda/Miniconda pip 配置清华镜像源_anaconda配置清华源-CSDN博客 四、常用操作命令 Anaconda/Miniconda 基本操作命令_miniconda创建环境命令-CSDN博客...

数学建模-滑翔伞伞翼面积的设计,运动状态计算和优化 !
我们考虑滑翔伞的伞翼面积设计问题以及运动状态描述。滑翔伞的性能主要取决于伞翼面积、气动特性以及飞行员的重量。我们的目标是建立数学模型来描述滑翔伞的运动状态,并优化伞翼面积的设计。 一、问题分析 滑翔伞在飞行过程中受到重力、升力和阻力的作用。升力和阻力与伞翼面…...

渗透实战PortSwigger靶场:lab13存储型DOM XSS详解
进来是需要留言的,先用做简单的 html 标签测试 发现面的</h1>不见了 数据包中找到了一个loadCommentsWithVulnerableEscapeHtml.js 他是把用户输入的<>进行 html 编码,输入的<>当成字符串处理回显到页面中,看来只是把用户输…...