生成对抗网络入门案例
前言
生成对抗网络(Generative Adversarial Networks,简称GANs)是一种用于生成新样本的机器学习模型。它由两个主要组件组成:生成器(Generator)和判别器(Discriminator)。生成器尝试生成与训练数据相似的新样本,而判别器则试图区分生成器生成的样本和真实训练数据。
下面是一个简单的对抗生成网络的入门例子,用于生成手写数字图像:
实现过程
1、导入必要的库和模块
import numpy as np
import matplotlib.pyplot as plt
from tensorflow.keras.datasets import mnist
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Flatten, Reshape
from tensorflow.keras.layers import Conv2D, Conv2DTranspose
from tensorflow.keras.optimizers import Adam
2、加载MNIST数据集
(x_train, _), (_, _) = mnist.load_data()
x_train = x_train / 255.0
x_train = np.expand_dims(x_train, axis=3)
3、定义生成器模型
generator = Sequential()
generator.add(Dense(7*7*128, input_shape=(100,), activation='relu'))
generator.add(Reshape((7, 7, 128)))
generator.add(Conv2DTranspose(64, (3, 3), strides=(2, 2), padding='same', activation='relu'))
generator.add(Conv2DTranspose(1, (3, 3), strides=(2, 2), padding='same', activation='sigmoid'))
4、定义判别器模型
discriminator = Sequential()
discriminator.add(Conv2D(64, (3, 3), strides=(2, 2), padding='same', input_shape=(28, 28, 1), activation='relu'))
discriminator.add(Conv2D(128, (3, 3), strides=(2, 2), padding='same', activation='relu'))
discriminator.add(Flatten())
discriminator.add(Dense(1, activation='sigmoid'))
5、编译判别器模型
discriminator.compile(loss='binary_crossentropy', optimizer=Adam(learning_rate=0.0002, beta_1=0.5), metrics=['accuracy'])
6、冻结判别器模型的权重
discriminator.trainable = False
7、定义GAN模型
gan = Sequential()
gan.add(generator)
gan.add(discriminator)
8、编译GAN模型
gan.compile(loss='binary_crossentropy', optimizer=Adam(learning_rate=0.0002, beta_1=0.5))
9、定义训练函数
def train_gan(epochs, batch_size, sample_interval):for epoch in range(epochs):# 生成随机噪声作为输入noise = np.random.normal(0, 1, (batch_size, 100))# 生成假样本generated_images = generator.predict(noise)# 从真实样本中随机选择一批样本real_images = x_train[np.random.randint(0, x_train.shape[0], batch_size)]# 训练判别器discriminator_loss_real = discriminator.train_on_batch(real_images, np.ones((batch_size, 1)))discriminator_loss_fake = discriminator.train_on_batch(generated_images, np.zeros((batch_size, 1)))discriminator_loss = 0.5 * np.add(discriminator_loss_real, discriminator_loss_fake)# 训练生成器noise = np.random.normal(0, 1, (batch_size, 100))generator_loss = gan.train_on_batch(noise, np.ones((batch_size, 1)))# 打印损失if epoch % sample_interval == 0:print(f"Epoch {epoch}/{epochs}, Discriminator Loss: {discriminator_loss[0]}, Generator Loss: {generator_loss}")# 保存生成的图像save_images(epoch)
10、保存生成的图像
def save_images(epoch):rows, cols = 5, 5noise = np.random.normal(0, 1, (rows * cols, 100))generated_images = generator.predict(noise)generated_images = 0.5 * generated_images + 0.5fig, axs = plt.subplots(rows, cols)idx = 0for i in range(rows):for j in range(cols):axs[i, j].imshow(generated_images[idx, :, :, 0], cmap='gray')axs[i, j].axis('off')idx += 1fig.savefig(f"gan_images/mnist_{epoch}.png")plt.close()
11、训练GAN模型
epochs = 10000
batch_size = 128
sample_interval = 1000
完整代码
import numpy as np
import matplotlib.pyplot as plt
from tensorflow.keras.datasets import mnist
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Flatten, Reshape
from tensorflow.keras.layers import Conv2D, Conv2DTranspose
from tensorflow.keras.optimizers import Adam# 加载MNIST数据集
(x_train, _), (_, _) = mnist.load_data()
x_train = x_train / 255.0
x_train = np.expand_dims(x_train, axis=3)# 定义生成器模型
generator = Sequential()
generator.add(Dense(7*7*128, input_shape=(100,), activation='relu'))
generator.add(Reshape((7, 7, 128)))
generator.add(Conv2DTranspose(64, (3, 3), strides=(2, 2), padding='same', activation='relu'))
generator.add(Conv2DTranspose(1, (3, 3), strides=(2, 2), padding='same', activation='sigmoid'))# 定义判别器模型
discriminator = Sequential()
discriminator.add(Conv2D(64, (3, 3), strides=(2, 2), padding='same', input_shape=(28, 28, 1), activation='relu'))
discriminator.add(Conv2D(128, (3, 3), strides=(2, 2), padding='same', activation='relu'))
discriminator.add(Flatten())
discriminator.add(Dense(1, activation='sigmoid'))# 编译判别器模型
discriminator.compile(loss='binary_crossentropy', optimizer=Adam(learning_rate=0.0002, beta_1=0.5), metrics=['accuracy'])# 冻结判别器模型的权重
discriminator.trainable = False# 定义GAN模型
gan = Sequential()
gan.add(generator)
gan.add(discriminator)# 编译GAN模型
gan.compile(loss='binary_crossentropy', optimizer=Adam(learning_rate=0.0002, beta_1=0.5))# 定义训练函数
def train_gan(epochs, batch_size, sample_interval):for epoch in range(epochs):# 生成随机噪声作为输入noise = np.random.normal(0, 1, (batch_size, 100))# 生成假样本generated_images = generator.predict(noise)# 从真实样本中随机选择一批样本real_images = x_train[np.random.randint(0, x_train.shape[0], batch_size)]# 训练判别器discriminator_loss_real = discriminator.train_on_batch(real_images, np.ones((batch_size, 1)))discriminator_loss_fake = discriminator.train_on_batch(generated_images, np.zeros((batch_size, 1)))discriminator_loss = 0.5 * np.add(discriminator_loss_real, discriminator_loss_fake)# 训练生成器noise = np.random.normal(0, 1, (batch_size, 100))generator_loss = gan.train_on_batch(noise, np.ones((batch_size, 1)))# 打印损失if epoch % sample_interval == 0:print(f"Epoch {epoch}/{epochs}, Discriminator Loss: {discriminator_loss[0]}, Generator Loss: {generator_loss}")# 保存生成的图像save_images(epoch)# 保存生成的图像
def save_images(epoch):rows, cols = 5, 5noise = np.random.normal(0, 1, (rows * cols, 100))generated_images = generator.predict(noise)generated_images = 0.5 * generated_images + 0.5fig, axs = plt.subplots(rows, cols)idx = 0for i in range(rows):for j in range(cols):axs[i, j].imshow(generated_images[idx, :, :, 0], cmap='gray')axs[i, j].axis('off')idx += 1fig.savefig(f"gan_images/mnist_{epoch}.png")plt.close()# 训练GAN模型
epochs = 10000
batch_size = 128
sample_interval = 1000train_gan(epochs, batch_size, sample_interval)
训练结果:
这个例子使用了MNIST数据集,生成手写数字图像。生成器和判别器模型使用了卷积神经网络的结构。在训练过程中,生成器试图生成逼真的手写数字图像,而判别器则试图区分真实图像和生成图像。通过反复迭代训练生成器和判别器,GAN模型能够逐渐生成更逼真的手写数字图像。生成的图像会保存在gan_images文件夹中。
相关文章:

生成对抗网络入门案例
前言 生成对抗网络(Generative Adversarial Networks,简称GANs)是一种用于生成新样本的机器学习模型。它由两个主要组件组成:生成器(Generator)和判别器(Discriminator)。生成器尝试…...

多头注意力机制
1、什么是多头注意力机制 从多头注意力的结构图中,貌似这个所谓的多个头就是指多组线性变换,但是并不是,只使用了一组线性变换层,即三个变换张量对 Q、K、V 分别进行线性变换,这些变化不会改变原有张量的尺寸…...

Qt + FFmpeg 搭建 Windows 开发环境
Qt FFmpeg 搭建 Windows 开发环境 Qt FFmpeg 搭建 Windows 开发环境安装 Qt Creator下载 FFmpeg 编译包测试 Qt FFmpeg踩坑解决方法1:换一个 FFmpeg 库解决方法2:把项目改成 64 位 后记 官方博客:https://www.yafeilinux.com/ Qt开源社区…...

[网鼎杯 2020 白虎组]PicDown python反弹shell proc/self目录的信息
[网鼎杯 2020 白虎组]PicDown - 知乎 这里确实完全不会 第一次遇到一个只有文件读取思路的题目 这里也确实说明还是要学学一些其他的东西了 首先打开环境 只存在一个框框 我们通过 目录扫描 抓包 注入 发现没有用 我们测试能不能任意文件读取 ?url../../../../etc/passwd …...

