当前位置: 首页 > news >正文

强化学习之DDPG算法

前言:
在正文开始之前,首先给大家介绍一个不错的人工智能学习教程:https://www.captainbed.cn/bbs。其中包含了机器学习、深度学习、强化学习等系列教程,感兴趣的读者可以自行查阅。


一、算法介绍

深度确定性策略梯度 (Deep Deterministic Policy Gradient,简称DDPG) 算法是一种基于策略梯度的方法,结合了深度神经网络和确定性策略的优势。它特别适用于具有连续动作空间的控制任务,如机械臂控制、自动驾驶等。DDPG算法通过同时训练一个演员网络(Actor)和一个评论家网络(Critic),实现对策略的优化。

主要特点包括:

  • 确定性策略:与随机策略不同,DDPG使用确定性策略,直接输出给定状态下的最优动作。
  • 经验回放(Replay Buffer):通过存储经验样本,打破样本间的相关性,提升训练稳定性。
  • 目标网络(Target Networks):使用延迟更新的目标网络,减少训练过程中的震荡和不稳定。

二、算法原理

2.1 网络结构

DDPG算法由两个主要网络组成:

  • 演员网络(Actor):参数为 θ μ \theta^\mu θμ,用于确定性地选择动作。

    a = μ ( s ∣ θ μ ) a = \mu(s|\theta^\mu) a=μ(sθμ)

  • 评论家网络(Critic):参数为 θ Q \theta^Q θQ,用于估计给定状态-动作对的Q值。

    Q ( s , a ∣ θ Q ) Q(s,a|\theta^Q) Q(s,aθQ)

此外,还存在两个目标网络,分别对应演员和评论家网络,参数为 θ μ ′ \theta^{\mu'} θμ θ Q ′ \theta^{Q'} θQ,用于计算目标Q值。

2.2 经验回放

经验回放池 D \mathcal{D} D用于存储经验元组 ( s t , a t , r t , s t + 1 ) (s_t, a_t, r_t, s_{t+1}) (st,at,rt,st+1)。在每次训练迭代中,算法从 D \mathcal{D} D中随机采样一个小批量样本,打破数据间的相关性,提高训练效率和稳定性。

2.3 目标网络的更新

目标网络的参数通过软更新方式更新:

θ μ ′ ← τ θ μ + ( 1 − τ ) θ μ ′ \theta^{\mu'} \leftarrow \tau \theta^\mu + (1 - \tau) \theta^{\mu'} θμτθμ+(1τ)θμ

θ Q ′ ← τ θ Q + ( 1 − τ ) θ Q ′ \theta^{Q'} \leftarrow \tau \theta^Q + (1 - \tau) \theta^{Q'} θQτθQ+(1τ)θQ

其中, τ \tau τ是软更新的步长,通常取值较小,如 0.001 0.001 0.001

2.4 损失函数与优化

  • 评论家网络的损失函数采用均方误差(MSE):

    L = 1 N ∑ i = 1 N ( y i − Q ( s i , a i ∣ θ Q ) ) 2 L = \frac{1}{N} \sum_{i=1}^N \left( y_i - Q(s_i, a_i|\theta^Q) \right)^2 L=N1i=1N(yiQ(si,aiθQ))2

    其中,

    y i = r i + γ Q ′ ( s i + 1 , μ ′ ( s i + 1 ∣ θ μ ′ ) ∣ θ Q ′ ) y_i = r_i + \gamma Q'(s_{i+1}, \mu'(s_{i+1}|\theta^{\mu'})|\theta^{Q'}) yi=ri+γQ(si+1,μ(si+1θμ)θQ)

  • 演员网络的损失函数通过最大化Q值来优化策略:

    J = − 1 N ∑ i = 1 N Q ( s i , μ ( s i ∣ θ μ ) ∣ θ Q ) J = -\frac{1}{N} \sum_{i=1}^N Q(s_i, \mu(s_i|\theta^\mu)|\theta^Q) J=N1i=1NQ(si,μ(siθμ)θQ)

