当前位置: 首页 > news >正文

transformer架构解析{掩码,(自)注意力机制,多头(自)注意力机制}(含代码)-3

目录

前言

掩码张量

什么是掩码张量

掩码张量的作用

生成掩码张量实现

注意力机制

学习目标

注意力计算规则

注意力和自注意力

注意力机制

注意力机制计算规则的代码实现

多头注意力机制

学习目标

什么是多头注意力机制 

多头注意力计算机制的作用

多头注意力机制的代码实现


前言

        在之前的小节中我们学习了词嵌入层(词向量编码)以及加入了位置编码的输入层的概念和代码实现的学习。在本小节中我们将学习transformer中最重要的部分-注意力机制

掩码张量

        我们先学习掩码张量。现提出两个问题:什么是掩码张量?生成掩码张量的过程?

什么是掩码张量

        张量尺寸不定,里面只有(0,1)元素,代表位置被遮掩或者不遮掩,它的作用就是让另外一张张量中的一些数值被遮掩,被替换,表现形式是一个张量。

掩码张量的作用

        在transformer中掩码张量的主要作用在应用attention时,有一些生成的attention张量中的值计算有可能已知了未来信息而得到的,未来信息被看到是因为训练时会把整个输出结果都一次性的Embedding,但是理论上的解码器的输出却并不是一次就能产生的最终结果的,而是一次次通过上一次结果综合得出的,因此,未来信息可能被提前利用,所以进行遮掩。

生成掩码张量实现

#生成掩码张量的代码分析
def  subsequeent_mask(size):#生成向后遮掩的掩码张量,参数size是掩码张量最后两个维度的大小,形成一个方阵#在函数中,首先定义掩码张量的形状attn_shape = (1,size,size)#使用np.ones方法像这个方阵中添加1元素,形成上三角阵#节约空间将数据类型变为无符号8位整型数字unit8subsequeent_mask = np.triu(np.ones(attn_shape),k=1 ).astype('uint8')#转换成tensor,内部做一个 1 - 操作实现反转return torch.from_numpy(1 - subsequeent_mask)
size = 5
sm = subsequeent_mask(size)
print('sm:',sm)

#掩码张量的可视化
plt.figure(figsize=(5,5))
plt.imshow(subsequeent_mask(20)[0])

        黄色是1的部分,这里代表被遮掩,紫色代表没有被遮掩,横坐标代表目标词汇的位置,纵坐标代表可查的位置。从上往下看,在0的位置看过去都是黄色,都被遮住了,1的位置望过去还是黄色,说明第一次词还没有产生,从第二位置看过去,就能看到位置1的词,其他位置看不到,以此类推

注意力机制

学习目标

        掌握注意力计算规则和注意力机制

        掌握注意力计算规则的实现过程(最具有辨识度的部分)

注意力计算规则

        它需要3个指定的输入Q(query),K(key),V(value),通过公式计算得出注意力的计算结果

        query在key和Value的作用下表示:

        Attention(Q,K,V)=softmax(\frac{QK^{T}}{\sqrt{d_{k}}})V

        大家想要了解具体注意力计算规则可以去了解自然语言处理-BERT处理框架-transformer这篇文章,里面有具体的Q,K,V注意力计算规则介绍。

注意力和自注意力

        这两者的区别在于Q,K,V矩阵。注意力,刚开始时Q矩阵输入原词向量,而K,V矩阵输入人为添加的已经总结好的特征向量且默认K=V。自注意力,开始时输入Q=K=V,K,V没有人为干扰完全自己去迭代。

注意力机制

        之前我们都是纸上谈兵,要使注意力计算规则能够应用在深度学习的网络,成为载体,包括全连接层以及相关张量的处理。

        注意力机制在网络中实现的图形表示:

注意力机制计算规则的代码实现

