当前位置: 首页 > news >正文

【强化学习】17 ——DDPG(Deep Deterministic Policy Gradient)

文章目录

  • 前言
    • DDPG特点
  • 随机策略与确定性策略
  • DDPG:深度确定性策略梯度
    • 伪代码
    • 代码实践

前言

之前的章节介绍了基于策略梯度的算法 REINFORCE、Actor-Critic 以及两个改进算法——TRPO 和 PPO。这类算法有一个共同的特点:它们都是在线策略算法,这意味着它们的样本效率(sample efficiency)比较低。本章将要介绍的深度确定性策略梯度(deep deterministic policy gradient,DDPG)算法通过使用离线的数据以及Belllman等式去学习 Q Q Q函数,并利用 Q Q Q函数去学习策略。

DDPG特点

  • DDPG是离线学习算法
  • DDPG可以在连续的动作空间中进行使用
  • Open AI Spinning Up 中的DDPG未实现并行运行

随机策略与确定性策略

首先来回顾一下随机策略与确定性策略相关内容

随机策略

  • 离散动作: π ( a ∣ s ; θ ) = exp ⁡ { Q θ ( s , a ) } ∑ a , exp ⁡ { Q θ ( s , a ′ ) } \pi(a|s;\theta)=\frac{\exp\{Q_\theta(s,a)\}}{\sum_a,\exp\{Q_\theta(s,a^{\prime})\}} π(as;θ)=a,exp{Qθ(s,a)}exp{Qθ(s,a)},学习出价值函数之后再求取相应的softmax分布
  • 连续动作: π ( a ∣ s ; θ ) ∝ exp ⁡ { ( a − μ θ ( s ) ) 2 } \pi(a|s;\theta)\propto\exp\left\{\left(a-\mu_\theta(s)\right)^2\right\} π(as;θ)exp{(aμθ(s))2},学习出的策略符合高斯分布(均值,方差)

确定性策略

  • 离散动作: π ( s ; θ ) = arg ⁡ max ⁡ a Q θ ( s , a ) \pi(s;\theta)=\arg\max_aQ_\theta(s,a) π(s;θ)=argmaxaQθ(s,a)策略不可微,但可以通过学习价值函数再求取argmax的方式得到相应的策略
  • 连续动作: a = π ( s ; θ ) a=\pi(s;\theta) a=π(s;θ)策略可微,建立相应的函数映射,通过函数求导的方式进行策略学习

那么如何利用确定性策略学习连续动作呢?首先需要一个用于估计价值的Critic模块。 Q w ( s , a ) ≃ Q π ( s , a ) Q^w(s,a)\simeq Q^\pi(s,a) Qw(s,a)Qπ(s,a) L ( w ) = E s ∼ ρ π , a ∼ π θ [ ( Q w ( s , a ) − Q π ( s , a ) ) 2 ] L(w)=\mathbb{E}_{s\sim\rho^\pi,a\sim\pi_\theta}\left[\left(Q^w(s,a)-Q^\pi(s,a)\right)^2\right] L(w)=Esρπ,aπθ[(Qw(s,a)Qπ(s,a))2]

通过与环境的交互,可以获得状态的总体分布,又因为 a = π ( s ; θ ) a=\pi(s;\theta) a=π(s;θ),因此可以利用链式法则进行求导。首先是 Q Q Q函数对 a a a进行求导( Q Q Q函数通常由网络学习出来,对 a a a向量进行求导相当于是调整相应的梯度以使得获得更大的 Q Q Q值),接着因为 a = π ( s ; θ ) a=\pi(s;\theta) a=π(s;θ),所以 a a a π \pi π进行求导。
J ( π θ ) = E s ∼ ρ π [ Q π ( s , a ) ] J(\pi_\theta)=\mathbb{E}_{s\sim\rho^\pi}[Q^\pi(s,a)] J(πθ)=Esρπ[Qπ(s,a)] ∇ θ J ( π θ ) = E s ∼ ρ π [ ∇ θ π θ ( s ) ∇ a Q π ( s , a ) ∣ a = π θ ( s ) ] \nabla_\theta J(\pi_\theta)=\mathbb{E}_{s\sim\rho^\pi}[\nabla_\theta\pi_\theta(s)\nabla_aQ^\pi(s,a)|_{a=\pi_\theta(s)}] θJ(πθ)=Esρπ[θπθ(s)aQπ(s,a)a=πθ(s)]

