当前位置: 首页 > news >正文

AC的改进算法——TRPO、PPO

两类AC的改进算法

整理了动手学强化学习的学习内容

1. TRPO 算法(Trust Region Policy Optimization)

1.1. 前沿

策略梯度算法即沿着梯度方向迭代更新策略参数 。但是这种算法有一个明显的缺点:当策略网络沿着策略梯度更新参数,可能由于步长太长,策略突然显著变差,进而影响训练效果。

针对以上问题,考虑在更新时找到一块信任区域(trust region),在这个区域上更新策略时能够得到某种策略性能的安全性保证,这就是信任区域策略优化(trust region policy optimization,TRPO)算法的主要思想。

1.2. 一些推导

首先,最常规的动作价值函数,状态价值函数,优势函数定义如下:
在这里插入图片描述接着,一个策略的好坏可以期望折扣奖励J(πθ)J(\pi_\theta)J(πθ)表示:
J(πθ)=Es0,a0,...[∑t=0∞γtr(st)]=Es0[Vπθ(s0)]J(\pi_\theta)=E_{s_0,a_0,...}[\sum_{t=0}^{\infty}\gamma^tr(s_t)]=E_{s_0}[V^{\pi_\theta}(s_0)]J(πθ)=Es0,a0,...[t=0γtr(st)]=Es0[Vπθ(s0)]

