【大模型LLM面试合集】大语言模型架构_MHA_MQA_GQA
MHA_MQA_GQA
1.总结
- 在 MHA(Multi Head Attention) 中,每个头有自己单独的 key-value 对;标准的多头注意力机制,h个Query、Key 和 Value 矩阵。
- 在 MQA(Multi Query Attention) 中只会有一组 key-value 对;多查询注意力的一种变体,也是用于自回归解码的一种注意力机制。与MHA不同的是,MQA 让所有的头之间共享同一份 Key 和 Value 矩阵,每个头只单独保留了一份 Query 参数,从而大大减少 Key 和 Value 矩阵的参数量。
- 在 GQA(Grouped Query Attention)中,会对 attention 进行分组操作,query 被分为 N 组,每个组共享一个 Key 和 Value 矩阵GQA将查询头分成G组,每个组共享一个Key 和 Value 矩阵。GQA-G是指具有G组的grouped-query attention。GQA-1具有单个组,因此具有单个Key 和 Value,等效于MQA。而GQA-H具有与头数相等的组,等效于MHA。
GQA-N 是指具有 N 组的 Grouped Query Attention。GQA-1具有单个组,因此具有单个Key 和 Value,等效于MQA。而GQA-H具有与头数相等的组,等效于MHA。
GQA介于MHA和MQA之间。GQA 综合 MHA 和 MQA ,既不损失太多性能,又能利用 MQA 的推理加速。不是所有 Q 头共享一组 KV,而是分组一定头数 Q 共享一组 KV,比如上图中就是两组 Q 共享一组 KV。
2.代码实现
2.1 MHA
多头注意力机制是Transformer模型中的核心组件。在其设计中,"多头"意味着该机制并不只计算一种注意力权重,而是并行计算多种权重,每种权重都从不同的“视角”捕获输入的不同信息。
- 为输入序列中的每个元素计算q, k, v,这是通过将输入此向量与三个权重矩阵相乘实现的:
q = x W q k = x W k v = x W v \begin{aligned} q & =x W_{q} \\ k & =x W_{k} \\ v & =x W_{v}\end{aligned} qkv=xWq=xWk=xWv
其中, x x x是输入词向量, W q W_q Wq, W k W_k Wk和 W v W_v Wv是q, k, v的权重矩阵 - 计算q, k 注意力得分: score ( q , k ) = q ⋅ k T d k \operatorname{score}(q, k)=\frac{q \cdot k^{T}}{\sqrt{d_{k}}} score(q,k)=dkq⋅kT,其中, d k d_k dk是k的维度
- 使用softmax得到注意力权重: Attention ( q , K ) = softmax ( score ( q , k ) ) \operatorname{Attention}(q, K)=\operatorname{softmax}(\operatorname{score}(q, k)) Attention(q,K)=softmax(score(q,k))
- 使用注意力权重和v,计算输出: O u t p u t = Attention ( q , K ) ⋅ V Output =\operatorname{Attention}(q, K) \cdot V Output=Attention(q,K)⋅V
- 拼接多头输出,乘以 W O W_O WO,得到最终输出: M u l t i H e a d O u t p u t = C o n c a t ( O u t p u t 1 , O u t p u t 2 , … , O u t p u t H ) W O MultiHeadOutput = Concat \left(\right. Output ^{1}, Output ^{2}, \ldots, Output \left.^{H}\right) W_{O} MultiHeadOutput=Concat(Output1,Output2,…,OutputH)WO
代码实现
import torch
from torch import nn
class MutiHeadAttention(torch.nn.Module):def __init__(self, hidden_size, num_heads):super(MutiHeadAttention, self).__init__()self.num_heads = num_headsself.head_dim = hidden_size // num_heads## 初始化Q、K、V投影矩阵self.q_linear = nn.Linear(hidden_size, hidden_size)self.k_linear = nn.Linear(hidden_size, hidden_size)self.v_linear = nn.Linear(hidden_size, hidden_size)## 输出线性层self.o_linear = nn.Linear(hidden_size, hidden_size)def forward(self, hidden_state, attention_mask=None):batch_size = hidden_state.size()[0]query = self.q_linear(hidden_state)key = self.k_linear(hidden_state)value = self.v_linear(hidden_state)query = self.split_head(query)key = self.split_head(key)value = self.split_head(value)## 计算注意力分数attention_scores = torch.matmul(query, key.transpose(-1, -2)) / torch.sqrt(torch.tensor(self.head_dim))if attention_mask != None:attention_scores += attention_mask * -1e-9## 对注意力分数进行归一化attention_probs = torch.softmax(attention_scores, dim=-1)output = torch.matmul(attention_probs, value)## 对注意力输出进行拼接output = output.transpose(-1, -2).contiguous().view(batch_size, -1, self.head_dim * self.num_heads)output = self.o_linear(output)return outputdef split_head(self, x):batch_size = x.size()[0]return x.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1,2)
2.2 MQA
上图最右侧,直观上就是在计算多头注意力的时候,query仍然进行分头,和多头注意力机制相同,而key和value只有一个头。
正常情况在计算多头注意力分数的时候,query、key的维度是相同的,所以可以直接进行矩阵乘法,但是在多查询注意力(MQA)中,query的维度为 [batch_size, num_heads, seq_len, head_dim]
,key和value的维度为 [batch_size, 1, seq_len, head_dim]
。这样就无法直接进行矩阵的乘法,为了完成这一乘法,可以采用torch的广播乘法
## 多查询注意力
import torch
from torch import nn
class MutiQueryAttention(torch.nn.Module):def __init__(self, hidden_size, num_heads):super(MutiQueryAttention, self).__init__()self.num_heads = num_headsself.head_dim = hidden_size // num_heads## 初始化Q、K、V投影矩阵self.q_linear = nn.Linear(hidden_size, hidden_size)self.k_linear = nn.Linear(hidden_size, self.head_dim) ###self.v_linear = nn.Linear(hidden_size, self.head_dim) ##### 输出线性层self.o_linear = nn.Linear(hidden_size, hidden_size)def forward(self, hidden_state, attention_mask=None):batch_size = hidden_state.size()[0]query = self.q_linear(hidden_state)key = self.k_linear(hidden_state)value = self.v_linear(hidden_state)query = self.split_head(query)key = self.split_head(key, 1)value = self.split_head(value, 1)## 计算注意力分数attention_scores = torch.matmul(query, key.transpose(-1, -2)) / torch.sqrt(torch.tensor(self.head_dim))if attention_mask != None:attention_scores += attention_mask * -1e-9## 对注意力分数进行归一化attention_probs = torch.softmax(attention_scores, dim=-1)output = torch.matmul(attention_probs, value)output = output.transpose(-1, -2).contiguous().view(batch_size, -1, self.head_dim * self.num_heads)output = self.o_linear(output)return outputdef split_head(self, x, head_num=None):batch_size = x.size()[0]if head_num == None:return x.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1,2)else:return x.view(batch_size, -1, head_num, self.head_dim).transpose(1,2)
相比于多头注意力,多查询注意力在W_k和W_v的维度映射上有所不同,还有就是计算注意力分数采用的是广播机制,计算最后的output也是广播机制,其他的与多头注意力完全相同。
2.3 GQA
GQA将MAQ中的key、value的注意力头数设置为一个能够被原本的注意力头数整除的一个数字,也就是group数。
不同的模型使用GQA有着不同的实现方式,但是总体的思路就是这么实现的,注意,设置的组一定要能够被注意力头数整除。
## 分组注意力查询
import torch
from torch import nn
class GroupQueryAttention(torch.nn.Module):def __init__(self, hidden_size, num_heads, group_num):super(MutiQueryAttention, self).__init__()self.num_heads = num_headsself.head_dim = hidden_size // num_headsself.group_num = group_num## 初始化Q、K、V投影矩阵self.q_linear = nn.Linear(hidden_size, hidden_size)self.k_linear = nn.Linear(hidden_size, self.group_num * self.head_dim)self.v_linear = nn.Linear(hidden_size, self.group_num * self.head_dim)## 输出线性层self.o_linear = nn.Linear(hidden_size, hidden_size)def forward(self, hidden_state, attention_mask=None):batch_size = hidden_state.size()[0]query = self.q_linear(hidden_state)key = self.k_linear(hidden_state)value = self.v_linear(hidden_state)query = self.split_head(query)key = self.split_head(key, self.group_num)value = self.split_head(value, self.group_num)## 计算注意力分数attention_scores = torch.matmul(query, key.transpose(-1, -2)) / torch.sqrt(torch.tensor(self.head_dim))if attention_mask != None:attention_scores += attention_mask * -1e-9## 对注意力分数进行归一化attention_probs = torch.softmax(attention_scores, dim=-1)output = torch.matmul(attention_probs, value)output = output.transpose(-1, -2).contiguous().view(batch_size, -1, self.head_dim * self.num_heads)output = self.o_linear(output)return outputdef split_head(self, x, group_num=None):batch_size,seq_len = x.size()[:2]if group_num == None:return x.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1,2)else:x = x.view(batch_size, -1, group_num, self.head_dim).transpose(1,2)x = x[:, :, None, :, :].expand(batch_size, group_num, self.num_heads // group_num, seq_len, self.head_dim).reshape(batch_size, self.num_heads // group_num * group_num, seq_len, self.head_dim)return x
相关文章:

【大模型LLM面试合集】大语言模型架构_MHA_MQA_GQA
MHA_MQA_GQA 1.总结 在 MHA(Multi Head Attention) 中,每个头有自己单独的 key-value 对;标准的多头注意力机制,h个Query、Key 和 Value 矩阵。在 MQA(Multi Query Attention) 中只会有一组 k…...

向上调整算法(详解)c++
算法流程: 与⽗结点的权值作⽐较,如果⽐它⼤,就与⽗亲交换; 交换完之后,重复 1 操作,直到⽐⽗亲⼩,或者换到根节点的位置 这里为什么插入85完后合法? 我们插入一个85,…...

【Transformer】手撕Attention
import torch from torch import nn import torch.functional as F import mathX torch.randn(16,64,512) # B,T,Dd_model 512 # 模型的维度 n_head 8 # 注意力头的数量多头注意力机制 class multi_head_attention(nn.Module): def __init__(self, d_model, n_hea…...
844.比较含退格的字符串
目录 题目思路解法收获 题目 给定 s 和 t 两个字符串,当它们分别被输入到空白的文本编辑器后,如果两者相等,返回 true 。# 代表退格字符。 注意:如果对空文本输入退格字符,文本继续为空。 思路 如何解退格之后left…...
图书管理系统 Axios 源码__编辑图书
目录 功能概述: 代码实现(index.js): 代码解析: 图书管理系统中,删除图书功能是核心操作之一。下是基于 HTML、Bootstrap、JavaScript 和 Axios 实现的删除图书功能的详细介绍。 功能概述: …...

