当前位置: 首页 > news >正文

使用pytorch构建一个无监督的深度卷积GAN网络模型

本文为此系列的第二篇DCGAN,上一篇为初级的GAN。普通GAN有训练不稳定、容易陷入局部最优等问题,DCGAN相对于普通GAN的优点是能够生成更加逼真、清晰的图像。
因为DCGAN是在GAN的基础上的改造,所以本篇只针对GAN的改造点进行讲解,其他还有不太了解的原理可以返回上一篇进行观看。

本文仍然使用MNIST手写数字数据集来构建一个深度卷积GAN(Deep Convolutional GAN)DCGAN,将使用卷积来替代全连接层,点击查看论文,generator的网络结构图如下:
在这里插入图片描述
DCGAN模型有以下特点:

  1. 判别器模型使用卷积步长取代了空间池化,生成器模型中使用反卷积操作扩大数据维度。
  2. 除了生成器模型的输出层和判别器模型的输入层,在整个对抗网络的其它层上都使用了Batch Normalization,原因是Batch Normalization可以稳定学习,有助于优化初始化参数值不良而导致的训练问题。
  3. 整个网络去除了全连接层,直接使用卷积层连接生成器和判别器的输入层以及输出层。
  4. 在生成器的输出层使用Tanh激活函数以控制输出范围,而在其它层中均使用了ReLU激活函数;在判别器上使用Leaky ReLU激活函数。

代码

model.py:

from torch import nnclass Generator(nn.Module):def __init__(self, z_dim=10, im_chan=1, hidden_dim=64):super(Generator, self).__init__()self.z_dim = z_dim# Build the neural networkself.gen = nn.Sequential(self.make_gen_block(z_dim, hidden_dim * 4),self.make_gen_block(hidden_dim * 4, hidden_dim * 2, kernel_size=4, stride=1),self.make_gen_block(hidden_dim * 2, hidden_dim),self.make_gen_block(hidden_dim, im_chan, kernel_size=4, final_layer=True),)def make_gen_block(self, input_channels, output_channels, kernel_size=3, stride=2, final_layer=False):if not final_layer:return nn.Sequential(nn.ConvTranspose2d(input_channels, output_channels, kernel_size=kernel_size, stride=stride),nn.BatchNorm2d(output_channels),nn.ReLU(inplace=True))else: # Final Layerreturn nn.Sequential(nn.ConvTranspose2d(input_channels, output_channels, kernel_size, stride),nn.Tanh())def unsqueeze_noise(self, noise):return noise.view(len(noise), self.z_dim, 1, 1)    # [b,c,h,w]def forward(self, noise):x = self.unsqueeze_noise(noise)return self.gen(x)class Discriminator(nn.Module):def __init__(self, im_chan=1, hidden_dim=16):super(Discriminator, self).__init__()self.disc = nn.Sequential(self.make_disc_block(im_chan, hidden_dim),self.make_disc_block(hidden_dim, hidden_dim * 2),self.make_disc_block(hidden_dim * 2, 1, final_layer=True),)def make_disc_block(self, input_channels, output_channels, kernel_size=4, stride=2, final_layer=False):if not final_layer:return nn.Sequential(nn.Conv2d(input_channels, output_channels, kernel_size, stride),nn.BatchNorm2d(output_channels),nn.LeakyReLU(0.2, inplace=True))else:  # Final Layerreturn nn.Sequential(nn.Conv2d(input_channels, output_channels, kernel_size, stride))def forward(self, image):disc_pred = self.disc(image)return disc_pred.view(len(disc_pred), -1)

train.py:

import torch
from torch import nn
from tqdm.auto import tqdm
from torchvision import transforms
from torchvision.datasets import MNIST
from torchvision.utils import make_grid
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
from model import *
torch.manual_seed(0) # Set for testing purposes, please do not change!def show_tensor_images(image_tensor, num_images=25, size=(1, 28, 28)):image_tensor = (image_tensor + 1) / 2image_unflat = image_tensor.detach().cpu()image_grid = make_grid(image_unflat[:num_images], nrow=5)plt.imshow(image_grid.permute(1, 2, 0).squeeze())plt.show()def get_noise(n_samples, z_dim, device='cpu'):return torch.randn(n_samples, z_dim, device=device)criterion = nn.BCEWithLogitsLoss()
z_dim = 64
display_step = 500
batch_size = 1280
# A learning rate of 0.0002 works well on DCGAN
lr = 0.0002beta_1 = 0.5
beta_2 = 0.999
device = 'cuda'# You can tranform the image values to be between -1 and 1 (the range of the tanh activation)
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5,), (0.5,)),
])dataloader = DataLoader(MNIST('.', download=False, transform=transform),batch_size=batch_size,shuffle=True)gen = Generator(z_dim).to(device)
gen_opt = torch.optim.Adam(gen.parameters(), lr=lr, betas=(beta_1, beta_2))
disc = Discriminator().to(device)
disc_opt = torch.optim.Adam(disc.parameters(), lr=lr, betas=(beta_1, beta_2))def weights_init(m):if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):torch.nn.init.normal_(m.weight, 0.0, 0.02)if isinstance(m, nn.BatchNorm2d):torch.nn.init.normal_(m.weight, 0.0, 0.02)torch.nn.init.constant_(m.bias, 0)
gen = gen.apply(weights_init)
disc = disc.apply(weights_init)n_epochs = 500
cur_step = 0
mean_generator_loss = 0
mean_discriminator_loss = 0
for epoch in range(n_epochs):# Dataloader returns the batchesfor real, _ in tqdm(dataloader):cur_batch_size = len(real)real = real.to(device)## Update discriminator ##disc_opt.zero_grad()fake_noise = get_noise(cur_batch_size, z_dim, device=device)fake = gen(fake_noise)disc_fake_pred = disc(fake.detach())disc_fake_loss = criterion(disc_fake_pred, torch.zeros_like(disc_fake_pred))disc_real_pred = disc(real)disc_real_loss = criterion(disc_real_pred, torch.ones_like(disc_real_pred))disc_loss = (disc_fake_loss + disc_real_loss) / 2# Keep track of the average discriminator lossmean_discriminator_loss += disc_loss.item() / display_step# Update gradientsdisc_loss.backward(retain_graph=True)# Update optimizerdisc_opt.step()## Update generator ##gen_opt.zero_grad()fake_noise_2 = get_noise(cur_batch_size, z_dim, device=device)fake_2 = gen(fake_noise_2)disc_fake_pred = disc(fake_2)gen_loss = criterion(disc_fake_pred, torch.ones_like(disc_fake_pred))gen_loss.backward()gen_opt.step()# Keep track of the average generator lossmean_generator_loss += gen_loss.item() / display_step## Visualization code ##if cur_step % display_step == 0 and cur_step > 0:print(f"Step {cur_step}: Generator loss: {mean_generator_loss}, discriminator loss: {mean_discriminator_loss}")show_tensor_images(fake)show_tensor_images(real)mean_generator_loss = 0mean_discriminator_loss = 0cur_step += 1

每500个batch展示一次
每500个batch展示一次。
在这里插入图片描述
可以看到生成器的网络模型不再使用全连接,使用反卷积操作扩大数据维度;在输出层使用Tanh激活函数以控制输出范围,而在其它层中均使用了ReLU激活函数;在隐藏层中每层都使用BN来讲输出归到一定的范围内来稳定学习,使得后层的隐藏单元不过分依赖本层的隐藏单元,减弱内部协变量偏移,从而加速对特征的学习。
因为不再使用全连接而是使用卷积,所以输入的dimension变为channel,所以输入之前先改变noise的shape为(batch_size,channel,high,width)。
在这里插入图片描述
判别器的网络模型使用卷积代替的全连接,使用卷积操作减小数据维度;隐藏层中每层在激活之前使用BN。
在这里插入图片描述
对生成器和鉴别器的权重进行初始化,对于卷积层和转置卷积层(也就是反卷积层)使用正态分布来初始化权重(均值为0,标准差为0.02)的原因是为了确保权重的初始值具有适当的大小,并且不会过大或过小,从而避免梯度消失或梯度爆炸的问题。
对于BN化层,同样使用正态分布来初始化权重,同时将偏置项初始化为0。这是因为批归一化层在训练中通过调整均值和方差来规范化输入数据,因此初始的权重和偏置项都设置为较小的值,有助于加速网络的收敛。

