生成对抗网络(GAN)手写数字生成
文章目录
- 一、前言
- 二、前期工作
- 1. 设置GPU(如果使用的是CPU可以忽略这步)
- 二、什么是生成对抗网络
- 1. 简单介绍
- 2. 应用领域
- 三、网络结构
- 四、构建生成器
- 五、构建鉴别器
- 六、训练模型
- 1. 保存样例图片
- 2. 训练模型
- 七、生成动图
一、前言
我的环境:
- 语言环境:Python3.6.5
- 编译器:jupyter notebook
- 深度学习环境:TensorFlow2.4.1
往期精彩内容:
- 卷积神经网络(CNN)实现mnist手写数字识别
- 卷积神经网络(CNN)多种图片分类的实现
- 卷积神经网络(CNN)衣服图像分类的实现
- 卷积神经网络(CNN)鲜花识别
- 卷积神经网络(CNN)天气识别
- 卷积神经网络(VGG-16)识别海贼王草帽一伙
- 卷积神经网络(ResNet-50)鸟类识别
- 卷积神经网络(AlexNet)鸟类识别
- 卷积神经网络(CNN)识别验证码
- 卷积神经网络(Inception-ResNet-v2)交通标志识别
来自专栏:机器学习与深度学习算法推荐
二、前期工作
1. 设置GPU(如果使用的是CPU可以忽略这步)
import tensorflow as tfgpus = tf.config.list_physical_devices("GPU")if gpus:tf.config.experimental.set_memory_growth(gpus[0], True) #设置GPU显存用量按需使用tf.config.set_visible_devices([gpus[0]],"GPU")# 打印显卡信息,确认GPU可用
print(gpus)
from tensorflow.keras import layers, datasets, Sequential, Model, optimizers
from tensorflow.keras.layers import LeakyReLU, UpSampling2D, Conv2Dimport matplotlib.pyplot as plt
import numpy as np
import sys,os,pathlib
img_shape = (28, 28, 1)
latent_dim = 200
二、什么是生成对抗网络
1. 简单介绍
生成对抗网络(GAN)
包含生成器和判别器,两个模型通过对抗训练不断学习、进化。
生成器(Generator)
:生成数据(大部分情况下是图像),目的是“骗过”判别器。鉴别器(Discriminator)
:判断这张图像是真实的还是机器生成的,目的是找出生成器生成的“假数据”。
2. 应用领域
GAN 的应用十分广泛,它的应用包括图像合成、风格迁移、照片修复以及照片编辑,数据增强等等。
1)风格迁移
图像风格迁移是将图像A的风格转换到图像B中去,得到新的图像。
2)图像生成
GAN 不但能生成人脸,还能生成其他类型的图片,比如漫画人物。
三、网络结构
简单来讲,就是用生成器生成手写数字图像,用鉴别器鉴别图像的真假。二者相互对抗学习(卷),在对抗学习(卷)的过程中不断完善自己,直至生成器可以生成以假乱真的图片(鉴别器无法判断其真假)。结构图如下:
GAN步骤:
- 1.生成器(Generator)接收随机数并返回生成图像。
- 2.将生成的数字图像与实际数据集中的数字图像一起送到鉴别器(Discriminator)。
- 3.鉴别器(Discriminator)接收真实和假图像并返回概率,0到1之间的数字,1表示真,0表示假。
四、构建生成器
def build_generator():# ======================================= ## 生成器,输入一串随机数字生成图片# ======================================= #model = Sequential([layers.Dense(256, input_dim=latent_dim),layers.LeakyReLU(alpha=0.2), # 高级一点的激活函数layers.BatchNormalization(momentum=0.8), # BN 归一化layers.Dense(512),layers.LeakyReLU(alpha=0.2),layers.BatchNormalization(momentum=0.8),layers.Dense(1024),layers.LeakyReLU(alpha=0.2),layers.BatchNormalization(momentum=0.8),layers.Dense(np.prod(img_shape), activation='tanh'),layers.Reshape(img_shape)])noise = layers.Input(shape=(latent_dim,))img = model(noise)return Model(noise, img)
五、构建鉴别器
def build_discriminator():# ===================================== ## 鉴别器,对输入的图片进行判别真假# ===================================== #model = Sequential([layers.Flatten(input_shape=img_shape),layers.Dense(512),layers.LeakyReLU(alpha=0.2),layers.Dense(256),layers.LeakyReLU(alpha=0.2),layers.Dense(1, activation='sigmoid')])img = layers.Input(shape=img_shape)validity = model(img)return Model(img, validity)
# 创建判别器
discriminator = build_discriminator()
# 定义优化器
optimizer = tf.keras.optimizers.Adam(1e-4)
discriminator.compile(loss='binary_crossentropy',optimizer=optimizer,metrics=['accuracy'])# 创建生成器
generator = build_generator()
gan_input = layers.Input(shape=(latent_dim,))
img = generator(gan_input)# 对生成的假图片进行预测
validity = discriminator(img)
combined = Model(gan_input, validity)
combined.compile(loss='binary_crossentropy', optimizer=optimizer)
六、训练模型
1. 保存样例图片
def sample_images(epoch):"""保存样例图片"""row, col = 4, 4noise = np.random.normal(0, 1, (row*col, latent_dim))gen_imgs = generator.predict(noise)fig, axs = plt.subplots(row, col)cnt = 0for i in range(row):for j in range(col):axs[i,j].imshow(gen_imgs[cnt, :,:,0], cmap='gray')axs[i,j].axis('off')cnt += 1fig.savefig("images/%05d.png" % epoch)plt.close()
2. 训练模型
train_on_batch
:函数接受单批数据,执行反向传播,然后更新模型参数,该批数据的大小可以是任意的,即,它不需要提供明确的批量大小,属于精细化控制训练模型。
def train(epochs, batch_size=128, sample_interval=50):# 加载数据(train_images,_), (_,_) = tf.keras.datasets.mnist.load_data()# 将图片标准化到 [-1, 1] 区间内 train_images = (train_images - 127.5) / 127.5# 数据train_images = np.expand_dims(train_images, axis=3)# 创建标签true = np.ones((batch_size, 1))fake = np.zeros((batch_size, 1))# 进行循环训练for epoch in range(epochs): # 随机选择 batch_size 张图片idx = np.random.randint(0, train_images.shape[0], batch_size)imgs = train_images[idx] # 生成噪音noise = np.random.normal(0, 1, (batch_size, latent_dim))# 生成器通过噪音生成图片,gen_imgs的shape为:(128, 28, 28, 1)gen_imgs = generator.predict(noise)# 训练鉴别器 d_loss_true = discriminator.train_on_batch(imgs, true)d_loss_fake = discriminator.train_on_batch(gen_imgs, fake)# 返回loss值d_loss = 0.5 * np.add(d_loss_true, d_loss_fake)# 训练生成器noise = np.random.normal(0, 1, (batch_size, latent_dim))g_loss = combined.train_on_batch(noise, true)print ("%d [D loss: %f, acc.: %.2f%%] [G loss: %f]" % (epoch, d_loss[0], 100*d_loss[1], g_loss))# 保存样例图片if epoch % sample_interval == 0:sample_images(epoch)
train(epochs=30000, batch_size=256, sample_interval=200)
七、生成动图
如果报错:ModuleNotFoundError: No module named 'imageio'
可以使用:pip install imageio
安装 imageio
库。
import imageiodef compose_gif():# 图片地址data_dir = "images_old"data_dir = pathlib.Path(data_dir)paths = list(data_dir.glob('*'))gif_images = []for path in paths:print(path)gif_images.append(imageio.imread(path))imageio.mimsave("test.gif",gif_images,fps=2)compose_gif()
相关文章:

生成对抗网络(GAN)手写数字生成
文章目录 一、前言二、前期工作1. 设置GPU(如果使用的是CPU可以忽略这步) 二、什么是生成对抗网络1. 简单介绍2. 应用领域 三、网络结构四、构建生成器五、构建鉴别器六、训练模型1. 保存样例图片2. 训练模型 七、生成动图 一、前言 我的环境࿱…...

LeetCode Hot100 31.下一个排列
题目: 整数数组的一个 排列 就是将其所有成员以序列或线性顺序排列。 例如,arr [1,2,3] ,以下这些都可以视作 arr 的排列:[1,2,3]、[1,3,2]、[3,1,2]、[2,3,1] 。 整数数组的 下一个排列 是指其整数的下一个字典序更大的排列…...

Redis主从与哨兵架构详解
目录 主从架构 主从环境搭建 主从复制流程 1. 全量复制 2. 部分复制 主从风暴 哨兵架构 概念 哨兵环境搭建 主从架构 主从环境搭建 1. 复制一份redis.conf文件, 修改下面几行配置 port 6380 pidfile /var/run/redis_6380.pid logfile "6380.log" dir /usr/…...

Linux:docker的数据管理(6)
数据管理操作*方便查看容器内产生的数据 *多容器间实现数据共享 两种管理方式数据卷 数据卷容器 1.数据卷 数据卷是一个供容器使用的特殊目录,位于容器中,可将宿主机的目录挂载到数据卷上,对数据卷的修改操作立刻可见,并且更新数…...

