生成对抗网络(GAN)手写数字生成
文章目录
- 一、前言
- 二、前期工作
- 1. 设置GPU(如果使用的是CPU可以忽略这步)
- 二、什么是生成对抗网络
- 1. 简单介绍
- 2. 应用领域
- 三、网络结构
- 四、构建生成器
- 五、构建鉴别器
- 六、训练模型
- 1. 保存样例图片
- 2. 训练模型
- 七、生成动图
一、前言
我的环境:
- 语言环境:Python3.6.5
- 编译器:jupyter notebook
- 深度学习环境:TensorFlow2.4.1
往期精彩内容:
- 卷积神经网络(CNN)实现mnist手写数字识别
- 卷积神经网络(CNN)多种图片分类的实现
- 卷积神经网络(CNN)衣服图像分类的实现
- 卷积神经网络(CNN)鲜花识别
- 卷积神经网络(CNN)天气识别
- 卷积神经网络(VGG-16)识别海贼王草帽一伙
- 卷积神经网络(ResNet-50)鸟类识别
- 卷积神经网络(AlexNet)鸟类识别
- 卷积神经网络(CNN)识别验证码
- 卷积神经网络(Inception-ResNet-v2)交通标志识别
来自专栏:机器学习与深度学习算法推荐
二、前期工作
1. 设置GPU(如果使用的是CPU可以忽略这步)
import tensorflow as tfgpus = tf.config.list_physical_devices("GPU")if gpus:tf.config.experimental.set_memory_growth(gpus[0], True) #设置GPU显存用量按需使用tf.config.set_visible_devices([gpus[0]],"GPU")# 打印显卡信息,确认GPU可用
print(gpus)
from tensorflow.keras import layers, datasets, Sequential, Model, optimizers
from tensorflow.keras.layers import LeakyReLU, UpSampling2D, Conv2Dimport matplotlib.pyplot as plt
import numpy as np
import sys,os,pathlib
img_shape = (28, 28, 1)
latent_dim = 200
二、什么是生成对抗网络
1. 简单介绍
生成对抗网络(GAN) 包含生成器和判别器,两个模型通过对抗训练不断学习、进化。
生成器(Generator):生成数据(大部分情况下是图像),目的是“骗过”判别器。鉴别器(Discriminator):判断这张图像是真实的还是机器生成的,目的是找出生成器生成的“假数据”。
2. 应用领域
GAN 的应用十分广泛,它的应用包括图像合成、风格迁移、照片修复以及照片编辑,数据增强等等。
1)风格迁移
图像风格迁移是将图像A的风格转换到图像B中去,得到新的图像。
2)图像生成
GAN 不但能生成人脸,还能生成其他类型的图片,比如漫画人物。
三、网络结构
简单来讲,就是用生成器生成手写数字图像,用鉴别器鉴别图像的真假。二者相互对抗学习(卷),在对抗学习(卷)的过程中不断完善自己,直至生成器可以生成以假乱真的图片(鉴别器无法判断其真假)。结构图如下:

GAN步骤:
- 1.生成器(Generator)接收随机数并返回生成图像。
- 2.将生成的数字图像与实际数据集中的数字图像一起送到鉴别器(Discriminator)。
- 3.鉴别器(Discriminator)接收真实和假图像并返回概率,0到1之间的数字,1表示真,0表示假。
四、构建生成器
def build_generator():# ======================================= ## 生成器,输入一串随机数字生成图片# ======================================= #model = Sequential([layers.Dense(256, input_dim=latent_dim),layers.LeakyReLU(alpha=0.2), # 高级一点的激活函数layers.BatchNormalization(momentum=0.8), # BN 归一化layers.Dense(512),layers.LeakyReLU(alpha=0.2),layers.BatchNormalization(momentum=0.8),layers.Dense(1024),layers.LeakyReLU(alpha=0.2),layers.BatchNormalization(momentum=0.8),layers.Dense(np.prod(img_shape), activation='tanh'),layers.Reshape(img_shape)])noise = layers.Input(shape=(latent_dim,))img = model(noise)return Model(noise, img)
五、构建鉴别器
def build_discriminator():# ===================================== ## 鉴别器,对输入的图片进行判别真假# ===================================== #model = Sequential([layers.Flatten(input_shape=img_shape),layers.Dense(512),layers.LeakyReLU(alpha=0.2),layers.Dense(256),layers.LeakyReLU(alpha=0.2),layers.Dense(1, activation='sigmoid')])img = layers.Input(shape=img_shape)validity = model(img)return Model(img, validity)
# 创建判别器
discriminator = build_discriminator()
# 定义优化器
optimizer = tf.keras.optimizers.Adam(1e-4)
discriminator.compile(loss='binary_crossentropy',optimizer=optimizer,metrics=['accuracy'])# 创建生成器
generator = build_generator()
gan_input = layers.Input(shape=(latent_dim,))
img = generator(gan_input)# 对生成的假图片进行预测
validity = discriminator(img)
combined = Model(gan_input, validity)
combined.compile(loss='binary_crossentropy', optimizer=optimizer)
六、训练模型
1. 保存样例图片
def sample_images(epoch):"""保存样例图片"""row, col = 4, 4noise = np.random.normal(0, 1, (row*col, latent_dim))gen_imgs = generator.predict(noise)fig, axs = plt.subplots(row, col)cnt = 0for i in range(row):for j in range(col):axs[i,j].imshow(gen_imgs[cnt, :,:,0], cmap='gray')axs[i,j].axis('off')cnt += 1fig.savefig("images/%05d.png" % epoch)plt.close()
2. 训练模型
train_on_batch:函数接受单批数据,执行反向传播,然后更新模型参数,该批数据的大小可以是任意的,即,它不需要提供明确的批量大小,属于精细化控制训练模型。
def train(epochs, batch_size=128, sample_interval=50):# 加载数据(train_images,_), (_,_) = tf.keras.datasets.mnist.load_data()# 将图片标准化到 [-1, 1] 区间内 train_images = (train_images - 127.5) / 127.5# 数据train_images = np.expand_dims(train_images, axis=3)# 创建标签true = np.ones((batch_size, 1))fake = np.zeros((batch_size, 1))# 进行循环训练for epoch in range(epochs): # 随机选择 batch_size 张图片idx = np.random.randint(0, train_images.shape[0], batch_size)imgs = train_images[idx] # 生成噪音noise = np.random.normal(0, 1, (batch_size, latent_dim))# 生成器通过噪音生成图片,gen_imgs的shape为:(128, 28, 28, 1)gen_imgs = generator.predict(noise)# 训练鉴别器 d_loss_true = discriminator.train_on_batch(imgs, true)d_loss_fake = discriminator.train_on_batch(gen_imgs, fake)# 返回loss值d_loss = 0.5 * np.add(d_loss_true, d_loss_fake)# 训练生成器noise = np.random.normal(0, 1, (batch_size, latent_dim))g_loss = combined.train_on_batch(noise, true)print ("%d [D loss: %f, acc.: %.2f%%] [G loss: %f]" % (epoch, d_loss[0], 100*d_loss[1], g_loss))# 保存样例图片if epoch % sample_interval == 0:sample_images(epoch)
train(epochs=30000, batch_size=256, sample_interval=200)
七、生成动图
如果报错:ModuleNotFoundError: No module named 'imageio' 可以使用:pip install imageio 安装 imageio 库。
import imageiodef compose_gif():# 图片地址data_dir = "images_old"data_dir = pathlib.Path(data_dir)paths = list(data_dir.glob('*'))gif_images = []for path in paths:print(path)gif_images.append(imageio.imread(path))imageio.mimsave("test.gif",gif_images,fps=2)compose_gif()
相关文章:
生成对抗网络(GAN)手写数字生成
文章目录 一、前言二、前期工作1. 设置GPU(如果使用的是CPU可以忽略这步) 二、什么是生成对抗网络1. 简单介绍2. 应用领域 三、网络结构四、构建生成器五、构建鉴别器六、训练模型1. 保存样例图片2. 训练模型 七、生成动图 一、前言 我的环境࿱…...
LeetCode Hot100 31.下一个排列
题目: 整数数组的一个 排列 就是将其所有成员以序列或线性顺序排列。 例如,arr [1,2,3] ,以下这些都可以视作 arr 的排列:[1,2,3]、[1,3,2]、[3,1,2]、[2,3,1] 。 整数数组的 下一个排列 是指其整数的下一个字典序更大的排列…...
Redis主从与哨兵架构详解
目录 主从架构 主从环境搭建 主从复制流程 1. 全量复制 2. 部分复制 主从风暴 哨兵架构 概念 哨兵环境搭建 主从架构 主从环境搭建 1. 复制一份redis.conf文件, 修改下面几行配置 port 6380 pidfile /var/run/redis_6380.pid logfile "6380.log" dir /usr/…...
Linux:docker的数据管理(6)
数据管理操作*方便查看容器内产生的数据 *多容器间实现数据共享 两种管理方式数据卷 数据卷容器 1.数据卷 数据卷是一个供容器使用的特殊目录,位于容器中,可将宿主机的目录挂载到数据卷上,对数据卷的修改操作立刻可见,并且更新数…...
深入理解Zookeeper系列-1.初识Zoookeeper
👏作者简介:大家好,我是爱吃芝士的土豆倪,24届校招生Java选手,很高兴认识大家📕系列专栏:Spring源码、JUC源码、Kafka原理、分布式技术原理🔥如果感觉博主的文章还不错的话ÿ…...
芯片技术探索:了解构芯片的设计与制造之旅
芯片技术探索:了解构芯片的设计与制造之旅 一、引言 随着现代科技的飞速发展,芯片作为信息技术的核心,已经渗透到我们生活的方方面面。从智能手机、电视、汽车到医疗设备和工业控制系统,芯片在各个领域都发挥着至关重要的作用。然而,对于大多数人来说,芯片仍然是一个神秘…...
STM32 超声波模块(HC-SR04)
HC-SR04介绍 典型工作电压:5v (如果你的超声波模块没有工作,可以看一下是不是电压不够)超小静态工作电流:<2mA 感应角度:<15 (超声波模块,是一个范围式的探…...
ELK+Filebeat
Filebeat概述 1.Filebeat简介 Filebeat是一款轻量级的日志收集工具,可以在非JAVA环境下运行。 因此,Filebeat常被用在非JAVAf的服务器上用于替代Logstash,收集日志信息。实际上,Filebeat几乎可以起到与Logstash相同的作用&…...
MySql之锁表、锁行解决方案
查询正在使用的表,没有跑业务,一般情况下是锁表了 show open tables where in_use > 0 ;查看进程,可以看到Command类型(Sleep为阻塞线程) show processlist;kill事务,kill 进程Id kill 8193583;其他 …...
2023年第十六届山东省职业院校技能大赛中职组“网络安全”赛项竞赛正式试题
第十六届山东省职业院校技能大赛中职组 “网络安全”赛项竞赛试题 目录 一、竞赛时间 二、竞赛阶段 三、竞赛任务书内容 (一)拓扑图 (二)A模块基础设施设置/安全加固(200分) (三…...
JAVA 整合 AWS S3(Amazon Simple Storage Service)文件上传,分片上传,删除,下载
依赖 因为aws需要发送请求上传、下载等api,所以需要加上httpclient相关的依赖 <dependency><groupId>com.amazonaws</groupId><artifactId>aws-java-sdk-s3</artifactId><version>1.11.628</version> </dependency&…...
记录:Unity脚本的编写9.0
目录 射线一些准备工作编写代码 突然发现好像没有写过关于射线的内容,我就说怎么总感觉好像少了什么东西(心虚 那就在这里写一下关于射线的内容吧,将在这里实现射线检测鼠标点击的功能 射线 射线是一种在Unity中检测碰撞器或触发器的方法&am…...
共享单车停放(简单的struct结构运用)
本来不想写这题的,但是想想最近沉迷玩雨世界,班长又问我这题,就草草写了一下 代码如下: #include<stdio.h> #include<math.h> struct parking{int distance;int remain;int speed;int time;int jud; }parking[50]; …...
【Java8系列07】Java8日期处理
💝💝💝欢迎来到我的博客,很高兴能够在这里和您见面!希望您在这里可以感受到一份轻松愉快的氛围,不仅可以获得有趣的内容和知识,也可以畅所欲言、分享您的想法和见解。 推荐:kwan 的首页,持续学…...
为什么做CSGO搬砖的不直接去炒股呢?
首先,CS2并非只有一个交易平台,阿阳个人觉得像IGXE等交易平台一样是交易,况且我记得很早的时候我就开始用IGXE了,我记得最早的时候还是机器人发货,后来因为V社对于很多开箱网站的管控,所以让这种发货的方式…...
12月01日,每日信息差//阿里国际发布3款AI设计生态工具//美团买菜升级为“小象超市”//外国人永居证换新、6国游客免签来华
_灵感 🎖 阿里国际发布3款AI设计生态工具 🎄 AITO问界系列11月交付新车18827辆 🌍 美团买菜升级为“小象超市” 🌋 全球首个金融风控大模型国际标准出炉,由腾讯牵头制定 🎁 支付宝:支持外国人…...
ChatGPT探索:提示工程详解—程序员效率提升必备技能【文末送书】
文章目录 一.人工智能-ChatGPT1.1 ChatGPT简介1.2 ChatGPT探索:提示工程详解1.2 提示工程的优势 二.提示工程探索2.1 提示工程实例:2.2 英语学习助手2.3 Active-Prompt思维链(CoT)方法2.4 提示工程总结 三.文末推荐与福利3.1《Cha…...
Pytest做性能测试?
Pytest其实也是可以做性能测试或者基准测试的。是非常方便的。 可以考虑使用Pytest-benchmark类库进行。 安装pytest-benchmark 首先,确保已经安装了pytest和pytest-benchmark插件。可以使用以下命令安装插件: pip install pytest pytest-benchmark …...
Swagger各版本访问地址
2.9.x 访问地址: http://ip:port/{context-path}/swagger-ui.html 3.0.x 访问地址: http://ip:port/{context-path}/swagger-ui/index.html 3.0集成knife4j 访问地址: http://ip:port/{context-path}/doc.html...
docker-compose;私有镜像仓库harbor搭建;镜像推送到私有仓库harbor
docker-compose;私有镜像仓库harbor搭建;镜像推送到私有仓库harbor 文章目录 docker-compose;私有镜像仓库harbor搭建;镜像推送到私有仓库harbordocker-compose私有镜像仓库harbor搭建镜像推送到私有仓库harbor docker-compose D…...
深入浅出Asp.Net Core MVC应用开发系列-AspNetCore中的日志记录
ASP.NET Core 是一个跨平台的开源框架,用于在 Windows、macOS 或 Linux 上生成基于云的新式 Web 应用。 ASP.NET Core 中的日志记录 .NET 通过 ILogger API 支持高性能结构化日志记录,以帮助监视应用程序行为和诊断问题。 可以通过配置不同的记录提供程…...
Oracle查询表空间大小
1 查询数据库中所有的表空间以及表空间所占空间的大小 SELECTtablespace_name,sum( bytes ) / 1024 / 1024 FROMdba_data_files GROUP BYtablespace_name; 2 Oracle查询表空间大小及每个表所占空间的大小 SELECTtablespace_name,file_id,file_name,round( bytes / ( 1024 …...
基于uniapp+WebSocket实现聊天对话、消息监听、消息推送、聊天室等功能,多端兼容
基于 UniApp + WebSocket实现多端兼容的实时通讯系统,涵盖WebSocket连接建立、消息收发机制、多端兼容性配置、消息实时监听等功能,适配微信小程序、H5、Android、iOS等终端 目录 技术选型分析WebSocket协议优势UniApp跨平台特性WebSocket 基础实现连接管理消息收发连接…...
java调用dll出现unsatisfiedLinkError以及JNA和JNI的区别
UnsatisfiedLinkError 在对接硬件设备中,我们会遇到使用 java 调用 dll文件 的情况,此时大概率出现UnsatisfiedLinkError链接错误,原因可能有如下几种 类名错误包名错误方法名参数错误使用 JNI 协议调用,结果 dll 未实现 JNI 协…...
鸿蒙中用HarmonyOS SDK应用服务 HarmonyOS5开发一个医院查看报告小程序
一、开发环境准备 工具安装: 下载安装DevEco Studio 4.0(支持HarmonyOS 5)配置HarmonyOS SDK 5.0确保Node.js版本≥14 项目初始化: ohpm init harmony/hospital-report-app 二、核心功能模块实现 1. 报告列表…...
【2025年】解决Burpsuite抓不到https包的问题
环境:windows11 burpsuite:2025.5 在抓取https网站时,burpsuite抓取不到https数据包,只显示: 解决该问题只需如下三个步骤: 1、浏览器中访问 http://burp 2、下载 CA certificate 证书 3、在设置--隐私与安全--…...
Java多线程实现之Thread类深度解析
Java多线程实现之Thread类深度解析 一、多线程基础概念1.1 什么是线程1.2 多线程的优势1.3 Java多线程模型 二、Thread类的基本结构与构造函数2.1 Thread类的继承关系2.2 构造函数 三、创建和启动线程3.1 继承Thread类创建线程3.2 实现Runnable接口创建线程 四、Thread类的核心…...
Web 架构之 CDN 加速原理与落地实践
文章目录 一、思维导图二、正文内容(一)CDN 基础概念1. 定义2. 组成部分 (二)CDN 加速原理1. 请求路由2. 内容缓存3. 内容更新 (三)CDN 落地实践1. 选择 CDN 服务商2. 配置 CDN3. 集成到 Web 架构 …...
服务器--宝塔命令
一、宝塔面板安装命令 ⚠️ 必须使用 root 用户 或 sudo 权限执行! sudo su - 1. CentOS 系统: yum install -y wget && wget -O install.sh http://download.bt.cn/install/install_6.0.sh && sh install.sh2. Ubuntu / Debian 系统…...
CVE-2020-17519源码分析与漏洞复现(Flink 任意文件读取)
漏洞概览 漏洞名称:Apache Flink REST API 任意文件读取漏洞CVE编号:CVE-2020-17519CVSS评分:7.5影响版本:Apache Flink 1.11.0、1.11.1、1.11.2修复版本:≥ 1.11.3 或 ≥ 1.12.0漏洞类型:路径遍历&#x…...
