当前位置: 首页 > news >正文

对抗式生成模仿学习(GAIL)

目录

1 预先基础知识 

1.1 对抗生成网络(GAN)

1.1.1 基本概念

1.1.2 损失函数

1.1.2.1 固定G,求解令损失函数最大的D

1.1.2.2 固定D,求解令损失函数最小的G

1.2 对抗式生成模仿学习特点

2 对抗式生成模仿学习(GAIL)详细说明

3 参考文献

1 预先基础知识 

1.1 对抗生成网络(GAN)

1.1.1 基本概念

在GAN生成对抗网络中,包含两个模型,一个生成模型,一个判别模型。

  • 生成模型:负责生成看起来真实自然,和原始数据相似的实例。
  • 判别模型:负责判断给出的实例是真实的还是人为伪造的。

生成模型努力去欺骗判别模型,判别模型努力不被欺骗,这样两种模型交替优化训练,都得到了提升。

对于辨别器,如果得到的是生成图片辨别器应该输出0,如果是真实的图片应该输出 1,得到误差梯度反向传播来更新参数。对于生成器,首先由生成器生成一张图片,然后输入给判别器判别并的到相应的误差梯度,然后反向传播这些图片梯度成为组成生成器的权重。直观上来说就是:辨别器不得不告诉生成器如何调整从而使它生成的图片变得更加真实。

1.1.2 损失函数

GAN模型的目标函数:

其中,参考GAN的架构图,字母 V是原始GAN论文中指定用来表示该交叉熵的字母,x 表示任意真实数据,z 表示与真实数据相同结构的任意随机数据,G(z)表示在生成器中基于 z 生成的假数据,而D(x)表示判别器在真实数据 x上判断出的结果,D(G(z))表示判别器在假数据 G(z)上判断出的结果,其中 D(x) 与D(G(z))都是样本为“真”的概率,即标签为1的概率。

上式,主要意思是先固定生成器G,从判别器D的角度令损失最大化,紧接着固定D,从生成器G的角度令损失最小化,即可让判别器和生成器在共享损失的情况下实现对抗。其中第一个期望\mathbb{E}_{x \sim p_{\text{data}}(x)} \left[ \log D(x) \right]是所有x都是真实数据时(log(D(x)))的期望,第二个期望\mathbb{E}_{z \sim p(z)} \left[ \log (1 - D(G(z))) \right]是所有数据都是生成数据时log(1-D(G(z)))的期望。可以看出,在求解最优解的过程中存在两个过程:

  • 固定G,求解令损失函数最大的D
  • 固定D,求解令损失函数最小的G

判别网络是一个2分类,目标是分清真实数据和伪造数据,也就是希望D(x) 趋近于1,D(G(z))趋近于0,这也就体现了对抗的思想。G网络的loss是log(1-D(G(z))),D的loss是-(log(D(x)))+log(1-D(G(z)))。

1.1.2.1 固定G,求解令损失函数最大的D

判别器D的输入x有两部分:一部分是真实数据,设其分布为P_{\text{data}}(x);另一部分是生成器生成的数据,参考架构图,生成器接收的数据z服从分布P(z),A输入z经过生成器的计算生成的数据分布设为P_{G}(x)

这两部分这两部分都是判别器D的输入,不同的是,G的输出来自分布P_{G}(x),而真实数据来自分布P_{\text{data}}(x),经过一系列推导后的结果:

可以看出,固定G,将最优的D带入后,此时V(G,D*),实际上是在度量P_{\text{data}}(x)P_{G}(x)之间的JS散度,同KL散度一样,他们之间的分布差异越大,JS散度值也越大。换句话说:保持G不变,最大化V(G,D)就等价于计算JS散度。对于判别器来说,尽可能找出生成器生成的数据与真实数据分布之间的差异,这个差异就是JS散度。

1.1.2.2 固定D,求解令损失函数最小的G

