当前位置: 首页 > news >正文

LSTM(长短期记忆网络)详解

1️⃣ LSTM介绍

标准的RNN存在梯度消失梯度爆炸问题,无法捕捉长期依赖关系。那么如何理解这个长期依赖关系呢?

例如,有一个语言模型基于先前的词来预测下一个词,我们有一句话 “the clouds are in the sky”,基于"the clouds are in the",预测"sky",在这样的场景中,预测的词和提供的信息之间位置间隔是非常小的,如下图所示,RNN可以捕捉到先前的信息。
在这里插入图片描述
然而,针对复杂场景,我们有一句话"I grew up in France… I speak fluent French","French"基于"France"推断,但是它们之间的间隔很远很远,RNN 会丧失学习到连接如此远信息的能力。这就是长期依赖关系。

为了解决该问题,LSTM通过引入三种门遗忘门输入门输出门控制信息的流入和流出,有助于保留长期依赖关系,并缓解梯度消失【注意:没有梯度爆炸昂】。LSTM在1997年被提出


2️⃣ 原理

下面这张图是标准的RNN结构:

  • x t x_t xt是t时刻的输入
  • s t s_t st是t时刻的隐层输出, s t = f ( U ⋅ x t + W ⋅ s t − 1 ) s_t=f(U\cdot x_t+W\cdot s_{t-1}) st=f(Uxt+Wst1),f表示激活函数, s t − 1 s_{t-1} st1表示t-1时刻的隐层输出
  • h t h_t ht是t时刻的输出, h t = s o f t m a x ( V ⋅ s t ) h_t=softmax(V\cdot s_t) ht=softmax(Vst)
    在这里插入图片描述

LSTM的整体结构如下图所示,第一眼看到,反正我是看不懂。前面讲到LSTM引入三种门遗忘门输入门输出门,现在我们逐一击破,一个个分析一下它们到底是什么。
在这里插入图片描述
这是3D视角的LSTM:
在这里插入图片描述
首先来看遗忘门,也就是下面这张图:

在这里插入图片描述

遗忘门输入包含两部分

  • s t − 1 s_{t-1} st1:表示t-1时刻的短期记忆(即隐层输出),在LSTM中当前时间步的输出 h t − 1 h_{t-1} ht1就是隐层输出 s t − 1 s_{t-1} st1
  • x t x_t xt:表示t时刻的输入

遗忘门输出为 f t f_t ft,公式表示为:
f t = σ ( W f ⋅ [ h t − 1 , x t ] + b f ) f_t=\sigma\left(W_f\cdot[h_{t-1},x_t] + b_f\right) ft=σ(Wf[ht1,xt]+bf)
其中, W f W_f Wf b f b_f bf是遗忘门的参数, [ s t − 1 , x t ] [s_{t-1},x_t] [st1,xt]表示concat操作。 σ ( ) \sigma() σ()表示sigmoid函数。

遗忘门定我们会从长期记忆中丢弃什么信息【理解为:删除什么日记】,输出一个在 0 到 1 之间的数值,1 表示“完全保留”,0 表示“完全舍弃”。

然后来看输入门
在这里插入图片描述

输入门的输入包含两部分:

  • s t − 1 s_{t-1} st1:表示t-1时刻的短期记忆
  • x t x_t xt:表示t时刻的输入

输入门的输出为新添加的内容 i t ∗ C ~ t i_t * \tilde{C}_t itC~t,其具体操作为:
i t = σ ( W i ⋅ [ s t − 1 , x t ] + b i ) C ~ t = tanh ⁡ ( W C ⋅ [ s t − 1 , x t ] + b C ) \begin{aligned}i_{t}&=\sigma\left(W_i\cdot[s_{t-1},x_t] + b_i\right)\\\tilde{C}_{t}&=\tanh(W_C\cdot[s_{t-1},x_t] + b_C)\end{aligned} itC~t=σ(Wi[st1,xt]+bi)=tanh(WC[st1,xt]+bC)