SDL2绘制ffmpeg解析的mp4文件
文章目录 1.FFMPEG利用命令行将mp4转yuv4202.ffmpeg将mp4解析为yuv数据2.1 核心api: 3.SDL2进行yuv绘制到屏幕3.1 核心api 4.完整代码5.效果展示6.SDL2事件响应补充6.1 处理方式-016.2 处理方式-02 本项目采用生产者消费者模型,生产者线程:使用ffmpeg将m…...

决策树C4.5算法的技术深度剖析、实战解读
目录 一、简介决策树(Decision Tree)例子: 信息熵(Information Entropy)与信息增益(Information Gain)例子: 信息增益比(Gain Ratio)例子: 二、算…...

LLMs Python解释器程序辅助语言模型(PAL)Program-aided language models (PAL)
正如您在本课程早期看到的,LLM执行算术和其他数学运算的能力是有限的。虽然您可以尝试使用链式思维提示来克服这一问题,但它只能帮助您走得更远。即使模型正确地通过了问题的推理,对于较大的数字或复杂的运算,它仍可能在个别数学操…...

【12】c++设计模式——>单例模式练习(任务队列)
属性: (1)存储任务的容器,这个容器可以选择使用STL中的队列(queue) (2)互斥锁,多线程访问的时候用于保护任务队列中的数据 方法:主要是对任务队列中的任务进行操作 &…...

Python之函数、模块、包库
函数、模块、包库基础概念和作用 A、函数 减少代码重复 将复杂问题代码分解成简单模块 提高代码可读性 复用老代码 """ 函数 """# 定义一个函数 def my_fuvtion():# 函数执行部分print(这是一个函数)# 定义带有参数的函数 def say_hello(n…...
SQL创建与删除索引
索引创建、删除与使用: 1.1 create方式创建索引:CREATE [UNIQUE – 唯一索引 | FULLTEXT – 全文索引 ] INDEX index_name ON table_name – 不指定唯一或全文时默认普通索引 (column1[(length) [DESC|ASC]] [,column2,…]) – 可以对多列建立组合索引 …...

