当前位置: 首页 > news >正文

[PyTorch][chapter 46][LSTM -1]

前言:

           长短期记忆网络(LSTM,Long Short-Term Memory)是一种时间循环神经网络,是为了解决一般的RNN(循环神经网络)存在的长期依赖问题而专门设计出来的。

目录:

  1.      背景简介
  2.      LSTM Cell
  3.      LSTM 反向传播算法
  4.      为什么能解决梯度消失
  5.       LSTM 模型的搭建


一  背景简介:

       1.1  RNN

         RNN 忽略o_t,L_t,y_t 模型可以简化成如下

      

       

          图中Rnn Cell 可以很清晰看出在隐藏状态h_t=f(x_t,h_{t-1})

            得到 h_t后:

              一方面用于当前层的模型损失计算,另一方面用于计算下一层的h_{t+1}

    由于RNN梯度消失的问题,后来通过LSTM 解决 

       1.2 LSTM 结构

        


二  LSTM  Cell

   LSTMCell(RNNCell) 结构

          

          前向传播算法 Forward

         2.1   更新: forget gate 忘记门

             f_t=\sigma(W_fh_{t-1}+U_{t}x_t+b_f)

             将值朝0 减少, 激活函数一般用sigmoid

             输出值[0,1]

         2.2 更新: Input gate 输入门

                i_t=\sigma(W_ih_{t-1}+U_ix_t+b_i)

                决定是不是忽略输入值

    

           2.3 更新: 候选记忆单元

                    a_t=\widetilde{c_t}=tanh(W_a h_{t-1}+U_ax_t+b_a)

           2.4 更新: 记忆单元

               c_t=f_t \odot c_{t-1}+i_t \odot a_t

             2.5  更新: 输出门

                决定是否使用隐藏值

                 o_t=\sigma(W_oh_{t-1}+U_ox_t+b_0)  

           2.6. 隐藏状态

                h_t=o_t \odot tanh(c_t)

           2.7  模型输出

                  \hat{y_t}=\sigma(Vh_t+b)

LSTM 门设计的解释一:

 输入门 ,遗忘门,输出门 不同取值组合的时候,记忆单元的输出情况


