当前位置: 首页 > news >正文

人工智能基础部分20-生成对抗网络(GAN)的实现应用

大家好,我是微学AI,今天给大家介绍一下人工智能基础部分20-生成对抗网络(GAN)的实现应用。生成对抗网络是一种由深度学习模型构成的神经网络系统,由一个生成器和一个判别器相互博弈来提升模型的能力。本文将从以下几个方面进行阐述:生成对抗网络的概念、GAN的原理、GAN的实验设计。

一、前言

随着近年来人工智能发展的不断加速,尤其是深度学习的出现,使得计算机视觉领域取得了许多重要突破。生成对抗网络(Generative Adversarial Networks, GAN)是其中一种具有广泛应用前景的技术。GAN是一种生成式模型,它的主要原理是通过博弈论的方式,将生成模型与判别模型进行对抗训练,从而实现生成图像、音频等数据的任务。本文将对GAN 的工作原理进行详细解释,并通过一个图像生成示例项目,展示如何使用 PyTorch 框架实现 GAN,并给出实验结果与完整代码。

二、生成对抗网络(GAN)原理

GAN的核心思想是让两个网络(生成器和判别器)进行博弈,最终迭代得到一个高质量的生成器。生成器的任务是生成与真实数据分布相近的伪数据,而判别器的任务则是判断输入数据是来源于真实数据还是伪数据。通过优化生成器与判别器的博弈过程,使得生成器逐渐改进,能够生成越来越接近真实数据的伪数据。

2.1 生成器

生成器的主要作用是以随机噪声为输入,输出生成的伪数据。随机噪声是一个高斯分布的向量,我们可以通过一个深度神经网络模型(如卷积神经网络、前馈神经网络等)将这个高斯分布的向量映射成我们想要输出的伪数据。

2.2 判别器

判别器是一个二分类神经网络模型,输入可能来自生成器也可能来自真实数据。其任务是对输入数据进行分类,输出一个概率值以判断输入数据是来自真实数据集还是生成器生成的伪数据。

2.3 博弈过程

生成器与判别器博弈的过程即是各自的训练过程。生成器训练的目标是使得判别器对其生成的数据预测为真实数据的概率最大;判别器训练的目标是使得自身对真实数据与生成的数据的分类准确率最高。通过反复迭代这个过程,最终生成器能够生成越来越接近真实数据的伪数据。

2.4 数学原理

生成对抗网络(Generative Adversarial Networks,简称 GAN)是一种基于博弈论的生成模型,其数学原理可以用以下公式表示:

假设p_{data}(x)表示真实数据的分布,p_z(z) 表示生成器输入随机噪声z 的分布,G(z;\theta_g)表示生成器的输出,其中 \theta_g是生成器的参数,D(x;\theta_d) 表示判别器的输出,其中\theta_d是判别器的参数。

GAN 的目标是最小化以下损失函数:

\min_G\max_D V(D,G) = \mathbb{E}{x \sim p_{data}(x)}[\log D(x)] + \mathbb{E}_{z \sim p_z(z)}[\log(1 - D(G(z)))]

其中 \mathbb{E} 表示期望值,\log表示自然对数。

这个损失函数的含义是:最小化生成器生成的数据与真实数据之间的差距,同时最大化判别器对生成器生成的数据和真实数据的区分度。具体来说,第一项\mathbb{E}{x \sim p{data}(x)}[\log D(x)]表示真实数据被判别为真实数据的概率,第二项 \mathbb{E}_{z \sim p_z(z)}[\log(1 - D(G(z)))] 表示生成器生成的虚构数据被判别为虚构数据的概率。

在训练过程中,GAN 会交替训练生成器和判别器,通过最小化损失函数 V(D,G)来优化模型参数。具体来说,对于每个训练迭代,我们首先固定生成器的参数,通过最大化损失函数V(D,G) 来优化判别器的参数。然后,我们固定判别器的参数,通过最小化损失函数V(D,G) 来优化生成器的参数。这个过程会一直迭代下去,直到达到预定的迭代次数或者损失函数收敛。

三、实验设计

本文使用 tensorflow  框架实现 GAN,并在图像生成任务上进行训练。实验workflow 分为以下五个步骤:数据准备\构建生成器与判别器\设置损失函数与优化器、训练过程,让我们先从数据准备开始。

