当前位置: 首页 > article >正文

长短期记忆(LSTM)网络模型

一、概述

  长短期记忆(Long Short-Term Memory,LSTM)网络是一种特殊的循环神经网络(RNN),专门设计用于解决传统 RNN 在处理长序列数据时面临的梯度消失 / 爆炸问题,能够有效捕捉长距离依赖关系。其核心在于引入记忆细胞(Cell State)和门控机制(Gate Mechanism),通过控制信息的流动来实现对长期信息的存储与遗忘。

二、模型原理

  LSTM 由记忆细胞和三个门控单元(遗忘门、输入门、输出门)组成,每个门控单元通过 sigmoid 激活函数输出 0 到 1 之间的数值,表示允许信息通过的程度(0 表示完全禁止,1 表示完全允许)。

1. 记忆细胞状态

  记忆细胞状态就像一条信息传输的 “高速公路”,它贯穿整个 LSTM 网络,负责在不同时间步之间传递信息。信息在记忆细胞状态中传递时,可以相对稳定地保留较长时间,避免了传统 RNN 中信息容易丢失的问题。遗忘门和输入门共同作用于记忆细胞状态,遗忘门决定删除哪些旧信息,输入门决定添加哪些新信息 ,从而实现对记忆细胞状态的更新。

2. 输入门

  输入门负责处理当前时刻的输入信息,决定哪些新的信息会被添加到记忆细胞状态中。它利用 sigmoid 函数输出一个值,用于控制新信息的 “准入程度”。同时,输入内容通过 tanh 函数生成一个候选值向量,这个向量包含了可能要添加到记忆细胞状态中的新信息。最后,将 sigmoid 函数的输出与 tanh 函数生成的候选值向量相乘,得到实际要添加到记忆细胞状态中的信息。

3. 遗忘门

  遗忘门决定了上一时刻记忆细胞状态中哪些信息会被保留到当前时刻。它接收上一时刻的隐藏状态和当前时刻的输入,通过一个 sigmoid 激活函数输出一个 0 到 1 之间的数值。这个数值就像一把 “钥匙”,数值越接近 1,表示上一时刻的该部分信息被保留的程度越高;数值越接近 0,则表示该部分信息被遗忘的程度越高。例如,在处理一段文字序列时,如果之前的内容与当前句子的主题关联不大,遗忘门就会降低这些信息的保留程度。

4. 输出门

  输出门根据当前记忆细胞状态和隐藏状态,决定最终的输出。它首先使用 sigmoid 函数得到一个控制输出的向量。然后,对记忆细胞状态进行 tanh 处理,将处理后的记忆细胞状态与 sigmoid 函数的输出向量相乘,从而得到 LSTM 单元的最终输出。
一个典型LSTM的单元结构为

在这里插入图片描述
  也就是说,对每个LSTM单元,都有四个输入、一个输出,这四个输入也就是对同一组输入数据的线性组合,只是组合了不同参数。具体的计算过程图示为

在这里插入图片描述
  显然,相较于传统的网络结构,LSTM具有四倍的参数量。

三、优势与局限

1. 优势

  LSTM 的门控机制使其在处理长序列数据时,能够有效保留和更新信息,避免梯度消失和梯度爆炸问题,从而学习到长距离的依赖关系,在许多序列数据处理任务中取得了优异的成绩。此外,LSTM 的结构具有较好的通用性,可以适应多种不同类型的序列数据处理任务。

2. 局限

  由于其结构相对复杂,包含多个门和大量参数,训练过程通常需要更多的计算资源和时间,并且容易出现过拟合问题。同时,LSTM 在解释性方面相对较差,难以直观地理解模型是如何做出决策的。

四、应用领域

1. 自然语言处理

  在自然语言处理任务中,LSTM 被广泛应用于文本分类、机器翻译、语音识别、问答系统等。例如,在机器翻译中,LSTM 可以将源语言句子的语义信息编码成固定长度的向量,然后通过解码过程将其转换为目标语言句子;在语音识别中,LSTM 能够处理语音信号中的时间序列信息,将语音转换为文字。