输入门决定什么样的新信息被加入到长期记忆(即细胞状态)中【理解为:添加什么日记】。

然后,我们来更新长期记忆,将 C t − 1 C_{t-1} Ct1更新为 C t C_t Ct。我们把旧状态 C t − 1 C_{t-1} Ct1与遗忘门的输出 f t f_t ft相乘,忘记一些东西。接着加上输入门的输出 i t ∗ C ~ t i_t * \tilde{C}_t itC~t,新加一些东西,最终得到新的长期记忆 C t C_t Ct。具体操作为:
在这里插入图片描述

C t = f t ∗ C t − 1 + i t ∗ C ~ t C_t=f_t*C_{t-1}+i_t*\tilde{C}_t Ct=ftCt1+itC~t

最后来看输出门
在这里插入图片描述

输出门的输入包含:

  • s t − 1 s_{t-1} st1:表示t-1时刻的短期记忆
  • x t x_t xt:表示t时刻的输入
  • c t c_t ct:更新后的长期记忆

输出门的输出为 h t h_{t} ht s t s_{t} st h t h_t ht作为当前时间步的输出, s t s_{t} st当做短期记忆输入到t+1,其具体操作为:
o t = σ ( W o [ s t − 1 , x t ] + b o ) s t = h t = o t ∗ t a n h ( C t ) \begin{aligned}&o_{t}=\sigma\left(W_{o} \left[ s_{t-1},x_{t}\right] + b_{o}\right)\\&s_{t}=h_{t}=o_{t}*\mathrm{tanh}\left(C_{t}\right)\end{aligned} ot=σ(Wo[st1,xt]+bo)st=ht=ottanh(Ct)

首先,我们运行一个 sigmoid 层来确定长期记忆的哪个部分将输出出去。接着,我们把长期记忆通过 tanh 进行处理(得到一个在-1到1之间的值)并将它和 o t o_{t} ot相乘,最终将输出copy成两份 h t h_t ht s t s_{t} st h t h_t ht作为当前时间步的输出, s t s_{t} st当做短期记忆输入到t+1。

LSTM的结构分析完了,那为什么LSTM能够缓解梯度消失呢?

我前面写的这篇文章中介绍了为什么RNN会有梯度消失和爆炸:点这里查看

