当前位置: 首页 > news >正文

PyTorch训练简单的生成对抗网络GAN

文章目录

    • 原理
    • 代码
    • 结果
    • 参考

原理

同时训练两个网络:辨别器Discriminator 和 生成器Generator
Generator是 造假者,用来生成假数据。
Discriminator 是警察,尽可能的分辨出来哪些是造假的,哪些是真实的数据。

目的:使得判别模型尽量犯错,无法判断数据是来自真实数据还是生成出来的数据。

GAN的梯度下降训练过程:

在这里插入图片描述
上图来源:https://arxiv.org/abs/1406.2661

Train 辨别器: m a x max max l o g ( D ( x ) ) + l o g ( 1 − D ( G ( z ) ) ) log(D(x)) + log(1 - D(G(z))) log(D(x))+log(1D(G(z)))

Train 生成器: m i n min min l o g ( 1 − D ( G ( z ) ) ) log(1-D(G(z))) log(1D(G(z)))

我们可以使用BCEloss来计算上述两个损失函数

BCEloss的表达式: m i n − [ y l n x + ( 1 − y ) l n ( 1 − x ) ] min -[ylnx + (1-y)ln(1-x)] min[ylnx+(1y)ln(1x)]
具体过程参加代码中注释

代码

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.datasets as datasets
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
from torch.utils.tensorboard import SummaryWriter  # to print to tensorboardclass Discriminator(nn.Module):def __init__(self, img_dim):super(Discriminator, self).__init__()self.disc = nn.Sequential(nn.Linear(img_dim, 128),nn.LeakyReLU(0.1),nn.Linear(128, 1),nn.Sigmoid(),)def forward(self, x):return self.disc(x)class Generator(nn.Module):def __init__(self, z_dim, img_dim): # z_dim 噪声的维度super(Generator, self).__init__()self.gen = nn.Sequential(nn.Linear(z_dim, 256),nn.LeakyReLU(0.1),nn.Linear(256, img_dim), # 28x28 -> 784nn.Tanh(),)def forward(self, x):return self.gen(x)# Hyperparameters
device = 'cuda' if torch.cuda.is_available() else 'cpu'
lr = 3e-4 # 3e-4是Adam最好的学习率
z_dim = 64 # 噪声维度
img_dim = 784 # 28x28x1
batch_size = 32
num_epochs = 50disc = Discriminator(img_dim).to(device)
gen = Generator(z_dim, img_dim).to(device)fixed_noise = torch.randn((batch_size, z_dim)).to(device)
transforms = transforms.Compose( # MNIST标准化系数:(0.1307,), (0.3081,)[transforms.ToTensor(), transforms.Normalize((0.1307, ), (0.3081,))] # 不同数据集就有不同的标准化系数
)dataset = datasets.MNIST(root='dataset/', transform=transforms, download=True)
loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)opt_disc = optim.Adam(disc.parameters(), lr=lr)
opt_gen = optim.Adam(gen.parameters(), lr=lr)
# BCE 损失
criterion = nn.BCELoss()# 打开tensorboard:在该目录下,使用 tensorboard --logdir=runs
writer_fake = SummaryWriter(f"runs/GAN_MNIST/fake")
writer_real = SummaryWriter(f"runs/GAN_MNIST/real")
step = 0for epoch in range(num_epochs):for batch_idx, (real, _) in enumerate(loader):real = real.view(-1, 784).to(device) # view相当于reshapebatch_size = real.shape[0]### Train Discriminator: max log(D(real)) + log(1 - D(G(z)))noise = torch.randn(batch_size, z_dim).to(device)fake = gen(noise) # G(z)disc_real = disc(real).view(-1) # flatten# BCEloss的表达式:min -[ylnx + (1-y)ln(1-x)]# max log(D(real)) 相当于 min -log(D(real))# ones_like:1填充得到y=1, 即可忽略  min -[ylnx + (1-y)ln(1-x)]中的后一项# 得到 min -lnx,这里的x就是我们的real图片lossD_real = criterion(disc_real, torch.ones_like(disc_real))disc_fake = disc(fake).view(-1)# max log(1 - D(G(z))) 相当于 min -log(1 - D(G(z)))# zeros_like用0填充,得到y=0,即可忽略  min -[ylnx + (1-y)ln(1-x)]中的前一项# 得到 min -ln(1-x),这里的x就是我们的fake噪声lossD_fake = criterion(disc_fake, torch.zeros_like(disc_fake))lossD = (lossD_real + lossD_fake) / 2disc.zero_grad()lossD.backward(retain_graph=True)opt_disc.step()### Train Generator: min log(1-D(G(z))) <--> max log(D(G(z))) <--> min - log(D(G(z)))# 依然可使用BCEloss来做output = disc(fake).view(-1)lossG = criterion(output, torch.ones_like(output))gen.zero_grad()lossG.backward()opt_gen.step()if batch_idx == 0:print(f"Epoch [{epoch}/{num_epochs}] \ "f"Loss D: {lossD:.4f}, Loss G: {lossG:.4f}")with torch.no_grad():fake = gen(fixed_noise).reshape(-1, 1, 28, 28)data = real.reshape(-1, 1, 28, 28)img_grid_fake = torchvision.utils.make_grid(fake, normalize=True)img_grid_real = torchvision.utils.make_grid(data, normalize=True)writer_fake.add_image("Mnist Fake Images", img_grid_fake, global_step=step)writer_real.add_image("Mnist Real Images", img_grid_real, global_step=step)step += 1

