当前位置: 首页 > news >正文

DP-GAN-生成器代码

首先看一下数据生成:
在这里插入图片描述
在预处理阶段会将label经过ont-hot编码转换为35个通道,即每个通道都是由(0,1)组成。
在这里插入图片描述
在train文件中,对生成器和判别器分别进行更新,根据loss的不同,分别计算对于的损失:

loss_G, losses_G_list = model(image, label, "losses_G", losses_computer)
loss_D, losses_D_list = model(image, label, "losses_D", losses_computer)

在model中:

from models.sync_batchnorm import DataParallelWithCallback
import models.generator as generators
import models.discriminator as discriminators
import os
import copy
import torch
import torch.nn as nn
from torch.nn import init
import models.losses as losses
class DP_GAN_model(nn.Module):def __init__(self, opt):super(DP_GAN_model, self).__init__()self.opt = opt#--- generator and discriminator ---self.netG = generators.DP_GAN_Generator(opt).cuda()if opt.phase == "train" or opt.phase == "eval":self.netD = discriminators.DP_GAN_Discriminator(opt)self.print_parameter_count()self.init_networks()#--- EMA of generator weights ---with torch.no_grad():self.netEMA = copy.deepcopy(self.netG) if not opt.no_EMA else None#--- load previous checkpoints if needed ---self.load_checkpoints()#--- perceptual loss ---#if opt.phase == "train":if opt.add_vgg_loss:self.VGG_loss = losses.VGGLoss(self.opt.gpu_ids)self.GAN_loss = losses.GANLoss()self.MSELoss = nn.MSELoss(reduction='mean')def align_loss(self, feats, feats_ref):loss_align = 0for f, fr in zip(feats, feats_ref):loss_align += self.MSELoss(f, fr)return loss_aligndef forward(self, image, label, mode, losses_computer):# Branching is applied to be compatible with DataParallelif mode == "losses_G":loss_G = 0fake = self.netG(label)output_D, scores, feats = self.netD(fake)_, _, feats_ref = self.netD(image)loss_G_adv = losses_computer.loss(output_D, label, for_real=True)loss_G += loss_G_advloss_ms = self.GAN_loss(scores, True, for_discriminator=False)loss_G += loss_ms.item()loss_align = self.align_loss(feats, feats_ref)loss_G += loss_alignif self.opt.add_vgg_loss:loss_G_vgg = self.opt.lambda_vgg * self.VGG_loss(fake, image)loss_G += loss_G_vggelse:loss_G_vgg = Nonereturn loss_G, [loss_G_adv, loss_G_vgg]if mode == "losses_D":loss_D = 0with torch.no_grad():fake = self.netG(label)output_D_fake, scores_fake, _ = self.netD(fake)loss_D_fake = losses_computer.loss(output_D_fake, label, for_real=False)loss_ms_fake = self.GAN_loss(scores_fake, False, for_discriminator=True)loss_D += loss_D_fake + loss_ms_fake.item()output_D_real, scores_real, _ = self.netD(image)loss_D_real = losses_computer.loss(output_D_real, label, for_real=True)loss_ms_real = self.GAN_loss(scores_real, True, for_discriminator=True)loss_D += loss_D_real + loss_ms_real.item()if not self.opt.no_labelmix:mixed_inp, mask = generate_labelmix(label, fake, image)output_D_mixed, _, _ = self.netD(mixed_inp)loss_D_lm = self.opt.lambda_labelmix * losses_computer.loss_labelmix(mask, output_D_mixed, output_D_fake,output_D_real)loss_D += loss_D_lmelse:loss_D_lm = Nonereturn loss_D, [loss_D_fake, loss_D_real, loss_D_lm]if mode == "generate":with torch.no_grad():if self.opt.no_EMA:fake = self.netG(label)else:fake = self.netEMA(label)return fakeif mode == "eval":with torch.no_grad():pred, _, _ = self.netD(image)return preddef load_checkpoints(self):if self.opt.phase == "test":which_iter = self.opt.ckpt_iterpath = os.path.join(self.opt.checkpoints_dir, self.opt.name, "models", str(which_iter) + "_")if self.opt.no_EMA:self.netG.load_state_dict(torch.load(path + "G.pth"))else:self.netEMA.load_state_dict(torch.load(path + "EMA.pth"))elif self.opt.phase == "eval":which_iter = self.opt.ckpt_iterpath = os.path.join(self.opt.checkpoints_dir, self.opt.name, "models", str(which_iter) + "_")self.netD.load_state_dict(torch.load(path + "D.pth"))elif self.opt.continue_train:which_iter = self.opt.which_iterpath = os.path.join(self.opt.checkpoints_dir, self.opt.name, "models", str(which_iter) + "_")self.netG.load_state_dict(torch.load(path + "G.pth"))self.netD.load_state_dict(torch.load(path + "D.pth"))if not self.opt.no_EMA:self.netEMA.load_state_dict(torch.load(path + "EMA.pth"))def print_parameter_count(self):if self.opt.phase == "train":networks = [self.netG, self.netD]else:networks = [self.netG]for network in networks:param_count = 0for name, module in network.named_modules():if (isinstance(module, nn.Conv2d)or isinstance(module, nn.Linear)or isinstance(module, nn.Embedding)):param_count += sum([p.data.nelement() for p in module.parameters()])print('Created', network.__class__.__name__, "with %d parameters" % param_count)def init_networks(self):def init_weights(m, gain=0.02):classname = m.__class__.__name__if classname.find('BatchNorm2d') != -1:if hasattr(m, 'weight') and m.weight is not None:init.normal_(m.weight.data, 1.0, gain)if hasattr(m, 'bias') and m.bias is not None:init.constant_(m.bias.data, 0.0)elif hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):init.xavier_normal_(m.weight.data, gain=gain)if hasattr(m, 'bias') and m.bias is not None:init.constant_(m.bias.data, 0.0)if self.opt.phase == "train":networks = [self.netG, self.netD]else:networks = [self.netG]for net in networks:net.apply(init_weights)def put_on_multi_gpus(model, opt):if opt.gpu_ids != "-1":gpus = list(map(int, opt.gpu_ids.split(",")))model = DataParallelWithCallback(model, device_ids=gpus).cuda()else:model.module = modelassert len(opt.gpu_ids.split(",")) == 0 or opt.batch_size % len(opt.gpu_ids.split(",")) == 0return modeldef preprocess_input(opt, data):data['label'] = data['label'].long()if opt.gpu_ids != "-1":data['label'] = data['label'].cuda()data['image'] = data['image'].cuda()label_map = data['label']bs, _, h, w = label_map.size()nc = opt.semantic_ncif opt.gpu_ids != "-1":input_label = torch.cuda.FloatTensor(bs, nc, h, w).zero_()else:input_label = torch.FloatTensor(bs, nc, h, w).zero_()input_semantics = input_label.scatter_(1, label_map, 1.0)return data['image'], input_semanticsdef generate_labelmix(label, fake_image, real_image):target_map = torch.argmax(label, dim = 1, keepdim = True)all_classes = torch.unique(target_map)for c in all_classes:target_map[target_map == c] = torch.randint(0,2,(1,)).cuda()target_map = target_map.float()mixed_image = target_map*real_image+(1-target_map)*fake_imagereturn mixed_image, target_map

