Transformer中Self-Attention以及Multi-Head Attention模块详解(附pytorch实现)
写在前面
最近在项目中需要使用Transformer模型来处理图像任务,所以稍微补充一下这部分的知识,本篇主要了解一下Self-Attention以及Multi-Head Attention模块。
原论文链接:https://arxiv.org/pdf/1706.03762
原文代码:tensor2tensor/tensor2tensor/models/transformer.py
自注意力 Self-Attention
Self-Attention(自注意力机制)是一种动态建模输入数据内部依赖关系的方法,能够让模型关注输入数据中不同部分之间的相关性。

注意力机制的公式表示如下:
这里我们介绍一下这个公式,假设输入有一个序列,其中
为是输入的第 i 个元素,维度为 d ,那么对于 Self-Attention,关键的公式如下:
- 计算Query、Key和Value输入
通过线性变换分别得到Query(Q)、Key(K)和Value(V):
其中分别是训练的权重矩阵。
- 计算注意力分数(Attention Scores)
利用 Query 和 Key 计算注意力分数。注意力分数是 Query 和 Key 的点积,然后经过缩放处理(除以,其中
是 Key 向量的维度)
这个分数反映了序列中每一对元素之间的相似度。接下来,我们对这些分数进行归一化处理(通过 Softmax 函数):
这个步骤确保了所有注意力权重的和为 1,使得它们可以作为概率分布。
- 计算加权和(Weighted Sum)
将注意力权重与 Value 进行加权平均,得到最终的输出:
这一步的输出是一个新的表示,它是所有输入位置的加权求和结果,权重由其与 Query 的相关性决定。
以上为公式的原理,现在我举一个实际的例子来帮助大家理解这一部分。假设我们有一个长度为 3 的输入序列,每个元素是一个 2 维的嵌入:
其中,X 是一个形状为 (3,2) 的矩阵,表示每个位置的 2 维特征。为了计算 Query(Q)、Key(K)和 Value(V),我们首先需要通过一组权重矩阵对输入进行线性变换。假设权重矩阵如下(为简单起见,设定为 2 x 2 矩阵):
这些矩阵分别用于计算 Query、Key 和 Value。
根据上面的公式可得到:
得到 Query 和 Key 后,我们可以计算注意力分数。具体来说,对于每个 Query,我们与所有 Key 进行点积运算。计算如下:
注意,我们在这里计算的是 Query 和 Key 的点积,然后将其结果进行缩放。缩放因子通常是,其中
是 Key 向量的维度。在这个例子中,Key 的维度是 2,所以缩放因子是
。
即是,
然后,我们对每行的 Scaled Attention Scores 应用 softmax 操作,得到注意力权重:
最后,使用注意力权重对 Value(V)进行加权求和:
最后两步的计算结果用我下面给的代码跑一下就知道了。
import torch
import torch.nn as nnclass Softmax(nn.Module):def __init__(self, dim=-1):super().__init__()self.dim = dimdef _softmax(self, x):exp_x = torch.exp(x)softmax = exp_x / torch.sum(exp_x, dim=self.dim, keepdim=True)return softmaxdef forward(self, x):return self._softmax(x)matrix = torch.tensor([[0.707, 1.414, 2.121], [0, 0.707, 0.707], [0.707, 2.121, 2.828]])
Value = torch.tensor([[1, 0],[0, 1], [1, 1]], dtype=torch.float32)
print(matrix, "\n", Value)
softmax = Softmax()
score = softmax(matrix)
print(score)
result = torch.matmul(score, Value)
print(result)
多头注意力 Multi-Head Attention
在多头注意力机制中,模型会并行地计算多个不同的注意力头,每个头都有自己独立的 Query、Key 和 Value 权重,然后将每个头的输出连接起来,并通过一个线性变换得到最终的结果。

