当前位置: 首页 > news >正文

PyTorch训练深度卷积生成对抗网络DCGAN

文章目录

    • DCGAN介绍
    • 代码
    • 结果
    • 参考

DCGAN介绍

将CNN和GAN结合起来,把监督学习和无监督学习结合起来。具体解释可以参见 深度卷积对抗生成网络(DCGAN)

DCGAN的生成器结构:
在这里插入图片描述
图片来源:https://arxiv.org/abs/1511.06434

代码

model.py

import torch
import torch.nn as nnclass Discriminator(nn.Module):def __init__(self, channels_img, features_d):super(Discriminator, self).__init__()self.disc = nn.Sequential(# Input: N x channels_img x 64 x 64nn.Conv2d(channels_img, features_d, kernel_size=4, stride=2, padding=1), # 32 x 32nn.LeakyReLU(0.2),self._block(features_d, features_d*2, 4, 2, 1), # 16 x 16self._block(features_d*2, features_d*4, 4, 2, 1), # 8 x 8self._block(features_d*4, features_d*8, 4, 2, 1), # 4 x 4nn.Conv2d(features_d*8, 1, kernel_size=4, stride=2, padding=0), # 1 x 1nn.Sigmoid(),)def _block(self, in_channels, out_channels, kernel_size, stride, padding):return nn.Sequential(nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=False),nn.BatchNorm2d(out_channels),nn.LeakyReLU(0.2),)def forward(self, x):return self.disc(x)class Generator(nn.Module):def __init__(self, z_dim, channels_img, features_g):super(Generator, self).__init__()self.gen = nn.Sequential(# Input: N x z_dim x 1 x 1self._block(z_dim, features_g*16, 4, 1, 0), # N x f_g*16 x 4 x 4self._block(features_g*16, features_g*8, 4, 2, 1), # 8x8self._block(features_g*8, features_g*4, 4, 2, 1), # 16x16self._block(features_g*4, features_g*2, 4, 2, 1), # 32x32nn.ConvTranspose2d(features_g*2, channels_img, kernel_size=4, stride=2, padding=1,),nn.Tanh(),)def _block(self, in_channels, out_channels, kernel_size, stride, padding):return nn.Sequential(nn.ConvTranspose2d(in_channels,out_channels,kernel_size,stride,padding,bias=False,),nn.BatchNorm2d(out_channels),nn.ReLU(),)def forward(self, x):return self.gen(x)def initialize_weights(model):for m in model.modules():if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d, nn.BatchNorm2d)):nn.init.normal_(m.weight.data, 0.0, 0.02)def test():N, in_channels, H, W = 8, 3, 64, 64z_dim = 100x = torch.randn((N, in_channels, H, W))disc = Discriminator(in_channels, 8)initialize_weights(disc)assert disc(x).shape == (N, 1, 1, 1)gen = Generator(z_dim, in_channels, 8)initialize_weights(gen)z = torch.randn((N, z_dim, 1, 1))assert gen(z).shape == (N, in_channels, H, W)print("success")if __name__ == "__main__":test()

训练使用的数据集:CelebA dataset (Images Only) 总共1.3GB的图片,使用方法,将其解压到当前目录

图片如下图所示:
在这里插入图片描述

train.py

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from model import Discriminator, Generator, initialize_weights# Hyperparameters etc.
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
LEARNING_RATE = 2e-4  # could also use two lrs, one for gen and one for disc
BATCH_SIZE = 128
IMAGE_SIZE = 64
CHANNELS_IMG = 3 # 1 if MNIST dataset; 3 if celeb dataset
NOISE_DIM = 100
NUM_EPOCHS = 5
FEATURES_DISC = 64
FEATURES_GEN = 64transforms = transforms.Compose([transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),transforms.ToTensor(),transforms.Normalize([0.5 for _ in range(CHANNELS_IMG)], [0.5 for _ in range(CHANNELS_IMG)]),]
)# If you train on MNIST, remember to set channels_img to 1
# dataset = datasets.MNIST(
#     root="dataset/", train=True, transform=transforms, download=True
# )# comment mnist above and uncomment below if train on CelebA# If you train on celeb dataset, remember to set channels_img to 3
dataset = datasets.ImageFolder(root="celeb_dataset", transform=transforms)
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)
gen = Generator(NOISE_DIM, CHANNELS_IMG, FEATURES_GEN).to(device)
disc = Discriminator(CHANNELS_IMG, FEATURES_DISC).to(device)
initialize_weights(gen)
initialize_weights(disc)opt_gen = optim.Adam(gen.parameters(), lr=LEARNING_RATE, betas=(0.5, 0.999))
opt_disc = optim.Adam(disc.parameters(), lr=LEARNING_RATE, betas=(0.5, 0.999))
criterion = nn.BCELoss()fixed_noise = torch.randn(32, NOISE_DIM, 1, 1).to(device)
writer_real = SummaryWriter(f"logs/real")
writer_fake = SummaryWriter(f"logs/fake")
step = 0gen.train()
disc.train()for epoch in range(NUM_EPOCHS):# Target labels not needed! <3 unsupervisedfor batch_idx, (real, _) in enumerate(dataloader):real = real.to(device)noise = torch.randn(BATCH_SIZE, NOISE_DIM, 1, 1).to(device)fake = gen(noise)### Train Discriminator: max log(D(x)) + log(1 - D(G(z)))disc_real = disc(real).reshape(-1)loss_disc_real = criterion(disc_real, torch.ones_like(disc_real))disc_fake = disc(fake.detach()).reshape(-1)loss_disc_fake = criterion(disc_fake, torch.zeros_like(disc_fake))loss_disc = (loss_disc_real + loss_disc_fake) / 2disc.zero_grad()loss_disc.backward()opt_disc.step()### Train Generator: min log(1 - D(G(z))) <-> max log(D(G(z))output = disc(fake).reshape(-1)loss_gen = criterion(output, torch.ones_like(output))gen.zero_grad()loss_gen.backward()opt_gen.step()# Print losses occasionally and print to tensorboardif batch_idx % 100 == 0:print(f"Epoch [{epoch}/{NUM_EPOCHS}] Batch {batch_idx}/{len(dataloader)} \Loss D: {loss_disc:.4f}, loss G: {loss_gen:.4f}")with torch.no_grad():fake = gen(fixed_noise)# take out (up to) 32 examplesimg_grid_real = torchvision.utils.make_grid(real[:32], normalize=True)img_grid_fake = torchvision.utils.make_grid(fake[:32], normalize=True)writer_real.add_image("Real", img_grid_real, global_step=step)writer_fake.add_image("Fake", img_grid_fake, global_step=step)step += 1