2. 时间序列预测

  LSTM 在时间序列预测领域表现出色,如股票价格预测、天气预测、电力负荷预测等。由于 LSTM 能够有效捕捉时间序列中的长期依赖关系,相比传统方法,它可以更准确地预测未来趋势。例如,在股票价格预测中,LSTM 可以分析历史股价数据中的复杂模式,预测未来股价走势。

3. 其他领域

  此外,LSTM 还在视频分析、生物信息学等领域得到应用。在视频分析中,LSTM 可以处理视频帧序列,实现动作识别、视频内容理解等任务;在生物信息学中,LSTM 可用于基因序列分析,预测基因功能等。

五、Python实现示例

(环境:Python 3.11,PyTorch 2.4.0)

import matplotlib
matplotlib.use('TkAgg')import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader, TensorDataset# 设置matplotlib的字体
plt.rcParams['font.sans-serif'] = ['SimHei']  # 'SimHei' 是黑体,也可设置 'Microsoft YaHei' 等
plt.rcParams['axes.unicode_minus'] = False  # 正确显示负号# 设置随机种子,保证结果可复现
torch.manual_seed(42)
np.random.seed(42)# 生成示例数据(正弦波序列)
def generate_sequence(length, freq=0.1):"""生成正弦波序列作为示例数据"""x = np.linspace(0, 2 * np.pi * freq, length)return np.sin(x)def create_sequences(data, seq_length):"""将数据转换为序列和对应目标值的形式"""xs, ys = [], []for i in range(len(data) - seq_length):x = data[i:i + seq_length]y = data[i + seq_length]xs.append(x)ys.append(y)return np.array(xs), np.array(ys)# 生成数据
seq_length = 10
data = generate_sequence(1000)
x, y = create_sequences(data, seq_length)# 转换为PyTorch张量
x_tensor = torch.FloatTensor(x).view(-1, seq_length, 1)
y_tensor = torch.FloatTensor(y).view(-1, 1)# 创建数据加载器
dataset = TensorDataset(x_tensor, y_tensor)
train_size = int(0.8 * len(dataset))
train_dataset, test_dataset = torch.utils.data.random_split(dataset, [train_size, len(dataset) - train_size])
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)# 定义LSTM模型
class LSTMModel(nn.Module):def __init__(self, input_size=1, hidden_size=50, num_layers=1, output_size=1):super(LSTMModel, self).__init__()self.hidden_size = hidden_sizeself.num_layers = num_layers# LSTM层self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)# 全连接层self.fc = nn.Linear(hidden_size, output_size)def forward(self, x):# 初始化隐藏状态和细胞状态h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(x.device)c0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(x.device)# 前向传播LSTMout, _ = self.lstm(x, (h0, c0))# 只取序列中的最后一个时间步的输出out = self.fc(out[:, -1, :])return out# 初始化模型、损失函数和优化器
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = LSTMModel(input_size=1, hidden_size=50, num_layers=1, output_size=1).to(device)
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)# 训练模型
def train_model(model, train_loader, criterion, optimizer, epochs=100):model.train()for epoch in range(epochs):total_loss = 0for inputs, targets in train_loader:inputs, targets = inputs.to(device), targets.to(device)# 清零梯度optimizer.zero_grad()# 前向传播outputs = model(inputs)loss = criterion(outputs, targets)# 反向传播和优化loss.backward()optimizer.step()total_loss += loss.item()if (epoch + 1) % 10 == 0:print(f'Epoch [{epoch + 1}/{epochs}], Loss: {total_loss / len(train_loader):.4f}')# 评估模型
def evaluate_model(model, test_loader):model.eval()predictions = []actuals = []with torch.no_grad():for inputs, targets in test_loader:inputs, targets = inputs.to(device), targets.to(device)outputs = model(inputs)predictions.extend(outputs.cpu().numpy())actuals.extend(targets.cpu().numpy())return np.array(predictions), np.array(actuals)# 训练模型
train_model(model, train_loader, criterion, optimizer, epochs=50)# 评估模型
predictions, actuals = evaluate_model(model, test_loader)# 可视化结果
plt.figure(figsize=(10, 6))
plt.plot(actuals, label='Actual Values')
plt.plot(predictions, label='Predicted Values')
plt.title('LSTM预测结果')
plt.xlabel('样本')
plt.ylabel('值')
plt.legend()
plt.show()

