当前位置: 首页 > news >正文

Swin Transformer模型详解(附pytorch实现)

写在前面

Swin Transformer(Shifted Window Transformer)是一种新颖的视觉Transformer模型,在2021年由微软亚洲研究院提出。这一模型提出了一种基于局部窗口的自注意力机制,显著改善了Vision Transformer(ViT)在处理高分辨率图像时的性能,尤其是在图像分类、物体检测等计算机视觉任务中表现出色。

Swin Transformer的最大创新之一是其引入了“平移窗口”机制,克服了传统自注意力方法在大图像处理时计算资源消耗过大的问题。这一机制使得模型能够在不同层次上以局部的方式计算自注意力,同时保持全局信息的处理能力。

在本文中,我们将通过详细的分析,介绍Swin Transformer的模型结构、核心思想及其实现,最后提供一个基于PyTorch的简单实现。

论文地址:https://arxiv.org/pdf/2103.14030

官方代码实现:https://github.com/microsoft/Swin-Transformer

Swin网络结构

如下图所示,Swin Transformer的Encoder采用分层的方式,通过多个阶段(Stage)逐渐减少特征图的分辨率,同时增加特征维度。每个Stage包含若干个Transformer Block。

每个Block通常由以下几个部分组成:

  • Window-based Self-Attention:每个Block使用窗口自注意力机制,在每个窗口内计算自注意力。这种方式减少了计算量,因为自注意力只在局部窗口内进行计算,而不是整个图像。
  • Shifted Window:为了增强不同窗口之间的联系,Swin Transformer在每一层的Block中采用了“窗口位移”策略。每一层中的窗口会偏移一定的步长,使得窗口之间的重叠区域增加,从而促进信息交流。

Patch Partition

Patch Partition 是将输入图像分割成固定大小的块(patch)并将其映射到高维空间的操作。就相当于是VIT模型当中的 Patch Embedding。

from functools import partialimport torch
import torch.nn as nn
import torch.nn.functional as F
from pyzjr.utils.FormatConver import to_2tuple
from pyzjr.nn.models.bricks.drop import DropPathLayerNorm = partial(nn.LayerNorm, eps=1e-6)class PatchPartition(nn.Module):def __init__(self, patch_size=4, in_channels=3, embed_dim=96, norm_layer=None):super().__init__()self.patch_size = to_2tuple(patch_size)self.embed_dim = embed_dimself.proj = nn.Conv2d(in_channels, self.embed_dim,kernel_size=self.patch_size, stride=self.patch_size)self.norm = norm_layer(self.embed_dim) if norm_layer else nn.Identity()def forward(self, x):_, _, H, W = x.shapeif H % self.patch_size[0] != 0:pad_h = self.patch_size[0] - H % self.patch_size[0]x = F.pad(x, (0, 0, 0, pad_h))if W % self.patch_size[1] != 0:pad_w = self.patch_size[1] - W % self.patch_size[1]x = F.pad(x, (0, pad_w, 0, 0))x = self.proj(x)     # [B, embed_dim, H/patch_size, W/patch_size]Wh, Ww = x.shape[2:]x = x.flatten(2).transpose(1, 2)    # [B, num_patches, embed_dim]# Linear Embeddingx = self.norm(x)# x = x.transpose(1, 2).view(-1, self.embed_dim, Wh, Ww)return x, Wh, Wwif __name__=="__main__":batch_size = 1in_channels = 3height, width = 30, 32patch_size = 4embed_dim = 96x = torch.randn(batch_size, in_channels, height, width)patch_partition = PatchPartition(patch_size=patch_size, in_channels=in_channels, embed_dim=embed_dim)output,_ ,_ = patch_partition(x)print(f"Output shape: {output.shape}")

Patch Merging

PatchMerging 这一层用于将输入的特征图进行下采样,类似于卷积神经网络中的池化层。

如果图像的高度或宽度是奇数,PatchMerging 会进行填充,使得其变为偶数。这是因为下采样操作需要将图像分割为以2为步长的区域。如果图像的高度或宽度是奇数,直接进行切片会导致不均匀的分割,因此需要填充以保证每个块的大小一致。

这里我们在吧如上图的相同颜色块提取并进行拼接,沿着通道维度合并成一个更大的特征,将合并后的张量重新调整形状,新的空间分辨率是原来的一半(H/2 和 W/2)。