结果

训练5个epoch,部分结果如下:

Epoch [3/5] Batch 1500/1583                   Loss D: 0.4996, loss G: 1.1738
Epoch [4/5] Batch 0/1583                   Loss D: 0.4268, loss G: 1.6633
Epoch [4/5] Batch 100/1583                   Loss D: 0.4841, loss G: 1.7475
Epoch [4/5] Batch 200/1583                   Loss D: 0.5094, loss G: 1.2376
Epoch [4/5] Batch 300/1583                   Loss D: 0.4376, loss G: 2.1271
Epoch [4/5] Batch 400/1583                   Loss D: 0.4173, loss G: 1.4380
Epoch [4/5] Batch 500/1583                   Loss D: 0.5213, loss G: 2.1665
Epoch [4/5] Batch 600/1583                   Loss D: 0.5036, loss G: 2.1079
Epoch [4/5] Batch 700/1583                   Loss D: 0.5158, loss G: 1.0579
Epoch [4/5] Batch 800/1583                   Loss D: 0.5426, loss G: 1.9427
Epoch [4/5] Batch 900/1583                   Loss D: 0.4721, loss G: 1.2659
Epoch [4/5] Batch 1000/1583                   Loss D: 0.5662, loss G: 2.4537
Epoch [4/5] Batch 1100/1583                   Loss D: 0.5604, loss G: 0.8978
Epoch [4/5] Batch 1200/1583                   Loss D: 0.4085, loss G: 2.0747
Epoch [4/5] Batch 1300/1583                   Loss D: 1.1894, loss G: 0.1825
Epoch [4/5] Batch 1400/1583                   Loss D: 0.4518, loss G: 2.1509
Epoch [4/5] Batch 1500/1583                   Loss D: 0.3814, loss G: 1.9391

使用

