长短期记忆网络(Long Short-Term Memory,LSTM)
简介:个人学习分享,如有错误,欢迎批评指正。
长短期记忆网络(Long Short-Term Memory,简称LSTM)是一种特殊的循环神经网络(Recurrent Neural Network,简称RNN)架构,专门设计用于处理和预测序列数据中的长依赖关系。LSTM由Sepp Hochreiter和Jürgen Schmidhuber在1997年提出,旨在克服传统RNN在处理长序列时面临的梯度消失和梯度爆炸问题。
背景与动机
传统的RNN在处理序列数据(如时间序列、自然语言等)时,通过其循环结构能够记忆和利用先前的信息。然而,随着序列长度的增加,RNN在训练过程中会遇到梯度消失或梯度爆炸的问题,导致模型难以学习到长期依赖关系
。该限制使得RNN在许多需要捕捉长距离依赖的任务中的表现不理想。LSTM通过引入门控机制,有效地解决了这一问题,使得网络能够在更长的序列中保持信息。
一、RNN的基本结构与局限性
1. RNN的基本结构
RNN通过循环连接来处理序列数据。对于一个序列输入 ( x 1 , x 2 , … , x T ) (x_1, x_2, \ldots, x_T) (x1,x2,…,xT),RNN在每个时间步 t t t 更新新隐藏状态 h t h_t ht:
h t = tanh ( W x h x t + W h h h t − 1 + b h ) h_t = \tanh(W_{xh} x_t + W_{hh} h_{t-1} + b_h) ht=tanh(Wxhxt+Whhht−1+bh)
其中, W x h W_{xh} Wxh 和 W h h W_{hh} Whh 是权重矩阵, b h b_h bh 是偏置项, tanh \tanh tanh 是激活函数。
2. RNN的局限性
-
梯度消失与梯度爆炸: 在反向传播过程中,长序列会导致梯度在时间步上传播的迅速减小或增大,使得模型难以学习长期依赖。(因为RNN在时序上共享参数,梯度在反向传播过程中,不断连乘,数值不是越来越大就是越来越小)
-
长期依赖难以捕捉: 由于梯度衰减,RNN难以记住序列中较早的信息。(梯度小幅更新的网络层会停止学习,这些通常是较早的层。由于这些层不学习,RNN无法记住它在较长序列中学习到的内容,因此它的记忆是短期的。)
二、LSTM的核心理念
LSTM旨在解决传统循环神经网络(RNN)在处理长序列时面临的梯度消失和梯度爆炸问题。其核心思想是通过引入门控机制(Gates)来控制信息的流动,允许网络选择性地记住或遗忘信息,从而有效地捕捉长期开依赖关系
。
1. 信息流动与记忆保持
在RNN中,隐藏状态 h t h_t ht 通过时间步传递,理论上可以保留任意长的历史信息。然而,实际训练中,由于梯度在反向传播的逐步消失或爆炸,RNN难以有效学习到长距离的依赖关系。LSTM通过设计专门的结构,确保关键信息可以在长时间内被有效传递和更新。
2. 门控机制的引入
LSTM引入了三个主要的门控单元——遗忘门
(Forget Gate)、输入门
(Input Gate)和输出门
(Output Gate)。这些门通过学习动态地控制信息的保留和更新,从而实现对长期和短期记忆的有效管理。
三、LSTM的详细结构
1. LSTM单元的组成
一个标准的LSTM单元包括以下几个关键部分:
- 记忆单元(Cell State, C t C_t Ct):负责存储
长期记忆
。 - 隐藏状态(Hidden State, h t h_t ht):传递
短期记忆和输出
。 - 遗忘门(Forget Gate, f t f_t ft):决定
遗忘多少过去的信息
。 - 输入门(Input Gate, i t i_t it):决定
接受多少新信息
。 - 候选记忆单元(Candidate Cell State, C ~ t \tilde{C}_t C~t):
生成新信息
用于更新记忆单元。 - 输出门(Output Gate, o t o_t ot):决定
输出多少记忆单元的信息
。
2. 信息流动路径
信息在LSTM单元中的流动可以分为以下几个步骤:
- 遗忘阶段:决定从记忆单元中遗忘多少信息。
- 输入阶段:决定接收多少新信息,并生成候选记忆单元。
- 更新记忆单元:结合遗忘门和输入门的输出,更新记忆单元。
- 输出阶段:决定输出多少记忆单元的信息作为隐藏状态。
3. 详细数学表示
以下是每个门控单元和记忆更新的详细数学表达:
3.1 遗忘门(Forget Gate)
遗忘门决定记忆单元中哪些信息需要被遗忘
。通过一个 sigmoid 激活函数,输出值 f t f_t ft 在 0 到 1 之间,每个元素决定对应记忆单元信息的保留程度
。
f t = σ ( W f ⋅ [ h t − 1 , x t ] + b f ) f_t = \sigma (W_f \cdot [h_{t-1}, x_t] + b_f) ft=σ(Wf⋅[ht−1,xt]+bf)
- σ \sigma σ:Sigmoid 激活函数。
- W f W_f Wf:遗忘门的权重矩阵。
- h t − 1 h_{t-1} ht−1:前一时刻的隐藏状态。
- x t x_t xt:当前时刻的输入。
- b f b_f bf:遗忘门的偏置。
解释:
- 当 f t f_t ft
接近 1 时
,记忆单元中的信息被保留
。 - 当 f t f_t ft
接近 0 时
,记忆单元中的信息被遗忘
。
3.2 输入门(Input Gate)与候选记忆单元(Candidate Cell State)
输入门控制新信息的加入。它由两个部分组成:
- 输入门层(Input Gate Layer):
通过 sigmoid 函数确定哪些部分需要更新
。 - 候选记忆单元层(Candidate Cell State Layer):
通过 tanh 函数生成新的候选记忆信息
。
i t = σ ( W i ⋅ [ h t − 1 , x t ] + b i ) i_t = \sigma (W_i \cdot [h_{t-1}, x_t] + b_i) it=σ(Wi⋅[ht−1,xt]+bi)
C ~ t = tanh ( W C ⋅ [ h t − 1 , x t ] + b C ) \tilde{C}_t = \tanh (W_C \cdot [h_{t-1}, x_t] + b_C) C~t=tanh(WC⋅[ht−1,xt]+bC)
- i t i_t it:输入门的输出。
- C ~ t \tilde{C}_t C~t:候选记忆单元。
- 其他符号含义同上。
解释:
- i t i_t it 决定了
记忆单元中哪些部分将被更新
。 - C ~ t \tilde{C}_t C~t 提供了
新的信息
,用于更新记忆单元。
3.3 更新记忆单元
结合遗忘门和输入门的输出,更新记忆单元状态 C t C_t Ct。
C t = f t ⊙ C t − 1 + i t ⊙ C ~ t C_t = f_t \odot C_{t-1} + i_t \odot \tilde{C}_t Ct=ft⊙Ct−1+it⊙C~t
- ⊙ \odot ⊙:逐元素相乘。
- C t − 1 C_{t-1} Ct−1:前一时刻的记忆单元状态。
解释:
- f t ⊙ C t − 1 f_t \odot C_{t-1} ft⊙Ct−1:保留部分记忆单元中的信息。
- i t ⊙ C ~ t i_t \odot \tilde{C}_t it⊙C~t:添加新信息到记忆单元中。
3.4 输出门(Output Gate)
输出门决定了下一隐藏状态 h t h_t ht 的值,即当前时刻 LSTM 单元的输出。
o t = σ ( W o ⋅ [ h t − 1 , x t ] + b o ) o_t = \sigma (W_o \cdot [h_{t-1}, x_t] + b_o) ot=σ(Wo⋅[ht−1,xt]+bo)
h t = o t ⊙ tanh ( C t ) h_t = o_t \odot \tanh(C_t) ht=ot⊙tanh(Ct)
- o t o_t ot:输出门的输出。
- h t h_t ht:当前时刻的隐藏状态。
解释:
- o t o_t ot 决定了
记忆单元中哪些部分将被输出
。 - h t h_t ht 是通过将输出门的输出与记忆单元状态的 tanh \tanh tanh 变换相乘得到的。
四、LSTM的工作流程详解
1. 前向传播过程
在每个时间步,LSTM单元按照以下步骤进行信息处理:
- 输入接收:接收当前输入 x t x_t xt 和前一时刻的隐藏状态 h t − 1 h_{t-1} ht−1。
- 计算遗忘门 f t f_t ft:通过遗忘门决定从记忆单元中遗忘多少信息。
- 计算输入门 i t i_t it 和候选记忆单元 C ~ t \tilde{C}_t C~t:决定添加多少新信息到记忆单元中。
- 更新记忆单元 C t C_t Ct:结合遗忘门和输入门的输出,更新记忆单元状态。
- 计算输出门 o t o_t ot:决定记忆单元中的哪些信息输出。
- 生成隐藏状态 h t h_t ht:通过输出门控制的记忆单元状态生成当前时刻的隐藏状态。
2. 反向传播与梯度传播
LSTM通过反向传播算法(Backpropagation Through Time, BPTT)进行训练。在反向传播过程中,梯度需要通过时间步传递。LSTM的设计通过门控机制有效地缓解了梯度消失和梯度爆炸
的问题。
2.1 梯度流动
- 直接路径:记忆单元 C t C_t Ct 通过加法操作与 f t f_t ft 和 i t i_t it 连接,允许梯度直接在时间步上传播,从而减缓梯度消失。
- 门控制机制的调节:遗忘门和输入门通过 sigmoid 激活函数动态调节梯度的流动。当需要保留信息时,门的激活值接近1,允许梯度通过;反之则减小梯度流动。
2.2 反向传播的数学细节
假设损失函数为 L L L,则需要计算每个参数对 L L L 的偏导数。以下是主要梯度计算步骤:
- 计算损失对输出 h t h_t ht 的梯度:
∂ L ∂ h t \frac{\partial L}{\partial h_t} ∂ht∂L
- 计算输出门 o t o_t ot 的梯度:
∂ L ∂ o t = ∂ L ∂ h t ⊙ tanh ( C t ) \frac{\partial L}{\partial o_t} = \frac{\partial L}{\partial h_t} \odot \tanh(C_t) ∂ot∂L=∂ht∂L⊙tanh(Ct)
- 计算记忆单元 C t C_t Ct 的梯度:
∂ L ∂ C t = ∂ L ∂ h t ⊙ o t ⊙ ( 1 − tanh 2 ( C t ) ) + ∂ L ∂ C t + 1 ⊙ f t + 1 \frac{\partial L}{\partial C_t} = \frac{\partial L}{\partial h_t} \odot o_t \odot (1 - \tanh^2(C_t)) + \frac{\partial L}{\partial C_{t+1}} \odot f_{t+1} ∂Ct∂L=∂ht∂L⊙ot⊙(1−tanh2(Ct))+∂Ct+1∂L⊙ft+1
- 计算遗忘门 f t f_t ft 和输入门 i t i_t it 的梯度:
∂ L ∂ f t = ∂ L ∂ C t ⊙ C t − 1 \frac{\partial L}{\partial f_t} = \frac{\partial L}{\partial C_t} \odot C_{t-1} ∂ft∂L=∂Ct∂L⊙Ct−1
∂ L ∂ i t = ∂ L ∂ C t ⊙ C ~ t \frac{\partial L}{\partial i_t} = \frac{\partial L}{\partial C_t} \odot \tilde{C}_t ∂it∂L=∂Ct∂L⊙C~t
- 计算候选记忆单元 C ~ t \tilde{C}_t C~t 的梯度:
∂ L ∂ C ~ t = ∂ L ∂ C t ⊙ i t \frac{\partial L}{\partial \tilde{C}_t} = \frac{\partial L}{\partial C_t} \odot i_t ∂C~t∂L=∂Ct∂L⊙it
- 计算各个门控单元的激活函数的梯度:
∂ L ∂ z = ∂ L ∂ gate output ⋅ gate output ⋅ ( 1 − gate output ) \frac{\partial L}{\partial z} = \frac{\partial L}{\partial \text{gate output}} \cdot \text{gate output} \cdot (1 - \text{gate output}) ∂z∂L=∂gate output∂L⋅gate output⋅(1−gate output)
其中, z z z 表示门的线性组合输入。
- 更新权重和偏置:通过链式法则将梯度传递给权重和偏置,并使用优化算法(如Adam、RMSprop等)更新参数。
3. 参数更新
LSTM的参数包括遗忘门、输入门、候选记忆单元和输出门的权重和偏置。具体参数更新步骤如下:
- 计算各参数的梯度:
∂ L ∂ W f , ∂ L ∂ b f , ∂ L ∂ W i , ∂ L ∂ b i , … \frac{\partial L}{\partial W_f}, \quad \frac{\partial L}{\partial b_f}, \quad \frac{\partial L}{\partial W_i}, \quad \frac{\partial L}{\partial b_i}, \ldots ∂Wf∂L,∂bf∂L,∂Wi∂L,∂bi∂L,…
- 应用优化算法(如Adam)根据梯度更新参数:
θ = θ − η ⋅ ∂ L ∂ θ \theta = \theta - \eta \cdot \frac{\partial L}{\partial \theta} θ=θ−η⋅∂θ∂L
其中, η \eta η 是学习率, θ \theta θ 代表参数。
五、LSTM的门控制详解
1. 遗忘门(Forget Gate)
在LSTM中的第一步是决定我们会从细胞状态中丢弃什么信息。这个决定通过一个称为“遗忘门”的结构完成。该遗忘门会读取上一个输出 h t − 1 h_{t-1} ht−1 和当前输入 x t x_t xt,做一个Sigmoid 的非线性映射,然后输出一个向量 f t f_t ft (该向量每一个维度的值都在0到1之间,1表示完全保留,0表示完全舍弃,相当于记住了重要的,忘记了无关紧要的
),最后与细胞状态 C t − 1 C_{t-1} Ct−1 相乘。
遗忘门的作用是决定从记忆单元中丢弃多少过去的信息。其通过当前输入 x t x_t xt 和前一隐藏状态 h t − 1 h_{t-1} ht−1 计算得出。
f t = σ ( W f ⋅ [ h t − 1 , x t ] + b f ) f_t = \sigma (W_f \cdot [h_{t-1}, x_t] + b_f) ft=σ(Wf⋅[ht−1,xt]+bf)
特性:
- 当 f t f_t ft 接近 1 时,记忆单元中的信息被保留。
- 当 f t f_t ft 接近 0 时,记忆单元中的信息被遗忘。
重要性:
允许网络动态决定保留或丢弃信息,有助于捕捉长期依赖关系
。
2. 输入门(Input Gate)与候选记忆单元
下一步是确定什么样的新信息被存放在细胞状态中。这里包含两个部分:
输入门控制当前输入的信息如何更新到记忆单元中。它包含两个部分:
- 输入门层:
i t = σ ( W i ⋅ [ h t − 1 , x t ] + b i ) i_t = \sigma (W_i \cdot [h_{t-1}, x_t] + b_i) it=σ(Wi⋅[ht−1,xt]+bi)
sigmoid层
称“输入门层”决定了哪些部分的候选记忆单元将被更新。
- 候选记忆单元层:
C ~ t = tanh ( W C ⋅ [ h t − 1 , x t ] + b C ) \tilde{C}_t = \tanh (W_C \cdot [h_{t-1}, x_t] + b_C) C~t=tanh(WC⋅[ht−1,xt]+bC)
tanh层
生成新的候选记忆信息。
特性:
- 输入门 i t i_t it 控制新信息的流入。
- 候选记忆单元 C ~ t \tilde{C}_t C~t 提供新的信息以更新记忆单元。
细胞状态
现在是更新旧细胞状态的时间了, C t − 1 C_{t-1} Ct−1 更新为 C t C_{t} Ct 。我们把旧状态与 f t f_t ft相乘,丢弃掉我们确定需要丢弃的信息,接着加上 i t ∗ C ~ t i_t*\tilde{C}_t it∗C~t。这就是新的候选值,根据我们决定更新每个状态的程度进行变化。
3. 输出门(Output Gate)
最终,我们需要确定输出什么值。这个输出将会基于我们的细胞状态,但是也是一个过滤后的版本。
首先,我们运行一个sigmoid层来确定细胞状态的哪个部分将输出出去。
接着,我们把细胞状态通过tanh进行处理(得到一个在-1到1之间的值)并将它和sigmoid门的输出相乘,最终我们仅仅会输出我们确定输出的那部分。
输出门决定了当前时刻的隐藏状态 h t h_t ht 以及输出。
o t = σ ( W o ⋅ [ h t − 1 , x t ] + b o ) o_t = \sigma (W_o \cdot [h_{t-1}, x_t] + b_o) ot=σ(Wo⋅[ht−1,xt]+bo)
h t = o t ⊙ tanh ( C t ) h_t = o_t \odot \tanh(C_t) ht=ot⊙tanh(Ct)
特性:
- 输出门 o t o_t ot 控制了记忆单元中哪些部分的信息被输出。
- 结合记忆单元状态,通过 tanh \tanh tanh 函数提供非线性变换。
重要性:
决定了隐藏状态
h t h_t ht中包含的信息,从而影响下一时刻的计算和最终输出
。
六、信息流动与记忆更新的详细过程
1. 信息流动示例
假设有一个序列 x = [ x 1 , x 2 , … , x T ] x = [x_1, x_2, \ldots, x_T] x=[x1,x2,…,xT],LSTM 在每个时间步 t t t 的计算过程如下:
-
时间步 t t t:
- 输入 x t x_t xt 和前一隐藏状态 h t − 1 h_{t-1} ht−1。
- 计算遗忘门 f t f_t ft。
- 计算输入门 i t i_t it 和候选记忆单元 C ~ t \tilde{C}_t C~t。
- 更新记忆单元 C t C_t Ct。
- 计算输出门 o t o_t ot。
- 生成隐藏状态 h t h_t ht。
-
信息流动:
- 记忆单元 C t C_t Ct 是通过保留部分 C t − 1 C_{t-1} Ct−1 和添加新信息 C ~ t \tilde{C}_t C~t 来更新的。
- 隐藏状态 h t h_t ht 是通过输出门控制的 C t C_t Ct 的 tanh \tanh tanh 变换生成的,作为当前时刻的输出,并传递到下一个时间步。
2. 记忆更新的动态
-
保留与遗忘:通过遗忘门 f t f_t ft,网络决定了哪些历史信息需要被保留,哪些需要被遗忘。这使得网络能够保留重要的长期信息,而忽略无关的短期信息。
-
新信息的引入:通过输入门 i t i_t it 和候选记忆单元 C ~ t \tilde{C}_t C~t,网络决定了引入多少新的信息到记忆单元中,从而更新当前的记忆状态。
-
输出的生成:通过输出门 o t o_t ot,网络决定了当前记忆单元状态中哪些部分需要被输出,从而影响当前时刻的隐藏状态 h t h_t ht。
3. 实例说明
假设我们正在处理一个文本序列,目标是预测下一个单词。LSTM在每个时间步接收当前单词的嵌入向量 x t x_t xt,并基于前一时刻的隐藏状态 h t − 1 h_{t-1} ht−1 和记忆单元状态 C t − 1 C_{t-1} Ct−1 进行计算:
- 遗忘门决定:例如,网络可能决定忘记当前一个时间步的某些主题信息(如一个名词)。
- 输入门决定:网络可能决定引入新的信息(如一个动词)。
- 记忆单元更新:结合遗忘和输入门的输出,记忆单元状态被更新为保留了重要的主题信息,并引入了新的动词信息。
- 输出门决定:网络根据新的记忆单元状态生成当前的隐藏状态,用于预测下一个单词。
这种动态调整的机制使得LSTM能够在处理长文本时,保持对主题的长期记忆,同时灵活地引入新的信息。
七、LSTM的变种与扩展
LSTM有许多变种和扩展,旨在改进其性能或适应特定的应用场景。以下是几种常见的变种:
1. 双向LSTM(Bidirectional LSTM)
概述:
- 双向LSTM由
两个LSTM单元组成,一个处理序列的正向信息,另一个处理序列的反向信息
。 - 最终的隐藏状态是两个方向隐藏状态的组合(通常是拼接或求和)。
公式:
h t → = LSTM ( x t , h t − 1 → ) \overrightarrow{h_t} = \text{LSTM}(x_t, \overrightarrow{h_{t-1}}) ht=LSTM(xt,ht−1)
h t ← = LSTM ( x t , h t + 1 ← ) \overleftarrow{h_t} = \text{LSTM}(x_t, \overleftarrow{h_{t+1}}) ht=LSTM(xt,ht+1)
h t = [ h t → ; h t ← ] h_t = [\overrightarrow{h_t}; \overleftarrow{h_t}] ht=[ht;ht]
优点:
- 能够同时利用前后文信息,提高对上下文的理解能力。
应用场景:
- 自然语言处理中如命名实体识别、语义理解等任务。
2. 堆叠LSTM(Stacked LSTM)
概述:
- 堆叠LSTM通过
将多个LSTM层堆叠在一起,形成更深的网络结构
。 - 每一层的输出作为下一层的输入,增加模型的表达能力。
公式:
h t ( l ) = LSTM ( l ) ( h t ( l − 1 ) , h t − 1 ( l ) ) h_t^{(l)} = \text{LSTM}^{(l)}(h_t^{(l-1)}, h_{t-1}^{(l)}) ht(l)=LSTM(l)(ht(l−1),ht−1(l))
其中, l l l 表示层数, h t ( 0 ) = x t h_t^{(0)} = x_t ht(0)=xt。
优势:
- 提升模型的复杂度和拟合能力。
- 更好地捕捉高级特征和抽象信息。
应用场景:
- 需要深层特征提取的任务,如复杂的自然语言处理任务和时间序列预测。
3. 卷积LSTM(Convolutional LSTM, ConvLSTM)
概述:
- 卷积LSTM
结合了卷积神经网络(CNN)和LSTM,适用于处理具有空间结构的时空数据
。 - 门控制机制中的全连接操作被卷积操作取代。
公式:
f t = σ ( W f ∗ [ h t − 1 , x t ] + b f ) f_t = \sigma (W_f * [h_{t-1}, x_t] + b_f) ft=σ(Wf∗[ht−1,xt]+bf)
i t = σ ( W i ∗ [ h t − 1 , x t ] + b i ) i_t = \sigma (W_i * [h_{t-1}, x_t] + b_i) it=σ(Wi∗[ht−1,xt]+bi)
C ~ t = tanh ( W C ∗ [ h t − 1 , x t ] + b C ) \tilde{C}_t = \tanh (W_C * [h_{t-1}, x_t] + b_C) C~t=tanh(WC∗[ht−1,xt]+bC)
C t = f t ⊙ C t − 1 + i t ⊙ C ~ t C_t = f_t \odot C_{t-1} + i_t \odot \tilde{C}_t Ct=ft⊙Ct−1+it⊙C~t
o t = σ ( W o ∗ [ h t − 1 , x t ] + b o ) o_t = \sigma (W_o * [h_{t-1}, x_t] + b_o) ot=σ(Wo∗[ht−1,xt]+bo)
h t = o t ⊙ tanh ( C t ) h_t = o_t \odot \tanh(C_t) ht=ot⊙tanh(Ct)
其中, ∗ * ∗ 表示卷积操作。
优势:
- 能够捕捉时空数据中的空间依赖和时间依赖。
应用场景:
- 视频预测、天气预报、交通流量预测等。
4. Peephole LSTM
概述:
- Peephole LSTM在
门控制中引入了记忆单元状态的直接连接,使门控单元能够访问记忆单元的状态
。
公式:
f t = σ ( W f ⋅ [ h t − 1 , x t ] + V f ⋅ C t − 1 + b f ) f_t = \sigma (W_f \cdot [h_{t-1}, x_t] + V_f \cdot C_{t-1} + b_f) ft=σ(Wf⋅[ht−1,xt]+Vf⋅Ct−1+bf)
i t = σ ( W i ⋅ [ h t − 1 , x t ] + V i ⋅ C t − 1 + b i ) i_t = \sigma (W_i \cdot [h_{t-1}, x_t] + V_i \cdot C_{t-1} + b_i) it=σ(Wi⋅[ht−1,xt]+Vi⋅Ct−1+bi)
o t = σ ( W o ⋅ [ h t − 1 , x t ] + V o ⋅ C t + b o ) o_t = \sigma (W_o \cdot [h_{t-1}, x_t] + V_o \cdot C_t + b_o) ot=σ(Wo⋅[ht−1,xt]+Vo⋅Ct+bo)
优势:
- 提高模型对记忆单元状态的感知能力,增强门控机制的表现力。
应用场景:
- 需要更精细控制记忆单元状态的任务,如精确的时间序列预测。
5. 注意力机制与LSTM结合
概述:
- 将注意力机制(Attention Mechanism)与LSTM结合,
使模型能够动态地关注输入序列中与当前输出最相关的部分
。
优势:
- 提升模型的性能和解释性,尤其是在处理长序列时能够更有效地利用重要信息。
应用场景:
- 机器翻译、文本摘要、图像描述生成等任务。
示例:
- 在机器翻译中,注意力机制使得解码器在生成每个目标词时,能够关注源句子中最相关的词,从而提高翻译质量。
八、LSTM的实现细节与优化
1. 权重矩阵的初始化
权重初始化对LSTM的训练至关重要,常用的方法包括:
-
Xavier初始化(Glorot Initialization):
Variance = 2 输入维度 + 输出维度 \text{Variance} = \frac{2}{\text{输入维度} + \text{输出维度}} Variance=输入维度+输出维度2
适用于Sigmoid和tanh激活函数。
-
He初始化:
Variance = 2 输入维度 \text{Variance} = \frac{2}{\text{输入维度}} Variance=输入维度2
适用于ReLU激活函数。
重要性:
- 合适的初始化方法有助于加速收敛,防止梯度消失或爆炸。
2. 激活函数的选择
-
Sigmoid函数:
σ ( x ) = 1 1 + e − x \sigma(x) = \frac{1}{1 + e^{-x}} σ(x)=1+e−x1
- 用于门控单元,输出范围在0到1之间。
- 控制信息的流动。
-
双曲正切函数(tanh):
tanh ( x ) = e x − e − x e x + e − x \tanh(x) = \frac{e^x - e^{-x}}{e^x + e^{-x}} tanh(x)=ex+e−xex−e−x
- 用于生成候选记忆单元和输出隐藏状态。
- 输出范围在-1到1之间,提供非线性变换。
3. 正则化技术
为了防止LSTM过拟合,常用的正则化技术包括:
-
Dropout:
- 在训练过程中随机丢弃部分神经元,防止模型过于依赖某些特征。
- 适用于LSTM的各个门控单元和隐藏层。
-
L2正则化:
- 在损失函数中加入权重的平方和,限制权重的大小。
- 有助于防止权重过大,减少过拟合风险。
4. 梯度裁剪(Gradient Clipping)
梯度裁剪用于防止梯度爆炸问题,特别是在处理长序列时。
实现方法:
-
全局梯度裁剪:
将所有参数的梯度组合成一个向量,如果其范数超过预设阈值,则按比例缩放
。如果 ∥ g ∥ > 阈值 ,则 g = g ∥ g ∥ × 阈值 \text{如果} \, \|g\| > \text{阈值} \, \text{,则} \, g = \frac{g}{\|g\|} \times \text{阈值} 如果∥g∥>阈值,则g=∥g∥g×阈值
-
按参数裁剪:
分别对每个参数的梯度进行裁剪,确保每个参数的梯度在阈值范围内
。
重要性:
- 防止梯度过大导致训练不稳定或参数更新过度。
5. 批量处理与序列填充
在实际应用中,为了提高训练效率,通常采用批量处理(Batch Processing)技术。然而,序列数据的长度可能不同,需要进行填充(Padding)以统一长度。
步骤:
-
序列填充(Padding):
- 将所有序列填充到相同的长度,通常在序列的末尾添加零向量。
- 确定一个最大序列长度,超出的部分截断,不足的部分填充。
-
掩码(Masking):
- 使用掩码标记填充的位置,使得模型在计算损失和梯度时忽略填充部分。
6. 优化算法的选择
常用的优化算法包括:
-
Adam优化器:
- 结合了动量和自适应学习率的优势。
- 适用于大多数深度学习任务。
-
RMSprop:
- 适用于处理非平稳目标,常用于循环神经网络。
-
SGD(随机梯度下降):
- 适用于大规模数据,但通常需要较长的训练时间和学习率调整。
九、案例:多变量时间序列预测及python代码
案例概述
使用长短期记忆网络(LSTM)进行多变量时间序列预测。我们将以股票价格预测为例,利用多个相关特征(如开盘价、收盘价、最高价、最低价、成交量等)来预测未来的收盘价。
1. 数据收集与预处理
1.1 获取股票数据
我们将使用yfinance库从雅虎财经获取苹果公司(AAPL)的股票数据。首先,确保已安装必要的库:
pip install yfinance pandas numpy scikit-learn matplotlib tensorflow
1.2 导入必要的库
import yfinance as yf
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from sklearn.preprocessing import MinMaxScaler
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import LSTM, Dense, Dropout
from tensorflow.keras.callbacks import EarlyStopping
1.3 下载股票数据
# 下载苹果公司(AAPL)的历史股票数据
df = yf.download('AAPL', start='2010-01-01', end='2023-12-31')# 查看数据
print(df.head())
1.4 处理缺失值和异常值
# 检查缺失值
print(df.isnull().sum())# 填充缺失值(如果有)
df.fillna(method='ffill', inplace=True)
1.5 特征工程(可选)
在这个例子中,我们将使用原始的开盘价(Open)、最高价(High)、最低价(Low)、收盘价(Close)和成交量(Volume)作为特征。此外,可以创建一些技术指标,如移动平均线(MA)、相对强弱指数(RSI)等,但为了简化,我们将仅使用基本特征。
1.6 数据归一化
LSTM对数据的尺度敏感,因此需要对数据进行归一化处理。我们将使用MinMaxScaler将数据缩放到0到1之间。
# 选择特征
features = ['Open', 'High', 'Low', 'Close', 'Volume']
data = df[features]# 初始化Scaler
scaler = MinMaxScaler(feature_range=(0, 1))
scaled_data = scaler.fit_transform(data)# 将归一化后的数据转换为DataFrame
scaled_df = pd.DataFrame(scaled_data, columns=features, index=df.index)
1.7 创建时间序列样本
我们将使用过去60天的数据来预测第61天的收盘价。
def create_sequences(data, seq_length):X = []y = []for i in range(seq_length, len(data)):X.append(data[i-seq_length:i])y.append(data[i, 3]) # 'Close'的索引为3return np.array(X), np.array(y)SEQ_LENGTH = 60# 转换为numpy数组
scaled_array = scaled_df.values# 创建序列
X, y = create_sequences(scaled_array, SEQ_LENGTH)print(f'Input shape: {X.shape}')
print(f'Target shape: {y.shape}')
1.8 拆分训练集和测试集
通常,时间序列数据按时间顺序拆分,不能随机拆分。
# 定义训练集比例
TRAIN_SIZE = 0.8
train_size = int(len(X) * TRAIN_SIZE)X_train, X_test = X[:train_size], X[train_size:]
y_train, y_test = y[:train_size], y[train_size:]print(f'Training samples: {X_train.shape[0]}')
print(f'Testing samples: {X_test.shape[0]}')
2. 构建LSTM模型
2.1 定义模型架构
我们将构建一个堆叠的LSTM模型,包含两个LSTM层和一个全连接层。为了防止过拟合,我们将在LSTM层之间添加Dropout层。
# 获取输入特征数量
n_features = X_train.shape[2]# 构建模型
model = Sequential()# 第一层LSTM
model.add(LSTM(units=50, return_sequences=True, input_shape=(SEQ_LENGTH, n_features)))
model.add(Dropout(0.2))# 第二层LSTM
model.add(LSTM(units=50, return_sequences=False))
model.add(Dropout(0.2))# 全连接层
model.add(Dense(units=25))
model.add(Dense(units=1)) # 输出一个值,即预测的收盘价# 查看模型摘要
model.summary()
2.2 编译模型
我们将使用均方误差(MSE)作为损失函数,优化器选择Adam。
model.compile(optimizer='adam', loss='mean_squared_error')
3. 训练模型
3.1 训练模型
为了防止过拟合,我们将使用Early Stopping回调,当验证损失在连续5个周期内不再改善时停止训练。
# 定义Early Stopping
early_stop = EarlyStopping(monitor='val_loss', patience=5, restore_best_weights=True)# 训练模型
history = model.fit(X_train, y_train,epochs=100,batch_size=64,validation_split=0.2,callbacks=[early_stop],verbose=1
)
3.2 可视化训练过程
# 绘制训练和验证损失
plt.figure(figsize=(12, 6))
plt.plot(history.history['loss'], label='Training Loss')
plt.plot(history.history['val_loss'], label='Validation Loss')
plt.title('Model Loss During Training')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.show()
4. 评估与预测
4.1 在测试集上评估模型
# 预测
predictions = model.predict(X_test)# 反归一化
predictions = scaler.inverse_transform(np.concatenate((np.zeros((predictions.shape[0], 4)), predictions), axis=1)
)[:, 4]# 真实值反归一化
y_test_rescaled = scaler.inverse_transform(np.concatenate((np.zeros((y_test.shape[0], 4)), y_test.reshape(-1, 1)), axis=1)
)[:, 4]# 计算均方根误差(RMSE)
rmse = np.sqrt(np.mean((predictions - y_test_rescaled) ** 2))
print(f'RMSE on Test Set: {rmse:.2f}')
4.2 可视化预测结果
# 创建一个DataFrame来存储真实值和预测值
test_dates = df.index[-len(y_test):]
comparison_df = pd.DataFrame({'Date': test_dates,'Actual Close': y_test_rescaled,'Predicted Close': predictions
})# 设置日期为索引
comparison_df.set_index('Date', inplace=True)# 绘制图表
plt.figure(figsize=(14, 7))
plt.plot(comparison_df['Actual Close'], label='Actual Close Price')
plt.plot(comparison_df['Predicted Close'], label='Predicted Close Price')
plt.title('Actual vs Predicted Close Price')
plt.xlabel('Date')
plt.ylabel('Close Price USD')
plt.legend()
plt.show()
4.3 预测未来价格
为了预测未来几天的收盘价,我们需要使用最新的60天数据作为输入。
# 假设我们要预测未来5天的收盘价
future_days = 5
last_sequence = scaled_array[-SEQ_LENGTH:]for _ in range(future_days):# 预测下一天的收盘价pred = model.predict(last_sequence.reshape(1, SEQ_LENGTH, n_features))# 反归一化pred_rescaled = scaler.inverse_transform(np.concatenate((np.zeros((1, 4)), pred), axis=1))[:, 4][0]print(f'Predicted Close Price: {pred_rescaled:.2f}')# 更新序列,移除最早的一天,添加预测值# 这里我们仅更新'Close'价格,其他特征保持不变或进行合理假设new_entry = last_sequence[-1].copy()new_entry[3] = pred # 更新'Close'价格# 这里简单地将新_entry的其他特征与'Close'价格相同,实际应用中应使用更合理的策略last_sequence = np.vstack([last_sequence[1:], new_entry])
十、LSTM的优势与局限
1. 优势
- 捕捉长期依赖:通过门控机制,LSTM能够有效地捕捉和保持长期依赖信息,解决了传统RNN的梯度消失问题。
- 灵活性高:适用于各种类型的序列数据,如文本、时间序列、音频、视频等。
- 稳定的训练过程:相较于传统RNN,LSTM更容易训练,梯度消失和爆炸问题得到缓解。
- 强大的表达能力:通过堆叠和双向等变种,LSTM能够捕捉复杂的模式和特征。
2. 局限
- 计算复杂度高:LSTM单元包含多个门控机制,参数较多,计算开销较大,导致训练和推理时间较长。
- 训练时间长:由于结构复杂,尤其在处理长序列时,训练时间相对较长。
- 模型解释性有限:尽管LSTM能够有效地捕捉序列中的依赖关系,但其内部工作机制对于人类来说不够直观,解释性差。
- 过拟合风险:在数据量不足的情况下,LSTM容易过拟合,需要采取正则化措施。
总结
长短期记忆网络(LSTM)通过引入遗忘门、输入门和输出门,有效地解决了传统RNN在处理长序列时的梯度消失和梯度爆炸问题,使其能够捕捉和保持长期依赖信息。LSTM广泛应用于自然语言处理、时间序列预测、视频分析等领域,并通过各种变种(如双向LSTM、堆叠LSTM、卷积LSTM等)进一步提升了其性能和适用性。尽管LSTM在处理序列数据方面表现出色,但其计算复杂度和训练时间仍然是需要考虑的因素。随着深度学习技术的不断发展,LSTM及其衍生模型将在更多应用场景中发挥重要作用。
参考文献:
一幅图真正理解LSTM、BiLSTM
结~~~
相关文章:

长短期记忆网络(Long Short-Term Memory,LSTM)
简介:个人学习分享,如有错误,欢迎批评指正。 长短期记忆网络(Long Short-Term Memory,简称LSTM)是一种特殊的循环神经网络(Recurrent Neural Network,简称RNN)架构&#…...
WHAT - 引入第三方组件或项目使用需要注意什么
目录 1. 功能匹配2. 社区与维护3. 兼容性4. 性能5. 易用性6. 安全性7. 授权和许可证8. 国际化支持9. 依赖性10. 未来维护 在前端开发过程中引入第三方组件或项目时,应该从以下几个方面进行考虑,以确保引入的组件能够有效解决问题并适合长期维护ÿ…...

原生鸿蒙操作系统HarmonyOS NEXT(HarmonyOS 5)正式发布
华为于10月22日19:00举办“原生鸿蒙之夜暨华为全场景新品发布会”。此次发布会推出全新的原生鸿蒙操作系统HarmonyOS NEXT(HarmonyOS 5)以及nova 13、WATCH Ultimate、MatePad Pro等新品。 据介绍,此前已经发布过的鸿蒙系统,由于系…...

WindTerm配置快捷键Ctrl+C和Ctrl+V
WindTerm配置快捷键CtrlC和CtrlV 平时使用ssh和sftp连接的时候,经常使用windterm, 但是windterm里面找不到相关的快捷键设置, 因为操作习惯,想把CtrlC和CtrlV分别配置为复制和粘贴,其他的快捷键操作可以按照该方法进…...

