注意力机制(四):多头注意力
专栏:神经网络复现目录
注意力机制
注意力机制(Attention Mechanism)是一种人工智能技术,它可以让神经网络在处理序列数据时,专注于关键信息的部分,同时忽略不重要的部分。在自然语言处理、计算机视觉、语音识别等领域,注意力机制已经得到了广泛的应用。
注意力机制的主要思想是,在对序列数据进行处理时,通过给不同位置的输入信号分配不同的权重,使得模型更加关注重要的输入。例如,在处理一句话时,注意力机制可以根据每个单词的重要性来调整模型对每个单词的注意力。这种技术可以提高模型的性能,尤其是在处理长序列数据时。
在深度学习模型中,注意力机制通常是通过添加额外的网络层实现的,这些层可以学习到如何计算权重,并将这些权重应用于输入信号。常见的注意力机制包括自注意力机制(self-attention)、多头注意力机制(multi-head attention)等。
总之,注意力机制是一种非常有用的技术,它可以帮助神经网络更好地处理序列数据,提高模型的性能。
文章目录
- 注意力机制
- 多头注意力
- 数学逻辑
- 实现
多头注意力
多头注意力(Multi-Head Attention)是注意力机制的一种扩展形式,可以在处理序列数据时更有效地提取信息。
在标准的注意力机制中,我们计算一个加权的上下文向量来表示输入序列的信息。而在多头注意力中,我们使用多组注意力权重,每组权重可以学习到不同的语义信息,并且每组权重都会产生一个上下文向量。最后,这些上下文向量会被拼接起来,再通过一个线性变换得到最终的输出。
多头注意力是Transformer模型中的一个重要组成部分,被广泛用于各种自然语言处理任务,如机器翻译、文本分类等。

数学逻辑
在实现多头注意力之前,让我们用数学语言将这个模型形式化地描述出来。 给定查询q∈Rdqq\in R^{d_q}q∈Rdq、 键k∈Rdkk\in R^{d_k}k∈Rdk和值v∈Rdvv\in R^{d_v}v∈Rdv, 每个注意力头的计算方法为:
hi=f(Wi(q)q,Wi(k)k,Wi(v)v)∈Rpvh_i=f(W_i^{(q)}q,W_i^{(k)}k,W_i^{(v)}v)\in R^{pv}hi=f(Wi(q)q,Wi(k)k,Wi(v)v)∈Rpv
其中,可学习的参数包括 Wi(q)W_i^{(q)}Wi(q)、 Wi(k)W_i^{(k)}Wi(k)和 Wi(v)W_i^{(v)}Wi(v), 以及代表注意力汇聚的函数fff。 fff可以是加性注意力和缩放点积注意力。 多头注意力的输出需要经过另一个线性转换, 它对应着hhh个头连结后的结果,因此其可学习参数是 WoW_oWo:

实现
在实现过程中通常选择缩放点积注意力作为每一个注意力头。 为了避免计算代价和参数代价的大幅增长, 我们设定pq=pk=pv=pp/hp_q=p_k=p_v=p_p/hpq=pk=pv=pp/h。 值得注意的是,如果将查询、键和值的线性变换的输出数量设置为pqh=pkh=pvh=ppp_qh=p_kh=p_vh=p_ppqh=pkh=pvh=pp, 则可以并行计算hhh个头。 在下面的实现中,是通过参数pop_oponum_hiddens指定的。
#@save
class MultiHeadAttention(nn.Module):"""多头注意力"""def __init__(self, key_size, query_size, value_size, num_hiddens,num_heads, dropout, bias=False, **kwargs):super(MultiHeadAttention, self).__init__(**kwargs)self.num_heads = num_headsself.attention = d2l.DotProductAttention(dropout)self.W_q = nn.Linear(query_size, num_hiddens, bias=bias)self.W_k = nn.Linear(key_size, num_hiddens, bias=bias)self.W_v = nn.Linear(value_size, num_hiddens, bias=bias)self.W_o = nn.Linear(num_hiddens, num_hiddens, bias=bias)def forward(self, queries, keys, values, valid_lens):# queries,keys,values的形状:# (batch_size,查询或者“键-值”对的个数,num_hiddens)# valid_lens 的形状:# (batch_size,)或(batch_size,查询的个数)# 经过变换后,输出的queries,keys,values 的形状:# (batch_size*num_heads,查询或者“键-值”对的个数,# num_hiddens/num_heads)queries = transpose_qkv(self.W_q(queries), self.num_heads)keys = transpose_qkv(self.W_k(keys), self.num_heads)values = transpose_qkv(self.W_v(values), self.num_heads)if valid_lens is not None:# 在轴0,将第一项(标量或者矢量)复制num_heads次,# 然后如此复制第二项,然后诸如此类。valid_lens = torch.repeat_interleave(valid_lens, repeats=self.num_heads, dim=0)# output的形状:(batch_size*num_heads,查询的个数,# num_hiddens/num_heads)output = self.attention(queries, keys, values, valid_lens)# output_concat的形状:(batch_size,查询的个数,num_hiddens)output_concat = transpose_output(output, self.num_heads)return self.W_o(output_concat)
为了能够使多个头并行计算, 上面的MultiHeadAttention类将使用下面定义的两个转置函数。 具体来说,transpose_output函数反转了transpose_qkv函数的操作。
#@save
def transpose_qkv(X, num_heads):"""为了多注意力头的并行计算而变换形状"""# 输入X的形状:(batch_size,查询或者“键-值”对的个数,num_hiddens)# 输出X的形状:(batch_size,查询或者“键-值”对的个数,num_heads,# num_hiddens/num_heads)X = X.reshape(X.shape[0], X.shape[1], num_heads, -1)# 输出X的形状:(batch_size,num_heads,查询或者“键-值”对的个数,# num_hiddens/num_heads)X = X.permute(0, 2, 1, 3)# 最终输出的形状:(batch_size*num_heads,查询或者“键-值”对的个数,# num_hiddens/num_heads)return X.reshape(-1, X.shape[2], X.shape[3])#@save
def transpose_output(X, num_heads):"""逆转transpose_qkv函数的操作"""X = X.reshape(-1, num_heads, X.shape[1], X.shape[2])X = X.permute(0, 2, 1, 3)return X.reshape(X.shape[0], X.shape[1], -1)
代码解释:
这段代码实现了多头注意力机制,其中 MultiHeadAttention 类实现了多头注意力的前向传播, transpose_qkv 函数将输入的 queries, keys, values 通过线性变换并按照 num_heads 进行分组,最终输出变换后的 queries, keys, values,在前向传播中使用这些变换后的 queries, keys, values 来计算注意力权重。在 transpose_qkv 函数的实现中,首先将 queries, keys, values 转换成形状为 (batch_size, queries/keys/values_num, num_hiddens) 的张量,然后根据 num_heads 将最后一维进行分组,变换成形状为 (batch_size, num_heads, queries/keys/values_num, num_hiddens/num_heads) 的张量,最后将第一维和第二维进行交换,输出形状为 (batch_size*num_heads, queries/keys/values_num, num_hiddens/num_heads) 的张量。transpose_output 函数实现了对 MultiHeadAttention 的输出进行逆转换的操作。
这么做的原因是因为多头注意力机制可以将输入张量进行 num_heads 个独立的注意力计算,将计算结果在最后一维拼接起来作为输出,这样可以提高模型的并行性,加快计算速度。同时,通过变换形状将 num_heads 独立处理,也可以增强模型对不同位置和特征的表征能力。
具体来说,这段代码实现的是一个MultiHeadAttention类,其中定义了一个forward方法。这个方法接收一个查询序列queries,一个键序列keys,一个值序列values和一个有效长度序列valid_lens作为输入,然后输出一个加权聚合的结果。
MultiHeadAttention类的初始化方法中,我们定义了几个线性层,以及注意力计算函数,然后用这些组件来定义一个多头注意力层。该层包括将输入queries、keys和values通过三个线性层进行变换,以便将它们的形状变为(batch_size * num_heads,查询或者“键-值”对的个数,num_hiddens/num_heads),其中num_heads表示注意力头的数量。然后,我们通过调用transpose_qkv函数对这些变换后的输入进行一次变换,以便在注意力计算函数中实现多头并行计算。最后,我们通过调用transpose_output函数将输出重构成(batch_size,查询的个数,num_hiddens),并通过一个线性层对其进行变换,输出最终结果。
transpose_qkv函数将输入的queries、keys和values通过reshape和permute操作进行变换,以便多头并行计算。具体来说,它将输入变换为(batch_size,查询或者“键-值”对的个数,num_heads,num_hiddens/num_heads)的形状,然后将第2和第3个轴进行交换。最后,它将输出变换为(batch_size * num_heads, 查询或者“键-值”对的个数, num_hiddens/num_heads)的形状。
transpose_output函数将多头并行计算得到的输出通过reshape和permute操作逆转回原来的形状,具体来说,它将输出变换为(batch_size,查询的个数,num_heads, num_hiddens/num_heads)的形状,然后将第2和第3个轴进行交换,最终将输出变换为(batch_size,查询的个数,num_hiddens)的形状。
这里似乎所有的单头都是同一些参数,这样不会导致每个单头的输出都是一样的吗?
这里的确有点难懂, 这里其实是把所有注意力头里面的参数拼起来, 变成了一个大的全连接层