tensorboard --logdir=logs

打开tensorboard

在这里插入图片描述

参考

[1] DCGAN implementation from scratch
[2] https://arxiv.org/abs/1511.06434

相关文章:

PyTorch训练深度卷积生成对抗网络DCGAN

文章目录 DCGAN介绍代码结果参考 DCGAN介绍 将CNN和GAN结合起来&#xff0c;把监督学习和无监督学习结合起来。具体解释可以参见 深度卷积对抗生成网络(DCGAN) DCGAN的生成器结构&#xff1a; 图片来源&#xff1a;https://arxiv.org/abs/1511.06434 代码 model.py impor…...

Spring-4-掌握Spring事务传播机制

今日目标 能够掌握Spring事务配置 Spring事务管理 1 Spring事务简介【重点】 1.1 Spring事务作用 事务作用&#xff1a;在数据层保障一系列的数据库操作同成功同失败 Spring事务作用&#xff1a;在数据层或业务层保障一系列的数据库操作同成功同失败 1.2 案例分析Spring…...

[PyTorch][chapter 49][创建自己的数据集 1]

前言&#xff1a; 后面几章主要利用DataSet 创建自己的数据集&#xff0c;实现建模&#xff0c; 训练&#xff0c;迁移等功能。 目录: pokemon 数据集深度学习工程步骤 一 pokemon 数据集介绍 1.1 pokemon: 数据集地址&#xff1a; 百度网盘路径: https://pan.baidu.com/s/1…...

中间件(二)dubbo负载均衡介绍

一、负载均衡概述 支持轮询、随机、一致性hash和最小活跃数等。 1、轮询 ① sequences&#xff1a;内部的序列计数器 ② 服务器接口方法权重一样&#xff1a;&#xff08;sequences1&#xff09;%服务器的数量&#xff08;决定调用&#xff09;哪个服务器的服务。 ③ 服务器…...

springboot异步文件上传获取输入流提示找不到文件java.io.FileNotFoundException

springboot上传文件&#xff0c;使用异步操作处理上传的文件数据&#xff0c;出现异常如下&#xff1a; 这个是在异步之后使用传过来的MultipartFile对象尝试调用getInputStream方法发生的异常。 java.io.FileNotFoundException: C:\Users\Administrator\AppData\Local\Temp\to…...

安装jenkins-cli

1、要在 Linux 操作系统上安装 jcli curl -L https://github.com/jenkins-zh/jenkins-cli/releases/latest/download/jcli-linux-amd64.tar.gz|tar xzv sudo mv jcli /usr/local/bin/ 在用户根目录下&#xff0c;增加 jcli 的配置文件&#xff1a; jcli config gen -ifalse …...

linux通过NC工具启动临时端口监听

1.安装nc工具 yum install nc -y2. 启动监听指定端口 #例如监听8080端口 nc -lk 8080#后台监听 nc -lk 8080 &3. 验证 #通过另外一台网络能通的机器&#xff0c;telnet 该机器ip 监听端口能通&#xff0c;并且能接手数据 telnet 192.xxx.xxx.xx 8080...

开源语音聊天软件Mumble

网友 大气 告诉我&#xff0c;Openblocks在国内还有个版本叫 码匠&#xff0c;更贴合国内软件开发的需求&#xff0c;如接入了国内常用的身份认证&#xff0c;接入了国内的数据库和云服务&#xff0c;也对小程序、企微 sdk 等场景做了适配。 在 https://majiang.co/docs/docke…...

JDK 1.6与JDK 1.8的区别

ArrayList使用默认的构造方式实例 jdk1.6默认初始值为10jdk1.8为0,第一次放入值才初始化&#xff0c;属于懒加载 Hashmap底层 jdk1.6与jdk1.8都是数组链表 jdk1.8是链表超过8时&#xff0c;自动转为红黑树 静态方式不同 jdk1.6是先初始化static后执行main方法。 jdk1.8是懒加…...

单片机实训报告

这周我们进行了单片机实训&#xff0c;一周中我们通过七个项目1&#xff1a;P1 口输入/输出 2&#xff1a;继电器控制 3 音频控制 4&#xff1a;子程序设计 5&#xff1a;字符碰头程序设计 6&#xff1a;外部中断 7&#xff1a; 急救车与交通信号灯&#xff0c;练习编写了子程…...

