当前位置: 首页 > news >正文

强化学习之Actor-Critic算法(基于值函数和策略的结合)——以CartPole环境为例

0.简介

DQN算法作为基于值函数的方法代表,基于值函数的方法只学习一个价值函数。REINFORCE算法作为基于策略的方法代表,基于策略的方法只学习一个策略函数。Actor-Critic算法则结合了两种学习方法,其本质是基于策略的方法,因为其目标是优化一个带参的策略,只是会额外学习价值函数帮助策略函数更好地学习。

我们回顾一下在 REINFORCE 算法中,目标函数的梯度中有一项轨迹回报,来指导策略的更新。而值函数的概念正是基于期望回报,我们能不能考虑拟合一个值函数来指导策略进行学习呢?这正是 Actor-Critic 算法所做的。让我们先回顾一下策略梯度的形式,在策略梯度中,我们可以把梯度写成下面这个形式:

image.png

其中 ψ t 可以有很多种形式:

image.png

 在 REINFORCE 的最后部分,我们提到了 REINFORCE通过蒙特卡洛采样的方法对梯度的估计是无偏的,但是方差非常大,我们可以用第三种形式引入基线 (baseline) b ( s t ) 来减小方差。此外我们也可以采用 Actor-Critic 算法,估计 一个动作价值函数 Q 来代替蒙特卡洛采样得到的回报,这便是第 4 种形式。这个时候,我们也可以把状态价值函数 V  作为基线,从偍牧但是用神经网络进行估计的方法可以减小方差、提高鲁棒性。除此之外,REINFORCE 算法基于蒙特卡洛采样,只能在序列结束后进行更新,而 Actor-Critic 的方法则可以在每一步之后都进行更新。

我们将 Actor-Critic 分为两个部分: 分别是 Actor (策略网络) 和 Critic (价值网络):

  • Critic 要做的是通过 Actor 与环境交互收集的数据学习一个价值函数,这个价值函数会用于帮助 Actor 进行更新策略。
  • Actor 要做的则是与环境交互,并利用 Ctitic 价值函数来用策略梯度学习一个更好的策略。

image.png

 与 DQN 中一样,我们采取类似于目标网络的方法,上式中 r + γ V ω ( s t + 1 )作为时序差分目标,不会产生梯度来更新价值函数。所以价值函数的梯度为

image.png

然后使用梯度下降方法即可。接下来让我们总体看看 Actor-Critic 算法的流程吧!

  • 初始化策略网络参数 θ  ,价值网络参数 ω
  • 不断进行如下循环 (每个循环是一条序列) :

。 用当前策略 π θ 平样轨 迹 { s 1 , a 1 , r 1 , s 2 , a 2 , r 2 … }

。 为每一步数据计算: δ t = r t + γ V ω ( s t + 1 ) − V ω ( s )

。 更新价值参数 w = w + α ω ∑ t δ t ∇ ω V ω ( s )

。 更新策略参数 θ = θ + α θ ∑ t δ t ∇ θ log ⁡ π θ ( a ∣ s )

 好了!这就是 Actor-Critic 算法的流程啦,让我们来用代码实现它看看效果如何吧!

1.导库

import gym
import torch
import matplotlib.pyplot as plt
import numpy as np
from tqdm import tqdm

2.策略网络PolicyNet定义

class PolicyNet(torch.nn.Module):#策略网络def __init__(self,statedim,hiddendim,actiondim):super(PolicyNet,self).__init__()self.fc1=torch.nn.Linear(statedim,hiddendim)self.fc2=torch.nn.Linear(hiddendim,actiondim)def forward(self,x):x=torch.nn.functional.relu(self.fc1(x))return torch.nn.functional.softmax(self.fc2(x),dim=1)

3.价值网络ValueNet定义

class ValueNet(torch.nn.Module):#价值网络def __init__(self,statedim,hiddendim):super(ValueNet,self).__init__()self.fc1=torch.nn.Linear(statedim,hiddendim)self.fc2=torch.nn.Linear(hiddendim,1)def forward(self,x):x=torch.nn.functional.relu(self.fc1(x))return self.fc2(x)

4.ActorCritic算法实现

