当前位置: 首页 > news >正文

AIGC实战——条件生成对抗网络(Conditional Generative Adversarial Net, CGAN)

AIGC实战——条件生成对抗网络

    • 0. 前言
    • 1. CGAN架构
    • 2. 模型训练
    • 3. CGAN 分析
    • 小结
    • 系列链接

0. 前言

我们已经学习了如何构建生成对抗网络 (Generative Adversarial Net, GAN) 以从给定的训练集中生成逼真图像。但是,我们无法控制想要生成的图像类型,例如控制模型生成男性或女性的面部图像;我们可以从潜空间中随机采样一个点,但是不能预知给定潜变量能够生成什么样的图像。在本节中,我们将构建一个能够控制输出的 GAN,即条件生成对抗网络 (Conditional Generative Adversarial Net, GAN)。该模型最早由 MirzaOsindero2014 年提出,是对 GAN 架构的简单改进。

1. CGAN架构

在节中,我们将使用面部数据集中的头发颜色属性来设置 CGAN 的条件。也就是说,我们将能够明确指定是否要生成带有金发的图像。头发颜色标签作为 CelebA 数据集的一部分已在数据集中提供,CGAN 的架构如下图所示。

CGAN 架构

标准 GANCGAN 之间的关键区别在于:在 CGAN 中,我们需要向生成器和判别器传递与标签相关的额外信息。在生成器中,标签信息转化为独热编码 (one-hot) 向量后附加在潜空间样本之后。在判别器中,通过重复独热编码向量填充得到与输入图像相同形状的通道,将标签信息添加为 RGB 图像的额外通道。
CGAN 之所以能够生成指定类型的图像,是因为其判别器可以获得关于图像内容的额外信息,因此生成器必须确保其输出与提供的标签一致,以继续欺骗判别器。如果生成器生成了与图像标签不一致的图像,即使图像非常逼真,判别器会将它们判定为伪造图像,因为图像和标签并不匹配。
在本节所构建的 CGAN 中,因为有两个类别(金发和非金发),独热编码标签的长度是 2。但是,我们也可以根据需要拥有使用多个标签。例如,在 Fashion-MNIST 数据集上训练 CGAN 时,为了输出 10 种不同类型的 Fashion-MNIST 图像,可以通过将长度为 10 的独热编码标签向量并入生成器的输入,并将 10 个额外的独热编码标签通道并入判别器的输入。
综上,我们需要对标准 GAN 架构所进行的修改是,将标签信息与生成器和判别器的现有输入连接起来:

# 图像通道和标签通道分别传递给判别器,并进行连接
critic_input = layers.Input(shape=(IMAGE_SIZE, IMAGE_SIZE, CHANNELS))
label_input = layers.Input(shape=(IMAGE_SIZE, IMAGE_SIZE, CLASSES))
x = layers.Concatenate(axis=-1)([critic_input, label_input])
x = layers.Conv2D(64, kernel_size=4, strides=2, padding="same")(x)
x = layers.LeakyReLU(0.2)(x)
x = layers.Conv2D(128, kernel_size=4, strides=2, padding="same")(x)
x = layers.LeakyReLU()(x)
x = layers.Dropout(0.3)(x)
x = layers.Conv2D(128, kernel_size=4, strides=2, padding="same")(x)
x = layers.LeakyReLU(0.2)(x)
x = layers.Dropout(0.3)(x)
x = layers.Conv2D(128, kernel_size=4, strides=2, padding="same")(x)
x = layers.LeakyReLU(0.2)(x)
x = layers.Dropout(0.3)(x)
x = layers.Conv2D(1, kernel_size=4, strides=1, padding="valid")(x)
critic_output = layers.Flatten()(x)critic = models.Model([critic_input, label_input], critic_output)
print(critic.summary())
# 潜向量和标签类别分别传递给生成器,并在调整形状之前进行连接
generator_input = layers.Input(shape=(Z_DIM,))
label_input = layers.Input(shape=(CLASSES,))
x = layers.Concatenate(axis=-1)([generator_input, label_input])
x = layers.Reshape((1, 1, Z_DIM + CLASSES))(x)
x = layers.Conv2DTranspose(128, kernel_size=4, strides=1, padding="valid", use_bias=False
)(x)
x = layers.BatchNormalization(momentum=0.9)(x)
x = layers.LeakyReLU(0.2)(x)
x = layers.Conv2DTranspose(128, kernel_size=4, strides=2, padding="same", use_bias=False
)(x)
x = layers.BatchNormalization(momentum=0.9)(x)
x = layers.LeakyReLU(0.2)(x)
x = layers.Conv2DTranspose(128, kernel_size=4, strides=2, padding="same", use_bias=False
)(x)
x = layers.BatchNormalization(momentum=0.9)(x)
x = layers.LeakyReLU(0.2)(x)
x = layers.Conv2DTranspose(64, kernel_size=4, strides=2, padding="same", use_bias=False
)(x)
x = layers.BatchNormalization(momentum=0.9)(x)
x = layers.LeakyReLU(0.2)(x)
generator_output = layers.Conv2DTranspose(CHANNELS, kernel_size=4, strides=2, padding="same", activation="tanh"
)(x)
generator = models.Model([generator_input, label_input], generator_output)
print(generator.summary())

