当前位置: 首页 > news >正文

FQ-GAN代码解析

主要看 model 、loss 和 data 部分如何实现和处理的。

  • model—VQ_models
    • VQModel
    • Encoder
    • VectorQuantizer
    • Decoder
  • loss—VQLoss_triple_codebook

model—VQ_models

创建vq_model直接根据传入的模型压缩倍率8/16初始化对应的VQ_8/VQ_16,两者都是初始化一个VQModel的类,只是压缩的倍率ch_mult不同(这个和UNet里的ch_mult是一致的,表示每个Block上/下采样的倍数,所有倍率之积就是压缩倍率)

	# create and load modelvq_model = VQ_models[args.vq_model](codebook_size=args.codebook_size,codebook_embed_dim=args.codebook_embed_dim,commit_loss_beta=args.commit_loss_beta,entropy_loss_ratio=args.entropy_loss_ratio,dropout_p=args.dropout_p,with_clip_supervision=args.with_clip_supervision,with_disentanglement=args.with_disentanglement,disentanglement_ratio=args.disentanglement_ratio,)def VQ_8(**kwargs):return VQModel(ModelArgs(encoder_ch_mult=[1, 2, 2, 4], decoder_ch_mult=[1, 2, 2, 4], **kwargs))def VQ_16(**kwargs):return VQModel(ModelArgs(encoder_ch_mult=[1, 1, 2, 2, 4], decoder_ch_mult=[1, 1, 2, 2, 4], **kwargs))VQ_models = {'VQ-16': VQ_16, 'VQ-8': VQ_8}

VQModel

包含3个 codebook 的 VQModel 的结构如下:

  • EncoderEncoder(逐步压缩spatial维度到embed_dim维度)
  • VectorQuantizer:3个VectorQuantizer(分别是pixel level的无teacher,mid semantic level的DINO teacher,high semantic level的CLIP teacher),配合3个quant_conv(将z_channels变成codebook_embed_dim
  • Decoder:1个post_quant_conv(将emebdding_dim从3*codebook_embed_dim变成z_channels),一个Decoder(逐步将embed_dim维度还原到spatial维度)
  • FeatPredHead:2个FeatPredHead(分别是将vq feature对齐到CLIP和DINO feature的MLP Head,用于蒸馏监督)
class VQModel(nn.Module):def __init__(self, config: ModelArgs):super().__init__()self.config = config# Two head encoderself.encoder = Encoder(ch_mult=config.encoder_ch_mult, z_channels=config.z_channels, dropout=config.dropout_p)# Quantizer for visual detail headself.quantize_vis = VectorQuantizer(config.codebook_size, config.codebook_embed_dim,config.commit_loss_beta, config.entropy_loss_ratio,config.codebook_l2_norm, config.codebook_show_usage)self.quant_conv_vis = nn.Conv2d(config.z_channels, config.codebook_embed_dim, 1)# Quantizer for mid-level semantic headself.quantize_sem_mid = VectorQuantizer(config.codebook_size, config.codebook_embed_dim,config.commit_loss_beta, config.entropy_loss_ratio,config.codebook_l2_norm, config.codebook_show_usage)self.quant_conv_sem_mid = nn.Conv2d(config.z_channels, config.codebook_embed_dim, 1)# Quantizer for high-level semantic headself.quantize_sem_high = VectorQuantizer(config.codebook_size, config.codebook_embed_dim,config.commit_loss_beta, config.entropy_loss_ratio,config.codebook_l2_norm, config.codebook_show_usage)self.quant_conv_sem_high = nn.Conv2d(config.z_channels, config.codebook_embed_dim, 1)print("Visual codebook: [{} x {}]".format(config.codebook_size, config.codebook_embed_dim))print("Mid Semantic codebook: [{} x {}]".format(config.codebook_size, config.codebook_embed_dim))print("High Semantic codebook: [{} x {}]".format(config.codebook_size, config.codebook_embed_dim))# Pixel decoderinput_dim = config.codebook_embed_dim * 3self.post_quant_conv = nn.Conv2d(input_dim, config.z_channels, 1)self.decoder = Decoder(ch_mult=config.decoder_ch_mult, z_channels=config.z_channels,dropout=config.dropout_p)# Down-sample factor in encoder channel multiplierself.num_resolutions = len(config.encoder_ch_mult)if self.num_resolutions == 5:  # encoder_ch_mult=[1, 1, 2, 2, 4]down_factor = 16elif self.num_resolutions == 4:  # encoder_ch_mult=[1, 2, 2, 4]down_factor = 8else:raise NotImplementedError# Semantic feature predictionif self.config.with_clip_supervision:print("Include feature prediction head for representation supervision")self.mid_sem_feat_pred = FeatPredHead(input_dim=config.codebook_embed_dim, out_dim=384, down_factor=down_factor)self.high_sem_feat_pred = FeatPredHead(input_dim=config.codebook_embed_dim, out_dim=768, down_factor=down_factor)else:print("NO representation supervision")if self.config.with_disentanglement:print("Disentangle Ratio: ", self.config.disentanglement_ratio)else:print("No Disentangle Regularization")

前向forward 包含encodevqdecode三个主要过程(因为需要KD,额外要一步feature对齐操作):

  1. ① 输入经过encoder得到3个不同的feature(h_vis, h_sem_mid, h_sem_high),再经过3个quant_conv将embed_dim对齐到codebook_embed_dim。
  2. ②将不同level的image feature送入不同的VectorQuantizer,得到三个不同的quant_featureemb_lossemb_loss包含vq_losscommit_lossentropy_loss三部分)
  3. 因为需要知识蒸馏,因此需要额外使用FeatPredHead将quant_feature对齐到CLIP和DINO特征的维度(mid_sem_feat_predhigh_sem_feat_pred)
  4. 因为希望3个codebook相互正交(解耦程度大),因此需要构造1个解耦loss,使3个level的vq feature相互不同(embedding点积之和的L2 loss,即disentangle_loss )。
  5. ③ 将quant_feature经过post_quant_conv和decoder,解码为原始image的pixel_values(dec)。
    def forward(self, input):# 1. encodeh_vis, h_sem_mid, h_sem_high = self.encoder(input)h_vis = self.quant_conv_vis(h_vis)h_sem_mid = self.quant_conv_sem_mid(h_sem_mid)h_sem_high = self.quant_conv_sem_high(h_sem_high)# 2. vqquant_vis, emb_loss_vis, _ = self.quantize_vis(h_vis)quant_sem_mid, emb_loss_sem_mid, _ = self.quantize_sem_mid(h_sem_mid)quant_sem_high, emb_loss_sem_high, _ = self.quantize_sem_high(h_sem_high)# for konwledge distillationif self.config.with_clip_supervision:mid_lvl_sem_feat = self.mid_sem_feat_pred(quant_sem_mid)high_lvl_sem_feat = self.high_sem_feat_pred(quant_sem_high)else:mid_lvl_sem_feat = Nonehigh_lvl_sem_feat = None# for disentangle vq feature of 3 codebookif self.config.with_disentanglement:disentangle_loss = (self.compute_disentangle_loss(quant_vis, quant_sem_mid) +self.compute_disentangle_loss(quant_vis, quant_sem_high) +self.compute_disentangle_loss(quant_sem_mid, quant_sem_high)) / 3.0else:disentangle_loss = 0# 3. decodequant = torch.cat([quant_vis, quant_sem_mid, quant_sem_high], dim=1)dec = self.decode(quant)return dec, \emb_loss_vis, emb_loss_sem_mid, emb_loss_sem_high, \disentangle_loss, \mid_lvl_sem_feat, high_lvl_sem_feat

