当前位置: 首页 > news >正文

强化学习之DDPG算法

前言:
在正文开始之前,首先给大家介绍一个不错的人工智能学习教程:https://www.captainbed.cn/bbs。其中包含了机器学习、深度学习、强化学习等系列教程,感兴趣的读者可以自行查阅。


一、算法介绍

深度确定性策略梯度 (Deep Deterministic Policy Gradient,简称DDPG) 算法是一种基于策略梯度的方法,结合了深度神经网络和确定性策略的优势。它特别适用于具有连续动作空间的控制任务,如机械臂控制、自动驾驶等。DDPG算法通过同时训练一个演员网络(Actor)和一个评论家网络(Critic),实现对策略的优化。

主要特点包括:

  • 确定性策略:与随机策略不同,DDPG使用确定性策略,直接输出给定状态下的最优动作。
  • 经验回放(Replay Buffer):通过存储经验样本,打破样本间的相关性,提升训练稳定性。
  • 目标网络(Target Networks):使用延迟更新的目标网络,减少训练过程中的震荡和不稳定。

二、算法原理

2.1 网络结构

DDPG算法由两个主要网络组成:

  • 演员网络(Actor):参数为 θ μ \theta^\mu θμ,用于确定性地选择动作。

    a = μ ( s ∣ θ μ ) a = \mu(s|\theta^\mu) a=μ(sθμ)

  • 评论家网络(Critic):参数为 θ Q \theta^Q θQ,用于估计给定状态-动作对的Q值。

    Q ( s , a ∣ θ Q ) Q(s,a|\theta^Q) Q(s,aθQ)

此外,还存在两个目标网络,分别对应演员和评论家网络,参数为 θ μ ′ \theta^{\mu'} θμ θ Q ′ \theta^{Q'} θQ,用于计算目标Q值。

2.2 经验回放

经验回放池 D \mathcal{D} D用于存储经验元组 ( s t , a t , r t , s t + 1 ) (s_t, a_t, r_t, s_{t+1}) (st,at,rt,st+1)。在每次训练迭代中,算法从 D \mathcal{D} D中随机采样一个小批量样本,打破数据间的相关性,提高训练效率和稳定性。

2.3 目标网络的更新

目标网络的参数通过软更新方式更新:

θ μ ′ ← τ θ μ + ( 1 − τ ) θ μ ′ \theta^{\mu'} \leftarrow \tau \theta^\mu + (1 - \tau) \theta^{\mu'} θμτθμ+(1τ)θμ

θ Q ′ ← τ θ Q + ( 1 − τ ) θ Q ′ \theta^{Q'} \leftarrow \tau \theta^Q + (1 - \tau) \theta^{Q'} θQτθQ+(1τ)θQ

其中, τ \tau τ是软更新的步长,通常取值较小,如 0.001 0.001 0.001

2.4 损失函数与优化

  • 评论家网络的损失函数采用均方误差(MSE):

    L = 1 N ∑ i = 1 N ( y i − Q ( s i , a i ∣ θ Q ) ) 2 L = \frac{1}{N} \sum_{i=1}^N \left( y_i - Q(s_i, a_i|\theta^Q) \right)^2 L=N1i=1N(yiQ(si,aiθQ))2

    其中,

    y i = r i + γ Q ′ ( s i + 1 , μ ′ ( s i + 1 ∣ θ μ ′ ) ∣ θ Q ′ ) y_i = r_i + \gamma Q'(s_{i+1}, \mu'(s_{i+1}|\theta^{\mu'})|\theta^{Q'}) yi=ri+γQ(si+1,μ(si+1θμ)θQ)

  • 演员网络的损失函数通过最大化Q值来优化策略:

    J = − 1 N ∑ i = 1 N Q ( s i , μ ( s i ∣ θ μ ) ∣ θ Q ) J = -\frac{1}{N} \sum_{i=1}^N Q(s_i, \mu(s_i|\theta^\mu)|\theta^Q) J=N1i=1NQ(si,μ(siθμ)θQ)

