当前位置: 首页 > news >正文

ray.rllib 入门实践-5: 训练算法

        前面的博客介绍了ray.rllib中算法的配置和构建,也包含了算法训练的代码。 但是rllib中实现算法训练的方式不止一种,本博客对此进行介绍。很多教程使用 PPOTrainer 进行训练,但是 PPOTrainer 在最近的 ray 版本中已经取消了。

环境配置:

        torch==2.5.1

        ray==2.10.0

        ray[rllib]==2.10.0

        ray[tune]==2.10.0

        ray[serve]==2.10.0

        numpy==1.23.0

        python==3.9.18

方式1: algo.train()

        rllib 中的 Algorithm 类自带了.train() 函数,实现算法训练,前面几个博客教程均是采用的这种方式。这里仅再提供一下示例, 不再赘述:

import os 
from ray.rllib.algorithms.ppo import PPO,PPOConfig
from ray.tune.logger import pretty_print## 配置算法
storage_path = "F:/codes/RLlib_study/ray_results/build_method_3"
config = PPOConfig()
config = config.rollouts(num_rollout_workers=2)
config = config.resources(num_gpus=0,num_cpus_per_worker=1,num_gpus_per_worker=0)
config = config.environment(env="CartPole-v1",env_config={})
config.output = storage_path  ## 设置过程文件的存储路径## 构建算法
algo = config.build()## 训练
for i in range(3):result = algo.train() print(f"episode_{i}")

方式2:tune.Tuner()

        以上方式只能固定训练超参数,不能对训练超参数寻优。ray中还有一个模块 tune, 专门用于算法训练过程中超参数调参。

        在使用tune.Tuner()执行rllib算法训练时, 可以默认为tune背后自动执行了以下操作:

algo = PPOConfig().build() ## 构建算法

result = algo.train()  ## 算法训练

print(pretty_print(result))  ## 每完成一次algo.train, 打印一次阶段性训练结果

algo.save_checkpoint()  ## 保存训练模型

并且遍历了多个超参数组合,多次进行训练。直到达到停止训练的条件(自己配置)。 

        基于tune的rllib训练示例如下(代码篇幅比较大是因为添加的功能模块和注释比较多,后面的介绍主要以 方式一为主,所以对于这种方式,这里介绍的多一些):