首先看生成器流程:
标签输入到生成器中得到fake image,fake image 和 real image 共同输入到判别器中得到中间变量输出,接着分别计算四个损失。我们需要明白生成器和辨别器模型的搭建,损失计算过程。
在这里插入图片描述
首先是生成器的组成:
在这里插入图片描述
在这里插入图片描述
输入标签大小是(b,c,h,w),首先z等于一个正态分布的随机数,大小为(b,64),接着view为(b,64,1,1),再扩张到(b,64,h,w)和(b,c,h,w)沿着通道维度拼接起来。将拼接的结果上采样到W和H大小。
在这里插入图片描述
其中在CityscapesDataset指定了:
在这里插入图片描述
则w=512//2^5=16,h=16/2=8.
在这里插入图片描述
令s等于input label,输入到pyrmid中,生成结果添加到列表中。

self.seg_pyrmid = nn.ModuleList([])if not self.opt.no_3dnoise:self.fc = nn.Conv2d(self.opt.semantic_nc + self.opt.z_dim, 16 * ch, 3, padding=1)self.seg_pyrmid.append(nn.Sequential(nn.Conv2d(self.opt.semantic_nc + self.opt.z_dim, 32, 3, stride=1, padding=1), nn.BatchNorm2d(32), nn.ReLU(inplace=True)))else:self.fc = nn.Conv2d(self.opt.semantic_nc, 16 * ch, 3, padding=1)self.seg_pyrmid.append(nn.Sequential(nn.Conv2d(self.opt.semantic_nc, 32, 3, stride=1, padding=1), nn.BatchNorm2d(64), nn.ReLU(inplace=True)))self.seg_pyrmid.append(nn.Sequential(nn.Conv2d(32, 64, 3, stride=1, padding=1), nn.BatchNorm2d(64), nn.ReLU(inplace=True)))for i in range(len(self.channels)-2):self.seg_pyrmid.append(nn.Sequential(nn.Conv2d(64, 64, 3, stride=2, padding=1), nn.BatchNorm2d(64), nn.ReLU(inplace=True)))         