#注意力计算规则的代码分析
def attention(query,key,value,mask=None,dropout=None):#注意力机制实现:输入分别是query,key,value,mask:掩码张量#在函数中,首先取query的最后一维的大小,词嵌入的维度d_kd_k,来进行缩放d_k = query.size(-1)#按照公式,将query与key的转置相乘,这里面key是将最后两个维度进行转置,再除以缩放系数#得到得分张量scoresscores = torch.matmul(query,key.transpose(-2,-1)) / math.sqrt(d_k)#判断是否使用掩码张量if mask is not None:#使用tensor的masked_fill方法,将掩码张量和scores张量的每一个位置一一比较,如果等于0就一一替换#则对应的socers张量用-1e9这个值来代替scores = scores.masked_fill(mask == 0,-1e9)#对scores的最后一维进行softmax操作,使用F.softmax方法,第一个参数是softmax对象,#获得最终的注意力张量p_attn = F.softmax(scores,dim=-1)#判断是否使用dropout进行随机置0if dropout is not None:#将p_attn传入dropout对象进行‘丢弃’处理p_attn = dropout(p_attn)#最后,根据公式将p_attn与value张量相乘获得最终的query注意力表示,同时返回注意力张量return torch.matmul(p_attn,value), p_attn
query = key = value = pe_result #自注意力机制
print(query.shape)
mask = Variable(torch.zeros(2,4,4))
attn,p_attn = attention(query,key,value,mask=mask)
print('attn:',attn)
print(attn.shape)
print('p_attn:',p_attn)
print(p_attn.shape)

多头注意力机制

学习目标

        了解多头注意力机制的作用,掌握多头注意力机制的实现过程

        多头注意力机制结构图:

什么是多头注意力机制 

        在图中,我们可以看到,有一组Linear层进行线性变换,变换前后的维度不变,就当是一个方阵的张量,每个张量的值不同,那么变化后的结果也不同,特征就丰富起来了。变换后进入注意力计算机制,一组有多少个linear层并行,就代表有几个头,将计算结果最后一维(Q*K^T)分割,然后组合成scores。

多头注意力计算机制的作用

        这种结构的设计能够让每个注意力机制去优化每个词的不同特征部分,从而均衡同一种注意力机制可能产生的偏差,让词义拥有更多元的表达,实验表明可以提升模型的效果。

多头注意力机制的代码实现

#多头注意力机制的实现#定义一个克隆函数,因为在多头注意力机制的实现中,用到多个结构相同的线性层
#我们将使用clone函数将他们一同初始化在一个网络层列表对象中,之后的结构中也会使用到该函数
def clones(module,N):#用于生成相同的网络层的克隆函数,它的参数module表示要克隆的目标网络层,N代表需要克隆的数量#在函数中,我们通过for循环对module进行N次的深度拷贝,使其每个module成为独立的层#然后将其放在nn.ModuleList类型的列表中存放return nn.ModuleList([copy.deepcopy(module) for _ in range(N)])
#使用一个类来实现多头注意力机制的处理
class MutiHeadedAttention(nn.Module):def __init__(self, head,embedding_dim,dropout=0.1):#在类的初始化时,会传入3个参数,head代表头数,embedding_dim代表词嵌入的维度#dropout代表进行dropout操作时置0比率super(MutiHeadedAttention,self).__init__()#在函数中,首先使用了一个测试中常用的assert语句,判断h是否能被d_model整除#这是因为我们之后要每个头分配等量的词特征,也就是emdedding_dim/head个assert embedding_dim % head == 0#得到每个头获得的分割词向量的维度d_kself.d_k = embedding_dim // head#传入头数self.head = headself.embedding_dim = embedding_dim#然后获得线性层的对象,通过nn.linear实例化,它的内部变换矩阵是embedding_dim x embedding_dim#为什么是4个呢,这是因为在多头注意力中Q,K,V各需要一个,拼接矩阵也需要一个self.linears = clones(nn.Linear(embedding_dim,embedding_dim),4)#self.attn为None,它代表最后获得的注意力张量self.attn = None#最后一个是self.dropout对象,它通过nn中的Dropout实例化而来,置0比率为传进来的的参数dropoutself.dropout = nn.Dropout(p=dropout)def forward(self,query,key,value,mask=None):#前向逻辑函数,它的输入参数有四个,前三个就是注意力机制需要的Q,K,V#最后一个是注意力机制中可能需要的mask掩码,默认是None#如果存在掩码张量maskif mask is not None:#使用unsqueeze拓展维度,代表多头中的第n个头mask = mask.unsqueeze(1)#接着,我们获得一个batch_size的变量,他是query尺寸的第一个数字,代表有多少条样本batch_size = query.size(0)#之后就进入多头处理环节#首先利用zip将QKV与三个线性层住到一起,然后使用for循环,将输入的QKV分别传到线性层中#做完线性变换后,开始为每个头分割输入,这里使用view方法对线性变换的结果进行维度重塑#这样就意味着每个头可以获得一部分词特征组成的句子,其中的-1代表自适应维度#计算机会根据这种变换自动计算这里的值,然后对第二维度和第三维度进行转置操作#为了句子长度维度和词向量维度能够相邻,这样注意力机制才能找到词义与句子位置的关系#从attention函数中可以看到,利用的是原始输入的倒数第一和第二维,这样我们就能得到每个头的query,key,value = \[model(x).view(batch_size,-1,self.head,self.d_k).transpose(1,2) for model,x in zip(self.linears,(query,key,value))]#print(query.shape)#print(key.shape)#print(value.shape)#得到每个头的输入后,接下来就是将他们传入到attention中#这里直接调用我们之前实现的attention函数。同时也将mask和dropout传入其中x,self.attn = attention(query,key,value,mask=mask,dropout=self.dropout)#通过多头注意力计算后,我们就得到每个头计算结果组成的4维张量,我们需要将其转换成为输入的格式#因此这里卡开始进行第一步处理:逆操作,先对第二第三维进行转置,然后使用contiguous方法#这个方法的作用就是让能够让转置后的张量应用view方法,否则将无法直接使用#所以,下一步就是使用view重塑,变成和输入形状相同x = x.transpose(1,2).contiguous().view(batch_size,-1,self.head * self.d_k)#最后使用线性层列表中的最后一个线性层对输入进行线性变换得到最终的多头注意力结构的输出return self.linears[-1](x)
#实例化参数
head = 8
embedding_dim = 512
dropout = 0.2#若干输入参数的初始化
query = key = value = pe_resultmask = Variable(torch.zeros(2,4,4))mha = MutiHeadedAttention(head,embedding_dim,dropout)mha_result = mha(query,key,value,mask)print(mha_result)
print(mha_result.shape)

