当前位置: 首页 > article >正文

PyTorch 深度学习实战(11):强化学习与深度 Q 网络(DQN)

在之前的文章中,我们介绍了神经网络、卷积神经网络(CNN)、循环神经网络(RNN)、Transformer 等多种深度学习模型,并应用于图像分类、文本分类、时间序列预测等任务。本文将介绍强化学习的基本概念,并使用 PyTorch 实现一个经典的深度 Q 网络(DQN)来解决强化学习中的经典问题——CartPole。

一、强化学习基础

强化学习(Reinforcement Learning, RL)是机器学习的一个重要分支,它通过智能体(Agent)与环境(Environment)的交互来学习策略,以最大化累积奖励。强化学习的核心思想是通过试错来学习,智能体在环境中采取行动,观察结果,并根据奖励信号调整策略。

1. 强化学习的基本要素

  • 智能体(Agent):学习并做出决策的主体。

  • 环境(Environment):智能体交互的外部世界。

  • 状态(State):环境在某一时刻的描述。

  • 动作(Action):智能体在某一状态下采取的行动。

  • 奖励(Reward):智能体采取动作后,环境返回的反馈信号。

  • 策略(Policy):智能体在给定状态下选择动作的规则。

  • 价值函数(Value Function):评估在某一状态下采取某一动作的长期回报。

2. Q-Learning 与深度 Q 网络(DQN)

Q-Learning 是一种经典的强化学习算法,它通过学习一个 Q 函数来评估在某一状态下采取某一动作的长期回报。Q 函数的更新公式为:

深度 Q 网络(DQN)将 Q-Learning 与深度学习结合,使用神经网络来近似 Q 函数。DQN 通过经验回放(Experience Replay)和目标网络(Target Network)来稳定训练过程。

二、CartPole 问题实战

CartPole 是强化学习中的经典问题,目标是控制一个小车(Cart)使其上的杆子(Pole)保持直立。我们将使用 PyTorch 实现一个 DQN 来解决这个问题。

1. 问题描述

CartPole 环境的状态空间包括小车的位置、速度、杆子的角度和角速度。动作空间包括向左或向右移动小车。智能体每保持杆子直立一步,就会获得 +1 的奖励,当杆子倾斜超过一定角度或小车移动超出范围时,游戏结束。

2. 实现步骤

  1. 安装并导入必要的库。

  2. 定义 DQN 模型。

  3. 定义经验回放缓冲区。

  4. 定义 DQN 训练过程。

  5. 测试模型并评估性能。

3. 代码实现