AOP学习
corol调用serverce不在是直接调用的是调用底层代理对象,由代理对象统一帮我们处理 AOP常见概念 通知类型 切面顺序...

【ubuntu18.04】ubuntu18.04升级cmake-3.29.8及还原系统自带cmake操作说明
参考链接 cmake升级、更新(ubuntu18.04)-CSDN博客 升级cmake操作说明 下载链接 Download CMake 下载版本 下载软件包 cmake-3.30.3-linux-x86_64.tar.gz 拷贝软件包到虚拟机 cp /var/run/vmblock-fuse/blockdir/jrY8KS/cmake-3.29.8-linux-x86_64…...

利用Docker搭建一套Mycat2+MySQL8一主一从、读写分离的最简单集群(保姆教程)
文章目录 1、Mycat介绍1.1、mycat简介1.2、mycat重要概念1.3、Mycat1.x与Mycat2功能对比1.2、主从复制原理 2、前提准备3、集群规划4、安装和配置mysql主从复制4.1、master节点安装mysql8容器4.2、slave节点安装mysql8容器4.2、配置主从复制4.3、测试主从复制配置 5、安装mycat…...

算法——python实现堆排序
文章目录 堆排序二叉树堆堆排序的过程:代码实现python中的heapq模块 堆排序 二叉树 关于二叉树的操作,其实核心就是 父节点找子节点,子节点找父节点 如果要将二叉树存储到队列中,就需要找出 父子节点之间的规律: 父…...