结果

训练50轮的的损失

Epoch [0/50] \ Loss D: 0.7366, Loss G: 0.7051
Epoch [1/50] \ Loss D: 0.2483, Loss G: 1.6877
Epoch [2/50] \ Loss D: 0.1049, Loss G: 2.4980
Epoch [3/50] \ Loss D: 0.1159, Loss G: 3.4923
Epoch [4/50] \ Loss D: 0.0400, Loss G: 3.8776
Epoch [5/50] \ Loss D: 0.0450, Loss G: 4.1703
...
Epoch [43/50] \ Loss D: 0.0022, Loss G: 7.7446
Epoch [44/50] \ Loss D: 0.0007, Loss G: 9.1281
Epoch [45/50] \ Loss D: 0.0138, Loss G: 6.2177
Epoch [46/50] \ Loss D: 0.0008, Loss G: 9.1188
Epoch [47/50] \ Loss D: 0.0025, Loss G: 8.9419
Epoch [48/50] \ Loss D: 0.0010, Loss G: 8.3315
Epoch [49/50] \ Loss D: 0.0007, Loss G: 7.8302

使用

tensorboard --logdir=runs

打开tensorboard:

在这里插入图片描述
可以看到效果并不好,这是由于我们只是采用了简单的线性网络来做辨别器和生成器。后面的博文我们会使用更复杂的网络来训练GAN。

参考

[1] Building our first simple GAN
[2] https://arxiv.org/abs/1406.2661

相关文章:

PyTorch训练简单的生成对抗网络GAN

文章目录 原理代码结果参考 原理 同时训练两个网络&#xff1a;辨别器Discriminator 和 生成器Generator Generator是 造假者&#xff0c;用来生成假数据。 Discriminator 是警察&#xff0c;尽可能的分辨出来哪些是造假的&#xff0c;哪些是真实的数据。 目的&#xff1a;使…...

django实现文件上传

在django中实现文件上传有三种方法可以实现&#xff1a; 自己手动写使用Form组件使用ModelForm组件 其中使用ModelForm组件实现是最简单的。 1、自己手写 先写一个上传的页面 upload_file.html enctype"multipart/form-data 一定要加这个&#xff0c;不然只会上传文件名…...

Java地图专题课 基本API BMapGLLib 地图找房案例 MongoDB

