Scaled_dot_product_attention(SDPA)使用详解
在学习huggingFace的Transformer库时,我们不可避免会遇到scaled_dot_product_attention(SDPA)这个函数,它被用来加速大模型的Attention计算,本文就详细介绍一下它的使用方法,核心内容主要参考了torch.nn.functional中该函数的注释。
1. Attention计算公式
Attention的计算主要涉及三个矩阵:Q、K、V。我们先不考虑multi-head attention,只考虑one head的self attention。在大模型的prefill阶段,这三个矩阵的维度均为N x d,N即为上下文的长度;在decode阶段,Q的维度为1 x d, KV还是N x d。然后通过下面的公式计算attention矩阵:
  O = A t t e n t i o n ( Q , K , V ) = s o f t m a x ( Q K T d ) V O=Attention(Q,K,V)=softmax(\frac{QK^T}{\sqrt d})V O=Attention(Q,K,V)=softmax(dQKT)V
 在真正使用attention的时候,我们往往采用multi-head attention(MHA)。MHA的计算公式和one head attention基本一致,它改变了Q、K、V每一行的定义:将维度d的向量分成h组变成一个h x dk的矩阵,Q、K、V此时成为了 N ∗ h ∗ d k N * h * d_k N∗h∗dk的三维矩阵(不考虑batch维)。分别将Q、K、V的第一和第二维进行转置得到三个维度为 h ∗ N ∗ d k h * N * d_k h∗N∗dk的三维矩阵。此时的三个矩阵就是具有h个头的Q、K、V,我们就可以按照self attention的定义计算h个头的attention值。
