当前位置: 首页 > news >正文

Transformer系列:注意力机制的优化,MQA和GQA原理简述

前言

多查询注意力(MQA)、分组查询注意力(GQA)是Transformer中多头注意力(MHA)的变种,它们大幅提高了解码器的推理效率,在LLaMA-2,ChatGLM2等大模型中有广泛使用,本篇介绍MQA、GQA的原理并分析其源码实现。


使用MQA,GQA的背景介绍

多查询注意力(Multi Query Attention,MQA)提出于2019年的论文《Fast Transformer Decoding: One Write-Head is All
You Need》,旨在解决Transformer增量推理阶段效率低下的问题,在当时并没有引起关注,而随着近几年Transformer和GPT成为生成式大模型的基座,面临着产业落地的实际情况,导致GPT的推理加速备受关注,因此MQA又重新被提及起来。
分组查询注意力(Group Query Attention,GQA)提出于2023年,是MQA更一般的形式,它介于MQA和MHA之间,是模型预测表现和模型推理性能之间的一个折衷。


MQA,GQA原理简述

MQA的原理很简单,它将原生Transformer每一层多头注意力的Key线性映射矩阵、Value线性映射矩阵改为该层下所有头共享,也就是说K、V矩阵每层只有一个,而Q矩阵不受影响,其数量和注意力头数相等。以ChatGLM2-6B为例,一共28层,32个注意力头,输入从4096经过Q、K、V矩阵映射维度为128,若采用原生多头注意力机制,则Q、K、V矩阵各有28×32个,而采用MQA的方式则整个模型包含28×32个Q矩阵,28个K矩阵,28个V矩阵,示意图如下

MHA和MQA的差别

可想而知,MQA这种方式大幅减小了参数数量,带来推理加速的同时会造成模型性能损失,且在训练过程使得模型变得不稳定,因此在此基础上提出了GQA,它将Query进行分组,每个组内共享一组Key、Value,令组的数量为N,若N等于1此时等效于MQA,若N等于Query头的数量,此时退化为MHA。GQA是推理效率和模型性能的trade-off。

GQA(中)和MQA(右)对比


MQA,GQA推理加速分析

MQA能够大幅加速采用MHA的Transformer的推理,但是会有明显的性能损失,而GQA通过设置合适的分组大小,可以和MQA的推理性能几乎相等,同时逼近MHA的模型性能。作者在GQA的论文中给到了实验结论来印证这一点。
作者采用T5模型作为研究对象,模型版本采用T5-Large和T5-XXL,它们都采用MHA注意力方式,其中Large参数量770M,XXL参数量11B,在此基础上作者通过up-training方法将T5-XXL改造为MQA和GQA,最终一共四个版本模型进行精度和推理效率的对比,结果如下

横轴代表平均每条样本的推理耗时,越大代表延迟越大,纵轴代表在众多数据集上的评价得分,越大代表得分越高。在MHA方式下,由于XXL的模型参数更大,因此MHA-XXL的推理延迟高于MHA-Large,同时MHA-XXL的模型评分在所有版本里面最高。
经过up-training方法将MHA改造为MQA之后,MQA-XXL获得了所有版本的最低延迟,甚至还低于小一个型号的Large模型,同时MQA使得其模型评分比MHA-XXL降低了,但还是超越了小一个信号的Large的模型,表明大参数量的MQA模型不论在精度还是效率上都超越了小参数量的MHA模型。
而经过up-training方法将MHA改造为GQA,GQA-XXL的推理延迟几乎和MQA-XXL相等,而其性能评分也和参数量最大的MHA-XXL十分接近,整体上GQA达到了最佳效果,推理性能和模型评分都十分优秀。

分组的大小对模型推理延迟的影响

