当前位置: 首页 > news >正文

【论文复现】LSTM长短记忆网络

LSTM

  • 前言
  • 网络架构
    • 总线
    • 遗忘门
    • 记忆门
    • 记忆细胞
    • 输出门
  • 模型定义
    • 单个LSTM神经元的定义
    • LSTM层内结构的定义
  • 模型训练
  • 模型评估
  • 代码细节
    • LSTM层单元的首尾的处理
    • 配置Tensorflow的GPU版本

前言

LSTM作为经典模型,可以用来做语言模型,实现类似于语言模型的功能,同时还经常用于做时间序列。由于LSTM的原版论文相关版权问题,这里以colah大佬的博客为基础进行讲解。之前写过一篇Tensorflow中的LSTM详解,但是原理部分跟代码部分的联系并不紧密,实践性较强但是如果想要进行更加深入的调试就会出现原理性上面的问题,因此特此作文解决这个问题,想要用LSTM这个有趣的模型做出更加好的机器学习效果😊。

网络架构

LSTM框架图
这张图展示了LSTM在整体结构,下面就开始分部分介绍中间这个东东。

总线

在这里插入图片描述
这条是总线,可以实现神经元结构的保存或者更改,如果就是像上图一样一条总线贯穿不做任何改变,那么就是不改变细胞状态。那么如果想要改变细胞状态怎么办?可以通过来实现,这里的门跟高中生物中学的神经兴奋阈值比较像,用数学来表示就是sigmoid函数或者其他的激活函数,当门的输入达到要求,门就会打开,允许当前门后面的信息“穿过”门改变主线上面传递的信息,如果把每一个神经元看成一个时间节点,那么从上一个时间节点传到下一个时间节点过程中的门的开启与关闭就实现了时间序列数据的信息传递。
在这里插入图片描述

遗忘门

在这里插入图片描述
首先是遗忘门,这个门的作用是决定从上一个神经元传输到当前神经元的数据丢弃的程度,如果经过sigmoid函数以后输出0表示全部丢弃,输出1表示全部保留,这个层的输入是旧的信息和当前的新信息。

σ \sigma σ:sigmoid函数
W f W_f Wf:权重向量
b f b_f bf:偏置项,决定丢弃上一个时间节点的程度,如果是正数,表示更容易遗忘,如果是负数,表示比较容易记忆
h t − 1 h_{t-1} ht1:上一个时刻的输入
x t x_t xt:当前层的输入

记忆门

在这里插入图片描述
接下来是记忆门,这个门决定要记住什么信息,同时决定按照什么程度记住上一个状态的信息。

i t i_t it:在时间步t时刻的输入门激活值,计算方法跟上面的遗忘门是一样的,只是目的不一样,这里是记忆
C ~ t \tilde{C}_{t} C~t:表示上一个时刻的信息和当前时刻的信息的集合,但是是规则化到[-1,1]这个范围内了的

记忆细胞

在这里插入图片描述
有了上面的要记忆的信息和要丢弃的信息,记忆细胞的功能就可以得到实现,用 f t f_t ft这个标量决定上一个状态要遗忘什么,用 i t i_t it这个标量决定上一个状态要记住什么以及当前状态的信息要记住什么。这样就形成了一个记忆闭环了。

输出门

在这里插入图片描述
最后,在有了记忆细胞以后不仅仅不要将当前细胞状态记住,还要将当前的信息向下一层继续传输,实现公式中的状态转移。

o t o_t ot:跟前面的门公式都一样,但是功能是决定输出的程度
h t h_t ht:将输出规范到[-1,1]的区间,这里有两个输出的原因是在构建LSTM网络的时候需要有纵向向上的那个 h t h_t ht,然而在当前层的LSTM的神经元之间还是首尾相接的😍。

模型定义

单个LSTM神经元的定义


