当前位置: 首页 > news >正文

强化学习_06_pytorch-PPO实践(Hopper-v4)

一、PPO优化

PPO的简介和实践可以看笔者之前的文章 强化学习_06_pytorch-PPO实践(Pendulum-v1)
针对之前的PPO做了主要以下优化:

  1. batch_normalize: 在mini_batch 函数中进行adv的normalize, 加速模型对adv的学习
  2. policyNet采用beta分布(0~1): 同时增加MaxMinScale 将beta分布产出值转换到action的分布空间
  3. 收集多个episode的数据,依次计算adv,后合并到一个dataloader中进行遍历:加速模型收敛

1.1 PPO2 代码

详细可见 Github: PPO2.py

class PPO2:"""PPO2算法, 采用截断方式"""def __init__(self,state_dim: int,actor_hidden_layers_dim: typ.List,critic_hidden_layers_dim: typ.List,action_dim: int,actor_lr: float,critic_lr: float,gamma: float,PPO_kwargs: typ.Dict,device: torch.device,reward_func: typ.Optional[typ.Callable]=None):dist_type = PPO_kwargs.get('dist_type', 'beta')self.dist_type = dist_typeself.actor = policyNet(state_dim, actor_hidden_layers_dim, action_dim, dist_type=dist_type).to(device)self.critic = valueNet(state_dim, critic_hidden_layers_dim).to(device)self.actor_lr = actor_lrself.critic_lr = critic_lrself.actor_opt = torch.optim.Adam(self.actor.parameters(), lr=actor_lr)self.critic_opt = torch.optim.Adam(self.critic.parameters(), lr=critic_lr)self.gamma = gammaself.lmbda = PPO_kwargs['lmbda']self.k_epochs = PPO_kwargs['k_epochs'] # 一条序列的数据用来训练的轮次self.eps = PPO_kwargs['eps'] # PPO中截断范围的参数self.sgd_batch_size = PPO_kwargs.get('sgd_batch_size', 512)self.minibatch_size = PPO_kwargs.get('minibatch_size', 128)self.action_bound = PPO_kwargs.get('action_bound', 1.0)self.action_low = -1 * self.action_bound self.action_high = self.action_boundif 'action_space' in PPO_kwargs:self.action_low = self.action_space.lowself.action_high = self.action_space.highself.count = 0 self.device = deviceself.reward_func = reward_funcself.min_batch_collate_func = partial(mini_batch, mini_batch_size=self.minibatch_size)def _action_fix(self, act):if self.dist_type == 'beta':# beta 0-1 -> low ~ highreturn act * (self.action_high - self.action_low) + self.action_lowreturn act def _action_return(self, act):if self.dist_type == 'beta':# low ~ high -> 0-1 act_out = (act - self.action_low) / (self.action_high - self.action_low)return act_out * 1 + 0return act def policy(self, state):state = torch.FloatTensor(np.array([state])).to(self.device)action_dist = self.actor.get_dist(state, self.action_bound)action = action_dist.sample()action = self._action_fix(action)return action.cpu().detach().numpy()[0]def _one_deque_pp(self, samples: deque):state, action, reward, next_state, done = zip(*samples)state = torch.FloatTensor(np.stack(state)).to(self.device)action = torch.FloatTensor(np.stack(action)).to(self.device)reward = torch.tensor(np.stack(reward)).view(-1, 1).to(self.device)if self.reward_func is not None:reward = self.reward_func(reward)next_state = torch.FloatTensor(np.stack(next_state)).to(self.device)done = torch.FloatTensor(np.stack(done)).view(-1, 1).to(self.device)old_v = self.critic(state)td_target = reward + self.gamma * self.critic(next_state) * (1 - done)td_delta = td_target - old_vadvantage = compute_advantage(self.gamma, self.lmbda, td_delta, done).to(self.device)# recomputetd_target = advantage + old_vaction_dists = self.actor.get_dist(state, self.action_bound)old_log_probs = action_dists.log_prob(self._action_return(action))return state, action, old_log_probs, advantage, td_targetdef data_prepare(self, samples_list: List[deque]):state_pt_list = []action_pt_list = []old_log_probs_pt_list = []advantage_pt_list = []td_target_pt_list = []for sample in samples_list:state_i, action_i, old_log_probs_i, advantage_i, td_target_i = self._one_deque_pp(sample)state_pt_list.append(state_i)action_pt_list.append(action_i)old_log_probs_pt_list.append(old_log_probs_i)advantage_pt_list.append(advantage_i)td_target_pt_list.append(td_target_i)state = torch.concat(state_pt_list) action = torch.concat(action_pt_list) old_log_probs = torch.concat(old_log_probs_pt_list) advantage = torch.concat(advantage_pt_list) td_target = torch.concat(td_target_pt_list)return state, action, old_log_probs, advantage, td_targetdef update(self, samples_list: List[deque]):state, action, old_log_probs, advantage, td_target = self.data_prepare(samples_list)if len(old_log_probs.shape) == 2:old_log_probs = old_log_probs.sum(dim=1)d_set = memDataset(state, action, old_log_probs, advantage, td_target)train_loader = DataLoader(d_set,batch_size=self.sgd_batch_size,shuffle=True,drop_last=True,collate_fn=self.min_batch_collate_func)for _ in range(self.k_epochs):for state_, action_, old_log_prob, adv, td_v in train_loader:action_dists = self.actor.get_dist(state_, self.action_bound)log_prob = action_dists.log_prob(self._action_return(action_))if len(log_prob.shape) == 2:log_prob = log_prob.sum(dim=1)# e(log(a/b))ratio = torch.exp(log_prob - old_log_prob.detach())surr1 = ratio * advsurr2 = torch.clamp(ratio, 1 - self.eps, 1 + self.eps) * advactor_loss = torch.mean(-torch.min(surr1, surr2)).float()critic_loss = torch.mean(F.mse_loss(self.critic(state_).float(), td_v.detach().float())).float()self.actor_opt.zero_grad()self.critic_opt.zero_grad()actor_loss.backward()critic_loss.backward()torch.nn.utils.clip_grad_norm_(self.actor.parameters(), 0.5) torch.nn.utils.clip_grad_norm_(self.critic.parameters(), 0.5) self.actor_opt.step()self.critic_opt.step()return Truedef save_model(self, file_path):if not os.path.exists(file_path):os.makedirs(file_path)act_f = os.path.join(file_path, 'PPO_actor.ckpt')critic_f = os.path.join(file_path, 'PPO_critic.ckpt')torch.save(self.actor.state_dict(), act_f)torch.save(self.critic.state_dict(), critic_f)def load_model(self, file_path):act_f = os.path.join(file_path, 'PPO_actor.ckpt')critic_f = os.path.join(file_path, 'PPO_critic.ckpt')self.actor.load_state_dict(torch.load(act_f, map_location='cpu'))self.critic.load_state_dict(torch.load(critic_f, map_location='cpu'))self.actor.to(self.device)self.critic.to(self.device)self.actor_opt = torch.optim.Adam(self.actor.parameters(), lr=self.actor_lr)self.critic_opt = torch.optim.Adam(self.critic.parameters(), lr=self.critic_lr)def train(self):self.training = Trueself.actor.train()self.critic.train()def eval(self):self.training = Falseself.actor.eval()self.critic.eval()