相关文章:

transformer架构解析{掩码,(自)注意力机制,多头(自)注意力机制}(含代码)-3

目录 前言 掩码张量 什么是掩码张量 掩码张量的作用 生成掩码张量实现 注意力机制 学习目标 注意力计算规则 注意力和自注意力 注意力机制 注意力机制计算规则的代码实现 多头注意力机制 学习目标 什么是多头注意力机制 多头注意力计算机制的作用 多头注意力机…...

【C++】switch 语句编译报错:error: jump to case label

/home/share/mcrockit_3588/prj_linux/../source/rkvpss.cpp: In member function ‘virtual u32 CRkVpss::Control(u32, void*, u32)’: /home/share/mcrockit_3588/prj_linux/../source/rkvpss.cpp:242:8: error: jump to case label242 | case emRkComCmd_DBG_SaveInput:|…...

linux中使用firewall命令操作端口

一、开放端口 1. 开放一个端口 sudo firewall-cmd --zonepublic --add-port8443/tcp --permanent sudo firewall-cmd --reload 2. 开放一组连续端口 sudo firewall-cmd --zonepublic --add-port100-500/tcp --permanent sudo firewall-cmd --reload 3. 一次开放多个不连续…...

C++第六节:stack和queue

本节目标: stack的介绍与使用queue的介绍与使用priority_queue的介绍与使用容器适配器模拟实现与结语 1 stack(堆)的介绍 stack是一种容器适配器,专门用在具有后进先出操作的上下文环境中,只能从容器的一端进行元素的插…...

算法 并查集

目录 前言 一 并查集的思路 二 并查集的代码分析 三 实操我们的代码 四 并查集的代码优化 总结 前言 并查集主要是用来求解集合问题的,用来查找集合还有就是合并集合,可以把这个运用到最小生成树里面 一 并查集的思路 1 并查集的相关的操作…...

yarn application命令中各参数的详细解释

yarn application 命令用于管理和监控 YARN 上运行的应用程序,下面为你详细解释该命令中各参数的含义和用途: 通用参数 -help [command] 作用:显示 yarn application 命令的帮助信息。如果指定了 command,则显示该子命令的详细使…...

算法之数据结构

目录 数据结构 数据结构与算法面试题 数据结构 《倚天村 • 图解数据结构》 | 小傅哥 bugstack 虫洞栈 ♥数据结构基础知识体系详解♥ | Java 全栈知识体系 线性数据结构 | JavaGuide 数据结构与算法面试题 数据结构与算法面试题 | 小林coding...

Android 图片压缩详解

在 Android 开发中,图片压缩是一个重要的优化手段,旨在提升用户体验、减少网络传输量以及降低存储空间占用。以下是几种主流的图片压缩方法,结合原理、使用场景和优缺点进行详细解析。 效果演示 直接先给大家对比几种图片压缩的效果 质量压缩 质量压缩:根据传递进去的质…...

迷你世界脚本计时器接口:MiniTimer