2.5 算法流程

  1. 初始化演员网络 μ ( s ∣ θ μ ) \mu(s|\theta^\mu) μ(sθμ)和评论家网络 Q ( s , a ∣ θ Q ) Q(s,a|\theta^Q) Q(s,aθQ),以及对应的目标网络 μ ′ \mu' μ Q ′ Q' Q
  2. 初始化经验回放池 D \mathcal{D} D
  3. 对于每个回合:
    • 在环境中选择动作 a t = μ ( s t ∣ θ μ ) + N t a_t = \mu(s_t|\theta^\mu) + \mathcal{N}_t at=μ(stθμ)+Nt,其中 N t \mathcal{N}_t Nt为噪声,用于探索。
    • 执行动作 a t a_t at,观察奖励 r t r_t rt和下一个状态 s t + 1 s_{t+1} st+1
    • 存储经验 ( s t , a t , r t , s t + 1 ) (s_t, a_t, r_t, s_{t+1}) (st,at,rt,st+1) D \mathcal{D} D
    • D \mathcal{D} D中随机采样一个小批量样本。
    • 计算目标Q值 y i y_i yi
    • 更新评论家网络参数 θ Q \theta^Q θQ,最小化损失 L L L
    • 更新演员网络参数 θ μ \theta^\mu θμ,最大化 J J J
    • 软更新目标网络参数 θ μ ′ \theta^{\mu'} θμ θ Q ′ \theta^{Q'} θQ
  4. 重复以上步骤,直至收敛。

三、案例分析

在本节中,我们将通过在Pendulum-v0环境中应用DDPG算法,展示其具体实现过程。该环境的目标是让倒立摆尽可能长时间地保持直立状态,涉及连续动作空间。

3.1 环境简介

  • 状态空间:摆锤的角度、角速度,共3个维度。
  • 动作空间:施加的力矩,范围为 [ − 2 , 2 ] [-2, 2] [2,2]

3.2 实现代码

以下是使用PyTorch实现的DDPG算法在Pendulum-v0环境中的部分代码。