上式即为确定性策略梯度定理。确定性策略梯度定理的具体证明过程可参考《动手学强化学习》13.5 节。

DDPG:深度确定性策略梯度

在实际应用中,上述的带有神经函数近似器的actor-critic方法在面对有
挑战性的问题时是不稳定的。深度确定性策略梯度(DDPG)给出了在确定性策略梯度(DPG)基础上的解决方法:
• 经验重放(离线策略)
• 目标网络
• 在动作输入前标准化Q网络
• 添加连续噪声

下面我们来看一下 DDPG 算法的细节。DDPG 要用到4个神经网络,其中 Actor 和 Critic 各用一个网络,此外它们都各自有一个目标网络。DDPG 中 Actor 也需要目标网络因为目标网络也会被用来计算目标 Q Q Q值。DDPG 中目标网络的更新与 DQN 中略有不同:在 DQN 中,每隔一段时间将 Q Q Q网络直接复制给目标 Q Q Q网络;而在 DDPG 中,目标 Q Q Q网络的更新采取的是一种软更新(延时更新)的方式,即让目标 Q Q Q网络缓慢更新,逐渐接近网络,其公式为:
ω − ← τ ω + ( 1 − τ ) ω − \omega^-\leftarrow\tau\omega+(1-\tau)\omega^- ωτω+(1τ)ω

通常 τ \tau τ是一个比较小的数,当 τ = 1 \tau=1 τ=1时,就和 DQN 的更新方式一致了。而目标 μ \mu μ网络(策略网络)也使用这种软更新的方式。

另外,由于 Q Q Q函数存在 Q Q Q值过高估计的问题,DDPG 采用了 Double DQN 中的技术来更新 Q Q Q网络。但是,由于 DDPG 采用的是确定性策略,它本身的探索仍然十分有限。回忆一下 DQN 算法,它的探索主要由 ϵ \epsilon ϵ-贪婪策略的行为策略产生。同样作为一种离线策略的算法,DDPG 在行为策略上引入一个 N \mathcal{N} N随机噪声(原论文使用的是OU噪声,后来许多实验证明高斯噪声具有更好的效果)来进行探索。
在这里插入图片描述

OU噪声

伪代码

在这里插入图片描述

在这里插入图片描述

代码实践

