当前位置: 首页 > news >正文

强化学习中的Double DQN、Dueling DQN和PER DQN算法详解及实战

1. 深度Q网络(DQN)回顾

DQN通过神经网络近似状态-动作值函数(Q函数),在训练过程中使用经验回放(Experience Replay)和固定目标网络(Fixed Target Network)来稳定训练过程。DQN的更新公式为:

Q(s, a) \leftarrow Q(s, a) + \alpha [r + \gamma \max_{a'} Q(s', a') - Q(s, a)]

2. Double DQN算法

原理

DQN存在一个问题,即在更新Q值时,使用同一个Q网络选择和评估动作,容易导致过高估计(overestimation)问题。Double DQN(Double Deep Q-Network, DDQN)通过引入两个Q网络,分别用于选择动作和评估动作,来缓解这一问题。

公式推导

Double DQN的更新公式为:

Q(s, a) \leftarrow Q(s, a) + \alpha [r + \gamma Q(s', \arg\max_{a'} Q(s', a'; \theta); \theta^- ) - Q(s, a)]

其中:

  • \theta 是当前Q网络的参数。
  • \theta^{-} 是目标Q网络的参数。
代码实现

我们以经典的OpenAI Gym中的CartPole环境为例,展示Double DQN算法的实现。

import gym
import numpy as np
import tensorflow as tf
from tensorflow.keras import models, layers, optimizersclass DoubleDQNAgent:def __init__(self, state_size, action_size):self.state_size = state_sizeself.action_size = action_sizeself.memory = []self.gamma = 0.95self.epsilon = 1.0self.epsilon_decay = 0.995self.epsilon_min = 0.01self.learning_rate = 0.001self.model = self._build_model()self.target_model = self._build_model()self.update_target_model()def _build_model(self):model = models.Sequential()model.add(layers.Dense(24, input_dim=self.state_size, activation='relu'))model.add(layers.Dense(24, activation='relu'))model.add(layers.Dense(self.action_size, activation='linear'))model.compile(loss='mse', optimizer=optimizers.Adam(learning_rate=self.learning_rate))return modeldef update_target_model(self):self.target_model.set_weights(self.model.get_weights())def remember(self, state, action, reward, next_state, done):self.memory.append((state, action, reward, next_state, done))def act(self, state):if np.random.rand() <= self.epsilon:return np.random.choice(self.action_size)q_values = self.model.predict(state)return np.argmax(q_values[0])def replay(self, batch_size):minibatch = np.random.choice(self.memory, batch_size)for state, action, reward, next_state, done in minibatch:target = self.model.predict(state)if done:target[0][action] = rewardelse:t = self.model.predict(next_state)t_ = self.target_model.predict(next_state)target[0][action] = reward + self.gamma * t_[0][np.argmax(t[0])]self.model.fit(state, target, epochs=1, verbose=0)if self.epsilon > self.epsilon_min:self.epsilon *= self.epsilon_decayenv = gym.make('CartPole-v1')
state_size = env.observation_space.shape[0]
action_size = env.action_space.n
agent = DoubleDQNAgent(state_size, action_size)
episodes = 1000for e in range(episodes):state = env.reset()state = np.reshape(state, [1, state_size])done = Falsetime = 0while not done:action = agent.act(state)next_state, reward, done, _ = env.step(action)next_state = np.reshape(next_state, [1, state_size])reward = reward if not done else -10agent.remember(state, action, reward, next_state, done)state = next_statetime += 1if done:agent.update_target_model()print(f"Episode: {e}/{episodes}, Score: {time}, Epsilon: {agent.epsilon:.2}")if len(agent.memory) > 32:agent.replay(32)env.close()
print("Double DQN训练完成")

3. Dueling DQN算法

原理

Dueling DQN通过将Q值函数拆分为状态价值(Value)和优势函数(Advantage),分别估计某一状态下所有动作的价值和某一动作相对于其他动作的优势。这样可以更好地评估状态的价值,从而提高算法性能。

公式推导

Dueling DQN的Q值函数定义为:

Q(s, a; \theta, \alpha, \beta) = V(s; \theta, \beta) + (A(s, a; \theta, \alpha) - \frac{1}{|\mathcal{A}|} \sum_{a'} A(s, a'; \theta, \alpha)) 

其中:

  • V(s; \theta, \beta)是状态价值函数。
  • A(s, a; \theta, \alpha)是优势函数。
代码实现

以CartPole环境为例,展示Dueling DQN算法的实现。

import gym
import numpy as np
import tensorflow as tf
from tensorflow.keras import models, layers, optimizersclass DuelingDQNAgent:def __init__(self, state_size, action_size):self.state_size = state_sizeself.action_size = action_sizeself.memory = []self.gamma = 0.95self.epsilon = 1.0self.epsilon_decay = 0.995self.epsilon_min = 0.01self.learning_rate = 0.001self.model = self._build_model()self.target_model = self._build_model()self.update_target_model()def _build_model(self):input = layers.Input(shape=(self.state_size,))dense1 = layers.Dense(24, activation='relu')(input)dense2 = layers.Dense(24, activation='relu')(dense1)value_fc = layers.Dense(24, activation='relu')(dense2)value = layers.Dense(1, activation='linear')(value_fc)advantage_fc = layers.Dense(24, activation='relu')(dense2)advantage = layers.Dense(self.action_size, activation='linear')(advantage_fc)q_values = layers.Lambda(lambda x: x[0] + (x[1] - tf.reduce_mean(x[1], axis=1, keepdims=True)))([value, advantage])model = models.Model(inputs=input, outputs=q_values)model.compile(loss='mse', optimizer=optimizers.Adam(learning_rate=self.learning_rate))return modeldef update_target_model(self):self.target_model.set_weights(self.model.get_weights())def remember(self, state, action, reward, next_state, done):self.memory.append((state, action, reward, next_state, done))def act(self, state):if np.random.rand() <= self.epsilon:return np.random.choice(self.action_size)q_values = self.model.predict(state)return np.argmax(q_values[0])def replay(self, batch_size):minibatch = np.random.choice(self.memory, batch_size)for state, action, reward, next_state, done in minibatch:target = self.model.predict(state)if done:target[0][action] = rewardelse:t = self.target_model.predict(next_state)target[0][action] = reward + self.gamma * np.amax(t[0])self.model.fit(state, target, epochs=1, verbose=0)if self.epsilon > self.epsilon_min:self.epsilon *= self.epsilon_decayenv = gym.make('CartPole-v1')
state_size = env.observation_space.shape[0]
action_size = env.action_space.n
agent = DuelingDQNAgent(state_size, action_size)episodes = 1000for e in range(episodes):state = env.reset()state = np.reshape(state, [1, state_size])done = Falsetime = 0while not done:action = agent.act(state)next_state, reward, done, _ = env.step(action)next_state = np.reshape(next_state, [1, state_size])reward = reward if not done else -10agent.remember(state, action, reward, next_state, done)state = next_statetime += 1if done:agent.update_target_model()print(f"Episode: {e}/{episodes}, Score: {time}, Epsilon: {agent.epsilon:.2}")if len(agent.memory) > 32:agent.replay(32)env.close()
print("Dueling DQN训练完成")

4. 优先经验回放DQN(PER DQN)

原理

优先经验回放(Prioritized Experience Replay, PER)通过赋予不同经验样本不同的优先级来增强经验回放机制。优先级高的样本更有可能被再次抽取,从而加速学习过程。

公式推导

优先经验回放基于TD误差计算优先级,定义为:

p_i = | \delta_i | + \epsilon

其中:

  • \delta_i 是TD误差。
  • \epsilon 是一个小的正数,防止优先级为零。

然后根据优先级分布概率来采样,使用重要性采样权重来修正梯度更新,定义为:

w_i = \left( \frac{1}{N} \cdot \frac{1}{P(i)} \right)^\beta

代码实现

以CartPole环境为例,展示PER DQN算法的实现。

import gym
import numpy as np
import tensorflow as tf
from tensorflow.keras import models, layers, optimizers
import random
import collectionsclass PERDQNAgent:def __init__(self, state_size, action_size):self.state_size = state_sizeself.action_size = action_sizeself.memory = collections.deque(maxlen=2000)self.gamma = 0.95self.epsilon = 1.0self.epsilon_decay = 0.995self.epsilon_min = 0.01self.learning_rate = 0.001self.model = self._build_model()self.target_model = self._build_model()self.update_target_model()self.priority = []self.alpha = 0.6self.beta = 0.4self.beta_increment_per_sampling = 0.001def _build_model(self):model = models.Sequential()model.add(layers.Dense(24, input_dim=self.state_size, activation='relu'))model.add(layers.Dense(24, activation='relu'))model.add(layers.Dense(self.action_size, activation='linear'))model.compile(loss='mse', optimizer=optimizers.Adam(learning_rate=self.learning_rate))return modeldef update_target_model(self):self.target_model.set_weights(self.model.get_weights())def remember(self, state, action, reward, next_state, done):self.memory.append((state, action, reward, next_state, done))self.priority.append(max(self.priority, default=1))def act(self, state):if np.random.rand() <= self.epsilon:return np.random.choice(self.action_size)q_values = self.model.predict(state)return np.argmax(q_values[0])def replay(self, batch_size):if len(self.memory) < batch_size:returnpriorities = np.array(self.priority)sampling_probabilities = priorities ** self.alphasampling_probabilities /= sampling_probabilities.sum()indices = np.random.choice(len(self.memory), batch_size, p=sampling_probabilities)minibatch = [self.memory[i] for i in indices]importance_sampling_weights = (len(self.memory) * sampling_probabilities[indices]) ** (-self.beta)importance_sampling_weights /= importance_sampling_weights.max()for i, (state, action, reward, next_state, done) in enumerate(minibatch):target = self.model.predict(state)if done:target[0][action] = rewardelse:t = self.target_model.predict(next_state)target[0][action] = reward + self.gamma * np.amax(t[0])self.model.fit(state, target, epochs=1, verbose=0, sample_weight=importance_sampling_weights[i])self.priority[indices[i]] = abs(target[0][action] - self.model.predict(state)[0][action]) + 1e-6if self.epsilon > self.epsilon_min:self.epsilon *= self.epsilon_decayself.beta = min(1.0, self.beta + self.beta_increment_per_sampling)env = gym.make('CartPole-v1')
state_size = env.observation_space.shape[0]
action_size = env.action_space.n
agent = PERDQNAgent(state_size, action_size)
episodes = 1000for e in range(episodes):state = env.reset()state = np.reshape(state, [1, state_size])done = Falsetime = 0while not done:action = agent.act(state)next_state, reward, done, _ = env.step(action)next_state = np.reshape(next_state, [1, state_size])reward = reward if not done else -10agent.remember(state, action, reward, next_state, done)state = next_statetime += 1if done:agent.update_target_model()print(f"Episode: {e}/{episodes}, Score: {time}, Epsilon: {agent.epsilon:.2}")if len(agent.memory) > 32:agent.replay(32)env.close()
print("PER DQN训练完成")

5. 总结

Double DQN、Dueling DQN和优先经验回放DQN(PER DQN)都是对原始DQN的改进,各有其优点和适用场景。Double DQN通过减少过高估计提高了算法的稳定性;Dueling DQN通过分离状态价值和优势函数更好地评估状态;PER DQN通过优先采样重要经验加速了学习过程。这些改进算法在不同的应用场景下,可以选择合适的算法来提升强化学习的效果。

相关文章:

强化学习中的Double DQN、Dueling DQN和PER DQN算法详解及实战

1. 深度Q网络&#xff08;DQN&#xff09;回顾 DQN通过神经网络近似状态-动作值函数&#xff08;Q函数&#xff09;&#xff0c;在训练过程中使用经验回放&#xff08;Experience Replay&#xff09;和固定目标网络&#xff08;Fixed Target Network&#xff09;来稳定训练过程…...

前端八股文 说一说样式优先级的规则是什么?

标准的回答 CSS样式的优先级应该分成四大类 第一类 !important&#xff1a; &#x1f604;无论引入方式是什么&#xff0c;选择器是什么&#xff0c;它的优先级都是最高的。 第二类 引入方式&#xff1a; &#x1f604;行内样式的优先级要高于嵌入和外链&#xff0c;嵌入和外链…...

洞察国内 AI 绘画行业的璀璨前景

在科技的浪潮中&#xff0c;AI 绘画如同一颗璀璨的新星&#xff0c;正在国内的艺术与技术领域绽放出耀眼的光芒。 近年来&#xff0c;国内 AI 绘画行业发展迅猛&#xff0c;展现出巨大的潜力。随着人工智能技术的不断突破&#xff0c;AI 绘画算法日益精进&#xff0c;能够生成…...

socket编程

文章目录 套接字网路字节序列TCP和UDP套接字 本文章主要介绍Linux下套接字的相关接口&#xff0c;和一些基础知识。 套接字 所有网络通信的行为本质都是进程间进行通信&#xff0c;网络通信也是进程间通信&#xff0c;只不过是不同主机上的两个进程之间的通信。网络通信对于双…...

python自动移除excel文件密码(升级v2版本)

欢迎查看第一版 https://blog.csdn.net/weixin_45631815/article/details/140013476?spm1001.2014.3001.5502 一功能改进 此版本主要改进功能有以下: 直接可以调用函数实现可以尝试多个密码没有加密的文件进行保存,可以按实际业务进行改进.思路来源:java 面向对象设计模式.…...

深入MOJO编程语言的单元测试世界

引言 在软件开发的历程中&#xff0c;单元测试扮演着至关重要的角色。单元测试不仅帮助开发者确保代码的每个部分都按预期工作&#xff0c;而且也是代码质量和维护性的关键保障。本文将引导读者了解如何在MOJO这一假想编程语言中编写单元测试&#xff0c;尽管MOJO并非真实存在…...

Canvas:掌握颜色线条与图像文字设置

想象一下&#xff0c;用几行代码就能创造出如此逼真的图像和动画&#xff0c;仿佛将艺术与科技完美融合&#xff0c;前端开发的Canvas技术正是这个数字化时代中最具魔力的一环&#xff0c;它不仅仅是网页的一部分&#xff0c;更是一个无限创意的画布&#xff0c;一个让你的想象…...

打包导入pyzbar的脚本时的注意事项

目录 前言问题问题的出现解决 总结 本文由Jzwalliser原创&#xff0c;发布在CSDN平台上&#xff0c;遵循CC 4.0 BY-SA协议。 因此&#xff0c;若需转载/引用本文&#xff0c;请注明作者并附原文链接&#xff0c;且禁止删除/修改本段文字。 违者必究&#xff0c;谢谢配合。 个人…...

02-android studio实现下拉列表+单选框+年月日功能

一、下拉列表功能 1.效果图 2.实现过程 1&#xff09;添加组件 <LinearLayoutandroid:layout_width"match_parent"android:layout_height"wrap_content"android:layout_marginLeft"20dp"android:layout_marginRight"20dp"android…...

曹操的五色棋布阵 - 工厂方法模式

定场诗 “兵无常势&#xff0c;水无常形&#xff0c;能因敌变化而取胜者&#xff0c;谓之神。” 在三国的战场上&#xff0c;兵法如棋&#xff0c;布阵如画。曹操的五色棋布阵&#xff0c;不正是今日软件设计中工厂方法模式的绝妙写照吗&#xff1f;让我们从这个神奇的布阵之…...

谷粒商城学习笔记-逆向工程错误记录

文章目录 1&#xff0c;Since Maven 3.8.1 http repositories are blocked.1.1 在maven的settings.xml文件中&#xff0c;新增如下配置&#xff1a;1.2&#xff0c;执行clean命令刷新maven配置 2&#xff0c;internal java compiler error3&#xff0c;启动逆向工程报错&#x…...

FastAPI+SQLAlchemy数据库连接

FastAPISQLAlchemy数据库连接 目录 FastAPISQLAlchemy数据库连接配置数据库连接创建表模型创建alembic迁移文件安装初始化编辑env.py编辑alembic.ini迁移数据库 视图函数查询 配置数据库连接 # db.py from sqlalchemy import create_engine from sqlalchemy.orm import sessio…...

Android中的适配器,你知道是做什么的吗?

&#x1f604;作者简介&#xff1a; 小曾同学.com,一个致力于测试开发的博主⛽️&#xff0c;主要职责&#xff1a;测试开发、CI/CD&#xff0c;日常还会涉及Android开发工作。 如果文章知识点有错误的地方&#xff0c;还请大家指正&#xff0c;让我们一起学习&#xff0c;一起…...

GitHub详解:代码托管与协作开发平台

文章目录 一、GitHub简介二、GitHub的核心功能2.1 仓库&#xff08;Repository&#xff09;2.2 版本控制与分支&#xff08;Branch&#xff09;2.3 Pull Request2.4 Issues与Projects2.5 GitHub Actions 三、GitHub的使用方法3.1 注册与登录3.2 创建和管理仓库3.3 使用Git进行代…...

【植物大战僵尸杂交版】获取+存档插件

文章目录 一、还记得《植物大战僵尸》吗&#xff1f;二、在哪下载&#xff0c;怎么安装&#xff1f;三、杂交版如何进行存档功能概述 一、还记得《植物大战僵尸》吗&#xff1f; 最近&#xff0c;一款曾经在15年前风靡一时的经典游戏《植物大战僵尸》似乎迎来了它的"文艺复…...

BP神经网络与反向传播算法在深度学习中的应用

BP神经网络与反向传播算法在深度学习中的应用 在神经网络的发展历史中&#xff0c;BP神经网络&#xff08;Backpropagation Neural Network&#xff09;占有重要地位。BP神经网络通过反向传播算法进行训练&#xff0c;这种算法在神经网络中引入了一种高效的学习方式。随着深度…...

【数据结构与算法】插入排序

&#x1f493; 博客主页&#xff1a;倔强的石头的CSDN主页 &#x1f4dd;Gitee主页&#xff1a;倔强的石头的gitee主页 ⏩ 文章专栏&#xff1a;《数据结构与算法》 期待您的关注 ​...

MySQL如何实现数据排序

根据explain的执行计划来看&#xff0c;MySQL可以分为索引排序和filesort 索引排序 如果查询中的order by字句包含的字段已经在索引中&#xff0c;且索引的排列顺序和order by子句一致&#xff0c;则可直接利用索引进行排序&#xff0c;由于索引有序&#xff0c;所以排序效率…...

给我的 IM 系统加上监控两件套:【Prometheus + Grafana】

监控是一个系统必不可少的组成部分&#xff0c;实时&#xff0c;准确的监控&#xff0c;将会大大有助于我们排查问题。而当今微服务系统的话有一个监控组合很火那就是 Prometheus Grafana&#xff0c;嘿你别说 这俩兄弟配合的相当完美&#xff0c;Prometheus负责数据采集&…...

【Python】基于动态规划和K聚类的彩色图片压缩算法

引言 当想要压缩一张彩色图像时&#xff0c;彩色图像通常由数百万个颜色值组成&#xff0c;每个颜色值都由红、绿、蓝三个分量组成。因此&#xff0c;如果我们直接对图像的每个像素进行编码&#xff0c;会导致非常大的数据量。为了减少数据量&#xff0c;我们可以尝试减少颜色…...

linux之kylin系统nginx的安装

一、nginx的作用 1.可做高性能的web服务器 直接处理静态资源&#xff08;HTML/CSS/图片等&#xff09;&#xff0c;响应速度远超传统服务器类似apache支持高并发连接 2.反向代理服务器 隐藏后端服务器IP地址&#xff0c;提高安全性 3.负载均衡服务器 支持多种策略分发流量…...

python/java环境配置

环境变量放一起 python&#xff1a; 1.首先下载Python Python下载地址&#xff1a;Download Python | Python.org downloads ---windows -- 64 2.安装Python 下面两个&#xff0c;然后自定义&#xff0c;全选 可以把前4个选上 3.环境配置 1&#xff09;搜高级系统设置 2…...

基于uniapp+WebSocket实现聊天对话、消息监听、消息推送、聊天室等功能,多端兼容

基于 ​UniApp + WebSocket​实现多端兼容的实时通讯系统,涵盖WebSocket连接建立、消息收发机制、多端兼容性配置、消息实时监听等功能,适配​微信小程序、H5、Android、iOS等终端 目录 技术选型分析WebSocket协议优势UniApp跨平台特性WebSocket 基础实现连接管理消息收发连接…...

自然语言处理——Transformer

自然语言处理——Transformer 自注意力机制多头注意力机制Transformer 虽然循环神经网络可以对具有序列特性的数据非常有效&#xff0c;它能挖掘数据中的时序信息以及语义信息&#xff0c;但是它有一个很大的缺陷——很难并行化。 我们可以考虑用CNN来替代RNN&#xff0c;但是…...

大语言模型(LLM)中的KV缓存压缩与动态稀疏注意力机制设计

随着大语言模型&#xff08;LLM&#xff09;参数规模的增长&#xff0c;推理阶段的内存占用和计算复杂度成为核心挑战。传统注意力机制的计算复杂度随序列长度呈二次方增长&#xff0c;而KV缓存的内存消耗可能高达数十GB&#xff08;例如Llama2-7B处理100K token时需50GB内存&a…...

【Java学习笔记】BigInteger 和 BigDecimal 类

BigInteger 和 BigDecimal 类 二者共有的常见方法 方法功能add加subtract减multiply乘divide除 注意点&#xff1a;传参类型必须是类对象 一、BigInteger 1. 作用&#xff1a;适合保存比较大的整型数 2. 使用说明 创建BigInteger对象 传入字符串 3. 代码示例 import j…...

day36-多路IO复用

一、基本概念 &#xff08;服务器多客户端模型&#xff09; 定义&#xff1a;单线程或单进程同时监测若干个文件描述符是否可以执行IO操作的能力 作用&#xff1a;应用程序通常需要处理来自多条事件流中的事件&#xff0c;比如我现在用的电脑&#xff0c;需要同时处理键盘鼠标…...

PostgreSQL——环境搭建

一、Linux # 安装 PostgreSQL 15 仓库 sudo dnf install -y https://download.postgresql.org/pub/repos/yum/reporpms/EL-$(rpm -E %{rhel})-x86_64/pgdg-redhat-repo-latest.noarch.rpm# 安装之前先确认是否已经存在PostgreSQL rpm -qa | grep postgres# 如果存在&#xff0…...

【从零开始学习JVM | 第四篇】类加载器和双亲委派机制(高频面试题)

前言&#xff1a; 双亲委派机制对于面试这块来说非常重要&#xff0c;在实际开发中也是经常遇见需要打破双亲委派的需求&#xff0c;今天我们一起来探索一下什么是双亲委派机制&#xff0c;在此之前我们先介绍一下类的加载器。 目录 ​编辑 前言&#xff1a; 类加载器 1. …...

协议转换利器,profinet转ethercat网关的两大派系,各有千秋

随着工业以太网的发展&#xff0c;其高效、便捷、协议开放、易于冗余等诸多优点&#xff0c;被越来越多的工业现场所采用。西门子SIMATIC S7-1200/1500系列PLC集成有Profinet接口&#xff0c;具有实时性、开放性&#xff0c;使用TCP/IP和IT标准&#xff0c;符合基于工业以太网的…...