主要原因是反向传播时,梯度中有这一部分:
∏ j = k + 1 3 ∂ s j ∂ s j − 1 = ∏ j = k + 1 3 t a n h ′ W \prod_{j=k+1}^3\frac{\partial s_j}{\partial s_{j-1}}=\prod_{j=k+1}^3tanh^{'}W j=k+13sj1sj=j=k+13tanhW
LSTM的作用就是让 ∂ s j ∂ s j − 1 \frac{\partial s_j}{\partial s_{j-1}} sj1sj≈1

在LSTM里,隐藏层的输出换了个符号,从 s s s变成 C C C了,即 C t = f t ∗ C t − 1 + i t ∗ C ~ t C_t=f_t*C_{t-1}+i_t*\tilde{C}_t Ct=ftCt1+itC~t。注意, f t f_t ft , i t 和 C ~ t i_{t\text{ 和}}\tilde{C}_t it C~t 都是 C t − 1 C_{t-1} Ct1的复合函数(因为它们都和 h t − 1 h_{t-1} ht1有关,而 h t − 1 h_{t-1} ht1又和 C t − 1 C_{t-1} Ct1有关)。因此我们来求一下 ∂ C t ∂ C t − 1 \frac{\partial C_t}{\partial C_{t-1}} Ct1Ct
∂ C t ∂ C t − 1 = f t + ∂ f t ∂ C t − 1 ⋅ C t − 1 + … \frac{\partial C_t}{\partial C_{t-1}}=f_t+\frac{\partial f_t}{\partial C_{t-1}}\cdot C_{t-1}+\ldots Ct1Ct=ft+Ct1ftCt1+

后面的我们就不管了,展开求导太麻烦了。这里面 f t f_t ft是遗忘门的输出,1表示完全保留旧状态,0表示完全舍弃旧状态,如果我们把 f t f_t ft设置成1或者是接近于1,那 ∂ C t ∂ C t − 1 \frac{\partial C_t}{\partial C_{t-1}} Ct1Ct就有梯度了。因此LSTM可以一定程度上缓解梯度消失,然而如果时间步很长的话,依然会存在梯度消失问题,所以只是缓解

注意:LSTM可以缓解梯度消失,但是梯度爆炸并不能解决,因为LSTM不影响参数W


3️⃣ 代码

# 创建一个LSTM模型
import torch
import torch.nn as nn
import torch.nn.functional as Fclass LSTM(nn.Module):def __init__(self,input_size,hidden_size,num_layers,output_size):super().__init__()self.num_layers=num_layersself.hidden_size=hidden_size# 定义LSTM层# batch_first=True则输入形状为(batch, seq_len, input_size)self.lstm=nn.LSTM(input_size,hidden_size,num_layers,batch_first=True)# 定义全连接层,用于输出self.fc=nn.Linear(hidden_size,output_size)def forward(self, x):# self.lstm(x)会返回两个值# out:形状为 (batch,seq_len,hidden_size)# 隐层状态和细胞状态:形状为 (batch, num_layers, hidden_size);在这里,我们忽略隐层状态和细胞状态的输出,因此使用了占位符out, _ = self.lstm(x)out = self.fc(out)return outif __name__=='__main__':input_size=10hidden_size=64num_layers=1output_size=1net=LSTM(input_size,hidden_size,num_layers,output_size)# x的形状为(batch_size, seq_len, input_size)x=torch.randn(16,8,input_size)out=net(x)print(out.shape)

输出结果为:

torch.Size([16, 8, 1]),表示有16个batch,对于每个batch,有8个时间步,每个时间步的output大小为1

4️⃣ 总结

  • 思考一个问题,对于多层LSTM,如何理解呢?
    在这里插入图片描述
    注意:图中颜色相同的其实表达的值一样, h = s h=s h=s

    1. 第一层 LSTM 首先初始隐层状态 s 0 l a y e r 1 s^{layer1}_0 s0layer1和细胞状态 c 0 l a y e r 1 c^{layer1}_0 c0layer1,然后输入 x t − 1 x_{t-1} xt1 生成隐层状态和输出 s t − 1 l a y e r 1 = h t − 1 l a r y e r 1 s^{layer1}_{t-1}=h_{t-1}^{laryer1} st1layer1=ht1laryer1和细胞状态 c t − 1 l a y e r 1 c^{layer1}_{t-1} ct1layer1
    2. 第二层 LSTM首先初始隐层状态 s 0 l a y e r 2 s^{layer2}_0 s0layer2和细胞状态 c 0 l a y e r 2 c^{layer2}_0 c0layer2,然后接收第一层的输出 h t − 1 l a r y e r 1 h_{t-1}^{laryer1} ht1laryer1作为输入,生成 s t − 1 l a y e r 2 = h t − 1 l a r y e r 2 s^{layer2}_{t-1}=h_{t-1}^{laryer2} st1layer2=ht1laryer2 c t − 1 l a y e r 2 c^{layer2}_{t-1} ct1layer2
    3. 第N层 LSTM首先初始隐层状态 s 0 l a y e r N s^{layerN}_0 s0layerN和细胞状态 c 0 l a y e r N c^{layerN}_0 c0layerN,然后接收第N-1层的输出 h t − 1 l a r y e r N − 1 h_{t-1}^{laryer N-1} ht1laryerN1作为输入,生成最终的 s t − 1 l a y e r N = h t − 1 l a r y e r 2 s^{layerN}_{t-1}=h_{t-1}^{laryer2} st1layerN=ht1laryer2 c t − 1 l a y e r N c^{layerN}_{t-1} ct1layerN
  • 为什么需要多层LSTM?
    多层 LSTM 通过增加深度来增强模型的表示能力和复杂度,能够学习到更高阶、更抽象的特征

  • 通过控制遗忘门的输出 f t f_t ft来控制梯度,以缓解梯度消失问题,但不能缓解梯度爆炸

5️⃣ 参考

  • 理解 LSTM 网络

  • 【LSTM长短期记忆网络】3D模型一目了然,带你领略算法背后的逻辑

  • 关于RNN的梯度消失&爆炸问题

相关文章:

LSTM(长短期记忆网络)详解

1️⃣ LSTM介绍 标准的RNN存在梯度消失和梯度爆炸问题,无法捕捉长期依赖关系。那么如何理解这个长期依赖关系呢? 例如,有一个语言模型基于先前的词来预测下一个词,我们有一句话 “the clouds are in the sky”,基于&…...

机器学习 贝叶斯公式

这是条件概率的计算公式 𝑃(𝐴|𝐵)𝑃(B|A)𝑃(𝐴)/𝑃(𝐵) 全概率公式 𝑃(𝐵)𝑃(𝐵|𝐴)𝑃(𝐴)&am…...

Scala-注释、标识符、变量与常量-用法详解

Scala Scala-变量和数据类型-用法详解 Scala一、注释二、标识符规范三、变量和常量1. 变量(var)2. 常量(val)3. 类型推断与显式声明4. var 和 val 的区别5. Scala与Java对比Tips: 各位看客老爷万福金安,一键…...

大数据学习14之Scala面向对象--至简原则

1.类和对象 1.1基本概念 面向对象(Object Oriented)是一种编程思想,面向对象主要是把事物给对象化,包括其属性和行为。面向对象编程更贴近实际生活的思想,总体来说面向对象的底层还是面向过程,面向过程抽象…...

docker 安装之 windows安装

文章目录 1: 在Windows安装Docker报19044版本错误的时候,请大家下载4.24.1之前的版本(含4.24.1)2: Desktop-WSL kernel version too low3: docker-compose 安装 (v2.21.0) 1: 在Windows安装Docker报19044版本错误的时候,请大家下载…...

JS 实现游戏流畅移动与按键立即响应

AWSD 按键移动 <!DOCTYPE html> <html><head><meta charset"utf-8"><title></title><style>.box1 {width: 400px;height: 400px;background: yellowgreen;margin: 0 auto;position: relative;}.box2 {width: 50px;height:…...

LabVIEW大数据处理

在物联网、工业4.0和科学实验中&#xff0c;大数据处理需求逐年上升。LabVIEW作为一款图形化编程语言&#xff0c;凭借其强大的数据采集和分析能力&#xff0c;广泛应用于实时数据处理和控制系统中。然而&#xff0c;在面对大数据处理时&#xff0c;LabVIEW也存在一些注意事项。…...

NVR录像机汇聚管理EasyNVR多品牌NVR管理工具视频汇聚技术在智慧安防监控中的应用与优势

随着信息技术的快速发展和数字化时代的到来&#xff0c;安防监控领域也在不断进行技术创新和突破。NVR管理平台EasyNVR作为视频汇聚技术的领先者&#xff0c;凭借其强大的视频处理、汇聚与融合能力&#xff0c;展现出了在安防监控领域巨大的应用潜力和价值。本文将详细介绍Easy…...

海思3403对RTSP进行目标检测

1.概述 主要功能是调过live555 testRTSPClient 简单封装的rtsp客户端库&#xff0c;拉取RTSP流&#xff0c;然后调过3403的VDEC模块进行解码&#xff0c;送个NPU进行目标检测&#xff0c;输出到hdmi&#xff0c;这样保证了开发没有sensor的时候可以识别其它摄像头的视频流&…...

Vue之插槽(slot)

插槽是vue中的一个非常强大且灵活的功能&#xff0c;在写组件时&#xff0c;可以为组件的使用者预留一些可以自定义内容的占位符。通过插槽&#xff0c;可以极大提高组件的客服用和灵活性。 插槽大体可以分为三类&#xff1a;默认插槽&#xff0c;具名插槽和作用域插槽。 下面…...

分布式服务高可用实现:复制

分布式服务高可用实现&#xff1a;复制 1. 为什么需要复制 我们可以考虑如下问题&#xff1a; 当数据量、读取或写入负载已经超过了当前服务器的处理能力&#xff0c;如何实现负载均衡&#xff1f;希望在单台服务器出现故障时仍能继续工作&#xff0c;这该如何实现&#xff…...

基于yolov8、yolov5的车型检测识别系统(含UI界面、训练好的模型、Python代码、数据集)

摘要&#xff1a;车型识别在交通管理、智能监控和车辆管理中起着至关重要的作用&#xff0c;不仅能帮助相关部门快速识别车辆类型&#xff0c;还为自动化交通监控提供了可靠的数据支撑。本文介绍了一款基于YOLOv8、YOLOv5等深度学习框架的车型识别模型&#xff0c;该模型使用了…...

机器学习—决定下一步做什么

现在已经看到了很多不同的学习算法&#xff0c;包括线性回归、逻辑回归甚至深度学习或神经网络。 关于如何构建机器学习系统的一些建议 假设你已经实现了正则化线性回归来预测房价&#xff0c;所以你有通常的学习算法的成本函数平方误差加上这个正则化项&#xff0c;但是如果…...

Java Optional详解:避免空指针异常的优雅方式

在 Java 编程中&#xff0c;空指针异常&#xff08;NullPointerException&#xff09;一直是困扰开发者的常见问题之一。为了更安全、优雅地处理可能为空的值&#xff0c;Java 8 引入了 Optional 类。Optional 提供了一种函数式的方式来表示一个值可能存在或不存在&#xff0c;…...

SpringBoot开发——整合EasyExcel实现百万级数据导入导出功能

文章目录 一、EasyExcel 框架及特性介绍二、实现步骤1、项目创建及依赖配置(pom.xml)2、项目文件结构3、配置文件(application.yml)4、启动类 Application.java5、配置类 EasyExcelConfig.java6、服务接口定义及实现 ExcelService.java7、控制器类 ExcelController.java8、…...

AcWing 1097 池塘计数 flood fill bfs搜索

代码 #include <bits/stdc.h> using namespace std;const int N 1010, M N * N;typedef pair<int, int> PII;int n, m;char g[N][N]; bool st[N][N]; PII q[M];void bfs (int xx, int yy) {int hh 0, tt -1;q[ tt] {xx, yy};st[xx][yy] true;while (hh <…...

R门 - rust第一课陈天 -内存知识学习笔记

内存 #mermaid-svg-1NFTUW33mcI2cBGB {font-family:"trebuchet ms",verdana,arial,sans-serif;font-size:16px;fill:#333;}#mermaid-svg-1NFTUW33mcI2cBGB .error-icon{fill:#552222;}#mermaid-svg-1NFTUW33mcI2cBGB .error-text{fill:#552222;stroke:#552222;}#merm…...

java itext后端生成pdf导出

public CustomApiResult<String> exportPdf(HttpServletRequest request, HttpServletResponse response) throws IOException {// 防止日志记录获取session异常request.getSession();// 设置编码格式response.setContentType("application/pdf;charsetUTF-8")…...

信号-3-信号处理

main 信号捕捉的操作 sigaction struct sigaction OS不允许信号处理方法进行嵌套&#xff1a;某一个信号正在被处理时&#xff0c;OS会自动block改信号&#xff0c;之后会自动恢复 同理&#xff0c;sigaction.sa_mask 为捕捉指定信号后临时屏蔽的表 pending什么时候清零&…...

38配置管理工具(如Ansible、Puppet、Chef)

每天五分钟学Linux | 第三十八课&#xff1a;配置管理工具&#xff08;如Ansible、Puppet、Chef&#xff09; 大家好&#xff01;欢迎再次来到我们的“每天五分钟学Linux”系列教程。在前面的课程中&#xff0c;我们学习了如何安装和配置邮件服务器。今天&#xff0c;我们将探…...

网络技术-定义配置ACL规则的语法和命令

定义ACL&#xff08;访问控制列表&#xff09;规则时&#xff0c;具体命令会根据所使用的设备和操作系统而有所不同。以下是一些常见的设备和操作系统中定义ACL规则的命令示例&#xff1a; 一&#xff1a;思科&#xff08;Cisco&#xff09;路由器/交换机 在思科设备中&#…...

动态规划一>子数组系列

题目&#xff1a; 2.解析&#xff1a; 代码&#xff1a; public int maxSubArray(int[] nums) {int n nums.length;int[] dp new int[n 1];int ret Integer.MIN_VALUE;for(int i 1; i < n; i){dp[i] Math.max(nums[i - 1], dp[i - 1] nums[i - 1]);ret Math.max(…...

一觉睡醒,全世界计算机水平下降100倍,而我却精通C语言——scanf函数

大家好啊&#xff0c;我是小象٩(๑ω๑)۶ 我的博客&#xff1a;Xiao Fei Xiangζั͡ޓއއ 很高兴见到大家&#xff0c;希望能够和大家一起交流学习&#xff0c;共同进步。* 这一节我们主要来学习scanf的基本用法&#xff0c;了解scanf返回值&#xff0c;懂得scanf占位符和…...

Altium Designer使用技巧(五)

一、敷铜(快捷键T-G-A) 1、工具栏点“设计”—>“规则”。 可以修改覆铜连线的宽度&#xff0c;也可以选择“直接连接”使得铜和网络节点完全相连。 这里方式有三种&#xff0c;可根据需要各个人习惯去设置。 2、敷铜与线全链接 &#xff08;1&#xff09;少量的话可右键“…...

Docker 的安装与使用

Docker 的安装 Docker 是一个开源的商业产品&#xff0c;有两个版本&#xff1a;社区版&#xff08;Community Edition&#xff0c;缩写为 CE&#xff09;和企业版&#xff08;Enterprise Edition&#xff0c;缩写为 EE&#xff09;。 Docker CE 的安装请参考官方文档&#xf…...

Android Studio 中三方库依赖无法找到的解决方案

目录 错误信息解析 解决方案 1. 检查依赖版本 2. 检查 Maven 仓库配置 3. 强制刷新 Gradle 缓存 4. 检查网络连接 5. 手动下载依赖 总结 相关推荐 最近&#xff0c;我在编译一个 Android 老项目时遇到了一个问题&#xff0c;错误信息显示无法找到 com.gyf.immersionba…...

PGMP练-DAY24

DAY241A program has completed and closed. The training for the receiving organization has also delivered. But the stakeholders still concern that the benefits cannot be realized in long term.What does the program manager review to improve the situation?项…...

【C++动态规划 最长公共子序列】1035. 不相交的线|1805

本文涉及知识点 C动态规划 LeetCode1035. 不相交的线 在两条独立的水平线上按给定的顺序写下 nums1 和 nums2 中的整数。 现在&#xff0c;可以绘制一些连接两个数字 nums1[i] 和 nums2[j] 的直线&#xff0c;这些直线需要同时满足&#xff1a; nums1[i] nums2[j] 且绘制的…...

FFmpeg的基本结构

FFmpeg框架可以简单分为两层&#xff0c;上层是以ffmpeg、ffplay、ffprobe为代表的命令行工具&#xff1b;其底层支撑是一些基础库&#xff0c;包含AVFormat、AVCodec、AVFilter、AVDevices、AVUtils等模块库。 常用函数如下&#xff1a; 1. AVFormat 封装/解封装模块 avf…...

react 受控组件和非受控组件

在 React 中&#xff0c;受控组件和非受控组件是两种处理表单元素&#xff08;如输入框、选择框等&#xff09;值的方式。 1. 受控组件 受控组件是指 React 组件的表单元素的值是由 React 组件的 state 来管理的。换句话说&#xff0c;React 会全程控制表单元素的值&#xff…...