当前位置: 首页 > news >正文

对抗式生成模仿学习(GAIL)

目录

1 预先基础知识 

1.1 对抗生成网络(GAN)

1.1.1 基本概念

1.1.2 损失函数

1.1.2.1 固定G,求解令损失函数最大的D

1.1.2.2 固定D,求解令损失函数最小的G

1.2 对抗式生成模仿学习特点

2 对抗式生成模仿学习(GAIL)详细说明

3 参考文献

1 预先基础知识 

1.1 对抗生成网络(GAN)

1.1.1 基本概念

在GAN生成对抗网络中,包含两个模型,一个生成模型,一个判别模型。

  • 生成模型:负责生成看起来真实自然,和原始数据相似的实例。
  • 判别模型:负责判断给出的实例是真实的还是人为伪造的。

生成模型努力去欺骗判别模型,判别模型努力不被欺骗,这样两种模型交替优化训练,都得到了提升。

对于辨别器,如果得到的是生成图片辨别器应该输出0,如果是真实的图片应该输出 1,得到误差梯度反向传播来更新参数。对于生成器,首先由生成器生成一张图片,然后输入给判别器判别并的到相应的误差梯度,然后反向传播这些图片梯度成为组成生成器的权重。直观上来说就是:辨别器不得不告诉生成器如何调整从而使它生成的图片变得更加真实。

1.1.2 损失函数

GAN模型的目标函数:

其中,参考GAN的架构图,字母 V是原始GAN论文中指定用来表示该交叉熵的字母,x 表示任意真实数据,z 表示与真实数据相同结构的任意随机数据,G(z)表示在生成器中基于 z 生成的假数据,而D(x)表示判别器在真实数据 x上判断出的结果,D(G(z))表示判别器在假数据 G(z)上判断出的结果,其中 D(x) 与D(G(z))都是样本为“真”的概率,即标签为1的概率。

上式,主要意思是先固定生成器G,从判别器D的角度令损失最大化,紧接着固定D,从生成器G的角度令损失最小化,即可让判别器和生成器在共享损失的情况下实现对抗。其中第一个期望\mathbb{E}_{x \sim p_{\text{data}}(x)} \left[ \log D(x) \right]是所有x都是真实数据时(log(D(x)))的期望,第二个期望\mathbb{E}_{z \sim p(z)} \left[ \log (1 - D(G(z))) \right]是所有数据都是生成数据时log(1-D(G(z)))的期望。可以看出,在求解最优解的过程中存在两个过程:

  • 固定G,求解令损失函数最大的D
  • 固定D,求解令损失函数最小的G

判别网络是一个2分类,目标是分清真实数据和伪造数据,也就是希望D(x) 趋近于1,D(G(z))趋近于0,这也就体现了对抗的思想。G网络的loss是log(1-D(G(z))),D的loss是-(log(D(x)))+log(1-D(G(z)))。

1.1.2.1 固定G,求解令损失函数最大的D

判别器D的输入x有两部分:一部分是真实数据,设其分布为P_{\text{data}}(x);另一部分是生成器生成的数据,参考架构图,生成器接收的数据z服从分布P(z),A输入z经过生成器的计算生成的数据分布设为P_{G}(x)

这两部分这两部分都是判别器D的输入,不同的是,G的输出来自分布P_{G}(x),而真实数据来自分布P_{\text{data}}(x),经过一系列推导后的结果:

可以看出,固定G,将最优的D带入后,此时V(G,D*),实际上是在度量P_{\text{data}}(x)P_{G}(x)之间的JS散度,同KL散度一样,他们之间的分布差异越大,JS散度值也越大。换句话说:保持G不变,最大化V(G,D)就等价于计算JS散度。对于判别器来说,尽可能找出生成器生成的数据与真实数据分布之间的差异,这个差异就是JS散度。

1.1.2.2 固定D,求解令损失函数最小的G

