当前位置: 首页 > news >正文

Stable-Baselines 3 部分源代码解读 2 on_policy_algorithm.py

Stable-Baselines 3 部分源代码解读 ./common/on_policy_algorithm.py

前言

阅读PPO相关的源码,了解一下标准库是如何建立PPO算法以及各种tricks的,以便于自己的复现。

在Pycharm里面一直跳转,可以看到PPO类是最终继承于基类,也就是这个py文件的内容。

所以阅读源码就先从这里开始。: )

import 包

import sys
import time
from typing import Any, Dict, List, Optional, Tuple, Type, TypeVar, Unionimport numpy as np
import torch as th
from gym import spacesfrom stable_baselines3.common.base_class import BaseAlgorithm
from stable_baselines3.common.buffers import DictRolloutBuffer, RolloutBuffer
from stable_baselines3.common.callbacks import BaseCallback
from stable_baselines3.common.policies import ActorCriticPolicy
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule
from stable_baselines3.common.utils import obs_as_tensor, safe_mean
from stable_baselines3.common.vec_env import VecEnv

OnPolicyAlgorithm 类

这个类是PPO算法类的中间曾,夹在底层基类和上层PPO类的之间。

主要是同策略算法,例如:A2C和PPO算法。

policyenvlearning_rate三者与基类base-class.py的一致

n_steps表示每次更新前需要经过的时间步,作者在这里给出了n_steps * n_envs的例子,可能的意思是,如果环境是重复的多个,打算做并行训练的话,那么就是每个子环境的时间步乘以环境的数量

batch_size经验回放的最小批次信息

gammagae_lambdaclip_rangeclip_range_vf均是具有默认值的参数,分别代表“折扣因子”、“GAE奖励中平衡偏置和方差的参数”、“为网络参数而限制幅度的范围”、“为值函数网络参数而限制幅度的范围”

normalize_advantage标志是否需要归一化优势(advantage)

ent_coefvf_coef损失计算的熵系数

max_grad_norm最大的梯度长度,梯度下降的限幅

use_sdesde_sample_freq是状态独立性探索,只适用于连续环境,与基类base-class.py的一致

target_kl限制每次更新时KL散度不能太大,因为clipping限幅不能防止大量更新

monitor_wrapper标志是否需要Gym库提供的监视器包装器

_init_setup_model是否建立模型,也就是是否在创建这个实例过程中创建初始化模型

