当前位置: 首页 > article >正文

昇腾CANN ops-transformer FlashAttention 反向传播:不存 Attention 矩阵怎么求梯度

FlashAttention 前向传播的精髓不存 N×N 的 attention 矩阵只存 O(N) 的输出和 softmax 归一化因子。反向传播时需要 attention 矩阵来计算梯度——但矩阵没存。解法重新算一遍。用额外的计算换显存——这是典型的 compute-for-memory tradeoff。512K 上下文下标准 attention backward 需要 512GB存 attention 矩阵 512GB存梯度 1TB 显存。FlashAttention backward 只需要 ~几 GB。标准 Attention 反向传播的梯度公式前向O softmax(QK^T / √d) × V P × V反向需要三个梯度dQ, dK, dV都需要 attention 矩阵 PdV P^T × dO dP dO × V^T dS dP ⊙ P - P ⊙ (sum(dP ⊙ P, dim-1)) # softmax 反向 dQ dS × K dK dS^T × Q所有公式都依赖 Pattention 矩阵——但 FlashAttention 前向没存它。FlashAttention 反向重算 分块反向传播的核心思路在反向 pass 中重新执行前向计算。前向时跑了分块 softmax 分块加权求和反向时再次分块——但这次不仅计算 O还要计算 dQ, dK, dV。// ops-transformer/kernels/flash_attention/flash_attention_backward.cpp__aicore__voidFlashAttentionBackward(GlobalTensorfloat16dO,// 输出梯度 [B, H, N, D]GlobalTensorfloat16Q,// 前向的 Q保留GlobalTensorfloat16K,// 前向的 K保留GlobalTensorfloat16V,// 前向的 V保留GlobalTensorfloat16L,// 前向的 row_sum (softmax 分母)GlobalTensorfloat16dQ,// Q 的梯度GlobalTensorfloat16dK,// K 的梯度GlobalTensorfloat16dV,// V 的梯度intN,intD){constexprintBr32;constexprintBc32;// 第一步重算 dV最简单——只需 P 和 dOfor(intj0;jnum_kv_blocks;j){LocalTensorfloat16dV_local(Bc,D);for(intbc0;bcBc;bc)for(intd0;dD;d)dV_local[bc][d]0.0f;for(inti0;inum_q_blocks;i){LocalTensorfloat16Qi(Br,D);LocalTensorfloat16Kj(Bc,D);DataCopy(Qi,Qi*Br*D,Br*D);DataCopy(Kj,Kj*Bc*D,Bc*D);// 重算 S Qi × Kj^TLocalTensorfloat16S_block(Br,Bc);for(intr0;rBr;r)for(intc0;cBc;c){floatsum0.0f;for(intd0;dD;d)sumfloat(Qi[r*Dd])*float(Kj[c*Dd]);S_block[r*Bcc]float16(sum);}// 重算 P softmax(S)用前向存的 row_max 和 Lfor(intr0;rBr;r){floatmax_valrow_max_forward[i*Brr];floatsum_expL[i*Brr];for(intc0;cBc;c){floatP_valexpf(float(S_block[r*Bcc])-max_val)/sum_exp;// dV_j sum_i(P_ij × dO_i)for(intd0;dD;d)dV_local[c][d]P_val*float(dO[(i*Brr)*Dd]);}}}DataCopy(dVj*Bc*D,dV_local,Bc*D);}// 第二步重算 dQ需要 dP 和 dSfor(inti0;inum_q_blocks;i){LocalTensorfloat16Qi(Br,D);DataCopy(Qi,Qi*Br*D,Br*D);floatdQi_init[Br][D]{0.0f};for(intj0;jnum_kv_blocks;j){LocalTensorfloat16Kj(Bc,D);LocalTensorfloat16Vj(Bc,D);LocalTensorfloat16dOi(Br,D);DataCopy(Kj,Kj*Bc*D,Bc*D);DataCopy(Vj,Vj*Bc*D,Bc*D);DataCopy(dOi,dOi*Br*D,Br*D);// 重算 S_blockLocalTensorfloat16S_block(Br,Bc);for(intr0;rBr;r)for(intc0;cBc;c){floatsum0.0f;for(intd0;dD;d)sumfloat(Qi[r*Dd])*float(Kj[c*Dd]);S_block[r*Bcc]float16(sum);}for(intr0;rBr;r){floatmax_valrow_max_forward[i*Brr];floatsum_expL[i*Brr];// dP_ij dO_i × V_j^TfloatdP[Bc];for(intc0;cBc;c){dP[c]0.0f;for(intd0;dD;d)dP[c]float(dOi[r*Dd])*float(Vj[c*Dd]);}// D_i sum_j(dP_ij × P_ij)floatD_i0.0f;for(intc0;cBc;c){floatP_valexpf(float(S_block[r*Bcc])-max_val)/sum_exp;D_idP[c]*P_val;}// dS_ij (dP_ij - D_i) × P_ij → dQ_i sum_j(dS_ij × Kj)for(intc0;cBc;c){floatP_valexpf(float(S_block[r*Bcc])-max_val)/sum_exp;floatdS_val(dP[c]-D_i)*P_val;for(intd0;dD;d)dQi_init[r][d]dS_val*float(Kj[c*Dd]);}}}for(intr0;rBr;r)for(intd0;dD;d)dQ[(i*Brr)*Dd]float16(dQi_init[r][d]);}// 第三步dK对称于 dQ公式是 dK_j sum_i(dS_ij^T × Qi)for(intj0;jnum_kv_blocks;j){LocalTensorfloat16Kj(Bc,D);DataCopy(Kj,Kj*Bc*D,Bc*D);floatdKj_init[Bc][D]{0.0f};for(inti0;inum_q_blocks;i){LocalTensorfloat16Qi(Br,D);LocalTensorfloat16Vj(Bc,D);LocalTensorfloat16dOi(Br,D);DataCopy(Qi,Qi*Br*D,Br*D);DataCopy(Vj,Vj*Bc*D,Bc*D);DataCopy(dOi,dOi*Br*D,Br*D);LocalTensorfloat16S_block(Br,Bc);for(intr0;rBr;r)for(intc0;cBc;c){floatsum0.0f;for(intd0;dD;d)sumfloat(Qi[r*Dd])*float(Kj[c*Dd]);S_block[r*Bcc]float16(sum);}for(intr0;rBr;r){floatmax_valrow_max_forward[i*Brr];floatsum_expL[i*Brr];floatdP[Bc];for(intc0;cBc;c){dP[c]0.0f;for(intd0;dD;d)dP[c]float(dOi[r*Dd])*float(Vj[c*Dd]);}floatD_i0.0f;for(intc0;cBc;c){floatP_valexpf(float(S_block[r*Bcc])-max_val)/sum_exp;D_idP[c]*P_val;}for(intc0;cBc;c){floatP_valexpf(float(S_block[r*Bcc])-max_val)/sum_exp;floatdS_val(dP[c]-D_i)*P_val;// dK_j_c sum_r(dS_rc × Qi_r) ← 转置关系for(intd0;dD;d)dKj_init[c][d]dS_val*float(Qi[r*Dd]);}}}for(intc0;cBc;c)for(intd0;dD;d)dK[(j*Bcc)*Dd]float16(dKj_init[c][d]);}}前向保存的关键数据FlashAttention 前向完成后只保存两个向量不是 N×N 矩阵structFlashAttentionForwardCache{float32*row_max;// [num_q_blocks × Br] — softmax 每行的最大值float32*L;// [num_q_blocks × Br] — softmax 分母指数和float16*O;// [B, H, N, D] — 正常大小的输出};反向传播时用row_max和L重算 softmax 矩阵 P——每次只重算一个分块不同时存在于显存中。计算量分析标准 Attention 反向传播 - 前向O(N² × D) 计算 O(N²) 存储存 P - 反向O(N² × D) 计算直接读 P O(N²) 存储 P dP - 显存O(N²) 512GB (512K seq) FlashAttention 反向传播 - 前向O(N² × D) 计算 O(N) 存储只存 row_max L - 反向O(N² × D) 计算 × 2重算两次 P → dV 和 dQ/dK - 显存O(N) ~几 GB 额外计算反向多一倍重算了 P 两次 显存节省O(N²) → O(N)512K seq 下 512GB → 几 GB踩坑一row_max 和 L 用 FP16 保存 → 梯度偏移重算 softmax 需要前向的 row_max 和 L。FP16 保存 ±0.001 误差在 exp(S - max) 中偏差被放大// ❌ FP16 保存 row_max → 还原时 ±0.001 误差float16 row_max_fwd_fp16[N];floatrow_max_restoredfloat(row_max_fwd_fp16[i]);// 偏差 0.001// exp(S - max) 中偏差放大floattrue_expexpf(88.0f-88.0f)1.0f;floatwrong_expexpf(88.0f-88.001f)0.999f;// 偏差 0.1%// ✅ FP32 保存 row_max 和 Lfloat32 row_max_fwd_fp32[N];float32 L_fwd_fp32[N];实测FP16 row_max → LLaMA 7B 训练 loss 在 5000 步后偏离 0.03vs 基线FP32 row_max → 只偏离 0.001。踩坑二dQ 和 dK 各自重算一遍 P → 白白多算一次反向需要重算两次前向一次算 dV需要 P一次算 dQ 和 dK也需要 P。两次重算之间 P 没保存 → 算了两遍。// ❌ 两次重算 P —— 第二次浪费了for(i,j){Precompute(Qi,Kj);dVP^T × dO;// 第一次重算——只用了一次}for(i,j){Precompute(Qi,Kj);// 又算一次——浪费dQdS × K;dKdS^T × Q;}// ✅ 一次重算 P同时输出 dV/dQ/dKfor(i,j){Precompute(Qi,Kj);dVP^T × dO[i];// 一次 P三种梯度dS(dP-D_i)× P;dQdS × K[j];dKdS^T × Q[i];}踩坑三FP16 累加器精度损失dQi_init[r][d]跨 8 个 Bc chunks 累加——每个贡献微量。FP16 累加 8 次 → 误差累积。// ❌ FP16 累加器float16 dQi_init[Br][D];// 8 次累加后每次舍入 → 总误差 ~0.1%// ✅ FP32 累加器只在写回时转 FP16floatdQi_init[Br][D];// ... 8 次累加全精度dQ[...]float16(dQi_init[r][d]);// 只一次转换FlashAttention 反向的本质不存 N² 矩阵用 N² 的额外计算换回来。前向保存 row_max 和 LO(N) 大小反向重算两次 PO(N²) 计算得到完整的 dQ/dK/dV。512K seq 下显存从 512GB 降到几 GB——这是训练长上下文模型的唯一可行路径。三个关键row_max/L 用 FP32 保存不要节约 4 bytes 丢了精度、dQ 和 dK 的重算合并为一次重算P 算一次输出三种梯度、累加器全用 FP32最后才转 FP16。

