当前位置: 首页 > news >正文

【深度学习】LSTM、BiLSTM详解

文章目录

  • 1. LSTM简介:
  • 2. LSTM结构图:
  • 3. 单层LSTM详解
  • 4. 双层LSTM详解
  • 5. BiLSTM
  • 6. Pytorch实现LSTM示例
  • 7. nn.LSTM参数详解

1. LSTM简介:

    LSTM是一种循环神经网络,它可以处理和预测时间序列中间隔和延迟相对较长的重要事件。LSTM通过使用门控单元来控制信息的流动,从而缓解RNN中的梯度消失和梯度爆炸的问题。LSTM的核心是三个门:输入门遗忘门输出门

遗忘门: 遗忘门的作用是决定哪些信息从记忆单元中遗忘,它使用sigmoid激活函数,可以输出在0到1之间的值,可以理解为保留信息的比例。
输入门: 作用是决定哪些新信息被存储在记忆单元中
输出门: 输出门决定了下一个隐藏状态,即生成当前时间步的输出并传递到下一时间步
记忆单元:负责长期信息的存储,通过遗忘门和输入门的相互作用,记忆单元能够学习如何选择性地记住或忘记信息

2. LSTM结构图:

在这里插入图片描述

涉及到的计算公式如下:
在这里插入图片描述

3. 单层LSTM详解

(1)设定有3个字的序列【“早”“上”“好”】要经过LSTM处理,每个序列由20个元素组成的列向量构成,所以input size就为20。

(2)设定全连接层中有100个隐藏单元,LSTM的层数为1。

(3)因为是3个字的序列,所以LSTM需要3个时间步(即会自循环3次)才能处理完这个序列。

(4)nn.LSTM()每层也可以拆开写,这样每层的隐藏单元个数就可以分别设定。

在这里插入图片描述
    LSTM单元包含三个输入参数x、c、h;首先t1时刻作为第一个时间步,输入到第一个LSTM单元中,此时输入的初始从c(0)和h(0)都是0矩阵,计算完成后,第一个LSTM单元输出一组h(t1)\c(t1),作为本层LSTM的第二个时间步的输入参数;因此第二个时间步的输入就是h(t1),c(t1),x(t2),而输出是h(t2),c(t2);因此第三个时间步的输入就是h(t2),c(t2),x(t3),而输出是h(t3),c(t3)。

4. 双层LSTM详解

(1)设定有3个字的序列【“早”“上”“好”】要经过LSTM处理,每个序列由20个元素组成的列向量构成,所以input size就为20。

(2)设定全连接层中有100个隐藏单元,LSTM的层数为2。

(3)因为是3个字的序列,所以LSTM需要3个时间步(即会自循环3次)才能处理完这个序列。

(4)nn.LSTM()每层也可以拆开写,这样每层的隐藏单元个数就可以分别设定。

在这里插入图片描述
    第二层LSTM没有输入参数x(t1)、x(t2)、x(t3);所以我们将第一层LSTM输出的h(t1)、h(t2)、h(t3)作为第二层LSTM的输入x(t1)、x(t2)、x(t3)。第一个时间步输入的初始c(0)和h(0)都为0矩阵,计算完成后,第一个时间步输出新的一组h(t1)、c(t1),作为本层LSTM的第二个时间步的输入参数;因此第二个时间步的输入就是h(t1),c(t1),x(t2),而输出是h(t2),c(t2);因此第三个时间步的输入就是h(t2),c(t2),x(t3),而输出是h(t3),c(t3)。

5. BiLSTM

单层的BiLSTM其实就是2个LSTM,一个正向去处理序列,一个反向去处理序列,处理完后,两个LSTM的输出会拼接起来。
在这里插入图片描述

6. Pytorch实现LSTM示例