四、代码实现

下面我们将使用MNIST(手写数字化)这一经典的数据集来展示GANs的实际应用效果。

import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers# 设置随机种子以获得可重现的结果
np.random.seed(42)
tf.random.set_seed(42)# 加载MNIST数据集
(x_train, y_train), (_, _) = keras.datasets.mnist.load_data()# 将数据规范化到[-1, 1]范围内
x_train = x_train.astype(np.float32) / 127.5 - 1# 将数据集重塑为(-1, 28, 28, 1)
x_train = np.expand_dims(x_train, axis=-1)# 创建生成器模型
def create_generator():generator = keras.Sequential()generator.add(layers.Dense(7 * 7 * 256, use_bias=False, input_shape=(100,)))generator.add(layers.BatchNormalization())generator.add(layers.LeakyReLU(alpha=0.2))generator.add(layers.Reshape((7, 7, 256)))generator.add(layers.Conv2DTranspose(128, (5, 5), strides=(1, 1), padding='same', use_bias = False))generator.add(layers.BatchNormalization())generator.add(layers.LeakyReLU(alpha=0.2))generator.add(layers.Conv2DTranspose(64, (5, 5), strides=(2, 2), padding='same', use_bias = False))generator.add(layers.BatchNormalization())generator.add(layers.LeakyReLU(alpha=0.2))generator.add(layers.Conv2DTranspose(1, (5, 5), strides=(2, 2), padding='same', use_bias = False, activation ='tanh'))return generatorgenerator = create_generator()# 创建鉴别器模型
def create_discriminator():discriminator = keras.Sequential()discriminator.add(layers.Conv2D(64, (5, 5), strides=(2, 2), padding='same', input_shape = (28, 28, 1)))discriminator.add(layers.LeakyReLU(alpha=0.2))discriminator.add(layers.BatchNormalization())discriminator.add(layers.Conv2D(128, (5, 5), strides=(2, 2), padding='same'))discriminator.add(layers.LeakyReLU(alpha=0.2))discriminator.add(layers.BatchNormalization())discriminator.add(layers.Flatten())discriminator.add(layers.Dropout(0.2))discriminator.add(layers.Dense(1, activation='sigmoid'))return discriminatordiscriminator = create_discriminator()# 编译鉴别器
discriminator_optimizer = keras.optimizers.Adam(lr=0.0002, beta_1=0.5)
discriminator.compile(optimizer=discriminator_optimizer, loss='binary_crossentropy', metrics = ['accuracy'])# 创建和编译整体GAN结构
discriminator.trainable = False
gan_input = keras.Input(shape=(100,))
gan_output = discriminator(generator(gan_input))
gan = keras.Model(gan_input, gan_output)gan_optimizer = keras.optimizers.Adam(lr=0.0002, beta_1=0.5)
gan.compile(optimizer=gan_optimizer, loss='binary_crossentropy')# 模型训练函数
def train_gan(epochs=100, batch_size=128):num_examples = x_train.shape[0]num_batches = num_examples // batch_sizefor epoch in range(epochs):for batch_idx in range(num_batches):noise = np.random.normal(size=(batch_size, 100))generated_images = generator.predict(noise)real_images = x_train[(batch_idx * batch_size):((batch_idx + 1) * batch_size)]all_images = np.concatenate([generated_images, real_images])labels = np.zeros(2 * batch_size)labels[batch_size:] = 1# 在噪声上加一点随机数,提高生成器的鲁棒性labels += 0.05 * np.random.rand(2 * batch_size)discriminator_loss = discriminator.train_on_batch(all_images, labels)noise = np.random.randn(batch_size, 100)misleading_targets = np.ones(batch_size)generator_loss = gan.train_on_batch(noise, misleading_targets)if (batch_idx + 1) % 50 == 0:print(f"Epoch:{epoch + 1}/{epochs} Batch:{batch_idx + 1}/{num_batches} Discriminator Loss: {discriminator_loss[0]} Generator Loss:{generator_loss}")train_gan()

以上实现了生成对抗网络是训练过程,实际中我们可以替换数据训练自己的数据模型。

相关文章:

人工智能基础部分20-生成对抗网络(GAN)的实现应用