不过,在真正进行大模型推理的时候就会发现KV Cache是非常占显存的,所以大家尝试各种手段压缩KV Cache,具体可以参考《大模型推理–KV Cache压缩》。一种手段就是将MHA替换成group query attention(GQA),这块在torch2.5以上的SDPA中也已经得到了支持。
2. SDPA伪代码
在SDPA的注释中,给出了伪代码:
def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0,is_causal=False, scale=None, enable_gqa=False) -> torch.Tensor:L, S = query.size(-2), key.size(-2)scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scaleattn_bias = torch.zeros(L, S, dtype=query.dtype)if is_causal:assert attn_mask is Nonetemp_mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0)attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))attn_bias.to(query.dtype)if attn_mask is not None:if attn_mask.dtype == torch.bool:attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf"))else:attn_bias += attn_maskif enable_gqa:key = key.repeat_interleave(query.size(-3)//key.size(-3), -3)value = value.repeat_interleave(query.size(-3)//value.size(-3), -3)attn_weight = query @ key.transpose(-2, -1) * scale_factorattn_weight += attn_biasattn_weight = torch.softmax(attn_weight, dim=-1)attn_weight = torch.dropout(attn_weight, dropout_p, train=True)return attn_weight @ value
可以看出,我们实际在使用SDPA时除了query、key和value之外,还有另外几个参数:attn_mask、dropout_p、is_causal、scale和enable_gqa。scale就是计算Attention时的缩放因子,一般无需传递。dropout_p表示Dropout概率,在推理阶段也不需要传递,不过官方建议如下输入:dropout_p=(self.p if self.training else 0.0)。我们着重看一下另外三个参数在使用时该如何设置。
先看enable_gqa。前面提到GQA是一种KV Cache压缩方法,MHA的KV和Q一样,也会有h个头,GQA则将KV的h个头进行压缩来减小KV Cache的大小。比如Qwen2-7B-Instruct这个模型,Q的h等于28,KV的h等于4,相当于把KV Cache压缩到之前的七分之一。GQA虽然压缩了KV Cache,但是真正要计算Attention的时候还是需要对齐KV与Q的head数,所以我们可以看到HF Transformer库中的qwen2.py在Attention计算时会有一个repeat_kv的操作,目的就是将QKV的head数统一。在torch2.5以后的版本中,我们无需再手动去执行repeat_kv,直接将SDPA的enable_gqa设置为True即可自动完成repeat_kv,而且速度比自己去做repaet_kv还要更快。
attn_mask和is_causal两个参数的作用相同,目的都是要给softmax之前的QKT矩阵添加mask。只不过attn_mask是自己在外面构造mask矩阵,is_causal则是根据大模型推理的阶段属于prefill还是decode来进行设置。通过看伪代码可以看出,SDPA会首先构造一个L x S的零矩阵attn_bias,L表示Q的上下文长度,S表示KV Cache的长度。在prefill阶段,L和S相等,在decode阶段,L为1,S还是N。所以在prefill阶段,attn_bias就是一个N x N的矩阵,将is_causal设置为True时就会构造一个下三角为0,上三角为负无穷的矩阵作为attn_bias,然后将其加到QKT矩阵上,这样就实现了因果关系的Attention计算。在decode阶段,attn_bias就是一个1 x N的向量,此时可以将is_causal设置为False,attn_bias始终为0就不会对 Q K T QK^T QKT行向量产生影响,表示KV Cache所有的行都参与计算,因果关系保持正确。
attn_mask作用和is_causal一样,但是需要我们自行构造,如果你对如何构造不了解建议就使用is_causal选项,prefill阶段设置为True,decode阶段设置为False,attn_mask设置为None。不过,如果prefill按照chunk来执行也即chunk_prefill阶段,我们会发现is_causal设置为True时的attn_bias设置的不正确,我们不是从左上角开始构造下三角矩阵,而是要从右下角开始构造下三角矩阵,这种情况下我们可以从外面自行构造attn_mask矩阵代替SDPA的构造。attn_mask有两种构造方式,一种是bool类型,True的位置会保持不变,False的位置会置为负无穷;一种是float类型,会直接将attn_mask加到SDPA内部的attn_bias上,和bool类型一样,我们一般是构造一个下三角为0上三角为负无穷的矩阵。总结来说,绝大多数情况下我们只需要设置is_causal选项,prefill阶段设置为True,decode阶段设置为False,attn_mask设置为None即可。如果推理阶段引入了chunk_prefill,则我们需要自行构造attn_mask,但是要注意构造的attn_mask矩阵是从右下角开始的下三角矩阵。
3. SDPA实现(翻译自SDPA注释)
目前SDPA有三种实现:
- 基于FlashAttention-2的实现;
- Memory-Efficient Attention(facebook xformers);
- Pytorch版本对上述伪代码的c++实现(对应MATH后端)。
针对CUDA后端,SDPA可能会调用经过优化的内核以提高性能。对于所有其他后端,将使用PyTorch实现。所有实现方式默认都是启用的,SDPA会尝试根据输入自动选择最优的实现方式。为了对使用哪种实现方式提供更细粒度的控制,torch提供了以下函数来启用和禁用各种实现方式:
- torch.nn.attention.sdpa_kernel:一个上下文管理器,用于启用或禁用任何一种实现方式;
- torch.backends.cuda.enable_flash_sdp:全局启用或禁用FlashAttention
- torch.backends.cuda.enable_mem_efficient_sdp:全局启用或禁用memory efficient attention
- torch.backends.cuda.enable_math_sdp:全局启用或禁用PyTorch的C++实现。
每个融合内核都有特定的输入限制。如果用户需要使用特定的融合实现方式,请使用torch.nn.attention.sdpa_kernel 禁用PyTorch 的C++实现。如果某个融合实现方式不可用,将会发出警告,说明该融合实现方式无法运行的原因。由于融合浮点运算的特性,此函数的输出可能会因所选择的后端内核而异。C++ 实现支持torch.float64,当需要更高精度时可以使用。对于math后端,如果输入是torch.half或torch.bfloat16类型,那么所有中间计算结果都会保持为torch.float类型。
4. SDPA使用示例
首先强调一点,灌入SDPA的QKV都是做过转置的,也即维度为batch x head x N x d,在老版本的torch中还需要QKV都是contiguous的,新版本下无此要求。SDPA注释中还给了两个示例,我们在此也给出:
# Optionally use the context manager to ensure one of the fused kernels is runquery = torch.rand(32, 8, 128, 64, dtype=torch.float16, device="cuda")key = torch.rand(32, 8, 128, 64, dtype=torch.float16, device="cuda")value = torch.rand(32, 8, 128, 64, dtype=torch.float16, device="cuda")with sdpa_kernel(backends=[SDPBackend.FLASH_ATTENTION]):F.scaled_dot_product_attention(query,key,value)
上述示例中,给定的输入为batch等于32,head等于8,上下文长度128,embedding维度64,然后通过sdpa_kernel选择使用FlashAttention。
 示例二:
# Sample for GQA for llama3
query = torch.rand(32, 32, 128, 64, dtype=torch.float16, device="cuda")
key = torch.rand(32, 8, 128, 64, dtype=torch.float16, device="cuda")
value = torch.rand(32, 8, 128, 64, dtype=torch.float16, device="cuda")
with sdpa_kernel(backends=[SDPBackend.MATH]):F.scaled_dot_product_attention(query,key,value,enable_gqa=True)
示例二演示了GQA的用法,给定的query head数为32,key和value均为8,此时我们可以通过enable_gqa选项来实现对GQA的支持,此外代码还通过sdpa_kernel选项使用了MATH后端。
5. 参考
- FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning
- Memory-Efficient Attention
- Grouped-Query Attention
- Attention Is All You Need
相关文章:
Scaled_dot_product_attention(SDPA)使用详解
在学习huggingFace的Transformer库时,我们不可避免会遇到scaled_dot_product_attention(SDPA)这个函数,它被用来加速大模型的Attention计算,本文就详细介绍一下它的使用方法,核心内容主要参考了torch.nn.functional中该函数的注释…...
 
Linux练级宝典->Linux进程概念介绍
目录 进程基本概念 PCB概念 task_struct tack_struct内容分类 PID和PPID fork函数创建子进程 进程优先级概念 4个名词 进程地址空间 进程地址空间的意义 内核进程调度队列 优先级 活动队列 过期队列 进程基本概念 一个正在执行的程序。担当分配系统资源的实体&#…...
OpenHarmony 5.0 mpegts封装的H265视频播放失败的解决方案
问题现象 OpenHarmony 5.0版本使用AVPlayer播放mpegts封装格式的H.265(HEVC)编码格式的视频时出现报错导致播放失败 问题原因 OpenHarmony 5.0版本AVPlayer播放器使用histreamer引擎,因为 libav_codec_hevc_parser.z.so 动态库未开源导致H265编码格式视频解析不到…...
 
Qt从入门到入土(九) -model/view(模型/视图)框架
简介 Qt的模型/视图(Model/View)架构是一种用于分离数据处理和用户界面展示的设计模式。它允许开发者将数据存储和管理(模型)与数据的显示和交互(视图)解耦,从而提高代码的可维护性和可扩展性。…...
 
