当前位置: 首页 > news >正文

LSTM(长短期记忆网络)详解

1️⃣ LSTM介绍

标准的RNN存在梯度消失梯度爆炸问题,无法捕捉长期依赖关系。那么如何理解这个长期依赖关系呢?

例如,有一个语言模型基于先前的词来预测下一个词,我们有一句话 “the clouds are in the sky”,基于"the clouds are in the",预测"sky",在这样的场景中,预测的词和提供的信息之间位置间隔是非常小的,如下图所示,RNN可以捕捉到先前的信息。
在这里插入图片描述
然而,针对复杂场景,我们有一句话"I grew up in France… I speak fluent French","French"基于"France"推断,但是它们之间的间隔很远很远,RNN 会丧失学习到连接如此远信息的能力。这就是长期依赖关系。

为了解决该问题,LSTM通过引入三种门遗忘门输入门输出门控制信息的流入和流出,有助于保留长期依赖关系,并缓解梯度消失【注意:没有梯度爆炸昂】。LSTM在1997年被提出


2️⃣ 原理

下面这张图是标准的RNN结构:

  • x t x_t xt是t时刻的输入
  • s t s_t st是t时刻的隐层输出, s t = f ( U ⋅ x t + W ⋅ s t − 1 ) s_t=f(U\cdot x_t+W\cdot s_{t-1}) st=f(Uxt+Wst1),f表示激活函数, s t − 1 s_{t-1} st1表示t-1时刻的隐层输出
  • h t h_t ht是t时刻的输出, h t = s o f t m a x ( V ⋅ s t ) h_t=softmax(V\cdot s_t) ht=softmax(Vst)
    在这里插入图片描述

LSTM的整体结构如下图所示,第一眼看到,反正我是看不懂。前面讲到LSTM引入三种门遗忘门输入门输出门,现在我们逐一击破,一个个分析一下它们到底是什么。
在这里插入图片描述
这是3D视角的LSTM:
在这里插入图片描述
首先来看遗忘门,也就是下面这张图:

在这里插入图片描述

遗忘门输入包含两部分

  • s t − 1 s_{t-1} st1:表示t-1时刻的短期记忆(即隐层输出),在LSTM中当前时间步的输出 h t − 1 h_{t-1} ht1就是隐层输出 s t − 1 s_{t-1} st1
  • x t x_t xt:表示t时刻的输入

遗忘门输出为 f t f_t ft,公式表示为:
f t = σ ( W f ⋅ [ h t − 1 , x t ] + b f ) f_t=\sigma\left(W_f\cdot[h_{t-1},x_t] + b_f\right) ft=σ(Wf[ht1,xt]+bf)
其中, W f W_f Wf b f b_f bf是遗忘门的参数, [ s t − 1 , x t ] [s_{t-1},x_t] [st1,xt]表示concat操作。 σ ( ) \sigma() σ()表示sigmoid函数。

遗忘门定我们会从长期记忆中丢弃什么信息【理解为:删除什么日记】,输出一个在 0 到 1 之间的数值,1 表示“完全保留”,0 表示“完全舍弃”。

然后来看输入门
在这里插入图片描述

输入门的输入包含两部分:

  • s t − 1 s_{t-1} st1:表示t-1时刻的短期记忆
  • x t x_t xt:表示t时刻的输入

输入门的输出为新添加的内容 i t ∗ C ~ t i_t * \tilde{C}_t itC~t,其具体操作为:
i t = σ ( W i ⋅ [ s t − 1 , x t ] + b i ) C ~ t = tanh ⁡ ( W C ⋅ [ s t − 1 , x t ] + b C ) \begin{aligned}i_{t}&=\sigma\left(W_i\cdot[s_{t-1},x_t] + b_i\right)\\\tilde{C}_{t}&=\tanh(W_C\cdot[s_{t-1},x_t] + b_C)\end{aligned} itC~t=σ(Wi[st1,xt]+bi)=tanh(WC[st1,xt]+bC)

输入门决定什么样的新信息被加入到长期记忆(即细胞状态)中【理解为:添加什么日记】。

