当前位置: 首页 > news >正文

使用pytorch构建一个无监督的深度卷积GAN网络模型

本文为此系列的第二篇DCGAN,上一篇为初级的GAN。普通GAN有训练不稳定、容易陷入局部最优等问题,DCGAN相对于普通GAN的优点是能够生成更加逼真、清晰的图像。
因为DCGAN是在GAN的基础上的改造,所以本篇只针对GAN的改造点进行讲解,其他还有不太了解的原理可以返回上一篇进行观看。

本文仍然使用MNIST手写数字数据集来构建一个深度卷积GAN(Deep Convolutional GAN)DCGAN,将使用卷积来替代全连接层,点击查看论文,generator的网络结构图如下:
在这里插入图片描述
DCGAN模型有以下特点:

  1. 判别器模型使用卷积步长取代了空间池化,生成器模型中使用反卷积操作扩大数据维度。
  2. 除了生成器模型的输出层和判别器模型的输入层,在整个对抗网络的其它层上都使用了Batch Normalization,原因是Batch Normalization可以稳定学习,有助于优化初始化参数值不良而导致的训练问题。
  3. 整个网络去除了全连接层,直接使用卷积层连接生成器和判别器的输入层以及输出层。
  4. 在生成器的输出层使用Tanh激活函数以控制输出范围,而在其它层中均使用了ReLU激活函数;在判别器上使用Leaky ReLU激活函数。

代码

model.py:

from torch import nnclass Generator(nn.Module):def __init__(self, z_dim=10, im_chan=1, hidden_dim=64):super(Generator, self).__init__()self.z_dim = z_dim# Build the neural networkself.gen = nn.Sequential(self.make_gen_block(z_dim, hidden_dim * 4),self.make_gen_block(hidden_dim * 4, hidden_dim * 2, kernel_size=4, stride=1),self.make_gen_block(hidden_dim * 2, hidden_dim),self.make_gen_block(hidden_dim, im_chan, kernel_size=4, final_layer=True),)def make_gen_block(self, input_channels, output_channels, kernel_size=3, stride=2, final_layer=False):if not final_layer:return nn.Sequential(nn.ConvTranspose2d(input_channels, output_channels, kernel_size=kernel_size, stride=stride),nn.BatchNorm2d(output_channels),nn.ReLU(inplace=True))else: # Final Layerreturn nn.Sequential(nn.ConvTranspose2d(input_channels, output_channels, kernel_size, stride),nn.Tanh())def unsqueeze_noise(self, noise):return noise.view(len(noise), self.z_dim, 1, 1)    # [b,c,h,w]def forward(self, noise):x = self.unsqueeze_noise(noise)return self.gen(x)class Discriminator(nn.Module):def __init__(self, im_chan=1, hidden_dim=16):super(Discriminator, self).__init__()self.disc = nn.Sequential(self.make_disc_block(im_chan, hidden_dim),self.make_disc_block(hidden_dim, hidden_dim * 2),self.make_disc_block(hidden_dim * 2, 1, final_layer=True),)def make_disc_block(self, input_channels, output_channels, kernel_size=4, stride=2, final_layer=False):if not final_layer:return nn.Sequential(nn.Conv2d(input_channels, output_channels, kernel_size, stride),nn.BatchNorm2d(output_channels),nn.LeakyReLU(0.2, inplace=True))else:  # Final Layerreturn nn.Sequential(nn.Conv2d(input_channels, output_channels, kernel_size, stride))def forward(self, image):disc_pred = self.disc(image)return disc_pred.view(len(disc_pred), -1)

train.py:

import torch
from torch import nn
from tqdm.auto import tqdm
from torchvision import transforms
from torchvision.datasets import MNIST
from torchvision.utils import make_grid
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
from model import *
torch.manual_seed(0) # Set for testing purposes, please do not change!def show_tensor_images(image_tensor, num_images=25, size=(1, 28, 28)):image_tensor = (image_tensor + 1) / 2image_unflat = image_tensor.detach().cpu()image_grid = make_grid(image_unflat[:num_images], nrow=5)plt.imshow(image_grid.permute(1, 2, 0).squeeze())plt.show()def get_noise(n_samples, z_dim, device='cpu'):return torch.randn(n_samples, z_dim, device=device)criterion = nn.BCEWithLogitsLoss()
z_dim = 64
display_step = 500
batch_size = 1280
# A learning rate of 0.0002 works well on DCGAN
lr = 0.0002beta_1 = 0.5
beta_2 = 0.999
device = 'cuda'# You can tranform the image values to be between -1 and 1 (the range of the tanh activation)
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5,), (0.5,)),
])dataloader = DataLoader(MNIST('.', download=False, transform=transform),batch_size=batch_size,shuffle=True)gen = Generator(z_dim).to(device)
gen_opt = torch.optim.Adam(gen.parameters(), lr=lr, betas=(beta_1, beta_2))
disc = Discriminator().to(device)
disc_opt = torch.optim.Adam(disc.parameters(), lr=lr, betas=(beta_1, beta_2))def weights_init(m):if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):torch.nn.init.normal_(m.weight, 0.0, 0.02)if isinstance(m, nn.BatchNorm2d):torch.nn.init.normal_(m.weight, 0.0, 0.02)torch.nn.init.constant_(m.bias, 0)
gen = gen.apply(weights_init)
disc = disc.apply(weights_init)n_epochs = 500
cur_step = 0
mean_generator_loss = 0
mean_discriminator_loss = 0
for epoch in range(n_epochs):# Dataloader returns the batchesfor real, _ in tqdm(dataloader):cur_batch_size = len(real)real = real.to(device)## Update discriminator ##disc_opt.zero_grad()fake_noise = get_noise(cur_batch_size, z_dim, device=device)fake = gen(fake_noise)disc_fake_pred = disc(fake.detach())disc_fake_loss = criterion(disc_fake_pred, torch.zeros_like(disc_fake_pred))disc_real_pred = disc(real)disc_real_loss = criterion(disc_real_pred, torch.ones_like(disc_real_pred))disc_loss = (disc_fake_loss + disc_real_loss) / 2# Keep track of the average discriminator lossmean_discriminator_loss += disc_loss.item() / display_step# Update gradientsdisc_loss.backward(retain_graph=True)# Update optimizerdisc_opt.step()## Update generator ##gen_opt.zero_grad()fake_noise_2 = get_noise(cur_batch_size, z_dim, device=device)fake_2 = gen(fake_noise_2)disc_fake_pred = disc(fake_2)gen_loss = criterion(disc_fake_pred, torch.ones_like(disc_fake_pred))gen_loss.backward()gen_opt.step()# Keep track of the average generator lossmean_generator_loss += gen_loss.item() / display_step## Visualization code ##if cur_step % display_step == 0 and cur_step > 0:print(f"Step {cur_step}: Generator loss: {mean_generator_loss}, discriminator loss: {mean_discriminator_loss}")show_tensor_images(fake)show_tensor_images(real)mean_generator_loss = 0mean_discriminator_loss = 0cur_step += 1

每500个batch展示一次
每500个batch展示一次。
在这里插入图片描述
可以看到生成器的网络模型不再使用全连接,使用反卷积操作扩大数据维度;在输出层使用Tanh激活函数以控制输出范围,而在其它层中均使用了ReLU激活函数;在隐藏层中每层都使用BN来讲输出归到一定的范围内来稳定学习,使得后层的隐藏单元不过分依赖本层的隐藏单元,减弱内部协变量偏移,从而加速对特征的学习。
因为不再使用全连接而是使用卷积,所以输入的dimension变为channel,所以输入之前先改变noise的shape为(batch_size,channel,high,width)。
在这里插入图片描述
判别器的网络模型使用卷积代替的全连接,使用卷积操作减小数据维度;隐藏层中每层在激活之前使用BN。
在这里插入图片描述
对生成器和鉴别器的权重进行初始化,对于卷积层和转置卷积层(也就是反卷积层)使用正态分布来初始化权重(均值为0,标准差为0.02)的原因是为了确保权重的初始值具有适当的大小,并且不会过大或过小,从而避免梯度消失或梯度爆炸的问题。
对于BN化层,同样使用正态分布来初始化权重,同时将偏置项初始化为0。这是因为批归一化层在训练中通过调整均值和方差来规范化输入数据,因此初始的权重和偏置项都设置为较小的值,有助于加速网络的收敛。

