当前位置: 首页 > news >正文

人工智能算法工程师(中级)课程11-PyTorch神经网络之循环神经网络RNN与代码详解

大家好,我是微学AI,今天给大家介绍一下人工智能算法工程师(中级)课程11-PyTorch神经网络之循环神经网络RNN与代码详解。循环神经网络(Recurrent Neural Network,RNN)是一种处理序列数据的神经网络。本文将详细介绍RNN网络的原理、运行过程、类别、参数计算,以及应用场景,并附上基于PyTorch框架的完整可运行代码。

文章目录

  • 一、RNN网络的原理
    • RNN的基本结构
  • 二、RNN网络的运行过程
  • 三、RNN的类别
    • 1. 普通RNN
    • 2.LSTM模型:记忆魔法的图书馆
    • 3.GRU模型:高效的记忆工作室
  • 四、RNN网络的参数计算
  • 五、RNN网络的应用场景
  • 六、PyTorch框架搭建RNN网络的
  • 七、总结

一、RNN网络的原理

RNN网络的核心思想是利用历史信息来影响当前输出。与传统的前馈神经网络不同,RNN在网络结构中引入了循环结构,使得网络能够记忆前面的信息。

RNN的基本结构

  1. 输入层:输入序列数据,如 x 1 , x 2 , … , x t x_{1}, x_{2}, \ldots, x_{t} x1,x2,,xt
  2. 隐藏层:包含一系列的循环单元,每个循环单元负责处理当前输入和上一时刻的隐藏状态,输出当前时刻的隐藏状态。
  3. 输出层:根据当前时刻的隐藏状态输出结果。
    在这里插入图片描述

二、RNN网络的运行过程

RNN网络的运行过程可以表示为以下公式:
h t = f ( W h h h t − 1 + W x h x t + b h ) h_{t} = f(W_{hh}h_{t-1} + W_{xh}x_{t} + b_{h}) ht=f(Whhht1+Wxhxt+bh)
y t = g ( W h y h t + b y ) y_{t} = g(W_{hy}h_{t} + b_{y}) yt=g(Whyht+by)
其中, h t h_{t} ht表示第 t t t时刻的隐藏状态, x t x_{t} xt表示第 t t t时刻的输入, y t y_{t} yt表示第 t t t时刻的输出, W W W表示权重矩阵, b b b表示偏置向量, f f f g g g分别表示隐藏层和输出层的激活函数。

三、RNN的类别

1. 普通RNN

最简单的RNN结构,存在梯度消失和梯度爆炸问题,难以学习长距离依赖。

2.LSTM模型:记忆魔法的图书馆

想象一下,LSTM模型就像一个拥有魔法能力的图书馆。这个图书馆的特殊之处在于,它能够记住很久以前读过的书籍内容,并且能够决定哪些信息是重要的,需要长期保留,哪些信息是可以丢弃的。

魔法书架(细胞状态)
图书馆的中心有一个魔法书架,称为“细胞状态”。这个书架上的书可以长期保存,是图书馆记忆的核心。书架上的书籍可以随着时间流动,新的书籍可以加入,旧的书籍也可以被替换。
魔法门(门结构)
图书馆有三个魔法门:遗忘门、输入门和输出门。
遗忘门
这个门决定哪些旧书籍(信息)不再重要,需要从书架上移除。如果某个信息对未来的预测不再重要,遗忘门就会让它“消失”。
输入门
这个门负责决定哪些新书(新信息)应该被添加到书架上。它检查新来的书籍,并决定哪些是有价值的,可以增强图书馆的记忆。
输出门
这个门决定哪些书籍的内容需要被阅读(输出),以影响图书馆的下一步行动。它查看书架上的书籍,并决定哪些信息需要传递到下一个时间步。
在这里插入图片描述

3.GRU模型:高效的记忆工作室