下一篇构建WGAN_GP。

相关文章:

使用pytorch构建一个无监督的深度卷积GAN网络模型

本文为此系列的第二篇DCGAN,上一篇为初级的GAN。普通GAN有训练不稳定、容易陷入局部最优等问题,DCGAN相对于普通GAN的优点是能够生成更加逼真、清晰的图像。 因为DCGAN是在GAN的基础上的改造,所以本篇只针对GAN的改造点进行讲解,其…...

[AI]文心一言出圈的同时,NLP处理下的ChatGPT-4.5最新资讯

AI文心一言出圈的同时,NLP处理下的ChatGPT-4.5最新资讯 1.背景介绍 随着人工智能技术的不断发展,自然语言处理(NLP)技术在近年来取得了显著的进步。其中,聊天机器人技术作为NLP领域的一个重要应用,已经广…...

vue.js设计与实现(分支切换与cleanup)

如存在三元运算符时,怎么处理 // 原始数据 const data { text: hello world,ok:true}// 副作用函数存在三元运算符 effect(function effectFn(){document.body.innerText obj.ok ? obj.text : not })// 理解如此,obj.ok和obj.text都会绑定effectFn函…...

206基于matlab的无人机航迹规划(UAV track plannin)

基于matlab的无人机航迹规划(UAV track plannin)。输入输出参数包括 横滚、俯仰、航向角(单位:度);横滚速率、俯仰速率、航向角速率(单位:度/秒);飞机运动速度——X右翼、…...

【Linux 】查看veth-pair对的映射关系

1. 查看当前存在的ns ip netns add netns199 //新建一个命名空间 # ip netns show netns199 (id: 3)可以看到一个名称叫做netns199 的命名空间,其 id为3 2. 创建一个对,并加入其中一个到其他命名空间中 $ sudo ip link add veth100 type veth peer n…...

Cisco Firepower FMCv修改管理Ip方法

FMCv 是部署在VMWARE虚拟平台上的FMC 部署完成后,如何修改管理IP 1 查看当前版本 show version 可以看到是for VMware 2 修改管理IP步骤 2.1 进入expert模式 expert2.2 进入超级用户 sudo su并输入密码 2.3 查看当前网卡Ip 2.4 修改Ip 命令: /…...

PHP开发全新29网课交单平台源码修复全开源版本,支持聚合登陆易支付

这是一套最新版本的PHP开发的网课交单平台源代码,已进行全开源修复,支持聚合登录和易支付功能。 项目 地 址 : runruncode.com/php/19721.html 以下是对该套代码的主要更新和修复: 1. 移除了论文编辑功能。 2. 移除了强国接码…...

【Web前端】CSS基本语法规范和引入方式常见选择器用法常见元素属性

一、基本语法规范 选择器 {一条/N条声明} 选择器决定针对谁修改 (找谁) 声明决定修改什么.。(干什么) 声明的属性是键值对.。使用 &#xff1a; 区分键值对&#xff0c; 使用 &#xff1a; 区分键和值。 <!DOCTYPE html> <html lang"en"> <head>&…...

SnapGene 5 for Mac 分子生物学软件

SnapGene 5 for Mac是一款专为Mac操作系统设计的分子生物学软件&#xff0c;以其强大的功能和用户友好的界面&#xff0c;为科研人员提供了高效、便捷的基因克隆和分子实验设计体验。 软件下载&#xff1a;SnapGene 5 for Mac v5.3.1中文激活版 这款软件支持DNA构建和克隆设计&…...

本地部署大模型的几种工具(上-相关使用)

目录 前言 为什么本地部署 目前的工具 vllm 介绍 下载模型 安装vllm 运行 存在问题 chatglm.cpp 介绍 下载 安装 运行 命令行运行 webdemo运行 GPU推理 ollama 介绍 下载 运行 运行不同参数量的模型 存在问题 lmstudio 介绍 下载 使用 下载模型文件…...

Spring Boot集成itext实现html生成PDF功能

