当前位置: 首页 > article >正文

用PyTorch手搓DDPG算法:从Actor-Critic到目标网络,一步步搞定连续控制

用PyTorch手搓DDPG算法从Actor-Critic到目标网络一步步搞定连续控制在强化学习领域连续控制问题一直是极具挑战性的研究方向。想象一下训练机器人完成精细操作或者让自动驾驶车辆在复杂环境中平稳行驶——这些场景都需要算法能够输出连续范围内的动作值。Deep Deterministic Policy GradientDDPG正是为解决这类问题而生的算法它巧妙地将深度神经网络与传统的Actor-Critic框架相结合成为攻克连续控制任务的利器。本文将带您从零开始实现DDPG算法重点解决实际编码中的三个核心难题如何设计四个协同工作的神经网络Actor、Critic及其目标网络、如何处理动作空间的连续输出、以及如何平衡探索与利用的关系。我们将以MountainCarContinuous-v0环境为实验场通过PyTorch代码逐层拆解算法实现细节让您不仅理解DDPG的工作原理更能掌握其工程实现的关键技巧。1. DDPG算法核心架构解析DDPG算法的精妙之处在于它融合了DQN和策略梯度的优点形成独特的双网络双目标结构。与离散动作空间的DQN不同DDPG的Actor网络直接输出连续动作值这使其特别适合控制类任务。1.1 四大神经网络分工class PolicyNet(nn.Module): # Actor主网络 def __init__(self, n_states, n_hiddens, n_actions, action_bound): super().__init__() self.fc1 nn.Linear(n_states, n_hiddens) self.fc2 nn.Linear(n_hiddens, n_actions) self.action_bound action_bound def forward(self, x): x torch.tanh(self.fc2(F.relu(self.fc1(x)))) return x * self.action_boundDDPG包含四个关键神经网络Actor网络策略函数μ(s|θ^μ)输入状态输出确定性动作Critic网络价值函数Q(s,a|θ^Q)评估状态-动作对的价值Target Actor策略目标网络μ(s|θ^μ)用于稳定训练Target Critic价值目标网络Q(s,a|θ^Q)提供TD目标基准这种分离设计解决了移动目标问题——当使用同一个网络既计算预测值又计算目标值时会导致训练过程不稳定。通过引入目标网络我们相当于为算法提供了一个相对固定的参考系。1.2 关键数学原理DDPG的核心更新规则建立在贝尔曼方程基础上Critic更新目标y r γ(1-done)Q(s,μ(s|θ^μ)|θ^Q)Actor更新策略∇θ^μ J ≈ E[∇a Q(s,a|θ^Q)|aμ(s) ∇θ^μ μ(s|θ^μ)]这两个公式揭示了DDPG的双重学习机制Critic学习准确评估动作价值而Actor则朝着提升Critic评分的方向优化策略。这种分工协作的模式使得算法既能处理连续动作空间又能保持较高的样本效率。2. 经验回放机制实现经验回放是DDPG稳定训练的关键组件它通过存储和重复利用历史经验打破了样本间的时序相关性。我们实现了一个高效的回放缓冲区2.1 回放缓冲区设计class ReplayBuffer: def __init__(self, capacity): self.buffer collections.deque(maxlencapacity) # 固定容量队列 def add(self, state, action, reward, next_state, done): self.buffer.append((state, action, reward, next_state, done)) def sample(self, batch_size): transitions random.sample(self.buffer, batch_size) return zip(*transitions)技术细节说明使用collections.deque实现固定大小的循环缓冲区每个经验元组存储(s_t, a_t, r_{t1}, s_{t1}, done)五要素采样时随机抽取batch_size个独立样本打破时序相关性2.2 优先经验回放改进基础实现采用均匀采样但我们可以进一步优化class PrioritizedReplayBuffer(ReplayBuffer): def __init__(self, capacity, alpha0.6): super().__init__(capacity) self.priorities np.zeros(capacity) self.alpha alpha # 控制优先程度 self.pos 0 def add(self, *args): max_prio self.priorities.max() if self.buffer else 1.0 self.priorities[self.pos] max_prio super().add(*args) self.pos (self.pos 1) % self.buffer.maxlen优先回放根据TD误差调整采样概率使对学习更有价值的经验更频繁地被回放。这种改进可以显著提升样本利用率特别是在稀疏奖励场景中。3. 噪声探索策略实现确定性策略的一个固有问题是缺乏探索能力。DDPG通过添加噪声解决这一问题使Agent能够探索动作空间。3.1 高斯噪声实现def take_action(self, state): state torch.FloatTensor(state).unsqueeze(0).to(self.device) action self.actor(state).cpu().detach().numpy()[0] # 添加高斯噪声 noise self.sigma * np.random.randn(self.n_actions) return np.clip(action noise, -self.action_bound, self.action_bound)参数调节技巧sigma控制噪声强度通常从0.1开始逐步衰减训练初期可设置较大噪声增强探索训练后期减小噪声使策略趋于稳定使用np.clip确保动作不超出环境允许范围3.2 噪声退火策略更高级的实现可以采用噪声退火机制self.sigma max(0.01, self.sigma * 0.995) # 每步衰减噪声这种线性或指数衰减策略能够在训练初期充分探索在后期稳定策略。实际应用中还可以采用Ornstein-Uhlenbeck过程噪声它特别适合物理系统的惯性特性。4. 网络更新机制详解DDPG的训练过程涉及四种网络的协同更新这是算法实现中最复杂的部分。我们将拆解每个更新步骤的代码实现。4.1 Critic网络更新Critic的目标是最小化TD误差# 计算目标Q值 next_actions self.target_actor(next_states) next_q_values self.target_critic(next_states, next_actions) q_targets rewards self.gamma * (1 - dones) * next_q_values # 计算当前Q值 q_values self.critic(states, actions) # 计算损失并更新 critic_loss F.mse_loss(q_values, q_targets) self.critic_optimizer.zero_grad() critic_loss.backward() self.critic_optimizer.step()关键点使用目标网络计算下一状态的动作和价值通过贝尔曼方程构造TD目标最小化预测值与目标值的均方误差4.2 Actor网络更新Actor的目标是最大化预期回报# 计算策略梯度 actor_actions self.actor(states) actor_loss -self.critic(states, actor_actions).mean() self.actor_optimizer.zero_grad() actor_loss.backward() self.actor_optimizer.step()这里使用负号是因为PyTorch优化器默认执行最小化。本质上我们是在沿着Critic评估的梯度方向提升策略性能。4.3 目标网络软更新DDPG采用软更新而非硬更新def soft_update(self, net, target_net): for param, target_param in zip(net.parameters(), target_net.parameters()): target_param.data.copy_( self.tau * param.data (1 - self.tau) * target_param.data )参数选择建议tau通常设为0.001-0.01较小的tau使目标网络更新更平缓过大的tau可能导致训练不稳定5. 完整训练流程实现现在我们将所有组件整合构建完整的训练循环。以MountainCarContinuous-v0环境为例5.1 环境初始化env gym.make(MountainCarContinuous-v0) n_states env.observation_space.shape[0] n_actions env.action_space.shape[0] action_bound env.action_space.high[0] agent DDPG( n_statesn_states, n_hiddens64, n_actionsn_actions, action_boundaction_bound, sigma0.1, actor_lr1e-3, critic_lr1e-3, tau0.005, gamma0.99, devicedevice )5.2 训练循环for episode in range(200): state env.reset() episode_return 0 while True: action agent.take_action(state) next_state, reward, done, _ env.step(action) replay_buffer.add(state, action, reward, next_state, done) state next_state episode_return reward if len(replay_buffer) batch_size: transitions replay_buffer.sample(batch_size) agent.update(transitions) if done: break5.3 训练曲线分析典型的训练过程会呈现以下特征初期回报波动较大Agent在探索阶段随着经验积累策略逐渐稳定后期回报趋于收敛找到较优策略建议监控以下指标每回合总回报Critic损失值Actor策略更新幅度噪声强度变化6. 实战技巧与调优策略在实际实现DDPG时有几个关键点需要特别注意6.1 网络结构设计Actor网络架构建议最后一层使用tanh激活输出范围[-1,1]通过action_bound缩放输出到环境范围隐藏层不宜过深2-3层通常足够Critic网络设计要点输入为状态和动作的拼接最后一层线性输出无激活函数可考虑使用Layer Normalization稳定训练6.2 超参数调优参数典型值调节建议回放缓冲区大小1e5-1e6越大越好但受内存限制批量大小64-256太小导致不稳定太大降低样本效率Actor学习率1e-4-1e-3通常小于Critic学习率Critic学习率1e-3-3e-3可适当增大折扣因子γ0.95-0.99长周期任务取较大值软更新系数τ0.001-0.01越小更新越平缓6.3 常见问题排查训练不收敛的可能原因Critic损失爆炸尝试减小学习率梯度裁剪Actor策略退化检查噪声是否足够增大探索回报波动大增大回放缓冲区减小批量大小目标网络更新过快减小τ值调试技巧可视化网络权重分布监控梯度幅度检查动作值是否合理验证TD误差是否逐渐减小7. 进阶改进方向基础DDPG实现后可以考虑以下改进方案提升性能7.1 Twin Delayed DDPG (TD3)TD3算法针对DDPG的三个主要弱点进行了改进目标策略平滑减少Critic估计误差双Critic网络取最小值避免过估计延迟策略更新Critic更稳定后再更新Actor# TD3的双Critic实现示例 self.critic1 QValueNet(n_states, n_hiddens, n_actions).to(device) self.critic2 QValueNet(n_states, n_hiddens, n_actions).to(device) self.target_critic1 QValueNet(n_states, n_hiddens, n_actions).to(device) self.target_critic2 QValueNet(n_states, n_hiddens, n_actions).to(device)7.2 分布式DDPG通过多个Agent并行收集经验加速训练过程每个Worker有独立的探索策略共享中心化回放缓冲区定期同步主网络参数7.3 分层DDPG对于复杂任务可以设计分层控制高层策略制定子目标底层DDPG执行具体动作通过目标重标定连接不同层次在MountainCarContinuous环境中基础DDPG通常能在100-200个训练回合后找到解决方案。实际测试中一个配置得当的DDPG Agent可以将小车在约110步内推到目标位置远优于随机策略的300步表现。

