当前位置: 首页 > news >正文

人工智能算法工程师(中级)课程11-PyTorch神经网络之循环神经网络RNN与代码详解

大家好,我是微学AI,今天给大家介绍一下人工智能算法工程师(中级)课程11-PyTorch神经网络之循环神经网络RNN与代码详解。循环神经网络(Recurrent Neural Network,RNN)是一种处理序列数据的神经网络。本文将详细介绍RNN网络的原理、运行过程、类别、参数计算,以及应用场景,并附上基于PyTorch框架的完整可运行代码。

文章目录

  • 一、RNN网络的原理
    • RNN的基本结构
  • 二、RNN网络的运行过程
  • 三、RNN的类别
    • 1. 普通RNN
    • 2.LSTM模型:记忆魔法的图书馆
    • 3.GRU模型:高效的记忆工作室
  • 四、RNN网络的参数计算
  • 五、RNN网络的应用场景
  • 六、PyTorch框架搭建RNN网络的
  • 七、总结

一、RNN网络的原理

RNN网络的核心思想是利用历史信息来影响当前输出。与传统的前馈神经网络不同,RNN在网络结构中引入了循环结构,使得网络能够记忆前面的信息。

RNN的基本结构

  1. 输入层:输入序列数据,如 x 1 , x 2 , … , x t x_{1}, x_{2}, \ldots, x_{t} x1,x2,,xt
  2. 隐藏层:包含一系列的循环单元,每个循环单元负责处理当前输入和上一时刻的隐藏状态,输出当前时刻的隐藏状态。
  3. 输出层:根据当前时刻的隐藏状态输出结果。
    在这里插入图片描述

二、RNN网络的运行过程

RNN网络的运行过程可以表示为以下公式:
h t = f ( W h h h t − 1 + W x h x t + b h ) h_{t} = f(W_{hh}h_{t-1} + W_{xh}x_{t} + b_{h}) ht=f(Whhht1+Wxhxt+bh)
y t = g ( W h y h t + b y ) y_{t} = g(W_{hy}h_{t} + b_{y}) yt=g(Whyht+by)
其中, h t h_{t} ht表示第 t t t时刻的隐藏状态, x t x_{t} xt表示第 t t t时刻的输入, y t y_{t} yt表示第 t t t时刻的输出, W W W表示权重矩阵, b b b表示偏置向量, f f f g g g分别表示隐藏层和输出层的激活函数。

三、RNN的类别

1. 普通RNN

最简单的RNN结构,存在梯度消失和梯度爆炸问题,难以学习长距离依赖。

2.LSTM模型:记忆魔法的图书馆

想象一下,LSTM模型就像一个拥有魔法能力的图书馆。这个图书馆的特殊之处在于,它能够记住很久以前读过的书籍内容,并且能够决定哪些信息是重要的,需要长期保留,哪些信息是可以丢弃的。

魔法书架(细胞状态)
图书馆的中心有一个魔法书架,称为“细胞状态”。这个书架上的书可以长期保存,是图书馆记忆的核心。书架上的书籍可以随着时间流动,新的书籍可以加入,旧的书籍也可以被替换。
魔法门(门结构)
图书馆有三个魔法门:遗忘门、输入门和输出门。
遗忘门
这个门决定哪些旧书籍(信息)不再重要,需要从书架上移除。如果某个信息对未来的预测不再重要,遗忘门就会让它“消失”。
输入门
这个门负责决定哪些新书(新信息)应该被添加到书架上。它检查新来的书籍,并决定哪些是有价值的,可以增强图书馆的记忆。
输出门
这个门决定哪些书籍的内容需要被阅读(输出),以影响图书馆的下一步行动。它查看书架上的书籍,并决定哪些信息需要传递到下一个时间步。
在这里插入图片描述

3.GRU模型:高效的记忆工作室

