当前位置: 首页 > news >正文

YOLO即插即用模块---AgentAttention

Agent Attention: On the Integration of Softmax and Linear Attention

论文地址:https://arxiv.org/pdf/2312.08874

问题: 普遍使用的 Softmax 注意力机制在视觉 Transformer 模型中计算复杂度过高,限制了其在各种场景中的应用。

方法: 提出了一个新的注意力机制,名为 Agent Attention,通过引入一组代理 token (A) 来解决计算复杂度过高的问题。

具体步骤

  1. 代理聚合 (Agent Aggregation): 将代理 token (A) 作为查询 token (Q) 的代理,从键 (K) 和值 (V) 中聚合信息,形成代理特征 (VA)。

  2. 代理广播 (Agent Broadcast): 将代理 token (A) 作为键,将全局信息从代理特征 (VA) 广播到每个查询 token (Q),形成最终的输出。

代理 token (A) 的获取方式

  • 可学习的参数

  • 从输入特征中提取 (例如,通过池化或卷积)

Agent Attention 模块

  • 包含纯 Agent Attention、代理偏置 (Agent Bias) 和深度可分离卷积 (DWC) 模块。

  • 代理偏置用于添加位置信息,帮助不同的代理 token 关注不同的区域。

  • DWC 模块用于保持特征多样性,弥补线性注意力的不足。

Agent Attention 的优势

  • 高效计算和高表达能力: 结合了 Softmax 注意力和线性注意力的优点,既降低了计算复杂度,又保持了高表达能力。

  • 大感受野: 可以采用更大的感受野,甚至全局感受野,同时保持相同的计算量。P8

实验结果

  • 在图像分类、目标检测、语义分割和图像生成等任务上,Agent Attention 都取得了显著的性能提升。

  • 在高分辨率场景中,Agent Attention 表现出优异的性能。

  • 将 Agent Attention 应用于 Stable Diffusion,可以加速图像生成过程,并显著提高图像生成质量,无需任何额外的训练。

总结: Agent Attention 是一种高效且高表达的注意力机制,可以有效地解决 Softmax 注意力计算复杂度过高的问题,在各种视觉任务中取得了显著的性能提升,特别是在高分辨率场景中。

即插即用代码:

import torch
import torch.nn as nn
from timm.models.layers import trunc_normal_class AgentAttention(nn.Module):def __init__(self, dim, num_patches, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.,sr_ratio=1, agent_num=49, **kwargs):super().__init__()assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}."self.dim = dimself.num_patches = num_patcheswindow_size = (int(num_patches ** 0.5), int(num_patches ** 0.5))self.window_size = window_sizeself.num_heads = num_headshead_dim = dim // num_headsself.scale = head_dim ** -0.5self.q = nn.Linear(dim, dim, bias=qkv_bias)self.kv = nn.Linear(dim, dim * 2, bias=qkv_bias)self.attn_drop = nn.Dropout(attn_drop)self.proj = nn.Linear(dim, dim)self.proj_drop = nn.Dropout(proj_drop)self.sr_ratio = sr_ratioif sr_ratio > 1:self.sr = nn.Conv2d(dim, dim, kernel_size=sr_ratio, stride=sr_ratio)self.norm = nn.LayerNorm(dim)self.agent_num = agent_numself.dwc = nn.Conv2d(in_channels=dim, out_channels=dim, kernel_size=(3, 3), padding=1, groups=dim)self.an_bias = nn.Parameter(torch.zeros(num_heads, agent_num, 7, 7))self.na_bias = nn.Parameter(torch.zeros(num_heads, agent_num, 7, 7))self.ah_bias = nn.Parameter(torch.zeros(1, num_heads, agent_num, window_size[0] // sr_ratio, 1))self.aw_bias = nn.Parameter(torch.zeros(1, num_heads, agent_num, 1, window_size[1] // sr_ratio))self.ha_bias = nn.Parameter(torch.zeros(1, num_heads, window_size[0], 1, agent_num))self.wa_bias = nn.Parameter(torch.zeros(1, num_heads, 1, window_size[1], agent_num))trunc_normal_(self.an_bias, std=.02)trunc_normal_(self.na_bias, std=.02)trunc_normal_(self.ah_bias, std=.02)trunc_normal_(self.aw_bias, std=.02)trunc_normal_(self.ha_bias, std=.02)trunc_normal_(self.wa_bias, std=.02)pool_size = int(agent_num ** 0.5)self.pool = nn.AdaptiveAvgPool2d(output_size=(pool_size, pool_size))self.softmax = nn.Softmax(dim=-1)def forward(self, x, H, W):b, n, c = x.shapenum_heads = self.num_headshead_dim = c // num_headsq = self.q(x)if self.sr_ratio > 1:x_ = x.permute(0, 2, 1).reshape(b, c, H, W)x_ = self.sr(x_).reshape(b, c, -1).permute(0, 2, 1)x_ = self.norm(x_)kv = self.kv(x_).reshape(b, -1, 2, c).permute(2, 0, 1, 3)else:kv = self.kv(x).reshape(b, -1, 2, c).permute(2, 0, 1, 3)k, v = kv[0], kv[1]agent_tokens = self.pool(q.reshape(b, H, W, c).permute(0, 3, 1, 2)).reshape(b, c, -1).permute(0, 2, 1)q = q.reshape(b, n, num_heads, head_dim).permute(0, 2, 1, 3)k = k.reshape(b, n // self.sr_ratio ** 2, num_heads, head_dim).permute(0, 2, 1, 3)v = v.reshape(b, n // self.sr_ratio ** 2, num_heads, head_dim).permute(0, 2, 1, 3)agent_tokens = agent_tokens.reshape(b, self.agent_num, num_heads, head_dim).permute(0, 2, 1, 3)kv_size = (self.window_size[0] // self.sr_ratio, self.window_size[1] // self.sr_ratio)position_bias1 = nn.functional.interpolate(self.an_bias, size=kv_size, mode='bilinear')position_bias1 = position_bias1.reshape(1, num_heads, self.agent_num, -1).repeat(b, 1, 1, 1)position_bias2 = (self.ah_bias + self.aw_bias).reshape(1, num_heads, self.agent_num, -1).repeat(b, 1, 1, 1)position_bias = position_bias1 + position_bias2agent_attn = self.softmax((agent_tokens * self.scale) @ k.transpose(-2, -1) + position_bias)agent_attn = self.attn_drop(agent_attn)agent_v = agent_attn @ vagent_bias1 = nn.functional.interpolate(self.na_bias, size=self.window_size, mode='bilinear')agent_bias1 = agent_bias1.reshape(1, num_heads, self.agent_num, -1).permute(0, 1, 3, 2).repeat(b, 1, 1, 1)agent_bias2 = (self.ha_bias + self.wa_bias).reshape(1, num_heads, -1, self.agent_num).repeat(b, 1, 1, 1)agent_bias = agent_bias1 + agent_bias2q_attn = self.softmax((q * self.scale) @ agent_tokens.transpose(-2, -1) + agent_bias)q_attn = self.attn_drop(q_attn)x = q_attn @ agent_vx = x.transpose(1, 2).reshape(b, n, c)v = v.transpose(1, 2).reshape(b, H // self.sr_ratio, W // self.sr_ratio, c).permute(0, 3, 1, 2)if self.sr_ratio > 1:v = nn.functional.interpolate(v, size=(H, W), mode='bilinear')x = x + self.dwc(v).permute(0, 2, 3, 1).reshape(b, n, c)x = self.proj(x)x = self.proj_drop(x)return xif __name__ == '__main__':dim = 4num_patches = 64block = AgentAttention(dim=dim, num_patches=num_patches)H, W = 8,8x = torch.rand(1, num_patches, dim)output = block(x, H, W)print(f"Input size: {x.size()}")print(f"Output size: {output.size()}")

