Scaled_dot_product_attention(SDPA)使用详解
在学习huggingFace的Transformer库时,我们不可避免会遇到scaled_dot_product_attention(SDPA)这个函数,它被用来加速大模型的Attention计算,本文就详细介绍一下它的使用方法,核心内容主要参考了torch.nn.functional中该函数的注释。
1. Attention计算公式
Attention的计算主要涉及三个矩阵:Q、K、V。我们先不考虑multi-head attention,只考虑one head的self attention。在大模型的prefill阶段,这三个矩阵的维度均为N x d,N即为上下文的长度;在decode阶段,Q的维度为1 x d, KV还是N x d。然后通过下面的公式计算attention矩阵:
O = A t t e n t i o n ( Q , K , V ) = s o f t m a x ( Q K T d ) V O=Attention(Q,K,V)=softmax(\frac{QK^T}{\sqrt d})V O=Attention(Q,K,V)=softmax(dQKT)V
在真正使用attention的时候,我们往往采用multi-head attention(MHA)。MHA的计算公式和one head attention基本一致,它改变了Q、K、V每一行的定义:将维度d的向量分成h组变成一个h x dk的矩阵,Q、K、V此时成为了 N ∗ h ∗ d k N * h * d_k N∗h∗dk的三维矩阵(不考虑batch维)。分别将Q、K、V的第一和第二维进行转置得到三个维度为 h ∗ N ∗ d k h * N * d_k h∗N∗dk的三维矩阵。此时的三个矩阵就是具有h个头的Q、K、V,我们就可以按照self attention的定义计算h个头的attention值。
不过,在真正进行大模型推理的时候就会发现KV Cache是非常占显存的,所以大家尝试各种手段压缩KV Cache,具体可以参考《大模型推理–KV Cache压缩》。一种手段就是将MHA替换成group query attention(GQA),这块在torch2.5以上的SDPA中也已经得到了支持。
2. SDPA伪代码
在SDPA的注释中,给出了伪代码:
def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0,is_causal=False, scale=None, enable_gqa=False) -> torch.Tensor:L, S = query.size(-2), key.size(-2)scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scaleattn_bias = torch.zeros(L, S, dtype=query.dtype)if is_causal:assert attn_mask is Nonetemp_mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0)attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))attn_bias.to(query.dtype)if attn_mask is not None:if attn_mask.dtype == torch.bool:attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf"))else:attn_bias += attn_maskif enable_gqa:key = key.repeat_interleave(query.size(-3)//key.size(-3), -3)value = value.repeat_interleave(query.size(-3)//value.size(-3), -3)attn_weight = query @ key.transpose(-2, -1) * scale_factorattn_weight += attn_biasattn_weight = torch.softmax(attn_weight, dim=-1)attn_weight = torch.dropout(attn_weight, dropout_p, train=True)return attn_weight @ value
可以看出,我们实际在使用SDPA时除了query、key和value之外,还有另外几个参数:attn_mask、dropout_p、is_causal、scale和enable_gqa。scale就是计算Attention时的缩放因子,一般无需传递。dropout_p表示Dropout概率,在推理阶段也不需要传递,不过官方建议如下输入:dropout_p=(self.p if self.training else 0.0)。我们着重看一下另外三个参数在使用时该如何设置。
先看enable_gqa。前面提到GQA是一种KV Cache压缩方法,MHA的KV和Q一样,也会有h个头,GQA则将KV的h个头进行压缩来减小KV Cache的大小。比如Qwen2-7B-Instruct这个模型,Q的h等于28,KV的h等于4,相当于把KV Cache压缩到之前的七分之一。GQA虽然压缩了KV Cache,但是真正要计算Attention的时候还是需要对齐KV与Q的head数,所以我们可以看到HF Transformer库中的qwen2.py在Attention计算时会有一个repeat_kv的操作,目的就是将QKV的head数统一。在torch2.5以后的版本中,我们无需再手动去执行repeat_kv,直接将SDPA的enable_gqa设置为True即可自动完成repeat_kv,而且速度比自己去做repaet_kv还要更快。
attn_mask和is_causal两个参数的作用相同,目的都是要给softmax之前的QKT矩阵添加mask。只不过attn_mask是自己在外面构造mask矩阵,is_causal则是根据大模型推理的阶段属于prefill还是decode来进行设置。通过看伪代码可以看出,SDPA会首先构造一个L x S的零矩阵attn_bias,L表示Q的上下文长度,S表示KV Cache的长度。在prefill阶段,L和S相等,在decode阶段,L为1,S还是N。所以在prefill阶段,attn_bias就是一个N x N的矩阵,将is_causal设置为True时就会构造一个下三角为0,上三角为负无穷的矩阵作为attn_bias,然后将其加到QKT矩阵上,这样就实现了因果关系的Attention计算。在decode阶段,attn_bias就是一个1 x N的向量,此时可以将is_causal设置为False,attn_bias始终为0就不会对 Q K T QK^T QKT行向量产生影响,表示KV Cache所有的行都参与计算,因果关系保持正确。
attn_mask作用和is_causal一样,但是需要我们自行构造,如果你对如何构造不了解建议就使用is_causal选项,prefill阶段设置为True,decode阶段设置为False,attn_mask设置为None。不过,如果prefill按照chunk来执行也即chunk_prefill阶段,我们会发现is_causal设置为True时的attn_bias设置的不正确,我们不是从左上角开始构造下三角矩阵,而是要从右下角开始构造下三角矩阵,这种情况下我们可以从外面自行构造attn_mask矩阵代替SDPA的构造。attn_mask有两种构造方式,一种是bool类型,True的位置会保持不变,False的位置会置为负无穷;一种是float类型,会直接将attn_mask加到SDPA内部的attn_bias上,和bool类型一样,我们一般是构造一个下三角为0上三角为负无穷的矩阵。总结来说,绝大多数情况下我们只需要设置is_causal选项,prefill阶段设置为True,decode阶段设置为False,attn_mask设置为None即可。如果推理阶段引入了chunk_prefill,则我们需要自行构造attn_mask,但是要注意构造的attn_mask矩阵是从右下角开始的下三角矩阵。
3. SDPA实现(翻译自SDPA注释)
目前SDPA有三种实现:
- 基于FlashAttention-2的实现;
- Memory-Efficient Attention(facebook xformers);
- Pytorch版本对上述伪代码的c++实现(对应MATH后端)。
针对CUDA后端,SDPA可能会调用经过优化的内核以提高性能。对于所有其他后端,将使用PyTorch实现。所有实现方式默认都是启用的,SDPA会尝试根据输入自动选择最优的实现方式。为了对使用哪种实现方式提供更细粒度的控制,torch提供了以下函数来启用和禁用各种实现方式:
- torch.nn.attention.sdpa_kernel:一个上下文管理器,用于启用或禁用任何一种实现方式;
- torch.backends.cuda.enable_flash_sdp:全局启用或禁用FlashAttention
- torch.backends.cuda.enable_mem_efficient_sdp:全局启用或禁用memory efficient attention
- torch.backends.cuda.enable_math_sdp:全局启用或禁用PyTorch的C++实现。
每个融合内核都有特定的输入限制。如果用户需要使用特定的融合实现方式,请使用torch.nn.attention.sdpa_kernel 禁用PyTorch 的C++实现。如果某个融合实现方式不可用,将会发出警告,说明该融合实现方式无法运行的原因。由于融合浮点运算的特性,此函数的输出可能会因所选择的后端内核而异。C++ 实现支持torch.float64,当需要更高精度时可以使用。对于math后端,如果输入是torch.half或torch.bfloat16类型,那么所有中间计算结果都会保持为torch.float类型。
4. SDPA使用示例
首先强调一点,灌入SDPA的QKV都是做过转置的,也即维度为batch x head x N x d,在老版本的torch中还需要QKV都是contiguous的,新版本下无此要求。SDPA注释中还给了两个示例,我们在此也给出:
# Optionally use the context manager to ensure one of the fused kernels is runquery = torch.rand(32, 8, 128, 64, dtype=torch.float16, device="cuda")key = torch.rand(32, 8, 128, 64, dtype=torch.float16, device="cuda")value = torch.rand(32, 8, 128, 64, dtype=torch.float16, device="cuda")with sdpa_kernel(backends=[SDPBackend.FLASH_ATTENTION]):F.scaled_dot_product_attention(query,key,value)
上述示例中,给定的输入为batch等于32,head等于8,上下文长度128,embedding维度64,然后通过sdpa_kernel选择使用FlashAttention。
示例二:
# Sample for GQA for llama3
query = torch.rand(32, 32, 128, 64, dtype=torch.float16, device="cuda")
key = torch.rand(32, 8, 128, 64, dtype=torch.float16, device="cuda")
value = torch.rand(32, 8, 128, 64, dtype=torch.float16, device="cuda")
with sdpa_kernel(backends=[SDPBackend.MATH]):F.scaled_dot_product_attention(query,key,value,enable_gqa=True)
示例二演示了GQA的用法,给定的query head数为32,key和value均为8,此时我们可以通过enable_gqa选项来实现对GQA的支持,此外代码还通过sdpa_kernel选项使用了MATH后端。
5. 参考
- FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning
- Memory-Efficient Attention
- Grouped-Query Attention
- Attention Is All You Need
相关文章:
Scaled_dot_product_attention(SDPA)使用详解
在学习huggingFace的Transformer库时,我们不可避免会遇到scaled_dot_product_attention(SDPA)这个函数,它被用来加速大模型的Attention计算,本文就详细介绍一下它的使用方法,核心内容主要参考了torch.nn.functional中该函数的注释…...
Linux练级宝典->Linux进程概念介绍
目录 进程基本概念 PCB概念 task_struct tack_struct内容分类 PID和PPID fork函数创建子进程 进程优先级概念 4个名词 进程地址空间 进程地址空间的意义 内核进程调度队列 优先级 活动队列 过期队列 进程基本概念 一个正在执行的程序。担当分配系统资源的实体&#…...
OpenHarmony 5.0 mpegts封装的H265视频播放失败的解决方案
问题现象 OpenHarmony 5.0版本使用AVPlayer播放mpegts封装格式的H.265(HEVC)编码格式的视频时出现报错导致播放失败 问题原因 OpenHarmony 5.0版本AVPlayer播放器使用histreamer引擎,因为 libav_codec_hevc_parser.z.so 动态库未开源导致H265编码格式视频解析不到…...
Qt从入门到入土(九) -model/view(模型/视图)框架
简介 Qt的模型/视图(Model/View)架构是一种用于分离数据处理和用户界面展示的设计模式。它允许开发者将数据存储和管理(模型)与数据的显示和交互(视图)解耦,从而提高代码的可维护性和可扩展性。…...
缓存之美:Guava Cache 相比于 Caffeine 差在哪里?
大家好,我是 方圆。本文将结合 Guava Cache 的源码来分析它的实现原理,并阐述它相比于 Caffeine Cache 在性能上的劣势。为了让大家对 Guava Cache 理解起来更容易,我们还是在开篇介绍它的原理: Guava Cache 通过分段(…...
[漏洞篇]XSS漏洞详解
[漏洞篇]XSS漏洞 一、 介绍 概念 XSS:通过JS达到攻击效果 XSS全称跨站脚本(Cross Site Scripting),为避免与层叠样式表(Cascading Style Sheets, CSS)的缩写混淆,故缩写为XSS。这是一种将任意 Javascript 代码插入到其他Web用户页面里执行以…...
【Leetcode 每日一题】2269. 找到一个数字的 K 美丽值
问题背景 一个整数 n u m num num 的 k k k 美丽值定义为 n u m num num 中符合以下条件的 子字符串 数目: 子字符串长度为 k k k。子字符串能整除 n u m num num。 给你整数 n u m num num 和 k k k,请你返回 n u m num num 的 k k k 美丽值…...
IO进程线程(线程)
作业 1.创建两个线程,分支线程1拷贝文件的前一部分,分支线程2拷贝文件的后一部分 2.创建三个线程,实现线程A打印A,线程B打印B,线程C打印C;重复打印顺序ABC。 信号量实现: 条件变量实现&#x…...
1-002:MySQL InnoDB引擎中的聚簇索引和非聚簇索引有什么区别?
在 MySQL InnoDB 存储引擎 中,索引主要分为 聚簇索引(Clustered Index) 和 非聚簇索引(Secondary Index)。它们的主要区别如下: 1. 聚簇索引(Clustered Index) 定义 聚簇索引是表数…...
tomcat单机多实例部署
一、部署方法 多实例可以运行多个不同的应用,也可以运行相同的应用,类似于虚拟主机,但是他可以做负载均衡。 方式一: 把tomcat的主目录挨个复制,然后把每台主机的端口给改掉就行了。 优点是最简单最直接,…...
论文阅读分享——UMDF(AAAI-24)
概述 题目:A Unified Self-Distillation Framework for Multimodal Sentiment Analysis with Uncertain Missing Modalities 发表:The Thirty-Eighth AAAI Conference on Artificial Intelligence (AAAI-24) 年份:2024 Github:暂…...
解决asp.net mvc发布到iis下安全问题
解决asp.net mvc发布到iis下安全问题 环境信息1.The web/application server is leaking version information via the "Server" HTTP response2.确保您的Web服务器、应用程序服务器、负载均衡器等已配置为强制执行Strict-Transport-Security。3.在HTML提交表单中找不…...
概念|RabbitMQ 消息生命周期 待消费的消息和待应答的消息有什么区别
目录 消息生命周期 一、消息创建与发布阶段 二、消息路由与存储阶段 三、消息存活与过期阶段 四、消息投递与消费阶段 五、消息生命周期终止 关键配置建议 待消费的消息和待应答的消息 一、待消费的消息(Unconsumed Messages) 二、待应答的消息…...
springboot三层架构详细讲解
目录 springBoot三层架构 0.简介1.各层架构 1.1 Controller层1.2 Service层1.3 ServiceImpl1.4 Mapper1.5 Entity1.6 Mapper.xml 2.各层之间的联系 2.1 Controller 与 Service2.2 Service 与 ServiceImpl2.3 Service 与 Mapper2.4 Mapper 与 Mapper.xml2.5 Service 与 Entity2…...
2025最新群智能优化算法:云漂移优化(Cloud Drift Optimization,CDO)算法求解23个经典函数测试集,MATLAB
一、云漂移优化算法 云漂移优化(Cloud Drift Optimization,CDO)算法是2025年提出的一种受自然现象启发的元启发式算法,它模拟云在大气中漂移的动态行为来解决复杂的优化问题。云在大气中受到各种大气力的影响,其粒子的…...
2025年Draw.io最新版本下载安装教程,附详细图文
2025年Draw.io最新版本下载安装教程,附详细图文 大家好,今天给大家介绍一款非常实用的流程图绘制软件——Draw.io。不管你是平时需要设计流程图、绘制思维导图,还是制作架构图,甚至是简单的草图,它都能帮你轻松搞定。…...
记录--洛谷 P1451 求细胞数量
如果想查看完整题目,请前往洛谷 P1451 求细胞数量 P1451 求细胞数量 题目描述 一矩形阵列由数字 0 0 0 到 9 9 9 组成,数字 1 1 1 到 9 9 9 代表细胞,细胞的定义为沿细胞数字上下左右若还是细胞数字则为同一细胞,求给定矩形…...
Android Studio 配置国内镜像源
Android Studio版本号:2022.1.1 Patch 2 1、配置gradle国内镜像,用腾讯云 镜像源地址:https\://mirrors.cloud.tencent.com/gradle 2、配置Android SDK国内镜像 地址:Index of /AndroidSDK/...
做到哪一步才算精通SQL
做到哪一步才算精通SQL-Structured Query Language 数据定义语言 DDL for StructCREATE:用来创建数据库、表、索引等对象ALTER:用来修改已存在的数据库对象DROP:用来删除整个数据库或者数据库中的表TRUNCATE:用来删除表中所有的行…...
Manus演示案例: 英伟达财务估值建模 解锁投资洞察的深度剖析
在当今瞬息万变的金融投资领域,精准剖析企业价值是投资者决胜市场的关键。英伟达(NVIDIA),作为科技行业的耀眼明星,其在人工智能和半导体领域的卓越表现备受瞩目。Manus 凭借专业的财务估值建模能力,深入挖…...
逻辑回归:给不确定性划界的分类大师
想象你是一名医生。面对患者的检查报告(肿瘤大小、血液指标),你需要做出一个**决定性判断**:恶性还是良性?这种“非黑即白”的抉择,正是**逻辑回归(Logistic Regression)** 的战场&a…...
基于Flask实现的医疗保险欺诈识别监测模型
基于Flask实现的医疗保险欺诈识别监测模型 项目截图 项目简介 社会医疗保险是国家通过立法形式强制实施,由雇主和个人按一定比例缴纳保险费,建立社会医疗保险基金,支付雇员医疗费用的一种医疗保险制度, 它是促进社会文明和进步的…...
Module Federation 和 Native Federation 的比较
前言 Module Federation 是 Webpack 5 引入的微前端架构方案,允许不同独立构建的应用在运行时动态共享模块。 Native Federation 是 Angular 官方基于 Module Federation 理念实现的专为 Angular 优化的微前端方案。 概念解析 Module Federation (模块联邦) Modul…...
微信小程序云开发平台MySQL的连接方式
注:微信小程序云开发平台指的是腾讯云开发 先给结论:微信小程序云开发平台的MySQL,无法通过获取数据库连接信息的方式进行连接,连接只能通过云开发的SDK连接,具体要参考官方文档: 为什么? 因为…...
MySQL中【正则表达式】用法
MySQL 中正则表达式通过 REGEXP 或 RLIKE 操作符实现(两者等价),用于在 WHERE 子句中进行复杂的字符串模式匹配。以下是核心用法和示例: 一、基础语法 SELECT column_name FROM table_name WHERE column_name REGEXP pattern; …...
dify打造数据可视化图表
一、概述 在日常工作和学习中,我们经常需要和数据打交道。无论是分析报告、项目展示,还是简单的数据洞察,一个清晰直观的图表,往往能胜过千言万语。 一款能让数据可视化变得超级简单的 MCP Server,由蚂蚁集团 AntV 团队…...
云原生安全实战:API网关Kong的鉴权与限流详解
🔥「炎码工坊」技术弹药已装填! 点击关注 → 解锁工业级干货【工具实测|项目避坑|源码燃烧指南】 一、基础概念 1. API网关(API Gateway) API网关是微服务架构中的核心组件,负责统一管理所有API的流量入口。它像一座…...
力扣热题100 k个一组反转链表题解
题目: 代码: func reverseKGroup(head *ListNode, k int) *ListNode {cur : headfor i : 0; i < k; i {if cur nil {return head}cur cur.Next}newHead : reverse(head, cur)head.Next reverseKGroup(cur, k)return newHead }func reverse(start, end *ListNode) *ListN…...
Proxmox Mail Gateway安装指南:从零开始配置高效邮件过滤系统
💝💝💝欢迎莅临我的博客,很高兴能够在这里和您见面!希望您在这里可以感受到一份轻松愉快的氛围,不仅可以获得有趣的内容和知识,也可以畅所欲言、分享您的想法和见解。 推荐:「storms…...
系统掌握PyTorch:图解张量、Autograd、DataLoader、nn.Module与实战模型
本文较长,建议点赞收藏,以免遗失。更多AI大模型应用开发学习视频及资料,尽在聚客AI学院。 本文通过代码驱动的方式,系统讲解PyTorch核心概念和实战技巧,涵盖张量操作、自动微分、数据加载、模型构建和训练全流程&#…...