# 经验回放池
class ReplayBuffer:def __init__(self, buffer_size, batch_size, seed):self.memory = deque(maxlen=buffer_size)self.batch_size = batch_sizeself.seed = random.seed(seed)def add(self, state, action, reward, next_state, done):self.memory.append((state, action, reward, next_state, done))def sample(self):experiences = random.sample(self.memory, k=self.batch_size)states = torch.FloatTensor([e[0] for e in experiences]).to(device)actions = torch.FloatTensor([e[1] for e in experiences]).to(device)rewards = torch.FloatTensor([e[2] for e in experiences]).unsqueeze(1).to(device)next_states = torch.FloatTensor([e[3] for e in experiences]).to(device)dones = torch.FloatTensor([float(e[4]) for e in experiences]).unsqueeze(1).to(device)return states, actions, rewards, next_states, donesdef __len__(self):return len(self.memory)# 神经网络定义
def hidden_init(layer):fan_in = layer.weight.data.size()[0]lim = 1. / np.sqrt(fan_in)return (-lim, lim)class PolicyNet(nn.Module):def __init__(self, state_dim, hidden_dim, action_dim, action_bound):super(PolicyNet, self).__init__()self.fc1 = nn.Linear(state_dim, hidden_dim)self.fc2 = nn.Linear(hidden_dim, action_dim)self.action_bound = action_bound  # 动作最大值# 初始化权重self.fc1.weight.data.uniform_(*hidden_init(self.fc1))self.fc2.weight.data.uniform_(-3e-3, 3e-3)def forward(self, x):x = F.relu(self.fc1(x))return torch.tanh(self.fc2(x)) * self.action_boundclass QValueNet(nn.Module):def __init__(self, state_dim, hidden_dim, action_dim):super(QValueNet, self).__init__()self.fc1 = nn.Linear(state_dim + action_dim, hidden_dim)self.fc2 = nn.Linear(hidden_dim, hidden_dim)self.fc_out = nn.Linear(hidden_dim, 1)# 初始化权重self.fc1.weight.data.uniform_(*hidden_init(self.fc1))self.fc2.weight.data.uniform_(*hidden_init(self.fc2))self.fc_out.weight.data.uniform_(-3e-3, 3e-3)def forward(self, x, a):cat = torch.cat([x, a], dim=1)  # 拼接状态和动作x = F.relu(self.fc1(cat))x = F.relu(self.fc2(x))return self.fc_out(x)# DDPG智能体
class DDPGAgent:''' DDPG算法 '''def __init__(self, state_dim, hidden_dim, action_dim, action_bound, sigma, actor_lr, critic_lr, tau, gamma, device):self.actor = PolicyNet(state_dim, hidden_dim, action_dim, action_bound).to(device)self.critic = QValueNet(state_dim, hidden_dim, action_dim).to(device)self.target_actor = PolicyNet(state_dim, hidden_dim, action_dim, action_bound).to(device)self.target_critic = QValueNet(state_dim, hidden_dim, action_dim).to(device)# 初始化目标网络并设置和主网络相同的参数self.target_critic.load_state_dict(self.critic.state_dict())self.target_actor.load_state_dict(self.actor.state_dict())self.actor_optimizer = optim.Adam(self.actor.parameters(), lr=actor_lr)self.critic_optimizer = optim.Adam(self.critic.parameters(), lr=critic_lr, weight_decay=WEIGHT_DECAY)self.gamma = gammaself.sigma = sigma  # 高斯噪声的标准差self.tau = tau  # 目标网络软更新参数self.action_dim = action_dimself.device = devicedef take_action(self, state):state = torch.tensor([state], dtype=torch.float).to(self.device)self.actor.eval()with torch.no_grad():action = self.actor(state).cpu().data.numpy().flatten()self.actor.train()# 给动作添加噪声,增加探索action += self.sigma * np.random.randn(self.action_dim)return np.clip(action, -self.actor.action_bound, self.actor.action_bound)def soft_update(self, net, target_net):for target_param, param in zip(target_net.parameters(), net.parameters()):target_param.data.copy_(param.data * self.tau + target_param.data * (1.0 - self.tau))def update(self, replay_buffer):if len(replay_buffer) < BATCH_SIZE:returnstates, actions, rewards, next_states, dones = replay_buffer.sample()# 更新Critic网络with torch.no_grad():next_actions = self.target_actor(next_states)Q_targets_next = self.target_critic(next_states, next_actions)Q_targets = rewards + (self.gamma * Q_targets_next * (1 - dones))Q_expected = self.critic(states, actions)critic_loss = F.mse_loss(Q_expected, Q_targets)self.critic_optimizer.zero_grad()critic_loss.backward()self.critic_optimizer.step()# 更新Actor网络actor_loss = -torch.mean(self.critic(states, self.actor(states)))self.actor_optimizer.zero_grad()actor_loss.backward()self.actor_optimizer.step()# 软更新目标网络self.soft_update(self.critic, self.target_critic)self.soft_update(self.actor, self.target_actor)

3.3 运行结果

Episode 10	Average Score: -1623.12
Episode 20	Average Score: -1536.40
Episode 30	Average Score: -1287.98
Episode 40	Average Score: -1021.30
Episode 50	Average Score: -995.55
Episode 60	Average Score: -401.11
Episode 70	Average Score: -311.09
Episode 80	Average Score: -433.98
Episode 90	Average Score: -122.43
Episode 100	Average Score: -125.27
Episode 110	Average Score: -122.54
Episode 120	Average Score: -122.86
Episode 130	Average Score: -122.51
Episode 140	Average Score: -123.11
Episode 150	Average Score: -122.93
Episode 160	Average Score: -127.22
Episode 170	Average Score: -146.53
Episode 180	Average Score: -138.31
Episode 190	Average Score: -119.34
Episode 200	Average Score: -118.65

Pendulum-v0环境中,DDPG智能体经过200个回合的训练后,奖励曲线应逐渐上升,表明智能体的策略在不断优化。滑动平均曲线更平滑,能够更清晰地反映训练趋势。

四、总结