计时器接口:MiniTimer 彼得兔 更新时间: 2023-04-26 20:24:50 具体函数名及描述如下: 序号 函数名 函数描述 1 isExist(...) 判断计时器是否存在 2 createTimer(...) 添加计时器 3 deleteTimer(...) 删除计时器 4 startBackwardTimer(.…...

JavaScript的变量以及数据类型

JS变量 变量的声明 四种声明方式 1. <script>var abc;abc"变量声明1";alert(abc);</script>2. <script>var abc"变量声明2";alert(abc);</script><script>var abc1,abc2;abc1"变量声明3.1";abc2"变量声明3…...

私有云基础架构

基础配置 使用 VMWare Workstation 创建三台 2 CPU、8G内存、100 GB硬盘 的虚拟机 主机 IP 安装服务 web01 192.168.184.110 Apache、PHP database 192.168.184.111 MariaDB web02 192.168.184.112 Apache、PHP 由于 openEuler 22.09 系统已经停止维护了&#xff…...

在 Windows 和 Linux 系统上安装和部署 Ollama

引言 Ollama 是一个强大的本地大语言模型&#xff08;LLM&#xff09;运行工具&#xff0c;允许用户轻松下载和运行不同的 AI 模型&#xff0c;如 LLaMA、Mistral 和 Gemma。无论是开发者还是研究人员&#xff0c;Ollama 都提供了一种简单而高效的方式来在本地环境中部署 AI 模…...

从零开始学习Slam--数学概念

正交矩阵 矩阵的转置等于它的逆矩阵&#xff0c;这样的矩阵称之为正交矩阵 即&#xff1a; Q T Q I Q^T Q I QTQI&#xff0c; 这样的矩阵列向量都是单位向量且两两正交。 旋转矩阵属于特殊的正交群&#xff0c;即SO(n)&#xff0c;这里n通常是3&#xff0c;所以SO(3)就是…...

【零基础到精通Java合集】第十五集:Map集合框架与泛型

课程标题:Map集合框架与泛型(15分钟) 目标:掌握泛型在Map中的键值类型约束,理解类型安全的键值操作,熟练使用泛型Map解决实际问题 0-1分钟:泛型Map的意义引入 以“字典翻译”类比泛型Map:明确键和值的类型(如英文→中文)。说明泛型Map的作用——确保键值对的类型一…...

从小米汽车召回看智驾“命门”:智能化时代 — 时间就是安全

2025年1月&#xff0c;小米因车辆“授时同步异常”召回3万余辆小米SU7&#xff0c;成为其造车历程中的首个重大安全事件。 从小米SU7召回事件剖析&#xff0c;授时同步何以成为智能驾驶的命门&#xff1f; 2024年11月&#xff0c;多名车主反馈SU7标准版的智能泊车辅助功能出现…...

Visual Studio Code 如何编写运行 C、C++ 程序

目录 安装 MinGW-w64 编译器&#xff08;推荐&#xff09;在 VS Code 中配置 C 开发环境 参考链接 在vs code上运行c脚本&#xff0c;报了下面的错误&#xff0c;我仅仅安装了vs code及在商店里下载了插件&#xff0c;其它配置操作没有做&#xff0c;直接对一个脚本进行运行&am…...

动静态库-Linux 学习

在软件开发中&#xff0c;程序库是一组预先编写好的程序代码&#xff0c;它们存储了常用的函数、变量和数据结构等。这些库可以帮助开发者节省大量的时间和精力&#xff0c;避免重复编写相同的代码。当我们在 Linux 系统中开发程序时&#xff0c;经常会用到两种类型的程序库&am…...

【Hudi-SQL DDL创建表语法】

CREATE TABLE 命令功能 CREATE TABLE命令通过指定带有表属性的字段列表来创建Hudi Table。 命令格式 CREATE TABLE [ IF NOT EXISTS] [database_name.]table_name[ (columnTypeList)]USING hudi[ COMMENT table_comment ][ LOCATION location_path ][ OPTIONS (options_lis…...

HTML label 标签使用

点击 <label> 标签通常会使与之关联的表单控件获得焦点或被激活。 通过正确使用 <label> 标签&#xff0c;可以使表单更加友好和易于使用&#xff0c;同时提高整体的可访问性。 基本用法 <label> 标签通过 for 属性与 id 为 username 的 <input> 元素…...

bge-large-zh-v1.5 与Pro/BAAI/bge-m3 区别

ge-large-zh-v1.5 和 Pro/BAAI/bge-m3 是两种不同的模型&#xff0c;主要区别在于架构、性能和应用场景。以下是它们的对比&#xff1a; 1. 模型架构 bge-large-zh-v1.5&#xff1a; 基于Transformer架构&#xff0c;专注于中文文本的嵌入表示。 参数量较大&#xff0c;适合处…...

RestClient

什么是RestClient RestClient 是 Elasticsearch 官方提供的 Java 低级 REST 客户端&#xff0c;它允许HTTP与Elasticsearch 集群通信&#xff0c;而无需处理 JSON 序列化/反序列化等底层细节。它是 Elasticsearch Java API 客户端的基础。 RestClient 主要特点 轻量级&#xff…...

【Python】 -- 趣味代码 - 小恐龙游戏

文章目录 文章目录 00 小恐龙游戏程序设计框架代码结构和功能游戏流程总结01 小恐龙游戏程序设计02 百度网盘地址00 小恐龙游戏程序设计框架 这段代码是一个基于 Pygame 的简易跑酷游戏的完整实现,玩家控制一个角色(龙)躲避障碍物(仙人掌和乌鸦)。以下是代码的详细介绍:…...

rknn优化教程(二)

文章目录 1. 前述2. 三方库的封装2.1 xrepo中的库2.2 xrepo之外的库2.2.1 opencv2.2.2 rknnrt2.2.3 spdlog 3. rknn_engine库 1. 前述 OK&#xff0c;开始写第二篇的内容了。这篇博客主要能写一下&#xff1a; 如何给一些三方库按照xmake方式进行封装&#xff0c;供调用如何按…...

23-Oracle 23 ai 区块链表(Blockchain Table)

小伙伴有没有在金融强合规的领域中遇见&#xff0c;必须要保持数据不可变&#xff0c;管理员都无法修改和留痕的要求。比如医疗的电子病历中&#xff0c;影像检查检验结果不可篡改行的&#xff0c;药品追溯过程中数据只可插入无法删除的特性需求&#xff1b;登录日志、修改日志…...

《从零掌握MIPI CSI-2: 协议精解与FPGA摄像头开发实战》-- CSI-2 协议详细解析 (一)

CSI-2 协议详细解析 (一&#xff09; 1. CSI-2层定义&#xff08;CSI-2 Layer Definitions&#xff09; 分层结构 &#xff1a;CSI-2协议分为6层&#xff1a; 物理层&#xff08;PHY Layer&#xff09; &#xff1a; 定义电气特性、时钟机制和传输介质&#xff08;导线&#…...

spring:实例工厂方法获取bean

spring处理使用静态工厂方法获取bean实例&#xff0c;也可以通过实例工厂方法获取bean实例。 实例工厂方法步骤如下&#xff1a; 定义实例工厂类&#xff08;Java代码&#xff09;&#xff0c;定义实例工厂&#xff08;xml&#xff09;&#xff0c;定义调用实例工厂&#xff…...

【Web 进阶篇】优雅的接口设计:统一响应、全局异常处理与参数校验

系列回顾&#xff1a; 在上一篇中&#xff0c;我们成功地为应用集成了数据库&#xff0c;并使用 Spring Data JPA 实现了基本的 CRUD API。我们的应用现在能“记忆”数据了&#xff01;但是&#xff0c;如果你仔细审视那些 API&#xff0c;会发现它们还很“粗糙”&#xff1a;有…...

JVM暂停(Stop-The-World,STW)的原因分类及对应排查方案

JVM暂停(Stop-The-World,STW)的完整原因分类及对应排查方案,结合JVM运行机制和常见故障场景整理而成: 一、GC相关暂停​​ 1. ​​安全点(Safepoint)阻塞​​ ​​现象​​:JVM暂停但无GC日志,日志显示No GCs detected。​​原因​​:JVM等待所有线程进入安全点(如…...

springboot整合VUE之在线教育管理系统简介

可以学习到的技能 学会常用技术栈的使用 独立开发项目 学会前端的开发流程 学会后端的开发流程 学会数据库的设计 学会前后端接口调用方式 学会多模块之间的关联 学会数据的处理 适用人群 在校学生&#xff0c;小白用户&#xff0c;想学习知识的 有点基础&#xff0c;想要通过项…...

Windows安装Miniconda

一、下载 https://www.anaconda.com/download/success 二、安装 三、配置镜像源 Anaconda/Miniconda pip 配置清华镜像源_anaconda配置清华源-CSDN博客 四、常用操作命令 Anaconda/Miniconda 基本操作命令_miniconda创建环境命令-CSDN博客...