uniapp-components(封装组件)
<myitem></myitem> 在其他类里面这样调用。...
avue-crud组件,输入框回车搜索问题
crud组件,输入框回车搜索问题。 文档是并没有标注,实际上已经具备此功能。 需要在curd的option增加属性 searchEnter: true 即可实现输入内容后回车搜索。 avue的一些踩坑记录 - 前端小小菜 - 博客园...

STM32F407ZGT6定时器相关测试
结论: 20us以下的IO翻转操作,存在误差输出比较定时器使能与禁用功能正常输入捕获定时器使能与禁用功能正常单通道输出比较、输入捕获均正常多通道输出比较波形无干扰,但仍是存在20us以下的IO翻转操作存在误差多通道输入捕获正常 一、单一通…...

群晖通过 Docker 安装 GitLab
Docker 配置容器步骤都是大同小异的,可以参考: 群晖通过 Docker 安装 Gitea-CSDN博客 1. 在 Docker 文件夹中创建 GitLab,并创建子文件夹 2. 设置权限 3. 打开 Docker 应用,并在注册表搜索 gitlab-ce 4. 选择 gitlab-ce 映像运行…...
1.Node.js环境搭建(windows)
一、环境搭建(windows) 1.1下载并安装 https://nodejs.org/dist/v18.20.4/node-v18.20.4-x64.msi1.2测试和查看版本 #cmd命令 node -v输出: #能正确输出版本号,说明安装成功 v18.20.41.3使用nodejs启动第一个js #hello.js console.log(hello world!…...