LabVIEW纤维集合体微电流测试仪
LabVIEW开发纤维集合体微电流测试仪。该设备精确测量纤维材料在特定电压下的电流变化,以分析纤维的结构、老化及回潮率等属性,对于纤维材料的科学研究及质量控制具有重要意义。 项目背景 在纤维材料的研究与应用中,电学性能是评估其性能…...
Commander 一款命令行自定义命令依赖
一、安装 commander 插件 npm install commander 二、基本用法 1. 创建一个简单的命令行程序 创建一个 JavaScript 文件,例如 mycli.js,并添加以下代码: // 引入 commander 模块并获取 program 对象。const { program } require("…...
Day24 洛谷普及2004(内涵前缀和与差分算法)
零基础洛谷刷题记录 Day01 2024.11.18 Day02 2024.11.25 Day03 2024.11.26 Day04 2024.11.28 Day05 2024.11.29 Day06 2024 12.02 Day07 2024.12.03 Day08 2024 12 05 Day09 2024.12.07 Day10 2024.12.09 Day11 2024.12.10 Day12 2024.12.12 Day13 2024.12.16 Day14 2024.12.1…...

遗传算法与深度学习实战(33)——WGAN详解与实现
遗传算法与深度学习实战(33)——WGAN详解与实现 0. 前言1. 训练生成对抗网络的挑战2. GAN 优化问题2.1 梯度消失2.2 模式崩溃 2.3 无法收敛3 Wasserstein GAN3.1 Wasserstein 损失3.2 使用 Wasserstein 损失改进 DCGAN 小结系列链接 0. 前言 原始的生成…...