class OnPolicyAlgorithm(BaseAlgorithm):"""The base for On-Policy algorithms (ex: A2C/PPO).:param policy: The policy model to use (MlpPolicy, CnnPolicy, ...):param env: The environment to learn from (if registered in Gym, can be str):param learning_rate: The learning rate, it can be a functionof the current progress remaining (from 1 to 0):param n_steps: The number of steps to run for each environment per update(i.e. batch size is n_steps * n_env where n_env is number of environment copies running in parallel):param gamma: Discount factor:param gae_lambda: Factor for trade-off of bias vs variance for Generalized Advantage Estimator.Equivalent to classic advantage when set to 1.:param ent_coef: Entropy coefficient for the loss calculation:param vf_coef: Value function coefficient for the loss calculation:param max_grad_norm: The maximum value for the gradient clipping:param use_sde: Whether to use generalized State Dependent Exploration (gSDE)instead of action noise exploration (default: False):param sde_sample_freq: Sample a new noise matrix every n steps when using gSDEDefault: -1 (only sample at the beginning of the rollout):param tensorboard_log: the log location for tensorboard (if None, no logging):param monitor_wrapper: When creating an environment, whether to wrap itor not in a Monitor wrapper.:param policy_kwargs: additional arguments to be passed to the policy on creation:param verbose: Verbosity level: 0 for no output, 1 for info messages (such as device or wrappers used), 2 fordebug messages:param seed: Seed for the pseudo random generators:param device: Device (cpu, cuda, ...) on which the code should be run.Setting it to auto, the code will be run on the GPU if possible.:param _init_setup_model: Whether or not to build the network at the creation of the instance:param supported_action_spaces: The action spaces supported by the algorithm."""def __init__(self,policy: Union[str, Type[ActorCriticPolicy]],env: Union[GymEnv, str],learning_rate: Union[float, Schedule],n_steps: int,gamma: float,gae_lambda: float,ent_coef: float,vf_coef: float,max_grad_norm: float,use_sde: bool,sde_sample_freq: int,tensorboard_log: Optional[str] = None,monitor_wrapper: bool = True,policy_kwargs: Optional[Dict[str, Any]] = None,verbose: int = 0,seed: Optional[int] = None,device: Union[th.device, str] = "auto",_init_setup_model: bool = True,supported_action_spaces: Optional[Tuple[spaces.Space, ...]] = None,):super().__init__(policy=policy,env=env,learning_rate=learning_rate,policy_kwargs=policy_kwargs,verbose=verbose,device=device,use_sde=use_sde,sde_sample_freq=sde_sample_freq,support_multi_env=True,seed=seed,tensorboard_log=tensorboard_log,supported_action_spaces=supported_action_spaces,)self.n_steps = n_stepsself.gamma = gammaself.gae_lambda = gae_lambdaself.ent_coef = ent_coefself.vf_coef = vf_coefself.max_grad_norm = max_grad_normself.rollout_buffer = None# 调用基类的_setup_model()模型if _init_setup_model:self._setup_model()def _setup_model(self) -> None:# 初始化学习率,让他可以调用self._setup_lr_schedule()# 设置随机数种子self.set_random_seed(self.seed)# 设置经验池子的类,如果观测空间是spaces.Dict类那么就赋值DictRolloutBuffer# 如果观测空间不是spaces.Dict类那么就赋值RolloutBufferbuffer_cls = DictRolloutBuffer if isinstance(self.observation_space, spaces.Dict) else RolloutBuffer# 根据类初始化实例经验池子# 初始化经验池子的是时候将设备信息、折扣率、GAE超参数和环境的数量也传进去了self.rollout_buffer = buffer_cls(self.n_steps,self.observation_space,self.action_space,device=self.device,gamma=self.gamma,gae_lambda=self.gae_lambda,n_envs=self.n_envs,)# 初始化策略,直接输入状态空间、动作空间、可调用的学习率、是否使用状态独立性探索,以及自己制定策略# 的时候自己家的模型的参数和激活函数self.policy = self.policy_class(  # pytype:disable=not-instantiableself.observation_space,self.action_space,self.lr_schedule,use_sde=self.use_sde,**self.policy_kwargs  # pytype:disable=not-instantiable)# 将策略放到GPU/CPU中self.policy = self.policy.to(self.device)def collect_rollouts(self,env: VecEnv,callback: BaseCallback,rollout_buffer: RolloutBuffer,n_rollout_steps: int,) -> bool:# 收集环境交互数据# 这个方法使用当前的策略并将交互历史填充到RolloutBuffer经验池子中# rollout的意思是无模型的概念,而不是有模型的RL或规划里面的rollout的概念# env 用于训练的环境# callback 在每个时间步都会调用的回调函数# rollout_buffer 将收集的经验放置到rollout_buffer中# 在每个环境中需要收集的条数# 返回值是True:如果rollout_buffer收集了这么多的经验;返回值是False:如果回调函数提前终止了# 这个rollouts。"""Collect experiences using the current policy and fill a ``RolloutBuffer``.The term rollout here refers to the model-free notion and should notbe used with the concept of rollout used in model-based RL or planning.:param env: The training environment:param callback: Callback that will be called at each step(and at the beginning and end of the rollout):param rollout_buffer: Buffer to fill with rollouts:param n_rollout_steps: Number of experiences to collect per environment:return: True if function returned with at least `n_rollout_steps`collected, False if callback terminated rollout prematurely."""assert self._last_obs is not None, "No previous observation was provided"# 将策略转变到评估模式# Switch to eval mode (this affects batch norm / dropout)self.policy.set_training_mode(False)# 重置经验池子,如果使用状态独立性探索,那么就重置策略的噪声n_steps = 0rollout_buffer.reset()# Sample new weights for the state dependent explorationif self.use_sde:self.policy.reset_noise(env.num_envs)# 回调函数执行on_rollout_start()命令,跳转定义时候没有看到具体定义callback.on_rollout_start()while n_steps < n_rollout_steps:# 如果使用了状态独立性探索,并且达到了探索频率的节点,那么就重置策略的噪声if self.use_sde and self.sde_sample_freq > 0 and n_steps % self.sde_sample_freq == 0:# Sample a new noise matrixself.policy.reset_noise(env.num_envs)# 在断开梯度的情况下,转换观测数据到tensor张量内,然后输入到策略中输出动作、价值和对数概率# 最后再将动作数据转移到numpy中with th.no_grad():# Convert to pytorch tensor or to TensorDictobs_tensor = obs_as_tensor(self._last_obs, self.device)actions, values, log_probs = self.policy(obs_tensor)actions = actions.cpu().numpy()# Rescale and perform action# 归一化动作信息,限制在动作空间的上下界clipped_actions = actions# Clip the actions to avoid out of bound errorif isinstance(self.action_space, spaces.Box):clipped_actions = np.clip(actions, self.action_space.low, self.action_space.high)# 将动作信息输入到环境中,输出新的观测、奖励数值、是否完成以及其他信息。new_obs, rewards, dones, infos = env.step(clipped_actions)# 处理回调函数和更新经验池子self.num_timesteps += env.num_envs# Give access to local variablescallback.update_locals(locals())if callback.on_step() is False:return Falseself._update_info_buffer(infos)n_steps += 1# 如果动作空间是离散空间的话,那么就转变成一个列向量if isinstance(self.action_space, spaces.Discrete):# Reshape in case of discrete actionactions = actions.reshape(-1, 1)# 判断数据是否是终止的# 终止之后计算累计奖励# Handle timeout by bootstraping with value function# see GitHub issue #633for idx, done in enumerate(dones):if (doneand infos[idx].get("terminal_observation") is not Noneand infos[idx].get("TimeLimit.truncated", False)):terminal_obs = self.policy.obs_to_tensor(infos[idx]["terminal_observation"])[0]with th.no_grad():terminal_value = self.policy.predict_values(terminal_obs)[0]rewards[idx] += self.gamma * terminal_value# 经验池子输入的是上一个状态、动作、奖励、上一个回合的开始状态、价值列表以及对数概率rollout_buffer.add(self._last_obs, actions, rewards, self._last_episode_starts, values, log_probs)self._last_obs = new_obsself._last_episode_starts = dones# 计算下一个状态的价值with th.no_grad():# Compute value for the last timestepvalues = self.policy.predict_values(obs_as_tensor(new_obs, self.device))# 计算回报和优势rollout_buffer.compute_returns_and_advantage(last_values=values, dones=dones)callback.on_rollout_end()return Truedef train(self) -> None:# 这个是父类的方法# 在子类的实际PPO类中做了重写"""Consume current rollout data and update policy parameters.Implemented by individual algorithms."""raise NotImplementedErrordef learn(self: SelfOnPolicyAlgorithm,total_timesteps: int,callback: MaybeCallback = None,log_interval: int = 1,tb_log_name: str = "OnPolicyAlgorithm",reset_num_timesteps: bool = True,progress_bar: bool = False,) -> SelfOnPolicyAlgorithm:iteration = 0# 初始化模型total_timesteps, callback = self._setup_learn(total_timesteps,callback,reset_num_timesteps,tb_log_name,progress_bar,)callback.on_training_start(locals(), globals())while self.num_timesteps < total_timesteps:# 这里开始执行上面的函数,在环境中收集数据,收集完了就继续训练# 如果出了故障了,就在接下来跳出循环continue_training = self.collect_rollouts(self.env, callback, self.rollout_buffer, n_rollout_steps=self.n_steps)if continue_training is False:break# 跌带次数+1,并根据当前的训练次数更新学习率iteration += 1self._update_current_progress_remaining(self.num_timesteps, total_timesteps)# 在控制台按照预先定义的频率输出相关信息# Display training infosif log_interval is not None and iteration % log_interval == 0:time_elapsed = max((time.time_ns() - self.start_time) / 1e9, sys.float_info.epsilon)fps = int((self.num_timesteps - self._num_timesteps_at_start) / time_elapsed)self.logger.record("time/iterations", iteration, exclude="tensorboard")if len(self.ep_info_buffer) > 0 and len(self.ep_info_buffer[0]) > 0:self.logger.record("rollout/ep_rew_mean", safe_mean([ep_info["r"] for ep_info in self.ep_info_buffer]))self.logger.record("rollout/ep_len_mean", safe_mean([ep_info["l"] for ep_info in self.ep_info_buffer]))self.logger.record("time/fps", fps)self.logger.record("time/time_elapsed", int(time_elapsed), exclude="tensorboard")self.logger.record("time/total_timesteps", self.num_timesteps, exclude="tensorboard")self.logger.dump(step=self.num_timesteps)self.train()callback.on_training_end()return selfdef _get_torch_save_params(self) -> Tuple[List[str], List[str]]:state_dicts = ["policy", "policy.optimizer"]return state_dicts, []

