当前位置: 首页 > article >正文

用PyTorch从零实现DQN算法:以CartPole游戏为例(附完整代码)

用PyTorch从零实现DQN算法以CartPole游戏为例附完整代码在强化学习领域深度Q网络DQN算法无疑是一座重要的里程碑。它将深度学习的强大表征能力与强化学习的决策框架完美结合为解决复杂环境中的决策问题提供了新思路。对于已经掌握Python和PyTorch基础想要深入实践强化学习的开发者来说从零实现一个DQN算法并将其应用于经典控制问题CartPole是一次绝佳的学习机会。本文将带你一步步构建完整的DQN系统从网络架构设计到训练策略优化每个环节都配有详细的代码解析和实战技巧。不同于理论推导为主的教程我们更关注工程实现中的坑与解比如如何设置合理的奖励机制、调试探索率衰减策略、优化经验回放缓冲区等实际问题。通过这个项目你不仅能理解DQN的核心思想更能获得可直接复用的代码模板。1. 环境准备与问题定义在开始编码之前我们需要明确CartPole问题的具体定义。这是一个经典的强化学习测试环境一根杆子通过非驱动关节连接到小车上小车沿着无摩擦的轨道移动。系统的状态由四个连续变量描述小车位置-4.8到4.8小车速度无限制杆子角度约-24°到24°杆子顶端速度无限制动作空间是离散的向左施加力0或向右施加力1。每步的奖励为1当杆子倾斜超过15度、小车移动超出边界中心点2.4单位距离或持续200步时回合结束。安装必要依赖pip install gym torch numpy关键参数初始化import gym import torch import numpy as np env gym.make(CartPole-v1) state_size env.observation_space.shape[0] # 4 action_size env.action_space.n # 22. DQN核心组件实现2.1 Q网络架构设计DQN的核心是用神经网络近似Q函数。我们设计一个三层的全连接网络输入维度与状态空间匹配4输出维度与动作空间匹配2。隐藏层使用ReLU激活函数引入非线性。import torch.nn as nn import torch.nn.functional as F class QNetwork(nn.Module): def __init__(self, state_size, action_size, hidden_size24): super(QNetwork, self).__init__() self.fc1 nn.Linear(state_size, hidden_size) self.fc2 nn.Linear(hidden_size, hidden_size) self.fc3 nn.Linear(hidden_size, action_size) def forward(self, x): x F.relu(self.fc1(x)) x F.relu(self.fc2(x)) return self.fc3(x)提示隐藏层大小是重要的超参数。过小会导致欠拟合过大则可能过拟合。24-64之间的值对CartPole通常效果不错。2.2 经验回放机制经验回放是DQN稳定训练的关键技术它通过存储并随机采样过往经验打破数据间的相关性。from collections import deque import random class ReplayBuffer: def __init__(self, capacity2000): self.buffer deque(maxlencapacity) def push(self, state, action, reward, next_state, done): self.buffer.append((state, action, reward, next_state, done)) def sample(self, batch_size): return random.sample(self.buffer, batch_size) def __len__(self): return len(self.buffer)经验回放的三个优势提高数据效率每条经验可被多次使用减少相关性随机采样打破时序依赖稳定训练平滑学习过程3. DQN智能体实现3.1 智能体核心逻辑DQN智能体需要管理探索与利用的平衡ε-greedy策略、目标网络更新和经验回放等关键功能。class DQNAgent: def __init__(self, state_size, action_size): self.state_size state_size self.action_size action_size self.memory ReplayBuffer() self.gamma 0.95 # 未来奖励折扣因子 self.epsilon 1.0 # 初始探索率 self.epsilon_min 0.01 self.epsilon_decay 0.995 self.learning_rate 0.001 self.model QNetwork(state_size, action_size) self.target_model QNetwork(state_size, action_size) self.optimizer torch.optim.Adam(self.model.parameters(), lrself.learning_rate) self.update_target_model() def update_target_model(self): self.target_model.load_state_dict(self.model.state_dict()) def act(self, state): if np.random.rand() self.epsilon: return random.randrange(self.action_size) state torch.FloatTensor(state) with torch.no_grad(): q_values self.model(state) return torch.argmax(q_values).item() def train(self, batch_size): if len(self.memory) batch_size: return minibatch self.memory.sample(batch_size) states torch.FloatTensor([t[0] for t in minibatch]) actions torch.LongTensor([t[1] for t in minibatch]) rewards torch.FloatTensor([t[2] for t in minibatch]) next_states torch.FloatTensor([t[3] for t in minibatch]) dones torch.FloatTensor([t[4] for t in minibatch]) current_q self.model(states).gather(1, actions.unsqueeze(1)) next_q self.target_model(next_states).max(1)[0].detach() target rewards (1 - dones) * self.gamma * next_q loss F.mse_loss(current_q.squeeze(), target) self.optimizer.zero_grad() loss.backward() self.optimizer.step() if self.epsilon self.epsilon_min: self.epsilon * self.epsilon_decay3.2 训练流程优化训练过程中有几个关键点需要特别注意奖励设计CartPole默认每步1奖励但可以调整终止惩罚探索策略ε的初始值和衰减率需要调优目标网络更新可以定期更新或软更新def train_agent(env, agent, episodes1000, batch_size32): scores [] for e in range(episodes): state env.reset() total_reward 0 for t in range(500): # 最大步数 action agent.act(state) next_state, reward, done, _ env.step(action) # 自定义终止惩罚 reward reward if not done else -10 agent.memory.push(state, action, reward, next_state, done) state next_state total_reward reward agent.train(batch_size) if done: break scores.append(total_reward) # 定期更新目标网络 if e % 10 0: agent.update_target_model() print(fEpisode: {e}, Score: {total_reward}, Epsilon: {agent.epsilon:.2f}) return scores4. 高级技巧与性能优化4.1 双重DQNDouble DQN原始DQN存在Q值高估问题。双重DQN通过解耦动作选择和Q值评估来缓解这个问题# 在DQNAgent类的train方法中修改目标Q计算 next_actions self.model(next_states).max(1)[1].unsqueeze(1) next_q self.target_model(next_states).gather(1, next_actions).squeeze() target rewards (1 - dones) * self.gamma * next_q4.2 优先级经验回放不是所有经验都同等重要。可以为缓冲区中的经验分配优先级更频繁地回放重要经验class PrioritizedReplayBuffer: def __init__(self, capacity2000, alpha0.6): self.buffer deque(maxlencapacity) self.priorities deque(maxlencapacity) self.alpha alpha def push(self, state, action, reward, next_state, done): max_prio max(self.priorities) if self.priorities else 1.0 self.buffer.append((state, action, reward, next_state, done)) self.priorities.append(max_prio) def sample(self, batch_size, beta0.4): prios np.array(self.priorities) probs prios ** self.alpha probs / probs.sum() indices np.random.choice(len(self.buffer), batch_size, pprobs) samples [self.buffer[idx] for idx in indices] weights (len(self.buffer) * probs[indices]) ** (-beta) weights / weights.max() return samples, indices, np.array(weights, dtypenp.float32) def update_priorities(self, indices, priorities): for idx, prio in zip(indices, priorities): self.priorities[idx] prio4.3 超参数调优指南DQN性能对超参数敏感。以下是经过实验验证的推荐范围超参数推荐值作用γ (gamma)0.9-0.99未来奖励折扣因子ε初始值1.0初始探索率ε最小值0.01-0.1最小探索率ε衰减率0.99-0.999探索率衰减速度学习率1e-4到1e-3优化器步长批量大小32-128每次训练样本数目标网络更新频率每10-100步稳定训练在实际项目中我发现ε衰减策略对最终性能影响显著。一个实用的技巧是在训练初期保持较高探索率ε1.0然后随着训练逐步衰减但不要降得太低保持在0.01左右以保留一定的探索能力。

