当前位置: 首页 > news >正文

人工智能算法工程师(中级)课程11-PyTorch神经网络之循环神经网络RNN与代码详解

大家好,我是微学AI,今天给大家介绍一下人工智能算法工程师(中级)课程11-PyTorch神经网络之循环神经网络RNN与代码详解。循环神经网络(Recurrent Neural Network,RNN)是一种处理序列数据的神经网络。本文将详细介绍RNN网络的原理、运行过程、类别、参数计算,以及应用场景,并附上基于PyTorch框架的完整可运行代码。

文章目录

  • 一、RNN网络的原理
    • RNN的基本结构
  • 二、RNN网络的运行过程
  • 三、RNN的类别
    • 1. 普通RNN
    • 2.LSTM模型:记忆魔法的图书馆
    • 3.GRU模型:高效的记忆工作室
  • 四、RNN网络的参数计算
  • 五、RNN网络的应用场景
  • 六、PyTorch框架搭建RNN网络的
  • 七、总结

一、RNN网络的原理

RNN网络的核心思想是利用历史信息来影响当前输出。与传统的前馈神经网络不同,RNN在网络结构中引入了循环结构,使得网络能够记忆前面的信息。

RNN的基本结构

  1. 输入层:输入序列数据,如 x 1 , x 2 , … , x t x_{1}, x_{2}, \ldots, x_{t} x1,x2,,xt
  2. 隐藏层:包含一系列的循环单元,每个循环单元负责处理当前输入和上一时刻的隐藏状态,输出当前时刻的隐藏状态。
  3. 输出层:根据当前时刻的隐藏状态输出结果。
    在这里插入图片描述

二、RNN网络的运行过程

RNN网络的运行过程可以表示为以下公式:
h t = f ( W h h h t − 1 + W x h x t + b h ) h_{t} = f(W_{hh}h_{t-1} + W_{xh}x_{t} + b_{h}) ht=f(Whhht1+Wxhxt+bh)
y t = g ( W h y h t + b y ) y_{t} = g(W_{hy}h_{t} + b_{y}) yt=g(Whyht+by)
其中, h t h_{t} ht表示第 t t t时刻的隐藏状态, x t x_{t} xt表示第 t t t时刻的输入, y t y_{t} yt表示第 t t t时刻的输出, W W W表示权重矩阵, b b b表示偏置向量, f f f g g g分别表示隐藏层和输出层的激活函数。

三、RNN的类别

1. 普通RNN

最简单的RNN结构,存在梯度消失和梯度爆炸问题,难以学习长距离依赖。

2.LSTM模型:记忆魔法的图书馆

想象一下,LSTM模型就像一个拥有魔法能力的图书馆。这个图书馆的特殊之处在于,它能够记住很久以前读过的书籍内容,并且能够决定哪些信息是重要的,需要长期保留,哪些信息是可以丢弃的。

魔法书架(细胞状态)
图书馆的中心有一个魔法书架,称为“细胞状态”。这个书架上的书可以长期保存,是图书馆记忆的核心。书架上的书籍可以随着时间流动,新的书籍可以加入,旧的书籍也可以被替换。
魔法门(门结构)
图书馆有三个魔法门:遗忘门、输入门和输出门。
遗忘门
这个门决定哪些旧书籍(信息)不再重要,需要从书架上移除。如果某个信息对未来的预测不再重要,遗忘门就会让它“消失”。
输入门
这个门负责决定哪些新书(新信息)应该被添加到书架上。它检查新来的书籍,并决定哪些是有价值的,可以增强图书馆的记忆。
输出门
这个门决定哪些书籍的内容需要被阅读(输出),以影响图书馆的下一步行动。它查看书架上的书籍,并决定哪些信息需要传递到下一个时间步。
在这里插入图片描述

3.GRU模型:高效的记忆工作室

