生成对抗网络入门案例
前言
生成对抗网络(Generative Adversarial Networks,简称GANs)是一种用于生成新样本的机器学习模型。它由两个主要组件组成:生成器(Generator)和判别器(Discriminator)。生成器尝试生成与训练数据相似的新样本,而判别器则试图区分生成器生成的样本和真实训练数据。
下面是一个简单的对抗生成网络的入门例子,用于生成手写数字图像:
实现过程
1、导入必要的库和模块
import numpy as np
import matplotlib.pyplot as plt
from tensorflow.keras.datasets import mnist
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Flatten, Reshape
from tensorflow.keras.layers import Conv2D, Conv2DTranspose
from tensorflow.keras.optimizers import Adam
2、加载MNIST数据集
(x_train, _), (_, _) = mnist.load_data()
x_train = x_train / 255.0
x_train = np.expand_dims(x_train, axis=3)
3、定义生成器模型
generator = Sequential()
generator.add(Dense(7*7*128, input_shape=(100,), activation='relu'))
generator.add(Reshape((7, 7, 128)))
generator.add(Conv2DTranspose(64, (3, 3), strides=(2, 2), padding='same', activation='relu'))
generator.add(Conv2DTranspose(1, (3, 3), strides=(2, 2), padding='same', activation='sigmoid'))
4、定义判别器模型
discriminator = Sequential()
discriminator.add(Conv2D(64, (3, 3), strides=(2, 2), padding='same', input_shape=(28, 28, 1), activation='relu'))
discriminator.add(Conv2D(128, (3, 3), strides=(2, 2), padding='same', activation='relu'))
discriminator.add(Flatten())
discriminator.add(Dense(1, activation='sigmoid'))
5、编译判别器模型
discriminator.compile(loss='binary_crossentropy', optimizer=Adam(learning_rate=0.0002, beta_1=0.5), metrics=['accuracy'])
6、冻结判别器模型的权重
discriminator.trainable = False
7、定义GAN模型
gan = Sequential()
gan.add(generator)
gan.add(discriminator)
8、编译GAN模型
gan.compile(loss='binary_crossentropy', optimizer=Adam(learning_rate=0.0002, beta_1=0.5))
9、定义训练函数
def train_gan(epochs, batch_size, sample_interval):for epoch in range(epochs):# 生成随机噪声作为输入noise = np.random.normal(0, 1, (batch_size, 100))# 生成假样本generated_images = generator.predict(noise)# 从真实样本中随机选择一批样本real_images = x_train[np.random.randint(0, x_train.shape[0], batch_size)]# 训练判别器discriminator_loss_real = discriminator.train_on_batch(real_images, np.ones((batch_size, 1)))discriminator_loss_fake = discriminator.train_on_batch(generated_images, np.zeros((batch_size, 1)))discriminator_loss = 0.5 * np.add(discriminator_loss_real, discriminator_loss_fake)# 训练生成器noise = np.random.normal(0, 1, (batch_size, 100))generator_loss = gan.train_on_batch(noise, np.ones((batch_size, 1)))# 打印损失if epoch % sample_interval == 0:print(f"Epoch {epoch}/{epochs}, Discriminator Loss: {discriminator_loss[0]}, Generator Loss: {generator_loss}")# 保存生成的图像save_images(epoch)
10、保存生成的图像
def save_images(epoch):rows, cols = 5, 5noise = np.random.normal(0, 1, (rows * cols, 100))generated_images = generator.predict(noise)generated_images = 0.5 * generated_images + 0.5fig, axs = plt.subplots(rows, cols)idx = 0for i in range(rows):for j in range(cols):axs[i, j].imshow(generated_images[idx, :, :, 0], cmap='gray')axs[i, j].axis('off')idx += 1fig.savefig(f"gan_images/mnist_{epoch}.png")plt.close()
11、训练GAN模型
epochs = 10000
batch_size = 128
sample_interval = 1000
完整代码
import numpy as np
import matplotlib.pyplot as plt
from tensorflow.keras.datasets import mnist
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Flatten, Reshape
from tensorflow.keras.layers import Conv2D, Conv2DTranspose
from tensorflow.keras.optimizers import Adam# 加载MNIST数据集
(x_train, _), (_, _) = mnist.load_data()
x_train = x_train / 255.0
x_train = np.expand_dims(x_train, axis=3)# 定义生成器模型
generator = Sequential()
generator.add(Dense(7*7*128, input_shape=(100,), activation='relu'))
generator.add(Reshape((7, 7, 128)))
generator.add(Conv2DTranspose(64, (3, 3), strides=(2, 2), padding='same', activation='relu'))
generator.add(Conv2DTranspose(1, (3, 3), strides=(2, 2), padding='same', activation='sigmoid'))# 定义判别器模型
discriminator = Sequential()
discriminator.add(Conv2D(64, (3, 3), strides=(2, 2), padding='same', input_shape=(28, 28, 1), activation='relu'))
discriminator.add(Conv2D(128, (3, 3), strides=(2, 2), padding='same', activation='relu'))
discriminator.add(Flatten())
discriminator.add(Dense(1, activation='sigmoid'))# 编译判别器模型
discriminator.compile(loss='binary_crossentropy', optimizer=Adam(learning_rate=0.0002, beta_1=0.5), metrics=['accuracy'])# 冻结判别器模型的权重
discriminator.trainable = False# 定义GAN模型
gan = Sequential()
gan.add(generator)
gan.add(discriminator)# 编译GAN模型
gan.compile(loss='binary_crossentropy', optimizer=Adam(learning_rate=0.0002, beta_1=0.5))# 定义训练函数
def train_gan(epochs, batch_size, sample_interval):for epoch in range(epochs):# 生成随机噪声作为输入noise = np.random.normal(0, 1, (batch_size, 100))# 生成假样本generated_images = generator.predict(noise)# 从真实样本中随机选择一批样本real_images = x_train[np.random.randint(0, x_train.shape[0], batch_size)]# 训练判别器discriminator_loss_real = discriminator.train_on_batch(real_images, np.ones((batch_size, 1)))discriminator_loss_fake = discriminator.train_on_batch(generated_images, np.zeros((batch_size, 1)))discriminator_loss = 0.5 * np.add(discriminator_loss_real, discriminator_loss_fake)# 训练生成器noise = np.random.normal(0, 1, (batch_size, 100))generator_loss = gan.train_on_batch(noise, np.ones((batch_size, 1)))# 打印损失if epoch % sample_interval == 0:print(f"Epoch {epoch}/{epochs}, Discriminator Loss: {discriminator_loss[0]}, Generator Loss: {generator_loss}")# 保存生成的图像save_images(epoch)# 保存生成的图像
def save_images(epoch):rows, cols = 5, 5noise = np.random.normal(0, 1, (rows * cols, 100))generated_images = generator.predict(noise)generated_images = 0.5 * generated_images + 0.5fig, axs = plt.subplots(rows, cols)idx = 0for i in range(rows):for j in range(cols):axs[i, j].imshow(generated_images[idx, :, :, 0], cmap='gray')axs[i, j].axis('off')idx += 1fig.savefig(f"gan_images/mnist_{epoch}.png")plt.close()# 训练GAN模型
epochs = 10000
batch_size = 128
sample_interval = 1000train_gan(epochs, batch_size, sample_interval)
训练结果:
这个例子使用了MNIST数据集,生成手写数字图像。生成器和判别器模型使用了卷积神经网络的结构。在训练过程中,生成器试图生成逼真的手写数字图像,而判别器则试图区分真实图像和生成图像。通过反复迭代训练生成器和判别器,GAN模型能够逐渐生成更逼真的手写数字图像。生成的图像会保存在gan_images文件夹中。
相关文章:

生成对抗网络入门案例
前言 生成对抗网络(Generative Adversarial Networks,简称GANs)是一种用于生成新样本的机器学习模型。它由两个主要组件组成:生成器(Generator)和判别器(Discriminator)。生成器尝试…...

多头注意力机制
1、什么是多头注意力机制 从多头注意力的结构图中,貌似这个所谓的多个头就是指多组线性变换,但是并不是,只使用了一组线性变换层,即三个变换张量对 Q、K、V 分别进行线性变换,这些变化不会改变原有张量的尺寸…...

Qt + FFmpeg 搭建 Windows 开发环境
Qt FFmpeg 搭建 Windows 开发环境 Qt FFmpeg 搭建 Windows 开发环境安装 Qt Creator下载 FFmpeg 编译包测试 Qt FFmpeg踩坑解决方法1:换一个 FFmpeg 库解决方法2:把项目改成 64 位 后记 官方博客:https://www.yafeilinux.com/ Qt开源社区…...

[网鼎杯 2020 白虎组]PicDown python反弹shell proc/self目录的信息
[网鼎杯 2020 白虎组]PicDown - 知乎 这里确实完全不会 第一次遇到一个只有文件读取思路的题目 这里也确实说明还是要学学一些其他的东西了 首先打开环境 只存在一个框框 我们通过 目录扫描 抓包 注入 发现没有用 我们测试能不能任意文件读取 ?url../../../../etc/passwd …...