2.5 算法流程

  1. 初始化演员网络 μ ( s ∣ θ μ ) \mu(s|\theta^\mu) μ(sθμ)和评论家网络 Q ( s , a ∣ θ Q ) Q(s,a|\theta^Q) Q(s,aθQ),以及对应的目标网络 μ ′ \mu' μ Q ′ Q' Q
  2. 初始化经验回放池 D \mathcal{D} D
  3. 对于每个回合:
    • 在环境中选择动作 a t = μ ( s t ∣ θ μ ) + N t a_t = \mu(s_t|\theta^\mu) + \mathcal{N}_t at=μ(stθμ)+Nt,其中 N t \mathcal{N}_t Nt为噪声,用于探索。
    • 执行动作 a t a_t at,观察奖励 r t r_t rt和下一个状态 s t + 1 s_{t+1} st+1
    • 存储经验 ( s t , a t , r t , s t + 1 ) (s_t, a_t, r_t, s_{t+1}) (st,at,rt,st+1) D \mathcal{D} D
    • D \mathcal{D} D中随机采样一个小批量样本。
    • 计算目标Q值 y i y_i yi
    • 更新评论家网络参数 θ Q \theta^Q θQ,最小化损失 L L L
    • 更新演员网络参数 θ μ \theta^\mu θμ,最大化 J J J
    • 软更新目标网络参数 θ μ ′ \theta^{\mu'} θμ θ Q ′ \theta^{Q'} θQ
  4. 重复以上步骤,直至收敛。

三、案例分析

在本节中,我们将通过在Pendulum-v0环境中应用DDPG算法,展示其具体实现过程。该环境的目标是让倒立摆尽可能长时间地保持直立状态,涉及连续动作空间。

3.1 环境简介

  • 状态空间:摆锤的角度、角速度,共3个维度。
  • 动作空间:施加的力矩,范围为 [ − 2 , 2 ] [-2, 2] [2,2]

3.2 实现代码

以下是使用PyTorch实现的DDPG算法在Pendulum-v0环境中的部分代码。