DDPG算法通过结合演员-评论家架构、经验回放和目标网络等技术,有效地解决了连续动作空间中的强化学习问题。在Pendulum-v0环境中的应用展示了其强大的学习能力和策略优化效果。随着研究的深入,DDPG及其衍生算法在更多复杂任务中的应用前景广阔。

相关文章:

强化学习之DDPG算法

前言&#xff1a; 在正文开始之前&#xff0c;首先给大家介绍一个不错的人工智能学习教程&#xff1a;https://www.captainbed.cn/bbs。其中包含了机器学习、深度学习、强化学习等系列教程&#xff0c;感兴趣的读者可以自行查阅。 一、算法介绍 深度确定性策略梯度 &#xff0…...

【进阶OpenCV】 (16)-- 人脸识别 -- FisherFaces算法

文章目录 FisherFaces算法一、算法原理二、算法优势与局限三、算法实现1. 图像预处理2. 创建FisherFace人脸特征识别器3. 训练模型4. 测试图像 总结 FisherFaces算法 PCA方法是EigenFaces人脸识别的核心&#xff0c;但是其具有明显的缺点&#xff0c;在操作过程中会损失许多人…...

电脑主机配置

显卡&#xff1a; 查看显卡&#xff1a;设备管理器--显示适配器 RTX4060 RTX和GTX区别&#xff1a; GTX是NVIDIA公司旧款显卡&#xff0c;RTX比GTX好但是贵 处理器CPU&#xff1a; Intel(R) Core(TM) i5-10400F CPU 2.90GHz 2.90 GHz 10400F&#xff1a;10指的是第几代…...

图书借阅小程序开源独立版

图书借阅微信小程序&#xff0c;多书馆切换模式&#xff0c;书馆一键同步图书信息&#xff0c;开通会员即可在线借书&#xff0c;一书一码书馆员工手机扫码出入库从会员到书馆每一步信息把控图书借阅小程序&#xff0c;让阅读触手可及在这个快节奏的时代&#xff0c;你是否渴望…...

flutter TextField限制中文,ios自带中文输入法变英文输入问题解决

由于业务需求&#xff0c;要限制TextField只能输入中文&#xff0c;但是测试在iOS测试机发现自带中文输入法会变英文输入问题&#xff0c;安卓没有问题&#xff0c;并且只有iOS自带输入法有问题&#xff0c;搜狗等输入法没问题。我们目前使用flutter2.5.3版本&#xff0c;高版本…...

ThreadLocal的应用场景

ThreadLocal介绍 ThreadLocal为每个线程都提供了变量的副本&#xff0c;使得每个线程访问各自独立的对象&#xff0c;这样就隔离了多个线程对数据的共享&#xff0c;使得线程安全。ThreadLocal有如下方法&#xff1a; 方法声明 描述public void set(T value)设置当前线程绑定的…...

Python--plt.errorbar学习笔记

plt.errorbar 是 Matplotlib 库中的一个函数&#xff0c;用于绘制带有误差条的图形。下面给出的代码行的详细解释&#xff1a; import numpy as np from scipy.special import kv, erfc from scipy.integrate import dblquad import matplotlib.pyplot as plt import scipy.in…...

文件信息类QFileInfo

常用方法&#xff1a; 构造函数 //参数&#xff1a;文件的绝对路径或相对路径 [explicit] QFileInfo::QFileInfo(const QString &path) 设置文件路径 可构造一个空的QFileInfo的对象&#xff0c;然后设置路径 //参数&#xff1a;文件的绝对路径或相对路径 void QFileI…...

堆排序(C++实现)

参考&#xff1a; 面试官&#xff1a;请写一个堆排序_哔哩哔哩_bilibiliC实现排序算法_c从小到大排序-CSDN博客 堆的基本概念 堆排实际上是利用堆的性质来进行排序。堆可以看做一颗完全二叉树。 堆分为两类&#xff1a; 最大堆&#xff08;大顶堆&#xff09;&#xff1a;除根…...

Qt中加入UI文件