二、 Pytorch实践

2.1 智能体构建与训练

PPO2主要是收集多轮的结果序列进行训练,增加训练轮数,适当降低学习率,稍微增Actor和Critic的网络深度
详细可见 Github: test_ppo.Hopper_v4_ppo2_test

import os
from os.path import dirname
import sys
import gymnasium as gym
import torch
# 笔者的github-RL库
from RLAlgo.PPO import PPO
from RLAlgo.PPO2 import PPO2
from RLUtils import train_on_policy, random_play, play, Config, gym_env_descenv_name = 'Hopper-v4'
gym_env_desc(env_name)
print("gym.__version__ = ", gym.__version__ )
path_ = os.path.dirname(__file__) 
env = gym.make(env_name, exclude_current_positions_from_observation=True,# healthy_reward=0
)
cfg = Config(env, # 环境参数save_path=os.path.join(path_, "test_models" ,'PPO_Hopper-v4_test2'), seed=42,# 网络参数actor_hidden_layers_dim=[256, 256, 256],critic_hidden_layers_dim=[256, 256, 256],# agent参数actor_lr=1.5e-4,critic_lr=5.5e-4,gamma=0.99,# 训练参数num_episode=12500,off_buffer_size=512,off_minimal_size=510,max_episode_steps=500,PPO_kwargs={'lmbda': 0.9,'eps': 0.25,'k_epochs': 4, 'sgd_batch_size': 128,'minibatch_size': 12, 'actor_bound': 1,'dist_type': 'beta'}
)
agent = PPO2(state_dim=cfg.state_dim,actor_hidden_layers_dim=cfg.actor_hidden_layers_dim,critic_hidden_layers_dim=cfg.critic_hidden_layers_dim,action_dim=cfg.action_dim,actor_lr=cfg.actor_lr,critic_lr=cfg.critic_lr,gamma=cfg.gamma,PPO_kwargs=cfg.PPO_kwargs,device=cfg.device,reward_func=None
)
agent.train()
train_on_policy(env, agent, cfg, wandb_flag=False, train_without_seed=True, test_ep_freq=1000, online_collect_nums=cfg.off_buffer_size,test_episode_count=5)