import gymnasium as gym
import numpy as np
from tqdm import tqdm
import torch
import torch.nn.functional as F
import utilclass PolicyNet(torch.nn.Module):def __init__(self, state_dim, hidden_dim, action_dim, action_bound):super(PolicyNet, self).__init__()self.fc1 = torch.nn.Linear(state_dim, hidden_dim)self.fc2 = torch.nn.Linear(hidden_dim, action_dim)# action_bound是环境可以接受的动作最大值self.action_bound = action_bounddef forward(self, x):x = F.relu(self.fc1(x))return torch.tanh(self.fc2(x)) * self.action_boundclass QValueNet(torch.nn.Module):def __init__(self, state_dim, hidden_dim, action_dim):super(QValueNet, self).__init__()self.fc1 = torch.nn.Linear(state_dim + action_dim, hidden_dim)self.fc2 = torch.nn.Linear(hidden_dim, hidden_dim)self.fc_out = torch.nn.Linear(hidden_dim, 1)def forward(self, s, a):# 拼接状态和动作cat = torch.cat([s, a], dim=1)x = F.relu(self.fc1(cat))x = F.relu(self.fc2(x))return self.fc_out(x)class DDPG:''' DDPG算法 '''def __init__(self, state_dim, hidden_dim, action_dim, actor_lr, critic_lr, gamma,action_bound, sigma, tau, buffer_size, minimal_size, batch_size, device, numOfEpisodes, env):self.action_dim = action_dimself.actor = PolicyNet(state_dim, hidden_dim, action_dim, action_bound).to(device)self.critic = QValueNet(state_dim, hidden_dim, action_dim).to(device)self.target_actor = PolicyNet(state_dim, hidden_dim, action_dim, action_bound).to(device)self.target_critic = QValueNet(state_dim, hidden_dim, action_dim).to(device)# 初始化目标价值网络并设置和价值网络相同的参数self.target_critic.load_state_dict(self.critic.state_dict())# 初始化目标策略网络并设置和策略相同的参数self.target_actor.load_state_dict(self.actor.state_dict())self.critic_optimizer = torch.optim.Adam(self.critic.parameters(), lr=critic_lr)self.actor_optimizer = torch.optim.Adam(self.actor.parameters(), lr=actor_lr)self.gamma = gammaself.sigma = sigma  # 高斯噪声的标准差,均值直接设为0self.tau = tau  # 目标网络软更新参数self.device = deviceself.env = envself.numOfEpisodes = numOfEpisodesself.buffer_size = buffer_sizeself.minimal_size = minimal_sizeself.batch_size = batch_sizedef take_action(self, state):state = torch.FloatTensor(np.array([state])).to(self.device)action = self.actor(state).item()# 给动作添加噪声,增加探索action = action + self.sigma * np.random.randn(self.action_dim)return actiondef soft_update(self, net, target_net):for param_target, param in zip(target_net.parameters(), net.parameters()):param_target.data.copy_(param_target.data * (1.0 - self.tau) + param.data * self.tau)def update(self, transition_dict):states = torch.tensor(np.array(transition_dict['states']), dtype=torch.float).to(self.device)actions = torch.tensor(np.array(transition_dict['actions']), dtype=torch.float).view(-1, 1).to(self.device)rewards = torch.tensor(transition_dict['rewards'], dtype=torch.float).view(-1, 1).to(self.device)next_states = torch.tensor(np.array(transition_dict['next_states']), dtype=torch.float).to(self.device)terminateds = torch.tensor(transition_dict['terminateds'], dtype=torch.float).view(-1, 1).to(self.device)truncateds = torch.tensor(transition_dict['truncateds'], dtype=torch.float).view(-1, 1).to(self.device)q_targets = rewards + self.gamma * (self.target_critic(next_states, self.target_actor(next_states))) * (1 - terminateds + truncateds)critic_loss = torch.mean(F.mse_loss(q_targets, self.critic(states, actions)))self.critic_optimizer.zero_grad()critic_loss.backward()self.critic_optimizer.step()actor_loss = -torch.mean(self.critic(states, self.actor(states)))self.actor_optimizer.zero_grad()actor_loss.backward()self.actor_optimizer.step()self.soft_update(self.actor, self.target_actor)  # 软更新策略网络self.soft_update(self.critic, self.target_critic)  # 软更新价值网络def DDPGtrain(self):replay_buffer = util.ReplayBuffer(self.buffer_size)returnList = []for i in range(10):with tqdm(total=int(self.numOfEpisodes / 10), desc='Iteration %d' % i) as pbar:for episode in range(int(self.numOfEpisodes / 10)):# initialize statestate, info = self.env.reset()terminated = Falsetruncated = FalseepisodeReward = 0# Loop for each step of episode:while (not terminated) or (not truncated):action = self.take_action(state)next_state, reward, terminated, truncated, info = self.env.step(action)replay_buffer.add(state, action, reward, next_state, terminated, truncated)state = next_stateepisodeReward += reward# 当buffer数据的数量超过一定值后,才进行Q网络训练if replay_buffer.size() > self.minimal_size:b_s, b_a, b_r, b_ns, b_te, b_tr = replay_buffer.sample(self.batch_size)transition_dict = {'states': b_s,'actions': b_a,'next_states': b_ns,'rewards': b_r,'terminateds': b_te,'truncateds': b_tr}self.update(transition_dict)if terminated or truncated:breakreturnList.append(episodeReward)if (episode + 1) % 10 == 0:  # 每10条序列打印一下这10条序列的平均回报pbar.set_postfix({'episode':'%d' % (self.numOfEpisodes / 10 * i + episode + 1),'return':'%.3f' % np.mean(returnList[-10:])})pbar.update(1)return returnList

超参数设置参考:

    agent = DDPG(state_dim=env.observation_space.shape[0],hidden_dim=256,action_dim=env.action_space.shape[0],actor_lr=3e-4,critic_lr=3e-3,gamma=0.99,action_bound=env.action_space.high[0],sigma=0.01,tau=0.005,buffer_size=10000,minimal_size=1000,batch_size=64,device=device,numOfEpisodes=200,env=env)