三  LSTM 反向传播推导

      3.1 定义两个\delta_t

             \delta_h^t=\frac{\partial L}{\partial h_t}

            \delta_c^t=\frac{\partial L}{\partial C_t}

    3.2  定义损失函数

            损失函数L(t)分为两部分: 

             时刻t的损失函数 l(t)

             时刻t后的损失函数L(t+1)

              L(t)=\left\{\begin{matrix} l(t)+L(t+1), if: t<T\\ l(t), if: t=T \end{matrix}\right.

      3.3 最后一个时刻\tau

              

 这里面要注意这里的o^{\tau}= Vh_{\tau}+c

    证明一下第二项,主要应用到微分的两个性质,以及微分和迹的关系:

   

   dl= tr((\frac{\partial L^{\tau}}{\partial h^{\tau}})^Tdh^{\tau})  ... 公式1: 微分和迹的关系

       =tr((\delta_h^{\tau})^Tdh^{\tau})

     因为

    h^{\tau}=o^{\tau} \odot tanh(c^{\tau})

   dh_T=o^{\tau}\odot(d(tanh (c^{\tau})))

           =o^{\tau} \odot (1-tanh^2(c^{\tau})) \odot dc^{\tau}

     带入上面公式1:

      dl= tr((\delta_h^{\tau})^T (o^{\tau}\odot(1-tanh^2(c^{\tau}))\odot dc^{\tau})

           =tr((\delta_h^{\tau} \odot o^{\tau} \odot(1-tanh^2(c^{\tau}))^Tdc^{\tau})

    所以

3.4   链式求导过程

       求导结果:

 

  这里详解一下推导过程:

  这是一个符合函数求导:先把h 写成向量形成

h=\begin{bmatrix} o_1*tanh(c_1)\\ o_2*tanh(c_2) \\ .... \\ o_n*tanh(c_n) \end{bmatrix}

 ------------------------------------------------------------   

 第一项: 

             

         h_{t+1}=o_{t+1}\odot tanh(c_{t+1})

         o_{t+1}=\sigma(W_oh_t+U_ox_{t+1}+b_0)

        设 a_{t+1}=W_oh_t+U_ox_{t+1}+b_0

           则    \frac{\partial h_{t+1}}{\partial h_{t}}=\frac{\partial h_{t+1}}{\partial o_{t+1}}\frac{\partial o_{t+1}}{\partial a_{t+1}}\frac{\partial a_{t+1}}{\partial h_{t}}

 

            其中:(利用矩阵求导的定义法 分子布局原理)

                    \frac{\partial h_{t+1}}{\partial o_{t+1}}=diag(tanh(c^{t+1})) 是一个对角矩阵

                  o=\begin{bmatrix} \sigma(a_1)\\ \sigma(a_2) \\ .... \\ \sigma(a_n) \end{bmatrix}

                 \frac{\partial o_{t+1}}{\partial a_{t+1}}=diag(o_{t+1}\odot(1-o_{t+1}))

                 \frac{\partial a_{t+1}}{\partial h_{t}}=W_o

                 几个连乘起来就是第一项

               

第二项

    c_{t+1}=f_{t+1}\odot c_t+i_{t+1}\odot a_{t+1}

   f_{t+1}=\sigma(W_fh_t+U_tx_{t+1}+b_f)

   i_{t+1}=\sigma(W_ih_t+U_i x_{t+1}+b_i)

  a_{t+1}=tanh(W_a h_t +U_ax_t +b_a)

参考:

   h=\begin{bmatrix} o_1*tanh(c_1)\\ o_2*tanh(c_2) \\ .... \\ o_n*tanh(c_n) \end{bmatrix}

其中:

\frac{\partial h_{t+1}}{\partial c^{t+1}}=diag(o^{t+1}\odot (1-tanh^2(c^{t+1}))

\frac{\partial h_{t+1}}{\partial h_{t}}=\frac{\partial h_{t+1}}{\partial c_{t+1}}\frac{\partial c_{t+1}}{\partial f_{t+1}}\frac{\partial f_{t+1}}{\partial h_{t}}

 \frac{\partial c_{t+1}}{\partial f_{t+1}}=diag(c^{t})

 \frac{\partial a_{t+1}}{\partial h_{t}}=diag(f_t \odot(1-f_t))W_f

其它也是相似,就有了上面的求导结果


四  为什么能解决梯度消失

    

     4.1 RNN 梯度消失的原理

                ,复旦大学邱锡鹏书里面 有更加详细的解释,通过极大假设:

在梯度计算中存在梯度的k 次方连乘 ,导致 梯度消失原理。

    4.2  LSTM 解决梯度消失 解释1:

            通过上面公式发现梯度计算中是加法运算,不存在连乘计算,

            极大概率降低了梯度消失的现象。

    4.3  LSTM 解决梯度 消失解释2:

              记忆单元c  作用相当于ResNet的残差部分.  

   比如f_{t}=1,\hat{c_t}=0 时候,\frac{\partial c_t}{\partial c_{t-1}}=1,不会存在梯度消失。

       


五 模型的搭建

   

    我们最后发现:

    O_t,C_t,H_t 的维度必须一致,都是hidden_size

    通过C_t,则 I_t,F_t,\tilde{c} 最后一个维度也必须是hidden_size

    

# -*- coding: utf-8 -*-
"""
Created on Thu Aug  3 15:11:19 2023@author: chengxf2
"""# -*- coding: utf-8 -*-
"""
Created on Wed Aug  2 15:34:25 2023@author: chengxf2
"""import torch
from torch import nn
from d21 import torch as d21def normal(shape,devices):data = torch.randn(size= shape, device=devices)*0.01return datadef get_lstm_params(input_size, hidden_size,categorize_size,devices):#隐藏门参数W_xf= normal((input_size, hidden_size), devices)W_hf = normal((hidden_size, hidden_size),devices)b_f = torch.zeros(hidden_size,devices)#输入门参数W_xi= normal((input_size, hidden_size), devices)W_hi = normal((hidden_size, hidden_size),devices)b_i = torch.zeros(hidden_size,devices)#输出门参数W_xo= normal((input_size, hidden_size), devices)W_ho = normal((hidden_size, hidden_size),devices)b_o = torch.zeros(hidden_size,devices)#临时记忆单元W_xc= normal((input_size, hidden_size), devices)W_hc = normal((hidden_size, hidden_size),devices)b_c = torch.zeros(hidden_size,devices)#最终分类结果参数W_hq = normal((hidden_size, categorize_size), devices)b_q = torch.zeros(categorize_size,devices)params =[W_xf,W_hf,b_f,W_xi,W_hi,b_i,W_xo,W_ho,b_o,W_xc,W_hc,b_c,W_hq,b_q]for param in params:param.requires_grad_(True)return paramsdef init_lstm_state(batch_size, hidden_size, devices):cell_init = torch.zeros((batch_size, hidden_size),device=devices)hidden_init = torch.zeros((batch_size, hidden_size),device=devices)return (cell_init, hidden_init)def lstm(inputs, state, params):[W_xf,W_hf,b_f,W_xi,W_hi,b_i,W_xo,W_ho,b_o,W_xc,W_hc,b_c,W_hq,b_q] = params    (H,C) = stateoutputs= []for x in inputs:#input gateI = torch.sigmoid((x@W_xi)+(H@W_hi)+b_i)F = torch.sigmoid((x@W_xf)+(H@W_hf)+b_f)O = torch.sigmoid((x@W_xo)+(H@W_ho)+b_o)C_tmp = torch.tanh((x@W_xc)+(H@W_hc)+b_c)C = F*C+I*C_tmpH = O*torch.tanh(C)Y = (H@W_hq)+b_qoutputs.append(Y)return torch.cat(outputs, dim=0),(H,C)def main():batch_size,num_steps =32, 35train_iter, cocab= d21.load_data_time_machine(batch_size, num_steps)if __name__ == "__main__":main()


 参考

 

CSDN

https://www.cnblogs.com/pinard/p/6519110.html

57 长短期记忆网络(LSTM)【动手学深度学习v2】_哔哩哔哩_bilibili

相关文章:

[PyTorch][chapter 46][LSTM -1]

前言&#xff1a; 长短期记忆网络&#xff08;LSTM&#xff0c;Long Short-Term Memory&#xff09;是一种时间循环神经网络&#xff0c;是为了解决一般的RNN&#xff08;循环神经网络&#xff09;存在的长期依赖问题而专门设计出来的。 目录&#xff1a; 背景简介 LSTM C…...

寄存器详解(二)

目录 内存中字的存储 示例&#xff1a; 数据段寄存器DS与[address] 字的传送 数据段简介 CPU提供的栈机制 栈段寄存器SS和栈顶指针寄存器SP PUSH AX指令的完整描述 示例图 POP AX指令的完整描述 示例图 栈顶超界问题 示例一&#xff1a; 示例二&#xff1a; 内存中字…...

Java AIO

在Java中&#xff0c;AIO代表异步I/O&#xff08;Asynchronous I/O&#xff09;&#xff0c;它是Java NIO的一个扩展&#xff0c;提供了更高级别的异步I/O操作。AIO允许应用程序执行非阻塞I/O操作&#xff0c;而无需使用Selector和手动轮询事件的方式。 与传统的NIO和Java NIO…...

java集合总结

1.常见集合 Collection List&#xff1a;有序可重复集合&#xff0c;可直接根据元素的索引来访问 Vector-StackArrayListLinkedList Queue&#xff1a;队列集合 Deque-LinkedList、ArrayDequePriorityQueue Set&#xff1a;无序不可重复集合&#xff0c;只能根据元素本身来访问…...

list交并补差集合

list交并补差集合 工具类依赖 <dependency><groupId>org.apache.commons</groupId><artifactId>commons-lang3</artifactId><version>3.8.1</version> </dependency><dependency><groupId>commons-collections&…...

【微信小程序】父组件修改子组件数据或调用子组件方法

一、使用场景 页面中用到了自定义组件形成父子组件关系&#xff0c;在父组件某个特定时期想要操作子组件中的数据或方法&#xff0c;比如离开页面的时候清空子组件的数据。 二、方法 父组件可以通过this.selectComponent方法获取子组件实例对象&#xff0c;这样就可以直接访…...

frp通过nginx映射multipart/x-mixed-replace; boundary=frame流媒体出外网访问

要通过Nginx访问multipart/x-mixed-replace流媒体协议&#xff0c;并通过FRP进行映射访问&#xff0c;你可以按照以下步骤进行操作&#xff1a; 配置Nginx以支持multipart/x-mixed-replace流媒体协议。你需要编辑Nginx的配置文件&#xff08;通常是nginx.conf&#xff09;&…...

Kubernetes概述

Kubernetes概述 使用kubeadm快速部署一个k8s集群 Kubernetes高可用集群二进制部署&#xff08;一&#xff09;主机准备和负载均衡器安装 Kubernetes高可用集群二进制部署&#xff08;二&#xff09;ETCD集群部署 Kubernetes高可用集群二进制部署&#xff08;三&#xff09;部署…...

Jmeter教程

目录 安装与配置 一&#xff1a;下载jdk——配置jdk环境变量 二&#xff1a;下载JMeter——配置环境变量 安装与配置 一&#xff1a;下载jdk——配置jdk环境变量 1.新建环境变量变量名:JAVA_HOME变量值&#xff1a;&#xff08;即JDK的安装路径&#xff09; 2.编辑Path%J…...

用Rust实现23种设计模式之建造者模式

当使用 Rust 实现建造者模式时&#xff0c;我们可以通过结构体和方法链来实现。建造者模式是一种创建型设计模式&#xff0c;它允许你按照特定的顺序构建复杂对象&#xff0c;同时使你能够灵活地构建不同的变体。下面是一个使用 Rust 实现建造者模式的示例&#xff0c; 在示例中…...

聚观早报 | 腾讯字节等企业驰援防汛救灾;新能源车7月销量单出炉

【聚观365】8月4日消息 腾讯字节等企业驰援防汛救灾新能源车7月销量成绩单出炉Model Y等车型低温续航衰减严重华为Mate60系列猜想图曝光支付宝做短视频引来羊毛党 腾讯字节等企业驰援防汛救灾 近日&#xff0c;京津冀地区遭遇极端降雨天气&#xff0c;引发洪涝和地质灾害&…...

Crack:CAD Exchanger SDK 3.20 Web Toolkit 应用

在CAD Exchanger SDK 版本 3.20.0中&#xff0c;我们在 Web Toolkit 中包含了绘图、BIM 和 MCAD 查看器的示例&#xff0c;以展示如何使用每个工具可视化数据。这些查看器具有显示不同类型数据的特定功能&#xff0c;允许用户根据自己的需求单独使用它们。我们将继续增强每个查…...

改造 ChatGPT-Next-Web 项目重新生成 Docker 镜像

改造 ChatGPT-Next-Web 项目重新生成 Docker 镜像 0.背景1. 修改代码2. 生成 Docker 镜像3. 上传 Docker 镜像4. 运行 Docker 镜像 0.背景 需要通过 ChatGPT-Next-Web 使用自己搭建的 OpenAI API 兼容的服务器&#xff0c;需要对 ChatGPT-Next-Web 项目的少量代码进行改造。 …...

git修改commit日志

由于公司对版本提交日志进行检查&#xff0c;如果不符合要求&#xff0c;则push失败。 以下是修改commit日志的方法&#xff1a; 1.进入到提交代码文件所在目录&#xff0c;即git所在目录下 cd app-repository 2.git log git log commit bf29e3e5e799d364fe2975677baf18c9…...

Qt之qml和widget混合编程调用

首先是创建一个widget项目 然后需要添加qml和quick的插件使用 QT quickwidgets qml 接着要在界面上创建一个quickwidget和按钮 创建一个c对象类 QObjectQml #ifndef QOBJECTQML_H #define QOBJECTQML_H#include <QObject> #include <QDebug> class QObjectQml …...

深度学习torch基础知识

torch. detach()拼接函数torch.stack()torch.nn.DataParallel()np.clip()torch.linspace()PyTorch中tensor.repeat()pytorch索引查找 index_select detach() detach是截断反向传播的梯度流 将某个node变成不需要梯度的Varibale。因此当反向传播经过这个node时&#xff0c;梯度…...

【JAVA】正则表达式是啥?

个人主页&#xff1a;【&#x1f60a;个人主页】 系列专栏&#xff1a;【❤️初识JAVA】 文章目录 前言正则表达式正则表达式语法正则表达式的特点捕获组实例 前言 如果我们想要判断给定的字符串是否符合正则表达式的过滤逻辑&#xff08;称作“匹配”&#xff09;&#xff0c…...

网络安全之原型链污染

目录&#xff1a; 目录&#xff1a; 一、概念 二、举例 三、 实操了解 总结 四、抛出原题&#xff0c;历年原题复现 第一题&#xff1a; 五、分析与原理 第二题&#xff1a; 八、分析与原理 九、具体操作&#xff0c;payload与结果 结果&#xff1a; 一、概念 Java…...

【腾讯云Cloud Studio实战训练营】使用Cloud Studio迅捷开发一个3D家具个性化定制应用

目录 前言&#xff1a; 一、腾讯云 Cloud Studio介绍&#xff1a; 1、接近本地 IDE 的开发体验 2、多环境可选&#xff0c;或连接到云主机 3、随时分享预览效果 4、兼容 VSCode 插件 5、 AI代码助手 二、腾讯云Cloud Studio项目实践&#xff08;3D家具个性化定制应用&…...

【计算机网络】第四章 网络层(一)

文章目录 第四章 网络层4.1 网络层概述4.2 网络层提供的两种服务4.2.1 小结 第四章 网络层 网络层是计算机网络体系结构中的一个关键层&#xff0c;位于传输层上方、数据链路层下方。它负责将传输层提供的数据分割成适当大小的数据包&#xff0c;并在不同网络之间进行路由选择和…...

(LeetCode 每日一题) 3442. 奇偶频次间的最大差值 I (哈希、字符串)

题目&#xff1a;3442. 奇偶频次间的最大差值 I 思路 &#xff1a;哈希&#xff0c;时间复杂度0(n)。 用哈希表来记录每个字符串中字符的分布情况&#xff0c;哈希表这里用数组即可实现。 C版本&#xff1a; class Solution { public:int maxDifference(string s) {int a[26]…...

分布式增量爬虫实现方案

之前我们在讨论的是分布式爬虫如何实现增量爬取。增量爬虫的目标是只爬取新产生或发生变化的页面&#xff0c;避免重复抓取&#xff0c;以节省资源和时间。 在分布式环境下&#xff0c;增量爬虫的实现需要考虑多个爬虫节点之间的协调和去重。 另一种思路&#xff1a;将增量判…...

零基础在实践中学习网络安全-皮卡丘靶场(第九期-Unsafe Fileupload模块)(yakit方式)

本期内容并不是很难&#xff0c;相信大家会学的很愉快&#xff0c;当然对于有后端基础的朋友来说&#xff0c;本期内容更加容易了解&#xff0c;当然没有基础的也别担心&#xff0c;本期内容会详细解释有关内容 本期用到的软件&#xff1a;yakit&#xff08;因为经过之前好多期…...

NXP S32K146 T-Box 携手 SD NAND(贴片式TF卡):驱动汽车智能革新的黄金组合

在汽车智能化的汹涌浪潮中&#xff0c;车辆不再仅仅是传统的交通工具&#xff0c;而是逐步演变为高度智能的移动终端。这一转变的核心支撑&#xff0c;来自于车内关键技术的深度融合与协同创新。车载远程信息处理盒&#xff08;T-Box&#xff09;方案&#xff1a;NXP S32K146 与…...

IP如何挑?2025年海外专线IP如何购买?

你花了时间和预算买了IP&#xff0c;结果IP质量不佳&#xff0c;项目效率低下不说&#xff0c;还可能带来莫名的网络问题&#xff0c;是不是太闹心了&#xff1f;尤其是在面对海外专线IP时&#xff0c;到底怎么才能买到适合自己的呢&#xff1f;所以&#xff0c;挑IP绝对是个技…...

Webpack性能优化:构建速度与体积优化策略

一、构建速度优化 1、​​升级Webpack和Node.js​​ ​​优化效果​​&#xff1a;Webpack 4比Webpack 3构建时间降低60%-98%。​​原因​​&#xff1a; V8引擎优化&#xff08;for of替代forEach、Map/Set替代Object&#xff09;。默认使用更快的md4哈希算法。AST直接从Loa…...

逻辑回归暴力训练预测金融欺诈

简述 「使用逻辑回归暴力预测金融欺诈&#xff0c;并不断增加特征维度持续测试」的做法&#xff0c;体现了一种逐步建模与迭代验证的实验思路&#xff0c;在金融欺诈检测中非常有价值&#xff0c;本文作为一篇回顾性记录了早年间公司给某行做反欺诈预测用到的技术和思路。百度…...

永磁同步电机无速度算法--基于卡尔曼滤波器的滑模观测器

一、原理介绍 传统滑模观测器采用如下结构&#xff1a; 传统SMO中LPF会带来相位延迟和幅值衰减&#xff0c;并且需要额外的相位补偿。 采用扩展卡尔曼滤波器代替常用低通滤波器(LPF)&#xff0c;可以去除高次谐波&#xff0c;并且不用相位补偿就可以获得一个误差较小的转子位…...

【深度学习新浪潮】什么是credit assignment problem?

Credit Assignment Problem(信用分配问题) 是机器学习,尤其是强化学习(RL)中的核心挑战之一,指的是如何将最终的奖励或惩罚准确地分配给导致该结果的各个中间动作或决策。在序列决策任务中,智能体执行一系列动作后获得一个最终奖励,但每个动作对最终结果的贡献程度往往…...

goreplay

1.github地址 https://github.com/buger/goreplay 2.简单介绍 GoReplay 是一个开源的网络监控工具&#xff0c;可以记录用户的实时流量并将其用于镜像、负载测试、监控和详细分析。 3.出现背景 随着应用程序的增长&#xff0c;测试它所需的工作量也会呈指数级增长。GoRepl…...