现在,让我们将GRU模型想象成一个高效的记忆工作室。这个工作室的任务与图书馆相似,但是它更简洁,更高效。
工作室的一体化空间(更新门和重置门)
GRU模型将LSTM的遗忘门和输入门合并成了一个叫做“更新门”的机制。同时,它还有一个“重置门”。
更新门
这个门同时负责决定哪些信息需要被遗忘,以及哪些新信息需要被存储。它就像一个高效的助手,一边清理旧资料,一边挑选新资料。
重置门
这个门决定如何将新的输入信息与旧的记忆相结合。有时候,我们需要完全忘记旧的信息,以便更好地吸收新的信息。
在这里插入图片描述

四、RNN网络的参数计算

以普通RNN为例,假设输入序列长度为 T T T,隐藏层维度为 H H H,输出维度为 O O O。则网络的参数计算如下:

  1. 输入权重矩阵: W x h ∈ R H × X W_{xh} \in \mathbb{R}^{H \times X} WxhRH×X,其中 X X X为输入维度。
  2. 隐藏权重矩阵: W h h ∈ R H × H W_{hh} \in \mathbb{R}^{H \times H} WhhRH×H
  3. 输出权重矩阵: W h y ∈ R O × H W_{hy} \in \mathbb{R}^{O \times H} WhyRO×H
  4. 隐藏层偏置向量: b h ∈ R H b_{h} \in \mathbb{R}^{H} bhRH
  5. 输出层偏置向量: b y ∈ R O b_{y} \in \mathbb{R}^{O} byRO

五、RNN网络的应用场景

  1. 自然语言处理:如文本分类、情感分析、机器翻译等。
  2. 语音识别:将语音信号转换为文字。
  3. 时间序列预测:如股票价格预测、气温预测等。

六、PyTorch框架搭建RNN网络的

下面是基于PyTorch框架的RNN网络实现代码:

import torch
import torch.nn as nn
# 定义RNN模型
class RNN(nn.Module):def __init__(self, input_size, hidden_size, output_size):super(RNN, self).__init__()self.hidden_size = hidden_size# 输入权重矩阵self.W_xh = nn.Parameter(torch.randn(input_size, hidden_size))# 隐藏权重矩阵self.W_hh = nn.Parameter(torch.randn(hidden_size, hidden_size))# 输出权重矩阵self.W_hy = nn.Parameter(torch.randn(hidden_size, output_size))# 隐藏层偏置向量self.b_h = nn.Parameter(torch.randn(hidden_size))# 输出层偏置向量self.b_y = nn.Parameter(torch.randn(output_size))def forward(self, x):h = torch.zeros(1, self.hidden_size)for i in range(x.size(0)):h = torch.tanh(torch.mm(x[i], self.W_xh) + torch.mm(h, self.W_hh) + self.b_h)y = torch.mm(h, self.W_hy) + self.b_yreturn y
# 实例化模型
input_size = 10
hidden_size = 20
output_size = 1
model = RNN(input_size, hidden_size, output_size)
# 损失函数和优化器
criterion = nn.MSELoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
# 输入数据
x = torch.randn(5, input_size)
y_true = torch.randn(output_size)
# 训练模型
for epoch in range(100):model.zero_grad()y_pred = model(x)loss = criterion(y_pred, y_true)loss.backward()optimizer.step()if (epoch + 1) % 10 == 0:print(f'Epoch [{epoch+1}/100], Loss: {loss.item()}')
# 测试模型
with torch.no_grad():y_pred = model(x)print(f'Predicted: {y_pred}, True: {y_true}')