本文叫FQ的创新点就是在于设计了这个disentangle_loss使得3个codebook相互正交解耦:这个损失函数的设计思想是,如果2个特征是解耦的,那么它们的点积应该接近于零,因为它们应该是正交的。通过最小化这个损失,模型被鼓励学习到解耦的不同level的特征。

    def compute_disentangle_loss(self, quant_vis, quant_sem):quant_vis = rearrange(quant_vis, 'b c h w -> (b h w) c')quant_sem = rearrange(quant_sem, 'b c h w -> (b h w) c')quant_vis = F.normalize(quant_vis, p=2, dim=-1)quant_sem = F.normalize(quant_sem, p=2, dim=-1)dot_product = torch.sum(quant_vis * quant_sem, dim=1)loss = torch.mean(dot_product ** 2) * self.config.disentanglement_ratioreturn loss

Encoder

Encoder是输入image feature,经过统一的downsampling conv_blocksmid blocks,再分别送入3个不同的adapter输出3个不同的feature

  1. conv_in:输入的image feature首先由conv_inchannel维度转化为128
  2. downsamplingconv_blocks根据ch_mult=(1,1,2,2,4)ch_mult=(1, 2, 2, 4)构建ResnetBlockAttnBlock以及Downsample组成,其中ch_mult用于控制每个conv_block的channel增加倍数。(channel增加,h和w减小)。每个block的下采样后的channel是128*ch_mult[i](例如ch_mult=(1, 2, 2, 4)时,共有4个block,channel的变化是128->128->256->512->2048)。
  3. mid:由ResnetBlock+AttnBlock+ResnetBlock组成其中卷积不改变channel,等效于MLP。
  4. adapter:由3个不同的FactorizedAdapter组成,用于将统一的encoder feature转化为3个不同的feature,用于后面3个codebook的VQ操作。
  5. conv_out:因为前一步将feature转化了3份(h_vis, h_sem_mid, h_sem_high),因此此处从conv_out分别用3个不同的conv2d用于对齐feature的channel维度(转换为z_channels维度)。
