当前位置: 首页 > news >正文

DeBiFormer实战:使用DeBiFormer实现图像分类任务(一)

摘要

一、论文介绍

  • 研究背景:视觉Transformer在计算机视觉领域展现出巨大潜力,能够捕获长距离依赖关系,具有高并行性,有利于大型模型的训练和推理。
  • 现有问题:尽管大量研究设计了高效的注意力模式,但查询并非源自语义区域的关键值对,强制所有查询关注不足的一组令牌可能无法产生最优结果。双级路由注意力虽由语义关键值对处理查询,但可能并非在所有情况下都能产生最优结果。
  • 论文目的:提出DeBiFormer,一种带有可变形双级路由注意力(DBRA)的视觉Transformer,旨在优化查询-键-值交互,自适应选择语义相关区域。
    在这里插入图片描述

二、创新点

  • 可变形双级路由注意力(DBRA):提出一种注意力中注意力架构,通过可变形点和双级路由机制,实现更高效、有意义的注意力分配。
  • 可变形点感知区域划分:确保每个可变形点仅与键值对的一个小子集进行交互,平衡重要区域和不太重要区域之间的注意力分配。
  • 区域间方法:通过构建有向图建立注意关系,使用topk操作符和路由索引矩阵保留每个区域的topk连接。

三、方法

  • 可变形注意力模块:包含一个偏移网络,为参考点生成偏移量,创建可变形点,这些点以高灵活性和效率向重要区域移动,捕获更多信息性特征。
  • 双层标记到可变形层标记注意力:利用区域路由矩阵,对区域内的每个可变形查询标记执行注意力操作,跨越位于topk路由区域中的所有键值对。
  • DeBiFormer模型架构:使用四阶段金字塔结构,包含重叠补丁嵌入、补丁合并模块、DeBiFormer块等,用于降低输入空间分辨率,增加通道数,实现跨位置关系建模和每个位置的嵌入。

四、模块作用

  • 可变形双级路由注意力(DBRA)模块:优化查询-键-值交互,自适应选择语义相关区域,实现更高效和有意义的注意力。通过可变形点和双级路由机制,提高模型对重要区域的关注度,同时减少不太重要区域的注意力。
  • 3x3深度卷积:在DeBiFormer块开始时使用,用于隐式编码相对位置信息,增强模型的局部敏感性。
  • 2-ConvFFN模块:用于每个位置的嵌入,扩展模型的特征表示能力。

五、实验结果

  • 图像分类:在ImageNet-1K数据集上从头训练图像分类模型,验证了DeBiFormer的有效性。
  • 语义分割:在ADE20K数据集上对预训练的主干网络进行微调,DeBiFormer表现出色,证明了其在密集预测任务中的性能。
  • 目标检测和实例分割:使用DeBiFormer作为Mask RCNN和RetinaNet框架中的主干网络,在COCO 2017数据集上评估其性能。尽管资源有限,但DeBiFormer在大目标上的性能优于一些最具竞争力的现有方法。
  • 消融研究:验证了DBRA和DeBiFormer的top-k选择的有效性,证明了可变形双级路由注意力对模型性能的贡献。

总结:本文介绍的DeBiFormer是一种专为图像分类和密集预测任务设计的新型分层视觉Transformer。通过提出可变形双级路由注意力(DBRA),优化了查询-键-值交互,自适应选择语义相关区域,实现了更高效和有意义的注意力。实验结果表明,DeBiFormer在多个计算机视觉任务上均表现出色,为设计灵活且语义感知的注意力机制提供了见解。

本文使用DeBiFormer模型实现图像分类任务,模型选择debi_tiny,在植物幼苗分类任务ACC达到了82%+。

在这里插入图片描述
在这里插入图片描述

通过深入阅读本文,您将能够掌握以下关键技能与知识:

  1. 数据增强的多种策略:包括利用PyTorch的transforms库进行基本增强,以及进阶技巧如CutOut、MixUp、CutMix等,这些方法能显著提升模型泛化能力。

  2. DeBiFormer模型的训练实现:了解如何从头开始构建并训练DeBiFormer(或其他深度学习模型),涵盖模型定义、数据加载、训练循环等关键环节。

  3. 混合精度训练:学习如何利用PyTorch自带的混合精度训练功能,加速训练过程同时减少内存消耗。

  4. 梯度裁剪技术:掌握梯度裁剪的应用,有效防止梯度爆炸问题,确保训练过程的稳定性。

  5. 分布式数据并行(DP)训练:了解如何在多GPU环境下使用PyTorch的分布式数据并行功能,加速大规模模型训练。

  6. 可视化训练过程:学习如何绘制训练过程中的loss和accuracy曲线,直观监控模型学习状况。

  7. 评估与生成报告:掌握在验证集上评估模型性能的方法,并生成详细的评估报告,包括ACC等指标。

  8. 测试脚本编写:学会编写测试脚本,对测试集进行预测,评估模型在实际应用中的表现。

  9. 学习率调整策略:理解并应用余弦退火策略动态调整学习率,优化训练效果。

  10. 自定义统计工具:使用AverageMeter类或其他工具统计和记录训练过程中的ACC、loss等关键指标,便于后续分析。

  11. 深入理解ACC1与ACC5:掌握图像分类任务中ACC1(Top-1准确率)和ACC5(Top-5准确率)的含义及其计算方法。

  12. 指数移动平均(EMA):学习如何在模型训练中应用EMA技术,进一步提升模型在测试集上的表现。

若您在以上任一领域基础尚浅,感到理解困难,推荐您参考我的专栏“经典主干网络精讲与实战”,该专栏从零开始,循序渐进地讲解上述所有知识点,助您轻松掌握深度学习中的这些核心技能。

安装包

安装timm

使用pip就行,命令:

pip install timm

mixup增强和EMA用到了timm

安装einops,执行命令:

pip install einops

数据增强Cutout和Mixup

为了提高模型的泛化能力和性能,我在数据预处理阶段加入了Cutout和Mixup这两种数据增强技术。Cutout通过随机遮挡图像的一部分来强制模型学习更鲁棒的特征,而Mixup则通过混合两张图像及其标签来生成新的训练样本,从而增加数据的多样性。实现这两种增强需要安装torchtoolbox。安装命令:

pip install torchtoolbox

Cutout实现,在transforms中。

from torchtoolbox.transform import Cutout
# 数据预处理
transform = transforms.Compose([transforms.Resize((224, 224)),Cutout(),transforms.ToTensor(),transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])

需要导入包:from timm.data.mixup import Mixup,

定义Mixup,和SoftTargetCrossEntropy

  mixup_fn = Mixup(mixup_alpha=0.8, cutmix_alpha=1.0, cutmix_minmax=None,prob=0.1, switch_prob=0.5, mode='batch',label_smoothing=0.1, num_classes=12)criterion_train = SoftTargetCrossEntropy()

Mixup 是一种在图像分类任务中常用的数据增强技术,它通过将两张图像以及其对应的标签进行线性组合来生成新的数据和标签。
参数详解:

mixup_alpha (float): mixup alpha 值,如果 > 0,则 mixup 处于活动状态。

cutmix_alpha (float):cutmix alpha 值,如果 > 0,cutmix 处于活动状态。

cutmix_minmax (List[float]):cutmix 最小/最大图像比率,cutmix 处于活动状态,如果不是 None,则使用这个 vs alpha。

如果设置了 cutmix_minmax 则cutmix_alpha 默认为1.0

prob (float): 每批次或元素应用 mixup 或 cutmix 的概率。

switch_prob (float): 当两者都处于活动状态时切换cutmix 和mixup 的概率 。