# 经验回放池
class ReplayBuffer:def __init__(self, buffer_size, batch_size, seed):self.memory = deque(maxlen=buffer_size)self.batch_size = batch_sizeself.seed = random.seed(seed)def add(self, state, action, reward, next_state, done):self.memory.append((state, action, reward, next_state, done))def sample(self):experiences = random.sample(self.memory, k=self.batch_size)states = torch.FloatTensor([e[0] for e in experiences]).to(device)actions = torch.FloatTensor([e[1] for e in experiences]).to(device)rewards = torch.FloatTensor([e[2] for e in experiences]).unsqueeze(1).to(device)next_states = torch.FloatTensor([e[3] for e in experiences]).to(device)dones = torch.FloatTensor([float(e[4]) for e in experiences]).unsqueeze(1).to(device)return states, actions, rewards, next_states, donesdef __len__(self):return len(self.memory)# 神经网络定义
def hidden_init(layer):fan_in = layer.weight.data.size()[0]lim = 1. / np.sqrt(fan_in)return (-lim, lim)class PolicyNet(nn.Module):def __init__(self, state_dim, hidden_dim, action_dim, action_bound):super(PolicyNet, self).__init__()self.fc1 = nn.Linear(state_dim, hidden_dim)self.fc2 = nn.Linear(hidden_dim, action_dim)self.action_bound = action_bound  # 动作最大值# 初始化权重self.fc1.weight.data.uniform_(*hidden_init(self.fc1))self.fc2.weight.data.uniform_(-3e-3, 3e-3)def forward(self, x):x = F.relu(self.fc1(x))return torch.tanh(self.fc2(x)) * self.action_boundclass QValueNet(nn.Module):def __init__(self, state_dim, hidden_dim, action_dim):super(QValueNet, self).__init__()self.fc1 = nn.Linear(state_dim + action_dim, hidden_dim)self.fc2 = nn.Linear(hidden_dim, hidden_dim)self.fc_out = nn.Linear(hidden_dim, 1)# 初始化权重self.fc1.weight.data.uniform_(*hidden_init(self.fc1))self.fc2.weight.data.uniform_(*hidden_init(self.fc2))self.fc_out.weight.data.uniform_(-3e-3, 3e-3)def forward(self, x, a):cat = torch.cat([x, a], dim=1)  # 拼接状态和动作x = F.relu(self.fc1(cat))x = F.relu(self.fc2(x))return self.fc_out(x)# DDPG智能体
class DDPGAgent:''' DDPG算法 '''def __init__(self, state_dim, hidden_dim, action_dim, action_bound, sigma, actor_lr, critic_lr, tau, gamma, device):self.actor = PolicyNet(state_dim, hidden_dim, action_dim, action_bound).to(device)self.critic = QValueNet(state_dim, hidden_dim, action_dim).to(device)self.target_actor = PolicyNet(state_dim, hidden_dim, action_dim, action_bound).to(device)self.target_critic = QValueNet(state_dim, hidden_dim, action_dim).to(device)# 初始化目标网络并设置和主网络相同的参数self.target_critic.load_state_dict(self.critic.state_dict())self.target_actor.load_state_dict(self.actor.state_dict())self.actor_optimizer = optim.Adam(self.actor.parameters(), lr=actor_lr)self.critic_optimizer = optim.Adam(self.critic.parameters(), lr=critic_lr, weight_decay=WEIGHT_DECAY)self.gamma = gammaself.sigma = sigma  # 高斯噪声的标准差self.tau = tau  # 目标网络软更新参数self.action_dim = action_dimself.device = devicedef take_action(self, state):state = torch.tensor([state], dtype=torch.float).to(self.device)self.actor.eval()with torch.no_grad():action = self.actor(state).cpu().data.numpy().flatten()self.actor.train()# 给动作添加噪声,增加探索action += self.sigma * np.random.randn(self.action_dim)return np.clip(action, -self.actor.action_bound, self.actor.action_bound)def soft_update(self, net, target_net):for target_param, param in zip(target_net.parameters(), net.parameters()):target_param.data.copy_(param.data * self.tau + target_param.data * (1.0 - self.tau))def update(self, replay_buffer):if len(replay_buffer) < BATCH_SIZE:returnstates, actions, rewards, next_states, dones = replay_buffer.sample()# 更新Critic网络with torch.no_grad():next_actions = self.target_actor(next_states)Q_targets_next = self.target_critic(next_states, next_actions)Q_targets = rewards + (self.gamma * Q_targets_next * (1 - dones))Q_expected = self.critic(states, actions)critic_loss = F.mse_loss(Q_expected, Q_targets)self.critic_optimizer.zero_grad()critic_loss.backward()self.critic_optimizer.step()# 更新Actor网络actor_loss = -torch.mean(self.critic(states, self.actor(states)))self.actor_optimizer.zero_grad()actor_loss.backward()self.actor_optimizer.step()# 软更新目标网络self.soft_update(self.critic, self.target_critic)self.soft_update(self.actor, self.target_actor)

3.3 运行结果

Episode 10	Average Score: -1623.12
Episode 20	Average Score: -1536.40
Episode 30	Average Score: -1287.98
Episode 40	Average Score: -1021.30
Episode 50	Average Score: -995.55
Episode 60	Average Score: -401.11
Episode 70	Average Score: -311.09
Episode 80	Average Score: -433.98
Episode 90	Average Score: -122.43
Episode 100	Average Score: -125.27
Episode 110	Average Score: -122.54
Episode 120	Average Score: -122.86
Episode 130	Average Score: -122.51
Episode 140	Average Score: -123.11
Episode 150	Average Score: -122.93
Episode 160	Average Score: -127.22
Episode 170	Average Score: -146.53
Episode 180	Average Score: -138.31
Episode 190	Average Score: -119.34
Episode 200	Average Score: -118.65

Pendulum-v0环境中,DDPG智能体经过200个回合的训练后,奖励曲线应逐渐上升,表明智能体的策略在不断优化。滑动平均曲线更平滑,能够更清晰地反映训练趋势。

四、总结

DDPG算法通过结合演员-评论家架构、经验回放和目标网络等技术,有效地解决了连续动作空间中的强化学习问题。在Pendulum-v0环境中的应用展示了其强大的学习能力和策略优化效果。随着研究的深入,DDPG及其衍生算法在更多复杂任务中的应用前景广阔。