class Encoder(nn.Module):def __init__(self, in_channels=3, ch=128, ch_mult=(1,1,2,2,4), num_res_blocks=2, norm_type='group', dropout=0.0, resamp_with_conv=True, z_channels=256):super().__init__()self.num_resolutions = len(ch_mult)self.num_res_blocks = num_res_blocksself.conv_in = nn.Conv2d(in_channels, ch, kernel_size=3, stride=1, padding=1)# downsamplingin_ch_mult = (1,) + tuple(ch_mult)self.conv_blocks = nn.ModuleList()for i_level in range(self.num_resolutions):conv_block = nn.Module()# res & attnres_block = nn.ModuleList()attn_block = nn.ModuleList()block_in = ch*in_ch_mult[i_level]block_out = ch*ch_mult[i_level]for _ in range(self.num_res_blocks):res_block.append(ResnetBlock(block_in, block_out, dropout=dropout, norm_type=norm_type))block_in = block_outif i_level == self.num_resolutions - 1:attn_block.append(AttnBlock(block_in, norm_type))conv_block.res = res_blockconv_block.attn = attn_block# downsampleif i_level != self.num_resolutions-1:conv_block.downsample = Downsample(block_in, resamp_with_conv)self.conv_blocks.append(conv_block)# middleself.mid = nn.ModuleList()self.mid.append(ResnetBlock(block_in, block_in, dropout=dropout, norm_type=norm_type))self.mid.append(AttnBlock(block_in, norm_type=norm_type))self.mid.append(ResnetBlock(block_in, block_in, dropout=dropout, norm_type=norm_type))if self.num_resolutions == 5:down_factor = 16elif self.num_resolutions == 4:down_factor = 8else:raise NotImplementedError# semantic head mid-levelself.semantic_head_mid = nn.ModuleList()self.semantic_head_mid.append(FactorizedAdapter(down_factor))# semantic head high-levelself.semantic_head_high = nn.ModuleList()self.semantic_head_high.append(FactorizedAdapter(down_factor))# visual details headself.visual_head = nn.ModuleList()self.visual_head.append(FactorizedAdapter(down_factor))# endself.norm_out_sem_mid = Normalize(block_in, norm_type)self.conv_out_sem_mid = nn.Conv2d(block_in, z_channels, kernel_size=3, stride=1, padding=1)self.norm_out_sem_high = Normalize(block_in, norm_type)self.conv_out_sem_high = nn.Conv2d(block_in, z_channels, kernel_size=3, stride=1, padding=1)self.norm_out_vis = Normalize(block_in, norm_type)self.conv_out_vis = nn.Conv2d(block_in, z_channels, kernel_size=3, stride=1, padding=1)def forward(self, x):h = self.conv_in(x)# downsamplingfor i_level, block in enumerate(self.conv_blocks):for i_block in range(self.num_res_blocks):h = block.res[i_block](h)if len(block.attn) > 0:h = block.attn[i_block](h)if i_level != self.num_resolutions - 1:h = block.downsample(h)# middlefor mid_block in self.mid:h = mid_block(h)h_vis = hh_sem_mid = hh_sem_high = h# semantic head mid-levelfor blk in self.semantic_head_mid:h_sem_mid = blk(h_sem_mid)h_sem_mid = self.norm_out_sem_mid(h_sem_mid)h_sem_mid = nonlinearity(h_sem_mid)h_sem_mid = self.conv_out_sem_mid(h_sem_mid)# semantic head high-levelfor blk in self.semantic_head_high:h_sem_high = blk(h_sem_high)h_sem_high = self.norm_out_sem_high(h_sem_high)h_sem_high = nonlinearity(h_sem_high)h_sem_high = self.conv_out_sem_high(h_sem_high)# visual headfor blk in self.visual_head:h_vis = blk(h_vis)h_vis = self.norm_out_vis(h_vis)h_vis = nonlinearity(h_vis)h_vis = self.conv_out_vis(h_vis)return h_vis, h_sem_mid, h_sem_high

VectorQuantizer

VectorQuantizer的初始化操作主要是创建一个大小[codebook_size, codebook_embed_dim]codebook embeddingembedding)。

class VectorQuantizer(nn.Module):def __init__(self, n_e, e_dim, beta, entropy_loss_ratio, l2_norm, show_usage):super().__init__()self.n_e = n_e  # codebook_sizeself.e_dim = e_dim  # codebook_embed_dimself.beta = beta  # commitment_loss scaleself.entropy_loss_ratio = entropy_loss_ratio  # entropy_loss scaleself.l2_norm = l2_norm  # l2_norm for codebook embeddingsself.show_usage = show_usage  # show codebook usage# create codebook embedding and initializeself.embedding = nn.Embedding(self.n_e, self.e_dim)self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e)if self.l2_norm:  # normalize embeddingsself.embedding.weight.data = F.normalize(self.embedding.weight.data, p=2, dim=-1)if self.show_usage:  # initialize codebook usageself.register_buffer("codebook_used", nn.Parameter(torch.zeros(65536)))     # 1048576

forward的操作和VQGAN一样,就是把image feature z的所有token embeddings查表量化为codebook中argmin(distances)的emebddings得到 quant image feature zq,同时计算3个loss(用于优化codebook embedding)。

  • l2 norm:同时对zcodebook embeddings进行L2归一化,可以将向量的模长缩放到相同的大小,即转换为在单位球面上的向量,这样每个向量在距离度量中的作用是相等的,使得不同向量更具有可比性!向量之间更容易比较和匹配,提高了训练稳定性和重建质量
  • argmin(distances):经典的VQ计算distances的操作,展开为两个平方和一个乘积, ( z − e ) 2 = z 2 + e 2 − 2 e ∗ z (z - e)^2 = z^2 + e^2 - 2 e * z (ze)2=z2+e22ez。然后argmin(distances)得到z中每个embedding在codebook中最近的embedding的index,再从codebook的embeddings中取出组成zq
  • codebook usage:是计算codebook中的embeddings 的利用率。
    def forward(self, z):# reshape z -> (batch, height, width, channel) and flattenz = torch.einsum('b c h w -> b h w c', z).contiguous()z_flattened = z.view(-1, self.e_dim)if self.l2_norm:  # normalize z and codebook_embedding for mapping vector to euclidean space(单位球上)z = F.normalize(z, p=2, dim=-1)z_flattened = F.normalize(z_flattened, p=2, dim=-1)embedding = F.normalize(self.embedding.weight, p=2, dim=-1)else:embedding = self.embedding.weight# distances from z to embeddings e_j: (z - e)^2 = z^2 + e^2 - 2 e * zd = torch.sum(z_flattened ** 2, dim=1, keepdim=True) + \torch.sum(embedding**2, dim=1) - 2 * \torch.einsum('bd,dn->bn', z_flattened, torch.einsum('n d -> d n', embedding))# argmin(distances)min_encoding_indices = torch.argmin(d, dim=1)# replace each z_i with its closest embedding e_jz_q = embedding[min_encoding_indices].view(z.shape)perplexity = Nonemin_encodings = Nonevq_loss = Nonecommit_loss = Noneentropy_loss = Nonecodebook_usage = 0# compute codebook usageif self.show_usage and self.training:cur_len = min_encoding_indices.shape[0]self.codebook_used[:-cur_len] = self.codebook_used[cur_len:].clone()  # copy last cur_len elements to frontself.codebook_used[-cur_len:] = min_encoding_indices  # set last cur_len elements as min_encoding_indicescodebook_usage = len(torch.unique(self.codebook_used)) / self.n_e
  • embedding loss
    • vq_loss是计算量化后的向量 z_q 和原始输入向量 z 之间的均方误差(Mean Squared Error, MSE)。z.detach() 表示 z 是从计算图中分离出来的,这意味着在计算 vq_loss 时,z 不会对其梯度产生影响。这个损失鼓励模型将输入向量 z 量化为与其尽可能接近的嵌入向量 z_q
    • commit_loss也是均方误差,但是这里 z_q 是从计算图中分离出来的。这意味着在计算 commit_loss 时,z_q 不会对其梯度产生影响。这个损失的作用是鼓励模型在量化过程中保持对原始输入向量 z 的承诺,即量化后的向量 z_q 应该尽可能地反映输入向量 z 的信息。参数 self.beta 是一个超参数,用于调节这个损失在总损失中的重要性。
    • entropy_loss用于鼓励码本的均匀使用,从而提高模型的泛化能力。compute_entropy_loss(-d) 计算的是基于码本距离的负值的熵损失,-d 表示我们对每个输入向量 z 计算到所有嵌入的平方距离,然后取负值。熵损失的计算通常涉及到对这些距离的softmax操作,然后计算交叉熵。self.entropy_loss_ratio 是一个超参数,用于调节熵损失在总损失中的重要性。
		# compute 3 loss for embeddingif self.training:vq_loss = torch.mean((z_q - z.detach()) ** 2) commit_loss = self.beta * torch.mean((z_q.detach() - z) ** 2) entropy_loss = self.entropy_loss_ratio * compute_entropy_loss(-d)# preserve gradientsz_q = z + (z_q - z).detach()# reshape back to match original input shapez_q = torch.einsum('b h w c -> b c h w', z_q)return z_q, (vq_loss, commit_loss, entropy_loss, codebook_usage), (perplexity, min_encodings, min_encoding_indices)
