57 长短期记忆网络(LSTM)_by《李沐:动手学深度学习v2》pytorch版
系列文章目录
文章目录
- 系列文章目录
- 长短期记忆网络(LSTM)
- 门控记忆元
- 输入门、忘记门和输出门
- 候选记忆元 (相当于RNN中计算 H t H_t Ht)
- 记忆元
- 隐状态
- 从零开始实现
- 初始化模型参数
- 定义模型
- 训练和预测
- 简洁实现
- 小结
- 练习
长短期记忆网络(LSTM)
长期以来,隐变量模型存在着长期信息保存和短期输入缺失的问题。
解决这一问题的最早方法之一是长短期存储器(long short-term memory,LSTM)它有许多与门控循环单元(GRU)一样的属性。有趣的是,长短期记忆网络的设计比门控循环单元稍微复杂一些,却比门控循环单元早诞生了近20年。
门控记忆元
可以说,长短期记忆网络的设计灵感来自于计算机的逻辑门。
长短期记忆网络引入了记忆元(memory cell),或简称为单元(cell)。
有些文献认为记忆元是隐状态的一种特殊类型,它们与隐状态具有相同的形状,其设计目的是用于记录附加的信息。
为了控制记忆元,我们需要许多门。
其中一个门用来从单元中输出条目,我们将其称为输出门(output gate)。
另外一个门用来决定何时将数据读入单元,我们将其称为输入门(input gate)。
我们还需要一种机制来重置单元的内容,由遗忘门(forget gate)来管理,这种设计的动机与门控循环单元相同,能够通过专用机制决定什么时候记忆或忽略隐状态中的输入。让我们看看这在实践中是如何运作的。
输入门、忘记门和输出门
就如在门控循环单元中一样,当前时间步的输入和前一个时间步的隐状态作为数据送入长短期记忆网络的门中,如下图所示。它们由三个具有sigmoid激活函数的全连接层处理,以计算输入门、遗忘门和输出门的值。
因此,这三个门的值都在 ( 0 , 1 ) (0, 1) (0,1)的范围内。
label:
lstm_0
我们来细化一下长短期记忆网络的数学表达。
假设有 h h h个隐藏单元,批量大小为 n n n,输入数为 d d d。
因此,输入为 X t ∈ R n × d \mathbf{X}_t \in \mathbb{R}^{n \times d} Xt∈Rn×d,前一时间步的隐状态为 H t − 1 ∈ R n × h \mathbf{H}_{t-1} \in \mathbb{R}^{n \times h} Ht−1∈Rn×h。
相应地,时间步 t t t的门被定义如下:
输入门是 I t ∈ R n × h \mathbf{I}_t \in \mathbb{R}^{n \times h} It∈Rn×h,
遗忘门是 F t ∈ R n × h \mathbf{F}_t \in \mathbb{R}^{n \times h} Ft∈Rn×h,
输出门是 O t ∈ R n × h \mathbf{O}_t \in \mathbb{R}^{n \times h} Ot∈Rn×h。
它们的计算方法如下:
I t = σ ( X t W x i + H t − 1 W h i + b i ) , F t = σ ( X t W x f + H t − 1 W h f + b f ) , O t = σ ( X t W x o + H t − 1 W h o + b o ) , \begin{aligned} \mathbf{I}_t &= \sigma(\mathbf{X}_t \mathbf{W}_{xi} + \mathbf{H}_{t-1} \mathbf{W}_{hi} + \mathbf{b}_i),\\ \mathbf{F}_t &= \sigma(\mathbf{X}_t \mathbf{W}_{xf} + \mathbf{H}_{t-1} \mathbf{W}_{hf} + \mathbf{b}_f),\\ \mathbf{O}_t &= \sigma(\mathbf{X}_t \mathbf{W}_{xo} + \mathbf{H}_{t-1} \mathbf{W}_{ho} + \mathbf{b}_o), \end{aligned} ItFtOt=σ(XtWxi+Ht−1Whi+bi),=σ(XtWxf+Ht−1Whf+bf),=σ(XtWxo+Ht−1Who+bo),
其中 W x i , W x f , W x o ∈ R d × h \mathbf{W}_{xi}, \mathbf{W}_{xf}, \mathbf{W}_{xo} \in \mathbb{R}^{d \times h} Wxi,Wxf,Wxo∈Rd×h和 W h i , W h f , W h o ∈ R h × h \mathbf{W}_{hi}, \mathbf{W}_{hf}, \mathbf{W}_{ho} \in \mathbb{R}^{h \times h} Whi,Whf,Who∈Rh×h是权重参数, b i , b f , b o ∈ R 1 × h \mathbf{b}_i, \mathbf{b}_f, \mathbf{b}_o \in \mathbb{R}^{1 \times h} bi,bf,bo∈R1×h是偏置参数。
候选记忆元 (相当于RNN中计算 H t H_t Ht)
由于还没有指定各种门的操作,所以先介绍候选记忆元(candidate memory cell) C ~ t ∈ R n × h \tilde{\mathbf{C}}_t \in \mathbb{R}^{n \times h} C~t∈Rn×h。
它的计算与上面描述的三个门的计算类似,但是使用 tanh \tanh tanh函数作为激活函数,函数的值范围为 ( − 1 , 1 ) (-1, 1) (−1,1)。
下面导出在时间步 t t t处的方程:
C ~ t = tanh ( X t W x c + H t − 1 W h c + b c ) , \tilde{\mathbf{C}}_t = \text{tanh}(\mathbf{X}_t \mathbf{W}_{xc} + \mathbf{H}_{t-1} \mathbf{W}_{hc} + \mathbf{b}_c), C~t=tanh(XtWxc+Ht−1Whc+bc),
其中 W x c ∈ R d × h \mathbf{W}_{xc} \in \mathbb{R}^{d \times h} Wxc∈Rd×h和 W h c ∈ R h × h \mathbf{W}_{hc} \in \mathbb{R}^{h \times h} Whc∈Rh×h是权重参数, b c ∈ R 1 × h \mathbf{b}_c \in \mathbb{R}^{1 \times h} bc∈R1×h是偏置参数。
候选记忆元的如下图 :numref:lstm_1
所示。
label:
lstm_1
记忆元
在门控循环单元中,有一种机制来控制输入和遗忘(或跳过)。
类似地,在长短期记忆网络中,也有两个门用于这样的目的:
输入门 I t \mathbf{I}_t It控制采用多少来自 C ~ t \tilde{\mathbf{C}}_t C~t的新数据,而遗忘门 F t \mathbf{F}_t Ft控制保留多少过去的记忆元 C t − 1 ∈ R n × h \mathbf{C}_{t-1} \in \mathbb{R}^{n \times h} Ct−1∈Rn×h的内容。
使用按元素乘法,得出:
C t = F t ⊙ C t − 1 + I t ⊙ C ~ t . \mathbf{C}_t = \mathbf{F}_t \odot \mathbf{C}_{t-1} + \mathbf{I}_t \odot \tilde{\mathbf{C}}_t. Ct=Ft⊙Ct−1+It⊙C~t.
如果遗忘门始终为 1 1 1且输入门始终为 0 0 0,则过去的记忆元 C t − 1 \mathbf{C}_{t-1} Ct−1将随时间被保存并传递到当前时间步。
引入这种设计是为了缓解梯度消失问题,并更好地捕获序列中的长距离依赖关系。
这样我们就得到了计算记忆元的流程图,如 :numref:lstm_2
。
label:
lstm_2
隐状态
最后,我们需要定义如何计算隐状态 H t ∈ R n × h \mathbf{H}_t \in \mathbb{R}^{n \times h} Ht∈Rn×h,这就是输出门发挥作用的地方。
在长短期记忆网络中,它仅仅是记忆元的 tanh \tanh tanh的门控版本。
这就确保了 H t \mathbf{H}_t Ht的值始终在区间 ( − 1 , 1 ) (-1, 1) (−1,1)内:
H t = O t ⊙ tanh ( C t ) . \mathbf{H}_t = \mathbf{O}_t \odot \tanh(\mathbf{C}_t). Ht=Ot⊙tanh(Ct).
只要输出门接近 1 1 1,我们就能够有效地将所有记忆信息传递给预测部分,而对于输出门接近 0 0 0,我们只保留记忆元内的所有信息,而不需要更新隐状态(相当于重置隐状态)。
下图 :numref:lstm_3
提供了数据流的图形化演示。
label:
lstm_3
从零开始实现
现在,我们从零开始实现长短期记忆网络。我们首先加载时光机器数据集。
import torch
from torch import nn
from d2l import torch as d2lbatch_size, num_steps = 32, 35
train_iter, vocab = d2l.load_data_time_machine(batch_size, num_steps)
初始化模型参数
接下来,我们需要定义和初始化模型参数。
如前所述,超参数num_hiddens
定义隐藏单元的数量。
我们按照标准差 0.01 0.01 0.01的高斯分布初始化权重,并将偏置项设为 0 0 0。
def get_lstm_params(vocab_size, num_hiddens, device):num_inputs = num_outputs = vocab_sizedef normal(shape):return torch.randn(size=shape, device=device)*0.01def three():return (normal((num_inputs, num_hiddens)),normal((num_hiddens, num_hiddens)),torch.zeros(num_hiddens, device=device))W_xi, W_hi, b_i = three() # 输入门参数W_xf, W_hf, b_f = three() # 遗忘门参数W_xo, W_ho, b_o = three() # 输出门参数W_xc, W_hc, b_c = three() # 候选记忆元参数# 输出层参数W_hq = normal((num_hiddens, num_outputs))b_q = torch.zeros(num_outputs, device=device)# 附加梯度params = [W_xi, W_hi, b_i, W_xf, W_hf, b_f, W_xo, W_ho, b_o, W_xc, W_hc,b_c, W_hq, b_q]for param in params:param.requires_grad_(True)return params
定义模型
在[初始化函数]中,长短期记忆网络的隐状态需要返回一个额外的记忆元,单元的值为0,形状为(批量大小,隐藏单元数)。因此,我们得到以下的状态初始化。
def init_lstm_state(batch_size, num_hiddens, device):return (torch.zeros((batch_size, num_hiddens), device=device),torch.zeros((batch_size, num_hiddens), device=device))
[实际模型]的定义与我们前面讨论的一样:
提供三个门和一个额外的记忆元。
请注意,只有隐状态才会传递到输出层,而记忆元 C t \mathbf{C}_t Ct不直接参与输出计算。
def lstm(inputs, state, params):[W_xi, W_hi, b_i, W_xf, W_hf, b_f, W_xo, W_ho, b_o, W_xc, W_hc, b_c,W_hq, b_q] = params(H, C) = stateoutputs = []for X in inputs:I = torch.sigmoid((X @ W_xi) + (H @ W_hi) + b_i)F = torch.sigmoid((X @ W_xf) + (H @ W_hf) + b_f)O = torch.sigmoid((X @ W_xo) + (H @ W_ho) + b_o)C_tilda = torch.tanh((X @ W_xc) + (H @ W_hc) + b_c)C = F * C + I * C_tildaH = O * torch.tanh(C)Y = (H @ W_hq) + b_qoutputs.append(Y) #Y的shape是(批量大小,词表长度)只有这里输出了批量大小的预测,之后才能用来计算损失return torch.cat(outputs, dim=0), (H, C)
训练和预测
让我们通过实例化RNN从零实现中引入的RNNModelScratch类来训练一个长短期记忆网络,就如我们在GRU中所做的一样。
vocab_size, num_hiddens, device = len(vocab), 256, d2l.try_gpu()
num_epochs, lr = 500, 1
model = d2l.RNNModelScratch(len(vocab), num_hiddens, device, get_lstm_params,init_lstm_state, lstm)
d2l.train_ch8(model, train_iter, vocab, lr, num_epochs, device)
perplexity 14.5, 27965.3 tokens/sec on cuda:0
time traveller te at at at at at at at at at at at at at at at a
traveller te at at at at at at at at at at at at at at at a<Figure size 350x250 with 1 Axes>
简洁实现
使用高级API,我们可以直接实例化LSTM
模型。
高级API封装了前文介绍的所有配置细节。
这段代码的运行速度要快得多,因为它使用的是编译好的运算符而不是Python来处理之前阐述的许多细节。
num_inputs = vocab_size
lstm_layer = nn.LSTM(num_inputs, num_hiddens)
model = d2l.RNNModel(lstm_layer, len(vocab))
model = model.to(device)
d2l.train_ch8(model, train_iter, vocab, lr, num_epochs, device)
perplexity 11.2, 233619.5 tokens/sec on cuda:0
time traveller the the the the the the the the the the the the t
traveller the the the the the the the the the the the the t<Figure size 350x250 with 1 Axes>
长短期记忆网络是典型的具有重要状态控制的隐变量自回归模型。
多年来已经提出了其许多变体,例如,多层、残差连接、不同类型的正则化。
然而,由于序列的长距离依赖性,训练长短期记忆网络和其他序列模型(例如门控循环单元)的成本是相当高的。
在后面的内容中,我们将讲述更高级的替代模型,如Transformer。
小结
- 长短期记忆网络有三种类型的门:输入门、遗忘门和输出门。
- 长短期记忆网络的隐藏层输出包括“隐状态”和“记忆元”。只有隐状态会传递到输出层,而记忆元完全属于内部信息。
- 长短期记忆网络可以缓解梯度消失和梯度爆炸。
练习
- 调整和分析超参数对运行时间、困惑度和输出顺序的影响。
- 如何更改模型以生成适当的单词,而不是字符序列?
- 在给定隐藏层维度的情况下,比较门控循环单元、长短期记忆网络和常规循环神经网络的计算成本。要特别注意训练和推断成本。
- 既然候选记忆元通过使用 tanh \tanh tanh函数来确保值范围在 ( − 1 , 1 ) (-1,1) (−1,1)之间,那么为什么隐状态需要再次使用 tanh \tanh tanh函数来确保输出值范围在 ( − 1 , 1 ) (-1,1) (−1,1)之间呢?
- 实现一个能够基于时间序列进行预测而不是基于字符序列进行预测的长短期记忆网络模型。
相关文章:

57 长短期记忆网络(LSTM)_by《李沐:动手学深度学习v2》pytorch版
系列文章目录 文章目录 系列文章目录长短期记忆网络(LSTM)门控记忆元输入门、忘记门和输出门候选记忆元 (相当于RNN中计算 H t H_t Ht)记忆元隐状态 从零开始实现初始化模型参数定义模型训练和预测 简洁实现小结练习 长短期记忆网络(LSTM&a…...

Linux系统安装教程
Linux安装流程 一、前置准备工作二、开始安装Linux 一、前置准备工作 安装好VMWare虚拟机,并下载Linux系统的安装包; Linux安装包路径为:安装包链接 , 提取码为:4tiM 二、开始安装Linux...

Redis: Sentinel工作原理和故障迁移流程
Sentinel 哨兵几个核心概念 1 ) 定时任务 Sentinel 它是如何工作的,是如何感知到其他的 Sentinel 节点以及 Master/Slave节点的就是通过它的一系列定时任务来做到的,它内部有三个定时任务 第一个就是每一秒每个 Sentinel 对其他 Sentinel 和 Redis 节点…...

通信工程学习:什么是IGMP因特网组管理协议
IGMP:因特网组管理协议 IGMP(Internet Group Management Protocol,因特网组管理协议)是TCP/IP协议簇中负责组播成员管理的协议。它主要用于在用户主机和与其直接相连的组播路由器之间建立和维护组播组成员关系。以下是关于IGMP协议…...
高效批量导入多个SQL文件至SQL Server数据库的实用方法
当需要批量导入多个SQL文件到SQL Server数据库时,可以通过以下几种方法来实现: 方法一:使用SQLCMD命令行工具(亲测可用) 准备SQL文件:确保所有的SQL文件都位于同一个文件夹内,并且文件扩展名为…...

