当前位置: 首页 > news >正文

人工智能基础部分20-生成对抗网络(GAN)的实现应用

大家好,我是微学AI,今天给大家介绍一下人工智能基础部分20-生成对抗网络(GAN)的实现应用。生成对抗网络是一种由深度学习模型构成的神经网络系统,由一个生成器和一个判别器相互博弈来提升模型的能力。本文将从以下几个方面进行阐述:生成对抗网络的概念、GAN的原理、GAN的实验设计。

一、前言

随着近年来人工智能发展的不断加速,尤其是深度学习的出现,使得计算机视觉领域取得了许多重要突破。生成对抗网络(Generative Adversarial Networks, GAN)是其中一种具有广泛应用前景的技术。GAN是一种生成式模型,它的主要原理是通过博弈论的方式,将生成模型与判别模型进行对抗训练,从而实现生成图像、音频等数据的任务。本文将对GAN 的工作原理进行详细解释,并通过一个图像生成示例项目,展示如何使用 PyTorch 框架实现 GAN,并给出实验结果与完整代码。

二、生成对抗网络(GAN)原理

GAN的核心思想是让两个网络(生成器和判别器)进行博弈,最终迭代得到一个高质量的生成器。生成器的任务是生成与真实数据分布相近的伪数据,而判别器的任务则是判断输入数据是来源于真实数据还是伪数据。通过优化生成器与判别器的博弈过程,使得生成器逐渐改进,能够生成越来越接近真实数据的伪数据。

2.1 生成器

生成器的主要作用是以随机噪声为输入,输出生成的伪数据。随机噪声是一个高斯分布的向量,我们可以通过一个深度神经网络模型(如卷积神经网络、前馈神经网络等)将这个高斯分布的向量映射成我们想要输出的伪数据。

2.2 判别器

判别器是一个二分类神经网络模型,输入可能来自生成器也可能来自真实数据。其任务是对输入数据进行分类,输出一个概率值以判断输入数据是来自真实数据集还是生成器生成的伪数据。

2.3 博弈过程

生成器与判别器博弈的过程即是各自的训练过程。生成器训练的目标是使得判别器对其生成的数据预测为真实数据的概率最大;判别器训练的目标是使得自身对真实数据与生成的数据的分类准确率最高。通过反复迭代这个过程,最终生成器能够生成越来越接近真实数据的伪数据。

2.4 数学原理

生成对抗网络(Generative Adversarial Networks,简称 GAN)是一种基于博弈论的生成模型,其数学原理可以用以下公式表示:

假设p_{data}(x)表示真实数据的分布,p_z(z) 表示生成器输入随机噪声z 的分布,G(z;\theta_g)表示生成器的输出,其中 \theta_g是生成器的参数,D(x;\theta_d) 表示判别器的输出,其中\theta_d是判别器的参数。

GAN 的目标是最小化以下损失函数:

\min_G\max_D V(D,G) = \mathbb{E}{x \sim p_{data}(x)}[\log D(x)] + \mathbb{E}_{z \sim p_z(z)}[\log(1 - D(G(z)))]

其中 \mathbb{E} 表示期望值,\log表示自然对数。

这个损失函数的含义是:最小化生成器生成的数据与真实数据之间的差距,同时最大化判别器对生成器生成的数据和真实数据的区分度。具体来说,第一项\mathbb{E}{x \sim p{data}(x)}[\log D(x)]表示真实数据被判别为真实数据的概率,第二项 \mathbb{E}_{z \sim p_z(z)}[\log(1 - D(G(z)))] 表示生成器生成的虚构数据被判别为虚构数据的概率。

在训练过程中,GAN 会交替训练生成器和判别器,通过最小化损失函数 V(D,G)来优化模型参数。具体来说,对于每个训练迭代,我们首先固定生成器的参数,通过最大化损失函数V(D,G) 来优化判别器的参数。然后,我们固定判别器的参数,通过最小化损失函数V(D,G) 来优化生成器的参数。这个过程会一直迭代下去,直到达到预定的迭代次数或者损失函数收敛。

三、实验设计