import gym
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import random
from collections import deque
import matplotlib.pyplot as plt
​
# 设置 Matplotlib 支持中文显示
plt.rcParams['font.sans-serif'] = ['SimHei']  # 设置字体为 SimHei(黑体)
plt.rcParams['axes.unicode_minus'] = False  # 解决负号显示问题
​
# 1. 安装并导入必要的库
env = gym.make('CartPole-v1')
​
# 2. 定义 DQN 模型
class DQN(nn.Module):def __init__(self, state_size, action_size):super(DQN, self).__init__()self.fc1 = nn.Linear(state_size, 64)self.fc2 = nn.Linear(64, 64)self.fc3 = nn.Linear(64, action_size)
​def forward(self, x):x = torch.relu(self.fc1(x))x = torch.relu(self.fc2(x))x = self.fc3(x)return x
​
# 3. 定义经验回放缓冲区
class ReplayBuffer:def __init__(self, capacity):self.buffer = deque(maxlen=capacity)
​def push(self, state, action, reward, next_state, done):self.buffer.append((state, action, reward, next_state, done))
​def sample(self, batch_size):state, action, reward, next_state, done = zip(*random.sample(self.buffer, batch_size))return np.array(state), np.array(action), np.array(reward), np.array(next_state), np.array(done)
​def __len__(self):return len(self.buffer)
​
# 4. 定义 DQN 训练过程
state_size = env.observation_space.shape[0]
action_size = env.action_space.n
model = DQN(state_size, action_size)
target_model = DQN(state_size, action_size)
target_model.load_state_dict(model.state_dict())
optimizer = optim.Adam(model.parameters(), lr=0.001)
buffer = ReplayBuffer(10000)
​
def train(batch_size, gamma=0.99):if len(buffer) < batch_size:returnstate, action, reward, next_state, done = buffer.sample(batch_size)state = torch.FloatTensor(state)next_state = torch.FloatTensor(next_state)action = torch.LongTensor(action)reward = torch.FloatTensor(reward)done = torch.FloatTensor(done)
​q_values = model(state)next_q_values = target_model(next_state)q_value = q_values.gather(1, action.unsqueeze(1)).squeeze(1)next_q_value = next_q_values.max(1)[0]expected_q_value = reward + gamma * next_q_value * (1 - done)
​loss = nn.MSELoss()(q_value, expected_q_value.detach())optimizer.zero_grad()loss.backward()optimizer.step()
​
# 5. 测试模型并评估性能
def test(env, model, episodes=10):total_reward = 0for _ in range(episodes):state = env.reset()done = Falsewhile not done:state = torch.FloatTensor(state).unsqueeze(0)action = model(state).max(1)[1].item()next_state, reward, done, _ = env.step(action)total_reward += rewardstate = next_statereturn total_reward / episodes
​
# 训练过程
episodes = 500
batch_size = 64
gamma = 0.99
epsilon = 1.0
epsilon_min = 0.01
epsilon_decay = 0.995
rewards = []
​
for episode in range(episodes):state = env.reset()done = Falsetotal_reward = 0
​while not done:if random.random() < epsilon:action = env.action_space.sample()else:state_tensor = torch.FloatTensor(state).unsqueeze(0)action = model(state_tensor).max(1)[1].item()
​next_state, reward, done, _ = env.step(action)buffer.push(state, action, reward, next_state, done)state = next_statetotal_reward += reward
​train(batch_size, gamma)
​epsilon = max(epsilon_min, epsilon * epsilon_decay)rewards.append(total_reward)
​if (episode + 1) % 50 == 0:avg_reward = test(env, model)print(f"Episode: {episode + 1}, Avg Reward: {avg_reward:.2f}")
​
# 6. 可视化训练结果
plt.plot(rewards)
plt.xlabel("Episode")
plt.ylabel("Total Reward")
plt.title("DQN 训练过程")
plt.show()

三、代码解析

  1. 环境与模型定义

    • 使用 gym 创建 CartPole 环境。

    • 定义 DQN 模型,包含三个全连接层。

  2. 经验回放缓冲区

    • 使用 deque 实现经验回放缓冲区,存储状态、动作、奖励等信息。

  3. 训练过程

    • 使用 epsilon-greedy 策略进行探索与利用。

    • 通过经验回放缓冲区采样数据进行训练,更新模型参数。

  4. 测试过程

    • 在测试环境中评估模型性能,计算平均奖励。

  5. 可视化

    • 绘制训练过程中的总奖励曲线。

四、运行结果

运行上述代码后,你将看到以下输出:

  • 训练过程中每 50 个 episode 打印一次平均奖励。

  • 训练结束后,绘制训练过程中的总奖励曲线。

五、总结

本文介绍了强化学习的基本概念,并使用 PyTorch 实现了一个深度 Q 网络(DQN)来解决 CartPole 问题。通过这个例子,我们学习了如何定义 DQN 模型、使用经验回放缓冲区、训练模型以及评估性能。

在下一篇文章中,我们将探讨更复杂的强化学习算法,如 Actor-Critic 和 Proximal Policy Optimization (PPO)。敬请期待!

代码实例说明

  • 本文代码可以直接在 Jupyter Notebook 或 Python 脚本中运行。

  • 如果你有 GPU,可以将模型和数据移动到 GPU 上运行,例如:model = model.to('cuda')state = state.to('cuda')

希望这篇文章能帮助你更好地理解强化学习的基础知识!如果有任何问题,欢迎在评论区留言讨论。

相关文章:

PyTorch 深度学习实战(11):强化学习与深度 Q 网络(DQN)