网络协议--链路层
2.1 引言 从图1-4中可以看出,在TCP/IP协议族中,链路层主要有三个目的: (1)为IP模块发送和接收IP数据报; (2)为ARP模块发送ARP请求和接收ARP应答; (3…...
HDLbits: Count clock
目前写过最长的verilog代码,用了将近三个小时,编写12h显示的时钟,改来改去,估计只有我自己看得懂(吐血) module top_module(input clk,input reset,input ena,output pm,output [7:0] hh,output [7:0] mm,…...
【1day】用友移动管理系统任意文件上传漏洞学习
注:该文章来自作者日常学习笔记,请勿利用文章内的相关技术从事非法测试,如因此产生的一切不良后果与作者无关。 目录 一、漏洞描述 二、影响版本 三、资产测绘 四、漏洞复现...
【c++】向webrtc学习容器操作
std::map的key为std::pair 时的查找 std::map<RemoteAndLocalNetworkId, size_t> in_flight_bytes_RTC_GUARDED_BY(&lock_);private:using RemoteAndLocalNetworkId = std::pair<uint16_t, uint16_t...

SpringBoot+Vue3外卖项目构思
SpringBoot的学习: SpringBoot的学习_明里灰的博客-CSDN博客 实现功能 前台 用户注册,邮箱登录,地址管理,历史订单,菜品规格,购物车,下单,菜品浏览,评价,…...

【AI视野·今日NLP 自然语言处理论文速览 第四十七期】Wed, 4 Oct 2023
AI视野今日CS.NLP 自然语言处理论文速览 Wed, 4 Oct 2023 Totally 73 papers 👉上期速览✈更多精彩请移步主页 Daily Computation and Language Papers Contrastive Post-training Large Language Models on Data Curriculum Authors Canwen Xu, Corby Rosset, Luc…...
c++的lambda表达式
文章目录 1 lambda表达式2 捕捉列表 vs 参数列表3 lambda表达式的传递3.1 函数作为形参3.2 场景1:条件表达式3.3 场景2:线程的运行表达式 1 lambda表达式 lambda表达式可以理解为匿名函数,也就是没有名字的函数,既然是函数&#…...

电梯安全监测丨S271W无线水浸传感器用于电梯机房/电梯基坑水浸监测
城市化进程中,电梯与我们的生活息息相关。高层住宅、医院、商场、学校、车站等各种商业体建筑、公共建筑中电梯为我们生活工作提供了诸多便利。 保障电梯系统的安全至关重要!特别是电梯机房和电梯基坑可通过智能化改造提高其安全性和稳定性。例如在暴风…...
Java异常:基本概念、分类和处理
Java异常:基本概念、分类和处理 在Java编程中,异常处理是一个非常重要的部分。了解如何识别、处理和避免异常对于编写健壮、可维护的代码至关重要。本文将介绍Java异常的基本概念、分类和处理方法,并通过简单的代码示例进行说明。 一、什么…...

小谈设计模式(19)—备忘录模式
小谈设计模式(19)—备忘录模式 专栏介绍专栏地址专栏介绍 备忘录模式主要角色发起人(Originator)备忘录(Memento)管理者(Caretaker) 应用场景结构实现步骤Java程序实现首先ÿ…...
Ubuntu系统下交叉编译openssl
一、参考资料 OpenSSL&&libcurl库的交叉编译 - hesetone - 博客园 二、准备工作 1. 编译环境 宿主机:Ubuntu 20.04.6 LTSHost:ARM32位交叉编译器:arm-linux-gnueabihf-gcc-11.1.0 2. 设置交叉编译工具链 在交叉编译之前&#x…...
基于大模型的 UI 自动化系统
基于大模型的 UI 自动化系统 下面是一个完整的 Python 系统,利用大模型实现智能 UI 自动化,结合计算机视觉和自然语言处理技术,实现"看屏操作"的能力。 系统架构设计 #mermaid-svg-2gn2GRvh5WCP2ktF {font-family:"trebuchet ms",verdana,arial,sans-…...

Prompt Tuning、P-Tuning、Prefix Tuning的区别
一、Prompt Tuning、P-Tuning、Prefix Tuning的区别 1. Prompt Tuning(提示调优) 核心思想:固定预训练模型参数,仅学习额外的连续提示向量(通常是嵌入层的一部分)。实现方式:在输入文本前添加可训练的连续向量(软提示),模型只更新这些提示参数。优势:参数量少(仅提…...
React hook之useRef
React useRef 详解 useRef 是 React 提供的一个 Hook,用于在函数组件中创建可变的引用对象。它在 React 开发中有多种重要用途,下面我将全面详细地介绍它的特性和用法。 基本概念 1. 创建 ref const refContainer useRef(initialValue);initialValu…...

Day131 | 灵神 | 回溯算法 | 子集型 子集
Day131 | 灵神 | 回溯算法 | 子集型 子集 78.子集 78. 子集 - 力扣(LeetCode) 思路: 笔者写过很多次这道题了,不想写题解了,大家看灵神讲解吧 回溯算法套路①子集型回溯【基础算法精讲 14】_哔哩哔哩_bilibili 完…...
【SpringBoot】100、SpringBoot中使用自定义注解+AOP实现参数自动解密
在实际项目中,用户注册、登录、修改密码等操作,都涉及到参数传输安全问题。所以我们需要在前端对账户、密码等敏感信息加密传输,在后端接收到数据后能自动解密。 1、引入依赖 <dependency><groupId>org.springframework.boot</groupId><artifactId...
【磁盘】每天掌握一个Linux命令 - iostat
目录 【磁盘】每天掌握一个Linux命令 - iostat工具概述安装方式核心功能基础用法进阶操作实战案例面试题场景生产场景 注意事项 【磁盘】每天掌握一个Linux命令 - iostat 工具概述 iostat(I/O Statistics)是Linux系统下用于监视系统输入输出设备和CPU使…...

[ICLR 2022]How Much Can CLIP Benefit Vision-and-Language Tasks?
论文网址:pdf 英文是纯手打的!论文原文的summarizing and paraphrasing。可能会出现难以避免的拼写错误和语法错误,若有发现欢迎评论指正!文章偏向于笔记,谨慎食用 目录 1. 心得 2. 论文逐段精读 2.1. Abstract 2…...

微服务商城-商品微服务
数据表 CREATE TABLE product (id bigint(20) UNSIGNED NOT NULL AUTO_INCREMENT COMMENT 商品id,cateid smallint(6) UNSIGNED NOT NULL DEFAULT 0 COMMENT 类别Id,name varchar(100) NOT NULL DEFAULT COMMENT 商品名称,subtitle varchar(200) NOT NULL DEFAULT COMMENT 商…...

Aspose.PDF 限制绕过方案:Java 字节码技术实战分享(仅供学习)
Aspose.PDF 限制绕过方案:Java 字节码技术实战分享(仅供学习) 一、Aspose.PDF 简介二、说明(⚠️仅供学习与研究使用)三、技术流程总览四、准备工作1. 下载 Jar 包2. Maven 项目依赖配置 五、字节码修改实现代码&#…...