class PatchMerging(nn.Module):def __init__(self, dim, norm_layer=LayerNorm):super().__init__()self.dim = dimself.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)self.norm = norm_layer(4 * dim)def forward(self, x, H, W):"""Args:x: Input feature, tensor size (B, H*W, C).H, W: Spatial resolution of the input feature."""B, L, C = x.shapeassert L == H * W, "input feature has wrong size"x = x.view(B, H, W, C)if H % 2 == 1 or W % 2 == 1:x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2))x0 = x[:, 0::2, 0::2, :]  # B H/2 W/2 Cx1 = x[:, 1::2, 0::2, :]  # B H/2 W/2 Cx2 = x[:, 0::2, 1::2, :]  # B H/2 W/2 Cx3 = x[:, 1::2, 1::2, :]  # B H/2 W/2 Cx = torch.cat([x0, x1, x2, x3], -1)  # B H/2 W/2 4*Cx = x.view(B, -1, 4 * C)  # B H/2*W/2 4*Cx = self.norm(x)x = self.reduction(x)return xif __name__=="__main__":batch_size = 1in_channels = 3height, width = 30, 32patch_size = 4embed_dim = 96x = torch.randn(batch_size, in_channels, height, width)patch_partition = PatchPartition(patch_size=patch_size, in_channels=in_channels, embed_dim=embed_dim)output, Wh, Ww = patch_partition(x)patch_merging = PatchMerging(dim=embed_dim)output = patch_merging(output, Wh, Ww)print(output.shape)

在代码中呢就是在高和宽的维度通过切片的形式获得,x0表示的是左上角,x1表示的是右上角,x2表示的是左下角,x3表示的是右下角。经过一系列操作后,最后通过线性层实现通道数翻倍。

W-MSA

W-MSA(Window-based Multi-Head Self-Attention)是Swin Transformer中的一个核心创新,它是为了优化传统自注意力机制在高分辨率输入图像处理中的效率问题而提出的。

\Omega (MSA) = 4hwC^{2}+2(hw)^{2}C

\Omega (W$-$MSA) = 4hwC^{2} + 2(M)^{2}hwC

这是原论文当中给出的计算公式,h,w和C分别表示特征的高度,宽度和深度,M表示窗口的大小。在标准的 Transformer 模型中,自注意力机制需要对整个输入进行计算,这使得计算和内存的消耗随着输入的增大而急剧增长。而在图像任务中,输入图像往往具有非常高的分辨率,因此直接应用标准的全局自注意力在计算上不可行。

W-MSA 通过在局部窗口内进行自注意力计算来解决这一问题,极大地减少了计算和内存开销,同时保持了模型的表示能力。

SW-MSA

SW-MSA (Shifted Window-based Multi-Head Self-Attention)结合了局部窗口化自注意力和窗口偏移(shifted)策略,既提升了计算效率,又能在捕捉局部信息的基础上,保持对全局信息的建模能力。

左侧就是刚刚说到的W-MSA,经过窗口的偏移变成了右边的SW-MSA,偏移的策略能够让模型在每一层的计算中捕捉到不同窗口之间的依赖关系,避免了 W-MSA 只能在单一窗口内计算的局限。这样,相邻窗口之间的信息就能够通过偏移和交错的方式进行交流,增强了模型的全局感知能力。

但是,现在的窗口从原来的四个变成了九个,如果对每一个窗口再进行W-MSA那就太麻烦了。为了应对这种情况,作者提出了一种 高效批处理计算方法,旨在优化窗口偏移后的大规模窗口计算。其核心思想是:通过批处理计算的方式来有效地处理这些偏移后的窗口,而不是每个窗口单独计算。

意思就是说将图中的A,B,C的位置通过偏移和交错方式变化后,可以将这些窗口的计算统一进行批处理,而不是一个一个地处理。这样可以显著减少计算时间和内存占用。

这个过程我个人感觉比较像是卡诺图,具体的过程可以看我下面画的图:

然后这里的4还和原来的一样,5和3组合成一个窗口,1和7组合成一个窗口,8、2、6、0又组合成一个窗口,这样就和原来一样是4个4x4的窗口了,保证了计算量的不变。但是如果这样做了就会将不相邻的信息混合在一起了。作者这里采用掩蔽机制将自注意力计算限制在每个子窗口内,其实就是创建一个蒙板来屏蔽信息。

Relative Position Bias

关于这一部分,作者没有怎么提,只是经过了相对位置偏移,指标有明显的提示。

关于这一部分,我是参考的官方代码以及b站的讲解视频理解的。首先需要创建一个相对位置偏置的参数表,它的范围是从[-Wh+1, Wh-1],这里的 +1 和 -1 是因为偏移量是相对于当前元素的位置而言的,当前元素自身的偏移量为0,但我们不包括0在偏移量的计算中(因为0表示没有偏移,通常会在自注意力机制中以其他方式处理)。因此,对于垂直方向(或水平方向),总的偏移量数量是 win_h(或 win_w)的正偏移量数量加上 win_h(或 win_w)的负偏移量数量,再减去一个(因为我们不计算0偏移量)。因此,相对位置偏置表的尺寸为:

[(2 * Wh - 1) * (2 * Ww - 1), num_heads]

每个元素的查询(Query)和键(Key)之间的内积会得到一个相似度分数,在这些分数的基础上,会加入相对位置偏置,调整相似度:

Attention = softmax((QK^T + Relative_Position_Bias) / sqrt(d_k))

