当前位置: 首页 > news >正文

动手学深度学习10.5. 多头注意力-笔记练习(PyTorch)

本节课程地址:多头注意力代码_哔哩哔哩_bilibili

本节教材地址:10.5. 多头注意力 — 动手学深度学习 2.0.0 documentation

本节开源代码:...>d2l-zh>pytorch>chapter_multilayer-perceptrons>multihead-attention.ipynb


多头注意力

在实践中,当给定相同的查询、键和值的集合时, 我们希望模型可以基于相同的注意力机制学习到不同的行为, 然后将不同的行为作为知识组合起来, 捕获序列内各种范围的依赖关系 (例如,短距离依赖和长距离依赖关系)。 因此,允许注意力机制组合使用查询、键和值的不同 子空间表示(representation subspaces)可能是有益的。

为此,与其只使用单独一个注意力汇聚, 我们可以用独立学习得到的 h 组不同的 线性投影(linear projections)来变换查询、键和值。 然后,这 h 组变换后的查询、键和值将并行地送到注意力汇聚中。 最后,将这 h 个注意力汇聚的输出拼接在一起, 并且通过另一个可以学习的线性投影进行变换, 以产生最终输出。 这种设计被称为多头注意力(multihead attention) (="https://zh.d2l.ai/chapter_references/zreferences.html#id174">Vaswaniet al., 2017)。 对于 h 个注意力汇聚输出,每一个注意力汇聚都被称作一个(head)。 图10.5.1 展示了使用全连接层来实现可学习的线性变换的多头注意力。

模型

在实现多头注意力之前,让我们用数学语言将这个模型形式化地描述出来。 给定查询 \mathbf{q} \in \mathbb{R}^{d_q} 、 键 \mathbf{k} \in \mathbb{R}^{d_k} 和 值 \mathbf{v} \in \mathbb{R}^{d_v} , 每个注意力头 \mathbf{h}_i( i = 1, \ldots, h )的计算方法为:

\mathbf{h}_i = f(\mathbf W_i^{(q)}\mathbf q, \mathbf W_i^{(k)}\mathbf k,\mathbf W_i^{(v)}\mathbf v) \in \mathbb R^{p_v},

其中,可学习的参数包括 \mathbf W_i^{(q)}\in\mathbb R^{p_q\times d_q} 、 \mathbf W_i^{(k)}\in\mathbb R^{p_k\times d_k} 和 \mathbf W_i^{(v)}\in\mathbb R^{p_v\times d_v} , 以及代表注意力汇聚的函数 f 。 f 可以是 10.3节 中的 加性注意力和缩放点积注意力。 多头注意力的输出需要经过另一个线性转换, 它对应着 h 个头连结后的结果,因此其可学习参数是 \mathbf W_o\in\mathbb R^{p_o\times h p_v} :

\mathbf W_o \begin{bmatrix}\mathbf h_1\\\vdots\\\mathbf h_h\end{bmatrix} \in \mathbb{R}^{p_o}.

基于这种设计,每个头都可能会关注输入的不同部分, 可以表示比简单加权平均值更复杂的函数。

import math
import torch
from torch import nn
from d2l import torch as d2l

实现

在实现过程中通常[选择缩放点积注意力作为每一个注意力头]。 为了避免计算代价和参数代价的大幅增长, 我们设定 p_q = p_k = p_v = p_o / h 。 值得注意的是,如果将查询、键和值的线性变换的输出数量设置为 p_q h = p_k h = p_v h = p_o , 则可以并行计算 h 个头。 在下面的实现中, p_o 是通过参数num_hiddens指定的。