对于生成器来说,让生成器生成的数据分布接近真实数据分布。现在第一步已经求出了最优解的D*,代入损失函数:

在最小化JS散度,JS散度越小,分部之间的差异越小,正好印证了第二个原则。

1.2 对抗式生成模仿学习特点

逆强化学习(Inverse Reinforcement Learning, IRL)作为一种典型的模仿学习方法,顾名思义,逆强化学习的学习过程与正常的强化学习利用奖励函数学习策略相反,不利用现有的奖励函数,而是试图学出一个奖励函数,并以之指导基于奖励函数的强化学习过程。IRL可以归结为解决从观察到的最优行为中提取奖励函数( Reward Function)的问题,这些最优行为也可以表示为专家策略 。基于IRL的方法交替地在两个过程中交替:一个阶段是使用示范数据来推断一个隐藏的奖励(Reward)或代价( Cost)函数,另一个阶段是使用强化学习基于推断的奖励函数来学习一个模仿策略。IRL的基本准则是:IRL选择奖励函数来优化策略,并且使得任何不同于\Pi _{E}的动作决策尽可能产生更大损失。

对抗式生成模仿学习(Generative Adversarial Imitation Learning,GAIL)是逆强化学习的一种重要实现方法之一。逆强化学习旨在从专家示范的行为中推断环境的奖励函数或者价值函数,而GAIL是逆强化学习的一种实现方式,它利用了生成对抗网络(GAN)的概念来进行模仿学习。

GAIL的关键点在于:

1生成对抗网络: GAIL使用生成对抗网络的框架,其中包括生成器和判别器。

2生成器与判别器: 生成器尝试生成与专家示范行为相似的状态-动作对,而判别器则尝试区分专家行为和生成器生成的行为。

3对抗优化: GAIL使用对抗训练的思想,通过生成器和判别器的对抗优化来使得生成器的输出逼近专家的行为。

GAIL的工作方式使得它在逆强化学习中发挥着重要作用,因为它提供了一种有效的方式来从专家示范中学习环境的奖励结构,以指导智能体的学习行为。通过对抗式生成模仿学习,智能体可以学习并模仿专家的行为,而无需显式地使用环境的奖励信号。

因此,GAIL作为逆强化学习的一种方法,为从专家示范中学习环境的奖励函数或者价值函数提供了一种有效的框架和方法。

2 对抗式生成模仿学习(GAIL)详细说明

 

生成式对抗模仿学习的整体优化流程如图所示。通过 GAIL 方法,策略生成器通过生成类似专家示教样本的探索样本,泛化示教样本的概率分布, 逼近专家示范行为数据,进而实现模仿专家技能的目的。该过程直接优化采样样本的概率分布,计算代价较小且算法通用性更强,实际模仿效果也更好。 

伪代码:

# 初始化策略 π、判别器 D、专家示范数据 D_expert、策略缓冲区 D_policy函数 GAIL_Training():初始化策略 π 的参数初始化判别器 D 的参数循环 直到收敛 或 达到最大迭代次数:# 使用当前策略 π 生成轨迹并存储在策略缓冲区 D_policy 中生成 trajectories 使用 π 并存储在 D_policy 中# 判别器训练循环 discriminator_updates 次数:# 从策略缓冲区 D_policy 中采样数据采样 (s_policy, a_policy) 从 D_policy 中# 从专家示范数据 D_expert 中采样数据采样 (s_expert, a_expert) 从 D_expert 中# 更新判别器 D计算 L_D = -[log(D(s_expert, a_expert)) + log(1 - D(s_policy, a_policy))]使用梯度下降法更新判别器参数以最小化 L_D# 策略更新采样 (s, a, ...) 从 D_policy 中计算伪奖励 r = -log(1 - D(s, a))# 使用伪奖励 r 更新策略 π计算 L_π 使用 PPO 或 其他强化学习方法使用梯度下降法更新策略 π 的参数以最大化 L_π