对于生成器来说,让生成器生成的数据分布接近真实数据分布。现在第一步已经求出了最优解的D*,代入损失函数:

在最小化JS散度,JS散度越小,分部之间的差异越小,正好印证了第二个原则。

1.2 对抗式生成模仿学习特点

逆强化学习(Inverse Reinforcement Learning, IRL)作为一种典型的模仿学习方法,顾名思义,逆强化学习的学习过程与正常的强化学习利用奖励函数学习策略相反,不利用现有的奖励函数,而是试图学出一个奖励函数,并以之指导基于奖励函数的强化学习过程。IRL可以归结为解决从观察到的最优行为中提取奖励函数( Reward Function)的问题,这些最优行为也可以表示为专家策略 。基于IRL的方法交替地在两个过程中交替:一个阶段是使用示范数据来推断一个隐藏的奖励(Reward)或代价( Cost)函数,另一个阶段是使用强化学习基于推断的奖励函数来学习一个模仿策略。IRL的基本准则是:IRL选择奖励函数来优化策略,并且使得任何不同于\Pi _{E}的动作决策尽可能产生更大损失。

对抗式生成模仿学习(Generative Adversarial Imitation Learning,GAIL)是逆强化学习的一种重要实现方法之一。逆强化学习旨在从专家示范的行为中推断环境的奖励函数或者价值函数,而GAIL是逆强化学习的一种实现方式,它利用了生成对抗网络(GAN)的概念来进行模仿学习。

GAIL的关键点在于:

1生成对抗网络: GAIL使用生成对抗网络的框架,其中包括生成器和判别器。

2生成器与判别器: 生成器尝试生成与专家示范行为相似的状态-动作对,而判别器则尝试区分专家行为和生成器生成的行为。

3对抗优化: GAIL使用对抗训练的思想,通过生成器和判别器的对抗优化来使得生成器的输出逼近专家的行为。

GAIL的工作方式使得它在逆强化学习中发挥着重要作用,因为它提供了一种有效的方式来从专家示范中学习环境的奖励结构,以指导智能体的学习行为。通过对抗式生成模仿学习,智能体可以学习并模仿专家的行为,而无需显式地使用环境的奖励信号。

因此,GAIL作为逆强化学习的一种方法,为从专家示范中学习环境的奖励函数或者价值函数提供了一种有效的框架和方法。

2 对抗式生成模仿学习(GAIL)详细说明

 

生成式对抗模仿学习的整体优化流程如图所示。通过 GAIL 方法,策略生成器通过生成类似专家示教样本的探索样本,泛化示教样本的概率分布, 逼近专家示范行为数据,进而实现模仿专家技能的目的。该过程直接优化采样样本的概率分布,计算代价较小且算法通用性更强,实际模仿效果也更好。 

伪代码:

# 初始化策略 π、判别器 D、专家示范数据 D_expert、策略缓冲区 D_policy函数 GAIL_Training():初始化策略 π 的参数初始化判别器 D 的参数循环 直到收敛 或 达到最大迭代次数:# 使用当前策略 π 生成轨迹并存储在策略缓冲区 D_policy 中生成 trajectories 使用 π 并存储在 D_policy 中# 判别器训练循环 discriminator_updates 次数:# 从策略缓冲区 D_policy 中采样数据采样 (s_policy, a_policy) 从 D_policy 中# 从专家示范数据 D_expert 中采样数据采样 (s_expert, a_expert) 从 D_expert 中# 更新判别器 D计算 L_D = -[log(D(s_expert, a_expert)) + log(1 - D(s_policy, a_policy))]使用梯度下降法更新判别器参数以最小化 L_D# 策略更新采样 (s, a, ...) 从 D_policy 中计算伪奖励 r = -log(1 - D(s, a))# 使用伪奖励 r 更新策略 π计算 L_π 使用 PPO 或 其他强化学习方法使用梯度下降法更新策略 π 的参数以最大化 L_π

