当前位置: 首页 > news >正文

解析生成对抗网络(GAN):原理与应用

目录

一、引言

二、生成对抗网络原理

(一)基本架构

(二)训练过程

三、生成对抗网络的应用

(一)图像生成

无条件图像生成:

(二)数据增强

(三)风格迁移

四、生成对抗网络训练中的挑战与解决策略

(一)模式崩溃

(二)梯度消失


一、引言

生成对抗网络(GAN)自 2014 年被 Goodfellow 等人提出以来,在深度学习领域引起了广泛的关注和研究热潮。它创新性地引入了一种对抗训练的思想,通过生成器和判别器的相互博弈,使得生成器能够学习到数据的潜在分布,从而生成逼真的样本数据。这种独特的机制使得 GAN 在图像生成、文本生成、音频生成等多个领域展现出了巨大的潜力,为人工智能技术的发展带来了新的突破和方向。

二、生成对抗网络原理

(一)基本架构

GAN 主要由两个核心组件构成:生成器(Generator)和判别器(Discriminator)。

  1. 生成器
    • 生成器的任务是接收一个随机噪声向量 (通常从一个简单的分布,如标准正态分布 N(0,1)采样得到),并通过一系列的神经网络层将其映射为与真实数据相似的生成数据G(z)
    • 例如,在图像生成任务中,生成器的输出将是一张与训练数据集中图像具有相似特征的合成图像。
    • 生成器通常采用多层的反卷积神经网络(Deconvolutional Neural Network)或转置卷积神经网络(Transposed Convolutional Neural Network)结构。以生成64*64其网络结构如下:
      import torch
      import torch.nn as nnclass Generator(nn.Module):def __init__(self):super(Generator, self).__init__()# 输入为 100 维的噪声向量self.fc = nn.Linear(100, 4 * 4 * 1024)self.deconv1 = nn.ConvTranspose2d(1024, 512, kernel_size=4, stride=2, padding=1)self.bn1 = nn.BatchNorm2d(512)self.deconv2 = nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1)self.bn2 = nn.BatchNorm2d(256)self.deconv3 = nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1)self.bn3 = nn.BatchNorm2d(128)self.deconv4 = nn.ConvTranspose2d(128, 3, kernel_size=4, stride=2, padding=1)def forward(self, x):x = self.fc(x)x = x.view(-1, 1024, 4, 4)x = torch.relu(self.bn1(self.deconv1(x)))x = torch.relu(self.bn2(self.deconv2(x)))x = torch.relu(self.bn3(self.deconv3(x)))x = torch.tanh(self.deconv4(x))return x

  2. 判别器
  • 判别器的作用是区分输入的数据是来自真实数据分布还是由生成器生成的数据。它接收真实数据 x 或生成数据 G(z),并输出一个表示数据真实性的概率值  D(x)或D(G(z)) ,取值范围在 0 到  1之间,接近  表示数据更可能是真实的,接近  表示数据更可能是生成的。

判别器通常采用卷积神经网络(Convolutional Neural Network)结构。例如,对于判断  彩色图像的判别器网络结构如下:

class Discriminator(nn.Module):def __init__(self):super(Discriminator, self).__init__()self.conv1 = nn.Conv2d(3, 128, kernel_size=4, stride=2, padding=1)self.bn1 = nn.BatchNorm2d(128)self.conv2 = nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1)self.bn2 = nn.BatchNorm2d(256)self.conv3 = nn.Conv2d(256, 512, kernel_size=4, stride=2, padding=1)self.bn3 = nn.BatchNorm2d(512)self.conv4 = nn.Conv2d(512, 1, kernel_size=4, stride=2, padding=0)def forward(self, x):x = torch.relu(self.bn1(self.conv1(x)))x = torch.relu(self.bn2(self.conv2(x)))x = torch.relu(self.bn3(self.conv3(x)))x = torch.sigmoid(self.conv4(x))return x.view(-1)

(二)训练过程

GAN 的训练过程是一个对抗性的迭代过程。

三、生成对抗网络的应用

