当前位置: 首页 > news >正文

生成对抗网络DCGAN学习

在AI内容生成领域,有三种常见的AI模型技术:GAN、VAE、Diffusion。其中,Diffusion是较新的技术,相关资料较为稀缺。VAE通常更多用于压缩任务,而GAN由于其问世较早,相关的开源项目和科普文章也更加全面,适合入门学习。

博主从入门和学习角度用Tensorflow跑通了DCGAN,本文对其进行记录以及分享。

1.简介

GAN(Generative Adversarial Network)是一种用于生成模型的机器学习框架。其原理基于两个主要组件:生成器(Generator)和判别器(Discriminator),二者通过对抗学习的方式相互竞争和提升。

从2014年左右发展至今,GAN目前有很多分支:

  • GAN 朴素GAN,最原始版本
  • DCGAN 卷积神经网络GAN
  • CGAN 条件GAN,训练时传入额外条件,例如通过不同的mask区域生成不同内容,可控制的生成
  • SeqGAN 使用GAN生成某些风格的句子,但不能进行对答
  • Cycle GAN 可实现图像风格迁移,其实现略复杂
  • 省略

2.原理介绍

先来看图

梯度
判别
G
LeakyReLU
tanh
InputNoise
FullConnectLayer123
OutputImage
D
LeakyReLU
Sigmoid
InputImage
FullConnectLayer12
OutputOneValue

生成器(Generator)和判别器(Discriminator)是GAN的两个主要模型,生成器在上图中用缩写G表示,判别器用缩写D表示。
生成器G输入[N]的一维噪声,即InputNoise。输出[W * H * RGB](大致类似)的张量
判别器D输入一张图像,输出[1]的张量,即一个浮点数,通过0-1的值得到图像是真还是假

判别器需要尽可能的认出造假图片,生成器需要尽可能的骗过判别器,两者会在这2个目标上不断的通过反向传播进行学习,从而达到生成器和判别器的纳什均衡,最终输出质量很高的生成图像。

2.2 重点1

在训练中,判别器返回一个0-1区间的浮点数(如[0]=0.63,[0]=0.21)作为判断结果,值越高也越认为是真实图片。由于判别器也是一个神经网络模型,因此可以将输出层的梯度一直传递回输入层,然后将输入层的梯度作为生成器的梯度继续反向传播,从而完成一次训练。

然而,很多文章并没有提到这一点。如果没有接触过这种多模型梯度传递训练方法,可能会认为使用一个数学方法或者计算机视觉方法来构建判别器也可以让整个模型正常运行。但事实上,这种方法是不可行的(通常情况下)。

2.3 重点2

使用更多的层可以增强模型的推理能力。例如,在训练过程中,如果模型生成出眉毛 A 的特征,则有鼻子 B、C 和 D 相关的备选项;而如果生成出眉毛 E 的特征,则有鼻子 F 和 G 相关的备选项。

这也是为什么生成器需要使用三个隐层的原因(博主的观点)。通过增加隐层的数量,模型可以捕捉到更多的特征和抽象概念,从而提高生成器的表现能力和推理能力。更深层次的网络结构能够帮助模型学习更复杂的模式和关联,使其在生成结果时更加准确和多样化。

上图生成器部分的激活函数用的是LeakyReLU,实际上就单隐层神经网络来说,ReLU要比Sigmoid能多解决很多类型问题,Sigmoid更适合分类问题,遇到一些奇怪的问题不容易收敛,而LeakyReLU激活函数即和ReLU逻辑一样也可以返回负数信息,这是博主觉得采用这个激活函数的原因。
而至于tanH和Sigmoid的比较,它们在某种程度上相似。一般来说,网上普遍认为tanH比Sigmoid更好,主要原因是它具有较窄的数值边界范围。

2.4 重点3

对于2套样本比较损失这类问题,一般使用二分类交叉熵,这不同于分类问题。
而二分类交叉熵又是在只有2种结果(r和1-r),的情况下对公式进行的简化:
https://blog.csdn.net/grayrail/article/details/131619144

2.5 模式崩溃

训练时还会出现一种情况,即生成器始终卡在一个生成结果上,比如生成0-9数字,结果训练几轮后始终在生成数字3。
这种情况称为模式崩溃,一般增加训练样本数量并调节参数,没有比较好的办法。

3.实践准备

python库下载使用国内镜像源:
https://zhuanlan.zhihu.com/p/477179822

使用方式:

pip install -i https://pypi.tuna.tsinghua.edu.cn/simple pyspider

github库下载耽误时间,可以缓存到gitee:
在这里插入图片描述

而gitee也有自己缓存好的镜像库,可以先去这里查:
https://gitcode.net/mirrors

python库查找:
https://pypi.org/

在pip中查找python库:
先 pip install pip-search 再使用命令 pip_search 搜索

4.实践

全连接神经网络版本的朴素GAN效果相对较差,而DCGAN(Deep Convolutional GAN)是卷积神经网络版本的GAN,下面以DCGAN为例使用Tensorflow进行实现:

import tensorflow as tf
import matplotlib.pyplot as plt
from tensorflow.keras import layers# 定义生成器模型
def build_generator():model = tf.keras.Sequential()model.add(layers.Dense(7*7*256, use_bias=False, input_shape=(100,)))model.add(layers.BatchNormalization())model.add(layers.LeakyReLU())model.add(layers.Reshape((7, 7, 256)))assert model.output_shape == (None, 7, 7, 256)  # 注意:batch size 没有限制model.add(layers.Conv2DTranspose(128, (5, 5), strides=(1, 1), padding='same', use_bias=False))assert model.output_shape == (None, 7, 7, 128)model.add(layers.BatchNormalization())model.add(layers.LeakyReLU())model.add(layers.Conv2DTranspose(64, (5, 5), strides=(2, 2), padding='same', use_bias=False))assert model.output_shape == (None, 14, 14, 64)model.add(layers.BatchNormalization())model.add(layers.LeakyReLU())model.add(layers.Conv2DTranspose(1, (5, 5), strides=(2, 2), padding='same', use_bias=False, activation='tanh'))assert model.output_shape == (None, 28, 28, 1)return model# 定义判别器模型
def build_discriminator():model = tf.keras.Sequential()model.add(layers.Conv2D(64, (5, 5), strides=(2, 2), padding='same',input_shape=[28, 28, 1]))model.add(layers.LeakyReLU())model.add(layers.Dropout(0.3))model.add(layers.Conv2D(128, (5, 5), strides=(2, 2), padding='same'))model.add(layers.LeakyReLU())model.add(layers.Dropout(0.3))model.add(layers.Flatten())model.add(layers.Dense(1))return model# 定义生成器和判别器
generator = build_generator()
discriminator = build_discriminator()# 定义损失函数
loss_fn = tf.keras.losses.BinaryCrossentropy(from_logits=True)# 定义生成器和判别器的优化器
generator_optimizer = tf.keras.optimizers.Adam(1e-4)
discriminator_optimizer = tf.keras.optimizers.Adam(1e-4)def generator_loss(fake_output):return loss_fn(tf.ones_like(fake_output), fake_output)def discriminator_loss(real_output, fake_output):real_loss = loss_fn(tf.ones_like(real_output), real_output)fake_loss = loss_fn(tf.zeros_like(fake_output), fake_output)total_loss = real_loss + fake_lossreturn total_loss# 定义训练循环
@tf.function  #这个是tensorflow的装饰器,标记后可提升性能,不加此标记也可
def train_step(images):# 生成噪声向量noise = tf.random.normal([BATCH_SIZE, 100])with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:# 使用生成器生成假图片generated_images = generator(noise, training=True)# 使用判别器判断真假图片real_output = discriminator(images, training=True)fake_output = discriminator(generated_images, training=True)# 计算损失函数gen_loss = generator_loss(fake_output)disc_loss = discriminator_loss(real_output, fake_output)# 计算梯度并更新生成器和判别器的参数gradients_of_generator = gen_tape.gradient(gen_loss, generator.trainable_variables)gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.trainable_variables)generator_optimizer.apply_gradients(zip(gradients_of_generator, generator.trainable_variables))discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.trainable_variables))def generate_and_save_images(model, epoch, test_input):predictions = model(test_input, training=False)print("predictions.shape:", predictions.shape)num_images = predictions.shape[0]rows = int(num_images ** 0.5) # 计算行数cols = num_images // rows # 计算列数fig = plt.figure(figsize=(8, 8))for i in range(num_images):plt.subplot(rows, cols, i+1)plt.imshow(predictions[i, :, :, 0] * 127.5 + 127.5, cmap='gray')plt.axis('off')plt.savefig('image_at_epoch_{:04d}.png'.format(epoch))#plt.show()# 加载MNIST数据集
(train_images, train_labels), (_, _) = tf.keras.datasets.mnist.load_data()# 标准化数据
train_images = train_images.reshape(train_images.shape[0], 28, 28, 1).astype('float32')
train_images = (train_images - 127.5) / 127.5# 批量大小与训练次数
BATCH_SIZE = 256
EPOCHS = 50# 数据集切分为批次并进行训练
dataset = tf.data.Dataset.from_tensor_slices(train_images).shuffle(60000).batch(BATCH_SIZE)for epoch in range(EPOCHS):for i,image_batch in enumerate(dataset):print("sub i",i)train_step(image_batch)print("------------------------------------------------------epoch:", epoch)# 每个 epoch 结束后生成并保存一组图像if (epoch + 1) % 5 == 0:seed = tf.random.normal([BATCH_SIZE, 100])generate_and_save_images(generator, epoch + 1, seed)