2.2 训练出的智能体观测

最后将训练的最好的网络拿出来进行观察

agent.load_model(cfg.save_path)
agent.eval()
env_ = gym.make(env_name, exclude_current_positions_from_observation=True,render_mode='human') # , render_mode='human'
play(env_, agent, cfg, episode_count=3, play_without_seed=True, render=True)

在这里插入图片描述

相关文章:

强化学习_06_pytorch-PPO实践(Hopper-v4)

一、PPO优化 PPO的简介和实践可以看笔者之前的文章 强化学习_06_pytorch-PPO实践(Pendulum-v1) 针对之前的PPO做了主要以下优化: batch_normalize: 在mini_batch 函数中进行adv的normalize, 加速模型对adv的学习policyNet采用beta分布(0~1): 同时增加MaxMinScale …...

Scala Intellij编译错误:idea报错xxxx“is already defined as”

今天写scala代码时,Idea报了这样的错误,如下图所示: 一般情况下原因分两种: 第一是我们定义的类或对象重复多次出现,编译器无法确定使用哪个定义。 这通常是由于以下几个原因导致的: 重复定义:在同一个文件…...

面试笔记系列五之MySql+Mybaits基础知识点整理及常见面试题

myibatis执行过程 1读取MyBatis的配置文件。 mybatis-config.xml为MyBatis的全局配置文件,用于配置数据库连接信息。 2加载映射文件。映射文件即SQL映射文件,该文件中配置了操作数据库的SQL语句,需要在MyBatis配置文件mybatis-config.xml中…...

掌握Pillow:Python图像处理的艺术

掌握Pillow:Python图像处理的艺术 引言Python与图像处理的概述Pillow库基础导入Pillow库基本概念图像的打开、保存和显示 图像操作基础图像的剪裁图像的旋转和缩放色彩转换和滤镜应用文字和图形的绘制 高级图像处理图像的合成与蒙版操作像素级操作与图像增强复杂图形…...

React最常用的几个hook

React最常用的几个Hook包括:useState、useEffect、useRef以及useContext。 useState: 用于在函数组件中添加状态管理。它返回一个数组,第一个元素是当前状态的值,第二个元素是更新状态的函数。在使用时,可以通过解构赋…...

自然语言处理Gensim入门:建模与模型保存

文章目录 自然语言处理Gensim入门:建模与模型保存关于gensim基础知识1. 模块导入2. 内部变量定义3. 主函数入口 (if __name__ __main__:)4. 加载语料库映射5. 加载和预处理语料库6. 根据方法参数选择模型训练方式7. 保存模型和变换后的语料8.代码 自然语言处理Gens…...

Windows 10中Visual Studio Code(VSCode)无法自动打开终端的解决办法