def compute_entropy_loss(affinity, loss_type="softmax", temperature=0.01):flat_affinity = affinity.reshape(-1, affinity.shape[-1])flat_affinity /= temperatureprobs = F.softmax(flat_affinity, dim=-1)log_probs = F.log_softmax(flat_affinity + 1e-5, dim=-1)if loss_type == "softmax":target_probs = probselse:raise ValueError("Entropy loss {} not supported".format(loss_type))avg_probs = torch.mean(target_probs, dim=0)avg_entropy = - torch.sum(avg_probs * torch.log(avg_probs + 1e-5))sample_entropy = - torch.mean(torch.sum(target_probs * log_probs, dim=-1))loss = sample_entropy - avg_entropyreturn loss

get_codebook_entry用于Transformer自回归的预测一个index序列后,用于在codebook查表转化为对应embeddings

    def get_codebook_entry(self, indices, shape=None, channel_first=True):# shape = (batch, channel, height, width) if channel_first else (batch, height, width, channel)if self.l2_norm:embedding = F.normalize(self.embedding.weight, p=2, dim=-1)else:embedding = self.embedding.weightz_q = embedding[indices]  # (b*h*w, c)if shape is not None:if channel_first:z_q = z_q.reshape(shape[0], shape[2], shape[3], shape[1])# reshape back to match original input shapez_q = z_q.permute(0, 3, 1, 2).contiguous()else:z_q = z_q.view(shape)return z_q

Decoder

整个VQ操作从zzq不改变image feature的shape,channel维度还是z_channels=256。因此Decoder将zq解码为image的pixel values的过程如下:

  1. conv_in:使用Conv2d将zq的channel维度从z_channels变换到block_in(由ch=128和ch_mult决定的)
  2. middle block:和Encoder一样由ResnetBlockAttnBlockResnetBlock组成,不改变channel维度。
  3. upsampling conv_blocks:和Encoder刚好相反,根据ch_mult构造多个Block,逐步上采样,增大spatial维度,减小channel维度。
  4. conv_out:最终的conv_out用于将channel维度从block_in转化为out_channels=3,得到图像pixel valuse。
class Decoder(nn.Module):def __init__(self, z_channels=256, ch=128, ch_mult=(1,1,2,2,4), num_res_blocks=2, norm_type="group",dropout=0.0, resamp_with_conv=True, out_channels=3):super().__init__()self.num_resolutions = len(ch_mult)self.num_res_blocks = num_res_blocksblock_in = ch*ch_mult[self.num_resolutions-1]# z to block_inself.conv_in = nn.Conv2d(z_channels, block_in, kernel_size=3, stride=1, padding=1)# middleself.mid = nn.ModuleList()self.mid.append(ResnetBlock(block_in, block_in, dropout=dropout, norm_type=norm_type))self.mid.append(AttnBlock(block_in, norm_type=norm_type))self.mid.append(ResnetBlock(block_in, block_in, dropout=dropout, norm_type=norm_type))# upsamplingself.conv_blocks = nn.ModuleList()for i_level in reversed(range(self.num_resolutions)):conv_block = nn.Module()# res & attnres_block = nn.ModuleList()attn_block = nn.ModuleList()block_out = ch*ch_mult[i_level]for _ in range(self.num_res_blocks + 1):res_block.append(ResnetBlock(block_in, block_out, dropout=dropout, norm_type=norm_type))block_in = block_outif i_level == self.num_resolutions - 1:attn_block.append(AttnBlock(block_in, norm_type))conv_block.res = res_blockconv_block.attn = attn_block# downsampleif i_level != 0:conv_block.upsample = Upsample(block_in, resamp_with_conv)self.conv_blocks.append(conv_block)# endself.norm_out = Normalize(block_in, norm_type)self.conv_out = nn.Conv2d(block_in, out_channels, kernel_size=3, stride=1, padding=1)@propertydef last_layer(self):return self.conv_out.weightdef forward(self, z):# z to block_inh = self.conv_in(z)# middlefor mid_block in self.mid:h = mid_block(h)# upsamplingfor i_level, block in enumerate(self.conv_blocks):for i_block in range(self.num_res_blocks + 1):h = block.res[i_block](h)if len(block.attn) > 0:h = block.attn[i_block](h)if i_level != self.num_resolutions - 1:h = block.upsample(h)# endh = self.norm_out(h)h = nonlinearity(h)h = self.conv_out(h)return h