#@save
class MultiHeadAttention(nn.Module):"""多头注意力"""def __init__(self, key_size, query_size, value_size, num_hiddens,num_heads, dropout, bias=False, **kwargs):super(MultiHeadAttention, self).__init__(**kwargs)self.num_heads = num_headsself.attention = d2l.DotProductAttention(dropout)self.W_q = nn.Linear(query_size, num_hiddens, bias=bias)self.W_k = nn.Linear(key_size, num_hiddens, bias=bias)self.W_v = nn.Linear(value_size, num_hiddens, bias=bias)self.W_o = nn.Linear(num_hiddens, num_hiddens, bias=bias)def forward(self, queries, keys, values, valid_lens):# queries,keys,values的形状:# (batch_size,查询或者“键-值”对的个数,num_hiddens)# valid_lens 的形状:# (batch_size,)或(batch_size,查询的个数)# 经过变换后,输出的queries,keys,values 的形状:# (batch_size*num_heads,查询或者“键-值”对的个数,# num_hiddens/num_heads)queries = transpose_qkv(self.W_q(queries), self.num_heads)keys = transpose_qkv(self.W_k(keys), self.num_heads)values = transpose_qkv(self.W_v(values), self.num_heads)if valid_lens is not None:# 在轴0,将第一项(标量或者矢量)复制num_heads次,# 然后如此复制第二项,然后诸如此类。valid_lens = torch.repeat_interleave(valid_lens, repeats=self.num_heads, dim=0)# output的形状:(batch_size*num_heads,查询的个数,# num_hiddens/num_heads)output = self.attention(queries, keys, values, valid_lens)# output_concat的形状:(batch_size,查询的个数,num_hiddens)output_concat = transpose_output(output, self.num_heads)return self.W_o(output_concat)

为了能够[使多个头并行计算], 上面的MultiHeadAttention类将使用下面定义的两个转置函数。 具体来说,transpose_output函数反转了transpose_qkv函数的操作。

#@save
def transpose_qkv(X, num_heads):"""为了多注意力头的并行计算而变换形状"""# 输入X的形状:(batch_size,查询或者“键-值”对的个数,num_hiddens)# 输出X的形状:(batch_size,查询或者“键-值”对的个数,num_heads,# num_hiddens/num_heads)X = X.reshape(X.shape[0], X.shape[1], num_heads, -1)# 输出X的形状:(batch_size,num_heads,查询或者“键-值”对的个数,# num_hiddens/num_heads)X = X.permute(0, 2, 1, 3)# 最终输出的形状:(batch_size*num_heads,查询或者“键-值”对的个数,# num_hiddens/num_heads)return X.reshape(-1, X.shape[2], X.shape[3])#@save
def transpose_output(X, num_heads):"""逆转transpose_qkv函数的操作"""# 输入X的形状:(batch_size*num_heads,查询或者“键-值”对的个数,num_hiddens/num_heads)# 输出X的形状:(batch_size,num_heads,查询或者“键-值”对的个数,num_hiddens/num_heads)X = X.reshape(-1, num_heads, X.shape[1], X.shape[2])# 输出X的形状:(batch_size,查询或者“键-值”对的个数,num_heads,num_hiddens/num_heads)X = X.permute(0, 2, 1, 3)# 最终输出的形状:((batch_size,查询或者“键-值”对的个数,num_hiddens)return X.reshape(X.shape[0], X.shape[1], -1)

下面使用键和值相同的小例子来[测试]我们编写的MultiHeadAttention类。 多头注意力输出的形状是(batch_sizenum_queriesnum_hiddens)。

num_hiddens, num_heads = 100, 5
attention = MultiHeadAttention(num_hiddens, num_hiddens, num_hiddens,num_hiddens, num_heads, 0.5)
attention.eval()
MultiHeadAttention((attention): DotProductAttention((dropout): Dropout(p=0.5, inplace=False))(W_q): Linear(in_features=100, out_features=100, bias=False)(W_k): Linear(in_features=100, out_features=100, bias=False)(W_v): Linear(in_features=100, out_features=100, bias=False)(W_o): Linear(in_features=100, out_features=100, bias=False)
)
batch_size, num_queries = 2, 4
num_kvpairs, valid_lens =  6, torch.tensor([3, 2])
X = torch.ones((batch_size, num_queries, num_hiddens))
Y = torch.ones((batch_size, num_kvpairs, num_hiddens))
attention(X, Y, Y, valid_lens).shape
torch.Size([2, 4, 100])

小结

  • 多头注意力融合了来自于多个注意力汇聚的不同知识,这些知识的不同来源于相同的查询、键和值的不同的子空间表示。
  • 基于适当的张量操作,可以实现多头注意力的并行计算。

练习

  1. 分别可视化这个实验中的多个头的注意力权重。

解:
代码如下:

attention.attention.attention_weights.shape
# (batch_size*num_heads,查询的个数,“键-值”对的个数)

输出结果:
torch.Size([10, 4, 6])