相关文章:

Stable-Baselines 3 部分源代码解读 2 on_policy_algorithm.py

Stable-Baselines 3 部分源代码解读 ./common/on_policy_algorithm.py 前言 阅读PPO相关的源码&#xff0c;了解一下标准库是如何建立PPO算法以及各种tricks的&#xff0c;以便于自己的复现。 在Pycharm里面一直跳转&#xff0c;可以看到PPO类是最终继承于基类&#xff0c;也…...

15. Qt中OPenGL的参数传递问题

1. 说明 在OPenGL中&#xff0c;需要使用GLSL语言来编写着色器的函数&#xff0c;在顶点着色器和片段着色器之间需要参数值的传递&#xff0c;且在CPU中的数据也需要传递到顶点着色器中进行使用。本文简单介绍几种参数传递的方式&#xff1a; &#xff08;本文内容仅个人理解&…...

注意,这本2区SCI期刊最快18天录用,还差一步录用只因犯了这个错

发表案例分享&#xff1a; 2区医学综合类SCI&#xff0c;仅18天录用&#xff0c;录用后28天见刊 2023.02.10 | 见刊 2023.01.13 | Accepted 2023.01.11 | 提交返修稿 2022.12.26 | 提交论文至期刊部系统 录用截图来源&#xff1a;期刊部投稿系统 见刊截图来源&#xff1a…...