链上相遇,节点之间的悸动与牵连
公主请阅 1. 返回倒数第 k 个节点1.1 题目说明1.2 题目分析1.3 解法一代码以及解释1.3 解法二代码以及解释 2.相交链表2.1 题目说明示例 1示例 2示例 3 2.2 题目分析2.3 代码部分2.4 代码分析 1. 返回倒数第 k 个节点 题目传送门 1.1 题目说明 题目名称: 面试题 02…...

一些简单的编程题(Java与C语言)
引言: 这篇文章呢,小编将会举一些简单的编程题用来帮助大家理解一下Java代码,并且与C语言做个对比,不过这篇文章所出现的题目小编不会向随缘解题系列里面那样详细的讲解每一到题,本篇文章的主要目的是帮助小编和读者们…...

java计算机毕设课设—愤怒小鸟游戏(附源码、文章、相关截图、部署视频)
这是什么系统? 资源获取方式再最下方 java计算机毕设课设—愤怒小鸟游戏(附源码、文章、相关截图、部署视频) 基于Java的愤怒小鸟游戏,我们不仅复刻了原版游戏的核心玩法,还增加了一些创新元素。游戏以2D图形界面呈现,玩家需要…...

【ARM】MDK-Flex服务管理软件使用说明
【更多软件使用问题请点击亿道电子官方网站】 1、 文档目标 记录MDK网络版部署工具Imtools.exe 的各个界面中相关配置的功能说明 2、 问题场景 解决客户咨询,该服务管理软件如何使用,为客户使用服务管理软件后期自行维护增加一定指导作用。 3、软硬件环…...
【H2O2|全栈】WPS/Office系列有哪些好用的快捷方式?
目录 WPS/Office 前言 准备工作 Office通用快捷键 PPT快捷键 Excel快捷键 Word快捷键 结束语 WPS/Office 前言 本章节属于前端前置知识,即使不学习前端,在工作中掌握常见的WPS/Office办公技能也是十分重要的。在本篇中,我将会分享常…...

