当前位置: 首页 > news >正文

DilateFormer: Multi-Scale Dilated Transformer for Visual Recognition 中的空洞自注意力机制

空洞自注意力机制

文章目录

  • 摘要
  • 1. 模型解释
    • 1.1. 滑动窗口扩张注意力
    • 1.2. 多尺度扩张注意力
  • 2. 代码
  • 3. 流程图
    • 3.1. MultiDilatelocalAttention
    • 3.2. DilateAttention
    • 3.3. MLP

摘要

    本文针对DilateFormer中的空洞自注意力机制原理和代码进行详细介绍,最后通过流程图梳理其实现原理。

1. 模型解释

1.1. 滑动窗口扩张注意力

    根据在普通视觉变换器(ViTs)中浅层全局注意力中观察到的局部性稀疏性特性,我们提出了一种滑动窗口扩张注意力(SWDA) 操作,其中,keys和values被以query patch为中心的滑动窗口稀疏地选择。然后对这些代表性patches进行自注意力。我们的 SWDA 正式描述如下:

X = S W D A ( Q , K , V , r ) ( 1 ) \begin{aligned} &&&&&&&&&&&&& X = SWDA(Q,K,V,r) &&&&&&&&&&&&&&&& (1) \end{aligned} X=SWDA(Q,K,V,r)(1)

其中, Q , K , V Q,K,V Q,K,V分别代表query、key和value矩阵,三个矩阵的每一行表示一个query/key/value特征向量。对于原始特征图上 ( i , j ) (i,j) (i,j)位置的query,SWDA以尺寸为 w × w w×w w×w大小的滑动窗口,稀疏地选择key和value去指导自注意力。

    而且,我们定义一个扩张率 r ϵ N + r \epsilon N^+ N+去控制稀疏程度。特别地,对于位置 ( i , j ) (i,j) (i,j)SWDA计算的输出 X X X中的相应分量 x i j x_{ij} xij定义如下:

x i j = A t t e n t i o n ( q i j , K r , V r ) , ( 2 ) = S o f t m a x ( q i j K r T d k ) V r , 1 ≤ i ≤ W , 1 ≤ i ≤ H \begin{aligned} &&&&&&&&&&&& x_{ij} &= Attention(q_{ij},K_r,V_r), &&&&&&&&&&&&&&&& (2)\\ &&&&&&&&&&&&&=Softmax(\frac{q_{ij}K^T_r}{\sqrt{d_k}})V_r,& 1≤i≤W, 1≤i≤H \\ \end{aligned} xij=Attention(qij,Kr,Vr),=Softmax(dk qijKrT)Vr,1iW,1iH(2)

其中, H H H W W W 是特征图的高和宽。 K r K_r Kr V r V_r Vr表示从特征图 K K K V V V 中选择的keys和values。

    给定位于 ( i , j ) (i,j) (i,j)的query,位于坐标 ( i ′ , j ′ ) (i', j') (i,j) 下keys和values将被选择去指导自注意力(self-attetion):

{ ( i ′ , j ′ ) ∣ i ′ = i + p × r , j ′ = j + q × r } , − w 2 ≤ p , q ≤ w 2 . ( 3 ) \begin{aligned} &&&&&&&&&&&&& \{(i',j')|i'=i+p×r, j'=j+q×r \}, \frac{-w}{2}≤p, q≤\frac{w}{2}. &&&&&&&&&&&&&&&& (3) \end{aligned} {(i,j)i=i+p×r,j=j+q×r},2wp,q2w.(3)

    我们的 SWDA 以滑动窗口的方式对所有query patches进行自注意力操作。对于特征图边缘的query,我们简单地使用卷积运算中常用的 补零策略 来保持特征图的大小。通过稀疏地选择以queries为中心的keys和values,所提出的 SWDA 明确满足局部性和稀疏性属性,并且可以有效地对远程依赖关系进行建模

1.2. 多尺度扩张注意力

