当前位置: 首页 > news >正文

什么是长短期记忆网络?

一、概念

        长短期记忆网络(Long Short-Term Memory, LSTM)是一种特殊的循环神经网络(RNN),旨在解决标准RNN在处理长序列时的梯度消失和梯度爆炸问题。LSTM通过引入三个门(输入门、遗忘门和输出门)来控制信息的流动。其中,每个门都是一个神经网络层,用于决定哪些信息应该被保留,哪些信息应该被丢弃。LSTM的核心是细胞状态(cell state),它通过这些门的控制来更新和传递信息。

二、核心算法

        令x_{t}为时间步 t 的输入向量,h_{t-1}为前一个时间步的隐藏状态向量,h_{t}为当前时间步的隐藏状态向量,C_{t-1}为前一个时间步的细胞状态向量,C_{t}为当前时间步的细胞状态变量,f_{t}为当前时间步的遗忘门向量,i_{t}为当前时间步的输入门向量,\bar{C_{t}}为当前时间步的候选细胞状态向量,o_{t}为当前时间步的输出门向量,W_{f},W_{i},W_{C},W_{o}分别为各门的权重矩阵,b_{f},b_{i},b_{C},b_{o}为偏置向量,\sigma为sigmoid激活函数,tanh为tanh激活函数,*为元素级乘法。LSTM的核心内容包括以下几个部分:

1、遗忘门(Forget Gate)

        遗忘门决定细胞状态中哪些信息需要被遗忘。通过sigmoid激活函数,遗忘门的输出在0到1之间,表示每个细胞状态元素被保留的比例。

f_{t} = \sigma(W_{f} \cdot \left [ h_{t-1}, x_{t} \right ] + b_{f})

2、输入门(Input Gate)

        输入门决定哪些新的信息需要被写入细胞状态。通过sigmoid激活函数,输入门的输出在0到1之间,表示每个候选细胞状态元素被写入的比例。候选细胞状态通过tanh激活函数生成,表示新的信息。

i_{t} = \sigma(W_{i} \cdot \left [ h_{t-1}, x_{t} \right ] + b_{i})

\bar{C}_{t} = tanh(W_{C} \cdot \left [ h_{t-1}, x_{t} \right ] + b_{C})

3、细胞状态更新

        细胞状态结合遗忘门和输入门的结果进行更新。遗忘门的输出与前一个时间步的细胞状态相乘,表示保留的旧信息。输入门的输出与候选细胞状态相乘,表示写入的新信息。两者相加得到当前时间步的细胞状态。

C_{t} = f_{t} \ast C_{t-1}+i_{t} \ast \bar{C}_{t}

4、输出门(Output Gate)

        输出门决定细胞状态的哪些部分将作为输出。通过sigmoid激活函数,输出门的输出在0到1之间,表示每个细胞状态元素被输出的比例。细胞状态通过tanh激活函数进行非线性变换,然后与输出门的输出相乘,得到当前时间步的隐藏状态。

o_{t} = \sigma(W_{o} \cdot \left [ h_{t-1}, x_{t} \right ] + b_{o})

h_{t} = o_{t} \ast tanh(C_{t})

三、python实现

