当前位置: 首页 > news >正文

神经网络之RNN和LSTM(基于pytorch-api)

1.RNN

1.1简介

RNN用于处理序列数据。在传统的神经网络模型中,是从输入层到隐含层再到输出层,层与层之间是全连接的,每层之间的节点是无连接的。但是这种普通的神经网络对于很多问题却无能无力。例如,你要预测句子的下一个单词是什么,一般需要用到前面的单词,因为一个句子中前后单词并不是独立的。RNN之所以称为循环神经网路,即一个序列当前的输出与前面的输出也有关。具体的表现形式为网络会对前面的信息进行记忆并应用于当前输出的计算中,即隐藏层之间的节点不再无连接而是有连接的,并且隐藏层的输入不仅包括输入层的输出还包括上一时刻隐藏层的输出。理论上,RNN能够对任何长度的序列数据进行处理。但是在实践中,为了降低复杂性往往假设当前的状态只与前面的几个状态相关。

1.2RNN结构

1.2.1一对多

   这种结构 无疑是“看图说话”,“看视频”

 1.2.2经典结构n to n

step1 :解析数据

 涉及到词向量的解析模型:一个字或者词把他分为300维向量!!!

而x1,x2,x3,x4....x300就是这些输入的特征

step2:建立一层

为了建模序列问题,RNN引入了隐状态h(hidden state)的概念,h可以对序列形的数据提取特征,接着再转换为输出。先从h1的计算开始看:

step3

2的计算和h1类似。要注意的是,在计算时,每一步使用的参数U、W、b都是一样的,也就是说每个步骤的参数都是共享的,这是RNN的重要特点,一定要牢记。

 依次计算剩下来的(使用相同的参数U,W,b):

这里为了方便起见,只画出序列长度为4的情况,实际上,这个计算过程可以到300()下去。得到输出值的方法就是直接通过h进行计算: 

正如之前所说,一个箭头就表示对对应的向量做一次类似于f(Wx+b)的变换,这里的这个箭头就表示对h1进行一次变换,得到输出y1。的变换,这里的这个箭头就表示对h1进行一次变换,得到输出y1。

 1.3RNN缺点

1. 梯度消失与梯度爆炸问题

  • 原理

    • 在 RNN 中,梯度在反向传播过程中会经过多次矩阵乘法运算。对于传统的激活函数(如 Sigmoid 或 Tanh),其导数范围通常在 0 到 1 之间。当进行多次乘法时,梯度会变得越来越小,最终趋近于 0,这就是梯度消失问题。相反,如果权重矩阵的值较大,梯度在反向传播过程中会不断增大,导致梯度爆炸。

  • 影响

    • 梯度消失会使 RNN 难以学习到序列中的长期依赖关系。因为当梯度趋近于 0 时,网络参数几乎不会更新,早期时间步的信息无法有效地传递到后面的时间步,导致模型无法捕捉到序列中远距离元素之间的关联。

    • 梯度爆炸则会使参数更新幅度过大,导致模型不稳定,无法收敛到最优解,甚至可能使参数值变为 NaN(非数字),导致训练失败。

2. 难以处理长序列

  • 信息丢失

    • 由于 RNN 是按时间步依次处理序列数据的,随着序列长度的增加,早期时间步的信息在传递到后面的时间步时会逐渐被稀释或遗忘。这使得 RNN 在处理长序列时,难以保留和利用序列开头的重要信息,从而影响模型对整个序列的理解和处理能力。

  • 训练效率低

    • 处理长序列时,RNN 需要进行大量的时间步计算,这会导致训练时间显著增加。同时,由于梯度消失和梯度爆炸问题的存在,训练过程可能会变得不稳定,需要更多的技巧和时间来调整参数,进一步降低了训练效率。

3. 并行计算困难

  • 计算方式限制

    • RNN 的结构决定了它在处理序列数据时是按时间步顺序进行的,每个时间步的输出依赖于前一个时间步的隐藏状态。这意味着在计算当前时间步的输出时,必须等待前一个时间步的计算完成,无法像前馈神经网络那样进行大规模的并行计算。

  • 影响训练和推理速度

    • 并行计算能力的缺乏限制了 RNN 在处理大规模数据时的效率。在训练过程中,无法充分利用现代 GPU 的并行计算能力,导致训练时间变长;在推理阶段,也会影响模型的响应速度,降低系统的实时性