现在,让我们将GRU模型想象成一个高效的记忆工作室。这个工作室的任务与图书馆相似,但是它更简洁,更高效。
工作室的一体化空间(更新门和重置门)
GRU模型将LSTM的遗忘门和输入门合并成了一个叫做“更新门”的机制。同时,它还有一个“重置门”。
更新门
这个门同时负责决定哪些信息需要被遗忘,以及哪些新信息需要被存储。它就像一个高效的助手,一边清理旧资料,一边挑选新资料。
重置门
这个门决定如何将新的输入信息与旧的记忆相结合。有时候,我们需要完全忘记旧的信息,以便更好地吸收新的信息。
在这里插入图片描述

四、RNN网络的参数计算

以普通RNN为例,假设输入序列长度为 T T T,隐藏层维度为 H H H,输出维度为 O O O。则网络的参数计算如下:

  1. 输入权重矩阵: W x h ∈ R H × X W_{xh} \in \mathbb{R}^{H \times X} WxhRH×X,其中 X X X为输入维度。
  2. 隐藏权重矩阵: W h h ∈ R H × H W_{hh} \in \mathbb{R}^{H \times H} WhhRH×H
  3. 输出权重矩阵: W h y ∈ R O × H W_{hy} \in \mathbb{R}^{O \times H} WhyRO×H
  4. 隐藏层偏置向量: b h ∈ R H b_{h} \in \mathbb{R}^{H} bhRH
  5. 输出层偏置向量: b y ∈ R O b_{y} \in \mathbb{R}^{O} byRO

五、RNN网络的应用场景

  1. 自然语言处理:如文本分类、情感分析、机器翻译等。
  2. 语音识别:将语音信号转换为文字。
  3. 时间序列预测:如股票价格预测、气温预测等。

六、PyTorch框架搭建RNN网络的

下面是基于PyTorch框架的RNN网络实现代码:

import torch
import torch.nn as nn
# 定义RNN模型
class RNN(nn.Module):def __init__(self, input_size, hidden_size, output_size):super(RNN, self).__init__()self.hidden_size = hidden_size# 输入权重矩阵self.W_xh = nn.Parameter(torch.randn(input_size, hidden_size))# 隐藏权重矩阵self.W_hh = nn.Parameter(torch.randn(hidden_size, hidden_size))# 输出权重矩阵self.W_hy = nn.Parameter(torch.randn(hidden_size, output_size))# 隐藏层偏置向量self.b_h = nn.Parameter(torch.randn(hidden_size))# 输出层偏置向量self.b_y = nn.Parameter(torch.randn(output_size))def forward(self, x):h = torch.zeros(1, self.hidden_size)for i in range(x.size(0)):h = torch.tanh(torch.mm(x[i], self.W_xh) + torch.mm(h, self.W_hh) + self.b_h)y = torch.mm(h, self.W_hy) + self.b_yreturn y
# 实例化模型
input_size = 10
hidden_size = 20
output_size = 1
model = RNN(input_size, hidden_size, output_size)
# 损失函数和优化器
criterion = nn.MSELoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
# 输入数据
x = torch.randn(5, input_size)
y_true = torch.randn(output_size)
# 训练模型
for epoch in range(100):model.zero_grad()y_pred = model(x)loss = criterion(y_pred, y_true)loss.backward()optimizer.step()if (epoch + 1) % 10 == 0:print(f'Epoch [{epoch+1}/100], Loss: {loss.item()}')
# 测试模型
with torch.no_grad():y_pred = model(x)print(f'Predicted: {y_pred}, True: {y_true}')

以上我利用pytorch搭建了一个RNN模型,用于序列数据的预测。详细解释一下代码:

  1. 定义RNN模型类,继承自nn.Module
  2. 在初始化方法中,定义了输入权重矩阵 W x h W_xh Wxh、隐藏权重矩阵 W h h W_hh Whh、输出权重矩阵 W h y W_hy Why、隐藏层偏置向量 b h b_h bh和输出层偏置向量 b y b_y by
  3. f o r w a r d forward forward方法实现了RNN的前向传播过程。对于每个时间步,计算隐藏状态 h h h,并在最后一个时间步计算输出 y y y
  4. 实例化RNN模型,设置输入维度、隐藏层维度和输出维度。
  5. 定义损失函数 M S E L o s s MSELoss MSELoss和优化器 S G D SGD SGD
  6. 生成随机的输入数据 x x x和真实标签 y t r u e y_true ytrue
  7. 训练模型,通过前向传播、计算损失、反向传播和更新权重。
  8. 每隔10个epoch打印损失,观察模型训练过程。
  9. 在测试阶段,关闭梯度计算,预测输入数据的输出,并与真实标签进行比较。