其中,Q 是查询向量,K 是键向量,Relative_Position_Bias 是根据相对位置计算得到的偏置。加入相对位置偏置后,模型可以更好地捕捉到局部结构的依赖关系。

网络实现

"""
Copyright (c) 2025, Auorui.
All rights reserved.Swin Transformer: Hierarchical Vision Transformer using Shifted Windows<https://arxiv.org/pdf/2103.14030>
use for reference: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/swin_transformer.pyhttps://github.com/WZMIAOMIAO/deep-learning-for-image-processing/blob/master/pytorch_classification/swin_transformer/model.py
"""
from functools import partialimport torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
from pyzjr.utils.FormatConver import to_2tuple
from pyzjr.nn.models.bricks.drop import DropPath
from pyzjr.nn.models.bricks.initer import trunc_normal_LayerNorm = partial(nn.LayerNorm, eps=1e-6)class PatchPartition(nn.Module):def __init__(self, patch_size=4, in_channels=3, embed_dim=96, norm_layer=None):super().__init__()self.patch_size = to_2tuple(patch_size)self.embed_dim = embed_dimself.proj = nn.Conv2d(in_channels, self.embed_dim,kernel_size=self.patch_size, stride=self.patch_size)self.norm = norm_layer(self.embed_dim) if norm_layer else nn.Identity()def forward(self, x):_, _, H, W = x.shapeif H % self.patch_size[0] != 0:pad_h = self.patch_size[0] - H % self.patch_size[0]x = F.pad(x, (0, 0, 0, pad_h))if W % self.patch_size[1] != 0:pad_w = self.patch_size[1] - W % self.patch_size[1]x = F.pad(x, (0, pad_w, 0, 0))x = self.proj(x)     # [B, embed_dim, H/patch_size, W/patch_size]Wh, Ww = x.shape[2:]x = x.flatten(2).transpose(1, 2)    # [B, num_patches, embed_dim]# Linear Embeddingx = self.norm(x)# x = x.transpose(1, 2).view(-1, self.embed_dim, Wh, Ww)return x, Wh, Wwclass MLP(nn.Module):def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop_ratio=0.):super().__init__()out_features = out_features or in_featureshidden_features = hidden_features or in_featuresself.fc1 = nn.Linear(in_features, hidden_features)self.act = act_layer()self.fc2 = nn.Linear(hidden_features, out_features)self.drop = nn.Dropout(drop_ratio)def forward(self, x):x = self.fc1(x)x = self.act(x)x = self.drop(x)x = self.fc2(x)x = self.drop(x)return xclass PatchMerging(nn.Module):def __init__(self, dim, norm_layer=LayerNorm):super().__init__()self.dim = dimself.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)self.norm = norm_layer(4 * dim)def forward(self, x, H, W):"""Args:x: Input feature, tensor size (B, H*W, C).H, W: Spatial resolution of the input feature."""B, L, C = x.shapeassert L == H * W, "input feature has wrong size"x = x.view(B, H, W, C)if H % 2 == 1 or W % 2 == 1:x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2))x0 = x[:, 0::2, 0::2, :]  # B H/2 W/2 Cx1 = x[:, 1::2, 0::2, :]  # B H/2 W/2 Cx2 = x[:, 0::2, 1::2, :]  # B H/2 W/2 Cx3 = x[:, 1::2, 1::2, :]  # B H/2 W/2 Cx = torch.cat([x0, x1, x2, x3], -1)  # B H/2 W/2 4*Cx = x.view(B, -1, 4 * C)  # B H/2*W/2 4*Cx = self.norm(x)x = self.reduction(x)return xclass WindowAttention(nn.Module):"""Window based multi-head self attention (W-MSA) module with relative position bias.It supports shifted and non-shifted windows."""def __init__(self,dim,window_size,num_heads,qkv_bias=True,proj_bias=True,attention_dropout_ratio=0.,proj_drop=0.,):super().__init__()self.dim = dimself.window_size = to_2tuple(window_size)win_h, win_w = self.window_sizeself.num_heads = num_headshead_dim = dim // num_headsself.scale = head_dim ** -0.5# define a parameter table of relative position biasself.relative_position_bias_table = nn.Parameter(torch.zeros((2 * win_h - 1) * (2 * win_w - 1), num_heads))   # [2*Wh-1 * 2*Ww-1, nHeads]   Offset Range: -Wh+1, Wh-1self.register_buffer("relative_position_index",self.get_relative_position_index(win_h, win_w), persistent=False)trunc_normal_(self.relative_position_bias_table, std=.02)self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)self.attn_drop = nn.Dropout(attention_dropout_ratio)self.proj = nn.Linear(dim, dim, bias=proj_bias)self.proj_drop = nn.Dropout(proj_drop)self.softmax = nn.Softmax(dim=-1)def get_relative_position_index(self, win_h: int, win_w: int):# get pair-wise relative position index for each token inside the windowcoords = torch.stack(torch.meshgrid(torch.arange(win_h), torch.arange(win_w), indexing='ij'))  # 2, Wh, Wwcoords_flatten = torch.flatten(coords, 1)  # 2, Wh*Wwrelative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]  # 2, Wh*Ww, Wh*Wwrelative_coords = relative_coords.permute(1, 2, 0).contiguous()  # Wh*Ww, Wh*Ww, 2relative_coords[:, :, 0] += win_h - 1  # shift to start from 0relative_coords[:, :, 1] += win_w - 1relative_coords[:, :, 0] *= 2 * win_w - 1return relative_coords.sum(-1)  # Wh*Ww, Wh*Wwdef forward(self, x, mask=None):"""Args:x: input features with shape of (num_windows*B, N, C)mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None"""B, N, C = x.shapeqkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)q, k, v = qkv[:3]q = q * self.scaleattn = (q @ k.transpose(-2, -1))relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1)  # Wh*Ww,Wh*Ww,nHrelative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()  # nH, Wh*Ww, Wh*Wwattn = attn + relative_position_bias.unsqueeze(0)if mask is not None:nW = mask.shape[0]attn = attn.view(B // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)attn = attn.view(-1, self.num_heads, N, N)attn = self.softmax(attn)else:attn = self.softmax(attn)attn = self.attn_drop(attn)x = (attn @ v).transpose(1, 2).reshape(B, N, C)x = self.proj(x)x = self.proj_drop(x)return xdef window_partition(x, window_size: int):"""将feature map按照window_size划分成一个个没有重叠的windowArgs:x: (B, H, W, C)window_size (int): window size(M)Returns:windows: (num_windows*B, window_size, window_size, C)"""B, H, W, C = x.shapex = x.view(B, H // window_size, window_size, W // window_size, window_size, C)# permute: [B, H//Mh, Mh, W//Mw, Mw, C] -> [B, H//Mh, W//Mh, Mw, Mw, C]# view: [B, H//Mh, W//Mw, Mh, Mw, C] -> [B*num_windows, Mh, Mw, C]windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)return windowsdef window_reverse(windows, window_size: int, H: int, W: int):"""将一个个window还原成一个feature mapArgs:windows: (num_windows*B, window_size, window_size, C)window_size (int): Window size(M)H (int): Height of imageW (int): Width of imageReturns:x: (B, H, W, C)"""B = int(windows.shape[0] / (H * W / window_size / window_size))# view: [B*num_windows, Mh, Mw, C] -> [B, H//Mh, W//Mw, Mh, Mw, C]x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)# permute: [B, H//Mh, W//Mw, Mh, Mw, C] -> [B, H//Mh, Mh, W//Mw, Mw, C]# view: [B, H//Mh, Mh, W//Mw, Mw, C] -> [B, H, W, C]x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)return xclass SwinTransformerBlock(nn.Module):r""" Swin Transformer Block."""mlp_ratio = 4def __init__(self,dim,num_heads,window_size=7,shift_size=0,qkv_bias=True,proj_bias=True,attention_dropout_ratio=0.,proj_drop=0.,drop_path_ratio=0.,norm_layer=LayerNorm,act_layer=nn.GELU,):super(SwinTransformerBlock, self).__init__()self.dim = dimself.num_heads = num_headsself.window_size = window_sizeself.shift_size = shift_sizeassert 0 <= self.shift_size < window_size, "shift_size must in 0-window_size"self.norm1 = norm_layer(dim)self.attn = WindowAttention(dim,window_size=self.window_size,num_heads=num_heads,qkv_bias=qkv_bias,proj_bias=proj_bias,attention_dropout_ratio=attention_dropout_ratio,proj_drop=proj_drop,)self.drop_path = DropPath(drop_path_ratio) if drop_path_ratio > 0. else nn.Identity()self.norm2 = norm_layer(dim)mlp_hidden_dim = int(dim * self.mlp_ratio)self.mlp = MLP(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop_ratio=proj_bias)self.H = Noneself.W = Nonedef forward(self, x, mask_matrix):"""Args:x: Input feature, tensor size (B, H*W, C).H, W: Spatial resolution of the input feature.mask_matrix: Attention mask for cyclic shift."""B, L, C = x.shapeH, W = self.H, self.Wassert L == H * W, "input feature has wrong size"shortcut = xx = self.norm1(x)x = x.view(B, H, W, C)# pad feature maps to multiples of window sizepad_l = pad_t = 0pad_r = (self.window_size - W % self.window_size) % self.window_sizepad_b = (self.window_size - H % self.window_size) % self.window_sizex = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b))_, Hp, Wp, _ = x.shape# cyclic shiftif self.shift_size > 0:shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))attn_mask = mask_matrixelse:shifted_x = xattn_mask = None# partition windowsx_windows = window_partition(shifted_x, self.window_size)  # nW*B, window_size, window_size, Cx_windows = x_windows.view(-1, self.window_size * self.window_size, C)  # nW*B, window_size*window_size, C# W-MSA/SW-MSAattn_windows = self.attn(x_windows, mask=attn_mask)  # nW*B, window_size*window_size, C# merge windowsattn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)shifted_x = window_reverse(attn_windows, self.window_size, Hp, Wp)  # B H' W' C# reverse cyclic shiftif self.shift_size > 0:x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))else:x = shifted_xif pad_r > 0 or pad_b > 0:x = x[:, :H, :W, :].contiguous()x = x.view(B, H * W, C)# FFNx = shortcut + self.drop_path(x)x = x + self.drop_path(self.mlp(self.norm2(x)))return xclass BasicLayer(nn.Module):""" A basic Swin Transformer layer for one stage."""def __init__(self,dim,num_layers,num_heads,drop_path,window_size=7,qkv_bias=True,proj_bias=True,attention_dropout_ratio=0.,proj_drop=0.,norm_layer=LayerNorm,act_layer=nn.GELU,downsample=None):super().__init__()self.window_size = window_sizeself.shift_size = window_size // 2self.num_layers = num_layers# build blocksself.blocks = nn.ModuleList([SwinTransformerBlock(dim=dim,num_heads=num_heads,window_size=window_size,shift_size=0 if (i % 2 == 0) else window_size // 2,qkv_bias=qkv_bias,proj_bias=proj_bias,attention_dropout_ratio=attention_dropout_ratio,proj_drop=proj_drop,drop_path_ratio=drop_path[i] if isinstance(drop_path, list) else drop_path,norm_layer=norm_layer,act_layer=act_layer)for i in range(num_layers)])# patch merging layerif downsample is not None:self.downsample = downsample(dim=dim, norm_layer=norm_layer)else:self.downsample = Nonedef forward(self, x, H, W):""" Forward function.Args:x: Input feature, tensor size (B, H*W, C).H, W: Spatial resolution of the input feature."""# calculate attention mask for SW-MSAHp = int(np.ceil(H / self.window_size)) * self.window_sizeWp = int(np.ceil(W / self.window_size)) * self.window_sizeimg_mask = torch.zeros((1, Hp, Wp, 1), device=x.device)  # 1 Hp Wp 1h_slices = (slice(0, -self.window_size),slice(-self.window_size, -self.shift_size),slice(-self.shift_size, None))w_slices = (slice(0, -self.window_size),slice(-self.window_size, -self.shift_size),slice(-self.shift_size, None))cnt = 0for h in h_slices:for w in w_slices:img_mask[:, h, w, :] = cntcnt += 1mask_windows = window_partition(img_mask, self.window_size)  # nW, window_size, window_size, 1mask_windows = mask_windows.view(-1, self.window_size * self.window_size)attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))for blk in self.blocks:blk.H, blk.W = H, Wx = blk(x, attn_mask)if self.downsample is not None:x = self.downsample(x, H, W)H, W = (H + 1) // 2, (W + 1) // 2return x, H, Wclass SwinTransformer(nn.Module):""" Swin Transformer backbone."""def __init__(self,patch_size=4,in_channels=3,num_classes=1000,embed_dim=96,depths=(2, 2, 6, 2),num_heads=(3, 6, 12, 24),window_size=7,qkv_bias=True,proj_bias=True,attention_dropout_ratio=0.,proj_drop=0.,drop_path_rate=0.2,norm_layer=LayerNorm,patch_norm=True,):super().__init__()self.num_classes = num_classesself.num_layers = len(depths)self.num_layers = len(depths)self.embed_dim = embed_dimself.patch_norm = patch_norm# stage4输出特征矩阵的channelsself.num_features = int(embed_dim * 2 ** (self.num_layers - 1))# split image into non-overlapping patchesself.patch_embed = PatchPartition(patch_size=patch_size, in_channels=in_channels, embed_dim=embed_dim,norm_layer=norm_layer if self.patch_norm else None)self.pos_drop = nn.Dropout(p=proj_drop)dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]layers = []for i_layer in range(self.num_layers):layer = BasicLayer(dim=int(embed_dim * 2 ** i_layer),num_layers=depths[i_layer],num_heads=num_heads[i_layer],window_size=window_size,qkv_bias=qkv_bias,proj_bias=proj_bias,attention_dropout_ratio=attention_dropout_ratio,proj_drop=proj_drop,drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],norm_layer=norm_layer,downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,)layers.append(layer)self.layers = nn.Sequential(*layers)self.norm = norm_layer(self.num_features)self.avgpool = nn.AdaptiveAvgPool1d(1)self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()self._initialize_weights()def _initialize_weights(self):for m in self.modules():if isinstance(m, nn.Linear):trunc_normal_(m.weight, std=.02)if isinstance(m, nn.Linear) and m.bias is not None:nn.init.constant_(m.bias, 0)elif isinstance(m, nn.LayerNorm):nn.init.constant_(m.bias, 0)nn.init.constant_(m.weight, 1.0)def forward(self, x):# x: [B, L, C]x, H, W = self.patch_embed(x)x = self.pos_drop(x)for layer in self.layers:x, H, W = layer(x, H, W)x = self.norm(x)  # [B, L, C]x = self.avgpool(x.transpose(1, 2))x = torch.flatten(x, 1)x = self.head(x)return xdef swin_t(num_classes) -> SwinTransformer:model = SwinTransformer(in_channels=3,patch_size=4,window_size=7,embed_dim=96,depths=(2, 2, 6, 2),num_heads=(3, 6, 12, 24),num_classes=num_classes)return modeldef swin_s(num_classes) -> SwinTransformer:model = SwinTransformer(in_channels=3,patch_size=4,window_size=7,embed_dim=96,depths=(2, 2, 18, 2),num_heads=(3, 6, 12, 24),num_classes=num_classes)return modeldef swin_b(num_classes) -> SwinTransformer:model = SwinTransformer(in_channels=3,patch_size=4,window_size=7,embed_dim=128,depths=(2, 2, 18, 2),num_heads=(4, 8, 16, 32),num_classes=num_classes)return modeldef swin_l(num_classes) -> SwinTransformer:model = SwinTransformer(in_channels=3,patch_size=4,window_size=7,embed_dim=192,depths=(2, 2, 18, 2),num_heads=(6, 12, 24, 48),num_classes=num_classes)return modelif __name__=="__main__":import pyzjrdevice = 'cuda' if torch.cuda.is_available() else 'cpu'input = torch.ones(2, 3, 224, 224).to(device)net = swin_l(num_classes=4)net = net.to(device)out = net(input)print(out)print(out.shape)pyzjr.summary_1(net, input_size=(3, 224, 224))# swin_t Total params: 27,499,108# swin_s Total params: 48,792,676# swin_b Total params: 86,683,780# swin_l Total params: 194,906,308