然后,我们来更新长期记忆,将 C t − 1 C_{t-1} Ct1更新为 C t C_t Ct。我们把旧状态 C t − 1 C_{t-1} Ct1与遗忘门的输出 f t f_t ft相乘,忘记一些东西。接着加上输入门的输出 i t ∗ C ~ t i_t * \tilde{C}_t itC~t,新加一些东西,最终得到新的长期记忆 C t C_t Ct。具体操作为:
在这里插入图片描述

C t = f t ∗ C t − 1 + i t ∗ C ~ t C_t=f_t*C_{t-1}+i_t*\tilde{C}_t Ct=ftCt1+itC~t

最后来看输出门
在这里插入图片描述

输出门的输入包含:

  • s t − 1 s_{t-1} st1:表示t-1时刻的短期记忆
  • x t x_t xt:表示t时刻的输入
  • c t c_t ct:更新后的长期记忆

输出门的输出为 h t h_{t} ht s t s_{t} st h t h_t ht作为当前时间步的输出, s t s_{t} st当做短期记忆输入到t+1,其具体操作为:
o t = σ ( W o [ s t − 1 , x t ] + b o ) s t = h t = o t ∗ t a n h ( C t ) \begin{aligned}&o_{t}=\sigma\left(W_{o} \left[ s_{t-1},x_{t}\right] + b_{o}\right)\\&s_{t}=h_{t}=o_{t}*\mathrm{tanh}\left(C_{t}\right)\end{aligned} ot=σ(Wo[st1,xt]+bo)st=ht=ottanh(Ct)

首先,我们运行一个 sigmoid 层来确定长期记忆的哪个部分将输出出去。接着,我们把长期记忆通过 tanh 进行处理(得到一个在-1到1之间的值)并将它和 o t o_{t} ot相乘,最终将输出copy成两份 h t h_t ht s t s_{t} st h t h_t ht作为当前时间步的输出, s t s_{t} st当做短期记忆输入到t+1。

LSTM的结构分析完了,那为什么LSTM能够缓解梯度消失呢?

我前面写的这篇文章中介绍了为什么RNN会有梯度消失和爆炸:点这里查看

