当前位置: 首页 > news >正文

一文解答Swin Transformer + 代码【详解】

文章目录

    • 1、Swin Transformer的介绍
      • 1.1 Swin Transformer解决图像问题的挑战
      • 1.2 Swin Transformer解决图像问题的方法
    • 2、Swin Transformer的具体过程
      • 2.1 Patch Partition 和 Linear Embedding
      • 2.2 W-MSA、SW-MSA
      • 2.3 Swin Transformer代码解析
        • 2.3.1 代码解释
      • 2.4 W-MSA和SW-MSA的softmax计算
        • 2.4.1 代码实现
      • 2.5 patch merging 实现过程
    • 3、不同Swin Transformer版本的区别

1、Swin Transformer的介绍

下面是Swin Transformer论文的Abstract,

在这里插入图片描述

1.1 Swin Transformer解决图像问题的挑战

原论文中的Abstract讲到:Challenges in adapting Transformer from language to vision arise from differences between the two domains, such as large variations in the scale of visual entities and the high resolution of pixels in images compared to words in text。 也就是说在一张图片中,不同的object有不同的尺度,以及与NLP相比,图像处理需要更多的token,下面来看在CNN中是如何解决不同的object有不同的尺度这个问题的,CNN是通过Conv和Pool对特征图中像素的不同感受野来检测不同尺度的物体,如下图所示,

在这里插入图片描述

为什么多尺度检测对于Swin Transformer是有挑战性的,请看下图最左边的Transformer结构,当图片的宽高像素是 h × w \sf{h} \times \sf{w} h×w,所以一共有 h × w \sf{h} \times \sf{w} h×w 个token,输入给Transformer时,输出也是 h × w \sf{h} \times \sf{w} h×w ,所以Transformer并没有做下采样操作,下图中间的是ViT的结构,下图最右边的是Swin Transformer结构,MSA是multi-head self-attention的缩写,

在这里插入图片描述

1.2 Swin Transformer解决图像问题的方法

Swin Transformer如何解决相比NLP中更多的token呢,Swin Transformer提出了hierarchical Transformer,hierarchical Transformer是由shifted windows计算出来的,

下面是MSA,W-MSA,SW-MSA的示意图, 如果对于MSA处理56x56个像素的图片,需要计算 313 6 2 3136^2 31362次的相似度,计算量较大;W-MSA将56x56个像素分割成8x8个windows,每个windows由7x7个像素组成,每个windows单独计算MSA,作用是减少计算量,缺点是windows之间没有信息交互,计算相似度为;SW-MSA可以解决windows之间没有信息交互的问题,例如下图最右边的示例,把一个8x8的图片的分成4个windows,然后分割grid向右下角移动 window // 2,这样分割的windows就可以包含多个windows之间的信息,

在这里插入图片描述

如下图中最左边部分的1,2,3包含多个windows的信息,

在这里插入图片描述

下图黄色虚线框标记的是Swin Transformer处理模块,每次经过Swin Transformer模块,patch Merging使特征层减少一半,通道数增加一倍,

在这里插入图片描述
在这里插入图片描述

与ViT特征层不变不同,每次经过Swin Transformer模块,特征层减少一半,通道数增加一倍,

在这里插入图片描述

2、Swin Transformer的具体过程

2.1 Patch Partition 和 Linear Embedding

如下图,黄色虚线框的作用是下采样,

在这里插入图片描述

将224x224x3的特征图中的4x4x3的像素patch成1个token,然后通过concat操作,就得到了4x4x3=48个特征维度,把这个48个维度的向量看成一个token,所以224x224x3的特征图有56x56个token,每个token的特征维度是48,再经过一个映射把48个维度映射成96个维度,如下图所示,

在这里插入图片描述

Patch Partition 和 Linear Embedding在上面的示意图上是2个步骤,但是在代码中却是一步完成的,如下代码中的self.proj,

在这里插入图片描述

2.2 W-MSA、SW-MSA

如下图,2个连续的swin transformer结构与ViT的结构只有W-MSA、SW-MSA和MSA的区别,其他的MLP、Norm结构都相同,

在这里插入图片描述

下面是MSA和W-MSA的示意图,相较于MSA,W-MSA计算量较小,但是windows之间没有信息交互,

在这里插入图片描述

下面是SW-MSA的示意图,SW-MSA采用新的 window 划分方式,将 windows 向右下角移动 window size // 2,移动之后会发现windows的大小不一样,先把A+B向右移动,然后A+C再向下移动就解决了windows的大小不一样的问题,这样就可以并行地计算SW-MSA了,
需要注意到,移动之后每个window的像素并不是连续的,如何解决这个问题,请看下面的解释,