在这里插入图片描述
DDPG算法相比之前的在线学习算法,更加稳定,同时收敛速度更快。

深度确定性策略梯度算法(DDPG),它是面向连续动作空间的深度确定性策略训练的典型算法。相比于它的先期工作,即确定性梯度算法(DPG),DDPG 加入了目标网络和软更新的方法,这对深度模型构建的价值网络和策略网络的稳定学习起到了关键的作用。

相关文章:

【强化学习】17 ——DDPG(Deep Deterministic Policy Gradient)

文章目录 前言DDPG特点 随机策略与确定性策略DDPG:深度确定性策略梯度伪代码代码实践 前言 之前的章节介绍了基于策略梯度的算法 REINFORCE、Actor-Critic 以及两个改进算法——TRPO 和 PPO。这类算法有一个共同的特点:它们都是在线策略算法&#xff0c…...

驱动开发11-2 编写SPI驱动程序-点亮数码管

驱动程序 #include <linux/init.h> #include <linux/module.h> #include <linux/spi/spi.h>int m74hc595_probe(struct spi_device *spi) {printk("%s:%d\n",__FILE__,__LINE__);char buf[]{0XF,0X6D};spi_write(spi,buf,sizeof(buf));return 0; …...

Java使用pdfbox进行pdf和图片之间的转换

简介 pdfbox是Apache开源的一个项目,支持pdf文档操作功能。 官网地址: Apache PDFBox | A Java PDF Library 支持的功能如下图.引入依赖 <dependency><groupId>org.apache.pdfbox</groupId><artifactId>pdfbox-app</artifactId><version>…...

机器学习中的关键组件

机器学习中的关键组件 数据 每个数据集由一个个样本组成&#xff0c;大多时候&#xff0c;它们遵循独立同分布。样本有时也叫作数据点或数据实例&#xff0c;通常每个样本由一组称为特征或协变量的属性组成。机器学习会根据这些属性进行预测&#xff0c;预测得到的称为标签或…...

【JVM】JDBC案例打破双亲委派机制

&#x1f40c;个人主页&#xff1a; &#x1f40c; 叶落闲庭 &#x1f4a8;我的专栏&#xff1a;&#x1f4a8; c语言 数据结构 javaEE 操作系统 Redis 石可破也&#xff0c;而不可夺坚&#xff1b;丹可磨也&#xff0c;而不可夺赤。 JVM 打破双亲委派机制&#xff08;JDBC案例…...

每天五分钟计算机视觉:池化层的反向传播

本文重点 卷积神经网络(Convolutional Neural Network,CNN)作为一种强大的深度学习模型,在计算机视觉任务中取得了巨大成功。其中,池化层(Pooling Layer)在卷积层之后起到了信息压缩和特征提取的作用。然而,池化层的反向传播一直以来都是一个相对复杂和深奥的问题。本…...

Docker的安装、基础命令与项目部署

文章目录 前言一、docker安装与MySQL部署1.Linux环境下docker的安装&#xff08;1&#xff09;基于CentOS7&#xff08;2&#xff09;基于Ubuntu 二、docker基础1.常见命令&#xff08;1&#xff09;快速创建一个mysql容器&#xff08;MySQL得一键安装&#xff09;。&#xff0…...

Nodejs和npm的使用方法和教程

Nodejs简介 Node.js 是一个开源和跨平台的 JavaScript 运行时环境。 它几乎是任何类型项目的流行工具&#xff01; &#xff08; 运行环境&#xff0c;是不是很熟悉&#xff0c;对。就是 java JRE&#xff0c;Java 运行时环境&#xff09; Node.js 在浏览器之外运行 V8 Java…...

机器学习---支持向量机的初步理解

1. SVM的经典解释 改编自支持向量机解释得很好 |字节大小生物学 (bytesizebio.net) 话说&#xff0c;在遥远的从前&#xff0c;有一只贪玩爱搞破坏的妖怪阿布劫持了善良美丽的女主小美&#xff0c;智勇双全 的男主大壮挺身而出&#xff0c;大壮跟随阿布来到了妖怪的住处&…...

【unity实战】Unity实现2D人物双击疾跑