参考文章

Swin-Transformer网络结构详解_swin transformer-CSDN博客

Swin-transformer详解_swin transformer-CSDN博客 

【深度学习】详解 Swin Transformer (SwinT)-CSDN博客 

推荐的视频:12.1 Swin-Transformer网络结构详解_哔哩哔哩_bilibili 

相关文章:

Swin Transformer模型详解(附pytorch实现)

写在前面 Swin Transformer&#xff08;Shifted Window Transformer&#xff09;是一种新颖的视觉Transformer模型&#xff0c;在2021年由微软亚洲研究院提出。这一模型提出了一种基于局部窗口的自注意力机制&#xff0c;显著改善了Vision Transformer&#xff08;ViT&#xf…...

gitee 使用教程

前言 Gitee 是一个中国的开源代码托管平台&#xff0c;类似于 GitHub&#xff0c;旨在为开发者提供一个高效、稳定、安全的代码管理和协作开发环境。Gitee 支持 Git 协议&#xff0c;可以托管 Git 仓库&#xff0c;进行版本控制、代码协作、项目管理等操作。 1. Gitee 的主要…...

基于YOLOv8的水下目标检测系统

基于YOLOv8的水下目标检测系统 (价格90) 使用的是DUO水下目标检测数据集 训练集 6671张 验证集 1111张 测试集 1111张 包含 [holothurian, echinus, scallop, starfish] [海参, 海胆, 扇贝, 海星] 4个类 通过PYQT构建UI界面&#xff0c;包含图片检测&#xff0c;视…...

