当前位置: 首页 > news >正文

G1 GAN生成MNIST手写数字图像

  • 🍨 本文为🔗365天深度学习训练营 中的学习记录博客
  • 🍖 原作者:K同学啊

G1 GAN生成MNIST手写数字图像

1. 生成对抗网络 (GAN) 简介

生成对抗网络 (GAN) 是一种通过“对抗性”学习生成数据的深度学习模型,通常用于生成图像、视频等数据。GAN 由两个网络组成:

  • 生成器 (Generator):用于生成假的数据样本,试图让判别器无法分辨其为假的。
  • 判别器 (Discriminator):用于区分输入的数据是真实的还是生成器生成的。

GAN 的核心思想是,生成器和判别器通过相互对抗学习,生成器逐渐提高生成逼真数据的能力,而判别器逐渐提高区分真假数据的能力。最后,生成器生成的样本与真实样本之间的差异会越来越小。

GAN 的基本流程

  1. 判别器输入真实数据,判别器输出一个接近1的值,表示为真;
  2. 生成器生成假的数据,并试图欺骗判别器;
  3. 判别器输出接近0的值,表示为假;
  4. 生成器通过更新自身的参数,试图让判别器认为生成的数据是真实的。

GAN 的目标是使得生成器生成的假数据,能骗过判别器。

GAN 的损失函数

GAN 的训练目标是让生成器和判别器进行对抗训练,其损失函数分为两个部分:生成器损失和判别器损失。生成器的目标是最大化判别器判断生成数据为真的概率,判别器的目标是最大化正确判断真实数据和生成数据的概率。

判别器的损失函数定义为:

L D = − [ E x ∼ p data [ log ⁡ D ( x ) ] + E z ∼ p z [ log ⁡ ( 1 − D ( G ( z ) ) ) ] ] \mathcal{L}_D = - \left[ \mathbb{E}_{x \sim p_{\text{data}}} \left[ \log D(x) \right] + \mathbb{E}_{z \sim p_z} \left[ \log (1 - D(G(z))) \right] \right] LD=[Expdata[logD(x)]+Ezpz[log(1D(G(z)))]]

生成器的损失函数定义为:

L G = − E z ∼ p z [ log ⁡ D ( G ( z ) ) ] \mathcal{L}_G = - \mathbb{E}_{z \sim p_z} \left[ \log D(G(z)) \right] LG=Ezpz[logD(G(z))]

其中:

  • ( D(x) ) 表示判别器对真实数据 ( x ) 判别为真的概率;
  • ( G(z) ) 是生成器通过噪声 ( z ) 生成的假数据;
  • ( D(G(z)) ) 表示判别器对生成器生成数据的输出(希望趋向于1)。

2. PyTorch 实现

下面使用 PyTorch 实现 GAN 生成 MNIST 手写数字图像。

2.1 导入库与超参数设置

import os
import numpy as np
import torch
import torch.nn as nn
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from torchvision.utils import save_image# 创建文件夹
os.makedirs('./output/images/', exist_ok=True)# 超参数设置
n_epochs = 50
batch_size = 64
lr = 0.0002
latent_dim = 100
img_size = 28
channels = 1
img_shape = (channels, img_size, img_size)
img_area = np.prod(img_shape)cuda = True if torch.cuda.is_available() else False

2.2 数据预处理

使用 torchvision.datasets.MNIST 下载并处理 MNIST 数据集。数据会被标准化到 [-1, 1] 区间,并通过 DataLoader 转化为可迭代数据集。

# 下载MNIST数据集并进行预处理
mnist = datasets.MNIST(root='./data', train=True, download=True,transform=transforms.Compose([transforms.Resize(img_size),transforms.ToTensor(),transforms.Normalize([0.5], [0.5])]))dataloader = DataLoader(mnist, batch_size=batch_size, shuffle=True)

2.3 定义生成器模型

生成器接受一个随机噪声向量 ( z ),通过多层线性变换和激活函数逐步生成一个 28x28 的图像。

class Generator(nn.Module):def __init__(self):super(Generator, self).__init__()def block(in_feat, out_feat, normalize=True):layers = [nn.Linear(in_feat, out_feat)]if normalize:layers.append(nn.BatchNorm1d(out_feat, 0.8))layers.append(nn.LeakyReLU(0.2, inplace=True))return layersself.model = nn.Sequential(*block(latent_dim, 128, normalize=False),*block(128, 256),*block(256, 512),*block(512, 1024),nn.Linear(1024, img_area),nn.Tanh())def forward(self, z):img = self.model(z)return img.view(img.size(0), *img_shape)

2.4 定义判别器模型