本文使用 tensorflow  框架实现 GAN,并在图像生成任务上进行训练。实验workflow 分为以下五个步骤:数据准备\构建生成器与判别器\设置损失函数与优化器、训练过程,让我们先从数据准备开始。

四、代码实现

下面我们将使用MNIST(手写数字化)这一经典的数据集来展示GANs的实际应用效果。

import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers# 设置随机种子以获得可重现的结果
np.random.seed(42)
tf.random.set_seed(42)# 加载MNIST数据集
(x_train, y_train), (_, _) = keras.datasets.mnist.load_data()# 将数据规范化到[-1, 1]范围内
x_train = x_train.astype(np.float32) / 127.5 - 1# 将数据集重塑为(-1, 28, 28, 1)
x_train = np.expand_dims(x_train, axis=-1)# 创建生成器模型
def create_generator():generator = keras.Sequential()generator.add(layers.Dense(7 * 7 * 256, use_bias=False, input_shape=(100,)))generator.add(layers.BatchNormalization())generator.add(layers.LeakyReLU(alpha=0.2))generator.add(layers.Reshape((7, 7, 256)))generator.add(layers.Conv2DTranspose(128, (5, 5), strides=(1, 1), padding='same', use_bias = False))generator.add(layers.BatchNormalization())generator.add(layers.LeakyReLU(alpha=0.2))generator.add(layers.Conv2DTranspose(64, (5, 5), strides=(2, 2), padding='same', use_bias = False))generator.add(layers.BatchNormalization())generator.add(layers.LeakyReLU(alpha=0.2))generator.add(layers.Conv2DTranspose(1, (5, 5), strides=(2, 2), padding='same', use_bias = False, activation ='tanh'))return generatorgenerator = create_generator()# 创建鉴别器模型
def create_discriminator():discriminator = keras.Sequential()discriminator.add(layers.Conv2D(64, (5, 5), strides=(2, 2), padding='same', input_shape = (28, 28, 1)))discriminator.add(layers.LeakyReLU(alpha=0.2))discriminator.add(layers.BatchNormalization())discriminator.add(layers.Conv2D(128, (5, 5), strides=(2, 2), padding='same'))discriminator.add(layers.LeakyReLU(alpha=0.2))discriminator.add(layers.BatchNormalization())discriminator.add(layers.Flatten())discriminator.add(layers.Dropout(0.2))discriminator.add(layers.Dense(1, activation='sigmoid'))return discriminatordiscriminator = create_discriminator()# 编译鉴别器
discriminator_optimizer = keras.optimizers.Adam(lr=0.0002, beta_1=0.5)
discriminator.compile(optimizer=discriminator_optimizer, loss='binary_crossentropy', metrics = ['accuracy'])# 创建和编译整体GAN结构
discriminator.trainable = False
gan_input = keras.Input(shape=(100,))
gan_output = discriminator(generator(gan_input))
gan = keras.Model(gan_input, gan_output)gan_optimizer = keras.optimizers.Adam(lr=0.0002, beta_1=0.5)
gan.compile(optimizer=gan_optimizer, loss='binary_crossentropy')# 模型训练函数
def train_gan(epochs=100, batch_size=128):num_examples = x_train.shape[0]num_batches = num_examples // batch_sizefor epoch in range(epochs):for batch_idx in range(num_batches):noise = np.random.normal(size=(batch_size, 100))generated_images = generator.predict(noise)real_images = x_train[(batch_idx * batch_size):((batch_idx + 1) * batch_size)]all_images = np.concatenate([generated_images, real_images])labels = np.zeros(2 * batch_size)labels[batch_size:] = 1# 在噪声上加一点随机数,提高生成器的鲁棒性labels += 0.05 * np.random.rand(2 * batch_size)discriminator_loss = discriminator.train_on_batch(all_images, labels)noise = np.random.randn(batch_size, 100)misleading_targets = np.ones(batch_size)generator_loss = gan.train_on_batch(noise, misleading_targets)if (batch_idx + 1) % 50 == 0:print(f"Epoch:{epoch + 1}/{epochs} Batch:{batch_idx + 1}/{num_batches} Discriminator Loss: {discriminator_loss[0]} Generator Loss:{generator_loss}")train_gan()