大家好,我是微学AI,今天给大家介绍一下人工智能基础部分20-生成对抗网络(GAN)的实现应用。生成对抗网络是一种由深度学习模型构成的神经网络系统,由一个生成器和一个判别器相互博弈来提升模型的能力。本文将从以下几个方面进行阐述&#xff1…...

JavaScript表单事件(上篇)

目录 一、input: 当表单元素的值发生改变时触发,适用于大多数表单元素。 二、change: 当表单元素的值发生改变且失去焦点时触发,适用于输入框、下拉列表等。 三、submit: 当表单提交时触发,适用于 form 元素。 四、reset: 当表单重置时触…...

vb6 Webview2微软Edge Chromium内核执行JS取网页数据测速

微软Edge Chromium内核执行JS获取网页数据测试 ExcuteScript eval(document.body.innerHTML) from : https://www.163.com 采集的网页HTM字符串占用字节空间1.2MB ExcuteScript回调事件中取得JS执行结果,用时 54 毫秒 其中JSON转字符13.5209毫秒 jSON数据长度: 增…...

编码,Part 1:ASCII、汉字及 Unicode 标准

个人博客 编码的历史由来就懒得介绍了,只需要知道人类处理文本信息是以字符为基本单位,而计算机在最底层只认识 0/1,所以当计算机要为人类存储/呈现字符时,就需要有一个规则,在字符和 0/1 序列之间建立映射关系&#…...

C++ Eigen库矩阵操作

C Eigen库 序号功能例子1赋值Eigen::MatrixXf mat (12,1); \\% mat << 1, 2, 3, 4,5,6,7,8,9,10,11,12;2Inplace操作 \\% resizemat.resize(4, 3); \\% 1 5 9 \\% 2 6 10 \\% 3 7 11 \\% 4 8 123转置 \\% transposeInPlacemat.transposeInPlace(); \\% 1 2 3 4 \\% 5…...

Linux-0.11 boot目录bootsect.s详解

Linux-0.11 boot目录bootsect.s详解 模块简介 bootsect.s是磁盘启动的引导程序&#xff0c;其概括起来就是代码的搬运工&#xff0c;将代码搬到合适的位置。下图是对搬运过程的概括&#xff0c;可以有个印象&#xff0c;后面将详细讲解。 bootsect.s主要做了如下的三件事: 搬…...

django组件552

前言&#xff1a;相信看到这篇文章的小伙伴都或多或少有一些编程基础&#xff0c;懂得一些linux的基本命令了吧&#xff0c;本篇文章将带领大家服务器如何部署一个使用django框架开发的一个网站进行云服务器端的部署。 文章使用到的的工具 Python&#xff1a;一种编程语言&…...

【枚举算法的Java实现及其应用】

文章目录 枚举算法概述枚举算法的实现步骤Java实现枚举算法枚举算法的底层工作原理枚举算法的底层代码讲解枚举算法的实际应用场景枚举算法在场景中解决的问题总结 枚举算法概述 枚举算法是一种通过列举所有可能情况来解决问题的方法。这种算法在解决一些特定类型的问题时非常…...

linux led 驱动

前言 今天是儿童节&#xff0c;挣个奖牌给小孩玩玩。 在 linux 驱动大家庭中&#xff0c;LED 驱动算是个儿童&#xff0c;今天就写写他吧。正好之前写过他的婴儿时期《i.MX6ULL 裸机点亮 LED》&#xff0c;记得那时候他还穿着开裆裤呢&#xff0c;裸鸡嘛。 ioremap() 裸机程…...

平面最近点对(分治算法)

文章目录 平面最近点对&#xff08;分治算法&#xff09;Solution流程完整模板代码 平面最近点对&#xff08;分治算法&#xff09; 文章首发于我的个人博客&#xff1a;欢迎大佬们来逛逛 平面最近点对&#xff08;加强版&#xff09; - 洛谷 给你一些点&#xff0c;求两点之…...

【基于前后端分离的博客系统】Servlet版本

&#x1f389;&#x1f389;&#x1f389;点进来你就是我的人了博主主页&#xff1a;&#x1f648;&#x1f648;&#x1f648;戳一戳,欢迎大佬指点! 欢迎志同道合的朋友一起加油喔&#x1f93a;&#x1f93a;&#x1f93a; 目录 一. 项目简介 1. 项目背景 2. 项目用到的技…...