1.检查设置: 打开VSCode。点击左侧菜单栏的“文件”(File)。选择“首选项”(Preferences)。点击“设置”(Settings)。在搜索框中输入“shell”,然后点击“settings.json”进行编辑。…...

python dictionary 字典中的内置函数介绍及其示例

Python字典内置方法: 本文介绍了Python字典(dictionary)中的内置函数及其用法示例。字典是Python中非常常用的一种数据结构,它允许我们通过键(key)来快速查找、添加、修改或删除值(value&#…...

pdf转word文档怎么转?分享4种转换方法

pdf转word文档怎么转?在日常工作中,我们经常遇到需要将PDF文件转换为Word文档的情况。无论是为了编辑、修改还是为了重新排版,将PDF转为Word都显得尤为重要。那么,PDF转Word文档怎么转呢?今天,就为大家分享…...

深度测试:指定DoC ID对ES写入性能的影响

在[[使用python批量写入ES索引数据]]中已经介绍了如何批量写入ES数据。基于该流程实际测试一下指定文档ID对ES性能的影响有多大。 一句话版 指定ID比不指定ID的性能下降了63%,且加剧趋势。 以下是测评验证的细节。 百万数据量 索引默认使用1分片和1副本。 指定…...

【JGit】 AddCommand 新增的文件不能添加到暂存区

执行git.add().addFilepattern(".").setUpdate(true).call() 。新增的文件不能添加到暂存区,为什么? 在 JGit 中,setUpdate(true) 方法用于在调用 AddCommand 的 addFilepattern() 方法时,将已跟踪文件标记为需要更新。…...

golang学习6,glang的web的restful接口传参

1.get传参 //get请求 返回json 接口传参r.GET("/getJson/:id", controller.GetUserInfo) 1.2.接收处理 package controllerimport "github.com/gin-gonic/gin"func GetUserInfo(c *gin.Context) {_ c.Param("id")ReturnSucess(c, 200, &quo…...

Carla自动驾驶仿真八:两种查找CARLA地图坐标点的方法

文章目录 前言一、通过Spectator获取坐标二、通过道路ID获取坐标总结 前言 CARLA没有直接的方法给使用者查找地图坐标点来生成车辆,这里推荐两种实用的方法在特定的地方生成车辆。 一、通过Spectator获取坐标 1、Spectator(观察者)&#xf…...

HarmonyOS | 状态管理(八) | PersistentStorage(持久化存储UI状态)

系列文章目录 1.HarmonyOS | 状态管理(一) | State装饰器 2.HarmonyOS | 状态管理(二) | Prop装饰器 3.HarmonyOS | 状态管理(三) | Link装饰器 4.HarmonyOS | 状态管理(四) | Provide和Consume装饰器 5.HarmonyOS | 状态管理(五) | Observed装饰器和ObjectLink装饰器 6.Harmo…...

Git 突破 文件尺寸限制

前言 当Git本地存储里右超过50MB,却又确实需要上传的时候,就需要用到了不是 解决 本代码就是把大文件进行拆解成小文件,然后上传。 等到拉取下来的时候,可以直接再进行合并,合并成原文件 代码如下,仅供…...

HarmonyOS开发云工程与开发云函数

创建函数 您可直接在DevEco Studio创建函数、编写函数业务代码、为函数配置调用触发器。 1.右击“cloudfunctions”目录,选择“New > Cloud Function”。 2.输入函数名称后,点击“OK”。 函数名称仅支持小写英文字母、数字、中划线(-&a…...

SpringMVC了解

1.springMVC概述 Spring MVC(Model-View-Controller)是基于 Java 的 Web 应用程序框架,用于开发 Web 应用程序。它通过将应用程序分为模型(Model)、视图(View)和控制器(Controller&a…...

day44((VueJS)路由的懒加载使用 路由的元信息(meta) 路由守卫函数 vant组件库的应用)

一.路由懒加载的使用 使用原因 1.使用原因1) 使用一般写法(即直接填写组件的缺点)当使用这种写法,页面在初次加载会将所有路由配置表的添加的组件一次性全部加载,如果项目中组件代码量庞大,就需要很长时间…...

非线性优化资料整理

做课题看了一些非线性优化的资料,整理一下,以方便查看: 优化的中文博客 数值优化|笔记整理(8)——带约束优化:引入,梯度投影法 (附代码)QP求解器对比对于MPC的QP求解器 数值优化| 二次规划的…...

踩坑wow.js 和animate.css一起使用没有效果

踩坑wow.js 和animate.css一起使用没有效果 问题及解决方法一、电脑系统配置问题二、版本问题 问题及解决方法 一、电脑系统配置问题 在系统属性里面把窗口内的动画和元素勾选 二、版本问题 使用wow加animate4.4.1也就是最新本,打开网页没有任何动画效果 但是把…...

第19节 Node.js Express 框架

Express 是一个为Node.js设计的web开发框架,它基于nodejs平台。 Express 简介 Express是一个简洁而灵活的node.js Web应用框架, 提供了一系列强大特性帮助你创建各种Web应用,和丰富的HTTP工具。 使用Express可以快速地搭建一个完整功能的网站。 Expre…...

国防科技大学计算机基础课程笔记02信息编码

1.机内码和国标码 国标码就是我们非常熟悉的这个GB2312,但是因为都是16进制,因此这个了16进制的数据既可以翻译成为这个机器码,也可以翻译成为这个国标码,所以这个时候很容易会出现这个歧义的情况; 因此,我们的这个国…...

RocketMQ延迟消息机制

两种延迟消息 RocketMQ中提供了两种延迟消息机制 指定固定的延迟级别 通过在Message中设定一个MessageDelayLevel参数,对应18个预设的延迟级别指定时间点的延迟级别 通过在Message中设定一个DeliverTimeMS指定一个Long类型表示的具体时间点。到了时间点后&#xf…...

<6>-MySQL表的增删查改

目录 一,create(创建表) 二,retrieve(查询表) 1,select列 2,where条件 三,update(更新表) 四,delete(删除表&#xf…...

PPT|230页| 制造集团企业供应链端到端的数字化解决方案:从需求到结算的全链路业务闭环构建

制造业采购供应链管理是企业运营的核心环节,供应链协同管理在供应链上下游企业之间建立紧密的合作关系,通过信息共享、资源整合、业务协同等方式,实现供应链的全面管理和优化,提高供应链的效率和透明度,降低供应链的成…...

解决Ubuntu22.04 VMware失败的问题 ubuntu入门之二十八

现象1 打开VMware失败 Ubuntu升级之后打开VMware上报需要安装vmmon和vmnet,点击确认后如下提示 最终上报fail 解决方法 内核升级导致,需要在新内核下重新下载编译安装 查看版本 $ vmware -v VMware Workstation 17.5.1 build-23298084$ lsb_release…...

java调用dll出现unsatisfiedLinkError以及JNA和JNI的区别

UnsatisfiedLinkError 在对接硬件设备中,我们会遇到使用 java 调用 dll文件 的情况,此时大概率出现UnsatisfiedLinkError链接错误,原因可能有如下几种 类名错误包名错误方法名参数错误使用 JNI 协议调用,结果 dll 未实现 JNI 协…...

MVC 数据库

MVC 数据库 引言 在软件开发领域,Model-View-Controller(MVC)是一种流行的软件架构模式,它将应用程序分为三个核心组件:模型(Model)、视图(View)和控制器(Controller)。这种模式有助于提高代码的可维护性和可扩展性。本文将深入探讨MVC架构与数据库之间的关系,以…...

Qwen3-Embedding-0.6B深度解析:多语言语义检索的轻量级利器

第一章 引言:语义表示的新时代挑战与Qwen3的破局之路 1.1 文本嵌入的核心价值与技术演进 在人工智能领域,文本嵌入技术如同连接自然语言与机器理解的“神经突触”——它将人类语言转化为计算机可计算的语义向量,支撑着搜索引擎、推荐系统、…...

04-初识css

一、css样式引入 1.1.内部样式 <div style"width: 100px;"></div>1.2.外部样式 1.2.1.外部样式1 <style>.aa {width: 100px;} </style> <div class"aa"></div>1.2.2.外部样式2 <!-- rel内表面引入的是style样…...