ray.rllib-入门实践-11: 自定义模型/网络
在ray.rllib中定义和使用自己的模型, 分为以下三个步骤:
1. 定义自己的模型。
2. 向ray注册自定义的模型
3. 在config中配置使用自定义的模型
环境配置:
torch==2.5.1
ray==2.10.0
ray[rllib]==2.10.0
ray[tune]==2.10.0
ray[serve]==2.10.0
numpy==1.23.0
python==3.9.18
一、 定义自己的模型
需要继承自 TFModel 或 TorchModelV2, 并重写需要自定义的方法, 其代码框架如下:
import torch.nn as nn
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2class My_Model(TorchModelV2, nn.Module): ## 重构以下函数, 函数接口不能变。def __init__(self, obs_space, action_space, num_outputs, model_config, name, *, custom_arg1, custom_arg2): ...def forward(self, input_dict, state, seq_lens): ...def value_function(self): ...
示例如下:
## 1. 定义自己的模型
import numpy as np
import torch.nn as nn
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
import gymnasium as gym
from gymnasium import spaces
from ray.rllib.utils.typing import Dict, TensorType, List, ModelConfigDictclass My_Model(TorchModelV2, nn.Module):def __init__(self, obs_space:gym.spaces.Space, action_space:gym.spaces.Space, num_outputs:int, model_config:ModelConfigDict, ## PPOConfig.training(model = ModelConfigDict), 调用的是config.model中的参数name:str,*, custom_arg1, custom_arg2):TorchModelV2.__init__(self, obs_space, action_space, num_outputs,model_config,name)nn.Module.__init__(self)## 测试 custom_arg1 , custom_arg2 传递进来的是什么数值print(f"=========================== custom_arg1 = {custom_arg1}, custom_arg2 = {custom_arg2}")## 定义网络层obs_dim = int(np.product(obs_space.shape))action_dim = int(np.product(action_space.shape))## shareNetself.shared_fc = nn.Linear(obs_dim,128)## actorNetself.actorNet = nn.Linear(128, action_dim)## criticNetself.criticNet = nn.Linear(128,1)self._feature = None def forward(self, input_dict, state, seq_lens):obs = input_dict["obs"].float()self._feature = self.shared_fc.forward(obs)action_logits = self.actorNet.forward(self._feature)return action_logits, state def value_function(self):value = self.criticNet.forward(self._feature).squeeze(1)return value
在rllib中,每个算法的所有网络都被汇集到同一个 ModelV2 类下,供算法调用。actor 网络和critic网络可以在外面定义,也可以在model内部直接定义。 model的forward用于返回actor网络的输出, value_function函数用于返回critic网络的输出。 网络结构和网络层共享可以自定义设置。输入输出接口,需要与上面保持严格一致。
二、 向ray注册自定义模型
ray.rllib.model.ModelCatalog 类,用于向ray注册自定义的model, 还可以用于获取env的 preprocessors 和 action distributions。
import ray
from ray.rllib.models import ModelCatalog # ModelCatalog 类: 用于注册 models, 获取env的 preprocessors 和 action distributions。 ModelCatalog.register_custom_model(model_name="my_torch_model", model_class = My_Model)
三、 在算法中配置并使用自定义的模型
主要是在 config.training() 模块中的 model 子模块中传入两个配置信息:
1)"custom_model":"my_torch_model" ,
2)"custom_model_config": {"custom_arg1": 1, "custom_arg2": 2,}})
两个关键字固定不变,填入自己注册的模型名和对应的模型参数即可。
可以有以下三种配置代码的编写方式:
配置方法1:
## 3. 在训练中使用自定义模型
from ray.rllib.algorithms.ppo import PPOConfig
from ray.tune.logger import pretty_print config = PPOConfig()
config = config.environment("CartPole-v1")
config = config.rollouts(num_rollout_workers=2)
config = config.framework(framework="torch") ## 配置使用自定义的模型
config = config.training(model= {"custom_model":"my_torch_model" , "custom_model_config": {"custom_arg1": 1, "custom_arg2": 2,}})
## 主要在上面两行配置使用自己的模型
## 配置 model 的 "custom_model" 项,用于指定rllib算法所使用的模型
## 配置 model 的 "custom_model_config" 项,用于传入自定义的网络参数,供自定义的model使用。
## 这两个关键词不可更改。algo = config.build()
## 4. 执行训练
result = algo.train()
print(pretty_print(result))
与以上配置内容一样,还可以用以下两种配置写法:
配置方法2:
config = PPOConfig()
config = config.environment("CartPole-v1")
config = config.rollouts(num_rollout_workers=2)
config = config.framework(framework="torch") ## 配置自定义模型
model_config_dict = {}
model_config_dict["custom_model"] = "my_torch_model"
model_config_dict["custom_model_config"] = {"custom_arg1": 1, "custom_arg2": 2,}
config = config.training(model= model_config_dict) algo = config.build()
配置方法3(推荐):
config = PPOConfig()
config = config.environment("CartPole-v1")
config = config.rollouts(num_rollout_workers=2)
config = config.framework(framework="torch") ## 配置自定义模型
config.model["custom_model"] = "my_torch_model"
config.model["custom_model_config"] = {"custom_arg1": 1, "custom_arg2": 2,}algo = config.build()
代码汇总:
"""
在ray.rllib中定义和使用自己的模型, 分为以下三个步骤:
1. 定义自己的模型。 需要继承自 TFModel 或 TorchModelV2, 并重写需要自定义的方法import torch.nn as nnfrom ray.rllib.models.torch.torch_modelv2 import TorchModelV2class CustomTorchModel(TorchModelV2, nn.Module): ## 重构以下函数, 函数接口不能变。 def __init__(self, obs_space, action_space, num_outputs, model_config, name, *, custom_arg1, custom_arg2): ...def forward(self, input_dict, state, seq_lens): ...def value_function(self): ...2. 向ray注册自定义的模型from ray.rllib.models import ModelCatalogModelCatalog.register_custom_model("wzg_torch_model", CustomTorchModel)3. 在config中配置使用自定义的模型model_config_dict = {"custom_model":"wzg_torch_model","custom_model_config":{"custom_arg1": 1,"custom_arg2": 2}}config = PPOConfig()# config = config.training(model = model_config_dict)config.model["custom_model"] = "wzg_torch_model"config.model["custom_model_config"] = {"custom_arg1": 1,"custom_arg2": 2}
"""## 1. 定义自己的模型
import numpy as np
import torch.nn as nn
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
import gymnasium as gym
from gymnasium import spaces
from ray.rllib.utils.typing import Dict, TensorType, List, ModelConfigDictclass My_Model(TorchModelV2, nn.Module):def __init__(self, obs_space:gym.spaces.Space, action_space:gym.spaces.Space, num_outputs:int, model_config:ModelConfigDict, ## PPOConfig.training(model = ModelConfigDict), 调用的是config.model中的参数name:str,*, custom_arg1, custom_arg2):TorchModelV2.__init__(self, obs_space, action_space, num_outputs,model_config,name)nn.Module.__init__(self)## 测试 custom_arg1 , custom_arg2 传递进来的是什么数值print(f"=========================== custom_arg1 = {custom_arg1}, custom_arg2 = {custom_arg2}")## 定义网络层obs_dim = int(np.product(obs_space.shape))action_dim = int(np.product(action_space.shape))## shareNetself.shared_fc = nn.Linear(obs_dim,128)## actorNetself.actorNet = nn.Linear(128, action_dim)## criticNetself.criticNet = nn.Linear(128,1)self._feature = None def forward(self, input_dict, state, seq_lens):obs = input_dict["obs"].float()self._feature = self.shared_fc.forward(obs)action_logits = self.actorNet.forward(self._feature)return action_logits, state def value_function(self):value = self.criticNet.forward(self._feature).squeeze(1)return value ## 2. 向ray注册自定义模型
import ray
from ray.rllib.models import ModelCatalog # ModelCatalog 类: 用于注册 models, 获取env的 preprocessors 和 action distributions。 ModelCatalog.register_custom_model(model_name="my_torch_model", model_class = My_Model)
ray.init()## 3. 在训练中使用自定义模型
from ray.rllib.algorithms.ppo import PPOConfig
from ray.tune.logger import pretty_print config = PPOConfig()
config = config.environment("CartPole-v1")
config = config.rollouts(num_rollout_workers=2)
config = config.framework(framework="torch") # ## 配置自定义模型:方法 1
# config = config.training(model= {"custom_model":"my_torch_model" ,
# "custom_model_config": {"custom_arg1": 1, "custom_arg2": 2,}})
# ## 配置自定义模型:方法 2
# model_config_dict = {}
# model_config_dict["custom_model"] = "my_torch_model"
# model_config_dict["custom_model_config"] = {"custom_arg1": 1, "custom_arg2": 2,}
# config = config.training(model= model_config_dict) ## 配置自定义模型: 方法 3 (个人更喜欢, 因为嵌套层次少)
config.model["custom_model"] = "my_torch_model"
config.model["custom_model_config"] = {"custom_arg1": 1, "custom_arg2": 2,}## 错误方法:
# model_config_dict = {}
# model_config_dict["custom_model"] = "my_torch_model"
# model_config_dict["custom_model_config"] = {"custom_arg1": 1, "custom_arg2": 2,}
# config.model = model_config_dict # 会清空 model 里面的其他默认配置,导致报错algo = config.build()## 4. 执行训练
result = algo.train()
print(pretty_print(result))
相关文章:
ray.rllib-入门实践-11: 自定义模型/网络
在ray.rllib中定义和使用自己的模型, 分为以下三个步骤: 1. 定义自己的模型。 2. 向ray注册自定义的模型 3. 在config中配置使用自定义的模型 环境配置: torch2.5.1 ray2.10.0 ray[rllib]2.10.0 ray[tune]2.10.0 ray[serve]2.10.0 numpy1.23.…...
C语言小项目——通讯录
功能介绍: 1.联系人信息:姓名年龄性别地址电话 2.通讯录中可以存放100个人的信息 3.功能: 1>增加联系人 2>删除指定联系人 3>查找指定联系人的信息 4>修改指定联系人的信息 5显示所有联系人的信息 6>排序(名字&…...
C#牵手Blazor,解锁跨平台Web应用开发新姿势
一、引言 在当今数字化时代,Web 应用已成为人们生活和工作中不可或缺的一部分 ,而开发跨平台的 Web 应用则是满足不同用户需求、扩大应用影响力的关键。C# 作为一种强大的编程语言,拥有丰富的类库和强大的功能,在企业级开发、游戏…...
PCIE模式配置
对于VU系列FPGA,当DMA/Bridge Subsystem for PCI Express IP配置为Bridge模式时,等同于K7系列中的AXI Memory Mapped To PCI Express IP。...
【论文阅读】RT-SKETCH: GOAL-CONDITIONED IMITATION LEARNING FROM HAND-DRAWN SKETCHES
RT-Sketch:基于手绘草图的目标条件模仿学习 摘要:在目标条件模仿学习(imitation learning,IL)中,自然语言和图像通常被用作目标表示。然而,自然语言可能存在歧义,图像则可能过于具体…...
【由浅入深认识Maven】第2部分 maven依赖管理与仓库机制
文章目录 第二篇:Maven依赖管理与仓库机制一、前言二、依赖管理基础1.依赖声明2. 依赖范围(Scope)3. 依赖冲突与排除 三、Maven的仓库机制1. 本地仓库2. 中央仓库3. 远程仓库 四、 版本管理策略1. 固定版本2. 版本范围 五、 总结 第二篇&…...
centos 安全配置基线
CentOS 是一个广泛使用的操作系统,为了确保系统的安全性,需要遵循一系列的安全基线。以下是详细的 CentOS 安全基线配置建议: 通过配置核查,CentOS操作系统未安装入侵防护软件,无法检测到对重要节点进行入侵的 解决方案: 安装入侵…...
备赛蓝桥杯之第十五届职业院校组省赛第一题:智能停车系统
提示:本篇文章仅仅是作者自己目前在备赛蓝桥杯中,自己学习与刷题的学习笔记,写的不好,欢迎大家批评与建议 由于个别题目代码量与题目量偏大,请大家自己去蓝桥杯官网【连接高校和企业 - 蓝桥云课】去寻找原题࿰…...
力扣 Hot 100 题解 (js版)更新ing
🚩哈希表 ✅ 1. 两数之和 Code: 暴力法 复杂度分析: 时间复杂度: ∗ O ( N 2 ) ∗ *O(N^2)* ∗O(N2)∗,其中 N 是数组中的元素数量。最坏情况下数组中任意两个数都要被匹配一次。空间复杂度:O(1)。 /…...
DeepSeek-R1:性能对标 OpenAI,开源助力 AI 生态发展
DeepSeek-R1:性能对标 OpenAI,开源助力 AI 生态发展 在人工智能领域,大模型的竞争一直备受关注。最近,DeepSeek 团队发布了 DeepSeek-R1 模型,并开源了模型权重,这一举动无疑为 AI 领域带来了新的活力。今…...
CY T 4 BB 5 CEB Q 1 A EE GS MCAL配置 - MCU组件
1、ResourceM 配置 选择芯片信号: 2、MCU 配置 2.1 General配置 1) McuDevErrorDetect: - 启用或禁用MCU驱动程序模块的开发错误通知功能。 - 注意:采用DET错误检测机制作为安全机制(故障检测)时,不能禁用开发错误检测。2) McuGetRamStateApi - enable/disable th…...
传输层协议TCP与UDP:深入解析与对比
传输层协议TCP与UDP:深入解析与对比 目录 传输层协议TCP与UDP:深入解析与对比引言1. 传输层协议概述2. TCP协议详解2.1 TCP的特点2.2 TCP的三次握手与四次挥手三次握手四次挥手 2.3 TCP的流量控制与拥塞控制2.4 TCP的可靠性机制 3. UDP协议详解3.1 UDP的…...
校园商铺管理系统设计与实现(代码+数据库+LW)
摘 要 信息数据从传统到当代,是一直在变革当中,突如其来的互联网让传统的信息管理看到了革命性的曙光,因为传统信息管理从时效性,还是安全性,还是可操作性等各个方面来讲,遇到了互联网时代才发现能补上自…...
【JavaWeb学习Day13】
Tlias智能学习系统 需求: 部门管理:查询、新增、修改、删除 员工管理:查询、新增、修改、删除和文件上传 报表统计 登录认证 日志管理 班级、学员管理(实战内容) 部门管理: 01准备工作 开发规范-…...
springboot使用tomcat浅析
springboot使用tomcat浅析 关于外部tomcat maven pom配置 // 打包时jar包改为war包 <packaging>war</packaging>// 内嵌的tomcat的scope标签影响范围设置为provided,只在编译和测试时有效,打包时不带入 <dependency><groupId>…...
rust 自定义错误(十二)
错误定义: let file_content parse_file("test.txt");if let Err(e) file_content {println!("Error: {:?}", e);}let file_content parse_file2("test.txt");if let Err(e) file_content {match e {ParseFileError::File > …...
如何使用CRM数据分析优化销售和客户关系?
嘿,大家好!你有没有想过为什么有些公司在市场上如鱼得水,而另一些却在苦苦挣扎?答案可能就藏在他们的销售策略和客户关系管理(CRM)系统里。今天我们要聊的就是如何通过有效的 CRM 数据分析来提升你的销售额…...
导出地图为pdf文件
有时我们只是想创建能共享的pdf文件,而不是将地图打印出来,arcpy的ExportToPDF()函数可以实现该功能. 操作方法: 1.在arcmap中打开目标地图 2.导入arcpy.mapping模块 import arcpy.mapping as mapping 3.引用当前活动地图文档,把该引用赋值给变量 mxd mapping.MapDocumen…...
Qt 控件与布局管理
1. Qt 控件的父子继承关系 在 Qt 中,继承自 QWidget 的类,通常会在构造函数中接收一个 parent 参数。 这个参数用于指定当前空间的父控件,从而建立控件间的父子关系。 当一个控件被设置为另一控件的子控件时,它会自动成为该父控…...
电力场效应晶体管(电力 MOSFET),全控型器件
电力场效应晶体管(Power MOSFET)属于全控型器件是一种电压触发的电力电子器件,一种载流子导电(单极性器件)一个器件是由一个个小的mosfet组成以下是相关介绍: 工作原理(栅极电压控制漏极电流&a…...
一文讲解Java中的重载、重写及里氏替换原则
提到重载和重写,Java小白应该都不陌生,接下来就通过这篇文章来一起回顾复习下吧! 重载和重写有什么区别呢? 如果一个类有多个名字相同但参数不同的方法,我们通常称这些方法为方法重载Overload。如果方法的功能是一样…...
StarRocks常用命令
目录 1、StarRocks 集群管理&配置命令 2、StarRocks 常用操作命令 3、StarRocks 数据导入和导出 1、StarRocks 集群管理&配置命令 查询 FE 节点信息 SHOW frontends; SHOW PROC /frontends; mysql -h192.168.1.250 -P9030 -uroot -p -e "SHOW PROC /dbs;"…...
Pandas基础02(DataFrame创建/索引/切片/属性/方法/层次化索引)
DataFrame数据结构 DataFrame 是一个二维表格的数据结构,类似于数据库中的表格或 Excel 工作表。它由多个 Series 组成,每个 Series 共享相同的索引。DataFrame 可以看作是具有列名和行索引的二维数组。设计初衷是将Series的使用场景从一维拓展到多维。…...
Meta-CoT:通过元链式思考增强大型语言模型的推理能力
大型语言模型(LLMs)在处理复杂推理任务时面临挑战,这突显了其在模拟人类认知中的不足。尽管 LLMs 擅长生成连贯文本和解决简单问题,但在需要逻辑推理、迭代方法和结果验证的复杂任务(如高级数学问题和抽象问题解决&…...
【时时三省】(C语言基础)二进制输入输出
山不在高,有仙则名。水不在深,有龙则灵。 ----CSDN 时时三省 二进制输入 用fread可以读取fwrite输入的内容 字符串以文本的形式写进去的时候,和以二进制写进去的内容是一样的 整数和浮点型以二进制写进去是不一样的 二进制输出 fwrite 字…...
【go语言】数组和切片
一、数组 1.1 什么是数组 数组是一组数:数组需要是相同类型的数据的集合;数组是需要定义大小的;数组一旦定义了大小是不可以改变的。 1.2 数组的声明 数组和其他变量定义没有什么区别,唯一的就是这个是一组数,需要给…...
10.片元
**片元(Fragment)**是渲染管线中的一个重要概念,可以理解为“潜在的像素”。用通俗易懂的方式来解释: 通俗解释:片元就像候选的颜料点 想象你是一个画家,正在画一幅画: 片元是候选的颜料点&…...
SQL-leetcode—1179. 重新格式化部门表
1179. 重新格式化部门表 表 Department: ---------------------- | Column Name | Type | ---------------------- | id | int | | revenue | int | | month | varchar | ---------------------- 在 SQL 中,(id, month) 是表的联合主键。 这个表格有关…...
k8s简介,k8s环境搭建
目录 K8s简介环境搭建和准备工作修改主机名(所有节点)配置静态IP(所有节点)关闭防火墙和seLinux,清除iptables规则(所有节点)关闭交换分区(所有节点)修改/etc/hosts文件&…...
Docker常用知识点问题
1.dockerfile基础命令及作用 —copy和add区别 —为什么要指定workdir —expose作用,能不能不用,不用会导致什么情况? —env,不用怎么打镜像 —from 2.dockerfile编写规范 —jdk版本 —依赖问题 —shell指令引用 —字体和时区配置 …...
