当前位置: 首页 > article >正文

动手学深度学习——注意力机制代码

1. 前言上一篇我们已经从思想上理解了注意力机制基础 Seq2Seq 的问题在于固定长度上下文向量解码器在不同时间步其实应该关注输入序列的不同位置注意力机制的本质就是对输入表示做加权和权重由当前位置和各输入位置的相关性决定这一篇就继续按李沐的节奏把注意力机制真正落到代码上。这一节最重要的不是一开始就把所有复杂变体都铺开而是先把最核心的代码骨架看懂查询query是什么键key是什么值value是什么注意力权重怎么得到加权求和怎么实现你会发现注意力机制代码的灵魂其实很简单先算相关性分数再做 softmax再对 value 加权求和。2. 注意力代码到底在做什么如果从最抽象的角度看注意力机制的输入通常是三部分querykeyvalue然后输出一个结果根据 query 和 key 的匹配程度决定如何对 value 做加权汇总。所以它的计算主线可以写成三步第一步算分数score(query, key)第二步归一化成权重softmax(scores)第三步对 value 加权和attention_weights values所以注意力代码不是神秘黑箱本质就是分数 → 权重 → 加权和3. 为什么会有 query、key、value 这三个名字这三个名字第一次看会有点抽象但其实非常形象。你可以把它理解成“查询数据库”的过程query表示你现在想找什么。key表示每个候选位置的“索引标签”。value表示每个候选位置真正存放的内容。在注意力里query 决定当前需要什么信息key 决定每个位置和当前需求有多相关value 才是最终被加权汇总的内容所以query 用来问key 用来比value 用来取。4. 在 Seq2Seq 中query、key、value 分别是谁放到机器翻译的解码器场景里最常见的理解是query当前解码器时刻的隐藏状态也就是我现在要生成第t个目标词我当前最需要什么信息key编码器每个时间步的输出表示也就是源句子每个位置都提供一个“可匹配的表示”value通常也是编码器每个时间步的输出表示也就是最终真正被加权汇总的源句信息所以在最基础的 Seq2Seq 注意力里常见是query decoder hidden statekey encoder outputsvalue encoder outputs5. 最基础的注意力代码要先解决什么这一节李沐这里通常会先实现一种比较简单的注意力层例如“加性注意力”或一个通用注意力模块。但在进入具体分数函数之前通常会先把一个公共步骤处理掉masked softmax因为在序列任务里输入往往有 padding。如果不把 padding 位置屏蔽掉模型可能会把注意力错误地分给那些补齐出来的无效位置。所以注意力代码里非常基础的一步就是先算出分数再对无效位置 mask再做 softmax6. 什么是 masked softmaxmasked softmax 的作用是只在有效位置上做 softmax把 padding 位置的权重压成 0。为什么需要它假设一个 batch 里两条句子长度不同第一句长度是 5第二句长度是 3但 pad 到了 5那么第二句后面两个位置其实是无效的pad。如果注意力还把权重分给这两个位置就会污染上下文向量。所以必须在 softmax 前把这些位置“屏蔽掉”。7. masked softmax 代码怎么理解常见写法大致如下def masked_softmax(X, valid_lens): if valid_lens is None: return nn.functional.softmax(X, dim-1) else: shape X.shape if valid_lens.dim() 1: valid_lens torch.repeat_interleave(valid_lens, shape[1]) else: valid_lens valid_lens.reshape(-1) X d2l.sequence_mask(X.reshape(-1, shape[-1]), valid_lens, value-1e6) return nn.functional.softmax(X.reshape(shape), dim-1)这段代码乍一看有点绕但核心思想其实很简单第一步把无效位置赋成一个非常小的值例如-1e6第二步再做 softmax因为 softmax 后有效位置还能得到正常权重无效位置由于值极小权重几乎就是 0所以它本质上就是先 mask再 softmax8. 为什么把无效位置设成-1e6因为 softmax 的形式是指数归一化exp(x_i) / sum(exp(x_j))如果某个位置被设成-1e6那么exp(-1e6) ≈ 0这样它在 softmax 后的权重就几乎为 0。所以这种做法非常常见也非常实用。它不需要单独手写一个“软屏蔽公式”只要借助 softmax 的性质就行。9.valid_lens是什么valid_lens表示每个样本真实有效的序列长度例如一个 batch 有两条序列第一条长度 5第二条长度 3那么valid_lens [5, 3]这样注意力层就知道第一条的前 5 个位置有效第二条只有前 3 个位置有效后面是 padding所以valid_lens本质上就是 mask 的依据。10. 为什么注意力代码里常常要保存attention_weights很多实现里都会写self.attention_weights ...这是因为注意力机制一个很大的优点就是可解释性很强保存注意力权重有两个作用第一后续计算需要有些模块需要直接拿权重做加权和。第二便于可视化分析你可以把注意力权重画出来看模型当前到底在关注哪些输入位置。这也是注意力机制特别有魅力的一点它不像普通隐状态那么黑箱至少你能看到“它把注意力放在哪里”。11. 一个典型的注意力层长什么样李沐这里通常会实现一个加性注意力层例如class AdditiveAttention(nn.Module): def __init__(self, key_size, query_size, num_hiddens, dropout, **kwargs): super(AdditiveAttention, self).__init__(**kwargs) self.W_k nn.Linear(key_size, num_hiddens, biasFalse) self.W_q nn.Linear(query_size, num_hiddens, biasFalse) self.w_v nn.Linear(num_hiddens, 1, biasFalse) self.dropout nn.Dropout(dropout)这里只先看初始化。你会发现它并没有直接拿 query 和 key 点乘而是先做了几次线性变换。这就是“加性注意力”的特点。12. 加性注意力为什么叫“加性”因为它不是直接做内积而是先把 query 和 key 投影到同一个隐藏空间再相加、过非线性、再打分。直觉上可以写成score(q, k) w^T tanh(W_q q W_k k)这里最显眼的地方就是W_q q W_k k有个“加”。所以它被称为加性注意力Additive Attention这类注意力最早在 Seq2Seq 里非常经典也常叫Bahdanau attention的打分思路。13. 这三个线性层分别在干什么在初始化代码里self.W_k nn.Linear(key_size, num_hiddens, biasFalse) self.W_q nn.Linear(query_size, num_hiddens, biasFalse) self.w_v nn.Linear(num_hiddens, 1, biasFalse)可以这样理解。W_k把 key 投影到共同隐藏空间。W_q把 query 也投影到共同隐藏空间。w_v把两者融合后的隐藏表示再压成一个标量分数。也就是说加性注意力的分数不是直接算出来的而是先投影融合压缩最后得到一个注意力分数。14. 加性注意力的前向传播怎么写常见写法如下def forward(self, queries, keys, values, valid_lens): queries, keys self.W_q(queries), self.W_k(keys) features queries.unsqueeze(2) keys.unsqueeze(1) features torch.tanh(features) scores self.w_v(features).squeeze(-1) self.attention_weights masked_softmax(scores, valid_lens) return torch.bmm(self.dropout(self.attention_weights), values)这段代码就是注意力机制代码里最值得细拆的一段。15.queries, keys self.W_q(queries), self.W_k(keys)在做什么这一句表示先把 query 投影到隐藏空间再把 key 投影到同一个隐藏空间这样做的好处是不管原始 query 和 key 维度是否一样都可以先映射到统一空间里再比较。这是一种非常常见的做法。因为不同来源的表示未必天然适合直接比较先投影能让匹配更灵活。16.unsqueeze和广播加法为什么这么写这一句是核心features queries.unsqueeze(2) keys.unsqueeze(1)它的目的就是让每个 query 和每个 key 两两配对。假设queries形状是(batch_size, num_queries, num_hiddens)keys形状是(batch_size, num_kv_pairs, num_hiddens)那么queries.unsqueeze(2)会变成(batch_size, num_queries, 1, num_hiddens)keys.unsqueeze(1)会变成(batch_size, 1, num_kv_pairs, num_hiddens)然后通过广播相加就得到(batch_size, num_queries, num_kv_pairs, num_hiddens)这就相当于每个 query 都和所有 key 组合了一遍。这一步特别关键因为注意力本质上就是要比较当前 query 和所有 key 的相关性17. 为什么后面要tanhfeatures torch.tanh(features)这是加性注意力的非线性变换步骤。它的作用是增强表达能力让 query-key 融合后的表示不只是线性相加为后面的分数计算提供更灵活特征这和前面 RNN/LSTM/GRU 中tanh的作用有些相似都是为了让模型不只是简单线性变换。18.scores self.w_v(features).squeeze(-1)在干什么这一句表示把最后那个num_hiddens维特征压成一个标量分数。也就是说对每个 query-key 对最终都会得到一个实数分数于是scores的形状通常是(batch_size, num_queries, num_kv_pairs)这正好对应每个 query 对所有 key 的相关性打分表这张分数表后面经过 softmax就会变成注意力权重。19.masked_softmax(scores, valid_lens)在这里的意义是什么这里就是把前面讲的 mask 用上了。因为 key/value 序列可能有 padding所以在注意力分数转成权重之前必须把无效位置屏蔽掉。这一步之后self.attention_weights就会变成一组合法的注意力分布非负和为 1padding 位置权重几乎为 0所以这一步本质上是在说只在真实有效输入位置上分配注意力。20.torch.bmm(attention_weights, values)为什么能得到上下文向量最后一步torch.bmm(self.attention_weights, values)这里的bmm是 batch matrix multiplication也就是批量矩阵乘法。假设attention_weights形状是(batch_size, num_queries, num_kv_pairs)values形状是(batch_size, num_kv_pairs, value_dim)那么相乘后得到(batch_size, num_queries, value_dim)这正好就是对每个 query把所有 value 按注意力权重做加权和所以bmm这一步其实就是把“加权求和”高效矩阵化实现了。这也是注意力代码最核心的落地点注意力输出 权重 × values21. 为什么 values 不一定等于 keys在很多基础 Seq2Seq 注意力里keys encoder outputsvalues encoder outputs所以两者看起来一样。但从更一般的框架看它们其实不是必须相同。key负责被 query 匹配决定权重。value负责被加权求和形成输出。在更复杂模型里key 和 value 可以来自不同投影或不同表示。所以把它们分开是一种更通用的设计。22. 这一节代码最该掌握什么如果从学习重点看最重要的是这几件事。22.1 理解 masked softmax知道为什么注意力一定要 mask padding。22.2 理解 query、key、value 的角色分工query当前需求key匹配对象value最终取出的内容22.3 理解unsqueeze broadcast的作用这是实现 query-key 两两配对的关键。22.4 理解注意力分数到注意力权重的转换也就是打分softmax得到分布22.5 理解bmm为什么就是加权和这是注意力机制代码最核心的一步。23. 这一节和下一节“注意力分数”是什么关系这一节主要是在讲注意力机制的基本代码框架怎么搭也就是分数算出来以后怎么办权重怎么算加权和怎么实现而下一节“注意力分数”会更聚焦于分数本身到底怎么设计例如加性注意力缩放点积注意力打分函数不同会带来什么差异所以这两节可以这么理解这一节偏整体计算流程。下一节偏分数函数本身。24. 本节总结这一节我们学习了注意力机制的代码基础核心内容可以总结为以下几点。24.1 注意力机制代码的主线是打分 → softmax → 加权和这是最核心的三步。24.2 masked softmax 用于屏蔽 padding 位置确保无效 token 不参与注意力分配。24.3 query、key、value 分别承担不同角色它们共同决定当前上下文向量如何生成。24.4 加性注意力通过线性变换、非线性融合和打分得到注意力分数这是经典的 Seq2Seq 注意力实现方式。24.5torch.bmm实现了对 values 的批量加权求和这是注意力输出的关键一步。25. 学习感悟这一节特别有价值因为它让注意力机制第一次真正“落地”成了一个你能看懂的计算过程。以前我们说模型在关注某些位置模型在动态分配注意力这些话听起来都很抽象。但代码一拆开你会发现它其实很朴素先比较相关性再把相关性变成权重再按权重把信息汇总出来。也就是说注意力机制的伟大之处不在于它特别复杂而在于它用一种很自然的方式把“选择性读信息”这件事变成了可训练的模块。