相关文章:

强化学习之DDPG算法

前言&#xff1a; 在正文开始之前&#xff0c;首先给大家介绍一个不错的人工智能学习教程&#xff1a;https://www.captainbed.cn/bbs。其中包含了机器学习、深度学习、强化学习等系列教程&#xff0c;感兴趣的读者可以自行查阅。 一、算法介绍 深度确定性策略梯度 &#xff0…...

【进阶OpenCV】 (16)-- 人脸识别 -- FisherFaces算法

文章目录 FisherFaces算法一、算法原理二、算法优势与局限三、算法实现1. 图像预处理2. 创建FisherFace人脸特征识别器3. 训练模型4. 测试图像 总结 FisherFaces算法 PCA方法是EigenFaces人脸识别的核心&#xff0c;但是其具有明显的缺点&#xff0c;在操作过程中会损失许多人…...

电脑主机配置

显卡&#xff1a; 查看显卡&#xff1a;设备管理器--显示适配器 RTX4060 RTX和GTX区别&#xff1a; GTX是NVIDIA公司旧款显卡&#xff0c;RTX比GTX好但是贵 处理器CPU&#xff1a; Intel(R) Core(TM) i5-10400F CPU 2.90GHz 2.90 GHz 10400F&#xff1a;10指的是第几代…...

图书借阅小程序开源独立版

图书借阅微信小程序&#xff0c;多书馆切换模式&#xff0c;书馆一键同步图书信息&#xff0c;开通会员即可在线借书&#xff0c;一书一码书馆员工手机扫码出入库从会员到书馆每一步信息把控图书借阅小程序&#xff0c;让阅读触手可及在这个快节奏的时代&#xff0c;你是否渴望…...

flutter TextField限制中文,ios自带中文输入法变英文输入问题解决

由于业务需求&#xff0c;要限制TextField只能输入中文&#xff0c;但是测试在iOS测试机发现自带中文输入法会变英文输入问题&#xff0c;安卓没有问题&#xff0c;并且只有iOS自带输入法有问题&#xff0c;搜狗等输入法没问题。我们目前使用flutter2.5.3版本&#xff0c;高版本…...

ThreadLocal的应用场景

ThreadLocal介绍 ThreadLocal为每个线程都提供了变量的副本&#xff0c;使得每个线程访问各自独立的对象&#xff0c;这样就隔离了多个线程对数据的共享&#xff0c;使得线程安全。ThreadLocal有如下方法&#xff1a; 方法声明 描述public void set(T value)设置当前线程绑定的…...

Python--plt.errorbar学习笔记

plt.errorbar 是 Matplotlib 库中的一个函数&#xff0c;用于绘制带有误差条的图形。下面给出的代码行的详细解释&#xff1a; import numpy as np from scipy.special import kv, erfc from scipy.integrate import dblquad import matplotlib.pyplot as plt import scipy.in…...

文件信息类QFileInfo

常用方法&#xff1a; 构造函数 //参数&#xff1a;文件的绝对路径或相对路径 [explicit] QFileInfo::QFileInfo(const QString &path) 设置文件路径 可构造一个空的QFileInfo的对象&#xff0c;然后设置路径 //参数&#xff1a;文件的绝对路径或相对路径 void QFileI…...

堆排序(C++实现)

参考&#xff1a; 面试官&#xff1a;请写一个堆排序_哔哩哔哩_bilibiliC实现排序算法_c从小到大排序-CSDN博客 堆的基本概念 堆排实际上是利用堆的性质来进行排序。堆可以看做一颗完全二叉树。 堆分为两类&#xff1a; 最大堆&#xff08;大顶堆&#xff09;&#xff1a;除根…...

Qt中加入UI文件

将 UI 文件整合到 Qt 项目 使用 Qt Designer 创建 UI 文件&#xff1a; 在 Qt Creator 中使用 Qt Designer 创建 UI 文件&#xff0c;设计所需的界面。确保在设计中包含所需的控件&#xff08;如按钮、文本框等&#xff09;&#xff0c;并为每个控件设置明确的对象名称&#xf…...

