当前位置: 首页 > news >正文

自注意力简介

在注意力机制中,每个查询都会关注所有的键值对并生成一个注意力输出。如果查询q,键k和值v都来自于同一组输入,那么这个注意力就被称为是自注意力(self-attention)。自注意力这部分理论,我觉得台大李宏毅老师的课程讲得最好。

自注意力就是输入一堆向量,假设称为a1,a2,a3,a4,那么这四个向量都会参与自注意力机制的运算,得到的结果仍然是四个输出,这四个输出再去做全连接运算。而每一个自注意力机制的输出都用到了a1~a4四个向量来进行运算,也就是说每个输出都是观察了所有的输入之后才得到的。

首先,输入a1需要和a2,a3,a4分别计算相关性,这个相关性可以由缩放点积方式来计算,也就称作缩放点积注意力,也可以由两个两个输入向量相加后再做非线性处理得到,称为加性注意力。

缩放点积的计算方法如下:

输入一个向量v1和一个向量v2,v1去乘上一个可训练矩阵Wq得到q,v2去乘上一个可训练矩阵Wk得到k,再把这个q和k做一个点积运算,得到的就是α,类似于相似度。

回到前面的例子中,a1这里既作为q,又作为k,又作为v。其中,Wq*a1就是q1,Wk*a1就是k1,q1和k1的点积就是α11,相当于a1自己和自己的相似度,同样的,a1和a2,a3,a4分别计算得到α12,α13,α14,然后将α11,α12,α13,α14经过softmax得到最终的四个输出,如下图所示:

然后再用一个可训练矩阵Wv去乘以a1得到v1,用计算得到的相似度α'11去乘以v1,得到一个值temp11;同样的,用可训练矩阵Wv去乘以a2得到v2,用计算得到的相似度α‘12去乘以v2,得到temp12;类似的得到temp13,temp14,然后把temp11+temp12+temp13+temp14得到b1,这个b1就是自注意力机制的第一个输出。

我们刚刚是以a1的视角做的运算,得到b1,同样可以以a2,a3,a4的视角做运算,得到b2,b3,b4。这次就得到了自注意力机制的输出。光看最后这个结构图,有点类似全连接,只是里面的运算过程比全连接要复杂。

下面,我们来看一下如何用代码实现自注意力的计算。

import torch  
import torch.nn as nn  
import torch.nn.functional as F  class SelfAttention(nn.Module):# embed_size代表输入的向量维度,heads代表多头注意力机制中的头数量def __init__(self, embed_size, heads): super(SelfAttention,self).__init__()self.embed_size = embed_sizeself.heads = headsself.head_dim = embed_size // heads # 每个头的维度# 用assert断言机制判断assert (self.head_dim * heads == embed_size), "Embedding size needs to be divisible by heads" # 没有偏置项,其实这个线性层本质上就是为了计算值Wv*a = Vself.values = nn.Linear(self.head_dim, self.head_dim, bias=False)self.keys = nn.Linear(self.head_dim, self.head_dim, bias=False)self.queries = nn.Linear(self.head_dim, self.head_dim, bias=False)# 最后的全连接操作,输出仍是输入的向量维度,也就是说大小是不变的self.fc_out = nn.Linear(heads*self.head_dim, self.embed_size)def forward(self, values, keys, query, mask):# 这个mask也很关键,它用于控制模型在处理序列数据时应该关注哪些部分,以及忽略哪些部分N = query.shape[0] # 获取输入的批量个数print("N:",N)value_len, key_len, query_len = values.shape[1], keys.shape[1], query.shape[1] # 获取输入序列的长度# Split the embedding into self.heads different pieces  # 把k,q,v都切分为多个组values = values.reshape(N, value_len, self.heads, self.head_dim)keys = keys.reshape(N, key_len, self.heads, self.head_dim)queries = query.reshape(N, query_len, self.heads, self.head_dim)# 计算k,q,vvalues = self.values(values)keys = self.keys(keys)queries = self.queries(queries)energy = torch.einsum("nqhd,nkhd->nhqk",[queries, keys]) # 格式转化print("queries.shape:", queries.shape)print("keys.shape:", keys.shape)print("energy.shape:", energy.shape)if mask is not None:energy = energy.masked_fill(mask==0, float("-1e20"))attention = torch.softmax(energy/(self.embed_size**(1/2)), dim=3) # softmax内部是缩放点积print("attention.shape:", attention.shape)print("values.shape:", values.shape)out = torch.einsum("nhql,nlhd->nqhd", [attention, values]).reshape(N, query_len, self.heads*self.head_dim)out = self.fc_out(out)return out
embed_size = 512
heads = 8
attention = SelfAttention(embed_size, heads)# batch size 1, seq length 60
values = torch.rand(1,60,embed_size)
keys = torch.rand(1,60,embed_size)
queries = torch.rand(1,60,embed_size)
mask = None # 假设没有maskout = attention(values, keys, queries, mask)
print(out.shape)# 输出
N: 1
queries.shape: torch.Size([1, 60, 8, 64])
keys.shape: torch.Size([1, 60, 8, 64])
energy.shape: torch.Size([1, 8, 60, 60])
attention.shape: torch.Size([1, 8, 60, 60])
values.shape: torch.Size([1, 60, 8, 64])
torch.Size([1, 60, 512])