import torch 
import torch.nn as nndevice = torch.device("cuda" if torch.cuda.is_available() else "cpu")class LSTM(nn.Module):def __init__(self, input_dim, hidden_dim, num_layers, output_dim):super(LSTM, self).__init__()self.hidden_dim = hidden_dim  # 隐藏层维度self.num_layers = num_layers  # LSTM层的数量# LSTM网络层self.lstm = nn.LSTM(input_dim, hidden_dim, num_layers, batch_first=True)# 全连接层,用于将LSTM的输出转换为最终的输出维度self.fc = nn.Linear(hidden_dim, output_dim)def forward(self, x):# 初始化隐藏状态和细胞状态h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_dim).to(device)c0 = torch.zeros(self.num_layers, x.size(0), self.hidden_dim).to(device)# 前向传播LSTM,返回输出和最新的隐藏状态与细胞状态out, (hn, cn) = self.lstm(x, (h0.detach(), c0.detach()))# 将LSTM的最后一个时间步的输出通过全连接层out = self.fc(out[:, -1, :])return out

7. nn.LSTM参数详解

pytorch官方定义:

CLASS torch.nn.LSTM(
        input_size,
        hidden_size,
        num_layers=1,
        bias=True,
        batch_first=False,
        dropout=0.0,
        bidirectional=False,
        proj_size=0,
        device=None,
        dtype=None
    )

input_size – 输入 x 中预期的特征数量
hidden_size – 隐藏状态 h 中的特征数量
num_layers – 循环层的数量。例如,设置 num_layers=2 表示将两个 LSTM 堆叠在一起形成一个 stacked LSTM,其中第二个 LSTM 接收第一个 LSTM 的输出并计算最终结果。默认值:1
bias – 如果 False,则该层不使用偏差权重 b_ih 和 b_hh。默认值:True
batch_first – 如果 True,则输入和输出张量将以 (batch, seq, feature) 而不是 (seq, batch, feature) 的形式提供。请注意,这并不适用于隐藏状态或单元状态。有关详细信息,请参见下面的输入/输出部分。默认值:False
dropout – 如果非零,则在除最后一层之外的每个 LSTM 层的输出上引入一个 Dropout 层,其 dropout 概率等于 dropout。默认值:0
bidirectional – 如果 True,则变为双向 LSTM。默认值:False
proj_size – 如果 > 0,则将使用具有相应大小的投影的 LSTM。默认值:0

对于输入序列每一个元素,每一层都会进行以下计算:
在这里插入图片描述
网络输入:
在这里插入图片描述

网络输出:
在这里插入图片描述

本文参考:https://blog.csdn.net/qq_34486832/article/details/134898868
https://pytorch.ac.cn/docs/stable/generated/torch.nn.LSTM.html#

LSTM每层的输出都要经过全连接层吗,还是直接对隐藏层进行输出?
通过在代码中对lstm的输出进行print输出:

import torch 
import torch.nn as nndevice = torch.device("cuda" if torch.cuda.is_available() else "cpu")class LSTM(nn.Module):def __init__(self, input_dim, hidden_dim, num_layers, output_dim):super(LSTM, self).__init__()self.hidden_dim = hidden_dim  # 隐藏层维度self.num_layers = num_layers  # LSTM层的数量# LSTM网络层self.lstm = nn.LSTM(input_dim, hidden_dim, num_layers, batch_first=True)# 全连接层,用于将LSTM的输出转换为最终的输出维度self.fc = nn.Linear(hidden_dim, output_dim)def forward(self, x):# 初始化隐藏状态和细胞状态h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_dim).to(device)c0 = torch.zeros(self.num_layers, x.size(0), self.hidden_dim).to(device)# 前向传播LSTM,返回输出和最新的隐藏状态与细胞状态out, (hn, cn) = self.lstm(x, (h0.detach(), c0.detach()))print(out)print(hn)print(cn)# 将LSTM的最后一个时间步的输出通过全连接层out = self.fc(out[:, -1, :])return out
if __name__ == "__main__":input_dim = 3        # 输入特征的维度hidden_dim = 4       # 隐藏层的维度num_layers = 1       # LSTM 层的数量output_dim = 1       # 输出特征的维度lstm = LSTM(input_dim, hidden_dim, num_layers, output_dim).to(device)batch_size = 1seq_length = 10input_tensor = torch.randn(batch_size, seq_length, input_dim).to(device)output = lstm(input_tensor)

通过对LSTM网络的输出我们可以看到,out的最后一层与最后一层隐藏层hn一致,说明并未经过全连接层,而是直接输出隐藏层
在这里插入图片描述