将 UI 文件整合到 Qt 项目 使用 Qt Designer 创建 UI 文件&#xff1a; 在 Qt Creator 中使用 Qt Designer 创建 UI 文件&#xff0c;设计所需的界面。确保在设计中包含所需的控件&#xff08;如按钮、文本框等&#xff09;&#xff0c;并为每个控件设置明确的对象名称&#xf…...

Redisson使用全解

redisson使用全解——redisson官方文档注释&#xff08;上篇&#xff09;_redisson官网中文-CSDN博客 redisson使用全解——redisson官方文档注释&#xff08;中篇&#xff09;-CSDN博客 redisson使用全解——redisson官方文档注释&#xff08;下篇&#xff09;_redisson官网…...

Go4 和对 Go 的贡献

本篇内容是根据2017年4月份Go4 and Contributing to Go音频录制内容的整理与翻译, Brad Fitzpatrick 加入节目谈论成为开源 Go 的代言人、让社区参与 bug 分类、Go 的潜在未来以及其他有趣的 Go 项目和新闻。 过程中为符合中文惯用表达有适当删改, 版权归原作者所有. Erik St…...

区间动态规划

区间动态规划&#xff08;Interval DP&#xff09;是动态规划的一种重要变种&#xff0c;特别适用于解决一类具有区间性质的问题。典型的应用场景是给定一个区间&#xff0c;要求我们在满足某些条件下进行最优划分或合并。本文将从区间DP的基本思想、常见问题模型以及算法实现几…...

什么情况下需要使用电压探头

高压探头是一种专门设计用于测量高压电路或设备的探头&#xff0c;其作用是在电路测试和测量中提供安全、准确的信号捕获&#xff0c;并确保操作人员的安全。这些探头通常用于测量高压电源、变压器、电力系统、医疗设备以及其他需要处理高电压的设备或系统。 而高压差分探头差分…...

数据结构——八大排序(下)

数据结构中的八大排序算法是计算机科学领域经典的排序方法&#xff0c;它们各自具有不同的特点和适用场景。以下是这八大排序算法的详细介绍&#xff1a; 五、选择排序&#xff08;Selection Sort&#xff09; 核心思想&#xff1a;每一轮从未排序的元素中选择最小&#xff0…...

Linux系统:Ubuntu上安装Chrome浏览器

Ubuntu系统版本&#xff1a;23.04 在Ubuntu系统上安装Google Chrome浏览器&#xff0c;可以通过以下步骤进行&#xff1a; 终端输入以下命令&#xff0c;先更新软件源&#xff1a; sudo apt update 或 sudo apt upgrade终端输入以下命令&#xff0c;下载最新的Google Chrome .…...

Redis位图BitMap

一、为什么使用位图&#xff1f; 使用位图能有效实现 用户签到 等行为&#xff0c;用数据库表记录签到&#xff0c;将占用很多存储&#xff1b;但使用 位图BitMap&#xff0c;就能 大大减少存储占用 二、关于位图 本质上是String类型&#xff0c;最小长度8位&#xff08;一个字…...

YOLOv11改进策略【卷积层】| ParNet 即插即用模块 二次创新C3k2

一、本文介绍 本文记录的是利用ParNet中的基础模块优化YOLOv11的目标检测网络模型。 ParNet block是一个即插即用模块,能够在不增加深度的情况下增加感受野,更好地处理图像中的不同尺度特征,有助于网络对输入数据更全面地理解和学习,从而提升网络的特征提取能力和分类性能…...

学习threejs,网格深度材质MeshDepthMaterial

&#x1f468;‍⚕️ 主页&#xff1a; gis分享者 &#x1f468;‍⚕️ 感谢各位大佬 点赞&#x1f44d; 收藏⭐ 留言&#x1f4dd; 加关注✅! &#x1f468;‍⚕️ 收录于专栏&#xff1a;threejs gis工程师 文章目录 一、&#x1f340;前言1.1 ☘️网格深度材质MeshDepthMate…...

算法时间、空间复杂度(二)