(一)图像生成

1.无条件图像生成

GAN 可以用于生成各种类型的图像,如人脸图像、风景图像等。例如,在人脸图像生成任务中,通过在大规模人脸数据集上训练 GAN,生成器能够学习到人脸的各种特征,如五官的形状、肤色、表情等,从而生成全新的、逼真的人脸图像。

代码示例:

# 假设已经定义好生成器 G 和判别器 D,以及相关的优化器和损失函数
# 训练循环
num_epochs = 100
for epoch in range(num_epochs):for i, (real_images, _) in enumerate(dataloader):# 训练判别器# 采样噪声z = torch.randn(real_images.shape[0], 100).to(device)# 生成假图像fake_images = G(z)# 计算判别器损失real_loss = criterion(D(real_images), torch.ones(real_images.shape[0]).to(device))fake_loss = criterion(D(fake_images.detach()), torch.zeros(fake_images.shape[0]).to(device))d_loss = (real_loss + fake_loss) / 2# 更新判别器参数d_optimizer.zero_grad()d_loss.backward()d_optimizer.step()# 训练生成器# 再次采样噪声z = torch.randn(real_images.shape[0], 100).to(device)# 生成假图像fake_images = G(z)# 计算生成器损失g_loss = criterion(D(fake_images), torch.ones(fake_images.shape[0]).to(device))# 更新生成器参数g_optimizer.zero_grad()g_loss.backward()g_optimizer.step()

2.条件图像生成

可以通过在生成器和判别器的输入中加入条件信息,实现条件图像生成。例如,根据给定的文本描述生成相应的图像,或者根据特定的类别标签生成属于该类别的图像。

以根据类别标签生成图像为例,在生成器的输入中除了噪声向量 ,还加入类别标签的编码向量 ,生成器的网络结构需要进行相应修改,如:

class ConditionalGenerator(nn.Module):def __init__(self, num_classes):super(ConditionalGenerator, self).__init__()# 输入为 100 维噪声向量和类别编码向量self.fc = nn.Linear(100 + num_classes, 4 * 4 * 1024)# 后续的反卷积层与无条件生成器类似self.deconv1 = nn.ConvTranspose2d(1024, 512, kernel_size=4, stride=2, padding=1)self.bn1 = nn.BatchNorm2d(512)#...def forward(self, x, y):# 拼接噪声向量和类别编码向量x = torch.cat([x, y], dim=1)x = self.fc(x)x = x.view(-1, 1024, 4, 4)x = torch.relu(self.bn1(self.deconv1(x)))#...return x

(二)数据增强

  • 图像数据增强
    • 在图像分类、目标检测等任务中,数据量不足可能导致模型过拟合。GAN 可以用于生成额外的图像数据来扩充数据集。通过在原始图像数据集上训练 GAN,生成与原始图像相似但又有一定变化的图像,如不同角度、光照条件下的图像,从而增加数据的多样性,提高模型的泛化能力。
  • 其他数据类型的数据增强
    • 除了图像数据,GAN 也可以应用于其他数据类型的数据增强,如文本数据。例如,通过生成与原始文本相似的新文本,扩充文本数据集,有助于训练更强大的文本处理模型,如文本分类、机器翻译等模型。

(三)风格迁移

  • 图像风格迁移原理
    • GAN 可以实现图像风格迁移,即将一幅图像的内容与另一幅图像的风格进行融合。其原理是通过定义内容损失和风格损失,利用生成器生成具有目标风格的图像,同时判别器用于判断生成图像的质量和风格一致性。
    • 例如,使用预训练的 VGG 网络来计算内容损失和风格损失。内容损失衡量生成图像与原始内容图像在特征表示上的差异,风格损失衡量生成图像与目标风格图像在风格特征(如纹理、颜色分布等)上的差异。

代码示例实现风格迁移