而pyrmid是一个modulist,便利添加的每一个module,生成一个结果:
首先将标签图和噪声拼接起来经过一个3x3卷积,输出通道变为32,再经过一个1x1卷积,输出通道变为64.再经过经过5个步长为2的3x3卷积,下采样32倍。这样pyrmid列表中就有7个结果。
接着将已经采样的x输入到Fc中,输出通道是1024.这里需要清楚两个变量x,和pyrmid.
1:x是输入下采样到(H,W)大小的label+noise.
2:pyrmid是储存经过七次(五次下采样)卷积之后的label+noise。
接着将pyrmid最后一个值采样到x的大小。然后和pyrmid的第i个值拼接在一起。
在这里插入图片描述
对应于:
在这里插入图片描述
每拼接一次生成的值和经过Fc之后的label+noise共同作为输入:
在这里插入图片描述
输入到SPADE块中:
首先要判断SPAD的两个参数即输入通道是否相等。
在这里插入图片描述
在这里插入图片描述
如果相等就输入到SPADE模块,如果不等令变量等于输入值。
在这里插入图片描述
其中最后一个参数是类别值:在Cityscape数据集设定语义标签是34类。有一类是未知,加上噪声的64个通道。
在这里插入图片描述
SPADE:

class SPADE(nn.Module):def __init__(self, opt, norm_nc, label_nc):super().__init__()self.first_norm = get_norm_layer(opt, norm_nc)ks = opt.spade_ksnhidden = 128pw = ks // 2#self.mlp_shared = nn.Sequential(#    nn.Conv2d(label_nc, nhidden, kernel_size=ks, padding=pw),#    nn.ReLU()#)self.mlp_gamma = nn.Conv2d(nhidden, norm_nc, kernel_size=ks, padding=pw)self.mlp_beta = nn.Conv2d(nhidden, norm_nc, kernel_size=ks, padding=pw)def forward(self, x, segmap):normalized = self.first_norm(x)#segmap = F.interpolate(segmap, size=x.size()[2:], mode='nearest')#actv = self.mlp_shared(segmap)actv = segmapgamma = self.mlp_gamma(actv)beta = self.mlp_beta(actv)out = normalized * (1 + gamma) + betareturn out

公式:
在这里插入图片描述
首先X经过一个norm层,即为分布式BN。
在这里插入图片描述
在这里插入图片描述
接着使用卷积学习β和γ。
在这里插入图片描述
在这里插入图片描述
卷积核大小都为3,padding为1。
接着经过bn之后的变量和γ相乘在和β相加,再和经过归一化之后的x相加。
在这里插入图片描述
接着:x和seg经过相同的norm操作。再进过一个LeakyReLU,再进行一个卷积层。中间有个midlayer过渡。
在这里插入图片描述
在这里插入图片描述
输出的结果经过一个跳连接得到最后输出。
在这里插入图片描述
经过SPADE之后的输出上采样两倍作为输入输入到下一个SPADE中。
最终输出一个通道为3的RGB图片。

