PyTorch 深度学习实战(11):强化学习与深度 Q 网络(DQN)
在之前的文章中,我们介绍了神经网络、卷积神经网络(CNN)、循环神经网络(RNN)、Transformer 等多种深度学习模型,并应用于图像分类、文本分类、时间序列预测等任务。本文将介绍强化学习的基本概念,并使用 PyTorch 实现一个经典的深度 Q 网络(DQN)来解决强化学习中的经典问题——CartPole。
一、强化学习基础
强化学习(Reinforcement Learning, RL)是机器学习的一个重要分支,它通过智能体(Agent)与环境(Environment)的交互来学习策略,以最大化累积奖励。强化学习的核心思想是通过试错来学习,智能体在环境中采取行动,观察结果,并根据奖励信号调整策略。
1. 强化学习的基本要素
-
智能体(Agent):学习并做出决策的主体。
-
环境(Environment):智能体交互的外部世界。
-
状态(State):环境在某一时刻的描述。
-
动作(Action):智能体在某一状态下采取的行动。
-
奖励(Reward):智能体采取动作后,环境返回的反馈信号。
-
策略(Policy):智能体在给定状态下选择动作的规则。
-
价值函数(Value Function):评估在某一状态下采取某一动作的长期回报。
2. Q-Learning 与深度 Q 网络(DQN)
Q-Learning 是一种经典的强化学习算法,它通过学习一个 Q 函数来评估在某一状态下采取某一动作的长期回报。Q 函数的更新公式为:

深度 Q 网络(DQN)将 Q-Learning 与深度学习结合,使用神经网络来近似 Q 函数。DQN 通过经验回放(Experience Replay)和目标网络(Target Network)来稳定训练过程。
二、CartPole 问题实战
CartPole 是强化学习中的经典问题,目标是控制一个小车(Cart)使其上的杆子(Pole)保持直立。我们将使用 PyTorch 实现一个 DQN 来解决这个问题。
1. 问题描述
CartPole 环境的状态空间包括小车的位置、速度、杆子的角度和角速度。动作空间包括向左或向右移动小车。智能体每保持杆子直立一步,就会获得 +1 的奖励,当杆子倾斜超过一定角度或小车移动超出范围时,游戏结束。
2. 实现步骤
-
安装并导入必要的库。
-
定义 DQN 模型。
-
定义经验回放缓冲区。
-
定义 DQN 训练过程。
-
测试模型并评估性能。
3. 代码实现
import gym
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import random
from collections import deque
import matplotlib.pyplot as plt
# 设置 Matplotlib 支持中文显示
plt.rcParams['font.sans-serif'] = ['SimHei'] # 设置字体为 SimHei(黑体)
plt.rcParams['axes.unicode_minus'] = False # 解决负号显示问题
# 1. 安装并导入必要的库
env = gym.make('CartPole-v1')
# 2. 定义 DQN 模型
class DQN(nn.Module):def __init__(self, state_size, action_size):super(DQN, self).__init__()self.fc1 = nn.Linear(state_size, 64)self.fc2 = nn.Linear(64, 64)self.fc3 = nn.Linear(64, action_size)
def forward(self, x):x = torch.relu(self.fc1(x))x = torch.relu(self.fc2(x))x = self.fc3(x)return x
# 3. 定义经验回放缓冲区
class ReplayBuffer:def __init__(self, capacity):self.buffer = deque(maxlen=capacity)
def push(self, state, action, reward, next_state, done):self.buffer.append((state, action, reward, next_state, done))
def sample(self, batch_size):state, action, reward, next_state, done = zip(*random.sample(self.buffer, batch_size))return np.array(state), np.array(action), np.array(reward), np.array(next_state), np.array(done)
def __len__(self):return len(self.buffer)
# 4. 定义 DQN 训练过程
state_size = env.observation_space.shape[0]
action_size = env.action_space.n
model = DQN(state_size, action_size)
target_model = DQN(state_size, action_size)
target_model.load_state_dict(model.state_dict())
optimizer = optim.Adam(model.parameters(), lr=0.001)
buffer = ReplayBuffer(10000)
def train(batch_size, gamma=0.99):if len(buffer) < batch_size:returnstate, action, reward, next_state, done = buffer.sample(batch_size)state = torch.FloatTensor(state)next_state = torch.FloatTensor(next_state)action = torch.LongTensor(action)reward = torch.FloatTensor(reward)done = torch.FloatTensor(done)
q_values = model(state)next_q_values = target_model(next_state)q_value = q_values.gather(1, action.unsqueeze(1)).squeeze(1)next_q_value = next_q_values.max(1)[0]expected_q_value = reward + gamma * next_q_value * (1 - done)
loss = nn.MSELoss()(q_value, expected_q_value.detach())optimizer.zero_grad()loss.backward()optimizer.step()
# 5. 测试模型并评估性能
def test(env, model, episodes=10):total_reward = 0for _ in range(episodes):state = env.reset()done = Falsewhile not done:state = torch.FloatTensor(state).unsqueeze(0)action = model(state).max(1)[1].item()next_state, reward, done, _ = env.step(action)total_reward += rewardstate = next_statereturn total_reward / episodes
# 训练过程
episodes = 500
batch_size = 64
gamma = 0.99
epsilon = 1.0
epsilon_min = 0.01
epsilon_decay = 0.995
rewards = []
for episode in range(episodes):state = env.reset()done = Falsetotal_reward = 0
while not done:if random.random() < epsilon:action = env.action_space.sample()else:state_tensor = torch.FloatTensor(state).unsqueeze(0)action = model(state_tensor).max(1)[1].item()
next_state, reward, done, _ = env.step(action)buffer.push(state, action, reward, next_state, done)state = next_statetotal_reward += reward
train(batch_size, gamma)
epsilon = max(epsilon_min, epsilon * epsilon_decay)rewards.append(total_reward)
if (episode + 1) % 50 == 0:avg_reward = test(env, model)print(f"Episode: {episode + 1}, Avg Reward: {avg_reward:.2f}")
# 6. 可视化训练结果
plt.plot(rewards)
plt.xlabel("Episode")
plt.ylabel("Total Reward")
plt.title("DQN 训练过程")
plt.show()
三、代码解析
-
环境与模型定义:
-
使用
gym创建 CartPole 环境。 -
定义 DQN 模型,包含三个全连接层。
-
-
经验回放缓冲区:
-
使用
deque实现经验回放缓冲区,存储状态、动作、奖励等信息。
-
-
训练过程:
-
使用 epsilon-greedy 策略进行探索与利用。
-
通过经验回放缓冲区采样数据进行训练,更新模型参数。
-
-
测试过程:
-
在测试环境中评估模型性能,计算平均奖励。
-
-
可视化:
-
绘制训练过程中的总奖励曲线。
-
四、运行结果
运行上述代码后,你将看到以下输出:
-
训练过程中每 50 个 episode 打印一次平均奖励。
-
训练结束后,绘制训练过程中的总奖励曲线。

