G1 GAN生成MNIST手写数字图像
- 🍨 本文为🔗365天深度学习训练营 中的学习记录博客
- 🍖 原作者:K同学啊
G1 GAN生成MNIST手写数字图像
1. 生成对抗网络 (GAN) 简介
生成对抗网络 (GAN) 是一种通过“对抗性”学习生成数据的深度学习模型,通常用于生成图像、视频等数据。GAN 由两个网络组成:
- 生成器 (Generator):用于生成假的数据样本,试图让判别器无法分辨其为假的。
- 判别器 (Discriminator):用于区分输入的数据是真实的还是生成器生成的。
GAN 的核心思想是,生成器和判别器通过相互对抗学习,生成器逐渐提高生成逼真数据的能力,而判别器逐渐提高区分真假数据的能力。最后,生成器生成的样本与真实样本之间的差异会越来越小。
GAN 的基本流程
- 判别器输入真实数据,判别器输出一个接近1的值,表示为真;
- 生成器生成假的数据,并试图欺骗判别器;
- 判别器输出接近0的值,表示为假;
- 生成器通过更新自身的参数,试图让判别器认为生成的数据是真实的。
GAN 的目标是使得生成器生成的假数据,能骗过判别器。
GAN 的损失函数
GAN 的训练目标是让生成器和判别器进行对抗训练,其损失函数分为两个部分:生成器损失和判别器损失。生成器的目标是最大化判别器判断生成数据为真的概率,判别器的目标是最大化正确判断真实数据和生成数据的概率。
判别器的损失函数定义为:
L D = − [ E x ∼ p data [ log D ( x ) ] + E z ∼ p z [ log ( 1 − D ( G ( z ) ) ) ] ] \mathcal{L}_D = - \left[ \mathbb{E}_{x \sim p_{\text{data}}} \left[ \log D(x) \right] + \mathbb{E}_{z \sim p_z} \left[ \log (1 - D(G(z))) \right] \right] LD=−[Ex∼pdata[logD(x)]+Ez∼pz[log(1−D(G(z)))]]
生成器的损失函数定义为:
L G = − E z ∼ p z [ log D ( G ( z ) ) ] \mathcal{L}_G = - \mathbb{E}_{z \sim p_z} \left[ \log D(G(z)) \right] LG=−Ez∼pz[logD(G(z))]
其中:
- ( D(x) ) 表示判别器对真实数据 ( x ) 判别为真的概率;
- ( G(z) ) 是生成器通过噪声 ( z ) 生成的假数据;
- ( D(G(z)) ) 表示判别器对生成器生成数据的输出(希望趋向于1)。
2. PyTorch 实现
下面使用 PyTorch 实现 GAN 生成 MNIST 手写数字图像。
2.1 导入库与超参数设置
import os
import numpy as np
import torch
import torch.nn as nn
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from torchvision.utils import save_image# 创建文件夹
os.makedirs('./output/images/', exist_ok=True)# 超参数设置
n_epochs = 50
batch_size = 64
lr = 0.0002
latent_dim = 100
img_size = 28
channels = 1
img_shape = (channels, img_size, img_size)
img_area = np.prod(img_shape)cuda = True if torch.cuda.is_available() else False
2.2 数据预处理
使用 torchvision.datasets.MNIST 下载并处理 MNIST 数据集。数据会被标准化到 [-1, 1] 区间,并通过 DataLoader 转化为可迭代数据集。
# 下载MNIST数据集并进行预处理
mnist = datasets.MNIST(root='./data', train=True, download=True,transform=transforms.Compose([transforms.Resize(img_size),transforms.ToTensor(),transforms.Normalize([0.5], [0.5])]))dataloader = DataLoader(mnist, batch_size=batch_size, shuffle=True)
2.3 定义生成器模型
生成器接受一个随机噪声向量 ( z ),通过多层线性变换和激活函数逐步生成一个 28x28 的图像。
class Generator(nn.Module):def __init__(self):super(Generator, self).__init__()def block(in_feat, out_feat, normalize=True):layers = [nn.Linear(in_feat, out_feat)]if normalize:layers.append(nn.BatchNorm1d(out_feat, 0.8))layers.append(nn.LeakyReLU(0.2, inplace=True))return layersself.model = nn.Sequential(*block(latent_dim, 128, normalize=False),*block(128, 256),*block(256, 512),*block(512, 1024),nn.Linear(1024, img_area),nn.Tanh())def forward(self, z):img = self.model(z)return img.view(img.size(0), *img_shape)
2.4 定义判别器模型
判别器是一个二分类网络,输入一个 28x28 的图像,输出一个表示真假概率的值。
class Discriminator(nn.Module):def __init__(self):super(Discriminator, self).__init__()self.model = nn.Sequential(nn.Linear(img_area, 512),nn.LeakyReLU(0.2, inplace=True),nn.Linear(512, 256),nn.LeakyReLU(0.2, inplace=True),nn.Linear(256, 1),nn.Sigmoid())def forward(self, img):img_flat = img.view(img.size(0), -1)validity = self.model(img_flat)return validity
2.5 定义优化器与损失函数
generator = Generator()
discriminator = Discriminator()# 定义损失函数
criterion = nn.BCELoss()# 定义生成器和判别器的优化器
optimizer_G = torch.optim.Adam(generator.parameters(), lr=lr, betas=(0.5, 0.999))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=lr, betas=(0.5, 0.999))if cuda:generator.cuda()discriminator.cuda()criterion.cuda()
2.6 训练过程
2.6.1 训练判别器
判别器需要区分真实图像和生成的假图像,通过两个损失值相加,更新判别器的参数。
real_img = Variable(imgs.type(torch.cuda.FloatTensor))
real_label = Variable(torch.ones(imgs.size(0), 1).cuda())
fake_label = Variable(torch.zeros(imgs.size(0), 1).cuda())real_out = discriminator(real_img)
loss_real = criterion(real_out, real_label)z = Variable(torch.randn(imgs.size(0), latent_dim).cuda())
fake_img = generator(z).detach()
fake_out = discriminator(fake_img)
loss_fake = criterion(fake_out, fake_label)loss_D = loss_real + loss_fake
optimizer_D.zero_grad()
loss_D.backward()
optimizer_D.step()
2.6.2 训练生成器
生成器的目标是让判别器认为生成的数据是真实的,因此生成器的损失是判别器对假图像的输出。
z = Variable(torch.randn(imgs.size(0), latent_dim).cuda())
fake_img = generator(z)
output = discriminator(fake_img)loss_G = criterion(output, real_label)
optimizer_G.zero_grad()
loss_G.backward()
optimizer_G.step()