Could not find resource jdbc.properties问题的解决

以如下开头的内容&#xff1a; Exception in thread "main" org.apache.ibatis.exceptions.PersistenceException: ### Error building SqlSession. ### The error may exist in SQL Mapper Configuration 出现以上问题是没有在src/main/resources下创建jdbc.prop…...

【面试题】==与equals区别、Hashcode作用、hashcode相同equals()也一定为true吗?泛型特点与好处

文章目录1. 和 equals 的区别是什么&#xff1f;2.Hashcode的作用3. 两个对象的hashCode() 相同&#xff0c; 那么equals()也一定为 true吗&#xff1f;4.泛型常用特点5.使用泛型的好处&#xff1f;1. 和 equals 的区别是什么&#xff1f; “” 对于基本类型和引用类型 的作…...

Flex布局中的flex属性

1.flex-grow&#xff0c;flex-shrink&#xff0c;flex-basis取值含义 flex-grow&#xff1a; 延申性描述。在满足“延申条件”时&#xff0c;flex容器中的项目会按照设置的flex-grow值的比例来延申&#xff0c;占满容器剩余空间。 取值情况&#xff1a; 取负值无效。取0值表示不…...

SpringBoot + Ant Design Pro Vue实现动态路由和菜单的前后端分离框架

Ant Design Pro Vue默认路由和菜单配置是采用中心化的方式&#xff0c;在 router.config.js统一配置和管理&#xff0c;同时也提供了动态获取路由和菜单的解决方案&#xff0c;并将在2.0.3版本中提供&#xff0c;因到目前为止&#xff0c;官方发布的版本为2.0.2&#xff0c;所以…...

robotframework自动化测试环境搭建

环境说明 win10 python版本&#xff1a;3.8.3rc1 安装清单 安装配置 selenium安装 首先检查pip命令是否安装&#xff1a; C:\Users\name>pipUsage:pip <command> [options]Commands:install Install packages.download Do…...

尚硅谷《Redis7》(小白篇)

尚硅谷《Redis7 》&#xff08;小白篇&#xff09; 02 redis 是什么 官方网站&#xff1a; https://redis.io/ 作者 Git Hub https://github.com/antirez 03 04 05 能做什么 06 去哪下 Download https://redis.io/download/ redis中文文档 https://www.redis.com.cn/docu…...

并非从0开始的c++ day6

