当前位置: 首页 > news >正文

强化学习中的Double DQN、Dueling DQN和PER DQN算法详解及实战

1. 深度Q网络(DQN)回顾

DQN通过神经网络近似状态-动作值函数(Q函数),在训练过程中使用经验回放(Experience Replay)和固定目标网络(Fixed Target Network)来稳定训练过程。DQN的更新公式为:

Q(s, a) \leftarrow Q(s, a) + \alpha [r + \gamma \max_{a'} Q(s', a') - Q(s, a)]

2. Double DQN算法

原理

DQN存在一个问题,即在更新Q值时,使用同一个Q网络选择和评估动作,容易导致过高估计(overestimation)问题。Double DQN(Double Deep Q-Network, DDQN)通过引入两个Q网络,分别用于选择动作和评估动作,来缓解这一问题。

公式推导

Double DQN的更新公式为:

Q(s, a) \leftarrow Q(s, a) + \alpha [r + \gamma Q(s', \arg\max_{a'} Q(s', a'; \theta); \theta^- ) - Q(s, a)]

其中:

  • \theta 是当前Q网络的参数。
  • \theta^{-} 是目标Q网络的参数。
代码实现

我们以经典的OpenAI Gym中的CartPole环境为例,展示Double DQN算法的实现。

import gym
import numpy as np
import tensorflow as tf
from tensorflow.keras import models, layers, optimizersclass DoubleDQNAgent:def __init__(self, state_size, action_size):self.state_size = state_sizeself.action_size = action_sizeself.memory = []self.gamma = 0.95self.epsilon = 1.0self.epsilon_decay = 0.995self.epsilon_min = 0.01self.learning_rate = 0.001self.model = self._build_model()self.target_model = self._build_model()self.update_target_model()def _build_model(self):model = models.Sequential()model.add(layers.Dense(24, input_dim=self.state_size, activation='relu'))model.add(layers.Dense(24, activation='relu'))model.add(layers.Dense(self.action_size, activation='linear'))model.compile(loss='mse', optimizer=optimizers.Adam(learning_rate=self.learning_rate))return modeldef update_target_model(self):self.target_model.set_weights(self.model.get_weights())def remember(self, state, action, reward, next_state, done):self.memory.append((state, action, reward, next_state, done))def act(self, state):if np.random.rand() <= self.epsilon:return np.random.choice(self.action_size)q_values = self.model.predict(state)return np.argmax(q_values[0])def replay(self, batch_size):minibatch = np.random.choice(self.memory, batch_size)for state, action, reward, next_state, done in minibatch:target = self.model.predict(state)if done:target[0][action] = rewardelse:t = self.model.predict(next_state)t_ = self.target_model.predict(next_state)target[0][action] = reward + self.gamma * t_[0][np.argmax(t[0])]self.model.fit(state, target, epochs=1, verbose=0)if self.epsilon > self.epsilon_min:self.epsilon *= self.epsilon_decayenv = gym.make('CartPole-v1')
state_size = env.observation_space.shape[0]
action_size = env.action_space.n
agent = DoubleDQNAgent(state_size, action_size)
episodes = 1000for e in range(episodes):state = env.reset()state = np.reshape(state, [1, state_size])done = Falsetime = 0while not done:action = agent.act(state)next_state, reward, done, _ = env.step(action)next_state = np.reshape(next_state, [1, state_size])reward = reward if not done else -10agent.remember(state, action, reward, next_state, done)state = next_statetime += 1if done:agent.update_target_model()print(f"Episode: {e}/{episodes}, Score: {time}, Epsilon: {agent.epsilon:.2}")if len(agent.memory) > 32:agent.replay(32)env.close()
print("Double DQN训练完成")

3. Dueling DQN算法

原理

Dueling DQN通过将Q值函数拆分为状态价值(Value)和优势函数(Advantage),分别估计某一状态下所有动作的价值和某一动作相对于其他动作的优势。这样可以更好地评估状态的价值,从而提高算法性能。

公式推导

Dueling DQN的Q值函数定义为:

Q(s, a; \theta, \alpha, \beta) = V(s; \theta, \beta) + (A(s, a; \theta, \alpha) - \frac{1}{|\mathcal{A}|} \sum_{a'} A(s, a'; \theta, \alpha)) 

其中:

  • V(s; \theta, \beta)是状态价值函数。
  • A(s, a; \theta, \alpha)是优势函数。
代码实现

以CartPole环境为例,展示Dueling DQN算法的实现。

import gym
import numpy as np
import tensorflow as tf
from tensorflow.keras import models, layers, optimizersclass DuelingDQNAgent:def __init__(self, state_size, action_size):self.state_size = state_sizeself.action_size = action_sizeself.memory = []self.gamma = 0.95self.epsilon = 1.0self.epsilon_decay = 0.995self.epsilon_min = 0.01self.learning_rate = 0.001self.model = self._build_model()self.target_model = self._build_model()self.update_target_model()def _build_model(self):input = layers.Input(shape=(self.state_size,))dense1 = layers.Dense(24, activation='relu')(input)dense2 = layers.Dense(24, activation='relu')(dense1)value_fc = layers.Dense(24, activation='relu')(dense2)value = layers.Dense(1, activation='linear')(value_fc)advantage_fc = layers.Dense(24, activation='relu')(dense2)advantage = layers.Dense(self.action_size, activation='linear')(advantage_fc)q_values = layers.Lambda(lambda x: x[0] + (x[1] - tf.reduce_mean(x[1], axis=1, keepdims=True)))([value, advantage])model = models.Model(inputs=input, outputs=q_values)model.compile(loss='mse', optimizer=optimizers.Adam(learning_rate=self.learning_rate))return modeldef update_target_model(self):self.target_model.set_weights(self.model.get_weights())def remember(self, state, action, reward, next_state, done):self.memory.append((state, action, reward, next_state, done))def act(self, state):if np.random.rand() <= self.epsilon:return np.random.choice(self.action_size)q_values = self.model.predict(state)return np.argmax(q_values[0])def replay(self, batch_size):minibatch = np.random.choice(self.memory, batch_size)for state, action, reward, next_state, done in minibatch:target = self.model.predict(state)if done:target[0][action] = rewardelse:t = self.target_model.predict(next_state)target[0][action] = reward + self.gamma * np.amax(t[0])self.model.fit(state, target, epochs=1, verbose=0)if self.epsilon > self.epsilon_min:self.epsilon *= self.epsilon_decayenv = gym.make('CartPole-v1')
state_size = env.observation_space.shape[0]
action_size = env.action_space.n
agent = DuelingDQNAgent(state_size, action_size)episodes = 1000for e in range(episodes):state = env.reset()state = np.reshape(state, [1, state_size])done = Falsetime = 0while not done:action = agent.act(state)next_state, reward, done, _ = env.step(action)next_state = np.reshape(next_state, [1, state_size])reward = reward if not done else -10agent.remember(state, action, reward, next_state, done)state = next_statetime += 1if done:agent.update_target_model()print(f"Episode: {e}/{episodes}, Score: {time}, Epsilon: {agent.epsilon:.2}")if len(agent.memory) > 32:agent.replay(32)env.close()
print("Dueling DQN训练完成")

4. 优先经验回放DQN(PER DQN)

原理

优先经验回放(Prioritized Experience Replay, PER)通过赋予不同经验样本不同的优先级来增强经验回放机制。优先级高的样本更有可能被再次抽取,从而加速学习过程。

公式推导

优先经验回放基于TD误差计算优先级,定义为:

p_i = | \delta_i | + \epsilon

其中:

  • \delta_i 是TD误差。
  • \epsilon 是一个小的正数,防止优先级为零。

然后根据优先级分布概率来采样,使用重要性采样权重来修正梯度更新,定义为:

w_i = \left( \frac{1}{N} \cdot \frac{1}{P(i)} \right)^\beta

代码实现

以CartPole环境为例,展示PER DQN算法的实现。

import gym
import numpy as np
import tensorflow as tf
from tensorflow.keras import models, layers, optimizers
import random
import collectionsclass PERDQNAgent:def __init__(self, state_size, action_size):self.state_size = state_sizeself.action_size = action_sizeself.memory = collections.deque(maxlen=2000)self.gamma = 0.95self.epsilon = 1.0self.epsilon_decay = 0.995self.epsilon_min = 0.01self.learning_rate = 0.001self.model = self._build_model()self.target_model = self._build_model()self.update_target_model()self.priority = []self.alpha = 0.6self.beta = 0.4self.beta_increment_per_sampling = 0.001def _build_model(self):model = models.Sequential()model.add(layers.Dense(24, input_dim=self.state_size, activation='relu'))model.add(layers.Dense(24, activation='relu'))model.add(layers.Dense(self.action_size, activation='linear'))model.compile(loss='mse', optimizer=optimizers.Adam(learning_rate=self.learning_rate))return modeldef update_target_model(self):self.target_model.set_weights(self.model.get_weights())def remember(self, state, action, reward, next_state, done):self.memory.append((state, action, reward, next_state, done))self.priority.append(max(self.priority, default=1))def act(self, state):if np.random.rand() <= self.epsilon:return np.random.choice(self.action_size)q_values = self.model.predict(state)return np.argmax(q_values[0])def replay(self, batch_size):if len(self.memory) < batch_size:returnpriorities = np.array(self.priority)sampling_probabilities = priorities ** self.alphasampling_probabilities /= sampling_probabilities.sum()indices = np.random.choice(len(self.memory), batch_size, p=sampling_probabilities)minibatch = [self.memory[i] for i in indices]importance_sampling_weights = (len(self.memory) * sampling_probabilities[indices]) ** (-self.beta)importance_sampling_weights /= importance_sampling_weights.max()for i, (state, action, reward, next_state, done) in enumerate(minibatch):target = self.model.predict(state)if done:target[0][action] = rewardelse:t = self.target_model.predict(next_state)target[0][action] = reward + self.gamma * np.amax(t[0])self.model.fit(state, target, epochs=1, verbose=0, sample_weight=importance_sampling_weights[i])self.priority[indices[i]] = abs(target[0][action] - self.model.predict(state)[0][action]) + 1e-6if self.epsilon > self.epsilon_min:self.epsilon *= self.epsilon_decayself.beta = min(1.0, self.beta + self.beta_increment_per_sampling)env = gym.make('CartPole-v1')
state_size = env.observation_space.shape[0]
action_size = env.action_space.n
agent = PERDQNAgent(state_size, action_size)
episodes = 1000for e in range(episodes):state = env.reset()state = np.reshape(state, [1, state_size])done = Falsetime = 0while not done:action = agent.act(state)next_state, reward, done, _ = env.step(action)next_state = np.reshape(next_state, [1, state_size])reward = reward if not done else -10agent.remember(state, action, reward, next_state, done)state = next_statetime += 1if done:agent.update_target_model()print(f"Episode: {e}/{episodes}, Score: {time}, Epsilon: {agent.epsilon:.2}")if len(agent.memory) > 32:agent.replay(32)env.close()
print("PER DQN训练完成")

5. 总结

Double DQN、Dueling DQN和优先经验回放DQN(PER DQN)都是对原始DQN的改进,各有其优点和适用场景。Double DQN通过减少过高估计提高了算法的稳定性;Dueling DQN通过分离状态价值和优势函数更好地评估状态;PER DQN通过优先采样重要经验加速了学习过程。这些改进算法在不同的应用场景下,可以选择合适的算法来提升强化学习的效果。

相关文章:

强化学习中的Double DQN、Dueling DQN和PER DQN算法详解及实战

1. 深度Q网络&#xff08;DQN&#xff09;回顾 DQN通过神经网络近似状态-动作值函数&#xff08;Q函数&#xff09;&#xff0c;在训练过程中使用经验回放&#xff08;Experience Replay&#xff09;和固定目标网络&#xff08;Fixed Target Network&#xff09;来稳定训练过程…...

前端八股文 说一说样式优先级的规则是什么?

标准的回答 CSS样式的优先级应该分成四大类 第一类 !important&#xff1a; &#x1f604;无论引入方式是什么&#xff0c;选择器是什么&#xff0c;它的优先级都是最高的。 第二类 引入方式&#xff1a; &#x1f604;行内样式的优先级要高于嵌入和外链&#xff0c;嵌入和外链…...

洞察国内 AI 绘画行业的璀璨前景

在科技的浪潮中&#xff0c;AI 绘画如同一颗璀璨的新星&#xff0c;正在国内的艺术与技术领域绽放出耀眼的光芒。 近年来&#xff0c;国内 AI 绘画行业发展迅猛&#xff0c;展现出巨大的潜力。随着人工智能技术的不断突破&#xff0c;AI 绘画算法日益精进&#xff0c;能够生成…...

socket编程

文章目录 套接字网路字节序列TCP和UDP套接字 本文章主要介绍Linux下套接字的相关接口&#xff0c;和一些基础知识。 套接字 所有网络通信的行为本质都是进程间进行通信&#xff0c;网络通信也是进程间通信&#xff0c;只不过是不同主机上的两个进程之间的通信。网络通信对于双…...

python自动移除excel文件密码(升级v2版本)

欢迎查看第一版 https://blog.csdn.net/weixin_45631815/article/details/140013476?spm1001.2014.3001.5502 一功能改进 此版本主要改进功能有以下: 直接可以调用函数实现可以尝试多个密码没有加密的文件进行保存,可以按实际业务进行改进.思路来源:java 面向对象设计模式.…...

深入MOJO编程语言的单元测试世界

引言 在软件开发的历程中&#xff0c;单元测试扮演着至关重要的角色。单元测试不仅帮助开发者确保代码的每个部分都按预期工作&#xff0c;而且也是代码质量和维护性的关键保障。本文将引导读者了解如何在MOJO这一假想编程语言中编写单元测试&#xff0c;尽管MOJO并非真实存在…...

Canvas:掌握颜色线条与图像文字设置

想象一下&#xff0c;用几行代码就能创造出如此逼真的图像和动画&#xff0c;仿佛将艺术与科技完美融合&#xff0c;前端开发的Canvas技术正是这个数字化时代中最具魔力的一环&#xff0c;它不仅仅是网页的一部分&#xff0c;更是一个无限创意的画布&#xff0c;一个让你的想象…...

打包导入pyzbar的脚本时的注意事项

目录 前言问题问题的出现解决 总结 本文由Jzwalliser原创&#xff0c;发布在CSDN平台上&#xff0c;遵循CC 4.0 BY-SA协议。 因此&#xff0c;若需转载/引用本文&#xff0c;请注明作者并附原文链接&#xff0c;且禁止删除/修改本段文字。 违者必究&#xff0c;谢谢配合。 个人…...

02-android studio实现下拉列表+单选框+年月日功能

一、下拉列表功能 1.效果图 2.实现过程 1&#xff09;添加组件 <LinearLayoutandroid:layout_width"match_parent"android:layout_height"wrap_content"android:layout_marginLeft"20dp"android:layout_marginRight"20dp"android…...

曹操的五色棋布阵 - 工厂方法模式

定场诗 “兵无常势&#xff0c;水无常形&#xff0c;能因敌变化而取胜者&#xff0c;谓之神。” 在三国的战场上&#xff0c;兵法如棋&#xff0c;布阵如画。曹操的五色棋布阵&#xff0c;不正是今日软件设计中工厂方法模式的绝妙写照吗&#xff1f;让我们从这个神奇的布阵之…...

谷粒商城学习笔记-逆向工程错误记录

文章目录 1&#xff0c;Since Maven 3.8.1 http repositories are blocked.1.1 在maven的settings.xml文件中&#xff0c;新增如下配置&#xff1a;1.2&#xff0c;执行clean命令刷新maven配置 2&#xff0c;internal java compiler error3&#xff0c;启动逆向工程报错&#x…...

FastAPI+SQLAlchemy数据库连接

FastAPISQLAlchemy数据库连接 目录 FastAPISQLAlchemy数据库连接配置数据库连接创建表模型创建alembic迁移文件安装初始化编辑env.py编辑alembic.ini迁移数据库 视图函数查询 配置数据库连接 # db.py from sqlalchemy import create_engine from sqlalchemy.orm import sessio…...

Android中的适配器,你知道是做什么的吗?

&#x1f604;作者简介&#xff1a; 小曾同学.com,一个致力于测试开发的博主⛽️&#xff0c;主要职责&#xff1a;测试开发、CI/CD&#xff0c;日常还会涉及Android开发工作。 如果文章知识点有错误的地方&#xff0c;还请大家指正&#xff0c;让我们一起学习&#xff0c;一起…...

GitHub详解:代码托管与协作开发平台

文章目录 一、GitHub简介二、GitHub的核心功能2.1 仓库&#xff08;Repository&#xff09;2.2 版本控制与分支&#xff08;Branch&#xff09;2.3 Pull Request2.4 Issues与Projects2.5 GitHub Actions 三、GitHub的使用方法3.1 注册与登录3.2 创建和管理仓库3.3 使用Git进行代…...

【植物大战僵尸杂交版】获取+存档插件

文章目录 一、还记得《植物大战僵尸》吗&#xff1f;二、在哪下载&#xff0c;怎么安装&#xff1f;三、杂交版如何进行存档功能概述 一、还记得《植物大战僵尸》吗&#xff1f; 最近&#xff0c;一款曾经在15年前风靡一时的经典游戏《植物大战僵尸》似乎迎来了它的"文艺复…...

BP神经网络与反向传播算法在深度学习中的应用

BP神经网络与反向传播算法在深度学习中的应用 在神经网络的发展历史中&#xff0c;BP神经网络&#xff08;Backpropagation Neural Network&#xff09;占有重要地位。BP神经网络通过反向传播算法进行训练&#xff0c;这种算法在神经网络中引入了一种高效的学习方式。随着深度…...

【数据结构与算法】插入排序

&#x1f493; 博客主页&#xff1a;倔强的石头的CSDN主页 &#x1f4dd;Gitee主页&#xff1a;倔强的石头的gitee主页 ⏩ 文章专栏&#xff1a;《数据结构与算法》 期待您的关注 ​...

MySQL如何实现数据排序

根据explain的执行计划来看&#xff0c;MySQL可以分为索引排序和filesort 索引排序 如果查询中的order by字句包含的字段已经在索引中&#xff0c;且索引的排列顺序和order by子句一致&#xff0c;则可直接利用索引进行排序&#xff0c;由于索引有序&#xff0c;所以排序效率…...

给我的 IM 系统加上监控两件套:【Prometheus + Grafana】

监控是一个系统必不可少的组成部分&#xff0c;实时&#xff0c;准确的监控&#xff0c;将会大大有助于我们排查问题。而当今微服务系统的话有一个监控组合很火那就是 Prometheus Grafana&#xff0c;嘿你别说 这俩兄弟配合的相当完美&#xff0c;Prometheus负责数据采集&…...

【Python】基于动态规划和K聚类的彩色图片压缩算法

引言 当想要压缩一张彩色图像时&#xff0c;彩色图像通常由数百万个颜色值组成&#xff0c;每个颜色值都由红、绿、蓝三个分量组成。因此&#xff0c;如果我们直接对图像的每个像素进行编码&#xff0c;会导致非常大的数据量。为了减少数据量&#xff0c;我们可以尝试减少颜色…...

日语AI面试高效通关秘籍:专业解读与青柚面试智能助攻

在如今就业市场竞争日益激烈的背景下&#xff0c;越来越多的求职者将目光投向了日本及中日双语岗位。但是&#xff0c;一场日语面试往往让许多人感到步履维艰。你是否也曾因为面试官抛出的“刁钻问题”而心生畏惧&#xff1f;面对生疏的日语交流环境&#xff0c;即便提前恶补了…...

Java 语言特性(面试系列2)

一、SQL 基础 1. 复杂查询 &#xff08;1&#xff09;连接查询&#xff08;JOIN&#xff09; 内连接&#xff08;INNER JOIN&#xff09;&#xff1a;返回两表匹配的记录。 SELECT e.name, d.dept_name FROM employees e INNER JOIN departments d ON e.dept_id d.dept_id; 左…...

linux之kylin系统nginx的安装

一、nginx的作用 1.可做高性能的web服务器 直接处理静态资源&#xff08;HTML/CSS/图片等&#xff09;&#xff0c;响应速度远超传统服务器类似apache支持高并发连接 2.反向代理服务器 隐藏后端服务器IP地址&#xff0c;提高安全性 3.负载均衡服务器 支持多种策略分发流量…...

React Native 开发环境搭建(全平台详解)

React Native 开发环境搭建&#xff08;全平台详解&#xff09; 在开始使用 React Native 开发移动应用之前&#xff0c;正确设置开发环境是至关重要的一步。本文将为你提供一份全面的指南&#xff0c;涵盖 macOS 和 Windows 平台的配置步骤&#xff0c;如何在 Android 和 iOS…...

SciencePlots——绘制论文中的图片

文章目录 安装一、风格二、1 资源 安装 # 安装最新版 pip install githttps://github.com/garrettj403/SciencePlots.git# 安装稳定版 pip install SciencePlots一、风格 简单好用的深度学习论文绘图专用工具包–Science Plot 二、 1 资源 论文绘图神器来了&#xff1a;一行…...

Python实现prophet 理论及参数优化

文章目录 Prophet理论及模型参数介绍Python代码完整实现prophet 添加外部数据进行模型优化 之前初步学习prophet的时候&#xff0c;写过一篇简单实现&#xff0c;后期随着对该模型的深入研究&#xff0c;本次记录涉及到prophet 的公式以及参数调优&#xff0c;从公式可以更直观…...

剑指offer20_链表中环的入口节点

链表中环的入口节点 给定一个链表&#xff0c;若其中包含环&#xff0c;则输出环的入口节点。 若其中不包含环&#xff0c;则输出null。 数据范围 节点 val 值取值范围 [ 1 , 1000 ] [1,1000] [1,1000]。 节点 val 值各不相同。 链表长度 [ 0 , 500 ] [0,500] [0,500]。 …...

Axios请求超时重发机制

Axios 超时重新请求实现方案 在 Axios 中实现超时重新请求可以通过以下几种方式&#xff1a; 1. 使用拦截器实现自动重试 import axios from axios;// 创建axios实例 const instance axios.create();// 设置超时时间 instance.defaults.timeout 5000;// 最大重试次数 cons…...

【C++从零实现Json-Rpc框架】第六弹 —— 服务端模块划分

一、项目背景回顾 前五弹完成了Json-Rpc协议解析、请求处理、客户端调用等基础模块搭建。 本弹重点聚焦于服务端的模块划分与架构设计&#xff0c;提升代码结构的可维护性与扩展性。 二、服务端模块设计目标 高内聚低耦合&#xff1a;各模块职责清晰&#xff0c;便于独立开发…...

mysql已经安装,但是通过rpm -q 没有找mysql相关的已安装包

文章目录 现象&#xff1a;mysql已经安装&#xff0c;但是通过rpm -q 没有找mysql相关的已安装包遇到 rpm 命令找不到已经安装的 MySQL 包时&#xff0c;可能是因为以下几个原因&#xff1a;1.MySQL 不是通过 RPM 包安装的2.RPM 数据库损坏3.使用了不同的包名或路径4.使用其他包…...