GRL-图强化学习
GRL代码解析
- 一、agent.py
- 二、drl.py
- 三、env.py
- 四、policy.py
- 五、utils.py
一、agent.py
这个Python文件agent.py
实现了一个强化学习(Reinforcement Learning, RL)的智能体,用于在图环境(graph environment)中进行学习。以下是文件的主要部分的概述:
-
导入依赖:
- 导入了
matplotlib.pyplot
用于绘图,tqdm
用于在循环中显示进度条。 - 从
utils.py
和policy.py
中导入了一些功能性代码(graph_nn
是图神经网络)。 - 从
drl.py
导入了REINFORCE
类,这是强化学习的一种算法。 - 从
cora_gcn.py
中导入了CoraGraphEnv
,可能是图环境的一个实现。 - 从
env.py
中导入graph_env
,可能是定义的环境。 - 从
torch
库中导入了设备管理和概率分布。
- 导入了
-
环境配置:
- 设置了使用
CUDA
(如果可用)或者CPU
。 - 设置随机种子以保证可复现性。
- 实例化了
graph_env
(图形环境)。
- 设置了使用
-
超参数定义:
- 定义了学习速率
learning_rate
,剧集数量episodes
,折扣因子gamma
,以及日志打印间隔log_interval
。
- 定义了学习速率
-
策略网络:
- 实例化了图神经网络
graph_nn
作为策略网络,根据环境动作空间、输入维度和隐藏维度。
- 实例化了图神经网络
-
学习器:
- 实例化了
REINFORCE
算法作为学习器,传入策略网络、学习速率和折扣因子。
- 实例化了
-
学习循环:
- 使用
tqdm
进行进度显示,迭代episodes
次。 - 在每次迭代中重置环境,执行一系列操作直到达到环境的
done
状态。 - 在每个步骤中,获取当前状态下的动作概率分布,选择动作,并与环境交互获得下一个状态、奖励和是否完成。
- 将这些数据存入学习器的记忆中。
- 更新累计奖励。
- 每次剧集结束后通过
learn()
方法更新策略网络。
- 使用
-
可视化结果:
- 收集每集的奖励,并绘制奖励随时间变化的曲线。
- 将奖励曲线保存为图片。
整体上,这是一个图神经网络通过强化学习来优化策略的任务,代码使用了REINFORCE算法进行策略学习,并最终保存奖励曲线图。
二、drl.py
这个Python源代码文件drl.py
实现了一个简单的强化学习算法类REINFORCE
,该类使用了策略梯度方法(Policy Gradient Method)进行参数优化。以下是文件概述:
-
目的:
- 定义并实现了一个名为
REINFORCE
的强化学习算法类。 - 用于优化给定的策略函数(例如图神经网络模型)。
- 定义并实现了一个名为
-
主要特征:
- 依赖于PyTorch库来构建和训练模型。
- 使用了Adam优化算法进行参数优化。
- 包含了一个经验数据存储池(experience buffer)用于存储经验数据。
- 引入了基线(baseline)以提高学习稳定性。
-
类成员:
policy
:策略函数,待优化的神经网络模型。optimizer
:优化算子,用于更新模型参数。gamma
:折扣因子,用于计算未来的回报。experience_buffer
:存储经验数据的列表。baseline
:用于减少方差且提高学习效率的基线。
-
方法:
__init__
:初始化方法,设置优化器和相关参数。memory_data(self, data)
:将新的经验数据添加到经验池中。learn(self)
:- 计算折扣回报并进行反向传播。
- 如果基线数据少于100个,直接用累计折现回报作为loss。
- 如果基线数据超过100个,使用最近10个回报的平均值作为基线,以减少方差。
-
注意事项:
- 代码中有大量的空行,应该清理。
- 在计算
loss
时,应注意符号的使用,避免潜在的错误。 - 确认
prob
是否应该是一个log概率,这在策略梯度方法中是常见的。 - 基线计算(在
else
部分)通过转换最近的回报为一个PyTorch张量来计算,这需要和模型的数据类型保持一致。
总结:drl.py
文件定义了强化学习算法REINFORCE
,主要用于通过梯度上升法来优化给定策略网络。其中包含了保存经验数据、计算折扣回报、更新模型参数等方法。
三、env.py
这个env.py
文件定义了一个基于图的环境模型类graph_env
,它是OpenAI Gym环境的一个封装器。以下是概述:
-
目的: 旨在将标准的Gym环境(在这个例子中是’CartPole-v1’)的状态转换成图数据结构,以便可以使用图神经网络(Graph Neural Networks,GNNs)进行学习和处理。
-
依赖:
gym
:用于导入OpenAI Gym环境。torch
:用于创建和操作张量。torch_geometric.data
:用于处理图数据结构。
-
核心类:
graph_env
:继承自gym.Env
,重写了标准的Gym环境的部分功能,使其能够返回图格式数据。
-
功能:
__init__
:初始化方法,创建一个CartPole-v1
环境的实例,并设置观察和动作空间。to_pyg_data
:将环境状态数据转换成一个可以被torch_geometric
处理的图数据结构(Data
对象),包括节点特征和边索引。reset
:重置环境到初始状态,并将这个状态转换为图数据结构。step
:根据采取的动作将环境推进到下一个状态,并返回转换后的图状态、奖励、环境是否结束以及附加信息。
-
图数据构建:
- 在
to_pyg_data
方法中,节点特征是由当前状态的不同组合构成的,边索引是由节点全排列生成的,表示图中所有可能的边。
- 在
-
适用性:
- 这个类适用于希望将图神经网络应用于像CartPole这样的经典控制问题环境的情况。
-
注意点:
- 这个简单的转换可能不足以表示所有类型的环境状态为图数据结构,特别是当环境复杂性提高时。
permutations
用于生成图中所有可能的边,这并不适用于所有图场景,因为它假设所有节点之间都存在潜在的连接。
四、policy.py
这是一个用PyTorch编写的图神经网络(Graph Neural Network, GNN)模型,主要用于处理图结构的数据。以下是该源代码的概述:
-
依赖库:
torch
:PyTorch的 核心。torch.nn
:PyTorch的神经网络模块。torch.nn.functional
:PyTorch的函数式API,用于激活函数等。torch_geometric.nn
:用于图神经网络的PyTorch几何扩展库,包含专门的图处理层。
-
设备配置:
- 自动检查是否可用GPU,并将设备设置为
cuda:0
,否则使用CPU。
- 自动检查是否可用GPU,并将设备设置为
-
类定义:
graph_nn
:一个继承自nn.Module
的图神经网络类。- 初始化参数:
action_space
:动作空间的大小,决定输出层的神经元数。input_dim
:输入特征的维度。hidden_dim
:隐藏层神经元的维度。
- 网络结构:
GCNConv
:图卷积层。nn.Linear
:两个全连接层。LayerNorm
:图归一化层(但在实际的前向传播中并没有使用)。
- 前向传播:
- 采用ReLU作为激活函数。
- 使用全局池化来减少图的特征到单点特征。
- 最后使用log-softmax作为输出层,常用于分类任务。
- 初始化参数:
-
前向传播函数:
forward(self,x,edge_index)
:定义了网络的前向传播过程,接收节点特征x
和边索引edge_index
作为输入,并输出节点的分类log-softmax结果。
-
注解:
- 代码中有一些被注释掉的部分,可能是以前版本的操作,如
self.layer_norm
的调用方式。
- 代码中有一些被注释掉的部分,可能是以前版本的操作,如
这个模型是一个基于图的结构化数据学习框架,可以用于在图上的分类问题或其他需要在节点或图级别进行预测的问题。
五、utils.py
概述:
utils.py
是一个Python模块,属于一个用于图形神经网络(Graph Neural Network, GNN)相关项目的工具脚本。以下是该模块的功能概述:
-
导入库和模块:
torch
:导入PyTorch库,用于构建和训练神经网络。torch_geometric.data.Data
:从PyTorch Geometric中导入Data类,用于处理图形数据。itertools.permutations
:导入itertools中的permutations,用于生成可迭代对象的排列。matplotlib.pyplot
:用于绘制图表。numpy
:使用NumPy进行数值计算。random
:用于生成随机数。
-
功能函数:
-
seed_torch(seed)
:设置PyTorch、NumPy和Python的随机种子,以保证可重复性。如果CUDNN可用,还将设置相关选项以确保算法的确定性执行。 -
plot_reward(reward)
:接收一个奖励数组并绘制奖励曲线。此函数使用matplotlib库来创建图表,用于分析策略执行过程中累积奖励随时间(或迭代次数)的变化。
-
-
未使用的代码:有一行代码
plt.subplot(1, 3, 1)
被注释掉,说明可能原本计划在一个更大的画布上绘制多个子图,但最终没有使用。
这个模块可能用于支持图形数据的处理、结果的可视化以及实验的可重复性。它作为项目的一部分,可以被其他脚本或模块调用以提供辅助功能。
以下是使用Markdown格式描述各个文件功能的表格:
文件路径 | 功能描述 |
---|---|
agent.py | 实现了一个强化学习智能体,用于在图环境中使用REINFORCE算法进行策略学习。 |
drl.py | 定义并实现了REINFORCE算法类,基于策略梯度方法优化策略网络。 |
env.py | 封装了标准的Gym环境,将其转换为图数据结构,以便可以使用图神经网络进行学习和处理。 |
policy.py | 实现了一个图神经网络模型,用作策略网络来处理图结构的数据并输出动作概率分布。 |
utils.py | 提供了一系列工具函数,包括设置随机种子、绘图等,用于支持图神经网络训练过程。 |
整体程序功能的概括:
这个程序是一个基于图神经网络和强化学习的框架,旨在通过策略梯度方法学习在图形环境中的最优策略。
相关文章:

GRL-图强化学习
GRL代码解析 一、agent.py二、drl.py三、env.py四、policy.py五、utils.py 一、agent.py 这个Python文件agent.py实现了一个强化学习(Reinforcement Learning, RL)的智能体,用于在图环境(graph environment)中进行学习…...

昇思25天学习打卡营第22天|Pix2Pix实现图像转换
Pix2Pix图像转换学习总结 概述 Pix2Pix是一种基于条件生成对抗网络(cGAN)的深度学习模型,旨在实现不同图像风格之间的转换,如从语义标签到真实图像、灰度图到彩色图、航拍图到地图等。这一模型由Phillip Isola等人在2017年提出&…...

全感知、全覆盖、全智能的智慧快消开源了。
智慧快消视频监控平台是一款功能强大且简单易用的实时算法视频监控系统。它的愿景是最底层打通各大芯片厂商相互间的壁垒,省去繁琐重复的适配流程,实现芯片、算法、应用的全流程组合,从而大大减少企业级应用约95%的开发成本。AI安全管理平台&…...

ABC364:D - K-th Nearest(二分)
题目 在一条数线上有 NQNQ 个点 A1,…,AN,B1,…,BQA1,…,AN,B1,…,BQ ,其中点 AiAi 的坐标为 aiai ,点 BjBj 的坐标为 bjbj 。 就每个点 j1,2,…,Qj1,2,…,Q 回答下面的问题: 设 XX 是 A1,A2,…,ANA1,A2,…,AN 中最…...