本课程基于百度地图技术&#xff0c;由基础入门开始到应用实战&#xff0c;适合零基础入门学习。将企业项目中地图相关常见应用场景的落地实战&#xff0c;包括有地图找房、轻骑小程序、金运物流等。同时讲了基于Netty实现高性能的web服务&#xff0c;来处理高并发的问题。还讲…...

vue实现可缩放拖拽盒子(亲测可用)

特征 没有依赖 使用可拖动&#xff0c;可调整大小或两者兼备定义用于调整大小的句柄限制大小和移动到父元素或自定义选择器将元素捕捉到自定义网格将拖动限制为垂直或水平轴保持纵横比启用触控功能使用自己的样式为句柄提供自己的样式 安装和基本用法 npm install --save vue-d…...

python一次性导出项目用到的依赖

导出依赖列表 如果你用到了Anaconda&#xff0c;记得先激活环境!!!! 下载pipreqs pip install pipreqs 在项目的根目录新建一个run_pipreqs.py文件&#xff0c;复制一下代码&#xff1a; # -*- coding: utf-8 -*- import os import subprocessos.environ["PYTHONIOE…...

移动端网页中的前端视频技术探索

引言 随着移动设备的普及和网络速度的提升&#xff0c;移动端网页中的视频播放已经成为了越来越重要的功能需求。本篇博客将介绍一些在移动端网页中实现前端视频播放的技术探索&#xff0c;并提供详细的代码示例。 1. 基本视频标签 在移动端网页中实现视频播放最基本的方法就…...

题解:ABC277C - Ladder Takahashi

题解&#xff1a;ABC277C - Ladder Takahashi 题目 链接&#xff1a;Atcoder。 链接&#xff1a;洛谷。 难度 算法难度&#xff1a;普及。 思维难度&#xff1a;入门。 调码难度&#xff1a;入门。 综合评价&#xff1a;简单。 算法 深度优先搜索简单图论 思路 把每…...

7.11 Java方法重写

7.11 Java方法重写 这里首先要确定的是重写跟属性没有关系&#xff0c;重写都是方法的重写&#xff0c;与属性无关 带有关键字Static修饰的方法的重写实例 父类实例 package com.baidu.www.oop.demo05;public class B {public static void test(){System.out.println("这…...

Android Stodio编译JNI项目,Cmake出错:Detecting C compiler ABI info - failed

在使用Android Stodio编译JNI项目时出现Cmake错误&#xff0c;报错如下&#xff1a; Execution failed for task :app:configureCMakeDebug[arm64-v8a]. > [CXX1429] error when building with cmake using C:\Users\Dell\AndroidStudioProjects\MyApplication2\app\src\ma…...

6.2 Spring Boot整合MyBatis

1、基于Spring BootMyBatis的学生信息系统的设计与实现案例 基于Spring BootMyBatis实现学生信息的新增、修改、删除、查询功能&#xff0c;并实现MySQL数据库的操作。 MySQL数据库创建学生表&#xff08;t_student&#xff09;&#xff0c;有主键、姓名、年龄、性别、出生日…...

在CentOS 7上使用kubeadm部署Kubernetes集群

如有错误&#xff0c;敬请谅解&#xff01; 此文章仅为本人学习笔记&#xff0c;仅供参考&#xff0c;如有冒犯&#xff0c;请联系作者删除&#xff01;&#xff01; 前言&#xff1a; Kubernetes是一个开源的容器编排平台&#xff0c;用于管理和自动化部署容器化的应用程序。…...

这6个免费设计素材网站,设计师都在用,马住

新手设计师不知道去哪里找素材&#xff0c;那就看看这几个设计师都在用的网站吧&#xff0c;免费、付费、商用素材都有&#xff0c;可根据需求选择&#xff0c;赶紧收藏~ 菜鸟图库 https://www.sucai999.com/?vNTYxMjky 菜鸟图库是一个非常大的素材库&#xff0c;站内包含设…...