在这里插入图片描述

如下图,0作为query,其本身和其他window内的数字作为k,示意图中矩形框表示0和移动的像素的乘积,这个乘积再作 − 100 -100 100计算,取softmax之后,值就接近于0了,

在这里插入图片描述

计算完成之后,再把移动之后的像素还原,

在这里插入图片描述

W-MSA、SW-MSA与MSA的不同还在于softmax的计算方式,W-MSA、SW-MSA与MSA的不同有W-MSA、SW-MSA是在softmax里面加入相对位置偏置,MSA是在输入加入位置编码,

在这里插入图片描述

2.3 Swin Transformer代码解析

代码:

class SwinTransformerBlock(nn.Module):r""" Swin Transformer Block.Args:dim (int): Number of input channels.num_heads (int): Number of attention heads.window_size (int): Window size.shift_size (int): Shift size for SW-MSA.mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: Truedrop (float, optional): Dropout rate. Default: 0.0attn_drop (float, optional): Attention dropout rate. Default: 0.0drop_path (float, optional): Stochastic depth rate. Default: 0.0act_layer (nn.Module, optional): Activation layer. Default: nn.GELUnorm_layer (nn.Module, optional): Normalization layer.  Default: nn.LayerNorm"""def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0,mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0., drop_path=0.,act_layer=nn.GELU, norm_layer=nn.LayerNorm):super().__init__()self.dim = dimself.input_resolution = input_resolutionself.num_heads = num_headsself.window_size = window_sizeself.shift_size = shift_sizeself.mlp_ratio = mlp_ratioassert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"self.norm1 = norm_layer(dim)self.attn = WindowAttention(dim, window_size=(self.window_size, self.window_size), num_heads=num_heads,qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()self.norm2 = norm_layer(dim)mlp_hidden_dim = int(dim * mlp_ratio)self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)if self.shift_size > 0:# calculate attention mask for SW-MSAH, W = self.input_resolutionimg_mask = torch.zeros((1, H, W, 1))  # 1 H W 1h_slices = (slice(0, -self.window_size),slice(-self.window_size, -self.shift_size),slice(-self.shift_size, None))w_slices = (slice(0, -self.window_size),slice(-self.window_size, -self.shift_size),slice(-self.shift_size, None))cnt = 0for h in h_slices:for w in w_slices:img_mask[:, h, w, :] = cntcnt += 1mask_windows = window_partition(img_mask, self.window_size)  # nW, window_size, window_size, 1mask_windows = mask_windows.view(-1, self.window_size * self.window_size)attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))else:attn_mask = Noneself.register_buffer("attn_mask", attn_mask)def forward(self, x):H, W = self.input_resolutionB, L, C = x.shapeassert L == H * W, "input feature has wrong size"shortcut = xx = self.norm1(x)x = x.view(B, H, W, C)# cyclic shiftif self.shift_size > 0:shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))else:shifted_x = x# partition windowsx_windows = window_partition(shifted_x, self.window_size)  # [nW*B, Mh, Mw, C]x_windows = x_windows.view(-1, self.window_size * self.window_size, C)  # [nW*B, Mh*Mw, C]# W-MSA/SW-MSAattn_windows = self.attn(x_windows, mask=self.attn_mask)  # [nW*B, Mh*Mw, C]# merge windowsattn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)  # [nW*B, Mh, Mw, C]shifted_x = window_reverse(attn_windows, self.window_size, H, W)  # [B, H', W', C]# reverse cyclic shiftif self.shift_size > 0:x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))else:x = shifted_xx = x.view(B, H * W, C)x = shortcut + self.drop_path(x)# FFNx = x + self.drop_path(self.mlp(self.norm2(x)))return x
2.3.1 代码解释

代码解释:

在这里插入图片描述
接上面的代码:
在这里插入图片描述

window_partition函数:

在这里插入图片描述

window_reverse函数:

在这里插入图片描述

MLP函数:

在这里插入图片描述

2.4 W-MSA和SW-MSA的softmax计算

在这里插入图片描述

W-MSA和SW-MSA的softmax计算分为2个步骤,
对于window size为2的矩阵,计算其relative position index矩阵,矩阵大小为 w i n d o w s i z e 2 \sf{windowsize}^2 windowsize2 x w i n d o w s i z e 2 \sf{windowsize}^2 windowsize2,relative position index矩阵的size和相似度矩阵的size一致,计算结果不用赘述,很容易看出来,相对位置的矩阵size是 2 M − 1 \sf{2M-1} 2M1 x 2 M − 1 \sf{2M-1} 2M1