能够表征GAIL流程的主程序如下: 

import torch
from torch import nn
import torch.nn.functional as F
from torch.optim import Adamfrom .ppo import PPO
from gail_airl_ppo.network import GAILDiscrimclass GAIL(PPO):def __init__(self, buffer_exp, state_shape, action_shape, device, seed,gamma=0.995, rollout_length=50000, mix_buffer=1,batch_size=64, lr_actor=3e-4, lr_critic=3e-4, lr_disc=3e-4,units_actor=(64, 64), units_critic=(64, 64),units_disc=(100, 100), epoch_ppo=50, epoch_disc=10,clip_eps=0.2, lambd=0.97, coef_ent=0.0, max_grad_norm=10.0):super().__init__(state_shape, action_shape, device, seed, gamma, rollout_length,mix_buffer, lr_actor, lr_critic, units_actor, units_critic,epoch_ppo, clip_eps, lambd, coef_ent, max_grad_norm)# Expert's buffer.self.buffer_exp = buffer_exp# Discriminator.self.disc = GAILDiscrim(state_shape=state_shape,action_shape=action_shape,hidden_units=units_disc,hidden_activation=nn.Tanh()).to(device)self.learning_steps_disc = 0self.optim_disc = Adam(self.disc.parameters(), lr=lr_disc)self.batch_size = batch_sizeself.epoch_disc = epoch_discdef update(self, writer):self.learning_steps += 1for _ in range(self.epoch_disc):self.learning_steps_disc += 1# Samples from current policy's trajectories.states, actions = self.buffer.sample(self.batch_size)[:2]# Samples from expert's demonstrations.states_exp, actions_exp = \self.buffer_exp.sample(self.batch_size)[:2]# Update discriminator.self.update_disc(states, actions, states_exp, actions_exp, writer)# We don't use reward signals here,states, actions, _, dones, log_pis, next_states = self.buffer.get()# Calculate rewards.rewards = self.disc.calculate_reward(states, actions)# Update PPO using estimated rewards.self.update_ppo(states, actions, rewards, dones, log_pis, next_states, writer)def update_disc(self, states, actions, states_exp, actions_exp, writer):# Output of discriminator is (-inf, inf), not [0, 1].logits_pi = self.disc(states, actions)logits_exp = self.disc(states_exp, actions_exp)# Discriminator is to maximize E_{\pi} [log(1 - D)] + E_{exp} [log(D)].loss_pi = -F.logsigmoid(-logits_pi).mean()loss_exp = -F.logsigmoid(logits_exp).mean()loss_disc = loss_pi + loss_expself.optim_disc.zero_grad()loss_disc.backward()self.optim_disc.step()if self.learning_steps_disc % self.epoch_disc == 0:writer.add_scalar('loss/disc', loss_disc.item(), self.learning_steps)# Discriminator's accuracies.with torch.no_grad():acc_pi = (logits_pi < 0).float().mean().item()acc_exp = (logits_exp > 0).float().mean().item()writer.add_scalar('stats/acc_pi', acc_pi, self.learning_steps)writer.add_scalar('stats/acc_exp', acc_exp, self.learning_steps)

3 参考文献

https://zhuanlan.zhihu.com/p/628915533

【强化学习】GAIL_gail算法-CSDN博客

代码:https://github.com/toshikwa/gail-airl-ppo.pytorch.git

相关文章:

对抗式生成模仿学习(GAIL)

目录 1 预先基础知识 1.1 对抗生成网络&#xff08;GAN&#xff09; 1.1.1 基本概念 1.1.2 损失函数 1.1.2.1 固定G&#xff0c;求解令损失函数最大的D 1.1.2.2 固定D&#xff0c;求解令损失函数最小的G 1.2 对抗式生成模仿学习特点 2 对抗式生成模仿学习&#xff08;…...

信息系统项目管理师 | 新一代信息技术