相关文章:

用PyTorch手搓DDPG算法:从Actor-Critic到目标网络,一步步搞定连续控制

用PyTorch手搓DDPG算法:从Actor-Critic到目标网络,一步步搞定连续控制 在强化学习领域,连续控制问题一直是极具挑战性的研究方向。想象一下训练机器人完成精细操作,或者让自动驾驶车辆在复杂环境中平稳行驶——这些场景都需要算法…...

通达信缠论指标插件:3分钟完成专业级技术分析部署指南

通达信缠论指标插件:3分钟完成专业级技术分析部署指南 【免费下载链接】Indicator 通达信缠论可视化分析插件 项目地址: https://gitcode.com/gh_mirrors/ind/Indicator 通达信缠论可视化分析插件是一款专为技术分析爱好者设计的C开发工具,能够自…...

PX4-Autopilot系统调用与API接口深度解析:构建自主飞行系统的技术架构

PX4-Autopilot系统调用与API接口深度解析:构建自主飞行系统的技术架构 【免费下载链接】PX4-Autopilot PX4 Autopilot Software 项目地址: https://gitcode.com/gh_mirrors/px/PX4-Autopilot PX4-Autopilot作为开源无人机飞控软件的标杆,其核心价…...

简单视频下载助手:轻松保存网页视频的终极解决方案