相关文章:

动手学深度学习——注意力机制代码

1. 前言上一篇我们已经从思想上理解了注意力机制:基础 Seq2Seq 的问题在于固定长度上下文向量解码器在不同时间步,其实应该关注输入序列的不同位置注意力机制的本质,就是对输入表示做加权和权重由当前位置和各输入位置的相关性决定这一篇就继…...

Python 安全开发全栈指南:零基础

Python 安全开发当前时间背景:2026年4月 (Python 3.14) 核心工具:Python 3.x | Requests | Lxml | Re️ 全栈知识体系思维导图mindmaproot((Python安全开发))基础核心变量与数据类型数值 (int, float)字符串 (str)布尔 (bool)运算符算术 ( - * /)赋值 ()…...

深入osgEarth内核:3DTiles加载背后的多线程机制与性能优化

深入osgEarth内核:3DTiles加载背后的多线程机制与性能优化 在三维地理信息系统开发中,osgEarth作为开源的高性能三维地球引擎,其加载海量3DTiles数据的能力直接影响用户体验。本文将深入剖析osgEarth加载3DTiles时的多线程架构设计&#xff0…...

乐高Studio与Solidworks联动指南:如何快速导入自定义3D模型并生成积木设计

乐高Studio与Solidworks联动指南:如何快速导入自定义3D模型并生成积木设计 在数字设计与实体搭建的交汇点上,乐高Studio和Solidworks的联动为创意工作者开辟了全新可能。想象一下,当你精心设计的机械结构或建筑模型能够直接转化为可拼装的乐…...