在线Excel绝配:SpreadJS 16.1.1+GcExcel 6.1.1 Crack

前端&#xff1a;SpreadJS 16.1.1 后端&#xff1a; GcExcel 6.1.1 全能 SpreadJS 16.1.1此版本的产品中包含以下功能和增强功能。 添加了各种输入掩码样式选项。 添加了在保护工作表时设置密码以及在取消保护时验证密码的支持。 增强了组合图以将其显示为仪表图。 添加了…...

一个轻量的登录鉴权工具Sa-Token 集成SpringBoot简要步骤

Sa-Token 集成SpringBoot简要步骤 1.1 简单介绍 Sa-Token是一个轻量级Java权限认证框架。 主要解决的问题如下&#xff1a; 登录认证 权限认证 单点登录 OAuth2.0 分布式Session会话 微服务网关鉴权等一系列权限相关问题。 1.2 登录认证 设计思路 对于一些登录之后…...

day 44 完全背包:518. 零钱兑换 II;377. 组合总和 Ⅳ

完全背包&#xff1a;物品可以使用多次 完全背包1. 与01背包区别 518. 零钱兑换 II1. dp数组以及下标名义2. 递归公式3. dp数组如何初始化4. 遍历顺序:不能颠倒两个for循环顺序5. 代码 377. 组合总和 Ⅳ:与零钱兑换类似&#xff0c;但是是求组合数1. dp数组以及下标名义2. 递归…...

K8s in Action 阅读笔记——【5】Services: enabling clients to discover and talk to pods

K8s in Action 阅读笔记——【5】Services: enabling clients to discover and talk to pods 你已了解Pod以及如何通过ReplicaSets等资源部署它们以确保持续运行。虽然某些Pod可以独立完成工作&#xff0c;但现今许多应用程序需要响应外部请求。例如&#xff0c;在微服务的情况…...

牛客网DAY2(编程题)

圣诞节来啦&#xff01;请用CSS给你的朋友们制作一颗圣诞树吧~这颗圣诞树描述起来是这样的&#xff1a; 1. "topbranch"是圣诞树的上枝叶&#xff0c;该上枝叶仅通过边框属性、左浮动、左外边距即可实现。边框的属性依次是&#xff1a;宽度为100px、是直线、颜色为gr…...

Java经典笔试题—day14

Java经典笔试题—day14 &#x1f50e;选择题&#x1f50e;编程题&#x1f36d;计算日期到天数转换&#x1f36d;幸运的袋子 &#x1f50e;结尾 &#x1f50e;选择题 (1)定义学生、教师和课程的关系模式 S (S#,Sn,Sd,Dc,SA &#xff09;&#xff08;其属性分别为学号、姓名、所…...

一个帮助写autoprefixer配置的网站

前端需要用到postcss的工具&#xff0c;用到一个插件叫autoprefixer&#xff0c;这个插件能够给css属性加上前缀&#xff0c;进行一些兼容的工作。 如何安装之类的问题在csdn上搜一下都能找到&#xff08;注意&#xff0c;vite是包含postcss的&#xff0c;不用在项目中安装pos…...

C语言中的类型转换

C语言中的类型转换 隐式类型转换 整型提升 概念&#xff1a; C语言的整型算术运算总是至少以缺省&#xff08;默认&#xff09;整型类型的精度来进行的为了获得这个精度&#xff0c;表达式中字符和短整型操作数在使用之前被转换为普通整型&#xff0c;这种转换成为整型提升 如…...

String底层详解(包括字符串常量池)

String a “abc”; &#xff0c;说一下这个过程会创建什么&#xff0c;放在哪里&#xff1f; JVM会使用常量池来管理字符串直接量。在执行这句话时&#xff0c;JVM会先检查常量池中是否已经存有"abc"&#xff0c;若没有则将"abc"存入常量池&#xff0c;否…...

借助快马平台优化蓝桥杯python解题代码,提升算法执行效率

最近在准备蓝桥杯Python组的比赛&#xff0c;发现很多题目对算法效率要求很高。就拿经典的"最大子序列和"问题来说&#xff0c;不同的解法效率差异巨大。今天分享一下我是如何借助InsCode(快马)平台来快速验证不同解法的效率的。 问题理解 最大子序列和问题要求在一个…...

