当前位置: 首页 > article >正文

别再死记硬背LSTM公式了!用PyTorch手写一个LSTM单元,5分钟搞懂门控机制

从零实现LSTM单元用PyTorch代码拆解门控机制当你第一次看到LSTM那一堆复杂的公式时是不是感觉头大遗忘门、输入门、输出门、细胞状态...这些概念听起来高大上但真正动手写代码时却不知从何下手。今天我们就用PyTorch从零开始实现一个LSTM单元让你在编写代码的过程中真正理解这些门控机制是如何协同工作的。1. 环境准备与基础概念在开始编码之前我们先快速回顾一下LSTM的核心组件。LSTMLong Short-Term Memory是一种特殊的循环神经网络它通过引入门控机制解决了传统RNN难以捕捉长期依赖的问题。与普通RNN相比LSTM多了三个关键的门控结构遗忘门决定哪些历史信息需要保留或丢弃输入门控制当前输入信息中有多少需要更新到记忆单元输出门决定当前时刻应该输出哪些信息这些门控机制都通过sigmoid函数输出0到1之间的值来控制信息流动的比例。下面是我们即将实现的LSTM单元的计算流程# 伪代码展示LSTM计算流程 def lstm_cell(x, h_prev, c_prev, Wf, Wi, Wo, Wc, bf, bi, bo, bc): # 遗忘门 f sigmoid(Wf [x, h_prev] bf) # 输入门 i sigmoid(Wi [x, h_prev] bi) # 候选记忆 c_tilde tanh(Wc [x, h_prev] bc) # 更新细胞状态 c f * c_prev i * c_tilde # 输出门 o sigmoid(Wo [x, h_prev] bo) # 计算当前隐藏状态 h o * tanh(c) return h, c准备好你的Python环境我们需要以下工具库pip install torch numpy matplotlib2. 构建LSTM单元类现在让我们用PyTorch实现一个完整的LSTMCell类。我们将逐步构建这个类并在每一步解释对应的数学原理。2.1 初始化参数首先我们需要初始化LSTM单元的所有可训练参数import torch import torch.nn as nn class LSTMCell(nn.Module): def __init__(self, input_size, hidden_size): super().__init__() self.input_size input_size self.hidden_size hidden_size # 遗忘门参数 self.W_f nn.Parameter(torch.Tensor(hidden_size, input_size hidden_size)) self.b_f nn.Parameter(torch.Tensor(hidden_size)) # 输入门参数 self.W_i nn.Parameter(torch.Tensor(hidden_size, input_size hidden_size)) self.b_i nn.Parameter(torch.Tensor(hidden_size)) # 输出门参数 self.W_o nn.Parameter(torch.Tensor(hidden_size, input_size hidden_size)) self.b_o nn.Parameter(torch.Tensor(hidden_size)) # 候选记忆参数 self.W_c nn.Parameter(torch.Tensor(hidden_size, input_size hidden_size)) self.b_c nn.Parameter(torch.Tensor(hidden_size)) self.reset_parameters() def reset_parameters(self): # 使用Xavier初始化权重 for param in self.parameters(): if param.dim() 1: nn.init.xavier_uniform_(param) else: nn.init.zeros_(param)2.2 实现前向传播接下来我们实现LSTM单元的前向传播逻辑def forward(self, x, state): h_prev, c_prev state # 拼接当前输入和前一时刻的隐藏状态 combined torch.cat((x, h_prev), dim1) # 计算遗忘门 f torch.sigmoid(combined self.W_f.t() self.b_f) # 计算输入门 i torch.sigmoid(combined self.W_i.t() self.b_i) # 计算候选记忆 c_tilde torch.tanh(combined self.W_c.t() self.b_c) # 更新细胞状态 c f * c_prev i * c_tilde # 计算输出门 o torch.sigmoid(combined self.W_o.t() self.b_o) # 计算当前隐藏状态 h o * torch.tanh(c) return h, c注意在实际应用中我们通常会使用PyTorch内置的LSTM实现因为它们经过了高度优化。这里我们手动实现是为了更好地理解内部机制。3. 验证LSTM单元为了验证我们的实现是否正确让我们用一个简单的序列预测任务来测试。3.1 创建测试数据我们生成一个简单的正弦波序列import numpy as np import matplotlib.pyplot as plt # 生成正弦波序列 seq_length 100 time_steps np.linspace(0, 4*np.pi, seq_length) data np.sin(time_steps) # 可视化 plt.plot(time_steps, data) plt.title(Sine Wave Sequence) plt.xlabel(Time) plt.ylabel(Value) plt.show()3.2 训练LSTM单元现在我们训练LSTM单元来预测序列中的下一个值# 准备训练数据 def create_dataset(seq, look_back1): X, y [], [] for i in range(len(seq)-look_back): X.append(seq[i:ilook_back]) y.append(seq[ilook_back]) return torch.FloatTensor(np.array(X)), torch.FloatTensor(np.array(y)) look_back 5 X, y create_dataset(data, look_back) X X.unsqueeze(-1) # (seq_len, look_back, input_size1) # 初始化模型 input_size 1 hidden_size 32 lstm_cell LSTMCell(input_size, hidden_size) linear nn.Linear(hidden_size, 1) criterion nn.MSELoss() optimizer torch.optim.Adam(list(lstm_cell.parameters()) list(linear.parameters()), lr0.01) # 训练循环 num_epochs 100 for epoch in range(num_epochs): h torch.zeros(1, hidden_size) c torch.zeros(1, hidden_size) total_loss 0 for i in range(len(X)): # 前向传播 h, c lstm_cell(X[i], (h, c)) output linear(h) loss criterion(output, y[i:i1]) # 反向传播 optimizer.zero_grad() loss.backward(retain_graphTrue) optimizer.step() total_loss loss.item() if (epoch1) % 10 0: print(fEpoch {epoch1}, Loss: {total_loss/len(X):.4f})3.3 测试模型训练完成后我们可以用模型来预测整个序列# 预测整个序列 predictions [] h torch.zeros(1, hidden_size) c torch.zeros(1, hidden_size) with torch.no_grad(): for i in range(len(X)): h, c lstm_cell(X[i], (h, c)) output linear(h) predictions.append(output.item()) # 可视化结果 plt.plot(time_steps[look_back:], data[look_back:], labelTrue) plt.plot(time_steps[look_back:], predictions, labelPredicted) plt.legend() plt.title(LSTM Sequence Prediction) plt.xlabel(Time) plt.ylabel(Value) plt.show()4. 门控机制可视化为了更直观地理解LSTM的门控机制我们可以可视化训练过程中各个门的激活值。4.1 记录门控值修改我们的LSTMCell类使其能够记录门控值class LSTMCellWithGates(LSTMCell): def forward(self, x, state): h_prev, c_prev state combined torch.cat((x, h_prev), dim1) # 计算各个门 self.f torch.sigmoid(combined self.W_f.t() self.b_f) self.i torch.sigmoid(combined self.W_i.t() self.b_i) self.o torch.sigmoid(combined self.W_o.t() self.b_o) self.c_tilde torch.tanh(combined self.W_c.t() self.b_c) # 更新细胞状态 c self.f * c_prev self.i * self.c_tilde h self.o * torch.tanh(c) return h, c4.2 可视化门控活动使用修改后的类重新训练模型并绘制门控值# 初始化带门控记录的模型 lstm_cell LSTMCellWithGates(input_size, hidden_size) linear nn.Linear(hidden_size, 1) # 训练模型...(与前面相同的训练代码) # 获取门控值 forget_gates [] input_gates [] output_gates [] h torch.zeros(1, hidden_size) c torch.zeros(1, hidden_size) with torch.no_grad(): for i in range(len(X)): h, c lstm_cell(X[i], (h, c)) forget_gates.append(lstm_cell.f.mean().item()) input_gates.append(lstm_cell.i.mean().item()) output_gates.append(lstm_cell.o.mean().item()) # 绘制门控活动 plt.figure(figsize(12, 6)) plt.plot(time_steps[look_back:], forget_gates, labelForget Gate) plt.plot(time_steps[look_back:], input_gates, labelInput Gate) plt.plot(time_steps[look_back:], output_gates, labelOutput Gate) plt.legend() plt.title(LSTM Gate Activations Over Time) plt.xlabel(Time) plt.ylabel(Gate Value) plt.show()从可视化结果中你可以看到遗忘门在序列变化平缓时倾向于保持较高值保留更多历史信息输入门在序列变化剧烈时激活更强需要更新更多新信息输出门则根据预测需求动态调整输出信息量5. 进阶应用与优化现在你已经理解了LSTM的基本实现让我们探讨一些进阶话题。5.1 多层LSTM在实际应用中我们通常会堆叠多个LSTM层来提取更复杂的特征class MultiLayerLSTM(nn.Module): def __init__(self, input_size, hidden_size, num_layers): super().__init__() self.layers nn.ModuleList([ LSTMCell(input_size if i 0 else hidden_size, hidden_size) for i in range(num_layers) ]) def forward(self, x, states): new_states [] for i, layer in enumerate(self.layers): h, c layer(x, states[i]) new_states.append((h, c)) x h # 上一层的输出作为下一层的输入 return x, new_states5.2 双向LSTM双向LSTM可以同时考虑过去和未来的上下文信息class BiLSTM(nn.Module): def __init__(self, input_size, hidden_size): super().__init__() self.forward_lstm LSTMCell(input_size, hidden_size) self.backward_lstm LSTMCell(input_size, hidden_size) def forward(self, x): # 前向传播 h_forward, c_forward torch.zeros(1, hidden_size), torch.zeros(1, hidden_size) forward_outputs [] for i in range(len(x)): h_forward, c_forward self.forward_lstm(x[i], (h_forward, c_forward)) forward_outputs.append(h_forward) # 反向传播 h_backward, c_backward torch.zeros(1, hidden_size), torch.zeros(1, hidden_size) backward_outputs [] for i in range(len(x)-1, -1, -1): h_backward, c_backward self.backward_lstm(x[i], (h_backward, c_backward)) backward_outputs.insert(0, h_backward) # 合并双向结果 return torch.cat((forward_outputs[-1], backward_outputs[0]), dim1)5.3 性能优化技巧在实际项目中你可以考虑以下优化策略优化策略描述适用场景梯度裁剪限制梯度最大值防止梯度爆炸训练不稳定时权重dropout在LSTM层间应用dropout防止过拟合层归一化在LSTM内部添加LayerNorm加速收敛变学习率使用学习率调度器训练后期微调# 示例在LSTMCell中添加层归一化 class LayerNormLSTMCell(LSTMCell): def __init__(self, input_size, hidden_size): super().__init__(input_size, hidden_size) self.ln_f nn.LayerNorm(hidden_size) self.ln_i nn.LayerNorm(hidden_size) self.ln_o nn.LayerNorm(hidden_size) self.ln_c nn.LayerNorm(hidden_size) def forward(self, x, state): h_prev, c_prev state combined torch.cat((x, h_prev), dim1) f torch.sigmoid(self.ln_f(combined self.W_f.t() self.b_f)) i torch.sigmoid(self.ln_i(combined self.W_i.t() self.b_i)) o torch.sigmoid(self.ln_o(combined self.W_o.t() self.b_o)) c_tilde torch.tanh(self.ln_c(combined self.W_c.t() self.b_c)) c f * c_prev i * c_tilde h o * torch.tanh(c) return h, c通过这次从零实现LSTM的实践我深刻体会到理论公式和实际代码之间的差距。在编写过程中最容易出错的地方是张量维度的匹配和梯度传播的处理。建议在实现复杂模型时先从小规模数据开始验证逐步扩展到完整数据集。

