pytorch实现长短期记忆网络 (LSTM)
人工智能例子汇总:AI常见的算法和例子-CSDN博客
LSTM 通过 记忆单元(cell) 和 三个门控机制(遗忘门、输入门、输出门)来控制信息流:
记忆单元(Cell State)
- 负责存储长期信息,并通过门控机制决定保留或丢弃信息。
遗忘门(Forget Gate, ftf_tft)
输入门(Input Gate, iti_tit)
输出门(Output Gate, oto_tot)
特性 | 传统 RNN | LSTM |
---|---|---|
记忆能力 | 短期记忆 | 长短期记忆 |
计算复杂度 | 低 | 高 |
解决梯度消失 | 否 | 是 |
适用场景 | 短序列数据 | 长序列数据 |
LSTM 应用场景
- 自然语言处理(NLP):文本生成、情感分析、机器翻译
- 时间序列预测:股票预测、天气预报、传感器数据分析
- 语音识别:自动字幕生成、语音转文字(ASR)
- 机器人与控制系统:智能体决策、自动驾驶
例子:
下面例子实现了一个 基于 LSTM 的强化学习智能体,在 1D 网格环境 里移动,并找到最优路径。
最终,我们 绘制 5 条测试路径,并高亮显示最佳路径(红色)
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt# ========== 1. 定义 LSTM 策略网络 ==========
class LSTMPolicy(nn.Module):def __init__(self, input_size, hidden_size, output_size, num_layers=1):super(LSTMPolicy, self).__init__()self.hidden_size = hidden_sizeself.num_layers = num_layersself.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)self.fc = nn.Linear(hidden_size, output_size)self.softmax = nn.Softmax(dim=-1)def forward(self, x, hidden_state):batch_size = x.size(0)# 确保 hidden_state 维度正确if hidden_state[0].dim() == 2:hidden_state = (hidden_state[0].unsqueeze(1).repeat(1, batch_size, 1),hidden_state[1].unsqueeze(1).repeat(1, batch_size, 1))out, hidden_state = self.lstm(x, hidden_state)out = self.fc(out[:, -1, :]) # 取最后时间步的输出action_prob = self.softmax(out) # 归一化输出,作为策略return action_prob, hidden_statedef init_hidden(self, batch_size=1):return (torch.zeros(self.num_layers, batch_size, self.hidden_size),torch.zeros(self.num_layers, batch_size, self.hidden_size))# ========== 2. 创建网格环境 ==========
class GridWorld:def __init__(self, grid_size=10, goal_position=9):self.grid_size = grid_sizeself.goal_position = goal_positionself.reset()def reset(self):self.position = 0return self.positiondef step(self, action):if action == 0:self.position = max(0, self.position - 1)elif action == 1:self.position = min(self.grid_size - 1, self.position + 1)reward = 1 if self.position == self.goal_position else -0.1done = self.position == self.goal_positionreturn self.position, reward, done# ========== 3. 训练智能体 ==========
def train(num_episodes=500, max_steps=50):env = GridWorld()input_size = 1hidden_size = 64output_size = 2num_layers = 1policy = LSTMPolicy(input_size, hidden_size, output_size, num_layers)optimizer = optim.Adam(policy.parameters(), lr=0.01)gamma = 0.99for episode in range(num_episodes):state = torch.tensor([[env.reset()]], dtype=torch.float32).unsqueeze(0) # (1, 1, input_size)hidden_state = policy.init_hidden(batch_size=1)log_probs = []rewards = []for step in range(max_steps):action_probs, hidden_state = policy(state, hidden_state)action = torch.multinomial(action_probs, 1).item()log_prob = torch.log(action_probs.squeeze(0)[action])log_probs.append(log_prob)next_state, reward, done = env.step(action)rewards.append(reward)if done:breakstate = torch.tensor([[next_state]], dtype=torch.float32).unsqueeze(0)# 计算回报并更新策略returns = []R = 0for r in reversed(rewards):R = r + gamma * Rreturns.insert(0, R)returns = torch.tensor(returns, dtype=torch.float32)returns = (returns - returns.mean()) / (returns.std() + 1e-9)loss = sum([-log_prob * R for log_prob, R in zip(log_probs, returns)])optimizer.zero_grad()loss.backward()optimizer.step()if (episode + 1) % 50 == 0:print(f"Episode {episode + 1}/{num_episodes}, Total Reward: {sum(rewards)}")torch.save(policy.state_dict(), "policy.pth")# 训练智能体
train(500)# ========== 4. 测试智能体并绘制最佳路径 ==========
def test(num_episodes=5):env = GridWorld()input_size = 1hidden_size = 64output_size = 2num_layers = 1policy = LSTMPolicy(input_size, hidden_size, output_size, num_layers)policy.load_state_dict(torch.load("policy.pth"))plt.figure(figsize=(10, 5))best_path = Nonebest_steps = float('inf')for episode in range(num_episodes):state = torch.tensor([[env.reset()]], dtype=torch.float32).unsqueeze(0) # (1, 1, input_size)hidden_state = policy.init_hidden(batch_size=1)positions = [env.position] # 记录位置变化while True:action_probs, hidden_state = policy(state, hidden_state)action = torch.argmax(action_probs, dim=-1).item()next_state, reward, done = env.step(action)positions.append(next_state)if done:breakstate = torch.tensor([[next_state]], dtype=torch.float32).unsqueeze(0)# 记录最佳路径(最短步数)if len(positions) < best_steps:best_steps = len(positions)best_path = positions# 绘制普通路径(蓝色)plt.plot(range(len(positions)), positions, marker='o', linestyle='-', color='blue', alpha=0.6,label=f'Episode {episode + 1}' if episode == 0 else "")# 绘制最佳路径(红色)if best_path:plt.plot(range(len(best_path)), best_path, marker='o', linestyle='-', color='red', linewidth=2,label="Best Path")# 打印最佳路径print(f"Best Path (steps={best_steps}): {best_path}")plt.xlabel("Time Steps")plt.ylabel("Agent Position")plt.title("Agent's Movement Path (Best Path in Red)")plt.legend()plt.grid(True)plt.show()# 测试并绘制智能体移动路径
test(5)
相关文章:

pytorch实现长短期记忆网络 (LSTM)
人工智能例子汇总:AI常见的算法和例子-CSDN博客 LSTM 通过 记忆单元(cell) 和 三个门控机制(遗忘门、输入门、输出门)来控制信息流: 记忆单元(Cell State) 负责存储长期信息&…...

【ubuntu】双系统ubuntu下一键切换到Windows
ubuntu下一键切换到Windows 1.4.1 重启脚本1.4.2 快捷方式1.4.3 移动快捷方式到系统目录 按前文所述文档,开机默认启动ubuntu。Windows切换到Ubuntu直接重启就行了,而Ubuntu切换到Windows稍微有点麻烦。可编辑切换重启到Windows的快捷方式。 1.4.1 重启…...

【PyTorch】6.张量形状操作:在深度学习的 “魔方” 里,玩转张量形状
目录 1. reshape 函数的用法 2. transpose 和 permute 函数的使用 4. squeeze 和 unsqueeze 函数的用法 5. 小节 个人主页:Icomi 专栏地址:PyTorch入门 在深度学习蓬勃发展的当下,PyTorch 是不可或缺的工具。它作为强大的深度学习框架&am…...

大模型GUI系列论文阅读 DAY4续:《Large Language Model Agent for Fake News Detection》
摘要 在当前的数字时代,在线平台上虚假信息的迅速传播对社会福祉、公众信任和民主进程构成了重大挑战,并影响着关键决策和公众舆论。为应对这些挑战,自动化假新闻检测机制的需求日益增长。 预训练的大型语言模型(LLMs࿰…...

论文阅读(九):通过概率图模型建立连锁不平衡模型和进行关联研究:最新进展访问之旅
1.论文链接:Modeling Linkage Disequilibrium and Performing Association Studies through Probabilistic Graphical Models: a Visiting Tour of Recent Advances 摘要: 本章对概率图模型(PGMs)的最新进展进行了深入的回顾&…...

python小知识-typing注解你的程序
python小知识-typing注解你的程序 1. Typing的简介 typing 是 Python 的一个标准库,它提供了类型注解的支持,但并不会强制类型检查。类型注解在 Python 3.5 中引入,并在后续版本中得到了增强和扩展。typing 库允许开发者为变量、函数参数和…...

git基础使用--1--版本控制的基本概念
git基础使用–1–版本控制的基本概念 1.版本控制的需求背景,即为啥需要版本控制 先说啥叫版本,这个就不多说了吧,我们写代码的时候肯定不可能一蹴而就,肯定是今天写一点,明天写一点,对于项目来讲ÿ…...

“新月智能武器系统”CIWS,开启智能武器的新纪元
新月人物传记:人物传记之新月篇-CSDN博客 相关文章链接:星际战争模拟系统:新月的编程之道-CSDN博客 新月智能护甲系统CMIA--未来战场的守护者-CSDN博客 “新月之智”智能战术头盔系统(CITHS)-CSDN博客 目录 智能武…...

JVM运行时数据区域-附面试题
Java虚拟机在执行Java程序的过程中会把它所管理的内存划分为若干个不同的数据区域。这些区域 有各自的用途,以及创建和销毁的时间,有的区域随着虚拟机进程的启动而一直存在,有些区域则是 依赖用户线程的启动和结束而建立和销毁。 1. 程序计…...

增删改查(CRUD)操作
文章目录 MySQL系列:1.CRUD简介2.Create(创建)2.1单行数据全列插入2.2 单行数据指定插入2.3 多⾏数据指定列插⼊ 3.Retrieve(读取)3.1 Select查询3.1.1 全列查询3.1.2 指定列查询3.1.3 查询字段为表达式(都是临时表不会对原有表数据产生影响)…...

Vue.js `Suspense` 和异步组件加载
Vue.js Suspense 和异步组件加载 今天我们来聊聊 Vue 3 中的一个强大特性:<Suspense> 组件,以及它如何帮助我们更优雅地处理异步组件加载。如果你曾在 Vue 项目中处理过异步组件加载,那么这篇文章将为你介绍一种更简洁高效的方式。 什…...

HTB:LinkVortex[WriteUP]
目录 连接至HTB服务器并启动靶机 信息收集 使用rustscan对靶机TCP端口进行开放扫描 使用nmap对靶机TCP开放端口进行脚本、服务扫描 使用nmap对靶机TCP开放端口进行漏洞、系统扫描 使用nmap对靶机常用UDP端口进行开放扫描 使用gobuster对靶机进行路径FUZZ 使用ffuf堆靶机…...

Linux命令入门
Linux命令入门 ls命令 ls命令的作用是列出目录下的内容,语法细节如下: 1s[-a -l -h] [Linux路径] -a -l -h是可选的选项 Linux路径是此命令可选的参数 当不使用选项和参数,直接使用ls命令本体,表示:以平铺形式,列出当前工作目录下的内容 ls命令的选项 -a -a选项&a…...

【问题】Chrome安装不受支持的扩展 解决方案
此扩展程序已停用,因为它已不再受支持 Chromium 建议您移除它。详细了解受支持的扩展程序 此扩展程序已停用,因为它已不再受支持 详情移除 解决 1. 解压扩展 2.打开manifest.json 3.修改版本 将 manifest_version 改为3及以上 {"manifest_ver…...

【题解】AtCoder Beginner Contest ABC391 D Gravity
题目大意 原题面链接 在一个 1 0 9 W 10^9\times W 109W 的平面里有 N N N 个方块。我们用 ( x , y ) (x,y) (x,y) 表示第 x x x 列从下往上数的 y y y 个位置。第 i i i 个方块的位置是 ( x i , y i ) (x_i,y_i) (xi,yi)。现在执行无数次操作,每一次…...

使用 SpringBoot+Thymeleaf 模板引擎进行 Web 开发
目录 一、什么是 Thymeleaf 模板引擎 二、Thymeleaf 模板引擎的 Maven 坐标 三、配置 Thymeleaf 四、访问页面 五、访问静态资源 六、Thymeleaf 使用示例 七、Thymeleaf 常用属性 前言 在现代 Web 开发中,模板引擎被广泛用于将动态内容渲染到静态页面中。Thy…...

【Java异步编程】CompletableFuture综合实战:泡茶喝水与复杂的异步调用
文章目录 一. 两个异步任务的合并:泡茶喝水二. 复杂的异步调用:结果依赖,以及异步执行调用等 一. 两个异步任务的合并:泡茶喝水 下面的代码中我们实现泡茶喝水。这里分3个任务:任务1负责洗水壶、烧开水,任…...

Nginx知识
nginx 精简的配置文件 worker_processes 1; # 可以理解为一个内核一个worker # 开多了可能性能不好events {worker_connections 1024; } # 一个 worker 可以创建的连接数 # 1024 代表默认一般不用改http {include mime.types;# 代表引入的配置文件# mime.types 在 ngi…...

Unity开发游戏使用XLua的基础
Unity使用Xlua的常用编码方式,做一下记录 1、C#调用lua 1、Lua解析器 private LuaEnv env new LuaEnv();//保持它的唯一性void Start(){env.DoString("print(你好lua)");//env.DoString("require(Main)"); 默认在resources文件夹下面//帮助…...

AI-ISP论文Learning to See in the Dark解读
论文地址:Learning to See in the Dark 图1. 利用卷积网络进行极微光成像。黑暗的室内环境。相机处的照度小于0.1勒克斯。索尼α7S II传感器曝光时间为1/30秒。(a) 相机在ISO 8000下拍摄的图像。(b) 相机在ISO 409600下拍摄的图像。该图像存在噪点和色彩偏差。©…...

OpenCV:开运算
目录 1. 简述 2. 用腐蚀和膨胀实现开运算 2.1 代码示例 2.2 运行结果 3. 开运算接口 3.1 参数详解 3.2 代码示例 3.3 运行结果 4. 开运算应用场景 5. 注意事项 6. 总结 相关阅读 OpenCV:图像的腐蚀与膨胀-CSDN博客 OpenCV:闭运算-CSDN博客 …...

38. RTC实验
一、RTC原理详解 1、6U内部自带到了一个RTC外设,确切的说是SRTC。6U和6ULL的RTC内容在SNVS章节。6U的RTC分为LP和HP。LP叫做SRTC,HP是RTC,但是HP的RTC掉电以后数据就丢失了,即使用了纽扣电池也没用。所以必须要使用LP,…...

Flutter 新春第一弹,Dart 宏功能推进暂停,后续专注定制数据处理支持
在去年春节,Flutter 官方发布了宏(Macros)编程的原型支持, 同年的 5 月份在 Google I/O 发布的 Dart 3.4 宣布了宏的实验性支持,但是对于 Dart 内部来说,从启动宏编程实验开始已经过去了几年,但…...

巴菲特价值投资思想的核心原则
巴菲特价值投资思想的核心原则 关键词:安全边际、长期投资、内在价值、管理团队、经济护城河、简单透明 摘要:本文深入探讨了巴菲特价值投资思想的核心原则,包括安全边际、长期投资、企业内在价值、优秀管理团队、经济护城河和简单透明的业务…...

C 或 C++ 中用于表示常量的后缀:1ULL
1ULL 是一个在 C 或 C 中用于表示常量的后缀,它具体指示编译器将这个数值视为特定类型的整数。让我们详细解释一下: 1ULL 的含义 1: 这是最基本的部分,表示数值 1。U: 表示该数值是无符号(Unsigned)的。这意味着它只…...

vue3中el-input无法获得焦点的问题
文章目录 现象两次nextTick()加setTimeout()解决结论 现象 el-input被外层div包裹了,设置autofocus不起作用: <el-dialog v-model"visible" :title"title" :append-to-bodytrue width"50%"><el-form v-model&q…...

程序诗篇里的灵动笔触:指针绘就数据的梦幻蓝图<3>
大家好啊,我是小象٩(๑ω๑)۶ 我的博客:Xiao Xiangζั͡ޓއއ 很高兴见到大家,希望能够和大家一起交流学习,共同进步。 今天我们来对上一节做一些小补充,了解学习一下assert断言,指针的使用和传址调用…...

(三)QT——信号与槽机制——计数器程序
目录 前言 信号(Signal)与槽(Slot)的定义 一、系统自带的信号和槽 二、自定义信号和槽 三、信号和槽的扩展 四、Lambda 表达式 总结 前言 信号与槽机制是 Qt 中的一种重要的通信机制,用于不同对象之间的事件响…...

Qt 5.14.2 学习记录 —— 이십이 QSS
文章目录 1、概念2、基本语法3、给控件应用QSS设置4、选择器1、子控件选择器2、伪类选择器 5、样式属性box model 6、实例7、登录界面 1、概念 参考了CSS,都是对界面的样式进行设置,不过功能不如CSS强大。 可通过QSS设置样式,也可通过C代码…...

Hot100之哈希
1两数之和 题目 思路解析 解法1--两次循环 解法2--哈希表一次循环 代码 解法1--两次循环 class Solution {public int[] twoSum(int[] nums, int target) {int nums1[] new int[2];int length nums.length;for (int i 0; i < length; i) {for (int j i 1; j < …...