通过这个程序,我们可以看到,自注意力机制是不改变输入和输出的形状的,输入的Q,K,V格式是[1,60,512],输出的结果的仍然是[1,60,512]。

下面是几点说明:

1. 这里的embed_size代表的是输入到自注意力层中的每个元素的向量维度。在Transformer模型中,输入数据首先会被转换成一个固定长度的向量,这个向量的长度就称为embed_size。

2. mask表示的是模型在处理序列数据时,应该忽略掉哪部分,我这里设置为None,也就是全部参与计算。

3. einsum,称为爱因斯坦求和,起源是爱因斯坦在研究广义相对论时,需要处理大量求和运算,为了简化这种繁复的运算,提出了求和约定,推动了张量分析的发展。einsum 可以计算向量、矩阵、张量运算,如果利用得当,sinsum可完全代替其他的矩阵计算方法。

例如,C = einsum('ij,jk->ik', A, B),就相当于两个矩阵求内积:cik = Σj AijBjk。

通过输出可以看到,在计算前queries的形状是[1,60,8,64],keys的形状是[1,60,8,64],在表达式"nqhd,nkhd->nhqk"中,n=1,q=60,h=8,d=64,k=60,两个矩阵进行内积,因此得到的结果是nhqk,也就是[1,8,60,60]。

相关文章:

自注意力简介

在注意力机制中,每个查询都会关注所有的键值对并生成一个注意力输出。如果查询q,键k和值v都来自于同一组输入,那么这个注意力就被称为是自注意力(self-attention)。自注意力这部分理论,我觉得台大李宏毅老师…...

【GameFramework框架】7-2、GameFramework框架是否“过度设计”?

推荐阅读 CSDN主页GitHub开源地址Unity3D插件分享简书地址QQ群:398291828大家好,我是佛系工程师☆恬静的小魔龙☆,不定时更新Unity开发技巧,觉得有用记得一键三连哦。 一、前言 【GameFramework框架】系列教程目录: https://blog.csdn.net/q764424567/article/details/1…...

RISC-V异常处理流程概述(2):异常处理机制

RISC-V异常处理流程概述(2):异常处理机制 一、异常处理流程和异常委托1.1 异常处理流程1.2 异常委托二、RISC-V异常处理中软件相关内容2.1 异常处理准备工作2.2 异常处理函数2.3 Opensbi系统调用的注册一、异常处理流程和异常委托 1.1 异常处理流程 发生异常时,首先需要执…...

Unity3D中如何降低游戏的Drawcall详解

在Unity3D游戏开发中,Drawcall是一个至关重要的性能指标,它指的是CPU通知GPU绘制一个物体的命令次数。过多的Drawcall会导致游戏性能下降,因此优化Drawcall的数量是提高游戏性能的关键。本文将详细介绍Unity3D中降低Drawcall的几种主要方法&a…...

小程序-设置环境变量

在实际开发中,不同的开发环境,调用的接口地址是不一样的 例如:开发环境需要调用开发版的接口地址,生产环境需要正式版的接口地址 这时候,我们就可以使用小程序提供了 wx.getAccountInfoSync() 接口,用来获取…...

【RabbitMQ】一文详解消息可靠性

目录: 1.前言 2.生产者 3.数据持久化 4.消费者 5.死信队列 1.前言 RabbitMQ 是一款高性能、高可靠性的消息中间件,广泛应用于分布式系统中。它允许系统中的各个模块进行异步通信,提供了高度的灵活性和可伸缩性。然而,这种通…...

RuntimeError: Unexpected error from cudaGetDeviceCount

RuntimeError: Unexpected error from cudaGetDeviceCount 0. 引言1. 临时解决方法 0. 引言 使用 vllm-0.4.2 部署时,多卡正常运行。升级到 vllm-0.5.1 时,报错如下: (VllmWorkerProcess pid30692) WARNING 07-12 08:16:22 utils.py:562] U…...