相关文章:

别再死记硬背LSTM公式了!用PyTorch手写一个LSTM单元,5分钟搞懂门控机制

从零实现LSTM单元:用PyTorch代码拆解门控机制 当你第一次看到LSTM那一堆复杂的公式时,是不是感觉头大?遗忘门、输入门、输出门、细胞状态...这些概念听起来高大上,但真正动手写代码时却不知从何下手。今天我们就用PyTorch从零开始…...

【YOLOv11】034、YOLOv11在边缘设备部署:使用TensorRT加速NVIDIA Jetson平台

深夜的调试日志:当YOLOv11遇上Jetson Nano 上周三凌晨两点,实验室的Jetson Nano风扇还在嘶吼。屏幕上显示着YOLOv11的检测帧率:3.2 FPS。这个数字让人清醒——项目要求的实时检测是25 FPS。原生的PyTorch模型在边缘设备上的无力感,在这个深夜格外清晰。这不是算法问题,是…...

从FHSS到OFDMA:Wi-Fi协议演进中的核心技术变革

1. Wi-Fi协议演进简史:从"慢车道"到"信息高速公路" 1997年,当IEEE首次发布802.11标准时,最高2Mbps的传输速率在今天看来简直像蜗牛爬行。记得我第一次接触早期Wi-Fi时,下载一首MP3歌曲需要等待近10分钟&#…...