loss—VQLoss_triple_codebook

前面的VQ_Model进行forward的时候会得到3个embed_loss和1个disentangle_loss:

  • codebook embedding loss:因为有3个codebook,所有3个VQ操作回得到3个embed_loss(emb_loss_vis, emb_loss_sem_mid, emb_loss_sem_high),每个emb_loss都是由3个loss组成(vq_loss, commit_loss, entropy_loss),用于优化codebook。
  • disentangle loss:本文的创新点之一,将不同的codebook的zq之间计算点积的L2距离之和作为disentangle_loss,希望不同codebook之间相互正交。

除此之外,在训练时还可以使用VQLoss_triple_codebook也可以另外计算reconstruction_loss,perceptual_loss,kd_teacher_loss:

  • pixel loss

    • reconstruction_loss:计算VQ_Model重建前后input和output的pixel values的l1_loss或者l2_loss
    • perceptual_loss:使用vgg-based LPIPS计算input和output的lpips值作为loss。
  • discriminator loss:用于优化鉴别器discriminator,discriminator可以是PatchGANStyleGAN(输入真实的image或重建的image,输出预测真假的概率分布logits)。discriminator_loss类型可以是hingevanillanon-saturating三类。

  • gen_adv_loss:用于优化生成器,生成器的目标是生成尽可能接近真实数据的假数据,以欺骗判别器。分为hingenon_saturating两种,都是希望重建后图像的概率分布logits_fake更倾向于重建后的image是真实的。

  • semantic loss(kd_teacher_loss):使用2个不同的FeatureHead输出了2个image vq feature分别与CLIP和DINO的feature计算loss,用来蒸馏通用的理解表征。

VQLoss_triple_codebook的初始化就是为计算上述loss准备一些参数和模型:

class VQLoss_triple_codebook(nn.Module):def __init__(self, disc_start, disc_loss="hinge", disc_dim=64, disc_type='patchgan', image_size=256,disc_num_layers=3, disc_in_channels=3, disc_weight=1.0, disc_adaptive_weight=False,gen_adv_loss='hinge', reconstruction_loss='l2', reconstruction_weight=1.0,codebook_weight=1.0, perceptual_weight=1.0,with_clip_supervision=False, semantic_weight=0.5,):super().__init__()# 1. discriminator lossassert disc_type in ["patchgan", "stylegan"]assert disc_loss in ["hinge", "vanilla", "non-saturating"]# discriminatorif disc_type == "patchgan":print("Using patchgan D")self.discriminator = PatchGANDiscriminator(input_nc=disc_in_channels,n_layers=disc_num_layers,ndf=disc_dim,)elif disc_type == "stylegan":print("Using stylegan D")self.discriminator = StyleGANDiscriminator(input_nc=disc_in_channels,image_size=image_size,)else:raise ValueError(f"Unknown GAN discriminator type '{disc_type}'.")# disc_loss typeif disc_loss == "hinge":self.disc_loss = hinge_d_losselif disc_loss == "vanilla":self.disc_loss = vanilla_d_losselif disc_loss == "non-saturating":self.disc_loss = non_saturating_d_losselse:raise ValueError(f"Unknown GAN discriminator loss '{disc_loss}'.")self.discriminator_iter_start = disc_startself.disc_weight = disc_weightself.disc_adaptive_weight = disc_adaptive_weightassert gen_adv_loss in ["hinge", "non-saturating"]# 2. gen_adv_lossif gen_adv_loss == "hinge":self.gen_adv_loss = hinge_gen_losselif gen_adv_loss == "non-saturating":self.gen_adv_loss = non_saturating_gen_losselse:raise ValueError(f"Unknown GAN generator loss '{gen_adv_loss}'.")# 3. perceptual lossself.perceptual_loss = LPIPS().eval()self.perceptual_weight = perceptual_weight# 4. semantic lossself.with_clip_supervision = with_clip_supervisionif with_clip_supervision:self.clip_model = CLIPVisionTower("/mnt/workspace/Project/UnderGenTokenizer/FQGAN/models/clip-vit-base-patch16").eval()self.dinov2_model = DinoVisionTower("/mnt/workspace/Project/UnderGenTokenizer/FQGAN/models/dinov2-small").eval()self.clip_model.requires_grad_(False)self.dinov2_model.requires_grad_(False)self.semantic_weight = semantic_weightelse:self.clip_model = Noneself.dinov2_model = Noneself.semantic_weight = None# 5. reconstruction lossif reconstruction_loss == "l1":self.rec_loss = F.l1_losselif reconstruction_loss == "l2":self.rec_loss = F.mse_losselse:raise ValueError(f"Unknown rec loss '{reconstruction_loss}'.")self.rec_weight = reconstruction_weight# 6. codebook lossself.codebook_weight = codebook_weight

