G1 GAN生成MNIST手写数字图像
- 🍨 本文为🔗365天深度学习训练营 中的学习记录博客
- 🍖 原作者:K同学啊
G1 GAN生成MNIST手写数字图像
1. 生成对抗网络 (GAN) 简介
生成对抗网络 (GAN) 是一种通过“对抗性”学习生成数据的深度学习模型,通常用于生成图像、视频等数据。GAN 由两个网络组成:
- 生成器 (Generator):用于生成假的数据样本,试图让判别器无法分辨其为假的。
- 判别器 (Discriminator):用于区分输入的数据是真实的还是生成器生成的。
GAN 的核心思想是,生成器和判别器通过相互对抗学习,生成器逐渐提高生成逼真数据的能力,而判别器逐渐提高区分真假数据的能力。最后,生成器生成的样本与真实样本之间的差异会越来越小。
GAN 的基本流程
- 判别器输入真实数据,判别器输出一个接近1的值,表示为真;
- 生成器生成假的数据,并试图欺骗判别器;
- 判别器输出接近0的值,表示为假;
- 生成器通过更新自身的参数,试图让判别器认为生成的数据是真实的。
GAN 的目标是使得生成器生成的假数据,能骗过判别器。
GAN 的损失函数
GAN 的训练目标是让生成器和判别器进行对抗训练,其损失函数分为两个部分:生成器损失和判别器损失。生成器的目标是最大化判别器判断生成数据为真的概率,判别器的目标是最大化正确判断真实数据和生成数据的概率。
判别器的损失函数定义为:
L D = − [ E x ∼ p data [ log D ( x ) ] + E z ∼ p z [ log ( 1 − D ( G ( z ) ) ) ] ] \mathcal{L}_D = - \left[ \mathbb{E}_{x \sim p_{\text{data}}} \left[ \log D(x) \right] + \mathbb{E}_{z \sim p_z} \left[ \log (1 - D(G(z))) \right] \right] LD=−[Ex∼pdata[logD(x)]+Ez∼pz[log(1−D(G(z)))]]
生成器的损失函数定义为:
L G = − E z ∼ p z [ log D ( G ( z ) ) ] \mathcal{L}_G = - \mathbb{E}_{z \sim p_z} \left[ \log D(G(z)) \right] LG=−Ez∼pz[logD(G(z))]
其中:
- ( D(x) ) 表示判别器对真实数据 ( x ) 判别为真的概率;
- ( G(z) ) 是生成器通过噪声 ( z ) 生成的假数据;
- ( D(G(z)) ) 表示判别器对生成器生成数据的输出(希望趋向于1)。
2. PyTorch 实现
下面使用 PyTorch 实现 GAN 生成 MNIST 手写数字图像。
2.1 导入库与超参数设置
import os
import numpy as np
import torch
import torch.nn as nn
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from torchvision.utils import save_image# 创建文件夹
os.makedirs('./output/images/', exist_ok=True)# 超参数设置
n_epochs = 50
batch_size = 64
lr = 0.0002
latent_dim = 100
img_size = 28
channels = 1
img_shape = (channels, img_size, img_size)
img_area = np.prod(img_shape)cuda = True if torch.cuda.is_available() else False
2.2 数据预处理
使用 torchvision.datasets.MNIST
下载并处理 MNIST 数据集。数据会被标准化到 [-1, 1] 区间,并通过 DataLoader
转化为可迭代数据集。
# 下载MNIST数据集并进行预处理
mnist = datasets.MNIST(root='./data', train=True, download=True,transform=transforms.Compose([transforms.Resize(img_size),transforms.ToTensor(),transforms.Normalize([0.5], [0.5])]))dataloader = DataLoader(mnist, batch_size=batch_size, shuffle=True)
2.3 定义生成器模型
生成器接受一个随机噪声向量 ( z ),通过多层线性变换和激活函数逐步生成一个 28x28 的图像。
class Generator(nn.Module):def __init__(self):super(Generator, self).__init__()def block(in_feat, out_feat, normalize=True):layers = [nn.Linear(in_feat, out_feat)]if normalize:layers.append(nn.BatchNorm1d(out_feat, 0.8))layers.append(nn.LeakyReLU(0.2, inplace=True))return layersself.model = nn.Sequential(*block(latent_dim, 128, normalize=False),*block(128, 256),*block(256, 512),*block(512, 1024),nn.Linear(1024, img_area),nn.Tanh())def forward(self, z):img = self.model(z)return img.view(img.size(0), *img_shape)
2.4 定义判别器模型
判别器是一个二分类网络,输入一个 28x28 的图像,输出一个表示真假概率的值。
class Discriminator(nn.Module):def __init__(self):super(Discriminator, self).__init__()self.model = nn.Sequential(nn.Linear(img_area, 512),nn.LeakyReLU(0.2, inplace=True),nn.Linear(512, 256),nn.LeakyReLU(0.2, inplace=True),nn.Linear(256, 1),nn.Sigmoid())def forward(self, img):img_flat = img.view(img.size(0), -1)validity = self.model(img_flat)return validity
2.5 定义优化器与损失函数
generator = Generator()
discriminator = Discriminator()# 定义损失函数
criterion = nn.BCELoss()# 定义生成器和判别器的优化器
optimizer_G = torch.optim.Adam(generator.parameters(), lr=lr, betas=(0.5, 0.999))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=lr, betas=(0.5, 0.999))if cuda:generator.cuda()discriminator.cuda()criterion.cuda()
2.6 训练过程
2.6.1 训练判别器
判别器需要区分真实图像和生成的假图像,通过两个损失值相加,更新判别器的参数。
real_img = Variable(imgs.type(torch.cuda.FloatTensor))
real_label = Variable(torch.ones(imgs.size(0), 1).cuda())
fake_label = Variable(torch.zeros(imgs.size(0), 1).cuda())real_out = discriminator(real_img)
loss_real = criterion(real_out, real_label)z = Variable(torch.randn(imgs.size(0), latent_dim).cuda())
fake_img = generator(z).detach()
fake_out = discriminator(fake_img)
loss_fake = criterion(fake_out, fake_label)loss_D = loss_real + loss_fake
optimizer_D.zero_grad()
loss_D.backward()
optimizer_D.step()
2.6.2 训练生成器
生成器的目标是让判别器认为生成的数据是真实的,因此生成器的损失是判别器对假图像的输出。
z = Variable(torch.randn(imgs.size(0), latent_dim).cuda())
fake_img = generator(z)
output = discriminator(fake_img)loss_G = criterion(output, real_label)
optimizer_G.zero_grad()
loss_G.backward()
optimizer_G.step()
2.7 保存与可视化生成图像
if batches_done % sample_interval == 0:save_image(fake_img.data[:25], "./output/images/%d.png" % batches_done, nrow=5, normalize=True)
4. 总结
这周学习了如何使用 PyTorch 实现生成对抗网络 (GAN) 来生成 MNIST 手写数字图像。GAN 通过生成器与判别器之间的对抗学习,不断提升生成图像的质量,是一种非常强大的生成模型。可以在论文中将其作为数据增强的一种方式。
相关文章:

G1 GAN生成MNIST手写数字图像
🍨 本文为🔗365天深度学习训练营 中的学习记录博客🍖 原作者:K同学啊 G1 GAN生成MNIST手写数字图像 1. 生成对抗网络 (GAN) 简介 生成对抗网络 (GAN) 是一种通过“对抗性”学习生成数据的深度学习模型,通常用于生成…...

WPFDeveloper正式版发布
WPFDeveloper WPFDeveloper一个基于WPF自定义高级控件的WPF开发人员UI库,它提供了众多的自定义控件。 该项目的创建者和主要维护者是现役微软MVP 闫驚鏵: https://github.com/yanjinhuagood 该项目还有众多的维护者,详情可以访问github上的README&…...

实现鼠标经过某个元素时弹出提示框(通常称为“工具提示”或“悬浮提示”)
要实现鼠标经过某个元素时弹出提示框(通常称为“工具提示”或“悬浮提示”),你可以使用 JavaScript 结合 CSS 来创建这个效果。以下是详细步骤,包括 HTML、CSS 和 JavaScript 的代码示例。 HTML 结构 首先,创建一个简…...

【GAMES101笔记速查——Lecture 17 Materials and Appearances】
目录 1 材质和外观 1.1 自然界中,外观是光线和材质共同作用的结果 1.2 图形学中,什么是材质? 1.2.1 渲染方程严格正确,其中BRDF项决定了物体的材质 1.2.2 漫反射材质 (1)如何定义漫反射系数࿱…...

对于从vscode ssh到virtualBox的timeout记录
如题,解决方式如下: 1.把虚拟机关机退出来,在这个界面进行网络设置:选桥接网卡 2.然后再进系统,使用命令 ip addr查看如今的ip地址,应该和在本机里面看到的是一个网段 3.打开vscode,该干啥干…...