深入理解Zookeeper系列-1.初识Zoookeeper
👏作者简介:大家好,我是爱吃芝士的土豆倪,24届校招生Java选手,很高兴认识大家📕系列专栏:Spring源码、JUC源码、Kafka原理、分布式技术原理🔥如果感觉博主的文章还不错的话ÿ…...

芯片技术探索:了解构芯片的设计与制造之旅
芯片技术探索:了解构芯片的设计与制造之旅 一、引言 随着现代科技的飞速发展,芯片作为信息技术的核心,已经渗透到我们生活的方方面面。从智能手机、电视、汽车到医疗设备和工业控制系统,芯片在各个领域都发挥着至关重要的作用。然而,对于大多数人来说,芯片仍然是一个神秘…...

STM32 超声波模块(HC-SR04)
HC-SR04介绍 典型工作电压:5v (如果你的超声波模块没有工作,可以看一下是不是电压不够)超小静态工作电流:<2mA 感应角度:<15 (超声波模块,是一个范围式的探…...

ELK+Filebeat
Filebeat概述 1.Filebeat简介 Filebeat是一款轻量级的日志收集工具,可以在非JAVA环境下运行。 因此,Filebeat常被用在非JAVAf的服务器上用于替代Logstash,收集日志信息。实际上,Filebeat几乎可以起到与Logstash相同的作用&…...

MySql之锁表、锁行解决方案
查询正在使用的表,没有跑业务,一般情况下是锁表了 show open tables where in_use > 0 ;查看进程,可以看到Command类型(Sleep为阻塞线程) show processlist;kill事务,kill 进程Id kill 8193583;其他 …...

2023年第十六届山东省职业院校技能大赛中职组“网络安全”赛项竞赛正式试题
第十六届山东省职业院校技能大赛中职组 “网络安全”赛项竞赛试题 目录 一、竞赛时间 二、竞赛阶段 三、竞赛任务书内容 (一)拓扑图 (二)A模块基础设施设置/安全加固(200分) (三…...

JAVA 整合 AWS S3(Amazon Simple Storage Service)文件上传,分片上传,删除,下载
依赖 因为aws需要发送请求上传、下载等api,所以需要加上httpclient相关的依赖 <dependency><groupId>com.amazonaws</groupId><artifactId>aws-java-sdk-s3</artifactId><version>1.11.628</version> </dependency&…...

记录:Unity脚本的编写9.0
目录 射线一些准备工作编写代码 突然发现好像没有写过关于射线的内容,我就说怎么总感觉好像少了什么东西(心虚 那就在这里写一下关于射线的内容吧,将在这里实现射线检测鼠标点击的功能 射线 射线是一种在Unity中检测碰撞器或触发器的方法&am…...

共享单车停放(简单的struct结构运用)
本来不想写这题的,但是想想最近沉迷玩雨世界,班长又问我这题,就草草写了一下 代码如下: #include<stdio.h> #include<math.h> struct parking{int distance;int remain;int speed;int time;int jud; }parking[50]; …...

【Java8系列07】Java8日期处理
💝💝💝欢迎来到我的博客,很高兴能够在这里和您见面!希望您在这里可以感受到一份轻松愉快的氛围,不仅可以获得有趣的内容和知识,也可以畅所欲言、分享您的想法和见解。 推荐:kwan 的首页,持续学…...

为什么做CSGO搬砖的不直接去炒股呢?
首先,CS2并非只有一个交易平台,阿阳个人觉得像IGXE等交易平台一样是交易,况且我记得很早的时候我就开始用IGXE了,我记得最早的时候还是机器人发货,后来因为V社对于很多开箱网站的管控,所以让这种发货的方式…...

12月01日,每日信息差//阿里国际发布3款AI设计生态工具//美团买菜升级为“小象超市”//外国人永居证换新、6国游客免签来华
_灵感 🎖 阿里国际发布3款AI设计生态工具 🎄 AITO问界系列11月交付新车18827辆 🌍 美团买菜升级为“小象超市” 🌋 全球首个金融风控大模型国际标准出炉,由腾讯牵头制定 🎁 支付宝:支持外国人…...

ChatGPT探索:提示工程详解—程序员效率提升必备技能【文末送书】
文章目录 一.人工智能-ChatGPT1.1 ChatGPT简介1.2 ChatGPT探索:提示工程详解1.2 提示工程的优势 二.提示工程探索2.1 提示工程实例:2.2 英语学习助手2.3 Active-Prompt思维链(CoT)方法2.4 提示工程总结 三.文末推荐与福利3.1《Cha…...

Pytest做性能测试?
Pytest其实也是可以做性能测试或者基准测试的。是非常方便的。 可以考虑使用Pytest-benchmark类库进行。 安装pytest-benchmark 首先,确保已经安装了pytest和pytest-benchmark插件。可以使用以下命令安装插件: pip install pytest pytest-benchmark …...

Swagger各版本访问地址
2.9.x 访问地址: http://ip:port/{context-path}/swagger-ui.html 3.0.x 访问地址: http://ip:port/{context-path}/swagger-ui/index.html 3.0集成knife4j 访问地址: http://ip:port/{context-path}/doc.html...

docker-compose;私有镜像仓库harbor搭建;镜像推送到私有仓库harbor
docker-compose;私有镜像仓库harbor搭建;镜像推送到私有仓库harbor 文章目录 docker-compose;私有镜像仓库harbor搭建;镜像推送到私有仓库harbordocker-compose私有镜像仓库harbor搭建镜像推送到私有仓库harbor docker-compose D…...

OpenTSDB(CVE-202035476)漏洞复现及利用
任务一: 复现环境中的命令注入漏洞。 任务二: 利用命令注入执行whoami,使用DNS外带技术获取结果 任务三:使用反弹shell,将漏洞环境中的shell反弹到宿主机或者vps服务器。 任务一: 1.搭建好环境 2.先去了…...

Maven无法拉取依赖/构建失败操作步骤(基本都能解决)
首先检查配置文件,确认配置文件没有问题(也可以直接用同事的配置文件(记得修改文件里的本地仓库地址)) 1.file->Invalidate Caches清除缓存重启(简单粗暴,但最有效) 2.刷新maven以及mvn clean,多刷几次,看看还有没有报红的依赖…...

【数据库】数据库并发控制的目标,可串行化序列的分析,并发控制调度器模型
数据库并发控制 专栏内容: 手写数据库toadb 本专栏主要介绍如何从零开发,开发的步骤,以及开发过程中的涉及的原理,遇到的问题等,让大家能跟上并且可以一起开发,让每个需要的人成为参与者。 本专栏会定期更…...

带头结点的双向循环链表
目录 带头结点的双向循环链表 1.存储定义 2.结点的创建 3.结点的初始化 4.尾插结点 5.尾删结点 6.头插结点 7.头删结点 8.查找并返回结点 9.在pos结点前插入结点 10.删除pos结点 11.打印链表 12.销毁链表 13.头插结点2.0版 14.尾插结点2.0版 前言: 当…...

2023年11月下旬大模型新动向集锦
2023年11月下旬大模型新动向集锦 2023.12.1版权声明:本文为博主chszs的原创文章,未经博主允许不得转载。 1、微软将向中国大陆开放Windows Copilot服务 据微软发布的消息,微软将在 2023 年 12 月 1 日面向中国大陆的企业和教育机构推出 We…...

有IP没有域名可以申请证书吗?
一、IP证书是什么? ip证书是用于公网ip地址的SSL证书,与我们通常所讲的SSL证书并无本质上的区别,但由于SSL证书通常颁发给域名,而组织机构需要公共ip地址的SSL证书,这类SSL证书就是我们所说的ip证书。ip证书具有安全、…...

【软件推荐】卸载360软件geek;护眼软件flux;
卸载360软件geek f.lux: software to make your life better (justgetflux.com) 卸载完扫描残留 护眼软件 hf.lux: software to make your life better (justgetflux.com)https://justgetflux.com/https://justgetflux.com/...

Module build failed: Error: ENOENT: no such file or directory
前言 这个错误通常发生在Node.js 和 vue,js项目中,当你试图访问一个不存在的文件或目录时。在大多数情况下,这是因为你的代码试图打开一个不存在的文件,或者你的构建系统(例如Webpack)需要一个配置文件,但找…...

Postgresql BatchInsert唯一键冲突及解决
Postgresql BatchInsert唯一键冲突及解决 当有唯一键冲突时,批量插入可能会报错; insert into tableA(sno,name,age,emp) values(),(),(); 会报错 insert into tableA(sno,name,age,emp) values(),(),() on conflict on contraint tableA_unique_key do …...

腾讯云AMD服务器标准型SA5实例AMD EPYC Bergamo处理器
腾讯云服务器标准型SA5实例是最新一代的标准型实例,CPU采用AMD EPYC™ Bergamo全新处理器,采用最新DDR5内存,默认网络优化,最高内网收发能力达4500万pps。腾讯云百科txybk.com分享腾讯云标准型SA5云服务器CPU、内存、网络、性能、…...