相关文章:

DP-GAN-生成器代码

首先看一下数据生成: 在预处理阶段会将label经过ont-hot编码转换为35个通道,即每个通道都是由(0,1)组成。 在train文件中,对生成器和判别器分别进行更新,根据loss的不同,分别计算对于的损失&a…...

2020-2023中国高等级自动驾驶产业发展趋势研究

1.1 概念界定 2020-2023中国高等级自动驾驶产业发展趋势研究Trends in China High-level Autonomous Driving from 2020 to 2023自动驾驶发展过程中,中国出现了诸多专注于研发L3级以上自动驾驶的公司,其在业界地位也越来越重要。本报告围绕“高等级自动…...

JDK19 - synchronized关键字导致的虚拟线程PINNED

JDK19 - synchronized关键字导致的虚拟线程PINNED 前言一. PINNED是什么意思1.1 synchronized 绑定测试1.2 synchronized 关键字的替代 二. -Djdk.tracePinnedThreads的作用和坑2.1 死锁案例测试2.2 发生原因的推测2.3 总结 前言 在 虚拟线程详解 这篇文章里面,我们…...

用msys2安装verilator并用spinal进行仿真

一 参考 SpinalHDL 开发环境搭建一步到位(图文版) - 极术社区 - 连接开发者与智能计算生态 (aijishu.com)https://aijishu.com/a/1060000000255643Setup and installation of Verilator — SpinalHDL documentation...

【ARM64 常见汇编指令学习 13 -- ARM 汇编 ORG 伪指令学习】

文章目录 ARM ORG 指令介绍UEFI 中对 ORG 指令的使用 ARM ORG 指令介绍 在ARM汇编中,"org"是一个汇编器伪指令,用于设置下一条指令的装入地址。"org"后面跟着的是一个表达式,这个表达式的值就是下一条指令的装入地址。如…...

Vue使用QuillEditor富文本编辑器问题记录

1.内容绑定的问题 绑定内容要使用 v-model:content"xxx" 的形式。 2.设置字体字号 字体以及字号大小的设置需要先注册。 <script> import { QuillEditor,Quill } from vueup/vue-quill import vueup/vue-quill/dist/vue-quill.snow.css; // 设置字体大小 c…...

spring AOP学习

概念 面向切面编程横向扩展动态代理 相关术语 动态代理 spring在运行期&#xff0c;生成动态代理对象&#xff0c;不需要特殊的编译器 Spring AOP的底层就是通过JDK动态代理或者CGLIb动态代理技术为目标Bean执行横向织入 目标对象实现了接口&#xff0c;spring使用JDK的ja…...

16.M端事件和JS插件

16.1移动端 移动端也有自己独特的地方 ●触屏事件touch (也称触摸事件)&#xff0c;Android 和I0S都有。 ●touch对象代表一个触摸点。触摸点可能是一根手指&#xff0c;也可能是一根触摸笔。触屏事件可响应用户手指(或触控笔)对屏幕或者触控板操作。 ●常见的触屏事件如下: …...

Zebec APP:构建全面、广泛的流支付应用体系

目前&#xff0c;流支付协议 Zebec Protocol 基本明确了生态的整体轮廓&#xff0c;它包括由其社区推动的模块化 Layer3 构架的公链 Nautilus Chain、流支付应用 Zebec APP 以及 流支付薪酬工具 Zebec payroll 。其中&#xff0c;Zebec APP 是原有 Zebec Protocol 的主要部分&a…...

Spark 3.1.1 遇到的 from_json regexp_replace组合表达式慢问题的解决

背景 目前公司在从spark 2.4.x升级到3.1.1的时候&#xff0c;遇到了一类SQL极慢的情况&#xff0c;该SQL的如下(只列举了关键的)&#xff1a; select device_personas.* from(selectdevice_id, ads_id, from_json(regexp_replace(device_personas, (?<(\\{|,))"devic…...