d2l.show_heatmaps(attention.attention.attention_weights.reshape((2,5,4,6)), xlabel='Key positions', ylabel='Query positions', titles=['Head %d' % i for i in range(1, 6)],figsize=(8, 3.5))

输出结果:

2. 假设有一个完成训练的基于多头注意力的模型,现在希望修剪最不重要的注意力头以提高预测速度。如何设计实验来衡量注意力头的重要性呢?

解:
首先定义评判注意力头重要性的指标,比如预测速度等;
然后采用单一变量法,修剪某一个头或某几个头的组合,重新训练模型,并在验证集上评估重要性指标的变化; 最后根据重要性指标的变化,判断最不重要的一个或几个注意力头,并修剪。

相关文章:

动手学深度学习10.5. 多头注意力-笔记练习(PyTorch)

本节课程地址:多头注意力代码_哔哩哔哩_bilibili 本节教材地址:10.5. 多头注意力 — 动手学深度学习 2.0.0 documentation 本节开源代码:...>d2l-zh>pytorch>chapter_multilayer-perceptrons>multihead-attention.ipynb 多头注…...

13 设计模式之外观模式(家庭影院案例)

一、什么是外观模式? 1.定义 在日常生活中,许多人喜欢通过遥控器来控制家中的电视、音响、DVD 播放器等设备。虽然这些设备各自独立工作,但遥控器提供了一个简洁的界面,让用户可以轻松地操作多个设备。而这一设计理念正是 外观模…...

单片机学习笔记 12. 定时/计数器_定时

更多单片机学习笔记:单片机学习笔记 1. 点亮一个LED灯单片机学习笔记 2. LED灯闪烁单片机学习笔记 3. LED灯流水灯单片机学习笔记 4. 蜂鸣器滴~滴~滴~单片机学习笔记 5. 数码管静态显示单片机学习笔记 6. 数码管动态显示单片机学习笔记 7. 独立键盘单片机学习笔记 8…...

Web安全基础实践

实践目标 (1)理解常用网络攻击技术的基本原理。(2)Webgoat实践下相关实验。 WebGoat WebGoat是由著名的OWASP负责维护的一个漏洞百出的J2EE Web应用程序,这些漏洞并非程序中的bug,而是故意设计用来讲授We…...

Zookeeper集群数据是如何同步的?