相关文章:

用PyTorch从零实现DQN算法:以CartPole游戏为例(附完整代码)

用PyTorch从零实现DQN算法:以CartPole游戏为例(附完整代码) 在强化学习领域,深度Q网络(DQN)算法无疑是一座重要的里程碑。它将深度学习的强大表征能力与强化学习的决策框架完美结合,为解决复杂环…...

别再让CPU等外设了!用Multi-Layer AHB搭建一个不堵车的片上‘高速公路网’

用Multi-Layer AHB构建片上系统的高效数据通道 堵在早高峰的高架桥上时,你有没有想过——芯片里的数据流其实也面临着类似的拥堵问题?当多个处理器核心、DMA控制器同时争抢总线带宽时,传统的单层AHB架构就像只有两条车道的城市主干道&#xf…...

深度解密Jsxer:JSXBIN反编译器的技术原理与工程实现

深度解密Jsxer:JSXBIN反编译器的技术原理与工程实现 【免费下载链接】jsxer A fast and accurate JSXBIN decompiler. 项目地址: https://gitcode.com/gh_mirrors/js/jsxer 在Adobe创意套件生态中,ExtendScript二进制格式(JSXBIN&…...

Linux性能优化之内存管理基础知识

写在前面 本文看下Linux内存管理相关基础内容。 1:linux是如何管理的内存的? 我们平时所说的内存多大的内存,指的是物理内存,物理上就是一个内存条:物理内存,也叫主存,现在的主存一般是动态随机…...