对比学习)
目录 概念 数据增强 损失函数 NCE(noise contrastive estimation) Info NCE CV上的发展 InstDisc InvaSpread CPC CMC MoCo simCLR MoCo v2 SimCLR v2 SwAV BYOL SimSiam MoCo v3 DiNO 概念 通过利用样本之间的相似性和不相似性&…...
第十六届蓝桥杯嵌入式真题
蓝桥杯嵌入式第十二届省赛真题二 蓝桥杯嵌入式第十三届省赛真题一 蓝桥杯嵌入式第十三届省赛真题二 蓝桥杯嵌入式第十四届省赛真题 蓝桥杯嵌入式第十四届模拟考试一 蓝桥杯嵌入式第十四届模拟考试二 蓝桥杯嵌入式第十五届模拟考试一 蓝桥杯嵌入式第十五届模拟考试二 蓝…...

Vue2 第一节_Vue2上手_插值表达式{{}}_访问数据和修改数据_Vue开发者工具
文章目录 1.Vue2上手-如何创建一个Vue实例,进行初始化渲染2. 插值表达式{{}}3. 访问数据和修改数据4. vue响应式5. Vue开发者工具--方便调试 1.Vue2上手-如何创建一个Vue实例,进行初始化渲染 准备容器引包创建Vue实例 new Vue()指定配置项 ->渲染数据 准备一个容器,例如: …...
Python爬虫(二):爬虫完整流程
爬虫完整流程详解(7大核心步骤实战技巧) 一、爬虫完整工作流程 以下是爬虫开发的完整流程,我将结合具体技术点和实战经验展开说明: 1. 目标分析与前期准备 网站技术分析: 使用浏览器开发者工具(F12&…...
【C语言练习】080. 使用C语言实现简单的数据库操作
080. 使用C语言实现简单的数据库操作 080. 使用C语言实现简单的数据库操作使用原生APIODBC接口第三方库ORM框架文件模拟1. 安装SQLite2. 示例代码:使用SQLite创建数据库、表和插入数据3. 编译和运行4. 示例运行输出:5. 注意事项6. 总结080. 使用C语言实现简单的数据库操作 在…...
JVM暂停(Stop-The-World,STW)的原因分类及对应排查方案
JVM暂停(Stop-The-World,STW)的完整原因分类及对应排查方案,结合JVM运行机制和常见故障场景整理而成: 一、GC相关暂停 1. 安全点(Safepoint)阻塞 现象:JVM暂停但无GC日志,日志显示No GCs detected。原因:JVM等待所有线程进入安全点(如…...
全面解析各类VPN技术:GRE、IPsec、L2TP、SSL与MPLS VPN对比
目录 引言 VPN技术概述 GRE VPN 3.1 GRE封装结构 3.2 GRE的应用场景 GRE over IPsec 4.1 GRE over IPsec封装结构 4.2 为什么使用GRE over IPsec? IPsec VPN 5.1 IPsec传输模式(Transport Mode) 5.2 IPsec隧道模式(Tunne…...

算法笔记2
1.字符串拼接最好用StringBuilder,不用String 2.创建List<>类型的数组并创建内存 List arr[] new ArrayList[26]; Arrays.setAll(arr, i -> new ArrayList<>()); 3.去掉首尾空格...

算法岗面试经验分享-大模型篇
文章目录 A 基础语言模型A.1 TransformerA.2 Bert B 大语言模型结构B.1 GPTB.2 LLamaB.3 ChatGLMB.4 Qwen C 大语言模型微调C.1 Fine-tuningC.2 Adapter-tuningC.3 Prefix-tuningC.4 P-tuningC.5 LoRA A 基础语言模型 A.1 Transformer (1)资源 论文&a…...

springboot整合VUE之在线教育管理系统简介
可以学习到的技能 学会常用技术栈的使用 独立开发项目 学会前端的开发流程 学会后端的开发流程 学会数据库的设计 学会前后端接口调用方式 学会多模块之间的关联 学会数据的处理 适用人群 在校学生,小白用户,想学习知识的 有点基础,想要通过项…...

HubSpot推出与ChatGPT的深度集成引发兴奋与担忧
上周三,HubSpot宣布已构建与ChatGPT的深度集成,这一消息在HubSpot用户和营销技术观察者中引发了极大的兴奋,但同时也存在一些关于数据安全的担忧。 许多网络声音声称,这对SaaS应用程序和人工智能而言是一场范式转变。 但向任何技…...
redis和redission的区别
Redis 和 Redisson 是两个密切相关但又本质不同的技术,它们扮演着完全不同的角色: Redis: 内存数据库/数据结构存储 本质: 它是一个开源的、高性能的、基于内存的 键值存储数据库。它也可以将数据持久化到磁盘。 核心功能: 提供丰…...