2.LSTM

解决梯度消失问题

长短时记忆网络的思路比较简单。原始RNN的隐藏层只有一个状态,即h,它对于短期的输入非常敏感。那么,假如我们再增加一个状态,即c,让它来保存长期的状态,那

 新增加的状态c,称为单元状态(cell state)。我们把上图按照时间维度展开

 LSTM的关键,就是怎样控制长期状态c。在这里,LSTM的思路是使用三个控制开关。第一个开关,负责控制继续保存长期状态c;第二个开关,负责控制把即时状态输入到长期状态c;第三个开关,负责控制是否把长期状态c作为当前的LSTM的输出。三个开关的作用如下图所示:

长短时记忆网络的前向计算

前面描述的开关是怎样在算法中实现的呢?这就用到了门(gate)的概念。门实际上就是一层全连接层,它的输入是一个向量,输出是一个0到1之间的实数向量。假设W是门的权重向量,b是偏置项,那么门可以表示为:

g(x)=\sigma (Wx+b)

就是用门的输出向量按元素乘以我们需要控制的那个向量。因为门的输出是0到1之间的实数向量,那么,当门输出为0时,任何向量与之相乘都会得到0向量,这就相当于啥都不能通过;输出为1时,任何向量与之相乘都不会有任何改变,这就相当于啥都可以通过。因为δ(也就是sigmoid函数)的值域是(0,1)

LSTM用两个门来控制单元状态c的内容,一个是遗忘门(forget gate),它决定了上一时刻的单元状态有多少保留到当前时刻ct;另一个是输入门(input gate),它决定了当前时刻网络的输入xt有多少保存到单元状态ct。LSTM用输出门(output gate)来控制单元状态ct有多少输出到LSTM的当前输出值ht。

1.遗忘门
 

                       

 

 2.输入门

 下图是的~ct计算

 

我们就把LSTM关于当前的记忆~ct和长期的记忆ct-1组合在一起,形成了新的单元状态ct。由于遗忘门的控制,它可以保存很久很久之前的信息,由于输入门的控制,它又可以避免当前无关紧要的内容进入记忆。下面,我们要看看输出门,它控制了长期记忆对当前输出的影响 

输出门:

 LSTM最终的输出,是由输出门和单元下图表示LSTM最终输出的计算:

下图表示LSTM最终输出的计算: 

 总结:

1.就是三个门就是通过全连接层来控制权重,再来控制再当前输入,你一般只需记得这一点就行。(lstm不行了)

2.集成函数:

torch.nn.LSTM(input_size, hidden_size, num_layers=1,bias=True, batch_first=False, dropout=0, bidirectional=False, proj_size=0)

1. input_size 根据词嵌入模型来

  • 类型int

  • 含义:输入特征的维度,也就是每个时间步输入向量的大小。例如,在处理词嵌入时,input_size 就是词向量的维度。

  • 设置原则:你自己的词向量转化模型

2. hidden_size 

  • 类型int

  • 含义:隐藏状态的维度,即 LSTM 单元中隐藏层的神经元数量。这个参数决定了 LSTM 能够学习和表示的信息复杂度。

  • 设置原则:应远小于词嵌入的维度,否则可能会丢失输入特征中的重要信息;但也不需要过大,以免增加模型的复杂度和计算成本。

3. num_layers

  • 类型int,默认值为 1

  • 含义:堆叠的 LSTM 层数。当 num_layers > 1 时,会构建一个多层的 LSTM 网络,上层 LSTM 的输入是下层 LSTM 的输出。

4. bias

  • 类型bool,默认值为 True

  • 含义:是否使用偏置项。如果设置为 True,LSTM 层会在计算过程中添加偏置项。

5. batch_first

  • 类型bool,默认值为 False

  • 含义:指定输入和输出张量的维度顺序。如果设置为 True,输入和输出张量的形状为 (batch_size, seq_len, input_size);如果设置为 False,形状为 (seq_len, batch_size, input_size)

  • 要用seq_len是序列长度

