pytorch实现长短期记忆网络 (LSTM)
人工智能例子汇总:AI常见的算法和例子-CSDN博客
LSTM 通过 记忆单元(cell) 和 三个门控机制(遗忘门、输入门、输出门)来控制信息流:
记忆单元(Cell State)
- 负责存储长期信息,并通过门控机制决定保留或丢弃信息。
遗忘门(Forget Gate, ftf_tft)

输入门(Input Gate, iti_tit)

输出门(Output Gate, oto_tot)

| 特性 | 传统 RNN | LSTM |
|---|---|---|
| 记忆能力 | 短期记忆 | 长短期记忆 |
| 计算复杂度 | 低 | 高 |
| 解决梯度消失 | 否 | 是 |
| 适用场景 | 短序列数据 | 长序列数据 |
LSTM 应用场景
- 自然语言处理(NLP):文本生成、情感分析、机器翻译
- 时间序列预测:股票预测、天气预报、传感器数据分析
- 语音识别:自动字幕生成、语音转文字(ASR)
- 机器人与控制系统:智能体决策、自动驾驶
例子:
下面例子实现了一个 基于 LSTM 的强化学习智能体,在 1D 网格环境 里移动,并找到最优路径。
最终,我们 绘制 5 条测试路径,并高亮显示最佳路径(红色)
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt# ========== 1. 定义 LSTM 策略网络 ==========
class LSTMPolicy(nn.Module):def __init__(self, input_size, hidden_size, output_size, num_layers=1):super(LSTMPolicy, self).__init__()self.hidden_size = hidden_sizeself.num_layers = num_layersself.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)self.fc = nn.Linear(hidden_size, output_size)self.softmax = nn.Softmax(dim=-1)def forward(self, x, hidden_state):batch_size = x.size(0)# 确保 hidden_state 维度正确if hidden_state[0].dim() == 2:hidden_state = (hidden_state[0].unsqueeze(1).repeat(1, batch_size, 1),hidden_state[1].unsqueeze(1).repeat(1, batch_size, 1))out, hidden_state = self.lstm(x, hidden_state)out = self.fc(out[:, -1, :]) # 取最后时间步的输出action_prob = self.softmax(out) # 归一化输出,作为策略return action_prob, hidden_statedef init_hidden(self, batch_size=1):return (torch.zeros(self.num_layers, batch_size, self.hidden_size),torch.zeros(self.num_layers, batch_size, self.hidden_size))# ========== 2. 创建网格环境 ==========
class GridWorld:def __init__(self, grid_size=10, goal_position=9):self.grid_size = grid_sizeself.goal_position = goal_positionself.reset()def reset(self):self.position = 0return self.positiondef step(self, action):if action == 0:self.position = max(0, self.position - 1)elif action == 1:self.position = min(self.grid_size - 1, self.position + 1)reward = 1 if self.position == self.goal_position else -0.1done = self.position == self.goal_positionreturn self.position, reward, done# ========== 3. 训练智能体 ==========
def train(num_episodes=500, max_steps=50):env = GridWorld()input_size = 1hidden_size = 64output_size = 2num_layers = 1policy = LSTMPolicy(input_size, hidden_size, output_size, num_layers)optimizer = optim.Adam(policy.parameters(), lr=0.01)gamma = 0.99for episode in range(num_episodes):state = torch.tensor([[env.reset()]], dtype=torch.float32).unsqueeze(0) # (1, 1, input_size)hidden_state = policy.init_hidden(batch_size=1)log_probs = []rewards = []for step in range(max_steps):action_probs, hidden_state = policy(state, hidden_state)action = torch.multinomial(action_probs, 1).item()log_prob = torch.log(action_probs.squeeze(0)[action])log_probs.append(log_prob)next_state, reward, done = env.step(action)rewards.append(reward)if done:breakstate = torch.tensor([[next_state]], dtype=torch.float32).unsqueeze(0)# 计算回报并更新策略returns = []R = 0for r in reversed(rewards):R = r + gamma * Rreturns.insert(0, R)returns = torch.tensor(returns, dtype=torch.float32)returns = (returns - returns.mean()) / (returns.std() + 1e-9)loss = sum([-log_prob * R for log_prob, R in zip(log_probs, returns)])optimizer.zero_grad()loss.backward()optimizer.step()if (episode + 1) % 50 == 0:print(f"Episode {episode + 1}/{num_episodes}, Total Reward: {sum(rewards)}")torch.save(policy.state_dict(), "policy.pth")# 训练智能体
train(500)# ========== 4. 测试智能体并绘制最佳路径 ==========
def test(num_episodes=5):env = GridWorld()input_size = 1hidden_size = 64output_size = 2num_layers = 1policy = LSTMPolicy(input_size, hidden_size, output_size, num_layers)policy.load_state_dict(torch.load("policy.pth"))plt.figure(figsize=(10, 5))best_path = Nonebest_steps = float('inf')for episode in range(num_episodes):state = torch.tensor([[env.reset()]], dtype=torch.float32).unsqueeze(0) # (1, 1, input_size)hidden_state = policy.init_hidden(batch_size=1)positions = [env.position] # 记录位置变化while True:action_probs, hidden_state = policy(state, hidden_state)action = torch.argmax(action_probs, dim=-1).item()next_state, reward, done = env.step(action)positions.append(next_state)if done:breakstate = torch.tensor([[next_state]], dtype=torch.float32).unsqueeze(0)# 记录最佳路径(最短步数)if len(positions) < best_steps:best_steps = len(positions)best_path = positions# 绘制普通路径(蓝色)plt.plot(range(len(positions)), positions, marker='o', linestyle='-', color='blue', alpha=0.6,label=f'Episode {episode + 1}' if episode == 0 else "")# 绘制最佳路径(红色)if best_path:plt.plot(range(len(best_path)), best_path, marker='o', linestyle='-', color='red', linewidth=2,label="Best Path")# 打印最佳路径print(f"Best Path (steps={best_steps}): {best_path}")plt.xlabel("Time Steps")plt.ylabel("Agent Position")plt.title("Agent's Movement Path (Best Path in Red)")plt.legend()plt.grid(True)plt.show()# 测试并绘制智能体移动路径
test(5)