hive中分区与分桶的区别
过去,在学习hive的过程中学习过分桶与分区。但是,却未曾将分区与分桶做详细比较。今天,回顾skew join时涉及到了分桶这一概念,一时间无法区分出分区与分桶的区别。查阅资料,特地记录下来。 一、Hive分区 1.分区一般是…...

Blender材质-PBR与纹理材质
1.PBR PBR:Physically Based Rendering 基于物理的渲染 BRDF:Bidirection Reflectance Distribution Function 双向散射分散函数 材质着色操作如下图: 2.纹理材质 左上角:编辑器类型中选择,着色器编辑器 新建着色器 -> 新建纹理 -> 新…...

微软的Edge浏览器如何设置兼容模式
微软的Edge浏览器如何设置兼容模式? Microsoft Edge 在浏览部分网站的时候,会被标记为不兼容,会有此网站需要Internet Explorer的提示,虽然可以手动点击在 Microsoft Edge 中继续浏览,但是操作起来相对复杂,…...

SpringBoot开启多端口探究(1)
文章目录 前情提要发散探索从management.port开始确定否需要开启额外端口额外端口是如何开启的ManagementContextFactory的故事从哪儿来创建过程 management 相关API如何被注册 小结 前情提要 最近遇到一个需求,在单个服务进程上开启多网络端口,将API的…...