并非从0开始的c day6二级指针练习-文件读写位运算位逻辑运算符按位取反 ~位于&#xff08;AND&#xff09;&#xff1a;&位或&#xff08;OR&#xff09;&#xff1a; |位异或: ^移位运算符左移<<右移>>多维数组一维数组数组名一维数组名传入到函数参数中数组指…...

PMP考前冲刺2.22 | 2023新征程,一举拿证

承载2023新一年的好运让我们迈向PMP终点一起冲刺&#xff01;一起拿证&#xff01;每日5道PMP习题助大家上岸PMP&#xff01;&#xff01;&#xff01;题目1-2&#xff1a;1.在新产品开发过程中&#xff0c;项目经理关注到行业排名第一的公司刚刚发布同类型的产品。相比竞品&am…...

RxJava的订阅过程

要使用Rxjava首先要导入两个包&#xff0c;其中rxandroid是rxjava在android中的扩展 implementation io.reactivex:rxandroid:1.2.1implementation io.reactivex:rxjava:1.2.0首先从最基本的Observable的创建到订阅开始分析 Observable.create(new Observable.OnSubscribe<S…...

【2.22】MySQL、Redis、动态规划

认识Redis Redis是一种基于内存的数据库&#xff0c;对数据的读写操作都是在内存中完成的&#xff0c;因此读写速度非常快&#xff0c;常用于缓存&#xff0c;消息队列&#xff0c;分布式锁等场景。 Redis提供了多种数据类型来支持不同的业务场景&#xff0c;比如String(字符串…...

2年手动测试,裸辞后找不到工作怎么办?

我们可以从以下几个方面来具体分析下&#xff0c;想通了&#xff0c;理解透了&#xff0c;才能更好的利用资源提升自己。一、我会什么&#xff1f;先说第一个我会什么&#xff1f;第一反应&#xff1a;我只会功能测试&#xff0c;在之前的4年的中我只做了功能测试。内心存在一种…...

Leetcode6. N字形变换

一、题目描述&#xff1a; 将一个给定字符串 s 根据给定的行数 numRows &#xff0c;以从上往下、从左到右进行 Z 字形排列。 比如输入字符串为 “PAYPALISHIRING” 行数为 3 时&#xff0c;排列如下&#xff1a; 之后&#xff0c;你的输出需要从左往右逐行读取&#xff0c;产…...

将Nginx 核心知识点扒了个底朝天(十)

ngx_http_upstream_module的作用是什么? ngx_http_upstream_module用于定义可通过fastcgi传递、proxy传递、uwsgi传递、memcached传递和scgi传递指令来引用的服务器组。 什么是C10K问题? C10K问题是指无法同时处理大量客户端(10,000)的网络套接字。 Nginx是否支持将请求压…...

GPU显卡环境配置安装

前言 最近公司购买了一张RTX3090的显卡和一台新的服务器&#xff0c;然后对机器的GPU环境进行了安装和配置&#xff0c;然后简单记录一下 环境版本 操作系统&#xff1a;Centos7.8 显卡型号&#xff1a;RTX3090 Python版本&#xff1a;3.7.6 Tensorflow版本&#xff1a;2…...

CIMCAI super unmanned intelligent gate container damage detect

世界港航人工智能领军者企业CIMCAI中集飞瞳打造全球最先进超级智能闸口无人闸口ceaspectusG™视频流动态感知集装箱箱况残损检测箱况残损识别率99%以上&#xff0c;箱信息识别率99.95%以上World port shipping AI leader CIMCAIThe worlds most advanced super intelligent gat…...

web概念概述

软件架构&#xff1a;1. C/S: Client/Server 客户端/服务器端* 在用户本地有一个客户端程序&#xff0c;在远程有一个服务器端程序* 如&#xff1a;QQ&#xff0c;迅雷...* 优点&#xff1a;1. 用户体验好* 缺点&#xff1a;1. 开发、安装&#xff0c;部署&#xff0c;维护 麻烦…...

编译原理笔记(1)绪论

文章目录1.什么是编译2.编译系统的结构3.词法分析概述4.语法分析概述5.语义分析概述6.中间代码生成和后端概述1.什么是编译 编译的定义&#xff1a;将高级语言翻译成汇编语言或机器语言的过程。前者称为源语言&#xff0c;后者称为目标语言。 高级语言源程序的处理过程&#…...