跑一阵子MNIST数据集后,结果如下:
在这里插入图片描述


参考:

论文精读: https://www.bilibili.com/video/BV1rb4y187vD

同济子豪兄精读版本: https://www.bilibili.com/video/BV1oi4y1m7np

相关文章:

生成对抗网络DCGAN学习

在AI内容生成领域,有三种常见的AI模型技术:GAN、VAE、Diffusion。其中,Diffusion是较新的技术,相关资料较为稀缺。VAE通常更多用于压缩任务,而GAN由于其问世较早,相关的开源项目和科普文章也更加全面&#…...

error: #5: cannot open source input file “core_cmInstr.h“

GD32F103VET6和STM32F103VET6引脚兼容。 GD32F103VET6工程模板需要包含头文件:core_cmInstr.h和core_cmFunc.h,这个和STM32F103还是有区别的,否则会报错,如下: error: #5: cannot open source input file "core…...

FastAPI 教程、结合vue实现前后端分离

英文版文档:https://fastapi.tiangolo.com/ 中文版文档:https://fastapi.tiangolo.com/zh/ 1、FastAPI 教程 简 介 FastAPI 和 Sanic 类似,都是 Python 中的异步 web 框架。相比 Sanic,FastAPI 更加的成熟、社区也更加的活跃。 …...

算法通关村第四关——如何基于数组(链表)实现栈

