当前位置: 首页 > news >正文

什么是长短期记忆网络?

一、概念

        长短期记忆网络(Long Short-Term Memory, LSTM)是一种特殊的循环神经网络(RNN),旨在解决标准RNN在处理长序列时的梯度消失和梯度爆炸问题。LSTM通过引入三个门(输入门、遗忘门和输出门)来控制信息的流动。其中,每个门都是一个神经网络层,用于决定哪些信息应该被保留,哪些信息应该被丢弃。LSTM的核心是细胞状态(cell state),它通过这些门的控制来更新和传递信息。

二、核心算法

        令x_{t}为时间步 t 的输入向量,h_{t-1}为前一个时间步的隐藏状态向量,h_{t}为当前时间步的隐藏状态向量,C_{t-1}为前一个时间步的细胞状态向量,C_{t}为当前时间步的细胞状态变量,f_{t}为当前时间步的遗忘门向量,i_{t}为当前时间步的输入门向量,\bar{C_{t}}为当前时间步的候选细胞状态向量,o_{t}为当前时间步的输出门向量,W_{f},W_{i},W_{C},W_{o}分别为各门的权重矩阵,b_{f},b_{i},b_{C},b_{o}为偏置向量,\sigma为sigmoid激活函数,tanh为tanh激活函数,*为元素级乘法。LSTM的核心内容包括以下几个部分:

1、遗忘门(Forget Gate)

        遗忘门决定细胞状态中哪些信息需要被遗忘。通过sigmoid激活函数,遗忘门的输出在0到1之间,表示每个细胞状态元素被保留的比例。

f_{t} = \sigma(W_{f} \cdot \left [ h_{t-1}, x_{t} \right ] + b_{f})

2、输入门(Input Gate)

        输入门决定哪些新的信息需要被写入细胞状态。通过sigmoid激活函数,输入门的输出在0到1之间,表示每个候选细胞状态元素被写入的比例。候选细胞状态通过tanh激活函数生成,表示新的信息。

i_{t} = \sigma(W_{i} \cdot \left [ h_{t-1}, x_{t} \right ] + b_{i})

\bar{C}_{t} = tanh(W_{C} \cdot \left [ h_{t-1}, x_{t} \right ] + b_{C})

3、细胞状态更新

        细胞状态结合遗忘门和输入门的结果进行更新。遗忘门的输出与前一个时间步的细胞状态相乘,表示保留的旧信息。输入门的输出与候选细胞状态相乘,表示写入的新信息。两者相加得到当前时间步的细胞状态。

C_{t} = f_{t} \ast C_{t-1}+i_{t} \ast \bar{C}_{t}

4、输出门(Output Gate)

        输出门决定细胞状态的哪些部分将作为输出。通过sigmoid激活函数,输出门的输出在0到1之间,表示每个细胞状态元素被输出的比例。细胞状态通过tanh激活函数进行非线性变换,然后与输出门的输出相乘,得到当前时间步的隐藏状态。

o_{t} = \sigma(W_{o} \cdot \left [ h_{t-1}, x_{t} \right ] + b_{o})

h_{t} = o_{t} \ast tanh(C_{t})

三、python实现