Redisson使用全解

redisson使用全解——redisson官方文档注释&#xff08;上篇&#xff09;_redisson官网中文-CSDN博客 redisson使用全解——redisson官方文档注释&#xff08;中篇&#xff09;-CSDN博客 redisson使用全解——redisson官方文档注释&#xff08;下篇&#xff09;_redisson官网…...

Go4 和对 Go 的贡献

本篇内容是根据2017年4月份Go4 and Contributing to Go音频录制内容的整理与翻译, Brad Fitzpatrick 加入节目谈论成为开源 Go 的代言人、让社区参与 bug 分类、Go 的潜在未来以及其他有趣的 Go 项目和新闻。 过程中为符合中文惯用表达有适当删改, 版权归原作者所有. Erik St…...

区间动态规划

区间动态规划&#xff08;Interval DP&#xff09;是动态规划的一种重要变种&#xff0c;特别适用于解决一类具有区间性质的问题。典型的应用场景是给定一个区间&#xff0c;要求我们在满足某些条件下进行最优划分或合并。本文将从区间DP的基本思想、常见问题模型以及算法实现几…...

什么情况下需要使用电压探头

高压探头是一种专门设计用于测量高压电路或设备的探头&#xff0c;其作用是在电路测试和测量中提供安全、准确的信号捕获&#xff0c;并确保操作人员的安全。这些探头通常用于测量高压电源、变压器、电力系统、医疗设备以及其他需要处理高电压的设备或系统。 而高压差分探头差分…...

数据结构——八大排序(下)

数据结构中的八大排序算法是计算机科学领域经典的排序方法&#xff0c;它们各自具有不同的特点和适用场景。以下是这八大排序算法的详细介绍&#xff1a; 五、选择排序&#xff08;Selection Sort&#xff09; 核心思想&#xff1a;每一轮从未排序的元素中选择最小&#xff0…...

Linux系统:Ubuntu上安装Chrome浏览器

Ubuntu系统版本&#xff1a;23.04 在Ubuntu系统上安装Google Chrome浏览器&#xff0c;可以通过以下步骤进行&#xff1a; 终端输入以下命令&#xff0c;先更新软件源&#xff1a; sudo apt update 或 sudo apt upgrade终端输入以下命令&#xff0c;下载最新的Google Chrome .…...

Redis位图BitMap

一、为什么使用位图&#xff1f; 使用位图能有效实现 用户签到 等行为&#xff0c;用数据库表记录签到&#xff0c;将占用很多存储&#xff1b;但使用 位图BitMap&#xff0c;就能 大大减少存储占用 二、关于位图 本质上是String类型&#xff0c;最小长度8位&#xff08;一个字…...

YOLOv11改进策略【卷积层】| ParNet 即插即用模块 二次创新C3k2

一、本文介绍 本文记录的是利用ParNet中的基础模块优化YOLOv11的目标检测网络模型。 ParNet block是一个即插即用模块,能够在不增加深度的情况下增加感受野,更好地处理图像中的不同尺度特征,有助于网络对输入数据更全面地理解和学习,从而提升网络的特征提取能力和分类性能…...

学习threejs,网格深度材质MeshDepthMaterial

&#x1f468;‍⚕️ 主页&#xff1a; gis分享者 &#x1f468;‍⚕️ 感谢各位大佬 点赞&#x1f44d; 收藏⭐ 留言&#x1f4dd; 加关注✅! &#x1f468;‍⚕️ 收录于专栏&#xff1a;threejs gis工程师 文章目录 一、&#x1f340;前言1.1 ☘️网格深度材质MeshDepthMate…...

算法时间、空间复杂度(二)

目录 大O渐进表示法 一、时间复杂度量级的判断 定义&#xff1a; 例一&#xff1a;执行2*N&#xff0b;1次 例二&#xff1a;执行MN次 例三&#xff1a;执行已知次数 例四:存在最好情况和最坏情况 顺序查找 冒泡排序 二分查找 例五&#xff1a;阶乘递归 ​编辑 例…...