在这里插入图片描述
在这里插入图片描述
  示例实现了一个基本的 LSTM 模型,用于预测正弦波序列的下一个值。主要包括以下几个部分:

  数据生成:创建一个正弦波序列,并将其转换为适合 LSTM 训练的序列格式。
  模型定义:定义了一个包含 LSTM 层和全连接层的模型,用于处理序列数据并输出预测结果。
  训练过程:使用均方误差损失函数和 Adam 优化器训练模型。
  评估和可视化:评估模型性能并可视化预测结果与实际值的对比。

  可以通过修改参数如seq_length(序列长度)、hidden_size(LSTM 隐藏层大小)、num_layers(LSTM 层数)等来调整模型,也可以将此框架应用于其他序列预测任务。



End.

相关文章:

长短期记忆(LSTM)网络模型

一、概述 长短期记忆(Long Short-Term Memory,LSTM)网络是一种特殊的循环神经网络(RNN),专门设计用于解决传统 RNN 在处理长序列数据时面临的梯度消失 / 爆炸问题,能够有效捕捉长距离依赖关系。…...

深入理解 Linux 文件系统与日志文件分析

一、Linux 文件系统概述 1. 文件系统的基本概念 文件系统(File System)是操作系统用于管理和组织存储设备上数据的机制。它提供了一种结构,使得用户和应用程序能够方便地存储和访问数据。 2. Linux 文件系统结构 Linux 文件系统采用树状目…...

CSS3美化页面元素

1. 字体 <span>标签 字体样式⭐ 字体类型&#xff08;font-family&#xff09; 字体大小&#xff08;font-size&#xff09; 字体风格&#xff08;font-style&#xff09; 字体粗细&#xff08;font-weight&#xff09; 字体属性&#xff08;font&#xff09; 2. 文本 文…...

网络安全-等级保护(等保)3-0 等级保护测评要求现行技术标准

################################################################################ 第三章&#xff1a;测评要求、测评机构要求&#xff0c;最终目的是通过测评&#xff0c;所以我们将等保要求和测评相关要求一一对应形成表格。 GB/T 28448-2019 《信息安全技术 网络安全等…...

WPS 利用 宏 脚本拆分 Excel 多行文本到多行

文章目录 WPS 利用 宏 脚本拆分 Excel 多行文本到多行效果需求背景&#x1f6e0; 操作步骤代码实现代码详解使用场景注意事项总结 WPS 利用 宏 脚本拆分 Excel 多行文本到多行 在 Excel 工作表中&#xff0c;我们经常遇到一列中包含多行文本&#xff08;用换行符分隔&#xff…...

R语言错误处理方法大全

在R语言的批量运行中&#xff0c;常需要自动跳过错误&#xff0c;继续向下运行。 1、使用 tryCatch() 捕获错误并返回占位符 # 示例&#xff1a;循环中跳过错误继续执行 results <- numeric(5) # 预分配结果向量for(i in 1:5) {# 用 tryCatch 包裹可能出错的代码results[…...

AI“实体化”革命:具身智能如何重构体育、工业与未来生活

近年来&#xff0c;人工智能&#xff08;AI&#xff09;技术的飞速发展正在重塑各行各业&#xff0c;而具身智能&#xff08;Embodied AI&#xff09;作为AI领域的重要分支&#xff0c;正逐渐从实验室走向现实应用。具身智能的核心在于让AI系统具备物理实体&#xff0c;能够与环…...

Opencv4 c++ 自用笔记 05 形态学操作

图像形态学主要获取物体的形状与位置信息。利用具有一定形态的结构元素度量和提取图像中的对应形状&#xff0c;达到对图像分析和识别的目的。操作主要包括腐蚀、膨胀、开运算和闭运算。 像素距离与连通域 图像形态学中&#xff0c;将不与其他区域链接的独立区域称为集合或者…...

DrissionPage 数据提取技巧全解析:从入门到实战

在当今数据驱动的时代&#xff0c;网页数据提取已成为自动化办公、市场分析和爬虫开发的核心技能。作为新一代网页自动化工具&#xff0c;DrissionPage 以其独特的双模式融合设计&#xff08;Selenium Requests&#xff09;脱颖而出。本文将结合官方文档与实战案例&#xff0c…...

如何构建自适应架构的镜像