以上实现了生成对抗网络是训练过程,实际中我们可以替换数据训练自己的数据模型。

相关文章:

人工智能基础部分20-生成对抗网络(GAN)的实现应用

大家好,我是微学AI,今天给大家介绍一下人工智能基础部分20-生成对抗网络(GAN)的实现应用。生成对抗网络是一种由深度学习模型构成的神经网络系统,由一个生成器和一个判别器相互博弈来提升模型的能力。本文将从以下几个方面进行阐述&#xff1…...

JavaScript表单事件(上篇)

目录 一、input: 当表单元素的值发生改变时触发,适用于大多数表单元素。 二、change: 当表单元素的值发生改变且失去焦点时触发,适用于输入框、下拉列表等。 三、submit: 当表单提交时触发,适用于 form 元素。 四、reset: 当表单重置时触…...

vb6 Webview2微软Edge Chromium内核执行JS取网页数据测速

微软Edge Chromium内核执行JS获取网页数据测试 ExcuteScript eval(document.body.innerHTML) from : https://www.163.com 采集的网页HTM字符串占用字节空间1.2MB ExcuteScript回调事件中取得JS执行结果,用时 54 毫秒 其中JSON转字符13.5209毫秒 jSON数据长度: 增…...

编码,Part 1:ASCII、汉字及 Unicode 标准

个人博客 编码的历史由来就懒得介绍了,只需要知道人类处理文本信息是以字符为基本单位,而计算机在最底层只认识 0/1,所以当计算机要为人类存储/呈现字符时,就需要有一个规则,在字符和 0/1 序列之间建立映射关系&#…...

C++ Eigen库矩阵操作

C Eigen库 序号功能例子1赋值Eigen::MatrixXf mat (12,1); \\% mat << 1, 2, 3, 4,5,6,7,8,9,10,11,12;2Inplace操作 \\% resizemat.resize(4, 3); \\% 1 5 9 \\% 2 6 10 \\% 3 7 11 \\% 4 8 123转置 \\% transposeInPlacemat.transposeInPlace(); \\% 1 2 3 4 \\% 5…...

Linux-0.11 boot目录bootsect.s详解

Linux-0.11 boot目录bootsect.s详解 模块简介 bootsect.s是磁盘启动的引导程序&#xff0c;其概括起来就是代码的搬运工&#xff0c;将代码搬到合适的位置。下图是对搬运过程的概括&#xff0c;可以有个印象&#xff0c;后面将详细讲解。 bootsect.s主要做了如下的三件事: 搬…...

django组件552

前言&#xff1a;相信看到这篇文章的小伙伴都或多或少有一些编程基础&#xff0c;懂得一些linux的基本命令了吧&#xff0c;本篇文章将带领大家服务器如何部署一个使用django框架开发的一个网站进行云服务器端的部署。 文章使用到的的工具 Python&#xff1a;一种编程语言&…...

【枚举算法的Java实现及其应用】

文章目录 枚举算法概述枚举算法的实现步骤Java实现枚举算法枚举算法的底层工作原理枚举算法的底层代码讲解枚举算法的实际应用场景枚举算法在场景中解决的问题总结 枚举算法概述 枚举算法是一种通过列举所有可能情况来解决问题的方法。这种算法在解决一些特定类型的问题时非常…...

linux led 驱动

前言 今天是儿童节&#xff0c;挣个奖牌给小孩玩玩。 在 linux 驱动大家庭中&#xff0c;LED 驱动算是个儿童&#xff0c;今天就写写他吧。正好之前写过他的婴儿时期《i.MX6ULL 裸机点亮 LED》&#xff0c;记得那时候他还穿着开裆裤呢&#xff0c;裸鸡嘛。 ioremap() 裸机程…...

平面最近点对(分治算法)

文章目录 平面最近点对&#xff08;分治算法&#xff09;Solution流程完整模板代码 平面最近点对&#xff08;分治算法&#xff09; 文章首发于我的个人博客&#xff1a;欢迎大佬们来逛逛 平面最近点对&#xff08;加强版&#xff09; - 洛谷 给你一些点&#xff0c;求两点之…...