import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split# 生成正弦波数据
def generate_sine_wave(seq_length, num_samples):x = np.linspace(0, num_samples, num_samples)y = np.sin(x)data = []for i in range(len(y) - seq_length):data.append(y[i:i+seq_length+1])return np.array(data)# 定义LSTM模型
class LSTMModel(nn.Module):def __init__(self, input_size, hidden_size, output_size, num_layers):super(LSTMModel, self).__init__()self.hidden_size = hidden_sizeself.num_layers = num_layersself.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)self.fc = nn.Linear(hidden_size, output_size)def forward(self, x):h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(x.device)c0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(x.device)out, _ = self.lstm(x, (h0, c0))out = self.fc(out[:, -1, :])return out# 超参数设置
seq_length = 50
num_samples = 1000
input_size = 1
hidden_size = 50
output_size = 1
num_layers = 2
batch_size = 64
learning_rate = 0.001
num_epochs = 5
test_size = 0.2  # 测试集占比# 生成数据
data = generate_sine_wave(seq_length, num_samples)
X = data[:, :-1]
y = data[:, -1]# 划分训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=test_size, random_state=42)# 转换为Tensor
X_train = torch.tensor(X_train.reshape(-1, seq_length, input_size), dtype=torch.float32)
y_train = torch.tensor(y_train.reshape(-1, output_size), dtype=torch.float32)
X_test = torch.tensor(X_test.reshape(-1, seq_length, input_size), dtype=torch.float32)
y_test = torch.tensor(y_test.reshape(-1, output_size), dtype=torch.float32)# 创建数据加载器
train_dataset = torch.utils.data.TensorDataset(X_train, y_train)
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
test_dataset = torch.utils.data.TensorDataset(X_test, y_test)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False)# 初始化模型、损失函数和优化器
model = LSTMModel(input_size, hidden_size, output_size, num_layers)
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)# 训练模型
for epoch in range(num_epochs):model.train()for i, (inputs, labels) in enumerate(train_loader):outputs = model(inputs)loss = criterion(outputs, labels)optimizer.zero_grad()loss.backward()optimizer.step()print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}')# 测试模型
model.eval()
with torch.no_grad():predicted = []actual = []for inputs, labels in test_loader:outputs = model(inputs)predicted.extend(outputs.numpy())actual.extend(labels.numpy())# 绘制结果
plt.plot(actual, label='Actual data')
plt.plot(predicted, label='Predicted data')
plt.legend()
plt.show()

四、总结

        LSTM能够捕捉长时间依赖关系,使得模型在处理长序列数据时表现得比标准的RNN更好。但由于LSTM的计算依赖于前一个时间步的输出,这使得这样的网络结构难以并行化,在处理大规模数据时的效率较低。

相关文章:

什么是长短期记忆网络?

一、概念 长短期记忆网络(Long Short-Term Memory, LSTM)是一种特殊的循环神经网络(RNN),旨在解决标准RNN在处理长序列时的梯度消失和梯度爆炸问题。LSTM通过引入三个门(输入门、遗忘门和输出门&#xff09…...

git中有关old mode 100644、new mode 10075的问题解决小结

在 Git 版本控制系统中,文件权限变更是一种常见情况。当你看到类似 old mode 100644 和 new mode 100755 的信息时,这通常表示文件的权限发生了变化。本文将详细解析这种情况,并提供解决方法和注意事项。 问题背景 在 Git 中,文…...

Jenkins上生成的allure report打不开怎么处理

目录 问题背景: 原因: 解决方案: Jenkins上修改配置 通过Groovy脚本在Script Console中设置和修改系统属性 步骤 验证是否清空成功 进一步的定制 也可以使用Nginx去解决 使用逆向代理服务器Nginx: 通过合理调整CSP配置&a…...

JSR303校验教学

1、什么是JSR303校验 JSR是Java Specification Requests的缩写,意思是Java 规范提案。是指向JCP(Java Community Process)提出新增一个标准化技术规范的正式请求。任何人都可以提交JSR,以向Java平台增添新的API和服务。JSR已成为Java界的一个重要标准。…...

使用DeepSeek技巧:提升内容创作效率与质量

一、引言 在当今快节奏的数字时代,内容创作的需求不断增加,无论是企业营销、个人博客还是学术研究,高效且高质量的内容生成变得至关重要。DeepSeek作为一款先进的人工智能写作助手,凭借其强大的语言生成能力,为创作者…...

【第六天】零基础入门刷题Python-算法篇-数据结构与算法的介绍-一种常见的贪心算法(持续更新)

提示:文章写完后,目录可以自动生成,如何生成可参考右边的帮助文档 文章目录 前言一、Python数据结构与算法的详细介绍1.Python中的常用的贪心算法2.贪心算法3.详细的贪心代码1)一种常见的贪心算法 总结 前言 提示:这里…...

C# Winform制作一个登录系统