数字IC设计中的TCL实战:用列表操作实现引脚自动排序

数字IC设计中的TCL实战:用列表操作实现引脚自动排序 在数字集成电路设计流程中,处理海量引脚信息是每位工程师的日常挑战。当面对数百个需要按特定规则排序的引脚时,手动操作不仅效率低下,还容易引入人为错误。TCL脚本作为EDA工具…...

XINGLIGHT成兴光 0603 球头正贴 LED 聚光透镜凸头球灯珠 高亮定向指示贴片 LED

XINGLIGHT 0603 球头正贴 LED 产品图 发光颜色 型号 红色 XL-TD1608SURC 黄色 XL-TD1608UYC 普绿 XL-TD1608SYGC 翠绿 XL-TD1608UGC 蓝色 XL-TD1608UBC XINGLIGHT 0603 球头正贴 LED,标准 0603 正装基底 顶部球面透镜一体封装,光线聚焦定向射出、视角集…...

高端工厂生产线储能与削峰系统功率器件选型方案:高效可靠能量转换系统适配指南

随着工业智能化与绿色制造的持续升级,工厂生产线储能与削峰填谷系统已成为保障连续生产、降低用能成本、提升电网韧性的核心设施。其功率转换系统作为整机“心脏”,需为电池管理、双向变流、负载切换等关键环节提供高效、可靠的电能变换,而功…...

告别结构体!手把手教你用Simulink.Signal配置汽车软件输入输出信号(含代码生成实战)

告别结构体!手把手教你用Simulink.Signal配置汽车软件输入输出信号(含代码生成实战) 在汽车电子控制单元(ECU)开发中,Simulink模型到C代码的转换是核心环节。许多工程师第一次生成代码时会发现,…...

OLED字库的构建与移植:从点阵数据到嵌入式显示

1. OLED字库的基础概念与工作原理 第一次接触OLED字库时,我也被那一串串十六进制数字搞得头晕眼花。直到后来才发现,这些看似复杂的数据背后,其实是一套非常直观的图形表达方式。OLED字库本质上就是字符的图形化表示,每个字符都被…...

从面试官视角看嵌入式C/C++:那些年我们踩过的坑与避开的雷

嵌入式C/C面试官的深度思考:技术考察背后的逻辑与实战智慧 在嵌入式开发领域,技术面试往往是一场无声的博弈。作为面试官,我们设计的每一个问题都像精心布置的棋盘,等待着候选人展示他们的思维路径。但这场博弈的目的不是难倒对方…...

别再死磕卡尔曼滤波了!用RBPF粒子滤波搞定机器人SLAM建图(附避坑指南)

