当前位置: 首页 > article >正文

pytorch实现门控循环单元 (GRU)

 人工智能例子汇总:AI常见的算法和例子-CSDN博客  

特性GRULSTM
计算效率更快,参数更少相对较慢,参数更多
结构复杂度只有两个门(更新门和重置门)三个门(输入门、遗忘门、输出门)
处理长时依赖一般适用于中等长度依赖更适合处理超长时序依赖
训练速度训练更快,梯度更稳定训练较慢,占用更多内存

例子:

import torch
import torch.nn as nn
import torch.optim as optim
import random
import matplotlib.pyplot as plt# 🏁 迷宫环境(5×5)
class MazeEnv:def __init__(self, size=5):self.size = sizeself.state = (0, 0)  # 起点self.goal = (size-1, size-1)  # 终点self.actions = [(0,1), (0,-1), (1,0), (-1,0)]  # 右、左、下、上def reset(self):self.state = (0, 0)  # 重置起点return self.statedef step(self, action):dx, dy = self.actions[action]x, y = self.statenx, ny = max(0, min(self.size-1, x+dx)), max(0, min(self.size-1, y+dy))reward = 1 if (nx, ny) == self.goal else -0.1done = (nx, ny) == self.goalself.state = (nx, ny)return (nx, ny), reward, done# 🤖 GRU 策略网络
class GRUPolicy(nn.Module):def __init__(self, input_size, hidden_size, output_size):super(GRUPolicy, self).__init__()self.gru = nn.GRU(input_size, hidden_size, batch_first=True)self.fc = nn.Linear(hidden_size, output_size)def forward(self, x, hidden):out, hidden = self.gru(x, hidden)out = self.fc(out[:, -1, :])  # 只取最后时间步return out, hidden# 🎯 训练参数
env = MazeEnv(size=5)
policy = GRUPolicy(input_size=2, hidden_size=16, output_size=4)
optimizer = optim.Adam(policy.parameters(), lr=0.01)
loss_fn = nn.CrossEntropyLoss()# 🎓 训练
num_episodes = 500
epsilon = 1.0  # 初始的ε值,控制探索的概率
epsilon_min = 0.01  # 最小ε值
epsilon_decay = 0.995  # ε衰减率
best_path = []  # 用于存储最佳路径for episode in range(num_episodes):state = env.reset()hidden = torch.zeros(1, 1, 16)  # GRU 初始状态states, actions, rewards = [], [], []logits_list = []  for _ in range(20):  # 最多 20 步state_tensor = torch.tensor([[state[0], state[1]]], dtype=torch.float32).unsqueeze(0)logits, hidden = policy(state_tensor, hidden)logits_list.append(logits)# ε-greedy 策略if random.random() < epsilon:action = random.choice(range(4))  # 随机选择动作else:action = torch.argmax(logits, dim=1).item()  # 选择最大值对应的动作next_state, reward, done = env.step(action)states.append(state)actions.append(action)rewards.append(reward)if done:print(f"Episode {episode} - Reached Goal!")# 找到最优路径best_path = states + [next_state]  # 当前 episode 的路径breakstate = next_state# 计算损失logits = torch.cat(logits_list, dim=0)  # (T, 4)action_tensor = torch.tensor(actions, dtype=torch.long)  # (T,)loss = loss_fn(logits, action_tensor)  optimizer.zero_grad()loss.backward()optimizer.step()# 衰减 εepsilon = max(epsilon_min, epsilon * epsilon_decay)if episode % 100 == 0:print(f"Episode {episode}, Loss: {loss.item():.4f}, Epsilon: {epsilon:.4f}")# 🧐 确保 best_path 已经记录
if len(best_path) == 0:print("No path found during training.")
else:print(f"Best path: {best_path}")# 🚀 测试路径(只绘制最佳路径)
fig, ax = plt.subplots(figsize=(6,6))# 初始化迷宫图
maze = [[0 for _ in range(5)] for _ in range(5)]  # 5×5 迷宫
ax.imshow(maze, cmap="coolwarm", origin="upper")# 画网格
ax.set_xticks(range(5))
ax.set_yticks(range(5))
ax.grid(True, color="black", linewidth=0.5)# 画出最佳路径(红色)
for (x, y) in best_path:ax.add_patch(plt.Rectangle((y, x), 1, 1, color="red", alpha=0.8))# 画起点和终点
ax.text(0, 0, "S", ha="center", va="center", fontsize=14, color="white", fontweight="bold")
ax.text(4, 4, "G", ha="center", va="center", fontsize=14, color="white", fontweight="bold")plt.title("GRU RL Agent - Best Path")
plt.show()