class ActorCritic:#演员-评论家算法def __init__(self,statedim,hiddendim,actiondim,actor_learningrate,critic_learningrate,gamma,device):self.actor=PolicyNet(statedim,hiddendim,actiondim).to(device)#策略网络self.critic=ValueNet(statedim,hiddendim).to(device)#价值网络self.actor_optimizer=torch.optim.Adam(self.actor.parameters(),lr=actor_learningrate)#策略网络优化器self.critic_optimizer=torch.optim.Adam(self.critic.parameters(),lr=critic_learningrate)#价值网络优化器self.gamma=gammaself.device=devicedef takeaction(self,state):#根据策略网络采取动作state=torch.tensor([state],dtype=torch.float).to(self.device)probs=self.actor(state)actiondist=torch.distributions.Categorical(probs)action=actiondist.sample()return action.item()#返回选择的动作的索引的标量形式def update(self,transitiondist):#更新策略网络和价值网络states,actions,rewards,nextstates,dones=transitiondist["states"],transitiondist["actions"],transitiondist["rewards"],transitiondist["nextstates"],transitiondist["dones"]states=torch.tensor(states,dtype=torch.float).to(self.device)actions=torch.tensor(actions).view(-1,1).to(self.device)rewards=torch.tensor(rewards,dtype=torch.float).view(-1,1).to(self.device)nextstates=torch.tensor(nextstates,dtype=torch.float).to(self.device)dones=torch.tensor(dones,dtype=torch.float).view(-1,1).to(self.device)td_target=rewards+self.gamma*self.critic(nextstates)*(1-dones)#时序差分目标td_delta=td_target-self.critic(states)#时序差分误差log_probs=torch.log(self.actor(states).gather(1,actions))#.detach() 来创建一个与原始张量值相同但不可训练的副本。这个副本可以在不影响原始张量的情况下进行各种操作,并且不会在反向传播中被更新。actor_loss=torch.mean(-log_probs*td_delta.detach())#策略网络的损失函数;#.detach()的作用是将这个张量从计算图中分离出来,这样在计算损失时不会对其进行反向传播,通常是为了防止某些不希望被更新的部分被意外更新。critic_loss=torch.mean(torch.nn.functional.mse_loss(self.critic(states),td_target.detach()))#均方差损失函数self.actor_optimizer.zero_grad()self.critic_optimizer.zero_grad()actor_loss.backward()#计算策略网络的梯度critic_loss.backward()#计算价值网络的梯度self.actor_optimizer.step()#策略网络参数更新self.critic_optimizer.step()#价值网络参数更新

5.训练本算法的函数实现

def train_on_policy_agent(env,agent,episodesnum,pbarnum,printreturnnum,seedid):#训练演员-评论家算法returnlist=[]for k in range(pbarnum):with tqdm(total=int(episodesnum/pbarnum),desc='Iteration %d' % k) as pbar:for episode in range(int(episodesnum/pbarnum)):episodereturn=0transitiondist={"states":[],"actions":[],"nextstates":[],"rewards":[],"dones":[]}state=env.reset(seed=seedid)[0]done=Falsewhile not done:action=agent.takeaction(state)nextstate,reward,done,truncated,_=env.step(action)done=done or truncatedtransitiondist["states"].append(state)transitiondist["actions"].append(action)transitiondist["nextstates"].append(nextstate)transitiondist["rewards"].append(reward)transitiondist["dones"].append(done)state=nextstateepisodereturn+=rewardreturnlist.append(episodereturn)agent.update(transitiondist)if (episode+1)%(printreturnnum)==0:pbar.set_postfix({"episode":"%d"%(episodesnum/pbarnum*k+episode+1),"return":"%.3f"%np.mean(returnlist[-printreturnnum:])})pbar.update(1)return returnlist

6.移动平滑处理时间序列函数实现

def moving_average(a, window_size):cumulative_sum = np.cumsum(np.insert(a, 0, 0)) middle = (cumulative_sum[window_size:] - cumulative_sum[:-window_size]) / window_sizer = np.arange(1, window_size-1, 2)begin = np.cumsum(a[:window_size-1])[::2] / rend = (np.cumsum(a[:-window_size:-1])[::2] / r)[::-1]return np.concatenate((begin, middle, end))

7.参数配置

actor_learningrate=1e-3
critic_learningrate=1e-2
episodesnum=1000
hiddendim=128
gamma=0.98
pbarnum=10
printreturnnum=10
seedid=0
device=torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

8.车杆环境实验