GQA的分组数是一个超参数,组数越大越接近MHA,推理延迟越大,同时模型精度也越高,作者给出了他的实验结论表明,当组数量从1逐渐上升到8时,模型推理的开销并没有明显的增长,在8以后推理开销显著变大,最终作者采用8个分组作为他的最佳选择。
以上从实验结果层给到结论,MQA略微损失了模型精度,但是确实能够大幅降低推理开销,而如果选择了合适的分组数,GQA能够两者皆得。在理论层,MQA和GQA对推理的帮助主要是以下两点

    1. 降低内存读取模型权重的时间开销:由于Key矩阵和Value矩阵数量变少了,因此权重参数量也减少了,需要读取到内存的数量量少了,因此减少了读取权重的等待时间
    1. KV-Cache空间占用降低:KV-Cache会将之前推理过的Key、Value向量存储在内存中,而随着步长和batch_size的增长,KV-Cache空间占用越来越高,使得KV-Cache不能被高效的读写,而MHA和GQA方式使得KV-Cache需要存储的参数量降低了head_num倍,从而提高KV-Cache的读写效率;另一方面,可以有空间来增大batch_size,从而提高模型推理的吞吐量

注意MQA和GQA并没有降低Attention的计算量(FLOPs),因为Key、Value映射矩阵会以广播变量的形式拓展到和MHA和一样,因此计算量不变,只是Key、Value参数共享。


ChatGLM2-6B中的MQA/GQA源码分析

本节采用ChatGLM2-6B的模型源码modeling_chatglm.py来说明MQA和GQA的实现,这两者在代码上没有区别,因为MQA是GQA的特例,当分组数等于1时就是MQA,而chatglm2-6B采用的是分组数为2的GQA,从它的配置文件config.json可以观察得到

{"_name_or_path": "THUDM/chatglm2-6b","model_type": "chatglm","architectures": ["ChatGLMModel"],"auto_map": {"AutoConfig": "configuration_chatglm.ChatGLMConfig","AutoModel": "modeling_chatglm.ChatGLMForConditionalGeneration","AutoModelForSeq2SeqLM": "modeling_chatglm.ChatGLMForConditionalGeneration"},..."multi_query_attention": true,"multi_query_group_num": 2,....
}

其中multi_query_attention代表是否开启多查询注意力,multi_query_group_num代表分组数。
MQA和GQA仅涉及到注意力层,因此直接定位到SelfAttention的代码块

class SelfAttention(torch.nn.Module):def __init__(self, config: ChatGLMConfig, layer_number, device=None):super(SelfAttention, self).__init__()self.layer_number = max(1, layer_number)# TODO 128 * 32self.projection_size = config.kv_channels * config.num_attention_heads# Per attention head and per partition values.self.hidden_size_per_attention_head = self.projection_size // config.num_attention_headsself.num_attention_heads_per_partition = config.num_attention_heads# TODO true 多查询注意力self.multi_query_attention = config.multi_query_attention# TODO qkv线性映射层到3*dself.qkv_hidden_size = 3 * self.projection_sizeif self.multi_query_attention:self.num_multi_query_groups_per_partition = config.multi_query_group_num  # 2self.qkv_hidden_size = (# TODO (128 * 32 + 2 * 128 * 2)self.projection_size + 2 * self.hidden_size_per_attention_head * config.multi_query_group_num)self.query_key_value = nn.Linear(config.hidden_size, self.qkv_hidden_size,bias=config.add_bias_linear or config.add_qkv_bias,device=device, **_config_to_kwargs(config))

该SelfAttention代表某一层下,32个注意力头的运算。在初始化阶段,作者采用大矩阵方案用一个矩阵将所有头QKV创建出来,因此projection_size为128×3,然后以MHA的方式将映射维度乘以3得到qkv_hidden_size,当采用多查询注意力时对qkv_hidden_size重新修改,它等于所有头的Q矩阵,加上KV矩阵各一个,因此projection_size为128 * 32 + 2 * 128 * 2=4608,最后通过一个Linear层实现QKV矩阵的创建。
在推理阶段使用query_key_value进行同意QKV映射,然后通过split算子将QKV进行分解,分解之后所有头的Query为4096维,Key和Value为256维,因为分组数为2,所有存在两个Key和两个Value。

    def forward(self, hidden_states, attention_mask, rotary_pos_emb, kv_cache=None, use_cache=True):# TODO [1(seq), 1(batch), 4608]mixed_x_layer = self.query_key_value(hidden_states)if self.multi_query_attention:# TODO [17, 1, 4096], [17, 1, 256], [17, 1, 256](query_layer, key_layer, value_layer) = mixed_x_layer.split([self.num_attention_heads_per_partition * self.hidden_size_per_attention_head,  # TODO 32 *128self.num_multi_query_groups_per_partition * self.hidden_size_per_attention_head,  # TODO 2 * 128self.num_multi_query_groups_per_partition * self.hidden_size_per_attention_head,  # TODO 2 * 128],dim=-1,)

