pytorch实现长短期记忆网络 (LSTM)
人工智能例子汇总:AI常见的算法和例子-CSDN博客
LSTM 通过 记忆单元(cell) 和 三个门控机制(遗忘门、输入门、输出门)来控制信息流:
记忆单元(Cell State)
- 负责存储长期信息,并通过门控机制决定保留或丢弃信息。
遗忘门(Forget Gate, ftf_tft)
输入门(Input Gate, iti_tit)
输出门(Output Gate, oto_tot)
特性 | 传统 RNN | LSTM |
---|---|---|
记忆能力 | 短期记忆 | 长短期记忆 |
计算复杂度 | 低 | 高 |
解决梯度消失 | 否 | 是 |
适用场景 | 短序列数据 | 长序列数据 |
LSTM 应用场景
- 自然语言处理(NLP):文本生成、情感分析、机器翻译
- 时间序列预测:股票预测、天气预报、传感器数据分析
- 语音识别:自动字幕生成、语音转文字(ASR)
- 机器人与控制系统:智能体决策、自动驾驶
例子:
下面例子实现了一个 基于 LSTM 的强化学习智能体,在 1D 网格环境 里移动,并找到最优路径。
最终,我们 绘制 5 条测试路径,并高亮显示最佳路径(红色)
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt# ========== 1. 定义 LSTM 策略网络 ==========
class LSTMPolicy(nn.Module):def __init__(self, input_size, hidden_size, output_size, num_layers=1):super(LSTMPolicy, self).__init__()self.hidden_size = hidden_sizeself.num_layers = num_layersself.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)self.fc = nn.Linear(hidden_size, output_size)self.softmax = nn.Softmax(dim=-1)def forward(self, x, hidden_state):batch_size = x.size(0)# 确保 hidden_state 维度正确if hidden_state[0].dim() == 2:hidden_state = (hidden_state[0].unsqueeze(1).repeat(1, batch_size, 1),hidden_state[1].unsqueeze(1).repeat(1, batch_size, 1))out, hidden_state = self.lstm(x, hidden_state)out = self.fc(out[:, -1, :]) # 取最后时间步的输出action_prob = self.softmax(out) # 归一化输出,作为策略return action_prob, hidden_statedef init_hidden(self, batch_size=1):return (torch.zeros(self.num_layers, batch_size, self.hidden_size),torch.zeros(self.num_layers, batch_size, self.hidden_size))# ========== 2. 创建网格环境 ==========
class GridWorld:def __init__(self, grid_size=10, goal_position=9):self.grid_size = grid_sizeself.goal_position = goal_positionself.reset()def reset(self):self.position = 0return self.positiondef step(self, action):if action == 0:self.position = max(0, self.position - 1)elif action == 1:self.position = min(self.grid_size - 1, self.position + 1)reward = 1 if self.position == self.goal_position else -0.1done = self.position == self.goal_positionreturn self.position, reward, done# ========== 3. 训练智能体 ==========
def train(num_episodes=500, max_steps=50):env = GridWorld()input_size = 1hidden_size = 64output_size = 2num_layers = 1policy = LSTMPolicy(input_size, hidden_size, output_size, num_layers)optimizer = optim.Adam(policy.parameters(), lr=0.01)gamma = 0.99for episode in range(num_episodes):state = torch.tensor([[env.reset()]], dtype=torch.float32).unsqueeze(0) # (1, 1, input_size)hidden_state = policy.init_hidden(batch_size=1)log_probs = []rewards = []for step in range(max_steps):action_probs, hidden_state = policy(state, hidden_state)action = torch.multinomial(action_probs, 1).item()log_prob = torch.log(action_probs.squeeze(0)[action])log_probs.append(log_prob)next_state, reward, done = env.step(action)rewards.append(reward)if done:breakstate = torch.tensor([[next_state]], dtype=torch.float32).unsqueeze(0)# 计算回报并更新策略returns = []R = 0for r in reversed(rewards):R = r + gamma * Rreturns.insert(0, R)returns = torch.tensor(returns, dtype=torch.float32)returns = (returns - returns.mean()) / (returns.std() + 1e-9)loss = sum([-log_prob * R for log_prob, R in zip(log_probs, returns)])optimizer.zero_grad()loss.backward()optimizer.step()if (episode + 1) % 50 == 0:print(f"Episode {episode + 1}/{num_episodes}, Total Reward: {sum(rewards)}")torch.save(policy.state_dict(), "policy.pth")# 训练智能体
train(500)# ========== 4. 测试智能体并绘制最佳路径 ==========
def test(num_episodes=5):env = GridWorld()input_size = 1hidden_size = 64output_size = 2num_layers = 1policy = LSTMPolicy(input_size, hidden_size, output_size, num_layers)policy.load_state_dict(torch.load("policy.pth"))plt.figure(figsize=(10, 5))best_path = Nonebest_steps = float('inf')for episode in range(num_episodes):state = torch.tensor([[env.reset()]], dtype=torch.float32).unsqueeze(0) # (1, 1, input_size)hidden_state = policy.init_hidden(batch_size=1)positions = [env.position] # 记录位置变化while True:action_probs, hidden_state = policy(state, hidden_state)action = torch.argmax(action_probs, dim=-1).item()next_state, reward, done = env.step(action)positions.append(next_state)if done:breakstate = torch.tensor([[next_state]], dtype=torch.float32).unsqueeze(0)# 记录最佳路径(最短步数)if len(positions) < best_steps:best_steps = len(positions)best_path = positions# 绘制普通路径(蓝色)plt.plot(range(len(positions)), positions, marker='o', linestyle='-', color='blue', alpha=0.6,label=f'Episode {episode + 1}' if episode == 0 else "")# 绘制最佳路径(红色)if best_path:plt.plot(range(len(best_path)), best_path, marker='o', linestyle='-', color='red', linewidth=2,label="Best Path")# 打印最佳路径print(f"Best Path (steps={best_steps}): {best_path}")plt.xlabel("Time Steps")plt.ylabel("Agent Position")plt.title("Agent's Movement Path (Best Path in Red)")plt.legend()plt.grid(True)plt.show()# 测试并绘制智能体移动路径
test(5)
相关文章:

pytorch实现长短期记忆网络 (LSTM)
人工智能例子汇总:AI常见的算法和例子-CSDN博客 LSTM 通过 记忆单元(cell) 和 三个门控机制(遗忘门、输入门、输出门)来控制信息流: 记忆单元(Cell State) 负责存储长期信息&…...
【ubuntu】双系统ubuntu下一键切换到Windows
ubuntu下一键切换到Windows 1.4.1 重启脚本1.4.2 快捷方式1.4.3 移动快捷方式到系统目录 按前文所述文档,开机默认启动ubuntu。Windows切换到Ubuntu直接重启就行了,而Ubuntu切换到Windows稍微有点麻烦。可编辑切换重启到Windows的快捷方式。 1.4.1 重启…...

【PyTorch】6.张量形状操作:在深度学习的 “魔方” 里,玩转张量形状
目录 1. reshape 函数的用法 2. transpose 和 permute 函数的使用 4. squeeze 和 unsqueeze 函数的用法 5. 小节 个人主页:Icomi 专栏地址:PyTorch入门 在深度学习蓬勃发展的当下,PyTorch 是不可或缺的工具。它作为强大的深度学习框架&am…...

大模型GUI系列论文阅读 DAY4续:《Large Language Model Agent for Fake News Detection》
摘要 在当前的数字时代,在线平台上虚假信息的迅速传播对社会福祉、公众信任和民主进程构成了重大挑战,并影响着关键决策和公众舆论。为应对这些挑战,自动化假新闻检测机制的需求日益增长。 预训练的大型语言模型(LLMs࿰…...

论文阅读(九):通过概率图模型建立连锁不平衡模型和进行关联研究:最新进展访问之旅
1.论文链接:Modeling Linkage Disequilibrium and Performing Association Studies through Probabilistic Graphical Models: a Visiting Tour of Recent Advances 摘要: 本章对概率图模型(PGMs)的最新进展进行了深入的回顾&…...
python小知识-typing注解你的程序
python小知识-typing注解你的程序 1. Typing的简介 typing 是 Python 的一个标准库,它提供了类型注解的支持,但并不会强制类型检查。类型注解在 Python 3.5 中引入,并在后续版本中得到了增强和扩展。typing 库允许开发者为变量、函数参数和…...

git基础使用--1--版本控制的基本概念
git基础使用–1–版本控制的基本概念 1.版本控制的需求背景,即为啥需要版本控制 先说啥叫版本,这个就不多说了吧,我们写代码的时候肯定不可能一蹴而就,肯定是今天写一点,明天写一点,对于项目来讲ÿ…...

“新月智能武器系统”CIWS,开启智能武器的新纪元
新月人物传记:人物传记之新月篇-CSDN博客 相关文章链接:星际战争模拟系统:新月的编程之道-CSDN博客 新月智能护甲系统CMIA--未来战场的守护者-CSDN博客 “新月之智”智能战术头盔系统(CITHS)-CSDN博客 目录 智能武…...

JVM运行时数据区域-附面试题
Java虚拟机在执行Java程序的过程中会把它所管理的内存划分为若干个不同的数据区域。这些区域 有各自的用途,以及创建和销毁的时间,有的区域随着虚拟机进程的启动而一直存在,有些区域则是 依赖用户线程的启动和结束而建立和销毁。 1. 程序计…...

增删改查(CRUD)操作
文章目录 MySQL系列:1.CRUD简介2.Create(创建)2.1单行数据全列插入2.2 单行数据指定插入2.3 多⾏数据指定列插⼊ 3.Retrieve(读取)3.1 Select查询3.1.1 全列查询3.1.2 指定列查询3.1.3 查询字段为表达式(都是临时表不会对原有表数据产生影响)…...
Vue.js `Suspense` 和异步组件加载
Vue.js Suspense 和异步组件加载 今天我们来聊聊 Vue 3 中的一个强大特性:<Suspense> 组件,以及它如何帮助我们更优雅地处理异步组件加载。如果你曾在 Vue 项目中处理过异步组件加载,那么这篇文章将为你介绍一种更简洁高效的方式。 什…...

HTB:LinkVortex[WriteUP]
目录 连接至HTB服务器并启动靶机 信息收集 使用rustscan对靶机TCP端口进行开放扫描 使用nmap对靶机TCP开放端口进行脚本、服务扫描 使用nmap对靶机TCP开放端口进行漏洞、系统扫描 使用nmap对靶机常用UDP端口进行开放扫描 使用gobuster对靶机进行路径FUZZ 使用ffuf堆靶机…...

Linux命令入门
Linux命令入门 ls命令 ls命令的作用是列出目录下的内容,语法细节如下: 1s[-a -l -h] [Linux路径] -a -l -h是可选的选项 Linux路径是此命令可选的参数 当不使用选项和参数,直接使用ls命令本体,表示:以平铺形式,列出当前工作目录下的内容 ls命令的选项 -a -a选项&a…...

【问题】Chrome安装不受支持的扩展 解决方案
此扩展程序已停用,因为它已不再受支持 Chromium 建议您移除它。详细了解受支持的扩展程序 此扩展程序已停用,因为它已不再受支持 详情移除 解决 1. 解压扩展 2.打开manifest.json 3.修改版本 将 manifest_version 改为3及以上 {"manifest_ver…...

【题解】AtCoder Beginner Contest ABC391 D Gravity
题目大意 原题面链接 在一个 1 0 9 W 10^9\times W 109W 的平面里有 N N N 个方块。我们用 ( x , y ) (x,y) (x,y) 表示第 x x x 列从下往上数的 y y y 个位置。第 i i i 个方块的位置是 ( x i , y i ) (x_i,y_i) (xi,yi)。现在执行无数次操作,每一次…...

使用 SpringBoot+Thymeleaf 模板引擎进行 Web 开发
目录 一、什么是 Thymeleaf 模板引擎 二、Thymeleaf 模板引擎的 Maven 坐标 三、配置 Thymeleaf 四、访问页面 五、访问静态资源 六、Thymeleaf 使用示例 七、Thymeleaf 常用属性 前言 在现代 Web 开发中,模板引擎被广泛用于将动态内容渲染到静态页面中。Thy…...
【Java异步编程】CompletableFuture综合实战:泡茶喝水与复杂的异步调用
文章目录 一. 两个异步任务的合并:泡茶喝水二. 复杂的异步调用:结果依赖,以及异步执行调用等 一. 两个异步任务的合并:泡茶喝水 下面的代码中我们实现泡茶喝水。这里分3个任务:任务1负责洗水壶、烧开水,任…...
Nginx知识
nginx 精简的配置文件 worker_processes 1; # 可以理解为一个内核一个worker # 开多了可能性能不好events {worker_connections 1024; } # 一个 worker 可以创建的连接数 # 1024 代表默认一般不用改http {include mime.types;# 代表引入的配置文件# mime.types 在 ngi…...
Unity开发游戏使用XLua的基础
Unity使用Xlua的常用编码方式,做一下记录 1、C#调用lua 1、Lua解析器 private LuaEnv env new LuaEnv();//保持它的唯一性void Start(){env.DoString("print(你好lua)");//env.DoString("require(Main)"); 默认在resources文件夹下面//帮助…...

AI-ISP论文Learning to See in the Dark解读
论文地址:Learning to See in the Dark 图1. 利用卷积网络进行极微光成像。黑暗的室内环境。相机处的照度小于0.1勒克斯。索尼α7S II传感器曝光时间为1/30秒。(a) 相机在ISO 8000下拍摄的图像。(b) 相机在ISO 409600下拍摄的图像。该图像存在噪点和色彩偏差。©…...

利用最小二乘法找圆心和半径
#include <iostream> #include <vector> #include <cmath> #include <Eigen/Dense> // 需安装Eigen库用于矩阵运算 // 定义点结构 struct Point { double x, y; Point(double x_, double y_) : x(x_), y(y_) {} }; // 最小二乘法求圆心和半径 …...

多模态2025:技术路线“神仙打架”,视频生成冲上云霄
文|魏琳华 编|王一粟 一场大会,聚集了中国多模态大模型的“半壁江山”。 智源大会2025为期两天的论坛中,汇集了学界、创业公司和大厂等三方的热门选手,关于多模态的集中讨论达到了前所未有的热度。其中,…...

突破不可导策略的训练难题:零阶优化与强化学习的深度嵌合
强化学习(Reinforcement Learning, RL)是工业领域智能控制的重要方法。它的基本原理是将最优控制问题建模为马尔可夫决策过程,然后使用强化学习的Actor-Critic机制(中文译作“知行互动”机制),逐步迭代求解…...

Redis相关知识总结(缓存雪崩,缓存穿透,缓存击穿,Redis实现分布式锁,如何保持数据库和缓存一致)
文章目录 1.什么是Redis?2.为什么要使用redis作为mysql的缓存?3.什么是缓存雪崩、缓存穿透、缓存击穿?3.1缓存雪崩3.1.1 大量缓存同时过期3.1.2 Redis宕机 3.2 缓存击穿3.3 缓存穿透3.4 总结 4. 数据库和缓存如何保持一致性5. Redis实现分布式…...
高防服务器能够抵御哪些网络攻击呢?
高防服务器作为一种有着高度防御能力的服务器,可以帮助网站应对分布式拒绝服务攻击,有效识别和清理一些恶意的网络流量,为用户提供安全且稳定的网络环境,那么,高防服务器一般都可以抵御哪些网络攻击呢?下面…...

Mac下Android Studio扫描根目录卡死问题记录
环境信息 操作系统: macOS 15.5 (Apple M2芯片)Android Studio版本: Meerkat Feature Drop | 2024.3.2 Patch 1 (Build #AI-243.26053.27.2432.13536105, 2025年5月22日构建) 问题现象 在项目开发过程中,提示一个依赖外部头文件的cpp源文件需要同步,点…...
JS手写代码篇----使用Promise封装AJAX请求
15、使用Promise封装AJAX请求 promise就有reject和resolve了,就不必写成功和失败的回调函数了 const BASEURL ./手写ajax/test.jsonfunction promiseAjax() {return new Promise((resolve, reject) > {const xhr new XMLHttpRequest();xhr.open("get&quo…...

高考志愿填报管理系统---开发介绍
高考志愿填报管理系统是一款专为教育机构、学校和教师设计的学生信息管理和志愿填报辅助平台。系统基于Django框架开发,采用现代化的Web技术,为教育工作者提供高效、安全、便捷的学生管理解决方案。 ## 📋 系统概述 ### 🎯 系统定…...
【Kafka】Kafka从入门到实战:构建高吞吐量分布式消息系统
Kafka从入门到实战:构建高吞吐量分布式消息系统 一、Kafka概述 Apache Kafka是一个分布式流处理平台,最初由LinkedIn开发,后成为Apache顶级项目。它被设计用于高吞吐量、低延迟的消息处理,能够处理来自多个生产者的海量数据,并将这些数据实时传递给消费者。 Kafka核心特…...
算法刷题-回溯
今天给大家分享的还是一道关于dfs回溯的问题,对于这类问题大家还是要多刷和总结,总体难度还是偏大。 对于回溯问题有几个关键点: 1.首先对于这类回溯可以节点可以随机选择的问题,要做mian函数中循环调用dfs(i&#x…...