using System; using System.Collections; using System.Collections.Generic; using System.ComponentModel; using System.Data; using System.Drawing; using System.Linq; using System.Text; using System.Threading.Tasks; using System.Windows.Forms;namespace 登录 {p…...

算法总结-哈希表

文章目录 1.赎金信1.答案2.思路 2.字母异位词分组1.答案2.思路 3.两数之和1.答案2.思路 4.快乐数1.答案2.思路 5.最长连续序列1.答案2.思路 1.赎金信 1.答案 package com.sunxiansheng.arithmetic.day14;/*** Description: 383. 赎金信** Author sun* Create 2025/1/22 11:10…...

向下调整算法(详解)c++

算法流程: 与⽗结点的权值作⽐较,如果⽐它⼤,就与⽗亲交换; 交换完之后,重复 1 操作,直到⽐⽗亲⼩,或者换到根节点的位置 大家可能会有点疑惑,这个是大根堆,22是怎么跑到…...

蓝桥杯之c++入门(一)【C++入门】

目录 前言5. 算术操作符5.1 算术操作符5.2 浮点数的除法5.3 负数取模5.4 数值溢出5.5 练习练习1:计算 ( a b ) ⋆ c (ab)^{\star}c (ab)⋆c练习2:带余除法练习3:整数个位练习4:整数十位练习5:时间转换练习6&#xff…...

使用Python爬虫获取1688商品拍立淘API接口(item_search_img)的实战指南

在电商领域,通过图片搜索商品(拍立淘)已经成为一种重要的商品检索方式。1688平台的item_search_img接口允许用户通过上传图片来搜索相似商品,这为商品信息采集和市场分析提供了极大的便利。本文将详细介绍如何使用Python爬虫技术调…...

ElasticSearch-文档元数据乐观并发控制

文章目录 什么是文档?文档元数据文档的部分更新Update 乐观并发控制 最近日常工作开发过程中使用到了 ES,最近在检索资料的时候翻阅到了 ES 的官方文档,里面对 ES 的基础与案例进行了通俗易懂的解释,读下来也有不少收获&#xff0…...

使用Navicat Premium管理数据库时,如何关闭事务默认自动提交功能?

使用Navicat Premium管理数据库时,最糟心的事情莫过于事务默认自动提交,也就是你写完语句运行时,它自动执行commit提交至数据库,此时你就无法进行回滚操作。 建议您尝试取消勾选“选项”中的“自动开始事务”,点击“工…...

【单细胞-第三节 多样本数据分析】

文件在单细胞\5_GC_py\1_single_cell\1.GSE183904.Rmd GSE183904 数据原文 1.获取临床信息 筛选样本可以参考临床信息 rm(list ls()) library(tinyarray) a geo_download("GSE183904")$pd head(a) table(a$Characteristics_ch1) #统计各样本有多少2.批量读取 学…...

(java) IO流

学习IO流之前,我们需要先认识file对象,帮助我们更好的使用IO流 1.1 file 作用:关联硬盘上的文件 写法: File(String path); (推荐)File(String parent, String child); //由父级路径,再子级路径拼接而成File(File p…...

2025年1月个人工作生活总结

本文为 2025年1月工作生活总结。 研发编码 使用sqlite3命令行查询表数据 可以直接使用sqlite3查询数据表,不需进入命令行模式。示例如下: sqlite3 database_name.db "SELECT * FROM table_name;"linux shell使用read超时一例 先前有个编译…...

线性调整器——耗能型调整器

线性调整器又称线性电压调节器,以下是关于它的介绍: 基本工作原理 线性调整器的基本电路如图1.1(a)所示,晶体管Q1(工作于线性状态,或非开关状态)构成一个连接直流源V和输出端V。的可调电气电阻,直流源V由60Hz隔离变压器(电气隔离和整流&#…...

【2025美赛D题】为更美好的城市绘制路线图建模|建模过程+完整代码论文全解全析

你是否在寻找数学建模比赛的突破点?数学建模进阶思路! 作为经验丰富的美赛O奖、国赛国一的数学建模团队,我们将为你带来本次数学建模竞赛的全面解析。这个解决方案包不仅包括完整的代码实现,还有详尽的建模过程和解析&#xff0c…...

【Numpy核心编程攻略:Python数据处理、分析详解与科学计算】1.28 存储之道:跨平台数据持久化方案

好的,我将按照您的要求生成一篇高质量的Python NumPy文章。以下是第28篇《存储之道:跨平台数据持久化方案》的完整内容,包括目录、正文和参考文献。 1.28 存储之道:跨平台数据持久化方案 目录 #mermaid-svg-n1z37AP8obEgptkD {f…...

拼车(1094)

1094. 拼车 - 力扣&#xff08;LeetCode&#xff09; 解法&#xff1a; class Solution { public:bool carPooling(vector<vector<int>>& trips, int capacity) {uint32_t passenger_cnt 0;//将原数据按照from排序auto func_0 [](vector<int> & …...

基于Python的人工智能患者风险评估预测模型构建与应用研究(下)

3.3 模型选择与训练 3.3.1 常见预测模型介绍 在构建患者风险评估模型时,选择合适的预测模型至关重要。不同的模型具有各自的优缺点和适用场景,需要根据医疗数据的特点、风险评估的目标以及计算资源等因素进行综合考虑。以下详细介绍几种常见的预测模型。 逻辑回归(Logisti…...

< OS 有关 > Android 手机 SSH 客户端 app: connectBot

connectBot 开源且功能齐全的SSH客户端,界面简洁,支持证书密钥。 下载量超 500万 方便在 Android 手机上&#xff0c;连接 SSH 服务器&#xff0c;去运行命令。 Fail2ban 12小时内抓获的 IP ~ ~ ~ ~ rootjpn:~# sudo fail2ban-client status sshd Status for the jail: sshd …...

向量和矩阵算法笔记

向量和矩阵算法笔记 Ps:因为本人实力有限,有一部分可能不太详细,若有补充评论区回复,QWQ 向量 向量的定义 首先,因为我刚刚学到高中的向量,对向量的看法呢就是一条有长度和方向的线,不过这在数学上的定义其实是不对,甚至跟我看的差别其实有点大,真正的定义就是数域…...

uniapp使用uni.navigateBack返回页面时携带参数到上个页面

我们平时开发中也经常遇到这种场景&#xff0c;跳转一个页面会进行一些操作&#xff0c;操作完成后再返回上个页面同时要携带着一些参数 其实也很简单&#xff0c;也来记录一下吧 假设从A页面 跳转到 B页面 A页面 直接上完整代码了哈&#xff0c;很简单&#xff1a; <t…...

Python 梯度下降法(二):RMSProp Optimize

文章目录 Python 梯度下降法&#xff08;二&#xff09;&#xff1a;RMSProp Optimize一、数学原理1.1 介绍1.2 公式 二、代码实现2.1 函数代码2.2 总代码 三、代码优化3.1 存在问题3.2 收敛判断3.3 函数代码3.4 总代码 四、优缺点4.1 优点4.2 缺点 Python 梯度下降法&#xff…...

Android Studio 正式版 10 周年回顾,承载 Androider 的峥嵘十年

Android Studio 1.0 宣发于 2014 年 12 月&#xff0c;而现在时间来到 2025 &#xff0c;不知不觉间 Android Studio 已经陪伴 Androider 走过十年历程。 Android Studio 10 周年&#xff0c;也代表着了我的职业生涯也超十年&#xff0c;现在回想起来依然觉得「唏嘘」&#xff…...

sem_wait的概念和使用案列

sem_wait 是 POSIX 标准中定义的一个用于同步的函数&#xff0c;它通常用于操作信号量&#xff08;semaphore&#xff09;。信号量是一个整数变量&#xff0c;可以用来控制对共享资源的访问。在多线程编程中&#xff0c;sem_wait 常用于实现线程间的同步。 概念 sem_wait 的基…...

集合的奇妙世界:Python集合的经典、避坑与实战

集合的奇妙世界&#xff1a;Python集合的经典、避坑与实战 内容简介 本系列文章是为 Python3 学习者精心设计的一套全面、实用的学习指南&#xff0c;旨在帮助读者从基础入门到项目实战&#xff0c;全面提升编程能力。文章结构由 5 个版块组成&#xff0c;内容层层递进&#x…...

专业视角深度解析:DeepSeek的核心优势何在?

杭州深度求索&#xff08;DeepSeek&#xff09;人工智能基础技术研究有限公司&#xff0c;是一家成立于2023年7月的中国人工智能初创企业&#xff0c;总部位于浙江省杭州市。该公司由量化对冲基金幻方量化&#xff08;High-Flyer&#xff09;的联合创始人梁文锋创立&#xff0c…...

MySQL 索引存储结构

索引是优化数据库查询最重要的方式之一&#xff0c;它是在 MySQL 的存储引擎层中实现的&#xff0c;所以 每一种存储引擎对应的索引不一定相同。我们可以通过下面这张表格&#xff0c;看看不同的存储引擎 分别支持哪种索引类型&#xff1a; BTree 索引和 Hash 索引是我们比较…...