SQL注入靶场23-37关实战通关攻略

本文将展示sql注入靶场23-37关的通关思路 第二十三关(GET - 报错注入:过滤注释符,用引号闭合) 进入第二十三关发现又回到了GET参数,但是有区别,这关将#和-- qwe等等注释符加入了黑名单,屏蔽掉…...

ABAP批量导入Excel数据实战:从文件选择到数据库插入的完整流程

ABAP高效Excel数据导入:从基础实现到性能优化的完整指南 在企业级SAP系统开发中,Excel数据批量导入是每个ABAP开发者必须掌握的技能。无论是期初数据加载、日常业务数据维护,还是系统间数据交换,高效可靠的数据导入机制都能显著提…...

AI投毒情报预警 | Xinference国产推理框架遭受供应链窃密后门投毒

风险概述 北京时间4月22日16点,悬镜AI安全情报中心在Pypi官方仓库中监测到国产热门开源AI模型推理框架 Xinference 短时间内连续发布2.6.0、2.6.1及2.6.2三个版本更新,并且在这三个新版本框架源码中都检出混淆代码及高风险恶意行为。在混淆恶意代码中发现…...

NHSE:动物森友会存档编辑工具全面指南

NHSE:动物森友会存档编辑工具全面指南 【免费下载链接】NHSE Animal Crossing: New Horizons save editor 项目地址: https://gitcode.com/gh_mirrors/nh/NHSE 你是否厌倦了在《集合啦!动物森友会》中反复刷资源、等待稀有村民出现?想…...