# 定义单个LSTM单元
# 定义单个LSTM单元
class My_LSTM(nn.Module):def __init__(self, input_size, hidden_size, output_size):super(My_LSTM, self).__init__()self.input_size = input_sizeself.hidden_size = hidden_sizeself.output_size = output_size# 初始化门的权重和偏置,由于每一个神经元都有自己的偏置,所以在定义单元内部定义self.Wf = nn.Parameter(torch.Tensor(input_size + hidden_size, hidden_size))self.bf = nn.Parameter(torch.Tensor(hidden_size))self.Wi = nn.Parameter(torch.Tensor(input_size + hidden_size, hidden_size))self.bi = nn.Parameter(torch.Tensor(hidden_size))self.Wo = nn.Parameter(torch.Tensor(input_size + hidden_size, hidden_size))self.bo = nn.Parameter(torch.Tensor(hidden_size))self.Wg = nn.Parameter(torch.Tensor(input_size + hidden_size, hidden_size))self.bg = nn.Parameter(torch.Tensor(hidden_size))# 初始化输出层的权重和偏置self.W = nn.Parameter(torch.Tensor(hidden_size, output_size))self.b = nn.Parameter(torch.Tensor(output_size))# 用于计算每一种权重的函数def cal_weight(self, input, weight, bias):return F.linear(input, weight, bias)# x是输入的数据,数据的格式是(batch, seq_len, input_size),包含的是batch个序列,每个序列有seq_len个时间步,每个时间步有input_size个特征def forward(self, x):# 初始化隐藏层和细胞状态h = torch.zeros(1, 1, self.hidden_size).to(x.device)c = torch.zeros(1, 1, self.hidden_size).to(x.device)# 遍历每一个时间步for i in range(x.size(1)):input = x[:, i, :].view(1, 1, -1) # 取出每一个时间步的数据# 计算每一个门的权重f = torch.sigmoid(self.cal_weight(input, self.Wf, self.bf)) # 遗忘门i = torch.sigmoid(self.cal_weight(input, self.Wi, self.bi)) # 输入门o = torch.sigmoid(self.cal_weight(input, self.Wo, self.bo)) # 输出门C_ = torch.tanh(self.cal_weight(input, self.Wg, self.bg)) # 候选值# 更新细胞状态c = f * c + i * C_# 更新隐藏层h = o * torch.tanh(c) # 将输出标准化到-1到1之间output = self.cal_weight(h, self.W, self.b) # 计算输出return output

LSTM层内结构的定义

class My_LSTMNetwork(nn.Module):def __init__(self, input_size, hidden_size, output_size):super(My_LSTMNetwork, self).__init__()self.hidden_size = hidden_sizeself.lstm = My_LSTM(input_size, hidden_size)  # 使用自定义的LSTM单元self.fc = nn.Linear(hidden_size, output_size)  # 定义全连接层def forward(self, x):h0 = torch.zeros(1, x.size(0), self.hidden_size).to(x.device)c0 = torch.zeros(1, x.size(0), self.hidden_size).to(x.device)out, _ = self.lstm(x, (h0, c0))  # LSTM层的前向传播out = self.fc(out[:, -1, :])  # 全连接层的前向传播return out

模型训练

history = model.fit(trainX, trainY, batch_size=64, epochs=50, validation_split=0.1, verbose=2)
print('compilation time:', time.time()-start)

模型评估

为了更加直观展示,这里用画图的方法进行结果展示。

fig3 = plt.figure(figsize=(20, 15))
plt.plot(np.arange(train_size+1, len(dataset)+1, 1), scaler.inverse_transform(dataset)[train_size:], label='dataset')
plt.plot(testPredictPlot, 'g', label='test')
plt.ylabel('price')
plt.xlabel('date')
plt.legend()
plt.show()

代码细节

LSTM层单元的首尾的处理

  • 首部:由于第一个节点不用接受来自上一个节点的输入,不需要有输入,当然也有一些是添加标识。

  • 尾部:由于已经进行到当前层的最后一个节点,因此输出只需要向下一层进行传递而不用向下一个节点传递,添加标识也是可以的。

配置Tensorflow的GPU版本

这一篇写的比较好,我自己的硬件环境如下图所示,需要的可以借鉴一下,当然也可以在我提供的代码链接直接用我给的environment.yml一键构建环境😃。
在这里插入图片描述

相关文章:

【论文复现】LSTM长短记忆网络

LSTM 前言网络架构总线遗忘门记忆门记忆细胞输出门 模型定义单个LSTM神经元的定义LSTM层内结构的定义 模型训练模型评估代码细节LSTM层单元的首尾的处理配置Tensorflow的GPU版本 前言 LSTM作为经典模型,可以用来做语言模型,实现类似于语言模型的功能&am…...

目标检测YOLO实战应用案例100讲-【自动驾驶】激光雷达

目录 前言 算法原理 测距方法 发射单元 接收单元 扫描单元...

用C语言设计轨道电机的驱动库

一、设计目的 设计能驱动立体轨道电机的抽象驱动程序库。 二、设计要求 命名规范。设计简单,方便使用。体积小。满足电机的移动、停止、初始化、恢复等控制,甚至通过网络控制。 三、设计内容 (一)属性封装 1、定义配置结构体 // 用于配置参数 typed…...

HTML跳动的爱心