目录 大O渐进表示法 一、时间复杂度量级的判断 定义&#xff1a; 例一&#xff1a;执行2*N&#xff0b;1次 例二&#xff1a;执行MN次 例三&#xff1a;执行已知次数 例四:存在最好情况和最坏情况 顺序查找 冒泡排序 二分查找 例五&#xff1a;阶乘递归 ​编辑 例…...

高级java每日一道面试题-2024年10月11日-数据库篇[Redis篇]-Redis都有哪些使用场景?

如果有遗漏,评论区告诉我进行补充 面试官: Redis都有哪些使用场景? 我回答: Redis 是一个开源的、基于键值对的数据结构存储系统&#xff0c;&#xff0c;它支持多种数据类型&#xff0c;包括字符串、散列、列表、集合和有序集合。它可以用作数据库、缓存和消息中间件。由于…...

0047__【python打包分发工具】setuptools详解

【python打包分发工具】setuptools详解-CSDN博客...

自定义拦截器处理token

目录 1、WebConfig 配置类 2、TSUserContext 把用户信息放到context中 3、自定义拦截器 4、在controller中可以使用 5、参考链接 1、WebConfig 配置类 @Configuration public class WebConfig implements WebMvcConfigurer {@Autowiredprivate AccessControlInterceptor …...

Scrapy | 使用Scrapy进行数据建模和请求

scrapy数据建模与请求 数据建模1.1 为什么建模1.2 如何建模1.3如何使用模板类1.4 开发流程总结 目标&#xff1a; 1.应用在scrapy项目中进行建模 2.应用构造Request对象&#xff0c;并发送请求 3.应用利用meta参数在不同的解析函数中传递数据 数据建模 | 通常在做项目的过程中…...

学习笔记——交换——STP(生成树)基本概念

三、基本概念 1、桥ID/网桥ID (Bridege ID&#xff0c;BID) 每一台运行STP的交换机都拥有一个唯一的桥ID(BID)&#xff0c;BID(Bridge ID/桥ID)。在STP里我们使用不同的桥ID标识不同的交换机。 (2)BID(桥ID)组成 BID(桥ID)组成(8个字节)&#xff1a;由16位(2字节)的桥优先级…...

机器学习笔记-2

文章目录 一、Linear model二、How to represent this function三、Function with unknown parameter四、ReLU总结、A fancy name 一、Linear model 线性模型过于简单&#xff0c;有很大限制&#xff0c;我们需要更多复杂模式 蓝色是线性模型&#xff0c;线性模型无法去表示…...

SpringSecurity(一)——认证实现

一、初步理解 SpringSecurity的原理其实就是一个过滤器链&#xff0c;内部包含了提供各种功能的过滤器。 当前系统中SpringSecurity过滤器链中有哪些过滤器及它们的顺序。 核心过滤器&#xff1a; &#xff08;认证&#xff09;UsernamePasswordAuthenticationFilter:负责处理…...

VMWare NAT 模式下 虚拟机上不了网原因排查

vmware 按照了Linux之后 无法上网&#xff0c;搞定后&#xff0c;记录一些信息。 window有两个虚拟网卡 VMnet1 对应的是 Host-Only&#xff08;仅主机模式&#xff09; VMnet8 对应的是 NAT&#xff08;网络地址转换模式&#xff09; 在NAT模式中&#xff0c;需要设置NAT和D…...

R语言手工实现主成分分析 PCA | 奇异值分解(svd) 与PCA | PCA的疑问和解答

几个问题: pca可以用相关系数矩阵做吗?效果比协方差矩阵比怎么样?pca做完后变量和样本的新坐标怎么旋转获得?pca做不做scale和center对结果有影响吗?pca用因子分解和奇异值分解有啥区别?后者怎么获得变量和样本的新坐标?1. 用R全手工实现 PCA(对比 prcomp() ) 不借助包…...

第三届OpenHarmony技术大会在上海成功举办

10月12日&#xff0c;以“技术引领筑生态&#xff0c;万物智联创未来”为主题的第三届OpenHarmony技术大会&#xff08;以下简称“大会”&#xff09;在上海成功举办。本次大会由OpenAtom OpenHarmony&#xff08;以下简称“OpenHarmony”&#xff09;项目群技术指导委员会&…...