关注WX&#xff1a;CodingTechWork 物联网 定义 The Internet of Things是指通过信息传感设备&#xff0c;按约定的协议&#xff0c;将任何物品与互联网连接&#xff0c;进行信息交互和通信&#xff0c;以实现智能化识别。定位、跟踪、监控和管理的一种网络。物联网主要解决…...

安全宣传咨询日活动向媒体投稿记住这个投稿好方法

在信息爆炸的时代,作为单位的信息宣传员,我肩负着将每一次重要活动,特别是像“安全宣传咨询日”这样的公益活动,有效传达给公众的重任。这份工作看似简单,实则充满了挑战,尤其是在我初涉此领域时,那段曲折而又难忘的投稿经历,至今记忆犹新。 初探投稿之海,遭遇重重困难 起初,我…...

第7章:系统架构设计基础知识-软件架构风格

由于历史原因&#xff0c;研究者和工程人员对Sofiware Architecture(简称SA)的翻译不尽相同&#xff0c;其软件的“体系结构”和“架构”具有相同的含义。 系统架构其实就是系统的结构&#xff0c;系统架构设计其实就是要给相关利益方说清楚通过什么样的结构来解决需求中功能和…...

自制调色小工具给图片加滤镜,修改图片红、绿、蓝通道及亮度,修改图片颜色

上篇&#xff1a; 上篇我们给地图添加了锐化、模糊等滤镜&#xff0c;这篇来写一个小工具给图片调色。 调色比锐化等滤镜要简单许多&#xff0c;直接拿到像素值修改即可。不需要用到卷积核。。。(*^▽^*) 核心原理就是图像结构&#xff0c;使用context.getImageData获取图像像…...

【Redis】java客户端(SpringData和jedis)

https://www.oz6.cn/articles/58 https://www.bilibili.com/video/BV1cr4y1671t/?p16 redis官网客户端介绍&#xff1a;https://redis.io/docs/latest/develop/connect/clients/ jedis maven引入依赖 <dependencies><!--引入Jedis依赖--><dependency><…...

大数据安全经典面试题及回答(上)

目录 一、大数据安全的主要挑战及应对策略 二、大数据安全中的“五个V”及其影响 三、在Hadoop集群中实施数据加密的步骤和注意事项 四、在大数据环境中实施访问控制和身份认证 五、大数据环境中数据备份和恢复的策略 六、大数据处理过程中保护用户隐私的策略 七、大数据…...

vi/vim使用命令

你是否在编辑文件时以为键盘坏了&#xff0c;为什么不能删除呢&#xff0c;为什么不能敲代码呢&#xff0c;当你初识vi&#xff0c;会觉得这个东西设计很难用&#xff0c;这篇教程带你熟练得用上这款经典的工具&#xff0c;当你熟练了这款工具就会真正体会到高效率打码 Vi 是在…...

webpack打包gz文件,nginx开启gzip压缩