6. dropout

  • 类型float,默认值为 

  • 含义:在除最后一层外的每层 LSTM 输出上应用的 Dropout 概率。Dropout 是一种正则化技术,用于防止过拟合。取值范围为 [0, 1],当设置为 0 时,不应用 Dropout。

  • 0.5

7. bidirectional

  • 类型bool,默认值为 False

  • 含义:是否使用双向 LSTM。如果设置为 True,LSTM 会同时从序列的正向和反向进行处理,然后将两个方向的输出拼接起来,这有助于模型捕捉序列中的前后文信

8. proj_size

  • 类型int,默认值为 0

  • 含义:如果设置为非零值,会在 LSTM 单元中添加投影层,将隐藏状态投影到指定的维度 proj_size。这可以减少模型的参数数量。

3.torchAPI 

利用torch库搭建一个网络

class Model(nn.Module):def __init__(self, config):super(Model, self).__init__()if config.embedding_pretrained is not None:self.embedding = nn.Embedding.from_pretrained(config.embedding_pretrained, freeze=False)else:self.embedding = nn.Embedding(config.n_vocab, config.embed, padding_idx=config.n_vocab - 1)self.lstm = nn.LSTM(config.embed, config.hidden_size, config.num_layers,bidirectional=True, batch_first=True, dropout=config.dropout)#三百维self.fc = nn.Linear(config.hidden_size * 2, config.num_classes)def forward(self, x):x, _ = x #jout = self.embedding(x)  # [batch_size, seq_len, embeding]=[128, 32, 300]out, _ = self.lstm(out)  # (output, (h_n, c_n))#out #(batch_size, seq_len, hidden_size)(当 batch_first=True 时out = self.fc(out[:, -1, :])  # 句子最后时刻的 hidden statereturn out

相关文章:

神经网络之RNN和LSTM(基于pytorch-api)

1.RNN 1.1简介 RNN用于处理序列数据。在传统的神经网络模型中,是从输入层到隐含层再到输出层,层与层之间是全连接的,每层之间的节点是无连接的。但是这种普通的神经网络对于很多问题却无能无力。例如,你要预测句子的下一个单词是…...

leetcode第39题组合总和

原题出于leetcode第39题https://leetcode.cn/problems/combination-sum/description/题目如下: 给你一个 无重复元素 的整数数组 candidates 和一个目标整数 target ,找出 candidates 中可以使数字和为目标数 target 的 所有 不同组合 ,并以…...

【UI设计——视频播放界面分享】

视频播放界面设计分享 在本次设计分享中,带来一个视频播放界面的设计作品。 此界面采用了简洁直观的布局。顶部是导航栏,包含主页、播放、搜索框等常见功能,方便用户快速找到所需操作。搜索框旁输入 “萌宠成长记”,体现了对特定内…...

动态规划刷题

文章目录 动态规划三步问题题目解析代码 动态规划 1. 状态表示:dp[i],表示dp表中i下标位置的值 2. 状态转移方程:以i位置位置的状态,最近的一步来划分问题,比如可以将状态拆分成前状态来表示现状态,dp[i] …...

stm32week5

stm32学习 二.外设 14.串口发送数据包 数据包的定义: HEX数据包(以0xFF为包头,0xFE为包尾,实际上可自定义): 固定包长,含包头包尾可变包长,含包头包尾 对于数据中不会出现包头包尾的数据可以用可变包长…...

fastapi中的patch请求

目录 示例测试使用 curl 访问:使用 requests 访问:预期返回: 浏览器访问 示例 下面是一个使用 app.patch("") 的 FastAPI 示例,该示例实现了一个简单的用户信息更新 API。我们使用 pydantic 定义数据模型,并…...

系统架构设计师—计算机基础篇—计算机网络

文章目录 网络互联模型网络协议与标准应用层协议FTP协议TFTP协议 HTTP协议HTTPS协议 DHCP动态主机配置协议DNS协议迭代查询递归查询 传输层协议网络层协议IPV4协议IPV6协议IPV6数据报的目的地址IPV4到IPV6的过渡技术 网络设计分层设计接入层汇聚层核心层 网络布线综合布线系统工…...