然后作者将QKV向量的维度进行reshape,将头维度拿出来,准备在下面的代码中对KV进行广播

            # TODO [17, 1, 32, 128]query_layer = query_layer.view(query_layer.size()[:-1] + (self.num_attention_heads_per_partition, self.hidden_size_per_attention_head))# TODO [17, 1, 2, 128]key_layer = key_layer.view(key_layer.size()[:-1] + (self.num_multi_query_groups_per_partition, self.hidden_size_per_attention_head))# TODO TODO [17, 1, 2, 128]value_layer = value_layer.view(value_layer.size()[:-1]+ (self.num_multi_query_groups_per_partition, self.hidden_size_per_attention_head))

在这之前,需要将原始的Query和Key携带旋转位置编码RoPE,使得注意力能够感知到相对位置信息,此步骤和Value无关

        if rotary_pos_emb is not None:query_layer = apply_rotary_pos_emb(query_layer, rotary_pos_emb)key_layer = apply_rotary_pos_emb(key_layer, rotary_pos_emb)

然后作者对Key和Value做广播,因为Query有32个头,而GQA有2组,因此要翻32/2=16倍数,通过torch的expand进行广播实现参数共享,广播之后QKV三者的shape变成一致

        if self.multi_query_attention:# TODO TODO [17, 1, 2, 128] => [17, 1, 2, 1, 128]key_layer = key_layer.unsqueeze(-2)# TODO [17, 1, 2, 16, 128]key_layer = key_layer.expand(# TODO expand 进行广播,k,v向量共享# TODO 只能对维度值是1的进行拓展,如果某些维不需要拓展,写为-1, 32 // 2=16# TODO 有32个头,KV组只有2组,要复制16份-1, -1, -1, self.num_attention_heads_per_partition // self.num_multi_query_groups_per_partition, -1)# TODO [17, 1, 2, 16, 128] => [17, 1, 32, 128]key_layer = key_layer.contiguous().view(key_layer.size()[:2] + (self.num_attention_heads_per_partition, self.hidden_size_per_attention_head))value_layer = value_layer.unsqueeze(-2)value_layer = value_layer.expand(-1, -1, -1, self.num_attention_heads_per_partition // self.num_multi_query_groups_per_partition, -1)value_layer = value_layer.contiguous().view(value_layer.size()[:2] + (self.num_attention_heads_per_partition, self.hidden_size_per_attention_head))

最后所有处理好之后计算注意力权重,并且将权重和Value相乘得到注意力的输出,这里和传统的MHA没有任何却别

context_layer = self.core_attention(query_layer, key_layer, value_layer, attention_mask)

MQA和GQA的主要代码流程结束,核心是先创建分组数的Key和Value矩阵,注意力点乘之前将Key和Value广播到和Query一致即可,全文完毕。

最后的最后

感谢你们的阅读和喜欢,我收藏了很多技术干货,可以共享给喜欢我文章的朋友们,如果你肯花时间沉下心去学习,它们一定能帮到你。

因为这个行业不同于其他行业,知识体系实在是过于庞大,知识更新也非常快。作为一个普通人,无法全部学完,所以我们在提升技术的时候,首先需要明确一个目标,然后制定好完整的计划,同时找到好的学习方法,这样才能更快的提升自己。

这份完整版的大模型 AI 学习资料已经上传CSDN,朋友们如果需要可以微信扫描下方CSDN官方认证二维码免费领取【保证100%免费

一、全套AGI大模型学习路线

AI大模型时代的学习之旅:从基础到前沿,掌握人工智能的核心技能!

img

二、640套AI大模型报告合集