相关文章:

pytorch实现门控循环单元 (GRU)

人工智能例子汇总&#xff1a;AI常见的算法和例子-CSDN博客 特性GRULSTM计算效率更快&#xff0c;参数更少相对较慢&#xff0c;参数更多结构复杂度只有两个门&#xff08;更新门和重置门&#xff09;三个门&#xff08;输入门、遗忘门、输出门&#xff09;处理长时依赖一般适…...

有没有个性化的UML图例

绿萝小绿萝 (53****338) 2012-05-10 11:55:45 各位大虾&#xff0c;有没有个性化的UML图例 绿萝小绿萝 (53****338) 2012-05-10 11:56:03 例如部署图或时序图的图例 潘加宇 (35***47) 2012-05-10 12:24:31 "个性化"指的是&#xff1f; 你的意思使用你自己的图标&…...

在CentOS服务器上部署DeepSeek R1

在CentOS服务器上部署DeepSeek R1,并通过公网IP与其进行对话,可以按照以下步骤操作: 一、环境准备 系统要求: CentOS 8+(需支持AVX512指令集)。 硬件配置: GPU版本:NVIDIA驱动520+,CUDA 11.8+。 CPU版本:至少16核处理器,64GB内存。 存储空间:原始模型需要30GB,量…...

Vue3.0实战:大数据平台可视化

文章目录 创建vue3.0项目项目初始化项目分辨率响应式设置项目顶部信息条创建页面主体创建全局引入echarts和axios后台接口创建express销售总量图实现完整项目下载项目任何问题都可在评论区,或者直接私信即可。 创建vue3.0项目 创建项目: vue create vueecharts选择第三项:…...

洛谷 P1130 红牌 C语言

题目描述 某地临时居民想获得长期居住权就必须申请拿到红牌。获得红牌的过程是相当复杂&#xff0c;一共包括 N 个步骤。每一步骤都由政府的某个工作人员负责检查你所提交的材料是否符合条件。为了加快进程&#xff0c;每一步政府都派了 M 个工作人员来检查材料。不幸的是&…...

语音识别播报人工智能分类垃圾桶(论文+源码)

2.1 需求分析 本次语音识别播报人工智能分类垃圾桶&#xff0c;设计功能要求如下∶ 1、具有四种垃圾桶&#xff0c;分别为用来回收厨余垃圾&#xff0c;有害垃圾&#xff0c;可回收垃圾&#xff0c;其他垃圾。 2、当用户语音说出“旧报纸”&#xff0c;“剩菜”等特定词语时…...

MVC、MVP和MVVM模式

MVC模式中&#xff0c;视图和模型之间直接交互&#xff0c;而MVP模式下&#xff0c;视图与模型通过Presenter进行通信&#xff0c;MVVM则采用双向绑定&#xff0c;减少手动同步视图和模型的工作。每种模式都有其优缺点&#xff0c;适合不同规模和类型的项目。 ### MVVM 与 MVP…...

shiro学习五:使用springboot整合shiro。在前面学习四的基础上,增加shiro的缓存机制,源码讲解:认证缓存、授权缓存。

文章目录 前言1. 直接上代码最后在讲解1.1 新增的pom依赖1.2 RedisCache.java1.3 RedisCacheManager.java1.4 jwt的三个类1.5 ShiroConfig.java新增Bean 2. 源码讲解。2.1 shiro 缓存的代码流程。2.2 缓存流程2.2.1 认证和授权简述2.2.2 AuthenticatingRealm.getAuthentication…...

负载均衡器高可用部署

Haproxy 和 Keepalived安装Haproxy配置文件准备Keepalived配置及健康检查启动Haproxy & Keepalived服务继续上一篇文章《K8S集群架构及主机准备》,下面介绍负载均衡器搭建过程 Haproxy 和 Keepalived安装 在负载均衡器两个主机上安装即可 apt install haproxy keepalived…...

属性编程与权限编程

