当前位置: 首页 > news >正文

强化学习中的Double DQN、Dueling DQN和PER DQN算法详解及实战

1. 深度Q网络(DQN)回顾

DQN通过神经网络近似状态-动作值函数(Q函数),在训练过程中使用经验回放(Experience Replay)和固定目标网络(Fixed Target Network)来稳定训练过程。DQN的更新公式为:

Q(s, a) \leftarrow Q(s, a) + \alpha [r + \gamma \max_{a'} Q(s', a') - Q(s, a)]

2. Double DQN算法

原理

DQN存在一个问题,即在更新Q值时,使用同一个Q网络选择和评估动作,容易导致过高估计(overestimation)问题。Double DQN(Double Deep Q-Network, DDQN)通过引入两个Q网络,分别用于选择动作和评估动作,来缓解这一问题。

公式推导

Double DQN的更新公式为:

Q(s, a) \leftarrow Q(s, a) + \alpha [r + \gamma Q(s', \arg\max_{a'} Q(s', a'; \theta); \theta^- ) - Q(s, a)]

其中:

  • \theta 是当前Q网络的参数。
  • \theta^{-} 是目标Q网络的参数。
代码实现

我们以经典的OpenAI Gym中的CartPole环境为例,展示Double DQN算法的实现。

import gym
import numpy as np
import tensorflow as tf
from tensorflow.keras import models, layers, optimizersclass DoubleDQNAgent:def __init__(self, state_size, action_size):self.state_size = state_sizeself.action_size = action_sizeself.memory = []self.gamma = 0.95self.epsilon = 1.0self.epsilon_decay = 0.995self.epsilon_min = 0.01self.learning_rate = 0.001self.model = self._build_model()self.target_model = self._build_model()self.update_target_model()def _build_model(self):model = models.Sequential()model.add(layers.Dense(24, input_dim=self.state_size, activation='relu'))model.add(layers.Dense(24, activation='relu'))model.add(layers.Dense(self.action_size, activation='linear'))model.compile(loss='mse', optimizer=optimizers.Adam(learning_rate=self.learning_rate))return modeldef update_target_model(self):self.target_model.set_weights(self.model.get_weights())def remember(self, state, action, reward, next_state, done):self.memory.append((state, action, reward, next_state, done))def act(self, state):if np.random.rand() <= self.epsilon:return np.random.choice(self.action_size)q_values = self.model.predict(state)return np.argmax(q_values[0])def replay(self, batch_size):minibatch = np.random.choice(self.memory, batch_size)for state, action, reward, next_state, done in minibatch:target = self.model.predict(state)if done:target[0][action] = rewardelse:t = self.model.predict(next_state)t_ = self.target_model.predict(next_state)target[0][action] = reward + self.gamma * t_[0][np.argmax(t[0])]self.model.fit(state, target, epochs=1, verbose=0)if self.epsilon > self.epsilon_min:self.epsilon *= self.epsilon_decayenv = gym.make('CartPole-v1')
state_size = env.observation_space.shape[0]
action_size = env.action_space.n
agent = DoubleDQNAgent(state_size, action_size)
episodes = 1000for e in range(episodes):state = env.reset()state = np.reshape(state, [1, state_size])done = Falsetime = 0while not done:action = agent.act(state)next_state, reward, done, _ = env.step(action)next_state = np.reshape(next_state, [1, state_size])reward = reward if not done else -10agent.remember(state, action, reward, next_state, done)state = next_statetime += 1if done:agent.update_target_model()print(f"Episode: {e}/{episodes}, Score: {time}, Epsilon: {agent.epsilon:.2}")if len(agent.memory) > 32:agent.replay(32)env.close()
print("Double DQN训练完成")

3. Dueling DQN算法

原理

Dueling DQN通过将Q值函数拆分为状态价值(Value)和优势函数(Advantage),分别估计某一状态下所有动作的价值和某一动作相对于其他动作的优势。这样可以更好地评估状态的价值,从而提高算法性能。

公式推导

Dueling DQN的Q值函数定义为:

Q(s, a; \theta, \alpha, \beta) = V(s; \theta, \beta) + (A(s, a; \theta, \alpha) - \frac{1}{|\mathcal{A}|} \sum_{a'} A(s, a'; \theta, \alpha)) 

其中:

  • V(s; \theta, \beta)是状态价值函数。
  • A(s, a; \theta, \alpha)是优势函数。
代码实现

以CartPole环境为例,展示Dueling DQN算法的实现。

import gym
import numpy as np
import tensorflow as tf
from tensorflow.keras import models, layers, optimizersclass DuelingDQNAgent:def __init__(self, state_size, action_size):self.state_size = state_sizeself.action_size = action_sizeself.memory = []self.gamma = 0.95self.epsilon = 1.0self.epsilon_decay = 0.995self.epsilon_min = 0.01self.learning_rate = 0.001self.model = self._build_model()self.target_model = self._build_model()self.update_target_model()def _build_model(self):input = layers.Input(shape=(self.state_size,))dense1 = layers.Dense(24, activation='relu')(input)dense2 = layers.Dense(24, activation='relu')(dense1)value_fc = layers.Dense(24, activation='relu')(dense2)value = layers.Dense(1, activation='linear')(value_fc)advantage_fc = layers.Dense(24, activation='relu')(dense2)advantage = layers.Dense(self.action_size, activation='linear')(advantage_fc)q_values = layers.Lambda(lambda x: x[0] + (x[1] - tf.reduce_mean(x[1], axis=1, keepdims=True)))([value, advantage])model = models.Model(inputs=input, outputs=q_values)model.compile(loss='mse', optimizer=optimizers.Adam(learning_rate=self.learning_rate))return modeldef update_target_model(self):self.target_model.set_weights(self.model.get_weights())def remember(self, state, action, reward, next_state, done):self.memory.append((state, action, reward, next_state, done))def act(self, state):if np.random.rand() <= self.epsilon:return np.random.choice(self.action_size)q_values = self.model.predict(state)return np.argmax(q_values[0])def replay(self, batch_size):minibatch = np.random.choice(self.memory, batch_size)for state, action, reward, next_state, done in minibatch:target = self.model.predict(state)if done:target[0][action] = rewardelse:t = self.target_model.predict(next_state)target[0][action] = reward + self.gamma * np.amax(t[0])self.model.fit(state, target, epochs=1, verbose=0)if self.epsilon > self.epsilon_min:self.epsilon *= self.epsilon_decayenv = gym.make('CartPole-v1')
state_size = env.observation_space.shape[0]
action_size = env.action_space.n
agent = DuelingDQNAgent(state_size, action_size)episodes = 1000for e in range(episodes):state = env.reset()state = np.reshape(state, [1, state_size])done = Falsetime = 0while not done:action = agent.act(state)next_state, reward, done, _ = env.step(action)next_state = np.reshape(next_state, [1, state_size])reward = reward if not done else -10agent.remember(state, action, reward, next_state, done)state = next_statetime += 1if done:agent.update_target_model()print(f"Episode: {e}/{episodes}, Score: {time}, Epsilon: {agent.epsilon:.2}")if len(agent.memory) > 32:agent.replay(32)env.close()
print("Dueling DQN训练完成")

4. 优先经验回放DQN(PER DQN)

原理

优先经验回放(Prioritized Experience Replay, PER)通过赋予不同经验样本不同的优先级来增强经验回放机制。优先级高的样本更有可能被再次抽取,从而加速学习过程。

公式推导

优先经验回放基于TD误差计算优先级,定义为:

p_i = | \delta_i | + \epsilon

其中:

  • \delta_i 是TD误差。
  • \epsilon 是一个小的正数,防止优先级为零。

然后根据优先级分布概率来采样,使用重要性采样权重来修正梯度更新,定义为:

w_i = \left( \frac{1}{N} \cdot \frac{1}{P(i)} \right)^\beta

代码实现

以CartPole环境为例,展示PER DQN算法的实现。

import gym
import numpy as np
import tensorflow as tf
from tensorflow.keras import models, layers, optimizers
import random
import collectionsclass PERDQNAgent:def __init__(self, state_size, action_size):self.state_size = state_sizeself.action_size = action_sizeself.memory = collections.deque(maxlen=2000)self.gamma = 0.95self.epsilon = 1.0self.epsilon_decay = 0.995self.epsilon_min = 0.01self.learning_rate = 0.001self.model = self._build_model()self.target_model = self._build_model()self.update_target_model()self.priority = []self.alpha = 0.6self.beta = 0.4self.beta_increment_per_sampling = 0.001def _build_model(self):model = models.Sequential()model.add(layers.Dense(24, input_dim=self.state_size, activation='relu'))model.add(layers.Dense(24, activation='relu'))model.add(layers.Dense(self.action_size, activation='linear'))model.compile(loss='mse', optimizer=optimizers.Adam(learning_rate=self.learning_rate))return modeldef update_target_model(self):self.target_model.set_weights(self.model.get_weights())def remember(self, state, action, reward, next_state, done):self.memory.append((state, action, reward, next_state, done))self.priority.append(max(self.priority, default=1))def act(self, state):if np.random.rand() <= self.epsilon:return np.random.choice(self.action_size)q_values = self.model.predict(state)return np.argmax(q_values[0])def replay(self, batch_size):if len(self.memory) < batch_size:returnpriorities = np.array(self.priority)sampling_probabilities = priorities ** self.alphasampling_probabilities /= sampling_probabilities.sum()indices = np.random.choice(len(self.memory), batch_size, p=sampling_probabilities)minibatch = [self.memory[i] for i in indices]importance_sampling_weights = (len(self.memory) * sampling_probabilities[indices]) ** (-self.beta)importance_sampling_weights /= importance_sampling_weights.max()for i, (state, action, reward, next_state, done) in enumerate(minibatch):target = self.model.predict(state)if done:target[0][action] = rewardelse:t = self.target_model.predict(next_state)target[0][action] = reward + self.gamma * np.amax(t[0])self.model.fit(state, target, epochs=1, verbose=0, sample_weight=importance_sampling_weights[i])self.priority[indices[i]] = abs(target[0][action] - self.model.predict(state)[0][action]) + 1e-6if self.epsilon > self.epsilon_min:self.epsilon *= self.epsilon_decayself.beta = min(1.0, self.beta + self.beta_increment_per_sampling)env = gym.make('CartPole-v1')
state_size = env.observation_space.shape[0]
action_size = env.action_space.n
agent = PERDQNAgent(state_size, action_size)
episodes = 1000for e in range(episodes):state = env.reset()state = np.reshape(state, [1, state_size])done = Falsetime = 0while not done:action = agent.act(state)next_state, reward, done, _ = env.step(action)next_state = np.reshape(next_state, [1, state_size])reward = reward if not done else -10agent.remember(state, action, reward, next_state, done)state = next_statetime += 1if done:agent.update_target_model()print(f"Episode: {e}/{episodes}, Score: {time}, Epsilon: {agent.epsilon:.2}")if len(agent.memory) > 32:agent.replay(32)env.close()
print("PER DQN训练完成")

5. 总结

Double DQN、Dueling DQN和优先经验回放DQN(PER DQN)都是对原始DQN的改进,各有其优点和适用场景。Double DQN通过减少过高估计提高了算法的稳定性;Dueling DQN通过分离状态价值和优势函数更好地评估状态;PER DQN通过优先采样重要经验加速了学习过程。这些改进算法在不同的应用场景下,可以选择合适的算法来提升强化学习的效果。

相关文章:

强化学习中的Double DQN、Dueling DQN和PER DQN算法详解及实战

1. 深度Q网络&#xff08;DQN&#xff09;回顾 DQN通过神经网络近似状态-动作值函数&#xff08;Q函数&#xff09;&#xff0c;在训练过程中使用经验回放&#xff08;Experience Replay&#xff09;和固定目标网络&#xff08;Fixed Target Network&#xff09;来稳定训练过程…...

前端八股文 说一说样式优先级的规则是什么?

标准的回答 CSS样式的优先级应该分成四大类 第一类 !important&#xff1a; &#x1f604;无论引入方式是什么&#xff0c;选择器是什么&#xff0c;它的优先级都是最高的。 第二类 引入方式&#xff1a; &#x1f604;行内样式的优先级要高于嵌入和外链&#xff0c;嵌入和外链…...

洞察国内 AI 绘画行业的璀璨前景

在科技的浪潮中&#xff0c;AI 绘画如同一颗璀璨的新星&#xff0c;正在国内的艺术与技术领域绽放出耀眼的光芒。 近年来&#xff0c;国内 AI 绘画行业发展迅猛&#xff0c;展现出巨大的潜力。随着人工智能技术的不断突破&#xff0c;AI 绘画算法日益精进&#xff0c;能够生成…...

socket编程

文章目录 套接字网路字节序列TCP和UDP套接字 本文章主要介绍Linux下套接字的相关接口&#xff0c;和一些基础知识。 套接字 所有网络通信的行为本质都是进程间进行通信&#xff0c;网络通信也是进程间通信&#xff0c;只不过是不同主机上的两个进程之间的通信。网络通信对于双…...

python自动移除excel文件密码(升级v2版本)

欢迎查看第一版 https://blog.csdn.net/weixin_45631815/article/details/140013476?spm1001.2014.3001.5502 一功能改进 此版本主要改进功能有以下: 直接可以调用函数实现可以尝试多个密码没有加密的文件进行保存,可以按实际业务进行改进.思路来源:java 面向对象设计模式.…...

深入MOJO编程语言的单元测试世界

引言 在软件开发的历程中&#xff0c;单元测试扮演着至关重要的角色。单元测试不仅帮助开发者确保代码的每个部分都按预期工作&#xff0c;而且也是代码质量和维护性的关键保障。本文将引导读者了解如何在MOJO这一假想编程语言中编写单元测试&#xff0c;尽管MOJO并非真实存在…...

Canvas:掌握颜色线条与图像文字设置

想象一下&#xff0c;用几行代码就能创造出如此逼真的图像和动画&#xff0c;仿佛将艺术与科技完美融合&#xff0c;前端开发的Canvas技术正是这个数字化时代中最具魔力的一环&#xff0c;它不仅仅是网页的一部分&#xff0c;更是一个无限创意的画布&#xff0c;一个让你的想象…...

打包导入pyzbar的脚本时的注意事项

目录 前言问题问题的出现解决 总结 本文由Jzwalliser原创&#xff0c;发布在CSDN平台上&#xff0c;遵循CC 4.0 BY-SA协议。 因此&#xff0c;若需转载/引用本文&#xff0c;请注明作者并附原文链接&#xff0c;且禁止删除/修改本段文字。 违者必究&#xff0c;谢谢配合。 个人…...

02-android studio实现下拉列表+单选框+年月日功能

一、下拉列表功能 1.效果图 2.实现过程 1&#xff09;添加组件 <LinearLayoutandroid:layout_width"match_parent"android:layout_height"wrap_content"android:layout_marginLeft"20dp"android:layout_marginRight"20dp"android…...

曹操的五色棋布阵 - 工厂方法模式

定场诗 “兵无常势&#xff0c;水无常形&#xff0c;能因敌变化而取胜者&#xff0c;谓之神。” 在三国的战场上&#xff0c;兵法如棋&#xff0c;布阵如画。曹操的五色棋布阵&#xff0c;不正是今日软件设计中工厂方法模式的绝妙写照吗&#xff1f;让我们从这个神奇的布阵之…...

谷粒商城学习笔记-逆向工程错误记录

文章目录 1&#xff0c;Since Maven 3.8.1 http repositories are blocked.1.1 在maven的settings.xml文件中&#xff0c;新增如下配置&#xff1a;1.2&#xff0c;执行clean命令刷新maven配置 2&#xff0c;internal java compiler error3&#xff0c;启动逆向工程报错&#x…...

FastAPI+SQLAlchemy数据库连接

FastAPISQLAlchemy数据库连接 目录 FastAPISQLAlchemy数据库连接配置数据库连接创建表模型创建alembic迁移文件安装初始化编辑env.py编辑alembic.ini迁移数据库 视图函数查询 配置数据库连接 # db.py from sqlalchemy import create_engine from sqlalchemy.orm import sessio…...

Android中的适配器,你知道是做什么的吗?

&#x1f604;作者简介&#xff1a; 小曾同学.com,一个致力于测试开发的博主⛽️&#xff0c;主要职责&#xff1a;测试开发、CI/CD&#xff0c;日常还会涉及Android开发工作。 如果文章知识点有错误的地方&#xff0c;还请大家指正&#xff0c;让我们一起学习&#xff0c;一起…...

GitHub详解:代码托管与协作开发平台

文章目录 一、GitHub简介二、GitHub的核心功能2.1 仓库&#xff08;Repository&#xff09;2.2 版本控制与分支&#xff08;Branch&#xff09;2.3 Pull Request2.4 Issues与Projects2.5 GitHub Actions 三、GitHub的使用方法3.1 注册与登录3.2 创建和管理仓库3.3 使用Git进行代…...

【植物大战僵尸杂交版】获取+存档插件

文章目录 一、还记得《植物大战僵尸》吗&#xff1f;二、在哪下载&#xff0c;怎么安装&#xff1f;三、杂交版如何进行存档功能概述 一、还记得《植物大战僵尸》吗&#xff1f; 最近&#xff0c;一款曾经在15年前风靡一时的经典游戏《植物大战僵尸》似乎迎来了它的"文艺复…...

BP神经网络与反向传播算法在深度学习中的应用

BP神经网络与反向传播算法在深度学习中的应用 在神经网络的发展历史中&#xff0c;BP神经网络&#xff08;Backpropagation Neural Network&#xff09;占有重要地位。BP神经网络通过反向传播算法进行训练&#xff0c;这种算法在神经网络中引入了一种高效的学习方式。随着深度…...

【数据结构与算法】插入排序

&#x1f493; 博客主页&#xff1a;倔强的石头的CSDN主页 &#x1f4dd;Gitee主页&#xff1a;倔强的石头的gitee主页 ⏩ 文章专栏&#xff1a;《数据结构与算法》 期待您的关注 ​...

MySQL如何实现数据排序

根据explain的执行计划来看&#xff0c;MySQL可以分为索引排序和filesort 索引排序 如果查询中的order by字句包含的字段已经在索引中&#xff0c;且索引的排列顺序和order by子句一致&#xff0c;则可直接利用索引进行排序&#xff0c;由于索引有序&#xff0c;所以排序效率…...

给我的 IM 系统加上监控两件套:【Prometheus + Grafana】

监控是一个系统必不可少的组成部分&#xff0c;实时&#xff0c;准确的监控&#xff0c;将会大大有助于我们排查问题。而当今微服务系统的话有一个监控组合很火那就是 Prometheus Grafana&#xff0c;嘿你别说 这俩兄弟配合的相当完美&#xff0c;Prometheus负责数据采集&…...

【Python】基于动态规划和K聚类的彩色图片压缩算法

引言 当想要压缩一张彩色图像时&#xff0c;彩色图像通常由数百万个颜色值组成&#xff0c;每个颜色值都由红、绿、蓝三个分量组成。因此&#xff0c;如果我们直接对图像的每个像素进行编码&#xff0c;会导致非常大的数据量。为了减少数据量&#xff0c;我们可以尝试减少颜色…...

springboot 百货中心供应链管理系统小程序

一、前言 随着我国经济迅速发展&#xff0c;人们对手机的需求越来越大&#xff0c;各种手机软件也都在被广泛应用&#xff0c;但是对于手机进行数据信息管理&#xff0c;对于手机的各种软件也是备受用户的喜爱&#xff0c;百货中心供应链管理系统被用户普遍使用&#xff0c;为方…...

循环冗余码校验CRC码 算法步骤+详细实例计算

通信过程&#xff1a;&#xff08;白话解释&#xff09; 我们将原始待发送的消息称为 M M M&#xff0c;依据发送接收消息双方约定的生成多项式 G ( x ) G(x) G(x)&#xff08;意思就是 G &#xff08; x ) G&#xff08;x) G&#xff08;x) 是已知的&#xff09;&#xff0…...

【SpringBoot】100、SpringBoot中使用自定义注解+AOP实现参数自动解密

在实际项目中,用户注册、登录、修改密码等操作,都涉及到参数传输安全问题。所以我们需要在前端对账户、密码等敏感信息加密传输,在后端接收到数据后能自动解密。 1、引入依赖 <dependency><groupId>org.springframework.boot</groupId><artifactId...

聊聊 Pulsar:Producer 源码解析

一、前言 Apache Pulsar 是一个企业级的开源分布式消息传递平台&#xff0c;以其高性能、可扩展性和存储计算分离架构在消息队列和流处理领域独树一帜。在 Pulsar 的核心架构中&#xff0c;Producer&#xff08;生产者&#xff09; 是连接客户端应用与消息队列的第一步。生产者…...

第25节 Node.js 断言测试

Node.js的assert模块主要用于编写程序的单元测试时使用&#xff0c;通过断言可以提早发现和排查出错误。 稳定性: 5 - 锁定 这个模块可用于应用的单元测试&#xff0c;通过 require(assert) 可以使用这个模块。 assert.fail(actual, expected, message, operator) 使用参数…...

【2025年】解决Burpsuite抓不到https包的问题

环境&#xff1a;windows11 burpsuite:2025.5 在抓取https网站时&#xff0c;burpsuite抓取不到https数据包&#xff0c;只显示&#xff1a; 解决该问题只需如下三个步骤&#xff1a; 1、浏览器中访问 http://burp 2、下载 CA certificate 证书 3、在设置--隐私与安全--…...

大学生职业发展与就业创业指导教学评价

这里是引用 作为软工2203/2204班的学生&#xff0c;我们非常感谢您在《大学生职业发展与就业创业指导》课程中的悉心教导。这门课程对我们即将面临实习和就业的工科学生来说至关重要&#xff0c;而您认真负责的教学态度&#xff0c;让课程的每一部分都充满了实用价值。 尤其让我…...

图表类系列各种样式PPT模版分享

图标图表系列PPT模版&#xff0c;柱状图PPT模版&#xff0c;线状图PPT模版&#xff0c;折线图PPT模版&#xff0c;饼状图PPT模版&#xff0c;雷达图PPT模版&#xff0c;树状图PPT模版 图表类系列各种样式PPT模版分享&#xff1a;图表系列PPT模板https://pan.quark.cn/s/20d40aa…...

是否存在路径(FIFOBB算法)

题目描述 一个具有 n 个顶点e条边的无向图&#xff0c;该图顶点的编号依次为0到n-1且不存在顶点与自身相连的边。请使用FIFOBB算法编写程序&#xff0c;确定是否存在从顶点 source到顶点 destination的路径。 输入 第一行两个整数&#xff0c;分别表示n 和 e 的值&#xff08;1…...

大数据学习(132)-HIve数据分析

​​​​&#x1f34b;&#x1f34b;大数据学习&#x1f34b;&#x1f34b; &#x1f525;系列专栏&#xff1a; &#x1f451;哲学语录: 用力所能及&#xff0c;改变世界。 &#x1f496;如果觉得博主的文章还不错的话&#xff0c;请点赞&#x1f44d;收藏⭐️留言&#x1f4…...