下一篇构建WGAN_GP。

相关文章:

使用pytorch构建一个无监督的深度卷积GAN网络模型

本文为此系列的第二篇DCGAN,上一篇为初级的GAN。普通GAN有训练不稳定、容易陷入局部最优等问题,DCGAN相对于普通GAN的优点是能够生成更加逼真、清晰的图像。 因为DCGAN是在GAN的基础上的改造,所以本篇只针对GAN的改造点进行讲解,其…...

[AI]文心一言出圈的同时,NLP处理下的ChatGPT-4.5最新资讯

AI文心一言出圈的同时,NLP处理下的ChatGPT-4.5最新资讯 1.背景介绍 随着人工智能技术的不断发展,自然语言处理(NLP)技术在近年来取得了显著的进步。其中,聊天机器人技术作为NLP领域的一个重要应用,已经广…...

vue.js设计与实现(分支切换与cleanup)

如存在三元运算符时,怎么处理 // 原始数据 const data { text: hello world,ok:true}// 副作用函数存在三元运算符 effect(function effectFn(){document.body.innerText obj.ok ? obj.text : not })// 理解如此,obj.ok和obj.text都会绑定effectFn函…...

206基于matlab的无人机航迹规划(UAV track plannin)

基于matlab的无人机航迹规划(UAV track plannin)。输入输出参数包括 横滚、俯仰、航向角(单位:度);横滚速率、俯仰速率、航向角速率(单位:度/秒);飞机运动速度——X右翼、…...

【Linux 】查看veth-pair对的映射关系

1. 查看当前存在的ns ip netns add netns199 //新建一个命名空间 # ip netns show netns199 (id: 3)可以看到一个名称叫做netns199 的命名空间,其 id为3 2. 创建一个对,并加入其中一个到其他命名空间中 $ sudo ip link add veth100 type veth peer n…...

Cisco Firepower FMCv修改管理Ip方法

FMCv 是部署在VMWARE虚拟平台上的FMC 部署完成后,如何修改管理IP 1 查看当前版本 show version 可以看到是for VMware 2 修改管理IP步骤 2.1 进入expert模式 expert2.2 进入超级用户 sudo su并输入密码 2.3 查看当前网卡Ip 2.4 修改Ip 命令: /…...

PHP开发全新29网课交单平台源码修复全开源版本,支持聚合登陆易支付

这是一套最新版本的PHP开发的网课交单平台源代码,已进行全开源修复,支持聚合登录和易支付功能。 项目 地 址 : runruncode.com/php/19721.html 以下是对该套代码的主要更新和修复: 1. 移除了论文编辑功能。 2. 移除了强国接码…...

【Web前端】CSS基本语法规范和引入方式常见选择器用法常见元素属性

一、基本语法规范 选择器 {一条/N条声明} 选择器决定针对谁修改 (找谁) 声明决定修改什么.。(干什么) 声明的属性是键值对.。使用 &#xff1a; 区分键值对&#xff0c; 使用 &#xff1a; 区分键和值。 <!DOCTYPE html> <html lang"en"> <head>&…...

SnapGene 5 for Mac 分子生物学软件

SnapGene 5 for Mac是一款专为Mac操作系统设计的分子生物学软件&#xff0c;以其强大的功能和用户友好的界面&#xff0c;为科研人员提供了高效、便捷的基因克隆和分子实验设计体验。 软件下载&#xff1a;SnapGene 5 for Mac v5.3.1中文激活版 这款软件支持DNA构建和克隆设计&…...

本地部署大模型的几种工具(上-相关使用)

目录 前言 为什么本地部署 目前的工具 vllm 介绍 下载模型 安装vllm 运行 存在问题 chatglm.cpp 介绍 下载 安装 运行 命令行运行 webdemo运行 GPU推理 ollama 介绍 下载 运行 运行不同参数量的模型 存在问题 lmstudio 介绍 下载 使用 下载模型文件…...

