当前位置: 首页 > news >正文

什么是长短期记忆网络?

一、概念

        长短期记忆网络(Long Short-Term Memory, LSTM)是一种特殊的循环神经网络(RNN),旨在解决标准RNN在处理长序列时的梯度消失和梯度爆炸问题。LSTM通过引入三个门(输入门、遗忘门和输出门)来控制信息的流动。其中,每个门都是一个神经网络层,用于决定哪些信息应该被保留,哪些信息应该被丢弃。LSTM的核心是细胞状态(cell state),它通过这些门的控制来更新和传递信息。

二、核心算法

        令x_{t}为时间步 t 的输入向量,h_{t-1}为前一个时间步的隐藏状态向量,h_{t}为当前时间步的隐藏状态向量,C_{t-1}为前一个时间步的细胞状态向量,C_{t}为当前时间步的细胞状态变量,f_{t}为当前时间步的遗忘门向量,i_{t}为当前时间步的输入门向量,\bar{C_{t}}为当前时间步的候选细胞状态向量,o_{t}为当前时间步的输出门向量,W_{f},W_{i},W_{C},W_{o}分别为各门的权重矩阵,b_{f},b_{i},b_{C},b_{o}为偏置向量,\sigma为sigmoid激活函数,tanh为tanh激活函数,*为元素级乘法。LSTM的核心内容包括以下几个部分:

1、遗忘门(Forget Gate)

        遗忘门决定细胞状态中哪些信息需要被遗忘。通过sigmoid激活函数,遗忘门的输出在0到1之间,表示每个细胞状态元素被保留的比例。

f_{t} = \sigma(W_{f} \cdot \left [ h_{t-1}, x_{t} \right ] + b_{f})

2、输入门(Input Gate)

        输入门决定哪些新的信息需要被写入细胞状态。通过sigmoid激活函数,输入门的输出在0到1之间,表示每个候选细胞状态元素被写入的比例。候选细胞状态通过tanh激活函数生成,表示新的信息。

i_{t} = \sigma(W_{i} \cdot \left [ h_{t-1}, x_{t} \right ] + b_{i})

\bar{C}_{t} = tanh(W_{C} \cdot \left [ h_{t-1}, x_{t} \right ] + b_{C})

3、细胞状态更新

        细胞状态结合遗忘门和输入门的结果进行更新。遗忘门的输出与前一个时间步的细胞状态相乘,表示保留的旧信息。输入门的输出与候选细胞状态相乘,表示写入的新信息。两者相加得到当前时间步的细胞状态。

C_{t} = f_{t} \ast C_{t-1}+i_{t} \ast \bar{C}_{t}

4、输出门(Output Gate)

        输出门决定细胞状态的哪些部分将作为输出。通过sigmoid激活函数,输出门的输出在0到1之间,表示每个细胞状态元素被输出的比例。细胞状态通过tanh激活函数进行非线性变换,然后与输出门的输出相乘,得到当前时间步的隐藏状态。

o_{t} = \sigma(W_{o} \cdot \left [ h_{t-1}, x_{t} \right ] + b_{o})

h_{t} = o_{t} \ast tanh(C_{t})

三、python实现