VQLoss_triple_codebook类的forward过程根据optimizer_idx的值分为2个模式,两个模式在同一个batch的先后执行,也就是在训练时,要进行2次的vq_loss类的forward,一次计算generator的loss,一次计算discriminator的loss。且generator和discriminator分别使用2个不同的优化器(optimizeroptimizer_disc):

  • optimizer_idx == 0时,优化generator,计算reconstruction lossperceptual losssemantic lossgen_adv_loss,并将其与之前VQModel推理时计算的codebook_embed_lossdisentangle_loss线性加权起来组成总loss,用于优化VQ_Model
  • optimizer_idx == 1时,优化discriminator,计算discriminator loss用于优化Discriminator
    def forward(self,codebook_loss_vis, codebook_loss_sem_mid, codebook_loss_sem_high,inputs, reconstructions,disentangle_loss,semantic_feat_mid, semantic_feat_high,optimizer_idx, global_step, last_layer=None,logger=None, log_every=100):# generator updateif optimizer_idx == 0:# reconstruction lossrec_loss = self.rec_loss(inputs.contiguous(), reconstructions.contiguous())# semantic lossif semantic_feat_mid is not None:assert semantic_feat_high is not Nonesemantic_loss_mid = self.dinov2_model(inputs.contiguous(), semantic_feat_mid)  # how to compute semantic loss?semantic_loss_mid = torch.mean(semantic_loss_mid)semantic_loss_high = self.clip_model(inputs.contiguous(), semantic_feat_high)semantic_loss_high = torch.mean(semantic_loss_high)else:assert self.with_clip_supervision == Falsesemantic_loss_mid = torch.mean(torch.zeros_like(rec_loss))semantic_loss_high = torch.mean(torch.ones_like(rec_loss))# perceptual lossp_loss = self.perceptual_loss(inputs.contiguous(), reconstructions.contiguous())p_loss = torch.mean(p_loss)# discriminator losslogits_fake = self.discriminator(reconstructions.contiguous())generator_adv_loss = self.gen_adv_loss(logits_fake)if self.disc_adaptive_weight:null_loss = self.rec_weight * rec_loss + self.perceptual_weight * p_loss  # pixel lossdisc_adaptive_weight = self.calculate_adaptive_weight(null_loss, generator_adv_loss,last_layer=last_layer)else:disc_adaptive_weight = 1disc_weight = adopt_weight(self.disc_weight, global_step, threshold=self.discriminator_iter_start)loss = self.rec_weight * rec_loss + \self.perceptual_weight * p_loss + \disc_adaptive_weight * disc_weight * generator_adv_loss + \codebook_loss_vis[0] + codebook_loss_vis[1] + codebook_loss_vis[2] + \codebook_loss_sem_mid[0] + codebook_loss_sem_mid[1] + codebook_loss_sem_mid[2] + \codebook_loss_sem_high[0] + codebook_loss_sem_high[1] + codebook_loss_sem_high[2] + \self.semantic_weight * semantic_loss_mid + self.semantic_weight * semantic_loss_high + disentangle_lossif global_step % log_every == 0:rec_loss = self.rec_weight * rec_lossp_loss = self.perceptual_weight * p_lossgenerator_adv_loss = disc_adaptive_weight * disc_weight * generator_adv_losslogger.info(f"(Generator) rec_loss: {rec_loss:.4f}, perceptual_loss: {p_loss:.4f}, "f"vq_loss_sem_mid: {codebook_loss_sem_mid[0]:.4f}, "f"commit_loss_sem_mid: {codebook_loss_sem_mid[1]:.4f}, "f"entropy_loss_sem_mid: {codebook_loss_sem_mid[2]:.4f}, "f"codebook_usage_sem_mid: {codebook_loss_sem_mid[3]:.4f}, "f"vq_loss_sem_high: {codebook_loss_sem_high[0]:.4f}, "f"commit_loss_sem_high: {codebook_loss_sem_high[1]:.4f}, "f"entropy_loss_sem_high: {codebook_loss_sem_high[2]:.4f}, "f"codebook_usage_sem_high: {codebook_loss_sem_high[3]:.4f}, "f"vq_loss_vis: {codebook_loss_vis[0]:.4f}, "f"commit_loss_vis: {codebook_loss_vis[1]:.4f}, "f"entropy_loss_vis: {codebook_loss_vis[2]:.4f}, "f"codebook_usage_vis: {codebook_loss_vis[3]:.4f}, "f"disentangle_loss: {disentangle_loss: .4f}"f"generator_adv_loss: {generator_adv_loss:.4f}, "f"disc_adaptive_weight: {disc_adaptive_weight:.4f}, disc_weight: {disc_weight:.4f}, "f"semantic_loss_mid: {semantic_loss_mid:.4f}, semantic_loss_high: {semantic_loss_high:.4f}")if dist.get_rank() == 0:wandb.log({"rec_loss": rec_loss,"perceptual_loss": p_loss,"disentangle_loss": disentangle_loss,"codebook_loss_sem_mid": codebook_loss_sem_mid[0],"commit_loss_sem_mid": codebook_loss_sem_mid[1],"entropy_loss_sem_mid": codebook_loss_sem_mid[2],"codebook_usage_sem_mid": codebook_loss_sem_mid[3],"codebook_loss_sem_high": codebook_loss_sem_high[0],"commit_loss_sem_high": codebook_loss_sem_high[1],"entropy_loss_sem_high": codebook_loss_sem_high[2],"codebook_usage_sem_high": codebook_loss_sem_high[3],"codebook_loss_vis": codebook_loss_vis[0],"commit_loss_vis": codebook_loss_vis[1],"entropy_loss_vis": codebook_loss_vis[2],"codebook_usage_vis": codebook_loss_vis[3],"generator_adv_loss": generator_adv_loss,"disc_adaptive_weight": disc_adaptive_weight,"disc_weight": disc_weight,"semantic_loss_mid": semantic_loss_mid,"semantic_loss_high": semantic_loss_high,})return loss# discriminator updateif optimizer_idx == 1:logits_real = self.discriminator(inputs.contiguous().detach())logits_fake = self.discriminator(reconstructions.contiguous().detach())disc_weight = adopt_weight(self.disc_weight, global_step, threshold=self.discriminator_iter_start)d_adversarial_loss = disc_weight * self.disc_loss(logits_real, logits_fake)if global_step % log_every == 0:logits_real = logits_real.detach().mean()logits_fake = logits_fake.detach().mean()logger.info(f"(Discriminator) "f"discriminator_adv_loss: {d_adversarial_loss:.4f}, disc_weight: {disc_weight:.4f}, "f"logits_real: {logits_real:.4f}, logits_fake: {logits_fake:.4f}")if dist.get_rank() == 0:wandb.log({"discriminator_adv_loss": d_adversarial_loss,"disc_weight": disc_weight,"logits_real": logits_real,"logits_fake": logits_fake,})return d_adversarial_loss