判别器是一个二分类网络,输入一个 28x28 的图像,输出一个表示真假概率的值。

class Discriminator(nn.Module):def __init__(self):super(Discriminator, self).__init__()self.model = nn.Sequential(nn.Linear(img_area, 512),nn.LeakyReLU(0.2, inplace=True),nn.Linear(512, 256),nn.LeakyReLU(0.2, inplace=True),nn.Linear(256, 1),nn.Sigmoid())def forward(self, img):img_flat = img.view(img.size(0), -1)validity = self.model(img_flat)return validity

2.5 定义优化器与损失函数

generator = Generator()
discriminator = Discriminator()# 定义损失函数
criterion = nn.BCELoss()# 定义生成器和判别器的优化器
optimizer_G = torch.optim.Adam(generator.parameters(), lr=lr, betas=(0.5, 0.999))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=lr, betas=(0.5, 0.999))if cuda:generator.cuda()discriminator.cuda()criterion.cuda()

2.6 训练过程

2.6.1 训练判别器

判别器需要区分真实图像和生成的假图像,通过两个损失值相加,更新判别器的参数。

real_img = Variable(imgs.type(torch.cuda.FloatTensor))
real_label = Variable(torch.ones(imgs.size(0), 1).cuda())
fake_label = Variable(torch.zeros(imgs.size(0), 1).cuda())real_out = discriminator(real_img)
loss_real = criterion(real_out, real_label)z = Variable(torch.randn(imgs.size(0), latent_dim).cuda())
fake_img = generator(z).detach()
fake_out = discriminator(fake_img)
loss_fake = criterion(fake_out, fake_label)loss_D = loss_real + loss_fake
optimizer_D.zero_grad()
loss_D.backward()
optimizer_D.step()
2.6.2 训练生成器

生成器的目标是让判别器认为生成的数据是真实的,因此生成器的损失是判别器对假图像的输出。

z = Variable(torch.randn(imgs.size(0), latent_dim).cuda())
fake_img = generator(z)
output = discriminator(fake_img)loss_G = criterion(output, real_label)
optimizer_G.zero_grad()
loss_G.backward()
optimizer_G.step()

在这里插入图片描述

2.7 保存与可视化生成图像

if batches_done % sample_interval == 0:save_image(fake_img.data[:25], "./output/images/%d.png" % batches_done, nrow=5, normalize=True)

在这里插入图片描述

4. 总结

这周学习了如何使用 PyTorch 实现生成对抗网络 (GAN) 来生成 MNIST 手写数字图像。GAN 通过生成器与判别器之间的对抗学习,不断提升生成图像的质量,是一种非常强大的生成模型。可以在论文中将其作为数据增强的一种方式。

相关文章:

G1 GAN生成MNIST手写数字图像

🍨 本文为🔗365天深度学习训练营 中的学习记录博客🍖 原作者:K同学啊 G1 GAN生成MNIST手写数字图像 1. 生成对抗网络 (GAN) 简介 生成对抗网络 (GAN) 是一种通过“对抗性”学习生成数据的深度学习模型,通常用于生成…...

WPFDeveloper正式版发布

WPFDeveloper WPFDeveloper一个基于WPF自定义高级控件的WPF开发人员UI库,它提供了众多的自定义控件。 该项目的创建者和主要维护者是现役微软MVP 闫驚鏵: https://github.com/yanjinhuagood 该项目还有众多的维护者,详情可以访问github上的README&…...

实现鼠标经过某个元素时弹出提示框(通常称为“工具提示”或“悬浮提示”)

要实现鼠标经过某个元素时弹出提示框(通常称为“工具提示”或“悬浮提示”),你可以使用 JavaScript 结合 CSS 来创建这个效果。以下是详细步骤,包括 HTML、CSS 和 JavaScript 的代码示例。 HTML 结构 首先,创建一个简…...

【GAMES101笔记速查——Lecture 17 Materials and Appearances】

目录 1 材质和外观 1.1 自然界中,外观是光线和材质共同作用的结果 1.2 图形学中,什么是材质? 1.2.1 渲染方程严格正确,其中BRDF项决定了物体的材质 1.2.2 漫反射材质 (1)如何定义漫反射系数&#xff1…...

对于从vscode ssh到virtualBox的timeout记录

如题,解决方式如下: 1.把虚拟机关机退出来,在这个界面进行网络设置:选桥接网卡 2.然后再进系统,使用命令 ip addr查看如今的ip地址,应该和在本机里面看到的是一个网段 3.打开vscode,该干啥干…...

鸿蒙原生应用扬帆起航