现在,让我们将GRU模型想象成一个高效的记忆工作室。这个工作室的任务与图书馆相似,但是它更简洁,更高效。
工作室的一体化空间(更新门和重置门)
GRU模型将LSTM的遗忘门和输入门合并成了一个叫做“更新门”的机制。同时,它还有一个“重置门”。
更新门
这个门同时负责决定哪些信息需要被遗忘,以及哪些新信息需要被存储。它就像一个高效的助手,一边清理旧资料,一边挑选新资料。
重置门
这个门决定如何将新的输入信息与旧的记忆相结合。有时候,我们需要完全忘记旧的信息,以便更好地吸收新的信息。
在这里插入图片描述

四、RNN网络的参数计算

以普通RNN为例,假设输入序列长度为 T T T,隐藏层维度为 H H H,输出维度为 O O O。则网络的参数计算如下:

  1. 输入权重矩阵: W x h ∈ R H × X W_{xh} \in \mathbb{R}^{H \times X} WxhRH×X,其中 X X X为输入维度。
  2. 隐藏权重矩阵: W h h ∈ R H × H W_{hh} \in \mathbb{R}^{H \times H} WhhRH×H
  3. 输出权重矩阵: W h y ∈ R O × H W_{hy} \in \mathbb{R}^{O \times H} WhyRO×H
  4. 隐藏层偏置向量: b h ∈ R H b_{h} \in \mathbb{R}^{H} bhRH
  5. 输出层偏置向量: b y ∈ R O b_{y} \in \mathbb{R}^{O} byRO

五、RNN网络的应用场景

  1. 自然语言处理:如文本分类、情感分析、机器翻译等。
  2. 语音识别:将语音信号转换为文字。
  3. 时间序列预测:如股票价格预测、气温预测等。

六、PyTorch框架搭建RNN网络的

下面是基于PyTorch框架的RNN网络实现代码:

import torch
import torch.nn as nn
# 定义RNN模型
class RNN(nn.Module):def __init__(self, input_size, hidden_size, output_size):super(RNN, self).__init__()self.hidden_size = hidden_size# 输入权重矩阵self.W_xh = nn.Parameter(torch.randn(input_size, hidden_size))# 隐藏权重矩阵self.W_hh = nn.Parameter(torch.randn(hidden_size, hidden_size))# 输出权重矩阵self.W_hy = nn.Parameter(torch.randn(hidden_size, output_size))# 隐藏层偏置向量self.b_h = nn.Parameter(torch.randn(hidden_size))# 输出层偏置向量self.b_y = nn.Parameter(torch.randn(output_size))def forward(self, x):h = torch.zeros(1, self.hidden_size)for i in range(x.size(0)):h = torch.tanh(torch.mm(x[i], self.W_xh) + torch.mm(h, self.W_hh) + self.b_h)y = torch.mm(h, self.W_hy) + self.b_yreturn y
# 实例化模型
input_size = 10
hidden_size = 20
output_size = 1
model = RNN(input_size, hidden_size, output_size)
# 损失函数和优化器
criterion = nn.MSELoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
# 输入数据
x = torch.randn(5, input_size)
y_true = torch.randn(output_size)
# 训练模型
for epoch in range(100):model.zero_grad()y_pred = model(x)loss = criterion(y_pred, y_true)loss.backward()optimizer.step()if (epoch + 1) % 10 == 0:print(f'Epoch [{epoch+1}/100], Loss: {loss.item()}')
# 测试模型
with torch.no_grad():y_pred = model(x)print(f'Predicted: {y_pred}, True: {y_true}')