import torchvision.models as models
import torch.nn.functional as F# 加载预训练的 VGG 网络
vgg = models.vgg19(pretrained=True).features.eval().to(device)# 定义内容损失函数
def content_loss(content_features, generated_features):return F.mse_loss(content_features, generated_features)# 定义风格损失函数
def style_loss(style_features, generated_features):style_loss = 0for s_feat, g_feat in zip(style_features, generated_features):# 计算 Gram 矩阵s_gram = gram_matrix(s_feat)g_gram = gram_matrix(g_feat)style_loss += F.mse_loss(s_gram, g_gram)return style_loss# Gram 矩阵计算函数
def gram_matrix(x):b, c, h, w = x.size()features = x.view(b * c, h * w)gram = torch.mm(features, features.t())return gram.div(b * c * h * w)

四、生成对抗网络训练中的挑战与解决策略

(一)模式崩溃

问题描述

模式崩溃是 GAN 训练中常见的问题之一,表现为生成器生成的样本多样性不足,往往集中在少数几种模式上。例如,在生成人脸图像时,可能生成的人脸都具有相似的特征,而不能涵盖人脸的多种可能形态。

解决策略

Wasserstein GAN(WGAN):WGAN 对 GAN 的损失函数进行了改进,采用 Wasserstein 距离来衡量真实数据分布和生成数据分布之间的差异,而不是传统的 JS 散度。这使得训练过程更加稳定,减少了模式崩溃的发生。其关键代码修改如下:

# 判别器的最后一层不再使用 Sigmoid 激活函数
self.conv4 = nn.Conv2d(512, 1, kernel_size=4, stride=2, padding=0)
# 定义 WGAN 的损失函数
def wgan_loss(real_pred, fake_pred):return -torch.mean(real_pred) + torch.mean(fake_pred)

模式正则化:通过在生成器的损失函数中加入正则化项,鼓励生成器生成更多样化的样本。例如,在生成器的损失函数中加入对生成样本的熵约束,使得生成样本的分布更加均匀。

(二)梯度消失

  • 问题描述
    • 在 GAN 训练初期,当判别器的性能非常好时,生成器的梯度可能会变得非常小,导致生成器难以更新参数,无法有效地学习到数据的分布。这是因为判别器能够很容易地区分真实数据和生成数据,使得生成器的损失函数接近饱和,梯度趋近于 。
  • 解决策略
    • 梯度惩罚(Gradient Penalty):在判别器的损失函数中加入梯度惩罚项,限制判别器的梯度大小,使得判别器不会过于强大,从而缓解生成器的梯度消失问题。例如,在 WGAN-GP(Wasserstein GAN with Gradient Penalty)中,梯度惩罚项的计算如下:
      def gradient_penalty(critic, real, fake, device):BATCH_SIZE, C, H, W = real.shape# 随机采样插值系数alpha = torch.rand((BATCH_SIZE, 1, 1, 1)).repeat(1, C, H, W).to(device)# 计算插值数据interpolated_images = real * alpha + fake * (1 - alpha)# 计算判别器对插值数据的输出mixed_scores = critic(interpolated_images)# 计算梯度gradient = torch.autograd.grad(inputs=interpolated_images,outputs=mixed_scores,grad_outputs=torch.ones_like(mixed_scores),create_graph=True,retain_graph=True,)[0]# 计算梯度惩罚项gradient = gradient.view(gradient.shape[0], -1)gradient_norm = gradient.norm(2, dim=1)gradient_penalty = torch.mean((gradient_norm - 1) ** 2)return gradient_penalty

    • 使用 Leaky ReLU 激活函数:在判别器和生成器的网络中使用 Leaky ReLU 激活函数替代传统的 ReLU 激活函数。Leaky ReLU 允许在负半轴有一个较小的斜率,从而避免了在某些情况下神经元完全不激活导致的梯度消失问题。

相关文章:

解析生成对抗网络(GAN):原理与应用