在之前的文章中&#xff0c;我们介绍了神经网络、卷积神经网络&#xff08;CNN&#xff09;、循环神经网络&#xff08;RNN&#xff09;、Transformer 等多种深度学习模型&#xff0c;并应用于图像分类、文本分类、时间序列预测等任务。本文将介绍强化学习的基本概念&#xff0…...

在Eclipse 中使用 MyBatis 进行开发,通常需要以下步骤:

在Eclipse 中使用 MyBatis 进行开发&#xff0c;通常需要以下步骤&#xff1a; 1. 创建 Maven 项目 首先&#xff0c;在 Eclipse 中创建一个 Maven 项目。如果你还没有安装 Maven 插件&#xff0c;可以通过 Eclipse Marketplace 安装 Maven 插件。 打开 Eclipse&#xff0c;选…...

Python学习第十九天

Django-分页 后端分页 Django提供了Paginator类来实现后端分页。Paginator类可以将一个查询集&#xff08;QuerySet&#xff09;分成多个页面&#xff0c;每个页面包含指定数量的对象。 from django.shortcuts import render, redirect, get_object_or_404 from .models impo…...

Adobe Premiere Pro2023配置要求

Windows 系统 最低配置 处理器&#xff1a;Intel 第六代或更新版本的 CPU&#xff0c;或 AMD Ryzen™ 1000 系列或更新版本的 CPU&#xff0c;需要支持 Advanced Vector Extensions 2&#xff08;AVX2&#xff09;。操作系统&#xff1a;Windows 10&#xff08;64 位&#xff…...

面试求助:接口测试用例设计主要考虑哪些方面?

一、基础功能验证 1. 正常场景覆盖 关键点&#xff1a;验证接口在合法输入下的正确响应&#xff08;状态码、数据结构、业务逻辑&#xff09;。 案例&#xff1a; json 复制 // 用户登录接口 输入&#xff1a;{"username": "合法用户", "password…...

C语言——变量与常量

C语言中的变量与常量&#xff1a;简洁易懂的指南 在C语言编程中&#xff0c;变量和常量是最基本的概念之一。理解它们的区别和使用方法对于编写高效、可维护的代码至关重要。本文将详细介绍C语言中的变量和常量&#xff0c;并通过图表和代码示例帮助你更好地理解。 目录 什么…...

考研408-数据结构完整代码 线性表的顺序存储结构 - 顺序表

线性表的顺序存储结构 - 顺序表 1. 顺序表的定义 ​ 用一组地址连续的存储单元依次存储线性表的数据元素&#xff0c;从而使逻辑上相邻的两个元素在物理位置上也相邻 2. 顺序表的特点 随机访问&#xff1a; 即通过首地址和元素序号可以在O(1) 时间内找到指定元素&#xff0…...

Windows环境下安装部署dzzoffice+onlyoffice的私有网盘和在线协同系统

安装前需要准备好Docker Desktop环境&#xff0c;可查看我的另一份亲测安装文章 https://blog.csdn.net/qq_43003203/article/details/146283915?spm1001.2014.3001.5501 1、安装配置onlyoffice 1、Docker 拉取onlyoffice容器镜像 管理员身份运行Windows PowerShell&#x…...

解决 Docker 镜像拉取超时问题:配置国内镜像源

在使用 Docker 的过程中&#xff0c;经常会遇到镜像拉取超时的问题&#xff0c;尤其是在国内网络环境下。这不仅会浪费大量的时间&#xff0c;还可能导致一些项目无法顺利进行。今天&#xff0c;我将分享一个简单而有效的解决方法&#xff1a;配置国内镜像源。 环境 操作系统 c…...

如何解决 Three.js 物体渲染的锯齿问题

在 Three.js 中&#xff0c;如果模型看起来不够平滑&#xff0c;或者在旋转视角时出现锯齿&#xff08;aliasing&#xff09;&#xff0c;可以通过以下方法来优化渲染效果。 1. 启用抗锯齿&#xff08;MSAA&#xff09; 默认情况下&#xff0c;Three.js 渲染器不会开启抗锯齿&…...

嵌入式八股,为什么单片机中不使用malloc函数

