当前位置: 首页 > news >正文

生成对抗网络(GAN)手写数字生成

文章目录

  • 一、前言
  • 二、前期工作
    • 1. 设置GPU(如果使用的是CPU可以忽略这步)
  • 二、什么是生成对抗网络
    • 1. 简单介绍
    • 2. 应用领域
  • 三、网络结构
  • 四、构建生成器
  • 五、构建鉴别器
  • 六、训练模型
    • 1. 保存样例图片
    • 2. 训练模型
  • 七、生成动图

一、前言

我的环境:

  • 语言环境:Python3.6.5
  • 编译器:jupyter notebook
  • 深度学习环境:TensorFlow2.4.1

往期精彩内容:

  • 卷积神经网络(CNN)实现mnist手写数字识别
  • 卷积神经网络(CNN)多种图片分类的实现
  • 卷积神经网络(CNN)衣服图像分类的实现
  • 卷积神经网络(CNN)鲜花识别
  • 卷积神经网络(CNN)天气识别
  • 卷积神经网络(VGG-16)识别海贼王草帽一伙
  • 卷积神经网络(ResNet-50)鸟类识别
  • 卷积神经网络(AlexNet)鸟类识别
  • 卷积神经网络(CNN)识别验证码
  • 卷积神经网络(Inception-ResNet-v2)交通标志识别

来自专栏:机器学习与深度学习算法推荐

二、前期工作

1. 设置GPU(如果使用的是CPU可以忽略这步)

import tensorflow as tfgpus = tf.config.list_physical_devices("GPU")if gpus:tf.config.experimental.set_memory_growth(gpus[0], True)  #设置GPU显存用量按需使用tf.config.set_visible_devices([gpus[0]],"GPU")# 打印显卡信息,确认GPU可用
print(gpus)
from tensorflow.keras import layers, datasets, Sequential, Model, optimizers
from tensorflow.keras.layers import LeakyReLU, UpSampling2D, Conv2Dimport matplotlib.pyplot as plt
import numpy             as np
import sys,os,pathlib
img_shape  = (28, 28, 1)
latent_dim = 200

二、什么是生成对抗网络

1. 简单介绍

生成对抗网络(GAN) 包含生成器和判别器,两个模型通过对抗训练不断学习、进化。

  • 生成器(Generator):生成数据(大部分情况下是图像),目的是“骗过”判别器。
  • 鉴别器(Discriminator):判断这张图像是真实的还是机器生成的,目的是找出生成器生成的“假数据”。

2. 应用领域

GAN 的应用十分广泛,它的应用包括图像合成、风格迁移、照片修复以及照片编辑,数据增强等等。

1)风格迁移

图像风格迁移是将图像A的风格转换到图像B中去,得到新的图像。

2)图像生成

GAN 不但能生成人脸,还能生成其他类型的图片,比如漫画人物。

三、网络结构

简单来讲,就是用生成器生成手写数字图像,用鉴别器鉴别图像的真假。二者相互对抗学习(卷),在对抗学习(卷)的过程中不断完善自己,直至生成器可以生成以假乱真的图片(鉴别器无法判断其真假)。结构图如下:

在这里插入图片描述

GAN步骤:

  • 1.生成器(Generator)接收随机数并返回生成图像。
  • 2.将生成的数字图像与实际数据集中的数字图像一起送到鉴别器(Discriminator)。
  • 3.鉴别器(Discriminator)接收真实和假图像并返回概率,0到1之间的数字,1表示真,0表示假。

四、构建生成器

def build_generator():# ======================================= ##     生成器,输入一串随机数字生成图片# ======================================= #model = Sequential([layers.Dense(256, input_dim=latent_dim),layers.LeakyReLU(alpha=0.2),               # 高级一点的激活函数layers.BatchNormalization(momentum=0.8),   # BN 归一化layers.Dense(512),layers.LeakyReLU(alpha=0.2),layers.BatchNormalization(momentum=0.8),layers.Dense(1024),layers.LeakyReLU(alpha=0.2),layers.BatchNormalization(momentum=0.8),layers.Dense(np.prod(img_shape), activation='tanh'),layers.Reshape(img_shape)])noise = layers.Input(shape=(latent_dim,))img = model(noise)return Model(noise, img)