在这里插入图片描述

图4. 多尺度空洞注意力。

    首先,特征图的通道被划分不同的heads。然后,自注意力操作是在红色查询块周围的窗口中的彩色块之间执行的,在不同的头中使用不同的膨胀率。此外,不同heads中的特征被连接在一起,然后输入到线性层中。默认情况下,我们使用 3 × 3 的内核大小,膨胀率 r = 1、2 和 3,不同头中参与感受野的大小为 3 × 3、5 × 5 和 7 × 7。

    为了利用块级自注意力机制在不同尺度上的稀疏性,我们进一步提出了多尺度扩张注意力(MSDA) 块来提取多尺度语义信息。如图4所示,给定特征图 X X X,我们通过 线性投影(linear projection) 获得相应的query、kay和value。之后,我们将特征图的通道划分到 n n n 个不同的 h e a d s heads heads,并在不同的 h e a d s heads heads中以不同的膨胀率(dilation rates)执行多尺度SWDA。具体来说,我们的MSDA计算如下:

h i = S W D A ( Q i , K i , V i , r i ) , 1 ≤ i ≤ n , ( 4 ) X = L i n e a r ( C o n c a t [ h 1 , . . . , h n ] ) , ( 5 ) \begin{aligned} &&&&&&&&&&&&& h_i=SWDA(Q_i,K_i,V_i,r_i), &1≤i≤n, &&&&&&&&&&&&&&&& (4)\\ &\\ &&&&&&&&&&&&& X=Linear(Concat[h_1,...,h_n]), &&&&&&&&&&&&&&&&& (5) \end{aligned} hi=SWDA(Qi,Ki,Vi,ri),X=Linear(Concat[h1,...,hn]),1in,(4)(5)

其中, r i r_i ri是第 i i i h e a d head head的扩张率, Q i , K i Q_i,K_i Qi,Ki V i V_i Vi 代表馈入第 i i i h e a d head head的特征图切片。输出 { h i } i = 1 n \{h_i\}_{i=1}^n {hi}i=1n被concat到一起,然后送到线性层进行特征聚合。

    通过为不同的 h e a d s heads heads 设置不同的扩张率,我们的 MSDA 有效地聚合了参与感受野内不同尺度的语义信息,并有效地减少了自注意力机制的冗余,而无需复杂的操作和额外的计算成本。

2. 代码