浅析PCIe链路均衡技术原理与演进

在现代计算机硬件体系的持续演进中&#xff0c;PCIe技术始终扮演着核心角色&#xff0c;其作为连接 CPU 与各类周边设备的关键高速通信链路&#xff0c;不断推动着计算机性能边界的拓展。而 PCIe Link Equalization均衡技术&#xff0c;作为保障数据在高速传输过程中准确性与稳…...

js代理模式

允许在不改变原始对象的情况下&#xff0c;通过代理对象来访问原始对象。代理对象可以在访问原始对象之前或之后&#xff0c;添加一些额外的逻辑或功能。 科学上网过程 一般情况下,在访问国外的网站,会显示无法访问 因为在dns解析过程,这些ip被禁止解析,所以显示无法访问 引…...

C++虚函数(八股总结)

什么是虚函数 虚函数是在父类中定义的一种特殊类型的函数&#xff0c;允许子类重写该函数以适应其自身需求。虚函数的调用取决于对象的实际类型&#xff0c;而不是指针或引用类型。通过将函数声明为虚函数&#xff0c;可以使继承层次结构中的每个子类都能够使用其自己的实现&a…...

vue的路由守卫逻辑处理不当导致部署在nginx上无法捕捉后端异步响应消息等问题

近期对前端的路由卫士有了更多的认识。 何为路由守卫&#xff1f;这可能是一种约定俗成的名称。就是VUE中的自定义函数&#xff0c;用来处理路由跳转。 import { createRouter, createWebHashHistory } from "vue-router";const router createRouter({history: cr…...