Docker 容器常用的命令和操作

1.容器操作 - 运行容器: docker run [OPTIONS] IMAGE [COMMAND] [ARG...] 示例&#xff1a; docker run -it --rm ubuntu /bin/bash - 查看正在运行的容器: docker ps [OPTIONS] 示例&#xff1a; docker ps -a - 停止容器: docker stop CONTAINER [CONTAINER...] 示…...

iTOP-RK3568开发板Windows 安装 RKTool 驱动

在烧写镜像之前首先需要安装 RKTool 驱动。 RKTool 驱动在网盘资料“iTOP-3568 开发板\01_【iTOP-RK3568 开发板】基础资料 \02_iTOP-RK3568 开发板烧写工具及驱动”路径下。 驱动如下图所示&#xff1a; 解压缩后&#xff0c;进入文件夹&#xff0c;如下图所示&#xff1a;…...

nginx rtmp http_flv直播推流

安装配置nginx yum install epel-release -y sudo rpm -Uvh http://li.nux.ro/download/nux/dextop/el7/x86_64/nux-dextop-release-0-5.el7.nux.noarch.rpm yum install ffmpeg ffmpeg-devel -y yum install gcc -y yum install pcre pcre-devel -y yum install openssl open…...

Day50 算法记录| 动态规划 17(子序列)

这里写目录标题 647. 回文子串516.最长回文子序列总结 647. 回文子串 1.动态规划和2.中心扩展 这个视频是基于上面的视频的代码 方法1:动态规划 布尔类型的dp[i][j]&#xff1a;表示区间范围[i,j] &#xff08;注意是左闭右闭&#xff09;的子串是否是回文子串&#xff0c;如…...

RabbitMQ:概念和安装,简单模式,工作,发布确认,交换机,死信队列,延迟队列,发布确认高级,其它知识,集群