【树莓派系列】树莓派wiringPi库详解,官方外设开发
树莓派wiringPi库详解,官方外设开发 文章目录 树莓派wiringPi库详解,官方外设开发一、安装wiringPi库二、wiringPi库API大全1.硬件初始化函数2.通用GPIO控制函数3.时间控制函数4.串口通信串口API串口通信配置多串口通信配置串口自发自收测试串口间通信测…...
前端模块化CommonJs、ESM、AMD总结
前端开发模式进化史 前端工程化正是为了应对这些演化中出现的挑战和需求而发展起来的: 前后端混合:服务端渲染,javascript仅实现交互前后端分离:借助 ajax 实现前后端分离、单页应用(SPA)等新模式模块化开发:npm 管理…...

JavaWeb - 8 - 请求响应 分层解耦
请求响应 请求(HttpServletRequest):获取请求数据 响应(HttpServletResponse):设置响应数据 BS架构:Browser/Server,浏览器/服务器架构模式。客户端只需要浏览器,应用程…...
1G,2G,3G,4G,5G各代通信技术的关键技术,联系和区别
目录 1G2G3G4G5G各代通信技术的联系和区别联系区别 1G 1G的主要特点是无线移动化。关键技术为蜂窝组网,支持频率复用和移动切换,可以实现个人和个人移动状态下不间断的语音通信。 1G通信系统现已关闭,其主要缺点是串好和盗号。 2G 数字化…...