相关文章:

昇腾CANN ops-transformer FlashAttention 反向传播:不存 Attention 矩阵怎么求梯度

FlashAttention 前向传播的精髓:不存 NN 的 attention 矩阵,只存 O(N) 的输出和 softmax 归一化因子。反向传播时,需要 attention 矩阵来计算梯度——但矩阵没存。解法:重新算一遍。用额外的计算换显存——这是典型的 compute-for…...

在node js后端服务中集成taotoken实现多模型智能客服响应

🚀 告别海外账号与网络限制!稳定直连全球优质大模型,限时半价接入中。 👉 点击领取海量免费额度 在 Node.js 后端服务中集成 Taotoken 实现多模型智能客服响应 构建一个在线客服系统时,一个核心挑战是如何平衡响应质量…...

通过Taotoken的Token Plan套餐实现项目成本的可预测与精细控制

🚀 告别海外账号与网络限制!稳定直连全球优质大模型,限时半价接入中。 👉 点击领取海量免费额度 通过Taotoken的Token Plan套餐实现项目成本的可预测与精细控制 对于有长期、稳定大模型调用需求的团队而言,项目预算的…...

现在停用默认filter_config将导致合规风险!DeepSeek最新CVE-2024-7812漏洞预警及3小时紧急加固方案

更多请点击: https://codechina.net 第一章:DeepSeek敏感信息过滤 DeepSeek系列大模型在企业级部署中,需严格遵循数据安全与隐私合规要求。敏感信息过滤(Sensitive Information Filtering, SIF)是其推理链路中关键的前…...