缓存之美:Guava Cache 相比于 Caffeine 差在哪里?
大家好,我是 方圆。本文将结合 Guava Cache 的源码来分析它的实现原理,并阐述它相比于 Caffeine Cache 在性能上的劣势。为了让大家对 Guava Cache 理解起来更容易,我们还是在开篇介绍它的原理: Guava Cache 通过分段(…...
 
[漏洞篇]XSS漏洞详解
[漏洞篇]XSS漏洞 一、 介绍 概念 XSS:通过JS达到攻击效果 XSS全称跨站脚本(Cross Site Scripting),为避免与层叠样式表(Cascading Style Sheets, CSS)的缩写混淆,故缩写为XSS。这是一种将任意 Javascript 代码插入到其他Web用户页面里执行以…...
【Leetcode 每日一题】2269. 找到一个数字的 K 美丽值
问题背景 一个整数 n u m num num 的 k k k 美丽值定义为 n u m num num 中符合以下条件的 子字符串 数目: 子字符串长度为 k k k。子字符串能整除 n u m num num。 给你整数 n u m num num 和 k k k,请你返回 n u m num num 的 k k k 美丽值…...
 
IO进程线程(线程)
作业 1.创建两个线程,分支线程1拷贝文件的前一部分,分支线程2拷贝文件的后一部分 2.创建三个线程,实现线程A打印A,线程B打印B,线程C打印C;重复打印顺序ABC。 信号量实现: 条件变量实现&#x…...
1-002:MySQL InnoDB引擎中的聚簇索引和非聚簇索引有什么区别?
在 MySQL InnoDB 存储引擎 中,索引主要分为 聚簇索引(Clustered Index) 和 非聚簇索引(Secondary Index)。它们的主要区别如下: 1. 聚簇索引(Clustered Index) 定义 聚簇索引是表数…...
 
tomcat单机多实例部署
一、部署方法 多实例可以运行多个不同的应用,也可以运行相同的应用,类似于虚拟主机,但是他可以做负载均衡。 方式一: 把tomcat的主目录挨个复制,然后把每台主机的端口给改掉就行了。 优点是最简单最直接,…...
 
论文阅读分享——UMDF(AAAI-24)
概述 题目:A Unified Self-Distillation Framework for Multimodal Sentiment Analysis with Uncertain Missing Modalities 发表:The Thirty-Eighth AAAI Conference on Artificial Intelligence (AAAI-24) 年份:2024 Github:暂…...
 
解决asp.net mvc发布到iis下安全问题
解决asp.net mvc发布到iis下安全问题 环境信息1.The web/application server is leaking version information via the "Server" HTTP response2.确保您的Web服务器、应用程序服务器、负载均衡器等已配置为强制执行Strict-Transport-Security。3.在HTML提交表单中找不…...
 
概念|RabbitMQ 消息生命周期 待消费的消息和待应答的消息有什么区别
目录 消息生命周期 一、消息创建与发布阶段 二、消息路由与存储阶段 三、消息存活与过期阶段 四、消息投递与消费阶段 五、消息生命周期终止 关键配置建议 待消费的消息和待应答的消息 一、待消费的消息(Unconsumed Messages) 二、待应答的消息…...
 
springboot三层架构详细讲解
目录 springBoot三层架构 0.简介1.各层架构 1.1 Controller层1.2 Service层1.3 ServiceImpl1.4 Mapper1.5 Entity1.6 Mapper.xml 2.各层之间的联系 2.1 Controller 与 Service2.2 Service 与 ServiceImpl2.3 Service 与 Mapper2.4 Mapper 与 Mapper.xml2.5 Service 与 Entity2…...
 
2025最新群智能优化算法:云漂移优化(Cloud Drift Optimization,CDO)算法求解23个经典函数测试集,MATLAB
一、云漂移优化算法 云漂移优化(Cloud Drift Optimization,CDO)算法是2025年提出的一种受自然现象启发的元启发式算法,它模拟云在大气中漂移的动态行为来解决复杂的优化问题。云在大气中受到各种大气力的影响,其粒子的…...
 
2025年Draw.io最新版本下载安装教程,附详细图文
2025年Draw.io最新版本下载安装教程,附详细图文 大家好,今天给大家介绍一款非常实用的流程图绘制软件——Draw.io。不管你是平时需要设计流程图、绘制思维导图,还是制作架构图,甚至是简单的草图,它都能帮你轻松搞定。…...
记录--洛谷 P1451 求细胞数量
如果想查看完整题目,请前往洛谷 P1451 求细胞数量 P1451 求细胞数量 题目描述 一矩形阵列由数字 0 0 0 到 9 9 9 组成,数字 1 1 1 到 9 9 9 代表细胞,细胞的定义为沿细胞数字上下左右若还是细胞数字则为同一细胞,求给定矩形…...
 
Android Studio 配置国内镜像源
Android Studio版本号:2022.1.1 Patch 2 1、配置gradle国内镜像,用腾讯云 镜像源地址:https\://mirrors.cloud.tencent.com/gradle 2、配置Android SDK国内镜像 地址:Index of /AndroidSDK/...
做到哪一步才算精通SQL
做到哪一步才算精通SQL-Structured Query Language 数据定义语言 DDL for StructCREATE:用来创建数据库、表、索引等对象ALTER:用来修改已存在的数据库对象DROP:用来删除整个数据库或者数据库中的表TRUNCATE:用来删除表中所有的行…...
Manus演示案例: 英伟达财务估值建模 解锁投资洞察的深度剖析
在当今瞬息万变的金融投资领域,精准剖析企业价值是投资者决胜市场的关键。英伟达(NVIDIA),作为科技行业的耀眼明星,其在人工智能和半导体领域的卓越表现备受瞩目。Manus 凭借专业的财务估值建模能力,深入挖…...
 
MPNet:旋转机械轻量化故障诊断模型详解python代码复现
目录 一、问题背景与挑战 二、MPNet核心架构 2.1 多分支特征融合模块(MBFM) 2.2 残差注意力金字塔模块(RAPM) 2.2.1 空间金字塔注意力(SPA) 2.2.2 金字塔残差块(PRBlock) 2.3 分类器设计 三、关键技术突破 3.1 多尺度特征融合 3.2 轻量化设计策略 3.3 抗噪声…...
【网络】每天掌握一个Linux命令 - iftop
在Linux系统中,iftop是网络管理的得力助手,能实时监控网络流量、连接情况等,帮助排查网络异常。接下来从多方面详细介绍它。 目录 【网络】每天掌握一个Linux命令 - iftop工具概述安装方式核心功能基础用法进阶操作实战案例面试题场景生产场景…...
条件运算符
C中的三目运算符(也称条件运算符,英文:ternary operator)是一种简洁的条件选择语句,语法如下: 条件表达式 ? 表达式1 : 表达式2• 如果“条件表达式”为true,则整个表达式的结果为“表达式1”…...
 
基于当前项目通过npm包形式暴露公共组件
1.package.sjon文件配置 其中xh-flowable就是暴露出去的npm包名 2.创建tpyes文件夹,并新增内容 3.创建package文件夹...
 
令牌桶 滑动窗口->限流 分布式信号量->限并发的原理 lua脚本分析介绍
文章目录 前言限流限制并发的实际理解限流令牌桶代码实现结果分析令牌桶lua的模拟实现原理总结: 滑动窗口代码实现结果分析lua脚本原理解析 限并发分布式信号量代码实现结果分析lua脚本实现原理 双注解去实现限流 并发结果分析: 实际业务去理解体会统一注…...
ip子接口配置及删除
配置永久生效的子接口,2个IP 都可以登录你这一台服务器。重启不失效。 永久的 [应用] vi /etc/sysconfig/network-scripts/ifcfg-eth0修改文件内内容 TYPE"Ethernet" BOOTPROTO"none" NAME"eth0" DEVICE"eth0" ONBOOT&q…...
iOS性能调优实战:借助克魔(KeyMob)与常用工具深度洞察App瓶颈
在日常iOS开发过程中,性能问题往往是最令人头疼的一类Bug。尤其是在App上线前的压测阶段或是处理用户反馈的高发期,开发者往往需要面对卡顿、崩溃、能耗异常、日志混乱等一系列问题。这些问题表面上看似偶发,但背后往往隐藏着系统资源调度不当…...
 
推荐 github 项目:GeminiImageApp(图片生成方向,可以做一定的素材)
推荐 github 项目:GeminiImageApp(图片生成方向,可以做一定的素材) 这个项目能干嘛? 使用 gemini 2.0 的 api 和 google 其他的 api 来做衍生处理 简化和优化了文生图和图生图的行为(我的最主要) 并且有一些目标检测和切割(我用不到) 视频和 imagefx 因为没 a…...
Vue 模板语句的数据来源
🧩 Vue 模板语句的数据来源:全方位解析 Vue 模板(<template> 部分)中的表达式、指令绑定(如 v-bind, v-on)和插值({{ }})都在一个特定的作用域内求值。这个作用域由当前 组件…...
 
WebRTC调研
WebRTC是什么,为什么,如何使用 WebRTC有什么优势 WebRTC Architecture Amazon KVS WebRTC 其它厂商WebRTC 海康门禁WebRTC 海康门禁其他界面整理 威视通WebRTC 局域网 Google浏览器 Microsoft Edge 公网 RTSP RTMP NVR ONVIF SIP SRT WebRTC协…...