相关文章:

FQ-GAN代码解析

主要看 model 、loss 和 data 部分如何实现和处理的。 model—VQ_modelsVQModelEncoderVectorQuantizerDecoder loss—VQLoss_triple_codebook model—VQ_models 创建vq_model直接根据传入的模型压缩倍率8/16初始化对应的VQ_8/VQ_16,两者都是初始化一个VQModel的类…...

如何恢复已删除的 Telegram 消息 [iOSamp;Android]

Telegram 是一款功能强大的消息应用程序,因其易用性、隐私保护和众多炫酷功能而深受用户喜爱。然而,有时我们会不小心删除重要的消息。在这种情况下你应该做什么? 本文将为您提供简单有效的解决方案来恢复 Telegram 上已删除的消息&#xff…...

asp.net core中的 Cookie 和 Session

在 Web 开发中,用户会话管理是非常重要的,尤其是在需要保持用户状态和身份验证的应用中。ASP.NET Core 提供了多种状态管理技术,如 Cookie 和 Session,它们可以帮助你管理用户会话、存储数据并实现用户身份验证等功能。下面将详细…...

Python实现一个简单的 HTTP echo 服务器

一个用来做测试的简单的 HTTP echo 服务器。 from http.server import HTTPServer, BaseHTTPRequestHandler import jsonclass EchoHandler(BaseHTTPRequestHandler):def do_GET(self):# 构造响应数据response_data {path: self.path,method: GET,headers: dict(self.headers…...

Ruby 中文编码

Ruby 中文编码 在 Ruby 编程语言中处理中文编码是一个常见的需求,尤其是在中国和其他使用中文的地区。Ruby 是一种动态、开放源代码的编程语言,它支持多种字符编码,包括中文编码。本文将探讨在 Ruby 中处理中文编码的几种方法,以…...

淘金优化算法的信息共享与更新机制改进

淘金优化算法作为一种模拟自然界淘金过程的启发式搜索算法,在解决复杂优化问题时展现出独特优势。然而,其性能在很大程度上依赖于信息共享与更新机制的有效性。传统机制在面对高维、多模态等复杂问题时,往往存在信息交流不畅、更新滞后等问题,导致算法陷入局部最优或收敛速…...

Python中的ast.literal_eval:安全地解析字符串为Python对象

Python中的ast.literal_eval:安全地解析字符串为Python对象 什么是ast.literal_eval?为什么说它是“安全”的? 如何使用ast.literal_eval?示例1:将字符串转换为列表示例2:将字符串转换为字典示例3&#xff…...

【AI数学基础】线性代数:内积和范数

(观前提醒,这是工科AI相关的数学基础的学习笔记,不是数学专业的文章,所以没有严谨的证明和定义,数院大神请勿批评) 2. 内积和范数 2.1 内积的定义 从代数的角度来说,内积是两个向量之间的一种…...

Go语言的 的泛型(Generics)核心知识

Go语言的泛型(Generics)核心知识 引言 在编程语言的发展历程中,泛型是一项重要的特性。它使得程序员能够编写更加灵活和可重用的代码,减少了代码重复,提高了类型安全性和性能。从最初的C和Java,到现代的R…...

C++vector

1. vector 的介绍及使用 1.1vector的介绍 vector的文档介绍 1.vector是表示可变大小数组的序列容器 2.就像数组一样,vector也采用的连续存储空间来存储元素,也就是意味着可以采用下标对vector 的元素进行访问,和数组一样高效但是又不像数组…...

如何配置【Docker镜像】加速器+【Docker镜像】的使用