最终效果 前言 我们要实现的功能是双击疾跑&#xff0c;当玩家快速地按下同一个移动键两次时能进入跑步状态 我假设快速按下的定义为0.2秒内&#xff0c;按下同一按键两次 简单的分析一下需求&#xff0c;实现它的关键在于获得按键按下的时间&#xff0c;我们需要知道第一次…...

Spring面试题:(二)基于xml方式的Spring配置

xml配置Bean的常见属性 id属性 name属性 scope属性 lazy-init属性 init-method属性和destroy属性 initializingBean方法 Bean实例化方式 ApplicationContext底层调用BeanFactory创建Bean&#xff0c;BeanFactory可以利用反射机制调用构造方法实例化Bean&#xff0c;也可采用工…...

XR Interaction ToolKit

一、简介 XR Interaction Toolkit是unity官方的XR交互工具包。 官方XRI示例地址&#xff1a;https://github.com/Unity-Technologies/XR-Interaction-Toolkit-Examples 2023.3.14官方博客&#xff0c;XRIT v2.3 https://blog.unity.com/engine-platform/whats-new-in-xr-int…...

spring-boot中实现分片上传文件

一、上传文件基本实现 1、前端效果图展示&#xff0c;这里使用element-ui plus来展示样式效果 2、基础代码如下 <template><div><el-uploadref"uploadRef"class"upload-demo":limit"1":on-change"handleExceed":auto-…...

【ICN综述】信息中心网络隐私安全

ICN基本原理&#xff1a; 信息中心网络也是需要实现在不可信环境下可靠的信息交换和身份认证 信息中心网络采用以数据内容为中心的传输方式代替现有IP 网络中以主机为中心的通信方式&#xff0c;淡化信息数据物理或逻辑位置的重要性&#xff0c;以内容标识为代表实现数据的查找…...

基于STC12C5A60S2系列1T 8051单片机EEPROM应用

基于STC12C5A60S2系列1T 8051单片机EEPROM应用 STC12C5A60S2系列1T 8051单片机管脚图STC12C5A60S2系列1T 8051单片机I/O口各种不同工作模式及配置STC12C5A60S2系列1T 8051单片机I/O口各种不同工作模式介绍STC12C5A60S2系列1T 8051单片机EEPROM介绍基于STC12C5A60S2系列1T 8051单…...

手撕排序之直接选择排序

前言&#xff1a; 直接选择排序是排序中比较简单的排序&#xff0c;同时也是时间复杂度不是很优的排序。 思想&#xff1a; 本文主要讲解直接选择排序的优化版本。 我们经过一次遍历直接将该数列中最大的和最小的值挑选出来&#xff0c;如果是升序&#xff0c;就将最小的和…...

洛谷 P1359 租用游艇

题目链接 P1359 租用游艇 普及 题目描述 长江游艇俱乐部在长江上设置了 n n n 个游艇出租站 1 , 2 , 3 , . . . , n 1,2,3,...,n 1,2,3,...,n&#xff0c;游客可在这些游艇出租站租用游艇&#xff0c;并在下游的任何一个游艇出租站归还游艇。游艇出租站 i i i 到游艇出租站…...

springboot中没有主清单属性解决办法

在执行一个 spring boot 启动类时&#xff0c;提示 没有主清单属性 一般这个问题是没加 spring-boot-maven-plugin 插件的问题&#xff0c;但是项目中已经加了 <build><plugins><plugin><groupId>org.springframework.boot</groupId><artifa…...

C/C++ static关键字详解(最全解析,static是什么,static如何使用,static的常考面试题)

目录 一、前言 二、static关键字是什么&#xff1f; 三、static关键字修饰的对象是什么&#xff1f; 四、C 语言中的 static &#x1f34e;static的C用法 &#x1f349;static的重点概念 &#x1f350;static修饰局部变量 &#x1f4a6;static在修饰局部变量和函数的作用 &a…...

windwos10搭建我的世界服务器,并通过内网穿透实现联机游戏Minecraft

文章目录 1. Java环境搭建2.安装我的世界Minecraft服务3. 启动我的世界服务4.局域网测试连接我的世界服务器5. 安装cpolar内网穿透6. 创建隧道映射内网端口7. 测试公网远程联机8. 配置固定TCP端口地址8.1 保留一个固定tcp地址8.2 配置固定tcp地址 9. 使用固定公网地址远程联机 …...