- 分割 Query、Key 和 Value 成多个头
对于每个头i,分别计算独立的 Query、Key 和 Value:
- 计算每个头的注意力输出
- 拼接各头的输出
将所有头的输出拼接在一起,h 是头的数量
- 最终输出
对拼接后的结果进行一次线性变换,得到最终的多头注意力输出。
这里我们不再举例了,下面你可以根据下面的代码进行测试。
注意力pytorch实现
import torch
import torch.nn as nn
import torch.nn.functional as Fclass MultiHeadAttention(nn.Module):def __init__(self, embed_dim, num_heads):"""初始化 MultiHeadAttention 模块:param embed_dim: 输入嵌入的特征维度:param num_heads: 注意力头的数量"""super(MultiHeadAttention, self).__init__()assert embed_dim % num_heads == 0, "embed_dim 必须能被 num_heads 整除"self.embed_dim = embed_dimself.num_heads = num_headsself.head_dim = embed_dim // num_heads # 每个头的特征维度# 定义 Query, Key 和 Value 的线性变换self.q_linear = nn.Linear(embed_dim, embed_dim)self.k_linear = nn.Linear(embed_dim, embed_dim)self.v_linear = nn.Linear(embed_dim, embed_dim)# 输出的线性变换self.out_linear = nn.Linear(embed_dim, embed_dim)def forward(self, x):""":param x: 输入张量,形状为 (batch_size, seq_len, embed_dim):return: 注意力后的输出,形状为 (batch_size, seq_len, embed_dim)"""batch_size, seq_len, embed_dim = x.size()# 生成 Query, Key, Value (batch_size, seq_len, embed_dim)Q = self.q_linear(x)K = self.k_linear(x)V = self.v_linear(x)# 分成多头 (batch_size, num_heads, seq_len, head_dim)Q = Q.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)K = K.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)V = V.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)# 计算注意力分数 (batch_size, num_heads, seq_len, seq_len)attention_scores = torch.matmul(Q, K.transpose(-2, -1)) / torch.sqrt(torch.tensor(self.head_dim, dtype=torch.float32))attention_weights = F.softmax(attention_scores, dim=-1)# 加权求和 (batch_size, num_heads, seq_len, head_dim)attention_output = torch.matmul(attention_weights, V)# 拼接多头输出 (batch_size, seq_len, embed_dimattention_output = attention_output.transpose(1, 2).contiguous().view(batch_size, seq_len, embed_dim)# 输出线性变换 (batch_size, seq_len, embed_dim)output = self.out_linear(attention_output)return outputclass SelfAttention(MultiHeadAttention):def __init__(self, embed_dim):"""初始化 SelfAttention 模块:param embed_dim: 输入嵌入的特征维度"""super(SelfAttention, self).__init__(embed_dim, num_heads=1)def forward(self, x):""":param x: 输入张量,形状为 (batch_size, seq_len, embed_dim):return: 注意力后的输出,形状为 (batch_size, seq_len, embed_dim)"""return super(SelfAttention, self).forward(x)if __name__ == "__main__":embed_dim = 64 # 输入特征维度num_heads = 8 # 注意力头的数量model = SelfAttention(embed_dim)multi_model = MultiHeadAttention(embed_dim, num_heads)batch_size = 2seq_len = 10x = torch.rand(batch_size, seq_len, embed_dim)output = model(x)output2 = multi_model(x)print("输出形状:", output.shape) # 应为 (batch_size, seq_len, embed_dim)print("输出形状:", output2.shape)
以上实现,与torch官方内部内部的实现略有不同,官方提供了一个实现好的多头注意力模块 torch.nn.MultiheadAttention。这个实现做了很多优化,比如对于输入和输出的形状、注意力分数的计算以及参数的处理,都进行了更加简化和高效的实现。官方实现默认要求输入形状为 (seq_len, batch_size, embed_dim)(但是实际可以通过参数batch_first来修改),这是因为官方实现是为实现批量并行化优化的。官方实现直接接受 Query、Key 和 Value 作为三个输入张量。
官方写的很好,但可能不够直观,我这里写的就比较的简洁了,并且由于内部实现有些许不同,我这个无法与其进行比较是否相同。下面是我做测试时候写的草稿,大家将就着看吧。
import torch
import torch.nn as nn
from model import MultiHeadAttentionclass OfficialMultiheadAttention(nn.Module):def __init__(self, embed_dim, num_heads):super(OfficialMultiheadAttention, self).__init__()self.multihead_attn = nn.MultiheadAttention(embed_dim, num_heads, batch_first=True)def forward(self, x):output, attention_weights = self.multihead_attn(x, x, x)# 检查每行的和是否为1attention_weights_sum = torch.sum(attention_weights, dim=-1)print("官方实现 - 每行的和:", attention_weights_sum)print("官方实现 - 每行的和是否为1:", torch.allclose(attention_weights_sum, torch.ones_like(attention_weights_sum)))return outputif __name__=="__main__":embed_dim = 8 # 输入特征维度num_heads = 2 # 注意力头的数量batch_size = 2seq_len = 5x = torch.rand(batch_size, seq_len, embed_dim)my_attention_model = MultiHeadAttention(embed_dim, num_heads)official_attention_model = OfficialMultiheadAttention(embed_dim, num_heads)my_attention_output = my_attention_model(x)official_attention_output = official_attention_model(x)print("SelfAttention 输出形状:", my_attention_output.shape) # 应为 (batch_size, seq_len, embed_dim)print("Official MultiheadAttention 输出形状:", official_attention_output.shape) # 应为 (batch_size, seq_len, embed_dim)is_same = torch.allclose(my_attention_output, official_attention_output, atol=1e-2)print("两个输出是否相同:", is_same)
总结
尽管官方的 MultiheadAttention 模块经过优化,具有更高的效率,但手动实现能够帮助大家更好地理解多头注意力机制的各个计算步骤。通过这些实验,我们不仅深入了解了注意力机制的原理,还能在实际应用中灵活使用这些机制,尤其是在图像任务中,Transformer 的强大能力得到了广泛的应用。
参考文章
详解Transformer中Self-Attention以及Multi-Head Attention_transformer multi head-CSDN博客
第四篇:一文搞懂Transformer架构的三种注意力机制_c3tr 注意力 详解-CSDN博客
一文搞定自注意力机制(Self-Attention)-CSDN博客
十分推荐的参考视频:Transformer中Self-Attention以及Multi-Head Attention详解_哔哩哔哩_bilibili
相关文章:
Transformer中Self-Attention以及Multi-Head Attention模块详解(附pytorch实现)
写在前面 最近在项目中需要使用Transformer模型来处理图像任务,所以稍微补充一下这部分的知识,本篇主要了解一下Self-Attention以及Multi-Head Attention模块。 原论文链接:https://arxiv.org/pdf/1706.03762 原文代码:tensor2…...
在Nvidia Jetson ADX Orin中使用TensorRT-LLM运行llama3-8b
目录 背景:步骤 1.获取模型权重第 2 步:准备第 3 步:构建 TensorRT-LLM 引擎 背景: 大型语言模型 (LLM) 推理的关键瓶颈在于 GPU 内存资源短缺。因此,各种加速框架主要强调减少峰值 GPU 内存使…...
六十一:HTTP/2的问题及HTTP/3的意义
随着互联网的快速发展,网络协议的升级成为优化用户体验和提升网络效率的重要手段。HTTP/2 于 2015 年发布,标志着超文本传输协议的重大改进。然而,尽管 HTTP/2 带来了许多新特性,它也存在一定的问题。在此背景下,HTTP/…...
IOS开发如何从入门进阶到高级
针对iOS开发的学习,不同阶段应采取不同的学习方式,以实现高效提升.本文将iOS开发的学习分为入门、实战、进阶三个阶段,下面分别详细介绍. 一、学习社区 iOS开源中国社区 这个社区专注于iOS开发的开源项目分享与协作,汇集了大量开…...
非一般的小数:小数的概念新解、小数分类、浮点数的存储
非一般的小数:小数的概念新解、小数分类、浮点数的存储 一、小数的概念二、小数的分类1.有限小数、无限循环小数、无限不循环小数2.纯小数、带小数3.定点数、浮点数 三、浮点数的存储 一、小数的概念 这还用解释吗?小…...
关于游戏销量的思考
1、黑神话达到2300万套,分析师上调预期到超过100亿营收。 以往的我的世界、小鸟、超级食肉男孩等游戏也都是几千万,上亿的销量。 也改变了相关开发者的命运。 一个开发者,卖出一个30万,或100万销量的作品,就足够改变…...
JuiceFS 详解:一款为云原生设计的高性能分布式文件系统
JuiceFS 详解:一款为云原生设计的高性能分布式文件系统 1. 什么是 JuiceFS? JuiceFS(Juiced File System)是一款高性能、POSIX 兼容的云原生分布式文件系统。它采用对象存储作为底层存储,支持多种元数据引擎…...
百度Android面试题及参考答案 (下)
Executorservice 和 Executor 有什么区别? Executor 接口 Executor 是一个简单的接口,它定义了一个方法execute(Runnable command)。这个接口的主要目的是将任务的提交和任务的执行分离,它提供了一种通用的方式来执行一个Runnable任务,但是它没有提供更多高级的功能,比如任…...
RK3588+FPGA全国产异步LED显示屏控制卡/屏幕拼接解决方案
RK3588FPGA核心板采用Rockchip RK3588新一代旗舰 级八核64位处理器,支持8K视频编解码,多屏4K输出,可实现12屏联屏拼接、同显、异显,适配多种操作系统,广泛适用于展览展示、广告内容投放、新零售、商超等领域实现各种媒…...
Elasticsearch:Query rules 疑难解答
作者:来自 Elastic Kathleen_DeRusso 查询规则(Query rules)为用户提供了一种对特定查询进行细粒度控制的方法。目前,查询规则的功能允许你将你选择的搜索结果固定在结果集的顶部,和/或根据上下文查询数据从结果集中排…...
四、VSCODE 使用GIT插件
VSCODE 使用GIT插件 一下载git插件与git Graph插件二、git插件使用三、文件提交到远程仓库四、git Graph插件 一下载git插件与git Graph插件 二、git插件使用 git插件一般VSCode自带了git,就是左边栏目的图标 在下载git软件后vscode的git插件会自动识别当前项目 …...
键盘鼠标共享工具Barrier(kail与windows操作系统)
键鼠共享工具Barrier(kail与windows操作系统)_barrier软件-CSDN博客 sudo apt install barrier...
QTcpSocket 中设置接收缓冲区大小
在 QTcpSocket 中设置接收缓冲区大小 使用setSocketOption方法 在QTcpSocket类中,可以使用setSocketOption函数来设置接收缓冲区大小。具体来说,对于 TCP 套接字,你可以使用QAbstractSocket::ReceiveBufferSizeSocketOption选项。以下是一个简…...
Arduino IDE刷微控制器并下载对应固件的原由
在使用Arduino IDE刷写某个微控制器时,下载对应的固件通常是为了确保微控制器能够正确识别和执行Arduino IDE中编写的代码。以下是对这一过程的详细解释: 一、固件的作用 固件是微控制器或嵌入式设备上运行的软件,它负责控制硬件设备的操作…...
Jurgen提出的Highway Networks:LSTM时间维方法应用到深度维
Jurgen提出的Highway Networks:LSTM时间维方法应用到深度维 具体实例与推演 假设我们有一个离散型随机变量 X X X,它表示掷一枚骰子得到的点数,求 X X X 的期望。 步骤: 列出 X X X 的所有可能取值 x i x_i xi(…...
Netron可视化深度学习的模型框架,大大降低了大模型的学习门槛
深度学习是机器学习的一个子领域,灵感来源于人脑的神经网络。深度学习通过多层神经网络自动提取数据中的高级特征,能够处理复杂和大量的数据,尤其在图像、语音、自然语言处理等任务中表现出色。常见的深度学习模型: 卷积神经网络…...
Android客制化------7.0设置壁纸存在的一些问题
ro.wallpaper.fixsize这个节点应该是RK这边导入的,可以通过追这个节点的代码查看具体的实现方式; 最近在开7.0的坑,遇到了一些小问题,记录一下。很大可能这个问题只是我这个芯片的代码上才存在的,不过殊途同归啦。 第…...
VuePress2配置unocss的闭坑指南
文章目录 1. 安装依赖:准备魔法材料2. 检查依赖版本一定要一致:确保魔法配方准确无误3. 新建uno.config.js:编写咒语书4. 配置config.js和client.js:完成仪式 1. 安装依赖:准备魔法材料 在开始我们的前端魔法之前&…...
海陵HLK-TX510人脸识别模块 stm32使用
一.主函数 #include "stm32f10x.h" // Device header #include "delay.h" #include "lcd.h" #include "dht11.h" #include "IOput.h" #include "usart.h" //#include "adc.h" …...
安卓14无法安装应用解决历程
客户手机基本情况: 安卓14,对应的 targetSdkVersion 34 前天遇到了安卓14适配问题,客户发来的截图是这样的 描述:无法安装我们公司的B应用。 型号:三星google美版 解决步骤: 1、寻找其他安卓14手机测试…...
接口测试中缓存处理策略
在接口测试中,缓存处理策略是一个关键环节,直接影响测试结果的准确性和可靠性。合理的缓存处理策略能够确保测试环境的一致性,避免因缓存数据导致的测试偏差。以下是接口测试中常见的缓存处理策略及其详细说明: 一、缓存处理的核…...
【kafka】Golang实现分布式Masscan任务调度系统
要求: 输出两个程序,一个命令行程序(命令行参数用flag)和一个服务端程序。 命令行程序支持通过命令行参数配置下发IP或IP段、端口、扫描带宽,然后将消息推送到kafka里面。 服务端程序: 从kafka消费者接收…...
基于uniapp+WebSocket实现聊天对话、消息监听、消息推送、聊天室等功能,多端兼容
基于 UniApp + WebSocket实现多端兼容的实时通讯系统,涵盖WebSocket连接建立、消息收发机制、多端兼容性配置、消息实时监听等功能,适配微信小程序、H5、Android、iOS等终端 目录 技术选型分析WebSocket协议优势UniApp跨平台特性WebSocket 基础实现连接管理消息收发连接…...
智能在线客服平台:数字化时代企业连接用户的 AI 中枢
随着互联网技术的飞速发展,消费者期望能够随时随地与企业进行交流。在线客服平台作为连接企业与客户的重要桥梁,不仅优化了客户体验,还提升了企业的服务效率和市场竞争力。本文将探讨在线客服平台的重要性、技术进展、实际应用,并…...
自然语言处理——循环神经网络
自然语言处理——循环神经网络 循环神经网络应用到基于机器学习的自然语言处理任务序列到类别同步的序列到序列模式异步的序列到序列模式 参数学习和长程依赖问题基于门控的循环神经网络门控循环单元(GRU)长短期记忆神经网络(LSTM)…...
今日学习:Spring线程池|并发修改异常|链路丢失|登录续期|VIP过期策略|数值类缓存
文章目录 优雅版线程池ThreadPoolTaskExecutor和ThreadPoolTaskExecutor的装饰器并发修改异常并发修改异常简介实现机制设计原因及意义 使用线程池造成的链路丢失问题线程池导致的链路丢失问题发生原因 常见解决方法更好的解决方法设计精妙之处 登录续期登录续期常见实现方式特…...
Docker 本地安装 mysql 数据库
Docker: Accelerated Container Application Development 下载对应操作系统版本的 docker ;并安装。 基础操作不再赘述。 打开 macOS 终端,开始 docker 安装mysql之旅 第一步 docker search mysql 》〉docker search mysql NAME DE…...
基于 TAPD 进行项目管理
起因 自己写了个小工具,仓库用的Github。之前在用markdown进行需求管理,现在随着功能的增加,感觉有点难以管理了,所以用TAPD这个工具进行需求、Bug管理。 操作流程 注册 TAPD,需要提供一个企业名新建一个项目&#…...
LINUX 69 FTP 客服管理系统 man 5 /etc/vsftpd/vsftpd.conf
FTP 客服管理系统 实现kefu123登录,不允许匿名访问,kefu只能访问/data/kefu目录,不能查看其他目录 创建账号密码 useradd kefu echo 123|passwd -stdin kefu [rootcode caozx26420]# echo 123|passwd --stdin kefu 更改用户 kefu 的密码…...
永磁同步电机无速度算法--基于卡尔曼滤波器的滑模观测器
一、原理介绍 传统滑模观测器采用如下结构: 传统SMO中LPF会带来相位延迟和幅值衰减,并且需要额外的相位补偿。 采用扩展卡尔曼滤波器代替常用低通滤波器(LPF),可以去除高次谐波,并且不用相位补偿就可以获得一个误差较小的转子位…...