一、配置Docker镜像加速器 1. 安装/升级容器引擎客户端​ 推荐安装1.11.2以上版本的容器引擎客户端 2. 配置镜像加速器​ 针对容器引擎客户端版本大于1.11.2的用户 以root用户登录容器引擎所在的虚拟机 修改 "/etc/docker/daemon.json" 文件(如果没有…...

Docker--Docker Network(网络)

Docker Network(网络)是Docker容器之间和容器与外部网络之间的通信和连接的一种机制。以下是对Docker Network的详细解释: 一、Docker网络的重要性 Docker容器网络是为应用程序所创造的虚拟环境的一部分,它能让应用从宿主机操作…...

Vue项目中生成node_modules文件夹的两种常用方法及npm优势

在Vue项目中生成node_modules文件夹的过程非常简单,主要步骤如下: 1、使用 npm 安装依赖包; 2、使用 yarn 安装依赖包。其中,推荐使用npm安装依赖包,原因如下: 兼容性更广:npm是Node.js的默认包管理工具,具有更高的兼容性。社区支持:npm拥有更大的用户基础和社区支持,…...

如何在 Ubuntu 22.04 上安装 Cassandra NoSQL 数据库教程

简介 本教程将向你介绍如何在 Ubuntu 22.04 上安装 Cassandra NoSQL 数据库。 Apache Cassandra 是一个分布式的 NoSQL 数据库,旨在处理跨多个普通服务器的大量数据,并提供高可用性,没有单点故障。Apache Cassandra 是一个高度可扩展的分布…...

leetcode 面试经典 150 题:轮转数组

链接轮转数组题序号189题型数组解法1. 额外数组法,2. 原数组翻转法(三次翻转法)难度中等熟练度✅✅✅✅ 题目 给定一个整数数组 nums,将数组中的元素向右轮转 k 个位置,其中 k 是非负数。 示例 1: 输入: nums [1,2,…...

如何在 Mac 上轻松恢复语音备忘录

在 Mac 上丢失重要的语音备忘录可能会令人沮丧,但好消息是有多种方法可以恢复它们。无论您是意外删除它们还是由于系统故障而丢失,您都可以轻松地在 Mac 上恢复语音备忘录。 在本指南中,我们将探讨两种方法:在没有备份的情况下恢…...

C++ 基础概念: 未定义行为(Undefined Behavior)

文章目录 Intro如何正确认识 UB有多少未定义行为?对 UB 的误解 C 标准定义的几种行为1. 定义的行为 (defined behavior)2. 实现定义的行为 (implementation defined behavior)3. 未指定的行为 (unspecified behavior)4. 未定义行为 (undefined behavior)揭晓答案 C 中如何定义…...

Rad Studio 11.3 Alexandria 3236a(DELPHI 11.3)官方ISO/百度云盘 下载地址

Embarcadero很高兴地宣布RAD Studio 11 Alexandria Release 3的发布,也被称为RAD Studio 11.3,同时发布的还有Delphi 11.3和CBuilder 11.3。这个版本专注于质量和改进,建立在RAD Studio 11 Alexandria三个前版本的伟大的新功能上。 RAD Studi…...

vue3-watchEffect异步依赖收集

当 b 更新时 a 并不会更新&#xff0c;因为watchEffect的依赖收集在该案例中停止于await asyncFn()&#xff0c;也就是只会收集同步代码的依赖&#xff0c;await 之后的异步代码的依赖并不会收集到 <template> <div>a: {{ a }} <br>b: {{ b }} <br>&l…...

微信小程序中 “页面” 和 “非页面” 的区别

微信小程序中 “页面” 和 “非页面” 的区别&#xff0c;并用表格进行对比。 核心概念&#xff1a; 页面 (Page)&#xff1a; 页面是微信小程序中用户可以直接交互的视图层&#xff0c;也是小程序的基本组成部分。每个页面都有自己的 WXML 结构、WXSS 样式和 JavaScript 逻辑…...

CVPR 2025 MIMO: 支持视觉指代和像素grounding 的医学视觉语言模型

CVPR 2025 | MIMO&#xff1a;支持视觉指代和像素对齐的医学视觉语言模型 论文信息 标题&#xff1a;MIMO: A medical vision language model with visual referring multimodal input and pixel grounding multimodal output作者&#xff1a;Yanyuan Chen, Dexuan Xu, Yu Hu…...

《Playwright:微软的自动化测试工具详解》

Playwright 简介:声明内容来自网络&#xff0c;将内容拼接整理出来的文档 Playwright 是微软开发的自动化测试工具&#xff0c;支持 Chrome、Firefox、Safari 等主流浏览器&#xff0c;提供多语言 API&#xff08;Python、JavaScript、Java、.NET&#xff09;。它的特点包括&a…...

今日科技热点速览

&#x1f525; 今日科技热点速览 &#x1f3ae; 任天堂Switch 2 正式发售 任天堂新一代游戏主机 Switch 2 今日正式上线发售&#xff0c;主打更强图形性能与沉浸式体验&#xff0c;支持多模态交互&#xff0c;受到全球玩家热捧 。 &#x1f916; 人工智能持续突破 DeepSeek-R1&…...

【JavaWeb】Docker项目部署

引言 之前学习了Linux操作系统的常见命令&#xff0c;在Linux上安装软件&#xff0c;以及如何在Linux上部署一个单体项目&#xff0c;大多数同学都会有相同的感受&#xff0c;那就是麻烦。 核心体现在三点&#xff1a; 命令太多了&#xff0c;记不住 软件安装包名字复杂&…...

【Java学习笔记】BigInteger 和 BigDecimal 类

BigInteger 和 BigDecimal 类 二者共有的常见方法 方法功能add加subtract减multiply乘divide除 注意点&#xff1a;传参类型必须是类对象 一、BigInteger 1. 作用&#xff1a;适合保存比较大的整型数 2. 使用说明 创建BigInteger对象 传入字符串 3. 代码示例 import j…...

Java毕业设计:WML信息查询与后端信息发布系统开发

JAVAWML信息查询与后端信息发布系统实现 一、系统概述 本系统基于Java和WML(无线标记语言)技术开发&#xff0c;实现了移动设备上的信息查询与后端信息发布功能。系统采用B/S架构&#xff0c;服务器端使用Java Servlet处理请求&#xff0c;数据库采用MySQL存储信息&#xff0…...

Webpack性能优化:构建速度与体积优化策略

一、构建速度优化 1、​​升级Webpack和Node.js​​ ​​优化效果​​&#xff1a;Webpack 4比Webpack 3构建时间降低60%-98%。​​原因​​&#xff1a; V8引擎优化&#xff08;for of替代forEach、Map/Set替代Object&#xff09;。默认使用更快的md4哈希算法。AST直接从Loa…...

c++第七天 继承与派生2

这一篇文章主要内容是 派生类构造函数与析构函数 在派生类中重写基类成员 以及多继承 第一部分&#xff1a;派生类构造函数与析构函数 当创建一个派生类对象时&#xff0c;基类成员是如何初始化的&#xff1f; 1.当派生类对象创建的时候&#xff0c;基类成员的初始化顺序 …...

给网站添加live2d看板娘

给网站添加live2d看板娘 参考文献&#xff1a; stevenjoezhang/live2d-widget: 把萌萌哒的看板娘抱回家 (ノ≧∇≦)ノ | Live2D widget for web platformEikanya/Live2d-model: Live2d model collectionzenghongtu/live2d-model-assets 前言 网站环境如下&#xff0c;文章也主…...

Vue ③-生命周期 || 脚手架

生命周期 思考&#xff1a;什么时候可以发送初始化渲染请求&#xff1f;&#xff08;越早越好&#xff09; 什么时候可以开始操作dom&#xff1f;&#xff08;至少dom得渲染出来&#xff09; Vue生命周期&#xff1a; 一个Vue实例从 创建 到 销毁 的整个过程。 生命周期四个…...