YOLO小伙伴可进群交流:

相关文章:

YOLO即插即用模块---AgentAttention

Agent Attention: On the Integration of Softmax and Linear Attention 论文地址:https://arxiv.org/pdf/2312.08874 问题: 普遍使用的 Softmax 注意力机制在视觉 Transformer 模型中计算复杂度过高,限制了其在各种场景中的应用。 方法&a…...

探索开源语音识别的未来:高效利用先进的自动语音识别技术20241030

🚀 探索开源语音识别的未来:高效利用自动语音识别技术 🌟 引言 在数字化时代,语音识别技术正在引领人机交互的新潮流,为各行业带来了颠覆性的改变。开源的自动语音识别(ASR)系统,如…...

学习路之TP6--workman安装

一、安装 首先通过 composer 安装 composer require topthink/think-worker 报错: 分析:最新版本需要TP8,或装低版本的 composer require topthink/think-worker:^3.*安装后, 增加目录 vendor\workerman vendor\topthink\think-w…...

.NET内网实战:通过白名单文件反序列化漏洞绕过UAC

01阅读须知 此文所节选自小报童《.NET 内网实战攻防》专栏,主要内容有.NET在各个内网渗透阶段与Windows系统交互的方式和技巧,对内网和后渗透感兴趣的朋友们可以订阅该电子报刊,解锁更多的报刊内容。 02基本介绍 03原理分析 在渗透测试和红…...

AI Agents - 自动化项目:计划、评估和分配

Agents: Role 角色Goal 目标Backstory 背景故事 Tasks: Description 描述Expected Output 期望输出Agent 代理 Automated Project: Planning, Estimation, and Allocation Initial Imports 1.本地文件helper.py # Add your utilities or helper functions to…...

Git的.gitignore文件

一、各语言对应的.gitignore模板文件 项目地址:https://github.com/github/gitignore 二、.gitignore文件不生效 .gitignore文件只是ignore没有被追踪的文件,已被追踪的文件,要先删除缓存文件。 # 单个文件 git rm --cached file/path/to…...

网站安全,WAF网站保护暴力破解

雷池的核心功能 通过过滤和监控 Web 应用与互联网之间的 HTTP 流量,功能包括: SQL 注入保护:防止恶意 SQL 代码的注入,保护网站数据安全。跨站脚本攻击 (XSS):阻止攻击者在用户浏览器中执行恶意脚本。暴力破解防护&a…...

深度学习:梯度下降算法简介

梯度下降算法简介 梯度下降算法 我们思考这样一个问题,现在需要用一条直线来回归拟合这三个点,直线的方程是 y w ^ x b y \hat{w}x b yw^xb,我们假设斜率 w ^ \hat{w} w^是已知的,现在想要找到一个最好的截距 b b b。 一条…...

SparkSQL整合Hive后,如何启动hiveserver2服务

当spark sql与hive整合后,我们就无法启动hiveserver2的服务了,每次都要先启动hive的元数据服务(nohup hive --service metastore)才能启动hive,之前的beeline命令也用不了,hiveserver2的无法启动,这也导致我…...

前端路由如何从0开始配置?vue-router 的使用

在 Web 开发中,路由是指根据 URL 的不同部分将请求分发到不同的处理函数或页面的过程。路由是单页应用(SPA, Single Page Application)和服务器端渲染(SSR, Server-Side Rendering)应用中的一个重要概念。 在开发中如何…...

Java中的运算符【与C语言的区别】

目录 1. 算术运算符 1.0 赋值运算符: 1.1 四则运算符: - * / % 【取余与C有点不同】 1.2 增量运算符: - * / % * 【右侧运算结果会自动转换类型】 1.3 自增、自减:、-- 2. 关系运算符 3. 逻辑运算符 3.1 短路求值 3.2 【…...

二、基础语法