Agent--多轮对话系统设计6道高频考题解析

去年面试某大厂AI岗位&#xff0c;多轮对话这块被追问了好几道题&#xff0c;有些问题当时答得磕磕绊绊&#xff0c;回来后我把相关知识点重新梳理了一遍。这次复盘把面试中遇到的核心问题分享出来&#xff0c;希望对准备面试的同学有点帮助。真题现场&#xff1a; 面试刚开始&…...

大数据领域数据预处理:优化数据分析结果的关键环节

大数据领域数据预处理:优化数据分析结果的关键环节 关键词:大数据、数据预处理、数据分析、优化、关键环节 摘要:本文深入探讨了大数据领域中数据预处理这一优化数据分析结果的关键环节。详细介绍了数据预处理的背景知识,包括目的、范围、预期读者等。通过生动形象的比喻解…...

智慧树自动学习助手:三分钟实现高效网课学习的完整指南

智慧树自动学习助手&#xff1a;三分钟实现高效网课学习的完整指南 【免费下载链接】zhihuishu 智慧树刷课插件&#xff0c;自动播放下一集、1.5倍速度、无声 项目地址: https://gitcode.com/gh_mirrors/zh/zhihuishu 还在为智慧树平台冗长的网课视频而烦恼吗&#xff1…...

别再只盯着CAN了!聊聊LIN总线在低成本IoT传感器网络里的那些‘骚操作’

LIN总线在低成本IoT传感器网络中的创新实践 当谈到工业物联网和传感器网络通信协议时&#xff0c;大多数人会立刻想到CAN、Modbus或以太网协议。但有一个被严重低估的选项正在悄然崛起——LIN总线。这个原本为汽车电子设计的轻量级协议&#xff0c;凭借其独特的成本优势和简洁架…...

SPSS加权处理实战:广告效果分析中的权重设置技巧(附详细步骤)

SPSS加权处理实战&#xff1a;广告效果分析中的权重设置技巧&#xff08;附详细步骤&#xff09; 当市场部门拿着厚厚一叠广告效果调研数据来找你时&#xff0c;最头疼的往往不是分析本身&#xff0c;而是那些看似简单却暗藏玄机的原始数据。上个月我就遇到这样一个案例&#x…...

高通平台实战:手把手教你解析和修改CDT中的board-id(附常见报错排查)

高通平台深度实战&#xff1a;CDT中board-id的解析与定制化修改指南 引言&#xff1a;为什么需要关注board-id&#xff1f; 在Android底层开发中&#xff0c;board-id就像设备的"身份证号"&#xff0c;它决定了系统如何识别硬件配置并加载对应的设备树和驱动。对于从…...

MySQL高频面试题(2026最新版):覆盖90%考点,小白也能直接背

很多开发者备考时&#xff0c;要么盲目刷题、记不住重点&#xff0c;要么只背答案、不懂原理&#xff0c;面试时被面试官追问一句就卡壳。其实MySQL面试没有那么复杂&#xff0c;核心考点就那么多&#xff0c;只要吃透高频题、理解底层逻辑&#xff0c;就能从容应对。本文整理了…...

Java调用动态库总崩溃?从SIGSEGV日志反向定位到C端ABI兼容性缺陷——一线故障复盘(含GDB+Java Core联合调试全流程)

第一章&#xff1a;Java调用动态库总崩溃&#xff1f;从SIGSEGV日志反向定位到C端ABI兼容性缺陷——一线故障复盘&#xff08;含GDBJava Core联合调试全流程&#xff09;某金融风控系统在JDK 17 Alpine Linux&#xff08;musl libc&#xff09;环境下频繁触发 JVM Crash&#…...

循环冷却水流量示意图设计 建筑水流量示意图绘制教程

一、引言 在建筑给排水、暖通空调及工业循环水系统设计中&#xff0c;循环冷却水流量示意图与建筑水流量示意图是核心技术图纸之一&#xff0c;其作用是直观呈现水流路径、管径规格、流量分配、设备连接关系及压力节点参数&#xff0c;为系统施工、调试、运维及故障排查提供可…...