相关文章:
注意力机制(四):多头注意力
专栏:神经网络复现目录 注意力机制 注意力机制(Attention Mechanism)是一种人工智能技术,它可以让神经网络在处理序列数据时,专注于关键信息的部分,同时忽略不重要的部分。在自然语言处理、计算机视觉、语…...
【2023Unity游戏开发教程】零基础带你从小白到超神19——射线检测
文章目录 射线检测从某点发射一条射线从摄像机发射一条射线射线检测 游戏中的红外线,默认肉眼是看不到的,从某个初始点开始,沿着特定的方向发射一条不可见且无限长的射线,通过此射线检测是否有任何模型添加了Collider碰撞器组件。一旦检测到碰撞,停止射线继续发射。 碰撞检…...
内存泄漏和内存溢出的区别
参考答案 内存溢出(out of memory):指程序在申请内存时,没有足够的内存空间供其使用,出现 out of memory。内存泄露(memory leak):指程序在申请内存后,无法释放已申请的内存空间,内存泄露堆积会导致内存被…...
文本三剑客之sed编辑器
文本三剑客:都是按行读取后处理。 grep 过滤行内容。awk 过滤字段。sed 过滤行内容;修改行内容。sed编辑器 sed是一种流编辑器,流编辑器会在编辑器处理数据之前基于预先提供的一组规则来编辑数据流。 sed编辑器可以根据命令来处理数据流中…...
深度学习:GPT1、GPT2、GPT-3
深度学习:GPT1、GPT2、GPT3的原理与模型代码解读GPT-1IntroductionFramework自监督学习微调ExperimentGPT-2IntroductionApproachConclusionGPT-3GPT-1 Introduction GPT-1(Generative Pre-training Transformer-1)是由OpenAI于2018年发布的…...
使用Docker 一键部署SpringBoot和SpringCloud项目
使用Docker 一键部署SpringBoot和SpringCloud项目 1. 准备工作2. 创建Dockerfile3. 创建Docker Compose文件4. 构建和运行Docker镜像5. 验证部署6. 总结Docker是一个非常流行的容器化技术,可以方便地将应用程序和服务打包成容器并运行在不同的环境中。在本篇博客中,我将向您展…...
【数据结构】用栈实现队列
💯💯💯 本篇总结利用栈如何实现队列的相关操作,不难观察,栈和队列是可以相互转化的,需要好好总结它们的特性,构造出一个恰当的结构来实现即可,所以本篇难点不在代码思维,…...
[Netty源码] 服务端启动过程 (二)
文章目录1.ServerBootstrap2.服务端启动过程3.具体步骤分析3.1 创建服务端Channel3.2 初始化服务端Channel3.3 注册selector3.4 端口绑定1.ServerBootstrap ServerBootstrap引导服务端启动流程: //主EventLoopGroup NioEventLoopGroup master new NioEventLoopGroup(); //从E…...
Week 14
代码源每日一题Div2 106. 订单编号 原题链接:订单编号 思路:这题本来没啥思路,直到获得了某位佬的提示才会做( 我们可以用set来维护一些区间,这些区间为 pair 类型,表示没有使用过的编号,每次…...
【微信小程序】-- 使用 Git 管理项目(五十)
💌 所属专栏:【微信小程序开发教程】 😀 作 者:我是夜阑的狗🐶 🚀 个人简介:一个正在努力学技术的CV工程师,专注基础和实战分享 ,欢迎咨询! &…...
leetcode每日一题:134. 加油站
系列:贪心算法 语言:java 题目来源:Leetcode134. 加油站 题目 在一条环路上有 n 个加油站,其中第 i 个加油站有汽油 gas[i] 升。 你有一辆油箱容量无限的的汽车,从第 i 个加油站开往第 i1 个加油站需要消耗汽油 cost[…...
开放式基金实时排行 API 数据接口
开放式基金实时排行 API 数据接口 多维度参数返回,实时数据,类型参数筛选。 1. 产品功能 返回实时开放式基金排行数据可定义查询基金类型参数;多个基金属性值返回多维指标,一次查询毫秒级返回;数据持续更新与维护&am…...
Android开发中synchronized的实现原理
synchronized的三种使用方式 **1.修饰实例方法,**作用于当前实例加锁,进入同步代码前要获得当前实例的锁。 没有问题的写法: public class AccountingSync implements Runnable{//共享资源(临界资源)static int i0;/*** synchronized 修饰实例方法*/p…...
【华为OD机试 2023最新 】 统一限载货物数最小值(C++)
题目描述 火车站附近的货物中转站负责将到站货物运往仓库,小明在中转站负责调度2K辆中转车(K辆干货中转车,K辆湿货中转车)。 货物由不同供货商从各地发来,各地的货物是依次进站,然后小明按照卸货顺序依次装货到中转车,一个供货商的货只能装到一辆车上,不能拆装,但是…...
【生活工作经验 十】ChatGPT模型对话初探
最近探索了下全球大火的ChatGPT,想对此做个初步了解 一篇博客 当今社会,自然语言处理技术得到了迅速的发展,人工智能技术也越来越受到关注。其中,基于深度学习的大型语言模型,如GPT(Generative Pre-train…...
基于Spring Boot房产销售平台的设计与实现【源码+论文】分享
开发语言:Java 框架:springboot JDK版本:JDK1.8 服务器:tomcat7 数据库:mysql 5.7 数据库工具:Navicat11 开发软件:eclipse/myeclipse/idea Maven包:Maven3.3.9 摘要 信息技术的发展…...
不同类型的电机的工作原理和控制方法汇总
电机控制是指对电机的启动、调速(加速、减速)、运转方向和停止进行的控制,不同类型的电机有着不同的工作原理和控制方法。 一、无刷电机 无刷电机是由电机主体和电机驱动板组成的一种没有电刷和换向器的机电一体化产品。在无刷电机中…...
计算机网络管理 TCP三次握手的建立过程,Wireshark抓包分析并验证TCP三次握手建立连接的报文
⬜⬜⬜ ---🟧🟨🟩🟦🟪 (*^▽^*)欢迎光临 🟧🟨🟩🟦🟪---⬜⬜⬜ ✏️write in front✏️ 📝个人主页:陈丹宇jmu 🎁欢迎各位→…...
HTTP/2.x:最新的网页加载技术,快速提高您的SEO排名
2.1 http2概念HTTP/2.0(又称HTTP2)是HTTP协议的第二个版本。它是对HTTP/1.x的更新,旨在提高网络性能和安全性。HTTP/2.0是由互联网工程任务组(IETF)标准化的,并于2015年发布。2.2 http2.x与http1.x区别HTTP…...
机器学习----线性回归
第一关:简单线性回归与多元线性回归 1、下面属于多元线性回归的是? A、 求得正方形面积与对角线之间的关系。 B、 建立股票价格与成交量、换手率等因素之间的线性关系。 C、 建立西瓜价格与西瓜大小、西瓜产地、甜度等因素之间的线性关系。 D、 建立西瓜…...
RocketMQ延迟消息机制
两种延迟消息 RocketMQ中提供了两种延迟消息机制 指定固定的延迟级别 通过在Message中设定一个MessageDelayLevel参数,对应18个预设的延迟级别指定时间点的延迟级别 通过在Message中设定一个DeliverTimeMS指定一个Long类型表示的具体时间点。到了时间点后…...
mongodb源码分析session执行handleRequest命令find过程
mongo/transport/service_state_machine.cpp已经分析startSession创建ASIOSession过程,并且验证connection是否超过限制ASIOSession和connection是循环接受客户端命令,把数据流转换成Message,状态转变流程是:State::Created 》 St…...
【大模型RAG】Docker 一键部署 Milvus 完整攻略
本文概要 Milvus 2.5 Stand-alone 版可通过 Docker 在几分钟内完成安装;只需暴露 19530(gRPC)与 9091(HTTP/WebUI)两个端口,即可让本地电脑通过 PyMilvus 或浏览器访问远程 Linux 服务器上的 Milvus。下面…...
macOS多出来了:Google云端硬盘、YouTube、表格、幻灯片、Gmail、Google文档等应用
文章目录 问题现象问题原因解决办法 问题现象 macOS启动台(Launchpad)多出来了:Google云端硬盘、YouTube、表格、幻灯片、Gmail、Google文档等应用。 问题原因 很明显,都是Google家的办公全家桶。这些应用并不是通过独立安装的…...
srs linux
下载编译运行 git clone https:///ossrs/srs.git ./configure --h265on make 编译完成后即可启动SRS # 启动 ./objs/srs -c conf/srs.conf # 查看日志 tail -n 30 -f ./objs/srs.log 开放端口 默认RTMP接收推流端口是1935,SRS管理页面端口是8080,可…...
在QWebEngineView上实现鼠标、触摸等事件捕获的解决方案
这个问题我看其他博主也写了,要么要会员、要么写的乱七八糟。这里我整理一下,把问题说清楚并且给出代码,拿去用就行,照着葫芦画瓢。 问题 在继承QWebEngineView后,重写mousePressEvent或event函数无法捕获鼠标按下事…...
STM32HAL库USART源代码解析及应用
STM32HAL库USART源代码解析 前言STM32CubeIDE配置串口USART和UART的选择使用模式参数设置GPIO配置DMA配置中断配置硬件流控制使能生成代码解析和使用方法串口初始化__UART_HandleTypeDef结构体浅析HAL库代码实际使用方法使用轮询方式发送使用轮询方式接收使用中断方式发送使用中…...
Git 3天2K星标:Datawhale 的 Happy-LLM 项目介绍(附教程)
引言 在人工智能飞速发展的今天,大语言模型(Large Language Models, LLMs)已成为技术领域的焦点。从智能写作到代码生成,LLM 的应用场景不断扩展,深刻改变了我们的工作和生活方式。然而,理解这些模型的内部…...
计算机基础知识解析:从应用到架构的全面拆解
目录 前言 1、 计算机的应用领域:无处不在的数字助手 2、 计算机的进化史:从算盘到量子计算 3、计算机的分类:不止 “台式机和笔记本” 4、计算机的组件:硬件与软件的协同 4.1 硬件:五大核心部件 4.2 软件&#…...
Vite中定义@软链接
在webpack中可以直接通过符号表示src路径,但是vite中默认不可以。 如何实现: vite中提供了resolve.alias:通过别名在指向一个具体的路径 在vite.config.js中 import { join } from pathexport default defineConfig({plugins: [vue()],//…...