uboot学习:(一)基础认知

目录 uboot是一个裸机程序(bootloader) 作用 要运行linux系统时,如何从外置的flash拷贝到DDR中,才能启动 uboot使用步骤 步骤1中的命令例子 注意 uboot源码获取方法 uboot是一个裸机程序(bootloader&#xff09…...

每天一个数据分析题(四百二十六)- 总体方差

为了比较两个总体方差,我们通常检验两个总体的() A. 方差差 B. 方差比 C. 方差乘积 D. 方差和 数据分析认证考试介绍:点击进入 题目来源于CDA模拟题库 点击此处获取答案 数据分析专项练习题库 内容涵盖Python,SQL,统计学&a…...

【C++】设计一套基于C++与C#的视频播放软件

在开发一款集视频播放与丰富交互功能于一体的软件时,结合C的高性能与C#在界面开发上的便捷性,是一个高效且实用的选择。以下,我们将概述这样一个系统的架构设计、关键技术点以及各功能模块的详细实现思路。 一、系统架构设计 1. 架构概览 …...

数学建模中的辅助变量、中间变量、指示变量

在数学建模中,除了决策变量外,还有一些其他类型的变量,如中间变量、辅助变量和指示变量。每种变量在模型中都有特定的用途和意义。以下是对这些变量的详细解释: 1. 决策变量(Decision Variables) 定义&am…...

python的seek()和tell()

seek() seek() 是用来在文件中移动指针位置的方法。它的作用是将文件内部的当前位置设置为指定的位置。 seek(offset, whence) 参数说明 offset: 这是一个整数值,表示相对于起始位置的偏移量。如果是正数,表示向文件末尾方向移动;如果是负…...

Go泛型详解

引子 如果我们要写一个函数分别比较2个整数和浮点数的大小&#xff0c;我们就要写2个函数。如下&#xff1a; func Min(x, y float64) float64 {if x < y {return x}return y }func MinInt(x, y int) int {if x < y {return x}return y }2个函数&#xff0c;除了数据类…...

【每日一练】python之sum()求和函数实例讲解

在Python中&#xff0c; sum()是一个内置函数&#xff0c;用于计算可迭代对象&#xff08;如列表、元组等&#xff09;中所有元素的总和。如下实例&#xff1a; """ 收入支出统计小程序 知识点:用户输入获取列表元素添加sum()函数&#xff0c;统计作用 "&…...

打造智慧校园德育管理,提升学生操行基础分

智慧校园的德育管理系统内嵌的操行基础分功能&#xff0c;是对学生日常行为规范和道德素养进行量化评估的一个创新实践。该功能通过将抽象的道德品质转化为具体可量化的指标&#xff0c;如遵守纪律、尊师重道、团结协作、爱护环境及参与集体活动的积极性等&#xff0c;为每个学…...

自定义函数---随机数系列函数

大家有没有发现平常在写随机数的时候&#xff0c;需要引入很多的头文件&#xff0c;然后还需要用一些复杂的函数&#xff0c;大家可能不太习惯&#xff0c;于是我就制作了一个头文件 // random_number.h #ifndef RANDOM_NUMBER_H // 预处理指令&#xff0c;防止头文件被重复包含…...

一文了解5G新通话技术演进与业务模型

5G新通话简介 5G新通话&#xff0c;也被称为VoNR&#xff0c;是基于R16及后续协议产生的一种增强型语音通话业务。 它在IMS网络里新增数据通道&#xff08;Data Channel&#xff09;&#xff0c;承载通话时的文本、图片、涂鸦、菜单等信息。它能在传统话音业务基础上提供更多服…...

视频使用操作说明书-T80002系列视频编码器如何对接海康NVR硬盘录像机,包括T80002系列高清HDMI编码器、4K超高清HDMI编码器

视频使用操作说明书-T80002系列视频编码器如何对接海康NVR硬盘录像机&#xff0c;包括T80002系列高清HDMI编码器、4K超高清HDMI编码器。 视频使用操作说明书-T80002系列视频编码器如何对接海康NVR硬盘录像机&#xff0c;包括T80002系列高清HDMI编码器、4K超高清HDMI编码器 同三…...

el-input-number计数器change事件校验数据,改变绑定数据值后change方法失效问题的原因及解决方法

在change事件中如果对el-input-number绑定的数据进行更改&#xff0c;会出现change事件失效的问题 试过&#xff1a;this.$set()及赋值等方法&#xff0c;都无法解决 解决方法&#xff1a;用$nextTick函数对绑定值进行更改&#xff08; this.$nextTick(() > { this.绑定…...

将vue项目整合到springboot项目中并在阿里云上运行

第一步&#xff0c;使用springboot中的thymeleaf模板引擎 导入依赖 <!-- thymeleaf 模板 --><dependency><groupId>org.springframework.boot</groupId><artifactId>spring-boot-starter-thymeleaf</artifactId></dependency> 在r…...

AI编程--插件对比分析:CodeRider、GitHub Copilot及其他

AI编程插件对比分析&#xff1a;CodeRider、GitHub Copilot及其他 随着人工智能技术的快速发展&#xff0c;AI编程插件已成为提升开发者生产力的重要工具。CodeRider和GitHub Copilot作为市场上的领先者&#xff0c;分别以其独特的特性和生态系统吸引了大量开发者。本文将从功…...

06 Deep learning神经网络编程基础 激活函数 --吴恩达

深度学习激活函数详解 一、核心作用 引入非线性:使神经网络可学习复杂模式控制输出范围:如Sigmoid将输出限制在(0,1)梯度传递:影响反向传播的稳定性二、常见类型及数学表达 Sigmoid σ ( x ) = 1 1 +...

Unity | AmplifyShaderEditor插件基础(第七集:平面波动shader)

目录 一、&#x1f44b;&#x1f3fb;前言 二、&#x1f608;sinx波动的基本原理 三、&#x1f608;波动起来 1.sinx节点介绍 2.vertexPosition 3.集成Vector3 a.节点Append b.连起来 4.波动起来 a.波动的原理 b.时间节点 c.sinx的处理 四、&#x1f30a;波动优化…...

Python 包管理器 uv 介绍

Python 包管理器 uv 全面介绍 uv 是由 Astral&#xff08;热门工具 Ruff 的开发者&#xff09;推出的下一代高性能 Python 包管理器和构建工具&#xff0c;用 Rust 编写。它旨在解决传统工具&#xff08;如 pip、virtualenv、pip-tools&#xff09;的性能瓶颈&#xff0c;同时…...

【从零学习JVM|第三篇】类的生命周期(高频面试题)

前言&#xff1a; 在Java编程中&#xff0c;类的生命周期是指类从被加载到内存中开始&#xff0c;到被卸载出内存为止的整个过程。了解类的生命周期对于理解Java程序的运行机制以及性能优化非常重要。本文会深入探寻类的生命周期&#xff0c;让读者对此有深刻印象。 目录 ​…...

Golang——6、指针和结构体

指针和结构体 1、指针1.1、指针地址和指针类型1.2、指针取值1.3、new和make 2、结构体2.1、type关键字的使用2.2、结构体的定义和初始化2.3、结构体方法和接收者2.4、给任意类型添加方法2.5、结构体的匿名字段2.6、嵌套结构体2.7、嵌套匿名结构体2.8、结构体的继承 3、结构体与…...

提升移动端网页调试效率:WebDebugX 与常见工具组合实践

在日常移动端开发中&#xff0c;网页调试始终是一个高频但又极具挑战的环节。尤其在面对 iOS 与 Android 的混合技术栈、各种设备差异化行为时&#xff0c;开发者迫切需要一套高效、可靠且跨平台的调试方案。过去&#xff0c;我们或多或少使用过 Chrome DevTools、Remote Debug…...

【SpringBoot自动化部署】

SpringBoot自动化部署方法 使用Jenkins进行持续集成与部署 Jenkins是最常用的自动化部署工具之一&#xff0c;能够实现代码拉取、构建、测试和部署的全流程自动化。 配置Jenkins任务时&#xff0c;需要添加Git仓库地址和凭证&#xff0c;设置构建触发器&#xff08;如GitHub…...

Linux中《基础IO》详细介绍

目录 理解"文件"狭义理解广义理解文件操作的归类认知系统角度文件类别 回顾C文件接口打开文件写文件读文件稍作修改&#xff0c;实现简单cat命令 输出信息到显示器&#xff0c;你有哪些方法stdin & stdout & stderr打开文件的方式 系统⽂件I/O⼀种传递标志位…...

AD学习(3)

1 PCB封装元素组成及简单的PCB封装创建 封装的组成部分&#xff1a; &#xff08;1&#xff09;PCB焊盘&#xff1a;表层的铜 &#xff0c;top层的铜 &#xff08;2&#xff09;管脚序号&#xff1a;用来关联原理图中的管脚的序号&#xff0c;原理图的序号需要和PCB封装一一…...