五、构建鉴别器

def build_discriminator():# ===================================== ##   鉴别器,对输入的图片进行判别真假# ===================================== #model = Sequential([layers.Flatten(input_shape=img_shape),layers.Dense(512),layers.LeakyReLU(alpha=0.2),layers.Dense(256),layers.LeakyReLU(alpha=0.2),layers.Dense(1, activation='sigmoid')])img = layers.Input(shape=img_shape)validity = model(img)return Model(img, validity)
# 创建判别器
discriminator = build_discriminator()
# 定义优化器
optimizer = tf.keras.optimizers.Adam(1e-4)
discriminator.compile(loss='binary_crossentropy',optimizer=optimizer,metrics=['accuracy'])# 创建生成器 
generator = build_generator()
gan_input = layers.Input(shape=(latent_dim,))
img = generator(gan_input)# 对生成的假图片进行预测
validity = discriminator(img)
combined = Model(gan_input, validity)
combined.compile(loss='binary_crossentropy', optimizer=optimizer)

六、训练模型

1. 保存样例图片

def sample_images(epoch):"""保存样例图片"""row, col = 4, 4noise = np.random.normal(0, 1, (row*col, latent_dim))gen_imgs = generator.predict(noise)fig, axs = plt.subplots(row, col)cnt = 0for i in range(row):for j in range(col):axs[i,j].imshow(gen_imgs[cnt, :,:,0], cmap='gray')axs[i,j].axis('off')cnt += 1fig.savefig("images/%05d.png" % epoch)plt.close()

2. 训练模型

train_on_batch:函数接受单批数据,执行反向传播,然后更新模型参数,该批数据的大小可以是任意的,即,它不需要提供明确的批量大小,属于精细化控制训练模型。

def train(epochs, batch_size=128, sample_interval=50):# 加载数据(train_images,_), (_,_) = tf.keras.datasets.mnist.load_data()# 将图片标准化到 [-1, 1] 区间内   train_images = (train_images - 127.5) / 127.5# 数据train_images = np.expand_dims(train_images, axis=3)# 创建标签true = np.ones((batch_size, 1))fake = np.zeros((batch_size, 1))# 进行循环训练for epoch in range(epochs): # 随机选择 batch_size 张图片idx = np.random.randint(0, train_images.shape[0], batch_size)imgs = train_images[idx]      # 生成噪音noise = np.random.normal(0, 1, (batch_size, latent_dim))# 生成器通过噪音生成图片,gen_imgs的shape为:(128, 28, 28, 1)gen_imgs = generator.predict(noise)# 训练鉴别器 d_loss_true = discriminator.train_on_batch(imgs, true)d_loss_fake = discriminator.train_on_batch(gen_imgs, fake)# 返回loss值d_loss = 0.5 * np.add(d_loss_true, d_loss_fake)# 训练生成器noise = np.random.normal(0, 1, (batch_size, latent_dim))g_loss = combined.train_on_batch(noise, true)print ("%d [D loss: %f, acc.: %.2f%%] [G loss: %f]" % (epoch, d_loss[0], 100*d_loss[1], g_loss))# 保存样例图片if epoch % sample_interval == 0:sample_images(epoch)
train(epochs=30000, batch_size=256, sample_interval=200)

七、生成动图

如果报错:ModuleNotFoundError: No module named 'imageio' 可以使用:pip install imageio 安装 imageio 库。

import imageiodef compose_gif():# 图片地址data_dir = "images_old"data_dir = pathlib.Path(data_dir)paths    = list(data_dir.glob('*'))gif_images = []for path in paths:print(path)gif_images.append(imageio.imread(path))imageio.mimsave("test.gif",gif_images,fps=2)compose_gif()

相关文章:

生成对抗网络(GAN)手写数字生成