1. 资源限制 单片机的内存资源通常非常有限&#xff0c;尤其是RAM的大小可能只有几KB到几十KB。在这种情况下&#xff0c;使用 malloc 进行动态内存分配可能会导致内存碎片化&#xff0c;使得程序在运行过程中逐渐耗尽可用内存。 2. 内存碎片问题 malloc 函数在分配和释放内…...

【计量地理学】实验一 地理数据的基本统计分析

阅前提示&#xff1a; 计量地理学实验课的实验报告为当堂提交&#xff0c;相较以往实验报告缺少打磨与整理的时间&#xff0c;因此内容中不可避免出现相关错误&#xff01;&#xff01;&#xff01; 出于个人完美主义的原则本不愿发布&#xff08;其实就是黑历史&#xff09;…...

ChatPromptTemplate的使用

ChatPromptTemplate 是 LangChain 中专门用于管理多角色对话结构的提示词模板工具。它的核心价值在于&#xff0c;开发者可以预先定义不同类型的对话角色消息&#xff08;如系统指令、用户提问、AI历史回复&#xff09;&#xff0c;并通过数据绑定动态生成完整对话上下文。 1.…...

SQL Server查询优化

最常用&#xff0c;最有效的数据库优化方式 查询语句层面 避免全表扫描 使用索引&#xff1a;确保查询条件中的字段有索引。例如&#xff0c;查询语句 SELECT * FROM users WHERE age > 20&#xff0c;若 age 字段有索引&#xff0c;数据库会利用索引快速定位符合条件的记…...

Blender插件NodeWrangler导入贴图报错解决方法

Blender用NodeWrangler插件 CtrlShiftT 导入贴图 直接报错 解决方法: 用CtrlshiftT打开需要导入的材质文件夹时&#xff0c;右边有一个默认勾选的相对路径&#xff0c;取消勾选就可以了。 开启node wrangler插件&#xff0c;然后在导入贴图是取消勾选"相对路径"&am…...

QT中的宏

Q_UNUSED(event); 是 Qt 提供的一个宏&#xff0c;用于标记某个变量或参数在当前作用域中未被使用。它的主要作用是避免编译器发出“未使用变量”的警告。 背景 在 C 中&#xff0c;如果一个函数参数或变量在代码中没有被使用&#xff0c;编译器会发出警告&#xff0c;例如&a…...

java项目之基于ssm的药店药品信息管理系统(源码+文档)

项目简介 药店药品信息管理系统实现了以下功能&#xff1a; 个人信息管理 负责管理个人用户的信息。 员工管理 负责管理药店或药品管理机构的员工信息。 药品管理 负责管理药品的详细信息&#xff0c;可能包括药品名称、成分、剂量、价格、库存等。 进货管理 负责管理药品…...

论文分享 | HE-Nav: 一种适用于复杂环境中空地机器人的高性能高效导航系统

阿木实验室始终致力于通过开源项目和智能无人机产品&#xff0c;为全球无人机开发者提供强有力的技术支持&#xff0c;并推出了开源项目校园赞助活动&#xff0c;助力高校学子在学术研究与技术创新中取得更大突破。近日&#xff0c;香港大学王俊铭同学&#xff0c;基于阿木实验…...

【大语言模型】【个人知识库正式内容】提示工程:如何设计模型的提示语

知识库条目&#xff1a;提示工程&#xff0c;如何构建提示词。 &#x1f3d6;️ 当人人都能使用AI时&#xff0c;你如何才能变得更出彩&#xff1f;……让 AI 带有自己的Tag —— 一、简介 什么是提示语 (Prompt)&#xff1a; 提示语是用户输入给AI系统的指令或信息, 简单来说…...

ubuntu 24 安装 python3.x 教程

目录 注意事项 一、安装不同 Python 版本 1. 安装依赖 2. 下载 Python 源码 3. 解压并编译安装 二、管理多个 Python 版本 1. 查看已安装的 Python 版本 2. 配置环境变量 3. 使用 update-alternatives​ 管理 Python 版本 三、使用虚拟环境为项目指定特定 Python 版本…...

(十一) 人工智能 - Python 教程 - Python元组