能够表征GAIL流程的主程序如下: 

import torch
from torch import nn
import torch.nn.functional as F
from torch.optim import Adamfrom .ppo import PPO
from gail_airl_ppo.network import GAILDiscrimclass GAIL(PPO):def __init__(self, buffer_exp, state_shape, action_shape, device, seed,gamma=0.995, rollout_length=50000, mix_buffer=1,batch_size=64, lr_actor=3e-4, lr_critic=3e-4, lr_disc=3e-4,units_actor=(64, 64), units_critic=(64, 64),units_disc=(100, 100), epoch_ppo=50, epoch_disc=10,clip_eps=0.2, lambd=0.97, coef_ent=0.0, max_grad_norm=10.0):super().__init__(state_shape, action_shape, device, seed, gamma, rollout_length,mix_buffer, lr_actor, lr_critic, units_actor, units_critic,epoch_ppo, clip_eps, lambd, coef_ent, max_grad_norm)# Expert's buffer.self.buffer_exp = buffer_exp# Discriminator.self.disc = GAILDiscrim(state_shape=state_shape,action_shape=action_shape,hidden_units=units_disc,hidden_activation=nn.Tanh()).to(device)self.learning_steps_disc = 0self.optim_disc = Adam(self.disc.parameters(), lr=lr_disc)self.batch_size = batch_sizeself.epoch_disc = epoch_discdef update(self, writer):self.learning_steps += 1for _ in range(self.epoch_disc):self.learning_steps_disc += 1# Samples from current policy's trajectories.states, actions = self.buffer.sample(self.batch_size)[:2]# Samples from expert's demonstrations.states_exp, actions_exp = \self.buffer_exp.sample(self.batch_size)[:2]# Update discriminator.self.update_disc(states, actions, states_exp, actions_exp, writer)# We don't use reward signals here,states, actions, _, dones, log_pis, next_states = self.buffer.get()# Calculate rewards.rewards = self.disc.calculate_reward(states, actions)# Update PPO using estimated rewards.self.update_ppo(states, actions, rewards, dones, log_pis, next_states, writer)def update_disc(self, states, actions, states_exp, actions_exp, writer):# Output of discriminator is (-inf, inf), not [0, 1].logits_pi = self.disc(states, actions)logits_exp = self.disc(states_exp, actions_exp)# Discriminator is to maximize E_{\pi} [log(1 - D)] + E_{exp} [log(D)].loss_pi = -F.logsigmoid(-logits_pi).mean()loss_exp = -F.logsigmoid(logits_exp).mean()loss_disc = loss_pi + loss_expself.optim_disc.zero_grad()loss_disc.backward()self.optim_disc.step()if self.learning_steps_disc % self.epoch_disc == 0:writer.add_scalar('loss/disc', loss_disc.item(), self.learning_steps)# Discriminator's accuracies.with torch.no_grad():acc_pi = (logits_pi < 0).float().mean().item()acc_exp = (logits_exp > 0).float().mean().item()writer.add_scalar('stats/acc_pi', acc_pi, self.learning_steps)writer.add_scalar('stats/acc_exp', acc_exp, self.learning_steps)

3 参考文献

https://zhuanlan.zhihu.com/p/628915533

【强化学习】GAIL_gail算法-CSDN博客

代码:https://github.com/toshikwa/gail-airl-ppo.pytorch.git

相关文章:

对抗式生成模仿学习(GAIL)

目录 1 预先基础知识 1.1 对抗生成网络&#xff08;GAN&#xff09; 1.1.1 基本概念 1.1.2 损失函数 1.1.2.1 固定G&#xff0c;求解令损失函数最大的D 1.1.2.2 固定D&#xff0c;求解令损失函数最小的G 1.2 对抗式生成模仿学习特点 2 对抗式生成模仿学习&#xff08;…...

信息系统项目管理师 | 新一代信息技术