2. 模型训练

调整 CGANtrain_step 方法,以令生成器和判别器适应新的输入格式:

    def train_step(self, data):# 从数据集中提取图像和标签real_images, one_hot_labels = data# 将独热编码向量扩展为具有与输入图像相同空间尺寸 (64×64) 的独热编码图像image_one_hot_labels = one_hot_labels[:, None, None, :]image_one_hot_labels = tf.repeat(image_one_hot_labels, repeats=IMAGE_SIZE, axis=1)image_one_hot_labels = tf.repeat(image_one_hot_labels, repeats=IMAGE_SIZE, axis=2)batch_size = tf.shape(real_images)[0]for i in range(self.critic_steps):random_latent_vectors = tf.random.normal( shape=(batch_size, self.latent_dim))with tf.GradientTape() as tape:# 生成器接受包含两个输入的列表——随机潜向量和独热编码的标签向量fake_images = self.generator([random_latent_vectors, one_hot_labels], training=True)# 判别器接受包含两个输入的列表——真实/生成图像和独热编码的标签通道fake_predictions = self.critic([fake_images, image_one_hot_labels], training=True)real_predictions = self.critic([real_images, image_one_hot_labels], training=True)c_wass_loss = tf.reduce_mean(fake_predictions) - tf.reduce_mean(real_predictions)c_gp = self.gradient_penalty(batch_size, real_images, fake_images, image_one_hot_labels)# 梯度惩罚函数还需要通过独热编码的标签通道传递(由于其流经判别器)c_loss = c_wass_loss + c_gp * self.gp_weightc_gradient = tape.gradient(c_loss, self.critic.trainable_variables)self.c_optimizer.apply_gradients(zip(c_gradient, self.critic.trainable_variables))random_latent_vectors = tf.random.normal(shape=(batch_size, self.latent_dim))with tf.GradientTape() as tape:# 生成器训练过程的修改与判别器训练步骤的修改相同fake_images = self.generator([random_latent_vectors, one_hot_labels], training=True)fake_predictions = self.critic([fake_images, image_one_hot_labels], training=True)g_loss = -tf.reduce_mean(fake_predictions)gen_gradient = tape.gradient(g_loss, self.generator.trainable_variables)self.g_optimizer.apply_gradients(zip(gen_gradient, self.generator.trainable_variables))self.c_loss_metric.update_state(c_loss)self.c_wass_loss_metric.update_state(c_wass_loss)self.c_gp_metric.update_state(c_gp)self.g_loss_metric.update_state(g_loss)return {m.name: m.result() for m in self.metrics}

3. CGAN 分析

我们可以通过将特定的独热编码标签传递到生成器的输入中来控制 CGAN 的输出。例如,要生成一张非金发的人脸图像,我们传入向量 [1, 0];要生成一张金发的人脸图像,我们传入向量 [0, 1]
CGAN 的输出如下图所示。可以看到,在保持随机潜向量不变的情况下,只改变条件标签向量,显然 CGAN 已经学会使用标签向量来控制图像的头发颜色属性,且图像的其余部分几乎没有改变。这证明了 GAN 能够以这种方式组织潜空间中的点,使得各个特征可以相互解耦。

生成结果

如果数据集中有标签可用,即使不一定需要将生成的输出与标签相关联,将它们作为 GAN 的输入通常也可以提高生成图像的质量,我们可以把标签看作是像素输入的信息扩展。

小结