DAY 47

三、通道注意力 3.1 通道注意力的定义 # 新增&#xff1a;通道注意力模块&#xff08;SE模块&#xff09; class ChannelAttention(nn.Module):"""通道注意力模块(Squeeze-and-Excitation)"""def __init__(self, in_channels, reduction_rat…...

oracle与MySQL数据库之间数据同步的技术要点

Oracle与MySQL数据库之间的数据同步是一个涉及多个技术要点的复杂任务。由于Oracle和MySQL的架构差异&#xff0c;它们的数据同步要求既要保持数据的准确性和一致性&#xff0c;又要处理好性能问题。以下是一些主要的技术要点&#xff1a; 数据结构差异 数据类型差异&#xff…...

【论文笔记】若干矿井粉尘检测算法概述

总的来说&#xff0c;传统机器学习、传统机器学习与深度学习的结合、LSTM等算法所需要的数据集来源于矿井传感器测量的粉尘浓度&#xff0c;通过建立回归模型来预测未来矿井的粉尘浓度。传统机器学习算法性能易受数据中极端值的影响。YOLO等计算机视觉算法所需要的数据集来源于…...

leetcodeSQL解题:3564. 季节性销售分析

leetcodeSQL解题&#xff1a;3564. 季节性销售分析 题目&#xff1a; 表&#xff1a;sales ---------------------- | Column Name | Type | ---------------------- | sale_id | int | | product_id | int | | sale_date | date | | quantity | int | | price | decimal | -…...

智能仓储的未来:自动化、AI与数据分析如何重塑物流中心

当仓库学会“思考”&#xff0c;物流的终极形态正在诞生 想象这样的场景&#xff1a; 凌晨3点&#xff0c;某物流中心灯火通明却空无一人。AGV机器人集群根据实时订单动态规划路径&#xff1b;AI视觉系统在0.1秒内扫描包裹信息&#xff1b;数字孪生平台正模拟次日峰值流量压力…...

网络编程(UDP编程)

思维导图 UDP基础编程&#xff08;单播&#xff09; 1.流程图 服务器&#xff1a;短信的接收方 创建套接字 (socket)-----------------------------------------》有手机指定网络信息-----------------------------------------------》有号码绑定套接字 (bind)--------------…...

Kubernetes 网络模型深度解析:Pod IP 与 Service 的负载均衡机制,Service到底是什么?

Pod IP 的本质与特性 Pod IP 的定位 纯端点地址&#xff1a;Pod IP 是分配给 Pod 网络命名空间的真实 IP 地址&#xff08;如 10.244.1.2&#xff09;无特殊名称&#xff1a;在 Kubernetes 中&#xff0c;它通常被称为 “Pod IP” 或 “容器 IP”生命周期&#xff1a;与 Pod …...

深度学习之模型压缩三驾马车:模型剪枝、模型量化、知识蒸馏

一、引言 在深度学习中&#xff0c;我们训练出的神经网络往往非常庞大&#xff08;比如像 ResNet、YOLOv8、Vision Transformer&#xff09;&#xff0c;虽然精度很高&#xff0c;但“太重”了&#xff0c;运行起来很慢&#xff0c;占用内存大&#xff0c;不适合部署到手机、摄…...

小木的算法日记-多叉树的递归/层序遍历

&#x1f332; 从二叉树到森林&#xff1a;一文彻底搞懂多叉树遍历的艺术 &#x1f680; 引言 你好&#xff0c;未来的算法大神&#xff01; 在数据结构的世界里&#xff0c;“树”无疑是最核心、最迷人的概念之一。我们中的大多数人都是从 二叉树 开始入门的&#xff0c;它…...

如何配置一个sql server使得其它用户可以通过excel odbc获取数据

要让其他用户通过 Excel 使用 ODBC 连接到 SQL Server 获取数据&#xff0c;你需要完成以下配置步骤&#xff1a; ✅ 一、在 SQL Server 端配置&#xff08;服务器设置&#xff09; 1. 启用 TCP/IP 协议 打开 “SQL Server 配置管理器”。导航到&#xff1a;SQL Server 网络配…...