gitlab云服务器配置
目录 1、关闭防火墙 2、安装gitlab 3、修改配置 4、查看版本 GitLab终端常用命令 5、访问 1、关闭防火墙 firewall-cmd --state 检查防火墙状态 systemctl stop firewalld.service 停止防火墙 2、安装gitlab xftp中导入安装包 [rootgitlab ~]#mkdir -p /service/tool…...

SAP SD学习笔记27 - 请求计划(开票计划)之1 - 定期请求(定期开票)
上两章讲了贩卖契约(框架协议)的概要,以及贩卖契约中最为常用的 基本契约 - 数量契约和金额契约。 SAP SD学习笔记26 - 贩卖契约(框架协议)的概要,基本契约 - 数量契约_sap 框架协议-CSDN博客 SAP SD学习笔记27 - 贩卖契约(框架…...
HTML DOM 修改 HTML 内容
HTML DOM 修改 HTML 内容 引言 HTML DOM(文档对象模型)是浏览器内部用来解析和操作HTML文档的一种机制。通过DOM,我们可以轻松地修改HTML文档的结构、样式和行为。本文将详细介绍如何使用HTML DOM来修改HTML内容,包括元素的增删改查、属性修改以及事件处理等。 1. HTML …...

基于VMware的ubuntu与vscode建立ssh连接
1.首先安装openssh服务 sudo apt update sudo apt install openssh-server -y 2.启动并检查ssh服务状态 到这里可以按q退出 之后输入命令 : ip a 红色挡住的部分就是我们要的地址,这里就不展示了哈 3.配置vscode 打开vscode 搜索并安装:…...
Flutter Candies 一桶天下
| | | | | | | | 入魔的冬瓜 最近刚入桶的兄弟,有责任心的开发者,对自己的项目会不断进行优化,达到最完美的状态 自定义日历组件 主要功能 支持公历,农历,节气,传统节日,常用节假日 …...

maven如何不把依赖的jar打包到同一个jar?
spring boot项目打jar包部署: 经过以下步骤, 最终会形成maven依赖的多个jar(包括lib下添加的)、 我们编写的程序代码打成一个jar,将程序jar与 依赖jar分开,便于管理: success: 最终…...
HTML5 技术深度解读:本地存储与地理定位的最佳实践
系列文章目录 01-从零开始学 HTML:构建网页的基本框架与技巧 02-HTML常见文本标签解析:从基础到进阶的全面指南 03-HTML从入门到精通:链接与图像标签全解析 04-HTML 列表标签全解析:无序与有序列表的深度应用 05-HTML表格标签全面…...

AIGC技术中常提到的 “嵌入转换到同一个向量空间中”该如何理解
在AIGC(人工智能生成内容)技术中,“嵌入转换到同一个向量空间中”是一个核心概念,其主要目的是将不同类型的输入数据(如文本、图像、音频等)映射到一个统一的连续向量空间中,从而实现数据之间的…...
【机器学习理论】朴素贝叶斯网络
基础知识: 先验概率:对某个事件发生的概率的估计。可以是基于历史数据的估计,可以由专家知识得出等等。一般是单独事件概率。 后验概率:指某件事已经发生,计算事情发生是由某个因素引起的概率。一般是一个条件概率。 …...
Docker 部署 GLPI(IT 资产管理软件系统)
GLPI 简介 GLPI open source tool to manage Helpdesk and IT assets GLPI stands for Gestionnaire Libre de Parc Informatique(法语 资讯设备自由软件 的缩写) is a Free Asset and IT Management Software package, that provides ITIL Service De…...

【Vaadin flow 实战】第5讲-使用常用UI组件绘制页面元素
vaadin flow官方提供的UI组件文档地址是 https://vaadin.com/docs/latest/components这里,我简单实战了官方提供的一些免费的UI组件,使用案例如下: Accordion 手风琴 Accordion 手风琴效果组件 Accordion 手风琴-测试案例代码 Slf4j PageT…...
OpenLayers 可视化之热力图
注:当前使用的是 ol 5.3.0 版本,天地图使用的key请到天地图官网申请,并替换为自己的key 热力图(Heatmap)又叫热点图,是一种通过特殊高亮显示事物密度分布、变化趋势的数据可视化技术。采用颜色的深浅来显示…...

7.4.分块查找
一.分块查找的算法思想: 1.实例: 以上述图片的顺序表为例, 该顺序表的数据元素从整体来看是乱序的,但如果把这些数据元素分成一块一块的小区间, 第一个区间[0,1]索引上的数据元素都是小于等于10的, 第二…...

python/java环境配置
环境变量放一起 python: 1.首先下载Python Python下载地址:Download Python | Python.org downloads ---windows -- 64 2.安装Python 下面两个,然后自定义,全选 可以把前4个选上 3.环境配置 1)搜高级系统设置 2…...