相关文章:

【深度学习】LSTM、BiLSTM详解

文章目录 1. LSTM简介:2. LSTM结构图:3. 单层LSTM详解4. 双层LSTM详解5. BiLSTM6. Pytorch实现LSTM示例7. nn.LSTM参数详解 1. LSTM简介: LSTM是一种循环神经网络,它可以处理和预测时间序列中间隔和延迟相对较长的重要事件。LSTM通…...

分子对接--软件安装

分子对接相关软件安装 一、软件 AutoDock,下载链接: linkMGLtools,下载链接: link 自行选择合适版本下载,这里主要叙述在win上的具体安装流程: 下载得到: 二、运行 运行autodocksuite-4.2.6.i86Windows得到&#…...

【Python无敌】在 QGIS 中使用 Python

QGIS 中有 Python 的运行环境,可以很好地执行各种任务。 这里的问题是如何在 Jupyter 中调用 QGIS 的功能。 首先可以肯定的是涉及到 GUI 的一些任务是无法在 Jupyter 中访问的, 这样可以用的功能主要是地处理工具。 按如下方式进行了尝试。 原想使用 gdal:hillshade ,但是…...

全面解读:低代码开发平台的必备要素——系统策划篇

在传统开发过程中,系统策划起着举足轻重的作用,它宛如一位幕后的总指挥,把控着整个软件开发项目的走向。而随着技术的不断进步,低代码开发平台逐渐崭露头角,它以快速开发、降低技术门槛等优势吸引了众多企业和开发者的…...

Vue开发自动生成验证码功能 前端实现不使用第三方插件实现随机验证码功能,生成的验证码添加干扰因素

Vue实现不使用第三方插件,开发随机生成验证码功能 效果图,其中包含了短信验证码功能,以及验证码输入是否正确功能 dom结构 <div class="VerityInputTu"><div class="labelClass">图形验证码</div><div class="tuxingInput…...

# filezilla连接 虚拟机ubuntu系统出错“尝试连接 ECONNREFUSED - 连接被服务器拒绝, 失败,无法连接服务器”解决方案

filezilla连接 虚拟机ubuntu系统出错“尝试连接 ECONNREFUSED - 连接被服务器拒绝&#xff0c; 失败&#xff0c;无法连接服务器”解决方案 一、问题描述&#xff1a; 当我们用filezilla客户端 连接 虚拟机ubuntu系统时&#xff0c;报错“尝试连接 ECONNREFUSED - 连接被服务…...

2024/11/13 英语每日一段

The new policy has drawn many critics. Data and privacy experts said the Metropolitan Transit Authority’s new initiative doesn’t address the underlying problem that causes fare evasion, which is related to poverty and access. Instead, the program tries “…...

【全栈开发平台】全面解析 StackBlitz 最新力作 Bolt.new:AI 驱动的全栈开发平台

文章目录 [TOC]&#x1f31f; Bolt.new 的独特价值1. **无需配置&#xff0c;立刻开发**2. **AI 驱动&#xff0c;智能生成代码**3. **极致的速度与安全性**4. **一键部署&#xff0c;轻松上线**5. **免费开放&#xff0c;生态丰富** &#x1f6e0;️ Bolt.new 使用教程一、快速…...

文献解读-DNAscope: High accuracy small variant calling using machine learning

关键词&#xff1a;基准与方法研究&#xff1b;基因测序&#xff1b;变异检测&#xff1b; 文献简介 标题&#xff08;英文&#xff09;&#xff1a;DNAscope: High accuracy small variant calling using machine learning标题&#xff08;中文&#xff09;&#xff1a;DNAsc…...

成都睿明智科技有限公司解锁抖音电商新玩法

在这个短视频风起云涌的时代&#xff0c;抖音电商以其独特的魅力迅速崛起&#xff0c;成为众多商家争夺的流量高地。而在这片充满机遇与挑战的蓝海中&#xff0c;成都睿明智科技有限公司犹如一颗璀璨的新星&#xff0c;以其专业的抖音电商服务&#xff0c;助力无数品牌实现从零…...