env=gym.make("CartPole-v1")#env=gym.make("CartPole-v1",render_mode="human")
env.reset(seed=seedid)
torch.manual_seed(seedid)
statedim=env.observation_space.shape[0]
actiondim=env.action_space.n
agent=ActorCritic(statedim,hiddendim,actiondim,actor_learningrate,critic_learningrate,gamma,device)
returnlist=train_on_policy_agent(env,agent,episodesnum,pbarnum,printreturnnum,seedid)
episodelist=list(range(len(returnlist)))
plt.plot(episodelist,returnlist)
plt.xlabel("Episodes")
plt.ylabel("Returns")
plt.title("Actor-Critic on {}-{}".format(env.spec.name,env.spec.id))
plt.show()
mvreturn=moving_average(returnlist,9)
plt.plot(episodelist,mvreturn)
plt.xlabel("Episodes")
plt.ylabel("Returns")
plt.title("Actor-Critic on {}-{}".format(env.spec.name,env.spec.id))
plt.show()

9.实验结果

Actor-Critic算法很快收敛到最优策略,训练过程非常稳定,抖动情况与REINFORCE算法相比有了明显改进,这说明价值函数的引入减少了方差。 

10.小结

Actor-Critic算法是基于值函数和基于策略的方法的叠加,价值模块Critic在策略模块Actor采样的数据中学习分辨什么是好的动作,什么是不好的动作,进而指导Actor进行策略更新,随着Actor训练不断进行,与环境交互产生的数据分布也发生改变,这需要Critic尽快适应新数据分布并给出好的判别。TRPO、PPO、DDPG、SAC等深度强化学习算法都是在Actor-Critic算法基础上进行发展改进的,其作为基础,深入理解大有裨益。

相关文章:

强化学习之Actor-Critic算法(基于值函数和策略的结合)——以CartPole环境为例

0.简介 DQN算法作为基于值函数的方法代表,基于值函数的方法只学习一个价值函数。REINFORCE算法作为基于策略的方法代表,基于策略的方法只学习一个策略函数。Actor-Critic算法则结合了两种学习方法,其本质是基于策略的方法,因为其目…...

Linux学习记录(五)-------三类读写函数

文章目录 三种读写函数1.行缓存2.无缓存3.全缓存4.fgets和fputs5.gets和puts 三种读写函数 1.行缓存 遇到新行(\n),或者写满缓存时,即调用系统函数 读:fgets,gets,printf,fprintf,sprintf写:fputs,puts,scanf 2.无缓…...

2024年8月13日(lvs NAT脚本 RS脚本 ds脚本)

lvs-nat模式的优点配置简单,缺点是请求和响应都必须经过ds,容易称为性能瓶颈 希望有这样的模式,请求的时候使用input链进行负载均衡,响应的时候就不要经过ds,直接由rs响应给客户端 在nat模式的时候,请求vip,接收vip的响应 构想 请求vip,接受rip响应,这是不允许lvs-dr模式 NAT脚…...

css实现水滴效果图

效果图&#xff1a; <template><div style"width: 100%;height:500px;padding:20px;"><div class"water"></div></div> </template> <script> export default {data() {return {};},watch: {},created() {},me…...

接口测试面试题目,你都会了吗?

面试题 什么是接口测试&#xff1f; 接口自动化测试的流程是什么&#xff1f; GET请求和POST请求区别是什么&#xff1f; 接口测试的常用工具有哪些&#xff1f; HTTP接口的请求参数类型有哪些&#xff1f; 如何从上一个接口获取相关的响应数据传递到下一个接口&#xff1…...

jmeter-beanshell学习16-自定义函数

之前写了一个从文件获取指定数据&#xff0c;用的时候发现不太好用&#xff0c;写了一大段&#xff0c;只能取出一个数&#xff0c;再想取另一个数&#xff0c;再粘一大段。太不好看了&#xff0c;就想到了函数。查了一下确实可以写。 public int test(a,b){return ab; } ctes…...

LogicFlow工作流在React和Vue3中的使用

LogicFlow 是一款流程图编辑框架&#xff0c;提供了一系列流程图交互、编辑所必需的功能和简单灵活的节点自定义、插件等拓展机制&#xff0c;方便我们快速在业务系统内满足类流程图的需求。 核心能力 可视化模型&#xff1a;通过 LogicFlow 提供的直观可视化界面&#xff0c…...

Python循环语句:不到长城心不死

Python中的循环语句是编程中非常重要的结构&#xff0c;它们允许你重复执行一段代码多次&#xff0c;直到满足某个条件为止。Python提供了两种主要的循环类型&#xff1a;for循环和while循环。 文章目录 1. for 循环2. while 循环循环控制语句range() 函数结合循环语句和 rang…...