1. 消息队列 1.0 课程介绍 1.1.MQ 的相关概念 1.1.1.什么是MQ MQ(message queue&#xff1a;消息队列)&#xff0c;从字面意思上看&#xff0c;本质是个队列&#xff0c;FIFO 先入先出&#xff0c;只不过队列中存放的内容是message 而已&#xff0c;还是一种跨进程的通信机制…...

小研究 - 基于解析树的 Java Web 灰盒模糊测试(二)

由于 Java Web 应用业务场景复杂, 且对输入数据的结构有效性要求较高, 现有的测试方法和工具在测试Java Web 时存在测试用例的有效率较低的问题. 为了解决上述问题, 本文提出了基于解析树的 Java Web 应用灰盒模糊测试方法. 首先为 Java Web 应用程序的输入数据包进行语法建模创…...

对于现有的分布式id发号器的思考 id生成器 雪花算法 uuid

在工作过程中接触了很多id生成策略&#xff0c;但是有一些问题 雪花id 强依赖时钟&#xff0c;对于时钟回拨无法很好解决 tinyid 滴滴开源&#xff0c;依赖mysql数据库&#xff0c;自增&#xff0c;无业务属性 uuid 生成是一个字符串没有顺序&#xff0c;数据库索引组织数据…...

jmeter中json提取器,获取多个值,并通过beanshell组成数组

jmeter中json提取器介绍 特别说明&#xff1a;**Compute concatenation var(suffix_ALL)&#x1f617;*如果找到许多结果&#xff0c;则插件将使用’ &#xff0c; 分隔符将它们连接起来&#xff0c;并将其存储在名为 _ALL的var中 json提取器调试 在查看结果树中选择JSON Pat…...

通过nvm工具快捷切换node.js版本、以及nvm的安装

使用nvm可以实现多个Node.js版本之间切换 步骤目录&#xff1a; 先卸载掉本系统中原有的node版本 去github上下载nvm安装包 安装node 常用的一些nvm命令 1、先卸载掉本系统中原有的node版本 2、去github上下载nvm安装包 https://github.com/coreybutler/nvm-windows/re…...

企业如何搭建矩阵内容,才能真正实现目的?

当下&#xff0c;新媒体矩阵营销已成为众多企业的营销选择之一&#xff0c;各企业可以通过新媒体矩阵实现扩大品牌声量、维持用户关系、提高销售业绩等不同的目的。 而不同目的的矩阵&#xff0c;它的内容运营模式会稍有差别&#xff0c;评价体系也会大不相同。 企业在运营某类…...

Linux 文件类型,目录与路径,文件与目录管理

文件类型 后面的字符表示文件类型标志 普通文件&#xff1a;-&#xff08;纯文本文件&#xff0c;二进制文件&#xff0c;数据格式文件&#xff09; 如文本文件、图片、程序文件等。 目录文件&#xff1a;d&#xff08;directory&#xff09; 用来存放其他文件或子目录。 设备…...

CTF show Web 红包题第六弹

提示 1.不是SQL注入 2.需要找关键源码 思路 进入页面发现是一个登录框&#xff0c;很难让人不联想到SQL注入&#xff0c;但提示都说了不是SQL注入&#xff0c;所以就不往这方面想了 ​ 先查看一下网页源码&#xff0c;发现一段JavaScript代码&#xff0c;有一个关键类ctfs…...

css实现圆环展示百分比,根据值动态展示所占比例

代码如下 <view class""><view class"circle-chart"><view v-if"!!num" class"pie-item" :style"{background: conic-gradient(var(--one-color) 0%,#E9E6F1 ${num}%),}"></view><view v-else …...

线程同步:确保多线程程序的安全与高效!

全文目录&#xff1a; 开篇语前序前言第一部分&#xff1a;线程同步的概念与问题1.1 线程同步的概念1.2 线程同步的问题1.3 线程同步的解决方案 第二部分&#xff1a;synchronized关键字的使用2.1 使用 synchronized修饰方法2.2 使用 synchronized修饰代码块 第三部分&#xff…...

汽车生产虚拟实训中的技能提升与生产优化​

在制造业蓬勃发展的大背景下&#xff0c;虚拟教学实训宛如一颗璀璨的新星&#xff0c;正发挥着不可或缺且日益凸显的关键作用&#xff0c;源源不断地为企业的稳健前行与创新发展注入磅礴强大的动力。就以汽车制造企业这一极具代表性的行业主体为例&#xff0c;汽车生产线上各类…...

Module Federation 和 Native Federation 的比较

前言 Module Federation 是 Webpack 5 引入的微前端架构方案&#xff0c;允许不同独立构建的应用在运行时动态共享模块。 Native Federation 是 Angular 官方基于 Module Federation 理念实现的专为 Angular 优化的微前端方案。 概念解析 Module Federation (模块联邦) Modul…...

【Web 进阶篇】优雅的接口设计:统一响应、全局异常处理与参数校验

系列回顾&#xff1a; 在上一篇中&#xff0c;我们成功地为应用集成了数据库&#xff0c;并使用 Spring Data JPA 实现了基本的 CRUD API。我们的应用现在能“记忆”数据了&#xff01;但是&#xff0c;如果你仔细审视那些 API&#xff0c;会发现它们还很“粗糙”&#xff1a;有…...

docker 部署发现spring.profiles.active 问题

报错&#xff1a; org.springframework.boot.context.config.InvalidConfigDataPropertyException: Property spring.profiles.active imported from location class path resource [application-test.yml] is invalid in a profile specific resource [origin: class path re…...

《C++ 模板》

目录 函数模板 类模板 非类型模板参数 模板特化 函数模板特化 类模板的特化 模板&#xff0c;就像一个模具&#xff0c;里面可以将不同类型的材料做成一个形状&#xff0c;其分为函数模板和类模板。 函数模板 函数模板可以简化函数重载的代码。格式&#xff1a;templa…...

智能AI电话机器人系统的识别能力现状与发展水平

一、引言 随着人工智能技术的飞速发展&#xff0c;AI电话机器人系统已经从简单的自动应答工具演变为具备复杂交互能力的智能助手。这类系统结合了语音识别、自然语言处理、情感计算和机器学习等多项前沿技术&#xff0c;在客户服务、营销推广、信息查询等领域发挥着越来越重要…...