[备忘.OFD]OFD是什么、OFD与PDF格式文件的互转换

‌OFD&#xff08;Open Fixed-layout Document&#xff09;是一种由工业和信息化部软件司牵头中国电子技术标准化研究院制定的版式文档国家标准&#xff0c;属于中国的一种自主格式‌‌。OFD旨在打破政府部门和党委机关电子公文格式不统一的问题&#xff0c;以方便电子文档的存…...

Pycharm连接远程解释器

这里写目录标题 0 前言1 给项目添加解释器2 通过SSH连接3 找到远程服务器的torch环境所对应的python路径&#xff0c;并设置同步映射&#xff08;1&#xff09;配置服务器的系统环境&#xff08;2&#xff09;配置服务器的conda环境 4 进入到程序入口&#xff08;main.py&#…...

嵌入式系统 tensorflow

&#x1f3ac; 秋野酱&#xff1a;《个人主页》 &#x1f525; 个人专栏:《Java专栏》《Python专栏》 ⛺️心若有所向往,何惧道阻且长 文章目录 探索嵌入式系统中的 TensorFlow&#xff1a;机遇与挑战一、TensorFlow 适配嵌入式的优势二、面临的硬件瓶颈三、软件优化策略四、实…...

深度学习知识点:LSTM

文章目录 1.应用现状2.发展历史3.基本结构4.LSTM和RNN的差异 1.应用现状 长短期记忆神经网络&#xff08;LSTM&#xff09;是一种特殊的循环神经网络(RNN)。原始的RNN在训练中&#xff0c;随着训练时间的加长以及网络层数的增多&#xff0c;很容易出现梯度爆炸或者梯度消失的问…...