文章目录 一、前言二、前期工作1. 设置GPU(如果使用的是CPU可以忽略这步) 二、什么是生成对抗网络1. 简单介绍2. 应用领域 三、网络结构四、构建生成器五、构建鉴别器六、训练模型1. 保存样例图片2. 训练模型 七、生成动图 一、前言 我的环境&#xff1…...

LeetCode Hot100 31.下一个排列

题目: 整数数组的一个 排列 就是将其所有成员以序列或线性顺序排列。 例如,arr [1,2,3] ,以下这些都可以视作 arr 的排列:[1,2,3]、[1,3,2]、[3,1,2]、[2,3,1] 。 整数数组的 下一个排列 是指其整数的下一个字典序更大的排列…...

Redis主从与哨兵架构详解

目录 主从架构 主从环境搭建 主从复制流程 1. 全量复制 2. 部分复制 主从风暴 哨兵架构 概念 哨兵环境搭建 主从架构 主从环境搭建 1. 复制一份redis.conf文件, 修改下面几行配置 port 6380 pidfile /var/run/redis_6380.pid logfile "6380.log" dir /usr/…...

Linux:docker的数据管理(6)

数据管理操作*方便查看容器内产生的数据 *多容器间实现数据共享 两种管理方式数据卷 数据卷容器 1.数据卷 数据卷是一个供容器使用的特殊目录,位于容器中,可将宿主机的目录挂载到数据卷上,对数据卷的修改操作立刻可见,并且更新数…...

深入理解Zookeeper系列-1.初识Zoookeeper

👏作者简介:大家好,我是爱吃芝士的土豆倪,24届校招生Java选手,很高兴认识大家📕系列专栏:Spring源码、JUC源码、Kafka原理、分布式技术原理🔥如果感觉博主的文章还不错的话&#xff…...

芯片技术探索:了解构芯片的设计与制造之旅

芯片技术探索:了解构芯片的设计与制造之旅 一、引言 随着现代科技的飞速发展,芯片作为信息技术的核心,已经渗透到我们生活的方方面面。从智能手机、电视、汽车到医疗设备和工业控制系统,芯片在各个领域都发挥着至关重要的作用。然而,对于大多数人来说,芯片仍然是一个神秘…...

STM32 超声波模块(HC-SR04)

HC-SR04介绍 典型工作电压&#xff1a;5v &#xff08;如果你的超声波模块没有工作&#xff0c;可以看一下是不是电压不够&#xff09;超小静态工作电流&#xff1a;<2mA 感应角度&#xff1a;<15 &#xff08;超声波模块&#xff0c;是一个范围式的探…...

ELK+Filebeat

Filebeat概述 1.Filebeat简介 Filebeat是一款轻量级的日志收集工具&#xff0c;可以在非JAVA环境下运行。 因此&#xff0c;Filebeat常被用在非JAVAf的服务器上用于替代Logstash&#xff0c;收集日志信息。实际上&#xff0c;Filebeat几乎可以起到与Logstash相同的作用&…...

MySql之锁表、锁行解决方案

查询正在使用的表&#xff0c;没有跑业务&#xff0c;一般情况下是锁表了 show open tables where in_use > 0 ;查看进程&#xff0c;可以看到Command类型&#xff08;Sleep为阻塞线程&#xff09; show processlist;kill事务&#xff0c;kill 进程Id kill 8193583;其他 …...

2023年第十六届山东省职业院校技能大赛中职组“网络安全”赛项竞赛正式试题

第十六届山东省职业院校技能大赛中职组 “网络安全”赛项竞赛试题 目录 一、竞赛时间 二、竞赛阶段 三、竞赛任务书内容 &#xff08;一&#xff09;拓扑图 &#xff08;二&#xff09;A模块基础设施设置/安全加固&#xff08;200分&#xff09; &#xff08;三&#xf…...

JAVA 整合 AWS S3(Amazon Simple Storage Service)文件上传,分片上传,删除,下载