在这里插入图片描述

relative position bias matrix是一个可学习的参数,是依照相对位置的所有组合方式而来, relative position bias B 的值来自 relative position bias matrix B ^ \sf{\hat{B}} B^ 的值,

在这里插入图片描述

2.4.1 代码实现

因为2维矩阵需要考虑2个维度的值,不方便计算,所以把2维坐标转为1维坐标,计算过程如下,

在这里插入图片描述

在这里插入图片描述

代码以及代码解释:

下图右边的代码假设window size为7,

在这里插入图片描述

在这里插入图片描述

2.5 patch merging 实现过程

patch merging在swin transformer中的位置如下,

在这里插入图片描述

patch merging的实现过程如下图,原来的像素在每隔1个像素组合,经过concat和LayerNorm以及Linear层之后size减倍,channel数加倍,

在这里插入图片描述

代码:

在这里插入图片描述

3、不同Swin Transformer版本的区别

区别在于stage1的channel数,stage3的Swin Transformer计算的次数,

在这里插入图片描述

参考:
1、原论文:https://arxiv.org/pdf/2103.14030
2、哔哩哔哩:https://www.bilibili.com/video/BV1xF41197E4/?p=6&spm_id_from=pageDriver

相关文章:

一文解答Swin Transformer + 代码【详解】

文章目录 1、Swin Transformer的介绍1.1 Swin Transformer解决图像问题的挑战1.2 Swin Transformer解决图像问题的方法 2、Swin Transformer的具体过程2.1 Patch Partition 和 Linear Embedding2.2 W-MSA、SW-MSA2.3 Swin Transformer代码解析2.3.1 代码解释 2.4 W-MSA和SW-MSA…...

Vue3:<Teleport>传送门组件的使用和注意事项

你好&#xff0c;我是沐爸&#xff0c;欢迎点赞、收藏、评论和关注。 Vue3 引入了一个新的内置组件 <Teleport>&#xff0c;它允许你将子组件树渲染到 DOM 中的另一个位置&#xff0c;而不是在父组件的模板中直接渲染。这对于需要跳出当前组件的 DOM 层级结构进行渲染的…...

项目之家:又一家项目信息发布合作对接及一手接单平台

这几天“小三劝退师时薪700”的消息甚嚣尘上&#xff0c;只能说从某一侧面来看心理咨询师这个职业的前景还是可以的&#xff0c;有兴趣的朋友可以关注下。话说上一篇文章给大家介绍了U客直谈&#xff0c;今天趁热打铁再给大家分享一个地推拉新项目合作平台~项目之家&#xff1a…...

02-java实习工作一个多月-经历分享

一、描述一下最近不写博客的原因 离我发java实习的工作的第一天的博客已经过去了一个多月了&#xff0c;本来还没入职的情况是打算每天工作都要写一份博客来记录一下的&#xff08;最坏的情况也是每周至少总结一下的&#xff09;&#xff0c;其实这个第一天的博客都是在公司快…...

JVM 调优篇2 jvm的内存结构以及堆栈参数设置与查看

一 jvm的内存模型 2.1 jvm内存模型概览 二 实操案例 2.1 设置和查看栈大小 1.代码 /*** 演示栈中的异常:StackOverflowError** author shkstart* create 2020 下午 9:08** 设置栈的大小&#xff1a; -Xss (-XX:ThreadStackSize)** -XX:PrintFlagsFinal*/ public class S…...

微信可以设置自动回复吗?

在日常的微信聊天中&#xff0c;我们或许会频繁地遭遇客户提出的相同问题&#xff0c;尤其是对于从事销售工作的朋友们来说&#xff0c;客户在添加好友后的第一句话往往是“在吗”或者“你好”。当我们的好友数量众多时&#xff0c;手动逐个回复可能会耗费大量的时间。因此&…...

同样数据源走RTMP播放延迟低还是RTSP低?

背景 在比较同一个数据源&#xff0c;是RTMP播放延迟低还是RTSP延迟低之前&#xff0c;我们先看看RTMP和RTSP的区别&#xff0c;我们知道&#xff0c;RTMP&#xff08;Real-Time Messaging Protocol&#xff09;和RTSP&#xff08;Real Time Streaming Protocol&#xff09;是…...

@开发者极客们,网易2024低代码大赛来啦

极客们&#xff0c;网易云信拍了拍你 9月6日起&#xff0c;2024网易低代码大赛正式开启啦&#xff01; 低代码大赛是由网易主办的权威赛事&#xff0c;鼓励开发者们用低代码开发的方式快速搭建应用&#xff0c;并最终以作品决出优胜。 从2022年11月起&#xff0c;网易低代码大赛…...