目标 我有一个服务叫xxx&#xff0c;一开始它运行在x86架构的机器上&#xff0c;所以最开始有个xxx:stable-amd64的镜像&#xff0c;后来它又需要运行在arm64架构的机器上&#xff0c;所以又重新打了个xxx:stable-arm64的镜像 但是对于安装脚本来说&#xff0c;我不希望我在拉…...

R语言基础| 创建数据集

在R语言中&#xff0c;有多种数据类型&#xff0c;用以存储和处理数据。每种数据类型都有其特定的用途和操作函数&#xff0c;使得R语言在处理各种数据分析任务时非常灵活和强大&#xff1a; 向量&#xff08;Vector&#xff09;: 向量是R语言中最基本的数据类型&#xff0c;它…...

剑指offer15_数值的整数次方

数值的整数次方 实现函数 double Power(double base, int exponent) 题目要求 计算 base exponent \text{base}^{\text{exponent}} baseexponent&#xff1a; 不得使用库函数不需要考虑大数问题&#xff0c;绝对误差不超过 10 − 2 10^{-2} 10−2不会出现底数和指数同为 0…...

Centos7搭建zabbix6.0

此方法适用于zabbix6以上版本zabbix6.0前期环境准备&#xff1a;Lamp&#xff08;linux httpd mysql8.0 php&#xff09;mysql官网下载位置&#xff1a;https://dev.mysql.com/downloads/mysql/Zabbix源码包地址&#xff1a;https://www.zabbix.com/cn/download_sourcesZabbix6…...

使用Redis的四个常见问题及其解决方案

Redis 缓存穿透 定义&#xff1a;redis查询一个不存在的数据&#xff0c;导致每次都查询数据库 解决方案&#xff1a; 如果查询的数据为空&#xff0c;在redis对应的key缓存空数据&#xff0c;并设置短TTL。 因为缓存穿透通常是因为被恶意用不存在的查询参数进行压测攻击&…...

Docker 部署前后端分离项目

1.Docker 1.1 什么是 Docker &#xff1f; Docker 是一种开源的 容器化平台&#xff0c;用于开发、部署和运行应用程序。它通过 容器&#xff08;Container&#xff09; 技术&#xff0c;将应用程序及其依赖项打包在一个轻量级、可移植的环境中&#xff0c;确保应用在不同计算…...

云游戏混合架构

云游戏混合架构通过整合本地计算资源与云端能力&#xff0c;形成了灵活且高性能的技术体系&#xff0c;其核心架构及技术特征可概括如下&#xff1a; 一、混合架构的典型模式 分层混合模式‌ 前端应用部署于公有云&#xff08;如渲染流化服务&#xff09;&#xff0c;后端逻辑…...

【小红书】API接口,获取笔记核心数据

小红书笔记核心数据API接口详解 - 深圳小于科技提供专业数据服务 深圳小于科技&#xff08;官网&#xff1a;https://www.szlessthan.com&#xff09;推出的小红书笔记核心数据API接口&#xff0c;为开发者提供精准的笔记互动数据分析能力&#xff0c;助力内容运营与商业决策。…...

会议室钥匙总丢失?换预约功能的智能门锁更安全

在企业日常运营中&#xff0c;会议室作为重要的沟通与协作场所&#xff0c;其管理效率与安全性直接影响着企业的运作顺畅度。然而&#xff0c;传统会议室管理方式中钥匙丢失、管理不便等问题频发&#xff0c;给企业带来了不少困扰。近期&#xff0c;某企业引入了启辰智慧预约系…...

Redis底层数据结构之跳表(SkipList)

SkipList是Redis有序结合ZSet底层的数据结构&#xff0c;也是ZSet的灵魂所在。与之相应的&#xff0c;Redis还有一个无序集合Set&#xff0c;这两个在底层的实现是不一样的。 标准的SkipList&#xff1a; 跳表的本质是一个链表。链表这种结构虽然简单清晰&#xff0c;但是在查…...

跨架构镜像打包问题及解决方案

问题背景&#xff1a; 需求&#xff1a; 有一个镜像是 docker.io 的&#xff0c;是 docker.io/aquasec/kube-bench:v0.10.6&#xff0c;我想把该镜像在本地电脑&#xff08;可翻墙&#xff09;下载下来&#xff0c;然后 docker save 打包成一个 tar 包&#xff0c;传输到服务器…...