依赖 因为aws需要发送请求上传、下载等api&#xff0c;所以需要加上httpclient相关的依赖 <dependency><groupId>com.amazonaws</groupId><artifactId>aws-java-sdk-s3</artifactId><version>1.11.628</version> </dependency&…...

记录:Unity脚本的编写9.0

目录 射线一些准备工作编写代码 突然发现好像没有写过关于射线的内容&#xff0c;我就说怎么总感觉好像少了什么东西&#xff08;心虚 那就在这里写一下关于射线的内容吧&#xff0c;将在这里实现射线检测鼠标点击的功能 射线 射线是一种在Unity中检测碰撞器或触发器的方法&am…...

共享单车停放(简单的struct结构运用)

本来不想写这题的&#xff0c;但是想想最近沉迷玩雨世界&#xff0c;班长又问我这题&#xff0c;就草草写了一下 代码如下&#xff1a; #include<stdio.h> #include<math.h> struct parking{int distance;int remain;int speed;int time;int jud; }parking[50]; …...

【Java8系列07】Java8日期处理

&#x1f49d;&#x1f49d;&#x1f49d;欢迎来到我的博客&#xff0c;很高兴能够在这里和您见面&#xff01;希望您在这里可以感受到一份轻松愉快的氛围&#xff0c;不仅可以获得有趣的内容和知识&#xff0c;也可以畅所欲言、分享您的想法和见解。 推荐:kwan 的首页,持续学…...

为什么做CSGO搬砖的不直接去炒股呢?

首先&#xff0c;CS2并非只有一个交易平台&#xff0c;阿阳个人觉得像IGXE等交易平台一样是交易&#xff0c;况且我记得很早的时候我就开始用IGXE了&#xff0c;我记得最早的时候还是机器人发货&#xff0c;后来因为V社对于很多开箱网站的管控&#xff0c;所以让这种发货的方式…...

12月01日,每日信息差//阿里国际发布3款AI设计生态工具//美团买菜升级为“小象超市”//外国人永居证换新、6国游客免签来华

_灵感 &#x1f396; 阿里国际发布3款AI设计生态工具 &#x1f384; AITO问界系列11月交付新车18827辆 &#x1f30d; 美团买菜升级为“小象超市” &#x1f30b; 全球首个金融风控大模型国际标准出炉&#xff0c;由腾讯牵头制定 &#x1f381; 支付宝&#xff1a;支持外国人…...

ChatGPT探索:提示工程详解—程序员效率提升必备技能【文末送书】

文章目录 一.人工智能-ChatGPT1.1 ChatGPT简介1.2 ChatGPT探索&#xff1a;提示工程详解1.2 提示工程的优势 二.提示工程探索2.1 提示工程实例&#xff1a;2.2 英语学习助手2.3 Active-Prompt思维链&#xff08;CoT&#xff09;方法2.4 提示工程总结 三.文末推荐与福利3.1《Cha…...

Pytest做性能测试?

Pytest其实也是可以做性能测试或者基准测试的。是非常方便的。 可以考虑使用Pytest-benchmark类库进行。 安装pytest-benchmark 首先&#xff0c;确保已经安装了pytest和pytest-benchmark插件。可以使用以下命令安装插件&#xff1a; pip install pytest pytest-benchmark …...

Swagger各版本访问地址

2.9.x 访问地址: http://ip:port/{context-path}/swagger-ui.html 3.0.x 访问地址: http://ip:port/{context-path}/swagger-ui/index.html 3.0集成knife4j 访问地址: http://ip:port/{context-path}/doc.html...

docker-compose;私有镜像仓库harbor搭建;镜像推送到私有仓库harbor

docker-compose&#xff1b;私有镜像仓库harbor搭建&#xff1b;镜像推送到私有仓库harbor 文章目录 docker-compose&#xff1b;私有镜像仓库harbor搭建&#xff1b;镜像推送到私有仓库harbordocker-compose私有镜像仓库harbor搭建镜像推送到私有仓库harbor docker-compose D…...

java调用dll出现unsatisfiedLinkError以及JNA和JNI的区别