【基于前后端分离的博客系统】Servlet版本

&#x1f389;&#x1f389;&#x1f389;点进来你就是我的人了博主主页&#xff1a;&#x1f648;&#x1f648;&#x1f648;戳一戳,欢迎大佬指点! 欢迎志同道合的朋友一起加油喔&#x1f93a;&#x1f93a;&#x1f93a; 目录 一. 项目简介 1. 项目背景 2. 项目用到的技…...

在线Excel绝配:SpreadJS 16.1.1+GcExcel 6.1.1 Crack

前端&#xff1a;SpreadJS 16.1.1 后端&#xff1a; GcExcel 6.1.1 全能 SpreadJS 16.1.1此版本的产品中包含以下功能和增强功能。 添加了各种输入掩码样式选项。 添加了在保护工作表时设置密码以及在取消保护时验证密码的支持。 增强了组合图以将其显示为仪表图。 添加了…...

一个轻量的登录鉴权工具Sa-Token 集成SpringBoot简要步骤

Sa-Token 集成SpringBoot简要步骤 1.1 简单介绍 Sa-Token是一个轻量级Java权限认证框架。 主要解决的问题如下&#xff1a; 登录认证 权限认证 单点登录 OAuth2.0 分布式Session会话 微服务网关鉴权等一系列权限相关问题。 1.2 登录认证 设计思路 对于一些登录之后…...

day 44 完全背包:518. 零钱兑换 II;377. 组合总和 Ⅳ

完全背包&#xff1a;物品可以使用多次 完全背包1. 与01背包区别 518. 零钱兑换 II1. dp数组以及下标名义2. 递归公式3. dp数组如何初始化4. 遍历顺序:不能颠倒两个for循环顺序5. 代码 377. 组合总和 Ⅳ:与零钱兑换类似&#xff0c;但是是求组合数1. dp数组以及下标名义2. 递归…...

K8s in Action 阅读笔记——【5】Services: enabling clients to discover and talk to pods

K8s in Action 阅读笔记——【5】Services: enabling clients to discover and talk to pods 你已了解Pod以及如何通过ReplicaSets等资源部署它们以确保持续运行。虽然某些Pod可以独立完成工作&#xff0c;但现今许多应用程序需要响应外部请求。例如&#xff0c;在微服务的情况…...

牛客网DAY2(编程题)

圣诞节来啦&#xff01;请用CSS给你的朋友们制作一颗圣诞树吧~这颗圣诞树描述起来是这样的&#xff1a; 1. "topbranch"是圣诞树的上枝叶&#xff0c;该上枝叶仅通过边框属性、左浮动、左外边距即可实现。边框的属性依次是&#xff1a;宽度为100px、是直线、颜色为gr…...

Java经典笔试题—day14

Java经典笔试题—day14 &#x1f50e;选择题&#x1f50e;编程题&#x1f36d;计算日期到天数转换&#x1f36d;幸运的袋子 &#x1f50e;结尾 &#x1f50e;选择题 (1)定义学生、教师和课程的关系模式 S (S#,Sn,Sd,Dc,SA &#xff09;&#xff08;其属性分别为学号、姓名、所…...

一个帮助写autoprefixer配置的网站

前端需要用到postcss的工具&#xff0c;用到一个插件叫autoprefixer&#xff0c;这个插件能够给css属性加上前缀&#xff0c;进行一些兼容的工作。 如何安装之类的问题在csdn上搜一下都能找到&#xff08;注意&#xff0c;vite是包含postcss的&#xff0c;不用在项目中安装pos…...

C语言中的类型转换

C语言中的类型转换 隐式类型转换 整型提升 概念&#xff1a; C语言的整型算术运算总是至少以缺省&#xff08;默认&#xff09;整型类型的精度来进行的为了获得这个精度&#xff0c;表达式中字符和短整型操作数在使用之前被转换为普通整型&#xff0c;这种转换成为整型提升 如…...

String底层详解(包括字符串常量池)