Spring Boot集成itext实现html生成PDF功能

1.itext介绍 iText是著名的开放源码的站点sourceforge一个项目,是用于生成PDF文档的一个java类库。通过iText不仅可以生成PDF或rtf的文档,而且可以将XML、Html文件转化为PDF文件 iText 的特点 以下是 iText 库的显着特点 − Interactive − iText 为你提供类(API)来生成…...

Java 多态、包、final、权限修饰符、静态代码块

多态 Java多态是指一个对象可以具有多种形态。它是面向对象编程的一个重要特性&#xff0c;允许子类对象可以被当作父类对象使用。多态的实现主要依赖于继承、接口和方法重写。 在Java中&#xff0c;多态的实现主要通过以下两种方式&#xff1a; 继承&#xff1a;子类继承父类…...

基于Spring boot + Vue协同过滤算法的电影推荐系统

末尾获取源码作者介绍&#xff1a;大家好&#xff0c;我是墨韵&#xff0c;本人4年开发经验&#xff0c;专注定制项目开发 更多项目&#xff1a;CSDN主页YAML墨韵 学如逆水行舟&#xff0c;不进则退。学习如赶路&#xff0c;不能慢一步。 目录 一、项目简介 二、开发技术与环…...

Chrome之解决:浏览器插件不能使用问题(十三)

简介&#xff1a; CSDN博客专家&#xff0c;专注Android/Linux系统&#xff0c;分享多mic语音方案、音视频、编解码等技术&#xff0c;与大家一起成长&#xff01; 优质专栏&#xff1a;Audio工程师进阶系列【原创干货持续更新中……】&#x1f680; 优质专栏&#xff1a;多媒…...

【正版特惠】IDM 永久授权 优惠低至109元!

尽管小编有修改版IDM&#xff0c;但是由于软件太好用了&#xff0c;很多同学干脆就直接购买了正版&#xff0c;现在正版也不贵&#xff0c;并且授权码绑定自己的邮箱&#xff0c;直接官方下载激活&#xff0c;无需其他的绿化修改之类的操作&#xff0c;不喜欢那么麻烦的&#x…...

SpringBoot与Prometheus监控整合

参考&#xff1a; springboot实战之prometheus监控整合-腾讯云开发者社区-腾讯云 https://www.cnblogs.com/skevin/p/15874139.html https://www.jianshu.com/p/e5dc2b45c7a4...

Linux 系统 docker搭建LNMP环境

1、安装nginx docker pull nginx (默认安装的是最新版本) 2、运行nginx docker run --name nginx -p 80:80 -d nginx:latest 备注&#xff1a;--name nginx 表示容器名为 nginx -d 表示后台运行 -p 80:80 表示把本地80端口绑定到Nginx服务端的 80端口 nginx:lates…...

拉普拉斯变换

定义&#xff1a; 拉普拉斯变换是一种在信号处理、控制理论和其他领域中广泛使用的数学工具&#xff0c;用于将一个函数从时域转换到复频域。拉普拉斯变换将一个函数 f(t) 变换为一个复变量函数 F(s)&#xff0c;其中 s 是复数变量。下面是拉普拉斯变换的推导过程&#xff1a;…...

Mashup-Math_Topic_One

Tutorial and Introspection A Rudolf and 121 注意到第 1 1 1 位只能被第 2 2 2 位影响&#xff0c;以此类推位置&#xff0c;对于 a i a_i ai​ , 如果 < 0 < 0 <0 &#xff0c;不合法 ; 否则&#xff0c; a i − a i , a i 1 − 2 ∗ a i , a i 2 − a …...

基于JavaWEB SSM SpringBoot婚纱影楼摄影预约网站设计和实现

基于JavaWEB SSM SpringBoot婚纱影楼摄影预约网站设计和实现 博主介绍&#xff1a;多年java开发经验&#xff0c;专注Java开发、定制、远程、文档编写指导等,csdn特邀作者、专注于Java技术领域 作者主页 央顺技术团队 Java毕设项目精品实战案例《1000套》 欢迎点赞 收藏 ⭐留言…...