DeepSeek免费额度怎么用才不浪费?资深MLOps工程师的6小时压测报告与最优请求批处理公式

更多请点击: https://kaifayun.com 第一章:DeepSeek免费额度怎么用才不浪费?资深MLOps工程师的6小时压测报告与最优请求批处理公式 在连续6小时、覆盖12种负载模式的真实压测中,我们发现DeepSeek API免费额度(当前为1…...

DeepSeek监控告警设置实战指南(告警失效率下降92%的7个关键开关)

更多请点击: https://kaifayun.com 第一章:DeepSeek监控告警设置的核心价值与落地挑战 在大模型推理服务规模化部署的背景下,DeepSeek系列模型(如DeepSeek-V2、DeepSeek-Coder)对资源稳定性、延迟敏感性及异常响应时效…...

Google 广告场景下 Uniswap 钓鱼攻击机理与 Web3 防御体系研究

摘要 2026 年 5 月 22 日,GoPlus 安全团队发布预警,针对 Web3 领域头部去中心化交易平台 Uniswap 的搜索引擎钓鱼攻击呈规模化爆发态势。攻击者通过购买 Google Ads 关键词广告,将高仿钓鱼网站置顶于搜索结果前列,结合视觉相似域名…...

人机协同闭环:AI 时代邮件安全 “人在回路” 防御体系研究