【操作系统】——调度算法

&#x1f339;&#x1f60a;&#x1f339;博客主页&#xff1a;【Hello_shuoCSDN博客】 ✨操作系统详见 【操作系统专项】 ✨C语言知识详见&#xff1a;【C语言专项】 目录 先来先服务&#xff08;FCFS, First Come First Serve) 短作业优先&#xff08;SJF, Shortest Job Fi…...

MySQL LOAD DATA INFILE导入数据报错

1.导入命令 LOAD DATA INFILE "merge.csv" INTO TABLE 报名数据 FIELDS TERMINATED BY , ENCLOSED BY " LINES TERMINATED BY \n IGNORE 1 LINES; 2.表结构 CREATE TABLE IF NOT EXISTS 报名数据 ( pid VARCHAR(100) NOT NULL, 查询日期 VARCHAR(25) NO…...

AI 写作(五)核心技术之文本摘要:分类与应用(5/10)

一、文本摘要&#xff1a;AI 写作的关键技术 文本摘要在 AI 写作中扮演着至关重要的角色。在当今信息爆炸的时代&#xff0c;人们每天都被大量的文本信息所包围&#xff0c;如何快速有效地获取关键信息成为了一个迫切的需求。文本摘要技术正是为了解决这个问题而诞生的&#x…...

CTFL(二)贯穿软件开发生存周期中的测试

贯穿软件开发生存周期中的测试 验收测试&#xff08;acceptance testing&#xff09;&#xff0c;黑盒测试&#xff08;black-box testing&#xff09;&#xff0c;组件集成测试&#xff08;component integration testing&#xff09;&#xff0c;组件测试&#xff08;compone…...

PMIC FS8405

FS8495 具有多个SMPS和LDO的故障安全系统基础芯片。   FS8X 大多数参数都是通过OTP寄存器设置的。 概述 FS85/FS84设备系列是按照ASIL D流程开发的,FS84具有ASIL B能力,而FS85具有ASIL D能力。所有的设备选项都是引脚到引脚和软件兼容的。   FS85/FS84是一种汽车功能安全…...

matlab建模入门指导

本文以水池中鸡蛋温度随时间的变化为切入点&#xff0c;对其进行数学建模并进行MATLAB求解&#xff0c;以更为通俗地进行数学建模问题入门指导。 一、问题简述 一个煮熟的鸡蛋有98摄氏度&#xff0c;将它放在18摄氏度的水池中&#xff0c;五分钟后鸡蛋的温度为38摄氏度&#x…...

微搭低代码入门03函数

目录 1 函数的定义与调用2 参数与返回值3 默认参数4 将功能拆分成小函数5 函数表达式6 箭头函数7 低代码中的函数总结 在用低代码开发软件的时候&#xff0c;除了我们上两节介绍的变量、条件语句外&#xff0c;还有一个重要的概念叫函数。函数是执行特定功能的代码片段&#xf…...

零基础Java第十六期:抽象类接口(二)

目录 一、接口&#xff08;补&#xff09; 1.1. 数组对象排序 1.2. 克隆接口 1.3. 浅拷贝和深拷贝 1.4. 抽象类和接口的区别 一、接口&#xff08;补&#xff09; 1.1. 数组对象排序 我们在讲一维数组的时候&#xff0c;使用到冒泡排序来对数组里的元素进行从小到大或从大…...

【css】html里面的图片宽度设为百分比,高度要与宽度一样

场景&#xff1a;展示图片列表的时候&#xff0c;原始图片宽高不一致。 外层div的宽度自适应&#xff0c;图片宽度不能固定数值&#xff0c;只能设置百分比。图片高度也不能设置固定数值。 如何让图片的高度与图片的宽度一样呢&#xff1f; html代码 &#xff1a; <div cl…...

前端三大组件之CSS,三大选择器,游戏网页仿写

回顾 full stack全栈 Web前端三大组件 结构(html) 样式(css) 动作/交互(js) --- 》 框架vue&#xff0c;安哥拉 div 常用的标签 扩展标签 列表 ul/ol order——有序号 unordered——没序号的黑点 <!DOCTYPE html> <html><head><meta charset"…...