UDP(Echoserver)
网络命令 Ping 命令 检测网络是否连通 使用方法: ping -c 次数 网址ping -c 3 www.baidu.comnetstat 命令 netstat 是一个用来查看网络状态的重要工具. 语法:netstat [选项] 功能:查看网络状态 常用选项: n 拒绝显示别名&#…...

家政维修平台实战20:权限设计
目录 1 获取工人信息2 搭建工人入口3 权限判断总结 目前我们已经搭建好了基础的用户体系,主要是分成几个表,用户表我们是记录用户的基础信息,包括手机、昵称、头像。而工人和员工各有各的表。那么就有一个问题,不同的角色…...
汇编常见指令
汇编常见指令 一、数据传送指令 指令功能示例说明MOV数据传送MOV EAX, 10将立即数 10 送入 EAXMOV [EBX], EAX将 EAX 值存入 EBX 指向的内存LEA加载有效地址LEA EAX, [EBX4]将 EBX4 的地址存入 EAX(不访问内存)XCHG交换数据XCHG EAX, EBX交换 EAX 和 EB…...

html css js网页制作成品——HTML+CSS榴莲商城网页设计(4页)附源码
目录 一、👨🎓网站题目 二、✍️网站描述 三、📚网站介绍 四、🌐网站效果 五、🪓 代码实现 🧱HTML 六、🥇 如何让学习不再盲目 七、🎁更多干货 一、👨…...

vulnyx Blogger writeup
信息收集 arp-scan nmap 获取userFlag 上web看看 一个默认的页面,gobuster扫一下目录 可以看到扫出的目录中得到了一个有价值的目录/wordpress,说明目标所使用的cms是wordpress,访问http://192.168.43.213/wordpress/然后查看源码能看到 这…...

Linux nano命令的基本使用
参考资料 GNU nanoを使いこなすnano基础 目录 一. 简介二. 文件打开2.1 普通方式打开文件2.2 只读方式打开文件 三. 文件查看3.1 打开文件时,显示行号3.2 翻页查看 四. 文件编辑4.1 Ctrl K 复制 和 Ctrl U 粘贴4.2 Alt/Esc U 撤回 五. 文件保存与退出5.1 Ctrl …...
Qt 事件处理中 return 的深入解析
Qt 事件处理中 return 的深入解析 在 Qt 事件处理中,return 语句的使用是另一个关键概念,它与 event->accept()/event->ignore() 密切相关但作用不同。让我们详细分析一下它们之间的关系和工作原理。 核心区别:不同层级的事件处理 方…...