数据分析-16-时间序列分析的常用模型

1 什么是时间序列 时间序列是一组按时间顺序排列的数据点的集合,通常以固定的时间间隔进行观测。这些数据点可以是按小时、天、月甚至年进行采样的。时间序列在许多领域中都有广泛应用,例如金融、经济学、气象学和工程等。 时间序列的分析可以帮助我们理解和预测未来的趋势和…...

SpringMVC使用:类型转换数据格式化数据验证

01-类型转换器 先在pom.xml里面导入依赖&#xff0c;一个是mvc框架的依赖&#xff0c;一个是junit依赖 然后在web.xml里面导入以下配置&#xff08;配置的详细说明和用法我在前面文章中有写到&#xff09; 创建此测试类的方法用于测试springmvc是具备自动类型转换功能的 user属…...

多语言ASO – 本地化的10个技巧

ASO优化是一个复杂的领域&#xff0c;即使你只关注讲英语的用户。如果您想面向国际受众并在全球范围内发展您的应用程序业务&#xff0c;您必须在App Store和Google Play Store上本地化应用程序的产品页面。不过&#xff0c;应用程序商店本地化的过程也有很多陷阱。 应用商店本…...

C程序设计——函数0

函数定义 前面说过C语言是结构化的程序设计语言&#xff0c;他把所有问题抽象为数据和对数据的操作&#xff0c;前面讲的变量、常量&#xff0c;都是数据。现在开始讲对数据操作——函数。 C语言的函数&#xff0c;定义方式如下&#xff1a; 返回值类型 函数名(参数列表) {…...

第二十一章 rust与动静态库的结合使用

注意 本系列文章已升级、转移至我的自建站点中,本章原文为:rust与动静态库的结合使用 目录 注意一、前言二、库生成三、库使用四、总结一、前言 rust中多了很多类型的库,比如前面章节中我们提到基本的bin与lib这两种crate类型库。 如果你在命令行执行下列语句: rustc -…...

修改服务器DNS解析及修改自动对时时区

修改服务器DNS解析&#xff1a; 1、搜索一下当地的DNS服务器的地址 2、登录服务器&#xff0c;执行 vim /etc/resolv.conf文件&#xff0c;在nameserver字段后填写DNS服务的地址 3、chattr i /etc/resolv.conf 加上不可修改权限&#xff0c;防止重启DNS被修改 修改自动对时…...

中科院TOP“灌水神刊”合集!盘点那些“又牛又水”的国人友好SCI

【SciencePub学术】本期&#xff0c;小编给大家推荐几本“又牛又水”的期刊&#xff0c;并且都是清一色的国人友好刊&#xff0c;涵盖各领域&#xff0c;以供各位学者参考&#xff01; NO.1 Nature Communications IF&#xff1a;14.7 分区&#xff1a;JCR1区中科院1区TOP 年…...

Python列表浅拷贝的陷阱与破解之道

引言 在Python编程世界中&#xff0c;列表的拷贝操作看似简单&#xff0c;却常常隐藏着一些令人意想不到的陷阱&#xff0c;尤其是当涉及到浅拷贝时。今天&#xff0c;我们将深入探讨Python列表浅拷贝现象及产生原因&#xff0c;并提供有效的解决方案&#xff0c;帮助你写出更…...

开放式系统互连(OSI)模型的实际意义

0 前言 开放式系统互连&#xff08;OSI&#xff0c;Open Systems Interconnection&#xff09;模型&#xff0c;由国际标准化组织&#xff08;ISO&#xff09;在1984年提出&#xff0c;目的是为了促进不同厂商生产的网络设备之间的互操作性。 定义了一种在层之间进行协议实现…...

回溯——10.全排列 II

力扣题目链接 给定一个可包含重复数字的序列 nums &#xff0c;按任意顺序 返回所有不重复的全排列。 示例 1&#xff1a; 输入&#xff1a;nums [1,1,2]输出&#xff1a; [[1,1,2], [1,2,1], [2,1,1]] 解题思路&#xff1a; 排序&#xff1a;首先对数组进行排序&#xf…...

基于百度AIStudio飞桨paddleRS-develop版道路模型开发训练

基于百度AIStudio飞桨paddleRS-develop版道路模型开发训练 参考地址&#xff1a;https://aistudio.baidu.com/projectdetail/8271882 基于python35paddle120env环境 预测可视化结果&#xff1a; &#xff08;一&#xff09;安装环境&#xff1a; 先上传本地下载的源代码Pad…...