相关文章:
pytorch实现长短期记忆网络 (LSTM)
人工智能例子汇总:AI常见的算法和例子-CSDN博客 LSTM 通过 记忆单元(cell) 和 三个门控机制(遗忘门、输入门、输出门)来控制信息流: 记忆单元(Cell State) 负责存储长期信息&…...
【ubuntu】双系统ubuntu下一键切换到Windows
ubuntu下一键切换到Windows 1.4.1 重启脚本1.4.2 快捷方式1.4.3 移动快捷方式到系统目录 按前文所述文档,开机默认启动ubuntu。Windows切换到Ubuntu直接重启就行了,而Ubuntu切换到Windows稍微有点麻烦。可编辑切换重启到Windows的快捷方式。 1.4.1 重启…...
【PyTorch】6.张量形状操作:在深度学习的 “魔方” 里,玩转张量形状
目录 1. reshape 函数的用法 2. transpose 和 permute 函数的使用 4. squeeze 和 unsqueeze 函数的用法 5. 小节 个人主页:Icomi 专栏地址:PyTorch入门 在深度学习蓬勃发展的当下,PyTorch 是不可或缺的工具。它作为强大的深度学习框架&am…...
大模型GUI系列论文阅读 DAY4续:《Large Language Model Agent for Fake News Detection》
摘要 在当前的数字时代,在线平台上虚假信息的迅速传播对社会福祉、公众信任和民主进程构成了重大挑战,并影响着关键决策和公众舆论。为应对这些挑战,自动化假新闻检测机制的需求日益增长。 预训练的大型语言模型(LLMs࿰…...
论文阅读(九):通过概率图模型建立连锁不平衡模型和进行关联研究:最新进展访问之旅
1.论文链接:Modeling Linkage Disequilibrium and Performing Association Studies through Probabilistic Graphical Models: a Visiting Tour of Recent Advances 摘要: 本章对概率图模型(PGMs)的最新进展进行了深入的回顾&…...
python小知识-typing注解你的程序
python小知识-typing注解你的程序 1. Typing的简介 typing 是 Python 的一个标准库,它提供了类型注解的支持,但并不会强制类型检查。类型注解在 Python 3.5 中引入,并在后续版本中得到了增强和扩展。typing 库允许开发者为变量、函数参数和…...
git基础使用--1--版本控制的基本概念
git基础使用–1–版本控制的基本概念 1.版本控制的需求背景,即为啥需要版本控制 先说啥叫版本,这个就不多说了吧,我们写代码的时候肯定不可能一蹴而就,肯定是今天写一点,明天写一点,对于项目来讲ÿ…...
“新月智能武器系统”CIWS,开启智能武器的新纪元
新月人物传记:人物传记之新月篇-CSDN博客 相关文章链接:星际战争模拟系统:新月的编程之道-CSDN博客 新月智能护甲系统CMIA--未来战场的守护者-CSDN博客 “新月之智”智能战术头盔系统(CITHS)-CSDN博客 目录 智能武…...
JVM运行时数据区域-附面试题
Java虚拟机在执行Java程序的过程中会把它所管理的内存划分为若干个不同的数据区域。这些区域 有各自的用途,以及创建和销毁的时间,有的区域随着虚拟机进程的启动而一直存在,有些区域则是 依赖用户线程的启动和结束而建立和销毁。 1. 程序计…...
增删改查(CRUD)操作
文章目录 MySQL系列:1.CRUD简介2.Create(创建)2.1单行数据全列插入2.2 单行数据指定插入2.3 多⾏数据指定列插⼊ 3.Retrieve(读取)3.1 Select查询3.1.1 全列查询3.1.2 指定列查询3.1.3 查询字段为表达式(都是临时表不会对原有表数据产生影响)…...
Vue.js `Suspense` 和异步组件加载
Vue.js Suspense 和异步组件加载 今天我们来聊聊 Vue 3 中的一个强大特性:<Suspense> 组件,以及它如何帮助我们更优雅地处理异步组件加载。如果你曾在 Vue 项目中处理过异步组件加载,那么这篇文章将为你介绍一种更简洁高效的方式。 什…...
HTB:LinkVortex[WriteUP]
目录 连接至HTB服务器并启动靶机 信息收集 使用rustscan对靶机TCP端口进行开放扫描 使用nmap对靶机TCP开放端口进行脚本、服务扫描 使用nmap对靶机TCP开放端口进行漏洞、系统扫描 使用nmap对靶机常用UDP端口进行开放扫描 使用gobuster对靶机进行路径FUZZ 使用ffuf堆靶机…...
Linux命令入门
Linux命令入门 ls命令 ls命令的作用是列出目录下的内容,语法细节如下: 1s[-a -l -h] [Linux路径] -a -l -h是可选的选项 Linux路径是此命令可选的参数 当不使用选项和参数,直接使用ls命令本体,表示:以平铺形式,列出当前工作目录下的内容 ls命令的选项 -a -a选项&a…...
【问题】Chrome安装不受支持的扩展 解决方案
此扩展程序已停用,因为它已不再受支持 Chromium 建议您移除它。详细了解受支持的扩展程序 此扩展程序已停用,因为它已不再受支持 详情移除 解决 1. 解压扩展 2.打开manifest.json 3.修改版本 将 manifest_version 改为3及以上 {"manifest_ver…...
【题解】AtCoder Beginner Contest ABC391 D Gravity
题目大意 原题面链接 在一个 1 0 9 W 10^9\times W 109W 的平面里有 N N N 个方块。我们用 ( x , y ) (x,y) (x,y) 表示第 x x x 列从下往上数的 y y y 个位置。第 i i i 个方块的位置是 ( x i , y i ) (x_i,y_i) (xi,yi)。现在执行无数次操作,每一次…...
使用 SpringBoot+Thymeleaf 模板引擎进行 Web 开发
目录 一、什么是 Thymeleaf 模板引擎 二、Thymeleaf 模板引擎的 Maven 坐标 三、配置 Thymeleaf 四、访问页面 五、访问静态资源 六、Thymeleaf 使用示例 七、Thymeleaf 常用属性 前言 在现代 Web 开发中,模板引擎被广泛用于将动态内容渲染到静态页面中。Thy…...
【Java异步编程】CompletableFuture综合实战:泡茶喝水与复杂的异步调用
文章目录 一. 两个异步任务的合并:泡茶喝水二. 复杂的异步调用:结果依赖,以及异步执行调用等 一. 两个异步任务的合并:泡茶喝水 下面的代码中我们实现泡茶喝水。这里分3个任务:任务1负责洗水壶、烧开水,任…...
Nginx知识
nginx 精简的配置文件 worker_processes 1; # 可以理解为一个内核一个worker # 开多了可能性能不好events {worker_connections 1024; } # 一个 worker 可以创建的连接数 # 1024 代表默认一般不用改http {include mime.types;# 代表引入的配置文件# mime.types 在 ngi…...
Unity开发游戏使用XLua的基础
Unity使用Xlua的常用编码方式,做一下记录 1、C#调用lua 1、Lua解析器 private LuaEnv env new LuaEnv();//保持它的唯一性void Start(){env.DoString("print(你好lua)");//env.DoString("require(Main)"); 默认在resources文件夹下面//帮助…...
AI-ISP论文Learning to See in the Dark解读
论文地址:Learning to See in the Dark 图1. 利用卷积网络进行极微光成像。黑暗的室内环境。相机处的照度小于0.1勒克斯。索尼α7S II传感器曝光时间为1/30秒。(a) 相机在ISO 8000下拍摄的图像。(b) 相机在ISO 409600下拍摄的图像。该图像存在噪点和色彩偏差。©…...
python打卡day49
知识点回顾: 通道注意力模块复习空间注意力模块CBAM的定义 作业:尝试对今天的模型检查参数数目,并用tensorboard查看训练过程 import torch import torch.nn as nn# 定义通道注意力 class ChannelAttention(nn.Module):def __init__(self,…...
51c自动驾驶~合集58
我自己的原文哦~ https://blog.51cto.com/whaosoft/13967107 #CCA-Attention 全局池化局部保留,CCA-Attention为LLM长文本建模带来突破性进展 琶洲实验室、华南理工大学联合推出关键上下文感知注意力机制(CCA-Attention),…...
3.3.1_1 检错编码(奇偶校验码)
从这节课开始,我们会探讨数据链路层的差错控制功能,差错控制功能的主要目标是要发现并且解决一个帧内部的位错误,我们需要使用特殊的编码技术去发现帧内部的位错误,当我们发现位错误之后,通常来说有两种解决方案。第一…...
uni-app学习笔记二十二---使用vite.config.js全局导入常用依赖
在前面的练习中,每个页面需要使用ref,onShow等生命周期钩子函数时都需要像下面这样导入 import {onMounted, ref} from "vue" 如果不想每个页面都导入,需要使用node.js命令npm安装unplugin-auto-import npm install unplugin-au…...
基于服务器使用 apt 安装、配置 Nginx
🧾 一、查看可安装的 Nginx 版本 首先,你可以运行以下命令查看可用版本: apt-cache madison nginx-core输出示例: nginx-core | 1.18.0-6ubuntu14.6 | http://archive.ubuntu.com/ubuntu focal-updates/main amd64 Packages ng…...
【磁盘】每天掌握一个Linux命令 - iostat
目录 【磁盘】每天掌握一个Linux命令 - iostat工具概述安装方式核心功能基础用法进阶操作实战案例面试题场景生产场景 注意事项 【磁盘】每天掌握一个Linux命令 - iostat 工具概述 iostat(I/O Statistics)是Linux系统下用于监视系统输入输出设备和CPU使…...
根据万维钢·精英日课6的内容,使用AI(2025)可以参考以下方法:
根据万维钢精英日课6的内容,使用AI(2025)可以参考以下方法: 四个洞见 模型已经比人聪明:以ChatGPT o3为代表的AI非常强大,能运用高级理论解释道理、引用最新学术论文,生成对顶尖科学家都有用的…...
css3笔记 (1) 自用
outline: none 用于移除元素获得焦点时默认的轮廓线 broder:0 用于移除边框 font-size:0 用于设置字体不显示 list-style: none 消除<li> 标签默认样式 margin: xx auto 版心居中 width:100% 通栏 vertical-align 作用于行内元素 / 表格单元格ÿ…...
稳定币的深度剖析与展望
一、引言 在当今数字化浪潮席卷全球的时代,加密货币作为一种新兴的金融现象,正以前所未有的速度改变着我们对传统货币和金融体系的认知。然而,加密货币市场的高度波动性却成为了其广泛应用和普及的一大障碍。在这样的背景下,稳定…...
LLMs 系列实操科普(1)
写在前面: 本期内容我们继续 Andrej Karpathy 的《How I use LLMs》讲座内容,原视频时长 ~130 分钟,以实操演示主流的一些 LLMs 的使用,由于涉及到实操,实际上并不适合以文字整理,但还是决定尽量整理一份笔…...