主要原因是反向传播时,梯度中有这一部分:
∏ j = k + 1 3 ∂ s j ∂ s j − 1 = ∏ j = k + 1 3 t a n h ′ W \prod_{j=k+1}^3\frac{\partial s_j}{\partial s_{j-1}}=\prod_{j=k+1}^3tanh^{'}W j=k+13sj1sj=j=k+13tanhW
LSTM的作用就是让 ∂ s j ∂ s j − 1 \frac{\partial s_j}{\partial s_{j-1}} sj1sj≈1

在LSTM里,隐藏层的输出换了个符号,从 s s s变成 C C C了,即 C t = f t ∗ C t − 1 + i t ∗ C ~ t C_t=f_t*C_{t-1}+i_t*\tilde{C}_t Ct=ftCt1+itC~t。注意, f t f_t ft , i t 和 C ~ t i_{t\text{ 和}}\tilde{C}_t it C~t 都是 C t − 1 C_{t-1} Ct1的复合函数(因为它们都和 h t − 1 h_{t-1} ht1有关,而 h t − 1 h_{t-1} ht1又和 C t − 1 C_{t-1} Ct1有关)。因此我们来求一下 ∂ C t ∂ C t − 1 \frac{\partial C_t}{\partial C_{t-1}} Ct1Ct
∂ C t ∂ C t − 1 = f t + ∂ f t ∂ C t − 1 ⋅ C t − 1 + … \frac{\partial C_t}{\partial C_{t-1}}=f_t+\frac{\partial f_t}{\partial C_{t-1}}\cdot C_{t-1}+\ldots Ct1Ct=ft+Ct1ftCt1+

后面的我们就不管了,展开求导太麻烦了。这里面 f t f_t ft是遗忘门的输出,1表示完全保留旧状态,0表示完全舍弃旧状态,如果我们把 f t f_t ft设置成1或者是接近于1,那 ∂ C t ∂ C t − 1 \frac{\partial C_t}{\partial C_{t-1}} Ct1Ct就有梯度了。因此LSTM可以一定程度上缓解梯度消失,然而如果时间步很长的话,依然会存在梯度消失问题,所以只是缓解

注意:LSTM可以缓解梯度消失,但是梯度爆炸并不能解决,因为LSTM不影响参数W


3️⃣ 代码

# 创建一个LSTM模型
import torch
import torch.nn as nn
import torch.nn.functional as Fclass LSTM(nn.Module):def __init__(self,input_size,hidden_size,num_layers,output_size):super().__init__()self.num_layers=num_layersself.hidden_size=hidden_size# 定义LSTM层# batch_first=True则输入形状为(batch, seq_len, input_size)self.lstm=nn.LSTM(input_size,hidden_size,num_layers,batch_first=True)# 定义全连接层,用于输出self.fc=nn.Linear(hidden_size,output_size)def forward(self, x):# self.lstm(x)会返回两个值# out:形状为 (batch,seq_len,hidden_size)# 隐层状态和细胞状态:形状为 (batch, num_layers, hidden_size);在这里,我们忽略隐层状态和细胞状态的输出,因此使用了占位符out, _ = self.lstm(x)out = self.fc(out)return outif __name__=='__main__':input_size=10hidden_size=64num_layers=1output_size=1net=LSTM(input_size,hidden_size,num_layers,output_size)# x的形状为(batch_size, seq_len, input_size)x=torch.randn(16,8,input_size)out=net(x)print(out.shape)

输出结果为:

torch.Size([16, 8, 1]),表示有16个batch,对于每个batch,有8个时间步,每个时间步的output大小为1

4️⃣ 总结

  • 思考一个问题,对于多层LSTM,如何理解呢?
    在这里插入图片描述
    注意:图中颜色相同的其实表达的值一样, h = s h=s h=s

    1. 第一层 LSTM 首先初始隐层状态 s 0 l a y e r 1 s^{layer1}_0 s0layer1和细胞状态 c 0 l a y e r 1 c^{layer1}_0 c0layer1,然后输入 x t − 1 x_{t-1} xt1 生成隐层状态和输出 s t − 1 l a y e r 1 = h t − 1 l a r y e r 1 s^{layer1}_{t-1}=h_{t-1}^{laryer1} st1layer1=ht1laryer1和细胞状态 c t − 1 l a y e r 1 c^{layer1}_{t-1} ct1layer1
    2. 第二层 LSTM首先初始隐层状态 s 0 l a y e r 2 s^{layer2}_0 s0layer2和细胞状态 c 0 l a y e r 2 c^{layer2}_0 c0layer2,然后接收第一层的输出 h t − 1 l a r y e r 1 h_{t-1}^{laryer1} ht1laryer1作为输入,生成 s t − 1 l a y e r 2 = h t − 1 l a r y e r 2 s^{layer2}_{t-1}=h_{t-1}^{laryer2} st1layer2=ht1laryer2 c t − 1 l a y e r 2 c^{layer2}_{t-1} ct1layer2
    3. 第N层 LSTM首先初始隐层状态 s 0 l a y e r N s^{layerN}_0 s0layerN和细胞状态 c 0 l a y e r N c^{layerN}_0 c0layerN,然后接收第N-1层的输出 h t − 1 l a r y e r N − 1 h_{t-1}^{laryer N-1} ht1laryerN1作为输入,生成最终的 s t − 1 l a y e r N = h t − 1 l a r y e r 2 s^{layerN}_{t-1}=h_{t-1}^{laryer2} st1layerN=ht1laryer2 c t − 1 l a y e r N c^{layerN}_{t-1} ct1layerN
  • 为什么需要多层LSTM?
    多层 LSTM 通过增加深度来增强模型的表示能力和复杂度,能够学习到更高阶、更抽象的特征

  • 通过控制遗忘门的输出 f t f_t ft来控制梯度,以缓解梯度消失问题,但不能缓解梯度爆炸

5️⃣ 参考

  • 理解 LSTM 网络

  • 【LSTM长短期记忆网络】3D模型一目了然,带你领略算法背后的逻辑

  • 关于RNN的梯度消失&爆炸问题

相关文章:

LSTM(长短期记忆网络)详解

1️⃣ LSTM介绍 标准的RNN存在梯度消失和梯度爆炸问题,无法捕捉长期依赖关系。那么如何理解这个长期依赖关系呢? 例如,有一个语言模型基于先前的词来预测下一个词,我们有一句话 “the clouds are in the sky”,基于&…...

机器学习 贝叶斯公式

这是条件概率的计算公式 𝑃(𝐴|𝐵)𝑃(B|A)𝑃(𝐴)/𝑃(𝐵) 全概率公式 𝑃(𝐵)𝑃(𝐵|𝐴)𝑃(𝐴)&am…...

Scala-注释、标识符、变量与常量-用法详解

Scala Scala-变量和数据类型-用法详解 Scala一、注释二、标识符规范三、变量和常量1. 变量(var)2. 常量(val)3. 类型推断与显式声明4. var 和 val 的区别5. Scala与Java对比Tips: 各位看客老爷万福金安,一键…...

大数据学习14之Scala面向对象--至简原则

1.类和对象 1.1基本概念 面向对象(Object Oriented)是一种编程思想,面向对象主要是把事物给对象化,包括其属性和行为。面向对象编程更贴近实际生活的思想,总体来说面向对象的底层还是面向过程,面向过程抽象…...

docker 安装之 windows安装

文章目录 1: 在Windows安装Docker报19044版本错误的时候,请大家下载4.24.1之前的版本(含4.24.1)2: Desktop-WSL kernel version too low3: docker-compose 安装 (v2.21.0) 1: 在Windows安装Docker报19044版本错误的时候,请大家下载…...

JS 实现游戏流畅移动与按键立即响应

AWSD 按键移动 <!DOCTYPE html> <html><head><meta charset"utf-8"><title></title><style>.box1 {width: 400px;height: 400px;background: yellowgreen;margin: 0 auto;position: relative;}.box2 {width: 50px;height:…...

LabVIEW大数据处理

在物联网、工业4.0和科学实验中&#xff0c;大数据处理需求逐年上升。LabVIEW作为一款图形化编程语言&#xff0c;凭借其强大的数据采集和分析能力&#xff0c;广泛应用于实时数据处理和控制系统中。然而&#xff0c;在面对大数据处理时&#xff0c;LabVIEW也存在一些注意事项。…...

NVR录像机汇聚管理EasyNVR多品牌NVR管理工具视频汇聚技术在智慧安防监控中的应用与优势

随着信息技术的快速发展和数字化时代的到来&#xff0c;安防监控领域也在不断进行技术创新和突破。NVR管理平台EasyNVR作为视频汇聚技术的领先者&#xff0c;凭借其强大的视频处理、汇聚与融合能力&#xff0c;展现出了在安防监控领域巨大的应用潜力和价值。本文将详细介绍Easy…...

海思3403对RTSP进行目标检测

1.概述 主要功能是调过live555 testRTSPClient 简单封装的rtsp客户端库&#xff0c;拉取RTSP流&#xff0c;然后调过3403的VDEC模块进行解码&#xff0c;送个NPU进行目标检测&#xff0c;输出到hdmi&#xff0c;这样保证了开发没有sensor的时候可以识别其它摄像头的视频流&…...

Vue之插槽(slot)

插槽是vue中的一个非常强大且灵活的功能&#xff0c;在写组件时&#xff0c;可以为组件的使用者预留一些可以自定义内容的占位符。通过插槽&#xff0c;可以极大提高组件的客服用和灵活性。 插槽大体可以分为三类&#xff1a;默认插槽&#xff0c;具名插槽和作用域插槽。 下面…...

分布式服务高可用实现:复制

分布式服务高可用实现&#xff1a;复制 1. 为什么需要复制 我们可以考虑如下问题&#xff1a; 当数据量、读取或写入负载已经超过了当前服务器的处理能力&#xff0c;如何实现负载均衡&#xff1f;希望在单台服务器出现故障时仍能继续工作&#xff0c;这该如何实现&#xff…...

基于yolov8、yolov5的车型检测识别系统(含UI界面、训练好的模型、Python代码、数据集)

摘要&#xff1a;车型识别在交通管理、智能监控和车辆管理中起着至关重要的作用&#xff0c;不仅能帮助相关部门快速识别车辆类型&#xff0c;还为自动化交通监控提供了可靠的数据支撑。本文介绍了一款基于YOLOv8、YOLOv5等深度学习框架的车型识别模型&#xff0c;该模型使用了…...

机器学习—决定下一步做什么

现在已经看到了很多不同的学习算法&#xff0c;包括线性回归、逻辑回归甚至深度学习或神经网络。 关于如何构建机器学习系统的一些建议 假设你已经实现了正则化线性回归来预测房价&#xff0c;所以你有通常的学习算法的成本函数平方误差加上这个正则化项&#xff0c;但是如果…...

Java Optional详解:避免空指针异常的优雅方式

在 Java 编程中&#xff0c;空指针异常&#xff08;NullPointerException&#xff09;一直是困扰开发者的常见问题之一。为了更安全、优雅地处理可能为空的值&#xff0c;Java 8 引入了 Optional 类。Optional 提供了一种函数式的方式来表示一个值可能存在或不存在&#xff0c;…...

SpringBoot开发——整合EasyExcel实现百万级数据导入导出功能

文章目录 一、EasyExcel 框架及特性介绍二、实现步骤1、项目创建及依赖配置(pom.xml)2、项目文件结构3、配置文件(application.yml)4、启动类 Application.java5、配置类 EasyExcelConfig.java6、服务接口定义及实现 ExcelService.java7、控制器类 ExcelController.java8、…...

AcWing 1097 池塘计数 flood fill bfs搜索

代码 #include <bits/stdc.h> using namespace std;const int N 1010, M N * N;typedef pair<int, int> PII;int n, m;char g[N][N]; bool st[N][N]; PII q[M];void bfs (int xx, int yy) {int hh 0, tt -1;q[ tt] {xx, yy};st[xx][yy] true;while (hh <…...

R门 - rust第一课陈天 -内存知识学习笔记

内存 #mermaid-svg-1NFTUW33mcI2cBGB {font-family:"trebuchet ms",verdana,arial,sans-serif;font-size:16px;fill:#333;}#mermaid-svg-1NFTUW33mcI2cBGB .error-icon{fill:#552222;}#mermaid-svg-1NFTUW33mcI2cBGB .error-text{fill:#552222;stroke:#552222;}#merm…...

java itext后端生成pdf导出

public CustomApiResult<String> exportPdf(HttpServletRequest request, HttpServletResponse response) throws IOException {// 防止日志记录获取session异常request.getSession();// 设置编码格式response.setContentType("application/pdf;charsetUTF-8")…...

信号-3-信号处理

main 信号捕捉的操作 sigaction struct sigaction OS不允许信号处理方法进行嵌套&#xff1a;某一个信号正在被处理时&#xff0c;OS会自动block改信号&#xff0c;之后会自动恢复 同理&#xff0c;sigaction.sa_mask 为捕捉指定信号后临时屏蔽的表 pending什么时候清零&…...

38配置管理工具(如Ansible、Puppet、Chef)

每天五分钟学Linux | 第三十八课&#xff1a;配置管理工具&#xff08;如Ansible、Puppet、Chef&#xff09; 大家好&#xff01;欢迎再次来到我们的“每天五分钟学Linux”系列教程。在前面的课程中&#xff0c;我们学习了如何安装和配置邮件服务器。今天&#xff0c;我们将探…...

web vue 项目 Docker化部署

Web 项目 Docker 化部署详细教程 目录 Web 项目 Docker 化部署概述Dockerfile 详解 构建阶段生产阶段 构建和运行 Docker 镜像 1. Web 项目 Docker 化部署概述 Docker 化部署的主要步骤分为以下几个阶段&#xff1a; 构建阶段&#xff08;Build Stage&#xff09;&#xff1a…...

Redis相关知识总结(缓存雪崩,缓存穿透,缓存击穿,Redis实现分布式锁,如何保持数据库和缓存一致)

文章目录 1.什么是Redis&#xff1f;2.为什么要使用redis作为mysql的缓存&#xff1f;3.什么是缓存雪崩、缓存穿透、缓存击穿&#xff1f;3.1缓存雪崩3.1.1 大量缓存同时过期3.1.2 Redis宕机 3.2 缓存击穿3.3 缓存穿透3.4 总结 4. 数据库和缓存如何保持一致性5. Redis实现分布式…...

前端倒计时误差!

提示:记录工作中遇到的需求及解决办法 文章目录 前言一、误差从何而来?二、五大解决方案1. 动态校准法(基础版)2. Web Worker 计时3. 服务器时间同步4. Performance API 高精度计时5. 页面可见性API优化三、生产环境最佳实践四、终极解决方案架构前言 前几天听说公司某个项…...

高频面试之3Zookeeper

高频面试之3Zookeeper 文章目录 高频面试之3Zookeeper3.1 常用命令3.2 选举机制3.3 Zookeeper符合法则中哪两个&#xff1f;3.4 Zookeeper脑裂3.5 Zookeeper用来干嘛了 3.1 常用命令 ls、get、create、delete、deleteall3.2 选举机制 半数机制&#xff08;过半机制&#xff0…...

使用Matplotlib创建炫酷的3D散点图:数据可视化的新维度

文章目录 基础实现代码代码解析进阶技巧1. 自定义点的大小和颜色2. 添加图例和样式美化3. 真实数据应用示例实用技巧与注意事项完整示例(带样式)应用场景在数据科学和可视化领域,三维图形能为我们提供更丰富的数据洞察。本文将手把手教你如何使用Python的Matplotlib库创建引…...

【Nginx】使用 Nginx+Lua 实现基于 IP 的访问频率限制

使用 NginxLua 实现基于 IP 的访问频率限制 在高并发场景下&#xff0c;限制某个 IP 的访问频率是非常重要的&#xff0c;可以有效防止恶意攻击或错误配置导致的服务宕机。以下是一个详细的实现方案&#xff0c;使用 Nginx 和 Lua 脚本结合 Redis 来实现基于 IP 的访问频率限制…...

关于uniapp展示PDF的解决方案

在 UniApp 的 H5 环境中使用 pdf-vue3 组件可以实现完整的 PDF 预览功能。以下是详细实现步骤和注意事项&#xff1a; 一、安装依赖 安装 pdf-vue3 和 PDF.js 核心库&#xff1a; npm install pdf-vue3 pdfjs-dist二、基本使用示例 <template><view class"con…...

MyBatis中关于缓存的理解

MyBatis缓存 MyBatis系统当中默认定义两级缓存&#xff1a;一级缓存、二级缓存 默认情况下&#xff0c;只有一级缓存开启&#xff08;sqlSession级别的缓存&#xff09;二级缓存需要手动开启配置&#xff0c;需要局域namespace级别的缓存 一级缓存&#xff08;本地缓存&#…...

c# 局部函数 定义、功能与示例

C# 局部函数&#xff1a;定义、功能与示例 1. 定义与功能 局部函数&#xff08;Local Function&#xff09;是嵌套在另一个方法内部的私有方法&#xff0c;仅在包含它的方法内可见。 • 作用&#xff1a;封装仅用于当前方法的逻辑&#xff0c;避免污染类作用域&#xff0c;提升…...

数据结构:递归的种类(Types of Recursion)

目录 尾递归&#xff08;Tail Recursion&#xff09; 什么是 Loop&#xff08;循环&#xff09;&#xff1f; 复杂度分析 头递归&#xff08;Head Recursion&#xff09; 树形递归&#xff08;Tree Recursion&#xff09; 线性递归&#xff08;Linear Recursion&#xff09;…...