摘要 2026 年,生成式 AI 全面渗透网络钓鱼攻击链,攻击从批量群发转向精准定制、从静态模板转向动态逃逸,传统纯技术防护出现显著盲区。数据显示,AI 自动化鱼叉式钓鱼点击率达 54%,攻击从投放至全面入侵的窗口压缩至秒级…...

高校邮件安全体系升级与 Proofpoint 部署实践研究 —— 以特拉华大学为例

摘要:随着网络钓鱼、垃圾邮件与恶意邮件攻击持续威胁高校信息系统,电子邮件安全已成为校园网络防护的核心环节。特拉华大学自 2026 年 6 月 1 日起全面启用 Proofpoint 邮件安全平台,构建覆盖邮件过滤、威胁隔离、用户自助处置与安全运营的全…...

Kali365 设备代码钓鱼攻击机理、危害及防御体系研究

摘要 2026 年 5 月 FBI 发布预警,新型钓鱼即服务平台 Kali365 通过滥用 Microsoft 365 OAuth 2.0 设备代码授权流程,可在不窃取密码、不伪造登录页面的前提下绕过多因素认证,获取长期有效访问令牌,实现账户持久化控制。该平台依托…...

基于 OAuth 设备码流滥用的 Kali365 钓鱼攻击机理与防御体系研究

摘要 2026 年 5 月,美国联邦调查局(FBI)发布安全预警,披露针对 Microsoft 365 环境的 PhaaS 平台 Kali365 正通过滥用 OAuth 设备码认证流程实施规模化钓鱼攻击,可绕过多因素认证(MFA)窃取合法访…...

为什么92%的DeepSeek微调失败?资深架构师拆解3类致命配置错误及实时诊断命令

更多请点击: https://kaifayun.com 第一章:DeepSeek模型微调失败率的行业现状与根本归因 近年来,DeepSeek系列大模型(如DeepSeek-V2、DeepSeek-Coder)在开源社区和企业私有化部署中广泛应用,但实证调研显示…...

【ChatGPT故事化表达黄金法则】:20年AI内容专家亲授3步叙事框架,让提示词转化率提升300%

更多请点击: https://intelliparadigm.com 第一章:ChatGPT故事化表达的底层认知革命 传统人机交互长期受限于指令式范式——用户需精确编码意图,系统则机械匹配关键词或规则。ChatGPT 的突破性不在于参数规模,而在于其将语言建模…...

C++学习笔记26:static 静态成员

目录 一、为什么需要静态成员? 二、静态成员变量 三、静态成员变量需要类外定义 四、用静态成员变量统计对象个数 五、静态成员变量不占对象空间 六、静态成员函数 七、静态成员函数没有 this 指针 八、静态成员函数可以访问静态成员 九、调用方式 1. 通过…...

【限时解锁】Gemini深度研究模式私有化部署方案:仅3家头部科研机构掌握的本地化推理链配置

更多请点击: https://codechina.net 第一章:Gemini深度研究模式的核心原理与能力边界 Gemini深度研究模式并非简单增强上下文长度的推理机制,而是一种面向复杂知识密集型任务的分层式认知架构。其核心原理在于动态构建“问题-证据-推理”三元…...

【Gemini生命周期价值深度解码】:20年AI架构师亲授5大阶段ROI测算模型与避坑指南

更多请点击: https://intelliparadigm.com 第一章:Gemini生命周期价值分析 Gemini 模型的生命周期价值(LTV)不仅体现在其推理性能与多模态能力上,更贯穿于从模型部署、持续微调、监控反馈到迭代升级的完整闭环。相较于…...

【ChatGPT投资人邮件撰写黄金法则】:20年FA/VC顾问亲授——3类高回复率模板+5个致命话术雷区

更多请点击: https://codechina.net 第一章:ChatGPT投资人邮件撰写的核心认知与底层逻辑 投资人邮件不是信息的简单堆砌,而是认知对齐、信任构建与决策催化三重目标的高度凝练表达。其底层逻辑根植于风险投资行业的决策机制——LP关注资金效…...

ChatGPT移动端隐私红线报告(2024Q2):麦克风/剪贴板/位置数据采集路径全曝光,3步彻底锁死敏感权限