MusicFree插件开发初探:手把手教你写一个简单的音源接口(.js文件)

MusicFree插件开发实战:从零构建自定义音源接口 第一次看到MusicFree的插件列表时,我就被它的开放性震撼了——这个播放器本身只是个"空壳",所有音源功能都靠插件实现。作为开发者,这意味着我们不仅能自由选择音源&…...

AutoSAR MCAL DIO驱动深度解析:英飞凌TC3XX的GPIO控制底层是如何工作的?

AutoSAR MCAL DIO驱动深度解析:英飞凌TC3XX的GPIO控制底层是如何工作的? 在嵌入式系统开发中,GPIO控制是最基础却又最关键的环节之一。当项目复杂度上升到需要符合AutoSAR标准时,传统的裸机寄存器操作方式就显得力不从心了。英飞凌…...

避开这些坑!NCCL多GPU环境配置常见问题排查手册(附性能测试脚本)

避开这些坑!NCCL多GPU环境配置常见问题排查手册(附性能测试脚本) 当你在Ubuntu系统上配置多GPU深度学习训练环境时,NCCL(NVIDIA Collective Communications Library)的性能表现往往决定了整个训练过程的效…...

HakcMyVM-Quick4

信息搜集 主机发现 ┌──(kali㉿kali)-[~] └─$ nmap -sn 192.168.2.0/24 Starting Nmap 7.95 ( https://nmap.org ) at 2026-04-15 03:19 EDTNmap scan report for quick4 (192.168.2.9) Host is up (0.00028s latency). MAC Address: 08:00:27:AA:84:13 (PCS Systemtechni…...

从‘飞线’到‘倒装’:一文看懂WBCSP和FCCSP封装该怎么选(附内存与处理器封装实战解析)

从‘飞线’到‘倒装’:WBCSP与FCCSP封装技术全维度对比与选型策略 在移动设备处理器和内存芯片的设计中,封装技术直接影响着性能、功耗和体积三大核心指标。当硬件工程师面对WBCSP(引线键合芯片级封装)和FCCSP(倒装芯片…...

2026届最火的AI辅助写作方案实际效果

Ai论文网站排名(开题报告、文献综述、降aigc率、降重综合对比) TOP1. 千笔AI TOP2. aipasspaper TOP3. 清北论文 TOP4. 豆包 TOP5. kimi TOP6. deepseek 在当下的学术环境里头,论文重复率过高乃是对毕业以及发表产生影响的关键所在问题…...

TinyML实战:从模型压缩到MCU部署的全链路解析

1. TinyML入门:为什么我们需要在MCU上跑AI? 第一次尝试在STM32F407上部署人脸检测模型时,我被现实狠狠教育了——原以为轻量级的MobileNetV2模型(在PC端只要20MB内存)可以直接运行,结果编译时报错显示内存不…...

不用显示器也能搞定!虚拟机环境下Jetson Nano镜像烧录全流程

无显示器环境下的Jetson Nano镜像烧录实战指南 在边缘计算和嵌入式AI开发领域,Jetson Nano凭借其强大的GPU算力和紧凑的尺寸,成为众多开发者的首选平台。然而,初次接触这块开发板时,镜像烧录过程往往成为第一道门槛——特别是当手…...

瑞芯微开发板避坑指南:yolov5s模型在RK3566上的帧率优化实战

瑞芯微RK3566开发板实战:YOLOv5模型选型与帧率优化全解析 边缘计算设备上的AI模型部署,往往需要在性能和精度之间寻找微妙的平衡。当我们手握一块瑞芯微RK3566开发板,面对YOLOv5系列模型时,如何根据实际场景选择最合适的模型&…...

用Python和sklearn搞定百度慧眼数据:从抓包到坐标转换的完整实战

Python实战:百度慧眼数据爬取与坐标转换全流程解析 当我们需要分析城市人流分布时,百度慧眼提供的热力图数据是个不错的选择。但直接从API获取的数据往往需要经过一系列处理才能用于分析。本文将带你完整走通从数据获取到坐标转换的整个流程,…...

朱雀AIGC检测不通过?手把手教你3步搞定降AI

朱雀AIGC检测不通过?手把手教你3步搞定降AI “论文查了朱雀,AIGC检测没通过,怎么办?” 这个问题最近在各种毕业群里出现的频率越来越高。尤其是2026年毕业季,越来越多的高校把朱雀AIGC检测作为论文提交的硬性要求&…...

朱雀AI检测率高怎么降?保姆级攻略:用嘎嘎降AI从56%降到0%

朱雀AI检测率高怎么降?保姆级攻略:用嘎嘎降AI从56%降到0% 最近好几个同学私信问我:论文交上去之前自己查了一下朱雀,AI检测率直接显示56%,心态都崩了。 别慌。56%看着吓人,但只要方法对,降到学校…...

蓝牙5.0广播包PDU字段逐行解读:从ADV_IND到AUX_CHAIN_IND,手把手教你抓包分析

蓝牙5.0广播包深度解析:从基础字段到实战抓包技巧 在物联网设备爆发式增长的今天,低功耗蓝牙(BLE)技术已经成为连接智能设备的首选方案。作为BLE通信的"敲门砖",广播包承载着设备发现、连接建立和数据交换的…...

别再为显存发愁了:用vLLM 0.6.3在单张3090上部署Qwen2-VL-7B的保姆级调参指南

单卡3090极限调优:Qwen2-VL-7B视觉语言模型高效部署实战手册 当24GB显存遇上70亿参数的视觉语言模型,这场"内存捉襟见肘"的战役该如何打赢?本文将揭示如何通过vLLM 0.6.3的精细调参,让Qwen2-VL-7B在单张RTX 3090上流畅运…...

别再只买NXP了!盘点国产NFC标签芯片(复旦微/飞聚/聚辰)选型指南

国产NFC标签芯片深度选型指南:复旦微、飞聚、聚辰实战对比 在智能硬件和物联网设备爆发式增长的今天,NFC技术因其便捷的"碰一碰"交互方式,正在从传统的支付、门禁领域向更广阔的应用场景扩展。然而,当大多数开发者习惯性…...

新手也能懂:用Python+NumPy模拟雷达快慢时间采样数据矩阵(附代码)

用PythonNumPy模拟雷达快慢时间采样数据矩阵实战指南 雷达信号处理听起来像是硬件工程师的专属领域?其实只要掌握基础Python和NumPy操作,软件开发者也能轻松理解雷达数据的核心逻辑。本文将带你用代码构建快慢时间采样矩阵,无需任何硬件设备&…...

告别复杂多任务学习:深度解读Depth Anything V3如何用‘一个Transformer+一个目标’统一3D重建

深度估计新范式:Depth Anything V3如何用极简架构重塑3D视觉 当计算机视觉领域还在为多视图几何的复杂性绞尽脑汁时,Depth Anything V3(DA3)的出现像一股清流,用"一个Transformer一个目标"的极简设计&#…...

PX4飞控参数调优实战:从“飘”到“稳”,手把手教你调好四旋翼PID

PX4飞控参数调优实战:从“飘”到“稳”,手把手教你调好四旋翼PID 当你第一次放飞自己组装的四旋翼无人机时,那种兴奋感难以言表。但很快,现实给了你当头一棒——无人机在空中像醉汉一样左右摇摆,或者像被风吹动的树叶一…...

告警风暴 vs 告警静默:多模态大模型监控体系的双峰困境破解术(基于200+线上实例的告警压缩率提升87%实践)

第一章:告警风暴 vs 告警静默:多模态大模型监控体系的双峰困境破解术(基于200线上实例的告警压缩率提升87%实践) 2026奇点智能技术大会(https://ml-summit.org) 在超大规模大模型服务集群中,传统阈值驱动的告警机制正…...

SeaTunnel Transform插件实战:从零构建自定义JSON解析器

1. 为什么需要自定义JSON解析器 在实际的数据处理场景中,我们经常会遇到各种复杂的JSON格式数据。就拿最常见的日志处理来说,从Kafka等消息队列获取的原始数据往往包含多层嵌套的JSON结构。比如下面这个典型例子: {"path": "x…...

酷安UWP:在Windows电脑上体验完整酷安社区的终极指南

酷安UWP:在Windows电脑上体验完整酷安社区的终极指南 【免费下载链接】Coolapk-UWP 一个基于 UWP 平台的第三方酷安客户端 项目地址: https://gitcode.com/gh_mirrors/co/Coolapk-UWP 还在为手机小屏幕刷酷安而感到眼睛酸痛吗?想在大屏幕上舒适地…...

如何高效使用KMS_VL_ALL_AIO智能激活工具:完整Windows与Office激活指南

如何高效使用KMS_VL_ALL_AIO智能激活工具:完整Windows与Office激活指南 【免费下载链接】KMS_VL_ALL_AIO Smart Activation Script 项目地址: https://gitcode.com/gh_mirrors/km/KMS_VL_ALL_AIO 还在为Windows系统激活而烦恼吗?每次重装系统后都…...

深入浅出:双三相电机弱磁控制里的‘电压极限圆’与‘电流极限圆’到底怎么用?

深入浅出:双三相电机弱磁控制里的‘电压极限圆’与‘电流极限圆’到底怎么用? 想象一下驾驶电动汽车爬坡时突然失去动力,或是高速巡航时电机发出异常噪音——这些都可能与弱磁控制策略不当有关。对于从事电机控制的工程师而言,理解…...

昆仑通态触摸屏与PLC标签通讯避坑指南:为什么变量名不能用中文?

昆仑通态触摸屏与PLC标签通讯优化实践:变量命名规范与性能提升 在工业自动化项目中,昆仑通态触摸屏与PLC的稳定通讯是确保系统高效运行的关键环节。许多工程师在实际调试中都遇到过通讯卡顿、操作响应延迟的问题,却往往忽略了最基础的变量命名…...

从PPO到Q-learning:手把手教你根据项目需求选对强化学习模式(在线vs离线)

从PPO到Q-learning:实战选型指南与强化学习模式决策框架 引言:当强化学习遇上工程现实 去年夏天,我参与了一个工业机器人抓取系统的优化项目。团队最初选择了PPO算法进行在线训练,结果机械臂在真实环境中频繁发生碰撞,…...

CentOS 7上Python 3.6连接人大金仓KingbaseES V8的保姆级教程(含libkci库配置避坑指南)

CentOS 7上Python 3.6连接KingbaseES V8的深度实践指南 在国产化技术生态快速发展的背景下,人大金仓数据库KingbaseES V8凭借其稳定性和兼容性,逐渐成为企业级应用的热门选择。对于需要在CentOS 7环境下使用Python 3.6进行开发的工程师而言,如…...