String a “abc”; &#xff0c;说一下这个过程会创建什么&#xff0c;放在哪里&#xff1f; JVM会使用常量池来管理字符串直接量。在执行这句话时&#xff0c;JVM会先检查常量池中是否已经存有"abc"&#xff0c;若没有则将"abc"存入常量池&#xff0c;否…...

Golang 面试经典题:map 的 key 可以是什么类型?哪些不可以?

Golang 面试经典题&#xff1a;map 的 key 可以是什么类型&#xff1f;哪些不可以&#xff1f; 在 Golang 的面试中&#xff0c;map 类型的使用是一个常见的考点&#xff0c;其中对 key 类型的合法性 是一道常被提及的基础却很容易被忽视的问题。本文将带你深入理解 Golang 中…...

在rocky linux 9.5上在线安装 docker

前面是指南&#xff0c;后面是日志 sudo dnf config-manager --add-repo https://download.docker.com/linux/centos/docker-ce.repo sudo dnf install docker-ce docker-ce-cli containerd.io -y docker version sudo systemctl start docker sudo systemctl status docker …...

微信小程序 - 手机震动

一、界面 <button type"primary" bindtap"shortVibrate">短震动</button> <button type"primary" bindtap"longVibrate">长震动</button> 二、js逻辑代码 注&#xff1a;文档 https://developers.weixin.qq…...

PL0语法,分析器实现!

简介 PL/0 是一种简单的编程语言,通常用于教学编译原理。它的语法结构清晰,功能包括常量定义、变量声明、过程(子程序)定义以及基本的控制结构(如条件语句和循环语句)。 PL/0 语法规范 PL/0 是一种教学用的小型编程语言,由 Niklaus Wirth 设计,用于展示编译原理的核…...

Rust 异步编程

Rust 异步编程 引言 Rust 是一种系统编程语言,以其高性能、安全性以及零成本抽象而著称。在多核处理器成为主流的今天,异步编程成为了一种提高应用性能、优化资源利用的有效手段。本文将深入探讨 Rust 异步编程的核心概念、常用库以及最佳实践。 异步编程基础 什么是异步…...

IT供电系统绝缘监测及故障定位解决方案

随着新能源的快速发展&#xff0c;光伏电站、储能系统及充电设备已广泛应用于现代能源网络。在光伏领域&#xff0c;IT供电系统凭借其持续供电性好、安全性高等优势成为光伏首选&#xff0c;但在长期运行中&#xff0c;例如老化、潮湿、隐裂、机械损伤等问题会影响光伏板绝缘层…...

在鸿蒙HarmonyOS 5中使用DevEco Studio实现录音机应用

1. 项目配置与权限设置 1.1 配置module.json5 {"module": {"requestPermissions": [{"name": "ohos.permission.MICROPHONE","reason": "录音需要麦克风权限"},{"name": "ohos.permission.WRITE…...

蓝桥杯3498 01串的熵

问题描述 对于一个长度为 23333333的 01 串, 如果其信息熵为 11625907.5798&#xff0c; 且 0 出现次数比 1 少, 那么这个 01 串中 0 出现了多少次? #include<iostream> #include<cmath> using namespace std;int n 23333333;int main() {//枚举 0 出现的次数//因…...

html css js网页制作成品——HTML+CSS榴莲商城网页设计(4页)附源码

目录 一、&#x1f468;‍&#x1f393;网站题目 二、✍️网站描述 三、&#x1f4da;网站介绍 四、&#x1f310;网站效果 五、&#x1fa93; 代码实现 &#x1f9f1;HTML 六、&#x1f947; 如何让学习不再盲目 七、&#x1f381;更多干货 一、&#x1f468;‍&#x1f…...

网站指纹识别

网站指纹识别 网站的最基本组成&#xff1a;服务器&#xff08;操作系统&#xff09;、中间件&#xff08;web容器&#xff09;、脚本语言、数据厍 为什么要了解这些&#xff1f;举个例子&#xff1a;发现了一个文件读取漏洞&#xff0c;我们需要读/etc/passwd&#xff0c;如…...