G1 GAN生成MNIST手写数字图像
- 🍨 本文为🔗365天深度学习训练营 中的学习记录博客
- 🍖 原作者:K同学啊
G1 GAN生成MNIST手写数字图像
1. 生成对抗网络 (GAN) 简介
生成对抗网络 (GAN) 是一种通过“对抗性”学习生成数据的深度学习模型,通常用于生成图像、视频等数据。GAN 由两个网络组成:
- 生成器 (Generator):用于生成假的数据样本,试图让判别器无法分辨其为假的。
- 判别器 (Discriminator):用于区分输入的数据是真实的还是生成器生成的。
GAN 的核心思想是,生成器和判别器通过相互对抗学习,生成器逐渐提高生成逼真数据的能力,而判别器逐渐提高区分真假数据的能力。最后,生成器生成的样本与真实样本之间的差异会越来越小。
GAN 的基本流程
- 判别器输入真实数据,判别器输出一个接近1的值,表示为真;
- 生成器生成假的数据,并试图欺骗判别器;
- 判别器输出接近0的值,表示为假;
- 生成器通过更新自身的参数,试图让判别器认为生成的数据是真实的。
GAN 的目标是使得生成器生成的假数据,能骗过判别器。
GAN 的损失函数
GAN 的训练目标是让生成器和判别器进行对抗训练,其损失函数分为两个部分:生成器损失和判别器损失。生成器的目标是最大化判别器判断生成数据为真的概率,判别器的目标是最大化正确判断真实数据和生成数据的概率。
判别器的损失函数定义为:
L D = − [ E x ∼ p data [ log D ( x ) ] + E z ∼ p z [ log ( 1 − D ( G ( z ) ) ) ] ] \mathcal{L}_D = - \left[ \mathbb{E}_{x \sim p_{\text{data}}} \left[ \log D(x) \right] + \mathbb{E}_{z \sim p_z} \left[ \log (1 - D(G(z))) \right] \right] LD=−[Ex∼pdata[logD(x)]+Ez∼pz[log(1−D(G(z)))]]
生成器的损失函数定义为:
L G = − E z ∼ p z [ log D ( G ( z ) ) ] \mathcal{L}_G = - \mathbb{E}_{z \sim p_z} \left[ \log D(G(z)) \right] LG=−Ez∼pz[logD(G(z))]
其中:
- ( D(x) ) 表示判别器对真实数据 ( x ) 判别为真的概率;
- ( G(z) ) 是生成器通过噪声 ( z ) 生成的假数据;
- ( D(G(z)) ) 表示判别器对生成器生成数据的输出(希望趋向于1)。
2. PyTorch 实现
下面使用 PyTorch 实现 GAN 生成 MNIST 手写数字图像。
2.1 导入库与超参数设置
import os
import numpy as np
import torch
import torch.nn as nn
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from torchvision.utils import save_image# 创建文件夹
os.makedirs('./output/images/', exist_ok=True)# 超参数设置
n_epochs = 50
batch_size = 64
lr = 0.0002
latent_dim = 100
img_size = 28
channels = 1
img_shape = (channels, img_size, img_size)
img_area = np.prod(img_shape)cuda = True if torch.cuda.is_available() else False
2.2 数据预处理
使用 torchvision.datasets.MNIST 下载并处理 MNIST 数据集。数据会被标准化到 [-1, 1] 区间,并通过 DataLoader 转化为可迭代数据集。
# 下载MNIST数据集并进行预处理
mnist = datasets.MNIST(root='./data', train=True, download=True,transform=transforms.Compose([transforms.Resize(img_size),transforms.ToTensor(),transforms.Normalize([0.5], [0.5])]))dataloader = DataLoader(mnist, batch_size=batch_size, shuffle=True)
2.3 定义生成器模型
生成器接受一个随机噪声向量 ( z ),通过多层线性变换和激活函数逐步生成一个 28x28 的图像。
class Generator(nn.Module):def __init__(self):super(Generator, self).__init__()def block(in_feat, out_feat, normalize=True):layers = [nn.Linear(in_feat, out_feat)]if normalize:layers.append(nn.BatchNorm1d(out_feat, 0.8))layers.append(nn.LeakyReLU(0.2, inplace=True))return layersself.model = nn.Sequential(*block(latent_dim, 128, normalize=False),*block(128, 256),*block(256, 512),*block(512, 1024),nn.Linear(1024, img_area),nn.Tanh())def forward(self, z):img = self.model(z)return img.view(img.size(0), *img_shape)
2.4 定义判别器模型
判别器是一个二分类网络,输入一个 28x28 的图像,输出一个表示真假概率的值。
class Discriminator(nn.Module):def __init__(self):super(Discriminator, self).__init__()self.model = nn.Sequential(nn.Linear(img_area, 512),nn.LeakyReLU(0.2, inplace=True),nn.Linear(512, 256),nn.LeakyReLU(0.2, inplace=True),nn.Linear(256, 1),nn.Sigmoid())def forward(self, img):img_flat = img.view(img.size(0), -1)validity = self.model(img_flat)return validity
2.5 定义优化器与损失函数
generator = Generator()
discriminator = Discriminator()# 定义损失函数
criterion = nn.BCELoss()# 定义生成器和判别器的优化器
optimizer_G = torch.optim.Adam(generator.parameters(), lr=lr, betas=(0.5, 0.999))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=lr, betas=(0.5, 0.999))if cuda:generator.cuda()discriminator.cuda()criterion.cuda()
2.6 训练过程
2.6.1 训练判别器
判别器需要区分真实图像和生成的假图像,通过两个损失值相加,更新判别器的参数。
real_img = Variable(imgs.type(torch.cuda.FloatTensor))
real_label = Variable(torch.ones(imgs.size(0), 1).cuda())
fake_label = Variable(torch.zeros(imgs.size(0), 1).cuda())real_out = discriminator(real_img)
loss_real = criterion(real_out, real_label)z = Variable(torch.randn(imgs.size(0), latent_dim).cuda())
fake_img = generator(z).detach()
fake_out = discriminator(fake_img)
loss_fake = criterion(fake_out, fake_label)loss_D = loss_real + loss_fake
optimizer_D.zero_grad()
loss_D.backward()
optimizer_D.step()
2.6.2 训练生成器
生成器的目标是让判别器认为生成的数据是真实的,因此生成器的损失是判别器对假图像的输出。
z = Variable(torch.randn(imgs.size(0), latent_dim).cuda())
fake_img = generator(z)
output = discriminator(fake_img)loss_G = criterion(output, real_label)
optimizer_G.zero_grad()
loss_G.backward()
optimizer_G.step()