优化算法:2.粒子群算法(PSO)及Python实现
一、定义 粒子群算法(Particle Swarm Optimization,PSO)是一种模拟鸟群觅食行为的优化算法。想象一群鸟在寻找食物,每只鸟都在尝试找到食物最多的位置。它们通过互相交流信息,逐渐向食物最多的地方聚集。PSO就是基于这…...

ThreadLocal面试三道题
针对ThreadLocal的面试题,我将按照由简单到困难的顺序给出三道题目,并附上参考答案的概要。 1. 简单题:请简述ThreadLocal是什么,以及它的主要作用。 参考答案: ThreadLocal是Java中的一个类,用于提供线…...

Git操作指令(已完结)
Git操作指令 一、安装git 1、设置配置信息: # global全局配置 git config --global user.name "Your username" git config --global user.email "Your email"# 显示颜色 git config --global color.ui true# 配置别名,各种指令都…...

大数据采集工具——Flume简介安装配置使用教程
Flume简介&安装配置&使用教程 1、Flume简介 一:概要 Flume 是一个可配置、可靠、高可用的大数据采集工具,主要用于将大量的数据从各种数据源(如日志文件、数据库、本地磁盘等)采集到数据存储系统(主要为Had…...

C语言 #具有展开功能的排雷游戏
文章目录 前言 一、整个排雷游戏的思维梳理 二、整体代码分布布局 三、游戏主体逻辑实现--test.c 四、整个游戏头文件的引用以及函数的声明-- game.h 五、游戏功能的具体实现 -- game.c 六、老六版本 总结 前言 路漫漫其修远兮,吾将上下而求索。 一、整个排…...

npm publish出错,‘proxy‘ config is set properly. See: ‘npm help config‘
问题:使用 npm publish发布项目依赖失败,报错 proxy config is set properly. See: npm help config 1、先查找一下自己的代理 npm config get proxy npm config get https-proxy npm config get registry2、然后将代理和缓存置空 方式一: …...

Springboot 多数据源事务
起因 在一个service方法上使用的事务,其中有方法是调用的多数据源orderDB 但是多数据源没有生效,而是使用的primaryDB 原因 spring 事务实现的方式 以 Transactional 注解为例 (也可以看 TransactionTemplate, 这个流程更简单一点)。 入口:ProxyTransa…...

Python每日学习
我是从c转来学习Python的,总感觉和c相比Python的实操简单,但是由于写c的代码多了,感觉Python的语法好奇怪 就比如说c的开头要有库(就是类似于#include <bits/stdc.h>)而且它每一项的代码结束之后要有一个表示结…...

数据库 执行sql添加删除字段
添加字段: ALTER TABLE 表明 ADD COLUMN 字段名 类型 DEFAULT NULL COMMENT 注释 AFTER 哪个字段后面; 效果: 删除字段: ALTER TABLE 表明 DROP COLUMN 字段;...

前端开发:HTML与CSS
文章目录 前言1.1、CS架构和BS架构1.2、网页构成 HTML1.web开发1.1、最简单的web应用程序1.2、HTTP协议1.2.1 、简介1.2.2、 http协议特性1.3.3、http请求协议与响应协议 2.HTML概述3.HTML标准结构4.标签的语法5.基本标签6.超链接标签6.1、超链接基本使用6.2、锚点 7.img标签8.…...

ctfshow解题方法
171 172 爆库名->爆表名->爆字段名->爆字段值 -1 union select 1,database() ,3 -- //返回数据库名 -1 union select 1,2,group_concat(table_name) from information_schema.tables where table_schema库名 -- //获取数据库里的表名 -1 union select 1,group_concat(…...

探索 Blockly:自定义积木实例
3.实例 3.1.基础块 无输入 , 无输出 3.1.1.json var textOneJson {"type": "sql_test_text_one","message0": " one ","colour": 30,"tooltip": 无输入 , 无输出 };javascriptGenerator.forBlock[sql_test_te…...

MongoDB教程(二十三):关于MongoDB自增机制
💝💝💝首先,欢迎各位来到我的博客,很高兴能够在这里和您见面!希望您在这里不仅可以有所收获,同时也能感受到一份轻松欢乐的氛围,祝你生活愉快! 文章目录 引言一、MongoD…...

展馆导览系统架构解析,从需求分析到上线运维
在物质生活日益丰富的当下,人们对精神世界的追求愈发强烈,博物馆、展馆、纪念馆等场所成为人们丰富知识、滋养心灵的热门选择。与此同时,人们对展馆的导航体验也提出了更高要求,展馆导览系统作为一种基于室内外地图相结合的位置引…...

Servlet详解(超详细)
Servlet详解 文章目录 Servlet详解一、基本概念二、Servlet的使用1、创建Servlet类2、配置Servleta. 使用web.xml配置b. 使用注解配置 3、部署Web应用4、处理HTTP请求和生成响应5、处理表单数据HTML表单Servlet 6、管理会话 三、servlet生命周期1、加载和实例化2、初始化3、 请…...

Meta AI引入Imagine Me功能,上传图片输入提示词即可实现个性化照片
AITOP100平台获悉,Meta 公司在 AI 领域再次迈出了重要的步伐,其发布的 Llama 3.1 开源 AI 模型以及对 Meta AI 功能的更新扩充引发了广泛关注。 其中,新引入的“Imagine Me”功能尤为引人注目。在这一功能下,美国地区的用户只需上…...

常用自启设置
一、开机自启动 1、编辑 vi /lib/systemd/system/nginx.service 文件,没有创建一个 touch nginx.service 然后将如下内容根据具体情况进行修改后,添加到nginx.service文件中: [Unit] Descriptionnginx Afternetwork.target remote-fs.targ…...

模块与组件、模块化与组件化的理解
在React或其他现代JavaScript框架中,模块与组件、模块化与组件化是核心概念,它们对于提高代码的可维护性、复用性和开发效率具有重要意义。以下是对这些概念的理解: 模块与组件 模块(Module) 定义:模块是…...

Rust:cargo的常用命令
1.查看版本 $ cargo --version cargo 1.79.0 (ffa9cf99a 2024-06-03) 2.创建新的项目 $ cargo new hello 创建后的目录结构为 $ tree hello/ hello/ ├── Cargo.toml └── src └── main.rs 3.运行项目 $ cd hello $ cargo run Compiling hello v0.1.0 (/home/c…...

LeetCode 3106.满足距离约束且字典序最小的字符串:模拟(贪心)
【LetMeFly】3106.满足距离约束且字典序最小的字符串:模拟(贪心) 力扣题目链接:https://leetcode.cn/problems/lexicographically-smallest-string-after-operations-with-constraint/ 给你一个字符串 s 和一个整数 k 。 定义函…...

Elasticsearch 与 MySQL 在查询和插入性能上的深度剖析
在当今的数据处理领域,选择合适的数据库对于应用的性能和效率至关重要。Elasticsearch 和 MySQL 作为两款常用的数据库,它们在查询和插入操作上的性能表现各有千秋。本文将对这两款数据库在这两个关键操作上进行详细的对比分析。 一、引言 随着数据量的…...

day4 vue2以及ElementUI
创建vue2项目 可能用到的命令行们 vue create 项目名称 // 创建项目 cd 项目名称 // 只有进入项目下,才能运行 npm run serve // 运行项目 D: //切换盘符 cd .. // 返回到上一级目录 clear // 清空终端 更改 Vue项目的端口配置 基础语法 项目创建完成之后&#…...