【编织时空四:探究顺序表与链表的数据之旅】

本章重点 链表的分类 带头双向循环链表接口实现 顺序表和链表的区别 缓存利用率参考存储体系结构 以及 局部原理性。 一、链表的分类 实际中链表的结构非常多样&#xff0c;以下情况组合起来就有8种链表结构&#xff1a; 1. 单向或者双向 2. 带头或者不带头 3. 循环或者非…...

PHP8的字符串操作1-PHP8知识详解

字符串是php中最重要的数据之一&#xff0c;字符串的操作在PHP编程占有重要的地位。在使用PHP语言开发web项目的过程中&#xff0c;为了实现某些功能&#xff0c;经常需要对某些字符串进行特殊的处理&#xff0c;比如字符串的格式化、字符串的连接与分割、字符串的比较、查找等…...

电脑提示msvcp140.dll丢失的解决方法,dll组件怎么处理

Windows系统有时在打开游戏或者软件时&#xff0c; 系统会弹窗提示缺少“msvcp140.dll.dll”文件 或者类似错误提示怎么办&#xff1f; 错误背景&#xff1a; msvcp140.dll是Microsoft Visual C Redistributable Package中的一个动态链接库文件&#xff0c;它在运行软件时提…...

stable diffusion基础

整合包下载&#xff1a;秋叶大佬 【AI绘画8月最新】Stable Diffusion整合包v4.2发布&#xff01; 参照&#xff1a;基础04】目前全网最贴心的Lora基础知识教程&#xff01; VAE 作用&#xff1a;滤镜微调 VAE下载地址&#xff1a;C站&#xff08;https://civitai.com/models…...

Greiner–Hormann裁剪算法深度探索:C++实现与应用案例

介绍 在计算几何中&#xff0c;裁剪是一个核心的主题。特别是&#xff0c;多边形裁剪已经被广泛地应用于计算机图形学&#xff0c;地理信息系统和许多其他领域。Greiner-Hormann裁剪算法是其中之一&#xff0c;提供了一个高效的方式来计算两个多边形的交集、并集等。在本文中&…...

Automatically Correcting Large Language Models

本文是大模型相关领域的系列文章&#xff0c;针对《Automatically Correcting Large Language Models: Surveying the landscape of diverse self-correction strategies》的翻译。 自动更正大型语言模型&#xff1a;综述各种自我更正策略的前景 摘要1 引言2 自动反馈校正LLM的…...

【学习FreeRTOS】第8章——FreeRTOS列表和列表项

1.列表和列表项的简介 列表是 FreeRTOS 中的一个数据结构&#xff0c;概念上和链表有点类似&#xff0c;列表被用来跟踪 FreeRTOS中的任务。列表项就是存放在列表中的项目。 列表相当于链表&#xff0c;列表项相当于节点&#xff0c;FreeRTOS 中的列表是一个双向环形链表列表的…...

分布式图数据库 NebulaGraph v3.6.0 正式发布,强化全文索引能力

本次 v3.6.0 版本&#xff0c;主要强化全文索引能力&#xff0c;以及优化部分场景下的 MATCH 性能。 强化 强化增强全文索引功能&#xff0c;具体 pr 参见&#xff1a;#5567、#5575、#5577、#5580、#5584、#5587 优化 支持使用 MATCH 子句检索 VID 或属性索引时使用变量&am…...

在 ubuntu 18.04 上使用源码升级 OpenSSH_7.6p1到 OpenSSH_9.3p1

1、检查系统已安装的当前 SSH 版本 使用命令 ssh -V 查看当前 ssh 版本&#xff0c;输出如下&#xff1a; OpenSSH_7.6p1 Ubuntu-4ubuntu0.7, OpenSSL 1.0.2n 7 Dec 20172、安装依赖&#xff0c;依次执行以下命令 sudo apt update sudo apt install build-essential zlib1g…...

python中可以处理word文档的模块:docx模块

前言 大家早好、午好、晚好吖 ❤ ~欢迎光临本文章 话不多说&#xff0c;直接开搞&#xff0c;如果有什么疑惑/资料需要的可以点击文章末尾名片领取源码 一.docx模块 Python可以利用python-docx模块处理word文档&#xff0c;处理方式是面向对象的。 也就是说python-docx模块…...