SDL2绘制ffmpeg解析的mp4文件
文章目录 1.FFMPEG利用命令行将mp4转yuv4202.ffmpeg将mp4解析为yuv数据2.1 核心api: 3.SDL2进行yuv绘制到屏幕3.1 核心api 4.完整代码5.效果展示6.SDL2事件响应补充6.1 处理方式-016.2 处理方式-02 本项目采用生产者消费者模型,生产者线程:使用ffmpeg将m…...

决策树C4.5算法的技术深度剖析、实战解读
目录 一、简介决策树(Decision Tree)例子: 信息熵(Information Entropy)与信息增益(Information Gain)例子: 信息增益比(Gain Ratio)例子: 二、算…...

LLMs Python解释器程序辅助语言模型(PAL)Program-aided language models (PAL)
正如您在本课程早期看到的,LLM执行算术和其他数学运算的能力是有限的。虽然您可以尝试使用链式思维提示来克服这一问题,但它只能帮助您走得更远。即使模型正确地通过了问题的推理,对于较大的数字或复杂的运算,它仍可能在个别数学操…...

【12】c++设计模式——>单例模式练习(任务队列)
属性: (1)存储任务的容器,这个容器可以选择使用STL中的队列(queue) (2)互斥锁,多线程访问的时候用于保护任务队列中的数据 方法:主要是对任务队列中的任务进行操作 &…...

Python之函数、模块、包库
函数、模块、包库基础概念和作用 A、函数 减少代码重复 将复杂问题代码分解成简单模块 提高代码可读性 复用老代码 """ 函数 """# 定义一个函数 def my_fuvtion():# 函数执行部分print(这是一个函数)# 定义带有参数的函数 def say_hello(n…...
SQL创建与删除索引
索引创建、删除与使用: 1.1 create方式创建索引:CREATE [UNIQUE – 唯一索引 | FULLTEXT – 全文索引 ] INDEX index_name ON table_name – 不指定唯一或全文时默认普通索引 (column1[(length) [DESC|ASC]] [,column2,…]) – 可以对多列建立组合索引 …...