Unity教程(九)角色攻击的改进

Unity开发2D类银河恶魔城游戏学习笔记 Unity教程&#xff08;零&#xff09;Unity和VS的使用相关内容 Unity教程&#xff08;一&#xff09;开始学习状态机 Unity教程&#xff08;二&#xff09;角色移动的实现 Unity教程&#xff08;三&#xff09;角色跳跃的实现 Unity教程&…...

宠物空气净化器真的能除毛吗?有哪些选购技巧和品牌推荐修改版

夏日炎炎&#xff0c;有猫超甜。作为一名资深铲屎官&#xff0c;家里养有猫让我倍感幸福&#xff0c;夏天里有空调、有西瓜、有猫&#xff0c;这几个搭配在一起真的是超级爽。但在这么高温的夏天&#xff0c;家里养有宠物还是有不少烦恼的。比如家里的浮毛一直飘&#xff0c;似…...

Qt自定义注释

前言 是谁在Qt中编写代码&#xff0c;函数注释&#xff0c;类注释时&#xff0c;注释符号一个一个的敲&#xff1f; comment注释brief简洁的 Detailed详细的 第一步&#xff1a; 打开Qt 工具->选项->文本编辑器->片段 第二步&#xff1a; 点击添加 然后点击OK…...

【模电笔记】——信号的运算和处理电路(含电压比较器)

tips&#xff1a;本章节的笔记已经打包到word文档里啦&#xff0c;建议大家下载文章顶部资源&#xff08;有时看不到是在审核中&#xff0c;等等就能下载了。手机端下载后里面的插图可能会乱&#xff0c;建议电脑下载&#xff0c;兼容性更好且易于观看&#xff09;&#xff0c;…...

Java之 equals()与==

目录 运算符用途&#xff1a;用于比较两个引用是否指向同一个对象。比较内容&#xff1a;比较的是内存地址&#xff08;引用&#xff09;。适用范围&#xff1a;适用于基本数据类型和对象引用 equals() 方法用途&#xff1a;用于比较两个对象的内容是否相同。比较内容&#xf…...

Ubuntu20.04 运行深蓝路径规划hw1

前言 环境&#xff1a; ubuntu 20.04 &#xff1b; ROS版本&#xff1a; noetic&#xff1b; 问题 1、出现PCL报错&#xff1a;#error PCL requires C14 or above catkin_make 编译时&#xff0c;出现如下错误 解决&#xff1a; 在grid_path_searcher文件夹下面的CMakeLis…...

企业如何组建安全稳定的跨国通信网络

当企业在海外设有分公司时&#xff0c;如何建立一个安全且稳定的跨国通信网络是一个关键问题。为了确保跨国通信的安全和稳定性&#xff0c;可以考虑以下几种方案。 首先&#xff0c;可以在分公司之间搭建虚拟专用网络。虚拟专用网络通过对传输数据进行加密&#xff0c;保护通信…...

WordPress原创插件:Download-block-plugin下载按钮图标美化

WordPress原创插件&#xff1a;Download-block-plugin下载按钮图标美化 https://download.csdn.net/download/huayula/89632743...

前端【详解】缓存

HTTP 缓存 https://blog.csdn.net/weixin_41192489/article/details/136446539 CDN 缓存 CDN 全称 Content Delivery Network,即内容分发网络。 用户在浏览网站的时候&#xff0c;CDN会选择一个离用户最近的CDN边缘节点来响应用户的请求 CDN边缘节点的缓存机制与HTTP 缓存相同…...

P5821 【LK R-03】密码串匹配

[题目通道](【L&K R-03】密码串匹配 - 洛谷) 一道神题。 如果没有修改操作&#xff0c;翻转A数组或B数组后就是裸的FFT了 如果每次操作都暴力修改FFT时间复杂度显然爆炸 如果每次操作都不修改&#xff0c;记下修改序列&#xff0c;询问时加上修改序列的贡献&#xff0c;…...

httpx,一个网络请求的 Python 新宠儿

大家好&#xff01;我是爱摸鱼的小鸿&#xff0c;关注我&#xff0c;收看每期的编程干货。 一个简单的库&#xff0c;也许能够开启我们的智慧之门&#xff0c; 一个普通的方法&#xff0c;也许能在危急时刻挽救我们于水深火热&#xff0c; 一个新颖的思维方式&#xff0c;也许能…...