eNSP-Cloud(实现本地电脑与eNSP内设备之间通信)

说明&#xff1a; 想象一下&#xff0c;你正在用eNSP搭建一个虚拟的网络世界&#xff0c;里面有虚拟的路由器、交换机、电脑&#xff08;PC&#xff09;等等。这些设备都在你的电脑里面“运行”&#xff0c;它们之间可以互相通信&#xff0c;就像一个封闭的小王国。 但是&#…...

手游刚开服就被攻击怎么办?如何防御DDoS?

开服初期是手游最脆弱的阶段&#xff0c;极易成为DDoS攻击的目标。一旦遭遇攻击&#xff0c;可能导致服务器瘫痪、玩家流失&#xff0c;甚至造成巨大经济损失。本文为开发者提供一套简洁有效的应急与防御方案&#xff0c;帮助快速应对并构建长期防护体系。 一、遭遇攻击的紧急应…...

OpenLayers 可视化之热力图

注&#xff1a;当前使用的是 ol 5.3.0 版本&#xff0c;天地图使用的key请到天地图官网申请&#xff0c;并替换为自己的key 热力图&#xff08;Heatmap&#xff09;又叫热点图&#xff0c;是一种通过特殊高亮显示事物密度分布、变化趋势的数据可视化技术。采用颜色的深浅来显示…...

地震勘探——干扰波识别、井中地震时距曲线特点

目录 干扰波识别反射波地震勘探的干扰波 井中地震时距曲线特点 干扰波识别 有效波&#xff1a;可以用来解决所提出的地质任务的波&#xff1b;干扰波&#xff1a;所有妨碍辨认、追踪有效波的其他波。 地震勘探中&#xff0c;有效波和干扰波是相对的。例如&#xff0c;在反射波…...

Oracle查询表空间大小

1 查询数据库中所有的表空间以及表空间所占空间的大小 SELECTtablespace_name,sum( bytes ) / 1024 / 1024 FROMdba_data_files GROUP BYtablespace_name; 2 Oracle查询表空间大小及每个表所占空间的大小 SELECTtablespace_name,file_id,file_name,round( bytes / ( 1024 …...

QMC5883L的驱动

简介 本篇文章的代码已经上传到了github上面&#xff0c;开源代码 作为一个电子罗盘模块&#xff0c;我们可以通过I2C从中获取偏航角yaw&#xff0c;相对于六轴陀螺仪的yaw&#xff0c;qmc5883l几乎不会零飘并且成本较低。 参考资料 QMC5883L磁场传感器驱动 QMC5883L磁力计…...

Python实现prophet 理论及参数优化

文章目录 Prophet理论及模型参数介绍Python代码完整实现prophet 添加外部数据进行模型优化 之前初步学习prophet的时候&#xff0c;写过一篇简单实现&#xff0c;后期随着对该模型的深入研究&#xff0c;本次记录涉及到prophet 的公式以及参数调优&#xff0c;从公式可以更直观…...

Python爬虫(二):爬虫完整流程

爬虫完整流程详解&#xff08;7大核心步骤实战技巧&#xff09; 一、爬虫完整工作流程 以下是爬虫开发的完整流程&#xff0c;我将结合具体技术点和实战经验展开说明&#xff1a; 1. 目标分析与前期准备 网站技术分析&#xff1a; 使用浏览器开发者工具&#xff08;F12&…...

ardupilot 开发环境eclipse 中import 缺少C++

目录 文章目录 目录摘要1.修复过程摘要 本节主要解决ardupilot 开发环境eclipse 中import 缺少C++,无法导入ardupilot代码,会引起查看不方便的问题。如下图所示 1.修复过程 0.安装ubuntu 软件中自带的eclipse 1.打开eclipse—Help—install new software 2.在 Work with中…...

【Java学习笔记】BigInteger 和 BigDecimal 类

BigInteger 和 BigDecimal 类 二者共有的常见方法 方法功能add加subtract减multiply乘divide除 注意点&#xff1a;传参类型必须是类对象 一、BigInteger 1. 作用&#xff1a;适合保存比较大的整型数 2. 使用说明 创建BigInteger对象 传入字符串 3. 代码示例 import j…...