import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split# 生成正弦波数据
def generate_sine_wave(seq_length, num_samples):x = np.linspace(0, num_samples, num_samples)y = np.sin(x)data = []for i in range(len(y) - seq_length):data.append(y[i:i+seq_length+1])return np.array(data)# 定义LSTM模型
class LSTMModel(nn.Module):def __init__(self, input_size, hidden_size, output_size, num_layers):super(LSTMModel, self).__init__()self.hidden_size = hidden_sizeself.num_layers = num_layersself.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)self.fc = nn.Linear(hidden_size, output_size)def forward(self, x):h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(x.device)c0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(x.device)out, _ = self.lstm(x, (h0, c0))out = self.fc(out[:, -1, :])return out# 超参数设置
seq_length = 50
num_samples = 1000
input_size = 1
hidden_size = 50
output_size = 1
num_layers = 2
batch_size = 64
learning_rate = 0.001
num_epochs = 5
test_size = 0.2  # 测试集占比# 生成数据
data = generate_sine_wave(seq_length, num_samples)
X = data[:, :-1]
y = data[:, -1]# 划分训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=test_size, random_state=42)# 转换为Tensor
X_train = torch.tensor(X_train.reshape(-1, seq_length, input_size), dtype=torch.float32)
y_train = torch.tensor(y_train.reshape(-1, output_size), dtype=torch.float32)
X_test = torch.tensor(X_test.reshape(-1, seq_length, input_size), dtype=torch.float32)
y_test = torch.tensor(y_test.reshape(-1, output_size), dtype=torch.float32)# 创建数据加载器
train_dataset = torch.utils.data.TensorDataset(X_train, y_train)
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
test_dataset = torch.utils.data.TensorDataset(X_test, y_test)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False)# 初始化模型、损失函数和优化器
model = LSTMModel(input_size, hidden_size, output_size, num_layers)
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)# 训练模型
for epoch in range(num_epochs):model.train()for i, (inputs, labels) in enumerate(train_loader):outputs = model(inputs)loss = criterion(outputs, labels)optimizer.zero_grad()loss.backward()optimizer.step()print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}')# 测试模型
model.eval()
with torch.no_grad():predicted = []actual = []for inputs, labels in test_loader:outputs = model(inputs)predicted.extend(outputs.numpy())actual.extend(labels.numpy())# 绘制结果
plt.plot(actual, label='Actual data')
plt.plot(predicted, label='Predicted data')
plt.legend()
plt.show()

四、总结

        LSTM能够捕捉长时间依赖关系,使得模型在处理长序列数据时表现得比标准的RNN更好。但由于LSTM的计算依赖于前一个时间步的输出,这使得这样的网络结构难以并行化,在处理大规模数据时的效率较低。

相关文章:

什么是长短期记忆网络?

一、概念 长短期记忆网络(Long Short-Term Memory, LSTM)是一种特殊的循环神经网络(RNN),旨在解决标准RNN在处理长序列时的梯度消失和梯度爆炸问题。LSTM通过引入三个门(输入门、遗忘门和输出门&#xff09…...

git中有关old mode 100644、new mode 10075的问题解决小结

在 Git 版本控制系统中,文件权限变更是一种常见情况。当你看到类似 old mode 100644 和 new mode 100755 的信息时,这通常表示文件的权限发生了变化。本文将详细解析这种情况,并提供解决方法和注意事项。 问题背景 在 Git 中,文…...

Jenkins上生成的allure report打不开怎么处理

目录 问题背景: 原因: 解决方案: Jenkins上修改配置 通过Groovy脚本在Script Console中设置和修改系统属性 步骤 验证是否清空成功 进一步的定制 也可以使用Nginx去解决 使用逆向代理服务器Nginx: 通过合理调整CSP配置&a…...

JSR303校验教学

1、什么是JSR303校验 JSR是Java Specification Requests的缩写,意思是Java 规范提案。是指向JCP(Java Community Process)提出新增一个标准化技术规范的正式请求。任何人都可以提交JSR,以向Java平台增添新的API和服务。JSR已成为Java界的一个重要标准。…...

使用DeepSeek技巧:提升内容创作效率与质量

一、引言 在当今快节奏的数字时代,内容创作的需求不断增加,无论是企业营销、个人博客还是学术研究,高效且高质量的内容生成变得至关重要。DeepSeek作为一款先进的人工智能写作助手,凭借其强大的语言生成能力,为创作者…...

【第六天】零基础入门刷题Python-算法篇-数据结构与算法的介绍-一种常见的贪心算法(持续更新)

提示:文章写完后,目录可以自动生成,如何生成可参考右边的帮助文档 文章目录 前言一、Python数据结构与算法的详细介绍1.Python中的常用的贪心算法2.贪心算法3.详细的贪心代码1)一种常见的贪心算法 总结 前言 提示:这里…...

C# Winform制作一个登录系统

