Swin Transformer模型详解(附pytorch实现)
写在前面
Swin Transformer(Shifted Window Transformer)是一种新颖的视觉Transformer模型,在2021年由微软亚洲研究院提出。这一模型提出了一种基于局部窗口的自注意力机制,显著改善了Vision Transformer(ViT)在处理高分辨率图像时的性能,尤其是在图像分类、物体检测等计算机视觉任务中表现出色。
Swin Transformer的最大创新之一是其引入了“平移窗口”机制,克服了传统自注意力方法在大图像处理时计算资源消耗过大的问题。这一机制使得模型能够在不同层次上以局部的方式计算自注意力,同时保持全局信息的处理能力。
在本文中,我们将通过详细的分析,介绍Swin Transformer的模型结构、核心思想及其实现,最后提供一个基于PyTorch的简单实现。
论文地址:https://arxiv.org/pdf/2103.14030
官方代码实现:https://github.com/microsoft/Swin-Transformer
Swin网络结构
如下图所示,Swin Transformer的Encoder采用分层的方式,通过多个阶段(Stage)逐渐减少特征图的分辨率,同时增加特征维度。每个Stage包含若干个Transformer Block。
每个Block通常由以下几个部分组成:
- Window-based Self-Attention:每个Block使用窗口自注意力机制,在每个窗口内计算自注意力。这种方式减少了计算量,因为自注意力只在局部窗口内进行计算,而不是整个图像。
- Shifted Window:为了增强不同窗口之间的联系,Swin Transformer在每一层的Block中采用了“窗口位移”策略。每一层中的窗口会偏移一定的步长,使得窗口之间的重叠区域增加,从而促进信息交流。
Patch Partition
Patch Partition 是将输入图像分割成固定大小的块(patch)并将其映射到高维空间的操作。就相当于是VIT模型当中的 Patch Embedding。
from functools import partialimport torch
import torch.nn as nn
import torch.nn.functional as F
from pyzjr.utils.FormatConver import to_2tuple
from pyzjr.nn.models.bricks.drop import DropPathLayerNorm = partial(nn.LayerNorm, eps=1e-6)class PatchPartition(nn.Module):def __init__(self, patch_size=4, in_channels=3, embed_dim=96, norm_layer=None):super().__init__()self.patch_size = to_2tuple(patch_size)self.embed_dim = embed_dimself.proj = nn.Conv2d(in_channels, self.embed_dim,kernel_size=self.patch_size, stride=self.patch_size)self.norm = norm_layer(self.embed_dim) if norm_layer else nn.Identity()def forward(self, x):_, _, H, W = x.shapeif H % self.patch_size[0] != 0:pad_h = self.patch_size[0] - H % self.patch_size[0]x = F.pad(x, (0, 0, 0, pad_h))if W % self.patch_size[1] != 0:pad_w = self.patch_size[1] - W % self.patch_size[1]x = F.pad(x, (0, pad_w, 0, 0))x = self.proj(x) # [B, embed_dim, H/patch_size, W/patch_size]Wh, Ww = x.shape[2:]x = x.flatten(2).transpose(1, 2) # [B, num_patches, embed_dim]# Linear Embeddingx = self.norm(x)# x = x.transpose(1, 2).view(-1, self.embed_dim, Wh, Ww)return x, Wh, Wwif __name__=="__main__":batch_size = 1in_channels = 3height, width = 30, 32patch_size = 4embed_dim = 96x = torch.randn(batch_size, in_channels, height, width)patch_partition = PatchPartition(patch_size=patch_size, in_channels=in_channels, embed_dim=embed_dim)output,_ ,_ = patch_partition(x)print(f"Output shape: {output.shape}")
Patch Merging
PatchMerging 这一层用于将输入的特征图进行下采样,类似于卷积神经网络中的池化层。
如果图像的高度或宽度是奇数,PatchMerging 会进行填充,使得其变为偶数。这是因为下采样操作需要将图像分割为以2为步长的区域。如果图像的高度或宽度是奇数,直接进行切片会导致不均匀的分割,因此需要填充以保证每个块的大小一致。
这里我们在吧如上图的相同颜色块提取并进行拼接,沿着通道维度合并成一个更大的特征,将合并后的张量重新调整形状,新的空间分辨率是原来的一半(H/2 和 W/2)。
class PatchMerging(nn.Module):def __init__(self, dim, norm_layer=LayerNorm):super().__init__()self.dim = dimself.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)self.norm = norm_layer(4 * dim)def forward(self, x, H, W):"""Args:x: Input feature, tensor size (B, H*W, C).H, W: Spatial resolution of the input feature."""B, L, C = x.shapeassert L == H * W, "input feature has wrong size"x = x.view(B, H, W, C)if H % 2 == 1 or W % 2 == 1:x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2))x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 Cx1 = x[:, 1::2, 0::2, :] # B H/2 W/2 Cx2 = x[:, 0::2, 1::2, :] # B H/2 W/2 Cx3 = x[:, 1::2, 1::2, :] # B H/2 W/2 Cx = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*Cx = x.view(B, -1, 4 * C) # B H/2*W/2 4*Cx = self.norm(x)x = self.reduction(x)return xif __name__=="__main__":batch_size = 1in_channels = 3height, width = 30, 32patch_size = 4embed_dim = 96x = torch.randn(batch_size, in_channels, height, width)patch_partition = PatchPartition(patch_size=patch_size, in_channels=in_channels, embed_dim=embed_dim)output, Wh, Ww = patch_partition(x)patch_merging = PatchMerging(dim=embed_dim)output = patch_merging(output, Wh, Ww)print(output.shape)
在代码中呢就是在高和宽的维度通过切片的形式获得,x0表示的是左上角,x1表示的是右上角,x2表示的是左下角,x3表示的是右下角。经过一系列操作后,最后通过线性层实现通道数翻倍。
W-MSA
W-MSA(Window-based Multi-Head Self-Attention)是Swin Transformer中的一个核心创新,它是为了优化传统自注意力机制在高分辨率输入图像处理中的效率问题而提出的。
这是原论文当中给出的计算公式,h,w和C分别表示特征的高度,宽度和深度,M表示窗口的大小。在标准的 Transformer 模型中,自注意力机制需要对整个输入进行计算,这使得计算和内存的消耗随着输入的增大而急剧增长。而在图像任务中,输入图像往往具有非常高的分辨率,因此直接应用标准的全局自注意力在计算上不可行。
W-MSA 通过在局部窗口内进行自注意力计算来解决这一问题,极大地减少了计算和内存开销,同时保持了模型的表示能力。
SW-MSA
SW-MSA (Shifted Window-based Multi-Head Self-Attention)结合了局部窗口化自注意力和窗口偏移(shifted)策略,既提升了计算效率,又能在捕捉局部信息的基础上,保持对全局信息的建模能力。
左侧就是刚刚说到的W-MSA,经过窗口的偏移变成了右边的SW-MSA,偏移的策略能够让模型在每一层的计算中捕捉到不同窗口之间的依赖关系,避免了 W-MSA 只能在单一窗口内计算的局限。这样,相邻窗口之间的信息就能够通过偏移和交错的方式进行交流,增强了模型的全局感知能力。
但是,现在的窗口从原来的四个变成了九个,如果对每一个窗口再进行W-MSA那就太麻烦了。为了应对这种情况,作者提出了一种 高效批处理计算方法,旨在优化窗口偏移后的大规模窗口计算。其核心思想是:通过批处理计算的方式来有效地处理这些偏移后的窗口,而不是每个窗口单独计算。
意思就是说将图中的A,B,C的位置通过偏移和交错方式变化后,可以将这些窗口的计算统一进行批处理,而不是一个一个地处理。这样可以显著减少计算时间和内存占用。
这个过程我个人感觉比较像是卡诺图,具体的过程可以看我下面画的图:
然后这里的4还和原来的一样,5和3组合成一个窗口,1和7组合成一个窗口,8、2、6、0又组合成一个窗口,这样就和原来一样是4个4x4的窗口了,保证了计算量的不变。但是如果这样做了就会将不相邻的信息混合在一起了。作者这里采用掩蔽机制将自注意力计算限制在每个子窗口内,其实就是创建一个蒙板来屏蔽信息。
Relative Position Bias
关于这一部分,作者没有怎么提,只是经过了相对位置偏移,指标有明显的提示。
关于这一部分,我是参考的官方代码以及b站的讲解视频理解的。首先需要创建一个相对位置偏置的参数表,它的范围是从[-Wh+1, Wh-1],这里的 +1 和 -1 是因为偏移量是相对于当前元素的位置而言的,当前元素自身的偏移量为0,但我们不包括0在偏移量的计算中(因为0表示没有偏移,通常会在自注意力机制中以其他方式处理)。因此,对于垂直方向(或水平方向),总的偏移量数量是 win_h(或 win_w)的正偏移量数量加上 win_h(或 win_w)的负偏移量数量,再减去一个(因为我们不计算0偏移量)。因此,相对位置偏置表的尺寸为:
[(2 * Wh - 1) * (2 * Ww - 1), num_heads]
每个元素的查询(Query)和键(Key)之间的内积会得到一个相似度分数,在这些分数的基础上,会加入相对位置偏置,调整相似度:
Attention = softmax((QK^T + Relative_Position_Bias) / sqrt(d_k))
其中,Q 是查询向量,K 是键向量,Relative_Position_Bias 是根据相对位置计算得到的偏置。加入相对位置偏置后,模型可以更好地捕捉到局部结构的依赖关系。
网络实现
"""
Copyright (c) 2025, Auorui.
All rights reserved.Swin Transformer: Hierarchical Vision Transformer using Shifted Windows<https://arxiv.org/pdf/2103.14030>
use for reference: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/swin_transformer.pyhttps://github.com/WZMIAOMIAO/deep-learning-for-image-processing/blob/master/pytorch_classification/swin_transformer/model.py
"""
from functools import partialimport torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
from pyzjr.utils.FormatConver import to_2tuple
from pyzjr.nn.models.bricks.drop import DropPath
from pyzjr.nn.models.bricks.initer import trunc_normal_LayerNorm = partial(nn.LayerNorm, eps=1e-6)class PatchPartition(nn.Module):def __init__(self, patch_size=4, in_channels=3, embed_dim=96, norm_layer=None):super().__init__()self.patch_size = to_2tuple(patch_size)self.embed_dim = embed_dimself.proj = nn.Conv2d(in_channels, self.embed_dim,kernel_size=self.patch_size, stride=self.patch_size)self.norm = norm_layer(self.embed_dim) if norm_layer else nn.Identity()def forward(self, x):_, _, H, W = x.shapeif H % self.patch_size[0] != 0:pad_h = self.patch_size[0] - H % self.patch_size[0]x = F.pad(x, (0, 0, 0, pad_h))if W % self.patch_size[1] != 0:pad_w = self.patch_size[1] - W % self.patch_size[1]x = F.pad(x, (0, pad_w, 0, 0))x = self.proj(x) # [B, embed_dim, H/patch_size, W/patch_size]Wh, Ww = x.shape[2:]x = x.flatten(2).transpose(1, 2) # [B, num_patches, embed_dim]# Linear Embeddingx = self.norm(x)# x = x.transpose(1, 2).view(-1, self.embed_dim, Wh, Ww)return x, Wh, Wwclass MLP(nn.Module):def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop_ratio=0.):super().__init__()out_features = out_features or in_featureshidden_features = hidden_features or in_featuresself.fc1 = nn.Linear(in_features, hidden_features)self.act = act_layer()self.fc2 = nn.Linear(hidden_features, out_features)self.drop = nn.Dropout(drop_ratio)def forward(self, x):x = self.fc1(x)x = self.act(x)x = self.drop(x)x = self.fc2(x)x = self.drop(x)return xclass PatchMerging(nn.Module):def __init__(self, dim, norm_layer=LayerNorm):super().__init__()self.dim = dimself.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)self.norm = norm_layer(4 * dim)def forward(self, x, H, W):"""Args:x: Input feature, tensor size (B, H*W, C).H, W: Spatial resolution of the input feature."""B, L, C = x.shapeassert L == H * W, "input feature has wrong size"x = x.view(B, H, W, C)if H % 2 == 1 or W % 2 == 1:x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2))x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 Cx1 = x[:, 1::2, 0::2, :] # B H/2 W/2 Cx2 = x[:, 0::2, 1::2, :] # B H/2 W/2 Cx3 = x[:, 1::2, 1::2, :] # B H/2 W/2 Cx = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*Cx = x.view(B, -1, 4 * C) # B H/2*W/2 4*Cx = self.norm(x)x = self.reduction(x)return xclass WindowAttention(nn.Module):"""Window based multi-head self attention (W-MSA) module with relative position bias.It supports shifted and non-shifted windows."""def __init__(self,dim,window_size,num_heads,qkv_bias=True,proj_bias=True,attention_dropout_ratio=0.,proj_drop=0.,):super().__init__()self.dim = dimself.window_size = to_2tuple(window_size)win_h, win_w = self.window_sizeself.num_heads = num_headshead_dim = dim // num_headsself.scale = head_dim ** -0.5# define a parameter table of relative position biasself.relative_position_bias_table = nn.Parameter(torch.zeros((2 * win_h - 1) * (2 * win_w - 1), num_heads)) # [2*Wh-1 * 2*Ww-1, nHeads] Offset Range: -Wh+1, Wh-1self.register_buffer("relative_position_index",self.get_relative_position_index(win_h, win_w), persistent=False)trunc_normal_(self.relative_position_bias_table, std=.02)self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)self.attn_drop = nn.Dropout(attention_dropout_ratio)self.proj = nn.Linear(dim, dim, bias=proj_bias)self.proj_drop = nn.Dropout(proj_drop)self.softmax = nn.Softmax(dim=-1)def get_relative_position_index(self, win_h: int, win_w: int):# get pair-wise relative position index for each token inside the windowcoords = torch.stack(torch.meshgrid(torch.arange(win_h), torch.arange(win_w), indexing='ij')) # 2, Wh, Wwcoords_flatten = torch.flatten(coords, 1) # 2, Wh*Wwrelative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Wwrelative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2relative_coords[:, :, 0] += win_h - 1 # shift to start from 0relative_coords[:, :, 1] += win_w - 1relative_coords[:, :, 0] *= 2 * win_w - 1return relative_coords.sum(-1) # Wh*Ww, Wh*Wwdef forward(self, x, mask=None):"""Args:x: input features with shape of (num_windows*B, N, C)mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None"""B, N, C = x.shapeqkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)q, k, v = qkv[:3]q = q * self.scaleattn = (q @ k.transpose(-2, -1))relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nHrelative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Wwattn = attn + relative_position_bias.unsqueeze(0)if mask is not None:nW = mask.shape[0]attn = attn.view(B // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)attn = attn.view(-1, self.num_heads, N, N)attn = self.softmax(attn)else:attn = self.softmax(attn)attn = self.attn_drop(attn)x = (attn @ v).transpose(1, 2).reshape(B, N, C)x = self.proj(x)x = self.proj_drop(x)return xdef window_partition(x, window_size: int):"""将feature map按照window_size划分成一个个没有重叠的windowArgs:x: (B, H, W, C)window_size (int): window size(M)Returns:windows: (num_windows*B, window_size, window_size, C)"""B, H, W, C = x.shapex = x.view(B, H // window_size, window_size, W // window_size, window_size, C)# permute: [B, H//Mh, Mh, W//Mw, Mw, C] -> [B, H//Mh, W//Mh, Mw, Mw, C]# view: [B, H//Mh, W//Mw, Mh, Mw, C] -> [B*num_windows, Mh, Mw, C]windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)return windowsdef window_reverse(windows, window_size: int, H: int, W: int):"""将一个个window还原成一个feature mapArgs:windows: (num_windows*B, window_size, window_size, C)window_size (int): Window size(M)H (int): Height of imageW (int): Width of imageReturns:x: (B, H, W, C)"""B = int(windows.shape[0] / (H * W / window_size / window_size))# view: [B*num_windows, Mh, Mw, C] -> [B, H//Mh, W//Mw, Mh, Mw, C]x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)# permute: [B, H//Mh, W//Mw, Mh, Mw, C] -> [B, H//Mh, Mh, W//Mw, Mw, C]# view: [B, H//Mh, Mh, W//Mw, Mw, C] -> [B, H, W, C]x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)return xclass SwinTransformerBlock(nn.Module):r""" Swin Transformer Block."""mlp_ratio = 4def __init__(self,dim,num_heads,window_size=7,shift_size=0,qkv_bias=True,proj_bias=True,attention_dropout_ratio=0.,proj_drop=0.,drop_path_ratio=0.,norm_layer=LayerNorm,act_layer=nn.GELU,):super(SwinTransformerBlock, self).__init__()self.dim = dimself.num_heads = num_headsself.window_size = window_sizeself.shift_size = shift_sizeassert 0 <= self.shift_size < window_size, "shift_size must in 0-window_size"self.norm1 = norm_layer(dim)self.attn = WindowAttention(dim,window_size=self.window_size,num_heads=num_heads,qkv_bias=qkv_bias,proj_bias=proj_bias,attention_dropout_ratio=attention_dropout_ratio,proj_drop=proj_drop,)self.drop_path = DropPath(drop_path_ratio) if drop_path_ratio > 0. else nn.Identity()self.norm2 = norm_layer(dim)mlp_hidden_dim = int(dim * self.mlp_ratio)self.mlp = MLP(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop_ratio=proj_bias)self.H = Noneself.W = Nonedef forward(self, x, mask_matrix):"""Args:x: Input feature, tensor size (B, H*W, C).H, W: Spatial resolution of the input feature.mask_matrix: Attention mask for cyclic shift."""B, L, C = x.shapeH, W = self.H, self.Wassert L == H * W, "input feature has wrong size"shortcut = xx = self.norm1(x)x = x.view(B, H, W, C)# pad feature maps to multiples of window sizepad_l = pad_t = 0pad_r = (self.window_size - W % self.window_size) % self.window_sizepad_b = (self.window_size - H % self.window_size) % self.window_sizex = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b))_, Hp, Wp, _ = x.shape# cyclic shiftif self.shift_size > 0:shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))attn_mask = mask_matrixelse:shifted_x = xattn_mask = None# partition windowsx_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, Cx_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C# W-MSA/SW-MSAattn_windows = self.attn(x_windows, mask=attn_mask) # nW*B, window_size*window_size, C# merge windowsattn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)shifted_x = window_reverse(attn_windows, self.window_size, Hp, Wp) # B H' W' C# reverse cyclic shiftif self.shift_size > 0:x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))else:x = shifted_xif pad_r > 0 or pad_b > 0:x = x[:, :H, :W, :].contiguous()x = x.view(B, H * W, C)# FFNx = shortcut + self.drop_path(x)x = x + self.drop_path(self.mlp(self.norm2(x)))return xclass BasicLayer(nn.Module):""" A basic Swin Transformer layer for one stage."""def __init__(self,dim,num_layers,num_heads,drop_path,window_size=7,qkv_bias=True,proj_bias=True,attention_dropout_ratio=0.,proj_drop=0.,norm_layer=LayerNorm,act_layer=nn.GELU,downsample=None):super().__init__()self.window_size = window_sizeself.shift_size = window_size // 2self.num_layers = num_layers# build blocksself.blocks = nn.ModuleList([SwinTransformerBlock(dim=dim,num_heads=num_heads,window_size=window_size,shift_size=0 if (i % 2 == 0) else window_size // 2,qkv_bias=qkv_bias,proj_bias=proj_bias,attention_dropout_ratio=attention_dropout_ratio,proj_drop=proj_drop,drop_path_ratio=drop_path[i] if isinstance(drop_path, list) else drop_path,norm_layer=norm_layer,act_layer=act_layer)for i in range(num_layers)])# patch merging layerif downsample is not None:self.downsample = downsample(dim=dim, norm_layer=norm_layer)else:self.downsample = Nonedef forward(self, x, H, W):""" Forward function.Args:x: Input feature, tensor size (B, H*W, C).H, W: Spatial resolution of the input feature."""# calculate attention mask for SW-MSAHp = int(np.ceil(H / self.window_size)) * self.window_sizeWp = int(np.ceil(W / self.window_size)) * self.window_sizeimg_mask = torch.zeros((1, Hp, Wp, 1), device=x.device) # 1 Hp Wp 1h_slices = (slice(0, -self.window_size),slice(-self.window_size, -self.shift_size),slice(-self.shift_size, None))w_slices = (slice(0, -self.window_size),slice(-self.window_size, -self.shift_size),slice(-self.shift_size, None))cnt = 0for h in h_slices:for w in w_slices:img_mask[:, h, w, :] = cntcnt += 1mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1mask_windows = mask_windows.view(-1, self.window_size * self.window_size)attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))for blk in self.blocks:blk.H, blk.W = H, Wx = blk(x, attn_mask)if self.downsample is not None:x = self.downsample(x, H, W)H, W = (H + 1) // 2, (W + 1) // 2return x, H, Wclass SwinTransformer(nn.Module):""" Swin Transformer backbone."""def __init__(self,patch_size=4,in_channels=3,num_classes=1000,embed_dim=96,depths=(2, 2, 6, 2),num_heads=(3, 6, 12, 24),window_size=7,qkv_bias=True,proj_bias=True,attention_dropout_ratio=0.,proj_drop=0.,drop_path_rate=0.2,norm_layer=LayerNorm,patch_norm=True,):super().__init__()self.num_classes = num_classesself.num_layers = len(depths)self.num_layers = len(depths)self.embed_dim = embed_dimself.patch_norm = patch_norm# stage4输出特征矩阵的channelsself.num_features = int(embed_dim * 2 ** (self.num_layers - 1))# split image into non-overlapping patchesself.patch_embed = PatchPartition(patch_size=patch_size, in_channels=in_channels, embed_dim=embed_dim,norm_layer=norm_layer if self.patch_norm else None)self.pos_drop = nn.Dropout(p=proj_drop)dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]layers = []for i_layer in range(self.num_layers):layer = BasicLayer(dim=int(embed_dim * 2 ** i_layer),num_layers=depths[i_layer],num_heads=num_heads[i_layer],window_size=window_size,qkv_bias=qkv_bias,proj_bias=proj_bias,attention_dropout_ratio=attention_dropout_ratio,proj_drop=proj_drop,drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],norm_layer=norm_layer,downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,)layers.append(layer)self.layers = nn.Sequential(*layers)self.norm = norm_layer(self.num_features)self.avgpool = nn.AdaptiveAvgPool1d(1)self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()self._initialize_weights()def _initialize_weights(self):for m in self.modules():if isinstance(m, nn.Linear):trunc_normal_(m.weight, std=.02)if isinstance(m, nn.Linear) and m.bias is not None:nn.init.constant_(m.bias, 0)elif isinstance(m, nn.LayerNorm):nn.init.constant_(m.bias, 0)nn.init.constant_(m.weight, 1.0)def forward(self, x):# x: [B, L, C]x, H, W = self.patch_embed(x)x = self.pos_drop(x)for layer in self.layers:x, H, W = layer(x, H, W)x = self.norm(x) # [B, L, C]x = self.avgpool(x.transpose(1, 2))x = torch.flatten(x, 1)x = self.head(x)return xdef swin_t(num_classes) -> SwinTransformer:model = SwinTransformer(in_channels=3,patch_size=4,window_size=7,embed_dim=96,depths=(2, 2, 6, 2),num_heads=(3, 6, 12, 24),num_classes=num_classes)return modeldef swin_s(num_classes) -> SwinTransformer:model = SwinTransformer(in_channels=3,patch_size=4,window_size=7,embed_dim=96,depths=(2, 2, 18, 2),num_heads=(3, 6, 12, 24),num_classes=num_classes)return modeldef swin_b(num_classes) -> SwinTransformer:model = SwinTransformer(in_channels=3,patch_size=4,window_size=7,embed_dim=128,depths=(2, 2, 18, 2),num_heads=(4, 8, 16, 32),num_classes=num_classes)return modeldef swin_l(num_classes) -> SwinTransformer:model = SwinTransformer(in_channels=3,patch_size=4,window_size=7,embed_dim=192,depths=(2, 2, 18, 2),num_heads=(6, 12, 24, 48),num_classes=num_classes)return modelif __name__=="__main__":import pyzjrdevice = 'cuda' if torch.cuda.is_available() else 'cpu'input = torch.ones(2, 3, 224, 224).to(device)net = swin_l(num_classes=4)net = net.to(device)out = net(input)print(out)print(out.shape)pyzjr.summary_1(net, input_size=(3, 224, 224))# swin_t Total params: 27,499,108# swin_s Total params: 48,792,676# swin_b Total params: 86,683,780# swin_l Total params: 194,906,308
参考文章
Swin-Transformer网络结构详解_swin transformer-CSDN博客
Swin-transformer详解_swin transformer-CSDN博客
【深度学习】详解 Swin Transformer (SwinT)-CSDN博客
推荐的视频:12.1 Swin-Transformer网络结构详解_哔哩哔哩_bilibili
相关文章:

Swin Transformer模型详解(附pytorch实现)
写在前面 Swin Transformer(Shifted Window Transformer)是一种新颖的视觉Transformer模型,在2021年由微软亚洲研究院提出。这一模型提出了一种基于局部窗口的自注意力机制,显著改善了Vision Transformer(ViT…...
gitee 使用教程
前言 Gitee 是一个中国的开源代码托管平台,类似于 GitHub,旨在为开发者提供一个高效、稳定、安全的代码管理和协作开发环境。Gitee 支持 Git 协议,可以托管 Git 仓库,进行版本控制、代码协作、项目管理等操作。 1. Gitee 的主要…...

基于YOLOv8的水下目标检测系统
基于YOLOv8的水下目标检测系统 (价格90) 使用的是DUO水下目标检测数据集 训练集 6671张 验证集 1111张 测试集 1111张 包含 [holothurian, echinus, scallop, starfish] [海参, 海胆, 扇贝, 海星] 4个类 通过PYQT构建UI界面,包含图片检测,视…...

浅析PCIe链路均衡技术原理与演进
在现代计算机硬件体系的持续演进中,PCIe技术始终扮演着核心角色,其作为连接 CPU 与各类周边设备的关键高速通信链路,不断推动着计算机性能边界的拓展。而 PCIe Link Equalization均衡技术,作为保障数据在高速传输过程中准确性与稳…...

js代理模式
允许在不改变原始对象的情况下,通过代理对象来访问原始对象。代理对象可以在访问原始对象之前或之后,添加一些额外的逻辑或功能。 科学上网过程 一般情况下,在访问国外的网站,会显示无法访问 因为在dns解析过程,这些ip被禁止解析,所以显示无法访问 引…...
C++虚函数(八股总结)
什么是虚函数 虚函数是在父类中定义的一种特殊类型的函数,允许子类重写该函数以适应其自身需求。虚函数的调用取决于对象的实际类型,而不是指针或引用类型。通过将函数声明为虚函数,可以使继承层次结构中的每个子类都能够使用其自己的实现&a…...

vue的路由守卫逻辑处理不当导致部署在nginx上无法捕捉后端异步响应消息等问题
近期对前端的路由卫士有了更多的认识。 何为路由守卫?这可能是一种约定俗成的名称。就是VUE中的自定义函数,用来处理路由跳转。 import { createRouter, createWebHashHistory } from "vue-router";const router createRouter({history: cr…...
[备忘.OFD]OFD是什么、OFD与PDF格式文件的互转换
OFD(Open Fixed-layout Document)是一种由工业和信息化部软件司牵头中国电子技术标准化研究院制定的版式文档国家标准,属于中国的一种自主格式。OFD旨在打破政府部门和党委机关电子公文格式不统一的问题,以方便电子文档的存…...

Pycharm连接远程解释器
这里写目录标题 0 前言1 给项目添加解释器2 通过SSH连接3 找到远程服务器的torch环境所对应的python路径,并设置同步映射(1)配置服务器的系统环境(2)配置服务器的conda环境 4 进入到程序入口(main.py&#…...
嵌入式系统 tensorflow
🎬 秋野酱:《个人主页》 🔥 个人专栏:《Java专栏》《Python专栏》 ⛺️心若有所向往,何惧道阻且长 文章目录 探索嵌入式系统中的 TensorFlow:机遇与挑战一、TensorFlow 适配嵌入式的优势二、面临的硬件瓶颈三、软件优化策略四、实…...

深度学习知识点:LSTM
文章目录 1.应用现状2.发展历史3.基本结构4.LSTM和RNN的差异 1.应用现状 长短期记忆神经网络(LSTM)是一种特殊的循环神经网络(RNN)。原始的RNN在训练中,随着训练时间的加长以及网络层数的增多,很容易出现梯度爆炸或者梯度消失的问…...
11.C语言内存管理与常用内存操作函数解析
目录 1.简介2.void 指针3.malloc4.free5.calloc6.realloc7.restrict 说明符8.memcpy9.memmove()10.memcmp 1.简介 本篇原文为:C语言内存管理与常用内存操作函数解析。 更多C进阶、rust、python、逆向等等教程,可点击此链接查看:酷程网 C 语…...

Python 中的错误处理与调试技巧
💖 欢迎来到我的博客! 非常高兴能在这里与您相遇。在这里,您不仅能获得有趣的技术分享,还能感受到轻松愉快的氛围。无论您是编程新手,还是资深开发者,都能在这里找到属于您的知识宝藏,学习和成长…...

门禁系统与消防报警的几种联动方式
1、规范中要求的出入口系统与消防联动 1.1《建筑设计防火规范》GB 50016-2018 1.2《民用建筑电气设计规范》JGJ 16-2008 14.4出入口控制系统 3 设置在平安疏散口的出入口限制装置,应与火灾自动报警系统联动;在紧急状况下应自动释放出入口限制系统&…...

云原生安全风险分析
一、什么是云原生安全 云原生安全包含两层含义: 面向云原生环境的安全具有云原生特征的安全 0x1:面向云原生环境的安全 面向云原生环境的安全的目标是防护云原生环境中基础设施、编排系统和微服务等系统的安全。 这类安全机制不一定具备云原生的特性…...
解决cursor50次使用限制问题并恢复账号次数
视频内容: 在这个视频教程中,我们将演示如何解决科sir软件50次使用限制的问题,具体步骤包括删除和注销账号、重新登录并刷新次数。教程详细展示了如何使用官网操作将账号的剩余次数恢复到250次,并进行软件功能测试。通过简单的操…...

python学习笔记—16—数据容器之元组
1. 元组——tuple(元组是一个只读的list) (1) 元组的定义注意:定义单个元素的元组,在元素后面要加上 , (2) 元组也支持嵌套 (3) 下标索引取出元素 (4) 元组的相关操作 1. index——查看元组中某个元素在元组中的位置从左到右第一次出现的位置 t1 (&qu…...

rabbitmq——岁月云实战笔记
1 rabbitmq设计 生产者并不是直接将消息投递到queue,而是发送给exchange,由exchange根据type的规则来选定投递的queue,这样消息设计在生产者和消费者就实现解耦。 rabbitmq会给没有type预定义一些exchage,而实际我们却应该使用自己定义的。 1.1 用户注册设计 用户在…...

Matlab APP Designer
我想给聚类的代码加一个图形化界面,需要输入一些数据和一些参数并输出聚类后的图像和一些评价指标的值。 gpt说 可以用 app designer 界面元素设计 在 设计视图 中直接拖动即可 如图1,我拖进去一个 按钮 ,图2 红色部分 出现一行 Button 图…...
CSS语言的编程范式
CSS语言的编程范式 引言 在现代网页开发中,CSS(层叠样式表)作为一种样式语言,承担着网站前端呈现的重要角色。无论是简单的静态网页还是复杂的单页应用,CSS都在人机交互中发挥着至关重要的作用。掩盖在美观背后的&am…...
conda相比python好处
Conda 作为 Python 的环境和包管理工具,相比原生 Python 生态(如 pip 虚拟环境)有许多独特优势,尤其在多项目管理、依赖处理和跨平台兼容性等方面表现更优。以下是 Conda 的核心好处: 一、一站式环境管理:…...
反向工程与模型迁移:打造未来商品详情API的可持续创新体系
在电商行业蓬勃发展的当下,商品详情API作为连接电商平台与开发者、商家及用户的关键纽带,其重要性日益凸显。传统商品详情API主要聚焦于商品基本信息(如名称、价格、库存等)的获取与展示,已难以满足市场对个性化、智能…...

理解 MCP 工作流:使用 Ollama 和 LangChain 构建本地 MCP 客户端
🌟 什么是 MCP? 模型控制协议 (MCP) 是一种创新的协议,旨在无缝连接 AI 模型与应用程序。 MCP 是一个开源协议,它标准化了我们的 LLM 应用程序连接所需工具和数据源并与之协作的方式。 可以把它想象成你的 AI 模型 和想要使用它…...

【Redis技术进阶之路】「原理分析系列开篇」分析客户端和服务端网络诵信交互实现(服务端执行命令请求的过程 - 初始化服务器)
服务端执行命令请求的过程 【专栏简介】【技术大纲】【专栏目标】【目标人群】1. Redis爱好者与社区成员2. 后端开发和系统架构师3. 计算机专业的本科生及研究生 初始化服务器1. 初始化服务器状态结构初始化RedisServer变量 2. 加载相关系统配置和用户配置参数定制化配置参数案…...

【快手拥抱开源】通过快手团队开源的 KwaiCoder-AutoThink-preview 解锁大语言模型的潜力
引言: 在人工智能快速发展的浪潮中,快手Kwaipilot团队推出的 KwaiCoder-AutoThink-preview 具有里程碑意义——这是首个公开的AutoThink大语言模型(LLM)。该模型代表着该领域的重大突破,通过独特方式融合思考与非思考…...
鸿蒙中用HarmonyOS SDK应用服务 HarmonyOS5开发一个医院查看报告小程序
一、开发环境准备 工具安装: 下载安装DevEco Studio 4.0(支持HarmonyOS 5)配置HarmonyOS SDK 5.0确保Node.js版本≥14 项目初始化: ohpm init harmony/hospital-report-app 二、核心功能模块实现 1. 报告列表…...
unix/linux,sudo,其发展历程详细时间线、由来、历史背景
sudo 的诞生和演化,本身就是一部 Unix/Linux 系统管理哲学变迁的微缩史。来,让我们拨开时间的迷雾,一同探寻 sudo 那波澜壮阔(也颇为实用主义)的发展历程。 历史背景:su的时代与困境 ( 20 世纪 70 年代 - 80 年代初) 在 sudo 出现之前,Unix 系统管理员和需要特权操作的…...
Caliper 配置文件解析:config.yaml
Caliper 是一个区块链性能基准测试工具,用于评估不同区块链平台的性能。下面我将详细解释你提供的 fisco-bcos.json 文件结构,并说明它与 config.yaml 文件的关系。 fisco-bcos.json 文件解析 这个文件是针对 FISCO-BCOS 区块链网络的 Caliper 配置文件,主要包含以下几个部…...

Redis数据倾斜问题解决
Redis 数据倾斜问题解析与解决方案 什么是 Redis 数据倾斜 Redis 数据倾斜指的是在 Redis 集群中,部分节点存储的数据量或访问量远高于其他节点,导致这些节点负载过高,影响整体性能。 数据倾斜的主要表现 部分节点内存使用率远高于其他节…...

selenium学习实战【Python爬虫】
selenium学习实战【Python爬虫】 文章目录 selenium学习实战【Python爬虫】一、声明二、学习目标三、安装依赖3.1 安装selenium库3.2 安装浏览器驱动3.2.1 查看Edge版本3.2.2 驱动安装 四、代码讲解4.1 配置浏览器4.2 加载更多4.3 寻找内容4.4 完整代码 五、报告文件爬取5.1 提…...