1.itext介绍 iText是著名的开放源码的站点sourceforge一个项目,是用于生成PDF文档的一个java类库。通过iText不仅可以生成PDF或rtf的文档,而且可以将XML、Html文件转化为PDF文件 iText 的特点 以下是 iText 库的显着特点 − Interactive − iText 为你提供类(API)来生成…...

Java 多态、包、final、权限修饰符、静态代码块

多态 Java多态是指一个对象可以具有多种形态。它是面向对象编程的一个重要特性&#xff0c;允许子类对象可以被当作父类对象使用。多态的实现主要依赖于继承、接口和方法重写。 在Java中&#xff0c;多态的实现主要通过以下两种方式&#xff1a; 继承&#xff1a;子类继承父类…...

基于Spring boot + Vue协同过滤算法的电影推荐系统

末尾获取源码作者介绍&#xff1a;大家好&#xff0c;我是墨韵&#xff0c;本人4年开发经验&#xff0c;专注定制项目开发 更多项目&#xff1a;CSDN主页YAML墨韵 学如逆水行舟&#xff0c;不进则退。学习如赶路&#xff0c;不能慢一步。 目录 一、项目简介 二、开发技术与环…...

Chrome之解决:浏览器插件不能使用问题(十三)

简介&#xff1a; CSDN博客专家&#xff0c;专注Android/Linux系统&#xff0c;分享多mic语音方案、音视频、编解码等技术&#xff0c;与大家一起成长&#xff01; 优质专栏&#xff1a;Audio工程师进阶系列【原创干货持续更新中……】&#x1f680; 优质专栏&#xff1a;多媒…...

【正版特惠】IDM 永久授权 优惠低至109元!

尽管小编有修改版IDM&#xff0c;但是由于软件太好用了&#xff0c;很多同学干脆就直接购买了正版&#xff0c;现在正版也不贵&#xff0c;并且授权码绑定自己的邮箱&#xff0c;直接官方下载激活&#xff0c;无需其他的绿化修改之类的操作&#xff0c;不喜欢那么麻烦的&#x…...

SpringBoot与Prometheus监控整合

参考&#xff1a; springboot实战之prometheus监控整合-腾讯云开发者社区-腾讯云 https://www.cnblogs.com/skevin/p/15874139.html https://www.jianshu.com/p/e5dc2b45c7a4...

Linux 系统 docker搭建LNMP环境

1、安装nginx docker pull nginx (默认安装的是最新版本) 2、运行nginx docker run --name nginx -p 80:80 -d nginx:latest 备注&#xff1a;--name nginx 表示容器名为 nginx -d 表示后台运行 -p 80:80 表示把本地80端口绑定到Nginx服务端的 80端口 nginx:lates…...

拉普拉斯变换

定义&#xff1a; 拉普拉斯变换是一种在信号处理、控制理论和其他领域中广泛使用的数学工具&#xff0c;用于将一个函数从时域转换到复频域。拉普拉斯变换将一个函数 f(t) 变换为一个复变量函数 F(s)&#xff0c;其中 s 是复数变量。下面是拉普拉斯变换的推导过程&#xff1a;…...

Mashup-Math_Topic_One

Tutorial and Introspection A Rudolf and 121 注意到第 1 1 1 位只能被第 2 2 2 位影响&#xff0c;以此类推位置&#xff0c;对于 a i a_i ai​ , 如果 < 0 < 0 <0 &#xff0c;不合法 ; 否则&#xff0c; a i − a i , a i 1 − 2 ∗ a i , a i 2 − a …...

基于JavaWEB SSM SpringBoot婚纱影楼摄影预约网站设计和实现

基于JavaWEB SSM SpringBoot婚纱影楼摄影预约网站设计和实现 博主介绍&#xff1a;多年java开发经验&#xff0c;专注Java开发、定制、远程、文档编写指导等,csdn特邀作者、专注于Java技术领域 作者主页 央顺技术团队 Java毕设项目精品实战案例《1000套》 欢迎点赞 收藏 ⭐留言…...

CTF show Web 红包题第六弹

