生成对抗网络入门案例
前言
生成对抗网络(Generative Adversarial Networks,简称GANs)是一种用于生成新样本的机器学习模型。它由两个主要组件组成:生成器(Generator)和判别器(Discriminator)。生成器尝试生成与训练数据相似的新样本,而判别器则试图区分生成器生成的样本和真实训练数据。
下面是一个简单的对抗生成网络的入门例子,用于生成手写数字图像:
实现过程
1、导入必要的库和模块
import numpy as np
import matplotlib.pyplot as plt
from tensorflow.keras.datasets import mnist
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Flatten, Reshape
from tensorflow.keras.layers import Conv2D, Conv2DTranspose
from tensorflow.keras.optimizers import Adam
2、加载MNIST数据集
(x_train, _), (_, _) = mnist.load_data()
x_train = x_train / 255.0
x_train = np.expand_dims(x_train, axis=3)
3、定义生成器模型
generator = Sequential()
generator.add(Dense(7*7*128, input_shape=(100,), activation='relu'))
generator.add(Reshape((7, 7, 128)))
generator.add(Conv2DTranspose(64, (3, 3), strides=(2, 2), padding='same', activation='relu'))
generator.add(Conv2DTranspose(1, (3, 3), strides=(2, 2), padding='same', activation='sigmoid'))
4、定义判别器模型
discriminator = Sequential()
discriminator.add(Conv2D(64, (3, 3), strides=(2, 2), padding='same', input_shape=(28, 28, 1), activation='relu'))
discriminator.add(Conv2D(128, (3, 3), strides=(2, 2), padding='same', activation='relu'))
discriminator.add(Flatten())
discriminator.add(Dense(1, activation='sigmoid'))
5、编译判别器模型
discriminator.compile(loss='binary_crossentropy', optimizer=Adam(learning_rate=0.0002, beta_1=0.5), metrics=['accuracy'])
6、冻结判别器模型的权重
discriminator.trainable = False
7、定义GAN模型
gan = Sequential()
gan.add(generator)
gan.add(discriminator)
8、编译GAN模型
gan.compile(loss='binary_crossentropy', optimizer=Adam(learning_rate=0.0002, beta_1=0.5))
9、定义训练函数
def train_gan(epochs, batch_size, sample_interval):for epoch in range(epochs):# 生成随机噪声作为输入noise = np.random.normal(0, 1, (batch_size, 100))# 生成假样本generated_images = generator.predict(noise)# 从真实样本中随机选择一批样本real_images = x_train[np.random.randint(0, x_train.shape[0], batch_size)]# 训练判别器discriminator_loss_real = discriminator.train_on_batch(real_images, np.ones((batch_size, 1)))discriminator_loss_fake = discriminator.train_on_batch(generated_images, np.zeros((batch_size, 1)))discriminator_loss = 0.5 * np.add(discriminator_loss_real, discriminator_loss_fake)# 训练生成器noise = np.random.normal(0, 1, (batch_size, 100))generator_loss = gan.train_on_batch(noise, np.ones((batch_size, 1)))# 打印损失if epoch % sample_interval == 0:print(f"Epoch {epoch}/{epochs}, Discriminator Loss: {discriminator_loss[0]}, Generator Loss: {generator_loss}")# 保存生成的图像save_images(epoch)
10、保存生成的图像
def save_images(epoch):rows, cols = 5, 5noise = np.random.normal(0, 1, (rows * cols, 100))generated_images = generator.predict(noise)generated_images = 0.5 * generated_images + 0.5fig, axs = plt.subplots(rows, cols)idx = 0for i in range(rows):for j in range(cols):axs[i, j].imshow(generated_images[idx, :, :, 0], cmap='gray')axs[i, j].axis('off')idx += 1fig.savefig(f"gan_images/mnist_{epoch}.png")plt.close()
11、训练GAN模型
epochs = 10000
batch_size = 128
sample_interval = 1000
完整代码
import numpy as np
import matplotlib.pyplot as plt
from tensorflow.keras.datasets import mnist
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Flatten, Reshape
from tensorflow.keras.layers import Conv2D, Conv2DTranspose
from tensorflow.keras.optimizers import Adam# 加载MNIST数据集
(x_train, _), (_, _) = mnist.load_data()
x_train = x_train / 255.0
x_train = np.expand_dims(x_train, axis=3)# 定义生成器模型
generator = Sequential()
generator.add(Dense(7*7*128, input_shape=(100,), activation='relu'))
generator.add(Reshape((7, 7, 128)))
generator.add(Conv2DTranspose(64, (3, 3), strides=(2, 2), padding='same', activation='relu'))
generator.add(Conv2DTranspose(1, (3, 3), strides=(2, 2), padding='same', activation='sigmoid'))# 定义判别器模型
discriminator = Sequential()
discriminator.add(Conv2D(64, (3, 3), strides=(2, 2), padding='same', input_shape=(28, 28, 1), activation='relu'))
discriminator.add(Conv2D(128, (3, 3), strides=(2, 2), padding='same', activation='relu'))
discriminator.add(Flatten())
discriminator.add(Dense(1, activation='sigmoid'))# 编译判别器模型
discriminator.compile(loss='binary_crossentropy', optimizer=Adam(learning_rate=0.0002, beta_1=0.5), metrics=['accuracy'])# 冻结判别器模型的权重
discriminator.trainable = False# 定义GAN模型
gan = Sequential()
gan.add(generator)
gan.add(discriminator)# 编译GAN模型
gan.compile(loss='binary_crossentropy', optimizer=Adam(learning_rate=0.0002, beta_1=0.5))# 定义训练函数
def train_gan(epochs, batch_size, sample_interval):for epoch in range(epochs):# 生成随机噪声作为输入noise = np.random.normal(0, 1, (batch_size, 100))# 生成假样本generated_images = generator.predict(noise)# 从真实样本中随机选择一批样本real_images = x_train[np.random.randint(0, x_train.shape[0], batch_size)]# 训练判别器discriminator_loss_real = discriminator.train_on_batch(real_images, np.ones((batch_size, 1)))discriminator_loss_fake = discriminator.train_on_batch(generated_images, np.zeros((batch_size, 1)))discriminator_loss = 0.5 * np.add(discriminator_loss_real, discriminator_loss_fake)# 训练生成器noise = np.random.normal(0, 1, (batch_size, 100))generator_loss = gan.train_on_batch(noise, np.ones((batch_size, 1)))# 打印损失if epoch % sample_interval == 0:print(f"Epoch {epoch}/{epochs}, Discriminator Loss: {discriminator_loss[0]}, Generator Loss: {generator_loss}")# 保存生成的图像save_images(epoch)# 保存生成的图像
def save_images(epoch):rows, cols = 5, 5noise = np.random.normal(0, 1, (rows * cols, 100))generated_images = generator.predict(noise)generated_images = 0.5 * generated_images + 0.5fig, axs = plt.subplots(rows, cols)idx = 0for i in range(rows):for j in range(cols):axs[i, j].imshow(generated_images[idx, :, :, 0], cmap='gray')axs[i, j].axis('off')idx += 1fig.savefig(f"gan_images/mnist_{epoch}.png")plt.close()# 训练GAN模型
epochs = 10000
batch_size = 128
sample_interval = 1000train_gan(epochs, batch_size, sample_interval)
训练结果:

这个例子使用了MNIST数据集,生成手写数字图像。生成器和判别器模型使用了卷积神经网络的结构。在训练过程中,生成器试图生成逼真的手写数字图像,而判别器则试图区分真实图像和生成图像。通过反复迭代训练生成器和判别器,GAN模型能够逐渐生成更逼真的手写数字图像。生成的图像会保存在gan_images文件夹中。
相关文章:
生成对抗网络入门案例
前言 生成对抗网络(Generative Adversarial Networks,简称GANs)是一种用于生成新样本的机器学习模型。它由两个主要组件组成:生成器(Generator)和判别器(Discriminator)。生成器尝试…...
多头注意力机制
1、什么是多头注意力机制 从多头注意力的结构图中,貌似这个所谓的多个头就是指多组线性变换,但是并不是,只使用了一组线性变换层,即三个变换张量对 Q、K、V 分别进行线性变换,这些变化不会改变原有张量的尺寸…...
Qt + FFmpeg 搭建 Windows 开发环境
Qt FFmpeg 搭建 Windows 开发环境 Qt FFmpeg 搭建 Windows 开发环境安装 Qt Creator下载 FFmpeg 编译包测试 Qt FFmpeg踩坑解决方法1:换一个 FFmpeg 库解决方法2:把项目改成 64 位 后记 官方博客:https://www.yafeilinux.com/ Qt开源社区…...
[网鼎杯 2020 白虎组]PicDown python反弹shell proc/self目录的信息
[网鼎杯 2020 白虎组]PicDown - 知乎 这里确实完全不会 第一次遇到一个只有文件读取思路的题目 这里也确实说明还是要学学一些其他的东西了 首先打开环境 只存在一个框框 我们通过 目录扫描 抓包 注入 发现没有用 我们测试能不能任意文件读取 ?url../../../../etc/passwd …...
SDL2绘制ffmpeg解析的mp4文件
文章目录 1.FFMPEG利用命令行将mp4转yuv4202.ffmpeg将mp4解析为yuv数据2.1 核心api: 3.SDL2进行yuv绘制到屏幕3.1 核心api 4.完整代码5.效果展示6.SDL2事件响应补充6.1 处理方式-016.2 处理方式-02 本项目采用生产者消费者模型,生产者线程:使用ffmpeg将m…...
决策树C4.5算法的技术深度剖析、实战解读
目录 一、简介决策树(Decision Tree)例子: 信息熵(Information Entropy)与信息增益(Information Gain)例子: 信息增益比(Gain Ratio)例子: 二、算…...
LLMs Python解释器程序辅助语言模型(PAL)Program-aided language models (PAL)
正如您在本课程早期看到的,LLM执行算术和其他数学运算的能力是有限的。虽然您可以尝试使用链式思维提示来克服这一问题,但它只能帮助您走得更远。即使模型正确地通过了问题的推理,对于较大的数字或复杂的运算,它仍可能在个别数学操…...
【12】c++设计模式——>单例模式练习(任务队列)
属性: (1)存储任务的容器,这个容器可以选择使用STL中的队列(queue) (2)互斥锁,多线程访问的时候用于保护任务队列中的数据 方法:主要是对任务队列中的任务进行操作 &…...
Python之函数、模块、包库
函数、模块、包库基础概念和作用 A、函数 减少代码重复 将复杂问题代码分解成简单模块 提高代码可读性 复用老代码 """ 函数 """# 定义一个函数 def my_fuvtion():# 函数执行部分print(这是一个函数)# 定义带有参数的函数 def say_hello(n…...
SQL创建与删除索引
索引创建、删除与使用: 1.1 create方式创建索引:CREATE [UNIQUE – 唯一索引 | FULLTEXT – 全文索引 ] INDEX index_name ON table_name – 不指定唯一或全文时默认普通索引 (column1[(length) [DESC|ASC]] [,column2,…]) – 可以对多列建立组合索引 …...
网络协议--链路层
2.1 引言 从图1-4中可以看出,在TCP/IP协议族中,链路层主要有三个目的: (1)为IP模块发送和接收IP数据报; (2)为ARP模块发送ARP请求和接收ARP应答; (3…...
HDLbits: Count clock
目前写过最长的verilog代码,用了将近三个小时,编写12h显示的时钟,改来改去,估计只有我自己看得懂(吐血) module top_module(input clk,input reset,input ena,output pm,output [7:0] hh,output [7:0] mm,…...
【1day】用友移动管理系统任意文件上传漏洞学习
注:该文章来自作者日常学习笔记,请勿利用文章内的相关技术从事非法测试,如因此产生的一切不良后果与作者无关。 目录 一、漏洞描述 二、影响版本 三、资产测绘 四、漏洞复现...
【c++】向webrtc学习容器操作
std::map的key为std::pair 时的查找 std::map<RemoteAndLocalNetworkId, size_t> in_flight_bytes_RTC_GUARDED_BY(&lock_);private:using RemoteAndLocalNetworkId = std::pair<uint16_t, uint16_t...
SpringBoot+Vue3外卖项目构思
SpringBoot的学习: SpringBoot的学习_明里灰的博客-CSDN博客 实现功能 前台 用户注册,邮箱登录,地址管理,历史订单,菜品规格,购物车,下单,菜品浏览,评价,…...
【AI视野·今日NLP 自然语言处理论文速览 第四十七期】Wed, 4 Oct 2023
AI视野今日CS.NLP 自然语言处理论文速览 Wed, 4 Oct 2023 Totally 73 papers 👉上期速览✈更多精彩请移步主页 Daily Computation and Language Papers Contrastive Post-training Large Language Models on Data Curriculum Authors Canwen Xu, Corby Rosset, Luc…...
c++的lambda表达式
文章目录 1 lambda表达式2 捕捉列表 vs 参数列表3 lambda表达式的传递3.1 函数作为形参3.2 场景1:条件表达式3.3 场景2:线程的运行表达式 1 lambda表达式 lambda表达式可以理解为匿名函数,也就是没有名字的函数,既然是函数&#…...
电梯安全监测丨S271W无线水浸传感器用于电梯机房/电梯基坑水浸监测
城市化进程中,电梯与我们的生活息息相关。高层住宅、医院、商场、学校、车站等各种商业体建筑、公共建筑中电梯为我们生活工作提供了诸多便利。 保障电梯系统的安全至关重要!特别是电梯机房和电梯基坑可通过智能化改造提高其安全性和稳定性。例如在暴风…...
Java异常:基本概念、分类和处理
Java异常:基本概念、分类和处理 在Java编程中,异常处理是一个非常重要的部分。了解如何识别、处理和避免异常对于编写健壮、可维护的代码至关重要。本文将介绍Java异常的基本概念、分类和处理方法,并通过简单的代码示例进行说明。 一、什么…...
小谈设计模式(19)—备忘录模式
小谈设计模式(19)—备忘录模式 专栏介绍专栏地址专栏介绍 备忘录模式主要角色发起人(Originator)备忘录(Memento)管理者(Caretaker) 应用场景结构实现步骤Java程序实现首先ÿ…...
LBE-LEX系列工业语音播放器|预警播报器|喇叭蜂鸣器的上位机配置操作说明
LBE-LEX系列工业语音播放器|预警播报器|喇叭蜂鸣器专为工业环境精心打造,完美适配AGV和无人叉车。同时,集成以太网与语音合成技术,为各类高级系统(如MES、调度系统、库位管理、立库等)提供高效便捷的语音交互体验。 L…...
vscode里如何用git
打开vs终端执行如下: 1 初始化 Git 仓库(如果尚未初始化) git init 2 添加文件到 Git 仓库 git add . 3 使用 git commit 命令来提交你的更改。确保在提交时加上一个有用的消息。 git commit -m "备注信息" 4 …...
地震勘探——干扰波识别、井中地震时距曲线特点
目录 干扰波识别反射波地震勘探的干扰波 井中地震时距曲线特点 干扰波识别 有效波:可以用来解决所提出的地质任务的波;干扰波:所有妨碍辨认、追踪有效波的其他波。 地震勘探中,有效波和干扰波是相对的。例如,在反射波…...
Linux 文件类型,目录与路径,文件与目录管理
文件类型 后面的字符表示文件类型标志 普通文件:-(纯文本文件,二进制文件,数据格式文件) 如文本文件、图片、程序文件等。 目录文件:d(directory) 用来存放其他文件或子目录。 设备…...
Zustand 状态管理库:极简而强大的解决方案
Zustand 是一个轻量级、快速和可扩展的状态管理库,特别适合 React 应用。它以简洁的 API 和高效的性能解决了 Redux 等状态管理方案中的繁琐问题。 核心优势对比 基本使用指南 1. 创建 Store // store.js import create from zustandconst useStore create((set)…...
Unity3D中Gfx.WaitForPresent优化方案
前言 在Unity中,Gfx.WaitForPresent占用CPU过高通常表示主线程在等待GPU完成渲染(即CPU被阻塞),这表明存在GPU瓶颈或垂直同步/帧率设置问题。以下是系统的优化方案: 对惹,这里有一个游戏开发交流小组&…...
(转)什么是DockerCompose?它有什么作用?
一、什么是DockerCompose? DockerCompose可以基于Compose文件帮我们快速的部署分布式应用,而无需手动一个个创建和运行容器。 Compose文件是一个文本文件,通过指令定义集群中的每个容器如何运行。 DockerCompose就是把DockerFile转换成指令去运行。 …...
网络编程(UDP编程)
思维导图 UDP基础编程(单播) 1.流程图 服务器:短信的接收方 创建套接字 (socket)-----------------------------------------》有手机指定网络信息-----------------------------------------------》有号码绑定套接字 (bind)--------------…...
2023赣州旅游投资集团
单选题 1.“不登高山,不知天之高也;不临深溪,不知地之厚也。”这句话说明_____。 A、人的意识具有创造性 B、人的认识是独立于实践之外的 C、实践在认识过程中具有决定作用 D、人的一切知识都是从直接经验中获得的 参考答案: C 本题解…...
在QWebEngineView上实现鼠标、触摸等事件捕获的解决方案
这个问题我看其他博主也写了,要么要会员、要么写的乱七八糟。这里我整理一下,把问题说清楚并且给出代码,拿去用就行,照着葫芦画瓢。 问题 在继承QWebEngineView后,重写mousePressEvent或event函数无法捕获鼠标按下事…...