更多请点击: https://intelliparadigm.com 第一章:ChatGPT移动端隐私红线报告(2024Q2)核心发现与风险定级 高危数据外泄通道实证 本季度对iOS与Android平台主流ChatGPT客户端(含官方App v6.12.1及第三方封装SDK集成应…...

【小红书算法偏爱的文案结构】:ChatGPT无法自学的3层语义嵌套技巧(含2024Q2平台最新流量权重白皮书节选)

更多请点击: https://kaifayun.com 第一章:小红书算法偏爱的文案结构本质解构 小红书的推荐算法并非仅依赖关键词或标签匹配,其核心是通过多模态语义理解与用户行为反馈闭环,对文案的信息密度、情绪节奏和结构可读性进行加权评估…...

新手注册Taotoken后第一步如何获取并测试API Key

🚀 告别海外账号与网络限制!稳定直连全球优质大模型,限时半价接入中。 👉 点击领取海量免费额度 新手注册Taotoken后第一步如何获取并测试API Key 注册Taotoken平台后,您已经拥有了一个统一的入口来调用多种大模型。接…...

Taotoken的Token Plan套餐如何帮助初创公司控制AI实验成本

🚀 告别海外账号与网络限制!稳定直连全球优质大模型,限时半价接入中。 👉 点击领取海量免费额度 Taotoken的Token Plan套餐如何帮助初创公司控制AI实验成本 1. 成本不可预测:初创AI实验的常见困境 在产品原型和早期开…...

如何为嵌入式项目配置大模型API调用使用Taotoken与Python

🚀 告别海外账号与网络限制!稳定直连全球优质大模型,限时半价接入中。 👉 点击领取海量免费额度 如何为嵌入式项目配置大模型API调用使用Taotoken与Python 对于嵌入式或物联网开发者而言,在资源受限的开发环境中集成A…...

创业团队如何利用Taotoken统一管理多个AI应用API成本

🚀 告别海外账号与网络限制!稳定直连全球优质大模型,限时半价接入中。 👉 点击领取海量免费额度 创业团队如何利用Taotoken统一管理多个AI应用API成本 对于同时开发多个集成AI功能的初创公司而言,技术选型与快速迭代是…...

对比按量计费与Token Plan套餐如何为项目选择更优成本模型

🚀 告别海外账号与网络限制!稳定直连全球优质大模型,限时半价接入中。 👉 点击领取海量免费额度 对比按量计费与Token Plan套餐如何为项目选择更优成本模型 在将大模型能力集成到开发项目中时,成本控制是一个绕不开的…...

3步构建物联网数字孪生:Eclipse Ditto实战指南

3步构建物联网数字孪生:Eclipse Ditto实战指南 【免费下载链接】ditto Eclipse Ditto™: Digital Twin framework of Eclipse IoT - main repository 项目地址: https://gitcode.com/gh_mirrors/ditto6/ditto 在物联网(IoT)时代,如何高效管理成千…...

凸轮机构设计(黄老板)

1. 2. 3....

通过curl命令快速测试Taotoken不同模型的响应速度与效果

🚀 告别海外账号与网络限制!稳定直连全球优质大模型,限时半价接入中。 👉 点击领取海量免费额度 通过curl命令快速测试Taotoken不同模型的响应速度与效果 对于习惯使用命令行工具的技术人员来说,curl是一个直接且高效…...

Solr CVE-2019-0193漏洞深度解析:DataImportHandler远程代码执行原理与实战修复

1. 这个漏洞不是“能远程执行代码”那么简单,而是Solr管理员自己亲手打开的后门 Apache Solr 是企业级搜索领域绕不开的基础设施,我经手过的金融、电商、政务类项目里,有七成以上都用它做全文检索底座。但2019年爆出的 CVE-2019-0193&#xf…...

微信M4A文件打不开怎么办?m4a转MP3只需一招,小白也能操作

很多人会遇到这种情况:别人通过微信发来一段录音、会议音频、课程音频或者采访素材,文件后缀是.m4a,在微信里可能能播放,但保存到手机本地、发到电脑、导入剪辑软件或者复制到U盘后,就可能出现打不开、无法识别、格式不…...

有哪些免费好用的在线论文排版工具值得推荐?

毕业季最让人头疼的,从来都不是论文内容创作,而是繁琐的格式排版 —— 标题层级错乱、目录更新失效、参考文献格式不规范、页眉页脚混乱…… 手动调整动辄耗费数小时,还容易反复返工。其实,多款免费好用的在线论文排版工具已能完美…...