使用pytorch构建一个无监督的深度卷积GAN网络模型
本文为此系列的第二篇DCGAN,上一篇为初级的GAN。普通GAN有训练不稳定、容易陷入局部最优等问题,DCGAN相对于普通GAN的优点是能够生成更加逼真、清晰的图像。
因为DCGAN是在GAN的基础上的改造,所以本篇只针对GAN的改造点进行讲解,其他还有不太了解的原理可以返回上一篇进行观看。
本文仍然使用MNIST手写数字数据集来构建一个深度卷积GAN(Deep Convolutional GAN)DCGAN,将使用卷积来替代全连接层,点击查看论文,generator的网络结构图如下:
DCGAN模型有以下特点:
- 判别器模型使用卷积步长取代了空间池化,生成器模型中使用反卷积操作扩大数据维度。
- 除了生成器模型的输出层和判别器模型的输入层,在整个对抗网络的其它层上都使用了Batch Normalization,原因是Batch Normalization可以稳定学习,有助于优化初始化参数值不良而导致的训练问题。
- 整个网络去除了全连接层,直接使用卷积层连接生成器和判别器的输入层以及输出层。
- 在生成器的输出层使用Tanh激活函数以控制输出范围,而在其它层中均使用了ReLU激活函数;在判别器上使用Leaky ReLU激活函数。
代码
model.py:
from torch import nnclass Generator(nn.Module):def __init__(self, z_dim=10, im_chan=1, hidden_dim=64):super(Generator, self).__init__()self.z_dim = z_dim# Build the neural networkself.gen = nn.Sequential(self.make_gen_block(z_dim, hidden_dim * 4),self.make_gen_block(hidden_dim * 4, hidden_dim * 2, kernel_size=4, stride=1),self.make_gen_block(hidden_dim * 2, hidden_dim),self.make_gen_block(hidden_dim, im_chan, kernel_size=4, final_layer=True),)def make_gen_block(self, input_channels, output_channels, kernel_size=3, stride=2, final_layer=False):if not final_layer:return nn.Sequential(nn.ConvTranspose2d(input_channels, output_channels, kernel_size=kernel_size, stride=stride),nn.BatchNorm2d(output_channels),nn.ReLU(inplace=True))else: # Final Layerreturn nn.Sequential(nn.ConvTranspose2d(input_channels, output_channels, kernel_size, stride),nn.Tanh())def unsqueeze_noise(self, noise):return noise.view(len(noise), self.z_dim, 1, 1) # [b,c,h,w]def forward(self, noise):x = self.unsqueeze_noise(noise)return self.gen(x)class Discriminator(nn.Module):def __init__(self, im_chan=1, hidden_dim=16):super(Discriminator, self).__init__()self.disc = nn.Sequential(self.make_disc_block(im_chan, hidden_dim),self.make_disc_block(hidden_dim, hidden_dim * 2),self.make_disc_block(hidden_dim * 2, 1, final_layer=True),)def make_disc_block(self, input_channels, output_channels, kernel_size=4, stride=2, final_layer=False):if not final_layer:return nn.Sequential(nn.Conv2d(input_channels, output_channels, kernel_size, stride),nn.BatchNorm2d(output_channels),nn.LeakyReLU(0.2, inplace=True))else: # Final Layerreturn nn.Sequential(nn.Conv2d(input_channels, output_channels, kernel_size, stride))def forward(self, image):disc_pred = self.disc(image)return disc_pred.view(len(disc_pred), -1)
train.py:
import torch
from torch import nn
from tqdm.auto import tqdm
from torchvision import transforms
from torchvision.datasets import MNIST
from torchvision.utils import make_grid
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
from model import *
torch.manual_seed(0) # Set for testing purposes, please do not change!def show_tensor_images(image_tensor, num_images=25, size=(1, 28, 28)):image_tensor = (image_tensor + 1) / 2image_unflat = image_tensor.detach().cpu()image_grid = make_grid(image_unflat[:num_images], nrow=5)plt.imshow(image_grid.permute(1, 2, 0).squeeze())plt.show()def get_noise(n_samples, z_dim, device='cpu'):return torch.randn(n_samples, z_dim, device=device)criterion = nn.BCEWithLogitsLoss()
z_dim = 64
display_step = 500
batch_size = 1280
# A learning rate of 0.0002 works well on DCGAN
lr = 0.0002beta_1 = 0.5
beta_2 = 0.999
device = 'cuda'# You can tranform the image values to be between -1 and 1 (the range of the tanh activation)
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5,), (0.5,)),
])dataloader = DataLoader(MNIST('.', download=False, transform=transform),batch_size=batch_size,shuffle=True)gen = Generator(z_dim).to(device)
gen_opt = torch.optim.Adam(gen.parameters(), lr=lr, betas=(beta_1, beta_2))
disc = Discriminator().to(device)
disc_opt = torch.optim.Adam(disc.parameters(), lr=lr, betas=(beta_1, beta_2))def weights_init(m):if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):torch.nn.init.normal_(m.weight, 0.0, 0.02)if isinstance(m, nn.BatchNorm2d):torch.nn.init.normal_(m.weight, 0.0, 0.02)torch.nn.init.constant_(m.bias, 0)
gen = gen.apply(weights_init)
disc = disc.apply(weights_init)n_epochs = 500
cur_step = 0
mean_generator_loss = 0
mean_discriminator_loss = 0
for epoch in range(n_epochs):# Dataloader returns the batchesfor real, _ in tqdm(dataloader):cur_batch_size = len(real)real = real.to(device)## Update discriminator ##disc_opt.zero_grad()fake_noise = get_noise(cur_batch_size, z_dim, device=device)fake = gen(fake_noise)disc_fake_pred = disc(fake.detach())disc_fake_loss = criterion(disc_fake_pred, torch.zeros_like(disc_fake_pred))disc_real_pred = disc(real)disc_real_loss = criterion(disc_real_pred, torch.ones_like(disc_real_pred))disc_loss = (disc_fake_loss + disc_real_loss) / 2# Keep track of the average discriminator lossmean_discriminator_loss += disc_loss.item() / display_step# Update gradientsdisc_loss.backward(retain_graph=True)# Update optimizerdisc_opt.step()## Update generator ##gen_opt.zero_grad()fake_noise_2 = get_noise(cur_batch_size, z_dim, device=device)fake_2 = gen(fake_noise_2)disc_fake_pred = disc(fake_2)gen_loss = criterion(disc_fake_pred, torch.ones_like(disc_fake_pred))gen_loss.backward()gen_opt.step()# Keep track of the average generator lossmean_generator_loss += gen_loss.item() / display_step## Visualization code ##if cur_step % display_step == 0 and cur_step > 0:print(f"Step {cur_step}: Generator loss: {mean_generator_loss}, discriminator loss: {mean_discriminator_loss}")show_tensor_images(fake)show_tensor_images(real)mean_generator_loss = 0mean_discriminator_loss = 0cur_step += 1
每500个batch展示一次。
可以看到生成器的网络模型不再使用全连接,使用反卷积操作扩大数据维度;在输出层使用Tanh激活函数以控制输出范围,而在其它层中均使用了ReLU激活函数;在隐藏层中每层都使用BN来讲输出归到一定的范围内来稳定学习,使得后层的隐藏单元不过分依赖本层的隐藏单元,减弱内部协变量偏移,从而加速对特征的学习。
因为不再使用全连接而是使用卷积,所以输入的dimension变为channel,所以输入之前先改变noise的shape为(batch_size,channel,high,width)。
判别器的网络模型使用卷积代替的全连接,使用卷积操作减小数据维度;隐藏层中每层在激活之前使用BN。
对生成器和鉴别器的权重进行初始化,对于卷积层和转置卷积层(也就是反卷积层)使用正态分布来初始化权重(均值为0,标准差为0.02)的原因是为了确保权重的初始值具有适当的大小,并且不会过大或过小,从而避免梯度消失或梯度爆炸的问题。
对于BN化层,同样使用正态分布来初始化权重,同时将偏置项初始化为0。这是因为批归一化层在训练中通过调整均值和方差来规范化输入数据,因此初始的权重和偏置项都设置为较小的值,有助于加速网络的收敛。
下一篇构建WGAN_GP。
相关文章:

使用pytorch构建一个无监督的深度卷积GAN网络模型
本文为此系列的第二篇DCGAN,上一篇为初级的GAN。普通GAN有训练不稳定、容易陷入局部最优等问题,DCGAN相对于普通GAN的优点是能够生成更加逼真、清晰的图像。 因为DCGAN是在GAN的基础上的改造,所以本篇只针对GAN的改造点进行讲解,其…...

[AI]文心一言出圈的同时,NLP处理下的ChatGPT-4.5最新资讯
AI文心一言出圈的同时,NLP处理下的ChatGPT-4.5最新资讯 1.背景介绍 随着人工智能技术的不断发展,自然语言处理(NLP)技术在近年来取得了显著的进步。其中,聊天机器人技术作为NLP领域的一个重要应用,已经广…...

vue.js设计与实现(分支切换与cleanup)
如存在三元运算符时,怎么处理 // 原始数据 const data { text: hello world,ok:true}// 副作用函数存在三元运算符 effect(function effectFn(){document.body.innerText obj.ok ? obj.text : not })// 理解如此,obj.ok和obj.text都会绑定effectFn函…...

206基于matlab的无人机航迹规划(UAV track plannin)
基于matlab的无人机航迹规划(UAV track plannin)。输入输出参数包括 横滚、俯仰、航向角(单位:度);横滚速率、俯仰速率、航向角速率(单位:度/秒);飞机运动速度——X右翼、…...