简单视频下载助手:轻松保存网页视频的终极解决方案 【免费下载链接】VideoDownloadHelper Chrome Extension to Help Download Video for Some Video Sites. 项目地址: https://gitcode.com/gh_mirrors/vi/VideoDownloadHelper 你是否经常遇到想要保存网页视…...

5大核心功能带你探索Xournal++:跨平台数字手写笔记的无限可能

5大核心功能带你探索Xournal:跨平台数字手写笔记的无限可能 【免费下载链接】xournalpp Xournal is a handwriting notetaking software with PDF annotation support. Written in C with GTK3, supporting Linux (e.g. Ubuntu, Debian, Arch, SUSE), macOS and Win…...

Windows微信批量消息发送工具:5分钟快速上手指南

Windows微信批量消息发送工具:5分钟快速上手指南 【免费下载链接】WeChat-mass-msg 微信自动发送信息,微信群发消息,Windows系统微信客户端(PC端 项目地址: https://gitcode.com/gh_mirrors/we/WeChat-mass-msg 还在为逐个…...

TFT Overlay终极指南:云顶之弈玩家的免费战术悬浮助手

TFT Overlay终极指南:云顶之弈玩家的免费战术悬浮助手 【免费下载链接】TFT-Overlay Overlay for Teamfight Tactics 项目地址: https://gitcode.com/gh_mirrors/tf/TFT-Overlay 你是否在云顶之弈对局中因为记不住复杂的装备合成公式而错失胜利机会&#xff…...

网络小白也能看懂的CDP和LLDP:手把手教你用它们快速摸清网络家底

网络小白也能看懂的CDP和LLDP:手把手教你用它们快速摸清网络家底 刚接手一个陌生网络时,最让人头疼的就是搞不清楚设备之间的连接关系。就像搬进新家却找不到水电总闸,每次排查故障都像在迷宫里打转。其实网络设备自带了"自动名片交换&q…...

别只盯着Trace了!CANoe Analysis功能区这3个隐藏功能,让你的测试报告更专业

别只盯着Trace了!CANoe Analysis功能区这3个隐藏功能,让你的测试报告更专业 在汽车电子测试领域,CANoe早已成为工程师们不可或缺的利器。但大多数用户仅仅停留在Trace窗口的基础使用上,殊不知Analysis功能区还隐藏着诸多能显著提升…...

Ollama Colab V4:云端免费部署大语言模型的完整指南

1. 项目概述:在云端免费运行大语言模型的“瑞士军刀” 如果你对运行像 Llama、Mistral 这类开源大语言模型(LLM)感兴趣,但又苦于没有足够性能的本地显卡,或者不想在环境配置上耗费大量时间,那么 Ollama C…...

通过用量看板清晰掌握各模型 API 调用成本

通过用量看板清晰掌握各模型 API 调用成本 1. 用量看板的核心价值 对于需要同时接入多个大模型的团队而言,成本透明度和资源分配合理性是技术决策的重要依据。Taotoken 控制台提供的用量看板功能,能够将分散在不同模型供应商的调用数据聚合到统一视图&…...

如何快速解锁电脑隐藏性能:UXTU电脑性能优化终极指南

如何快速解锁电脑隐藏性能:UXTU电脑性能优化终极指南 【免费下载链接】Universal-x86-Tuning-Utility Unlock the full potential of your Intel/AMD based device. 项目地址: https://gitcode.com/gh_mirrors/un/Universal-x86-Tuning-Utility 你是否曾经疑…...

终极解决:TranslucentTB任务栏透明工具依赖问题完整指南

终极解决:TranslucentTB任务栏透明工具依赖问题完整指南 【免费下载链接】TranslucentTB A lightweight utility that makes the Windows taskbar translucent/transparent. 项目地址: https://gitcode.com/gh_mirrors/tr/TranslucentTB TranslucentTB是一款…...

VULK Skills:为AI编程助手注入团队编码规范与最佳实践

1. 项目概述:为AI编码助手注入“肌肉记忆” 如果你用过Claude Code、Cursor或者Windsurf这类AI编程助手,大概率有过这样的体验:你让它“写一个登录表单”,它确实能给你生成代码,但结果往往千差万别。有时候它用了一堆…...

ESP32-S3实现0.7秒手势识别:嵌入式AI实战指南

1. 项目概述在嵌入式AI领域,将深度学习模型部署到资源受限的微控制器上一直是个挑战。最近Ali Hassan Shah成功在ESP32-S3-EYE开发板上实现了基于ESP-DL库的手势识别系统,整个推理过程仅需0.7秒。这个项目展示了如何在边缘设备上运行自定义的卷积神经网络…...

3分钟上手:如何用开源可视化工具将数据变成精美图表

3分钟上手:如何用开源可视化工具将数据变成精美图表 【免费下载链接】ArchivePasswordTestTool 利用7zip测试压缩包的功能 对加密压缩包进行自动化测试密码 项目地址: https://gitcode.com/gh_mirrors/ar/ArchivePasswordTestTool 你是否曾经面对一堆复杂数据…...

网盘直链下载助手:一键获取9大网盘真实下载地址的完整指南

网盘直链下载助手:一键获取9大网盘真实下载地址的完整指南 【免费下载链接】Online-disk-direct-link-download-assistant 一个基于 JavaScript 的网盘文件下载地址获取工具。基于【网盘直链下载助手】修改 ,支持 百度网盘 / 阿里云盘 / 中国移动云盘 / …...

手把手教你搞定杰理AC695 SDK v1.4.0的编译与下载(附常见错误修复)

杰理AC695 SDK v1.4.0开发实战:从环境搭建到固件烧录全指南 第一次接触杰理AC695芯片的开发者,往往会被其丰富的功能和相对复杂的开发环境所困扰。作为一款广泛应用于蓝牙音频、智能家居等领域的MCU,AC695的性能和灵活性确实令人印象深刻&…...

告别刹车油!聊聊汽车EMB电子机械制动,它真能干掉用了百年的液压系统吗?

告别刹车油!汽车EMB电子机械制动能否终结百年液压时代? 想象一下,你的爱车不再需要定期更换刹车油,维修时不再有液压管路漏液的烦恼,制动响应速度比传统系统快3倍——这就是EMB电子机械制动技术带来的未来图景。在特斯…...

量子电路优化中的黎曼几何与随机子空间方法

1. 量子电路优化与黎曼几何方法概述 量子计算领域近年来在NISQ(含噪声中等规模量子)时代面临的核心挑战之一,是如何高效优化参数化量子电路(PQC)。变分量子算法(VQA)作为当前主流的解决方案&…...

3步轻松安装KK-HF Patch:解锁Koikatsu游戏200+模组与完整翻译体验

3步轻松安装KK-HF Patch:解锁Koikatsu游戏200模组与完整翻译体验 【免费下载链接】KK-HF_Patch Automatically translate, uncensor and update Koikatu! and Koikatsu Party! 项目地址: https://gitcode.com/gh_mirrors/kk/KK-HF_Patch 还在为Koikatu或Koik…...

08-MLOps与工程落地——02. 实验追踪:Weights Biases

02. 实验追踪:Weights & Biases 一、W&B概述 1.1 产品定位与特点 Weights & Biases(W&B)是一个专注于机器学习实验管理的平台,提供云端实验追踪、可视化、超参数搜索和协作功能。 核心特点: 轻量…...

终极魔兽地图转换指南:3分钟解决地图版本兼容性问题

终极魔兽地图转换指南:3分钟解决地图版本兼容性问题 【免费下载链接】w3x2lni 魔兽地图格式转换工具 项目地址: https://gitcode.com/gh_mirrors/w3/w3x2lni 你是否遇到过精心制作的魔兽地图在新版本游戏中无法运行?或者老地图在1.32.8版本中频频…...

5分钟掌握Upscayl:免费开源AI图像放大工具实战指南

5分钟掌握Upscayl:免费开源AI图像放大工具实战指南 【免费下载链接】upscayl 🆙 Upscayl - #1 Free and Open Source AI Image Upscaler for Linux, MacOS and Windows. 项目地址: https://gitcode.com/GitHub_Trending/up/upscayl 还在为模糊的老…...

Flowstep 1.0 技术深度解析:AI 设计引擎的架构、渲染与工程化实现

摘要 Flowstep 1.0 是一款面向开发者与技术设计师的 AI 设计工程化工具,核心解决 “设计 - 代码” 重复转换的低效痛点。本文从技术底层出发,系统拆解 Flowstep 1.0 的核心架构设计、无限画布渲染引擎、AI 生成模型体系、代码导出引擎、MCP 协议集成五大…...

AI 免费获客结束进入商业化验证,豆包付费测试能否破解盈利难题?

【AI 商业化新阶段开启】免费获客阶段结束,AI 应用开始进入“成本分层 用户分层 商业化验证”阶段。最近,豆包 App Store 页面出现了付费订阅信息,除免费基础版外,可能有 68 元/月标准版、200 元/月加强版、500 元/月专业版&…...

【MCP 2026边缘部署性能优化权威指南】:基于17个工业现场POC数据,提炼出的3.2μs级时序收敛公式

更多请点击: https://intelliparadigm.com 第一章:MCP 2026边缘部署性能优化的工程意义与边界定义 MCP 2026(Multi-Controller Protocol 2026)作为新一代边缘协同控制协议,其在资源受限设备上的高效部署直接决定工业物…...

WSA-Pacman:Windows安卓子系统图形化包管理的终极解决方案

WSA-Pacman:Windows安卓子系统图形化包管理的终极解决方案 【免费下载链接】wsa_pacman A GUI package manager and package installer for Windows Subsystem for Android (WSA) 项目地址: https://gitcode.com/gh_mirrors/ws/wsa_pacman 在Windows 11上运行…...

AXOrderBook:构建微秒级A股高频交易订单簿系统的完整指南

AXOrderBook:构建微秒级A股高频交易订单簿系统的完整指南 【免费下载链接】AXOrderBook A股订单簿工具,使用逐笔行情进行订单簿重建、千档快照发布、各档委托队列展示等,包括python模型和FPGA HLS实现。 项目地址: https://gitcode.com/gh_…...

开发极简主义运动实践指南手册:软件测试从业者的效率跃升之路

一、测试困境与极简主义的觉醒在软件开发快速迭代的浪潮中,软件测试从业者正陷入一场前所未有的“数字喧嚣”困境。每天,我们穿梭于海量的需求文档、日益庞杂的技术栈、数不胜数的测试用例以及永不停歇的通知流之间。当“更多”成为下意识的追求——更多…...