关注WX&#xff1a;CodingTechWork 物联网 定义 The Internet of Things是指通过信息传感设备&#xff0c;按约定的协议&#xff0c;将任何物品与互联网连接&#xff0c;进行信息交互和通信&#xff0c;以实现智能化识别。定位、跟踪、监控和管理的一种网络。物联网主要解决…...

安全宣传咨询日活动向媒体投稿记住这个投稿好方法

在信息爆炸的时代,作为单位的信息宣传员,我肩负着将每一次重要活动,特别是像“安全宣传咨询日”这样的公益活动,有效传达给公众的重任。这份工作看似简单,实则充满了挑战,尤其是在我初涉此领域时,那段曲折而又难忘的投稿经历,至今记忆犹新。 初探投稿之海,遭遇重重困难 起初,我…...

第7章:系统架构设计基础知识-软件架构风格

由于历史原因&#xff0c;研究者和工程人员对Sofiware Architecture(简称SA)的翻译不尽相同&#xff0c;其软件的“体系结构”和“架构”具有相同的含义。 系统架构其实就是系统的结构&#xff0c;系统架构设计其实就是要给相关利益方说清楚通过什么样的结构来解决需求中功能和…...

自制调色小工具给图片加滤镜,修改图片红、绿、蓝通道及亮度,修改图片颜色

上篇&#xff1a; 上篇我们给地图添加了锐化、模糊等滤镜&#xff0c;这篇来写一个小工具给图片调色。 调色比锐化等滤镜要简单许多&#xff0c;直接拿到像素值修改即可。不需要用到卷积核。。。(*^▽^*) 核心原理就是图像结构&#xff0c;使用context.getImageData获取图像像…...

【Redis】java客户端(SpringData和jedis)

https://www.oz6.cn/articles/58 https://www.bilibili.com/video/BV1cr4y1671t/?p16 redis官网客户端介绍&#xff1a;https://redis.io/docs/latest/develop/connect/clients/ jedis maven引入依赖 <dependencies><!--引入Jedis依赖--><dependency><…...

大数据安全经典面试题及回答(上)

目录 一、大数据安全的主要挑战及应对策略 二、大数据安全中的“五个V”及其影响 三、在Hadoop集群中实施数据加密的步骤和注意事项 四、在大数据环境中实施访问控制和身份认证 五、大数据环境中数据备份和恢复的策略 六、大数据处理过程中保护用户隐私的策略 七、大数据…...

vi/vim使用命令

你是否在编辑文件时以为键盘坏了&#xff0c;为什么不能删除呢&#xff0c;为什么不能敲代码呢&#xff0c;当你初识vi&#xff0c;会觉得这个东西设计很难用&#xff0c;这篇教程带你熟练得用上这款经典的工具&#xff0c;当你熟练了这款工具就会真正体会到高效率打码 Vi 是在…...

webpack打包gz文件,nginx开启gzip压缩