提示 1.不是SQL注入 2.需要找关键源码 思路 进入页面发现是一个登录框&#xff0c;很难让人不联想到SQL注入&#xff0c;但提示都说了不是SQL注入&#xff0c;所以就不往这方面想了 ​ 先查看一下网页源码&#xff0c;发现一段JavaScript代码&#xff0c;有一个关键类ctfs…...

Linux简单的操作

ls ls 查看当前目录 ll 查看详细内容 ls -a 查看所有的内容 ls --help 查看方法文档 pwd pwd 查看当前路径 cd cd 转路径 cd .. 转上一级路径 cd 名 转换路径 …...

家政维修平台实战20:权限设计

目录 1 获取工人信息2 搭建工人入口3 权限判断总结 目前我们已经搭建好了基础的用户体系&#xff0c;主要是分成几个表&#xff0c;用户表我们是记录用户的基础信息&#xff0c;包括手机、昵称、头像。而工人和员工各有各的表。那么就有一个问题&#xff0c;不同的角色&#xf…...

Qwen3-Embedding-0.6B深度解析:多语言语义检索的轻量级利器

第一章 引言&#xff1a;语义表示的新时代挑战与Qwen3的破局之路 1.1 文本嵌入的核心价值与技术演进 在人工智能领域&#xff0c;文本嵌入技术如同连接自然语言与机器理解的“神经突触”——它将人类语言转化为计算机可计算的语义向量&#xff0c;支撑着搜索引擎、推荐系统、…...

Spring Cloud Gateway 中自定义验证码接口返回 404 的排查与解决

Spring Cloud Gateway 中自定义验证码接口返回 404 的排查与解决 问题背景 在一个基于 Spring Cloud Gateway WebFlux 构建的微服务项目中&#xff0c;新增了一个本地验证码接口 /code&#xff0c;使用函数式路由&#xff08;RouterFunction&#xff09;和 Hutool 的 Circle…...

html css js网页制作成品——HTML+CSS榴莲商城网页设计(4页)附源码

目录 一、&#x1f468;‍&#x1f393;网站题目 二、✍️网站描述 三、&#x1f4da;网站介绍 四、&#x1f310;网站效果 五、&#x1fa93; 代码实现 &#x1f9f1;HTML 六、&#x1f947; 如何让学习不再盲目 七、&#x1f381;更多干货 一、&#x1f468;‍&#x1f…...

Spring是如何解决Bean的循环依赖:三级缓存机制

1、什么是 Bean 的循环依赖 在 Spring框架中,Bean 的循环依赖是指多个 Bean 之间‌互相持有对方引用‌,形成闭环依赖关系的现象。 多个 Bean 的依赖关系构成环形链路,例如: 双向依赖:Bean A 依赖 Bean B,同时 Bean B 也依赖 Bean A(A↔B)。链条循环: Bean A → Bean…...

人机融合智能 | “人智交互”跨学科新领域

本文系统地提出基于“以人为中心AI(HCAI)”理念的人-人工智能交互(人智交互)这一跨学科新领域及框架,定义人智交互领域的理念、基本理论和关键问题、方法、开发流程和参与团队等,阐述提出人智交互新领域的意义。然后,提出人智交互研究的三种新范式取向以及它们的意义。最后,总结…...

【网络安全】开源系统getshell漏洞挖掘

审计过程&#xff1a; 在入口文件admin/index.php中&#xff1a; 用户可以通过m,c,a等参数控制加载的文件和方法&#xff0c;在app/system/entrance.php中存在重点代码&#xff1a; 当M_TYPE system并且M_MODULE include时&#xff0c;会设置常量PATH_OWN_FILE为PATH_APP.M_T…...

Cilium动手实验室: 精通之旅---13.Cilium LoadBalancer IPAM and L2 Service Announcement

Cilium动手实验室: 精通之旅---13.Cilium LoadBalancer IPAM and L2 Service Announcement 1. LAB环境2. L2公告策略2.1 部署Death Star2.2 访问服务2.3 部署L2公告策略2.4 服务宣告 3. 可视化 ARP 流量3.1 部署新服务3.2 准备可视化3.3 再次请求 4. 自动IPAM4.1 IPAM Pool4.2 …...