在本节中,构建了一个条件生成对抗网络 (Conditional Generative Adversarial Net, CGAN),通过将标签作为输入传递给判别器和生成器,能够生成可控类别的图像,这是由于标签为网络提供了额外的信息,以便使生成的输出与给定的标签相关联。

系列链接

AIGC实战——生成模型简介
AIGC实战——深度学习 (Deep Learning, DL)
AIGC实战——卷积神经网络(Convolutional Neural Network, CNN)
AIGC实战——自编码器(Autoencoder)
AIGC实战——变分自编码器(Variational Autoencoder, VAE)
AIGC实战——使用变分自编码器生成面部图像
AIGC实战——生成对抗网络(Generative Adversarial Network, GAN)
AIGC实战——WGAN(Wasserstein GAN)

相关文章:

AIGC实战——条件生成对抗网络(Conditional Generative Adversarial Net, CGAN)

AIGC实战——条件生成对抗网络 0. 前言1. CGAN架构2. 模型训练3. CGAN 分析小结系列链接 0. 前言 我们已经学习了如何构建生成对抗网络 (Generative Adversarial Net, GAN) 以从给定的训练集中生成逼真图像。但是,我们无法控制想要生成的图像类型,例如控…...

高性能计算HPC与统一存储

高性能计算(HPC)广泛应用于处理大量数据的复杂计算,提供更精确高效的计算结果,在石油勘探、基因分析、气象预测等领域,是企业科研机构进行研发的有效手段。为了分析复杂和大量的数据,存储方案需要响应更快&…...

秋招上岸记录咕咕咕了。

思考了一下,感觉并没有单独写这样一篇博客的必要。 能够写出来的,一些可能会对人有帮助的东西都做进了视频里面,未来会在blbl发布,目前剪辑正在施工中(?) 另外就是,那个视频里面使…...

vue模板语法

一、插值 1、文本 &#xff08;1&#xff09;v-text语法 缩写&#xff1a; {{…}}&#xff08;双大括号&#xff09;的文本插值 方法一&#xff1a; <template><h1> hello </h1><p v-text"data.name"></p><!-- v-text的简写--&…...

Pytorch神经网络的模型架构(nn.Module和nn.Sequential的用法)

一、层和块 在构造自定义块之前&#xff0c;我们先回顾一下多层感知机的代码。下面的代码生成一个网络&#xff0c;其中包含一个具有256个单元和ReLU激活函数的全连接隐藏层&#xff0c;然后是一个具有10个隐藏单元且不带激活函数的全连接输出层。 import torch from torch im…...

JS数组之展开运算符

展开运算符是什么&#xff1f;有什么作用&#xff1f; 展开运算符可以将一个数组展开 const arr [1,2,3,4,5]// 我们使用...展开数组console.log(...arr) //1 2 3 4 5它不会修改原数组 典型运用场景&#xff1a;求数组最大值、最小值、合并数组等 会让我们代码更加简洁 最大值…...

读书笔记:《汽车构造与原理》

《透视汽车会跑的奥秘》《汽车为什么会跑&#xff1a;底盘图解》《汽车为什么会跑&#xff1a;图解汽车构造与原理》 一、心脏&#xff1a;发动机 活塞往复运动转化为曲轴的旋转运动 活塞&#xff1a;膝关节活塞连杆&#xff1a;小腿曲轴&#xff1a;自行车脚踏板 四冲程&…...

INS 量测更新