wepback配置 webpack4配合"compression-webpack-plugin": "^6.1.2"打包压缩gz chain.plugin("compression").use(new CompressionPlugin({test: /\.js$|\.html$|\.css$/,threshold: 10240, // 超过10KB的压缩deleteOriginalAssets: false,// 保…...

微服务开发与实战Day11 - 微服务面试篇

一、分布式事务 1. CAP定理 1998年&#xff0c;加州大学的计算机科学及Eric Brewer提出&#xff0c;分布式系统有三个指标&#xff1a; Consistency&#xff08;一致性&#xff09;Availability&#xff08;可用性&#xff09;Partition tolerance&#xff08;分区容错性&am…...

基于Spring Boot+VUE职称评审管理系统

1管理员功能模块 管理员登录&#xff0c;通过填写注册时输入的用户名、密码、角色进行登录&#xff0c;如图1所示。 图1管理员登录界面图 管理员登录进入职称评审管理系统可以查看首页、个人中心、用户管理、评审员管理、省份管理、评审条件管理、职称申请管理、结果公布管理、…...

MySQL 基本语法讲解及示例(上)

第一节&#xff1a;MySQL的基本操作 1. 创建数据库 在 MySQL 中&#xff0c;创建数据库的步骤如下&#xff1a; 命令行操作 打开 MySQL 命令行客户端或连接到 MySQL 服务器。 输入以下命令创建一个数据库&#xff1a; CREATE DATABASE database_name;例如&#xff0c;创建一…...

6.18作业

完善对话框&#xff0c;点击登录对话框&#xff0c;如果账号和密码匹配&#xff0c;则弹出信息对话框&#xff0c;给出提示”登录成功“&#xff0c;提供一个Ok按钮&#xff0c;用户点击Ok后&#xff0c;关闭登录界面&#xff0c;跳转到其他界面 如果账号和密码不匹配&#xff…...

Excel文件转换为HTML文件

文章目录 前言安装python包python代码 前言 将一个Excel文件转换为HTML文件 安装python包 使用pandas和openpyxl库来实现这个功能 pip install pandas openpyxlpython代码 1、首先使用tkinter库中的filedialog模块弹出一个对话框来选择要转换的Excel文件 2、使用pandas库…...

MySQL数据库入门

1、MySQL概述 MySQL官方网站 https://www.mysql.com/downloads/ MySQL被Oracle公司收购了&#xff0c;作者又重新编写了一个开源的数据库管理系统&#xff0c;Mariadb 2、MySQL产品&版本 2、数据库在网站架构中的角色 LAMP LNMP网站架构 3、安装MySQL-基于yum 查…...

vue element-ui 下拉框 以及 input 限制输入,小数点后保留两位 界面设计案例 和 例子:支持mp4和m3u8视频播放

vue input 限制输入&#xff0c;小数点后保留两位 以及 图片垂直居中显示 和 分享 git 小技巧-CSDN博客文章浏览阅读430次&#xff0c;点赞5次&#xff0c;收藏4次。error:Your local changes to the following files would be overwritten by merge:_error: your local change…...

Python基础用法 之 运算符

1.算数运算符 符号作用说明举例加与“”相同 - 减与“-”相同*乘 与“ ”相同 9*218/除 与“ ”相同 9/24.5 、6/32.0//求商&#xff08;整数部分&#xff09; 两个数据做除法的 商 9//24%取余&#xff08;余数部分&#xff09; 是两个数据做除法的 余数 9%21**幂、次方2**…...

事务所管理系统的设计

管理员账户功能包括&#xff1a;系统首页&#xff0c;个人中心&#xff0c;管理员管理&#xff0c;客户管理&#xff0c;评论管理&#xff0c;基础数据管理&#xff0c;公告信息管理 客户账户功能包括&#xff1a;系统首页&#xff0c;个人中心&#xff0c;律师管理&#xff0…...

airsim安装

继续进行&#xff0c;遇到下面的报错 Cannot find path HKEY_CLASSES_ROOT\Unreal.ProjectFile\shell\rungenproj 在Git地址的issue中&#xff0c;搜到下面的解决方法&#xff0c;根因是安装Unreal Engine之后未重启电脑&#xff0c;文件未关联导致&#xff0c;或者出现重定向…...

打造精致UI界面:字体设计的妙招

字体设计是UI设计的关键模块之一。字体设计是否有效可能直接实现或破坏整个UI界面。那么&#xff0c;界面设计的字体设计有哪些规范呢&#xff1f;如何设计细节字体&#xff1f;本文将解释字体设计规范的可读性、可读性和可用性&#xff0c;并介绍UI界面中的字体设计技巧。 如…...

iOS 26 携众系统重磅更新,但“苹果智能”仍与国行无缘

美国西海岸的夏天&#xff0c;再次被苹果点燃。一年一度的全球开发者大会 WWDC25 如期而至&#xff0c;这不仅是开发者的盛宴&#xff0c;更是全球数亿苹果用户翘首以盼的科技春晚。今年&#xff0c;苹果依旧为我们带来了全家桶式的系统更新&#xff0c;包括 iOS 26、iPadOS 26…...

rknn优化教程(二)

文章目录 1. 前述2. 三方库的封装2.1 xrepo中的库2.2 xrepo之外的库2.2.1 opencv2.2.2 rknnrt2.2.3 spdlog 3. rknn_engine库 1. 前述 OK&#xff0c;开始写第二篇的内容了。这篇博客主要能写一下&#xff1a; 如何给一些三方库按照xmake方式进行封装&#xff0c;供调用如何按…...

深入浅出:JavaScript 中的 `window.crypto.getRandomValues()` 方法

深入浅出&#xff1a;JavaScript 中的 window.crypto.getRandomValues() 方法 在现代 Web 开发中&#xff0c;随机数的生成看似简单&#xff0c;却隐藏着许多玄机。无论是生成密码、加密密钥&#xff0c;还是创建安全令牌&#xff0c;随机数的质量直接关系到系统的安全性。Jav…...

华为OD机试-食堂供餐-二分法

import java.util.Arrays; import java.util.Scanner;public class DemoTest3 {public static void main(String[] args) {Scanner in new Scanner(System.in);// 注意 hasNext 和 hasNextLine 的区别while (in.hasNextLine()) { // 注意 while 处理多个 caseint a in.nextIn…...

学习STC51单片机31(芯片为STC89C52RCRC)OLED显示屏1

每日一言 生活的美好&#xff0c;总是藏在那些你咬牙坚持的日子里。 硬件&#xff1a;OLED 以后要用到OLED的时候找到这个文件 OLED的设备地址 SSD1306"SSD" 是品牌缩写&#xff0c;"1306" 是产品编号。 驱动 OLED 屏幕的 IIC 总线数据传输格式 示意图 …...

【Java_EE】Spring MVC

目录 Spring Web MVC ​编辑注解 RestController RequestMapping RequestParam RequestParam RequestBody PathVariable RequestPart 参数传递 注意事项 ​编辑参数重命名 RequestParam ​编辑​编辑传递集合 RequestParam 传递JSON数据 ​编辑RequestBody ​…...

如何理解 IP 数据报中的 TTL?

目录 前言理解 前言 面试灵魂一问&#xff1a;说说对 IP 数据报中 TTL 的理解&#xff1f;我们都知道&#xff0c;IP 数据报由首部和数据两部分组成&#xff0c;首部又分为两部分&#xff1a;固定部分和可变部分&#xff0c;共占 20 字节&#xff0c;而即将讨论的 TTL 就位于首…...

智能分布式爬虫的数据处理流水线优化:基于深度强化学习的数据质量控制

在数字化浪潮席卷全球的今天&#xff0c;数据已成为企业和研究机构的核心资产。智能分布式爬虫作为高效的数据采集工具&#xff0c;在大规模数据获取中发挥着关键作用。然而&#xff0c;传统的数据处理流水线在面对复杂多变的网络环境和海量异构数据时&#xff0c;常出现数据质…...

LRU 缓存机制详解与实现(Java版) + 力扣解决

&#x1f4cc; LRU 缓存机制详解与实现&#xff08;Java版&#xff09; 一、&#x1f4d6; 问题背景 在日常开发中&#xff0c;我们经常会使用 缓存&#xff08;Cache&#xff09; 来提升性能。但由于内存有限&#xff0c;缓存不可能无限增长&#xff0c;于是需要策略决定&am…...

wpf在image控件上快速显示内存图像

wpf在image控件上快速显示内存图像https://www.cnblogs.com/haodafeng/p/10431387.html 如果你在寻找能够快速在image控件刷新大图像&#xff08;比如分辨率3000*3000的图像&#xff09;的办法&#xff0c;尤其是想把内存中的裸数据&#xff08;只有图像的数据&#xff0c;不包…...