生成对抗网络(GAN)手写数字生成
文章目录
- 一、前言
- 二、前期工作
- 1. 设置GPU(如果使用的是CPU可以忽略这步)
- 二、什么是生成对抗网络
- 1. 简单介绍
- 2. 应用领域
- 三、网络结构
- 四、构建生成器
- 五、构建鉴别器
- 六、训练模型
- 1. 保存样例图片
- 2. 训练模型
- 七、生成动图
一、前言
我的环境:
- 语言环境:Python3.6.5
- 编译器:jupyter notebook
- 深度学习环境:TensorFlow2.4.1
往期精彩内容:
- 卷积神经网络(CNN)实现mnist手写数字识别
- 卷积神经网络(CNN)多种图片分类的实现
- 卷积神经网络(CNN)衣服图像分类的实现
- 卷积神经网络(CNN)鲜花识别
- 卷积神经网络(CNN)天气识别
- 卷积神经网络(VGG-16)识别海贼王草帽一伙
- 卷积神经网络(ResNet-50)鸟类识别
- 卷积神经网络(AlexNet)鸟类识别
- 卷积神经网络(CNN)识别验证码
- 卷积神经网络(Inception-ResNet-v2)交通标志识别
来自专栏:机器学习与深度学习算法推荐
二、前期工作
1. 设置GPU(如果使用的是CPU可以忽略这步)
import tensorflow as tfgpus = tf.config.list_physical_devices("GPU")if gpus:tf.config.experimental.set_memory_growth(gpus[0], True) #设置GPU显存用量按需使用tf.config.set_visible_devices([gpus[0]],"GPU")# 打印显卡信息,确认GPU可用
print(gpus)
from tensorflow.keras import layers, datasets, Sequential, Model, optimizers
from tensorflow.keras.layers import LeakyReLU, UpSampling2D, Conv2Dimport matplotlib.pyplot as plt
import numpy as np
import sys,os,pathlib
img_shape = (28, 28, 1)
latent_dim = 200
二、什么是生成对抗网络
1. 简单介绍
生成对抗网络(GAN) 包含生成器和判别器,两个模型通过对抗训练不断学习、进化。
生成器(Generator):生成数据(大部分情况下是图像),目的是“骗过”判别器。鉴别器(Discriminator):判断这张图像是真实的还是机器生成的,目的是找出生成器生成的“假数据”。
2. 应用领域
GAN 的应用十分广泛,它的应用包括图像合成、风格迁移、照片修复以及照片编辑,数据增强等等。
1)风格迁移
图像风格迁移是将图像A的风格转换到图像B中去,得到新的图像。
2)图像生成
GAN 不但能生成人脸,还能生成其他类型的图片,比如漫画人物。
三、网络结构
简单来讲,就是用生成器生成手写数字图像,用鉴别器鉴别图像的真假。二者相互对抗学习(卷),在对抗学习(卷)的过程中不断完善自己,直至生成器可以生成以假乱真的图片(鉴别器无法判断其真假)。结构图如下:

GAN步骤:
- 1.生成器(Generator)接收随机数并返回生成图像。
- 2.将生成的数字图像与实际数据集中的数字图像一起送到鉴别器(Discriminator)。
- 3.鉴别器(Discriminator)接收真实和假图像并返回概率,0到1之间的数字,1表示真,0表示假。
四、构建生成器
def build_generator():# ======================================= ## 生成器,输入一串随机数字生成图片# ======================================= #model = Sequential([layers.Dense(256, input_dim=latent_dim),layers.LeakyReLU(alpha=0.2), # 高级一点的激活函数layers.BatchNormalization(momentum=0.8), # BN 归一化layers.Dense(512),layers.LeakyReLU(alpha=0.2),layers.BatchNormalization(momentum=0.8),layers.Dense(1024),layers.LeakyReLU(alpha=0.2),layers.BatchNormalization(momentum=0.8),layers.Dense(np.prod(img_shape), activation='tanh'),layers.Reshape(img_shape)])noise = layers.Input(shape=(latent_dim,))img = model(noise)return Model(noise, img)
五、构建鉴别器
def build_discriminator():# ===================================== ## 鉴别器,对输入的图片进行判别真假# ===================================== #model = Sequential([layers.Flatten(input_shape=img_shape),layers.Dense(512),layers.LeakyReLU(alpha=0.2),layers.Dense(256),layers.LeakyReLU(alpha=0.2),layers.Dense(1, activation='sigmoid')])img = layers.Input(shape=img_shape)validity = model(img)return Model(img, validity)
# 创建判别器
discriminator = build_discriminator()
# 定义优化器
optimizer = tf.keras.optimizers.Adam(1e-4)
discriminator.compile(loss='binary_crossentropy',optimizer=optimizer,metrics=['accuracy'])# 创建生成器
generator = build_generator()
gan_input = layers.Input(shape=(latent_dim,))
img = generator(gan_input)# 对生成的假图片进行预测
validity = discriminator(img)
combined = Model(gan_input, validity)
combined.compile(loss='binary_crossentropy', optimizer=optimizer)
六、训练模型
1. 保存样例图片
def sample_images(epoch):"""保存样例图片"""row, col = 4, 4noise = np.random.normal(0, 1, (row*col, latent_dim))gen_imgs = generator.predict(noise)fig, axs = plt.subplots(row, col)cnt = 0for i in range(row):for j in range(col):axs[i,j].imshow(gen_imgs[cnt, :,:,0], cmap='gray')axs[i,j].axis('off')cnt += 1fig.savefig("images/%05d.png" % epoch)plt.close()
2. 训练模型
train_on_batch:函数接受单批数据,执行反向传播,然后更新模型参数,该批数据的大小可以是任意的,即,它不需要提供明确的批量大小,属于精细化控制训练模型。
def train(epochs, batch_size=128, sample_interval=50):# 加载数据(train_images,_), (_,_) = tf.keras.datasets.mnist.load_data()# 将图片标准化到 [-1, 1] 区间内 train_images = (train_images - 127.5) / 127.5# 数据train_images = np.expand_dims(train_images, axis=3)# 创建标签true = np.ones((batch_size, 1))fake = np.zeros((batch_size, 1))# 进行循环训练for epoch in range(epochs): # 随机选择 batch_size 张图片idx = np.random.randint(0, train_images.shape[0], batch_size)imgs = train_images[idx] # 生成噪音noise = np.random.normal(0, 1, (batch_size, latent_dim))# 生成器通过噪音生成图片,gen_imgs的shape为:(128, 28, 28, 1)gen_imgs = generator.predict(noise)# 训练鉴别器 d_loss_true = discriminator.train_on_batch(imgs, true)d_loss_fake = discriminator.train_on_batch(gen_imgs, fake)# 返回loss值d_loss = 0.5 * np.add(d_loss_true, d_loss_fake)# 训练生成器noise = np.random.normal(0, 1, (batch_size, latent_dim))g_loss = combined.train_on_batch(noise, true)print ("%d [D loss: %f, acc.: %.2f%%] [G loss: %f]" % (epoch, d_loss[0], 100*d_loss[1], g_loss))# 保存样例图片if epoch % sample_interval == 0:sample_images(epoch)
train(epochs=30000, batch_size=256, sample_interval=200)
七、生成动图
如果报错:ModuleNotFoundError: No module named 'imageio' 可以使用:pip install imageio 安装 imageio 库。
import imageiodef compose_gif():# 图片地址data_dir = "images_old"data_dir = pathlib.Path(data_dir)paths = list(data_dir.glob('*'))gif_images = []for path in paths:print(path)gif_images.append(imageio.imread(path))imageio.mimsave("test.gif",gif_images,fps=2)compose_gif()
相关文章:
生成对抗网络(GAN)手写数字生成
文章目录 一、前言二、前期工作1. 设置GPU(如果使用的是CPU可以忽略这步) 二、什么是生成对抗网络1. 简单介绍2. 应用领域 三、网络结构四、构建生成器五、构建鉴别器六、训练模型1. 保存样例图片2. 训练模型 七、生成动图 一、前言 我的环境࿱…...
LeetCode Hot100 31.下一个排列
题目: 整数数组的一个 排列 就是将其所有成员以序列或线性顺序排列。 例如,arr [1,2,3] ,以下这些都可以视作 arr 的排列:[1,2,3]、[1,3,2]、[3,1,2]、[2,3,1] 。 整数数组的 下一个排列 是指其整数的下一个字典序更大的排列…...
Redis主从与哨兵架构详解
目录 主从架构 主从环境搭建 主从复制流程 1. 全量复制 2. 部分复制 主从风暴 哨兵架构 概念 哨兵环境搭建 主从架构 主从环境搭建 1. 复制一份redis.conf文件, 修改下面几行配置 port 6380 pidfile /var/run/redis_6380.pid logfile "6380.log" dir /usr/…...
Linux:docker的数据管理(6)
数据管理操作*方便查看容器内产生的数据 *多容器间实现数据共享 两种管理方式数据卷 数据卷容器 1.数据卷 数据卷是一个供容器使用的特殊目录,位于容器中,可将宿主机的目录挂载到数据卷上,对数据卷的修改操作立刻可见,并且更新数…...
深入理解Zookeeper系列-1.初识Zoookeeper
👏作者简介:大家好,我是爱吃芝士的土豆倪,24届校招生Java选手,很高兴认识大家📕系列专栏:Spring源码、JUC源码、Kafka原理、分布式技术原理🔥如果感觉博主的文章还不错的话ÿ…...
芯片技术探索:了解构芯片的设计与制造之旅
芯片技术探索:了解构芯片的设计与制造之旅 一、引言 随着现代科技的飞速发展,芯片作为信息技术的核心,已经渗透到我们生活的方方面面。从智能手机、电视、汽车到医疗设备和工业控制系统,芯片在各个领域都发挥着至关重要的作用。然而,对于大多数人来说,芯片仍然是一个神秘…...
STM32 超声波模块(HC-SR04)
HC-SR04介绍 典型工作电压:5v (如果你的超声波模块没有工作,可以看一下是不是电压不够)超小静态工作电流:<2mA 感应角度:<15 (超声波模块,是一个范围式的探…...
ELK+Filebeat
Filebeat概述 1.Filebeat简介 Filebeat是一款轻量级的日志收集工具,可以在非JAVA环境下运行。 因此,Filebeat常被用在非JAVAf的服务器上用于替代Logstash,收集日志信息。实际上,Filebeat几乎可以起到与Logstash相同的作用&…...
MySql之锁表、锁行解决方案
查询正在使用的表,没有跑业务,一般情况下是锁表了 show open tables where in_use > 0 ;查看进程,可以看到Command类型(Sleep为阻塞线程) show processlist;kill事务,kill 进程Id kill 8193583;其他 …...
2023年第十六届山东省职业院校技能大赛中职组“网络安全”赛项竞赛正式试题
第十六届山东省职业院校技能大赛中职组 “网络安全”赛项竞赛试题 目录 一、竞赛时间 二、竞赛阶段 三、竞赛任务书内容 (一)拓扑图 (二)A模块基础设施设置/安全加固(200分) (三…...
JAVA 整合 AWS S3(Amazon Simple Storage Service)文件上传,分片上传,删除,下载
依赖 因为aws需要发送请求上传、下载等api,所以需要加上httpclient相关的依赖 <dependency><groupId>com.amazonaws</groupId><artifactId>aws-java-sdk-s3</artifactId><version>1.11.628</version> </dependency&…...
记录:Unity脚本的编写9.0
目录 射线一些准备工作编写代码 突然发现好像没有写过关于射线的内容,我就说怎么总感觉好像少了什么东西(心虚 那就在这里写一下关于射线的内容吧,将在这里实现射线检测鼠标点击的功能 射线 射线是一种在Unity中检测碰撞器或触发器的方法&am…...
共享单车停放(简单的struct结构运用)
本来不想写这题的,但是想想最近沉迷玩雨世界,班长又问我这题,就草草写了一下 代码如下: #include<stdio.h> #include<math.h> struct parking{int distance;int remain;int speed;int time;int jud; }parking[50]; …...
【Java8系列07】Java8日期处理
💝💝💝欢迎来到我的博客,很高兴能够在这里和您见面!希望您在这里可以感受到一份轻松愉快的氛围,不仅可以获得有趣的内容和知识,也可以畅所欲言、分享您的想法和见解。 推荐:kwan 的首页,持续学…...
为什么做CSGO搬砖的不直接去炒股呢?
首先,CS2并非只有一个交易平台,阿阳个人觉得像IGXE等交易平台一样是交易,况且我记得很早的时候我就开始用IGXE了,我记得最早的时候还是机器人发货,后来因为V社对于很多开箱网站的管控,所以让这种发货的方式…...
12月01日,每日信息差//阿里国际发布3款AI设计生态工具//美团买菜升级为“小象超市”//外国人永居证换新、6国游客免签来华
_灵感 🎖 阿里国际发布3款AI设计生态工具 🎄 AITO问界系列11月交付新车18827辆 🌍 美团买菜升级为“小象超市” 🌋 全球首个金融风控大模型国际标准出炉,由腾讯牵头制定 🎁 支付宝:支持外国人…...
ChatGPT探索:提示工程详解—程序员效率提升必备技能【文末送书】
文章目录 一.人工智能-ChatGPT1.1 ChatGPT简介1.2 ChatGPT探索:提示工程详解1.2 提示工程的优势 二.提示工程探索2.1 提示工程实例:2.2 英语学习助手2.3 Active-Prompt思维链(CoT)方法2.4 提示工程总结 三.文末推荐与福利3.1《Cha…...
Pytest做性能测试?
Pytest其实也是可以做性能测试或者基准测试的。是非常方便的。 可以考虑使用Pytest-benchmark类库进行。 安装pytest-benchmark 首先,确保已经安装了pytest和pytest-benchmark插件。可以使用以下命令安装插件: pip install pytest pytest-benchmark …...
Swagger各版本访问地址
2.9.x 访问地址: http://ip:port/{context-path}/swagger-ui.html 3.0.x 访问地址: http://ip:port/{context-path}/swagger-ui/index.html 3.0集成knife4j 访问地址: http://ip:port/{context-path}/doc.html...
docker-compose;私有镜像仓库harbor搭建;镜像推送到私有仓库harbor
docker-compose;私有镜像仓库harbor搭建;镜像推送到私有仓库harbor 文章目录 docker-compose;私有镜像仓库harbor搭建;镜像推送到私有仓库harbordocker-compose私有镜像仓库harbor搭建镜像推送到私有仓库harbor docker-compose D…...
2026制造业数字化转型:Agent委外加工成本智能核算功能详解与应用
站在2026年这个智能体(Agent)全面爆发的时间节点回望,企业数字化转型已从早期的“数据上云”演进到了“决策自动化”的深度应用阶段。根据IDC与Gartner联合发布的《2026年全球智能体产业发展白皮书》,具备深度业务逻辑处理能力的智…...
如何解决黑苹果USB端口识别问题:USBInjectAll内核扩展完整指南
如何解决黑苹果USB端口识别问题:USBInjectAll内核扩展完整指南 【免费下载链接】OS-X-USB-Inject-All Kext to inject all USB ports for the installed Intel EHCI/XHCI chipset automatically. 项目地址: https://gitcode.com/gh_mirrors/os/OS-X-USB-Inject-Al…...
jquery.inputmask插件介绍
目录 一、什么是 jQuery.inputmask? 主要应用场景 二、快速上手 1. 引入依赖文件 2. 基础用法 3. 掩码字符定义 三、高级功能 1. 自定义占位符 2. 完成回调 3. 扩展自定义字符 4. 重复掩码 5. 移除默认占位符 四、配合 Vue.js 使用 五、更多实用示例 …...
一种三菱MXF100-8 走CC LINK IE TSN 网络控制单轴伺服的功能块(可控30+轴)
三菱电机去年新推出了MX系列的PLC,其中最吸引人的应该就是本体网口支持CC Link TSN总线了。但MXF100系列的轴控功能,只有8轴和16轴两个版本,为了充分应用TSN的强大性能,作者手搓了一个直接读写对象字典实现单轴伺服定位控制的功能…...
告别ifconfig!用ip命令和ethtool搞定Linux网卡状态排查(附实战案例)
告别ifconfig!用ip命令和ethtool搞定Linux网卡状态排查(附实战案例) 在Linux服务器运维中,网络故障排查是最常见的任务之一。记得去年深夜处理一次线上事故时,面对一台突然失联的数据库服务器,我习惯性地敲…...
yt-fts高级配置技巧:数据库路径、Chroma设置与性能优化
yt-fts高级配置技巧:数据库路径、Chroma设置与性能优化 【免费下载链接】yt-fts YouTube Full Text Search - Search all of YouTube from the command line 项目地址: https://gitcode.com/gh_mirrors/yt/yt-fts yt-fts是一款强大的YouTube全文搜索工具&…...
软考中级《嵌入式系统设计师》全套备考资料(真题 + 教材 + 笔记)
大家好,今天给大家分享一份软考中级「嵌入式系统设计师」的完整备考资料包,从教材、真题到高频笔记全配齐,帮你省去整理资料的时间,直接进入高效备考状态! 📁 资料清单 这套资料覆盖了嵌入式系统设计师备考…...
GitHub史诗级泄露:3800个核心仓库被窃,TeamPCP如何通过VS Code扩展攻破全球最大代码平台
一、引言:全球开发者的至暗时刻 2026年5月20日,一则消息震惊了整个科技界:微软旗下全球最大代码托管平台GitHub确认,约3800个内部私有仓库被威胁组织TeamPCP窃取,涵盖GitHub Copilot、CodeQL、GitHub Actions、Codespa…...
(C语言)指针详解与应用
指针是C语言的灵魂,指针与底层硬件联系紧密,使用指针可操作数据的地址,实现数据的间接访问。指针即指针变量,用于存放其他数据单元,如变量、数组、结构体和函数的首地址。若指针存放了某个数据单元的首地址,…...
Sequin实战教程:构建企业级变更数据捕获管道
Sequin实战教程:构建企业级变更数据捕获管道 【免费下载链接】sequin Postgres change data capture to streams, queues, and search indexes like Kafka, SQS, Elasticsearch, HTTP endpoints, and more 项目地址: https://gitcode.com/gh_mirrors/se/sequin …...