这套包含640份报告的合集,涵盖了AI大模型的理论研究、技术实现、行业应用等多个方面。无论您是科研人员、工程师,还是对AI大模型感兴趣的爱好者,这套报告合集都将为您提供宝贵的信息和启示。

img

三、AI大模型经典PDF籍

随着人工智能技术的飞速发展,AI大模型已经成为了当今科技领域的一大热点。这些大型预训练模型,如GPT-3、BERT、XLNet等,以其强大的语言理解和生成能力,正在改变我们对人工智能的认识。 那以下这些PDF籍就是非常不错的学习资源。

img

四、AI大模型商业化落地方案

img

五、面试资料

我们学习AI大模型必然是想找到高薪的工作,下面这些面试题都是总结当前最新、最热、最高频的面试题,并且每道题都有详细的答案,面试前刷完这套面试题资料,小小offer,不在话下。
在这里插入图片描述

这份完整版的大模型 AI 学习资料已经上传CSDN,朋友们如果需要可以微信扫描下方CSDN官方认证二维码免费领取【保证100%免费

相关文章:

Transformer系列:注意力机制的优化,MQA和GQA原理简述

前言 多查询注意力(MQA)、分组查询注意力(GQA)是Transformer中多头注意力(MHA)的变种,它们大幅提高了解码器的推理效率,在LLaMA-2,ChatGLM2等大模型中有广泛使用,本篇介绍MQA、GQA的原理并分析其源码实现。 使用MQA,G…...

Python知识点11---高阶函数

提前说一点:如果你是专注于Python开发,那么本系列知识点只是带你入个门再详细的开发点就要去看其他资料了,而如果你和作者一样只是操作其他技术的Python API那就足够了。 本篇介绍一下Python的内置函数也叫高阶函数,就是Python自…...

JavaSE——【逻辑控制】(习题)