2.7 保存与可视化生成图像
if batches_done % sample_interval == 0:save_image(fake_img.data[:25], "./output/images/%d.png" % batches_done, nrow=5, normalize=True)

4. 总结
这周学习了如何使用 PyTorch 实现生成对抗网络 (GAN) 来生成 MNIST 手写数字图像。GAN 通过生成器与判别器之间的对抗学习,不断提升生成图像的质量,是一种非常强大的生成模型。可以在论文中将其作为数据增强的一种方式。
相关文章:
G1 GAN生成MNIST手写数字图像
🍨 本文为🔗365天深度学习训练营 中的学习记录博客🍖 原作者:K同学啊 G1 GAN生成MNIST手写数字图像 1. 生成对抗网络 (GAN) 简介 生成对抗网络 (GAN) 是一种通过“对抗性”学习生成数据的深度学习模型,通常用于生成…...
WPFDeveloper正式版发布
WPFDeveloper WPFDeveloper一个基于WPF自定义高级控件的WPF开发人员UI库,它提供了众多的自定义控件。 该项目的创建者和主要维护者是现役微软MVP 闫驚鏵: https://github.com/yanjinhuagood 该项目还有众多的维护者,详情可以访问github上的README&…...
实现鼠标经过某个元素时弹出提示框(通常称为“工具提示”或“悬浮提示”)
要实现鼠标经过某个元素时弹出提示框(通常称为“工具提示”或“悬浮提示”),你可以使用 JavaScript 结合 CSS 来创建这个效果。以下是详细步骤,包括 HTML、CSS 和 JavaScript 的代码示例。 HTML 结构 首先,创建一个简…...
【GAMES101笔记速查——Lecture 17 Materials and Appearances】
目录 1 材质和外观 1.1 自然界中,外观是光线和材质共同作用的结果 1.2 图形学中,什么是材质? 1.2.1 渲染方程严格正确,其中BRDF项决定了物体的材质 1.2.2 漫反射材质 (1)如何定义漫反射系数࿱…...
对于从vscode ssh到virtualBox的timeout记录
如题,解决方式如下: 1.把虚拟机关机退出来,在这个界面进行网络设置:选桥接网卡 2.然后再进系统,使用命令 ip addr查看如今的ip地址,应该和在本机里面看到的是一个网段 3.打开vscode,该干啥干…...
鸿蒙原生应用扬帆起航
就在2024年6月21日华为在开发者大会上发布了全新操作的系统HarmonyOS Next开发测试版,网友们把它称之为“称之为纯血鸿蒙”。因为在此之前鸿蒙系统底层式有两套基础架构的,一套是是Android的AOSP,一套是鸿蒙的Open Harmony,因为早…...
《计算机视觉》—— 表情识别
根据计算眼睛、嘴巴的变化,判断是什么表情结合以下两篇文章来理解表情识别的实现方法 基于 dilib 库的人脸检测 https://blog.csdn.net/weixin_73504499/article/details/142977202?spm1001.2014.3001.5501 基于 dlib 库的人脸关键点定位 https://blog.csdn.net/we…...
NVIDIA Aerial Omniverse
NVIDIA Aerial Omniverse 数字孪生助力打造新一代无线网络 文章目录 前言一、从链路级仿真到系统级仿真二、转变无线研发方式1. 开放且可定制的模块化平台2. 适用于 6G 标准化的 3GPP 兼容平台3. 部署前测试4. AI 和 ML 在数字孪生中的应用5. 高级物理精准的电磁求解器6. 合作伙…...
QT程序报错解决方案:Cannot queue arguments of type ‘QTextCharFormat‘ 或 ‘QTextCursor‘
项目场景: 项目场景:基于QT实现的C某程序,搭载在Linux环境中。 问题描述 执行程序时,发现log中报错如下内容: QObject::connect: Cannot queue arguments of type QTextCharFormat (Make sure QTextCharFormat is r…...
MySQL知识点_03
MySQL 命令大全 基础命令 操作命令连接到 MySQL 数据库mysql -u 用户名 -p查看所有数据库SHOW DATABASES;选择一个数据库USE 数据库名;查看所有表SHOW TABLES;查看表结构DESCRIBE 表名; 或 SHOW COLUMNS FROM 表名;创建一个新数据库CREATE DATABASE 数据库名;删除一个数据库D…...
leetcode:744. 寻找比目标字母大的最小字母(python3解法)
难度:简单 给你一个字符数组 letters,该数组按非递减顺序排序,以及一个字符 target。letters 里至少有两个不同的字符。 返回 letters 中大于 target 的最小的字符。如果不存在这样的字符,则返回 letters 的第一个字符。 示例 1&a…...
2015年-2016年 软件工程程序设计题(算法题)实战_c语言程序设计数据结构程序设计分析
文章目录 2015年1.c语言程序设计部分2.数据结构程序设计部分 2016年1.c语言程序设计部分2.数据结构程序设计部分 2015年 1.c语言程序设计部分 1.从一组数据中选择最大的和最小的输出。 void print_maxandmin(double a[],int length) //在一组数据中选择最大的或者最小的输出…...
整理一下实际开发和工作中Git工具的使用 (持续更新中)
介绍一下Git 在实际开发和工作中,Git工具的使用可以说是至关重要的,它不仅提高了团队协作的效率,还帮助开发者有效地管理代码版本。以下是对Git工具使用的扩展描述: 版本控制:Git能够跟踪代码的每一个修改记录&#x…...
Axios 的基本使用与 Fetch 的比较、在 Vue 项目中使用 Axios 的最佳实践
文章目录 1. 引言2. Axios 的基本使用2.1 安装 Axios2.2 发起 GET 请求2.3 发起 POST 请求2.4 请求拦截器2.5 设置全局配置 3. Axios 与 Fetch 的比较3.1 Axios 与 Fetch 的异同点3.2 Fetch 的基本使用3.3 使用 Fetch 处理 POST 请求 4. 讨论在 Vue 项目中使用 Axios 的最佳实践…...
Dockerfile样例
一、基础jar镜像制作 ## Dockerfile FROM registry.openanolis.cn/openanolis/anolisos:8.9 RUN mkdir /work ADD jdk17.tar.gz fonts.tar.gz /work/ RUN yum install fontconfig ttmkfdir -y && yum clean all && \chmod -R 755 /work/fonts ADD fonts.conf …...
MYSQL-多表查询
一、概述 1、定义 多表查询,也称为关联查询,指两个或更多个表一起完成查询操作。 2、前提条件 这些一起查询的表之间是有关系的(一对一、一对多),它们之间一定是有关联字段,这个关联字段可能建立了外键…...
MySQL改密码后不生效问题
MySQL修改密码后连接报密码错误 1.mysql修改密码命令: 这两种连接方式密码都必须修改 修改远程连接密码 ALTER USER ‘root’‘%’ IDENTIFIED BY ‘password’; 修改本地连接密码 ALTER USER ‘root’‘localhost’ IDENTIFIED BY ‘password’; 修改完后必须刷新…...
15分钟学Go 第1天:Go语言简介与特点
Go语言简介与特点 1. Go语言概述 Go语言(又称Golang)是由谷歌于2007年开发并在2009年正式发布的一种开源编程语言。它旨在简单、高效地进行软件开发,尤其适合于网络编程和分布式系统。 1.1 发展背景 多核处理器:随着计算机硬件…...
UDP/TCP协议
网络层只负责将数据包送达至目标主机,并不负责将数据包上交给上层的哪一个应用程序,这是传输层需要干的事,传输层通过端口来区分不同的应用程序。传输层协议主要分为UDP(用户数据报协议)和TCP(传输控制协议…...
gitee建立/取消关联仓库
目录 一、常用指令总结 二、建立关联具体操作 三、取消关联具体操作 一、常用指令总结 首先要选中要关联的文件,右击,选择Git Bash Here。 git remote -v //查看自己的文件有几个关联的仓库git init //初始化文件夹为git可远程建立链接的文件夹…...
IDEA运行Tomcat出现乱码问题解决汇总
最近正值期末周,有很多同学在写期末Java web作业时,运行tomcat出现乱码问题,经过多次解决与研究,我做了如下整理: 原因: IDEA本身编码与tomcat的编码与Windows编码不同导致,Windows 系统控制台…...
第19节 Node.js Express 框架
Express 是一个为Node.js设计的web开发框架,它基于nodejs平台。 Express 简介 Express是一个简洁而灵活的node.js Web应用框架, 提供了一系列强大特性帮助你创建各种Web应用,和丰富的HTTP工具。 使用Express可以快速地搭建一个完整功能的网站。 Expre…...
Docker 离线安装指南
参考文章 1、确认操作系统类型及内核版本 Docker依赖于Linux内核的一些特性,不同版本的Docker对内核版本有不同要求。例如,Docker 17.06及之后的版本通常需要Linux内核3.10及以上版本,Docker17.09及更高版本对应Linux内核4.9.x及更高版本。…...
练习(含atoi的模拟实现,自定义类型等练习)
一、结构体大小的计算及位段 (结构体大小计算及位段 详解请看:自定义类型:结构体进阶-CSDN博客) 1.在32位系统环境,编译选项为4字节对齐,那么sizeof(A)和sizeof(B)是多少? #pragma pack(4)st…...
现代密码学 | 椭圆曲线密码学—附py代码
Elliptic Curve Cryptography 椭圆曲线密码学(ECC)是一种基于有限域上椭圆曲线数学特性的公钥加密技术。其核心原理涉及椭圆曲线的代数性质、离散对数问题以及有限域上的运算。 椭圆曲线密码学是多种数字签名算法的基础,例如椭圆曲线数字签…...
关于 WASM:1. WASM 基础原理
一、WASM 简介 1.1 WebAssembly 是什么? WebAssembly(WASM) 是一种能在现代浏览器中高效运行的二进制指令格式,它不是传统的编程语言,而是一种 低级字节码格式,可由高级语言(如 C、C、Rust&am…...
稳定币的深度剖析与展望
一、引言 在当今数字化浪潮席卷全球的时代,加密货币作为一种新兴的金融现象,正以前所未有的速度改变着我们对传统货币和金融体系的认知。然而,加密货币市场的高度波动性却成为了其广泛应用和普及的一大障碍。在这样的背景下,稳定…...
python报错No module named ‘tensorflow.keras‘
是由于不同版本的tensorflow下的keras所在的路径不同,结合所安装的tensorflow的目录结构修改from语句即可。 原语句: from tensorflow.keras.layers import Conv1D, MaxPooling1D, LSTM, Dense 修改后: from tensorflow.python.keras.lay…...
视频行为标注工具BehaviLabel(源码+使用介绍+Windows.Exe版本)
前言: 最近在做行为检测相关的模型,用的是时空图卷积网络(STGCN),但原有kinetic-400数据集数据质量较低,需要进行细粒度的标注,同时粗略搜了下已有开源工具基本都集中于图像分割这块,…...
CVE-2020-17519源码分析与漏洞复现(Flink 任意文件读取)
漏洞概览 漏洞名称:Apache Flink REST API 任意文件读取漏洞CVE编号:CVE-2020-17519CVSS评分:7.5影响版本:Apache Flink 1.11.0、1.11.1、1.11.2修复版本:≥ 1.11.3 或 ≥ 1.12.0漏洞类型:路径遍历&#x…...