粒子滤波实战:用RBPF突破SLAM建图瓶颈的工程指南 当你在ROS中运行gmapping节点时,是否遇到过地图突然扭曲变形的情况?或是发现粒子群在重采样后迅速退化,导致定位完全失败?这些正是传统卡尔曼滤波方法在复杂环境中暴露…...

Harness层接口签名:防篡改设计

Harness层接口签名:防篡改设计一、引言 (Introduction) 1.1 钩子:从微服务架构中那起“无声无息的100万元损失”说起 各位读者好,我是资深软件架构师、开源社区安全方向贡献者,同时也是「云原生与微服务安全实践」技术专栏的作者。…...

MAA自动化框架技术揭秘:计算机视觉驱动的游戏任务智能调度系统实现原理

MAA自动化框架技术揭秘:计算机视觉驱动的游戏任务智能调度系统实现原理 【免费下载链接】MaaAssistantArknights 《明日方舟》小助手,全日常一键长草!| A one-click tool for the daily tasks of Arknights, supporting all clients. 项目地…...

CloudCompare实战:点云二次曲面拟合精度分析与优化策略

1. 二次曲面拟合基础与CloudCompare实现 点云数据处理中,曲面拟合是个绕不开的话题。我第一次接触CloudCompare的二次曲面拟合功能时,就被它的简洁界面吸引,但实际用起来发现没那么简单。二次曲面拟合的本质,是用数学方程来描述点…...

从零部署MinerU文档解析服务:GPU加速、防OOM配置与Docker打包全攻略

从零部署MinerU文档解析服务:GPU加速、防OOM配置与Docker打包全攻略 在AI模型服务化的浪潮中,文档解析作为企业数字化转型的关键环节,正经历着从实验室Demo到生产级服务的蜕变。MinerU-OpenAPI以其多模态处理能力和工业级稳定性,成…...

PLC西门子杯比赛:三部十层电梯博图v15.1程序设计与WinCC界面展示

PLC西门子杯比赛,三部十层电梯博图v15.1程序,带wincc画面。凌晨三点的实验室里,咖啡杯在工控机旁边堆成了防御工事。我盯着博图V15.1里那三台虚拟电梯的运行轨迹,突然发现它们像极了三个不愿加班的打工人——总想着偷懒却又要假装…...

**发散创新:基于RBAC模型的权限管理系统在Python中的高效实现**在现代软件系统中,权限管理是保障数