以上我利用pytorch搭建了一个RNN模型,用于序列数据的预测。详细解释一下代码:

  1. 定义RNN模型类,继承自nn.Module
  2. 在初始化方法中,定义了输入权重矩阵 W x h W_xh Wxh、隐藏权重矩阵 W h h W_hh Whh、输出权重矩阵 W h y W_hy Why、隐藏层偏置向量 b h b_h bh和输出层偏置向量 b y b_y by
  3. f o r w a r d forward forward方法实现了RNN的前向传播过程。对于每个时间步,计算隐藏状态 h h h,并在最后一个时间步计算输出 y y y
  4. 实例化RNN模型,设置输入维度、隐藏层维度和输出维度。
  5. 定义损失函数 M S E L o s s MSELoss MSELoss和优化器 S G D SGD SGD
  6. 生成随机的输入数据 x x x和真实标签 y t r u e y_true ytrue
  7. 训练模型,通过前向传播、计算损失、反向传播和更新权重。
  8. 每隔10个epoch打印损失,观察模型训练过程。
  9. 在测试阶段,关闭梯度计算,预测输入数据的输出,并与真实标签进行比较。

七、总结

本文详细介绍了循环神经网络(RNN)的原理、运行过程、类别、参数计算和应用场景,并通过PyTorch框架给出了一个完整的RNN模型实现。通过本文,读者可以了解到RNN在处理序列数据方面的优势,以及如何在实际应用中使用RNN。
需要注意的是,实际应用中通常会使用PyTorch提供的内置RNN模块,如nn.RNNnn.LSTMnn.GRU,这些模块提供了更高效、更灵活的实现。以下是一个使用PyTorch内置LSTM模块的示例:

相关文章:

人工智能算法工程师(中级)课程11-PyTorch神经网络之循环神经网络RNN与代码详解

大家好,我是微学AI,今天给大家介绍一下人工智能算法工程师(中级)课程11-PyTorch神经网络之循环神经网络RNN与代码详解。循环神经网络(Recurrent Neural Network,RNN)是一种处理序列数据的神经网络。本文将详细介绍RNN网…...

服务端生成RSA密钥实例