11.C语言内存管理与常用内存操作函数解析

目录 1.简介2.void 指针3.malloc4.free5.calloc6.realloc7.restrict 说明符8.memcpy9.memmove()10.memcmp 1.简介 本篇原文为&#xff1a;C语言内存管理与常用内存操作函数解析。 更多C进阶、rust、python、逆向等等教程&#xff0c;可点击此链接查看&#xff1a;酷程网 C 语…...

Python 中的错误处理与调试技巧

&#x1f496; 欢迎来到我的博客&#xff01; 非常高兴能在这里与您相遇。在这里&#xff0c;您不仅能获得有趣的技术分享&#xff0c;还能感受到轻松愉快的氛围。无论您是编程新手&#xff0c;还是资深开发者&#xff0c;都能在这里找到属于您的知识宝藏&#xff0c;学习和成长…...

门禁系统与消防报警的几种联动方式

1、规范中要求的出入口系统与消防联动 1.1《建筑设计防火规范》GB 50016-2018 1.2《民用建筑电气设计规范》JGJ 16-2008  14.4出入口控制系统 3 设置在平安疏散口的出入口限制装置&#xff0c;应与火灾自动报警系统联动;在紧急状况下应自动释放出入口限制系统&…...

云原生安全风险分析

一、什么是云原生安全 云原生安全包含两层含义&#xff1a; 面向云原生环境的安全具有云原生特征的安全 0x1&#xff1a;面向云原生环境的安全 面向云原生环境的安全的目标是防护云原生环境中基础设施、编排系统和微服务等系统的安全。 这类安全机制不一定具备云原生的特性…...

解决cursor50次使用限制问题并恢复账号次数

视频内容&#xff1a; 在这个视频教程中&#xff0c;我们将演示如何解决科sir软件50次使用限制的问题&#xff0c;具体步骤包括删除和注销账号、重新登录并刷新次数。教程详细展示了如何使用官网操作将账号的剩余次数恢复到250次&#xff0c;并进行软件功能测试。通过简单的操…...

python学习笔记—16—数据容器之元组

1. 元组——tuple(元组是一个只读的list) (1) 元组的定义注意&#xff1a;定义单个元素的元组&#xff0c;在元素后面要加上 , (2) 元组也支持嵌套 (3) 下标索引取出元素 (4) 元组的相关操作 1. index——查看元组中某个元素在元组中的位置从左到右第一次出现的位置 t1 (&qu…...

rabbitmq——岁月云实战笔记

1 rabbitmq设计 生产者并不是直接将消息投递到queue,而是发送给exchange,由exchange根据type的规则来选定投递的queue,这样消息设计在生产者和消费者就实现解耦。 rabbitmq会给没有type预定义一些exchage,而实际我们却应该使用自己定义的。 1.1 用户注册设计 用户在…...

Matlab APP Designer