发散创新:基于RBAC模型的权限管理系统在Python中的高效实现 在现代软件系统中,权限管理是保障数据安全和业务逻辑隔离的核心模块。传统的角色-权限绑定方式容易导致冗余与耦合,而**基于角色的访问控制(Role-Based Access Control,…...

Lv驱动库底层实际使用 Q8定点及其定点实现

目录 一、定点化 二、数据节点规划 三、Lv Q8定点计算代码实现 四、数据线性插值 ISP Pipeline中Lv实现方式探究之一ISP Pipeline中Lv实现方式探究之二ISP Pipeline中Lv实现方式探究之三--lv计算定点实现ISP Pipeline中Lv实现方式探究之四----正LV值定点实现 一、定点化 如上…...

**梯度压缩实战:用PyTorch实现高效分布式训练中的通信优化**在大规模深度学习模型训练中,**梯度同步**

梯度压缩实战:用PyTorch实现高效分布式训练中的通信优化 在大规模深度学习模型训练中,梯度同步是分布式训练的核心瓶颈之一。尤其是在多节点环境下,梯度数据传输消耗大量带宽和时间,严重影响训练效率。梯度压缩技术应运而生——它…...

直接撸代码才是硬道理!搞工控的都懂,IO监控画面最烦的就是一个个按钮指示灯拖到画面上。今天分享个骚操作——用下拉菜单+SCL动态绑定,直接一页搞定所有IO监控

西门子博途HMI监控1200或1500的IO状态时做成一页,IO监控画面做在一页显示,通过下拉菜单选择,方便快捷,不用一个一个去摆放了,是HMI及PLC源程序(SCL编写)先说PLC端的核心逻辑。用SCL搞个循环把IO状态打包成数组&#xf…...

从台球碰撞到火箭发射:用Python模拟动量守恒定律的5个趣味案例

从台球碰撞到火箭发射:用Python模拟动量守恒定律的5个趣味案例 物理学中的动量守恒定律看似抽象,但通过编程模拟,我们可以直观地观察这一原理在各类场景中的应用。本文将带你用Python实现5个经典案例,从台球碰撞到火箭发射&#x…...

Open WebUI:5分钟搭建你的专属AI助手,开启完全离线智能对话新时代

Open WebUI:5分钟搭建你的专属AI助手,开启完全离线智能对话新时代 【免费下载链接】open-webui User-friendly AI Interface (Supports Ollama, OpenAI API, ...) 项目地址: https://gitcode.com/GitHub_Trending/op/open-webui Open WebUI是一款…...

【每日一题】一文搞懂消费类电子的电池容量单位

我们平时使用移动充电宝,笔记本电脑,手机,智能穿戴设备,例如智能眼镜,经常看到标注的电池的容量大小,被五花八门的单位搞得晕头转向,今天我们就来看看这些单位,例如mA,mA…...

从一道ACM题看博弈论:当Alice和Bob开始‘吃瓜’比赛时,到底谁更占便宜?

从一道ACM题看博弈论:当Alice和Bob开始‘吃瓜’比赛时,到底谁更占便宜? 想象一下这样的场景:Alice和Bob面前摆着一堆西瓜,两人轮流拿取,每次可以拿任意数量的瓜,但必须花时间吃完才能继续拿。Al…...

终极glogg指南:如何用这款免费跨平台日志查看器快速分析海量日志文件

终极glogg指南:如何用这款免费跨平台日志查看器快速分析海量日志文件 【免费下载链接】glogg A fast, advanced log explorer. 项目地址: https://gitcode.com/gh_mirrors/gl/glogg glogg是一款专为程序员和系统管理员设计的跨平台GUI日志查看器,…...

收藏!SaaS小白必看:AI大模型落地实战路线图,从功能堆砌到价值创造

本文分析了SaaS公司在整合AI大模型时应避免“功能堆砌”陷阱,并介绍了三大AI技术路线:Prompt/RAG/微调的特点及适用场景。文章强调SaaSAI产品的成功关键在于技术路线与客户价值的适配,提出了分阶段组合策略,即初创期以提示词为主&…...

实战指南:如何高效配置VcXsrv实现Windows与Linux图形应用无缝连接

实战指南:如何高效配置VcXsrv实现Windows与Linux图形应用无缝连接 【免费下载链接】vcxsrv VcXsrv Windows X Server (X2Go/Arctica Builds) 项目地址: https://gitcode.com/gh_mirrors/vc/vcxsrv 在跨平台开发工作中,开发者经常面临一个核心挑战…...

5分钟快速上手Qwerty Learner:提升英语打字效率的终极指南

5分钟快速上手Qwerty Learner:提升英语打字效率的终极指南 【免费下载链接】qwerty-learner 为键盘工作者设计的单词记忆与英语肌肉记忆锻炼软件 / Words learning and English muscle memory training software designed for keyboard workers 项目地址: https:/…...

保姆级教程:从Vivado导出的XSA文件到Petalinux定制Linux系统(以AX7010开发板为例)

从XSA到嵌入式Linux:基于Petalinux的Zynq开发板全流程实战指南 第一次接触Zynq和Petalinux的开发者常会遇到这样的困惑:Vivado生成的硬件描述文件如何转化为可启动的Linux系统?本文将手把手带你完成从XSA文件到完整Linux系统的全流程构建&…...

Edge组策略避坑指南:当企业AD域遇到浏览器管控,这5个细节最容易翻车

Edge组策略避坑指南:企业AD域环境下的5个关键配置陷阱 1. 策略模板版本冲突:被忽视的兼容性杀手 在AD域环境中部署Edge浏览器管控时,策略模板版本与浏览器实际版本不匹配是最常见的翻车点。许多管理员直接从微软官网下载最新策略模板&#…...