更多系列教程&#xff0c;每天更新 更多教程关注&#xff1a;xxxueba.com 星星学霸 1 元组&#xff08;Tuple&#xff09; 元组是有序且不可更改的集合。在 Python 中&#xff0c;元组是用圆括号编写的。 实例 创建元组&#xff1a; thistuple ("apple", "b…...

【sql靶场】第13、14、17关-post提交报错注入保姆级教程

目录 【sql靶场】第13、14、17关-post提交报错注入保姆级教程 1.知识回顾 1.报错注入深解 2.报错注入格式 3.使用的函数 4.URL 5.核心组成部分 6.数据编码规范 7.请求方法 2.第十三关 1.测试闭合 2.列数测试 3.测试回显 4.爆出数据库名 5.爆出表名 6.爆出字段 …...

93.HarmonyOS NEXT窗口管理基础教程:深入理解WindowSizeManager

温馨提示&#xff1a;本篇博客的详细代码已发布到 git : https://gitcode.com/nutpi/HarmonyosNext 可以下载运行哦&#xff01; HarmonyOS NEXT窗口管理基础教程&#xff1a;深入理解WindowSizeManager 文章目录 HarmonyOS NEXT窗口管理基础教程&#xff1a;深入理解WindowSiz…...

Python----数据分析(Pandas一:pandas库介绍,pandas操作文件读取和保存)

一、Pandas库 1.1、概念 Pandas是一个开源的、用于数据处理和分析的Python库&#xff0c;特别适合处理表格类数 据。它建立在NumPy数组之上&#xff0c;提供了高效的数据结构和数据分析工具&#xff0c;使得数据操作变得更加简单、便捷和高效。 Pandas 的目标是成为 Python 数据…...

基于WebRTC技术的EasyRTC嵌入式音视频SDK:多平台兼容与性能优化

在当今数字化、智能化的时代背景下&#xff0c;实时音视频通信技术已成为众多领域不可或缺的关键技术。基于WebRTC技术的EasyRTC嵌入式音视频SDK&#xff0c;凭借其在ARM、Linux、Windows、安卓、iOS等多平台上的兼容性&#xff0c;为开发者提供了强大的工具&#xff0c;推动了…...

【快速入门】MyBatis

一.基础操作 1.准备工作 1&#xff09;引入依赖 一个是mysql驱动包&#xff0c;一个是mybatis的依赖包&#xff1a; <dependency><groupId>org.mybatis.spring.boot</groupId><artifactId>mybatis-spring-boot-starter</artifactId><vers…...

提升 React 应用性能:使用 React Profiler 进行性能调优

前言 在现代前端开发中&#xff0c;性能优化是一个不可忽视的重要环节。在 React 生态系统中&#xff0c;React Profiler 是一个强大的工具&#xff0c;它可以帮助我们检测和优化应用的性能。 本文将通过通俗易懂的语言介绍 React Profiler 的作用&#xff0c;并展示如何使用它…...

八、Prometheus 静态配置(Static Configuration)

所有的配置都可以用静态配置来监控,只不过用servicemonitor简单,但是域名需要静态配置 如果使用 Prometheus 静态配置(Static Configuration),确实不需要 ServiceMonitor、Service 和 Endpoints,但这也意味着失去了 Kubernetes 自动发现(Service Discovery, SD) 的能力…...

重生之我在学Vue--第16天 Vue 3 插件开发

重生之我在学Vue–第16天 Vue 3 插件开发 文章目录 重生之我在学Vue--第16天 Vue 3 插件开发前言一、插件的作用与开发思路1.1 插件能做什么&#xff1f;1.2 插件开发四部曲 二、开发全局通知插件2.1 插件基础结构2.2 完整插件代码&#xff08;带注释解析&#xff09;2.3 样式文…...

网络VLAN技术详解:原理、类型与实战配置

网络VLAN技术详解&#xff1a;原理、类型与实战配置 1. 什么是VLAN&#xff1f; VLAN&#xff08;Virtual Local Area Network&#xff0c;虚拟局域网&#xff09; 是一种通过逻辑划分而非物理连接隔离网络设备的技术。它允许管理员将同一物理网络中的设备划分为多个独立的广播…...