【Linux 】查看veth-pair对的映射关系
1. 查看当前存在的ns ip netns add netns199 //新建一个命名空间 # ip netns show netns199 (id: 3)可以看到一个名称叫做netns199 的命名空间,其 id为3 2. 创建一个对,并加入其中一个到其他命名空间中 $ sudo ip link add veth100 type veth peer n…...

Cisco Firepower FMCv修改管理Ip方法
FMCv 是部署在VMWARE虚拟平台上的FMC 部署完成后,如何修改管理IP 1 查看当前版本 show version 可以看到是for VMware 2 修改管理IP步骤 2.1 进入expert模式 expert2.2 进入超级用户 sudo su并输入密码 2.3 查看当前网卡Ip 2.4 修改Ip 命令: /…...

PHP开发全新29网课交单平台源码修复全开源版本,支持聚合登陆易支付
这是一套最新版本的PHP开发的网课交单平台源代码,已进行全开源修复,支持聚合登录和易支付功能。 项目 地 址 : runruncode.com/php/19721.html 以下是对该套代码的主要更新和修复: 1. 移除了论文编辑功能。 2. 移除了强国接码…...

【Web前端】CSS基本语法规范和引入方式常见选择器用法常见元素属性
一、基本语法规范 选择器 {一条/N条声明} 选择器决定针对谁修改 (找谁) 声明决定修改什么.。(干什么) 声明的属性是键值对.。使用 : 区分键值对, 使用 : 区分键和值。 <!DOCTYPE html> <html lang"en"> <head>&…...

SnapGene 5 for Mac 分子生物学软件
SnapGene 5 for Mac是一款专为Mac操作系统设计的分子生物学软件,以其强大的功能和用户友好的界面,为科研人员提供了高效、便捷的基因克隆和分子实验设计体验。 软件下载:SnapGene 5 for Mac v5.3.1中文激活版 这款软件支持DNA构建和克隆设计&…...

本地部署大模型的几种工具(上-相关使用)
目录 前言 为什么本地部署 目前的工具 vllm 介绍 下载模型 安装vllm 运行 存在问题 chatglm.cpp 介绍 下载 安装 运行 命令行运行 webdemo运行 GPU推理 ollama 介绍 下载 运行 运行不同参数量的模型 存在问题 lmstudio 介绍 下载 使用 下载模型文件…...

Spring Boot集成itext实现html生成PDF功能
1.itext介绍 iText是著名的开放源码的站点sourceforge一个项目,是用于生成PDF文档的一个java类库。通过iText不仅可以生成PDF或rtf的文档,而且可以将XML、Html文件转化为PDF文件 iText 的特点 以下是 iText 库的显着特点 − Interactive − iText 为你提供类(API)来生成…...

Java 多态、包、final、权限修饰符、静态代码块
多态 Java多态是指一个对象可以具有多种形态。它是面向对象编程的一个重要特性,允许子类对象可以被当作父类对象使用。多态的实现主要依赖于继承、接口和方法重写。 在Java中,多态的实现主要通过以下两种方式: 继承:子类继承父类…...

基于Spring boot + Vue协同过滤算法的电影推荐系统
末尾获取源码作者介绍:大家好,我是墨韵,本人4年开发经验,专注定制项目开发 更多项目:CSDN主页YAML墨韵 学如逆水行舟,不进则退。学习如赶路,不能慢一步。 目录 一、项目简介 二、开发技术与环…...

Chrome之解决:浏览器插件不能使用问题(十三)
简介: CSDN博客专家,专注Android/Linux系统,分享多mic语音方案、音视频、编解码等技术,与大家一起成长! 优质专栏:Audio工程师进阶系列【原创干货持续更新中……】🚀 优质专栏:多媒…...

【正版特惠】IDM 永久授权 优惠低至109元!
尽管小编有修改版IDM,但是由于软件太好用了,很多同学干脆就直接购买了正版,现在正版也不贵,并且授权码绑定自己的邮箱,直接官方下载激活,无需其他的绿化修改之类的操作,不喜欢那么麻烦的&#x…...

SpringBoot与Prometheus监控整合
参考: springboot实战之prometheus监控整合-腾讯云开发者社区-腾讯云 https://www.cnblogs.com/skevin/p/15874139.html https://www.jianshu.com/p/e5dc2b45c7a4...

Linux 系统 docker搭建LNMP环境
1、安装nginx docker pull nginx (默认安装的是最新版本) 2、运行nginx docker run --name nginx -p 80:80 -d nginx:latest 备注:--name nginx 表示容器名为 nginx -d 表示后台运行 -p 80:80 表示把本地80端口绑定到Nginx服务端的 80端口 nginx:lates…...