以上我利用pytorch搭建了一个RNN模型,用于序列数据的预测。详细解释一下代码:

  1. 定义RNN模型类,继承自nn.Module
  2. 在初始化方法中,定义了输入权重矩阵 W x h W_xh Wxh、隐藏权重矩阵 W h h W_hh Whh、输出权重矩阵 W h y W_hy Why、隐藏层偏置向量 b h b_h bh和输出层偏置向量 b y b_y by
  3. f o r w a r d forward forward方法实现了RNN的前向传播过程。对于每个时间步,计算隐藏状态 h h h,并在最后一个时间步计算输出 y y y
  4. 实例化RNN模型,设置输入维度、隐藏层维度和输出维度。
  5. 定义损失函数 M S E L o s s MSELoss MSELoss和优化器 S G D SGD SGD
  6. 生成随机的输入数据 x x x和真实标签 y t r u e y_true ytrue
  7. 训练模型,通过前向传播、计算损失、反向传播和更新权重。
  8. 每隔10个epoch打印损失,观察模型训练过程。
  9. 在测试阶段,关闭梯度计算,预测输入数据的输出,并与真实标签进行比较。

七、总结

本文详细介绍了循环神经网络(RNN)的原理、运行过程、类别、参数计算和应用场景,并通过PyTorch框架给出了一个完整的RNN模型实现。通过本文,读者可以了解到RNN在处理序列数据方面的优势,以及如何在实际应用中使用RNN。
需要注意的是,实际应用中通常会使用PyTorch提供的内置RNN模块,如nn.RNNnn.LSTMnn.GRU,这些模块提供了更高效、更灵活的实现。以下是一个使用PyTorch内置LSTM模块的示例:

相关文章:

人工智能算法工程师(中级)课程11-PyTorch神经网络之循环神经网络RNN与代码详解

大家好,我是微学AI,今天给大家介绍一下人工智能算法工程师(中级)课程11-PyTorch神经网络之循环神经网络RNN与代码详解。循环神经网络(Recurrent Neural Network,RNN)是一种处理序列数据的神经网络。本文将详细介绍RNN网…...

服务端生成RSA密钥实例