网络协议--链路层
2.1 引言 从图1-4中可以看出,在TCP/IP协议族中,链路层主要有三个目的: (1)为IP模块发送和接收IP数据报; (2)为ARP模块发送ARP请求和接收ARP应答; (3…...
HDLbits: Count clock
目前写过最长的verilog代码,用了将近三个小时,编写12h显示的时钟,改来改去,估计只有我自己看得懂(吐血) module top_module(input clk,input reset,input ena,output pm,output [7:0] hh,output [7:0] mm,…...
【1day】用友移动管理系统任意文件上传漏洞学习
注:该文章来自作者日常学习笔记,请勿利用文章内的相关技术从事非法测试,如因此产生的一切不良后果与作者无关。 目录 一、漏洞描述 二、影响版本 三、资产测绘 四、漏洞复现...
【c++】向webrtc学习容器操作
std::map的key为std::pair 时的查找 std::map<RemoteAndLocalNetworkId, size_t> in_flight_bytes_RTC_GUARDED_BY(&lock_);private:using RemoteAndLocalNetworkId = std::pair<uint16_t, uint16_t...

SpringBoot+Vue3外卖项目构思
SpringBoot的学习: SpringBoot的学习_明里灰的博客-CSDN博客 实现功能 前台 用户注册,邮箱登录,地址管理,历史订单,菜品规格,购物车,下单,菜品浏览,评价,…...

【AI视野·今日NLP 自然语言处理论文速览 第四十七期】Wed, 4 Oct 2023
AI视野今日CS.NLP 自然语言处理论文速览 Wed, 4 Oct 2023 Totally 73 papers 👉上期速览✈更多精彩请移步主页 Daily Computation and Language Papers Contrastive Post-training Large Language Models on Data Curriculum Authors Canwen Xu, Corby Rosset, Luc…...
c++的lambda表达式
文章目录 1 lambda表达式2 捕捉列表 vs 参数列表3 lambda表达式的传递3.1 函数作为形参3.2 场景1:条件表达式3.3 场景2:线程的运行表达式 1 lambda表达式 lambda表达式可以理解为匿名函数,也就是没有名字的函数,既然是函数&#…...

电梯安全监测丨S271W无线水浸传感器用于电梯机房/电梯基坑水浸监测
城市化进程中,电梯与我们的生活息息相关。高层住宅、医院、商场、学校、车站等各种商业体建筑、公共建筑中电梯为我们生活工作提供了诸多便利。 保障电梯系统的安全至关重要!特别是电梯机房和电梯基坑可通过智能化改造提高其安全性和稳定性。例如在暴风…...
Java异常:基本概念、分类和处理
Java异常:基本概念、分类和处理 在Java编程中,异常处理是一个非常重要的部分。了解如何识别、处理和避免异常对于编写健壮、可维护的代码至关重要。本文将介绍Java异常的基本概念、分类和处理方法,并通过简单的代码示例进行说明。 一、什么…...

小谈设计模式(19)—备忘录模式
小谈设计模式(19)—备忘录模式 专栏介绍专栏地址专栏介绍 备忘录模式主要角色发起人(Originator)备忘录(Memento)管理者(Caretaker) 应用场景结构实现步骤Java程序实现首先ÿ…...

通过Wrangler CLI在worker中创建数据库和表
官方使用文档:Getting started Cloudflare D1 docs 创建数据库 在命令行中执行完成之后,会在本地和远程创建数据库: npx wranglerlatest d1 create prod-d1-tutorial 在cf中就可以看到数据库: 现在,您的Cloudfla…...
OkHttp 中实现断点续传 demo
在 OkHttp 中实现断点续传主要通过以下步骤完成,核心是利用 HTTP 协议的 Range 请求头指定下载范围: 实现原理 Range 请求头:向服务器请求文件的特定字节范围(如 Range: bytes1024-) 本地文件记录:保存已…...

第 86 场周赛:矩阵中的幻方、钥匙和房间、将数组拆分成斐波那契序列、猜猜这个单词
Q1、[中等] 矩阵中的幻方 1、题目描述 3 x 3 的幻方是一个填充有 从 1 到 9 的不同数字的 3 x 3 矩阵,其中每行,每列以及两条对角线上的各数之和都相等。 给定一个由整数组成的row x col 的 grid,其中有多少个 3 3 的 “幻方” 子矩阵&am…...

云原生玩法三问:构建自定义开发环境
云原生玩法三问:构建自定义开发环境 引言 临时运维一个古董项目,无文档,无环境,无交接人,俗称三无。 运行设备的环境老,本地环境版本高,ssh不过去。正好最近对 腾讯出品的云原生 cnb 感兴趣&…...

面向无人机海岸带生态系统监测的语义分割基准数据集
描述:海岸带生态系统的监测是维护生态平衡和可持续发展的重要任务。语义分割技术在遥感影像中的应用为海岸带生态系统的精准监测提供了有效手段。然而,目前该领域仍面临一个挑战,即缺乏公开的专门面向海岸带生态系统的语义分割基准数据集。受…...

手机平板能效生态设计指令EU 2023/1670标准解读
手机平板能效生态设计指令EU 2023/1670标准解读 以下是针对欧盟《手机和平板电脑生态设计法规》(EU) 2023/1670 的核心解读,综合法规核心要求、最新修正及企业合规要点: 一、法规背景与目标 生效与强制时间 发布于2023年8月31日(OJ公报&…...
LangFlow技术架构分析
🔧 LangFlow 的可视化技术栈 前端节点编辑器 底层框架:基于 (一个现代化的 React 节点绘图库) 功能: 拖拽式构建 LangGraph 状态机 实时连线定义节点依赖关系 可视化调试循环和分支逻辑 与 LangGraph 的深…...
Docker拉取MySQL后数据库连接失败的解决方案
在使用Docker部署MySQL时,拉取并启动容器后,有时可能会遇到数据库连接失败的问题。这种问题可能由多种原因导致,包括配置错误、网络设置问题、权限问题等。本文将分析可能的原因,并提供解决方案。 一、确认MySQL容器的运行状态 …...

VisualXML全新升级 | 新增数据库编辑功能
VisualXML是一个功能强大的网络总线设计工具,专注于简化汽车电子系统中复杂的网络数据设计操作。它支持多种主流总线网络格式的数据编辑(如DBC、LDF、ARXML、HEX等),并能够基于Excel表格的方式生成和转换多种数据库文件。由此&…...

恶补电源:1.电桥
一、元器件的选择 搜索并选择电桥,再multisim中选择FWB,就有各种型号的电桥: 电桥是用来干嘛的呢? 它是一个由四个二极管搭成的“桥梁”形状的电路,用来把交流电(AC)变成直流电(DC)。…...