Cursor 官宣AI新玩具:Canvas

推荐阅读 IDEA 官宣:终于可以爽用Cursor了! 重磅!前端再次被碾压,比 Cursor 更强的 AI 工具发布了! Cursor 3.1 发布:VS Code 那一套要失效了吗? 💡 前言:以前和 A…...

安全编程实践常见漏洞与防范措施

在数字化时代,软件安全已成为开发过程中不可忽视的核心问题。安全编程实践旨在通过规范代码编写方式,预防潜在漏洞,降低被攻击风险。由于开发者的疏忽或知识盲区,常见漏洞如注入攻击、缓冲区溢出等仍频繁出现。本文将聚焦三类典型…...

从malloc到memsafe_c:2026规范强制要求的4类API替换清单,不改业务逻辑也能通过ISO/IEC 17961合规审计

第一章:现代 C 语言内存安全编码规范 2026 成本控制策略在嵌入式系统、操作系统内核与高性能服务开发中,C 语言仍占据不可替代地位,但传统内存操作(如裸指针算术、未校验的 malloc 返回值、strcpy 类危险函数)已成为安…...

Linux文件系统(一):从磁盘结构到文件系统基础

目录 一、计算机存储体系 1. 从计算机到磁盘 2. 什么是磁盘 二、磁盘的物理结构 1. 磁盘组成 2. 数据写入原理 三、磁盘的存储结构 1. 扇区、磁道、柱面 2. 磁盘与数组 单磁道展开 同半径磁道展开 全盘展开 C / C 数组思维的线性化 四、磁盘寻址方式 1. CHS 寻址…...