RSA非对称加密算法的一种,这里分享一下服务端生成公钥和私钥的实例,并打印出来。 一:实例代码 package mainimport ("bufio""crypto/rand""crypto/rsa""crypto/x509""encoding/pem"&quo…...

Maven Nexus3 私服搭建、配置、项目发布指南

maven nexus私服搭建 访问nexus3官方镜像库,选择需要的版本下载:Docker Nexus docker pull sonatype/nexus3:3.49.0 创建数据目录并赋权 sudo mkdir /nexus-data && sudo chown -R 200 /nexus-data 运行(数据目录选择硬盘大的卷进行挂载) …...

东方博宜1627 - 暑期的旅游计划(2)

问题描述 期末考试结束了,小华语文、数学、英语三门功课分别考了 x、y、z 分,小华的家长说,如果小华三门功课中有一门考到 90 分或者 90 分以上,那么就去北京旅游,如果都没考到,那么就去南京玩。 请从键盘…...

FastAPI 学习之路(三十五)项目结构优化

之前我们创建的文件都是在一个目录中,但是在我们的实际开发中,肯定不能这样设计,那么我们去创建一个目录,叫models,大致如下。 主要目录是: __init__.py 是一个空文件,说明models是一个package…...

linux源码安装mysql8.0的小白教程

1.下载8.x版本的mysql MySQL :: Download MySQL Community Server (Archived Versions) 2.安装linux 我安装的是Rocky Linux8.6 3.设置ip地址,方便远程连接 使用nmcli或者nmtui设置或修改ip地址 4.使用远程连接工具MobaXterm操作: (1)将mysql8版本的压缩包上传到mybaxterm…...

如何评估独立站的外链质量?

要评估独立站的外链质量时,首先要看的不是别的,而是内容,跟你网站相关的文章内容才是最重要的,其他的一切其实都不重要。什么网站的DA,评级,网站的主要内容跟你的文章内容是否相关其实都不重要,…...

AI在编程领域的作用

AI(人工智能)在软件开发和许多其他领域都发挥着重要作用,但这并不意味着它在取代开发者。相反,AI更多地是在帮助开发者提高工作效率,解决复杂问题,并创造新的可能性。 探讨AI工具对开发者日常工作的影响 …...

医疗器械网络安全 | 漏洞扫描、渗透测试没有发现问题,是否说明我的设备是安全的?

尽管漏洞扫描、模糊测试和渗透测试在评估系统安全性方面是非常重要和有效的工具,但即使这些测试没有发现任何问题,也不能完全保证您的医疗器械是绝对安全的。这是因为安全性的评估是一个多维度、复杂且持续的过程,涉及多个方面和因素。以下是…...

【GameFramework扩展应用】6-4、GameFramework框架增加AB包加解密功能

推荐阅读 CSDN主页GitHub开源地址Unity3D插件分享简书地址QQ群:398291828大家好,我是佛系工程师☆恬静的小魔龙☆,不定时更新Unity开发技巧,觉得有用记得一键三连哦。 一、前言 【GameFramework框架】系列教程目录: https://blog.csdn.net/q764424567/article/details/1…...

通用图形处理器设计GPGPU基础与架构(二)

一、前言 本系列旨在介绍通用图形处理器设计GPGPU的基础与架构,因此在介绍GPGPU具体架构之前,需要了解GPGPU的编程模型,了解软件层面是怎么做到并行的,硬件层面又要怎么配合软件,乃至定出合适的架构来实现软硬件协同。…...

在一个使用了 Sass 的 React Webpack 项目中安装和使用 Tailwind CSS

要在一个使用了 Sass 的 React Webpack 项目中安装和使用 Tailwind CSS,可以按照以下步骤操作: 1. 安装 Tailwind CSS 及其依赖 首先,确保你的项目根目录下有 package.json 文件,然后运行以下命令来安装 Tailwind CSS 及其所需的…...

HDMI简介

本篇主要介绍HDMI常见接口以及TMDS传输技术。 文章目录 一、HDMI简介二、TMDS传输技术1.编码(encoder)2.并转串(serializer)——OSERDESE2原语3.单端转差分——OBUFDS源语 三、常见的几种信号传输方式 一、HDMI简介 HDMI(High-Definition Multimedia I…...

原作者带队,LSTM卷土重来之Vision-LSTM出世

与 DeiT 等使用 ViT 和 Vision-Mamba (Vim) 方法的模型相比,ViL 的性能更胜一筹。 AI 领域的研究者应该还记得,在 Transformer 诞生后的三年,谷歌将这一自然语言处理届的重要研究扩展到了视觉领域,也就是 Vision Transformer。后来…...

Fiddler 抓包工具抓https

Fiddler 抓包工具抓https...

详细谈谈负载均衡的startupProbe探针、livenessProbe探针、readnessProbe探针如何使用以及使用差异化

文章目录 startupProbe探针startupProbe说明示例配置参数解释 使用场景说明实例——要求: 容器在8秒内完成启动,否则杀死对应容器工作流程说明timeoutSeconds: 和 periodSeconds: 参数顺序说明 livenessProbe探针livenessProbe说明示例配置参数解释 使用…...

守望数据边界:sklearn中的离群点检测技术

守望数据边界:sklearn中的离群点检测技术 在数据分析和机器学习项目中,离群点检测是一项关键任务。离群点,又称异常值或离群点,是指那些与其他数据显著不同的观测值。这些点可能由测量误差、数据录入错误或真实的变异性造成。正确…...

python工作中遇到的坑

1. 字典拷贝 有些场景下,需要对字典拷贝一个副本。这个副本用于保存原始数据,然后原来的字典去参与其他运算,或者作为参数传递给一些函数。 例如, >>> dict_a {"name": "John", "address&q…...

中职网络安全wire0077数据包分析

从靶机服务器的FTP上下载wire0077.pcap,分析该文件,找出黑客入侵使用的协议,提交协议名称 SMTP 分析该文件,找出黑客入侵获取的zip压缩包,提交压缩包文件名 DESKTOP-M1JC4XX_2020_09_24_22_43_12.zip 分析该文件&…...

引领未来:在【PyCharm】中利用【机器学习】与【支持向量机】实现高效【图像识别】

目录 一、数据准备 1. 获取数据集 2. 数据可视化 3. 数据清洗 二、特征提取 1. 数据标准化 2. 图像增强 三、模型训练 1. 划分训练集和测试集 2. 训练 SVM 模型 3. 参数调优 四、模型评估 1. 评估模型性能 2. 可视化结果 五、预测新图像 1. 加载和预处理新图像…...

DeepSeek 赋能智慧能源:微电网优化调度的智能革新路径

目录 一、智慧能源微电网优化调度概述1.1 智慧能源微电网概念1.2 优化调度的重要性1.3 目前面临的挑战 二、DeepSeek 技术探秘2.1 DeepSeek 技术原理2.2 DeepSeek 独特优势2.3 DeepSeek 在 AI 领域地位 三、DeepSeek 在微电网优化调度中的应用剖析3.1 数据处理与分析3.2 预测与…...

蓝牙 BLE 扫描面试题大全(2):进阶面试题与实战演练

前文覆盖了 BLE 扫描的基础概念与经典问题蓝牙 BLE 扫描面试题大全(1):从基础到实战的深度解析-CSDN博客,但实际面试中,企业更关注候选人对复杂场景的应对能力(如多设备并发扫描、低功耗与高发现率的平衡)和前沿技术的…...

SpringBoot+uniapp 的 Champion 俱乐部微信小程序设计与实现,论文初版实现

摘要 本论文旨在设计并实现基于 SpringBoot 和 uniapp 的 Champion 俱乐部微信小程序,以满足俱乐部线上活动推广、会员管理、社交互动等需求。通过 SpringBoot 搭建后端服务,提供稳定高效的数据处理与业务逻辑支持;利用 uniapp 实现跨平台前…...

关于 WASM:1. WASM 基础原理

一、WASM 简介 1.1 WebAssembly 是什么? WebAssembly(WASM) 是一种能在现代浏览器中高效运行的二进制指令格式,它不是传统的编程语言,而是一种 低级字节码格式,可由高级语言(如 C、C、Rust&am…...

Mobile ALOHA全身模仿学习

一、题目 Mobile ALOHA:通过低成本全身远程操作学习双手移动操作 传统模仿学习(Imitation Learning)缺点:聚焦与桌面操作,缺乏通用任务所需的移动性和灵活性 本论文优点:(1)在ALOHA…...

Springboot社区养老保险系统小程序

一、前言 随着我国经济迅速发展,人们对手机的需求越来越大,各种手机软件也都在被广泛应用,但是对于手机进行数据信息管理,对于手机的各种软件也是备受用户的喜爱,社区养老保险系统小程序被用户普遍使用,为方…...

Caliper 负载(Workload)详细解析

Caliper 负载(Workload)详细解析 负载(Workload)是 Caliper 性能测试的核心部分,它定义了测试期间要执行的具体合约调用行为和交易模式。下面我将全面深入地讲解负载的各个方面。 一、负载模块基本结构 一个典型的负载模块(如 workload.js)包含以下基本结构: use strict;/…...

0x-3-Oracle 23 ai-sqlcl 25.1 集成安装-配置和优化

是不是受够了安装了oracle database之后sqlplus的简陋,无法删除无法上下翻页的苦恼。 可以安装readline和rlwrap插件的话,配置.bahs_profile后也能解决上下翻页这些,但是很多生产环境无法安装rpm包。 oracle提供了sqlcl免费许可&#xff0c…...

DAY 26 函数专题1

函数定义与参数知识点回顾:1. 函数的定义2. 变量作用域:局部变量和全局变量3. 函数的参数类型:位置参数、默认参数、不定参数4. 传递参数的手段:关键词参数5 题目1:计算圆的面积 任务: 编写一…...

OCR MLLM Evaluation

为什么需要评测体系?——背景与矛盾 ​​ 能干的事:​​ 看清楚发票、身份证上的字(准确率>90%),速度飞快(眨眼间完成)。​​干不了的事:​​ 碰到复杂表格(合并单元…...