DilateFormer: Multi-Scale Dilated Transformer for Visual Recognition 中的空洞自注意力机制
空洞自注意力机制
文章目录
- 摘要
- 1. 模型解释
- 1.1. 滑动窗口扩张注意力
- 1.2. 多尺度扩张注意力
- 2. 代码
- 3. 流程图
- 3.1. MultiDilatelocalAttention
- 3.2. DilateAttention
- 3.3. MLP
摘要
本文针对DilateFormer中的空洞自注意力机制原理和代码进行详细介绍,最后通过流程图梳理其实现原理。
1. 模型解释
1.1. 滑动窗口扩张注意力
根据在普通视觉变换器(ViTs)中浅层全局注意力中观察到的局部性和稀疏性特性,我们提出了一种滑动窗口扩张注意力(SWDA) 操作,其中,keys和values被以query patch为中心的滑动窗口稀疏地选择。然后对这些代表性patches进行自注意力。我们的 SWDA 正式描述如下:
X = S W D A ( Q , K , V , r ) ( 1 ) \begin{aligned} &&&&&&&&&&&&& X = SWDA(Q,K,V,r) &&&&&&&&&&&&&&&& (1) \end{aligned} X=SWDA(Q,K,V,r)(1)
其中, Q , K , V Q,K,V Q,K,V分别代表query、key和value矩阵,三个矩阵的每一行表示一个query/key/value特征向量。对于原始特征图上 ( i , j ) (i,j) (i,j)位置的query,SWDA以尺寸为 w × w w×w w×w大小的滑动窗口,稀疏地选择key和value去指导自注意力。
而且,我们定义一个扩张率 r ϵ N + r \epsilon N^+ rϵN+去控制稀疏程度。特别地,对于位置 ( i , j ) (i,j) (i,j),SWDA计算的输出 X X X中的相应分量 x i j x_{ij} xij定义如下:
x i j = A t t e n t i o n ( q i j , K r , V r ) , ( 2 ) = S o f t m a x ( q i j K r T d k ) V r , 1 ≤ i ≤ W , 1 ≤ i ≤ H \begin{aligned} &&&&&&&&&&&& x_{ij} &= Attention(q_{ij},K_r,V_r), &&&&&&&&&&&&&&&& (2)\\ &&&&&&&&&&&&&=Softmax(\frac{q_{ij}K^T_r}{\sqrt{d_k}})V_r,& 1≤i≤W, 1≤i≤H \\ \end{aligned} xij=Attention(qij,Kr,Vr),=Softmax(dkqijKrT)Vr,1≤i≤W,1≤i≤H(2)
其中, H H H 和 W W W 是特征图的高和宽。 K r K_r Kr和 V r V_r Vr表示从特征图 K K K 和 V V V 中选择的keys和values。
给定位于 ( i , j ) (i,j) (i,j)的query,位于坐标 ( i ′ , j ′ ) (i', j') (i′,j′) 下keys和values将被选择去指导自注意力(self-attetion):
{ ( i ′ , j ′ ) ∣ i ′ = i + p × r , j ′ = j + q × r } , − w 2 ≤ p , q ≤ w 2 . ( 3 ) \begin{aligned} &&&&&&&&&&&&& \{(i',j')|i'=i+p×r, j'=j+q×r \}, \frac{-w}{2}≤p, q≤\frac{w}{2}. &&&&&&&&&&&&&&&& (3) \end{aligned} {(i′,j′)∣i′=i+p×r,j′=j+q×r},2−w≤p,q≤2w.(3)
我们的 SWDA 以滑动窗口的方式对所有query patches进行自注意力操作。对于特征图边缘的query,我们简单地使用卷积运算中常用的 补零策略 来保持特征图的大小。通过稀疏地选择以queries为中心的keys和values,所提出的 SWDA 明确满足局部性和稀疏性属性,并且可以有效地对远程依赖关系进行建模
1.2. 多尺度扩张注意力