目录 写在前面 HTML简介 跳动的爱心 代码分析 运行结果 推荐文章 写在后面 写在前面 哎呀,这是谁的小心心?跳得好快吖! HTML简介 老生常谈啦,咱们还是从HTML开始吧! HTML是超文本标记语言(Hyper…...

汇编原理(二)

寄存器:所有寄存器都是16位(0-15),可以存放两个字节 AX,BX,CX,DX存放一般性数据,称为通用寄存器 AX的逻辑结构。最大存放的数据为2的16次方减1。可分为AH和AL,兼容8位寄存器。 字:1word 2Byte…...

Android Studio开发之路(十三)主题影响Button颜色问题解决及button自定义样式

一、问题描述 在开发过程中发现安卓的默认主题色是紫色,并且会导致button也是紫色,有时直接在xml布局文件中直接设置button的背景色或者设置背景图片不起效果 方案一、如果是app,可以直接设置主题颜色 比如,将主题设置为白色&a…...

eNSP学习——OSPF单区域配置

目录 相关命令 实验背景 实验目的 实验步骤 实验拓扑 实验编址 实验步骤 1、基础配置 2、部署单区域OSPF网络 3、检查OSPF单区域的配置结果 OSPF——开放式最短路径优先 基于链路状态的协议,具有收敛快、路由无环、扩展性好等优点; 相关命令 […...

深度学习中的优化算法二(Pytorch 19)

一 梯度下降 尽管梯度下降(gradient descent)很少直接用于深度学习,但了解它是理解下一节 随机梯度下降算法 的关键。例如,由于学习率过大,优化问题可能会发散,这种现象早已在梯度下降中出现。同样地&…...

R实验 方差分析

实验目的: 掌握单因素方差分析的思想和方法; 掌握多重均值检验方法; 掌握多个总体的方差齐性检验; 掌握Kruskal-Wallis秩和检验的思想和方法; 掌握多重Wilcoxon秩和检验的思想和方法。 实验内容: &…...

AI智能体|手把手教你使用扣子Coze图像流的文生图功能

大家好,我是无界生长。 AI智能体|手把手教你使用扣子Coze图像流的文生图功能本文详细介绍了Coze平台的\x26quot;图像流\x26quot;功能中的\x26quot;文生图\x26quot;节点,包括创建图像流、编排文生图节点、节点参数配置,并通过案例…...

应用程序图标提取

文章目录 [toc]提取过程提取案例——提取7-zip应用程序的图标 提取过程 找到需要提取图标的应用程序的.exe文件 复制.exe文件到桌面,并将复制的.exe文件后缀改为.zip 使用解压工具7-zip解压.zip文件 在解压后的文件夹中,在.rsrc/ICON路径下的.ico文件…...

Excel表格在线解密:轻松解密密码,快速恢复数据

忘记了excel表格密码?教你简单两步走:具体步骤如下。首先,在百度搜索中键入“密码帝官网”。其次,点击“立即开始”,在用户中心上传表格文件即可找回密码。这种方法不用下载软件,操作简单易行,适…...

springboot小结1

什么是springboot ​ Spring Boot是为了简化Spring应用的创建、运行、调试、部署等而出现的,使用它可以做到专注于Spring应用的开发,而无需过多关注XML的配置。 ​ 简单来说,它提供了一堆依赖打包Starter,并已经按照使用习惯解决…...

【Qt 学习笔记】Qt窗口 | 菜单栏 | QMenuBar的使用及说明

博客主页:Duck Bro 博客主页系列专栏:Qt 专栏关注博主,后期持续更新系列文章如果有错误感谢请大家批评指出,及时修改感谢大家点赞👍收藏⭐评论✍ Qt窗口 | 菜单栏 | QMenuBar的使用及说明 文章编号:Qt 学习…...

Spark运行模式详解

Spark概述 Spark 可以在多种不同的运行模式下执行,每种模式都有其自身的特点和适用场景。 部署Spark集群大体上分为两种模式:单机模式与集群模式。大多数分布式框架都支持单机模式,方便开发者调试框架的运行环境。但是在生产环境中&#xff…...

vcpkg环境配置

vcpkg 使用linux相关库,设置环境变量VCPKG_ROOT,设置cmake工具链$VCPKG_ROOT/scripts\buildsystems\vcpkg.cmake set VCPKG_DEFAULT_TRIPLETx64-windows .\vcpkg.exe install fftw3 freetype gettext glibmm gtkmm libjpeg-turbo libpng libxmlpp libs…...

python学习:基础语句

目录 条件语句 循环语句 for 循环 while 循环 break continue 条件语句 Python提供了 if、elif、else 来进行逻辑判断。格式如下: Pythonif 判断条件1: 执行语句1... elif 判断条件2: 执行语句2... elif 判断条件3: 执行语句3... else: 执行语句4…...

Nginx限制IP访问详解

在Web服务器管理中,限制某些IP地址访问网站是一个常见的需求。Nginx作为一款高性能的HTTP服务器和反向代理服务器,提供了灵活强大的配置选项来实现这一功能。本文将详细讲解如何在Nginx中限制IP访问,并通过示例代码展示具体操作。 一、Nginx…...

Three.js——二维平面、二维圆、自定义二维图形、立方体、球体、圆柱体、圆环、扭结、多面体、文字

个人简介 👀个人主页: 前端杂货铺 ⚡开源项目: rich-vue3 (基于 Vue3 TS Pinia Element Plus Spring全家桶 MySQL) 🙋‍♂️学习方向: 主攻前端方向,正逐渐往全干发展 &#x1…...

24年湖南教资认定即将开始,别被照片卡审!

24年湖南教资认定即将开始,别被照片卡审!...

PHP和Node.js哪个更爽?

先说结论,rust完胜。 php:laravel,swoole,webman,最开始在苏宁的时候写了几年php,当时觉得php真的是世界上最好的语言,因为当初活在舒适圈里,不愿意跳出来,就好比当初活在…...

华为OD机试-食堂供餐-二分法

import java.util.Arrays; import java.util.Scanner;public class DemoTest3 {public static void main(String[] args) {Scanner in new Scanner(System.in);// 注意 hasNext 和 hasNextLine 的区别while (in.hasNextLine()) { // 注意 while 处理多个 caseint a in.nextIn…...

HTML前端开发:JavaScript 常用事件详解

作为前端开发的核心,JavaScript 事件是用户与网页交互的基础。以下是常见事件的详细说明和用法示例: 1. onclick - 点击事件 当元素被单击时触发(左键点击) button.onclick function() {alert("按钮被点击了!&…...

【学习笔记】深入理解Java虚拟机学习笔记——第4章 虚拟机性能监控,故障处理工具

第2章 虚拟机性能监控,故障处理工具 4.1 概述 略 4.2 基础故障处理工具 4.2.1 jps:虚拟机进程状况工具 命令:jps [options] [hostid] 功能:本地虚拟机进程显示进程ID(与ps相同),可同时显示主类&#x…...

C++.OpenGL (14/64)多光源(Multiple Lights)

多光源(Multiple Lights) 多光源渲染技术概览 #mermaid-svg-3L5e5gGn76TNh7Lq {font-family:"trebuchet ms",verdana,arial,sans-serif;font-size:16px;fill:#333;}#mermaid-svg-3L5e5gGn76TNh7Lq .error-icon{fill:#552222;}#mermaid-svg-3L5e5gGn76TNh7Lq .erro…...

保姆级教程:在无网络无显卡的Windows电脑的vscode本地部署deepseek

文章目录 1 前言2 部署流程2.1 准备工作2.2 Ollama2.2.1 使用有网络的电脑下载Ollama2.2.2 安装Ollama(有网络的电脑)2.2.3 安装Ollama(无网络的电脑)2.2.4 安装验证2.2.5 修改大模型安装位置2.2.6 下载Deepseek模型 2.3 将deepse…...

云原生安全实战:API网关Kong的鉴权与限流详解

🔥「炎码工坊」技术弹药已装填! 点击关注 → 解锁工业级干货【工具实测|项目避坑|源码燃烧指南】 一、基础概念 1. API网关(API Gateway) API网关是微服务架构中的核心组件,负责统一管理所有API的流量入口。它像一座…...

Go语言多线程问题

打印零与奇偶数(leetcode 1116) 方法1:使用互斥锁和条件变量 package mainimport ("fmt""sync" )type ZeroEvenOdd struct {n intzeroMutex sync.MutexevenMutex sync.MutexoddMutex sync.Mutexcurrent int…...

苹果AI眼镜:从“工具”到“社交姿态”的范式革命——重新定义AI交互入口的未来机会

在2025年的AI硬件浪潮中,苹果AI眼镜(Apple Glasses)正在引发一场关于“人机交互形态”的深度思考。它并非简单地替代AirPods或Apple Watch,而是开辟了一个全新的、日常可接受的AI入口。其核心价值不在于功能的堆叠,而在于如何通过形态设计打破社交壁垒,成为用户“全天佩戴…...

用鸿蒙HarmonyOS5实现中国象棋小游戏的过程

下面是一个基于鸿蒙OS (HarmonyOS) 的中国象棋小游戏的实现代码。这个实现使用Java语言和鸿蒙的Ability框架。 1. 项目结构 /src/main/java/com/example/chinesechess/├── MainAbilitySlice.java // 主界面逻辑├── ChessView.java // 游戏视图和逻辑├──…...