YOLO即插即用模块---AgentAttention
Agent Attention: On the Integration of Softmax and Linear Attention
论文地址:https://arxiv.org/pdf/2312.08874
问题: 普遍使用的 Softmax 注意力机制在视觉 Transformer 模型中计算复杂度过高,限制了其在各种场景中的应用。
方法: 提出了一个新的注意力机制,名为 Agent Attention,通过引入一组代理 token (A) 来解决计算复杂度过高的问题。
具体步骤:
-
代理聚合 (Agent Aggregation): 将代理 token (A) 作为查询 token (Q) 的代理,从键 (K) 和值 (V) 中聚合信息,形成代理特征 (VA)。
-
代理广播 (Agent Broadcast): 将代理 token (A) 作为键,将全局信息从代理特征 (VA) 广播到每个查询 token (Q),形成最终的输出。
代理 token (A) 的获取方式:
-
可学习的参数
-
从输入特征中提取 (例如,通过池化或卷积)
Agent Attention 模块:
-
包含纯 Agent Attention、代理偏置 (Agent Bias) 和深度可分离卷积 (DWC) 模块。
-
代理偏置用于添加位置信息,帮助不同的代理 token 关注不同的区域。
-
DWC 模块用于保持特征多样性,弥补线性注意力的不足。
-

Agent Attention 的优势:
-
高效计算和高表达能力: 结合了 Softmax 注意力和线性注意力的优点,既降低了计算复杂度,又保持了高表达能力。
-
大感受野: 可以采用更大的感受野,甚至全局感受野,同时保持相同的计算量。P8
实验结果:
-
在图像分类、目标检测、语义分割和图像生成等任务上,Agent Attention 都取得了显著的性能提升。
-
在高分辨率场景中,Agent Attention 表现出优异的性能。
-
将 Agent Attention 应用于 Stable Diffusion,可以加速图像生成过程,并显著提高图像生成质量,无需任何额外的训练。
总结: Agent Attention 是一种高效且高表达的注意力机制,可以有效地解决 Softmax 注意力计算复杂度过高的问题,在各种视觉任务中取得了显著的性能提升,特别是在高分辨率场景中。
即插即用代码:
import torch
import torch.nn as nn
from timm.models.layers import trunc_normal_class AgentAttention(nn.Module):def __init__(self, dim, num_patches, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.,sr_ratio=1, agent_num=49, **kwargs):super().__init__()assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}."self.dim = dimself.num_patches = num_patcheswindow_size = (int(num_patches ** 0.5), int(num_patches ** 0.5))self.window_size = window_sizeself.num_heads = num_headshead_dim = dim // num_headsself.scale = head_dim ** -0.5self.q = nn.Linear(dim, dim, bias=qkv_bias)self.kv = nn.Linear(dim, dim * 2, bias=qkv_bias)self.attn_drop = nn.Dropout(attn_drop)self.proj = nn.Linear(dim, dim)self.proj_drop = nn.Dropout(proj_drop)self.sr_ratio = sr_ratioif sr_ratio > 1:self.sr = nn.Conv2d(dim, dim, kernel_size=sr_ratio, stride=sr_ratio)self.norm = nn.LayerNorm(dim)self.agent_num = agent_numself.dwc = nn.Conv2d(in_channels=dim, out_channels=dim, kernel_size=(3, 3), padding=1, groups=dim)self.an_bias = nn.Parameter(torch.zeros(num_heads, agent_num, 7, 7))self.na_bias = nn.Parameter(torch.zeros(num_heads, agent_num, 7, 7))self.ah_bias = nn.Parameter(torch.zeros(1, num_heads, agent_num, window_size[0] // sr_ratio, 1))self.aw_bias = nn.Parameter(torch.zeros(1, num_heads, agent_num, 1, window_size[1] // sr_ratio))self.ha_bias = nn.Parameter(torch.zeros(1, num_heads, window_size[0], 1, agent_num))self.wa_bias = nn.Parameter(torch.zeros(1, num_heads, 1, window_size[1], agent_num))trunc_normal_(self.an_bias, std=.02)trunc_normal_(self.na_bias, std=.02)trunc_normal_(self.ah_bias, std=.02)trunc_normal_(self.aw_bias, std=.02)trunc_normal_(self.ha_bias, std=.02)trunc_normal_(self.wa_bias, std=.02)pool_size = int(agent_num ** 0.5)self.pool = nn.AdaptiveAvgPool2d(output_size=(pool_size, pool_size))self.softmax = nn.Softmax(dim=-1)def forward(self, x, H, W):b, n, c = x.shapenum_heads = self.num_headshead_dim = c // num_headsq = self.q(x)if self.sr_ratio > 1:x_ = x.permute(0, 2, 1).reshape(b, c, H, W)x_ = self.sr(x_).reshape(b, c, -1).permute(0, 2, 1)x_ = self.norm(x_)kv = self.kv(x_).reshape(b, -1, 2, c).permute(2, 0, 1, 3)else:kv = self.kv(x).reshape(b, -1, 2, c).permute(2, 0, 1, 3)k, v = kv[0], kv[1]agent_tokens = self.pool(q.reshape(b, H, W, c).permute(0, 3, 1, 2)).reshape(b, c, -1).permute(0, 2, 1)q = q.reshape(b, n, num_heads, head_dim).permute(0, 2, 1, 3)k = k.reshape(b, n // self.sr_ratio ** 2, num_heads, head_dim).permute(0, 2, 1, 3)v = v.reshape(b, n // self.sr_ratio ** 2, num_heads, head_dim).permute(0, 2, 1, 3)agent_tokens = agent_tokens.reshape(b, self.agent_num, num_heads, head_dim).permute(0, 2, 1, 3)kv_size = (self.window_size[0] // self.sr_ratio, self.window_size[1] // self.sr_ratio)position_bias1 = nn.functional.interpolate(self.an_bias, size=kv_size, mode='bilinear')position_bias1 = position_bias1.reshape(1, num_heads, self.agent_num, -1).repeat(b, 1, 1, 1)position_bias2 = (self.ah_bias + self.aw_bias).reshape(1, num_heads, self.agent_num, -1).repeat(b, 1, 1, 1)position_bias = position_bias1 + position_bias2agent_attn = self.softmax((agent_tokens * self.scale) @ k.transpose(-2, -1) + position_bias)agent_attn = self.attn_drop(agent_attn)agent_v = agent_attn @ vagent_bias1 = nn.functional.interpolate(self.na_bias, size=self.window_size, mode='bilinear')agent_bias1 = agent_bias1.reshape(1, num_heads, self.agent_num, -1).permute(0, 1, 3, 2).repeat(b, 1, 1, 1)agent_bias2 = (self.ha_bias + self.wa_bias).reshape(1, num_heads, -1, self.agent_num).repeat(b, 1, 1, 1)agent_bias = agent_bias1 + agent_bias2q_attn = self.softmax((q * self.scale) @ agent_tokens.transpose(-2, -1) + agent_bias)q_attn = self.attn_drop(q_attn)x = q_attn @ agent_vx = x.transpose(1, 2).reshape(b, n, c)v = v.transpose(1, 2).reshape(b, H // self.sr_ratio, W // self.sr_ratio, c).permute(0, 3, 1, 2)if self.sr_ratio > 1:v = nn.functional.interpolate(v, size=(H, W), mode='bilinear')x = x + self.dwc(v).permute(0, 2, 3, 1).reshape(b, n, c)x = self.proj(x)x = self.proj_drop(x)return xif __name__ == '__main__':dim = 4num_patches = 64block = AgentAttention(dim=dim, num_patches=num_patches)H, W = 8,8x = torch.rand(1, num_patches, dim)output = block(x, H, W)print(f"Input size: {x.size()}")print(f"Output size: {output.size()}")
YOLO小伙伴可进群交流:

相关文章:
YOLO即插即用模块---AgentAttention
Agent Attention: On the Integration of Softmax and Linear Attention 论文地址:https://arxiv.org/pdf/2312.08874 问题: 普遍使用的 Softmax 注意力机制在视觉 Transformer 模型中计算复杂度过高,限制了其在各种场景中的应用。 方法&a…...
探索开源语音识别的未来:高效利用先进的自动语音识别技术20241030
🚀 探索开源语音识别的未来:高效利用自动语音识别技术 🌟 引言 在数字化时代,语音识别技术正在引领人机交互的新潮流,为各行业带来了颠覆性的改变。开源的自动语音识别(ASR)系统,如…...
学习路之TP6--workman安装
一、安装 首先通过 composer 安装 composer require topthink/think-worker 报错: 分析:最新版本需要TP8,或装低版本的 composer require topthink/think-worker:^3.*安装后, 增加目录 vendor\workerman vendor\topthink\think-w…...
.NET内网实战:通过白名单文件反序列化漏洞绕过UAC
01阅读须知 此文所节选自小报童《.NET 内网实战攻防》专栏,主要内容有.NET在各个内网渗透阶段与Windows系统交互的方式和技巧,对内网和后渗透感兴趣的朋友们可以订阅该电子报刊,解锁更多的报刊内容。 02基本介绍 03原理分析 在渗透测试和红…...
AI Agents - 自动化项目:计划、评估和分配
Agents: Role 角色Goal 目标Backstory 背景故事 Tasks: Description 描述Expected Output 期望输出Agent 代理 Automated Project: Planning, Estimation, and Allocation Initial Imports 1.本地文件helper.py # Add your utilities or helper functions to…...
Git的.gitignore文件
一、各语言对应的.gitignore模板文件 项目地址:https://github.com/github/gitignore 二、.gitignore文件不生效 .gitignore文件只是ignore没有被追踪的文件,已被追踪的文件,要先删除缓存文件。 # 单个文件 git rm --cached file/path/to…...
网站安全,WAF网站保护暴力破解
雷池的核心功能 通过过滤和监控 Web 应用与互联网之间的 HTTP 流量,功能包括: SQL 注入保护:防止恶意 SQL 代码的注入,保护网站数据安全。跨站脚本攻击 (XSS):阻止攻击者在用户浏览器中执行恶意脚本。暴力破解防护&a…...
深度学习:梯度下降算法简介
梯度下降算法简介 梯度下降算法 我们思考这样一个问题,现在需要用一条直线来回归拟合这三个点,直线的方程是 y w ^ x b y \hat{w}x b yw^xb,我们假设斜率 w ^ \hat{w} w^是已知的,现在想要找到一个最好的截距 b b b。 一条…...
SparkSQL整合Hive后,如何启动hiveserver2服务
当spark sql与hive整合后,我们就无法启动hiveserver2的服务了,每次都要先启动hive的元数据服务(nohup hive --service metastore)才能启动hive,之前的beeline命令也用不了,hiveserver2的无法启动,这也导致我…...
前端路由如何从0开始配置?vue-router 的使用
在 Web 开发中,路由是指根据 URL 的不同部分将请求分发到不同的处理函数或页面的过程。路由是单页应用(SPA, Single Page Application)和服务器端渲染(SSR, Server-Side Rendering)应用中的一个重要概念。 在开发中如何…...
Java中的运算符【与C语言的区别】
目录 1. 算术运算符 1.0 赋值运算符: 1.1 四则运算符: - * / % 【取余与C有点不同】 1.2 增量运算符: - * / % * 【右侧运算结果会自动转换类型】 1.3 自增、自减:、-- 2. 关系运算符 3. 逻辑运算符 3.1 短路求值 3.2 【…...
二、基础语法
入门了解 注释 **作用:**在代码中加一些注释和说明,方便自己或者其他程序员阅读代码 两种格式: 单行注释:// 描述信息 通常放在一行代码的上方,或者一条语句的末尾,对该行代码进行说明 多行注释&#x…...
DB-GPT系列(一):DB-GPT能帮你做什么?
DB-GPT是一个开源的AI原生数据应用开发框架(AI Native Data App Development framework with AWEL and Agents),围绕大模型提供灵活、可拓展的AI原生数据应用管理与开发能力,可以帮助企业快速构建、部署智能AI数据应用,通过智能数据分析、洞察…...
【Python各个击破】numpy
简介 NumPy是一个开源的Python库,它提供了一个强大的N维数组对象和许多用于操作这些数组的函数。它是大多数Python科学计算的基础,包括Pandas、SciPy和scikit-learn等库都建立在NumPy之上。 安装 !pip install numpy导入 import numpy as np用法 # …...
【STM32 Blue Pill编程实例】-4位7段数码管使用
4位7段数码管使用 文章目录 4位7段数码管使用1、7段数码介绍2、硬件准备与接线3、模块配置4、代码实现在本文中,我们将介绍如何将 STM32 Blue Pill开发板与 4 位 7 段数码管连接,并在 STM32CubeIDE 中对其进行编程。 在文章中首先将介绍 4 位 7 段数码管及其与 STM32 Blue Pi…...
[进阶]java基础之集合(三)数据结构
文章目录 数据结构概述常见的数据结构数据结构(栈)数据结构(队列)数据结构(数组)数据结构(链表) 数据结构 概述 数据结构是计算机底层存储、组织数据的方式。是指数据相互之间是以什么方式排列在一起的。数据结构是为了更加方便的管理和使用数据,需要结合具体的业…...
《Apache Cordova/PhoneGap 使用技巧分享》
一、引言 在移动应用开发的领域中,Apache Cordova(也被称为 PhoneGap)是一个强大的工具,它允许开发者使用 HTML、CSS 和 JavaScript 等 Web 技术来构建跨平台的移动应用。这种方式不仅能够提高开发效率,还能降低开发成…...
SCP(Secure Copy
SCP(Secure Copy)是Linux系统下基于SSH协议的安全文件传输工具,用于在本地和远程主机间安全、快速地传输文件和目录。SCP命令通过加密传输确保数据的安全性,并且不占用过多系统资源。 SCP的基本用法 基本语法:…...
uniApp 省市区自定义数据
关于自定义省市区选择 其实也是用了 uniApp的内置组件 picker <picker mode"multiSelector" change"bindRegionChange" columnchange"bindMultiPickerColumnChange" :value"valueRegion" :range"multiArray"><v…...
图解Redis 06 | Hash数据类型的原理及应用场景
介绍 Hash 类型特别适合存储对象,例如用户信息等。 String类型也可以用于存储用户信息,Hash与String存储用户信息的区别如下图所示: 内部实现 Hash 类型 的底层数据结构是通过压缩列表(Ziplist)或哈希表ÿ…...
Oracle Product Hub Portal Cloud(简称 OPH Cloud)是 Oracle 提供的基于云的主数据管理(MDM)解决方案
Oracle Product Hub Portal Cloud(简称 OPH Cloud)是 Oracle 提供的基于云的主数据管理(MDM)解决方案,专为统一、治理和分发产品主数据而设计。它是 Oracle Cloud Enterprise Resource Planning (ERP)、Supply Chain M…...
多轴点焊机器人产业动能强劲:538.2亿元市场规模奠基,2032年将跃升至近1154.9亿元
据恒州诚思调研统计,2025年全球多轴点焊机器人市场规模约达538.2亿元。在全球工业自动化浪潮的推动下,预计未来该市场将持续平稳增长,到2032年市场规模将接近1154.9亿元,未来六年复合年均增长率(CAGR)为11.…...
别再只盯着蓝牙和ZigBee了!用Telink TLSR8258芯片的2.4G私有协议,自己动手做个低功耗遥控器
从零构建2.4G私有协议遥控器:Telink TLSR8258实战指南 当市面上大多数IoT设备还在蓝牙和ZigBee的框架下挣扎时,Telink TLSR8258芯片的2.4G私有协议正在悄然改写低功耗无线通信的规则。我曾在一个智能农业项目中,需要控制200米外的灌溉阀门&am…...
深度解析:SillyTavern如何通过五大革新打造终极AI对话体验?
深度解析:SillyTavern如何通过五大革新打造终极AI对话体验? 【免费下载链接】SillyTavern LLM Frontend for Power Users. 项目地址: https://gitcode.com/GitHub_Trending/si/SillyTavern 你是否曾想过,一个AI对话前端能如何超越简单…...
中文语义相似度计算新范式:技术演进与实践路径
中文语义相似度计算新范式:技术演进与实践路径 【免费下载链接】Awesome-Chinese-LLM 整理开源的中文大语言模型,以规模较小、可私有化部署、训练成本较低的模型为主,包括底座模型,垂直领域微调及应用,数据集与教程等。…...
软件毕业设计新手避坑指南:从选题到部署的全链路技术实践
最近在帮几个学弟学妹看他们的软件毕业设计,发现大家遇到的问题都惊人的相似:选题要么太大做不完,要么太小没亮点;技术栈东拼西凑,代码写得像一锅粥;好不容易本地跑通了,一到部署就各种报错&…...
遗传算法优化PID控制:MATLAB 2021b下的 m 文件与Simulink联合仿真之旅
遗传算法优化 PID 控制,采用 m 文件联合 Simulink进行仿真,MATLAB2021b,在控制系统领域,PID控制凭借其结构简单、鲁棒性好等优点,一直占据着重要地位。然而,传统PID控制器参数的整定往往依赖经验࿰…...
如何利用OpenCode实现高效专业的AI驱动开发工作流?
如何利用OpenCode实现高效专业的AI驱动开发工作流? 【免费下载链接】opencode 一个专为终端打造的开源AI编程助手,模型灵活可选,可远程驱动。 项目地址: https://gitcode.com/GitHub_Trending/openc/opencode 在当今快速迭代的软件开发…...
ArcGIS JS API调用天地图WMTS服务实战:从GetCapabilities解析到完整代码实现
ArcGIS JS API调用天地图WMTS服务全流程解析 在WebGIS开发中,将第三方地图服务无缝集成到ArcGIS生态系统中是常见需求。天地图作为国内权威的地理信息服务,其WMTS(Web Map Tile Service)接口的调用尤为关键。本文将深入剖析从服务…...
hadoop+spark+hive爬虫农产品推荐系统 农产品爬虫 农产品可视化 农产品价格预测系统 爬虫+线性回归预测算法+Flask框架
1、项目 介绍 技术栈: python语言、FLASK框架、requests爬虫技术、Echarts可视化、HTML、线性回归预测算法模型 惠农网https://www.cnhnb.com/农产品价格预测系统在现代农业领域发挥着重要作用,它不仅有助于农民合理安排农作物的种植和销售,…...
