G1 GAN生成MNIST手写数字图像
- 🍨 本文为🔗365天深度学习训练营 中的学习记录博客
- 🍖 原作者:K同学啊
G1 GAN生成MNIST手写数字图像
1. 生成对抗网络 (GAN) 简介
生成对抗网络 (GAN) 是一种通过“对抗性”学习生成数据的深度学习模型,通常用于生成图像、视频等数据。GAN 由两个网络组成:
- 生成器 (Generator):用于生成假的数据样本,试图让判别器无法分辨其为假的。
- 判别器 (Discriminator):用于区分输入的数据是真实的还是生成器生成的。
GAN 的核心思想是,生成器和判别器通过相互对抗学习,生成器逐渐提高生成逼真数据的能力,而判别器逐渐提高区分真假数据的能力。最后,生成器生成的样本与真实样本之间的差异会越来越小。
GAN 的基本流程
- 判别器输入真实数据,判别器输出一个接近1的值,表示为真;
- 生成器生成假的数据,并试图欺骗判别器;
- 判别器输出接近0的值,表示为假;
- 生成器通过更新自身的参数,试图让判别器认为生成的数据是真实的。
GAN 的目标是使得生成器生成的假数据,能骗过判别器。
GAN 的损失函数
GAN 的训练目标是让生成器和判别器进行对抗训练,其损失函数分为两个部分:生成器损失和判别器损失。生成器的目标是最大化判别器判断生成数据为真的概率,判别器的目标是最大化正确判断真实数据和生成数据的概率。
判别器的损失函数定义为:
L D = − [ E x ∼ p data [ log D ( x ) ] + E z ∼ p z [ log ( 1 − D ( G ( z ) ) ) ] ] \mathcal{L}_D = - \left[ \mathbb{E}_{x \sim p_{\text{data}}} \left[ \log D(x) \right] + \mathbb{E}_{z \sim p_z} \left[ \log (1 - D(G(z))) \right] \right] LD=−[Ex∼pdata[logD(x)]+Ez∼pz[log(1−D(G(z)))]]
生成器的损失函数定义为:
L G = − E z ∼ p z [ log D ( G ( z ) ) ] \mathcal{L}_G = - \mathbb{E}_{z \sim p_z} \left[ \log D(G(z)) \right] LG=−Ez∼pz[logD(G(z))]
其中:
- ( D(x) ) 表示判别器对真实数据 ( x ) 判别为真的概率;
- ( G(z) ) 是生成器通过噪声 ( z ) 生成的假数据;
- ( D(G(z)) ) 表示判别器对生成器生成数据的输出(希望趋向于1)。
2. PyTorch 实现
下面使用 PyTorch 实现 GAN 生成 MNIST 手写数字图像。
2.1 导入库与超参数设置
import os
import numpy as np
import torch
import torch.nn as nn
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from torchvision.utils import save_image# 创建文件夹
os.makedirs('./output/images/', exist_ok=True)# 超参数设置
n_epochs = 50
batch_size = 64
lr = 0.0002
latent_dim = 100
img_size = 28
channels = 1
img_shape = (channels, img_size, img_size)
img_area = np.prod(img_shape)cuda = True if torch.cuda.is_available() else False
2.2 数据预处理
使用 torchvision.datasets.MNIST
下载并处理 MNIST 数据集。数据会被标准化到 [-1, 1] 区间,并通过 DataLoader
转化为可迭代数据集。
# 下载MNIST数据集并进行预处理
mnist = datasets.MNIST(root='./data', train=True, download=True,transform=transforms.Compose([transforms.Resize(img_size),transforms.ToTensor(),transforms.Normalize([0.5], [0.5])]))dataloader = DataLoader(mnist, batch_size=batch_size, shuffle=True)
2.3 定义生成器模型
生成器接受一个随机噪声向量 ( z ),通过多层线性变换和激活函数逐步生成一个 28x28 的图像。
class Generator(nn.Module):def __init__(self):super(Generator, self).__init__()def block(in_feat, out_feat, normalize=True):layers = [nn.Linear(in_feat, out_feat)]if normalize:layers.append(nn.BatchNorm1d(out_feat, 0.8))layers.append(nn.LeakyReLU(0.2, inplace=True))return layersself.model = nn.Sequential(*block(latent_dim, 128, normalize=False),*block(128, 256),*block(256, 512),*block(512, 1024),nn.Linear(1024, img_area),nn.Tanh())def forward(self, z):img = self.model(z)return img.view(img.size(0), *img_shape)
2.4 定义判别器模型
判别器是一个二分类网络,输入一个 28x28 的图像,输出一个表示真假概率的值。
class Discriminator(nn.Module):def __init__(self):super(Discriminator, self).__init__()self.model = nn.Sequential(nn.Linear(img_area, 512),nn.LeakyReLU(0.2, inplace=True),nn.Linear(512, 256),nn.LeakyReLU(0.2, inplace=True),nn.Linear(256, 1),nn.Sigmoid())def forward(self, img):img_flat = img.view(img.size(0), -1)validity = self.model(img_flat)return validity
2.5 定义优化器与损失函数
generator = Generator()
discriminator = Discriminator()# 定义损失函数
criterion = nn.BCELoss()# 定义生成器和判别器的优化器
optimizer_G = torch.optim.Adam(generator.parameters(), lr=lr, betas=(0.5, 0.999))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=lr, betas=(0.5, 0.999))if cuda:generator.cuda()discriminator.cuda()criterion.cuda()
2.6 训练过程
2.6.1 训练判别器
判别器需要区分真实图像和生成的假图像,通过两个损失值相加,更新判别器的参数。
real_img = Variable(imgs.type(torch.cuda.FloatTensor))
real_label = Variable(torch.ones(imgs.size(0), 1).cuda())
fake_label = Variable(torch.zeros(imgs.size(0), 1).cuda())real_out = discriminator(real_img)
loss_real = criterion(real_out, real_label)z = Variable(torch.randn(imgs.size(0), latent_dim).cuda())
fake_img = generator(z).detach()
fake_out = discriminator(fake_img)
loss_fake = criterion(fake_out, fake_label)loss_D = loss_real + loss_fake
optimizer_D.zero_grad()
loss_D.backward()
optimizer_D.step()
2.6.2 训练生成器
生成器的目标是让判别器认为生成的数据是真实的,因此生成器的损失是判别器对假图像的输出。
z = Variable(torch.randn(imgs.size(0), latent_dim).cuda())
fake_img = generator(z)
output = discriminator(fake_img)loss_G = criterion(output, real_label)
optimizer_G.zero_grad()
loss_G.backward()
optimizer_G.step()
2.7 保存与可视化生成图像
if batches_done % sample_interval == 0:save_image(fake_img.data[:25], "./output/images/%d.png" % batches_done, nrow=5, normalize=True)
4. 总结
这周学习了如何使用 PyTorch 实现生成对抗网络 (GAN) 来生成 MNIST 手写数字图像。GAN 通过生成器与判别器之间的对抗学习,不断提升生成图像的质量,是一种非常强大的生成模型。可以在论文中将其作为数据增强的一种方式。
相关文章:

G1 GAN生成MNIST手写数字图像
🍨 本文为🔗365天深度学习训练营 中的学习记录博客🍖 原作者:K同学啊 G1 GAN生成MNIST手写数字图像 1. 生成对抗网络 (GAN) 简介 生成对抗网络 (GAN) 是一种通过“对抗性”学习生成数据的深度学习模型,通常用于生成…...

WPFDeveloper正式版发布
WPFDeveloper WPFDeveloper一个基于WPF自定义高级控件的WPF开发人员UI库,它提供了众多的自定义控件。 该项目的创建者和主要维护者是现役微软MVP 闫驚鏵: https://github.com/yanjinhuagood 该项目还有众多的维护者,详情可以访问github上的README&…...
实现鼠标经过某个元素时弹出提示框(通常称为“工具提示”或“悬浮提示”)
要实现鼠标经过某个元素时弹出提示框(通常称为“工具提示”或“悬浮提示”),你可以使用 JavaScript 结合 CSS 来创建这个效果。以下是详细步骤,包括 HTML、CSS 和 JavaScript 的代码示例。 HTML 结构 首先,创建一个简…...

【GAMES101笔记速查——Lecture 17 Materials and Appearances】
目录 1 材质和外观 1.1 自然界中,外观是光线和材质共同作用的结果 1.2 图形学中,什么是材质? 1.2.1 渲染方程严格正确,其中BRDF项决定了物体的材质 1.2.2 漫反射材质 (1)如何定义漫反射系数࿱…...

对于从vscode ssh到virtualBox的timeout记录
如题,解决方式如下: 1.把虚拟机关机退出来,在这个界面进行网络设置:选桥接网卡 2.然后再进系统,使用命令 ip addr查看如今的ip地址,应该和在本机里面看到的是一个网段 3.打开vscode,该干啥干…...

