PyTorch训练深度卷积生成对抗网络DCGAN
文章目录
- DCGAN介绍
- 代码
- 结果
- 参考
DCGAN介绍
将CNN和GAN结合起来,把监督学习和无监督学习结合起来。具体解释可以参见 深度卷积对抗生成网络(DCGAN)
DCGAN的生成器结构:

图片来源:https://arxiv.org/abs/1511.06434
代码
model.py
import torch
import torch.nn as nnclass Discriminator(nn.Module):def __init__(self, channels_img, features_d):super(Discriminator, self).__init__()self.disc = nn.Sequential(# Input: N x channels_img x 64 x 64nn.Conv2d(channels_img, features_d, kernel_size=4, stride=2, padding=1), # 32 x 32nn.LeakyReLU(0.2),self._block(features_d, features_d*2, 4, 2, 1), # 16 x 16self._block(features_d*2, features_d*4, 4, 2, 1), # 8 x 8self._block(features_d*4, features_d*8, 4, 2, 1), # 4 x 4nn.Conv2d(features_d*8, 1, kernel_size=4, stride=2, padding=0), # 1 x 1nn.Sigmoid(),)def _block(self, in_channels, out_channels, kernel_size, stride, padding):return nn.Sequential(nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=False),nn.BatchNorm2d(out_channels),nn.LeakyReLU(0.2),)def forward(self, x):return self.disc(x)class Generator(nn.Module):def __init__(self, z_dim, channels_img, features_g):super(Generator, self).__init__()self.gen = nn.Sequential(# Input: N x z_dim x 1 x 1self._block(z_dim, features_g*16, 4, 1, 0), # N x f_g*16 x 4 x 4self._block(features_g*16, features_g*8, 4, 2, 1), # 8x8self._block(features_g*8, features_g*4, 4, 2, 1), # 16x16self._block(features_g*4, features_g*2, 4, 2, 1), # 32x32nn.ConvTranspose2d(features_g*2, channels_img, kernel_size=4, stride=2, padding=1,),nn.Tanh(),)def _block(self, in_channels, out_channels, kernel_size, stride, padding):return nn.Sequential(nn.ConvTranspose2d(in_channels,out_channels,kernel_size,stride,padding,bias=False,),nn.BatchNorm2d(out_channels),nn.ReLU(),)def forward(self, x):return self.gen(x)def initialize_weights(model):for m in model.modules():if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d, nn.BatchNorm2d)):nn.init.normal_(m.weight.data, 0.0, 0.02)def test():N, in_channels, H, W = 8, 3, 64, 64z_dim = 100x = torch.randn((N, in_channels, H, W))disc = Discriminator(in_channels, 8)initialize_weights(disc)assert disc(x).shape == (N, 1, 1, 1)gen = Generator(z_dim, in_channels, 8)initialize_weights(gen)z = torch.randn((N, z_dim, 1, 1))assert gen(z).shape == (N, in_channels, H, W)print("success")if __name__ == "__main__":test()
训练使用的数据集:CelebA dataset (Images Only) 总共1.3GB的图片,使用方法,将其解压到当前目录
图片如下图所示:

train.py
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from model import Discriminator, Generator, initialize_weights# Hyperparameters etc.
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
LEARNING_RATE = 2e-4 # could also use two lrs, one for gen and one for disc
BATCH_SIZE = 128
IMAGE_SIZE = 64
CHANNELS_IMG = 3 # 1 if MNIST dataset; 3 if celeb dataset
NOISE_DIM = 100
NUM_EPOCHS = 5
FEATURES_DISC = 64
FEATURES_GEN = 64transforms = transforms.Compose([transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),transforms.ToTensor(),transforms.Normalize([0.5 for _ in range(CHANNELS_IMG)], [0.5 for _ in range(CHANNELS_IMG)]),]
)# If you train on MNIST, remember to set channels_img to 1
# dataset = datasets.MNIST(
# root="dataset/", train=True, transform=transforms, download=True
# )# comment mnist above and uncomment below if train on CelebA# If you train on celeb dataset, remember to set channels_img to 3
dataset = datasets.ImageFolder(root="celeb_dataset", transform=transforms)
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)
gen = Generator(NOISE_DIM, CHANNELS_IMG, FEATURES_GEN).to(device)
disc = Discriminator(CHANNELS_IMG, FEATURES_DISC).to(device)
initialize_weights(gen)
initialize_weights(disc)opt_gen = optim.Adam(gen.parameters(), lr=LEARNING_RATE, betas=(0.5, 0.999))
opt_disc = optim.Adam(disc.parameters(), lr=LEARNING_RATE, betas=(0.5, 0.999))
criterion = nn.BCELoss()fixed_noise = torch.randn(32, NOISE_DIM, 1, 1).to(device)
writer_real = SummaryWriter(f"logs/real")
writer_fake = SummaryWriter(f"logs/fake")
step = 0gen.train()
disc.train()for epoch in range(NUM_EPOCHS):# Target labels not needed! <3 unsupervisedfor batch_idx, (real, _) in enumerate(dataloader):real = real.to(device)noise = torch.randn(BATCH_SIZE, NOISE_DIM, 1, 1).to(device)fake = gen(noise)### Train Discriminator: max log(D(x)) + log(1 - D(G(z)))disc_real = disc(real).reshape(-1)loss_disc_real = criterion(disc_real, torch.ones_like(disc_real))disc_fake = disc(fake.detach()).reshape(-1)loss_disc_fake = criterion(disc_fake, torch.zeros_like(disc_fake))loss_disc = (loss_disc_real + loss_disc_fake) / 2disc.zero_grad()loss_disc.backward()opt_disc.step()### Train Generator: min log(1 - D(G(z))) <-> max log(D(G(z))output = disc(fake).reshape(-1)loss_gen = criterion(output, torch.ones_like(output))gen.zero_grad()loss_gen.backward()opt_gen.step()# Print losses occasionally and print to tensorboardif batch_idx % 100 == 0:print(f"Epoch [{epoch}/{NUM_EPOCHS}] Batch {batch_idx}/{len(dataloader)} \Loss D: {loss_disc:.4f}, loss G: {loss_gen:.4f}")with torch.no_grad():fake = gen(fixed_noise)# take out (up to) 32 examplesimg_grid_real = torchvision.utils.make_grid(real[:32], normalize=True)img_grid_fake = torchvision.utils.make_grid(fake[:32], normalize=True)writer_real.add_image("Real", img_grid_real, global_step=step)writer_fake.add_image("Fake", img_grid_fake, global_step=step)step += 1
结果
训练5个epoch,部分结果如下:
Epoch [3/5] Batch 1500/1583 Loss D: 0.4996, loss G: 1.1738
Epoch [4/5] Batch 0/1583 Loss D: 0.4268, loss G: 1.6633
Epoch [4/5] Batch 100/1583 Loss D: 0.4841, loss G: 1.7475
Epoch [4/5] Batch 200/1583 Loss D: 0.5094, loss G: 1.2376
Epoch [4/5] Batch 300/1583 Loss D: 0.4376, loss G: 2.1271
Epoch [4/5] Batch 400/1583 Loss D: 0.4173, loss G: 1.4380
Epoch [4/5] Batch 500/1583 Loss D: 0.5213, loss G: 2.1665
Epoch [4/5] Batch 600/1583 Loss D: 0.5036, loss G: 2.1079
Epoch [4/5] Batch 700/1583 Loss D: 0.5158, loss G: 1.0579
Epoch [4/5] Batch 800/1583 Loss D: 0.5426, loss G: 1.9427
Epoch [4/5] Batch 900/1583 Loss D: 0.4721, loss G: 1.2659
Epoch [4/5] Batch 1000/1583 Loss D: 0.5662, loss G: 2.4537
Epoch [4/5] Batch 1100/1583 Loss D: 0.5604, loss G: 0.8978
Epoch [4/5] Batch 1200/1583 Loss D: 0.4085, loss G: 2.0747
Epoch [4/5] Batch 1300/1583 Loss D: 1.1894, loss G: 0.1825
Epoch [4/5] Batch 1400/1583 Loss D: 0.4518, loss G: 2.1509
Epoch [4/5] Batch 1500/1583 Loss D: 0.3814, loss G: 1.9391
使用
tensorboard --logdir=logs
打开tensorboard