UnsatisfiedLinkError 在对接硬件设备中&#xff0c;我们会遇到使用 java 调用 dll文件 的情况&#xff0c;此时大概率出现UnsatisfiedLinkError链接错误&#xff0c;原因可能有如下几种 类名错误包名错误方法名参数错误使用 JNI 协议调用&#xff0c;结果 dll 未实现 JNI 协…...

基于服务器使用 apt 安装、配置 Nginx

&#x1f9fe; 一、查看可安装的 Nginx 版本 首先&#xff0c;你可以运行以下命令查看可用版本&#xff1a; apt-cache madison nginx-core输出示例&#xff1a; nginx-core | 1.18.0-6ubuntu14.6 | http://archive.ubuntu.com/ubuntu focal-updates/main amd64 Packages ng…...

Golang dig框架与GraphQL的完美结合

将 Go 的 Dig 依赖注入框架与 GraphQL 结合使用&#xff0c;可以显著提升应用程序的可维护性、可测试性以及灵活性。 Dig 是一个强大的依赖注入容器&#xff0c;能够帮助开发者更好地管理复杂的依赖关系&#xff0c;而 GraphQL 则是一种用于 API 的查询语言&#xff0c;能够提…...

2.Vue编写一个app

1.src中重要的组成 1.1main.ts // 引入createApp用于创建应用 import { createApp } from "vue"; // 引用App根组件 import App from ./App.vue;createApp(App).mount(#app)1.2 App.vue 其中要写三种标签 <template> <!--html--> </template>…...

高等数学(下)题型笔记(八)空间解析几何与向量代数

目录 0 前言 1 向量的点乘 1.1 基本公式 1.2 例题 2 向量的叉乘 2.1 基础知识 2.2 例题 3 空间平面方程 3.1 基础知识 3.2 例题 4 空间直线方程 4.1 基础知识 4.2 例题 5 旋转曲面及其方程 5.1 基础知识 5.2 例题 6 空间曲面的法线与切平面 6.1 基础知识 6.2…...

WEB3全栈开发——面试专业技能点P2智能合约开发(Solidity)

一、Solidity合约开发 下面是 Solidity 合约开发 的概念、代码示例及讲解&#xff0c;适合用作学习或写简历项目背景说明。 &#x1f9e0; 一、概念简介&#xff1a;Solidity 合约开发 Solidity 是一种专门为 以太坊&#xff08;Ethereum&#xff09;平台编写智能合约的高级编…...

大模型多显卡多服务器并行计算方法与实践指南

一、分布式训练概述 大规模语言模型的训练通常需要分布式计算技术,以解决单机资源不足的问题。分布式训练主要分为两种模式: 数据并行:将数据分片到不同设备,每个设备拥有完整的模型副本 模型并行:将模型分割到不同设备,每个设备处理部分模型计算 现代大模型训练通常结合…...

视频行为标注工具BehaviLabel(源码+使用介绍+Windows.Exe版本)

前言&#xff1a; 最近在做行为检测相关的模型&#xff0c;用的是时空图卷积网络&#xff08;STGCN&#xff09;&#xff0c;但原有kinetic-400数据集数据质量较低&#xff0c;需要进行细粒度的标注&#xff0c;同时粗略搜了下已有开源工具基本都集中于图像分割这块&#xff0c…...

Python Einops库:深度学习中的张量操作革命

Einops&#xff08;爱因斯坦操作库&#xff09;就像给张量操作戴上了一副"语义眼镜"——让你用人类能理解的方式告诉计算机如何操作多维数组。这个基于爱因斯坦求和约定的库&#xff0c;用类似自然语言的表达式替代了晦涩的API调用&#xff0c;彻底改变了深度学习工程…...

python爬虫——气象数据爬取

一、导入库与全局配置 python 运行 import json import datetime import time import requests from sqlalchemy import create_engine import csv import pandas as pd作用&#xff1a; 引入数据解析、网络请求、时间处理、数据库操作等所需库。requests&#xff1a;发送 …...