【 C++ 】C/C++内存管理

前言&#xff1a; &#x1f618;我的主页&#xff1a;OMGmyhair-CSDN博客 目录 一、C/C内存分布 二、C语言中动态内存管理方式&#xff1a;malloc/calloc/realloc/free malloc&#xff1a; calloc&#xff1a; realloc&#xff1a; free&#xff1a; 三、C内存管理方式…...

synchronized 学习

学习源&#xff1a; https://www.bilibili.com/video/BV1aJ411V763?spm_id_from333.788.videopod.episodes&vd_source32e1c41a9370911ab06d12fbc36c4ebc 1.应用场景 不超卖&#xff0c;也要考虑性能问题&#xff08;场景&#xff09; 2.常见面试问题&#xff1a; sync出…...

TDengine 快速体验(Docker 镜像方式)

简介 TDengine 可以通过安装包、Docker 镜像 及云服务快速体验 TDengine 的功能&#xff0c;本节首先介绍如何通过 Docker 快速体验 TDengine&#xff0c;然后介绍如何在 Docker 环境下体验 TDengine 的写入和查询功能。如果你不熟悉 Docker&#xff0c;请使用 安装包的方式快…...

如何在看板中体现优先级变化

在看板中有效体现优先级变化的关键措施包括&#xff1a;采用颜色或标签标识优先级、设置任务排序规则、使用独立的优先级列或泳道、结合自动化规则同步优先级变化、建立定期的优先级审查流程。其中&#xff0c;设置任务排序规则尤其重要&#xff0c;因为它让看板视觉上直观地体…...

Robots.txt 文件

什么是robots.txt&#xff1f; robots.txt 是一个位于网站根目录下的文本文件&#xff08;如&#xff1a;https://example.com/robots.txt&#xff09;&#xff0c;它用于指导网络爬虫&#xff08;如搜索引擎的蜘蛛程序&#xff09;如何抓取该网站的内容。这个文件遵循 Robots…...

关于 WASM:1. WASM 基础原理

一、WASM 简介 1.1 WebAssembly 是什么&#xff1f; WebAssembly&#xff08;WASM&#xff09; 是一种能在现代浏览器中高效运行的二进制指令格式&#xff0c;它不是传统的编程语言&#xff0c;而是一种 低级字节码格式&#xff0c;可由高级语言&#xff08;如 C、C、Rust&am…...

零基础在实践中学习网络安全-皮卡丘靶场(第九期-Unsafe Fileupload模块)(yakit方式)

本期内容并不是很难&#xff0c;相信大家会学的很愉快&#xff0c;当然对于有后端基础的朋友来说&#xff0c;本期内容更加容易了解&#xff0c;当然没有基础的也别担心&#xff0c;本期内容会详细解释有关内容 本期用到的软件&#xff1a;yakit&#xff08;因为经过之前好多期…...

Angular微前端架构:Module Federation + ngx-build-plus (Webpack)

以下是一个完整的 Angular 微前端示例&#xff0c;其中使用的是 Module Federation 和 npx-build-plus 实现了主应用&#xff08;Shell&#xff09;与子应用&#xff08;Remote&#xff09;的集成。 &#x1f6e0;️ 项目结构 angular-mf/ ├── shell-app/ # 主应用&…...

nnUNet V2修改网络——暴力替换网络为UNet++

更换前,要用nnUNet V2跑通所用数据集,证明nnUNet V2、数据集、运行环境等没有问题 阅读nnU-Net V2 的 U-Net结构,初步了解要修改的网络,知己知彼,修改起来才能游刃有余。 U-Net存在两个局限,一是网络的最佳深度因应用场景而异,这取决于任务的难度和可用于训练的标注数…...

LCTF液晶可调谐滤波器在多光谱相机捕捉无人机目标检测中的作用

中达瑞和自2005年成立以来&#xff0c;一直在光谱成像领域深度钻研和发展&#xff0c;始终致力于研发高性能、高可靠性的光谱成像相机&#xff0c;为科研院校提供更优的产品和服务。在《低空背景下无人机目标的光谱特征研究及目标检测应用》这篇论文中提到中达瑞和 LCTF 作为多…...

GraphQL 实战篇:Apollo Client 配置与缓存

GraphQL 实战篇&#xff1a;Apollo Client 配置与缓存 上一篇&#xff1a;GraphQL 入门篇&#xff1a;基础查询语法 依旧和上一篇的笔记一样&#xff0c;主实操&#xff0c;没啥过多的细节讲解&#xff0c;代码具体在&#xff1a; https://github.com/GoldenaArcher/graphql…...