当前位置: 首页 > news >正文

从零开始 - 在Python中构建和训练生成对抗网络(GAN)模型

生成对抗网络(GANs)是一种强大的生成模型,可以合成新的逼真图像。通过完整的实现过程,读者将对GANs在幕后的工作原理有深刻的理解。本教程首先导入必要的库并加载将用于训练GAN的Fashion-MNIST数据集。然后,提供了构建GAN核心组件(生成器和判别器模型)的代码示例。接下来的部分解释了如何构建一个组合模型,该模型训练生成器以欺骗判别器,以及如何设计一个训练函数来优化对抗过程。

目录:

1. 导入库和下载数据集

2. 构建生成器模型

3. 构建判别器模型

4. 构建组合模型

5. 构建训练函数

6. 训练和观察结果

  1. 导入库和下载数据集

让我们首先导入本文中将使用的重要库:

from __future__ import print_function, division
from keras.datasets import fashion_mnist
from keras.layers import Input, Dense, Reshape, Flatten, Dropout
from keras.layers import BatchNormalization, Activation, ZeroPadding2D
from keras.layers.advanced_activations import LeakyReLU
from keras.layers.convolutional import UpSampling2D, Conv2D
from keras.models import Sequential, Model
from keras.optimizers import Adam
import numpy as np
import matplotlib.pyplot as plt

在本文中,您将在Fashion-MNIST数据集上训练DCGAN。Fashion-MNIST包含60,000个用于训练的灰度图像和一个包含10,000个图像的测试集。每个28×28的灰度图像与10个类别中的一个标签相关联。Fashion-MNIST旨在作为原始MNIST数据集的直接替代品,用于对比机器学习算法的性能。与三通道的彩色图像相比,灰度图像在一通道上训练卷积网络时需要更少的计算能力,这使您更容易在没有GPU的个人计算机上进行训练。

a43e74d2137f4a31ce4d40fe66ab7a52.jpeg

数据集分为10个时尚类别。类别标签如下:

760b0174d7592e71606bec49bf3407a5.jpeg

您可以使用以下代码加载数据集:

(training_data, _), (_, _) = fashion_mnist.load_data()
X_train = training_data / 127.5 - 1.
X_train = np.expand_dims(X_train, axis=3)

要可视化数据集中的图像,可以使用以下代码:

def visualize_input(img, ax):ax.imshow(img, cmap='gray')width, height = img.shapethresh = img.max()/2.5for x in range(width):for y in range(height):ax.annotate(str(round(img[x][y],2)), xy=(y,x),horizontalalignment='center',verticalalignment='center',color='white' if img[x][y]<thresh else="" 'black')=""  =""  
fig = plt.figure(figsize = (12,12))
ax = fig.add_subplot(111)
visualize_input(training_data[3343], ax)We also use batch normalization and a ReLU activation.
For each of these layers, the general scheme is convolution ⇒ batch normalization
⇒ ReLU. We keep stacking up layers like this until we get the final transposed
convolution layer with shape 28 × 28 × 1:

b001bcb6986483ef65aa3f19ef9b657e.jpeg

2. 构建生成器模型

正如我们在前面的文章中所探讨的,GANs由两个主要组件组成,即生成器和判别器。在这一部分中,我们将构建生成器模型,其输入将是一个噪声向量(z)。生成器的架构如下图所示。

第一层是一个全连接层,然后被重新塑造成深而窄的层,在原始的DCGAN论文中,作者将输入重新塑造为4×4×1024。在这里,我们将使用7×7×128。然后,我们使用上采样层将特征映射的维度从7×7加倍到14×14,然后再次加倍到28×28。在这个网络中,我们使用了三个卷积层。我们还将使用批归一化和ReLU激活。

对于每个层,通用方案是卷积 ⇒ 批归一化 ⇒ ReLU。我们不断地堆叠这样的层,直到得到最终的转置卷积层,形状为28×28×1。

4fabaa16f62175b0c474ff334293c279.jpeg

以下是构建上述生成器模型的Keras代码:

def build_generator():generator = Sequential()generator.add(Dense(6272, activation="relu", input_dim=100)) # Add dense layergenerator.add(Reshape((7, 7, 128)))  # reshape the imagegenerator.add(UpSampling2D()) # Upsampling layer to double the size of the imagegenerator.add(Conv2D(128, kernel_size=3, padding="same", activation="relu"))generator.add(BatchNormalization(momentum=0.8))generator.add(UpSampling2D())# convolutional + batch normalization layersgenerator.add(Conv2D(64, kernel_size=3, padding="same", activation="relu"))generator.add(BatchNormalization(momentum=0.8))# convolutional layer with filters = 1generator.add(Conv2D(1, kernel_size=3, padding="same", activation="relu"))generator.summary() # prints the model summary"""We don't add upsampling here because the image size of 28 × 28 is equal to the image size in the MNIST dataset. You can adjust this for your own problem."""noise = Input(shape=(100,))fake_image = generator(noise)# Returns a model that takes the noise vector as an input and outputs the fake imagereturn Model(inputs=noise, outputs=fake_image)

3. 构建判别器模型

GANs的第二个主要组件是判别器。判别器只是一个传统的卷积分类器。判别器的输入是28×28×1的图像。我们希望有一些卷积层,然后是输出的全连接层。

与之前一样,我们希望得到一个Sigmoid输出,并且我们需要返回logits。对于卷积层的深度,我们可以从第一层开始使用32或64个过滤器,然后在添加层时将深度加倍。在这个实现中,我们将从64层开始,然后是128,然后是256。对于降采样,我们不使用池化层。相反,我们只使用步幅卷积层进行降采样,类似于Radford等人的实现。

我们还使用批归一化和dropout来优化训练。对于四个卷积层的每一层,通用方案是卷积 ⇒ 批归一化 ⇒ 泄漏的ReLU。

c99ea77aec1203923646688e02c6e1d6.jpeg

现在,让我们构建build_discriminator函数:

def build_discriminator():discriminator = Sequential()discriminator.add(Conv2D(32, kernel_size=3, strides=2, input_shape=(28,28,1), padding="same"))discriminator.add(LeakyReLU(alpha=0.2))discriminator.add(Dropout(0.25))discriminator.add(Conv2D(64, kernel_size=3, strides=2,padding="same"))discriminator.add(ZeroPadding2D(padding=((0,1),(0,1))))discriminator.add(BatchNormalization(momentum=0.8))discriminator.add(LeakyReLU(alpha=0.2))discriminator.add(Dropout(0.25))discriminator.add(Conv2D(128, kernel_size=3, strides=2, padding="same"))discriminator.add(BatchNormalization(momentum=0.8))discriminator.add(LeakyReLU(alpha=0.2))discriminator.add(Dropout(0.25))discriminator.add(Conv2D(256, kernel_size=3, strides=1, padding="same"))discriminator.add(BatchNormalization(momentum=0.8))discriminator.add(LeakyReLU(alpha=0.2))discriminator.add(Dropout(0.25))discriminator.add(Flatten())discriminator.add(Dense(1, activation='sigmoid'))img = Input(shape=(28,28,1))probability = discriminator(img)return Model(inputs=img, outputs=probability)

4. 构建组合模型

正如本系列的第二篇文章中所解释的,为了训练生成器,我们需要构建一个包含生成器和判别器的组合网络。组合模型以噪声信号(z)作为输入,并将判别器的预测输出作为虚假或真实输出。

e90e9c2335ae20998fab73b192b20485.jpeg

重要的是要记住,我们希望在组合模型中禁用判别器的训练,正如本系列的第二篇文章中所解释的那样。在训练生成器时,我们不希望判别器更新权重,但我们仍然希望将判别器模型包含在生成器训练中。因此,我们创建一个包含两个模型的组合网络,但在组合网络中冻结判别器模型的权重:

optimizer = Adam(learning_rate=0.0002, beta_1=0.5)
discriminator = build_discriminator()
discriminator.compile(loss='binary_crossentropy', optimizer=optimizer, metrics=['accuracy'])
discriminator.trainable = False# Build the generator
generator = build_generator()
z = Input(shape=(100,))
img = generator(z)
valid = discriminator(img)
combined = Model(inputs=z, outputs=valid)
combined.compile(loss='binary_crossentropy', optimizer=optimizer)