using System; using System.Collections; using System.Collections.Generic; using System.ComponentModel; using System.Data; using System.Drawing; using System.Linq; using System.Text; using System.Threading.Tasks; using System.Windows.Forms;namespace 登录 {p…...

算法总结-哈希表

文章目录 1.赎金信1.答案2.思路 2.字母异位词分组1.答案2.思路 3.两数之和1.答案2.思路 4.快乐数1.答案2.思路 5.最长连续序列1.答案2.思路 1.赎金信 1.答案 package com.sunxiansheng.arithmetic.day14;/*** Description: 383. 赎金信** Author sun* Create 2025/1/22 11:10…...

向下调整算法(详解)c++

算法流程: 与⽗结点的权值作⽐较,如果⽐它⼤,就与⽗亲交换; 交换完之后,重复 1 操作,直到⽐⽗亲⼩,或者换到根节点的位置 大家可能会有点疑惑,这个是大根堆,22是怎么跑到…...

蓝桥杯之c++入门(一)【C++入门】

目录 前言5. 算术操作符5.1 算术操作符5.2 浮点数的除法5.3 负数取模5.4 数值溢出5.5 练习练习1:计算 ( a b ) ⋆ c (ab)^{\star}c (ab)⋆c练习2:带余除法练习3:整数个位练习4:整数十位练习5:时间转换练习6&#xff…...

使用Python爬虫获取1688商品拍立淘API接口(item_search_img)的实战指南

在电商领域,通过图片搜索商品(拍立淘)已经成为一种重要的商品检索方式。1688平台的item_search_img接口允许用户通过上传图片来搜索相似商品,这为商品信息采集和市场分析提供了极大的便利。本文将详细介绍如何使用Python爬虫技术调…...

ElasticSearch-文档元数据乐观并发控制

文章目录 什么是文档?文档元数据文档的部分更新Update 乐观并发控制 最近日常工作开发过程中使用到了 ES,最近在检索资料的时候翻阅到了 ES 的官方文档,里面对 ES 的基础与案例进行了通俗易懂的解释,读下来也有不少收获&#xff0…...

使用Navicat Premium管理数据库时,如何关闭事务默认自动提交功能?

使用Navicat Premium管理数据库时,最糟心的事情莫过于事务默认自动提交,也就是你写完语句运行时,它自动执行commit提交至数据库,此时你就无法进行回滚操作。 建议您尝试取消勾选“选项”中的“自动开始事务”,点击“工…...

【单细胞-第三节 多样本数据分析】

文件在单细胞\5_GC_py\1_single_cell\1.GSE183904.Rmd GSE183904 数据原文 1.获取临床信息 筛选样本可以参考临床信息 rm(list ls()) library(tinyarray) a geo_download("GSE183904")$pd head(a) table(a$Characteristics_ch1) #统计各样本有多少2.批量读取 学…...

(java) IO流

学习IO流之前,我们需要先认识file对象,帮助我们更好的使用IO流 1.1 file 作用:关联硬盘上的文件 写法: File(String path); (推荐)File(String parent, String child); //由父级路径,再子级路径拼接而成File(File p…...

2025年1月个人工作生活总结

本文为 2025年1月工作生活总结。 研发编码 使用sqlite3命令行查询表数据 可以直接使用sqlite3查询数据表,不需进入命令行模式。示例如下: sqlite3 database_name.db "SELECT * FROM table_name;"linux shell使用read超时一例 先前有个编译…...

线性调整器——耗能型调整器

线性调整器又称线性电压调节器,以下是关于它的介绍: 基本工作原理 线性调整器的基本电路如图1.1(a)所示,晶体管Q1(工作于线性状态,或非开关状态)构成一个连接直流源V和输出端V。的可调电气电阻,直流源V由60Hz隔离变压器(电气隔离和整流&#…...

【2025美赛D题】为更美好的城市绘制路线图建模|建模过程+完整代码论文全解全析

你是否在寻找数学建模比赛的突破点?数学建模进阶思路! 作为经验丰富的美赛O奖、国赛国一的数学建模团队,我们将为你带来本次数学建模竞赛的全面解析。这个解决方案包不仅包括完整的代码实现,还有详尽的建模过程和解析&#xff0c…...

【Numpy核心编程攻略:Python数据处理、分析详解与科学计算】1.28 存储之道:跨平台数据持久化方案

好的,我将按照您的要求生成一篇高质量的Python NumPy文章。以下是第28篇《存储之道:跨平台数据持久化方案》的完整内容,包括目录、正文和参考文献。 1.28 存储之道:跨平台数据持久化方案 目录 #mermaid-svg-n1z37AP8obEgptkD {f…...

拼车(1094)

1094. 拼车 - 力扣&#xff08;LeetCode&#xff09; 解法&#xff1a; class Solution { public:bool carPooling(vector<vector<int>>& trips, int capacity) {uint32_t passenger_cnt 0;//将原数据按照from排序auto func_0 [](vector<int> & …...

调用支付宝接口响应40004 SYSTEM_ERROR问题排查

在对接支付宝API的时候&#xff0c;遇到了一些问题&#xff0c;记录一下排查过程。 Body:{"datadigital_fincloud_generalsaas_face_certify_initialize_response":{"msg":"Business Failed","code":"40004","sub_msg…...

ssc377d修改flash分区大小

1、flash的分区默认分配16M、 / # df -h Filesystem Size Used Available Use% Mounted on /dev/root 1.9M 1.9M 0 100% / /dev/mtdblock4 3.0M...

《从零掌握MIPI CSI-2: 协议精解与FPGA摄像头开发实战》-- CSI-2 协议详细解析 (一)

CSI-2 协议详细解析 (一&#xff09; 1. CSI-2层定义&#xff08;CSI-2 Layer Definitions&#xff09; 分层结构 &#xff1a;CSI-2协议分为6层&#xff1a; 物理层&#xff08;PHY Layer&#xff09; &#xff1a; 定义电气特性、时钟机制和传输介质&#xff08;导线&#…...

【位运算】消失的两个数字(hard)

消失的两个数字&#xff08;hard&#xff09; 题⽬描述&#xff1a;解法&#xff08;位运算&#xff09;&#xff1a;Java 算法代码&#xff1a;更简便代码 题⽬链接&#xff1a;⾯试题 17.19. 消失的两个数字 题⽬描述&#xff1a; 给定⼀个数组&#xff0c;包含从 1 到 N 所有…...

SCAU期末笔记 - 数据分析与数据挖掘题库解析

这门怎么题库答案不全啊日 来简单学一下子来 一、选择题&#xff08;可多选&#xff09; 将原始数据进行集成、变换、维度规约、数值规约是在以下哪个步骤的任务?(C) A. 频繁模式挖掘 B.分类和预测 C.数据预处理 D.数据流挖掘 A. 频繁模式挖掘&#xff1a;专注于发现数据中…...

【决胜公务员考试】求职OMG——见面课测验1

2025最新版&#xff01;&#xff01;&#xff01;6.8截至答题&#xff0c;大家注意呀&#xff01; 博主码字不易点个关注吧,祝期末顺利~~ 1.单选题(2分) 下列说法错误的是:&#xff08; B &#xff09; A.选调生属于公务员系统 B.公务员属于事业编 C.选调生有基层锻炼的要求 D…...

JUC笔记(上)-复习 涉及死锁 volatile synchronized CAS 原子操作

一、上下文切换 即使单核CPU也可以进行多线程执行代码&#xff0c;CPU会给每个线程分配CPU时间片来实现这个机制。时间片非常短&#xff0c;所以CPU会不断地切换线程执行&#xff0c;从而让我们感觉多个线程是同时执行的。时间片一般是十几毫秒(ms)。通过时间片分配算法执行。…...

多模态大语言模型arxiv论文略读(108)

CROME: Cross-Modal Adapters for Efficient Multimodal LLM ➡️ 论文标题&#xff1a;CROME: Cross-Modal Adapters for Efficient Multimodal LLM ➡️ 论文作者&#xff1a;Sayna Ebrahimi, Sercan O. Arik, Tejas Nama, Tomas Pfister ➡️ 研究机构: Google Cloud AI Re…...

高防服务器能够抵御哪些网络攻击呢?

高防服务器作为一种有着高度防御能力的服务器&#xff0c;可以帮助网站应对分布式拒绝服务攻击&#xff0c;有效识别和清理一些恶意的网络流量&#xff0c;为用户提供安全且稳定的网络环境&#xff0c;那么&#xff0c;高防服务器一般都可以抵御哪些网络攻击呢&#xff1f;下面…...

智能分布式爬虫的数据处理流水线优化:基于深度强化学习的数据质量控制

在数字化浪潮席卷全球的今天&#xff0c;数据已成为企业和研究机构的核心资产。智能分布式爬虫作为高效的数据采集工具&#xff0c;在大规模数据获取中发挥着关键作用。然而&#xff0c;传统的数据处理流水线在面对复杂多变的网络环境和海量异构数据时&#xff0c;常出现数据质…...