当前位置: 首页 > news >正文

ray.rllib 入门实践-5: 训练算法

        前面的博客介绍了ray.rllib中算法的配置和构建,也包含了算法训练的代码。 但是rllib中实现算法训练的方式不止一种,本博客对此进行介绍。很多教程使用 PPOTrainer 进行训练,但是 PPOTrainer 在最近的 ray 版本中已经取消了。

环境配置:

        torch==2.5.1

        ray==2.10.0

        ray[rllib]==2.10.0

        ray[tune]==2.10.0

        ray[serve]==2.10.0

        numpy==1.23.0

        python==3.9.18

方式1: algo.train()

        rllib 中的 Algorithm 类自带了.train() 函数,实现算法训练,前面几个博客教程均是采用的这种方式。这里仅再提供一下示例, 不再赘述:

import os 
from ray.rllib.algorithms.ppo import PPO,PPOConfig
from ray.tune.logger import pretty_print## 配置算法
storage_path = "F:/codes/RLlib_study/ray_results/build_method_3"
config = PPOConfig()
config = config.rollouts(num_rollout_workers=2)
config = config.resources(num_gpus=0,num_cpus_per_worker=1,num_gpus_per_worker=0)
config = config.environment(env="CartPole-v1",env_config={})
config.output = storage_path  ## 设置过程文件的存储路径## 构建算法
algo = config.build()## 训练
for i in range(3):result = algo.train() print(f"episode_{i}")

方式2:tune.Tuner()

        以上方式只能固定训练超参数,不能对训练超参数寻优。ray中还有一个模块 tune, 专门用于算法训练过程中超参数调参。

        在使用tune.Tuner()执行rllib算法训练时, 可以默认为tune背后自动执行了以下操作:

algo = PPOConfig().build() ## 构建算法

result = algo.train()  ## 算法训练

print(pretty_print(result))  ## 每完成一次algo.train, 打印一次阶段性训练结果

algo.save_checkpoint()  ## 保存训练模型

并且遍历了多个超参数组合,多次进行训练。直到达到停止训练的条件(自己配置)。 

        基于tune的rllib训练示例如下(代码篇幅比较大是因为添加的功能模块和注释比较多,后面的介绍主要以 方式一为主,所以对于这种方式,这里介绍的多一些):