栈的基础知识 栈的特征 特征1 栈和队列是比较特殊的线性表,又被称为 访问受限的线性表。栈是很多表达式、符号等运算的基础,也是递归的底层实现(递归就是方法自己调用自己,在JVM的虚拟机栈中,一个线程中的栈帧就是…...

Postgresql警告日志的配置

文章目录 1.postgresql与日志有关的参数2.开启日志3.指定日志目录4.設置文件名format5.設置日志文件產出模式6.設置日志记录格式7.日誌輪換7.1非截斷式輪換7.2 截斷式輪換 8.日誌記錄內容8.1 log_statement8.2 log_min_duration_statement 9 輸出範本 1.postgresql与日志有关的…...

Java、JSAPI、 ssm架构 微信支付demo

1.前端 index.html <%page import"com.tenpay.configure.WxPayConfig"%> <% page language"java" contentType"text/html; charsetUTF-8" pageEncoding"UTF-8"%> <html><style>#fukuan{font-size: 50px;marg…...

MongoDB文档--基本安装-linux安装(mongodb环境搭建)-docker安装(挂载数据卷)-以及详细版本对比

阿丹&#xff1a; 前面了解了mongodb的一些基本概念。本节文章对安装mongodb进行讲解以及汇总。 官网教程如下&#xff1a; 安装 MongoDB - MongoDB-CN-Manual 版本特性 下面是各个版本的选择请在安装以及选择版本的时候参考一下&#xff1a; MongoDB 2.x 版本&#xff1a…...

tomcat限制IP访问

tomcat可以通过增加配置&#xff0c;来对来源ip进行限制&#xff0c;即只允许某些ip访问或禁止某些来源ip访问。 配置路径&#xff1a;server.xml 文件下 标签下。与同级 <Valve className"org.apache.catalina.valves.RemoteAddrValve" allow"192.168.x.x&…...

互联网宠物医院系统开发:数字化时代下宠物医疗的革新之路

随着人们对宠物关爱意识的提高&#xff0c;宠物医疗服务的需求也日益增加。传统的宠物医院存在排队等待、预约难、信息不透明等问题&#xff0c;给宠物主人带来了诸多不便。而互联网宠物医院系统的开发&#xff0c;则可以带来许多便利和好处。下面将介绍互联网宠物医院系统开发…...

docker镜像批量导出导入

docker镜像批量导出导入 image_tar为存储镜像目录 删除所有容器 一、首先需要停止所有运行中的容器 docker stopdocker ps -a -q docker ps -a -q 意思是列出所有容器&#xff08;包括未运行的&#xff09;&#xff0c;只显示容器编号&#xff0c;其中 -a : 显示所有的容器&…...

宇凡微2.4g遥控船开发方案,采用合封芯片

2.4GHz遥控船的开发方案是一个有趣且具有挑战性的项目。这样的遥控船可以通过无线2.4GHz频率进行远程控制&#xff0c;让用户在池塘或湖泊上畅游。以下是一个简要的2.4GHz遥控船开发方案&#xff1a; 基本构想如下 mcu驱动两个小电机&#xff0c;小电机上安装两个螺旋桨&#…...

RPC框架引入zookeeper服务注册与服务发现

Zookeeper概念及其作用 ZooKeeper是一个分布式的&#xff0c;开放源码的分布式应用程序协调服务&#xff0c;是Google的Chubby一个开源的实现&#xff0c;是大数据生态中的重要组件。它是集群的管理者&#xff0c;监视着集群中各个节点的状态根据节点提交的反馈进行下一步合理…...

MySQL用通配符过滤数据

简单的不使用通配符过滤数据的方式使用的值都是已知的&#xff0c;但是当搜索产品名中包含ashui的所有产品时&#xff0c;用简单的比较操作符肯定不行&#xff0c;必须使用通配符。利用通配符可以创建比较特定数据的搜索模式。 通配符&#xff1a;用来匹配值的一部分的特殊字符…...

低通、高通、带通、阻通滤波器

目录 低通、高通、带通、阻通滤波器 低通、高通、带通、带阻滤波器的区别 通俗理解&#xff1a; 1、低通滤波器 2、高通滤波器 3、带通滤波器 4、带阻滤波器 5、全通滤波器 低通、高通、带通、阻通滤波器 低通、高通、带通、带阻滤波器的区别 低通滤波器&#xff1a;只…...

IDEA SpringBoot Maven profiles 配置

IDEA SpringBoot Maven profiles 配置 IDEA版本&#xff1a; IntelliJ IDEA 2022.2.3 注意&#xff1a;切换环境之后务必点击一下刷新&#xff0c;推荐点击耗时更短。 application.yaml spring:profiles:active: env多环境文件名&#xff1a; application-dev.yaml、 applicat…...

微信小程序 背景图片如何占满整个屏幕

1. 在页面的wxss文件中&#xff0c;设置背景图片的样式&#xff1a; page{background-image: url(图片路径);background-size: 100% 100%;background-repeat: no-repeat; } 2. 在页面的json文件中&#xff0c;设置背景图片的样式&#xff1a; {"backgroundTextStyle&qu…...

邪恶版ChatGPT来了!

「邪恶版」ChatGPT 出现&#xff1a;每月 60 欧元&#xff0c;毫无道德限制&#xff0c;专为“网络罪犯”而生。 WormGPT 并不是一个人工智能聊天机器人&#xff0c;它的开发目的不是为了有趣地提供无脊椎动物的人工智能帮助&#xff0c;就像专注于猫科动物的CatGPT一样。相反&…...

一、Postfix[安装与配置、smtp认证、Python发送邮件以及防垃圾邮件方法、使用腾讯云邮件服务]

Debian 11 一、安装 apt install postfix 二、配置 1.dns配置 解释&#xff1a;搭建真实的邮件服务器需要在DNS提供商那里配置下面的dns 配置A记录mail.www.com-1.x.x.x配置MX记录www.com-mail.www.com 解释&#xff1a;按照上面的配置通常邮件格式就是adminwww.com其通过…...

React哲学——官方示例

在本篇技术博客中&#xff0c;我们将介绍React官方示例&#xff1a;React哲学。我们将深入探讨这个示例中使用的组件化、状态管理和数据流等核心概念。让我们一起开始吧&#xff01; 项目概览 React是一个流行的JavaScript库&#xff0c;用于构建用户界面。React的设计理念是…...

设计模式之开闭原则

什么是开闭原则? 开放封闭原则称为OCP原则&#xff08;Open Closed Principle&#xff09;是所有面向对象原则的核心。 “开闭原则”是面向对象编程中最基础和最重要的设计原则之一。 软件设计本身所追求的目标就是封装变化、降低耦合&#xff0c;而开放封闭原则正是对这一…...

超短脉冲激光自聚焦效应

前言与目录 强激光引起自聚焦效应机理 超短脉冲激光在脆性材料内部加工时引起的自聚焦效应&#xff0c;这是一种非线性光学现象&#xff0c;主要涉及光学克尔效应和材料的非线性光学特性。 自聚焦效应可以产生局部的强光场&#xff0c;对材料产生非线性响应&#xff0c;可能…...

智慧工地云平台源码,基于微服务架构+Java+Spring Cloud +UniApp +MySql

智慧工地管理云平台系统&#xff0c;智慧工地全套源码&#xff0c;java版智慧工地源码&#xff0c;支持PC端、大屏端、移动端。 智慧工地聚焦建筑行业的市场需求&#xff0c;提供“平台网络终端”的整体解决方案&#xff0c;提供劳务管理、视频管理、智能监测、绿色施工、安全管…...

相机Camera日志实例分析之二:相机Camx【专业模式开启直方图拍照】单帧流程日志详解

【关注我&#xff0c;后续持续新增专题博文&#xff0c;谢谢&#xff01;&#xff01;&#xff01;】 上一篇我们讲了&#xff1a; 这一篇我们开始讲&#xff1a; 目录 一、场景操作步骤 二、日志基础关键字分级如下 三、场景日志如下&#xff1a; 一、场景操作步骤 操作步…...

Day131 | 灵神 | 回溯算法 | 子集型 子集

Day131 | 灵神 | 回溯算法 | 子集型 子集 78.子集 78. 子集 - 力扣&#xff08;LeetCode&#xff09; 思路&#xff1a; 笔者写过很多次这道题了&#xff0c;不想写题解了&#xff0c;大家看灵神讲解吧 回溯算法套路①子集型回溯【基础算法精讲 14】_哔哩哔哩_bilibili 完…...

聊聊 Pulsar:Producer 源码解析

一、前言 Apache Pulsar 是一个企业级的开源分布式消息传递平台&#xff0c;以其高性能、可扩展性和存储计算分离架构在消息队列和流处理领域独树一帜。在 Pulsar 的核心架构中&#xff0c;Producer&#xff08;生产者&#xff09; 是连接客户端应用与消息队列的第一步。生产者…...

linux 错误码总结

1,错误码的概念与作用 在Linux系统中,错误码是系统调用或库函数在执行失败时返回的特定数值,用于指示具体的错误类型。这些错误码通过全局变量errno来存储和传递,errno由操作系统维护,保存最近一次发生的错误信息。值得注意的是,errno的值在每次系统调用或函数调用失败时…...

Spring AI与Spring Modulith核心技术解析

Spring AI核心架构解析 Spring AI&#xff08;https://spring.io/projects/spring-ai&#xff09;作为Spring生态中的AI集成框架&#xff0c;其核心设计理念是通过模块化架构降低AI应用的开发复杂度。与Python生态中的LangChain/LlamaIndex等工具类似&#xff0c;但特别为多语…...

DeepSeek 技术赋能无人农场协同作业:用 AI 重构农田管理 “神经网”

目录 一、引言二、DeepSeek 技术大揭秘2.1 核心架构解析2.2 关键技术剖析 三、智能农业无人农场协同作业现状3.1 发展现状概述3.2 协同作业模式介绍 四、DeepSeek 的 “农场奇妙游”4.1 数据处理与分析4.2 作物生长监测与预测4.3 病虫害防治4.4 农机协同作业调度 五、实际案例大…...

Springboot社区养老保险系统小程序

一、前言 随着我国经济迅速发展&#xff0c;人们对手机的需求越来越大&#xff0c;各种手机软件也都在被广泛应用&#xff0c;但是对于手机进行数据信息管理&#xff0c;对于手机的各种软件也是备受用户的喜爱&#xff0c;社区养老保险系统小程序被用户普遍使用&#xff0c;为方…...

【生成模型】视频生成论文调研

工作清单 上游应用方向&#xff1a;控制、速度、时长、高动态、多主体驱动 类型工作基础模型WAN / WAN-VACE / HunyuanVideo控制条件轨迹控制ATI~镜头控制ReCamMaster~多主体驱动Phantom~音频驱动Let Them Talk: Audio-Driven Multi-Person Conversational Video Generation速…...