问题 如何获取文件的大小&#xff0c;时间戳以及类型等信息&#xff1f; 再论 inode 文件的物理载体是硬盘&#xff0c;硬盘的最小存储单元是扇区 (每个扇区 512 字节) 文件系统以 块 为单位(每个块 8 个扇区) 管理文件数据 文件元信息 (创建者、创建日期、文件大小&#x…...

用 HTML、CSS 和 JavaScript 实现抽奖转盘效果

顺序抽奖 前言 这段代码实现了一个简单的抽奖转盘效果。页面上有一个九宫格布局的抽奖区域&#xff0c;周围八个格子分别放置了不同的奖品名称&#xff0c;中间是一个 “开始抽奖” 的按钮。点击按钮后&#xff0c;抽奖区域的格子会快速滚动&#xff0c;颜色不断变化&#xf…...

R语言绘制有向无环图(DAG)

有向无环图&#xff08;Directed Acyclic Graph&#xff0c;简称DAG&#xff09;是一种特殊的有向图&#xff0c;它由一系列顶点和有方向的边组成&#xff0c;其中不存在任何环路。这意味着从任一顶点出发&#xff0c;沿着箭头方向移动&#xff0c;你永远无法回到起始点。 从流…...

报错Too many open files

1、先查看系统最大打开文件数 # 查看当前系统打开文件最大数 # ulimit -a core file size (blocks, -c) 0 data seg size (kbytes, -d) unlimited scheduling priority (-e) 0 file size (blocks, -f) unlimited pending signal…...

Spring Web MVC基础第一篇

目录 1.什么是Spring Web MVC&#xff1f; 2.创建Spring Web MVC项目 3.注解使用 3.1RequestMapping&#xff08;路由映射&#xff09; 3.2一般参数传递 3.3RequestParam&#xff08;参数重命名&#xff09; 3.4RequestBody&#xff08;传递JSON数据&#xff09; 3.5Pa…...

129.求根节点到叶节点数字之和(遍历思想)

Problem: 129.求根节点到叶节点数字之和 文章目录 题目描述思路复杂度Code 题目描述 思路 遍历思想(利用二叉树的先序遍历) 直接利用二叉树的先序遍历&#xff0c;将遍历过程中的节点值先利用字符串拼接起来遇到根节点时再转为数字并累加起来&#xff0c;在归的过程中&#xf…...

unity中的动画混合树

为什么需要动画混合树&#xff0c;动画混合树有什么作用&#xff1f; 在Unity中&#xff0c;动画混合树&#xff08;Animation Blend Tree&#xff09;是一种用于管理和混合多个动画状态的工具&#xff0c;包括1D和2D两种类型&#xff0c;以下是其作用及使用必要性的介绍&…...

AWS EMR使用Apache Kylin快速分析大数据

在AWS Elastic MapReduce&#xff08;EMR&#xff09;集群上部署和使用Apache Kylin&#xff0c;以实现对大规模数据集的快速分析&#xff0c;企业可以充分利用云计算的强大资源和Kylin的数据分析能力&#xff0c;实现快速、高效的数据分析。以下是该案例的详细步骤和要点&…...

MySQL存储过程和存储函数_mysql 存储过 call proc_stat_data(3,null)

2&#xff09;很难调试存储过程。只有少数数据库管理系统允许调试存储过程。不幸的是&#xff0c;MySQL不提供调试存储过程的功能。 1.2 数据准备 创建数据库&#xff1a; DEFAULT CHARACTER SET utf8; use test;这里记得设置编码&#xff01; 创建测试表&#xff1a; DROP…...

spacemacs gnuplot

个人博客地址&#xff1a;spacemacs gnuplot | 一张假钞的真实世界 环境 Ubuntu 16.10Emacs 24 安装过程 spacemacs安装 安装Emacs sudo apt-get install emacs 安装spacemacs &#xff08;1&#xff09;如果已经存在Emacs配置文件&#xff0c;首先备份&#xff1a; c…...

Flink2支持提交StreamGraph到Flink集群

最近研究Flink源码的时候&#xff0c;发现Flink已经支持提交StreamGraph到集群了&#xff0c;替换掉了原来的提交JobGraph。 新增ExecutionPlan接口&#xff0c;将JobGraph和StreamGraph作为实现。 Flink集群Dispatcher也进行了修改&#xff0c;从JobGraph改成了接口Executio…...