微软PowerBI考试 PL300-选择 Power BI 模型框架【附练习数据】

微软PowerBI考试 PL300-选择 Power BI 模型框架 20 多年来&#xff0c;Microsoft 持续对企业商业智能 (BI) 进行大量投资。 Azure Analysis Services (AAS) 和 SQL Server Analysis Services (SSAS) 基于无数企业使用的成熟的 BI 数据建模技术。 同样的技术也是 Power BI 数据…...

大语言模型如何处理长文本?常用文本分割技术详解

为什么需要文本分割? 引言:为什么需要文本分割?一、基础文本分割方法1. 按段落分割(Paragraph Splitting)2. 按句子分割(Sentence Splitting)二、高级文本分割策略3. 重叠分割(Sliding Window)4. 递归分割(Recursive Splitting)三、生产级工具推荐5. 使用LangChain的…...

Redis数据倾斜问题解决

Redis 数据倾斜问题解析与解决方案 什么是 Redis 数据倾斜 Redis 数据倾斜指的是在 Redis 集群中&#xff0c;部分节点存储的数据量或访问量远高于其他节点&#xff0c;导致这些节点负载过高&#xff0c;影响整体性能。 数据倾斜的主要表现 部分节点内存使用率远高于其他节…...

RNN避坑指南:从数学推导到LSTM/GRU工业级部署实战流程

本文较长&#xff0c;建议点赞收藏&#xff0c;以免遗失。更多AI大模型应用开发学习视频及资料&#xff0c;尽在聚客AI学院。 本文全面剖析RNN核心原理&#xff0c;深入讲解梯度消失/爆炸问题&#xff0c;并通过LSTM/GRU结构实现解决方案&#xff0c;提供时间序列预测和文本生成…...

在web-view 加载的本地及远程HTML中调用uniapp的API及网页和vue页面是如何通讯的?

uni-app 中 Web-view 与 Vue 页面的通讯机制详解 一、Web-view 简介 Web-view 是 uni-app 提供的一个重要组件&#xff0c;用于在原生应用中加载 HTML 页面&#xff1a; 支持加载本地 HTML 文件支持加载远程 HTML 页面实现 Web 与原生的双向通讯可用于嵌入第三方网页或 H5 应…...

Java编程之桥接模式

定义 桥接模式&#xff08;Bridge Pattern&#xff09;属于结构型设计模式&#xff0c;它的核心意图是将抽象部分与实现部分分离&#xff0c;使它们可以独立地变化。这种模式通过组合关系来替代继承关系&#xff0c;从而降低了抽象和实现这两个可变维度之间的耦合度。 用例子…...

Vite中定义@软链接

在webpack中可以直接通过符号表示src路径&#xff0c;但是vite中默认不可以。 如何实现&#xff1a; vite中提供了resolve.alias&#xff1a;通过别名在指向一个具体的路径 在vite.config.js中 import { join } from pathexport default defineConfig({plugins: [vue()],//…...

探索Selenium:自动化测试的神奇钥匙

目录 一、Selenium 是什么1.1 定义与概念1.2 发展历程1.3 功能概述 二、Selenium 工作原理剖析2.1 架构组成2.2 工作流程2.3 通信机制 三、Selenium 的优势3.1 跨浏览器与平台支持3.2 丰富的语言支持3.3 强大的社区支持 四、Selenium 的应用场景4.1 Web 应用自动化测试4.2 数据…...

uniapp 小程序 学习(一)

利用Hbuilder 创建项目 运行到内置浏览器看效果 下载微信小程序 安装到Hbuilder 下载地址 &#xff1a;开发者工具默认安装 设置服务端口号 在Hbuilder中设置微信小程序 配置 找到运行设置&#xff0c;将微信开发者工具放入到Hbuilder中&#xff0c; 打开后出现 如下 bug 解…...

实战三:开发网页端界面完成黑白视频转为彩色视频

​一、需求描述 设计一个简单的视频上色应用&#xff0c;用户可以通过网页界面上传黑白视频&#xff0c;系统会自动将其转换为彩色视频。整个过程对用户来说非常简单直观&#xff0c;不需要了解技术细节。 效果图 ​二、实现思路 总体思路&#xff1a; 用户通过Gradio界面上…...