RSA非对称加密算法的一种,这里分享一下服务端生成公钥和私钥的实例,并打印出来。 一:实例代码 package mainimport ("bufio""crypto/rand""crypto/rsa""crypto/x509""encoding/pem"&quo…...

Maven Nexus3 私服搭建、配置、项目发布指南

maven nexus私服搭建 访问nexus3官方镜像库,选择需要的版本下载:Docker Nexus docker pull sonatype/nexus3:3.49.0 创建数据目录并赋权 sudo mkdir /nexus-data && sudo chown -R 200 /nexus-data 运行(数据目录选择硬盘大的卷进行挂载) …...

东方博宜1627 - 暑期的旅游计划(2)

问题描述 期末考试结束了,小华语文、数学、英语三门功课分别考了 x、y、z 分,小华的家长说,如果小华三门功课中有一门考到 90 分或者 90 分以上,那么就去北京旅游,如果都没考到,那么就去南京玩。 请从键盘…...

FastAPI 学习之路(三十五)项目结构优化

之前我们创建的文件都是在一个目录中,但是在我们的实际开发中,肯定不能这样设计,那么我们去创建一个目录,叫models,大致如下。 主要目录是: __init__.py 是一个空文件,说明models是一个package…...

linux源码安装mysql8.0的小白教程

1.下载8.x版本的mysql MySQL :: Download MySQL Community Server (Archived Versions) 2.安装linux 我安装的是Rocky Linux8.6 3.设置ip地址,方便远程连接 使用nmcli或者nmtui设置或修改ip地址 4.使用远程连接工具MobaXterm操作: (1)将mysql8版本的压缩包上传到mybaxterm…...

如何评估独立站的外链质量?

要评估独立站的外链质量时,首先要看的不是别的,而是内容,跟你网站相关的文章内容才是最重要的,其他的一切其实都不重要。什么网站的DA,评级,网站的主要内容跟你的文章内容是否相关其实都不重要,…...

AI在编程领域的作用

AI(人工智能)在软件开发和许多其他领域都发挥着重要作用,但这并不意味着它在取代开发者。相反,AI更多地是在帮助开发者提高工作效率,解决复杂问题,并创造新的可能性。 探讨AI工具对开发者日常工作的影响 …...

医疗器械网络安全 | 漏洞扫描、渗透测试没有发现问题,是否说明我的设备是安全的?

尽管漏洞扫描、模糊测试和渗透测试在评估系统安全性方面是非常重要和有效的工具,但即使这些测试没有发现任何问题,也不能完全保证您的医疗器械是绝对安全的。这是因为安全性的评估是一个多维度、复杂且持续的过程,涉及多个方面和因素。以下是…...

【GameFramework扩展应用】6-4、GameFramework框架增加AB包加解密功能

推荐阅读 CSDN主页GitHub开源地址Unity3D插件分享简书地址QQ群:398291828大家好,我是佛系工程师☆恬静的小魔龙☆,不定时更新Unity开发技巧,觉得有用记得一键三连哦。 一、前言 【GameFramework框架】系列教程目录: https://blog.csdn.net/q764424567/article/details/1…...

通用图形处理器设计GPGPU基础与架构(二)

一、前言 本系列旨在介绍通用图形处理器设计GPGPU的基础与架构,因此在介绍GPGPU具体架构之前,需要了解GPGPU的编程模型,了解软件层面是怎么做到并行的,硬件层面又要怎么配合软件,乃至定出合适的架构来实现软硬件协同。…...

在一个使用了 Sass 的 React Webpack 项目中安装和使用 Tailwind CSS

要在一个使用了 Sass 的 React Webpack 项目中安装和使用 Tailwind CSS,可以按照以下步骤操作: 1. 安装 Tailwind CSS 及其依赖 首先,确保你的项目根目录下有 package.json 文件,然后运行以下命令来安装 Tailwind CSS 及其所需的…...

HDMI简介

本篇主要介绍HDMI常见接口以及TMDS传输技术。 文章目录 一、HDMI简介二、TMDS传输技术1.编码(encoder)2.并转串(serializer)——OSERDESE2原语3.单端转差分——OBUFDS源语 三、常见的几种信号传输方式 一、HDMI简介 HDMI(High-Definition Multimedia I…...

原作者带队,LSTM卷土重来之Vision-LSTM出世

与 DeiT 等使用 ViT 和 Vision-Mamba (Vim) 方法的模型相比,ViL 的性能更胜一筹。 AI 领域的研究者应该还记得,在 Transformer 诞生后的三年,谷歌将这一自然语言处理届的重要研究扩展到了视觉领域,也就是 Vision Transformer。后来…...

Fiddler 抓包工具抓https

Fiddler 抓包工具抓https...

详细谈谈负载均衡的startupProbe探针、livenessProbe探针、readnessProbe探针如何使用以及使用差异化

文章目录 startupProbe探针startupProbe说明示例配置参数解释 使用场景说明实例——要求: 容器在8秒内完成启动,否则杀死对应容器工作流程说明timeoutSeconds: 和 periodSeconds: 参数顺序说明 livenessProbe探针livenessProbe说明示例配置参数解释 使用…...

守望数据边界:sklearn中的离群点检测技术

守望数据边界:sklearn中的离群点检测技术 在数据分析和机器学习项目中,离群点检测是一项关键任务。离群点,又称异常值或离群点,是指那些与其他数据显著不同的观测值。这些点可能由测量误差、数据录入错误或真实的变异性造成。正确…...

python工作中遇到的坑

1. 字典拷贝 有些场景下,需要对字典拷贝一个副本。这个副本用于保存原始数据,然后原来的字典去参与其他运算,或者作为参数传递给一些函数。 例如, >>> dict_a {"name": "John", "address&q…...

中职网络安全wire0077数据包分析

从靶机服务器的FTP上下载wire0077.pcap,分析该文件,找出黑客入侵使用的协议,提交协议名称 SMTP 分析该文件,找出黑客入侵获取的zip压缩包,提交压缩包文件名 DESKTOP-M1JC4XX_2020_09_24_22_43_12.zip 分析该文件&…...

引领未来:在【PyCharm】中利用【机器学习】与【支持向量机】实现高效【图像识别】

目录 一、数据准备 1. 获取数据集 2. 数据可视化 3. 数据清洗 二、特征提取 1. 数据标准化 2. 图像增强 三、模型训练 1. 划分训练集和测试集 2. 训练 SVM 模型 3. 参数调优 四、模型评估 1. 评估模型性能 2. 可视化结果 五、预测新图像 1. 加载和预处理新图像…...

240707-Sphinx配置Pydata-Sphinx-Theme

Step A. 最终效果 Step B. 为什么选择Pydata-Sphinx-Theme主题 Gallery of sites using this theme — PyData Theme 0.15.4 documentation Step 1. 创建并激活Conda环境 conda create -n rtd_pydata python3.10 conda activate rtd_pydataStep 2. 安装默认的工具包 pip in…...

华为如何做成数字化转型?

目录 企业数字化转型是什么? 华为如何定义数字化转型? 为什么做数字化转型? 怎么做数字化转型? 华为IPD的最佳实践之“金蝶云” 企业数字化转型是什么? 先看一下案例,华为经历了多次战略转型&#xf…...

Python | Leetcode Python题解之第229题多数元素II

题目: 题解: class Solution:def majorityElement(self, nums: List[int]) -> List[int]:cnt {}ans []for v in nums:if v in cnt:cnt[v] 1else:cnt[v] 1for item in cnt.keys():if cnt[item] > len(nums)//3:ans.append(item)return ans...

TCP/IP模型和OSI模型的区别(面试题)

OSI模型,是国际标准化组织ISO制定的用于计算机或通讯系统间互联的标准化体系,主要分为7个层级: 物理层数据链路层网络层传输层会话层表示层应用层 虽然OSI模型在理论上更全面,但是在实际网络通讯中,TCP/IP模型更加实…...

UML建模工具Draw.io简介

新书速览|《UML 2.5基础、建模与设计实践 Draw.io是一个非常出色的免费、开源、简洁、方便的绘图软件,利用这款软件可以绘制出生动有趣的图形,包括流程图、地图、网络架构图、UML用例图、流程图等。它支持各种快捷键,免费提供了1000多张画图…...

qt udp 协议 详解

1.qt udp 协议链接举例 在Qt框架中,使用UDP协议进行通信主要依赖于QUdpSocket类。以下是一个基于Qt的UDP通信示例,包括UDP套接字的创建、绑定端口、发送和接收数据报的步骤。 1. 创建UDP套接字 首先,需要创建一个QUdpSocket对象。这通常在…...

ubuntu 换源

sudo apt update 错误如下 Ign:1 http://security.ubuntu.com/ubuntu focal-security InRelease Ign:2 http://us.archive.ubuntu.com/ubuntu focal InRelease Err:3 http://security.ubuntu.com/ubuntu focal-security Release SECURITY: URL redirect target…...

基于ssm的图书管理系统的设计与实现

摘 要 在当今信息技术日新月异的时代背景下,图书管理领域正经历着深刻的变革,传统的管理模式已难以适应现代社会的快节奏和高要求,逐渐向数字化、智能化的方向演进。本论文聚焦于这一转变趋势,致力于设计并成功实现一个基于 SSM&…...

python压缩PDF方案(Ghostscript+pdfc)

第一步:安装Ghostscript Ghostscript是一套建基于Adobe、PostScript及可移植文档格式(PDF)的页面描述语言等而编译成的免费软件。它可以作为文件格式转换器,如PostScript和PDF转换器,也为编程提供API。[1]PDF压缩本质…...

kotlin 基础

文章目录 1、安装 Java 和 Kotlin 环境2、程序代码基本结构3、变量的声明与使用4、数据类型5、数字类型的运算1)布尔类型2)字符类型3)字符串类型 6、 选择结构1)(if - else)2) 选择结构(when&am…...