C++:std::is_convertible

C++标志库中提供is_convertible,可以测试一种类型是否可以转换为另一只类型: template <class From, class To> struct is_convertible; 使用举例: #include <iostream> #include <string>using namespace std;struct A { }; struct B : A { };int main…...

在鸿蒙HarmonyOS 5中实现抖音风格的点赞功能

下面我将详细介绍如何使用HarmonyOS SDK在HarmonyOS 5中实现类似抖音的点赞功能&#xff0c;包括动画效果、数据同步和交互优化。 1. 基础点赞功能实现 1.1 创建数据模型 // VideoModel.ets export class VideoModel {id: string "";title: string ""…...

FFmpeg 低延迟同屏方案

引言 在实时互动需求激增的当下&#xff0c;无论是在线教育中的师生同屏演示、远程办公的屏幕共享协作&#xff0c;还是游戏直播的画面实时传输&#xff0c;低延迟同屏已成为保障用户体验的核心指标。FFmpeg 作为一款功能强大的多媒体框架&#xff0c;凭借其灵活的编解码、数据…...

【JVM】- 内存结构

引言 JVM&#xff1a;Java Virtual Machine 定义&#xff1a;Java虚拟机&#xff0c;Java二进制字节码的运行环境好处&#xff1a; 一次编写&#xff0c;到处运行自动内存管理&#xff0c;垃圾回收的功能数组下标越界检查&#xff08;会抛异常&#xff0c;不会覆盖到其他代码…...

【C语言练习】080. 使用C语言实现简单的数据库操作

080. 使用C语言实现简单的数据库操作 080. 使用C语言实现简单的数据库操作使用原生APIODBC接口第三方库ORM框架文件模拟1. 安装SQLite2. 示例代码:使用SQLite创建数据库、表和插入数据3. 编译和运行4. 示例运行输出:5. 注意事项6. 总结080. 使用C语言实现简单的数据库操作 在…...

RNN避坑指南:从数学推导到LSTM/GRU工业级部署实战流程

本文较长&#xff0c;建议点赞收藏&#xff0c;以免遗失。更多AI大模型应用开发学习视频及资料&#xff0c;尽在聚客AI学院。 本文全面剖析RNN核心原理&#xff0c;深入讲解梯度消失/爆炸问题&#xff0c;并通过LSTM/GRU结构实现解决方案&#xff0c;提供时间序列预测和文本生成…...

Mac下Android Studio扫描根目录卡死问题记录

环境信息 操作系统: macOS 15.5 (Apple M2芯片)Android Studio版本: Meerkat Feature Drop | 2024.3.2 Patch 1 (Build #AI-243.26053.27.2432.13536105, 2025年5月22日构建) 问题现象 在项目开发过程中&#xff0c;提示一个依赖外部头文件的cpp源文件需要同步&#xff0c;点…...

OPENCV形态学基础之二腐蚀

一.腐蚀的原理 (图1) 数学表达式&#xff1a;dst(x,y) erode(src(x,y)) min(x,y)src(xx,yy) 腐蚀也是图像形态学的基本功能之一&#xff0c;腐蚀跟膨胀属于反向操作&#xff0c;膨胀是把图像图像变大&#xff0c;而腐蚀就是把图像变小。腐蚀后的图像变小变暗淡。 腐蚀…...

在Ubuntu24上采用Wine打开SourceInsight

1. 安装wine sudo apt install wine 2. 安装32位库支持,SourceInsight是32位程序 sudo dpkg --add-architecture i386 sudo apt update sudo apt install wine32:i386 3. 验证安装 wine --version 4. 安装必要的字体和库(解决显示问题) sudo apt install fonts-wqy…...

【无标题】路径问题的革命性重构:基于二维拓扑收缩色动力学模型的零点隧穿理论

路径问题的革命性重构&#xff1a;基于二维拓扑收缩色动力学模型的零点隧穿理论 一、传统路径模型的根本缺陷 在经典正方形路径问题中&#xff08;图1&#xff09;&#xff1a; mermaid graph LR A((A)) --- B((B)) B --- C((C)) C --- D((D)) D --- A A -.- C[无直接路径] B -…...