鸿蒙原生应用扬帆起航
就在2024年6月21日华为在开发者大会上发布了全新操作的系统HarmonyOS Next开发测试版,网友们把它称之为“称之为纯血鸿蒙”。因为在此之前鸿蒙系统底层式有两套基础架构的,一套是是Android的AOSP,一套是鸿蒙的Open Harmony,因为早…...

《计算机视觉》—— 表情识别
根据计算眼睛、嘴巴的变化,判断是什么表情结合以下两篇文章来理解表情识别的实现方法 基于 dilib 库的人脸检测 https://blog.csdn.net/weixin_73504499/article/details/142977202?spm1001.2014.3001.5501 基于 dlib 库的人脸关键点定位 https://blog.csdn.net/we…...

NVIDIA Aerial Omniverse
NVIDIA Aerial Omniverse 数字孪生助力打造新一代无线网络 文章目录 前言一、从链路级仿真到系统级仿真二、转变无线研发方式1. 开放且可定制的模块化平台2. 适用于 6G 标准化的 3GPP 兼容平台3. 部署前测试4. AI 和 ML 在数字孪生中的应用5. 高级物理精准的电磁求解器6. 合作伙…...

QT程序报错解决方案:Cannot queue arguments of type ‘QTextCharFormat‘ 或 ‘QTextCursor‘
项目场景: 项目场景:基于QT实现的C某程序,搭载在Linux环境中。 问题描述 执行程序时,发现log中报错如下内容: QObject::connect: Cannot queue arguments of type QTextCharFormat (Make sure QTextCharFormat is r…...

MySQL知识点_03
MySQL 命令大全 基础命令 操作命令连接到 MySQL 数据库mysql -u 用户名 -p查看所有数据库SHOW DATABASES;选择一个数据库USE 数据库名;查看所有表SHOW TABLES;查看表结构DESCRIBE 表名; 或 SHOW COLUMNS FROM 表名;创建一个新数据库CREATE DATABASE 数据库名;删除一个数据库D…...

leetcode:744. 寻找比目标字母大的最小字母(python3解法)
难度:简单 给你一个字符数组 letters,该数组按非递减顺序排序,以及一个字符 target。letters 里至少有两个不同的字符。 返回 letters 中大于 target 的最小的字符。如果不存在这样的字符,则返回 letters 的第一个字符。 示例 1&a…...

2015年-2016年 软件工程程序设计题(算法题)实战_c语言程序设计数据结构程序设计分析
文章目录 2015年1.c语言程序设计部分2.数据结构程序设计部分 2016年1.c语言程序设计部分2.数据结构程序设计部分 2015年 1.c语言程序设计部分 1.从一组数据中选择最大的和最小的输出。 void print_maxandmin(double a[],int length) //在一组数据中选择最大的或者最小的输出…...

整理一下实际开发和工作中Git工具的使用 (持续更新中)
介绍一下Git 在实际开发和工作中,Git工具的使用可以说是至关重要的,它不仅提高了团队协作的效率,还帮助开发者有效地管理代码版本。以下是对Git工具使用的扩展描述: 版本控制:Git能够跟踪代码的每一个修改记录&#x…...

Axios 的基本使用与 Fetch 的比较、在 Vue 项目中使用 Axios 的最佳实践
文章目录 1. 引言2. Axios 的基本使用2.1 安装 Axios2.2 发起 GET 请求2.3 发起 POST 请求2.4 请求拦截器2.5 设置全局配置 3. Axios 与 Fetch 的比较3.1 Axios 与 Fetch 的异同点3.2 Fetch 的基本使用3.3 使用 Fetch 处理 POST 请求 4. 讨论在 Vue 项目中使用 Axios 的最佳实践…...

Dockerfile样例
一、基础jar镜像制作 ## Dockerfile FROM registry.openanolis.cn/openanolis/anolisos:8.9 RUN mkdir /work ADD jdk17.tar.gz fonts.tar.gz /work/ RUN yum install fontconfig ttmkfdir -y && yum clean all && \chmod -R 755 /work/fonts ADD fonts.conf …...

MYSQL-多表查询
一、概述 1、定义 多表查询,也称为关联查询,指两个或更多个表一起完成查询操作。 2、前提条件 这些一起查询的表之间是有关系的(一对一、一对多),它们之间一定是有关联字段,这个关联字段可能建立了外键…...

MySQL改密码后不生效问题
MySQL修改密码后连接报密码错误 1.mysql修改密码命令: 这两种连接方式密码都必须修改 修改远程连接密码 ALTER USER ‘root’‘%’ IDENTIFIED BY ‘password’; 修改本地连接密码 ALTER USER ‘root’‘localhost’ IDENTIFIED BY ‘password’; 修改完后必须刷新…...

15分钟学Go 第1天:Go语言简介与特点
Go语言简介与特点 1. Go语言概述 Go语言(又称Golang)是由谷歌于2007年开发并在2009年正式发布的一种开源编程语言。它旨在简单、高效地进行软件开发,尤其适合于网络编程和分布式系统。 1.1 发展背景 多核处理器:随着计算机硬件…...

UDP/TCP协议
网络层只负责将数据包送达至目标主机,并不负责将数据包上交给上层的哪一个应用程序,这是传输层需要干的事,传输层通过端口来区分不同的应用程序。传输层协议主要分为UDP(用户数据报协议)和TCP(传输控制协议…...

gitee建立/取消关联仓库
目录 一、常用指令总结 二、建立关联具体操作 三、取消关联具体操作 一、常用指令总结 首先要选中要关联的文件,右击,选择Git Bash Here。 git remote -v //查看自己的文件有几个关联的仓库git init //初始化文件夹为git可远程建立链接的文件夹…...

在 Windows 环境下,Git 默认会自动处理 CRLF 和 LF 之间的转换。
在 Windows 环境下,Git 默认会自动处理 CRLF 和 LF 之间的转换。 要确保这一点并自动处理换行符差异,你可以按照以下步骤配置 1. 配置 Git 自动转换 CRLF 使用 Git Bash 或命令行执行以下命令,设置 Git 自动处理换行符: git con…...

Kibana可视化Dashboard如何基于字段是否包含某关键词进行过滤
kinana是一个功能强大、可对Elasticsearch数据进行可视化的开源工具。 我们在dashboard创建可视化时,有时需要将某个index里数据的某个字段根据是否包含某些特定关键词进行过滤,这个时候就可以用到lens里的filter功能很方便地进行操作。 如上图所示&…...

架构师之路-学渣到学霸历程-23
实战:NFS安装部署 接早上了解过了NFS的一些基本原理,咋们就看看一些实战; 尝试自己部署一下实验;然后实验成功了是我们最大的鼓励来着; 实战过程中,我们也面临了很多报错;所以每个实战的报错我…...

怎么修改编辑PDF的内容,有这4个工具就行了。
PDF 软件在现代的办公或者是学习当中的应用非常广泛,编辑PDF内容对很多人来说也是一件常有的事情。如果有了PDF 编辑软件,查看,编辑,修改,分享也会变得更加方便简单,所以今天要给大家介绍几款这样的工具。 …...

腾讯云宝塔面板前后端项目发版
后端发版 1. 打开“网站”页面,找到java项目,点击状态暂停服务 2.打开“文件”页面,进入jar包目录,删除原有的jar包,上传新jar包 3. 再回到第一步中的网站页面,找到jar项目,启动项目即可 前端发…...

C语言的结构体定义 java赋值关系运算符
1. /*#include<stdio.h> struct student { int num; //成员列表 int score; float avg; }; int main(void) { struct student Tom;//Tom结构体变量 struct student class4[50];//结构体数组 return 0; }*/ struct student { int nu…...

重学SpringBoot3-Spring WebFlux简介
更多SpringBoot3内容请关注我的专栏:《SpringBoot3》 期待您的点赞👍收藏⭐评论✍ 重学SpringBoot3-Spring WebFlux简介 1. 什么是 WebFlux?2. WebFlux 与 Spring MVC 的区别3. WebFlux 的用处3.1 非阻塞 I/O 操作3.2 响应式编程模型3.3 更高…...

distinct 和 group by
最近生产加了一个新字段 a、然后将主键赋值给 a 然后投产后验证是否有漏网之鱼。当时使用的是 select count(distinct pk),count(distinct a) from tableName当时在想这样子跟 group by 有啥区别 select a from tableName group by a having count(a) > 1所以查一下两者…...

RTThread-Nano学习一-基于MDK移植
一、简介 关于RTThread-nano的介绍,这里不做过多解释,官方文档已经介绍的非常详细了,有兴趣的可以参考如下文档:RT-Thread 文档中心 二、移植 1.准备一个能正常运行的代码 手头有M0内核的板子,那就以…...

Vue中v-bind对样式控制的增强—(详解v-bind操作class以及操作style属性,附有案例+代码)
文章目录 v-bind对样式控制的增强2.1 操作class2.1.1 语法2.1.2 对象语法2.1.3 数组语法2.1.4 使用2.1.5 Test 2.2 操作style2.2.1 语法2.2.2 使用2.2.3 Test v-bind对样式控制的增强 2.1 操作class 2.1.1 语法 <div> :class "对象/数组">这是一个div&l…...