sqlsever 分布式存储查询

当数据存储在不同的服务器上的时候怎么取出来进行正常管连呢?比如你有 A 和B 两个服务器 里面存有两个表 分别是 A_TABLE、B_TABLE 其中 他们的关联关系是 ID 互相关联 1.创建链接服务器如果在B数据库要访问A数据库 那么 就在B数据库创建 -- 创建链接服务器 EXEC sp_addlink…...

deeponet(nature原文部分重点提取)

论文链接&#xff1a;Learning nonlinear operators via DeepONet based on the universal approximation theorem of operators | Nature Machine Intelligence 原文部分重点提取 DeepONets 会产生小的泛化误差 隐式类型算子还可以描述我们对其形式没有任何数学知识的系统 De…...

LeetCode【0036】有效的数独

本文目录 1 中文题目2 求解方法&#xff1a;python内置函数set2.1 方法思路2.2 Python代码2.3 复杂度分析 3 题目总结 1 中文题目 请根据以下规则判断一个 9 x 9 的数独是否有效。 数字 1-9 在每一行只能出现一次。数字 1-9 在每一列只能出现一次。数字 1-9 在每一个以粗实线…...

Typecho登陆与评论添加Geetest极验证,支持PJAX主题(如Handsome)

Typecho登陆与评论添加Geetest极验证&#xff0c;支持PJAX主题&#xff08;如Handsome&#xff09; 起因 最近垃圾评论比较多&#xff0c;为了防止一些机器人&#xff0c;我给博客添加了一些评论过滤机制&#xff0c;并为评论添加了验证码。 原本使用的插件是noisky/typecho…...

前端入门一之ES6--面向对象、够着函数和原型、继承、ES5新增方法、函数进阶、严格模式、高阶函数、闭包

前言 JS是前端三件套之一&#xff0c;也是核心&#xff0c;本人将会更新JS基础、JS对象、DOM、BOM、ES6等知识点&#xff0c;这篇是ES6;这篇文章是本人大一学习前端的笔记&#xff1b;欢迎点赞 收藏 关注&#xff0c;本人将会持续更新。 文章目录 JS高级 ES61、面向对象1.1…...

脑机接口、嵌入式 AI 、工业级 MR、空间视频和下一代 XR 浏览器丨RTE2024 空间计算和新硬件专场回顾

这一轮硬件创新由 AI 引爆&#xff0c;或许最大受益者仍是 AI&#xff0c;因为只有硬件才能为 AI 直接获取最真实世界的数据。 在人工智能与硬件融合的新时代&#xff0c;实时互动技术正迎来前所未有的创新浪潮。从嵌入式系统到混合现实&#xff0c;从空间视频到脑机接口&…...

RoseTTAFold MSA_emb类解读

MSA_emb 类的作用是对多序列对齐(MSA)数据进行嵌入编码,同时添加位置编码和查询编码(调用PositionalEncoding 和 QueryEncoding)以便为序列特征建模类。 源代码: class MSA_emb(nn.Module):def __init__(self, d_model=64, d_msa=21, p_drop=0.1, max_len=5000):super(…...

2411C++,C++26反射示例

参考 namespace __impl {template<auto... vals>struct replicator_type {template<typename F>constexpr void operator>>(F body) const {(body.template operator()<vals>(), ...);}};template<auto... vals>replicator_type<vals...>…...

Ubuntu上搭建Flink Standalone集群

Ubuntu上搭建Flink Standalone集群 本文部分内容转自如下链接。 环境说明 ubuntu 22.06 先执行apt-get update更新环境 第1步 安装JDK 通过apt自动拉取 openjdk8 apt-get install openjdk-8-jdk执行java -version&#xff0c;如果能显示Java版本号&#xff0c;表示安装并…...

C语言 精选真题2

题目要求&#xff1a;将形参s所指向的字符串转换为整数并且返回 知识点&#xff1a; 将字符1转化为整数1 int fun(char *s) {int flag1,n0; if(*s-) //先根据第一个符号来判断是正负&#xff1b;然后读取第二位{flag-1;s; }else if(*s){s;}while(*s>0&&…...