import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split# 生成正弦波数据
def generate_sine_wave(seq_length, num_samples):x = np.linspace(0, num_samples, num_samples)y = np.sin(x)data = []for i in range(len(y) - seq_length):data.append(y[i:i+seq_length+1])return np.array(data)# 定义LSTM模型
class LSTMModel(nn.Module):def __init__(self, input_size, hidden_size, output_size, num_layers):super(LSTMModel, self).__init__()self.hidden_size = hidden_sizeself.num_layers = num_layersself.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)self.fc = nn.Linear(hidden_size, output_size)def forward(self, x):h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(x.device)c0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(x.device)out, _ = self.lstm(x, (h0, c0))out = self.fc(out[:, -1, :])return out# 超参数设置
seq_length = 50
num_samples = 1000
input_size = 1
hidden_size = 50
output_size = 1
num_layers = 2
batch_size = 64
learning_rate = 0.001
num_epochs = 5
test_size = 0.2  # 测试集占比# 生成数据
data = generate_sine_wave(seq_length, num_samples)
X = data[:, :-1]
y = data[:, -1]# 划分训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=test_size, random_state=42)# 转换为Tensor
X_train = torch.tensor(X_train.reshape(-1, seq_length, input_size), dtype=torch.float32)
y_train = torch.tensor(y_train.reshape(-1, output_size), dtype=torch.float32)
X_test = torch.tensor(X_test.reshape(-1, seq_length, input_size), dtype=torch.float32)
y_test = torch.tensor(y_test.reshape(-1, output_size), dtype=torch.float32)# 创建数据加载器
train_dataset = torch.utils.data.TensorDataset(X_train, y_train)
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
test_dataset = torch.utils.data.TensorDataset(X_test, y_test)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False)# 初始化模型、损失函数和优化器
model = LSTMModel(input_size, hidden_size, output_size, num_layers)
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)# 训练模型
for epoch in range(num_epochs):model.train()for i, (inputs, labels) in enumerate(train_loader):outputs = model(inputs)loss = criterion(outputs, labels)optimizer.zero_grad()loss.backward()optimizer.step()print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}')# 测试模型
model.eval()
with torch.no_grad():predicted = []actual = []for inputs, labels in test_loader:outputs = model(inputs)predicted.extend(outputs.numpy())actual.extend(labels.numpy())# 绘制结果
plt.plot(actual, label='Actual data')
plt.plot(predicted, label='Predicted data')
plt.legend()
plt.show()

四、总结

        LSTM能够捕捉长时间依赖关系,使得模型在处理长序列数据时表现得比标准的RNN更好。但由于LSTM的计算依赖于前一个时间步的输出,这使得这样的网络结构难以并行化,在处理大规模数据时的效率较低。

相关文章:

什么是长短期记忆网络?

一、概念 长短期记忆网络(Long Short-Term Memory, LSTM)是一种特殊的循环神经网络(RNN),旨在解决标准RNN在处理长序列时的梯度消失和梯度爆炸问题。LSTM通过引入三个门(输入门、遗忘门和输出门&#xff09…...

git中有关old mode 100644、new mode 10075的问题解决小结

在 Git 版本控制系统中,文件权限变更是一种常见情况。当你看到类似 old mode 100644 和 new mode 100755 的信息时,这通常表示文件的权限发生了变化。本文将详细解析这种情况,并提供解决方法和注意事项。 问题背景 在 Git 中,文…...

Jenkins上生成的allure report打不开怎么处理

目录 问题背景: 原因: 解决方案: Jenkins上修改配置 通过Groovy脚本在Script Console中设置和修改系统属性 步骤 验证是否清空成功 进一步的定制 也可以使用Nginx去解决 使用逆向代理服务器Nginx: 通过合理调整CSP配置&a…...

JSR303校验教学

1、什么是JSR303校验 JSR是Java Specification Requests的缩写,意思是Java 规范提案。是指向JCP(Java Community Process)提出新增一个标准化技术规范的正式请求。任何人都可以提交JSR,以向Java平台增添新的API和服务。JSR已成为Java界的一个重要标准。…...

使用DeepSeek技巧:提升内容创作效率与质量

一、引言 在当今快节奏的数字时代,内容创作的需求不断增加,无论是企业营销、个人博客还是学术研究,高效且高质量的内容生成变得至关重要。DeepSeek作为一款先进的人工智能写作助手,凭借其强大的语言生成能力,为创作者…...

【第六天】零基础入门刷题Python-算法篇-数据结构与算法的介绍-一种常见的贪心算法(持续更新)

提示:文章写完后,目录可以自动生成,如何生成可参考右边的帮助文档 文章目录 前言一、Python数据结构与算法的详细介绍1.Python中的常用的贪心算法2.贪心算法3.详细的贪心代码1)一种常见的贪心算法 总结 前言 提示:这里…...

C# Winform制作一个登录系统