mode (str): 如何应用 mixup/cutmix 参数(每个’batch’,‘pair’(元素对),‘elem’(元素)。

correct_lam (bool): 当 cutmix bbox 被图像边框剪裁时应用。 lambda 校正

label_smoothing (float):将标签平滑应用于混合目标张量。

num_classes (int): 目标的类数。

EMA

EMA(Exponential Moving Average)在深度学习中是一种用于模型参数优化的技术,它通过计算参数的指数移动平均值来平滑模型的学习过程。这种方法有助于提高模型的稳定性和泛化能力,特别是在训练后期。以下是关于EMA的总结,表达进行了优化:

EMA概述

EMA是一种加权移动平均技术,其中每个新的平均值都是前一个平均值和当前值的加权和。在深度学习中,EMA被用于模型参数的更新,以减缓参数在训练过程中的快速波动,从而得到更加平滑和稳定的模型表现。

工作原理

在训练过程中,除了维护当前模型的参数外,还额外保存一份EMA参数。每个训练步骤或每隔一定步骤,根据当前模型参数和EMA参数,按照指数衰减的方式更新EMA参数。具体来说,EMA参数的更新公式通常如下:

EMA new = decay × EMA old + ( 1 − decay ) × model_parameters \text{EMA}_{\text{new}} = \text{decay} \times \text{EMA}_{\text{old}} + (1 - \text{decay}) \times \text{model\_parameters} EMAnew=decay×EMAold+(1decay)×model_parameters
其中,decay是一个介于0和1之间的超参数,控制着旧EMA值和新模型参数值之间的权重分配。较大的decay值意味着EMA更新时更多地依赖于旧值,即平滑效果更强。

应用优势

  1. 稳定性:EMA通过平滑参数更新过程,减少了模型在训练过程中的波动,使得模型更加稳定。
  2. 泛化能力:由于EMA参数是历史参数的平滑版本,它往往能捕捉到模型训练过程中的全局趋势,因此在测试或评估时,使用EMA参数往往能获得更好的泛化性能。
  3. 快速收敛:虽然EMA本身不直接加速训练过程,但通过稳定模型参数,它可能间接地帮助模型更快地收敛到更优的解。

使用场景

EMA在深度学习中的使用场景广泛,特别是在需要高度稳定性和良好泛化能力的任务中,如图像分类、目标检测等。在训练大型模型时,EMA尤其有用,因为它可以帮助减少过拟合的风险,并提高模型在未见数据上的表现。

具体实现如下:


import logging
from collections import OrderedDict
from copy import deepcopy
import torch
import torch.nn as nn_logger = logging.getLogger(__name__)class ModelEma:def __init__(self, model, decay=0.9999, device='', resume=''):# make a copy of the model for accumulating moving average of weightsself.ema = deepcopy(model)self.ema.eval()self.decay = decayself.device = device  # perform ema on different device from model if setif device:self.ema.to(device=device)self.ema_has_module = hasattr(self.ema, 'module')if resume:self._load_checkpoint(resume)for p in self.ema.parameters():p.requires_grad_(False)def _load_checkpoint(self, checkpoint_path):checkpoint = torch.load(checkpoint_path, map_location='cpu')assert isinstance(checkpoint, dict)if 'state_dict_ema' in checkpoint:new_state_dict = OrderedDict()for k, v in checkpoint['state_dict_ema'].items():# ema model may have been wrapped by DataParallel, and need module prefixif self.ema_has_module:name = 'module.' + k if not k.startswith('module') else kelse:name = knew_state_dict[name] = vself.ema.load_state_dict(new_state_dict)_logger.info("Loaded state_dict_ema")else:_logger.warning("Failed to find state_dict_ema, starting from loaded model weights")def update(self, model):# correct a mismatch in state dict keysneeds_module = hasattr(model, 'module') and not self.ema_has_modulewith torch.no_grad():msd = model.state_dict()for k, ema_v in self.ema.state_dict().items():if needs_module:k = 'module.' + kmodel_v = msd[k].detach()if self.device:model_v = model_v.to(device=self.device)ema_v.copy_(ema_v * self.decay + (1. - self.decay) * model_v)

加入到模型中。

#初始化
if use_ema:model_ema = ModelEma(model_ft,decay=model_ema_decay,device='cpu',resume=resume)# 训练过程中,更新完参数后,同步update shadow weights
def train():optimizer.step()if model_ema is not None:model_ema.update(model)# 将model_ema传入验证函数中
val(model_ema.ema, DEVICE, test_loader)

针对没有预训练的模型,容易出现EMA不上分的情况,这点大家要注意啊!

项目结构

DeBiFormer_Demo
├─data1
│  ├─Black-grass
│  ├─Charlock
│  ├─Cleavers
│  ├─Common Chickweed
│  ├─Common wheat
│  ├─Fat Hen
│  ├─Loose Silky-bent
│  ├─Maize
│  ├─Scentless Mayweed
│  ├─Shepherds Purse
│  ├─Small-flowered Cranesbill
│  └─Sugar beet
├─models
│  └─debiformer.py
├─mean_std.py
├─makedata.py
├─train.py
└─test.py

mean_std.py:计算mean和std的值。
makedata.py:生成数据集。
train.py:训练models文件下DeBiFormer的模型
models:来源官方代码。

计算mean和std

在深度学习中,特别是在处理图像数据时,计算数据的均值(mean)和标准差(standard deviation, std)并进行归一化(Normalization)是加速模型收敛、提高模型性能的关键步骤之一。这里我将详细解释这两个概念,并讨论它们如何帮助模型学习。

均值(Mean)

均值是所有数值加和后除以数值的个数得到的平均值。在图像处理中,我们通常对每个颜色通道(如RGB图像的三个通道)分别计算均值。这意味着,如果我们的数据集包含多张图像,我们会计算所有图像在R通道上的像素值的均值,同样地,我们也会计算G通道和B通道的均值。

标准差(Standard Deviation, Std)

标准差是衡量数据分布离散程度的统计量。它反映了数据点与均值的偏离程度。在计算图像数据的标准差时,我们也是针对每个颜色通道分别进行的。标准差较大的颜色通道意味着该通道上的像素值变化较大,而标准差较小的通道则相对较为稳定。

归一化(Normalization)

归一化是将数据按比例缩放,使之落入一个小的特定区间,通常是[0, 1]或[-1, 1]。在图像处理中,我们通常会使用计算得到的均值和标准差来进行归一化,公式如下:

Normalized Value = Original Value − Mean Std \text{Normalized Value} = \frac{\text{Original Value} - \text{Mean}}{\text{Std}} Normalized Value=StdOriginal ValueMean

注意,在某些情况下,为了简化计算并确保数据非负,我们可能会选择将数据缩放到[0, 1]区间,这时使用的是最大最小值归一化,而不是基于均值和标准差的归一化。但在这里,我们主要讨论基于均值和标准差的归一化,因为它能保留数据的分布特性。

为什么需要归一化?

  1. 加速收敛:归一化后的数据具有相似的尺度,这有助于梯度下降算法更快地找到最优解,因为不同特征的梯度更新将在同一数量级上,从而避免了某些特征因尺度过大或过小而导致的训练缓慢或梯度消失/爆炸问题。

  2. 提高精度:归一化可以改善模型的泛化能力,因为它使得模型更容易学习到特征之间的相对关系,而不是被特征的绝对大小所影响。

  3. 稳定性:归一化后的数据更加稳定,减少了训练过程中的波动,有助于模型更加稳定地收敛。

如何计算和使用mean和std

  1. 计算全局mean和std:在整个数据集上计算mean和std。这通常是在训练开始前进行的,并使用这些值来归一化训练集、验证集和测试集。

  2. 使用库函数:许多深度学习框架(如PyTorch、TensorFlow等)提供了计算mean和std的便捷函数,并可以直接用于数据集的归一化。

  3. 动态调整:在某些情况下,特别是当数据集非常大或持续更新时,可能需要动态地计算mean和std。这通常涉及到在训练过程中使用移动平均(如EMA)来更新这些统计量。

计算并使用数据的mean和std进行归一化是深度学习中的一项基本且重要的预处理步骤,它对于加速模型收敛、提高模型性能和稳定性具有重要意义。新建mean_std.py,插入代码:

from torchvision.datasets import ImageFolder
import torch
from torchvision import transformsdef get_mean_and_std(train_data):train_loader = torch.utils.data.DataLoader(train_data, batch_size=1, shuffle=False, num_workers=0,pin_memory=True)mean = torch.zeros(3)std = torch.zeros(3)for X, _ in train_loader:for d in range(3):mean[d] += X[:, d, :, :].mean()std[d] += X[:, d, :, :].std()mean.div_(len(train_data))std.div_(len(train_data))return list(mean.numpy()), list(std.numpy())if __name__ == '__main__':train_dataset = ImageFolder(root=r'data1', transform=transforms.ToTensor())print(get_mean_and_std(train_dataset))

数据集结构:

image-20220221153058619

运行结果:

([0.3281186, 0.28937867, 0.20702125], [0.09407319, 0.09732835, 0.106712654])

把这个结果记录下来,后面要用!

生成数据集

我们整理还的图像分类的数据集结构是这样的

data
├─Black-grass
├─Charlock
├─Cleavers
├─Common Chickweed
├─Common wheat
├─Fat Hen
├─Loose Silky-bent
├─Maize
├─Scentless Mayweed
├─Shepherds Purse
├─Small-flowered Cranesbill
└─Sugar beet

pytorch和keras默认加载方式是ImageNet数据集格式,格式是

├─data
│  ├─val
│  │   ├─Black-grass
│  │   ├─Charlock
│  │   ├─Cleavers
│  │   ├─Common Chickweed
│  │   ├─Common wheat
│  │   ├─Fat Hen
│  │   ├─Loose Silky-bent
│  │   ├─Maize
│  │   ├─Scentless Mayweed
│  │   ├─Shepherds Purse
│  │   ├─Small-flowered Cranesbill
│  │   └─Sugar beet
│  └─train
│      ├─Black-grass
│      ├─Charlock
│      ├─Cleavers
│      ├─Common Chickweed
│      ├─Common wheat
│      ├─Fat Hen
│      ├─Loose Silky-bent
│      ├─Maize
│      ├─Scentless Mayweed
│      ├─Shepherds Purse
│      ├─Small-flowered Cranesbill
│      └─Sugar beet

新增格式转化脚本makedata.py,插入代码:

import glob
import os
import shutilimage_list=glob.glob('data1/*/*.png')
print(image_list)
file_dir='data'
if os.path.exists(file_dir):print('true')#os.rmdir(file_dir)shutil.rmtree(file_dir)#删除再建立os.makedirs(file_dir)
else:os.makedirs(file_dir)from sklearn.model_selection import train_test_split
trainval_files, val_files = train_test_split(image_list, test_size=0.3, random_state=42)
train_dir='train'
val_dir='val'
train_root=os.path.join(file_dir,train_dir)
val_root=os.path.join(file_dir,val_dir)
for file in trainval_files:file_class=file.replace("\\","/").split('/')[-2]file_name=file.replace("\\","/").split('/')[-1]file_class=os.path.join(train_root,file_class)if not os.path.isdir(file_class):os.makedirs(file_class)shutil.copy(file, file_class + '/' + file_name)for file in val_files:file_class=file.replace("\\","/").split('/')[-2]file_name=file.replace("\\","/").split('/')[-1]file_class=os.path.join(val_root,file_class)if not os.path.isdir(file_class):os.makedirs(file_class)shutil.copy(file, file_class + '/' + file_name)

完成上面的内容就可以开启训练和测试了。

DeBiFormer代码

import numpy as np
from collections import defaultdict
import matplotlib.pyplot as plt
from timm.models.registry import register_model
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import datasets, transforms
import torchvisionfrom torch import Tensor
from typing import Tuple
import numbers
from timm.models.layers import to_2tuple, trunc_normal_
from einops import rearrange
import gc
import torch
import torch.nn as nn
from einops import rearrange
from timm.models.layers import DropPath, to_2tuple, trunc_normal_
from collections import OrderedDict
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from einops.layers.torch import Rearrange
from fairscale.nn.checkpoint import checkpoint_wrapper
from timm.models import register_model
from timm.models.layers import DropPath, to_2tuple, trunc_normal_
from timm.models.vision_transformer import _cfgclass LayerNorm2d(nn.Module):def __init__(self, channels):super().__init__()self.ln = nn.LayerNorm(channels)def forward(self, x):x = rearrange(x, "N C H W -> N H W C")x = self.ln(x)x = rearrange(x, "N H W C -> N C H W")return xdef init_linear(m):if isinstance(m, (nn.Conv2d, nn.Linear)):nn.init.kaiming_normal_(m.weight)if m.bias is not None: nn.init.zeros_(m.bias)elif isinstance(m, nn.LayerNorm):nn.init.constant_(m.bias, 0)nn.init.constant_(m.weight, 1.0)def to_4d(x,h,w):return rearrange(x, 'b (h w) c -> b c h w',h=h,w=w)#def to_4d(x,s,h,w):
#    return rearrange(x, 'b (s h w) c -> b c s h w',s=s,h=h,w=w)def to_3d(x):return rearrange(x, 'b c h w -> b (h w) c')#def to_3d(x):
#    return rearrange(x, 'b c s h w -> b (s h w) c')class Partial:def __init__(self, module, *args, **kwargs):self.module = moduleself.args = argsself.kwargs = kwargsdef __call__(self, *args_c, **kwargs_c):return self.module(*args_c, *self.args, **kwargs_c, **self.kwargs)class LayerNormChannels(nn.Module):def __init__(self, channels):super().__init__()self.norm = nn.LayerNorm(channels)def forward(self, x):x = x.transpose(1, -1)x = self.norm(x)x = x.transpose(-1, 1)return xclass LayerNormProxy(nn.Module):def __init__(self, dim):super().__init__()self.norm = nn.LayerNorm(dim)def forward(self, x):x = rearrange(x, 'b c h w -> b h w c')x = self.norm(x)return rearrange(x, 'b h w c -> b c h w')class BiasFree_LayerNorm(nn.Module):def __init__(self, normalized_shape):super(BiasFree_LayerNorm, self).__init__()if isinstance(normalized_shape, numbers.Integral):normalized_shape = (normalized_shape,)normalized_shape = torch.Size(normalized_shape)assert len(normalized_shape) == 1self.weight = nn.Parameter(torch.ones(normalized_shape))self.normalized_shape = normalized_shapedef forward(self, x):sigma = x.var(-1, keepdim=True, unbiased=False)return x / torch.sqrt(sigma+1e-5) * self.weightclass WithBias_LayerNorm(nn.Module):def __init__(self, normalized_shape):super(WithBias_LayerNorm, self).__init__()if isinstance(normalized_shape, numbers.Integral):normalized_shape = (normalized_shape,)normalized_shape = torch.Size(normalized_shape)assert len(normalized_shape) == 1self.weight = nn.Parameter(torch.ones(normalized_shape))self.bias = nn.Parameter(torch.zeros(normalized_shape))self.normalized_shape = normalized_shapedef forward(self, x):mu = x.mean(-1, keepdim=True)sigma = x.var(-1, keepdim=True, unbiased=False)return (x - mu) / torch.sqrt(sigma+1e-5) * self.weight + self.biasclass LayerNorm(nn.Module):def __init__(self, dim, LayerNorm_type):super(LayerNorm, self).__init__()if LayerNorm_type =='BiasFree':self.body = BiasFree_LayerNorm(dim)else:self.body = WithBias_LayerNorm(dim)def forward(self, x):h, w = x.shape[-2:]return to_4d(self.body(to_3d(x)), h, w)#class LayerNorm(nn.Module):
#    def __init__(self, dim, LayerNorm_type):
#        super(LayerNorm, self).__init__()
#        if LayerNorm_type =='BiasFree':
#            self.body = BiasFree_LayerNorm(dim)
#        else:
#            self.body = WithBias_LayerNorm(dim)
#    def forward(self, x):
#        s, h, w = x.shape[-3:]
#        return to_4d(self.body(to_3d(x)),s, h, w)class DWConv(nn.Module):def __init__(self, dim=768):super(DWConv, self).__init__()self.dwconv = nn.Conv2d(dim, dim, 3, 1, 1, bias=True, groups=dim)def forward(self, x):"""x: NHWC tensor"""x = x.permute(0, 3, 1, 2) #NCHWx = self.dwconv(x)x = x.permute(0, 2, 3, 1) #NHWCreturn xclass ConvFFN(nn.Module):def __init__(self, dim=768):super(DWConv, self).__init__()self.dwconv = nn.Conv2d(dim, dim, 1, 1, 0)def forward(self, x):"""x: NHWC tensor"""x = x.permute(0, 3, 1, 2) #NCHWx = self.dwconv(x)x = x.permute(0, 2, 3, 1) #NHWCreturn xclass Attention(nn.Module):"""vanilla attention"""def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):super().__init__()self.num_heads = num_headshead_dim = dim // num_heads# NOTE scale factor was wrong in my original version, can set manually to be compat with prev weightsself.scale = qk_scale or head_dim ** -0.5self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)self.attn_drop = nn.Dropout(attn_drop)self.proj = nn.Linear(dim, dim)self.proj_drop = nn.Dropout(proj_drop)def forward(self, x):"""args:x: NHWC tensorreturn:NHWC tensor"""_, H, W, _ = x.size()x = rearrange(x, 'n h w c -> n (h w) c')#######################################B, N, C = x.shapeqkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)q, k, v = qkv[0], qkv[1], qkv[2]   # make torchscript happy (cannot use tensor as tuple)attn = (q @ k.transpose(-2, -1)) * self.scaleattn = attn.softmax(dim=-1)attn = self.attn_drop(attn)x = (attn @ v).transpose(1, 2).reshape(B, N, C)x = self.proj(x)x = self.proj_drop(x)#######################################x = rearrange(x, 'n (h w) c -> n h w c', h=H, w=W)return xclass AttentionLePE(nn.Module):"""vanilla attention"""def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., side_dwconv=5):super().__init__()self.num_heads = num_headshead_dim = dim // num_heads# NOTE scale factor was wrong in my original version, can set manually to be compat with prev weightsself.scale = qk_scale or head_dim ** -0.5self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)self.attn_drop = nn.Dropout(attn_drop)self.proj = nn.Linear(dim, dim)self.proj_drop = nn.Dropout(proj_drop)self.lepe = nn.Conv2d(dim, dim, kernel_size=side_dwconv, stride=1, padding=side_dwconv//2, groups=dim) if side_dwconv > 0 else \lambda x: torch.zeros_like(x)def forward(self, x):"""args:x: NHWC tensorreturn:NHWC tensor"""_, H, W, _ = x.size()x = rearrange(x, 'n h w c -> n (h w) c')#######################################B, N, C = x.shapeqkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)q, k, v = qkv[0], qkv[1], qkv[2]   # make torchscript happy (cannot use tensor as tuple)lepe = self.lepe(rearrange(x, 'n (h w) c -> n c h w', h=H, w=W))lepe = rearrange(lepe, 'n c h w -> n (h w) c')attn = (q @ k.transpose(-2, -1)) * self.scaleattn = attn.softmax(dim=-1)attn = self.attn_drop(attn)x = (attn @ v).transpose(1, 2).reshape(B, N, C)x = x + lepex = self.proj(x)x = self.proj_drop(x)#######################################x = rearrange(x, 'n (h w) c -> n h w c', h=H, w=W)return xclass nchwAttentionLePE(nn.Module):"""Attention with LePE, takes nchw input"""def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., side_dwconv=5):super().__init__()self.num_heads = num_headsself.head_dim = dim // num_headsself.scale = qk_scale or self.head_dim ** -0.5self.qkv = nn.Conv2d(dim, dim*3, kernel_size=1, bias=qkv_bias)self.attn_drop = nn.Dropout(attn_drop)self.proj = nn.Conv2d(dim, dim, kernel_size=1)self.proj_drop = nn.Dropout(proj_drop)self.lepe = nn.Conv2d(dim, dim, kernel_size=side_dwconv, stride=1, padding=side_dwconv//2, groups=dim) if side_dwconv > 0 else \lambda x: torch.zeros_like(x)def forward(self, x:torch.Tensor):"""args:x: NCHW tensorreturn:NCHW tensor"""B, C, H, W = x.size()q, k, v = self.qkv.forward(x).chunk(3, dim=1) # B, C, H, Wattn = q.view(B, self.num_heads, self.head_dim, H*W).transpose(-1, -2) @ \k.view(B, self.num_heads, self.head_dim, H*W)attn = torch.softmax(attn*self.scale, dim=-1)attn = self.attn_drop(attn)# (B, nhead, HW, HW) @ (B, nhead, HW, head_dim) -> (B, nhead, HW, head_dim)output:torch.Tensor = attn @ v.view(B, self.num_heads, self.head_dim, H*W).transpose(-1, -2)output = output.permute(0, 1, 3, 2).reshape(B, C, H, W)output = output + self.lepe(v)output = self.proj_drop(self.proj(output))return outputclass TopkRouting(nn.Module):"""differentiable topk routing with scalingArgs:qk_dim: int, feature dimension of query and keytopk: int, the 'topk'qk_scale: int or None, temperature (multiply) of softmax activationwith_param: bool, wether inorporate learnable params in routing unitdiff_routing: bool, wether make routing differentiablesoft_routing: bool, wether make output value multiplied by routing weights"""def __init__(self, qk_dim, topk=4, qk_scale=None, param_routing=False, diff_routing=False):super().__init__()self.topk = topkself.qk_dim = qk_dimself.scale = qk_scale or qk_dim ** -0.5self.diff_routing = diff_routing# TODO: norm layer before/after linear?self.emb = nn.Linear(qk_dim, qk_dim) if param_routing else nn.Identity()# routing activationself.routing_act = nn.Softmax(dim=-1)def forward(self, query:Tensor, key:Tensor)->Tuple[Tensor]:"""Args:q, k: (n, p^2, c) tensorReturn:r_weight, topk_index: (n, p^2, topk) tensor"""if not self.diff_routing:query, key = query.detach(), key.detach()query_hat, key_hat = self.emb(query), self.emb(key) # per-window pooling -> (n, p^2, c)attn_logit = (query_hat*self.scale) @ key_hat.transpose(-2, -1) # (n, p^2, p^2)topk_attn_logit, topk_index = torch.topk(attn_logit, k=self.topk, dim=-1) # (n, p^2, k), (n, p^2, k)r_weight = self.routing_act(topk_attn_logit) # (n, p^2, k)return r_weight, topk_indexclass KVGather(nn.Module):def __init__(self, mul_weight='none'):super().__init__()assert mul_weight in ['none', 'soft', 'hard']self.mul_weight = mul_weightdef forward(self, r_idx:Tensor, r_weight:Tensor, kv:Tensor):"""r_idx: (n, p^2, topk) tensorr_weight: (n, p^2, topk) tensorkv: (n, p^2, w^2, c_kq+c_v)Return:(n, p^2, topk, w^2, c_kq+c_v) tensor"""# select kv according to routing indexn, p2, w2, c_kv = kv.size()topk = r_idx.size(-1)# print(r_idx.size(), r_weight.size())# FIXME: gather consumes much memory (topk times redundancy), write cuda kernel?topk_kv = torch.gather(kv.view(n, 1, p2, w2, c_kv).expand(-1, p2, -1, -1, -1), # (n, p^2, p^2, w^2, c_kv) without mem cpydim=2,index=r_idx.view(n, p2, topk, 1, 1).expand(-1, -1, -1, w2, c_kv) # (n, p^2, k, w^2, c_kv))if self.mul_weight == 'soft':topk_kv = r_weight.view(n, p2, topk, 1, 1) * topk_kv # (n, p^2, k, w^2, c_kv)elif self.mul_weight == 'hard':raise NotImplementedError('differentiable hard routing TBA')# else: #'none'#     topk_kv = topk_kv # do nothingreturn topk_kvclass QKVLinear(nn.Module):def __init__(self, dim, qk_dim, bias=True):super().__init__()self.dim = dimself.qk_dim = qk_dimself.qkv = nn.Linear(dim, qk_dim + qk_dim + dim, bias=bias)def forward(self, x):q, kv = self.qkv(x).split([self.qk_dim, self.qk_dim+self.dim], dim=-1)return q, kv# q, k, v = self.qkv(x).split([self.qk_dim, self.qk_dim, self.dim], dim=-1)# return q, k, vclass QKVConv(nn.Module):def __init__(self, dim, qk_dim, bias=True):super().__init__()self.dim = dimself.qk_dim = qk_dimself.qkv = nn.Conv2d(dim,  qk_dim + qk_dim + dim, 1, 1, 0)def forward(self, x):q, kv = self.qkv(x).split([self.qk_dim, self.qk_dim+self.dim], dim=1)return q, kvclass BiLevelRoutingAttention(nn.Module):"""n_win: number of windows in one side (so the actual number of windows is n_win*n_win)kv_per_win: for kv_downsample_mode='ada_xxxpool' only, number of key/values per window. Similar to n_win, the actual number is kv_per_win*kv_per_win.topk: topk for window filteringparam_attention: 'qkvo'-linear for q,k,v and o, 'none': param free attentionparam_routing: extra linear for routingdiff_routing: wether to set routing differentiablesoft_routing: wether to multiply soft routing weights """def __init__(self, dim, num_heads=8, n_win=7, qk_dim=None, qk_scale=None,kv_per_win=4, kv_downsample_ratio=4, kv_downsample_kernel=None, kv_downsample_mode='identity',topk=4, param_attention="qkvo", param_routing=False, diff_routing=False, soft_routing=False, side_dwconv=3,auto_pad=False):super().__init__()# local attention settingself.dim = dimself.n_win = n_win  # Wh, Wwself.num_heads = num_headsself.qk_dim = qk_dim or dimassert self.qk_dim % num_heads == 0 and self.dim % num_heads==0, 'qk_dim and dim must be divisible by num_heads!'self.scale = qk_scale or self.qk_dim ** -0.5################side_dwconv (i.e. LCE in ShuntedTransformer)###########self.lepe = nn.Conv2d(dim, dim, kernel_size=side_dwconv, stride=1, padding=side_dwconv//2, groups=dim) if side_dwconv > 0 else \lambda x: torch.zeros_like(x)################ global routing setting #################self.topk = topkself.param_routing = param_routingself.diff_routing = diff_routingself.soft_routing = soft_routing# routerassert not (self.param_routing and not self.diff_routing) # cannot be with_param=True and diff_routing=Falseself.router = TopkRouting(qk_dim=self.qk_dim,qk_scale=self.scale,topk=self.topk,diff_routing=self.diff_routing,param_routing=self.param_routing)if self.soft_routing: # soft routing, always diffrentiable (if no detach)mul_weight = 'soft'elif self.diff_routing: # hard differentiable routingmul_weight = 'hard'else:  # hard non-differentiable routingmul_weight = 'none'self.kv_gather = KVGather(mul_weight=mul_weight)# qkv mapping (shared by both global routing and local attention)self.param_attention = param_attentionif self.param_attention == 'qkvo':self.qkv = QKVLinear(self.dim, self.qk_dim)self.wo = nn.Linear(dim, dim)elif self.param_attention == 'qkv':self.qkv = QKVLinear(self.dim, self.qk_dim)self.wo = nn.Identity()else:raise ValueError(f'param_attention mode {self.param_attention} is not surpported!')self.kv_downsample_mode = kv_downsample_modeself.kv_per_win = kv_per_winself.kv_downsample_ratio = kv_downsample_ratioself.kv_downsample_kenel = kv_downsample_kernelif self.kv_downsample_mode == 'ada_avgpool':assert self.kv_per_win is not Noneself.kv_down = nn.AdaptiveAvgPool2d(self.kv_per_win)elif self.kv_downsample_mode == 'ada_maxpool':assert self.kv_per_win is not Noneself.kv_down = nn.AdaptiveMaxPool2d(self.kv_per_win)elif self.kv_downsample_mode == 'maxpool':assert self.kv_downsample_ratio is not Noneself.kv_down = nn.MaxPool2d(self.kv_downsample_ratio) if self.kv_downsample_ratio > 1 else nn.Identity()elif self.kv_downsample_mode == 'avgpool':assert self.kv_downsample_ratio is not Noneself.kv_down = nn.AvgPool2d(self.kv_downsample_ratio) if self.kv_downsample_ratio > 1 else nn.Identity()elif self.kv_downsample_mode == 'identity': # no kv downsamplingself.kv_down = nn.Identity()elif self.kv_downsample_mode == 'fracpool':# assert self.kv_downsample_ratio is not None# assert self.kv_downsample_kenel is not None# TODO: fracpool# 1. kernel size should be input size dependent# 2. there is a random factor, need to avoid independent sampling for k and v raise NotImplementedError('fracpool policy is not implemented yet!')elif kv_downsample_mode == 'conv':# TODO: need to consider the case where k != v so that need two downsample modulesraise NotImplementedError('conv policy is not implemented yet!')else:raise ValueError(f'kv_down_sample_mode {self.kv_downsaple_mode} is not surpported!')# softmax for local attentionself.attn_act = nn.Softmax(dim=-1)self.auto_pad=auto_paddef forward(self, x, ret_attn_mask=False):"""x: NHWC tensorReturn:NHWC tensor"""# NOTE: use padding for semantic segmentation###################################################if self.auto_pad:N, H_in, W_in, C = x.size()pad_l = pad_t = 0pad_r = (self.n_win - W_in % self.n_win) % self.n_winpad_b = (self.n_win - H_in % self.n_win) % self.n_winx = F.pad(x, (0, 0, # dim=-1pad_l, pad_r, # dim=-2pad_t, pad_b)) # dim=-3_, H, W, _ = x.size() # padded sizeelse:N, H, W, C = x.size()#assert H%self.n_win == 0 and W%self.n_win == 0 ##################################################### patchify, (n, p^2, w, w, c), keep 2d window as we need 2d pooling to reduce kv sizex = rearrange(x, "n (j h) (i w) c -> n (j i) h w c", j=self.n_win, i=self.n_win)#################qkv projection#################### q: (n, p^2, w, w, c_qk)# kv: (n, p^2, w, w, c_qk+c_v)# NOTE: separte kv if there were memory leak issue caused by gatherq, kv = self.qkv(x) # pixel-wise qkv# q_pix: (n, p^2, w^2, c_qk)# kv_pix: (n, p^2, h_kv*w_kv, c_qk+c_v)q_pix = rearrange(q, 'n p2 h w c -> n p2 (h w) c')kv_pix = self.kv_down(rearrange(kv, 'n p2 h w c -> (n p2) c h w'))kv_pix = rearrange(kv_pix, '(n j i) c h w -> n (j i) (h w) c', j=self.n_win, i=self.n_win)q_win, k_win = q.mean([2, 3]), kv[..., 0:self.qk_dim].mean([2, 3]) # window-wise qk, (n, p^2, c_qk), (n, p^2, c_qk)##################side_dwconv(lepe)################### NOTE: call contiguous to avoid gradient warning when using ddplepe = self.lepe(rearrange(kv[..., self.qk_dim:], 'n (j i) h w c -> n c (j h) (i w)', j=self.n_win, i=self.n_win).contiguous())lepe = rearrange(lepe, 'n c (j h) (i w) -> n (j h) (i w) c', j=self.n_win, i=self.n_win)############ gather q dependent k/v #################r_weight, r_idx = self.router(q_win, k_win) # both are (n, p^2, topk) tensorskv_pix_sel = self.kv_gather(r_idx=r_idx, r_weight=r_weight, kv=kv_pix) #(n, p^2, topk, h_kv*w_kv, c_qk+c_v)k_pix_sel, v_pix_sel = kv_pix_sel.split([self.qk_dim, self.dim], dim=-1)# kv_pix_sel: (n, p^2, topk, h_kv*w_kv, c_qk)# v_pix_sel: (n, p^2, topk, h_kv*w_kv, c_v)######### do attention as normal ####################k_pix_sel = rearrange(k_pix_sel, 'n p2 k w2 (m c) -> (n p2) m c (k w2)', m=self.num_heads) # flatten to BMLC, (n*p^2, m, topk*h_kv*w_kv, c_kq//m) transpose here?v_pix_sel = rearrange(v_pix_sel, 'n p2 k w2 (m c) -> (n p2) m (k w2) c', m=self.num_heads) # flatten to BMLC, (n*p^2, m, topk*h_kv*w_kv, c_v//m)q_pix = rearrange(q_pix, 'n p2 w2 (m c) -> (n p2) m w2 c', m=self.num_heads) # to BMLC tensor (n*p^2, m, w^2, c_qk//m)# param-free multihead attentionattn_weight = (q_pix * self.scale) @ k_pix_sel # (n*p^2, m, w^2, c) @ (n*p^2, m, c, topk*h_kv*w_kv) -> (n*p^2, m, w^2, topk*h_kv*w_kv)attn_weight = self.attn_act(attn_weight)out = attn_weight @ v_pix_sel # (n*p^2, m, w^2, topk*h_kv*w_kv) @ (n*p^2, m, topk*h_kv*w_kv, c) -> (n*p^2, m, w^2, c)out = rearrange(out, '(n j i) m (h w) c -> n (j h) (i w) (m c)', j=self.n_win, i=self.n_win,h=H//self.n_win, w=W//self.n_win)out = out + lepe# output linearout = self.wo(out)# NOTE: use padding for semantic segmentation# crop padded regionif self.auto_pad and (pad_r > 0 or pad_b > 0):out = out[:, :H_in, :W_in, :].contiguous()if ret_attn_mask:return out, r_weight, r_idx, attn_weightelse:return outclass TransformerMLPWithConv(nn.Module):def __init__(self, channels, expansion, drop):super().__init__()self.dim1 = channelsself.dim2 = channels * expansionself.linear1 = nn.Sequential(nn.Conv2d(self.dim1, self.dim2, 1, 1, 0),# nn.GELU(),# nn.BatchNorm2d(self.dim2, eps=1e-5))self.drop1 = nn.Dropout(drop, inplace=True)self.act = nn.GELU()# self.bn = nn.BatchNorm2d(self.dim2, eps=1e-5)self.linear2 = nn.Sequential(nn.Conv2d(self.dim2, self.dim1, 1, 1, 0),# nn.BatchNorm2d(self.dim1, eps=1e-5))self.drop2 = nn.Dropout(drop, inplace=True)self.dwc = nn.Conv2d(self.dim2, self.dim2, 3, 1, 1, groups=self.dim2)def forward(self, x):x = self.linear1(x)x = self.drop1(x)x = x + self.dwc(x)x = self.act(x)# x = self.bn(x)x = self.linear2(x)x = self.drop2(x)return xclass DeBiLevelRoutingAttention(nn.Module):"""n_win: number of windows in one side (so the actual number of windows is n_win*n_win)kv_per_win: for kv_downsample_mode='ada_xxxpool' only, number of key/values per window. Similar to n_win, the actual number is kv_per_win*kv_per_win.topk: topk for window filteringparam_attention: 'qkvo'-linear for q,k,v and o, 'none': param free attentionparam_routing: extra linear for routingdiff_routing: wether to set routing differentiablesoft_routing: wether to multiply soft routing weights"""def __init__(self, dim, num_heads=8, n_win=7, qk_dim=None, qk_scale=None,kv_per_win=4, kv_downsample_ratio=4, kv_downsample_kernel=None, kv_downsample_mode='identity',topk=4, param_attention="qkvo", param_routing=False, diff_routing=False, soft_routing=False, side_dwconv=3,auto_pad=False, param_size='small'):super().__init__()# local attention settingself.dim = dimself.n_win = n_win  # Wh, Wwself.num_heads = num_headsself.qk_dim = qk_dim or dim#############################################################if param_size=='tiny':if self.dim == 64 :self.n_groups = 1self.top_k_def = 16   # 2    128self.kk = 9self.stride_def = 8self.expain_ratio = 3self.q_size=to_2tuple(56)if self.dim == 128 :self.n_groups = 2self.top_k_def = 16   # 4    256self.kk = 7self.stride_def = 4self.expain_ratio = 3self.q_size=to_2tuple(28)if self.dim == 256 :self.n_groups = 4self.top_k_def = 4   # 8    512self.kk = 5self.stride_def = 2self.expain_ratio = 3self.q_size=to_2tuple(14)if self.dim == 512 :self.n_groups = 8self.top_k_def = 49   # 8    512self.kk = 3self.stride_def = 1self.expain_ratio = 3self.q_size=to_2tuple(7)
#############################################################if param_size=='small':if self.dim == 64 :self.n_groups = 1self.top_k_def = 16   # 2    128self.kk = 9self.stride_def = 8self.expain_ratio = 3self.q_size=to_2tuple(56)if self.dim == 128 :self.n_groups = 2self.top_k_def = 16   # 4    256self.kk = 7self.stride_def = 4self.expain_ratio = 3self.q_size=to_2tuple(28)if self.dim == 256 :self.n_groups = 4self.top_k_def = 4   # 8    512self.kk = 5self.stride_def = 2self.expain_ratio = 3self.q_size=to_2tuple(14)if self.dim == 512 :self.n_groups = 8self.top_k_def = 49   # 8    512self.kk = 3self.stride_def = 1self.expain_ratio = 1self.q_size=to_2tuple(7)
#############################################################if param_size=='base':if self.dim == 96 :self.n_groups = 1self.top_k_def = 16   # 2    128self.kk = 9self.stride_def = 8self.expain_ratio = 3self.q_size=to_2tuple(56)if self.dim == 192 :self.n_groups = 2self.top_k_def = 16   # 4    256self.kk = 7self.stride_def = 4self.expain_ratio = 3self.q_size=to_2tuple(28)if self.dim == 384 :self.n_groups = 3self.top_k_def = 4   # 8    512self.kk = 5self.stride_def = 2self.expain_ratio = 3self.q_size=to_2tuple(14)if self.dim == 768 :self.n_groups = 6self.top_k_def = 49   # 8    512self.kk = 3self.stride_def = 1self.expain_ratio = 3self.q_size=to_2tuple(7)self.q_h, self.q_w = self.q_sizeself.kv_h, self.kv_w = self.q_h // self.stride_def, self.q_w // self.stride_defself.n_group_channels = self.dim // self.n_groupsself.n_group_heads = self.num_heads // self.n_groupsself.n_group_channels = self.dim // self.n_groupsself.offset_range_factor = -1self.head_channels = dim // num_headsself.n_group_heads = self.num_heads // self.n_groups#assert self.qk_dim % num_heads == 0 and self.dim % num_heads==0, 'qk_dim and dim must be divisible by num_heads!'self.scale = qk_scale or self.qk_dim ** -0.5self.rpe_table = nn.Parameter(torch.zeros(self.num_heads, self.q_h * 2 - 1, self.q_w * 2 - 1))trunc_normal_(self.rpe_table, std=0.01)################side_dwconv (i.e. LCE in ShuntedTransformer)###########self.lepe1 = nn.Conv2d(dim, dim, kernel_size=side_dwconv, stride=self.stride_def, padding=side_dwconv//2, groups=dim) if side_dwconv > 0 else \lambda x: torch.zeros_like(x)################ global routing setting #################self.topk = topkself.param_routing = param_routingself.diff_routing = diff_routingself.soft_routing = soft_routing# router#assert not (self.param_routing and not self.diff_routing) # cannot be with_param=True and diff_routing=Falseself.router = TopkRouting(qk_dim=self.qk_dim,qk_scale=self.scale,topk=self.topk,diff_routing=self.diff_routing,param_routing=self.param_routing)if self.soft_routing: # soft routing, always diffrentiable (if no detach)mul_weight = 'soft'elif self.diff_routing: # hard differentiable routingmul_weight = 'hard'else:  # hard non-differentiable routingmul_weight = 'none'self.kv_gather = KVGather(mul_weight=mul_weight)# qkv mapping (shared by both global routing and local attention)self.param_attention = param_attentionif self.param_attention == 'qkvo':#self.qkv = QKVLinear(self.dim, self.qk_dim)self.qkv_conv = QKVConv(self.dim, self.qk_dim)#self.wo = nn.Linear(dim, dim)elif self.param_attention == 'qkv':#self.qkv = QKVLinear(self.dim, self.qk_dim)self.qkv_conv = QKVConv(self.dim, self.qk_dim)#self.wo = nn.Identity()else:raise ValueError(f'param_attention mode {self.param_attention} is not surpported!')self.kv_downsample_mode = kv_downsample_modeself.kv_per_win = kv_per_winself.kv_downsample_ratio = kv_downsample_ratioself.kv_downsample_kenel = kv_downsample_kernelif self.kv_downsample_mode == 'ada_avgpool':assert self.kv_per_win is not Noneself.kv_down = nn.AdaptiveAvgPool2d(self.kv_per_win)elif self.kv_downsample_mode == 'ada_maxpool':assert self.kv_per_win is not Noneself.kv_down = nn.AdaptiveMaxPool2d(self.kv_per_win)elif self.kv_downsample_mode == 'maxpool':assert self.kv_downsample_ratio is not Noneself.kv_down = nn.MaxPool2d(self.kv_downsample_ratio) if self.kv_downsample_ratio > 1 else nn.Identity()elif self.kv_downsample_mode == 'avgpool':assert self.kv_downsample_ratio is not Noneself.kv_down = nn.AvgPool2d(self.kv_downsample_ratio) if self.kv_downsample_ratio > 1 else nn.Identity()elif self.kv_downsample_mode == 'identity': # no kv downsamplingself.kv_down = nn.Identity()elif self.kv_downsample_mode == 'fracpool':raise NotImplementedError('fracpool policy is not implemented yet!')elif kv_downsample_mode == 'conv':raise NotImplementedError('conv policy is not implemented yet!')else:raise ValueError(f'kv_down_sample_mode {self.kv_downsaple_mode} is not surpported!')self.attn_act = nn.Softmax(dim=-1)self.auto_pad=auto_pad##########################################################################################self.proj_q = nn.Conv2d(dim, dim,kernel_size=1, stride=1, padding=0)self.proj_k = nn.Conv2d(dim, dim,kernel_size=1, stride=1, padding=0)self.proj_v = nn.Conv2d(dim, dim,kernel_size=1, stride=1, padding=0)self.proj_out = nn.Conv2d(dim, dim,kernel_size=1, stride=1, padding=0)self.unifyheads1 = nn.Conv2d(dim, dim,kernel_size=1, stride=1, padding=0)self.conv_offset_q = nn.Sequential(nn.Conv2d(self.n_group_channels, self.n_group_channels, (self.kk,self.kk), (self.stride_def,self.stride_def), (self.kk//2,self.kk//2), groups=self.n_group_channels, bias=False),LayerNormProxy(self.n_group_channels),nn.GELU(),nn.Conv2d(self.n_group_channels, 1, 1, 1, 0, bias=False),)### FFNself.norm = nn.LayerNorm(dim, eps=1e-6)self.norm2 = nn.LayerNorm(dim, eps=1e-6)self.mlp =TransformerMLPWithConv(dim, self.expain_ratio, 0.)@torch.no_grad()def _get_ref_points(self, H_key, W_key, B, dtype, device):ref_y, ref_x = torch.meshgrid(torch.linspace(0.5, H_key - 0.5, H_key, dtype=dtype, device=device),torch.linspace(0.5, W_key - 0.5, W_key, dtype=dtype, device=device))ref = torch.stack((ref_y, ref_x), -1)ref[..., 1].div_(W_key).mul_(2).sub_(1)ref[..., 0].div_(H_key).mul_(2).sub_(1)ref = ref[None, ...].expand(B * self.n_groups, -1, -1, -1) # B * g H W 2return ref@torch.no_grad()def _get_q_grid(self, H, W, B, dtype, device):ref_y, ref_x = torch.meshgrid(torch.arange(0, H, dtype=dtype, device=device),torch.arange(0, W, dtype=dtype, device=device),indexing='ij')ref = torch.stack((ref_y, ref_x), -1)ref[..., 1].div_(W - 1.0).mul_(2.0).sub_(1.0)ref[..., 0].div_(H - 1.0).mul_(2.0).sub_(1.0)ref = ref[None, ...].expand(B * self.n_groups, -1, -1, -1) # B * g H W 2return refdef forward(self, x, ret_attn_mask=False):dtype, device = x.dtype, x.device"""x: NHWC tensorReturn:NHWC tensor"""
# NOTE: use padding for semantic segmentation
###################################################if self.auto_pad:N, H_in, W_in, C = x.size()pad_l = pad_t = 0pad_r = (self.n_win - W_in % self.n_win) % self.n_winpad_b = (self.n_win - H_in % self.n_win) % self.n_winx = F.pad(x, (0, 0, # dim=-1pad_l, pad_r, # dim=-2pad_t, pad_b)) # dim=-3_, H, W, _ = x.size() # padded sizeelse:N, H, W, C = x.size()assert H%self.n_win == 0 and W%self.n_win == 0 ##print("X_in")#print(x.shape)####################################################q=self.proj_q_def(x)x_res = rearrange(x, "n h w c -> n c h w")
#################qkv projection###################q,kv = self.qkv_conv(x.permute(0, 3, 1, 2))q_bi = rearrange(q, "n c (j h) (i w) -> n (j i) h w c", j=self.n_win, i=self.n_win)kv = rearrange(kv, "n c (j h) (i w) -> n (j i) h w c", j=self.n_win, i=self.n_win)q_pix = rearrange(q_bi, 'n p2 h w c -> n p2 (h w) c')kv_pix = self.kv_down(rearrange(kv, 'n p2 h w c -> (n p2) c h w'))kv_pix = rearrange(kv_pix, '(n j i) c h w -> n (j i) (h w) c', j=self.n_win, i=self.n_win)##################side_dwconv(lepe)################### NOTE: call contiguous to avoid gradient warning when using ddplepe1 = self.lepe1(rearrange(kv[..., self.qk_dim:], 'n (j i) h w c -> n c (j h) (i w)', j=self.n_win, i=self.n_win).contiguous())#################################################################   Offset Qq_off = rearrange(q, 'b (g c) h w -> (b g) c h w', g=self.n_groups, c=self.n_group_channels)offset_q = self.conv_offset_q(q_off).contiguous() # B * g 2 Sg HWgHk, Wk = offset_q.size(2), offset_q.size(3)n_sample = Hk * Wkif self.offset_range_factor > 0:offset_range = torch.tensor([1.0 / Hk, 1.0 / Wk], device=device).reshape(1, 2, 1, 1)offset_q = offset_q.tanh().mul(offset_range).mul(self.offset_range_factor)offset_q = rearrange(offset_q, 'b p h w -> b h w p') # B * g 2 Hg Wg -> B*g Hg Wg 2reference = self._get_ref_points(Hk, Wk, N, dtype, device)if self.offset_range_factor >= 0:pos_k = offset_q + referenceelse:pos_k = (offset_q + reference).clamp(-1., +1.)x_sampled_q = F.grid_sample(input=x_res.reshape(N * self.n_groups, self.n_group_channels, H, W),grid=pos_k[..., (1, 0)], # y, x -> x, ymode='bilinear', align_corners=True) # B * g, Cg, Hg, Wgq_sampled = x_sampled_q.reshape(N, C, Hk, Wk)########  Bi-LEVEL Gatheringif self.auto_pad:q_sampled=q_sampled.permute(0, 2, 3, 1)Ng, Hg, Wg, Cg = q_sampled.size()pad_l = pad_t = 0pad_rg = (self.n_win - Wg % self.n_win) % self.n_winpad_bg = (self.n_win - Hg % self.n_win) % self.n_winq_sampled = F.pad(q_sampled, (0, 0, # dim=-1pad_l, pad_rg, # dim=-2pad_t, pad_bg)) # dim=-3_, Hg, Wg, _ = q_sampled.size() # padded sizeq_sampled=q_sampled.permute(0, 3, 1, 2)lepe1 = F.pad(lepe1.permute(0, 2, 3, 1), (0, 0, # dim=-1pad_l, pad_rg, # dim=-2pad_t, pad_bg)) # dim=-3lepe1=lepe1.permute(0, 3, 1, 2)pos_k = F.pad(pos_k, (0, 0, # dim=-1pad_l, pad_rg, # dim=-2pad_t, pad_bg)) # dim=-3queries_def = self.proj_q(q_sampled)  #Linnear projectionqueries_def = rearrange(queries_def, "n c (j h) (i w) -> n (j i) h w c", j=self.n_win, i=self.n_win).contiguous()q_win, k_win = queries_def.mean([2, 3]), kv[..., 0:(self.qk_dim)].mean([2, 3])r_weight, r_idx = self.router(q_win, k_win)kv_gather = self.kv_gather(r_idx=r_idx, r_weight=r_weight, kv=kv_pix)  # (n, p^2, topk, h_kv*w_kv, c )k_gather, v_gather = kv_gather.split([self.qk_dim, self.dim], dim=-1)###     Bi-level Routing MHAk = rearrange(k_gather, 'n p2 k hw (m c) -> (n p2) m c (k hw)', m=self.num_heads)v = rearrange(v_gather, 'n p2 k hw (m c) -> (n p2) m (k hw) c', m=self.num_heads)q_def = rearrange(queries_def,  'n p2 h w (m c)-> (n p2) m (h w) c',m=self.num_heads)attn_weight = (q_def * self.scale) @ kattn_weight = self.attn_act(attn_weight)out = attn_weight @ vout_def = rearrange(out, '(n j i) m (h w) c -> n (m c) (j h) (i w)', j=self.n_win, i=self.n_win, h=Hg//self.n_win, w=Wg//self.n_win).contiguous()out_def = out_def + lepe1out_def = self.unifyheads1(out_def)out_def = q_sampled + out_defout_def = out_def + self.mlp(self.norm2(out_def.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)) # (N, C, H, W)#####################################################################################################   Deformable Gathering
#############################################################################################  out_def = self.norm(out_def.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)k = self.proj_k(out_def)v = self.proj_v(out_def)k_pix_sel = rearrange(k, 'n (m c) h w -> (n m) c (h w)', m=self.num_heads)v_pix_sel = rearrange(v, 'n (m c) h w -> (n m) c (h w)', m=self.num_heads)q_pix = rearrange(q, 'n (m c) h w -> (n m) c (h w)', m=self.num_heads)attn = torch.einsum('b c m, b c n -> b m n', q_pix, k_pix_sel) # B * h, HW, Nsattn = attn.mul(self.scale)### Biasrpe_table = self.rpe_tablerpe_bias = rpe_table[None, ...].expand(N, -1, -1, -1)q_grid = self._get_q_grid(H, W, N, dtype, device)displacement = (q_grid.reshape(N * self.n_groups, H * W, 2).unsqueeze(2) - pos_k.reshape(N * self.n_groups, Hg*Wg, 2).unsqueeze(1)).mul(0.5)attn_bias = F.grid_sample(input=rearrange(rpe_bias, 'b (g c) h w -> (b g) c h w', c=self.n_group_heads, g=self.n_groups),grid=displacement[..., (1, 0)],mode='bilinear', align_corners=True) # B * g, h_g, HW, Nsattn_bias = attn_bias.reshape(N * self.num_heads, H * W, Hg*Wg)attn = attn + attn_bias### attn = F.softmax(attn, dim=2)out = torch.einsum('b m n, b c n -> b c m', attn, v_pix_sel)out = out.reshape(N,C,H,W).contiguous()out = self.proj_out(out).permute(0,2,3,1)############################################################################################## NOTE: use padding for semantic segmentation# crop padded regionif self.auto_pad and (pad_r > 0 or pad_b > 0):out = out[:, :H_in, :W_in, :].contiguous()if ret_attn_mask:return out, r_weight, r_idx, attn_weightelse:return outdef get_pe_layer(emb_dim, pe_dim=None, name='none'):if name == 'none':return nn.Identity()else:raise ValueError(f'PE name {name} is not surpported!')class Block(nn.Module):def __init__(self, dim, drop_path=0., layer_scale_init_value=-1,num_heads=8, n_win=7, qk_dim=None, qk_scale=None,kv_per_win=4, kv_downsample_ratio=4, kv_downsample_kernel=None, kv_downsample_mode='ada_avgpool',topk=4, param_attention="qkvo", param_routing=False, diff_routing=False, soft_routing=False, mlp_ratio=4, param_size='small',mlp_dwconv=False,side_dwconv=5, before_attn_dwconv=3, pre_norm=True, auto_pad=False):super().__init__()qk_dim = qk_dim or dim# modulesif before_attn_dwconv > 0:self.pos_embed1 = nn.Conv2d(dim, dim,  kernel_size=before_attn_dwconv, padding=1, groups=dim)self.pos_embed2 = nn.Conv2d(dim, dim,  kernel_size=before_attn_dwconv, padding=1, groups=dim)else:self.pos_embed = lambda x: 0self.norm1 = nn.LayerNorm(dim, eps=1e-6) # important to avoid attention collapsing#if topk > 0:if topk == 4:self.attn1 = BiLevelRoutingAttention(dim=dim, num_heads=num_heads, n_win=n_win, qk_dim=qk_dim,qk_scale=qk_scale, kv_per_win=kv_per_win, kv_downsample_ratio=kv_downsample_ratio,kv_downsample_kernel=kv_downsample_kernel, kv_downsample_mode=kv_downsample_mode,topk=1, param_attention=param_attention, param_routing=param_routing,diff_routing=diff_routing, soft_routing=soft_routing, side_dwconv=side_dwconv,auto_pad=auto_pad)self.attn2 = DeBiLevelRoutingAttention(dim=dim, num_heads=num_heads, n_win=n_win, qk_dim=qk_dim,qk_scale=qk_scale, kv_per_win=kv_per_win, kv_downsample_ratio=kv_downsample_ratio,kv_downsample_kernel=kv_downsample_kernel, kv_downsample_mode=kv_downsample_mode,topk=topk, param_attention=param_attention, param_routing=param_routing,diff_routing=diff_routing, soft_routing=soft_routing, side_dwconv=side_dwconv,auto_pad=auto_pad,param_size=param_size)elif topk == 8:self.attn1 = BiLevelRoutingAttention(dim=dim, num_heads=num_heads, n_win=n_win, qk_dim=qk_dim,qk_scale=qk_scale, kv_per_win=kv_per_win, kv_downsample_ratio=kv_downsample_ratio,kv_downsample_kernel=kv_downsample_kernel, kv_downsample_mode=kv_downsample_mode,topk=4, param_attention=param_attention, param_routing=param_routing,diff_routing=diff_routing, soft_routing=soft_routing, side_dwconv=side_dwconv,auto_pad=auto_pad)self.attn2 = DeBiLevelRoutingAttention(dim=dim, num_heads=num_heads, n_win=n_win, qk_dim=qk_dim,qk_scale=qk_scale, kv_per_win=kv_per_win, kv_downsample_ratio=kv_downsample_ratio,kv_downsample_kernel=kv_downsample_kernel, kv_downsample_mode=kv_downsample_mode,topk=topk, param_attention=param_attention, param_routing=param_routing,diff_routing=diff_routing, soft_routing=soft_routing, side_dwconv=side_dwconv,uto_pad=auto_pad,param_size=param_size)elif topk == 16:self.attn1 = BiLevelRoutingAttention(dim=dim, num_heads=num_heads, n_win=n_win, qk_dim=qk_dim,qk_scale=qk_scale, kv_per_win=kv_per_win, kv_downsample_ratio=kv_downsample_ratio,kv_downsample_kernel=kv_downsample_kernel, kv_downsample_mode=kv_downsample_mode,topk=16, param_attention=param_attention, param_routing=param_routing,diff_routing=diff_routing, soft_routing=soft_routing, side_dwconv=side_dwconv,auto_pad=auto_pad)self.attn2 = DeBiLevelRoutingAttention(dim=dim, num_heads=num_heads, n_win=n_win, qk_dim=qk_dim,qk_scale=qk_scale, kv_per_win=kv_per_win, kv_downsample_ratio=kv_downsample_ratio,kv_downsample_kernel=kv_downsample_kernel, kv_downsample_mode=kv_downsample_mode,topk=topk, param_attention=param_attention, param_routing=param_routing,diff_routing=diff_routing, soft_routing=soft_routing, side_dwconv=side_dwconv,uto_pad=auto_pad,param_size=param_size)elif topk == -1:self.attn = Attention(dim=dim)elif topk == -2:self.attn1 = DeBiLevelRoutingAttention(dim=dim, num_heads=num_heads, n_win=n_win, qk_dim=qk_dim,qk_scale=qk_scale, kv_per_win=kv_per_win, kv_downsample_ratio=kv_downsample_ratio,kv_downsample_kernel=kv_downsample_kernel, kv_downsample_mode=kv_downsample_mode,topk=49, param_attention=param_attention, param_routing=param_routing,diff_routing=diff_routing, soft_routing=soft_routing, side_dwconv=side_dwconv,uto_pad=auto_pad,param_size=param_size)self.attn2 = DeBiLevelRoutingAttention(dim=dim, num_heads=num_heads, n_win=n_win, qk_dim=qk_dim,qk_scale=qk_scale, kv_per_win=kv_per_win, kv_downsample_ratio=kv_downsample_ratio,kv_downsample_kernel=kv_downsample_kernel, kv_downsample_mode=kv_downsample_mode,topk=49, param_attention=param_attention, param_routing=param_routing,diff_routing=diff_routing, soft_routing=soft_routing, side_dwconv=side_dwconv,uto_pad=auto_pad,param_size=param_size)elif topk == 0:self.attn = nn.Sequential(Rearrange('n h w c -> n c h w'), # compatiabilitynn.Conv2d(dim, dim, 1), # pseudo qkv linearnn.Conv2d(dim, dim, 5, padding=2, groups=dim), # pseudo attentionnn.Conv2d(dim, dim, 1), # pseudo out linearRearrange('n c h w -> n h w c'))self.norm2 = nn.LayerNorm(dim, eps=1e-6)self.mlp1 = TransformerMLPWithConv(dim, mlp_ratio, 0.)self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()self.norm3 = nn.LayerNorm(dim, eps=1e-6)self.norm4 = nn.LayerNorm(dim, eps=1e-6)self.mlp2 =TransformerMLPWithConv(dim, mlp_ratio, 0.)# tricks: layer scale & pre_norm/post_normif layer_scale_init_value > 0:self.use_layer_scale = Trueself.gamma1 = nn.Parameter(layer_scale_init_value * torch.ones((dim)), requires_grad=True)self.gamma2 = nn.Parameter(layer_scale_init_value * torch.ones((dim)), requires_grad=True)self.gamma3 = nn.Parameter(layer_scale_init_value * torch.ones((dim)), requires_grad=True)self.gamma4 = nn.Parameter(layer_scale_init_value * torch.ones((dim)), requires_grad=True)else:self.use_layer_scale = Falseself.pre_norm = pre_normdef forward(self, x):"""x: NCHW tensor"""# conv pos embeddingx = x + self.pos_embed1(x)# permute to NHWC tensor for attention & mlpx = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C)# attention & mlpif self.pre_norm:if self.use_layer_scale:x = x + self.drop_path1(self.gamma1 * self.attn1(self.norm1(x))) # (N, H, W, C)x = x + self.drop_path1(self.gamma2 * self.mlp1(self.norm2(x))) # (N, H, W, C)# conv pos embeddingx = x + self.pos_embed2(x.permute(0, 3, 1, 2)).permute(0, 2, 3, 1)x = x + self.drop_path2(self.gamma3 * self.attn2(self.norm3(x))) # (N, H, W, C)x = x + self.drop_path2(self.gamma4 * self.mlp2(self.norm4(x))) # (N, H, W, C)else:x = x + self.drop_path1(self.attn1(self.norm1(x))) # (N, H, W, C)x = x + self.drop_path1(self.mlp1(self.norm2(x).permute(0, 3, 1, 2)).permute(0, 2, 3, 1)) # (N, H, W, C)# conv pos embeddingx = x + self.pos_embed2(x.permute(0, 3, 1, 2)).permute(0, 2, 3, 1)x = x + self.drop_path2(self.attn2(self.norm3(x))) # (N, H, W, C)x = x + self.drop_path2(self.mlp2(self.norm4(x).permute(0, 3, 1, 2)).permute(0, 2, 3, 1)) # (N, H, W, C)else: # https://kexue.fm/archives/9009if self.use_layer_scale:x = self.norm1(x + self.drop_path1(self.gamma1 * self.attn1(x))) # (N, H, W, C)x = self.norm2(x + self.drop_path1(self.gamma2 * self.mlp1(x))) # (N, H, W, C)# conv pos embeddingx = x + self.pos_embed2(x.permute(0, 3, 1, 2)).permute(0, 2, 3, 1)x = self.norm3(x + self.drop_path2(self.gamma3 * self.attn2(x))) # (N, H, W, C)x = self.norm4(x + self.drop_path2(self.gamma4 * self.mlp2(x))) # (N, H, W, C)else:x = self.norm1(x + self.drop_path1(self.attn1(x))) # (N, H, W, C)x = x + self.drop_path1(self.mlp1(self.norm2(x).permute(0, 3, 1, 2)).permute(0, 2, 3, 1)) # (N, H, W, C)# conv pos embeddingx = x + self.pos_embed2(x.permute(0, 3, 1, 2)).permute(0, 2, 3, 1)x = self.norm3(x + self.drop_path2(self.attn2(x))) # (N, H, W, C)x = x + self.drop_path2(self.mlp2(self.norm4(x).permute(0, 3, 1, 2)).permute(0, 2, 3, 1)) # (N, H, W, C)# permute backx = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W)return xclass DeBiFormer(nn.Module):def __init__(self, depth=[3, 4, 8, 3], in_chans=3, num_classes=1000, embed_dim=[64, 128, 320, 512],head_dim=64, qk_scale=None, representation_size=None,drop_path_rate=0., drop_rate=0.,use_checkpoint_stages=[],########n_win=7,kv_downsample_mode='ada_avgpool',kv_per_wins=[2, 2, -1, -1],topks=[8, 8, -1, -1],side_dwconv=5,layer_scale_init_value=-1,qk_dims=[None, None, None, None],param_routing=False, diff_routing=False, soft_routing=False,pre_norm=True,pe=None,pe_stages=[0],before_attn_dwconv=3,auto_pad=False,#-----------------------kv_downsample_kernels=[4, 2, 1, 1],kv_downsample_ratios=[4, 2, 1, 1], # -> kv_per_win = [2, 2, 2, 1]mlp_ratios=[4, 4, 4, 4],param_attention='qkvo',param_size='small',mlp_dwconv=False):"""Args:depth (list): depth of each stageimg_size (int, tuple): input image sizein_chans (int): number of input channelsnum_classes (int): number of classes for classification headembed_dim (list): embedding dimension of each stagehead_dim (int): head dimensionmlp_ratio (int): ratio of mlp hidden dim to embedding dimqkv_bias (bool): enable bias for qkv if Trueqk_scale (float): override default qk scale of head_dim ** -0.5 if setrepresentation_size (Optional[int]): enable and set representation layer (pre-logits) to this value if setdrop_rate (float): dropout rateattn_drop_rate (float): attention dropout ratedrop_path_rate (float): stochastic depth ratenorm_layer (nn.Module): normalization layerconv_stem (bool): whether use overlapped patch stem"""super().__init__()self.num_classes = num_classesself.num_features = self.embed_dim = embed_dim  # num_features for consistency with other models############ downsample layers (patch embeddings) ######################self.downsample_layers = nn.ModuleList()# NOTE: uniformer uses two 3*3 conv, while in many other transformers this is one 7*7 convstem = nn.Sequential(nn.Conv2d(in_chans, embed_dim[0] // 2, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)),nn.BatchNorm2d(embed_dim[0] // 2),nn.GELU(),nn.Conv2d(embed_dim[0] // 2, embed_dim[0], kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)),nn.BatchNorm2d(embed_dim[0]),)if (pe is not None) and 0 in pe_stages:stem.append(get_pe_layer(emb_dim=embed_dim[0], name=pe))if use_checkpoint_stages:stem = checkpoint_wrapper(stem)self.downsample_layers.append(stem)for i in range(3):downsample_layer = nn.Sequential(nn.Conv2d(embed_dim[i], embed_dim[i+1], kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)),nn.BatchNorm2d(embed_dim[i+1]))if (pe is not None) and i+1 in pe_stages:downsample_layer.append(get_pe_layer(emb_dim=embed_dim[i+1], name=pe))if use_checkpoint_stages:downsample_layer = checkpoint_wrapper(downsample_layer)self.downsample_layers.append(downsample_layer)##########################################################################self.stages = nn.ModuleList() # 4 feature resolution stages, each consisting of multiple residual blocksnheads= [dim // head_dim for dim in qk_dims]dp_rates=[x.item() for x in torch.linspace(0, drop_path_rate, sum(depth))]cur = 0for i in range(4):stage = nn.Sequential(*[Block(dim=embed_dim[i], drop_path=dp_rates[cur + j],layer_scale_init_value=layer_scale_init_value,topk=topks[i],num_heads=nheads[i],n_win=n_win,qk_dim=qk_dims[i],qk_scale=qk_scale,kv_per_win=kv_per_wins[i],kv_downsample_ratio=kv_downsample_ratios[i],kv_downsample_kernel=kv_downsample_kernels[i],kv_downsample_mode=kv_downsample_mode,param_attention=param_attention,param_size=param_size,param_routing=param_routing,diff_routing=diff_routing,soft_routing=soft_routing,mlp_ratio=mlp_ratios[i],mlp_dwconv=mlp_dwconv,side_dwconv=side_dwconv,before_attn_dwconv=before_attn_dwconv,pre_norm=pre_norm,auto_pad=auto_pad) for j in range(depth[i])],)if i in use_checkpoint_stages:stage = checkpoint_wrapper(stage)self.stages.append(stage)cur += depth[i]##########################################################################self.norm = nn.BatchNorm2d(embed_dim[-1])# Representation layerif representation_size:self.num_features = representation_sizeself.pre_logits = nn.Sequential(OrderedDict([('fc', nn.Linear(embed_dim, representation_size)),('act', nn.Tanh())]))else:self.pre_logits = nn.Identity()# Classifier headself.head = nn.Linear(embed_dim[-1], num_classes) if num_classes > 0 else nn.Identity()self.reset_parameters()def reset_parameters(self):for m in self.parameters():if isinstance(m, (nn.Linear, nn.Conv2d)):nn.init.kaiming_normal_(m.weight)nn.init.zeros_(m.bias)@torch.jit.ignoredef no_weight_decay(self):return {'pos_embed', 'cls_token'}def get_classifier(self):return self.headdef reset_classifier(self, num_classes, global_pool=''):self.num_classes = num_classesself.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()def forward_features(self, x):for i in range(4):x = self.downsample_layers[i](x) # res = (56, 28, 14, 7), wins = (64, 16, 4, 1)x = self.stages[i](x)x = self.norm(x)x = self.pre_logits(x)return xdef forward(self, x):x = self.forward_features(x)x = x.flatten(2).mean(-1)x = self.head(x)return x            @register_model
def debi_tiny(pretrained=False, pretrained_cfg=None, **kwargs):model = DeBiFormer(depth=[1, 1, 4, 1],embed_dim=[64, 128, 256, 512], mlp_ratios=[3, 3, 3, 3],param_size='tiny',drop_path_rate=0.,  #Drop rate#------------------------------n_win=7,kv_downsample_mode='identity',kv_per_wins=[-1, -1, -1, -1],topks=[4, 8, 16, -2],side_dwconv=5,before_attn_dwconv=3,layer_scale_init_value=-1,qk_dims=[64, 128, 256, 512],head_dim=32,param_routing=False, diff_routing=False, soft_routing=False,pre_norm=True,pe=None)return model@register_model
def debi_small(pretrained=False, pretrained_cfg=None, **kwargs):model = DeBiFormer(depth=[2, 2, 9, 3],embed_dim=[64, 128, 256, 512], mlp_ratios=[3, 3, 3, 2],param_size='small',drop_path_rate=0.3,  #Drop rate#------------------------------n_win=7,kv_downsample_mode='identity',kv_per_wins=[-1, -1, -1, -1],topks=[4, 8, 16, -2],side_dwconv=5,before_attn_dwconv=3,layer_scale_init_value=-1,qk_dims=[64, 128, 256, 512],head_dim=32,param_routing=False, diff_routing=False, soft_routing=False,pre_norm=True,pe=None)return model@register_model
def debi_base(pretrained=False, pretrained_cfg=None, **kwargs):model = DeBiFormer(depth=[2, 2, 9, 2],embed_dim=[96, 192, 384, 768], mlp_ratios=[3, 3, 3, 3],param_size='base',drop_path_rate=0.4,  #Drop rate#------------------------------n_win=7,kv_downsample_mode='identity',kv_per_wins=[-1, -1, -1, -1],topks=[4, 8, 16, -2],side_dwconv=5,before_attn_dwconv=3,layer_scale_init_value=-1,qk_dims=[96, 192, 384, 768],head_dim=32,param_routing=False, diff_routing=False, soft_routing=False,pre_norm=True,pe=None)return modelif __name__ == '__main__':from mmcv.cnn.utils import flops_countermodel = DeBiFormer(depth=[2, 2, 9, 1],embed_dim=[64, 128, 256, 512], mlp_ratios=[3, 3, 3, 2],#------------------------------n_win=7,kv_downsample_mode='identity',kv_per_wins=[-1, -1, -1, -1],topks=[4, 8, 16, -2],side_dwconv=5,before_attn_dwconv=3,layer_scale_init_value=-1,qk_dims=[64, 128, 256, 512],head_dim=32,param_routing=False, diff_routing=False, soft_routing=False,pre_norm=True,pe=None)input_shape = (3, 224, 224)flops_counter.get_model_complexity_info(model, input_shape)  

相关文章:

DeBiFormer实战:使用DeBiFormer实现图像分类任务(一)

摘要 一、论文介绍 研究背景:视觉Transformer在计算机视觉领域展现出巨大潜力,能够捕获长距离依赖关系,具有高并行性,有利于大型模型的训练和推理。现有问题:尽管大量研究设计了高效的注意力模式,但查询并…...

【go从零单排】迭代器(Iterators)

🌈Don’t worry , just coding! 内耗与overthinking只会削弱你的精力,虚度你的光阴,每天迈出一小步,回头时发现已经走了很远。 📗概念 在 Go 语言中,迭代器的实现通常不是通过语言内置的迭代器类型&#x…...

Java与HTML:构建静态网页

在Web开发领域,HTML是构建网页的基础标记语言,而Java作为一种强大的编程语言,也能够在创建HTML内容方面发挥重要作用。今天,我们就来探讨一下如何使用Java来制作一个不那么简单的静态网页。 一、项目准备 首先,我们需…...

软件测试:测试用例详解

🍅 点击文末小卡片,免费获取软件测试全套资料,资料在手,涨薪更快 一、通用测试用例八要素   1、用例编号;    2、测试项目;   3、测试标题; 4、重要级别;    5、预置…...

FreeSWITCH Ubuntu 18.04 源码编译

应朋友邀请,试了试 FreeSWITCH Ubuntu 18.04 源码编译,交的作业如下: #!/bin/bash####### Ubuntu 18.04 LTS ####### ARM64 ####### FreeSWITCH 1.10.12apt update && \ apt install -y --fix-missing git sed bison build-essentia…...

spring—boot(整合redis)

整合redis 第一步导入数据源 <!--redis--> <dependency><groupId>org.springframework.boot</groupId><artifactId>spring-boot-starter-data-redis</artifactId> </dependency> RedisConfig&#xff08;默认有RedisTemplate&#…...

Python 包镜像源

阿里云、清华大学和豆瓣之外&#xff0c;还有许多其他的 Python 包镜像源。下面是更新后的代码&#xff0c;增加了更多常用的镜像源&#xff0c;如华为云、腾讯云等 import tkinter as tk from tkinter import messagebox import os# 定义 pip 配置文件路径 pip_config_file …...

Sigrity SPEED2000 Power Ground Noise Simulation模式如何进行电源阻抗仿真分析操作指导(一)-无电容

Sigrity SPEED2000 Power Ground Noise Simulation模式如何进行电源阻抗仿真分析操作指导(一)-无电容 Sigrity Power Ground Noise Simulation模式同样可以用来观测电源网络的自阻抗&#xff0c;以下图为例进行说明 2D 视图 3D view 本例要观测的是U17端口处的自阻抗&#xff0…...

Unity3D ASTC贴图压缩格式详解

一、技术详解 ASTC&#xff08;Adaptive Scalable Texture Compression&#xff09;是一种先进的纹理压缩格式&#xff0c;特别适用于OpenGL ES 3.0及更高版本。ASTC在2012年推出&#xff0c;自那以后已经成为游戏开发中重要的纹理压缩技术。它不仅在iOS设备上得到广泛应用&am…...

Docker的轻量级可视化工具Portainer

docker目录 1 Portainer官方链接2 是什么&#xff1f;3 下载安装4 跑通一次5 后记 1 Portainer官方链接 这里给出portainer的官方链接&#xff1a;https://www.portainer.io/ portainer安装的官方链接&#xff1a;https://docs.portainer.io/start/install-ce/server/docker/l…...

udp丢包问题

udp或者tcp丢包问题监测方式&#xff1a; netstat -su 问题分析&#xff1a; 1. 内存 2. cpu 3. 发送接收缓存 动画图解 socket 缓冲区的那些事儿-CSDN博客...

儿童安全座椅行业全面深入分析

儿童安全座椅就是一种专为不同体重&#xff08;或年龄段&#xff09;的儿童设计&#xff0c;将孩子束缚在安全座椅内&#xff0c;能有效提高儿童乘车安全的座椅。欧洲强制性执行标准ECE R44/03的定义是&#xff1a;能够固定到机动车辆上&#xff0c;带有ISOFIX接口、LATCH接口的…...

【笔记】扩散模型(九):Imagen 理论与实现

论文链接&#xff1a;Photorealistic Text-to-Image Diffusion Models with Deep Language Understanding 非官方实现&#xff1a;lucidrains/imagen-pytorch Imagen 是 Google Research 的文生图工作&#xff0c;这个工作并没有沿用 Stable Diffusion 的架构&#xff0c;而是级…...

05 SQL炼金术:深入探索与实战优化

文章目录 SQL炼金术&#xff1a;深入探索与实战优化一、SQL解析与执行计划1.1 获取执行计划1.2 解读执行计划 二、统计信息与执行上下文2.1 收集统计信息2.2 执行上下文 三、SQL优化工具与实战3.1 SQL Profile3.2 Hint3.3 Plan Baselines3.4 实战优化示例 SQL炼金术&#xff1a…...

Linux用lvm格式挂载磁盘

Linux用lvm格式挂载磁盘 本次目标是将磁盘/dev/sdd以lvm格式挂载到/backup目录作为备份盘来用 1、查看当前磁盘 [rootquentin ~]# lsblk NAME MAJ:MIN RM SIZE RO TYPE MOUNTPOINT sda 8:0 0 300G 0 disk ├─sda1 8:1 0 1G…...

Xshell,Shell的相关介绍与Linux中的权限问题

目录 XShell的介绍 Shell的运行原理 Linux当中的权限问题 Linux权限的概念 Linux权限管理 文件访问者的分类&#xff08;人&#xff09; 文件类型和访问权限&#xff08;事物属性&#xff09; 文件权限值的表示方法 文件访问权限的相关设置方法 如何改变文件的访问权限…...

考研要求掌握的C语言(选择排序)

选择排序的特点 每次进行一趟排序后&#xff0c;就确定一个数据的最终位置 选择排序的原理 就是假设你是最小&#xff08;最大数据&#xff09;的下标&#xff0c;然后和其他进行比较&#xff0c;若发现还有比你还小&#xff08;或还大&#xff09;的数据&#xff0c;就更新…...

达梦8数据库适配ORACLE的8个参数

目录 1、概述 1.1 概述 1.2 实验环境 2、参数简介 3、实验部分 3.1 参数BLANK_PAD_MODE 3.2 参数COMPATIBLE_MODE 3.3 参数ORDER_BY_NULLS_FLAG 3.4 参数DATETIME_FMT_MODE 3.5 参数PL_SQLCODE_COMPATIBLE 3.6 参数CALC_AS_DECIMAL 3.7 参数ENABLE_PL_SYNONYM 3.8…...

CSS实现文字渐变效果

效果图&#xff1a; 代码&#xff1a; h1 {font-size: 100px;color:linear-gradient(gold,deeppink);background-image:linear-gradient( -gold, deeppink); /*春意盎然*///背景被裁剪成文字的前景色。background-clip:text;/*兼容内核版本较低的浏览器*/-webkit-background-c…...

3. Redis的通用命令介绍

Redis作为一个高效的键值对存储系统&#xff0c;不仅支持多种数据结构&#xff0c;还提供了丰富的通用命令&#xff0c;这些命令适用于各种场景。本文将详细介绍Redis的常用通用命令&#xff0c;并结合具体应用场景&#xff0c;帮助你理解这些命令的功能与使用时机。 1. 键(key…...

Java 语言特性(面试系列2)

一、SQL 基础 1. 复杂查询 &#xff08;1&#xff09;连接查询&#xff08;JOIN&#xff09; 内连接&#xff08;INNER JOIN&#xff09;&#xff1a;返回两表匹配的记录。 SELECT e.name, d.dept_name FROM employees e INNER JOIN departments d ON e.dept_id d.dept_id; 左…...

Zustand 状态管理库:极简而强大的解决方案

Zustand 是一个轻量级、快速和可扩展的状态管理库&#xff0c;特别适合 React 应用。它以简洁的 API 和高效的性能解决了 Redux 等状态管理方案中的繁琐问题。 核心优势对比 基本使用指南 1. 创建 Store // store.js import create from zustandconst useStore create((set)…...

JavaScript 中的 ES|QL:利用 Apache Arrow 工具

作者&#xff1a;来自 Elastic Jeffrey Rengifo 学习如何将 ES|QL 与 JavaScript 的 Apache Arrow 客户端工具一起使用。 想获得 Elastic 认证吗&#xff1f;了解下一期 Elasticsearch Engineer 培训的时间吧&#xff01; Elasticsearch 拥有众多新功能&#xff0c;助你为自己…...

CRMEB 框架中 PHP 上传扩展开发:涵盖本地上传及阿里云 OSS、腾讯云 COS、七牛云

目前已有本地上传、阿里云OSS上传、腾讯云COS上传、七牛云上传扩展 扩展入口文件 文件目录 crmeb\services\upload\Upload.php namespace crmeb\services\upload;use crmeb\basic\BaseManager; use think\facade\Config;/*** Class Upload* package crmeb\services\upload* …...

IoT/HCIP实验-3/LiteOS操作系统内核实验(任务、内存、信号量、CMSIS..)

文章目录 概述HelloWorld 工程C/C配置编译器主配置Makefile脚本烧录器主配置运行结果程序调用栈 任务管理实验实验结果osal 系统适配层osal_task_create 其他实验实验源码内存管理实验互斥锁实验信号量实验 CMISIS接口实验还是得JlINKCMSIS 简介LiteOS->CMSIS任务间消息交互…...

关键领域软件测试的突围之路:如何破解安全与效率的平衡难题

在数字化浪潮席卷全球的今天&#xff0c;软件系统已成为国家关键领域的核心战斗力。不同于普通商业软件&#xff0c;这些承载着国家安全使命的软件系统面临着前所未有的质量挑战——如何在确保绝对安全的前提下&#xff0c;实现高效测试与快速迭代&#xff1f;这一命题正考验着…...

MySQL账号权限管理指南:安全创建账户与精细授权技巧

在MySQL数据库管理中&#xff0c;合理创建用户账号并分配精确权限是保障数据安全的核心环节。直接使用root账号进行所有操作不仅危险且难以审计操作行为。今天我们来全面解析MySQL账号创建与权限分配的专业方法。 一、为何需要创建独立账号&#xff1f; 最小权限原则&#xf…...

mac 安装homebrew (nvm 及git)

mac 安装nvm 及git 万恶之源 mac 安装这些东西离不开Xcode。及homebrew 一、先说安装git步骤 通用&#xff1a; 方法一&#xff1a;使用 Homebrew 安装 Git&#xff08;推荐&#xff09; 步骤如下&#xff1a;打开终端&#xff08;Terminal.app&#xff09; 1.安装 Homebrew…...

【Android】Android 开发 ADB 常用指令

查看当前连接的设备 adb devices 连接设备 adb connect 设备IP 断开已连接的设备 adb disconnect 设备IP 安装应用 adb install 安装包的路径 卸载应用 adb uninstall 应用包名 查看已安装的应用包名 adb shell pm list packages 查看已安装的第三方应用包名 adb shell pm list…...

Proxmox Mail Gateway安装指南:从零开始配置高效邮件过滤系统

&#x1f49d;&#x1f49d;&#x1f49d;欢迎莅临我的博客&#xff0c;很高兴能够在这里和您见面&#xff01;希望您在这里可以感受到一份轻松愉快的氛围&#xff0c;不仅可以获得有趣的内容和知识&#xff0c;也可以畅所欲言、分享您的想法和见解。 推荐&#xff1a;「storms…...