Kotlin 使用 Springboot 反射执行方法并自动传参

在使用反射的时候&#xff0c;执行方法的时候在想如果Springboot 能对需要执行的反射方法的参数自动注入就好了。所以就有了下文。 知识点 获取上下文通过上下文获取 Bean通过上下文创建一个对象&#xff0c;该对象所需的参数由 Springboot 自己注入 创建参数 因为需要对反…...

索罗斯的“反身性”(Reflexivity)理论:市场如何扭曲现实?(中英双语)

索罗斯的“反身性”&#xff08;Reflexivity&#xff09;理论&#xff1a;市场如何扭曲现实&#xff1f; 一、引言&#xff1a;市场是镜子&#xff0c;还是哈哈镜&#xff1f; 在传统经济学中&#xff0c;市场通常被认为是一个理性、有效的反映现实的系统。按照经典经济学理论…...

Vue 入门到实战 七

第7章 渲染函数 目录 7.1 DOM树 7.2 什么是渲染函数 7.3 h()函数 7.3.1 基本参数 7.3.2 约束 7.3.3 使用JavaScript代替模板功能 7.1 DOM树 7.2 什么是渲染函数 在多数情况下&#xff0c;Vue推荐使用模板template来创建HTML。然而在一些应用场景中&#xff0c;需要使用J…...

系统学习算法: 专题八 二叉树中的深搜

深搜其实就是深度优先遍历&#xff08;dfs&#xff09;&#xff0c;与此相对的还有宽度优先遍历&#xff08;bfs&#xff09; 如果学完数据结构有点忘记&#xff0c;如下图&#xff0c;左边是dfs&#xff0c;右边是bfs 而二叉树的前序&#xff0c;中序&#xff0c;后序遍历都可…...

进程、线程、内存和IO模型的概念详解

进程、线程、内存和IO模型的概念详解 1 进程与线程1.1 进程1.1.1 进程分类1.1.2 进程的状态和转换1.1.3 僵尸进程和孤儿进程的区别1.1.4 进程之间的通信1.1.5 用户态和内核态1.1.6 用户空间和内核空间 1.2 线程1.2.1 线程的状态和转换1.2.2 进程与线程的区别 1.3 多进程和多线程…...

DeepSeek:AI领域的创新先锋

在人工智能领域&#xff0c;DeepSeek正以其独特的创新技术引领着行业的发展。作为一款高性能、低成本的AI模型&#xff0c;DeepSeek在架构设计、训练优化和应用场景等多个方面都展现出了显著的创新点。这些创新不仅使其在技术上取得了突破&#xff0c;也为AI的普及化和应用拓展…...

Labelme转Voc、Coco

Q&#xff1a;在github找的cv代码基本都是根据现有且流行的公共数据集格式组织的训练数据集&#xff0c;这导致我使用labelme标注好之后需要我们重新组织数据集 labelme2coco #!/usr/bin/env pythonimport argparse import collections import datetime import glob import j…...

pytorch实现变分自编码器

人工智能例子汇总&#xff1a;AI常见的算法和例子-CSDN博客 变分自编码器&#xff08;Variational Autoencoder, VAE&#xff09;是一种生成模型&#xff0c;属于深度学习中的无监督学习方法。它通过学习输入数据的潜在分布&#xff08;Latent Distribution&#xff09;&…...

使用 Numpy 自定义数据集,使用pytorch框架实现逻辑回归并保存模型,然后保存模型后再加载模型进行预测,对预测结果计算精确度和召回率及F1分数

1. 导入必要的库 首先&#xff0c;导入我们需要的库&#xff1a;Numpy、Pytorch 和相关工具包。 import numpy as np import torch import torch.nn as nn import torch.optim as optim from sklearn.metrics import accuracy_score, recall_score, f1_score2. 自定义数据集 …...

JVM方法区

一、栈、堆、方法区的交互关系 二、方法区的理解: 尽管所有的方法区在逻辑上属于堆的一部分&#xff0c;但是一些简单的实现可能不会去进行垃圾收集或者进行压缩&#xff0c;方法区可以看作是一块独立于Java堆的内存空间。 方法区(Method Area)与Java堆一样&#xff0c;是各个…...