首先,特征图的通道被划分不同的heads。然后,自注意力操作是在红色查询块周围的窗口中的彩色块之间执行的,在不同的头中使用不同的膨胀率。此外,不同heads中的特征被连接在一起,然后输入到线性层中。默认情况下,我们使用 3 × 3 的内核大小,膨胀率 r = 1、2 和 3,不同头中参与感受野的大小为 3 × 3、5 × 5 和 7 × 7。
为了利用块级自注意力机制在不同尺度上的稀疏性,我们进一步提出了多尺度扩张注意力(MSDA) 块来提取多尺度语义信息。如图4所示,给定特征图 X X X,我们通过 线性投影(linear projection) 获得相应的query、kay和value。之后,我们将特征图的通道划分到 n n n 个不同的 h e a d s heads heads,并在不同的 h e a d s heads heads中以不同的膨胀率(dilation rates)执行多尺度SWDA。具体来说,我们的MSDA计算如下:
h i = S W D A ( Q i , K i , V i , r i ) , 1 ≤ i ≤ n , ( 4 ) X = L i n e a r ( C o n c a t [ h 1 , . . . , h n ] ) , ( 5 ) \begin{aligned} &&&&&&&&&&&&& h_i=SWDA(Q_i,K_i,V_i,r_i), &1≤i≤n, &&&&&&&&&&&&&&&& (4)\\ &\\ &&&&&&&&&&&&& X=Linear(Concat[h_1,...,h_n]), &&&&&&&&&&&&&&&&& (5) \end{aligned} hi=SWDA(Qi,Ki,Vi,ri),X=Linear(Concat[h1,...,hn]),1≤i≤n,(4)(5)
其中, r i r_i ri是第 i i i 个 h e a d head head的扩张率, Q i , K i Q_i,K_i Qi,Ki 和 V i V_i Vi 代表馈入第 i i i 个 h e a d head head的特征图切片。输出 { h i } i = 1 n \{h_i\}_{i=1}^n {hi}i=1n被concat到一起,然后送到线性层进行特征聚合。
通过为不同的 h e a d s heads heads 设置不同的扩张率,我们的 MSDA 有效地聚合了参与感受野内不同尺度的语义信息,并有效地减少了自注意力机制的冗余,而无需复杂的操作和额外的计算成本。
2. 代码
import torch
import torch.nn as nn
from functools import partial
from timm.models.layers import DropPath, to_2tuple, trunc_normal_
from timm.models.registry import register_model
from timm.models.vision_transformer import _cfgclass Mlp(nn.Module):def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):super().__init__()out_features = out_features or in_featureshidden_features = hidden_features or in_featuresself.fc1 = nn.Linear(in_features, hidden_features)self.act = act_layer()self.fc2 = nn.Linear(hidden_features, out_features)self.drop = nn.Dropout(drop)def forward(self, x):x = self.fc1(x)x = self.act(x)x = self.drop(x)x = self.fc2(x)x = self.drop(x)return xclass DilateAttention(nn.Module):"Implementation of Dilate-attention"def __init__(self, head_dim, qk_scale=None, attn_drop=0, kernel_size=3, dilation=1):super().__init__()self.head_dim = head_dimself.scale = qk_scale or head_dim ** -0.5self.kernel_size=kernel_sizeself.unfold = nn.Unfold(kernel_size, dilation, dilation*(kernel_size-1)//2, 1)self.attn_drop = nn.Dropout(attn_drop)def forward(self,q,k,v):#B, C//3, H, Wq, k, v = q.detach(), k.detach(), v.detach() # todo:!!!B,d,H,W = q.shapeq = q.reshape([B, d//self.head_dim, self.head_dim, 1 ,H*W]).permute(0, 1, 4, 3, 2) # B,h,N,1,dk = self.unfold(k).reshape([B, d//self.head_dim, self.head_dim, self.kernel_size*self.kernel_size, H*W]).permute(0, 1, 4, 2, 3) #B,h,N,d,k*kattn = (q @ k) * self.scale # B,h,N,1,k*kattn = attn.softmax(dim=-1)attn = self.attn_drop(attn)v = self.unfold(v).reshape([B, d//self.head_dim, self.head_dim, self.kernel_size*self.kernel_size, H*W]).permute(0, 1, 4, 3, 2) # B,h,N,k*k,dx = (attn @ v).transpose(1, 2).reshape(B, H, W, d)return xclass MultiDilatelocalAttention(nn.Module):"Implementation of Dilate-attention"def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None,attn_drop=0.,proj_drop=0., kernel_size=3, dilation=[1, 2, 3]):super().__init__()self.dim = dimself.num_heads = num_headshead_dim = dim // num_headsself.dilation = dilationself.kernel_size = kernel_sizeself.scale = qk_scale or head_dim ** -0.5self.num_dilation = len(dilation)assert num_heads % self.num_dilation == 0, f"num_heads{num_heads} must be the times of num_dilation{self.num_dilation}!!"self.qkv = nn.Conv2d(dim, dim * 3, 1, bias=qkv_bias)self.dilate_attention = nn.ModuleList([DilateAttention(head_dim, qk_scale, attn_drop, kernel_size, dilation[i])for i in range(self.num_dilation)])self.proj = nn.Linear(dim, dim)self.proj_drop = nn.Dropout(proj_drop)def forward(self, x):B, H, W, C = x.shapex = x.permute(0, 3, 1, 2)# B, C, H, Wqkv = self.qkv(x).reshape(B, 3, self.num_dilation, C//self.num_dilation, H, W).permute(2, 1, 0, 3, 4, 5)#num_dilation,3,B,C//num_dilation,H,Wx = x.reshape(B, self.num_dilation, C//self.num_dilation, H, W).permute(1, 0, 3, 4, 2 )# num_dilation, B, H, W, C//num_dilationfor i in range(self.num_dilation):x[i] = self.dilate_attention[i](qkv[i][0], qkv[i][1], qkv[i][2])# B, H, W,C//num_dilationx = x.permute(1, 2, 3, 0, 4).reshape(B, H, W, C)x = self.proj(x)x = self.proj_drop(x)return xclass DilateBlock(nn.Module):"Implementation of Dilate-attention block"def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False,qk_scale=None, drop=0., attn_drop=0.,drop_path=0.,act_layer=nn.GELU, norm_layer=nn.LayerNorm, kernel_size=3, dilation=[1, 2, 3],cpe_per_block=False):super().__init__()self.dim = dimself.num_heads = num_headsself.mlp_ratio = mlp_ratioself.kernel_size = kernel_sizeself.dilation = dilationself.cpe_per_block = cpe_per_blockif self.cpe_per_block:self.pos_embed = nn.Conv2d(dim, dim, 3, padding=1, groups=dim)self.norm1 = norm_layer(dim)self.attn = MultiDilatelocalAttention(dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,attn_drop=attn_drop, kernel_size=kernel_size, dilation=dilation)self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()self.norm2 = norm_layer(dim)mlp_hidden_dim = int(dim * mlp_ratio)self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim,act_layer=act_layer, drop=drop)def forward(self, x):if self.cpe_per_block:x = x + self.pos_embed(x)x = x.permute(0, 2, 3, 1)x = x + self.drop_path(self.attn(self.norm1(x)))x = x + self.drop_path(self.mlp(self.norm2(x)))x = x.permute(0, 3, 1, 2)#B, C, H, Wreturn xif __name__ == "__main__":x = torch.rand([2,72,56,56])B, C, H, W = x.shapedim = Cnum_heads = 3 # 必须是dilation的整数倍 且 被dim整除head_dim = dim // num_heads#######################drop_path=0.1depths = [2, 2, 6, 2]num_layers = len(depths)dpr = [x.item() for x in torch.linspace(0, drop_path, sum(depths))]for i_layer in range(num_layers):drop_paths = dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])]#######################m = DilateBlock(dim=C,num_heads=num_heads,kernel_size=3,dilation=[1,2,3],mlp_ratio=4.,qkv_bias=True,qk_scale=head_dim ** -0.5,drop=0.,attn_drop=0.,drop_path=drop_paths[1] if isinstance(drop_paths, list) else drop_paths,norm_layer=nn.LayerNorm, act_layer=nn.GELU, cpe_per_block=True)y = m(x)print(y.shape)
3. 流程图

3.1. MultiDilatelocalAttention

3.2. DilateAttention

3.3. MLP

完整流程图如下:

相关文章:
DilateFormer: Multi-Scale Dilated Transformer for Visual Recognition 中的空洞自注意力机制
空洞自注意力机制 文章目录 摘要1. 模型解释1.1. 滑动窗口扩张注意力1.2. 多尺度扩张注意力 2. 代码3. 流程图3.1. MultiDilatelocalAttention3.2. DilateAttention3.3. MLP 摘要 本文针对DilateFormer中的空洞自注意力机制原理和代码进行详细介绍,最后通过流程图梳…...
二十三种设计模式-适配器模式
适配器模式(Adapter Pattern)是一种结构型设计模式,它允许将不兼容的接口转换成客户端期望的接口,从而使原本因接口不匹配而不能一起工作的类可以协同工作。以下是关于适配器模式的详细介绍: 一、定义及作用 定义&am…...
复用类(2):代理、结合使用组合和继承
1 代理 第三种关系称为代理,这是继承与组合之间的中庸之道,因为我们将一个成员对象置于所要构造的类中(就像组合),但与此同时我们在新类中暴露了该成员对象的所有方法(就像继承)。例如ÿ…...
浅谈云计算07 | 云安全机制
云计算安全机制 一、引言二、加密技术:数据的隐形护盾三、散列机制:数据完整性的忠诚卫士四、数字签名:数据来源与真伪的鉴定专家五、公钥基础设施(PKI):信任的基石六、身份与访问管理(IAM&…...
【机器学习】零售行业的智慧升级:机器学习驱动的精准营销与库存管理
我的个人主页 我的领域:人工智能篇,希望能帮助到大家!!!👍点赞 收藏❤ 在当今数字化浪潮汹涌澎湃的时代,零售行业正站在转型升级的十字路口。市场竞争的白热化使得企业必须另辟蹊径࿰…...
深入理解 Entity、VO、QO、DTO 的区别及其在 MVC 架构中的应用
文章背景 在现代软件开发中,我们经常会接触到各种数据结构的概念,比如 Entity、VO(Value Object)、QO(Query Object)、DTO(Data Transfer Object)等。这些概念尽管看似相似ÿ…...
vue集成高德地图API实现坐标拾取功能
安装与配置: 组件 | vue-amapDescriptionhttps://elemefe.github.io/vue-amap/#/zh-cn/introduction/install简介 | vuemap/vue-amap简介https://vue-amap.guyixi.cn/zh-cn/introduction/introduction.html 我的应用 | 高德控制台高德开放平台官网控…...
Spring Boot Actuator 详细介绍
Spring Boot Actuator 详细介绍 1. 简介 Spring Boot Actuator 是 Spring Boot 提供的一个用于监控和管理应用程序的强大功能模块。它可以帮助我们了解应用程序的运行状况、指标收集、环境信息、日志级别管理等。 2. 添加依赖 2.1 在 pom.xml 中添加以下依赖: …...
联通用户管理系统(一)
#联通用户管理系统(一) 1.新建项目 如果你是windows的话,界面应该是如下的: 2.创建app python manage.py startapp app01一般情况下:我们是在pycharm的终端中运行上述指令,但是pychrm中为我们提供了工具…...
go chan底层分析
go chan底层分析 底层源码hchanmakechan 方法 环形队列阻塞机制向管道写数据流程图源码 从管道读数据流程图源码 关闭通道 底层源码 hchan type hchan struct {qcount uint // 当前队列中剩余元素个数dataqsiz uint // 环形队列长度,即可以…...
idea上git log面板的使用
文章目录 各种颜色含义具体的文件的颜色标签颜色🏷️ 节点和路线 各种颜色含义 具体的文件的颜色 红色:表示还没有 git add 提交到暂存区绿色:表示已经 git add 过,但是从来没有 commit 过蓝色:表示文件有过改动 标…...
WOA-Transformer鲸鱼算法优化编码器时间序列预测(Matlab实现)
WOA-Transformer鲸鱼算法优化编码器时间序列预测(Matlab实现) 目录 WOA-Transformer鲸鱼算法优化编码器时间序列预测(Matlab实现)预测效果基本介绍程序设计参考资料 预测效果 基本介绍 1.Matlab实现WOA-Transformer鲸鱼算法优化编…...
dock 制作 python环境
报错 :Get "https://registry-1.docker.io/v2/": net/http: request canceled while waiting for connection (Client.Timeout exceeded while awaiting headers) 解决方法 配置加速地址 vim /etc/docker/daemon.json 添加以下内容 { "registry-mirror…...
2025第3周 | json-server的基本使用
目录 1. json-server是什么?2. json-server怎么用?2.1 安装2.2 创建db.json2.3 启动服务2.4 查看效果 3. 前端进行模拟交互3.1 创建demo.html3.2 创建demo.js 2025,做想做的事,读想读的书,持续学习,自律生活…...
Autodl转发端口,在本地机器上运行Autodl服务器中的ipynb文件
通过 SSH 隧道将远程端口转发到本地机器 输入服务器示例的SSH指令和密码,将远程的6006端口代理到本地 在服务器终端,激活conda虚拟环境 conda activate posecnnexport PYOPENGL_PLATFORMegljupyter notebook --no-browser --port6006 --allow-root从…...
flutter Get GetMiddleware 中间件不起作用问题
当使用 get: ^5.0.0-release-candidate-9.2.1最新版本时,中间件GetMiddleware各种教程都是让我们在redirect中实现,比如: overrideRouteSettings? redirect(String? route) {return RouteSettings(name: "/companyAuthIndexPage"…...
RabbitMQ(三)
RabbitMQ中的各模式及其用法 工作队列模式一、生产者代码1、封装工具类2、编写代码3、发送消息效果 二、消费者代码1、编写代码2、运行效果 发布订阅模式一、生产者代码二、消费者代码1、消费者1号2、消费者2号 三、运行效果四、小结 路由模式一、生产者代码二、消费者代码1、消…...
【Python】Python之locust压测教程+从0到1demo:基础轻量级压测实战(1)
文章目录 一、什么是Locust二、Locust 架构组成三、实战 Demo准备一个可调用的接口编写一个接口测试用例编写一个性能测试用例执行性能测试用例代码1、通过 Web UI 执行(GUI模式)2、通过命令行执行(非GUI模式) 小知识:…...
【JavaScript】基础内容,HTML如何引用JavaScript, JS 常用的数据类型
HTML 嵌入 Javascript 的方式 引入外部 js 文件 <head> <script Language "javaScript" src"index.js"/> </head>内部声明 <head> <script language"javascript">function hello(){alert("hello word&qu…...
vue使用自动化导入api插件unplugin-auto-import,避免频繁手动导入
unplugin-auto-import是一个现代的自动导入插件,旨在简化前端开发中的导入过程,减少手动导入的繁琐工作,提升开发效率。它支持多种构建工具,包括Vite、Webpack、Rollup和esbuild,并且可以与TypeScript配合使用&…...
智慧医疗能源事业线深度画像分析(上)
引言 医疗行业作为现代社会的关键基础设施,其能源消耗与环境影响正日益受到关注。随着全球"双碳"目标的推进和可持续发展理念的深入,智慧医疗能源事业线应运而生,致力于通过创新技术与管理方案,重构医疗领域的能源使用模式。这一事业线融合了能源管理、可持续发…...
VB.net复制Ntag213卡写入UID
本示例使用的发卡器:https://item.taobao.com/item.htm?ftt&id615391857885 一、读取旧Ntag卡的UID和数据 Private Sub Button15_Click(sender As Object, e As EventArgs) Handles Button15.Click轻松读卡技术支持:网站:Dim i, j As IntegerDim cardidhex, …...
PHP和Node.js哪个更爽?
先说结论,rust完胜。 php:laravel,swoole,webman,最开始在苏宁的时候写了几年php,当时觉得php真的是世界上最好的语言,因为当初活在舒适圈里,不愿意跳出来,就好比当初活在…...
FastAPI 教程:从入门到实践
FastAPI 是一个现代、快速(高性能)的 Web 框架,用于构建 API,支持 Python 3.6。它基于标准 Python 类型提示,易于学习且功能强大。以下是一个完整的 FastAPI 入门教程,涵盖从环境搭建到创建并运行一个简单的…...
3-11单元格区域边界定位(End属性)学习笔记
返回一个Range 对象,只读。该对象代表包含源区域的区域上端下端左端右端的最后一个单元格。等同于按键 End 向上键(End(xlUp))、End向下键(End(xlDown))、End向左键(End(xlToLeft)End向右键(End(xlToRight)) 注意:它移动的位置必须是相连的有内容的单元格…...
【Go语言基础【13】】函数、闭包、方法
文章目录 零、概述一、函数基础1、函数基础概念2、参数传递机制3、返回值特性3.1. 多返回值3.2. 命名返回值3.3. 错误处理 二、函数类型与高阶函数1. 函数类型定义2. 高阶函数(函数作为参数、返回值) 三、匿名函数与闭包1. 匿名函数(Lambda函…...
使用Spring AI和MCP协议构建图片搜索服务
目录 使用Spring AI和MCP协议构建图片搜索服务 引言 技术栈概览 项目架构设计 架构图 服务端开发 1. 创建Spring Boot项目 2. 实现图片搜索工具 3. 配置传输模式 Stdio模式(本地调用) SSE模式(远程调用) 4. 注册工具提…...
Caliper 负载(Workload)详细解析
Caliper 负载(Workload)详细解析 负载(Workload)是 Caliper 性能测试的核心部分,它定义了测试期间要执行的具体合约调用行为和交易模式。下面我将全面深入地讲解负载的各个方面。 一、负载模块基本结构 一个典型的负载模块(如 workload.js)包含以下基本结构: use strict;/…...
DAY 26 函数专题1
函数定义与参数知识点回顾:1. 函数的定义2. 变量作用域:局部变量和全局变量3. 函数的参数类型:位置参数、默认参数、不定参数4. 传递参数的手段:关键词参数5 题目1:计算圆的面积 任务: 编写一…...
Android屏幕刷新率与FPS(Frames Per Second) 120hz
Android屏幕刷新率与FPS(Frames Per Second) 120hz 屏幕刷新率是屏幕每秒钟刷新显示内容的次数,单位是赫兹(Hz)。 60Hz 屏幕:每秒刷新 60 次,每次刷新间隔约 16.67ms 90Hz 屏幕:每秒刷新 90 次,…...