七、总结

本文详细介绍了循环神经网络(RNN)的原理、运行过程、类别、参数计算和应用场景,并通过PyTorch框架给出了一个完整的RNN模型实现。通过本文,读者可以了解到RNN在处理序列数据方面的优势,以及如何在实际应用中使用RNN。
需要注意的是,实际应用中通常会使用PyTorch提供的内置RNN模块,如nn.RNNnn.LSTMnn.GRU,这些模块提供了更高效、更灵活的实现。以下是一个使用PyTorch内置LSTM模块的示例:

相关文章:

人工智能算法工程师(中级)课程11-PyTorch神经网络之循环神经网络RNN与代码详解

大家好,我是微学AI,今天给大家介绍一下人工智能算法工程师(中级)课程11-PyTorch神经网络之循环神经网络RNN与代码详解。循环神经网络(Recurrent Neural Network,RNN)是一种处理序列数据的神经网络。本文将详细介绍RNN网…...

服务端生成RSA密钥实例

RSA非对称加密算法的一种,这里分享一下服务端生成公钥和私钥的实例,并打印出来。 一:实例代码 package mainimport ("bufio""crypto/rand""crypto/rsa""crypto/x509""encoding/pem"&quo…...

Maven Nexus3 私服搭建、配置、项目发布指南

maven nexus私服搭建 访问nexus3官方镜像库,选择需要的版本下载:Docker Nexus docker pull sonatype/nexus3:3.49.0 创建数据目录并赋权 sudo mkdir /nexus-data && sudo chown -R 200 /nexus-data 运行(数据目录选择硬盘大的卷进行挂载) …...

东方博宜1627 - 暑期的旅游计划(2)

问题描述 期末考试结束了,小华语文、数学、英语三门功课分别考了 x、y、z 分,小华的家长说,如果小华三门功课中有一门考到 90 分或者 90 分以上,那么就去北京旅游,如果都没考到,那么就去南京玩。 请从键盘…...

FastAPI 学习之路(三十五)项目结构优化

之前我们创建的文件都是在一个目录中,但是在我们的实际开发中,肯定不能这样设计,那么我们去创建一个目录,叫models,大致如下。 主要目录是: __init__.py 是一个空文件,说明models是一个package…...

linux源码安装mysql8.0的小白教程

1.下载8.x版本的mysql MySQL :: Download MySQL Community Server (Archived Versions) 2.安装linux 我安装的是Rocky Linux8.6 3.设置ip地址,方便远程连接 使用nmcli或者nmtui设置或修改ip地址 4.使用远程连接工具MobaXterm操作: (1)将mysql8版本的压缩包上传到mybaxterm…...

如何评估独立站的外链质量?

要评估独立站的外链质量时,首先要看的不是别的,而是内容,跟你网站相关的文章内容才是最重要的,其他的一切其实都不重要。什么网站的DA,评级,网站的主要内容跟你的文章内容是否相关其实都不重要,…...

AI在编程领域的作用

AI(人工智能)在软件开发和许多其他领域都发挥着重要作用,但这并不意味着它在取代开发者。相反,AI更多地是在帮助开发者提高工作效率,解决复杂问题,并创造新的可能性。 探讨AI工具对开发者日常工作的影响 …...

医疗器械网络安全 | 漏洞扫描、渗透测试没有发现问题,是否说明我的设备是安全的?

尽管漏洞扫描、模糊测试和渗透测试在评估系统安全性方面是非常重要和有效的工具,但即使这些测试没有发现任何问题,也不能完全保证您的医疗器械是绝对安全的。这是因为安全性的评估是一个多维度、复杂且持续的过程,涉及多个方面和因素。以下是…...

【GameFramework扩展应用】6-4、GameFramework框架增加AB包加解密功能

推荐阅读 CSDN主页GitHub开源地址Unity3D插件分享简书地址QQ群:398291828大家好,我是佛系工程师☆恬静的小魔龙☆,不定时更新Unity开发技巧,觉得有用记得一键三连哦。 一、前言 【GameFramework框架】系列教程目录: https://blog.csdn.net/q764424567/article/details/1…...

通用图形处理器设计GPGPU基础与架构(二)

一、前言 本系列旨在介绍通用图形处理器设计GPGPU的基础与架构,因此在介绍GPGPU具体架构之前,需要了解GPGPU的编程模型,了解软件层面是怎么做到并行的,硬件层面又要怎么配合软件,乃至定出合适的架构来实现软硬件协同。…...

在一个使用了 Sass 的 React Webpack 项目中安装和使用 Tailwind CSS

要在一个使用了 Sass 的 React Webpack 项目中安装和使用 Tailwind CSS,可以按照以下步骤操作: 1. 安装 Tailwind CSS 及其依赖 首先,确保你的项目根目录下有 package.json 文件,然后运行以下命令来安装 Tailwind CSS 及其所需的…...

HDMI简介

本篇主要介绍HDMI常见接口以及TMDS传输技术。 文章目录 一、HDMI简介二、TMDS传输技术1.编码(encoder)2.并转串(serializer)——OSERDESE2原语3.单端转差分——OBUFDS源语 三、常见的几种信号传输方式 一、HDMI简介 HDMI(High-Definition Multimedia I…...

原作者带队,LSTM卷土重来之Vision-LSTM出世

与 DeiT 等使用 ViT 和 Vision-Mamba (Vim) 方法的模型相比,ViL 的性能更胜一筹。 AI 领域的研究者应该还记得,在 Transformer 诞生后的三年,谷歌将这一自然语言处理届的重要研究扩展到了视觉领域,也就是 Vision Transformer。后来…...

Fiddler 抓包工具抓https

Fiddler 抓包工具抓https...

详细谈谈负载均衡的startupProbe探针、livenessProbe探针、readnessProbe探针如何使用以及使用差异化

文章目录 startupProbe探针startupProbe说明示例配置参数解释 使用场景说明实例——要求: 容器在8秒内完成启动,否则杀死对应容器工作流程说明timeoutSeconds: 和 periodSeconds: 参数顺序说明 livenessProbe探针livenessProbe说明示例配置参数解释 使用…...

守望数据边界:sklearn中的离群点检测技术

守望数据边界:sklearn中的离群点检测技术 在数据分析和机器学习项目中,离群点检测是一项关键任务。离群点,又称异常值或离群点,是指那些与其他数据显著不同的观测值。这些点可能由测量误差、数据录入错误或真实的变异性造成。正确…...

python工作中遇到的坑

1. 字典拷贝 有些场景下,需要对字典拷贝一个副本。这个副本用于保存原始数据,然后原来的字典去参与其他运算,或者作为参数传递给一些函数。 例如, >>> dict_a {"name": "John", "address&q…...

中职网络安全wire0077数据包分析

从靶机服务器的FTP上下载wire0077.pcap,分析该文件,找出黑客入侵使用的协议,提交协议名称 SMTP 分析该文件,找出黑客入侵获取的zip压缩包,提交压缩包文件名 DESKTOP-M1JC4XX_2020_09_24_22_43_12.zip 分析该文件&…...

引领未来:在【PyCharm】中利用【机器学习】与【支持向量机】实现高效【图像识别】

目录 一、数据准备 1. 获取数据集 2. 数据可视化 3. 数据清洗 二、特征提取 1. 数据标准化 2. 图像增强 三、模型训练 1. 划分训练集和测试集 2. 训练 SVM 模型 3. 参数调优 四、模型评估 1. 评估模型性能 2. 可视化结果 五、预测新图像 1. 加载和预处理新图像…...

【Python】 -- 趣味代码 - 小恐龙游戏

文章目录 文章目录 00 小恐龙游戏程序设计框架代码结构和功能游戏流程总结01 小恐龙游戏程序设计02 百度网盘地址00 小恐龙游戏程序设计框架 这段代码是一个基于 Pygame 的简易跑酷游戏的完整实现,玩家控制一个角色(龙)躲避障碍物(仙人掌和乌鸦)。以下是代码的详细介绍:…...

理解 MCP 工作流:使用 Ollama 和 LangChain 构建本地 MCP 客户端

🌟 什么是 MCP? 模型控制协议 (MCP) 是一种创新的协议,旨在无缝连接 AI 模型与应用程序。 MCP 是一个开源协议,它标准化了我们的 LLM 应用程序连接所需工具和数据源并与之协作的方式。 可以把它想象成你的 AI 模型 和想要使用它…...

ESP32读取DHT11温湿度数据

芯片:ESP32 环境:Arduino 一、安装DHT11传感器库 红框的库,别安装错了 二、代码 注意,DATA口要连接在D15上 #include "DHT.h" // 包含DHT库#define DHTPIN 15 // 定义DHT11数据引脚连接到ESP32的GPIO15 #define D…...

MySQL中【正则表达式】用法

MySQL 中正则表达式通过 REGEXP 或 RLIKE 操作符实现(两者等价),用于在 WHERE 子句中进行复杂的字符串模式匹配。以下是核心用法和示例: 一、基础语法 SELECT column_name FROM table_name WHERE column_name REGEXP pattern; …...

Java多线程实现之Thread类深度解析

Java多线程实现之Thread类深度解析 一、多线程基础概念1.1 什么是线程1.2 多线程的优势1.3 Java多线程模型 二、Thread类的基本结构与构造函数2.1 Thread类的继承关系2.2 构造函数 三、创建和启动线程3.1 继承Thread类创建线程3.2 实现Runnable接口创建线程 四、Thread类的核心…...

Python ROS2【机器人中间件框架】 简介

销量过万TEEIS德国护膝夏天用薄款 优惠券冠生园 百花蜂蜜428g 挤压瓶纯蜂蜜巨奇严选 鞋子除臭剂360ml 多芬身体磨砂膏280g健70%-75%酒精消毒棉片湿巾1418cm 80片/袋3袋大包清洁食品用消毒 优惠券AIMORNY52朵红玫瑰永生香皂花同城配送非鲜花七夕情人节生日礼物送女友 热卖妙洁棉…...

无人机侦测与反制技术的进展与应用

国家电网无人机侦测与反制技术的进展与应用 引言 随着无人机(无人驾驶飞行器,UAV)技术的快速发展,其在商业、娱乐和军事领域的广泛应用带来了新的安全挑战。特别是对于关键基础设施如电力系统,无人机的“黑飞”&…...

阿里云Ubuntu 22.04 64位搭建Flask流程(亲测)

cd /home 进入home盘 安装虚拟环境: 1、安装virtualenv pip install virtualenv 2.创建新的虚拟环境: virtualenv myenv 3、激活虚拟环境(激活环境可以在当前环境下安装包) source myenv/bin/activate 此时,终端…...

python读取SQLite表个并生成pdf文件

代码用于创建含50列的SQLite数据库并插入500行随机浮点数据,随后读取数据,通过ReportLab生成横向PDF表格,包含格式化(两位小数)及表头、网格线等美观样式。 # 导入所需库 import sqlite3 # 用于操作…...

生信服务器 | 做生信为什么推荐使用Linux服务器?

原文链接&#xff1a;生信服务器 | 做生信为什么推荐使用Linux服务器&#xff1f; 一、 做生信为什么推荐使用服务器&#xff1f; 大家好&#xff0c;我是小杜。在做生信分析的同学&#xff0c;或是将接触学习生信分析的同学&#xff0c;<font style"color:rgb(53, 1…...