目录 一、引言 二、生成对抗网络原理 (一)基本架构 (二)训练过程 三、生成对抗网络的应用 (一)图像生成 无条件图像生成: (二)数据增强 (三&#xff…...

CodeIgniter URL结构

CodeIgniter 的URL 结构设计得简洁且易于管理。通常遵循以下模式&#xff1a; http://<domain>/<index_page>/<controller>/<method>/<parameters> 下面是每个部分的详细说明&#xff1a; <domain>&#xff1a; 这是你的网站域名&#…...

从 App Search 到 Elasticsearch — 挖掘搜索的未来

作者&#xff1a;来自 Elastic Nick Chow App Search 将在 9.0 版本中停用&#xff0c;但 Elasticsearch 拥有你构建强大的 AI 搜索体验所需的一切。以下是你需要了解的内容。 生成式人工智能的最新进展正在改变用户行为&#xff0c;激励开发人员创造更具活力、更直观、更引人入…...

鸿蒙本地模拟器 模拟TCP服务端的过程

鸿蒙模拟器模拟TCP服务端的过程涉及几个关键步骤&#xff0c;主要包括创建TCPSocketServer实例、绑定IP地址和端口、监听连接请求、接收和发送数据以及处理连接事件。以下是详细的模拟过程&#xff1a; **1.创建TCPSocketServer实例&#xff1a;**首先&#xff0c;需要导入鸿蒙…...

Qt/C++基于重力模拟的像素点水平堆叠效果

本文将深入解析一个基于 Qt/C 的像素点模拟程序。程序通过 重力作用&#xff0c;将随机分布的像素点下落并水平堆叠&#xff0c;同时支持窗口动态拉伸后重新计算像素点分布。 程序功能概述 随机生成像素点&#xff1a;程序在初始化时随机生成一定数量的像素点&#xff0c;每个…...

Zookeeper学习心得

本人学zookeeper时按照此文路线学的 Zookeeper学习大纲 - 似懂非懂视为不懂 - 博客园 一、Zookeeper安装 ZooKeeper 入门教程 - Java陈序员 - 博客园 Docker安装Zookeeper教程&#xff08;超详细&#xff09;_docker 安装zk-CSDN博客 二、 zookeeper的数据模型 ZooKeepe…...

嵌入式开发工程师面试题 - 2024/11/24

原文嵌入式开发工程师面试题 - 2024/11/24 转载请注明来源 1.若有以下定义语句double a[8]&#xff0c;*pa&#xff1b;int i5&#xff1b;对数组元素错误的引用是&#xff1f; A *a B a[5] C *&#xff08;p1&#xff09; D p[8] 解析&#xff1a; 在 C 或 C 语言中&am…...

Python中打印当前目录文件树的脚本

效果图&#xff1a; 实现脚本&#xff1a; 1、显示所有文件和文件夹&#xff1a; import osdef list_files(startpath, prefix):items os.listdir(startpath)items.sort()for index, item in enumerate(items):item_path os.path.join(startpath, item)is_last index le…...

全景图像(Panorama Image)向透视图像(Perspective Image)的跨视图转化(Cross-view)

一、概念讲解 全景图像到透视图像的转化是一个复杂的图像处理过程&#xff0c;它涉及到将一个360度的全景图像转换为一个具有透视效果的图像&#xff0c;这种图像更接近于人眼观察世界的方式。全景图像通常是一个矩形图像&#xff0c;它通过将球面图像映射到平面上得到&#xf…...

Redis 中的 hcan 命令耗内存,有什么优化的方式吗 ?

Redis 中的 hcan 命令耗内存&#xff0c;有什么优化的方式吗 &#xff1f; 1. 使用合适的游标值&#xff1a;2. 控制每次迭代返回的键数量&#xff1a;3. 避免长时间运行的迭代&#xff1a;4. 使用HSCAN与SCAN命令结合&#xff1a;5. 优化哈希表结构&#xff1a;6. 监控和调整R…...

豆包MarsCode算法题:三数之和问题

问题描述 思路分析 1. 排序数组 目的: 将数组 arr 按升序排序&#xff0c;这样可以方便地使用双指针找到满足条件的三元组&#xff0c;同时避免重复的三元组被重复计算。优势: 数组有序后&#xff0c;处理两个数和 target - arr[i] 的问题可以通过双指针快速找到所有可能的组…...

【Android】AnimationDrawable帧动画的实现

目录 引言 一、AnimationDrawable常用方法 1.1 导包 1.2 addFrame 1.3 setOneShot 1.4 start 1.5 stop 1.6 isRunning 二、 从xml文件获取并播放帧动画 2.1 创建XML文件 2.2 在布局文件中使用帧动画资源 三、在代码中生成并播放帧动画 3.1 addFrame加入帧动画列…...

【消息序列】详解(7):剖析回环模式--设备测试的核心利器

目录 一、概述 1.1. 本地回环模式 1.2. 远程环回模式 二、本地回环模式&#xff08;Local Loopback mode&#xff09; 2.1. 步骤 1&#xff1a;主机进入本地环回模式 2.2. 本地回环测试 2.2.1. 步骤 2a&#xff1a;主机发送HCI数据包并接收环回数据 2.2.2. 步骤 2b&…...

解决Ubuntu 22.04系统中网络Ping问题的方法

在Ubuntu 22.04系统中&#xff0c;网络问题时有发生&#xff0c;尤其是当涉及到静态IP地址配置和网线直连的两台机器时。本文将探讨一种常见问题——断开并重新连接网线后&#xff0c;尽管网卡显示为UP状态&#xff0c;但无法立即ping通对方机器&#xff0c;以及如何解决这一问…...

【大数据学习 | Spark-SQL】Spark-SQL编程

上面的是SparkSQL的API操作。 1. 将RDD转化为DataFrame对象 DataFrame&#xff1a; DataFrame是一种以RDD为基础的分布式数据集&#xff0c;类似于传统数据库中的二维表格。带有schema元信息&#xff0c;即DataFrame所表示的二维表数据集的每一列都带有名称和类型。这样的数…...

15分钟做完一个小程序,腾讯这个工具有点东西

我记得很久之前&#xff0c;我们都在讲什么低代码/无代码平台&#xff0c;这个概念很久了&#xff0c;但是&#xff0c;一直没有很好的落地&#xff0c;整体的效果也不算好。 自从去年 ChatGPT 这类大模型大火以来&#xff0c;各大科技公司也都推出了很多 AI 代码助手&#xff…...

manim动画编程(安装+入门)

文章目录 1.基本介绍2.效果展示3.安装步骤3.1安装manba软件3.2配置环境变量3.3查看是否成功3.4什么是mamba3.5创建虚拟环境3.6尝试进入虚拟环境 4.vscode操作4.1默认配置文件 5.安装ffmpeg6.安装manim软件6.vscode制作7.我的学习收获 1.基本介绍 这个manim就是一款软件&#x…...

STL算法之数值算法<stl_numeric.h>

这一节介绍的算法&#xff0c;统称为数值(numeric)算法。STL规定&#xff0c;欲使用它们&#xff0c;客户端必须包含头文件<numeric>.SGI将它们实现与<stl_numeric.h>文件中。 目录 运用实例 accumulate adjacent_difference inner_product partial_sum pow…...

Oracle如何记录登录用户IP

在运维场景中&#xff0c;在定位到某个SQL引起系统故障之后&#xff0c;想知道是哪台机器发过来的&#xff0c;方便定位源头&#xff0c;该如何解决&#xff1f; 在 Oracle 数据库中记录登录用户的 IP 地址可以通过多种方法实现。以下是几种常见的方法&#xff0c;包括使用触发…...

Python图像处理:打造平滑液化效果动画

液化动画中的强度变化是通过在每一帧中逐渐调整液化效果的强度参数来实现的。在提供的代码示例中&#xff0c;强度变化是通过一个简单的线性插值方法来控制的&#xff0c;即随着动画帧数的增加&#xff0c;液化效果的强度也逐渐增加。 def liquify_image(image, center, radius…...

AtCoder 第409​场初级竞赛 A~E题解

A Conflict 【题目链接】 原题链接&#xff1a;A - Conflict 【考点】 枚举 【题目大意】 找到是否有两人都想要的物品。 【解析】 遍历两端字符串&#xff0c;只有在同时为 o 时输出 Yes 并结束程序&#xff0c;否则输出 No。 【难度】 GESP三级 【代码参考】 #i…...

2024年赣州旅游投资集团社会招聘笔试真

2024年赣州旅游投资集团社会招聘笔试真 题 ( 满 分 1 0 0 分 时 间 1 2 0 分 钟 ) 一、单选题(每题只有一个正确答案,答错、不答或多答均不得分) 1.纪要的特点不包括()。 A.概括重点 B.指导传达 C. 客观纪实 D.有言必录 【答案】: D 2.1864年,()预言了电磁波的存在,并指出…...

【快手拥抱开源】通过快手团队开源的 KwaiCoder-AutoThink-preview 解锁大语言模型的潜力

引言&#xff1a; 在人工智能快速发展的浪潮中&#xff0c;快手Kwaipilot团队推出的 KwaiCoder-AutoThink-preview 具有里程碑意义——这是首个公开的AutoThink大语言模型&#xff08;LLM&#xff09;。该模型代表着该领域的重大突破&#xff0c;通过独特方式融合思考与非思考…...

MVC 数据库

MVC 数据库 引言 在软件开发领域,Model-View-Controller(MVC)是一种流行的软件架构模式,它将应用程序分为三个核心组件:模型(Model)、视图(View)和控制器(Controller)。这种模式有助于提高代码的可维护性和可扩展性。本文将深入探讨MVC架构与数据库之间的关系,以…...

令牌桶 滑动窗口->限流 分布式信号量->限并发的原理 lua脚本分析介绍

文章目录 前言限流限制并发的实际理解限流令牌桶代码实现结果分析令牌桶lua的模拟实现原理总结&#xff1a; 滑动窗口代码实现结果分析lua脚本原理解析 限并发分布式信号量代码实现结果分析lua脚本实现原理 双注解去实现限流 并发结果分析&#xff1a; 实际业务去理解体会统一注…...

[Java恶补day16] 238.除自身以外数组的乘积

给你一个整数数组 nums&#xff0c;返回 数组 answer &#xff0c;其中 answer[i] 等于 nums 中除 nums[i] 之外其余各元素的乘积 。 题目数据 保证 数组 nums之中任意元素的全部前缀元素和后缀的乘积都在 32 位 整数范围内。 请 不要使用除法&#xff0c;且在 O(n) 时间复杂度…...

高效线程安全的单例模式:Python 中的懒加载与自定义初始化参数

高效线程安全的单例模式:Python 中的懒加载与自定义初始化参数 在软件开发中,单例模式(Singleton Pattern)是一种常见的设计模式,确保一个类仅有一个实例,并提供一个全局访问点。在多线程环境下,实现单例模式时需要注意线程安全问题,以防止多个线程同时创建实例,导致…...

PAN/FPN

import torch import torch.nn as nn import torch.nn.functional as F import mathclass LowResQueryHighResKVAttention(nn.Module):"""方案 1: 低分辨率特征 (Query) 查询高分辨率特征 (Key, Value).输出分辨率与低分辨率输入相同。"""def __…...

Python Ovito统计金刚石结构数量

大家好,我是小马老师。 本文介绍python ovito方法统计金刚石结构的方法。 Ovito Identify diamond structure命令可以识别和统计金刚石结构,但是无法直接输出结构的变化情况。 本文使用python调用ovito包的方法,可以持续统计各步的金刚石结构,具体代码如下: from ovito…...

【C++进阶篇】智能指针

C内存管理终极指南&#xff1a;智能指针从入门到源码剖析 一. 智能指针1.1 auto_ptr1.2 unique_ptr1.3 shared_ptr1.4 make_shared 二. 原理三. shared_ptr循环引用问题三. 线程安全问题四. 内存泄漏4.1 什么是内存泄漏4.2 危害4.3 避免内存泄漏 五. 最后 一. 智能指针 智能指…...