计算机网络408考研 2014

1 计算机网络408考研2014年真题解析_哔哩哔哩_bilibili 1 111 1 11 1...

MongoDB学习和应用(高效的非关系型数据库)

一丶 MongoDB简介 对于社交类软件的功能&#xff0c;我们需要对它的功能特点进行分析&#xff1a; 数据量会随着用户数增大而增大读多写少价值较低非好友看不到其动态信息地理位置的查询… 针对以上特点进行分析各大存储工具&#xff1a; mysql&#xff1a;关系型数据库&am…...

可靠性+灵活性:电力载波技术在楼宇自控中的核心价值

可靠性灵活性&#xff1a;电力载波技术在楼宇自控中的核心价值 在智能楼宇的自动化控制中&#xff0c;电力载波技术&#xff08;PLC&#xff09;凭借其独特的优势&#xff0c;正成为构建高效、稳定、灵活系统的核心解决方案。它利用现有电力线路传输数据&#xff0c;无需额外布…...

pam_env.so模块配置解析

在PAM&#xff08;Pluggable Authentication Modules&#xff09;配置中&#xff0c; /etc/pam.d/su 文件相关配置含义如下&#xff1a; 配置解析 auth required pam_env.so1. 字段分解 字段值说明模块类型auth认证类模块&#xff0c;负责验证用户身份&am…...

抖音增长新引擎:品融电商,一站式全案代运营领跑者

抖音增长新引擎&#xff1a;品融电商&#xff0c;一站式全案代运营领跑者 在抖音这个日活超7亿的流量汪洋中&#xff0c;品牌如何破浪前行&#xff1f;自建团队成本高、效果难控&#xff1b;碎片化运营又难成合力——这正是许多企业面临的增长困局。品融电商以「抖音全案代运营…...

【Web 进阶篇】优雅的接口设计:统一响应、全局异常处理与参数校验

系列回顾&#xff1a; 在上一篇中&#xff0c;我们成功地为应用集成了数据库&#xff0c;并使用 Spring Data JPA 实现了基本的 CRUD API。我们的应用现在能“记忆”数据了&#xff01;但是&#xff0c;如果你仔细审视那些 API&#xff0c;会发现它们还很“粗糙”&#xff1a;有…...

【C++从零实现Json-Rpc框架】第六弹 —— 服务端模块划分

一、项目背景回顾 前五弹完成了Json-Rpc协议解析、请求处理、客户端调用等基础模块搭建。 本弹重点聚焦于服务端的模块划分与架构设计&#xff0c;提升代码结构的可维护性与扩展性。 二、服务端模块设计目标 高内聚低耦合&#xff1a;各模块职责清晰&#xff0c;便于独立开发…...

分布式增量爬虫实现方案

之前我们在讨论的是分布式爬虫如何实现增量爬取。增量爬虫的目标是只爬取新产生或发生变化的页面&#xff0c;避免重复抓取&#xff0c;以节省资源和时间。 在分布式环境下&#xff0c;增量爬虫的实现需要考虑多个爬虫节点之间的协调和去重。 另一种思路&#xff1a;将增量判…...

听写流程自动化实践,轻量级教育辅助

随着智能教育工具的发展&#xff0c;越来越多的传统学习方式正在被数字化、自动化所优化。听写作为语文、英语等学科中重要的基础训练形式&#xff0c;也迎来了更高效的解决方案。 这是一款轻量但功能强大的听写辅助工具。它是基于本地词库与可选在线语音引擎构建&#xff0c;…...

C#中的CLR属性、依赖属性与附加属性

CLR属性的主要特征 封装性&#xff1a; 隐藏字段的实现细节 提供对字段的受控访问 访问控制&#xff1a; 可单独设置get/set访问器的可见性 可创建只读或只写属性 计算属性&#xff1a; 可以在getter中执行计算逻辑 不需要直接对应一个字段 验证逻辑&#xff1a; 可以…...

AirSim/Cosys-AirSim 游戏开发(四)外部固定位置监控相机

这个博客介绍了如何通过 settings.json 文件添加一个无人机外的 固定位置监控相机&#xff0c;因为在使用过程中发现 Airsim 对外部监控相机的描述模糊&#xff0c;而 Cosys-Airsim 在官方文档中没有提供外部监控相机设置&#xff0c;最后在源码示例中找到了&#xff0c;所以感…...