MATLAB中asManyOfPattern函数用法

目录 语法 说明 示例 匹配尽可能多的模式实例 指定要匹配的最小模式数 指定要匹配的最小和最大模式数 asManyOfPattern函数的功能是模式匹配次数尽可能多。 语法 newpat asManyOfPattern(pat) newpat asManyOfPattern(pat,minPattern) newpat asManyOfPattern(pat,m…...

Kafka面试题及原理

1. 消息可靠性(不丢失) 使用Kafka在消息的收发过程都会出现消息丢失,Kafka分别给出了解决方案 生产者发送消息到Brocker丢失消息在Brocker中存储丢失消费者从Brocker 幂等方案:【分布式锁、数据库锁(悲观锁、乐观锁…...

Grok 3 AI 角色扮演提示词 化身顶级设计师

Grok 3:设计下一个大型软件项目的终极工具 🔥 Grok 3 是一个革命性的工具,能够在短短 一小时 内,帮助你完成软件项目设计中最关键的步骤。无论是创建用户画像、设计网站地图,还是编写用户故事及验收标准,G…...

从零开始设计一个完整的网站:HTML、CSS、PHP、MySQL 和 JavaScript 实战教程

前言 本文将从实战角度出发,带你一步步设计一个完整的网站。我们将从 静态网页 开始,然后加入 动态功能(使用 PHP),连接 数据库,最后加入 JavaScript 实现交互功能。通过这个教程,你将掌握一个…...

CSS 对齐:深入理解与技巧实践

CSS 对齐:深入理解与技巧实践 引言 在网页设计中,元素的对齐是至关重要的。一个页面中元素的对齐方式直接影响到页面的美观度和用户体验。CSS 提供了丰富的对齐属性,使得开发者可以轻松实现各种对齐效果。本文将深入探讨 CSS 对齐的原理、方法和技巧,帮助开发者更好地掌握…...

oracle游标为什么没有共享,统计一下原因

-- Script Code为什么没共享 define sql_id bs391f0yq5tpw;set serveroutput onDECLAREv_count number;v_sql varchar2(500);v_sql_id varchar2(30) : &sql_id; BEGINv_sql_id : lower(v_sql_id);dbms_output.put_line(chr(13)||chr(10));dbms_output.put_line(sql_id: ||…...

IDEA中.gitignore未忽略指定文件的问题排查与解决

IDEA 中.gitignore 未忽略.env 文件的问题排查与解决 在使用 IntelliJ IDEA 进行项目开发时,合理利用.gitignore文件来管理版本控制是非常重要的。它能帮助我们排除一些不需要纳入版本管理的文件,比如包含敏感信息的.env文件。然而,有时我们会遇到一种情况:明明已经将.env…...

通往 AI 之路:Python 机器学习入门-语法基础

第一章 Python 语法基础 Python 是一种简单易学的编程语言,广泛用于数据分析、机器学习和人工智能领域。在学习机器学习之前,我们需要先掌握 Python 的基本语法。本章将介绍 Python 的变量与数据类型、条件语句、循环、函数以及文件操作,帮助…...

形象生动讲解Linux 虚拟化 I/O

用现实生活的比喻和简单例子来解释 Linux 虚拟化 I/O,就像给朋友讲故事一样。 虚拟化 I/O 要解决什么问题? 想象你有一栋大房子(物理服务器),想把它分割成多个小公寓(虚拟机)出租。每个租客&…...

6. Nginx 动静分离配置案例(附有详细说明+配图)

6. Nginx 动静分离配置案例(附有详细说明配图) 文章目录 6. Nginx 动静分离配置案例(附有详细说明配图)1. 动静分离概述说明2. 先使用传统方式实现,不使用 Nginx3. 使用上 Nginx 实现动静分离优化步骤4. 最后: 1. 动静分离概述说明 什么是动静分离&…...

数据集笔记:新加坡停车费

data.gov.sg 该数据集包含 新加坡各停车场的停车费,具体信息包括: 停车场名称(Carpark):如 Toa Payoh Lorong 8、Ang Mo Kio Hub、Bras Basah Complex 等。停车区域类别(Category)&#xff1a…...

SQL经典题型

查询不在表里的数据,一张学生表,一张学生的选课表,要求查出没有选课的学生? select students.student_name from students left join course_selection on students.student_idcourse_selection.student_id where course_selecti…...

最新Java面试题,常见面试题及答案汇总

Java最新常见面试题 答案汇总 原文地址:https://blog.csdn.net/sufu1065/article/details/88051083 1、面试题模块汇总 面试题包括以下十九个模块: Java 基础、容器、多线程、反射、对象拷贝、Java Web 模块、异常、网络、设计模式、Spring/Spring MVC…...

江苏艾立泰跨国资源接力:废料变黄金的绿色供应链革命

在华东塑料包装行业面临限塑令深度调整的背景下,江苏艾立泰以一场跨国资源接力的创新实践,重新定义了绿色供应链的边界。 跨国回收网络:废料变黄金的全球棋局 艾立泰在欧洲、东南亚建立再生塑料回收点,将海外废弃包装箱通过标准…...

python如何将word的doc另存为docx

将 DOCX 文件另存为 DOCX 格式(Python 实现) 在 Python 中,你可以使用 python-docx 库来操作 Word 文档。不过需要注意的是,.doc 是旧的 Word 格式,而 .docx 是新的基于 XML 的格式。python-docx 只能处理 .docx 格式…...

.Net Framework 4/C# 关键字(非常用,持续更新...)

一、is 关键字 is 关键字用于检查对象是否于给定类型兼容,如果兼容将返回 true,如果不兼容则返回 false,在进行类型转换前,可以先使用 is 关键字判断对象是否与指定类型兼容,如果兼容才进行转换,这样的转换是安全的。 例如有:首先创建一个字符串对象,然后将字符串对象隐…...

初学 pytest 记录

安装 pip install pytest用例可以是函数也可以是类中的方法 def test_func():print()class TestAdd: # def __init__(self): 在 pytest 中不可以使用__init__方法 # self.cc 12345 pytest.mark.api def test_str(self):res add(1, 2)assert res 12def test_int(self):r…...

JAVA后端开发——多租户

数据隔离是多租户系统中的核心概念,确保一个租户(在这个系统中可能是一个公司或一个独立的客户)的数据对其他租户是不可见的。在 RuoYi 框架(您当前项目所使用的基础框架)中,这通常是通过在数据表中增加一个…...

scikit-learn机器学习

# 同时添加如下代码, 这样每次环境(kernel)启动的时候只要运行下方代码即可: # Also add the following code, # so that every time the environment (kernel) starts, # just run the following code: import sys sys.path.append(/home/aistudio/external-libraries)机…...

深入浅出Diffusion模型:从原理到实践的全方位教程

I. 引言:生成式AI的黎明 – Diffusion模型是什么? 近年来,生成式人工智能(Generative AI)领域取得了爆炸性的进展,模型能够根据简单的文本提示创作出逼真的图像、连贯的文本,乃至更多令人惊叹的…...

C++--string的模拟实现

一,引言 string的模拟实现是只对string对象中给的主要功能经行模拟实现,其目的是加强对string的底层了解,以便于在以后的学习或者工作中更加熟练的使用string。本文中的代码仅供参考并不唯一。 二,默认成员函数 string主要有三个成员变量,…...

高分辨率图像合成归一化流扩展

大家读完觉得有帮助记得关注和点赞!!! 1 摘要 我们提出了STARFlow,一种基于归一化流的可扩展生成模型,它在高分辨率图像合成方面取得了强大的性能。STARFlow的主要构建块是Transformer自回归流(TARFlow&am…...

C#最佳实践:为何优先使用as或is而非强制转换

C#最佳实践:为何优先使用as或is而非强制转换 在 C# 的编程世界里,类型转换是我们经常会遇到的操作。就像在现实生活中,我们可能需要把不同形状的物品重新整理归类一样,在代码里,我们也常常需要将一个数据类型转换为另…...