入门了解 注释 **作用:**在代码中加一些注释和说明,方便自己或者其他程序员阅读代码 两种格式: 单行注释:// 描述信息 通常放在一行代码的上方,或者一条语句的末尾,对该行代码进行说明 多行注释&#x…...

DB-GPT系列(一):DB-GPT能帮你做什么?

DB-GPT是一个开源的AI原生数据应用开发框架(AI Native Data App Development framework with AWEL and Agents),围绕大模型提供灵活、可拓展的AI原生数据应用管理与开发能力,可以帮助企业快速构建、部署智能AI数据应用,通过智能数据分析、洞察…...

【Python各个击破】numpy

简介 NumPy是一个开源的Python库,它提供了一个强大的N维数组对象和许多用于操作这些数组的函数。它是大多数Python科学计算的基础,包括Pandas、SciPy和scikit-learn等库都建立在NumPy之上。 安装 !pip install numpy导入 import numpy as np用法 # …...

【STM32 Blue Pill编程实例】-4位7段数码管使用

4位7段数码管使用 文章目录 4位7段数码管使用1、7段数码介绍2、硬件准备与接线3、模块配置4、代码实现在本文中,我们将介绍如何将 STM32 Blue Pill开发板与 4 位 7 段数码管连接,并在 STM32CubeIDE 中对其进行编程。 在文章中首先将介绍 4 位 7 段数码管及其与 STM32 Blue Pi…...

[进阶]java基础之集合(三)数据结构

文章目录 数据结构概述常见的数据结构数据结构(栈)数据结构(队列)数据结构(数组)数据结构(链表) 数据结构 概述 数据结构是计算机底层存储、组织数据的方式。是指数据相互之间是以什么方式排列在一起的。数据结构是为了更加方便的管理和使用数据,需要结合具体的业…...

《Apache Cordova/PhoneGap 使用技巧分享》

一、引言 在移动应用开发的领域中,Apache Cordova(也被称为 PhoneGap)是一个强大的工具,它允许开发者使用 HTML、CSS 和 JavaScript 等 Web 技术来构建跨平台的移动应用。这种方式不仅能够提高开发效率,还能降低开发成…...