一、分支结构 2.1 if 语句 【练习】2.1.1 小明,如果这次考到90分以上,给你奖励一个大鸡腿,否则奖你一个大嘴巴子 int score 92;if(score > 90){System.out.println("吃个大鸡腿!!!");}else{System.out.println("挨大嘴…...

自动驾驶仿真:python和carsim联合仿真案例

文章目录 前言一、Carsim官方案例二、Carsim配置1、车辆模型2、procedure配置3、Run Control配置 三、python编写四、运行carsim五、运行python总结 前言 carsim内部有许多相关联合仿真的demo,simulink、labview等等都有涉及,这里简单介绍下python和car…...

Qt报错:libvlc开发的程序,出现Direct3D output全屏窗口

问题描述: 在qt中开发重播模块时,第一次在窗口正常播放,点击重播按钮后会弹出新的Direct3D output窗口播放视频 分析: 因为libvlc_media_player_set_hwnd 这个函数 设置了不存在的窗口句柄,导致vlc视频播放窗口没有嵌…...

yolov5的口罩识别系统+GUI界面 (附代码)

基于YOLOv5模型的口罩识别系统,结合了GUI界面,旨在帮助用户快速、准确地识别图像或视频中佩戴口罩的情况。YOLOv5是一种流行的目标检测模型,具有高效的实时检测能力,而GUI界面则提供了友好的用户交互界面,使得整个系统…...

WPF中Window的外观实现及常用属性

文章目录 1. 概要2. Window的外观2.1 Window的外观组成2.2 Window的实现2.3 Window外观配置2.4 Window 的其他常用属性1. AllowsTransparency 2. WindowStartupLocation3. ShowInTaskbar4. ShowActivated5. SizeToContent6. Topmost7. WindowStyle 1. 概要 和 Android 类似, W…...

(有代码示例)Vue 或 JavaScript中使用全局通信的3种方式

在 Vue 或 JavaScript 应用中,可以使用以下库来实现全局事件通信: Vue.js 中的 EventBus: 在 Vue.js 中,可以使用 EventBus 来实现全局事件通信。EventBus 是一个 Vue 实例,用于在组件之间传递事件。你可以使用 $on、…...

MAB规范(1):概览介绍

前言 MATLAB的MAAB(MathWorks Automotive Advisory Board)建模规范是一套由MathWorks主导的建模指南,旨在提高基于Simulink和Stateflow进行建模的代码质量、可读性、可维护性和可重用性。这些规范最初是由汽车行业的主要厂商共同制定的&…...

基于振弦采集仪的土木工程安全监测技术研究

基于振弦采集仪的土木工程安全监测技术研究 随着土木工程的发展,安全监测成为了非常重要的一部分。土木工程的安全监测旨在及早发现结构的变形、位移、振动等异常情况,以便及时采取措施进行修复或加固,从而保障工程的安全运行。振弦采集仪作…...

这个高考作文满分的极客,想和你聊聊新媒体写作

计育韬 曾为上海市高考作文满分考生 微信官方 SVG AttributeName 开发者 新榜 500 强运营人 复旦大学青年智库讲师 浙江传媒学院客座导师 上海团市委新媒体顾问 上海市金山区青联副主席 文案能力,从来就不是一蹴而就的。今天,来和大家聊聊当年我的…...

AI推介-多模态视觉语言模型VLMs论文速览(arXiv方向):2024.05.25-2024.05.31

文章目录~ 1.Empowering Visual Creativity: A Vision-Language Assistant to Image Editing Recommendations2.Bootstrap3D: Improving 3D Content Creation with Synthetic Data3.Video-MME: The First-Ever Comprehensive Evaluation Benchmark of Multi-modal L…...

如何通过Python SMTP配置示例发附件邮件?

Python SMTP配置的步骤?SMTP服务器的优缺点有哪些? 当我们需要发送包含附件的邮件时,自动化的解决方案显得尤为重要。Python提供了SMTP库,使我们能够轻松配置并发送带有附件的邮件。AokSend将通过一个示例来展示如何操作&#xf…...

amd64

MD64,或"x64",是一种64位元的电脑处理器架构。它是基于现有32位元的x86架构,由AMD公司所开发,应用AMD64指令集的自家产品有Athlon(速龙) 64、Athlon 64 FX、Athlon 64 X2、Turion(炫龙) 64、Opteron(皓龙)、Sempron(闪龙…...

2024如何优化SEO?

在2024年的今天,要问我会如何优化seo,我会专注于几个关键的方面。首先,随着AI内容生成技术的发展,我会利用这些工具来帮助创建或优化我的网站内容,但是,随着谷歌3月份的算法更新,纯粹的ai内容可…...

【NoSQL数据库】Redis命令、持久化、主从复制

Redis命令、持久化、主从复制 redis配置 Redis命令、持久化、主从复制Redis数据类型redis数据库常用命令redis多数据库常用命令1、多数据库间切换2、多数据库间移动数据3、清除数据库内数据 key命令1、keys 命令2、判断键值是否存在exists3、删除当前数据库的指定key del4、获取…...

使用Django JWT实现身份验证

文章目录 安装依赖配置Django设置创建API生成和验证Token总结与展望 在现代Web应用程序中,安全性和身份验证是至关重要的。JSON Web Token(JWT)是一种流行的身份验证方法,它允许在客户端和服务器之间安全地传输信息。Django是一个…...

MT2084 检测敌人

思路: 1. 以装置为中心->以敌人为中心。 以敌人为中心,r为半径做圆,与x轴交于a,b点,则在[a,b]之间的装置都能覆盖此敌人。 每个敌人都有[a,b]区间,则此题转化为:有多少个装置能覆盖到这些[a,b]区间。…...

支持向量机、随机森林、K最近邻和逻辑回归-九五小庞

支持向量机(Support Vector Machine, SVM)、随机森林(Random Forest)、K最近邻(K-Nearest Neighbors, KNN)和逻辑回归(Logistic Regression)是机器学习和统计学习中常用的分类算法。…...

MySQL—多表查询—多表关系介绍

一、引言 提到查询,我们想到之前学习的单表查询(DQL语句)。而这一章节部分的博客我们将要去学习和了解多表查询。 对于多表查询,主要从以下7个方面进行学习。 (1)第一部分:介绍 1、多表关系 2、…...

Golang dig框架与GraphQL的完美结合

将 Go 的 Dig 依赖注入框架与 GraphQL 结合使用,可以显著提升应用程序的可维护性、可测试性以及灵活性。 Dig 是一个强大的依赖注入容器,能够帮助开发者更好地管理复杂的依赖关系,而 GraphQL 则是一种用于 API 的查询语言,能够提…...

cf2117E

原题链接&#xff1a;https://codeforces.com/contest/2117/problem/E 题目背景&#xff1a; 给定两个数组a,b&#xff0c;可以执行多次以下操作&#xff1a;选择 i (1 < i < n - 1)&#xff0c;并设置 或&#xff0c;也可以在执行上述操作前执行一次删除任意 和 。求…...

Linux-07 ubuntu 的 chrome 启动不了

文章目录 问题原因解决步骤一、卸载旧版chrome二、重新安装chorme三、启动不了&#xff0c;报错如下四、启动不了&#xff0c;解决如下 总结 问题原因 在应用中可以看到chrome&#xff0c;但是打不开(说明&#xff1a;原来的ubuntu系统出问题了&#xff0c;这个是备用的硬盘&a…...

ArcGIS Pro制作水平横向图例+多级标注

今天介绍下载ArcGIS Pro中如何设置水平横向图例。 之前我们介绍了ArcGIS的横向图例制作&#xff1a;ArcGIS横向、多列图例、顺序重排、符号居中、批量更改图例符号等等&#xff08;ArcGIS出图图例8大技巧&#xff09;&#xff0c;那这次我们看看ArcGIS Pro如何更加快捷的操作。…...

Java线上CPU飙高问题排查全指南

一、引言 在Java应用的线上运行环境中&#xff0c;CPU飙高是一个常见且棘手的性能问题。当系统出现CPU飙高时&#xff0c;通常会导致应用响应缓慢&#xff0c;甚至服务不可用&#xff0c;严重影响用户体验和业务运行。因此&#xff0c;掌握一套科学有效的CPU飙高问题排查方法&…...

高效线程安全的单例模式:Python 中的懒加载与自定义初始化参数

高效线程安全的单例模式:Python 中的懒加载与自定义初始化参数 在软件开发中,单例模式(Singleton Pattern)是一种常见的设计模式,确保一个类仅有一个实例,并提供一个全局访问点。在多线程环境下,实现单例模式时需要注意线程安全问题,以防止多个线程同时创建实例,导致…...

AI病理诊断七剑下天山,医疗未来触手可及

一、病理诊断困局&#xff1a;刀尖上的医学艺术 1.1 金标准背后的隐痛 病理诊断被誉为"诊断的诊断"&#xff0c;医生需通过显微镜观察组织切片&#xff0c;在细胞迷宫中捕捉癌变信号。某省病理质控报告显示&#xff0c;基层医院误诊率达12%-15%&#xff0c;专家会诊…...

【电力电子】基于STM32F103C8T6单片机双极性SPWM逆变(硬件篇)

本项目是基于 STM32F103C8T6 微控制器的 SPWM(正弦脉宽调制)电源模块,能够生成可调频率和幅值的正弦波交流电源输出。该项目适用于逆变器、UPS电源、变频器等应用场景。 供电电源 输入电压采集 上图为本设计的电源电路,图中 D1 为二极管, 其目的是防止正负极电源反接, …...

解析两阶段提交与三阶段提交的核心差异及MySQL实现方案

引言 在分布式系统的事务处理中&#xff0c;如何保障跨节点数据操作的一致性始终是核心挑战。经典的两阶段提交协议&#xff08;2PC&#xff09;通过准备阶段与提交阶段的协调机制&#xff0c;以同步决策模式确保事务原子性。其改进版本三阶段提交协议&#xff08;3PC&#xf…...

密码学基础——SM4算法

博客主页&#xff1a;christine-rr-CSDN博客 ​​​​专栏主页&#xff1a;密码学 &#x1f4cc; 【今日更新】&#x1f4cc; 对称密码算法——SM4 目录 一、国密SM系列算法概述 二、SM4算法 2.1算法背景 2.2算法特点 2.3 基本部件 2.3.1 S盒 2.3.2 非线性变换 ​编辑…...