拉普拉斯变换
定义: 拉普拉斯变换是一种在信号处理、控制理论和其他领域中广泛使用的数学工具,用于将一个函数从时域转换到复频域。拉普拉斯变换将一个函数 f(t) 变换为一个复变量函数 F(s),其中 s 是复数变量。下面是拉普拉斯变换的推导过程:…...

Mashup-Math_Topic_One
Tutorial and Introspection A Rudolf and 121 注意到第 1 1 1 位只能被第 2 2 2 位影响,以此类推位置,对于 a i a_i ai , 如果 < 0 < 0 <0 ,不合法 ; 否则, a i − a i , a i 1 − 2 ∗ a i , a i 2 − a …...

基于JavaWEB SSM SpringBoot婚纱影楼摄影预约网站设计和实现
基于JavaWEB SSM SpringBoot婚纱影楼摄影预约网站设计和实现 博主介绍:多年java开发经验,专注Java开发、定制、远程、文档编写指导等,csdn特邀作者、专注于Java技术领域 作者主页 央顺技术团队 Java毕设项目精品实战案例《1000套》 欢迎点赞 收藏 ⭐留言…...

逐步学习Go-Select多路复用
概述 这里又有多路复用,但是Go中的这个多路复用不同于网络中的多路复用。在Go里,select用于同时等待多个通信操作(即多个channel的发送或接收操作)。Go中的channel可以参考我的文章:逐步学习Go-并发通道chan(channel)…...

王道:OJ15
课时15作业 Description 读取10个元素 87 7 60 80 59 34 86 99 21 3,然后建立二叉查找树,排序后输出3 7 21 34 59 60 80 86 87 99,针对有序后的元素,存入一个长度为10的数组中,通过折半查找找到21的下标(…...

【案例·查】数据类型强制转换,方便查询匹配
问题描述: MySQL执行中需要将某种数据类型的表达式显式转换为另一种数据类型,可以使用 SQL 中的cast()来处理 案例: SELECT CAST(9.0 AS decimal) #String化为小数类型SELECT * FROM table_1 WHERE 1888-03-07 CAST(theDate AS DATE) …...

spring boot3自定义注解+拦截器+Redis实现高并发接口限流
⛰️个人主页: 蒾酒 🔥系列专栏:《spring boot实战》 🌊山高路远,行路漫漫,终有归途 目录 写在前面 内容简介 实现思路 实现步骤 1.自定义限流注解 2.编写限流拦截器 3.注册拦截器 4.接口限流测试 写在前…...

使用certbot为网站启用https
1. 安装certbot客户端 cd /usr/local/bin wget https://dl.eff.org/certbot-auto chmod ax ./certbot-auto 2. 创建目录和配置nginx用于验证域名 mkdir -p /data/www/letsencryptserver {listen 80;server_name ~^(?<subdomain>.).ninvfeng.com;location /.well-known…...

Unity 背包系统中拖拽物体到指定位置或互换位置效果的实现
在Unity中,背包系统是一种常见的游戏系统,可以用于管理和展示玩家所持有的物品、道具或装备。 其中的拖拽功能非常有意思,具体功能就是玩家可以通过拖拽物品图标来移动物品在背包中的位置,或者将物品拖拽到其他位置或界面中&…...

iOS客户端自动化UI自动化airtest+appium从0到1搭建macos+脚本设计demo演示+全网最全最详细保姆级有步骤有图
Android客户端自动化UI自动化airtest从0到1搭建macos脚本设计demo演示全网最全最详细保姆级有步骤有图-CSDN博客 避坑系列-必读: 不要安装iOS-Tagent ,安装appium -这2个性质其实是差不多的都是为了安装wda。注意安装appium最新版本,安装完…...

每周编辑精选|在线运行 Deepmoney 金融大模型、AI 偏好等多个优质数据集上线
目前,AI 领域对金融模型的研究成果大多是基于公共知识进行训练的,但在实际的金融实践中,这些公共知识对于当前市场的可解释性往往严重不足。一个理想的金融大模型应该能够理解新闻或数据事件,并能够即时地从主观和量化两个角度对事…...

C++多重继承与虚继承
多重继承的原理 多重继承(multiple inheritance)是指从多个直接基类中产生派生类的能力。 多重继承的派生类继承了所有父类的属性。 在面向对象的编程中,多重继承意味着一个类可以从多个父类继承属性和方法。 就像你有一杯混合果汁,它是由多种水果榨取…...

请简单介绍一下Shiro框架是什么?Shiro在Java安全领域的主要作用是什么?Shiro主要提供了哪些安全功能?
请简单介绍一下Shiro框架是什么? Shiro框架是一个强大且灵活的开源安全框架,为Java应用程序提供了全面的安全解决方案。它主要用于身份验证、授权、加密和会话管理等功能,可以轻松地集成到任何Java Web应用程序中,并提供了易于理解…...