【宽搜】2. leetcode 102 二叉树的层序遍历
题目描述 题目链接:二叉树的层序遍历 根据上一篇文章的模板可以直接写代码,需要改变的就是将N叉树的child改为二叉树的left和right。 代码 class Solution { public:vector<vector<int>> levelOrder(TreeNode* root) {vector<vector&…...
Go语言实现长连接并发框架 - 请求分发器
文章目录 前言接口结构体接口实现项目地址最后 前言 你好,我是醉墨居士,我们上篇博客实现了任务管理器的功能,接下来这篇博客我们将要实现请求分发模块的开发 接口 trait/dispatcher.go type Dispatcher interface {Start()Dispatch(conn…...

Redis: 集群测试和集群原理
集群测试 1 ) SET/GET 命令 测试 set 和 get 因为其他命令也基本相似,我们在 101 节点上尝试连接 103 $ /usr/local/redis/bin/redis-cli -c -a 123456 -h 192.168.10.103 -p 6376我们在插入或读取一个 key的时候,会对这个key做一个hash运算,…...

问题解决实录 | bash 中 tmux 颜色显示不全
点我进入博客 如下图,tmux 中颜色显示不全: echo $TERM输出的是 screen 但在 bash 里面输出的是 xterm-256 color 在 bash 里面输入: touch ~/.tmux.conf vim ~/.tmux.conf set -g default-terminal "xterm-256color"使之生效 source …...
古典舞在线交流平台:SpringBoot设计与实现详解
摘 要 随着互联网技术的发展,各类网站应运而生,网站具有新颖、展现全面的特点。因此,为了满足用户古典舞在线交流的需求,特开发了本古典舞在线交流平台。 本古典舞在线交流平台应用Java技术,MYSQL数据库存储数据&#…...

五子棋双人对战项目(6)——对战模块(解读代码)
目录 一、约定前后端交互接口的参数 1、房间准备就绪 (1)配置 websocket 连接路径 (2)构造 游戏就绪 的 响应对象 2、“落子” 的请求和响应 (1)“落子” 请求对象 (2)“落子…...

查缺补漏----I/O中断处理过程
中断优先级包括响应优先级和处理优先级,响应优先级由硬件线路或查询程序的查询顺序决定,不可动态改变。处理优先级可利用中断屏蔽技术动态调整,以实现多重中断。下面来看他们如何运用在中断处理过程中: 中断控制器位于CPU和外设之…...
Java API接口开发规范
文章目录 一、命名规范1.1 接口命名1.2 变量命名 二、接收参数规范2.1 请求体(Body)2.2 查询参数(Query Parameters) 三、参数检验四、接收方式规范五、异常类处理六、统一返回格式的定义七、API接口的幂等性(Idempote…...
Go语言实现长连接并发框架 - 任务管理器
文章目录 前言接口结构体接口实现项目地址最后 前言 你好,我是醉墨居士,我们上篇博客实现了路由分组的功能,接下来这篇博客我们将要实现任务管理模块 接口 trait/task_mgr.go type TaskMgr interface {RouterGroupStart()StartWorker(tas…...
【大数据】深入解析分布式数据库:架构、技术与未来
目录 1. 分布式数据库的定义2. 架构类型2.1 主从架构2.2 同步与异步复制2.3 分片架构 3. 技术实现3.1 一致性模型3.2 CAP理论3.3 数据存储引擎 4. 应用场景5. 选择分布式数据库的因素5.1 数据一致性需求5.2 读写负载5.3 成本5.4 技术栈兼容性 6. 未来发展趋势总结 分布式数据库…...

uniapp框架中实现文件选择上传组件,可以选择图片、视频等任意文件并上传到当前绑定的服务空间
前言 uni-file-picker是uniapp中的一个文件选择器组件,用于选择本地文件并返回选择的文件路径或文件信息。该组件支持选择单个文件或多个文件,可以设置文件的类型、大小限制,并且可以进行文件预览。 提示:以下是本篇文章正文内容,下面案例可供参考 uni-file-picker组件具…...

【kafka】Golang实现分布式Masscan任务调度系统
要求: 输出两个程序,一个命令行程序(命令行参数用flag)和一个服务端程序。 命令行程序支持通过命令行参数配置下发IP或IP段、端口、扫描带宽,然后将消息推送到kafka里面。 服务端程序: 从kafka消费者接收…...
进程地址空间(比特课总结)
一、进程地址空间 1. 环境变量 1 )⽤户级环境变量与系统级环境变量 全局属性:环境变量具有全局属性,会被⼦进程继承。例如当bash启动⼦进程时,环 境变量会⾃动传递给⼦进程。 本地变量限制:本地变量只在当前进程(ba…...
树莓派超全系列教程文档--(62)使用rpicam-app通过网络流式传输视频
使用rpicam-app通过网络流式传输视频 使用 rpicam-app 通过网络流式传输视频UDPTCPRTSPlibavGStreamerRTPlibcamerasrc GStreamer 元素 文章来源: http://raspberry.dns8844.cn/documentation 原文网址 使用 rpicam-app 通过网络流式传输视频 本节介绍来自 rpica…...

Xshell远程连接Kali(默认 | 私钥)Note版
前言:xshell远程连接,私钥连接和常规默认连接 任务一 开启ssh服务 service ssh status //查看ssh服务状态 service ssh start //开启ssh服务 update-rc.d ssh enable //开启自启动ssh服务 任务二 修改配置文件 vi /etc/ssh/ssh_config //第一…...

第一篇:Agent2Agent (A2A) 协议——协作式人工智能的黎明
AI 领域的快速发展正在催生一个新时代,智能代理(agents)不再是孤立的个体,而是能够像一个数字团队一样协作。然而,当前 AI 生态系统的碎片化阻碍了这一愿景的实现,导致了“AI 巴别塔问题”——不同代理之间…...
【Web 进阶篇】优雅的接口设计:统一响应、全局异常处理与参数校验
系列回顾: 在上一篇中,我们成功地为应用集成了数据库,并使用 Spring Data JPA 实现了基本的 CRUD API。我们的应用现在能“记忆”数据了!但是,如果你仔细审视那些 API,会发现它们还很“粗糙”:有…...

【Oracle】分区表
个人主页:Guiat 归属专栏:Oracle 文章目录 1. 分区表基础概述1.1 分区表的概念与优势1.2 分区类型概览1.3 分区表的工作原理 2. 范围分区 (RANGE Partitioning)2.1 基础范围分区2.1.1 按日期范围分区2.1.2 按数值范围分区 2.2 间隔分区 (INTERVAL Partit…...
2023赣州旅游投资集团
单选题 1.“不登高山,不知天之高也;不临深溪,不知地之厚也。”这句话说明_____。 A、人的意识具有创造性 B、人的认识是独立于实践之外的 C、实践在认识过程中具有决定作用 D、人的一切知识都是从直接经验中获得的 参考答案: C 本题解…...

C# 求圆面积的程序(Program to find area of a circle)
给定半径r,求圆的面积。圆的面积应精确到小数点后5位。 例子: 输入:r 5 输出:78.53982 解释:由于面积 PI * r * r 3.14159265358979323846 * 5 * 5 78.53982,因为我们只保留小数点后 5 位数字。 输…...
IP如何挑?2025年海外专线IP如何购买?
你花了时间和预算买了IP,结果IP质量不佳,项目效率低下不说,还可能带来莫名的网络问题,是不是太闹心了?尤其是在面对海外专线IP时,到底怎么才能买到适合自己的呢?所以,挑IP绝对是个技…...