2.7 保存与可视化生成图像
if batches_done % sample_interval == 0:save_image(fake_img.data[:25], "./output/images/%d.png" % batches_done, nrow=5, normalize=True)

4. 总结
这周学习了如何使用 PyTorch 实现生成对抗网络 (GAN) 来生成 MNIST 手写数字图像。GAN 通过生成器与判别器之间的对抗学习,不断提升生成图像的质量,是一种非常强大的生成模型。可以在论文中将其作为数据增强的一种方式。
相关文章:
G1 GAN生成MNIST手写数字图像
🍨 本文为🔗365天深度学习训练营 中的学习记录博客🍖 原作者:K同学啊 G1 GAN生成MNIST手写数字图像 1. 生成对抗网络 (GAN) 简介 生成对抗网络 (GAN) 是一种通过“对抗性”学习生成数据的深度学习模型,通常用于生成…...
WPFDeveloper正式版发布
WPFDeveloper WPFDeveloper一个基于WPF自定义高级控件的WPF开发人员UI库,它提供了众多的自定义控件。 该项目的创建者和主要维护者是现役微软MVP 闫驚鏵: https://github.com/yanjinhuagood 该项目还有众多的维护者,详情可以访问github上的README&…...
实现鼠标经过某个元素时弹出提示框(通常称为“工具提示”或“悬浮提示”)
要实现鼠标经过某个元素时弹出提示框(通常称为“工具提示”或“悬浮提示”),你可以使用 JavaScript 结合 CSS 来创建这个效果。以下是详细步骤,包括 HTML、CSS 和 JavaScript 的代码示例。 HTML 结构 首先,创建一个简…...
【GAMES101笔记速查——Lecture 17 Materials and Appearances】
目录 1 材质和外观 1.1 自然界中,外观是光线和材质共同作用的结果 1.2 图形学中,什么是材质? 1.2.1 渲染方程严格正确,其中BRDF项决定了物体的材质 1.2.2 漫反射材质 (1)如何定义漫反射系数࿱…...
对于从vscode ssh到virtualBox的timeout记录
如题,解决方式如下: 1.把虚拟机关机退出来,在这个界面进行网络设置:选桥接网卡 2.然后再进系统,使用命令 ip addr查看如今的ip地址,应该和在本机里面看到的是一个网段 3.打开vscode,该干啥干…...
鸿蒙原生应用扬帆起航
就在2024年6月21日华为在开发者大会上发布了全新操作的系统HarmonyOS Next开发测试版,网友们把它称之为“称之为纯血鸿蒙”。因为在此之前鸿蒙系统底层式有两套基础架构的,一套是是Android的AOSP,一套是鸿蒙的Open Harmony,因为早…...
《计算机视觉》—— 表情识别
根据计算眼睛、嘴巴的变化,判断是什么表情结合以下两篇文章来理解表情识别的实现方法 基于 dilib 库的人脸检测 https://blog.csdn.net/weixin_73504499/article/details/142977202?spm1001.2014.3001.5501 基于 dlib 库的人脸关键点定位 https://blog.csdn.net/we…...
NVIDIA Aerial Omniverse
NVIDIA Aerial Omniverse 数字孪生助力打造新一代无线网络 文章目录 前言一、从链路级仿真到系统级仿真二、转变无线研发方式1. 开放且可定制的模块化平台2. 适用于 6G 标准化的 3GPP 兼容平台3. 部署前测试4. AI 和 ML 在数字孪生中的应用5. 高级物理精准的电磁求解器6. 合作伙…...
QT程序报错解决方案:Cannot queue arguments of type ‘QTextCharFormat‘ 或 ‘QTextCursor‘
项目场景: 项目场景:基于QT实现的C某程序,搭载在Linux环境中。 问题描述 执行程序时,发现log中报错如下内容: QObject::connect: Cannot queue arguments of type QTextCharFormat (Make sure QTextCharFormat is r…...
MySQL知识点_03
MySQL 命令大全 基础命令 操作命令连接到 MySQL 数据库mysql -u 用户名 -p查看所有数据库SHOW DATABASES;选择一个数据库USE 数据库名;查看所有表SHOW TABLES;查看表结构DESCRIBE 表名; 或 SHOW COLUMNS FROM 表名;创建一个新数据库CREATE DATABASE 数据库名;删除一个数据库D…...
leetcode:744. 寻找比目标字母大的最小字母(python3解法)
难度:简单 给你一个字符数组 letters,该数组按非递减顺序排序,以及一个字符 target。letters 里至少有两个不同的字符。 返回 letters 中大于 target 的最小的字符。如果不存在这样的字符,则返回 letters 的第一个字符。 示例 1&a…...
2015年-2016年 软件工程程序设计题(算法题)实战_c语言程序设计数据结构程序设计分析
文章目录 2015年1.c语言程序设计部分2.数据结构程序设计部分 2016年1.c语言程序设计部分2.数据结构程序设计部分 2015年 1.c语言程序设计部分 1.从一组数据中选择最大的和最小的输出。 void print_maxandmin(double a[],int length) //在一组数据中选择最大的或者最小的输出…...
整理一下实际开发和工作中Git工具的使用 (持续更新中)
介绍一下Git 在实际开发和工作中,Git工具的使用可以说是至关重要的,它不仅提高了团队协作的效率,还帮助开发者有效地管理代码版本。以下是对Git工具使用的扩展描述: 版本控制:Git能够跟踪代码的每一个修改记录&#x…...
Axios 的基本使用与 Fetch 的比较、在 Vue 项目中使用 Axios 的最佳实践
文章目录 1. 引言2. Axios 的基本使用2.1 安装 Axios2.2 发起 GET 请求2.3 发起 POST 请求2.4 请求拦截器2.5 设置全局配置 3. Axios 与 Fetch 的比较3.1 Axios 与 Fetch 的异同点3.2 Fetch 的基本使用3.3 使用 Fetch 处理 POST 请求 4. 讨论在 Vue 项目中使用 Axios 的最佳实践…...
Dockerfile样例
一、基础jar镜像制作 ## Dockerfile FROM registry.openanolis.cn/openanolis/anolisos:8.9 RUN mkdir /work ADD jdk17.tar.gz fonts.tar.gz /work/ RUN yum install fontconfig ttmkfdir -y && yum clean all && \chmod -R 755 /work/fonts ADD fonts.conf …...
MYSQL-多表查询
一、概述 1、定义 多表查询,也称为关联查询,指两个或更多个表一起完成查询操作。 2、前提条件 这些一起查询的表之间是有关系的(一对一、一对多),它们之间一定是有关联字段,这个关联字段可能建立了外键…...
MySQL改密码后不生效问题
MySQL修改密码后连接报密码错误 1.mysql修改密码命令: 这两种连接方式密码都必须修改 修改远程连接密码 ALTER USER ‘root’‘%’ IDENTIFIED BY ‘password’; 修改本地连接密码 ALTER USER ‘root’‘localhost’ IDENTIFIED BY ‘password’; 修改完后必须刷新…...
15分钟学Go 第1天:Go语言简介与特点
Go语言简介与特点 1. Go语言概述 Go语言(又称Golang)是由谷歌于2007年开发并在2009年正式发布的一种开源编程语言。它旨在简单、高效地进行软件开发,尤其适合于网络编程和分布式系统。 1.1 发展背景 多核处理器:随着计算机硬件…...
UDP/TCP协议
网络层只负责将数据包送达至目标主机,并不负责将数据包上交给上层的哪一个应用程序,这是传输层需要干的事,传输层通过端口来区分不同的应用程序。传输层协议主要分为UDP(用户数据报协议)和TCP(传输控制协议…...
gitee建立/取消关联仓库
目录 一、常用指令总结 二、建立关联具体操作 三、取消关联具体操作 一、常用指令总结 首先要选中要关联的文件,右击,选择Git Bash Here。 git remote -v //查看自己的文件有几个关联的仓库git init //初始化文件夹为git可远程建立链接的文件夹…...
数学公式也能懂:gte-base-zh与MathType内容协同处理方案
数学公式也能懂:gte-base-zh与MathType内容协同处理方案 你有没有遇到过这样的烦恼?面对一份满是复杂数学公式和文字说明的学术论文或技术文档,想快速找到某个特定公式的推导过程,或者想检索所有提到“傅里叶变换”的地方&#x…...
Wux Weapp 终极国际化方案:打造多语言小程序完整指南
Wux Weapp 终极国际化方案:打造多语言小程序完整指南 【免费下载链接】wux-weapp :dog: 一套组件化、可复用、易扩展的微信小程序 UI 组件库 项目地址: https://gitcode.com/gh_mirrors/wu/wux-weapp 想要让你的微信小程序走向全球市场吗?Wux Wea…...
【C语言】-指针(1)
🦆 个人主页:深邃- ❄️专栏传送门:《C语言》《数据结构》 🌟Gitee仓库:《C语言》《数据结构》 目录内存和地址指针变量和地址指针变量和解引用操作符(*)指针变量的大小内存存放指针变量类型的…...
Golang如何把日志写到文件_Golang日志文件教程【秒懂】
Go log包默认只输出到stderr,需用os.OpenFile创建*os.File(实现io.Writer)传给log.SetOutput;并发写安全但格式易乱;需手动flush或用bufio.NewWriter;长期运行需日志轮转等高级功能。Go 标准库的 log 包默认…...
Embedded Coder vs Simulink Coder:如何为你的项目选择正确的代码生成工具?
Embedded Coder与Simulink Coder深度对比:从项目需求出发的选型指南 在嵌入式系统开发领域,代码生成工具的选择往往决定了项目的成败。当工程师面对MathWorks提供的两款核心代码生成工具——Embedded Coder和Simulink Coder时,如何做出明智决…...
电源环路分析仪不会用?2026年硬件工程师的必备技能该补上了
电源环路分析仪不会用?2026年硬件工程师的必备技能该补上了实验室里,Buck电源刚调通,输出纹波看着也不错,但一上动态负载,输出电压就开始剧烈振荡。换了几组补偿参数,还是没找到症结所在。这时候,旁边有经验的前辈说了一句:"你测过环路稳定性吗?"说实话,…...
应对“中年危机”的前置策略:留学生入职第一天就该考虑的事情——如何建立你的“被动求职”网络?
在 2026 年的北美科技职场,拿到全职 Offer 签下字的那一刻,许多留学生会如释重负地认为自己终于进入了“保险箱”。然而,在残酷的宏观经济周期和快速迭代的 AI 浪潮面前,传统的“绝对稳定”早已不复存在。 无论是硅谷巨头…...
OpenClaw+百川2-13B量化模型:个人知识库自动整理方案实测
OpenClaw百川2-13B量化模型:个人知识库自动整理方案实测 1. 为什么需要自动化知识管理 作为一个长期与技术文档打交道的开发者,我的电脑里堆积着超过200GB的未整理资料——从会议录音转写的文字稿、GitHub扒下来的开源项目说明,到随手保存的…...
OpenClaw实操指南09|云端部署实战:腾讯云+OpenClaw,打造7×24小时不断线AI助手
很多人第一次用OpenClaw,是在自己电脑上跑的。 用着挺爽——但只要关机,AI助手就断了。出门在路上,飞书消息发出去,没有回应。 本地部署的致命缺陷:你不在,它也不在。 这篇教程解决这个问题。用腾讯云轻…...
盘姬工具箱功能详解:百余款实用工具助力系统优化
盘姬工具箱最大的特点就是功能的全面性。 软件安装后即可直接使用,打开界面就能看到丰富多样的功能模块。 这些功能模块分类清晰,操作直观,即使是电脑新手也能快速上手。 从日常的小工具到高级的技术工具,盘姬工具箱几乎涵盖了…...