5. 构建训练函数

为了训练GAN模型,我们训练两个网络:判别器和我们在前面部分创建的组合网络。让我们构建train函数,该函数接受以下参数:

  • epoch

  • batch size 大小

  • save_interval,以指定多久保存一次结果

def train(epochs, batch_size=128, save_interval=50):valid = np.ones((batch_size, 1))fake = np.zeros((batch_size, 1))for epoch in range(epochs):  # Train Discriminator networkidx = np.random.randint(0, X_train.shape[0], batch_size)imgs = X_train[idx]noise = np.random.normal(0, 1, (batch_size, 100))gen_imgs = generator.predict(noise)d_loss_real = discriminator.train_on_batch(imgs, valid)d_loss_fake = discriminator.train_on_batch(gen_imgs, fake)d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)g_loss = combined.train_on_batch(noise, valid)# printing progressprint("%d [D loss: %f, acc.: %.2f%%] [G loss: %f]" %(epoch, d_loss[0], 100*d_loss[1], g_loss))if epoch % save_interval == 0:plot_generated_images(epoch, generator)

我们还将创建另一个函数`plot_generated_images()` 来绘制生成的图像。

def plot_generated_images(epoch, generator, examples=100, dim=(10, 10),figsize=(10, 10)):noise = np.random.normal(0, 1, size=[examples, latent_dim])generated_images = generator.predict(noise)generated_images = generated_images.reshape(examples, 28, 28)plt.figure(figsize=figsize)for i in range(generated_images.shape[0]):plt.subplot(dim[0], dim[1], i+1)plt.imshow(generated_images[i], interpolation='nearest', cmap='gray_r')plt.axis('off')plt.tight_layout()plt.savefig('gan_generated_image_epoch_%d.png' % epoch

最后,让我们为训练GAN模型定义重要的变量和参数:

# Input shape
img_shape = (28,28,1)
channels = 1
latent_dim = 100
optimizer = Adam(0.0002, 0.5)# Build and compile the discriminator
discriminator = build_discriminator()
discriminator.compile(loss='binary_crossentropy', optimizer=optimizer, metrics=['accuracy'])
# Build the generator
generator = build_generator()
# The generator takes noise as input and generates imgs
z = Input(shape=(latent_dim,))
img = generator(z)
# For the combined model we will only train the generator
discriminator.trainable = False
# The discriminator takes generated images as input and determines validity
valid = discriminator(img)
# The combined model  (stacked generator and discriminator)
# Trains the generator to fool the discriminator
combined = Model(z, valid)
combined.compile(loss='binary_crossentropy', optimizer=optimizer)

6. 训练和观察结果

此时,代码实现已经完成,我们准备开始DCGAN的训练。要训练模型,请运行以下代码行:

train(epochs=1000, batch_size=32, save_interval=50)

这将在1,000个epochs上运行训练,并每50个epochs保存一次图像。当运行`train()` 函数时,训练进度将如下所示:

86d990d67af3b9ee259b9424b3e1e521.jpeg

如下图所示,在epoch = 0时,图像只是随机噪声,没有明确的模式或有意义的数据。到了第50个epoch,图案已经开始形成。

80fb00ada0dc22c60488b9d4fda559aa.jpeg

在训练过程的后期,到了第1,000个epoch,您可以看到清晰的形状,可能能够猜测输入到GAN模型的训练数据的类型。

49de38a46bd9065cb03bb8125b1a990e.jpeg

再快进到第10,000个epoch,您会发现生成器已经非常擅长重新创建训练数据集中不存在的新图像。

de6db2898ea32036dd85c216a275c842.jpeg

·  END  ·

HAPPY LIFE

aeccadbe0b4d2dc12a3db6eea9e70b49.png

本文仅供学习交流使用,如有侵权请联系作者删除

相关文章:

从零开始 - 在Python中构建和训练生成对抗网络(GAN)模型

生成对抗网络&#xff08;GANs&#xff09;是一种强大的生成模型&#xff0c;可以合成新的逼真图像。通过完整的实现过程&#xff0c;读者将对GANs在幕后的工作原理有深刻的理解。本教程首先导入必要的库并加载将用于训练GAN的Fashion-MNIST数据集。然后&#xff0c;提供了构建…...

OfficeWeb365 Indexs 任意文件读取漏洞复现

0x01 产品简介 OfficeWeb365 是专注于 Office 文档在线预览及PDF文档在线预览云服务,包括 Microsoft Word 文档在线预览、Excel 表格在线预览、Powerpoint 演示文档在线预览,WPS 文字处理、WPS 表格、WPS 演示及 Adobe PDF 文档在线预览。 0x02 漏洞概述 OfficeWeb365 /Pi…...

Crypto的简单应用-前后端加密传输

最近遇到一个数据脱敏处理的需求&#xff0c;想要用一种轻量级的技术实现&#xff0c;必须足够简单并且适用于所有场合如前后端加密传输、路由加密、数据脱敏等。抽时间研究了一下Crypto加密库的一些API&#xff0c;发现完全符合上述需求&#xff0c;扩展也比较容易。 1、前端加…...

Vue3-32-路由-重定向路由

什么是重定向 路由的重定向 &#xff1a;将匹配到的路由 【替换】 为另一个路由。 redirect : 重定向的关键字。 重定向的特点 1、重定向是路由的直接替换,路由的地址是直接改变的&#xff1b; 2、在没有子路由配置的情况下&#xff0c;重定向的路由可以省略 component 属性的配…...

如何用js动态修改字体大小

在项目中&#xff0c;我们常常会遇到使用v-html渲染文本的情况。 如果需要点击大中小三个字号按钮&#xff0c;需要修改字体的大小。那我们应该怎么做呢 function fontSize(element, type) {let size {big: 22,middle: 16,small: 12};var result element.innerHTML.replac…...

【BIG_FG_CSDN】C++ 数组与指针 (个人向——学习笔记)

一维数组 在内存占用连续存储单元的相同类型数据序列的存储。 数组是静态存储器的块&#xff1b;在编译时确定大小后才能使用&#xff1b; 其声明格式如下&#xff1a; 元素类型 数组名[常量]&#xff1b;元素类型&#xff1a;数组中元素的数据类型&#xff1b; 常量&#…...

桌面天气预报软件 Weather Widget free mac特点介绍

Weather Widget free for Mac多种吸引人的小部件设计可供选择&#xff0c;可以随时了解天气&#xff01;还可以在Dock和菜单栏中为您提供简短的天气预报或当前状况的概述。 Weather Widget free for Mac软件介绍 始终在桌面上使用时尚的天气小部件来随时了解天气&#xff01;多…...

HarmonyOS应用开发-搭建开发环境

本文介绍如何搭建 HarmonyOS 应用的开发环境&#xff0c;介绍下载安装 DevEco Studio 开发工具和 SDK 的详细流程。华为鸿蒙 DevEco Studio 是面向全场景的一站式集成开发环境&#xff0c;面向全场景多设备&#xff0c;提供一站式的分布式应用开发平台&#xff0c;支持分布式多…...

<JavaEE> TCP 的通信机制(五) -- 延时应答、捎带应答、面向字节流

目录 TCP的通信机制的核心特性 七、延时应答 1&#xff09;什么是延时应答&#xff1f; 2&#xff09;延时应答的作用 八、捎带应答 1&#xff09;什么是捎带应答&#xff1f; 2&#xff09;捎带应答的作用 九、面向字节流 1&#xff09;沾包问题 2&#xff09;“沾包…...

电脑怎么设置代理IP上网?如何隐藏自己电脑的真实IP?

在现代互联网中&#xff0c;代理IP已成为许多用户保护隐私和上网安全的重要手段。通过设置代理IP&#xff0c;用户可以隐藏自己的真实IP地址&#xff0c;提高上网的安全性&#xff0c;同时保护个人信息不被泄露。本文将详细介绍如何设置代理IP上网以及如何隐藏电脑的真实IP地址…...

Django信号机制源码分析(观察者模式)

Django信号的实现原理本质是设计模式中的观察者模式&#xff0c;浅谈Python设计模式 -- 观察者模式&#xff0c;也可以叫做发布-订阅模式&#xff0c;信号对象维护一个订阅者列表&#xff0c;当信号被触发时&#xff0c;它会遍历订阅者&#xff0c;依次通知它们。 先来回顾一下…...

MyBatis-config.xml配置文件

1、基本介绍&#xff1a; mybatis的核心配置文件(mybatis-config.xml)&#xff0c;比如配置jdbc连接信息&#xff0c;注册mapper等等&#xff0c;我们需要对这个配置文件有详细的了解。 官网地址有详细介绍 mybatis – MyBatis 3 | 配置 2、properties属性 在通常的情况下&am…...

【Spring实战】17 REST服务介绍

文章目录 1. 为什么出现2. 拥有哪些优势3. Spring中的应用4. spring-boot-starter-data-rest总结 REST&#xff08;Representational State Transfer&#xff09;是一种软件架构风格&#xff0c;通常用于设计网络应用程序的服务接口。RESTful 服务是基于 REST 原则构建的网络服…...

java struts2教务管理系统Myeclipse开发mysql数据库struts2结构java编程计算机网页项目

一、源码特点 java struts2 教务管理系统 是一套完善的web设计系统&#xff0c;对理解JSP java编程开发语言有帮助 struts2 框架开发&#xff0c;系统具有完整的源代码和数据库&#xff0c;系统主要采用B/S模式开发。开发环境 为TOMCAT7.0,Myeclipse8.5开发&#xff0c;数据库…...

跟着cherno手搓游戏引擎【3】事件系统和预编译头文件

不多说了直接上代码&#xff0c;课程中的架构讲的比较宽泛&#xff0c;而且有些方法写完之后并未测试。所以先把代码写完。理解其原理&#xff0c;未来使用时候会再此完善此博客。 文件架构&#xff1a; Event.h:核心基类 #pragma once #include"../Core.h" #inclu…...

排序算法之快速排序

快速排序是一种高效的排序算法&#xff0c;它的基本思想是采用分治策略&#xff0c;将一个无序数组分割成两个子数组&#xff0c;分别对子数组进行排序&#xff0c;然后将两个排序好的子数组合并成一个有序数组。快速排序的性能优于归并排序&#xff0c;尤其在处理大规模数据时…...

Docker 从入门到实践:Docker介绍

前言 在当今的软件开发和部署领域&#xff0c;Docker已经成为了一个不可或缺的工具。Docker以其轻量级、可移植性和标准化等特点&#xff0c;使得应用程序的部署和管理变得前所未有的简单。无论您是一名开发者、系统管理员&#xff0c;还是IT架构师&#xff0c;理解并掌握Dock…...

用IDEA创建/同步到gitee(码云)远程仓库(保姆级详细)

前言&#xff1a; 笔者最近在学习java&#xff0c;最开始在用很笨的方法&#xff1a;先克隆远程仓库到本地&#xff0c;再把自己练习的代码从本地仓库上传到远程仓库&#xff0c;很是繁琐。后发现可以IDEA只需要做些操作可以直接把代码上传到远程仓库&#xff0c;也在网上搜了些…...

【Linux】进程控制深度了解

> 作者简介&#xff1a;დ旧言~&#xff0c;目前大二&#xff0c;现在学习Java&#xff0c;c&#xff0c;c&#xff0c;Python等 > 座右铭&#xff1a;松树千年终是朽&#xff0c;槿花一日自为荣。 > 目标&#xff1a;熟练掌握Linux下的进程控制 > 毒鸡汤&#xff…...

kbdnso.dll文件缺失,软件或游戏报错的快速修复方法

很多小伙伴遇到电脑报错&#xff0c;提示“kbdnso.dll文件缺失&#xff0c;程序无法启动执行”时&#xff0c;不知道应该怎样处理&#xff0c;还以为是程序出现了问题&#xff0c;想卸载重装。 首先&#xff0c;先要了解“kbdnso.dll文件”是什么&#xff1f; kbdnso.dll是Win…...

【Java学习笔记】Arrays类

Arrays 类 1. 导入包&#xff1a;import java.util.Arrays 2. 常用方法一览表 方法描述Arrays.toString()返回数组的字符串形式Arrays.sort()排序&#xff08;自然排序和定制排序&#xff09;Arrays.binarySearch()通过二分搜索法进行查找&#xff08;前提&#xff1a;数组是…...

无法与IP建立连接,未能下载VSCode服务器

如题&#xff0c;在远程连接服务器的时候突然遇到了这个提示。 查阅了一圈&#xff0c;发现是VSCode版本自动更新惹的祸&#xff01;&#xff01;&#xff01; 在VSCode的帮助->关于这里发现前几天VSCode自动更新了&#xff0c;我的版本号变成了1.100.3 才导致了远程连接出…...

[ICLR 2022]How Much Can CLIP Benefit Vision-and-Language Tasks?

论文网址&#xff1a;pdf 英文是纯手打的&#xff01;论文原文的summarizing and paraphrasing。可能会出现难以避免的拼写错误和语法错误&#xff0c;若有发现欢迎评论指正&#xff01;文章偏向于笔记&#xff0c;谨慎食用 目录 1. 心得 2. 论文逐段精读 2.1. Abstract 2…...

多模态商品数据接口:融合图像、语音与文字的下一代商品详情体验

一、多模态商品数据接口的技术架构 &#xff08;一&#xff09;多模态数据融合引擎 跨模态语义对齐 通过Transformer架构实现图像、语音、文字的语义关联。例如&#xff0c;当用户上传一张“蓝色连衣裙”的图片时&#xff0c;接口可自动提取图像中的颜色&#xff08;RGB值&…...

使用 Streamlit 构建支持主流大模型与 Ollama 的轻量级统一平台

🎯 使用 Streamlit 构建支持主流大模型与 Ollama 的轻量级统一平台 📌 项目背景 随着大语言模型(LLM)的广泛应用,开发者常面临多个挑战: 各大模型(OpenAI、Claude、Gemini、Ollama)接口风格不统一;缺乏一个统一平台进行模型调用与测试;本地模型 Ollama 的集成与前…...

宇树科技,改名了!

提到国内具身智能和机器人领域的代表企业&#xff0c;那宇树科技&#xff08;Unitree&#xff09;必须名列其榜。 最近&#xff0c;宇树科技的一项新变动消息在业界引发了不少关注和讨论&#xff0c;即&#xff1a; 宇树向其合作伙伴发布了一封公司名称变更函称&#xff0c;因…...

逻辑回归暴力训练预测金融欺诈

简述 「使用逻辑回归暴力预测金融欺诈&#xff0c;并不断增加特征维度持续测试」的做法&#xff0c;体现了一种逐步建模与迭代验证的实验思路&#xff0c;在金融欺诈检测中非常有价值&#xff0c;本文作为一篇回顾性记录了早年间公司给某行做反欺诈预测用到的技术和思路。百度…...

Qemu arm操作系统开发环境

使用qemu虚拟arm硬件比较合适。 步骤如下&#xff1a; 安装qemu apt install qemu-system安装aarch64-none-elf-gcc 需要手动下载&#xff0c;下载地址&#xff1a;https://developer.arm.com/-/media/Files/downloads/gnu/13.2.rel1/binrel/arm-gnu-toolchain-13.2.rel1-x…...

Kubernetes 网络模型深度解析:Pod IP 与 Service 的负载均衡机制,Service到底是什么?

Pod IP 的本质与特性 Pod IP 的定位 纯端点地址&#xff1a;Pod IP 是分配给 Pod 网络命名空间的真实 IP 地址&#xff08;如 10.244.1.2&#xff09;无特殊名称&#xff1a;在 Kubernetes 中&#xff0c;它通常被称为 “Pod IP” 或 “容器 IP”生命周期&#xff1a;与 Pod …...

数学建模-滑翔伞伞翼面积的设计,运动状态计算和优化 !

我们考虑滑翔伞的伞翼面积设计问题以及运动状态描述。滑翔伞的性能主要取决于伞翼面积、气动特性以及飞行员的重量。我们的目标是建立数学模型来描述滑翔伞的运动状态,并优化伞翼面积的设计。 一、问题分析 滑翔伞在飞行过程中受到重力、升力和阻力的作用。升力和阻力与伞翼面…...