import ray 
from ray.rllib.algorithms.ppo import PPO,PPOConfig 
from ray import train, tune
import torch 
import os 
import shutil 
from ray.tune.logger import pretty_print 
import gymnasium as gym ray.init()####  配置算法  ####
config = PPOConfig()
config = config.training(lr=tune.grid_search([0.01, 0.001]))
config = config.environment(env="CartPole-v1")####  配置 tune  ###### 准备 tune 的 stop_condition,多个条件之间是”或“的关系。有一个满足即停止训练。 
##   'episode_reward_mean'关键字将在 ”ray-2.40”版中中被抛弃,
##   届时需要用 'env_runners/episode_return_mean' 替代 'episode_reward_mean'
stop_condition = {'episode_reward_mean':10,    ## 这里设置的结束条件很宽松,所以能够快速结束训练。"training_iteration":3}## 准备 tune 的过程文件存储路径
storage_path = "F:/codes/RLlib_study/ray_results"
os.makedirs(storage_path, exist_ok=True)## 准备 tune 的 checkpoint_config
##    tune 默认保存每个 algo.train() 训练得到的 checkpoint.
##    通过以下配置,可以对此进行自定义修改
checkpoint_config =  train.CheckpointConfig(num_to_keep=None, ## 保存几个checkpoint, None 表示保存所有checkpointcheckpoint_at_end=True) ## 是否在训练结束后保存 checkpoint. ## 配置 tuner
tuner = tune.Tuner(PPO,                                       ## 需要是一个 rllib 的 Algorithm 类, 从 ray.rllib.algorithms 导入, 也可以是自定义的,后面介绍                   run_config = train.RunConfig(stop = stop_condition, checkpoint_config = checkpoint_config,  ## 用于设置保存哪个checkpoint. storage_path = storage_path,            ## 如果不设置, 默认存储路径是 “~/ray-results” 或 “C:/用户/xxx/ray_results”),param_space=config,  ## 这里定义了参数搜索调优空间,是一个 PPOConfig 对象
)## 执行训练                         
results = tuner.fit() ## tuner 返回一个Result表格对象,该对象允许进一步分析训练结果并检索经过训练的智能体的checkpoint。
print("====训练结束====")## 获取最佳训练结果
best_result = results.get_best_result(metric="episode_reward_mean", mode="max")
## 以 "episode_reward_mean" 为选择指标, 从results里面选择checkpoint, 选择模式是“max”,
## 'episode_reward_mean'关键字将在ray-2.40版中中被抛弃,届时需要用 'env_runners/episode_return_mean' 替代 'episode_reward_mean'## 从最佳训练结果中提取对应的 checkpoint , 并保存
checkpoint_save_dir = "F:/codes/RLlib_study/ray_results/best_checkpoints"
os.makedirs(checkpoint_save_dir, exist_ok=True)best_checkpoint = best_result.checkpoint 
if best_checkpoint:with best_checkpoint.as_directory() as checkpoint_dir:print(f"====最佳模型路径位于:{checkpoint_dir}====")## 把最佳模型转存到指定位置。 shutil.rmtree(checkpoint_save_dir)shutil.copytree(checkpoint_dir,checkpoint_save_dir)print(f"====保存最佳模型到:{checkpoint_save_dir}====")## 加载保存的最佳模型
checkpoint_dir = "F:/codes/RLlib_study/ray_results/best_checkpoints"
algo = PPO.from_checkpoint(checkpoint_dir)
print(f"==== 加载最佳模型: {checkpoint_dir}")## evaluate 模型
env_name = "CartPole-v1"
env = gym.make(env_name)## 模型推断: method-1
step = 0
episode_reward = 0
terminated = truncated = False obs,info = env.reset()
while not terminated and not truncated:action = algo.compute_single_action(obs)obs, reward, terminated, truncated, info = env.step(action)episode_reward += rewardstep += 1print(f"step = {step}, reward = {reward}, action = {action}, obs = {obs}, episode_reward = {episode_reward}")

相关文章:

ray.rllib 入门实践-5: 训练算法

前面的博客介绍了ray.rllib中算法的配置和构建,也包含了算法训练的代码。 但是rllib中实现算法训练的方式不止一种,本博客对此进行介绍。很多教程使用 PPOTrainer 进行训练,但是 PPOTrainer 在最近的 ray 版本中已经取消了。 环境配置&#x…...

FPGA 使用 CLOCK_LOW_FANOUT 约束

使用 CLOCK_LOW_FANOUT 约束 您可以使用 CLOCK_LOW_FANOUT 约束在单个时钟区域中包含时钟缓存负载。在由全局时钟缓存直接驱动的时钟网段 上对 CLOCK_LOW_FANOUT 进行设置,而且全局时钟缓存扇出必须低于 2000 个负载。 注释: 当与其他时钟约束配合…...

选择的阶段性质疑

条条大路通罗马,每个人选择的道路,方向并不一样,但不妨碍都可以到达终点,而往往大家会更推崇自己走过的路径。 自己靠什么走向成功,自己用了什么方法,奉行什么原则或者理念,也会尽可能传播这种&…...

固有频率与模态分析

目录 引言 1. 固有频率:物体的“天生节奏” 1.1 定义 1.2 关键特点 1.3 实际意义 2. 有限元中的模态分析:给结构“体检振动” 2.1 模态分析的意义 2.2 实际案例 2.2.1 桥梁模态分析 2.2.2 飞机机翼模态分析 2.2.3 具体事例 3. 模态分析的工具…...

数科OFD证照生成原理剖析与平替方案实现

一、 引言 近年来,随着电子发票的普及,OFD格式作为我国电子发票的标准格式,其应用范围日益广泛。然而,由于不同软件生成的OFD文件存在差异,以及用户对OFD文件处理需求的多样化,OFD套餐转换工具应运而生。本…...

CAN总线数据采集与分析

CAN总线数据采集与分析 目录 CAN总线数据采集与分析1. 引言2. 数据采集2.1 数据采集简介2.2 数据采集实现3. 数据分析3.1 数据分析简介3.2 数据分析实现4. 数据可视化4.1 数据可视化简介4.2 数据可视化实现5. 案例说明5.1 案例1:数据采集实现5.2 案例2:数据分析实现5.3 案例3…...

SpringSecurity:There is no PasswordEncoder mapped for the id “null“

文章目录 一、情景说明二、分析三、解决 一、情景说明 在整合SpringSecurity功能的时候 我先是去实现认证功能 也就是,去数据库比对用户名和密码 相关的类: UserDetailsServiceImpl implements UserDetailsService 用于SpringSecurity查询数据库 Logi…...

ResNet 残差网络

目录 网络结构 残差块(Residual Block) ResNet网络结构示意图 残差块(Residual Block)细节 基本残差块(ResNet-18/34) Bottleneck残差块(ResNet-50/101/152) 残差连接类型对比 变体网…...

CAPL编程常见问题与解决方案深度解析

CAPL编程常见问题与解决方案深度解析 目录 CAPL编程常见问题与解决方案深度解析引言1. CAPL编程核心难点剖析1.1 典型问题分类2. 六大典型问题场景解析案例1:定时器资源竞争导致逻辑错乱2.1.1 问题现象2.1.2 根因分析2.1.3 解决方案案例2:大数据量报文处理引发性能瓶颈2.2.1 …...

信号处理以及队列

下面是一个使用C和POSIX信号处理以及队列的简单示例。这个示例展示了如何使用信号处理程序将信号放入队列中&#xff0c;并在主循环中处理这些信号。 #include <iostream> #include <csignal> #include <queue> #include <mutex> #include <thread…...

Linux pkill 命令使用详解

简介 pkill 命令用于根据进程名称、用户、组或其他属性终止进程。它是 procps-ng 包的一部分&#xff0c;通常比 kill 更受欢迎&#xff0c;因为它无需查找进程 ID (PID)。 常用选项 -<signal>, --signal <signal>&#xff1a;定义要发送给每个匹配进程的信号&am…...

react注意事项

1.状态的定义以及修改 2.排序用lodash进行排序 import _ from lodassh 3.利用className插件进行动态类名的使用 4.表单使用 5.react中获取dom...

【开源免费】基于SpringBoot+Vue.JS在线考试学习交流网页平台(JAVA毕业设计)

本文项目编号 T 158 &#xff0c;文末自助获取源码 \color{red}{T158&#xff0c;文末自助获取源码} T158&#xff0c;文末自助获取源码 目录 一、系统介绍二、数据库设计三、配套教程3.1 启动教程3.2 讲解视频3.3 二次开发教程 四、功能截图五、文案资料5.1 选题背景5.2 国内…...

怎样在PPT中启用演讲者视图功能?

怎样在PPT中启用演讲者视图功能&#xff1f; 如果你曾经参加过重要的会议或者演讲&#xff0c;你就会知道&#xff0c;演讲者视图&#xff08;Presenter View&#xff09;对PPT展示至关重要。它不仅能帮助演讲者更好地掌控演讲节奏&#xff0c;还能提供额外的提示和支持&#…...

UE AController

定义和功能 AController是一种特定于游戏的控制器&#xff0c;在UE框架中用于定义玩家和AI的控制逻辑。AController负责处理玩家输入&#xff0c;并根据这些输入驱动游戏中的角色或其他实体的行为。设计理念 AController设计用于分离控制逻辑与游戏角色&#xff0c;增强游戏设计…...

H264原始码流格式分析

1.H264码流结构组成 H.264裸码流&#xff08;Raw Bitstream&#xff09;数据主要由一系列的NALU&#xff08;网络抽象层单元&#xff09;组成。每个NALU包含一个NAL头和一个RBSP&#xff08;原始字节序列载荷&#xff09;。 1.1 H.264码流层次 H.264码流的结构可以分为两个层…...

JAVA 接口、抽象类的关系和用处 详细解析

接口 - Java教程 - 廖雪峰的官方网站 一个 抽象类 如果实现了一个接口&#xff0c;可以只选择实现接口中的 部分方法&#xff08;所有的方法都要有&#xff0c;可以一部分已经写具体&#xff0c;另一部分继续保留抽象&#xff09;&#xff0c;原因在于&#xff1a; 抽象类本身…...

反向代理模块b

1 概念 1.1 反向代理概念 反向代理是指以代理服务器来接收客户端的请求&#xff0c;然后将请求转发给内部网络上的服务器&#xff0c;将从服务器上得到的结果返回给客户端&#xff0c;此时代理服务器对外表现为一个反向代理服务器。 对于客户端来说&#xff0c;反向代理就相当于…...

Nuitka打包python脚本

Python脚本打包 Python是解释执行语言&#xff0c;需要解释器才能运行代码&#xff0c;这就导致在开发机上编写的代码在别的电脑上无法直接运行&#xff0c;除非目标机器上也安装了Python解释器&#xff0c;有时候还需要额外安装Python第三方包&#xff0c;相当麻烦。 事实上P…...

pytorch线性回归模型预测房价例子

import torch import torch.nn as nn import torch.optim as optim import numpy as np# 1. 创建线性回归模型类 class LinearRegressionModel(nn.Module):def __init__(self):super(LinearRegressionModel, self).__init__()self.linear nn.Linear(1, 1) # 1个输入特征&…...

(LeetCode 每日一题) 3442. 奇偶频次间的最大差值 I (哈希、字符串)

题目&#xff1a;3442. 奇偶频次间的最大差值 I 思路 &#xff1a;哈希&#xff0c;时间复杂度0(n)。 用哈希表来记录每个字符串中字符的分布情况&#xff0c;哈希表这里用数组即可实现。 C版本&#xff1a; class Solution { public:int maxDifference(string s) {int a[26]…...

Android Wi-Fi 连接失败日志分析

1. Android wifi 关键日志总结 (1) Wi-Fi 断开 (CTRL-EVENT-DISCONNECTED reason3) 日志相关部分&#xff1a; 06-05 10:48:40.987 943 943 I wpa_supplicant: wlan0: CTRL-EVENT-DISCONNECTED bssid44:9b:c1:57:a8:90 reason3 locally_generated1解析&#xff1a; CTR…...

MySQL 隔离级别:脏读、幻读及不可重复读的原理与示例

一、MySQL 隔离级别 MySQL 提供了四种隔离级别,用于控制事务之间的并发访问以及数据的可见性,不同隔离级别对脏读、幻读、不可重复读这几种并发数据问题有着不同的处理方式,具体如下: 隔离级别脏读不可重复读幻读性能特点及锁机制读未提交(READ UNCOMMITTED)允许出现允许…...

IGP(Interior Gateway Protocol,内部网关协议)

IGP&#xff08;Interior Gateway Protocol&#xff0c;内部网关协议&#xff09; 是一种用于在一个自治系统&#xff08;AS&#xff09;内部传递路由信息的路由协议&#xff0c;主要用于在一个组织或机构的内部网络中决定数据包的最佳路径。与用于自治系统之间通信的 EGP&…...

页面渲染流程与性能优化

页面渲染流程与性能优化详解&#xff08;完整版&#xff09; 一、现代浏览器渲染流程&#xff08;详细说明&#xff09; 1. 构建DOM树 浏览器接收到HTML文档后&#xff0c;会逐步解析并构建DOM&#xff08;Document Object Model&#xff09;树。具体过程如下&#xff1a; (…...

Java入门学习详细版(一)

大家好&#xff0c;Java 学习是一个系统学习的过程&#xff0c;核心原则就是“理论 实践 坚持”&#xff0c;并且需循序渐进&#xff0c;不可过于着急&#xff0c;本篇文章推出的这份详细入门学习资料将带大家从零基础开始&#xff0c;逐步掌握 Java 的核心概念和编程技能。 …...

自然语言处理——Transformer

自然语言处理——Transformer 自注意力机制多头注意力机制Transformer 虽然循环神经网络可以对具有序列特性的数据非常有效&#xff0c;它能挖掘数据中的时序信息以及语义信息&#xff0c;但是它有一个很大的缺陷——很难并行化。 我们可以考虑用CNN来替代RNN&#xff0c;但是…...

深入解析C++中的extern关键字:跨文件共享变量与函数的终极指南

&#x1f680; C extern 关键字深度解析&#xff1a;跨文件编程的终极指南 &#x1f4c5; 更新时间&#xff1a;2025年6月5日 &#x1f3f7;️ 标签&#xff1a;C | extern关键字 | 多文件编程 | 链接与声明 | 现代C 文章目录 前言&#x1f525;一、extern 是什么&#xff1f;&…...

【Redis】笔记|第8节|大厂高并发缓存架构实战与优化

缓存架构 代码结构 代码详情 功能点&#xff1a; 多级缓存&#xff0c;先查本地缓存&#xff0c;再查Redis&#xff0c;最后才查数据库热点数据重建逻辑使用分布式锁&#xff0c;二次查询更新缓存采用读写锁提升性能采用Redis的发布订阅机制通知所有实例更新本地缓存适用读多…...

C# 表达式和运算符(求值顺序)

求值顺序 表达式可以由许多嵌套的子表达式构成。子表达式的求值顺序可以使表达式的最终值发生 变化。 例如&#xff0c;已知表达式3*52&#xff0c;依照子表达式的求值顺序&#xff0c;有两种可能的结果&#xff0c;如图9-3所示。 如果乘法先执行&#xff0c;结果是17。如果5…...