就在2024年6月21日华为在开发者大会上发布了全新操作的系统HarmonyOS Next开发测试版,网友们把它称之为“称之为纯血鸿蒙”。因为在此之前鸿蒙系统底层式有两套基础架构的,一套是是Android的AOSP,一套是鸿蒙的Open Harmony,因为早…...

《计算机视觉》—— 表情识别

根据计算眼睛、嘴巴的变化,判断是什么表情结合以下两篇文章来理解表情识别的实现方法 基于 dilib 库的人脸检测 https://blog.csdn.net/weixin_73504499/article/details/142977202?spm1001.2014.3001.5501 基于 dlib 库的人脸关键点定位 https://blog.csdn.net/we…...

NVIDIA Aerial Omniverse

NVIDIA Aerial Omniverse 数字孪生助力打造新一代无线网络 文章目录 前言一、从链路级仿真到系统级仿真二、转变无线研发方式1. 开放且可定制的模块化平台2. 适用于 6G 标准化的 3GPP 兼容平台3. 部署前测试4. AI 和 ML 在数字孪生中的应用5. 高级物理精准的电磁求解器6. 合作伙…...

QT程序报错解决方案:Cannot queue arguments of type ‘QTextCharFormat‘ 或 ‘QTextCursor‘

项目场景: 项目场景:基于QT实现的C某程序,搭载在Linux环境中。 问题描述 执行程序时,发现log中报错如下内容: QObject::connect: Cannot queue arguments of type QTextCharFormat (Make sure QTextCharFormat is r…...

MySQL知识点_03

MySQL 命令大全 基础命令 操作命令连接到 MySQL 数据库mysql -u 用户名 -p查看所有数据库SHOW DATABASES;选择一个数据库USE 数据库名;查看所有表SHOW TABLES;查看表结构DESCRIBE 表名; 或 SHOW COLUMNS FROM 表名;创建一个新数据库CREATE DATABASE 数据库名;删除一个数据库D…...

leetcode:744. 寻找比目标字母大的最小字母(python3解法)

难度:简单 给你一个字符数组 letters,该数组按非递减顺序排序,以及一个字符 target。letters 里至少有两个不同的字符。 返回 letters 中大于 target 的最小的字符。如果不存在这样的字符,则返回 letters 的第一个字符。 示例 1&a…...

2015年-2016年 软件工程程序设计题(算法题)实战_c语言程序设计数据结构程序设计分析

文章目录 2015年1.c语言程序设计部分2.数据结构程序设计部分 2016年1.c语言程序设计部分2.数据结构程序设计部分 2015年 1.c语言程序设计部分 1.从一组数据中选择最大的和最小的输出。 void print_maxandmin(double a[],int length) //在一组数据中选择最大的或者最小的输出…...

整理一下实际开发和工作中Git工具的使用 (持续更新中)

介绍一下Git 在实际开发和工作中,Git工具的使用可以说是至关重要的,它不仅提高了团队协作的效率,还帮助开发者有效地管理代码版本。以下是对Git工具使用的扩展描述: 版本控制:Git能够跟踪代码的每一个修改记录&#x…...

Axios 的基本使用与 Fetch 的比较、在 Vue 项目中使用 Axios 的最佳实践

文章目录 1. 引言2. Axios 的基本使用2.1 安装 Axios2.2 发起 GET 请求2.3 发起 POST 请求2.4 请求拦截器2.5 设置全局配置 3. Axios 与 Fetch 的比较3.1 Axios 与 Fetch 的异同点3.2 Fetch 的基本使用3.3 使用 Fetch 处理 POST 请求 4. 讨论在 Vue 项目中使用 Axios 的最佳实践…...

Dockerfile样例

一、基础jar镜像制作 ## Dockerfile FROM registry.openanolis.cn/openanolis/anolisos:8.9 RUN mkdir /work ADD jdk17.tar.gz fonts.tar.gz /work/ RUN yum install fontconfig ttmkfdir -y && yum clean all && \chmod -R 755 /work/fonts ADD fonts.conf …...

MYSQL-多表查询

一、概述 1、定义 多表查询,也称为关联查询,指两个或更多个表一起完成查询操作。 2、前提条件 这些一起查询的表之间是有关系的(一对一、一对多),它们之间一定是有关联字段,这个关联字段可能建立了外键…...

MySQL改密码后不生效问题

MySQL修改密码后连接报密码错误 1.mysql修改密码命令: 这两种连接方式密码都必须修改 修改远程连接密码 ALTER USER ‘root’‘%’ IDENTIFIED BY ‘password’; 修改本地连接密码 ALTER USER ‘root’‘localhost’ IDENTIFIED BY ‘password’; 修改完后必须刷新…...

15分钟学Go 第1天:Go语言简介与特点

Go语言简介与特点 1. Go语言概述 Go语言(又称Golang)是由谷歌于2007年开发并在2009年正式发布的一种开源编程语言。它旨在简单、高效地进行软件开发,尤其适合于网络编程和分布式系统。 1.1 发展背景 多核处理器:随着计算机硬件…...

UDP/TCP协议

网络层只负责将数据包送达至目标主机,并不负责将数据包上交给上层的哪一个应用程序,这是传输层需要干的事,传输层通过端口来区分不同的应用程序。传输层协议主要分为UDP(用户数据报协议)和TCP(传输控制协议…...

gitee建立/取消关联仓库

目录 一、常用指令总结 二、建立关联具体操作 三、取消关联具体操作 一、常用指令总结 首先要选中要关联的文件,右击,选择Git Bash Here。 git remote -v //查看自己的文件有几个关联的仓库git init //初始化文件夹为git可远程建立链接的文件夹…...

3.3.1_1 检错编码(奇偶校验码)

从这节课开始,我们会探讨数据链路层的差错控制功能,差错控制功能的主要目标是要发现并且解决一个帧内部的位错误,我们需要使用特殊的编码技术去发现帧内部的位错误,当我们发现位错误之后,通常来说有两种解决方案。第一…...

mongodb源码分析session执行handleRequest命令find过程

mongo/transport/service_state_machine.cpp已经分析startSession创建ASIOSession过程,并且验证connection是否超过限制ASIOSession和connection是循环接受客户端命令,把数据流转换成Message,状态转变流程是:State::Created 》 St…...

【RockeMQ】第2节|RocketMQ快速实战以及核⼼概念详解(二)

升级Dledger高可用集群 一、主从架构的不足与Dledger的定位 主从架构缺陷 数据备份依赖Slave节点,但无自动故障转移能力,Master宕机后需人工切换,期间消息可能无法读取。Slave仅存储数据,无法主动升级为Master响应请求&#xff…...

IoT/HCIP实验-3/LiteOS操作系统内核实验(任务、内存、信号量、CMSIS..)

文章目录 概述HelloWorld 工程C/C配置编译器主配置Makefile脚本烧录器主配置运行结果程序调用栈 任务管理实验实验结果osal 系统适配层osal_task_create 其他实验实验源码内存管理实验互斥锁实验信号量实验 CMISIS接口实验还是得JlINKCMSIS 简介LiteOS->CMSIS任务间消息交互…...

Java面试专项一-准备篇

一、企业简历筛选规则 一般企业的简历筛选流程:首先由HR先筛选一部分简历后,在将简历给到对应的项目负责人后再进行下一步的操作。 HR如何筛选简历 例如:Boss直聘(招聘方平台) 直接按照条件进行筛选 例如&#xff1a…...

html css js网页制作成品——HTML+CSS榴莲商城网页设计(4页)附源码

目录 一、👨‍🎓网站题目 二、✍️网站描述 三、📚网站介绍 四、🌐网站效果 五、🪓 代码实现 🧱HTML 六、🥇 如何让学习不再盲目 七、🎁更多干货 一、👨‍&#x1f…...

JavaScript基础-API 和 Web API

在学习JavaScript的过程中,理解API(应用程序接口)和Web API的概念及其应用是非常重要的。这些工具极大地扩展了JavaScript的功能,使得开发者能够创建出功能丰富、交互性强的Web应用程序。本文将深入探讨JavaScript中的API与Web AP…...

Golang——6、指针和结构体

指针和结构体 1、指针1.1、指针地址和指针类型1.2、指针取值1.3、new和make 2、结构体2.1、type关键字的使用2.2、结构体的定义和初始化2.3、结构体方法和接收者2.4、给任意类型添加方法2.5、结构体的匿名字段2.6、嵌套结构体2.7、嵌套匿名结构体2.8、结构体的继承 3、结构体与…...

python爬虫——气象数据爬取

一、导入库与全局配置 python 运行 import json import datetime import time import requests from sqlalchemy import create_engine import csv import pandas as pd作用: 引入数据解析、网络请求、时间处理、数据库操作等所需库。requests:发送 …...

Neko虚拟浏览器远程协作方案:Docker+内网穿透技术部署实践

前言:本文将向开发者介绍一款创新性协作工具——Neko虚拟浏览器。在数字化协作场景中,跨地域的团队常需面对实时共享屏幕、协同编辑文档等需求。通过本指南,你将掌握在Ubuntu系统中使用容器化技术部署该工具的具体方案,并结合内网…...