参考
[1] DCGAN implementation from scratch
[2] https://arxiv.org/abs/1511.06434
相关文章:
PyTorch训练深度卷积生成对抗网络DCGAN
文章目录 DCGAN介绍代码结果参考 DCGAN介绍 将CNN和GAN结合起来,把监督学习和无监督学习结合起来。具体解释可以参见 深度卷积对抗生成网络(DCGAN) DCGAN的生成器结构: 图片来源:https://arxiv.org/abs/1511.06434 代码 model.py impor…...
Spring-4-掌握Spring事务传播机制
今日目标 能够掌握Spring事务配置 Spring事务管理 1 Spring事务简介【重点】 1.1 Spring事务作用 事务作用:在数据层保障一系列的数据库操作同成功同失败 Spring事务作用:在数据层或业务层保障一系列的数据库操作同成功同失败 1.2 案例分析Spring…...
[PyTorch][chapter 49][创建自己的数据集 1]
前言: 后面几章主要利用DataSet 创建自己的数据集,实现建模, 训练,迁移等功能。 目录: pokemon 数据集深度学习工程步骤 一 pokemon 数据集介绍 1.1 pokemon: 数据集地址: 百度网盘路径: https://pan.baidu.com/s/1…...
中间件(二)dubbo负载均衡介绍
一、负载均衡概述 支持轮询、随机、一致性hash和最小活跃数等。 1、轮询 ① sequences:内部的序列计数器 ② 服务器接口方法权重一样:(sequences1)%服务器的数量(决定调用)哪个服务器的服务。 ③ 服务器…...
springboot异步文件上传获取输入流提示找不到文件java.io.FileNotFoundException
springboot上传文件,使用异步操作处理上传的文件数据,出现异常如下: 这个是在异步之后使用传过来的MultipartFile对象尝试调用getInputStream方法发生的异常。 java.io.FileNotFoundException: C:\Users\Administrator\AppData\Local\Temp\to…...
安装jenkins-cli
1、要在 Linux 操作系统上安装 jcli curl -L https://github.com/jenkins-zh/jenkins-cli/releases/latest/download/jcli-linux-amd64.tar.gz|tar xzv sudo mv jcli /usr/local/bin/ 在用户根目录下,增加 jcli 的配置文件: jcli config gen -ifalse …...
linux通过NC工具启动临时端口监听
1.安装nc工具 yum install nc -y2. 启动监听指定端口 #例如监听8080端口 nc -lk 8080#后台监听 nc -lk 8080 &3. 验证 #通过另外一台网络能通的机器,telnet 该机器ip 监听端口能通,并且能接手数据 telnet 192.xxx.xxx.xx 8080...
开源语音聊天软件Mumble
网友 大气 告诉我,Openblocks在国内还有个版本叫 码匠,更贴合国内软件开发的需求,如接入了国内常用的身份认证,接入了国内的数据库和云服务,也对小程序、企微 sdk 等场景做了适配。 在 https://majiang.co/docs/docke…...
JDK 1.6与JDK 1.8的区别
ArrayList使用默认的构造方式实例 jdk1.6默认初始值为10jdk1.8为0,第一次放入值才初始化,属于懒加载 Hashmap底层 jdk1.6与jdk1.8都是数组链表 jdk1.8是链表超过8时,自动转为红黑树 静态方式不同 jdk1.6是先初始化static后执行main方法。 jdk1.8是懒加…...
单片机实训报告
这周我们进行了单片机实训,一周中我们通过七个项目1:P1 口输入/输出 2:继电器控制 3 音频控制 4:子程序设计 5:字符碰头程序设计 6:外部中断 7: 急救车与交通信号灯,练习编写了子程…...
【编织时空四:探究顺序表与链表的数据之旅】
本章重点 链表的分类 带头双向循环链表接口实现 顺序表和链表的区别 缓存利用率参考存储体系结构 以及 局部原理性。 一、链表的分类 实际中链表的结构非常多样,以下情况组合起来就有8种链表结构: 1. 单向或者双向 2. 带头或者不带头 3. 循环或者非…...
PHP8的字符串操作1-PHP8知识详解
字符串是php中最重要的数据之一,字符串的操作在PHP编程占有重要的地位。在使用PHP语言开发web项目的过程中,为了实现某些功能,经常需要对某些字符串进行特殊的处理,比如字符串的格式化、字符串的连接与分割、字符串的比较、查找等…...
电脑提示msvcp140.dll丢失的解决方法,dll组件怎么处理
Windows系统有时在打开游戏或者软件时, 系统会弹窗提示缺少“msvcp140.dll.dll”文件 或者类似错误提示怎么办? 错误背景: msvcp140.dll是Microsoft Visual C Redistributable Package中的一个动态链接库文件,它在运行软件时提…...
stable diffusion基础
整合包下载:秋叶大佬 【AI绘画8月最新】Stable Diffusion整合包v4.2发布! 参照:基础04】目前全网最贴心的Lora基础知识教程! VAE 作用:滤镜微调 VAE下载地址:C站(https://civitai.com/models…...
Greiner–Hormann裁剪算法深度探索:C++实现与应用案例
介绍 在计算几何中,裁剪是一个核心的主题。特别是,多边形裁剪已经被广泛地应用于计算机图形学,地理信息系统和许多其他领域。Greiner-Hormann裁剪算法是其中之一,提供了一个高效的方式来计算两个多边形的交集、并集等。在本文中&…...
Automatically Correcting Large Language Models
本文是大模型相关领域的系列文章,针对《Automatically Correcting Large Language Models: Surveying the landscape of diverse self-correction strategies》的翻译。 自动更正大型语言模型:综述各种自我更正策略的前景 摘要1 引言2 自动反馈校正LLM的…...
【学习FreeRTOS】第8章——FreeRTOS列表和列表项
1.列表和列表项的简介 列表是 FreeRTOS 中的一个数据结构,概念上和链表有点类似,列表被用来跟踪 FreeRTOS中的任务。列表项就是存放在列表中的项目。 列表相当于链表,列表项相当于节点,FreeRTOS 中的列表是一个双向环形链表列表的…...
分布式图数据库 NebulaGraph v3.6.0 正式发布,强化全文索引能力
本次 v3.6.0 版本,主要强化全文索引能力,以及优化部分场景下的 MATCH 性能。 强化 强化增强全文索引功能,具体 pr 参见:#5567、#5575、#5577、#5580、#5584、#5587 优化 支持使用 MATCH 子句检索 VID 或属性索引时使用变量&am…...
在 ubuntu 18.04 上使用源码升级 OpenSSH_7.6p1到 OpenSSH_9.3p1
1、检查系统已安装的当前 SSH 版本 使用命令 ssh -V 查看当前 ssh 版本,输出如下: OpenSSH_7.6p1 Ubuntu-4ubuntu0.7, OpenSSL 1.0.2n 7 Dec 20172、安装依赖,依次执行以下命令 sudo apt update sudo apt install build-essential zlib1g…...
python中可以处理word文档的模块:docx模块
前言 大家早好、午好、晚好吖 ❤ ~欢迎光临本文章 话不多说,直接开搞,如果有什么疑惑/资料需要的可以点击文章末尾名片领取源码 一.docx模块 Python可以利用python-docx模块处理word文档,处理方式是面向对象的。 也就是说python-docx模块…...
使用VSCode开发Django指南
使用VSCode开发Django指南 一、概述 Django 是一个高级 Python 框架,专为快速、安全和可扩展的 Web 开发而设计。Django 包含对 URL 路由、页面模板和数据处理的丰富支持。 本文将创建一个简单的 Django 应用,其中包含三个使用通用基本模板的页面。在此…...
通过Wrangler CLI在worker中创建数据库和表
官方使用文档:Getting started Cloudflare D1 docs 创建数据库 在命令行中执行完成之后,会在本地和远程创建数据库: npx wranglerlatest d1 create prod-d1-tutorial 在cf中就可以看到数据库: 现在,您的Cloudfla…...
关于iview组件中使用 table , 绑定序号分页后序号从1开始的解决方案
问题描述:iview使用table 中type: "index",分页之后 ,索引还是从1开始,试过绑定后台返回数据的id, 这种方法可行,就是后台返回数据的每个页面id都不完全是按照从1开始的升序,因此百度了下,找到了…...
【Oracle】分区表
个人主页:Guiat 归属专栏:Oracle 文章目录 1. 分区表基础概述1.1 分区表的概念与优势1.2 分区类型概览1.3 分区表的工作原理 2. 范围分区 (RANGE Partitioning)2.1 基础范围分区2.1.1 按日期范围分区2.1.2 按数值范围分区 2.2 间隔分区 (INTERVAL Partit…...
大数据学习(132)-HIve数据分析
🍋🍋大数据学习🍋🍋 🔥系列专栏: 👑哲学语录: 用力所能及,改变世界。 💖如果觉得博主的文章还不错的话,请点赞👍收藏⭐️留言Ǵ…...
.Net Framework 4/C# 关键字(非常用,持续更新...)
一、is 关键字 is 关键字用于检查对象是否于给定类型兼容,如果兼容将返回 true,如果不兼容则返回 false,在进行类型转换前,可以先使用 is 关键字判断对象是否与指定类型兼容,如果兼容才进行转换,这样的转换是安全的。 例如有:首先创建一个字符串对象,然后将字符串对象隐…...
九天毕昇深度学习平台 | 如何安装库?
pip install 库名 -i https://pypi.tuna.tsinghua.edu.cn/simple --user 举个例子: 报错 ModuleNotFoundError: No module named torch 那么我需要安装 torch pip install torch -i https://pypi.tuna.tsinghua.edu.cn/simple --user pip install 库名&#x…...
Python ROS2【机器人中间件框架】 简介
销量过万TEEIS德国护膝夏天用薄款 优惠券冠生园 百花蜂蜜428g 挤压瓶纯蜂蜜巨奇严选 鞋子除臭剂360ml 多芬身体磨砂膏280g健70%-75%酒精消毒棉片湿巾1418cm 80片/袋3袋大包清洁食品用消毒 优惠券AIMORNY52朵红玫瑰永生香皂花同城配送非鲜花七夕情人节生日礼物送女友 热卖妙洁棉…...
CSS | transition 和 transform的用处和区别
省流总结: transform用于变换/变形,transition是动画控制器 transform 用来对元素进行变形,常见的操作如下,它是立即生效的样式变形属性。 旋转 rotate(角度deg)、平移 translateX(像素px)、缩放 scale(倍数)、倾斜 skewX(角度…...
解读《网络安全法》最新修订,把握网络安全新趋势
《网络安全法》自2017年施行以来,在维护网络空间安全方面发挥了重要作用。但随着网络环境的日益复杂,网络攻击、数据泄露等事件频发,现行法律已难以完全适应新的风险挑战。 2025年3月28日,国家网信办会同相关部门起草了《网络安全…...