7.4.分块查找

一.分块查找的算法思想&#xff1a; 1.实例&#xff1a; 以上述图片的顺序表为例&#xff0c; 该顺序表的数据元素从整体来看是乱序的&#xff0c;但如果把这些数据元素分成一块一块的小区间&#xff0c; 第一个区间[0,1]索引上的数据元素都是小于等于10的&#xff0c; 第二…...

Linux简单的操作

ls ls 查看当前目录 ll 查看详细内容 ls -a 查看所有的内容 ls --help 查看方法文档 pwd pwd 查看当前路径 cd cd 转路径 cd .. 转上一级路径 cd 名 转换路径 …...

Robots.txt 文件

什么是robots.txt&#xff1f; robots.txt 是一个位于网站根目录下的文本文件&#xff08;如&#xff1a;https://example.com/robots.txt&#xff09;&#xff0c;它用于指导网络爬虫&#xff08;如搜索引擎的蜘蛛程序&#xff09;如何抓取该网站的内容。这个文件遵循 Robots…...

Matlab | matlab常用命令总结

常用命令 一、 基础操作与环境二、 矩阵与数组操作(核心)三、 绘图与可视化四、 编程与控制流五、 符号计算 (Symbolic Math Toolbox)六、 文件与数据 I/O七、 常用函数类别重要提示这是一份 MATLAB 常用命令和功能的总结,涵盖了基础操作、矩阵运算、绘图、编程和文件处理等…...

如何在网页里填写 PDF 表格?

有时候&#xff0c;你可能希望用户能在你的网站上填写 PDF 表单。然而&#xff0c;这件事并不简单&#xff0c;因为 PDF 并不是一种原生的网页格式。虽然浏览器可以显示 PDF 文件&#xff0c;但原生并不支持编辑或填写它们。更糟的是&#xff0c;如果你想收集表单数据&#xff…...

短视频矩阵系统文案创作功能开发实践,定制化开发

在短视频行业迅猛发展的当下&#xff0c;企业和个人创作者为了扩大影响力、提升传播效果&#xff0c;纷纷采用短视频矩阵运营策略&#xff0c;同时管理多个平台、多个账号的内容发布。然而&#xff0c;频繁的文案创作需求让运营者疲于应对&#xff0c;如何高效产出高质量文案成…...

Web中间件--tomcat学习

Web中间件–tomcat Java虚拟机详解 什么是JAVA虚拟机 Java虚拟机是一个抽象的计算机&#xff0c;它可以执行Java字节码。Java虚拟机是Java平台的一部分&#xff0c;Java平台由Java语言、Java API和Java虚拟机组成。Java虚拟机的主要作用是将Java字节码转换为机器代码&#x…...

消息队列系统设计与实践全解析

文章目录 &#x1f680; 消息队列系统设计与实践全解析&#x1f50d; 一、消息队列选型1.1 业务场景匹配矩阵1.2 吞吐量/延迟/可靠性权衡&#x1f4a1; 权衡决策框架 1.3 运维复杂度评估&#x1f527; 运维成本降低策略 &#x1f3d7;️ 二、典型架构设计2.1 分布式事务最终一致…...

0x-3-Oracle 23 ai-sqlcl 25.1 集成安装-配置和优化

是不是受够了安装了oracle database之后sqlplus的简陋&#xff0c;无法删除无法上下翻页的苦恼。 可以安装readline和rlwrap插件的话&#xff0c;配置.bahs_profile后也能解决上下翻页这些&#xff0c;但是很多生产环境无法安装rpm包。 oracle提供了sqlcl免费许可&#xff0c…...

数据结构:递归的种类(Types of Recursion)

目录 尾递归&#xff08;Tail Recursion&#xff09; 什么是 Loop&#xff08;循环&#xff09;&#xff1f; 复杂度分析 头递归&#xff08;Head Recursion&#xff09; 树形递归&#xff08;Tree Recursion&#xff09; 线性递归&#xff08;Linear Recursion&#xff09;…...