wepback配置 webpack4配合"compression-webpack-plugin": "^6.1.2"打包压缩gz chain.plugin("compression").use(new CompressionPlugin({test: /\.js$|\.html$|\.css$/,threshold: 10240, // 超过10KB的压缩deleteOriginalAssets: false,// 保…...

微服务开发与实战Day11 - 微服务面试篇

一、分布式事务 1. CAP定理 1998年&#xff0c;加州大学的计算机科学及Eric Brewer提出&#xff0c;分布式系统有三个指标&#xff1a; Consistency&#xff08;一致性&#xff09;Availability&#xff08;可用性&#xff09;Partition tolerance&#xff08;分区容错性&am…...

基于Spring Boot+VUE职称评审管理系统

1管理员功能模块 管理员登录&#xff0c;通过填写注册时输入的用户名、密码、角色进行登录&#xff0c;如图1所示。 图1管理员登录界面图 管理员登录进入职称评审管理系统可以查看首页、个人中心、用户管理、评审员管理、省份管理、评审条件管理、职称申请管理、结果公布管理、…...

MySQL 基本语法讲解及示例(上)

第一节&#xff1a;MySQL的基本操作 1. 创建数据库 在 MySQL 中&#xff0c;创建数据库的步骤如下&#xff1a; 命令行操作 打开 MySQL 命令行客户端或连接到 MySQL 服务器。 输入以下命令创建一个数据库&#xff1a; CREATE DATABASE database_name;例如&#xff0c;创建一…...

6.18作业

完善对话框&#xff0c;点击登录对话框&#xff0c;如果账号和密码匹配&#xff0c;则弹出信息对话框&#xff0c;给出提示”登录成功“&#xff0c;提供一个Ok按钮&#xff0c;用户点击Ok后&#xff0c;关闭登录界面&#xff0c;跳转到其他界面 如果账号和密码不匹配&#xff…...

Excel文件转换为HTML文件

文章目录 前言安装python包python代码 前言 将一个Excel文件转换为HTML文件 安装python包 使用pandas和openpyxl库来实现这个功能 pip install pandas openpyxlpython代码 1、首先使用tkinter库中的filedialog模块弹出一个对话框来选择要转换的Excel文件 2、使用pandas库…...

MySQL数据库入门

1、MySQL概述 MySQL官方网站 https://www.mysql.com/downloads/ MySQL被Oracle公司收购了&#xff0c;作者又重新编写了一个开源的数据库管理系统&#xff0c;Mariadb 2、MySQL产品&版本 2、数据库在网站架构中的角色 LAMP LNMP网站架构 3、安装MySQL-基于yum 查…...

vue element-ui 下拉框 以及 input 限制输入,小数点后保留两位 界面设计案例 和 例子:支持mp4和m3u8视频播放

vue input 限制输入&#xff0c;小数点后保留两位 以及 图片垂直居中显示 和 分享 git 小技巧-CSDN博客文章浏览阅读430次&#xff0c;点赞5次&#xff0c;收藏4次。error:Your local changes to the following files would be overwritten by merge:_error: your local change…...

Python基础用法 之 运算符

1.算数运算符 符号作用说明举例加与“”相同 - 减与“-”相同*乘 与“ ”相同 9*218/除 与“ ”相同 9/24.5 、6/32.0//求商&#xff08;整数部分&#xff09; 两个数据做除法的 商 9//24%取余&#xff08;余数部分&#xff09; 是两个数据做除法的 余数 9%21**幂、次方2**…...

事务所管理系统的设计

管理员账户功能包括&#xff1a;系统首页&#xff0c;个人中心&#xff0c;管理员管理&#xff0c;客户管理&#xff0c;评论管理&#xff0c;基础数据管理&#xff0c;公告信息管理 客户账户功能包括&#xff1a;系统首页&#xff0c;个人中心&#xff0c;律师管理&#xff0…...

airsim安装

继续进行&#xff0c;遇到下面的报错 Cannot find path HKEY_CLASSES_ROOT\Unreal.ProjectFile\shell\rungenproj 在Git地址的issue中&#xff0c;搜到下面的解决方法&#xff0c;根因是安装Unreal Engine之后未重启电脑&#xff0c;文件未关联导致&#xff0c;或者出现重定向…...

打造精致UI界面:字体设计的妙招

字体设计是UI设计的关键模块之一。字体设计是否有效可能直接实现或破坏整个UI界面。那么&#xff0c;界面设计的字体设计有哪些规范呢&#xff1f;如何设计细节字体&#xff1f;本文将解释字体设计规范的可读性、可读性和可用性&#xff0c;并介绍UI界面中的字体设计技巧。 如…...

Java 8 Stream API 入门到实践详解

一、告别 for 循环&#xff01; 传统痛点&#xff1a; Java 8 之前&#xff0c;集合操作离不开冗长的 for 循环和匿名类。例如&#xff0c;过滤列表中的偶数&#xff1a; List<Integer> list Arrays.asList(1, 2, 3, 4, 5); List<Integer> evens new ArrayList…...

Oracle查询表空间大小

1 查询数据库中所有的表空间以及表空间所占空间的大小 SELECTtablespace_name,sum( bytes ) / 1024 / 1024 FROMdba_data_files GROUP BYtablespace_name; 2 Oracle查询表空间大小及每个表所占空间的大小 SELECTtablespace_name,file_id,file_name,round( bytes / ( 1024 …...

如何在看板中体现优先级变化

在看板中有效体现优先级变化的关键措施包括&#xff1a;采用颜色或标签标识优先级、设置任务排序规则、使用独立的优先级列或泳道、结合自动化规则同步优先级变化、建立定期的优先级审查流程。其中&#xff0c;设置任务排序规则尤其重要&#xff0c;因为它让看板视觉上直观地体…...

无法与IP建立连接,未能下载VSCode服务器

如题&#xff0c;在远程连接服务器的时候突然遇到了这个提示。 查阅了一圈&#xff0c;发现是VSCode版本自动更新惹的祸&#xff01;&#xff01;&#xff01; 在VSCode的帮助->关于这里发现前几天VSCode自动更新了&#xff0c;我的版本号变成了1.100.3 才导致了远程连接出…...

Cilium动手实验室: 精通之旅---20.Isovalent Enterprise for Cilium: Zero Trust Visibility

Cilium动手实验室: 精通之旅---20.Isovalent Enterprise for Cilium: Zero Trust Visibility 1. 实验室环境1.1 实验室环境1.2 小测试 2. The Endor System2.1 部署应用2.2 检查现有策略 3. Cilium 策略实体3.1 创建 allow-all 网络策略3.2 在 Hubble CLI 中验证网络策略源3.3 …...

反射获取方法和属性

Java反射获取方法 在Java中&#xff0c;反射&#xff08;Reflection&#xff09;是一种强大的机制&#xff0c;允许程序在运行时访问和操作类的内部属性和方法。通过反射&#xff0c;可以动态地创建对象、调用方法、改变属性值&#xff0c;这在很多Java框架中如Spring和Hiberna…...

Robots.txt 文件

什么是robots.txt&#xff1f; robots.txt 是一个位于网站根目录下的文本文件&#xff08;如&#xff1a;https://example.com/robots.txt&#xff09;&#xff0c;它用于指导网络爬虫&#xff08;如搜索引擎的蜘蛛程序&#xff09;如何抓取该网站的内容。这个文件遵循 Robots…...

【JavaSE】绘图与事件入门学习笔记

-Java绘图坐标体系 坐标体系-介绍 坐标原点位于左上角&#xff0c;以像素为单位。 在Java坐标系中,第一个是x坐标,表示当前位置为水平方向&#xff0c;距离坐标原点x个像素;第二个是y坐标&#xff0c;表示当前位置为垂直方向&#xff0c;距离坐标原点y个像素。 坐标体系-像素 …...

Mac下Android Studio扫描根目录卡死问题记录

环境信息 操作系统: macOS 15.5 (Apple M2芯片)Android Studio版本: Meerkat Feature Drop | 2024.3.2 Patch 1 (Build #AI-243.26053.27.2432.13536105, 2025年5月22日构建) 问题现象 在项目开发过程中&#xff0c;提示一个依赖外部头文件的cpp源文件需要同步&#xff0c;点…...

GC1808高性能24位立体声音频ADC芯片解析

1. 芯片概述 GC1808是一款24位立体声音频模数转换器&#xff08;ADC&#xff09;&#xff0c;支持8kHz~96kHz采样率&#xff0c;集成Δ-Σ调制器、数字抗混叠滤波器和高通滤波器&#xff0c;适用于高保真音频采集场景。 2. 核心特性 高精度&#xff1a;24位分辨率&#xff0c…...