using System; using System.Collections; using System.Collections.Generic; using System.ComponentModel; using System.Data; using System.Drawing; using System.Linq; using System.Text; using System.Threading.Tasks; using System.Windows.Forms;namespace 登录 {p…...

算法总结-哈希表

文章目录 1.赎金信1.答案2.思路 2.字母异位词分组1.答案2.思路 3.两数之和1.答案2.思路 4.快乐数1.答案2.思路 5.最长连续序列1.答案2.思路 1.赎金信 1.答案 package com.sunxiansheng.arithmetic.day14;/*** Description: 383. 赎金信** Author sun* Create 2025/1/22 11:10…...

向下调整算法(详解)c++

算法流程: 与⽗结点的权值作⽐较,如果⽐它⼤,就与⽗亲交换; 交换完之后,重复 1 操作,直到⽐⽗亲⼩,或者换到根节点的位置 大家可能会有点疑惑,这个是大根堆,22是怎么跑到…...

蓝桥杯之c++入门(一)【C++入门】

目录 前言5. 算术操作符5.1 算术操作符5.2 浮点数的除法5.3 负数取模5.4 数值溢出5.5 练习练习1:计算 ( a b ) ⋆ c (ab)^{\star}c (ab)⋆c练习2:带余除法练习3:整数个位练习4:整数十位练习5:时间转换练习6&#xff…...

使用Python爬虫获取1688商品拍立淘API接口(item_search_img)的实战指南

在电商领域,通过图片搜索商品(拍立淘)已经成为一种重要的商品检索方式。1688平台的item_search_img接口允许用户通过上传图片来搜索相似商品,这为商品信息采集和市场分析提供了极大的便利。本文将详细介绍如何使用Python爬虫技术调…...

ElasticSearch-文档元数据乐观并发控制

文章目录 什么是文档?文档元数据文档的部分更新Update 乐观并发控制 最近日常工作开发过程中使用到了 ES,最近在检索资料的时候翻阅到了 ES 的官方文档,里面对 ES 的基础与案例进行了通俗易懂的解释,读下来也有不少收获&#xff0…...

使用Navicat Premium管理数据库时,如何关闭事务默认自动提交功能?

使用Navicat Premium管理数据库时,最糟心的事情莫过于事务默认自动提交,也就是你写完语句运行时,它自动执行commit提交至数据库,此时你就无法进行回滚操作。 建议您尝试取消勾选“选项”中的“自动开始事务”,点击“工…...

【单细胞-第三节 多样本数据分析】

文件在单细胞\5_GC_py\1_single_cell\1.GSE183904.Rmd GSE183904 数据原文 1.获取临床信息 筛选样本可以参考临床信息 rm(list ls()) library(tinyarray) a geo_download("GSE183904")$pd head(a) table(a$Characteristics_ch1) #统计各样本有多少2.批量读取 学…...

(java) IO流

学习IO流之前,我们需要先认识file对象,帮助我们更好的使用IO流 1.1 file 作用:关联硬盘上的文件 写法: File(String path); (推荐)File(String parent, String child); //由父级路径,再子级路径拼接而成File(File p…...

2025年1月个人工作生活总结

本文为 2025年1月工作生活总结。 研发编码 使用sqlite3命令行查询表数据 可以直接使用sqlite3查询数据表,不需进入命令行模式。示例如下: sqlite3 database_name.db "SELECT * FROM table_name;"linux shell使用read超时一例 先前有个编译…...

线性调整器——耗能型调整器

线性调整器又称线性电压调节器,以下是关于它的介绍: 基本工作原理 线性调整器的基本电路如图1.1(a)所示,晶体管Q1(工作于线性状态,或非开关状态)构成一个连接直流源V和输出端V。的可调电气电阻,直流源V由60Hz隔离变压器(电气隔离和整流&#…...

【2025美赛D题】为更美好的城市绘制路线图建模|建模过程+完整代码论文全解全析

你是否在寻找数学建模比赛的突破点?数学建模进阶思路! 作为经验丰富的美赛O奖、国赛国一的数学建模团队,我们将为你带来本次数学建模竞赛的全面解析。这个解决方案包不仅包括完整的代码实现,还有详尽的建模过程和解析&#xff0c…...

【Numpy核心编程攻略:Python数据处理、分析详解与科学计算】1.28 存储之道:跨平台数据持久化方案

好的,我将按照您的要求生成一篇高质量的Python NumPy文章。以下是第28篇《存储之道:跨平台数据持久化方案》的完整内容,包括目录、正文和参考文献。 1.28 存储之道:跨平台数据持久化方案 目录 #mermaid-svg-n1z37AP8obEgptkD {f…...

拼车(1094)

1094. 拼车 - 力扣&#xff08;LeetCode&#xff09; 解法&#xff1a; class Solution { public:bool carPooling(vector<vector<int>>& trips, int capacity) {uint32_t passenger_cnt 0;//将原数据按照from排序auto func_0 [](vector<int> & …...

应用升级/灾备测试时使用guarantee 闪回点迅速回退

1.场景 应用要升级,当升级失败时,数据库回退到升级前. 要测试系统,测试完成后,数据库要回退到测试前。 相对于RMAN恢复需要很长时间&#xff0c; 数据库闪回只需要几分钟。 2.技术实现 数据库设置 2个db_recovery参数 创建guarantee闪回点&#xff0c;不需要开启数据库闪回。…...

工业安全零事故的智能守护者:一体化AI智能安防平台

前言&#xff1a; 通过AI视觉技术&#xff0c;为船厂提供全面的安全监控解决方案&#xff0c;涵盖交通违规检测、起重机轨道安全、非法入侵检测、盗窃防范、安全规范执行监控等多个方面&#xff0c;能够实现对应负责人反馈机制&#xff0c;并最终实现数据的统计报表。提升船厂…...

Oracle查询表空间大小

1 查询数据库中所有的表空间以及表空间所占空间的大小 SELECTtablespace_name,sum( bytes ) / 1024 / 1024 FROMdba_data_files GROUP BYtablespace_name; 2 Oracle查询表空间大小及每个表所占空间的大小 SELECTtablespace_name,file_id,file_name,round( bytes / ( 1024 …...

无法与IP建立连接,未能下载VSCode服务器

如题&#xff0c;在远程连接服务器的时候突然遇到了这个提示。 查阅了一圈&#xff0c;发现是VSCode版本自动更新惹的祸&#xff01;&#xff01;&#xff01; 在VSCode的帮助->关于这里发现前几天VSCode自动更新了&#xff0c;我的版本号变成了1.100.3 才导致了远程连接出…...

条件运算符

C中的三目运算符&#xff08;也称条件运算符&#xff0c;英文&#xff1a;ternary operator&#xff09;是一种简洁的条件选择语句&#xff0c;语法如下&#xff1a; 条件表达式 ? 表达式1 : 表达式2• 如果“条件表达式”为true&#xff0c;则整个表达式的结果为“表达式1”…...

将对透视变换后的图像使用Otsu进行阈值化,来分离黑色和白色像素。这句话中的Otsu是什么意思?

Otsu 是一种自动阈值化方法&#xff0c;用于将图像分割为前景和背景。它通过最小化图像的类内方差或等价地最大化类间方差来选择最佳阈值。这种方法特别适用于图像的二值化处理&#xff0c;能够自动确定一个阈值&#xff0c;将图像中的像素分为黑色和白色两类。 Otsu 方法的原…...

高等数学(下)题型笔记(八)空间解析几何与向量代数

目录 0 前言 1 向量的点乘 1.1 基本公式 1.2 例题 2 向量的叉乘 2.1 基础知识 2.2 例题 3 空间平面方程 3.1 基础知识 3.2 例题 4 空间直线方程 4.1 基础知识 4.2 例题 5 旋转曲面及其方程 5.1 基础知识 5.2 例题 6 空间曲面的法线与切平面 6.1 基础知识 6.2…...

跨链模式:多链互操作架构与性能扩展方案

跨链模式&#xff1a;多链互操作架构与性能扩展方案 ——构建下一代区块链互联网的技术基石 一、跨链架构的核心范式演进 1. 分层协议栈&#xff1a;模块化解耦设计 现代跨链系统采用分层协议栈实现灵活扩展&#xff08;H2Cross架构&#xff09;&#xff1a; 适配层&#xf…...

云原生玩法三问:构建自定义开发环境

云原生玩法三问&#xff1a;构建自定义开发环境 引言 临时运维一个古董项目&#xff0c;无文档&#xff0c;无环境&#xff0c;无交接人&#xff0c;俗称三无。 运行设备的环境老&#xff0c;本地环境版本高&#xff0c;ssh不过去。正好最近对 腾讯出品的云原生 cnb 感兴趣&…...

iOS性能调优实战:借助克魔(KeyMob)与常用工具深度洞察App瓶颈

在日常iOS开发过程中&#xff0c;性能问题往往是最令人头疼的一类Bug。尤其是在App上线前的压测阶段或是处理用户反馈的高发期&#xff0c;开发者往往需要面对卡顿、崩溃、能耗异常、日志混乱等一系列问题。这些问题表面上看似偶发&#xff0c;但背后往往隐藏着系统资源调度不当…...