我想给聚类的代码加一个图形化界面&#xff0c;需要输入一些数据和一些参数并输出聚类后的图像和一些评价指标的值。 gpt说 可以用 app designer 界面元素设计 在 设计视图 中直接拖动即可 如图1&#xff0c;我拖进去一个 按钮 &#xff0c;图2 红色部分 出现一行 Button 图…...

CSS语言的编程范式

CSS语言的编程范式 引言 在现代网页开发中&#xff0c;CSS&#xff08;层叠样式表&#xff09;作为一种样式语言&#xff0c;承担着网站前端呈现的重要角色。无论是简单的静态网页还是复杂的单页应用&#xff0c;CSS都在人机交互中发挥着至关重要的作用。掩盖在美观背后的&am…...

MMaDA: Multimodal Large Diffusion Language Models

CODE &#xff1a; https://github.com/Gen-Verse/MMaDA Abstract 我们介绍了一种新型的多模态扩散基础模型MMaDA&#xff0c;它被设计用于在文本推理、多模态理解和文本到图像生成等不同领域实现卓越的性能。该方法的特点是三个关键创新:(i) MMaDA采用统一的扩散架构&#xf…...

Java - Mysql数据类型对应

Mysql数据类型java数据类型备注整型INT/INTEGERint / java.lang.Integer–BIGINTlong/java.lang.Long–––浮点型FLOATfloat/java.lang.FloatDOUBLEdouble/java.lang.Double–DECIMAL/NUMERICjava.math.BigDecimal字符串型CHARjava.lang.String固定长度字符串VARCHARjava.lang…...

vue3 定时器-定义全局方法 vue+ts

1.创建ts文件 路径&#xff1a;src/utils/timer.ts 完整代码&#xff1a; import { onUnmounted } from vuetype TimerCallback (...args: any[]) > voidexport function useGlobalTimer() {const timers: Map<number, NodeJS.Timeout> new Map()// 创建定时器con…...

WEB3全栈开发——面试专业技能点P2智能合约开发(Solidity)

一、Solidity合约开发 下面是 Solidity 合约开发 的概念、代码示例及讲解&#xff0c;适合用作学习或写简历项目背景说明。 &#x1f9e0; 一、概念简介&#xff1a;Solidity 合约开发 Solidity 是一种专门为 以太坊&#xff08;Ethereum&#xff09;平台编写智能合约的高级编…...

涂鸦T5AI手搓语音、emoji、otto机器人从入门到实战

“&#x1f916;手搓TuyaAI语音指令 &#x1f60d;秒变表情包大师&#xff0c;让萌系Otto机器人&#x1f525;玩出智能新花样&#xff01;开整&#xff01;” &#x1f916; Otto机器人 → 直接点明主体 手搓TuyaAI语音 → 强调 自主编程/自定义 语音控制&#xff08;TuyaAI…...

视觉slam十四讲实践部分记录——ch2、ch3

ch2 一、使用g++编译.cpp为可执行文件并运行(P30) g++ helloSLAM.cpp ./a.out运行 二、使用cmake编译 mkdir build cd build cmake .. makeCMakeCache.txt 文件仍然指向旧的目录。这表明在源代码目录中可能还存在旧的 CMakeCache.txt 文件,或者在构建过程中仍然引用了旧的路…...

基于Java+MySQL实现(GUI)客户管理系统

客户资料管理系统的设计与实现 第一章 需求分析 1.1 需求总体介绍 本项目为了方便维护客户信息为了方便维护客户信息&#xff0c;对客户进行统一管理&#xff0c;可以把所有客户信息录入系统&#xff0c;进行维护和统计功能。可通过文件的方式保存相关录入数据&#xff0c;对…...

A2A JS SDK 完整教程:快速入门指南

目录 什么是 A2A JS SDK?A2A JS 安装与设置A2A JS 核心概念创建你的第一个 A2A JS 代理A2A JS 服务端开发A2A JS 客户端使用A2A JS 高级特性A2A JS 最佳实践A2A JS 故障排除 什么是 A2A JS SDK? A2A JS SDK 是一个专为 JavaScript/TypeScript 开发者设计的强大库&#xff…...

免费PDF转图片工具

免费PDF转图片工具 一款简单易用的PDF转图片工具&#xff0c;可以将PDF文件快速转换为高质量PNG图片。无需安装复杂的软件&#xff0c;也不需要在线上传文件&#xff0c;保护您的隐私。 工具截图 主要特点 &#x1f680; 快速转换&#xff1a;本地转换&#xff0c;无需等待上…...

RabbitMQ入门4.1.0版本(基于java、SpringBoot操作)

RabbitMQ 一、RabbitMQ概述 RabbitMQ RabbitMQ最初由LShift和CohesiveFT于2007年开发&#xff0c;后来由Pivotal Software Inc.&#xff08;现为VMware子公司&#xff09;接管。RabbitMQ 是一个开源的消息代理和队列服务器&#xff0c;用 Erlang 语言编写。广泛应用于各种分布…...