uni-app引入sortable列表拖拽,兼容App和H5,拖拽排序。

效果: 拖拽排序 背景&#xff1a; 作为一名前端开发人员&#xff0c;在工作中难免会遇到拖拽功能&#xff0c;分享一个github上一个不错的拖拽js库&#xff0c;能满足我们在项目开发中的需要&#xff0c;下面是我在uniapp中使用SortableJS的使用详细流程&#xff1b; vue开发…...

Redis-内存淘汰算法

Redis可以存多少数据 32位的操作系统默认3G 谁现在用32位啊?我们说64位的 一般来讲是不设上限的 但是我们也可以主动配置maxmemory, maxmemory支持各单位: maxmemory 1024 (默认字节) maxmemory 1024KB maxmemory 1024MB maxmemory 1204GB 当Redis存储超过这个配置值&#…...

Git 合并分支时允许合并不相关的历史

git fetch git fetch 是 Git 的一个命令&#xff0c;用于从远程仓库中获取最新的提交和数据&#xff0c;同时更新本地仓库的远程分支指针。 使用 git fetch 命令可以获取远程仓库的最新提交&#xff0c;但并不会自动合并或修改本地分支。它会将远程仓库的提交和引用&#xff…...

世界上最著名的密码学夫妻的历史

Alice和Bob是密码学领域里最著名的虚拟夫妻&#xff0c;自1978年“诞生”以来&#xff0c;到走进二十一世纪的移动互联网时代&#xff0c;作为虚构的故事主角&#xff0c;Alice和Bob不仅在计算机理论、逻辑学、量子计算等与密码学相关的领域中得到应用&#xff0c;他们的名字也…...

二维码网络钓鱼攻击泛滥!美国著名能源企业成主要攻击目标

近日&#xff0c;Cofense发现了一次专门针对美国能源公司的网络钓鱼攻击活动&#xff0c;攻击者利用二维码将恶意电子邮件塞进收件箱并绕过安全系统。 Cofense 方面表示&#xff0c;这是首次发现网络钓鱼行为者如此大规模的使用二维码进行钓鱼攻击&#xff0c;这表明他们可能正…...

前端面试题-CSS

1. 盒模型 ⻚⾯渲染时&#xff0c; dom 元素所采⽤的 布局模型。可通过 box-sizing 进⾏设置。根据计算宽⾼的区域可分为 content-box ( W3C 标准盒模型)border-box ( IE 盒模型)padding-boxmargin-box (浏览器未实现) 2. BFC 块级格式化上下⽂&#xff0c;是⼀个独⽴的渲染…...

6.1 安全漏洞与网络攻击

数据参考&#xff1a;CISP官方 目录 安全漏洞及产生原因信息收集与分析网络攻击实施后门设置与痕迹清除 一、安全漏洞及产生原因 什么是安全漏洞 安全漏洞也称脆弱性&#xff0c;是计算机系统存在的缺陷 漏洞的形式 安全漏洞以不同形式存在漏洞数量逐年递增 漏洞产生的…...

STM32--EXTI外部中断

前文回顾---STM32--GPIO 相关回顾--有关中断系统简介 目录 STM32中断 NVIC EXTI外部中断 AFIO EXTI框图 旋转编码器简介 对射式红外传感器工程 代码&#xff1a; 旋转编码器工程 代码&#xff1a; STM32中断 先说一下基本原理&#xff1a; 1.中断请求发生&#xff1a…...

硬件身份伪装终极指南:3分钟掌握EASY-HWID-SPOOFER的深度伪装技术

硬件身份伪装终极指南&#xff1a;3分钟掌握EASY-HWID-SPOOFER的深度伪装技术 【免费下载链接】EASY-HWID-SPOOFER 基于内核模式的硬件信息欺骗工具 项目地址: https://gitcode.com/gh_mirrors/ea/EASY-HWID-SPOOFER 你是否曾经遇到过这样的情况&#xff1a;刚买的软件因…...

