PyTorch 深度学习实战(14):Deep Deterministic Policy Gradient (DDPG) 算法
在上一篇文章中,我们介绍了 Proximal Policy Optimization (PPO) 算法,并使用它解决了 CartPole 问题。本文将深入探讨 Deep Deterministic Policy Gradient (DDPG) 算法,这是一种用于连续动作空间的强化学习算法。我们将使用 PyTorch 实现 DDPG 算法,并应用于经典的 Pendulum 问题。
一、DDPG 算法基础
DDPG 是一种基于 Actor-Critic 框架的算法,专门用于解决连续动作空间的强化学习问题。它结合了深度 Q 网络(DQN)和策略梯度方法的优点,能够高效地处理高维状态和动作空间。
1. DDPG 的核心思想
-
确定性策略:
-
DDPG 使用确定性策略(Deterministic Policy),即给定状态时,策略网络直接输出一个确定的动作,而不是动作的概率分布。
-
-
目标网络:
-
DDPG 使用目标网络(Target Network)来稳定训练过程,类似于 DQN 中的目标网络。
-
-
经验回放:
-
DDPG 使用经验回放缓冲区(Replay Buffer)来存储和重用过去的经验,从而提高数据利用率。
-
2. DDPG 的优势
-
适用于连续动作空间:
-
DDPG 能够直接输出连续动作,适用于机器人控制、自动驾驶等任务。
-
-
训练稳定:
-
通过目标网络和经验回放,DDPG 能够稳定地训练策略网络和价值网络。
-
-
高效采样:
-
DDPG 可以重复使用旧策略的采样数据,从而提高数据利用率。
-
3. DDPG 的算法流程
-
使用当前策略采样一批数据。
-
使用目标网络计算目标 Q 值。
-
更新 Critic 网络以最小化 Q 值的误差。
-
更新 Actor 网络以最大化 Q 值。
-
更新目标网络。
-
重复上述过程,直到策略收敛。
二、Pendulum 问题实战
我们将使用 PyTorch 实现 DDPG 算法,并应用于 Pendulum 问题。目标是控制摆杆使其保持直立。
1. 问题描述
Pendulum 环境的状态空间包括摆杆的角度和角速度。动作空间是一个连续的扭矩值,范围在 −2,2 之间。智能体每保持摆杆直立一步,就会获得一个负的奖励,目标是最大化累积奖励。
2. 实现步骤
-
安装并导入必要的库。
-
定义 Actor 网络和 Critic 网络。
-
定义 DDPG 训练过程。
-
测试模型并评估性能。
3. 代码实现
以下是完整的代码实现:
import gym
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import numpy as np
import random
from collections import deque
import matplotlib.pyplot as plt
# 设置 Matplotlib 支持中文显示
plt.rcParams['font.sans-serif'] = ['SimHei']
plt.rcParams['axes.unicode_minus'] = False
# 检查 GPU 是否可用
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"使用设备: {device}")
# 环境初始化
env = gym.make('Pendulum-v1')
state_dim = env.observation_space.shape[0]
action_dim = env.action_space.shape[0]
max_action = float(env.action_space.high[0])
# 随机种子设置
SEED = 42
torch.manual_seed(SEED)
np.random.seed(SEED)
random.seed(SEED)
# 定义 Actor 网络
class Actor(nn.Module):def __init__(self, state_dim, action_dim, max_action):super(Actor, self).__init__()self.fc1 = nn.Linear(state_dim, 512)self.ln1 = nn.LayerNorm(512) # 层归一化self.fc2 = nn.Linear(512, 512)self.ln2 = nn.LayerNorm(512)self.fc3 = nn.Linear(512, action_dim)self.max_action = max_action
def forward(self, x):x = F.relu(self.ln1(self.fc1(x)))x = F.relu(self.ln2(self.fc2(x)))return self.max_action * torch.tanh(self.fc3(x))
# 定义 Critic 网络
class Critic(nn.Module):def __init__(self, state_dim, action_dim):super(Critic, self).__init__()self.fc1 = nn.Linear(state_dim + action_dim, 256)self.fc2 = nn.Linear(256, 256)self.fc3 = nn.Linear(256, 1)
def forward(self, x, u):x = F.relu(self.fc1(torch.cat([x, u], 1)))x = F.relu(self.fc2(x))x = self.fc3(x)return x
# 添加OU噪声类
class OUNoise:def __init__(self, action_dim, mu=0, theta=0.15, sigma=0.2):self.mu = mu * np.ones(action_dim)self.theta = thetaself.sigma = sigmaself.reset()
def reset(self):self.state = np.copy(self.mu)
def sample(self):dx = self.theta * (self.mu - self.state) + self.sigma * np.random.randn(len(self.state))self.state += dxreturn self.state
# 定义 DDPG 算法
class DDPG:def __init__(self, state_dim, action_dim, max_action):self.actor = Actor(state_dim, action_dim, max_action).to(device)self.actor_target = Actor(state_dim, action_dim, max_action).to(device)self.actor_target.load_state_dict(self.actor.state_dict())self.actor_optimizer = optim.Adam(self.actor.parameters(), lr=1e-4)
self.critic = Critic(state_dim, action_dim).to(device)self.critic_target = Critic(state_dim, action_dim).to(device)self.critic_target.load_state_dict(self.critic.state_dict())self.critic_optimizer = optim.Adam(self.critic.parameters(), lr=1e-3)self.noise = OUNoise(action_dim, sigma=0.2) # 示例:Ornstein-Uhlenbeck噪声
self.max_action = max_actionself.replay_buffer = deque(maxlen=1000000)self.batch_size = 64self.gamma = 0.99self.tau = 0.005self.noise_sigma = 0.5 # 初始噪声强度self.noise_decay = 0.995
self.actor_lr_scheduler = optim.lr_scheduler.StepLR(self.actor_optimizer, step_size=100, gamma=0.95)self.critic_lr_scheduler = optim.lr_scheduler.StepLR(self.critic_optimizer, step_size=100, gamma=0.95)
def select_action(self, state):state = torch.FloatTensor(state).unsqueeze(0).to(device)self.actor.eval()with torch.no_grad():action = self.actor(state).cpu().data.numpy().flatten()self.actor.train()return action
def train(self):if len(self.replay_buffer) < self.batch_size:return
# 从经验回放缓冲区中采样batch = random.sample(self.replay_buffer, self.batch_size)state = torch.FloatTensor(np.array([transition[0] for transition in batch])).to(device)action = torch.FloatTensor(np.array([transition[1] for transition in batch])).to(device)reward = torch.FloatTensor(np.array([transition[2] for transition in batch])).reshape(-1, 1).to(device)next_state = torch.FloatTensor(np.array([transition[3] for transition in batch])).to(device)done = torch.FloatTensor(np.array([transition[4] for transition in batch])).reshape(-1, 1).to(device)
# 计算目标 Q 值next_action = self.actor_target(next_state)target_Q = self.critic_target(next_state, next_action)target_Q = reward + (1 - done) * self.gamma * target_Q
# 更新 Critic 网络current_Q = self.critic(state, action)critic_loss = F.mse_loss(current_Q, target_Q.detach())self.critic_optimizer.zero_grad()critic_loss.backward()self.critic_optimizer.step()
# 更新 Actor 网络actor_loss = -self.critic(state, self.actor(state)).mean()self.actor_optimizer.zero_grad()actor_loss.backward()self.actor_optimizer.step()
# 更新目标网络for param, target_param in zip(self.critic.parameters(), self.critic_target.parameters()):target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)for param, target_param in zip(self.actor.parameters(), self.actor_target.parameters()):target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)
def save(self, filename):torch.save(self.actor.state_dict(), filename + "_actor.pth")torch.save(self.critic.state_dict(), filename + "_critic.pth")
def load(self, filename):self.actor.load_state_dict(torch.load(filename + "_actor.pth"))self.critic.load_state_dict(torch.load(filename + "_critic.pth"))
# 训练流程
def train_ddpg(env, agent, episodes=500):rewards_history = []moving_avg = []
for ep in range(episodes):state,_ = env.reset()episode_reward = 0done = False
while not done:action = agent.select_action(state)next_state, reward, done, _, _ = env.step(action)agent.replay_buffer.append((state, action, reward, next_state, done))state = next_stateepisode_reward += rewardagent.train()
rewards_history.append(episode_reward)moving_avg.append(np.mean(rewards_history[-50:]))
if (ep + 1) % 50 == 0:print(f"Episode: {ep + 1}, Avg Reward: {moving_avg[-1]:.2f}")
return moving_avg, rewards_history
# 训练启动
ddpg_agent = DDPG(state_dim, action_dim, max_action)
moving_avg, rewards_history = train_ddpg(env, ddpg_agent)
# 可视化结果
plt.figure(figsize=(12, 6))
plt.plot(rewards_history, alpha=0.6, label='single round reward')
plt.plot(moving_avg, 'r-', linewidth=2, label='moving average (50 rounds)')
plt.xlabel('episodes')
plt.ylabel('reward')
plt.title('DDPG training performance on Pendulum-v1')
plt.legend()
plt.grid(True)
plt.show()
三、代码解析
-
Actor 和 Critic 网络:
-
Actor 网络输出连续动作,通过
tanh函数将动作限制在 −max_action,max_action 范围内。 -
Critic 网络输出状态-动作对的 Q 值。
-
-
DDPG 训练过程:
-
使用当前策略采样一批数据。
-
使用目标网络计算目标 Q 值。
-
更新 Critic 网络以最小化 Q 值的误差。
-
更新 Actor 网络以最大化 Q 值。
-
更新目标网络。
-
-
训练过程:
-
在训练过程中,每 50 个 episode 打印一次平均奖励。
-
训练结束后,绘制训练过程中的总奖励曲线。
-
四、运行结果
运行上述代码后,你将看到以下输出:
-
训练过程中每 50 个 episode 打印一次平均奖励。
-
训练结束后,绘制训练过程中的总奖励曲线。