其中,s0∼ρ0(s0)s_0 \sim \rho_0(s_0)s0ρ0(s0)at∼πθ(at∣st)a_t \sim \pi_\theta(a_t|s_t)atπθ(atst)at+1∼P(st+1∣st,at)a_{t+1} \sim P(s_{t+1}|s_t,a_t)at+1P(st+1st,at)
由于初始状态s0s_0s0的分布ρ0\rho_0ρ0和策略无关,因此上述策略πθ\pi_\thetaπθ下的优化目标J(πθ)J(\pi_\theta)J(πθ)可以写成在新策略πθ′\pi_{\theta'}πθ的期望形式:
在这里插入图片描述从而,推导新旧策略的目标函数之间的差距:
deltaJ=A将时序差分残差定义为优势函数A:
在这里插入图片描述所以只要我们能找到一个新策略,使得J(θ′)−J(θ)>=0J(\theta')-J(\theta)>=0J(θ)J(θ)>=0,就能保证策略性能单调递增。

但是直接求解该式是非常困难的,因为πθ′\pi_{\theta'}πθ是我们需要求解的策略,但我们又要用它来收集样本。把所有可能的新策略都拿来收集数据,然后判断哪个策略满足上述条件的做法显然是不现实的。

于是 TRPO 做了一步近似操作,对状态访问分布进行了相应处理。具体而言,忽略两个策略之间的状态访问分布变化,直接采用旧的策略的状态分布,定义如下替代优化目标:
在这里插入图片描述当新旧策略非常接近时,状态访问分布变化很小,这么近似是合理的。其中,动作仍然用新策略πθ′\pi_{\theta'}πθ采样得到,我们可以用重要性采样对动作分布进行处理:
在这里插入图片描述为了保证新旧策略足够接近,TRPO 使用了KL散度来衡量策略之间的距离,并给出了整体的优化公式:
优化这里的不等式约束定义了策略空间中的一个 KL 球,被称为信任区域。在这个区域中,可以认为当前学习策略和环境交互的状态分布与上一轮策略最后采样的状态分布一致,进而可以基于一步行动的重要性采样方法使当前学习策略稳定提升。

1.3. 近似求解

直接求解上式带约束的优化问题比较麻烦,TRPO 在其具体实现中做了一步近似操作来快速求解。
对目标函数和约束在θk\theta_kθk进行泰勒展开,分别用 1 阶、2 阶进行近似:
在这里插入图片描述于是我们的优化目标变成了:
在这里插入图片描述此时,我们可以用KKT条件直接导出上述问题的解:
解

1.4. 共轭梯度

一般来说,用神经网络表示的策略函数的参数数量都是成千上万的,计算和存储黑塞矩阵的逆矩阵会耗费大量的内存资源和时间。

TRPO 通过共轭梯度法(conjugate gradient method)回避了这个问题,它的核心思想是直接计算x=H−1gx=H^{-1}gx=H1gxxx即参数更新方向。假设满足 KL距离约束的参数更新时的最大步长为β=θ′−θ\beta=\theta'-\thetaβ=θθ
于是,根据 KL 距离约束条件12(θ′−θk)TH(θ′−θk)<=δ\frac{1}{2}(\theta'-\theta_k)^TH(\theta'-\theta_k)<=\delta21(θθk)TH(θθk)<=δ,有12(βx)TH(βx)=δ\frac{1}{2}(\beta x)^TH(\beta x)=\delta21(βx)TH(βx)=δ。求解β\betaβ,得到β=2δxTHx\beta=\sqrt{\frac{2\delta}{x^THx}}β=xTHx2δ。因此,此时参数更新方式为
θk+1=θk+2δxTHxx\theta_{k+1}=\theta_k+\sqrt{\frac{2\delta}{x^THx}}xθk+1=θk+xTHx2δx
因此,只要可以直接计算x=H−1gx=H^{-1}gx=H1g,就可以根据该式更新参数,问题转化为解Hx=gHx=gHx=g。实际上HHH为对称正定矩阵,所以我们可以使用共轭梯度法来求解。
共轭梯度法的具体流程如下:
在这里插入图片描述在共轭梯度运算过程中,直接计算αk\alpha_kαkrk+1r_{k+1}rk+1需要计算和存储海森矩阵HHH。为了避免这种大矩阵的出现,我们只计算HxHxHx向量,而不直接计算和存储HHH矩阵。这样做比较容易,因为对于任意的列向量vvv,容易验证:
Hv即先用梯度和向量vvv点乘后计算梯度。

    def hessian_matrix_vector_product(self, states, old_action_dists, vector):# 计算黑塞矩阵和一个向量的乘积new_action_dists = torch.distributions.Categorical(self.actor(states))kl = torch.mean(torch.distributions.kl.kl_divergence(old_action_dists,new_action_dists))  # 计算平均KL距离kl_grad = torch.autograd.grad(kl,self.actor.parameters(),create_graph=True)kl_grad_vector = torch.cat([grad.view(-1) for grad in kl_grad])# KL距离的梯度先和向量进行点积运算kl_grad_vector_product = torch.dot(kl_grad_vector, vector)grad2 = torch.autograd.grad(kl_grad_vector_product,self.actor.parameters())grad2_vector = torch.cat([grad.view(-1) for grad in grad2])return grad2_vectordef conjugate_gradient(self, grad, states, old_action_dists):  # 共轭梯度法求解方程x = torch.zeros_like(grad)r = grad.clone()p = grad.clone()rdotr = torch.dot(r, r)for i in range(10):  # 共轭梯度主循环Hp = self.hessian_matrix_vector_product(states, old_action_dists,p)alpha = rdotr / torch.dot(p, Hp)x += alpha * pr -= alpha * Hpnew_rdotr = torch.dot(r, r)if new_rdotr < 1e-10:breakbeta = new_rdotr / rdotrp = r + beta * prdotr = new_rdotrreturn x

1.5. 线性搜索

由于 TRPO 算法用到了泰勒展开的 1 阶和 2 阶近似,这并非精准求解,因此,θ\thetaθ可能未必比θk\theta_kθk好,或未必能满足 KL 散度限制。TRPO 在每次迭代的最后进行一次线性搜索,以确保找到满足条件。具体来说,就是找到一个最小的非负整数iii,使得按照
θk+1=θk+αi2δxTHxx\theta_{k+1}=\theta_{k}+\alpha^i \sqrt{\frac{2\delta}{x^THx}}xθk+1=θk+αixTHx2δx

求出的θk+1\theta_{k+1}θk+1依然满足最初的 KL 散度限制,并且确实能够提升目标函数,这KaTeX parse error: Undefined control sequence: \apha at position 1: \̲a̲p̲h̲a̲ ̲\in (0,1)其中是一个决定线性搜索长度的超参数。

1.6. 总结

至此,我们已经基本上清楚了 TRPO 算法的大致过程,它具体的算法流程如下:
在这里插入图片描述

2. PPO 算法(Trust Region Policy Optimization)

2.1. 前沿

PPO 算法作为TRPO算法的改进版,但是其算法实现更加简单。并且大量的实验结果表明,与TRPO相比,PPO能学习得一样好(甚至更快),这使得PPO成为非常流行的强化学习算法。如果我们想要尝试在一个新的环境中使用强化学习算法,那么 PPO 就属于可以首先尝试的算法。

PPO 的优化目标与 TRPO 相同,但 PPO用了一些相对简单的方法来求解(TRPO 使用泰勒展开近似、共轭梯度、线性搜索等方法直接求解)。具体来说,PPO 有两种形式,一是 PPO-惩罚,二是 PPO-截断,接下来对这两种形式进行介绍。

2.2. PPO-惩罚

PPO-Penalty拉格朗日乘数法直接将 KL 散度的限制放进了目标函数中,这就变成了一个无约束的优化问题,在迭代的过程中不断更新 KL 散度前的系数。即:
无约束的优化问题dk=DKLπθk(πθk,πθ)d_k=D_{KL}^{\pi_{\theta_k}}(\pi_{\theta_k},\pi_{\theta})dk=DKLπθk(πθk,πθ)β\betaβ的更新规则如下:

  1. 如果dk<δ/1.5d_k<\delta/1.5dk<δ/1.5,那么βk+1=βk/2\beta_{k+1}=\beta_k/2βk+1=βk/2
  2. 如果dk>δ×1.5d_k>\delta \times 1.5dk>δ×1.5,那么βk+1=βk×2\beta_{k+1}=\beta_k \times 2βk+1=βk×2
  3. 否则βk+1=βk\beta_{k+1}=\beta_kβk+1=βk

其中,δ\deltaδ是事先设定的一个超参数,用于限制学习策略和之前一轮策略的差距。

2.3 PPO-截断

PPO的另一种形式 PPO-截断(PPO-Clip) 更加直接,它在目标函数中进行限制,以保证新的参数和旧的参数的差距不会太大,即:
在这里插入图片描述其中clip(x,l,r):=max(min(x,r),l)clip(x,l,r):=max(min(x,r),l)clip(x,l,r):=max(min(x,r),l) ,即把xxx限制在[l,r][l,r][l,r]内。上式中ϵ\epsilonϵ是一个超参数,表示进行截断(clip)的范围。

如果Aπθk(s,a)>0A^{\pi_{\theta_k}}(s,a)>0Aπθk(s,a)>0,说明这个动作的价值高于平均,最大化这个式子会增大πθ(a∣s)πθk(a∣s)\frac{\pi_\theta (a|s)}{\pi_{\theta_k} (a|s)}πθk(as)πθ(as),但不会让其超过1+ϵ1+\epsilon1+ϵ。反之,如果Aπθk(s,a)<0A^{\pi_{\theta_k}}(s,a)<0Aπθk(s,a)<0,最大化这个式子会减小πθ(a∣s)πθk(a∣s)\frac{\pi_\theta (a|s)}{\pi_{\theta_k} (a|s)}πθk(as)πθ(as),但不会让其超过1−ϵ1-\epsilon1ϵ。如下图所示。
在这里插入图片描述

代码

最后,两个算法的代码可参考GitHub,Good Night!

相关文章:

AC的改进算法——TRPO、PPO

两类AC的改进算法 整理了动手学强化学习的学习内容 1. TRPO 算法&#xff08;Trust Region Policy Optimization&#xff09; 1.1. 前沿 策略梯度算法即沿着梯度方向迭代更新策略参数 。但是这种算法有一个明显的缺点&#xff1a;当策略网络沿着策略梯度更新参数&#xff0c…...

【C++学习】list的使用及模拟实现

&#x1f431;作者&#xff1a;一只大喵咪1201 &#x1f431;专栏&#xff1a;《C学习》 &#x1f525;格言&#xff1a;你只管努力&#xff0c;剩下的交给时间&#xff01; list的使用及模拟实现&#x1f63c;构造函数&#x1f435;模拟实现&#x1f63c;迭代器&#x1f435;…...

动态规划专题精讲1

致前行的人&#xff1a; 要努力&#xff0c;但不要着急&#xff0c;繁花锦簇&#xff0c;硕果累累都需要过程&#xff01; 前言&#xff1a; 本篇文章为大家带来一种重要的算法题&#xff0c;就是动态规划类型相关的题目&#xff0c;动态规划类的题目在笔试和面试中是考察非常高…...

PPO(proximal policy optimization)算法

博客写到一半发现有篇讲的很清楚&#xff0c;直接化缘了 https://www.jianshu.com/p/9f113adc0c50 Policy gradient 强化学习的目标&#xff1a;学习到一个策略πθ(a∣s)\pi\theta(a|s)πθ(a∣s)来最大化期望回报。 一种直接的方法就是在策略空间中直接搜索来得到最优策略&…...

ElasticSearch基本使用

title: ElasticSearch基本使用 date: 2022-08-29 00:00:00 tags: ElasticSearch基本使用 categories:ElasticSearch 基本概念 随着ES版本的升级&#xff0c;文中有些概念可能已经废弃。 索引词(term) 一个能够被索引的精确值&#xff0c;区分大小写&#xff0c;可以通过term查…...

windows微软商店下载应用失败/下载故障的解决办法;如何在网页上下载微软商店的应用

一、问题背景 设置惠普打印机时&#xff0c;需要安装hp smart&#xff0c;但是官方只提供微软商店这一下载渠道。 点击安装HP Smart&#xff0c;确定进入微软商店下载。 完全加载不出来&#xff0c;可能是因为开了代理。 把代理关了&#xff0c;就能正常打开了。 但是点击“…...

MySQL进阶篇之InnoDB存储引擎

06、InnoDB引擎 6.1、逻辑存储结构 表空间&#xff08;Tablespace&#xff09; 表空间在MySQL中最终会生成ibd文件&#xff0c;一个mysql实例可以对应多个表空间&#xff0c;用于存储记录、索引等数据。 段&#xff08;Segment&#xff09; 段&#xff0c;分为数据段&#x…...

商标侵权行为的种类有哪些

商标侵权行为的种类有哪些 1、商标侵权行为的种类有以下七种&#xff1a; (1)未经商标注册人的许可&#xff0c;在同一种商品上使用与其注册商标相同的商标的; (2)未经商标注册人的许可&#xff0c;在同一种商品上使用与其注册商标近似的商标&#xff0c;或者在类似商品上使…...

Similarity-Preserving KD(ICCV 2019)原理与代码解析

paper&#xff1a;Similarity-Preserving Knowledge Distillationcode&#xff1a;https://github.com/megvii-research/mdistiller/blob/master/mdistiller/distillers/SP.py背景本文的灵感来源于作者观察到在一个训练好的网络中&#xff0c;语义上相似的输入倾向于引起相似的…...

在Linux和Windows上安装seata-1.6.0

记录&#xff1a;381场景&#xff1a;在CentOS 7.9操作系统上&#xff0c;安装seata-1.6.0。在Windows上操作系统上&#xff0c;安装seata-1.6.0。Seata&#xff0c;一款开源的分布式事务解决方案&#xff0c;致力于提供高性能和简单易用的分布式事务服务。版本&#xff1a;JDK…...

兼职任务平台收集(二)分享给有需要的朋友们

互联网时代&#xff0c;给人们带来了很大的便利。信息交流、生活缴费、足不出户购物、便捷出行、线上医疗、线上教育等等很多。可以说&#xff0c;网络的时代会一直存在着。很多人也在互联网上赚到了第一桶金&#xff0c;这跟他们的努力和付出是息息相关的。所谓一份耕耘&#…...

目标检测三大数据格式VOC,YOLO,COCO的详细介绍

注&#xff1a;本文仅供学习&#xff0c;未经同意请勿转载 说明&#xff1a;该博客来源于xiaobai_Ry:2020年3月笔记 对应的PDF下载链接在&#xff1a;待上传 目录 目标检测常见数据集总结 V0C数据集(Annotation的格式是xmI) A. 数据集包含种类: B. V0C2007和V0C2012的区别…...

SpringBoot实现统一返回接口(除AOP)

起因 关于使用AOP去实现统一返回接口在之前的博客中我们已经实现了&#xff0c;但我突然突发奇想&#xff0c;SpringBoot中异常类的统一返回好像是通过RestControllerAdvice 这个注解去完成的&#xff0c;那我是否也可以通过这个注解去实现统一返回接口。 正文 这个方法主要…...

ChatGpt - 基于人工智能检索进行论文写作

摘要 ChatGPT 是一款由 OpenAI 训练的大型语言模型,可用于各种自然语言处理任务,包括论文写作。使用 ChatGPT 可以帮助作者提高论文的语言流畅度、增强表达能力和提高文章质量。在写作过程中,作者可以使用 ChatGPT 生成自然语言的段落、句子、单词或者短语,作为启发式的写…...

实例三:MATLAB APP design-多项式函数拟合

一、APP 界面设计展示 注:在左侧点击数据导入,选择自己的数据表,如果数据导入成功,在右侧的空白框就会显示数据导入成功。在多项式项数右侧框中输入项数,例如2、3、4等,点击计算按钮,右侧坐标框就会显示函数图像,在平均相对误差下面的空白框显示平均相对误差。...

springboot多种方式注入bean获取Bean

springboot动态注入bean1、创建Bean(demo)2、动态注入Bean3、通过注解注入Bean4、通过config配置注入Bean5、通过Import注解导入6、使用FactoryBean接口7、实现BeanDefinitionRegistryPostProcessor接口1、创建Bean(demo) Data public class Demo(){private String name;publi…...

Markdown及其语法详细介绍(全面)

文章目录一、基本语法1.标题2.段落和换行3.强调4.列表5.链接6.图片7.引用8.代码9.分割线10表格二、扩展语法1.标题锚点标题 {#anchor}2.脚注3.自动链接4.任务列表5.删除线6.表情符号7.数学公式三、Markdown 应用1.文档编辑2.博客写作3.代码笔记四、常见的工具和平台支持 Markdo…...

在Linux和Windows上安装sentinel-1.8.5

记录&#xff1a;380场景&#xff1a;在CentOS 7.9操作系统上&#xff0c;安装sentinel-1.8.5。在Windows上操作系统上&#xff0c;安装sentinel-1.8.5。Sentinel是面向分布式、多语言异构化服务架构的流量治理组件。版本&#xff1a;JDK 1.8 sentinel-1.8.5 CentOS 7.9官网地址…...

面试攻略,Java 基础面试 100 问(十)

StringBuffer、StringBuilder、String区别 线程安全 StringBuffer&#xff1a;线程安全&#xff0c;StringBuilder&#xff1a;线程不安全。 因为 StringBuffer 的所有公开方法都是 synchronized 修饰的&#xff0c;而 StringBuilder 并没有 synchronized 修饰。 StringBuf…...

Zero-shot(零次学习)简介

zero-shot基本概念 首先通过一个例子来引入zero-shot的概念。假设我们已知驴子和马的形态特征&#xff0c;又已知老虎和鬣狗都是又相间条纹的动物&#xff0c;熊猫和企鹅是黑白相间的动物&#xff0c;再次的基础上&#xff0c;我们定义斑马是黑白条纹相间的马科动物。不看任何斑…...

XML Group端口详解

在XML数据映射过程中&#xff0c;经常需要对数据进行分组聚合操作。例如&#xff0c;当处理包含多个物料明细的XML文件时&#xff0c;可能需要将相同物料号的明细归为一组&#xff0c;或对相同物料号的数量进行求和计算。传统实现方式通常需要编写脚本代码&#xff0c;增加了开…...

第19节 Node.js Express 框架

Express 是一个为Node.js设计的web开发框架&#xff0c;它基于nodejs平台。 Express 简介 Express是一个简洁而灵活的node.js Web应用框架, 提供了一系列强大特性帮助你创建各种Web应用&#xff0c;和丰富的HTTP工具。 使用Express可以快速地搭建一个完整功能的网站。 Expre…...

FFmpeg 低延迟同屏方案

引言 在实时互动需求激增的当下&#xff0c;无论是在线教育中的师生同屏演示、远程办公的屏幕共享协作&#xff0c;还是游戏直播的画面实时传输&#xff0c;低延迟同屏已成为保障用户体验的核心指标。FFmpeg 作为一款功能强大的多媒体框架&#xff0c;凭借其灵活的编解码、数据…...

Linux相关概念和易错知识点(42)(TCP的连接管理、可靠性、面临复杂网络的处理)

目录 1.TCP的连接管理机制&#xff08;1&#xff09;三次握手①握手过程②对握手过程的理解 &#xff08;2&#xff09;四次挥手&#xff08;3&#xff09;握手和挥手的触发&#xff08;4&#xff09;状态切换①挥手过程中状态的切换②握手过程中状态的切换 2.TCP的可靠性&…...

测试markdown--肇兴

day1&#xff1a; 1、去程&#xff1a;7:04 --11:32高铁 高铁右转上售票大厅2楼&#xff0c;穿过候车厅下一楼&#xff0c;上大巴车 &#xffe5;10/人 **2、到达&#xff1a;**12点多到达寨子&#xff0c;买门票&#xff0c;美团/抖音&#xff1a;&#xffe5;78人 3、中饭&a…...

智能在线客服平台:数字化时代企业连接用户的 AI 中枢

随着互联网技术的飞速发展&#xff0c;消费者期望能够随时随地与企业进行交流。在线客服平台作为连接企业与客户的重要桥梁&#xff0c;不仅优化了客户体验&#xff0c;还提升了企业的服务效率和市场竞争力。本文将探讨在线客服平台的重要性、技术进展、实际应用&#xff0c;并…...

在四层代理中还原真实客户端ngx_stream_realip_module

一、模块原理与价值 PROXY Protocol 回溯 第三方负载均衡&#xff08;如 HAProxy、AWS NLB、阿里 SLB&#xff09;发起上游连接时&#xff0c;将真实客户端 IP/Port 写入 PROXY Protocol v1/v2 头。Stream 层接收到头部后&#xff0c;ngx_stream_realip_module 从中提取原始信息…...

Qwen3-Embedding-0.6B深度解析:多语言语义检索的轻量级利器

第一章 引言&#xff1a;语义表示的新时代挑战与Qwen3的破局之路 1.1 文本嵌入的核心价值与技术演进 在人工智能领域&#xff0c;文本嵌入技术如同连接自然语言与机器理解的“神经突触”——它将人类语言转化为计算机可计算的语义向量&#xff0c;支撑着搜索引擎、推荐系统、…...

MySQL 8.0 OCP 英文题库解析(十三)

Oracle 为庆祝 MySQL 30 周年&#xff0c;截止到 2025.07.31 之前。所有人均可以免费考取原价245美元的MySQL OCP 认证。 从今天开始&#xff0c;将英文题库免费公布出来&#xff0c;并进行解析&#xff0c;帮助大家在一个月之内轻松通过OCP认证。 本期公布试题111~120 试题1…...

ArcGIS Pro制作水平横向图例+多级标注

今天介绍下载ArcGIS Pro中如何设置水平横向图例。 之前我们介绍了ArcGIS的横向图例制作&#xff1a;ArcGIS横向、多列图例、顺序重排、符号居中、批量更改图例符号等等&#xff08;ArcGIS出图图例8大技巧&#xff09;&#xff0c;那这次我们看看ArcGIS Pro如何更加快捷的操作。…...