SillyTavern角色卡片系统:打造属于你的AI灵魂伴侣

SillyTavern角色卡片系统&#xff1a;打造属于你的AI灵魂伴侣 【免费下载链接】SillyTavern LLM Frontend for Power Users. 项目地址: https://gitcode.com/GitHub_Trending/si/SillyTavern 你是否曾经幻想过&#xff0c;能有一个真正理解你、陪伴你的AI伙伴&#xff1…...

告别单调终端:250+ Xshell配色方案让你的命令行焕然一新

告别单调终端&#xff1a;250 Xshell配色方案让你的命令行焕然一新 【免费下载链接】Xshell-ColorScheme 250 Xshell Color Schemes 项目地址: https://gitcode.com/gh_mirrors/xs/Xshell-ColorScheme 每天面对单调的黑白终端界面&#xff0c;是否感到视觉疲劳&#xff…...

PowerShdll源码深度分析:从DLL导出到控制台劫持的完整实现原理

PowerShdll源码深度分析&#xff1a;从DLL导出到控制台劫持的完整实现原理 【免费下载链接】PowerShdll Run PowerShell with rundll32. Bypass software restrictions. 项目地址: https://gitcode.com/gh_mirrors/po/PowerShdll PowerShdll是一个创新的PowerShell绕过工…...

wBlock Safari扩展架构详解:5个内容拦截扩展的协同工作原理

wBlock Safari扩展架构详解&#xff1a;5个内容拦截扩展的协同工作原理 【免费下载链接】wBlock The next-generation ad blocker for Safari. 项目地址: https://gitcode.com/gh_mirrors/wb/wBlock wBlock是一款下一代Safari广告拦截器&#xff0c;通过创新的多扩展架构…...

MemoryOS:开源时序知识图谱AI记忆系统

AI的记忆困局&#xff1a;为什么需要"时序"和"知识图谱"&#xff1f;用过ChatGPT或任何AI助手的人大概都有过这样的体验&#xff1a;昨天告诉AI自己住在北京&#xff0c;今天问它"我住哪儿"&#xff0c;它可能还能答对&#xff1b;但是过了两周&…...

点式玻璃幕墙及采光顶设计的一些想法

点式玻璃幕墙及采光顶设计的一些想法 点式玻璃幕墙是在主龙骨上面固定点支撑装置,由点支撑装置支撑玻璃面板的一种常用幕墙表现形式,他最早起源于国外。因为玻璃的通透性,建筑内外有效融合,空间感增强,开阔了视野,增加了建筑物的现代感。 点式玻璃幕墙最主要的组成部分是…...

magnetW磁力聚合搜索工具:一站式资源发现神器

magnetW磁力聚合搜索工具&#xff1a;一站式资源发现神器 【免费下载链接】magnetW [已失效&#xff0c;不再维护] 项目地址: https://gitcode.com/gh_mirrors/ma/magnetW 磁力搜索工具magnetW是一款基于Electron框架开发的跨平台桌面应用&#xff0c;专为技术爱好者和普…...

Hackintool:黑苹果配置不再复杂,这款工具让你轻松搞定所有难题

Hackintool&#xff1a;黑苹果配置不再复杂&#xff0c;这款工具让你轻松搞定所有难题 【免费下载链接】Hackintool The Swiss army knife of vanilla Hackintoshing 项目地址: https://gitcode.com/gh_mirrors/ha/Hackintool 还在为黑苹果的配置问题头疼吗&#xff1f;…...

后端架构师转型AI智能体落地:收藏这份3个月进阶指南,轻松玩转不确定性系统

本文为后端/全栈/架构师提供了一条从零到一掌握AI智能体落地的技术路径。文章首先分析了架构师在AI智能体落地中的核心优势&#xff0c;如分布式系统设计、数据库设计、API封装等&#xff1b;接着&#xff0c;提出了一个分四阶段的三个月进阶计划&#xff0c;包括掌握核心范式、…...