import torch
import torch.nn as nn
from functools import partial
from timm.models.layers import DropPath, to_2tuple, trunc_normal_
from timm.models.registry import register_model
from timm.models.vision_transformer import _cfgclass Mlp(nn.Module):def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):super().__init__()out_features = out_features or in_featureshidden_features = hidden_features or in_featuresself.fc1 = nn.Linear(in_features, hidden_features)self.act = act_layer()self.fc2 = nn.Linear(hidden_features, out_features)self.drop = nn.Dropout(drop)def forward(self, x):x = self.fc1(x)x = self.act(x)x = self.drop(x)x = self.fc2(x)x = self.drop(x)return xclass DilateAttention(nn.Module):"Implementation of Dilate-attention"def __init__(self, head_dim, qk_scale=None, attn_drop=0, kernel_size=3, dilation=1):super().__init__()self.head_dim = head_dimself.scale = qk_scale or head_dim ** -0.5self.kernel_size=kernel_sizeself.unfold = nn.Unfold(kernel_size, dilation, dilation*(kernel_size-1)//2, 1)self.attn_drop = nn.Dropout(attn_drop)def forward(self,q,k,v):#B, C//3, H, Wq, k, v = q.detach(), k.detach(), v.detach()  # todo:!!!B,d,H,W = q.shapeq = q.reshape([B, d//self.head_dim, self.head_dim, 1 ,H*W]).permute(0, 1, 4, 3, 2)  # B,h,N,1,dk = self.unfold(k).reshape([B, d//self.head_dim, self.head_dim, self.kernel_size*self.kernel_size, H*W]).permute(0, 1, 4, 2, 3)  #B,h,N,d,k*kattn = (q @ k) * self.scale  # B,h,N,1,k*kattn = attn.softmax(dim=-1)attn = self.attn_drop(attn)v = self.unfold(v).reshape([B, d//self.head_dim, self.head_dim, self.kernel_size*self.kernel_size, H*W]).permute(0, 1, 4, 3, 2)  # B,h,N,k*k,dx = (attn @ v).transpose(1, 2).reshape(B, H, W, d)return xclass MultiDilatelocalAttention(nn.Module):"Implementation of Dilate-attention"def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None,attn_drop=0.,proj_drop=0., kernel_size=3, dilation=[1, 2, 3]):super().__init__()self.dim = dimself.num_heads = num_headshead_dim = dim // num_headsself.dilation = dilationself.kernel_size = kernel_sizeself.scale = qk_scale or head_dim ** -0.5self.num_dilation = len(dilation)assert num_heads % self.num_dilation == 0, f"num_heads{num_heads} must be the times of num_dilation{self.num_dilation}!!"self.qkv = nn.Conv2d(dim, dim * 3, 1, bias=qkv_bias)self.dilate_attention = nn.ModuleList([DilateAttention(head_dim, qk_scale, attn_drop, kernel_size, dilation[i])for i in range(self.num_dilation)])self.proj = nn.Linear(dim, dim)self.proj_drop = nn.Dropout(proj_drop)def forward(self, x):B, H, W, C = x.shapex = x.permute(0, 3, 1, 2)# B, C, H, Wqkv = self.qkv(x).reshape(B, 3, self.num_dilation, C//self.num_dilation, H, W).permute(2, 1, 0, 3, 4, 5)#num_dilation,3,B,C//num_dilation,H,Wx = x.reshape(B, self.num_dilation, C//self.num_dilation, H, W).permute(1, 0, 3, 4, 2 )# num_dilation, B, H, W, C//num_dilationfor i in range(self.num_dilation):x[i] = self.dilate_attention[i](qkv[i][0], qkv[i][1], qkv[i][2])# B, H, W,C//num_dilationx = x.permute(1, 2, 3, 0, 4).reshape(B, H, W, C)x = self.proj(x)x = self.proj_drop(x)return xclass DilateBlock(nn.Module):"Implementation of Dilate-attention block"def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False,qk_scale=None, drop=0., attn_drop=0.,drop_path=0.,act_layer=nn.GELU, norm_layer=nn.LayerNorm, kernel_size=3, dilation=[1, 2, 3],cpe_per_block=False):super().__init__()self.dim = dimself.num_heads = num_headsself.mlp_ratio = mlp_ratioself.kernel_size = kernel_sizeself.dilation = dilationself.cpe_per_block = cpe_per_blockif self.cpe_per_block:self.pos_embed = nn.Conv2d(dim, dim, 3, padding=1, groups=dim)self.norm1 = norm_layer(dim)self.attn = MultiDilatelocalAttention(dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,attn_drop=attn_drop, kernel_size=kernel_size, dilation=dilation)self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()self.norm2 = norm_layer(dim)mlp_hidden_dim = int(dim * mlp_ratio)self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim,act_layer=act_layer, drop=drop)def forward(self, x):if self.cpe_per_block:x = x + self.pos_embed(x)x = x.permute(0, 2, 3, 1)x = x + self.drop_path(self.attn(self.norm1(x)))x = x + self.drop_path(self.mlp(self.norm2(x)))x = x.permute(0, 3, 1, 2)#B, C, H, Wreturn xif __name__ == "__main__":x = torch.rand([2,72,56,56])B, C, H, W = x.shapedim = Cnum_heads = 3   # 必须是dilation的整数倍 且 被dim整除head_dim = dim // num_heads#######################drop_path=0.1depths = [2, 2, 6, 2]num_layers = len(depths)dpr = [x.item() for x in torch.linspace(0, drop_path, sum(depths))]for i_layer in range(num_layers):drop_paths = dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])]#######################m = DilateBlock(dim=C,num_heads=num_heads,kernel_size=3,dilation=[1,2,3],mlp_ratio=4.,qkv_bias=True,qk_scale=head_dim ** -0.5,drop=0.,attn_drop=0.,drop_path=drop_paths[1] if isinstance(drop_paths, list) else drop_paths,norm_layer=nn.LayerNorm, act_layer=nn.GELU, cpe_per_block=True)y = m(x)print(y.shape)

3. 流程图

在这里插入图片描述


3.1. MultiDilatelocalAttention

在这里插入图片描述

3.2. DilateAttention

在这里插入图片描述

3.3. MLP

在这里插入图片描述

完整流程图如下:

请添加图片描述

相关文章:

DilateFormer: Multi-Scale Dilated Transformer for Visual Recognition 中的空洞自注意力机制

空洞自注意力机制 文章目录 摘要1. 模型解释1.1. 滑动窗口扩张注意力1.2. 多尺度扩张注意力 2. 代码3. 流程图3.1. MultiDilatelocalAttention3.2. DilateAttention3.3. MLP 摘要 本文针对DilateFormer中的空洞自注意力机制原理和代码进行详细介绍,最后通过流程图梳…...

二十三种设计模式-适配器模式

适配器模式(Adapter Pattern)是一种结构型设计模式,它允许将不兼容的接口转换成客户端期望的接口,从而使原本因接口不匹配而不能一起工作的类可以协同工作。以下是关于适配器模式的详细介绍: 一、定义及作用 定义&am…...

复用类(2):代理、结合使用组合和继承

1 代理 第三种关系称为代理,这是继承与组合之间的中庸之道,因为我们将一个成员对象置于所要构造的类中(就像组合),但与此同时我们在新类中暴露了该成员对象的所有方法(就像继承)。例如&#xff…...

浅谈云计算07 | 云安全机制

云计算安全机制 一、引言二、加密技术:数据的隐形护盾三、散列机制:数据完整性的忠诚卫士四、数字签名:数据来源与真伪的鉴定专家五、公钥基础设施(PKI):信任的基石六、身份与访问管理(IAM&…...

【机器学习】零售行业的智慧升级:机器学习驱动的精准营销与库存管理

我的个人主页 我的领域:人工智能篇,希望能帮助到大家!!!👍点赞 收藏❤ 在当今数字化浪潮汹涌澎湃的时代,零售行业正站在转型升级的十字路口。市场竞争的白热化使得企业必须另辟蹊径&#xff0…...

深入理解 Entity、VO、QO、DTO 的区别及其在 MVC 架构中的应用

文章背景 在现代软件开发中,我们经常会接触到各种数据结构的概念,比如 Entity、VO(Value Object)、QO(Query Object)、DTO(Data Transfer Object)等。这些概念尽管看似相似&#xff…...

vue集成高德地图API实现坐标拾取功能

安装与配置: 组件 | vue-amapDescriptionhttps://elemefe.github.io/vue-amap/#/zh-cn/introduction/install简介 | vuemap/vue-amap简介https://vue-amap.guyixi.cn/zh-cn/introduction/introduction.html ​​​​我的应用 | 高德控制台高德开放平台官网控…...

Spring Boot Actuator 详细介绍

Spring Boot Actuator 详细介绍 1. 简介 Spring Boot Actuator 是 Spring Boot 提供的一个用于监控和管理应用程序的强大功能模块。它可以帮助我们了解应用程序的运行状况、指标收集、环境信息、日志级别管理等。 2. 添加依赖 2.1 在 pom.xml 中添加以下依赖: …...

联通用户管理系统(一)

#联通用户管理系统(一) 1.新建项目 如果你是windows的话,界面应该是如下的: 2.创建app python manage.py startapp app01一般情况下:我们是在pycharm的终端中运行上述指令,但是pychrm中为我们提供了工具…...

go chan底层分析

go chan底层分析 底层源码hchanmakechan 方法 环形队列阻塞机制向管道写数据流程图源码 从管道读数据流程图源码 关闭通道 底层源码 hchan type hchan struct {qcount uint // 当前队列中剩余元素个数dataqsiz uint // 环形队列长度,即可以…...

idea上git log面板的使用

文章目录 各种颜色含义具体的文件的颜色标签颜色🏷️ 节点和路线 各种颜色含义 具体的文件的颜色 红色:表示还没有 git add 提交到暂存区绿色:表示已经 git add 过,但是从来没有 commit 过蓝色:表示文件有过改动 标…...

WOA-Transformer鲸鱼算法优化编码器时间序列预测(Matlab实现)

WOA-Transformer鲸鱼算法优化编码器时间序列预测(Matlab实现) 目录 WOA-Transformer鲸鱼算法优化编码器时间序列预测(Matlab实现)预测效果基本介绍程序设计参考资料 预测效果 基本介绍 1.Matlab实现WOA-Transformer鲸鱼算法优化编…...

dock 制作 python环境

报错 :Get "https://registry-1.docker.io/v2/": net/http: request canceled while waiting for connection (Client.Timeout exceeded while awaiting headers) 解决方法 配置加速地址 vim /etc/docker/daemon.json 添加以下内容 { "registry-mirror…...

2025第3周 | json-server的基本使用

目录 1. json-server是什么?2. json-server怎么用?2.1 安装2.2 创建db.json2.3 启动服务2.4 查看效果 3. 前端进行模拟交互3.1 创建demo.html3.2 创建demo.js 2025,做想做的事,读想读的书,持续学习,自律生活…...

Autodl转发端口,在本地机器上运行Autodl服务器中的ipynb文件

通过 SSH 隧道将远程端口转发到本地机器 输入服务器示例的SSH指令和密码,将远程的6006端口代理到本地 在服务器终端,激活conda虚拟环境 conda activate posecnnexport PYOPENGL_PLATFORMegljupyter notebook --no-browser --port6006 --allow-root从…...

flutter Get GetMiddleware 中间件不起作用问题

当使用 get: ^5.0.0-release-candidate-9.2.1最新版本时,中间件GetMiddleware各种教程都是让我们在redirect中实现,比如: overrideRouteSettings? redirect(String? route) {return RouteSettings(name: "/companyAuthIndexPage"…...

RabbitMQ(三)

RabbitMQ中的各模式及其用法 工作队列模式一、生产者代码1、封装工具类2、编写代码3、发送消息效果 二、消费者代码1、编写代码2、运行效果 发布订阅模式一、生产者代码二、消费者代码1、消费者1号2、消费者2号 三、运行效果四、小结 路由模式一、生产者代码二、消费者代码1、消…...

【Python】Python之locust压测教程+从0到1demo:基础轻量级压测实战(1)

文章目录 一、什么是Locust二、Locust 架构组成三、实战 Demo准备一个可调用的接口编写一个接口测试用例编写一个性能测试用例执行性能测试用例代码1、通过 Web UI 执行(GUI模式)2、通过命令行执行(非GUI模式) 小知识:…...

【JavaScript】基础内容,HTML如何引用JavaScript, JS 常用的数据类型

HTML 嵌入 Javascript 的方式 引入外部 js 文件 <head> <script Language "javaScript" src"index.js"/> </head>内部声明 <head> <script language"javascript">function hello(){alert("hello word&qu…...

vue使用自动化导入api插件unplugin-auto-import,避免频繁手动导入

‌unplugin-auto-import‌是一个现代的自动导入插件&#xff0c;旨在简化前端开发中的导入过程&#xff0c;减少手动导入的繁琐工作&#xff0c;提升开发效率。它支持多种构建工具&#xff0c;包括Vite、Webpack、Rollup和esbuild&#xff0c;并且可以与TypeScript配合使用&…...

生成xcframework

打包 XCFramework 的方法 XCFramework 是苹果推出的一种多平台二进制分发格式&#xff0c;可以包含多个架构和平台的代码。打包 XCFramework 通常用于分发库或框架。 使用 Xcode 命令行工具打包 通过 xcodebuild 命令可以打包 XCFramework。确保项目已经配置好需要支持的平台…...

【OSG学习笔记】Day 18: 碰撞检测与物理交互

物理引擎&#xff08;Physics Engine&#xff09; 物理引擎 是一种通过计算机模拟物理规律&#xff08;如力学、碰撞、重力、流体动力学等&#xff09;的软件工具或库。 它的核心目标是在虚拟环境中逼真地模拟物体的运动和交互&#xff0c;广泛应用于 游戏开发、动画制作、虚…...

Appium+python自动化(十六)- ADB命令

简介 Android 调试桥(adb)是多种用途的工具&#xff0c;该工具可以帮助你你管理设备或模拟器 的状态。 adb ( Android Debug Bridge)是一个通用命令行工具&#xff0c;其允许您与模拟器实例或连接的 Android 设备进行通信。它可为各种设备操作提供便利&#xff0c;如安装和调试…...

React第五十七节 Router中RouterProvider使用详解及注意事项

前言 在 React Router v6.4 中&#xff0c;RouterProvider 是一个核心组件&#xff0c;用于提供基于数据路由&#xff08;data routers&#xff09;的新型路由方案。 它替代了传统的 <BrowserRouter>&#xff0c;支持更强大的数据加载和操作功能&#xff08;如 loader 和…...

mongodb源码分析session执行handleRequest命令find过程

mongo/transport/service_state_machine.cpp已经分析startSession创建ASIOSession过程&#xff0c;并且验证connection是否超过限制ASIOSession和connection是循环接受客户端命令&#xff0c;把数据流转换成Message&#xff0c;状态转变流程是&#xff1a;State::Created 》 St…...

centos 7 部署awstats 网站访问检测

一、基础环境准备&#xff08;两种安装方式都要做&#xff09; bash # 安装必要依赖 yum install -y httpd perl mod_perl perl-Time-HiRes perl-DateTime systemctl enable httpd # 设置 Apache 开机自启 systemctl start httpd # 启动 Apache二、安装 AWStats&#xff0…...

微信小程序 - 手机震动

一、界面 <button type"primary" bindtap"shortVibrate">短震动</button> <button type"primary" bindtap"longVibrate">长震动</button> 二、js逻辑代码 注&#xff1a;文档 https://developers.weixin.qq…...

【ROS】Nav2源码之nav2_behavior_tree-行为树节点列表

1、行为树节点分类 在 Nav2(Navigation2)的行为树框架中,行为树节点插件按照功能分为 Action(动作节点)、Condition(条件节点)、Control(控制节点) 和 Decorator(装饰节点) 四类。 1.1 动作节点 Action 执行具体的机器人操作或任务,直接与硬件、传感器或外部系统…...

html-<abbr> 缩写或首字母缩略词

定义与作用 <abbr> 标签用于表示缩写或首字母缩略词&#xff0c;它可以帮助用户更好地理解缩写的含义&#xff0c;尤其是对于那些不熟悉该缩写的用户。 title 属性的内容提供了缩写的详细说明。当用户将鼠标悬停在缩写上时&#xff0c;会显示一个提示框。 示例&#x…...

大语言模型(LLM)中的KV缓存压缩与动态稀疏注意力机制设计

随着大语言模型&#xff08;LLM&#xff09;参数规模的增长&#xff0c;推理阶段的内存占用和计算复杂度成为核心挑战。传统注意力机制的计算复杂度随序列长度呈二次方增长&#xff0c;而KV缓存的内存消耗可能高达数十GB&#xff08;例如Llama2-7B处理100K token时需50GB内存&a…...