5 量测更新 5.1 GNSS位置及速度更新 r ^ G P S , i n r ^ I M U n D R − 1 C b n l b v ^ G P S , i n v ^ I M U n ω i n n C b n l b − C b n ω i b b l b \begin{aligned} \hat{r}_{GPS,i}^{n} & \hat{r}_{IMU}^{n} D_{R}^{-1}C_{b}^{n} l^b\\ \hat{v}_{GPS…...

【ssh基础知识】

ssh基础知识 常用命令登录流程配置文件ssh密钥登录生成密钥上传公钥关闭密码登录 ssh服务管理查看日志ssh端口转发 ssh&#xff08;ssh客户端&#xff09;是一个用于登录到远程机器并在远程机器上执行命令的程序。 它旨在提供安全的加密通信在不安全的网络上的两个不受信任的主…...

04 开发第一个组件

概述 在Vue3中&#xff0c;一个组件就是一个.vue文件。 在本小节中&#xff0c;我们来开发第一个Vue3组件。这个组件的功能非常的简单&#xff0c;只需要在浏览器上输出一个固定的字符串”欢迎跟着Python私教一起学Vue3“即可。 实现步骤 第一步&#xff1a;新增src/compon…...

【Unity】如何让Unity程序一打开就运行命令行命令

【背景】 Unity程序有时依赖于某些服务去实现一些功能,此时可能需要类似打开程序就自动运行Windows命令行命令的功能。 【方法】 using UnityEngine; using System.Diagnostics; using System.Threading.Tasks; using System.IO; using System.Text...

Web前端-HTML(表格与表单)

文章目录 1.表格与表单1.1 概述 2.表格 table2.1 表格概述2.2. 创建表格2.3 表格属性2.4. 表头单元格标签th2.5 表格标题caption&#xff08;了解&#xff09;2.6 合并单元格(难点)2.7 总结表格 3. 表单标签(重点)3.1 概述3.2 form表单3.3 input 控件(重点)type 属性value属性值…...

Android RecycleView实现平滑滚动置顶和调整滚动速度

目录 一、滑动到指定位置&#xff08;target position&#xff09;并且置顶 1. RecycleView默认的几个实现方法及缺陷 2. 优化源码实现置顶方案 二、调整平移滑动速率 三、其他方案&#xff1a;置顶、置顶加偏移、居中 1. 其他置顶方案 2. 置顶加偏移 3. 滚动居中 在实…...

跳跃游戏 + 45. 跳跃游戏 II

给你一个非负整数数组 nums &#xff0c;你最初位于数组的 第一个下标 。数组中的每个元素代表你在该位置可以跳跃的最大长度。 判断你是否能够到达最后一个下标&#xff0c;如果可以&#xff0c;返回 true &#xff1b;否则&#xff0c;返回 false 。 示例 1&#xff1a; 输…...

在Django中使用多语言(i18n)

在Django中使用多语言 配置中间件 MIDDLEWARE [......django.contrib.sessions.middleware.SessionMiddleware,django.middleware.locale.LocaleMiddleware, # 此行重点django.middleware.common.CommonMiddleware,...... ]配置翻译文件目录 根目录下创建目录locale # 国…...

高性价比AWS Lambda无服务体验

前言 之前听到一个讲座说到AWS Lambda服务&#xff0c;基于Serverless无服务模型&#xff0c;另外官网还免费提供 100 万个请求 按月&#xff0c;包含在 AWS 免费套餐中是真的很香&#xff0c;对于一些小型的起步的网站或者用户量不大的网站&#xff0c;简直就是免费&#xff…...

【物联网】EMQX(二)——docker快速搭建EMQX 和 MQTTX客户端使用

一、前言 在上一篇文章中&#xff0c;小编向大家介绍了物联网必然会用到的消息服务器EMQ&#xff0c;相信大家也对EMQ有了一定的了解&#xff0c;那么接下来&#xff0c;小编从这篇文章正式开始展开对EMQ的学习教程&#xff0c;本章节来记录一下如何对EMQ进行安装。 二、使用…...

2023 亚马逊云科技 re:lnvent 大会探秘: Amazon Connect 全渠道云联络中心

2023 亚马逊云科技 re:lnvent 大会探秘: Amazon Connect 全渠道云联络中心 前言一. Amazon Connect 介绍 &#x1f5fa;️二. Amazon Connect 使用教程 &#x1f5fa;️1.我们打开URl链接找到对应服务2.输入Amazon Connect选中第一个点击进入即可&#xff1b;3.在进入之后我们就…...

鸿蒙开发之用户隐私权限申请

一、简介 鸿蒙开发过程中可用于请求的权限一共有两种&#xff1a;normal和system_basic。以下内容摘自官网&#xff1a; normal权限 normal 权限允许应用访问超出默认规则外的普通系统资源。这些系统资源的开放&#xff08;包括数据和功能&#xff09;对用户隐私以及其他应用带…...

Docker笔记:简单部署 nodejs 项目和 golang 项目

docker 简单的维护 nodejs 项目容器 1 &#xff09;Nodejs 程序 const express require(express) const app express()app.get(/, (req, res) > {res.send(首页) })app.get(/news, (req, res) > {res.send(news) })// dokcer 做端口映射不要指定ip app.listen(3000)2…...

java调用dll出现unsatisfiedLinkError以及JNA和JNI的区别

UnsatisfiedLinkError 在对接硬件设备中&#xff0c;我们会遇到使用 java 调用 dll文件 的情况&#xff0c;此时大概率出现UnsatisfiedLinkError链接错误&#xff0c;原因可能有如下几种 类名错误包名错误方法名参数错误使用 JNI 协议调用&#xff0c;结果 dll 未实现 JNI 协…...

【大模型RAG】Docker 一键部署 Milvus 完整攻略

本文概要 Milvus 2.5 Stand-alone 版可通过 Docker 在几分钟内完成安装&#xff1b;只需暴露 19530&#xff08;gRPC&#xff09;与 9091&#xff08;HTTP/WebUI&#xff09;两个端口&#xff0c;即可让本地电脑通过 PyMilvus 或浏览器访问远程 Linux 服务器上的 Milvus。下面…...

ETLCloud可能遇到的问题有哪些?常见坑位解析

数据集成平台ETLCloud&#xff0c;主要用于支持数据的抽取&#xff08;Extract&#xff09;、转换&#xff08;Transform&#xff09;和加载&#xff08;Load&#xff09;过程。提供了一个简洁直观的界面&#xff0c;以便用户可以在不同的数据源之间轻松地进行数据迁移和转换。…...

反射获取方法和属性

Java反射获取方法 在Java中&#xff0c;反射&#xff08;Reflection&#xff09;是一种强大的机制&#xff0c;允许程序在运行时访问和操作类的内部属性和方法。通过反射&#xff0c;可以动态地创建对象、调用方法、改变属性值&#xff0c;这在很多Java框架中如Spring和Hiberna…...

涂鸦T5AI手搓语音、emoji、otto机器人从入门到实战

“&#x1f916;手搓TuyaAI语音指令 &#x1f60d;秒变表情包大师&#xff0c;让萌系Otto机器人&#x1f525;玩出智能新花样&#xff01;开整&#xff01;” &#x1f916; Otto机器人 → 直接点明主体 手搓TuyaAI语音 → 强调 自主编程/自定义 语音控制&#xff08;TuyaAI…...

Linux 内存管理实战精讲:核心原理与面试常考点全解析

Linux 内存管理实战精讲&#xff1a;核心原理与面试常考点全解析 Linux 内核内存管理是系统设计中最复杂但也最核心的模块之一。它不仅支撑着虚拟内存机制、物理内存分配、进程隔离与资源复用&#xff0c;还直接决定系统运行的性能与稳定性。无论你是嵌入式开发者、内核调试工…...

【MATLAB代码】基于最大相关熵准则(MCC)的三维鲁棒卡尔曼滤波算法(MCC-KF),附源代码|订阅专栏后可直接查看

文章所述的代码实现了基于最大相关熵准则(MCC)的三维鲁棒卡尔曼滤波算法(MCC-KF),针对传感器观测数据中存在的脉冲型异常噪声问题,通过非线性加权机制提升滤波器的抗干扰能力。代码通过对比传统KF与MCC-KF在含异常值场景下的表现,验证了后者在状态估计鲁棒性方面的显著优…...

作为测试我们应该关注redis哪些方面

1、功能测试 数据结构操作&#xff1a;验证字符串、列表、哈希、集合和有序的基本操作是否正确 持久化&#xff1a;测试aof和aof持久化机制&#xff0c;确保数据在开启后正确恢复。 事务&#xff1a;检查事务的原子性和回滚机制。 发布订阅&#xff1a;确保消息正确传递。 2、性…...

【安全篇】金刚不坏之身:整合 Spring Security + JWT 实现无状态认证与授权

摘要 本文是《Spring Boot 实战派》系列的第四篇。我们将直面所有 Web 应用都无法回避的核心问题&#xff1a;安全。文章将详细阐述认证&#xff08;Authentication) 与授权&#xff08;Authorization的核心概念&#xff0c;对比传统 Session-Cookie 与现代 JWT&#xff08;JS…...

算法打卡第18天

从中序与后序遍历序列构造二叉树 (力扣106题) 给定两个整数数组 inorder 和 postorder &#xff0c;其中 inorder 是二叉树的中序遍历&#xff0c; postorder 是同一棵树的后序遍历&#xff0c;请你构造并返回这颗 二叉树 。 示例 1: 输入&#xff1a;inorder [9,3,15,20,7…...