五、总结
本文介绍了 DDPG 算法的基本原理,并使用 PyTorch 实现了一个简单的 DDPG 模型来解决 Pendulum 问题。通过这个例子,我们学习了如何使用 DDPG 算法进行连续动作空间的策略优化。
在下一篇文章中,我们将探讨更高级的强化学习算法,如 Twin Delayed DDPG (TD3)。敬请期待!
代码实例说明:
-
本文代码可以直接在 Jupyter Notebook 或 Python 脚本中运行。
-
如果你有 GPU,可以将模型和数据移动到 GPU 上运行,例如:
actor = actor.to('cuda'),state = state.to('cuda')。
希望这篇文章能帮助你更好地理解 DDPG 算法!如果有任何问题,欢迎在评论区留言讨论。
相关文章:
PyTorch 深度学习实战(14):Deep Deterministic Policy Gradient (DDPG) 算法
在上一篇文章中,我们介绍了 Proximal Policy Optimization (PPO) 算法,并使用它解决了 CartPole 问题。本文将深入探讨 Deep Deterministic Policy Gradient (DDPG) 算法,这是一种用于连续动作空间的强化学习算法。我们将使用 PyTorch 实现 D…...
Ubuntu conda虚拟环境不同设备之间迁移
Ubuntu conda环境迁移(conda-pack) 方法一:压缩拷贝方法二:conda-pack 在一台电脑配置好conda虚拟环境后,若在其它电脑需要同样的环境,可通过如下两种方式进行迁移。 方法一:压缩拷贝 找到Ubu…...
Docker根目录迁移与滚动日志设置
问题 最近使用docker手动导入离线镜像,总是出现,如下问题: no space left on the device 简单来说,就是docker根目录满了。 解决 查询当前docker info设置位置 使用如下命令,查询docker根目录位置: do…...
【JVM】GC 常见问题
GC 常见问题 哪些情况新生代会进入老年代 新生代 GC 后幸存区(survivor)不够存放存活下来的对象,会通过内存担保机制晋升到老年代。大对象直接进入老年代,因为大对象再新生代之间来会复制会影响 GC 性能。由 -XX:PretenureSizeT…...
「Unity3D」UGUI运行时设置元素的锚点Anchor,维持元素Rect的显示不变,即待在原处
在编辑器中,通过设置Raw edit mode,可以切换两种,元素锚点的改变模式: 一种是锚点单独改变,即:不开启原始模式,保持原样,改变anchoredPosition与sizeDelta。一种是锚点联动显示&…...
Angular由一个bug说起之十四:SCSS @import 警告与解决⽅案
SCSS import 警告与解决⽅案 ⚠ 警告信息 在 SCSS 中,使⽤ import 可能会产⽣以下警告: Deprecation Warning: Sass import rules are deprecated and will be removed in Dart Sass 3.0.0. ? 为什么会有这个警告? Sass 官⽅已经废弃 imp…...
PyTorch系列教程:基于LSTM构建情感分析模型
情感分析是一种强大的自然语言处理(NLP)技术,用于确定文本背后的情绪基调。它常用于理解客户对产品或服务的意见和反馈。本文将介绍如何使用PyTorch和长短期记忆网络(LSTMs)创建一个情感分析管道,LSTMs在处…...
SEO新手基础优化三步法
内容概要 在网站优化的初始阶段,新手常因缺乏系统性认知而陷入技术细节的误区。本文以“三步法”为核心框架,系统梳理从关键词定位到内容布局、再到外链构建的完整优化链路。通过拆解搜索引擎工作原理,重点阐明基础操作中容易被忽视的底层逻…...
Tcp网络通信的基本流程梳理
先来一张经典的流程图 接下介绍一下大概流程,各个函数的参数大家自己去了解加深一下印象 服务端流程 1.创建套接字:使用 socket 函数创建一个套接字,这个套接字后续会被用于监听客户端的连接请求。 需要注意的是,服务端一般有俩…...
PHP函数缺陷详解
无问社区-官网:http://www.wwlib.cn 本期无人投稿,欢迎大家投稿,投稿可获得无问社区AI大模型的使用红包哦! 无问社区:网安文章沉浸式免费看! 无问AI大模型不懂的问题随意问! 全网网安资源智…...
解决 Redis 后台持久化失败的问题:内存不足导致 fork 失败
文章目录 解决 Redis 后台持久化失败的问题:内存不足导致 fork 失败问题背景与成因解决方案修改内核参数 vm.overcommit_memory增加系统内存或 Swap 空间调整 Redis 配置 stop-writes-on-bgsave-error 在 Docker 环境中的注意事项总结 解决 Redis 后台持久化失败的问…...
深度学习GRU模型原理
一、介绍 门控循环单元(Gated Recurrent Unit, GRU) 是一种改进的循环神经网络(RNN),专为解决传统RNN的长期依赖问题(梯度消失/爆炸)而设计。其核心是通过门控机制动态控制信息的流动。与LSTM相…...
网络空间安全(31)安全巡检
一、定义与目的 定义: 安全巡检是指由专业人员或特定部门负责,对各类设施、设备、环境等进行全面或重点检查,及时发现潜在的安全隐患或问题。 目的: 预防事故发生:通过定期的安全巡检,及时发现并解决潜在的…...
基于Python+SQLite实现(Web)验室设备管理系统
实验室设备管理系统 应用背景 为方便实验室进行设备管理,某大学拟开发实验室设备管理系统 来管理所有实验室里的各种设备。系统可实现管理员登录,查看现有的所有设备, 增加设备等功能。 开发环境 Mac OSPyCharm IDEPython3Flaskÿ…...
面试系列|蚂蚁金服技术面【2】
今天继续分享一下蚂蚁金服的 Java 后端开发岗位真实社招面经,复盘面试过程中踩过的坑,整理面试过程中提到的知识点,希望能给正在准备面试的你一些参考和启发,希望对你有帮助,愿你能够获得心仪的 offer ! 第一轮面试完…...
【JavaEE】网络原理之初识
1.❤️❤️前言~🥳🎉🎉🎉 Hello, Hello~ 亲爱的朋友们👋👋,这里是E绵绵呀✍️✍️。 如果你喜欢这篇文章,请别吝啬你的点赞❤️❤️和收藏📖📖。如果你对我的…...
WorkTool 技术解析:企业微信自动化办公的合规实现方案
引言:企业微信生态中的自动化需求 随着企业微信用户规模突破4亿(据腾讯2023年财报),其开放生态催生了自动化办公的技术需求。传统RPA(机器人流程自动化)工具在PC端已广泛应用,但移动端自动化仍…...
从0到1构建AI深度学习视频分析系统--基于YOLO 目标检测的动作序列检查系统:(2)消息队列与消息中间件
文章大纲 原始视频队列Python 内存视频缓存优化方案(4GB 以内)一、核心参数设计二、内存管理实现三、性能优化策略四、内存占用验证五、高级优化技巧六、部署建议检测结果队列YOLO检测结果队列技术方案一、技术选型矩阵二、核心实现代码三、性能优化策略四、可视化方案对比五…...
一文讲通锁标记对象std::adopt_lock盲点
一文讲通锁标记对象std::adopt_lock盲点 1. 核心概念2. 代码详解1. 单个锁2. 多重锁(可以用来预防死锁)3. 条件变量的互斥控制4. 复杂示例: 多生产者-多消费者模型(超纲了, 可不看,哈哈哈哈) 3. 小结 1. 核心概念 在C中, std::adopt_lock是一…...
Vscode工具开发Vue+ts项目时vue文件ts语法报错-红波浪线等
Vscode工具开发Vuets项目时vue文件ts语法报错-红波浪线等 解决方案 问题如题描述,主要原因是开发工具使用的代码检查与项目的中的ts不一致导导致,解决办法,修改 vscode 中, 快捷键:command shift p, 输入ÿ…...
Mac下安装Zed以及Zed对MCP(模型上下文协议)的支持
Zed是当前新流行的一种编辑器,支持MCP(模型上下文协议) Mac下安装Zed比较简单,直接有安装包,在这里: brew install --cask zedMac Monterey下是可以安装上的,亲测有效。 配置 使用CtrlShiftP…...
ROS实践(五)机器人自动导航(robot_navigation)
目录 一、知识点 1. 定位 2. 路径规划 (1)全局路径规划 (2)局部路径规划 3. 避障 二、常用工具和传感器 三、相关功能包 1. move_base(决策规划) 2. amcl(定位) 3. costmap_2d(代价地图) 4. global_planner(全局规划器) 5. local_planner(局部规划器…...
REDIS生产环境配置
REDIS生产环境配置 REDIS生产环境配置docker-compose文件redis.conf文件 REDIS生产环境配置 docker-compose模式部署生产环境 docker-compose文件 d_redis:image: redis:${REDIS_VERSION}container_name: d_redisvolumes:- ${REDIS_1_CONF_FILE}:/etc/redis.conf:ro- ${DATA_…...
【小沐学Web3D】three.js 加载三维模型(React)
文章目录 1、简介1.1 three.js1.2 react.js 2、three.js React结语 1、简介 1.1 three.js Three.js 是一款 webGL(3D绘图标准)引擎,可以运行于所有支持 webGL 的浏览器。Three.js 封装了 webGL 底层的 API ,为我们提供了高级的…...
软考教材重点内容 信息安全工程师 第19章 操作系统安全保护
19.1.1 操作系统安全概念 一般来说,操作系统的安全是指满足安全策略要求,具有相应的安全机制及安全功能,符合特定的安全标准,在一定约束条件下,能够抵御常见的网络安全威胁,保障自身的安全运行及资源安全。…...
【C++设计模式】第二十一篇:模板方法模式(Template Method)
注意:复现代码时,确保 VS2022 使用 C17/20 标准以支持现代特性。 算法骨架的标准化定义 1. 模式定义与用途 核心思想 模板方法模式:在父类中定义算法的骨架,将某些步骤延迟到子类实现,使得子类不改变算法结构即可…...
【机器学习】基于t-SNE的MNIST数据集可视化探索
一、前言 在机器学习和数据科学领域,高维数据的可视化是一个极具挑战但又至关重要的问题。高维数据难以直观地理解和分析,而有效的可视化方法能够帮助我们发现数据中的潜在结构、模式和关系。本文以经典的MNIST手写数字数据集为例,探讨如何利…...
【Pycharm】Pycharm无法复制粘贴,提示系统剪贴板不可用
我也没有用vim的插件,检查了本地和ubutnu上都没有。区别是我是远程到ubutnu的pycharm,我本地直接控制windowes的pycharm是没问题的。现象是可以从外部复制到pycharm反之则不行。 ctl c ctlv 以及右键 都不行 参考:Pycharm无法复制粘贴&…...
基于python+django+vue.js开发的医院门诊管理系统/医疗管理系统源码+运行
功能介绍 平台采用B/S结构,后端采用主流的Python语言进行开发,前端采用主流的Vue.js进行开发。源码 功能包括:医生管理、科室管理、护士管理、住院管理、药品管理、用户管理、日志管理、系统信息模块。 源码地址 https://github.com/geee…...
Spring Boot整合RabbitMQ极简教程
一、消息队列能解决什么问题? 异步处理:解耦耗时操作(如发短信、日志记录)流量削峰:应对突发请求,避免系统过载应用解耦:服务间通过消息通信,降低依赖 二、快速整合RabbitMQ 1. 环…...