五、总结
本文介绍了强化学习的基本概念,并使用 PyTorch 实现了一个深度 Q 网络(DQN)来解决 CartPole 问题。通过这个例子,我们学习了如何定义 DQN 模型、使用经验回放缓冲区、训练模型以及评估性能。
在下一篇文章中,我们将探讨更复杂的强化学习算法,如 Actor-Critic 和 Proximal Policy Optimization (PPO)。敬请期待!
代码实例说明:
-
本文代码可以直接在 Jupyter Notebook 或 Python 脚本中运行。
-
如果你有 GPU,可以将模型和数据移动到 GPU 上运行,例如:
model = model.to('cuda'),state = state.to('cuda')。
希望这篇文章能帮助你更好地理解强化学习的基础知识!如果有任何问题,欢迎在评论区留言讨论。
相关文章:
PyTorch 深度学习实战(11):强化学习与深度 Q 网络(DQN)
在之前的文章中,我们介绍了神经网络、卷积神经网络(CNN)、循环神经网络(RNN)、Transformer 等多种深度学习模型,并应用于图像分类、文本分类、时间序列预测等任务。本文将介绍强化学习的基本概念࿰…...
在Eclipse 中使用 MyBatis 进行开发,通常需要以下步骤:
在Eclipse 中使用 MyBatis 进行开发,通常需要以下步骤: 1. 创建 Maven 项目 首先,在 Eclipse 中创建一个 Maven 项目。如果你还没有安装 Maven 插件,可以通过 Eclipse Marketplace 安装 Maven 插件。 打开 Eclipse,选…...
Python学习第十九天
Django-分页 后端分页 Django提供了Paginator类来实现后端分页。Paginator类可以将一个查询集(QuerySet)分成多个页面,每个页面包含指定数量的对象。 from django.shortcuts import render, redirect, get_object_or_404 from .models impo…...
Adobe Premiere Pro2023配置要求
Windows 系统 最低配置 处理器:Intel 第六代或更新版本的 CPU,或 AMD Ryzen™ 1000 系列或更新版本的 CPU,需要支持 Advanced Vector Extensions 2(AVX2)。操作系统:Windows 10(64 位ÿ…...
面试求助:接口测试用例设计主要考虑哪些方面?
一、基础功能验证 1. 正常场景覆盖 关键点:验证接口在合法输入下的正确响应(状态码、数据结构、业务逻辑)。 案例: json 复制 // 用户登录接口 输入:{"username": "合法用户", "password…...
C语言——变量与常量
C语言中的变量与常量:简洁易懂的指南 在C语言编程中,变量和常量是最基本的概念之一。理解它们的区别和使用方法对于编写高效、可维护的代码至关重要。本文将详细介绍C语言中的变量和常量,并通过图表和代码示例帮助你更好地理解。 目录 什么…...
考研408-数据结构完整代码 线性表的顺序存储结构 - 顺序表
线性表的顺序存储结构 - 顺序表 1. 顺序表的定义 用一组地址连续的存储单元依次存储线性表的数据元素,从而使逻辑上相邻的两个元素在物理位置上也相邻 2. 顺序表的特点 随机访问: 即通过首地址和元素序号可以在O(1) 时间内找到指定元素࿰…...
Windows环境下安装部署dzzoffice+onlyoffice的私有网盘和在线协同系统
安装前需要准备好Docker Desktop环境,可查看我的另一份亲测安装文章 https://blog.csdn.net/qq_43003203/article/details/146283915?spm1001.2014.3001.5501 1、安装配置onlyoffice 1、Docker 拉取onlyoffice容器镜像 管理员身份运行Windows PowerShell&#x…...
解决 Docker 镜像拉取超时问题:配置国内镜像源
在使用 Docker 的过程中,经常会遇到镜像拉取超时的问题,尤其是在国内网络环境下。这不仅会浪费大量的时间,还可能导致一些项目无法顺利进行。今天,我将分享一个简单而有效的解决方法:配置国内镜像源。 环境 操作系统 c…...
如何解决 Three.js 物体渲染的锯齿问题
在 Three.js 中,如果模型看起来不够平滑,或者在旋转视角时出现锯齿(aliasing),可以通过以下方法来优化渲染效果。 1. 启用抗锯齿(MSAA) 默认情况下,Three.js 渲染器不会开启抗锯齿&…...
嵌入式八股,为什么单片机中不使用malloc函数
1. 资源限制 单片机的内存资源通常非常有限,尤其是RAM的大小可能只有几KB到几十KB。在这种情况下,使用 malloc 进行动态内存分配可能会导致内存碎片化,使得程序在运行过程中逐渐耗尽可用内存。 2. 内存碎片问题 malloc 函数在分配和释放内…...
【计量地理学】实验一 地理数据的基本统计分析
阅前提示: 计量地理学实验课的实验报告为当堂提交,相较以往实验报告缺少打磨与整理的时间,因此内容中不可避免出现相关错误!!! 出于个人完美主义的原则本不愿发布(其实就是黑历史)…...
ChatPromptTemplate的使用
ChatPromptTemplate 是 LangChain 中专门用于管理多角色对话结构的提示词模板工具。它的核心价值在于,开发者可以预先定义不同类型的对话角色消息(如系统指令、用户提问、AI历史回复),并通过数据绑定动态生成完整对话上下文。 1.…...
SQL Server查询优化
最常用,最有效的数据库优化方式 查询语句层面 避免全表扫描 使用索引:确保查询条件中的字段有索引。例如,查询语句 SELECT * FROM users WHERE age > 20,若 age 字段有索引,数据库会利用索引快速定位符合条件的记…...
Blender插件NodeWrangler导入贴图报错解决方法
Blender用NodeWrangler插件 CtrlShiftT 导入贴图 直接报错 解决方法: 用CtrlshiftT打开需要导入的材质文件夹时,右边有一个默认勾选的相对路径,取消勾选就可以了。 开启node wrangler插件,然后在导入贴图是取消勾选"相对路径"&am…...
QT中的宏
Q_UNUSED(event); 是 Qt 提供的一个宏,用于标记某个变量或参数在当前作用域中未被使用。它的主要作用是避免编译器发出“未使用变量”的警告。 背景 在 C 中,如果一个函数参数或变量在代码中没有被使用,编译器会发出警告,例如&a…...
java项目之基于ssm的药店药品信息管理系统(源码+文档)
项目简介 药店药品信息管理系统实现了以下功能: 个人信息管理 负责管理个人用户的信息。 员工管理 负责管理药店或药品管理机构的员工信息。 药品管理 负责管理药品的详细信息,可能包括药品名称、成分、剂量、价格、库存等。 进货管理 负责管理药品…...
论文分享 | HE-Nav: 一种适用于复杂环境中空地机器人的高性能高效导航系统
阿木实验室始终致力于通过开源项目和智能无人机产品,为全球无人机开发者提供强有力的技术支持,并推出了开源项目校园赞助活动,助力高校学子在学术研究与技术创新中取得更大突破。近日,香港大学王俊铭同学,基于阿木实验…...
【大语言模型】【个人知识库正式内容】提示工程:如何设计模型的提示语
知识库条目:提示工程,如何构建提示词。 🏖️ 当人人都能使用AI时,你如何才能变得更出彩?……让 AI 带有自己的Tag —— 一、简介 什么是提示语 (Prompt): 提示语是用户输入给AI系统的指令或信息, 简单来说…...
ubuntu 24 安装 python3.x 教程
目录 注意事项 一、安装不同 Python 版本 1. 安装依赖 2. 下载 Python 源码 3. 解压并编译安装 二、管理多个 Python 版本 1. 查看已安装的 Python 版本 2. 配置环境变量 3. 使用 update-alternatives 管理 Python 版本 三、使用虚拟环境为项目指定特定 Python 版本…...
(十一) 人工智能 - Python 教程 - Python元组
更多系列教程,每天更新 更多教程关注:xxxueba.com 星星学霸 1 元组(Tuple) 元组是有序且不可更改的集合。在 Python 中,元组是用圆括号编写的。 实例 创建元组: thistuple ("apple", "b…...
【sql靶场】第13、14、17关-post提交报错注入保姆级教程
目录 【sql靶场】第13、14、17关-post提交报错注入保姆级教程 1.知识回顾 1.报错注入深解 2.报错注入格式 3.使用的函数 4.URL 5.核心组成部分 6.数据编码规范 7.请求方法 2.第十三关 1.测试闭合 2.列数测试 3.测试回显 4.爆出数据库名 5.爆出表名 6.爆出字段 …...
93.HarmonyOS NEXT窗口管理基础教程:深入理解WindowSizeManager
温馨提示:本篇博客的详细代码已发布到 git : https://gitcode.com/nutpi/HarmonyosNext 可以下载运行哦! HarmonyOS NEXT窗口管理基础教程:深入理解WindowSizeManager 文章目录 HarmonyOS NEXT窗口管理基础教程:深入理解WindowSiz…...
Python----数据分析(Pandas一:pandas库介绍,pandas操作文件读取和保存)
一、Pandas库 1.1、概念 Pandas是一个开源的、用于数据处理和分析的Python库,特别适合处理表格类数 据。它建立在NumPy数组之上,提供了高效的数据结构和数据分析工具,使得数据操作变得更加简单、便捷和高效。 Pandas 的目标是成为 Python 数据…...
基于WebRTC技术的EasyRTC嵌入式音视频SDK:多平台兼容与性能优化
在当今数字化、智能化的时代背景下,实时音视频通信技术已成为众多领域不可或缺的关键技术。基于WebRTC技术的EasyRTC嵌入式音视频SDK,凭借其在ARM、Linux、Windows、安卓、iOS等多平台上的兼容性,为开发者提供了强大的工具,推动了…...
【快速入门】MyBatis
一.基础操作 1.准备工作 1)引入依赖 一个是mysql驱动包,一个是mybatis的依赖包: <dependency><groupId>org.mybatis.spring.boot</groupId><artifactId>mybatis-spring-boot-starter</artifactId><vers…...
提升 React 应用性能:使用 React Profiler 进行性能调优
前言 在现代前端开发中,性能优化是一个不可忽视的重要环节。在 React 生态系统中,React Profiler 是一个强大的工具,它可以帮助我们检测和优化应用的性能。 本文将通过通俗易懂的语言介绍 React Profiler 的作用,并展示如何使用它…...
八、Prometheus 静态配置(Static Configuration)
所有的配置都可以用静态配置来监控,只不过用servicemonitor简单,但是域名需要静态配置 如果使用 Prometheus 静态配置(Static Configuration),确实不需要 ServiceMonitor、Service 和 Endpoints,但这也意味着失去了 Kubernetes 自动发现(Service Discovery, SD) 的能力…...
重生之我在学Vue--第16天 Vue 3 插件开发
重生之我在学Vue–第16天 Vue 3 插件开发 文章目录 重生之我在学Vue--第16天 Vue 3 插件开发前言一、插件的作用与开发思路1.1 插件能做什么?1.2 插件开发四部曲 二、开发全局通知插件2.1 插件基础结构2.2 完整插件代码(带注释解析)2.3 样式文…...
网络VLAN技术详解:原理、类型与实战配置
网络VLAN技术详解:原理、类型与实战配置 1. 什么是VLAN? VLAN(Virtual Local Area Network,虚拟局域网) 是一种通过逻辑划分而非物理连接隔离网络设备的技术。它允许管理员将同一物理网络中的设备划分为多个独立的广播…...