大家好,我是锋哥。今天分享关于【Zookeeper集群数据是如何同步的?】面试题。希望对大家有帮助; Zookeeper集群数据是如何同步的? 1000道 互联网大厂Java工程师 精选面试题-Java资源分享网 Zookeeper集群中的数据同步是通过一种称为ZAB(Zo…...

SpringCloud框架学习(第六部分:Sentinel实现熔断与限流)

目录 十四、SpringCloud Alibaba Sentinel实现熔断与限流 1.简介 2.作用 3.下载安装 4.微服务 8401 整合 Sentinel 入门案例 5.流控规则 (1)基本介绍 (2)流控模式 Ⅰ. 直接 Ⅱ. 关联 Ⅲ. 链路 (3&#xff0…...

动态规划-----路径问题

动态规划-----路径问题 下降最小路径和1:状态表示2:状态转移方程3 初始化4 填表顺序5 返回值6 代码实现 总结: 下降最小路径和 1:状态表示 假设:用dp[i][j]表示:到达[i,j]的最小路径 2:状态转…...

Rust循环引用与多线程并发

循环引用与自引用 循环引用的概念 循环引用指的是两个或多个对象之间相互持有对方的引用。在 Rust 中&#xff0c;由于所有权和生命周期的严格约束&#xff0c;直接创建循环引用通常会导致编译失败。例如&#xff1a; // 错误的循环引用示例 struct Node {next: Option<B…...

东方隐侠网安瞭望台第8期

谷歌应用商店贷款应用中的 SpyLoan 恶意软件影响 800 万安卓用户 迈克菲实验室的新研究发现&#xff0c;谷歌应用商店中有十多个恶意安卓应用被下载量总计超过 800 万次&#xff0c;这些应用包含名为 SpyLoan 的恶意软件。安全研究员费尔南多・鲁伊斯上周发布的分析报告称&…...

底部导航栏新增功能按键

场景需求&#xff1a; 在底部导航栏添加power案件&#xff0c;单击息屏&#xff0c;长按 关机 如下实现图 借此需求&#xff0c;需要掌握技能&#xff1a; 底部导航栏如何实现新增、修改、删除底部导航栏流程对底部导航栏部分样式如何修改。 比如放不下、顺序排列、坑点如…...

C++ 之弦上舞:string 类与多样字符串操作的优雅旋律

string 类的重要性及与 C 语言字符串对比 在 C 语言中&#xff0c;字符串是以 \0 结尾的字符集合&#xff0c;操作字符串需借助 C 标准库的 str 系列函数&#xff0c;但这些函数与字符串分离&#xff0c;不符合 OOP 思想&#xff0c;且底层空间管理易出错。而在 C 中&#xff0…...

centos8:Could not resolve host: mirrorlist.centos.org

【1】错误消息&#xff1a; [rootcentos211 redis-7.0.15]# yum update CentOS Stream 8 - AppStream …...

Linux 定时任务 命令解释 定时任务格式详解

目录 时间命令 修改时间和日期 定时任务格式 定时任务执行 查看定时任务进程 重启定时任务 时间命令 #查看时间 [rootlocalhost ~]# date 2021年 07月 23日 星期五 14:38:19 CST --------------------------------------- [rootlocalhost ~]# date %F 2021-07-23 -----…...

aws(学习笔记第十五课) 如何从灾难中恢复(recover)

aws(学习笔记第十五课) 如何从灾难中恢复 学习内容&#xff1a; 使用CloudWatch对服务器进行监视与恢复区域(region)&#xff0c;可用区(available zone)和子网(subnet)使用自动扩展(AutoScalingGroup) 1. 使用CloudWatch对服务器进行监视与恢复 整体架构 这里模拟Jenkins Se…...

github webhooks 实现网站自动更新

本文目录 Github Webhooks 介绍Webhooks 工作原理配置与验证应用云服务器通过 Webhook 自动部署网站实现复制私钥编写 webhook 接口Github 仓库配置 webhook以服务的形式运行 app.py Github Webhooks 介绍 Webhooks是GitHub提供的一种通知方式&#xff0c;当GitHub上发生特定事…...

【C语言】递归的内存占用过程

递归 递归是函数调用自身的一种编程技术。在C语言中&#xff0c;递归的实现会占用内存栈&#xff08;Call Stack&#xff09;&#xff0c;每次递归调用都会在栈上分配一个新的 “栈帧&#xff08;Stack Frame&#xff09;”&#xff0c;用于存储本次调用的函数局部变量、返回地…...

365天深度学习训练营-第P6周:VGG-16算法-Pytorch实现人脸识别

&#x1f368; 本文为&#x1f517;365天深度学习训练营中的学习记录博客&#x1f356; 原作者&#xff1a;K同学啊 文为「365天深度学习训练营」内部文章 参考本文所写记录性文章&#xff0c;请在文章开头带上「&#x1f449;声明」 &#x1f37a;要求&#xff1a; 保存训练过…...

企业AI助理在数据分析与决策中扮演的角色

在当今这个数据驱动的时代&#xff0c;企业每天都需要处理和分析大量的数据&#xff0c;以支持其业务决策。然而&#xff0c;面对如此庞大的数据量&#xff0c;传统的数据分析方法已经显得力不从心。幸运的是&#xff0c;随着人工智能&#xff08;AI&#xff09;技术的不断发展…...

洛谷 B2029:大象喝水 ← 圆柱体体积

【题目来源】https://www.luogu.com.cn/problem/B2029【题目描述】 一只大象口渴了&#xff0c;要喝 20 升水才能解渴&#xff0c;但现在只有一个深 h 厘米&#xff0c;底面半径为 r 厘米的小圆桶 &#xff08;h 和 r 都是整数&#xff09;。问大象至少要喝多少桶水才会解渴。 …...

go每日一题:mock打桩、defer、recovery、panic的调用顺序

题目一&#xff1a;单元测试中使用—打桩 打桩概念&#xff1a;使用A替换 原函数B&#xff0c;那么A就是打桩函数打桩原理&#xff1a;运行时&#xff0c;通过一个包&#xff0c;将内存中函数的地址替换为桩函数的地址打桩操作&#xff1a;利用Patch&#xff08;&#xff09;函…...

css实现圆环展示百分比,根据值动态展示所占比例

代码如下 <view class""><view class"circle-chart"><view v-if"!!num" class"pie-item" :style"{background: conic-gradient(var(--one-color) 0%,#E9E6F1 ${num}%),}"></view><view v-else …...

云启出海,智联未来|阿里云网络「企业出海」系列客户沙龙上海站圆满落地

借阿里云中企出海大会的东风&#xff0c;以**「云启出海&#xff0c;智联未来&#xff5c;打造安全可靠的出海云网络引擎」为主题的阿里云企业出海客户沙龙云网络&安全专场于5.28日下午在上海顺利举办&#xff0c;现场吸引了来自携程、小红书、米哈游、哔哩哔哩、波克城市、…...

OkHttp 中实现断点续传 demo

在 OkHttp 中实现断点续传主要通过以下步骤完成&#xff0c;核心是利用 HTTP 协议的 Range 请求头指定下载范围&#xff1a; 实现原理 Range 请求头&#xff1a;向服务器请求文件的特定字节范围&#xff08;如 Range: bytes1024-&#xff09; 本地文件记录&#xff1a;保存已…...

CocosCreator 之 JavaScript/TypeScript和Java的相互交互

引擎版本&#xff1a; 3.8.1 语言&#xff1a; JavaScript/TypeScript、C、Java 环境&#xff1a;Window 参考&#xff1a;Java原生反射机制 您好&#xff0c;我是鹤九日&#xff01; 回顾 在上篇文章中&#xff1a;CocosCreator Android项目接入UnityAds 广告SDK。 我们简单讲…...

Springcloud:Eureka 高可用集群搭建实战(服务注册与发现的底层原理与避坑指南)

引言&#xff1a;为什么 Eureka 依然是存量系统的核心&#xff1f; 尽管 Nacos 等新注册中心崛起&#xff0c;但金融、电力等保守行业仍有大量系统运行在 Eureka 上。理解其高可用设计与自我保护机制&#xff0c;是保障分布式系统稳定的必修课。本文将手把手带你搭建生产级 Eur…...

select、poll、epoll 与 Reactor 模式

在高并发网络编程领域&#xff0c;高效处理大量连接和 I/O 事件是系统性能的关键。select、poll、epoll 作为 I/O 多路复用技术的代表&#xff0c;以及基于它们实现的 Reactor 模式&#xff0c;为开发者提供了强大的工具。本文将深入探讨这些技术的底层原理、优缺点。​ 一、I…...

OPenCV CUDA模块图像处理-----对图像执行 均值漂移滤波(Mean Shift Filtering)函数meanShiftFiltering()

操作系统&#xff1a;ubuntu22.04 OpenCV版本&#xff1a;OpenCV4.9 IDE:Visual Studio Code 编程语言&#xff1a;C11 算法描述 在 GPU 上对图像执行 均值漂移滤波&#xff08;Mean Shift Filtering&#xff09;&#xff0c;用于图像分割或平滑处理。 该函数将输入图像中的…...

企业如何增强终端安全?

在数字化转型加速的今天&#xff0c;企业的业务运行越来越依赖于终端设备。从员工的笔记本电脑、智能手机&#xff0c;到工厂里的物联网设备、智能传感器&#xff0c;这些终端构成了企业与外部世界连接的 “神经末梢”。然而&#xff0c;随着远程办公的常态化和设备接入的爆炸式…...

虚拟电厂发展三大趋势:市场化、技术主导、车网互联

市场化&#xff1a;从政策驱动到多元盈利 政策全面赋能 2025年4月&#xff0c;国家发改委、能源局发布《关于加快推进虚拟电厂发展的指导意见》&#xff0c;首次明确虚拟电厂为“独立市场主体”&#xff0c;提出硬性目标&#xff1a;2027年全国调节能力≥2000万千瓦&#xff0…...

脑机新手指南(七):OpenBCI_GUI:从环境搭建到数据可视化(上)

一、OpenBCI_GUI 项目概述 &#xff08;一&#xff09;项目背景与目标 OpenBCI 是一个开源的脑电信号采集硬件平台&#xff0c;其配套的 OpenBCI_GUI 则是专为该硬件设计的图形化界面工具。对于研究人员、开发者和学生而言&#xff0c;首次接触 OpenBCI 设备时&#xff0c;往…...