SCP(Secure Copy

SCP(Secure Copy)‌是Linux系统下基于SSH协议的安全文件传输工具,用于在本地和远程主机间安全、快速地传输文件和目录。SCP命令通过加密传输确保数据的安全性,并且不占用过多系统资源‌。 SCP的基本用法 ‌基本语法‌&#xff1a…...

uniApp 省市区自定义数据

关于自定义省市区选择 其实也是用了 uniApp的内置组件 picker <picker mode"multiSelector" change"bindRegionChange" columnchange"bindMultiPickerColumnChange" :value"valueRegion" :range"multiArray"><v…...

图解Redis 06 | Hash数据类型的原理及应用场景

介绍 Hash 类型特别适合存储对象&#xff0c;例如用户信息等。 String类型也可以用于存储用户信息&#xff0c;Hash与String存储用户信息的区别如下图所示&#xff1a; 内部实现 Hash 类型 的底层数据结构是通过压缩列表&#xff08;Ziplist&#xff09;或哈希表&#xff…...

Oracle Product Hub Portal Cloud(简称 OPH Cloud)是 Oracle 提供的基于云的主数据管理(MDM)解决方案

Oracle Product Hub Portal Cloud&#xff08;简称 OPH Cloud&#xff09;是 Oracle 提供的基于云的主数据管理&#xff08;MDM&#xff09;解决方案&#xff0c;专为统一、治理和分发产品主数据而设计。它是 Oracle Cloud Enterprise Resource Planning (ERP)、Supply Chain M…...

多轴点焊机器人产业动能强劲:538.2亿元市场规模奠基,2032年将跃升至近1154.9亿元

据恒州诚思调研统计&#xff0c;2025年全球多轴点焊机器人市场规模约达538.2亿元。在全球工业自动化浪潮的推动下&#xff0c;预计未来该市场将持续平稳增长&#xff0c;到2032年市场规模将接近1154.9亿元&#xff0c;未来六年复合年均增长率&#xff08;CAGR&#xff09;为11.…...

别再只盯着蓝牙和ZigBee了!用Telink TLSR8258芯片的2.4G私有协议,自己动手做个低功耗遥控器

从零构建2.4G私有协议遥控器&#xff1a;Telink TLSR8258实战指南 当市面上大多数IoT设备还在蓝牙和ZigBee的框架下挣扎时&#xff0c;Telink TLSR8258芯片的2.4G私有协议正在悄然改写低功耗无线通信的规则。我曾在一个智能农业项目中&#xff0c;需要控制200米外的灌溉阀门&am…...

深度解析:SillyTavern如何通过五大革新打造终极AI对话体验?

深度解析&#xff1a;SillyTavern如何通过五大革新打造终极AI对话体验&#xff1f; 【免费下载链接】SillyTavern LLM Frontend for Power Users. 项目地址: https://gitcode.com/GitHub_Trending/si/SillyTavern 你是否曾想过&#xff0c;一个AI对话前端能如何超越简单…...

中文语义相似度计算新范式:技术演进与实践路径

中文语义相似度计算新范式&#xff1a;技术演进与实践路径 【免费下载链接】Awesome-Chinese-LLM 整理开源的中文大语言模型&#xff0c;以规模较小、可私有化部署、训练成本较低的模型为主&#xff0c;包括底座模型&#xff0c;垂直领域微调及应用&#xff0c;数据集与教程等。…...

软件毕业设计新手避坑指南:从选题到部署的全链路技术实践

最近在帮几个学弟学妹看他们的软件毕业设计&#xff0c;发现大家遇到的问题都惊人的相似&#xff1a;选题要么太大做不完&#xff0c;要么太小没亮点&#xff1b;技术栈东拼西凑&#xff0c;代码写得像一锅粥&#xff1b;好不容易本地跑通了&#xff0c;一到部署就各种报错&…...

遗传算法优化PID控制:MATLAB 2021b下的 m 文件与Simulink联合仿真之旅

遗传算法优化 PID 控制&#xff0c;采用 m 文件联合 Simulink进行仿真&#xff0c;MATLAB2021b&#xff0c;在控制系统领域&#xff0c;PID控制凭借其结构简单、鲁棒性好等优点&#xff0c;一直占据着重要地位。然而&#xff0c;传统PID控制器参数的整定往往依赖经验&#xff0…...

如何利用OpenCode实现高效专业的AI驱动开发工作流?

如何利用OpenCode实现高效专业的AI驱动开发工作流&#xff1f; 【免费下载链接】opencode 一个专为终端打造的开源AI编程助手&#xff0c;模型灵活可选&#xff0c;可远程驱动。 项目地址: https://gitcode.com/GitHub_Trending/openc/opencode 在当今快速迭代的软件开发…...

ArcGIS JS API调用天地图WMTS服务实战:从GetCapabilities解析到完整代码实现

ArcGIS JS API调用天地图WMTS服务全流程解析 在WebGIS开发中&#xff0c;将第三方地图服务无缝集成到ArcGIS生态系统中是常见需求。天地图作为国内权威的地理信息服务&#xff0c;其WMTS&#xff08;Web Map Tile Service&#xff09;接口的调用尤为关键。本文将深入剖析从服务…...

hadoop+spark+hive爬虫农产品推荐系统 农产品爬虫 农产品可视化 农产品价格预测系统 爬虫+线性回归预测算法+Flask框架

1、项目 介绍 技术栈&#xff1a; python语言、FLASK框架、requests爬虫技术、Echarts可视化、HTML、线性回归预测算法模型 惠农网https://www.cnhnb.com/农产品价格预测系统在现代农业领域发挥着重要作用&#xff0c;它不仅有助于农民合理安排农作物的种植和销售&#xff0c;…...