Elasticsearch分布式原理:集群数据分布机制与分片路由全流程深度剖析

Elasticsearch分布式原理:集群数据分布机制与分片路由全流程深度剖析前言一、核心前置:分布式数据依赖的三大基础组件1.1 主节点(Master Node)1.2 数据节点(Data Node)1.3 分片与副本(Shard &am…...

揭秘论文优化新利器:书匠策AI,让降重与去AIGC痕迹变得如此简单!

在学术的浩瀚宇宙中,每一篇论文都是探索者智慧与汗水的结晶。然而,当重复率成为横亘在发表之路上的巨石,当AIGC(人工智能生成内容)的痕迹让论文显得机械而缺乏灵魂,我们该如何破局?别怕&#xf…...

技术支持管理中的服务台建设

技术支持管理中的服务台建设:提升效率与用户体验的关键 在数字化转型的浪潮中,企业对技术支持的依赖日益加深。服务台作为技术支持管理的核心枢纽,不仅是问题解决的“第一窗口”,更是提升用户满意度和运维效率的关键环节。一个高…...

DeepL翻译浏览器扩展:让外语内容阅读变得轻松自然

DeepL翻译浏览器扩展:让外语内容阅读变得轻松自然 【免费下载链接】deepl-chrome-extension A DeepL Translator Chrome extension 项目地址: https://gitcode.com/gh_mirrors/de/deepl-chrome-extension 在当今全球化的信息环境中,我们每天都会接…...

Rspack简介

Rspack简介 前言:在前端构建领域,Webpack 长期占据主导地位,而 Vite 的出现打破了这一格局,两者各有优势,但也都存在明显短板:Webpack 生态成熟、兼容性强,但随着项目规模扩大,构建…...

解锁学术新次元:书匠策AI——期刊论文写作的“魔法宝盒”

在学术的浩瀚宇宙里,期刊论文就像是那璀璨的星辰,照亮着知识探索的道路。可对于许多人来说,撰写一篇高质量的期刊论文,就像是在迷雾中摸索前行,困难重重。不过别担心,今天我要给大家揭开一个神秘“魔法宝盒…...

3个步骤让经典游戏重获新生:IPXWrapper如何解决现代Windows的网络兼容难题?

3个步骤让经典游戏重获新生:IPXWrapper如何解决现代Windows的网络兼容难题? 【免费下载链接】ipxwrapper 项目地址: https://gitcode.com/gh_mirrors/ip/ipxwrapper 还记得那些年,和朋友们一起在《红色警戒2》的战场上厮杀&#xff0…...

别再瞎调权重了!手把手教你用Ceph CRUSH Map优化混合存储(SSD/HDD)性能

别再瞎调权重了!手把手教你用Ceph CRUSH Map优化混合存储(SSD/HDD)性能 当你的Ceph集群同时包含SSD和HDD时,是否经常遇到这样的困扰:高IOPS业务(如数据库)和冷数据归档业务混在一起,…...

QMK Toolbox 终极指南:3分钟掌握键盘固件烧录与调试完整流程

QMK Toolbox 终极指南:3分钟掌握键盘固件烧录与调试完整流程 【免费下载链接】qmk_toolbox A Toolbox companion for QMK Firmware 项目地址: https://gitcode.com/gh_mirrors/qm/qmk_toolbox 你是否曾经想过完全掌控自己的机械键盘?想让每一个按…...

告别龟速!手把手教你给Termux换清华源,pkg update飞起来

告别龟速!手把手教你给Termux换清华源,pkg update飞起来 每次在Termux里执行pkg update时,看着那缓慢的进度条一点点往前挪,是不是感觉时间仿佛被拉长了?作为Android上最强大的终端模拟器,Termux的官方源服…...

华为VRP网络运维:从零到精通的命令实战指南

1. 华为VRP平台入门:认识你的网络操作系统 第一次接触华为VRP(Versatile Routing Platform)时,我完全被满屏的命令行吓到了。但后来发现,这就像学开车要先熟悉方向盘和档位一样,掌握几个基础命令就能让设备…...

别再用错__attribute__了!C语言高手都在用的15个实战技巧(附代码避坑)

别再用错__attribute__了!C语言高手都在用的15个实战技巧(附代码避坑) 在嵌入式开发和系统级编程中,编译器扩展特性往往是区分普通开发者和高手的关键分水岭。GNU C的__attribute__机制就像瑞士军刀中的隐藏工具——90%的开发者只…...

抖音无水印下载终极指南:3分钟学会批量保存纯净视频

抖音无水印下载终极指南:3分钟学会批量保存纯净视频 【免费下载链接】douyin-downloader A practical Douyin downloader for both single-item and profile batch downloads, with progress display, retries, SQLite deduplication, and browser fallback support…...

Mujoco+强化学习入门实战教程

前言:本文是为了方便机器人初学者快速学习Mujoco强化学习而设计的教程,循序渐进,从环境搭建到简单的运动控制再到强化学习自主探索,难度逐步提升,帮助初学者建立学习路线,思维框架,并在此基础上…...

别再为小众物种发愁了!手把手教你用R包biomaRt和AnnotationForge定制专属OrgDb数据库

突破非模式生物分析瓶颈:从零构建定制化OrgDb数据库的实战指南 当你在深夜的实验室里盯着屏幕上那些无法匹配的基因ID时,是否曾感到一丝绝望?作为一名长期与山羊、绵羊等非模式生物打交道的生物信息学研究者,我完全理解这种挫败感…...

工业级YOLO检测数据处理:C#上位机存储+报表导出全方案(含SQLite+Excel+PDF+7×24小时稳定运行)

摘要 在工业视觉检测系统中,YOLO模型的推理性能只是基础,检测结果的可靠存储、规范管理与标准化报表导出才是决定系统能否真正落地的关键。很多项目只关注模型精度,却因数据处理方案简陋导致数据丢失、追溯困难、报表不规范等问题,最终无法通过企业验收。 本文基于C# Win…...

【WPF】巧用BitmapCacheOption.OnLoad释放图像文件句柄,解决资源锁定与程序崩溃难题

1. 为什么WPF会锁定图像文件? 在WPF开发中,很多开发者都遇到过这样的尴尬场景:程序加载了一张本地图片后,想要删除或修改这个图片文件时,系统却提示"文件正在被另一个程序使用"。这种情况通常发生在使用Bitm…...

Harness Engineering:AI Agent 落地企业的工程化核心

2025年是AI Agent的爆发元年,各类智能体工具层出不穷,但落地企业生产环境时却问题频发——越权操作、逻辑混乱、无法审计的情况屡见不鲜。2026年,Harness Engineering 成为行业破局关键,它让AI Agent从「实验室玩具」变成「企业级…...

别再傻傻分不清了!一张图看懂PLM、ERP、MES、CRM在工厂里到底怎么分工协作

制造业四大核心系统协同作战指南:PLM、ERP、MES、CRM如何打通产品全生命周期 走进任何一家现代化制造企业的信息化部门,你都会听到PLM、ERP、MES、CRM这些英文缩写被频繁提及。对于初次接触这些系统的IT人员或业务管理者来说,最困惑的往往不是…...