import ray 
from ray.rllib.algorithms.ppo import PPO,PPOConfig 
from ray import train, tune
import torch 
import os 
import shutil 
from ray.tune.logger import pretty_print 
import gymnasium as gym ray.init()####  配置算法  ####
config = PPOConfig()
config = config.training(lr=tune.grid_search([0.01, 0.001]))
config = config.environment(env="CartPole-v1")####  配置 tune  ###### 准备 tune 的 stop_condition,多个条件之间是”或“的关系。有一个满足即停止训练。 
##   'episode_reward_mean'关键字将在 ”ray-2.40”版中中被抛弃,
##   届时需要用 'env_runners/episode_return_mean' 替代 'episode_reward_mean'
stop_condition = {'episode_reward_mean':10,    ## 这里设置的结束条件很宽松,所以能够快速结束训练。"training_iteration":3}## 准备 tune 的过程文件存储路径
storage_path = "F:/codes/RLlib_study/ray_results"
os.makedirs(storage_path, exist_ok=True)## 准备 tune 的 checkpoint_config
##    tune 默认保存每个 algo.train() 训练得到的 checkpoint.
##    通过以下配置,可以对此进行自定义修改
checkpoint_config =  train.CheckpointConfig(num_to_keep=None, ## 保存几个checkpoint, None 表示保存所有checkpointcheckpoint_at_end=True) ## 是否在训练结束后保存 checkpoint. ## 配置 tuner
tuner = tune.Tuner(PPO,                                       ## 需要是一个 rllib 的 Algorithm 类, 从 ray.rllib.algorithms 导入, 也可以是自定义的,后面介绍                   run_config = train.RunConfig(stop = stop_condition, checkpoint_config = checkpoint_config,  ## 用于设置保存哪个checkpoint. storage_path = storage_path,            ## 如果不设置, 默认存储路径是 “~/ray-results” 或 “C:/用户/xxx/ray_results”),param_space=config,  ## 这里定义了参数搜索调优空间,是一个 PPOConfig 对象
)## 执行训练                         
results = tuner.fit() ## tuner 返回一个Result表格对象,该对象允许进一步分析训练结果并检索经过训练的智能体的checkpoint。
print("====训练结束====")## 获取最佳训练结果
best_result = results.get_best_result(metric="episode_reward_mean", mode="max")
## 以 "episode_reward_mean" 为选择指标, 从results里面选择checkpoint, 选择模式是“max”,
## 'episode_reward_mean'关键字将在ray-2.40版中中被抛弃,届时需要用 'env_runners/episode_return_mean' 替代 'episode_reward_mean'## 从最佳训练结果中提取对应的 checkpoint , 并保存
checkpoint_save_dir = "F:/codes/RLlib_study/ray_results/best_checkpoints"
os.makedirs(checkpoint_save_dir, exist_ok=True)best_checkpoint = best_result.checkpoint 
if best_checkpoint:with best_checkpoint.as_directory() as checkpoint_dir:print(f"====最佳模型路径位于:{checkpoint_dir}====")## 把最佳模型转存到指定位置。 shutil.rmtree(checkpoint_save_dir)shutil.copytree(checkpoint_dir,checkpoint_save_dir)print(f"====保存最佳模型到:{checkpoint_save_dir}====")## 加载保存的最佳模型
checkpoint_dir = "F:/codes/RLlib_study/ray_results/best_checkpoints"
algo = PPO.from_checkpoint(checkpoint_dir)
print(f"==== 加载最佳模型: {checkpoint_dir}")## evaluate 模型
env_name = "CartPole-v1"
env = gym.make(env_name)## 模型推断: method-1
step = 0
episode_reward = 0
terminated = truncated = False obs,info = env.reset()
while not terminated and not truncated:action = algo.compute_single_action(obs)obs, reward, terminated, truncated, info = env.step(action)episode_reward += rewardstep += 1print(f"step = {step}, reward = {reward}, action = {action}, obs = {obs}, episode_reward = {episode_reward}")

相关文章:

ray.rllib 入门实践-5: 训练算法

前面的博客介绍了ray.rllib中算法的配置和构建,也包含了算法训练的代码。 但是rllib中实现算法训练的方式不止一种,本博客对此进行介绍。很多教程使用 PPOTrainer 进行训练,但是 PPOTrainer 在最近的 ray 版本中已经取消了。 环境配置&#x…...

FPGA 使用 CLOCK_LOW_FANOUT 约束

使用 CLOCK_LOW_FANOUT 约束 您可以使用 CLOCK_LOW_FANOUT 约束在单个时钟区域中包含时钟缓存负载。在由全局时钟缓存直接驱动的时钟网段 上对 CLOCK_LOW_FANOUT 进行设置,而且全局时钟缓存扇出必须低于 2000 个负载。 注释: 当与其他时钟约束配合…...

选择的阶段性质疑

条条大路通罗马,每个人选择的道路,方向并不一样,但不妨碍都可以到达终点,而往往大家会更推崇自己走过的路径。 自己靠什么走向成功,自己用了什么方法,奉行什么原则或者理念,也会尽可能传播这种&…...

固有频率与模态分析

目录 引言 1. 固有频率:物体的“天生节奏” 1.1 定义 1.2 关键特点 1.3 实际意义 2. 有限元中的模态分析:给结构“体检振动” 2.1 模态分析的意义 2.2 实际案例 2.2.1 桥梁模态分析 2.2.2 飞机机翼模态分析 2.2.3 具体事例 3. 模态分析的工具…...

数科OFD证照生成原理剖析与平替方案实现

一、 引言 近年来,随着电子发票的普及,OFD格式作为我国电子发票的标准格式,其应用范围日益广泛。然而,由于不同软件生成的OFD文件存在差异,以及用户对OFD文件处理需求的多样化,OFD套餐转换工具应运而生。本…...

CAN总线数据采集与分析

CAN总线数据采集与分析 目录 CAN总线数据采集与分析1. 引言2. 数据采集2.1 数据采集简介2.2 数据采集实现3. 数据分析3.1 数据分析简介3.2 数据分析实现4. 数据可视化4.1 数据可视化简介4.2 数据可视化实现5. 案例说明5.1 案例1:数据采集实现5.2 案例2:数据分析实现5.3 案例3…...

SpringSecurity:There is no PasswordEncoder mapped for the id “null“

文章目录 一、情景说明二、分析三、解决 一、情景说明 在整合SpringSecurity功能的时候 我先是去实现认证功能 也就是,去数据库比对用户名和密码 相关的类: UserDetailsServiceImpl implements UserDetailsService 用于SpringSecurity查询数据库 Logi…...

ResNet 残差网络

目录 网络结构 残差块(Residual Block) ResNet网络结构示意图 残差块(Residual Block)细节 基本残差块(ResNet-18/34) Bottleneck残差块(ResNet-50/101/152) 残差连接类型对比 变体网…...

CAPL编程常见问题与解决方案深度解析

CAPL编程常见问题与解决方案深度解析 目录 CAPL编程常见问题与解决方案深度解析引言1. CAPL编程核心难点剖析1.1 典型问题分类2. 六大典型问题场景解析案例1:定时器资源竞争导致逻辑错乱2.1.1 问题现象2.1.2 根因分析2.1.3 解决方案案例2:大数据量报文处理引发性能瓶颈2.2.1 …...

信号处理以及队列

下面是一个使用C和POSIX信号处理以及队列的简单示例。这个示例展示了如何使用信号处理程序将信号放入队列中&#xff0c;并在主循环中处理这些信号。 #include <iostream> #include <csignal> #include <queue> #include <mutex> #include <thread…...

Linux pkill 命令使用详解

简介 pkill 命令用于根据进程名称、用户、组或其他属性终止进程。它是 procps-ng 包的一部分&#xff0c;通常比 kill 更受欢迎&#xff0c;因为它无需查找进程 ID (PID)。 常用选项 -<signal>, --signal <signal>&#xff1a;定义要发送给每个匹配进程的信号&am…...

react注意事项

1.状态的定义以及修改 2.排序用lodash进行排序 import _ from lodassh 3.利用className插件进行动态类名的使用 4.表单使用 5.react中获取dom...

【开源免费】基于SpringBoot+Vue.JS在线考试学习交流网页平台(JAVA毕业设计)

本文项目编号 T 158 &#xff0c;文末自助获取源码 \color{red}{T158&#xff0c;文末自助获取源码} T158&#xff0c;文末自助获取源码 目录 一、系统介绍二、数据库设计三、配套教程3.1 启动教程3.2 讲解视频3.3 二次开发教程 四、功能截图五、文案资料5.1 选题背景5.2 国内…...

怎样在PPT中启用演讲者视图功能?

怎样在PPT中启用演讲者视图功能&#xff1f; 如果你曾经参加过重要的会议或者演讲&#xff0c;你就会知道&#xff0c;演讲者视图&#xff08;Presenter View&#xff09;对PPT展示至关重要。它不仅能帮助演讲者更好地掌控演讲节奏&#xff0c;还能提供额外的提示和支持&#…...

UE AController

定义和功能 AController是一种特定于游戏的控制器&#xff0c;在UE框架中用于定义玩家和AI的控制逻辑。AController负责处理玩家输入&#xff0c;并根据这些输入驱动游戏中的角色或其他实体的行为。设计理念 AController设计用于分离控制逻辑与游戏角色&#xff0c;增强游戏设计…...

H264原始码流格式分析

1.H264码流结构组成 H.264裸码流&#xff08;Raw Bitstream&#xff09;数据主要由一系列的NALU&#xff08;网络抽象层单元&#xff09;组成。每个NALU包含一个NAL头和一个RBSP&#xff08;原始字节序列载荷&#xff09;。 1.1 H.264码流层次 H.264码流的结构可以分为两个层…...

JAVA 接口、抽象类的关系和用处 详细解析

接口 - Java教程 - 廖雪峰的官方网站 一个 抽象类 如果实现了一个接口&#xff0c;可以只选择实现接口中的 部分方法&#xff08;所有的方法都要有&#xff0c;可以一部分已经写具体&#xff0c;另一部分继续保留抽象&#xff09;&#xff0c;原因在于&#xff1a; 抽象类本身…...

反向代理模块b

1 概念 1.1 反向代理概念 反向代理是指以代理服务器来接收客户端的请求&#xff0c;然后将请求转发给内部网络上的服务器&#xff0c;将从服务器上得到的结果返回给客户端&#xff0c;此时代理服务器对外表现为一个反向代理服务器。 对于客户端来说&#xff0c;反向代理就相当于…...

Nuitka打包python脚本

Python脚本打包 Python是解释执行语言&#xff0c;需要解释器才能运行代码&#xff0c;这就导致在开发机上编写的代码在别的电脑上无法直接运行&#xff0c;除非目标机器上也安装了Python解释器&#xff0c;有时候还需要额外安装Python第三方包&#xff0c;相当麻烦。 事实上P…...

pytorch线性回归模型预测房价例子

import torch import torch.nn as nn import torch.optim as optim import numpy as np# 1. 创建线性回归模型类 class LinearRegressionModel(nn.Module):def __init__(self):super(LinearRegressionModel, self).__init__()self.linear nn.Linear(1, 1) # 1个输入特征&…...

React Native 导航系统实战(React Navigation)

导航系统实战&#xff08;React Navigation&#xff09; React Navigation 是 React Native 应用中最常用的导航库之一&#xff0c;它提供了多种导航模式&#xff0c;如堆栈导航&#xff08;Stack Navigator&#xff09;、标签导航&#xff08;Tab Navigator&#xff09;和抽屉…...

Zustand 状态管理库:极简而强大的解决方案

Zustand 是一个轻量级、快速和可扩展的状态管理库&#xff0c;特别适合 React 应用。它以简洁的 API 和高效的性能解决了 Redux 等状态管理方案中的繁琐问题。 核心优势对比 基本使用指南 1. 创建 Store // store.js import create from zustandconst useStore create((set)…...

QMC5883L的驱动

简介 本篇文章的代码已经上传到了github上面&#xff0c;开源代码 作为一个电子罗盘模块&#xff0c;我们可以通过I2C从中获取偏航角yaw&#xff0c;相对于六轴陀螺仪的yaw&#xff0c;qmc5883l几乎不会零飘并且成本较低。 参考资料 QMC5883L磁场传感器驱动 QMC5883L磁力计…...

Swift 协议扩展精进之路:解决 CoreData 托管实体子类的类型不匹配问题(下)

概述 在 Swift 开发语言中&#xff0c;各位秃头小码农们可以充分利用语法本身所带来的便利去劈荆斩棘。我们还可以恣意利用泛型、协议关联类型和协议扩展来进一步简化和优化我们复杂的代码需求。 不过&#xff0c;在涉及到多个子类派生于基类进行多态模拟的场景下&#xff0c;…...

前端导出带有合并单元格的列表

// 导出async function exportExcel(fileName "共识调整.xlsx") {// 所有数据const exportData await getAllMainData();// 表头内容let fitstTitleList [];const secondTitleList [];allColumns.value.forEach(column > {if (!column.children) {fitstTitleL…...

将对透视变换后的图像使用Otsu进行阈值化,来分离黑色和白色像素。这句话中的Otsu是什么意思?

Otsu 是一种自动阈值化方法&#xff0c;用于将图像分割为前景和背景。它通过最小化图像的类内方差或等价地最大化类间方差来选择最佳阈值。这种方法特别适用于图像的二值化处理&#xff0c;能够自动确定一个阈值&#xff0c;将图像中的像素分为黑色和白色两类。 Otsu 方法的原…...

从零实现STL哈希容器:unordered_map/unordered_set封装详解

本篇文章是对C学习的STL哈希容器自主实现部分的学习分享 希望也能为你带来些帮助~ 那咱们废话不多说&#xff0c;直接开始吧&#xff01; 一、源码结构分析 1. SGISTL30实现剖析 // hash_set核心结构 template <class Value, class HashFcn, ...> class hash_set {ty…...

鱼香ros docker配置镜像报错:https://registry-1.docker.io/v2/

使用鱼香ros一件安装docker时的https://registry-1.docker.io/v2/问题 一键安装指令 wget http://fishros.com/install -O fishros && . fishros出现问题&#xff1a;docker pull 失败 网络不同&#xff0c;需要使用镜像源 按照如下步骤操作 sudo vi /etc/docker/dae…...

【笔记】WSL 中 Rust 安装与测试完整记录

#工作记录 WSL 中 Rust 安装与测试完整记录 1. 运行环境 系统&#xff1a;Ubuntu 24.04 LTS (WSL2)架构&#xff1a;x86_64 (GNU/Linux)Rust 版本&#xff1a;rustc 1.87.0 (2025-05-09)Cargo 版本&#xff1a;cargo 1.87.0 (2025-05-06) 2. 安装 Rust 2.1 使用 Rust 官方安…...

Web中间件--tomcat学习

Web中间件–tomcat Java虚拟机详解 什么是JAVA虚拟机 Java虚拟机是一个抽象的计算机&#xff0c;它可以执行Java字节码。Java虚拟机是Java平台的一部分&#xff0c;Java平台由Java语言、Java API和Java虚拟机组成。Java虚拟机的主要作用是将Java字节码转换为机器代码&#x…...