鸿蒙原生应用扬帆起航
就在2024年6月21日华为在开发者大会上发布了全新操作的系统HarmonyOS Next开发测试版,网友们把它称之为“称之为纯血鸿蒙”。因为在此之前鸿蒙系统底层式有两套基础架构的,一套是是Android的AOSP,一套是鸿蒙的Open Harmony,因为早…...
《计算机视觉》—— 表情识别
根据计算眼睛、嘴巴的变化,判断是什么表情结合以下两篇文章来理解表情识别的实现方法 基于 dilib 库的人脸检测 https://blog.csdn.net/weixin_73504499/article/details/142977202?spm1001.2014.3001.5501 基于 dlib 库的人脸关键点定位 https://blog.csdn.net/we…...

NVIDIA Aerial Omniverse
NVIDIA Aerial Omniverse 数字孪生助力打造新一代无线网络 文章目录 前言一、从链路级仿真到系统级仿真二、转变无线研发方式1. 开放且可定制的模块化平台2. 适用于 6G 标准化的 3GPP 兼容平台3. 部署前测试4. AI 和 ML 在数字孪生中的应用5. 高级物理精准的电磁求解器6. 合作伙…...
QT程序报错解决方案:Cannot queue arguments of type ‘QTextCharFormat‘ 或 ‘QTextCursor‘
项目场景: 项目场景:基于QT实现的C某程序,搭载在Linux环境中。 问题描述 执行程序时,发现log中报错如下内容: QObject::connect: Cannot queue arguments of type QTextCharFormat (Make sure QTextCharFormat is r…...
MySQL知识点_03
MySQL 命令大全 基础命令 操作命令连接到 MySQL 数据库mysql -u 用户名 -p查看所有数据库SHOW DATABASES;选择一个数据库USE 数据库名;查看所有表SHOW TABLES;查看表结构DESCRIBE 表名; 或 SHOW COLUMNS FROM 表名;创建一个新数据库CREATE DATABASE 数据库名;删除一个数据库D…...

leetcode:744. 寻找比目标字母大的最小字母(python3解法)
难度:简单 给你一个字符数组 letters,该数组按非递减顺序排序,以及一个字符 target。letters 里至少有两个不同的字符。 返回 letters 中大于 target 的最小的字符。如果不存在这样的字符,则返回 letters 的第一个字符。 示例 1&a…...

2015年-2016年 软件工程程序设计题(算法题)实战_c语言程序设计数据结构程序设计分析
文章目录 2015年1.c语言程序设计部分2.数据结构程序设计部分 2016年1.c语言程序设计部分2.数据结构程序设计部分 2015年 1.c语言程序设计部分 1.从一组数据中选择最大的和最小的输出。 void print_maxandmin(double a[],int length) //在一组数据中选择最大的或者最小的输出…...
整理一下实际开发和工作中Git工具的使用 (持续更新中)
介绍一下Git 在实际开发和工作中,Git工具的使用可以说是至关重要的,它不仅提高了团队协作的效率,还帮助开发者有效地管理代码版本。以下是对Git工具使用的扩展描述: 版本控制:Git能够跟踪代码的每一个修改记录&#x…...
Axios 的基本使用与 Fetch 的比较、在 Vue 项目中使用 Axios 的最佳实践
文章目录 1. 引言2. Axios 的基本使用2.1 安装 Axios2.2 发起 GET 请求2.3 发起 POST 请求2.4 请求拦截器2.5 设置全局配置 3. Axios 与 Fetch 的比较3.1 Axios 与 Fetch 的异同点3.2 Fetch 的基本使用3.3 使用 Fetch 处理 POST 请求 4. 讨论在 Vue 项目中使用 Axios 的最佳实践…...

Dockerfile样例
一、基础jar镜像制作 ## Dockerfile FROM registry.openanolis.cn/openanolis/anolisos:8.9 RUN mkdir /work ADD jdk17.tar.gz fonts.tar.gz /work/ RUN yum install fontconfig ttmkfdir -y && yum clean all && \chmod -R 755 /work/fonts ADD fonts.conf …...

MYSQL-多表查询
一、概述 1、定义 多表查询,也称为关联查询,指两个或更多个表一起完成查询操作。 2、前提条件 这些一起查询的表之间是有关系的(一对一、一对多),它们之间一定是有关联字段,这个关联字段可能建立了外键…...
MySQL改密码后不生效问题
MySQL修改密码后连接报密码错误 1.mysql修改密码命令: 这两种连接方式密码都必须修改 修改远程连接密码 ALTER USER ‘root’‘%’ IDENTIFIED BY ‘password’; 修改本地连接密码 ALTER USER ‘root’‘localhost’ IDENTIFIED BY ‘password’; 修改完后必须刷新…...

15分钟学Go 第1天:Go语言简介与特点
Go语言简介与特点 1. Go语言概述 Go语言(又称Golang)是由谷歌于2007年开发并在2009年正式发布的一种开源编程语言。它旨在简单、高效地进行软件开发,尤其适合于网络编程和分布式系统。 1.1 发展背景 多核处理器:随着计算机硬件…...

UDP/TCP协议
网络层只负责将数据包送达至目标主机,并不负责将数据包上交给上层的哪一个应用程序,这是传输层需要干的事,传输层通过端口来区分不同的应用程序。传输层协议主要分为UDP(用户数据报协议)和TCP(传输控制协议…...

gitee建立/取消关联仓库
目录 一、常用指令总结 二、建立关联具体操作 三、取消关联具体操作 一、常用指令总结 首先要选中要关联的文件,右击,选择Git Bash Here。 git remote -v //查看自己的文件有几个关联的仓库git init //初始化文件夹为git可远程建立链接的文件夹…...
后进先出(LIFO)详解
LIFO 是 Last In, First Out 的缩写,中文译为后进先出。这是一种数据结构的工作原则,类似于一摞盘子或一叠书本: 最后放进去的元素最先出来 -想象往筒状容器里放盘子: (1)你放进的最后一个盘子(…...
零门槛NAS搭建:WinNAS如何让普通电脑秒变私有云?
一、核心优势:专为Windows用户设计的极简NAS WinNAS由深圳耘想存储科技开发,是一款收费低廉但功能全面的Windows NAS工具,主打“无学习成本部署” 。与其他NAS软件相比,其优势在于: 无需硬件改造:将任意W…...
【Linux】shell脚本忽略错误继续执行
在 shell 脚本中,可以使用 set -e 命令来设置脚本在遇到错误时退出执行。如果你希望脚本忽略错误并继续执行,可以在脚本开头添加 set e 命令来取消该设置。 举例1 #!/bin/bash# 取消 set -e 的设置 set e# 执行命令,并忽略错误 rm somefile…...

VB.net复制Ntag213卡写入UID
本示例使用的发卡器:https://item.taobao.com/item.htm?ftt&id615391857885 一、读取旧Ntag卡的UID和数据 Private Sub Button15_Click(sender As Object, e As EventArgs) Handles Button15.Click轻松读卡技术支持:网站:Dim i, j As IntegerDim cardidhex, …...
React Native在HarmonyOS 5.0阅读类应用开发中的实践
一、技术选型背景 随着HarmonyOS 5.0对Web兼容层的增强,React Native作为跨平台框架可通过重新编译ArkTS组件实现85%以上的代码复用率。阅读类应用具有UI复杂度低、数据流清晰的特点。 二、核心实现方案 1. 环境配置 (1)使用React Native…...

新能源汽车智慧充电桩管理方案:新能源充电桩散热问题及消防安全监管方案
随着新能源汽车的快速普及,充电桩作为核心配套设施,其安全性与可靠性备受关注。然而,在高温、高负荷运行环境下,充电桩的散热问题与消防安全隐患日益凸显,成为制约行业发展的关键瓶颈。 如何通过智慧化管理手段优化散…...

优选算法第十二讲:队列 + 宽搜 优先级队列
优选算法第十二讲:队列 宽搜 && 优先级队列 1.N叉树的层序遍历2.二叉树的锯齿型层序遍历3.二叉树最大宽度4.在每个树行中找最大值5.优先级队列 -- 最后一块石头的重量6.数据流中的第K大元素7.前K个高频单词8.数据流的中位数 1.N叉树的层序遍历 2.二叉树的锯…...
Java 二维码
Java 二维码 **技术:**谷歌 ZXing 实现 首先添加依赖 <!-- 二维码依赖 --><dependency><groupId>com.google.zxing</groupId><artifactId>core</artifactId><version>3.5.1</version></dependency><de…...
音视频——I2S 协议详解
I2S 协议详解 I2S (Inter-IC Sound) 协议是一种串行总线协议,专门用于在数字音频设备之间传输数字音频数据。它由飞利浦(Philips)公司开发,以其简单、高效和广泛的兼容性而闻名。 1. 信号线 I2S 协议通常使用三根或四根信号线&a…...

从“安全密码”到测试体系:Gitee Test 赋能关键领域软件质量保障
关键领域软件测试的"安全密码":Gitee Test如何破解行业痛点 在数字化浪潮席卷全球的今天,软件系统已成为国家关键领域的"神经中枢"。从国防军工到能源电力,从金融交易到交通管控,这些关乎国计民生的关键领域…...