云原生时代 Kafka 深度实践:05性能调优与场景实战

5.1 性能调优全攻略 Producer调优 批量发送与延迟发送 通过调整batch.size和linger.ms参数提升吞吐量&#xff1a; props.put(ProducerConfig.BATCH_SIZE_CONFIG, 16384); // 默认16KB props.put(ProducerConfig.LINGER_MS_CONFIG, 10); // 等待10ms以积累更多消息ba…...

Ubuntu安装Docker命令清单(以20.04为例)

在你虚拟机上完成Ubuntu的下载后打开终端&#xff01;&#xff01;&#xff01; Ubuntu安装Docker终极命令清单&#xff08;以20.04为例&#xff09; # 1. 卸载旧版本&#xff08;全新系统可跳过&#xff09; sudo apt-get remove docker docker-engine docker.io containerd …...

使用 Python 制作 GIF 动图,并打包为 EXE 可执行程序

文章目录 成品百度网盘下载&#x1f3ac; 使用 Python 制作 GIF 动图&#xff0c;并打包为 EXE 可执行程序&#xff08;含图形界面&#xff09;&#x1f9f0; 环境准备&#x1f4bb; 功能预览&#x1f9d1;‍&#x1f4bb; 完整代码&#xff08;图形界面 功能&#xff09;如何…...

HarmonyOS Next 弹窗系列教程(2)

HarmonyOS Next 弹窗系列教程&#xff08;2&#xff09; 上一章节我们讲了自定义弹出框 (openCustomDialog)&#xff0c;那对于一些简单的业务场景&#xff0c;不一定需要都是自定义&#xff0c;也可以使用 HarmonyOS Next 内置的一些弹窗效果。比如&#xff1a; 名称描述不依…...

Ubuntu 18.04 上源码安装 protobuf 3.7.0

&#x1f527; 1️⃣ 安装依赖 sudo apt update sudo apt install -y autoconf automake libtool curl make g unzip&#x1f4e5; 2️⃣ 下载源码 cd ~ git clone https://github.com/protocolbuffers/protobuf.git cd protobuf git checkout v3.7.0⚙️ 3️⃣ 编译 & 安…...

中小企业搭建网站选择虚拟主机还是云服务器?华为云有话说

这是一个很常见的问题&#xff0c;许多小企业在搭建网站时都会面临这个选择。虚拟主机和云服务器都有各自的优缺点&#xff0c;需要根据自己的需求和预算来决定。 虚拟主机是指将一台物理服务器分割成多个虚拟空间&#xff0c;每个空间都可以运行一个网站。虚拟主机的优点是价格…...

使用 HTML + JavaScript 在高德地图上实现物流轨迹跟踪系统

在电商行业蓬勃发展的今天&#xff0c;物流信息查询已成为人们日常生活中的重要需求。本文将详细介绍如何基于高德地图 API 利用 HTML JavaScript 实现物流轨迹跟踪系统的开发。 效果演示 项目概述 本项目主要包含以下核心功能&#xff1a; 地图初始化与展示运单号查询功能…...

19-项目部署(Linux)

Linux是一套免费使用和自由传播的操作系统。说到操作系统&#xff0c;大家比较熟知的应该就是Windows和MacOS操作系统&#xff0c;我们今天所学习的Linux也是一款操作系统。 我们作为javaEE开发工程师&#xff0c;将来在企业中开发时会涉及到很多的数据库、中间件等技术&#…...

html基础01:前端基础知识学习

html基础01&#xff1a;前端基础知识学习 1.个人建立打造 -- 之前知识的小总结1.1个人简历展示1.2简历信息填写页面 1.个人建立打造 – 之前知识的小总结 1.1个人简历展示 <!DOCTYPE html> <html lang"en"><head><meta charset"UTF-8&qu…...

Golang学习之旅

Golang学习之旅&#xff1a;初探Go语言的奥秘 在当今这个快速发展的技术时代&#xff0c;编程语言层出不穷&#xff0c;每一种都有其独特的魅力和适用场景。作为一名对技术充满热情的开发者&#xff0c;我一直在探索新的知识&#xff0c;以提升自己的编程技能。最近&#xff0…...