当前位置: 首页 > article >正文

LeetCode - Google 大模型校招10题 第1天 Attention 汇总 (3题)

欢迎关注我的CSDN:https://spike.blog.csdn.net/
本文地址:https://spike.blog.csdn.net/article/details/145368666


GQA
GroupQueryAttention(分组查询注意力机制) 和 KVCache(键值缓存) 是大语言模型中的常见架构,GroupQueryAttention 是注意力机制的变体,通过将查询(Query)分组,每组与相同的键(Key)值(Value)交互,优化计算效率和性能,保持模型对于输入信息有效关注,减少计算资源的消耗,适用于处理大规模数据和复杂任务的场景。KVCache 是缓存机制,用于存储和快速检索键值对(KV),当模型处理新的输入(Q)时,直接从缓存中读取KV数据,无需重新计算,显著提高模型的推理速度和效率。GQA 与 KVCache 在提升模型性能和优化资源利用方面,都发挥着重要作用,结合使用可以进一步增强模型在实际应用中的表现。

从 MHA 到 GQA,再到 GQA+KVCache,简单实现,参考:

  • GQA:从头实现 LLaMA3 网络与推理流程
  • KVCache:GPT(Decoder Only) 类模型的 KV Cache 公式与原理

Scaled Dot-Product Attention (缩放点积注意力机制),也称单头自注意力机制,公式:
A t t e n t i o n ( Q , K , V ) = s o f t m a x ( Q K ⊤ d k ) V Attention(Q,K,V)=softmax(\frac{QK^{\top}}{\sqrt{d_{k}}})V Attention(Q,K,V)=softmax(dk QK)V

1. MultiHeadAttention

MultiHeadAttention (多头注意力机制),合计 43 行:

  1. __init__ 初始化 (10行):
    • 输入:heads(头数)、d_model(维度)、dropout (用于 scores)
    • 计算 d_k 每个 Head 的维度,即 d m o d e l = h e a d s × d k d_{model} = heads \times d_{k} dmodel=heads×dk
    • 线性层是 QKVO,Dropout 层
  2. attention 注意力 (10行):
    • q q q 的维度 [bs,h,s,d],与 k ⊤ k^{\top} k[bs,h,d,s],mm 之后 scores 是 [bs,h,s,s]
    • mask 的维度是 [bs,s,s],使用 unsqueeze(1),转换成 [bs,1,s,s]
    • QKV 的计算,额外支持 Dropout
  3. forward 推理 (12行):
    • QKV Linear 转换成 [bs,s,h,dk],再转换 [bs,h,s,dk]
    • 计算 attn 的 [bs,h,s,dk]
    • 转换 [bs,s,h,dk],再 contiguous(),再 合并 h × d k = d h \times d_{k} = d h×dk=d
    • 再过 O
  4. 测试 (11行):
    • torch.randn 构建数据
    • Mask 的 torch.tril(torch.ones(bs, s, s))

即:

import math
import torch
import torch.nn.functional as F
from torch import nn
class MultiHeadAttention(nn.Module):"""多头自注意力机制 MultiHeadAttention"""def __init__(self, heads, d_model, dropout=0.1):  # 10行super().__init__()self.d_model = d_modelself.d_k = d_model // headsself.h = headsself.q_linear = nn.Linear(d_model, d_model)self.k_linear = nn.Linear(d_model, d_model)self.v_linear = nn.Linear(d_model, d_model)self.out = nn.Linear(d_model, d_model)self.dropout = nn.Dropout(dropout)@staticmethoddef attention(q, k, v, d_k, mask=None, dropout=None):  # 10行scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(d_k)# 掩盖掉那些为了填补长度增加的单元,使其通过 softmax 计算后为 0if mask is not None:mask = mask.unsqueeze(1)scores = scores.masked_fill(mask == 0, -1e9)scores = F.softmax(scores, dim=-1)if dropout is not None:scores = dropout(scores)output = torch.matmul(scores, v)return outputdef forward(self, q, k, v, mask=None):  # 12行bs = q.size(0)# 进行线性操作划分为成 h 个头k = self.k_linear(k).view(bs, -1, self.h, self.d_k)q = self.q_linear(q).view(bs, -1, self.h, self.d_k)v = self.v_linear(v).view(bs, -1, self.h, self.d_k)# 矩阵转置k = k.transpose(1, 2)  # [bs,h,s,d] = [2, 8, 10, 64]q = q.transpose(1, 2)v = v.transpose(1, 2)# 计算 attentionattn = self.attention(q, k, v, self.d_k, mask, self.dropout)print(f"[Info] attn: {attn.shape}")# 连接多个头并输入到最后的线性层concat = attn.transpose(1, 2).contiguous().view(bs, -1, self.d_model)output = self.out(concat)return output
def main():# 设置超参数bs, s, h, d = 2, 10, 8, 512dropout_rate = 0.1# 创建 MultiHeadAttention 实例attention = MultiHeadAttention(h, d, dropout_rate)# 创建随机输入张量q = torch.randn(bs, s, d)k = torch.randn(bs, s, d)v = torch.randn(bs, s, d)# 可选:创建掩码,因果掩码,上三角矩阵mask = torch.tril(torch.ones(bs, s, s))# 测试无掩码的情况output_no_mask = attention(q, k, v)print("Output shape without mask:", output_no_mask.shape)# 测试有掩码的情况output_with_mask = attention(q, k, v, mask)print("Output shape with mask:", output_with_mask.shape)# 检查输出是否符合预期assert output_no_mask.shape == (bs, s, d), "Output shape is incorrect without mask"assert output_with_mask.shape == (bs, s, d), "Output shape is incorrect with mask"print("Test passed!")
if __name__ == '__main__':main()

2. GroupQueryAttention

GroupQueryAttention (分组查询注意力机制),相比于 MHA,参考 torch.nn.functional.scaled_dot_product_attention

  1. __init__ :增加参数 kv_heads,即 KV Head 数量,KV 的 Linear 层输出维度(kv_heads * self.d_k)也需要修改。
  2. forward:使用 repeat_interleave 扩充 KV 维度,其他相同,增加 3 行。

即:

import math
import torch
import torch.nn.functional as F
from torch import nn
class GroupQueryAttention(nn.Module):"""分组查询注意力机制(Group Query Attention)"""def __init__(self, heads, d_model, kv_heads, dropout=0.1):super().__init__()self.d_model = d_modelself.d_k = d_model // headsself.h = headsself.kv_heads = kv_headsself.q_linear = nn.Linear(d_model, d_model)self.k_linear = nn.Linear(d_model, kv_heads * self.d_k)self.v_linear = nn.Linear(d_model, kv_heads * self.d_k)self.out = nn.Linear(d_model, d_model)self.dropout = nn.Dropout(dropout)@staticmethoddef attention(q, k, v, d_k, mask=None, dropout=None):# [2, 8, 10, 64] x [2, 8, 64, 10] = [2, 8, 10, 10]scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(d_k)# 掩盖掉那些为了填补长度增加的单元,使其通过 softmax 计算后为 0if mask is not None:mask = mask.unsqueeze(1)scores = scores.masked_fill(mask == 0, -1e9)scores = F.softmax(scores, dim=-1)if dropout is not None:scores = dropout(scores)output = torch.matmul(scores, v)return outputdef forward(self, q, k, v, mask=None):bs = q.size(0)# 进行线性操作q = self.q_linear(q).view(bs, -1, self.h, self.d_k)  # [2, 10, 8, 64]k = self.k_linear(k).view(bs, -1, self.kv_heads, self.d_k)  # [2, 10, 4, 64]v = self.v_linear(v).view(bs, -1, self.kv_heads, self.d_k)# 复制键值头以匹配查询头的数量group = self.h // self.kv_headsk = k.repeat_interleave(group, dim=2)  # [2, 10, 4, 64] -> [2, 10, 8, 64]v = v.repeat_interleave(group, dim=2)# 矩阵转置, 将 head 在前k = k.transpose(1, 2)  # [2, 8, 10, 64]q = q.transpose(1, 2)v = v.transpose(1, 2)# 计算 attentionattn = self.attention(q, k, v, self.d_k, mask, self.dropout)# 连接多个头并输入到最后的线性层concat = attn.transpose(1, 2).contiguous().view(bs, -1, self.d_model)output = self.out(concat)return output
def main():# 设置超参数, GQA 8//4=2组bs, s, h, d, kv_heads = 2, 10, 8, 512, 4dropout_rate = 0.1# 创建 MultiHeadAttention 实例attention = GroupQueryAttention(h, d, kv_heads, dropout_rate)# 创建随机输入张量q = torch.randn(bs, s, d)k = torch.randn(bs, s, d)v = torch.randn(bs, s, d)# 可选:创建掩码,因果掩码,上三角矩阵mask = torch.tril(torch.ones(bs, s, s))# 测试无掩码的情况output_no_mask = attention(q, k, v)print("Output shape without mask:", output_no_mask.shape)# 测试有掩码的情况output_with_mask = attention(q, k, v, mask)print("Output shape with mask:", output_with_mask.shape)# 检查输出是否符合预期assert output_no_mask.shape == (bs, s, d), "Output shape is incorrect without mask"assert output_with_mask.shape == (bs, s, d), "Output shape is incorrect with mask"print("Test passed!")
if __name__ == '__main__':main()

3. GQA + KVCache

GroupQueryAttention + KVCache,相比于 GQA,增加 KVCache:

  1. forward :增加参数 kv_cache,合并 [cached_k, new_k],同时返回 new_kv_cache,用于迭代,增加 5 行。
  2. 设置 cur_qkvcur_mask,迭代序列s维度,合计 8 行。

即:

import math
import torch
import torch.nn.functional as F
from torch import nn
class GroupQueryAttention(nn.Module):"""分组查询注意力机制(Group Query Attention)"""def __init__(self, heads, d_model, kv_heads, dropout=0.1):super().__init__()self.d_model = d_modelself.d_k = d_model // headsself.h = headsself.kv_heads = kv_headsself.q_linear = nn.Linear(d_model, d_model)self.k_linear = nn.Linear(d_model, kv_heads * self.d_k)self.v_linear = nn.Linear(d_model, kv_heads * self.d_k)self.out = nn.Linear(d_model, d_model)self.dropout = nn.Dropout(dropout)@staticmethoddef attention(q, k, v, d_k, mask=None, dropout=None):# [2, 8, 1, 64] x [2, 8, 64, 10] = [2, 8, 1, 10]scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(d_k)# 掩盖掉那些为了填补长度增加的单元,使其通过 softmax 计算后为 0if mask is not None:mask = mask.unsqueeze(1)scores = scores.masked_fill(mask == 0, -1e9)scores = F.softmax(scores, dim=-1)if dropout is not None:scores = dropout(scores)output = torch.matmul(scores, v)return outputdef forward(self, q, k, v, mask=None, kv_cache=None):bs = q.size(0)# 进行线性操作q = self.q_linear(q).view(bs, -1, self.h, self.d_k)  # [2, 1, 8, 64]new_k = self.k_linear(k).view(bs, -1, self.kv_heads, self.d_k)  # [2, 1, 4, 64]new_v = self.v_linear(v).view(bs, -1, self.kv_heads, self.d_k)  # [2, 1, 4, 64]# 处理 KV Cacheif kv_cache is not None:cached_k, cached_v = kv_cachenew_k = torch.cat([cached_k, new_k], dim=1)new_v = torch.cat([cached_v, new_v], dim=1)# 复制键值头以匹配查询头的数量group = self.h // self.kv_headsk = new_k.repeat_interleave(group, dim=2)  # [2, 10, 4, 64] -> [2, 10, 8, 64]v = new_v.repeat_interleave(group, dim=2)# 矩阵转置, 将 head 在前# KV Cache 最后1轮: q—>[2, 8, 1, 64] k->[2, 8, 10, 64] v->[2, 8, 10, 64]k = k.transpose(1, 2)  # [2, 8, 10, 64]q = q.transpose(1, 2)v = v.transpose(1, 2)# 计算 attentionattn = self.attention(q, k, v, self.d_k, mask, self.dropout)  # [2, 8, 1, 64]print(f"[Info] attn: {attn.shape}")# 连接多个头并输入到最后的线性层concat = attn.transpose(1, 2).contiguous().view(bs, -1, self.d_model)output = self.out(concat)# 更新 KV Cachenew_kv_cache = (new_k, new_v)  # 当前的 KV 缓存return output, new_kv_cache
def main():# 设置超参数bs, s, h, d, kv_heads = 2, 10, 8, 512, 4dropout_rate = 0.1# 创建 GroupQueryAttention 实例attention = GroupQueryAttention(h, d, kv_heads, dropout_rate)# 创建随机输入张量q = torch.randn(bs, s, d)k = torch.randn(bs, s, d)v = torch.randn(bs, s, d)# 可选:创建掩码,因果掩码,上三角矩阵mask = torch.tril(torch.ones(bs, s, s))# 模拟逐步生成序列,测试 KV Cacheprint("Testing KV Cache...")kv_cache, output = None, Nonefor i in range(s):cur_q = q[:, i:i+1, :]cur_k = k[:, i:i+1, :]cur_v = v[:, i:i+1, :]cur_mask = mask[:, i:i+1, :i+1]   # q是 i:i+1,k是 :i+1output, kv_cache = attention(cur_q, cur_k, cur_v, cur_mask, kv_cache)print(f"Output shape at step {i}:", output.shape)# 检查输出是否符合预期assert output.shape == (bs, 1, d), "Output shape is incorrect when using KV Cache"print("Test passed!")
if __name__ == "__main__":main()

相关文章:

LeetCode - Google 大模型校招10题 第1天 Attention 汇总 (3题)

欢迎关注我的CSDN:https://spike.blog.csdn.net/ 本文地址:https://spike.blog.csdn.net/article/details/145368666 GroupQueryAttention(分组查询注意力机制) 和 KVCache(键值缓存) 是大语言模型中的常见架构,GroupQueryAttention 是注意力…...

人工智能研究报告:技术、应用与未来趋势洞察

一、引言 1.1 研究背景 在当今科技飞速发展的时代,人工智能(Artificial Intelligence,简称 AI)已成为最为关键的技术领域之一。它犹如一股强大的变革力量,正深刻地重塑着各行业的发展格局,对社会的各个层…...

Kotlin开发(七):对象表达式、对象声明和委托的奥秘

Kotlin 让代码更优雅! 每个程序员都希望写出优雅高效的代码,但现实往往不尽人意。对象表达式、对象声明和 Kotlin 委托正是为了解决代码中的复杂性而诞生的。为什么选择这个主题?因为它不仅是 Kotlin 语言的亮点之一,还能极大地提…...

数据库、数据仓库、数据湖有什么不同

数据库、数据仓库和数据湖是三种不同的数据存储和管理技术,它们在用途、设计目标、数据处理方式以及适用场景上存在显著差异。以下将从多个角度详细说明它们之间的区别: 1. 数据结构与存储方式 数据库: 数据库主要用于存储结构化的数据&…...

【2024年华为OD机试】 (B卷,100分)- 字符串摘要(JavaScriptJava PythonC/C++)

一、问题描述 题目描述 给定一个字符串的摘要算法,请输出给定字符串的摘要值。具体步骤如下: 去除字符串中非字母的符号:只保留字母字符。处理连续字符:如果出现连续字符(不区分大小写),则输…...

DIY QMK量子键盘

最近放假了,趁这个空余在做一个分支项目,一款机械键盘,量子键盘取自固件名称QMK(Quantum Mechanical Keyboard)。 键盘作为计算机或其他电子设备的重要输入设备之一,通过将按键的物理动作转换为数字信号&am…...

mamba论文学习

rnn 1986 训练速度慢 testing很快 但是很快就忘了 lstm 1997 训练速度慢 testing很快 但是也会忘(序列很长的时候) GRU实在lstm的基础上改进,改变了一些门 transformer2017 训练很快,testing慢些,时间复杂度高&am…...

智慧消防营区一体化安全管控 2024 年度深度剖析与展望

在 2024 年,智慧消防营区一体化安全管控领域取得了令人瞩目的进展,成为保障营区安全稳定运行的关键力量。这一年,行业在政策驱动、技术创新应用、实践成果及合作交流等方面呈现出多元且深刻的发展态势,同时也面临着一系列亟待解决…...

解锁微服务:五大进阶业务场景深度剖析

目录 医疗行业:智能诊疗的加速引擎 电商领域:数据依赖的破局之道 金融行业:运维可观测性的提升之路 物流行业:智慧物流的创新架构 综合业务:服务依赖的优化策略 医疗行业:智能诊疗的加速引擎 在医疗行业迈…...

pip 安装 numpy 报错 AttributeError: module ‘pkgutil‘ has no attribute ‘ImpImporter‘

conda 环境下 pip 安装 numpy 1.x 版本&#xff0c;报如下错误 File "C:\Users\UserName\AppData\Local\Temp\pip-build-env-_lgbq70y\overlay\Lib\site-packages\pkg_resources\__init__.py", line 2191, in <module>register_finder(pkgutil.ImpImporter, fi…...

javascript-es6 (一)

作用域&#xff08;scope&#xff09; 规定了变量能够被访问的“范围”&#xff0c;离开了这个“范围”变量便不能被访问 局部作用域 函数作用域&#xff1a; 在函数内部声明的变量只能在函数内部被访问&#xff0c;外部无法直接访问 function getSum(){ //函数内部是函数作用…...

Babylon.js 中的 setHardwareScalingLevel和getHardwareScalingLevel:作用与配合修改内容

在 Babylon.js 中&#xff0c;Engine类提供了setHardwareScalingLevel和getHardwareScalingLevel方法&#xff0c;用于管理和调整渲染分辨率与屏幕分辨率的比例。这些方法在优化性能和提升画质方面非常有用。尤其是在某些平台不支持硬件反锯齿时&#xff0c;可以考虑使用setHar…...

二十三种设计模式-桥接模式

桥接模式&#xff08;Bridge Pattern&#xff09;是一种结构型设计模式&#xff0c;其核心思想是将抽象与实现解耦&#xff0c;让它们可以独立变化。桥接模式主要用于解决类的继承问题&#xff0c;避免由于继承而带来的类层次结构过于复杂和难以维护的问题。 1. 核心概念 桥接…...

jenkins-k8s pod方式动态生成slave节点

一. 简述&#xff1a; 使用 Jenkins 和 Kubernetes (k8s) 动态生成 Slave 节点是一种高效且灵活的方式来管理 CI/CD 流水线。通过这种方式&#xff0c;Jenkins 可以根据需要在 Kubernetes 集群中创建和销毁 Pod 来执行任务&#xff0c;从而充分利用集群资源并实现更好的隔离性…...

代码工艺:实践 Spring Boot TDD 测试驱动开发

TDD 的核心理念是 “先写测试&#xff0c;再写功能”&#xff0c;其过程遵循一个严格的循环&#xff0c;即 Red-Green-Refactor&#xff1a; TDD 的流程 1. Red&#xff08;编写失败的测试&#xff09; 根据需求&#xff0c;先编写一个测试用例&#xff0c;描述期望的行为。…...

【云安全】云原生-K8S-简介

K8S简介 Kubernetes&#xff08;简称K8S&#xff09;是一种开源的容器编排平台&#xff0c;用于管理容器化应用的部署、扩展和运维。它由Google于2014年开源并交给CNCF&#xff08;Cloud Native Computing Foundation&#xff09;维护。K8S通过提供自动化、灵活的功能&#xf…...

aws(学习笔记第二十六课) 使用AWS Elastic Beanstalk

aws(学习笔记第二十六课) 使用aws Elastic Beanstalk 学习内容&#xff1a; AWS Elastic Beanstalk整体架构AWS Elastic Beanstalk的hands onAWS Elastic Beanstalk部署node.js程序包练习使用AWS Elastic Beanstalk的ebcli 1. AWS Elastic Beanstalk整体架构 官方的guide AWS…...

反向代理模块。。

1 概念 1.1 反向代理概念 反向代理是指以代理服务器来接收客户端的请求&#xff0c;然后将请求转发给内部网络上的服务器&#xff0c;将从服务器上得到的结果返回给客户端&#xff0c;此时代理服务器对外表现为一个反向代理服务器。 对于客户端来说&#xff0c;反向代理就相当于…...

C语言的灵魂——指针(1)

指针是C语言的灵魂&#xff0c;有了指针C语言才能完成一些复杂的程序&#xff1b;没了指针就相当于C语言最精髓的部分被去掉了&#xff0c;可见指针是多么重要。废话不多讲我们直接开始。 指针 一&#xff0c;内存和地址二&#xff0c;编址三&#xff0c;指针变量和地址1&#…...

14-6-2C++STL的list

(一&#xff09;list对象的带参数构造 1.list&#xff08;elem);//构造函数将n个elem拷贝给本身 #include <iostream> #include <list> using namespace std; int main() { list<int> lst(3,7); list<int>::iterator it; for(itlst.begi…...

Linux 基础1

gcc的编译过程 预处理——编译——汇编——链接 Linux文件类型 普通文件&#xff0c;目录文件&#xff0c;管道文件&#xff0c;链接文件&#xff0c;块设备文件&#xff0c;字符设备文件&#xff0c;套接字文件 Linux系统下的软链接和硬链接有什么异同 linux中软链接和硬…...

Ubuntu Server 安装 XFCE4桌面

Ubuntu Server没有桌面环境&#xff0c;一些软件有桌面环境使用起来才更加方便&#xff0c;所以我尝试安装桌面环境。常用的桌面环境有&#xff1a;GNOME、KDE Plasma、XFCE4等。这里我选择安装XFCE4桌面环境&#xff0c;主要因为它是一个极轻量级的桌面环境&#xff0c;适合内…...

xarray转换nc文件经度范围:0-360更改为-180-180

原文见https://blog.csdn.net/weixin_44237337/article/details/119707332&#xff0c;因为觉得很实用就转载一下。 lon_name longitude #你的nc文件中经度的命名 ds[longitude_adjusted] xr.where(ds[lon_name] > 180,ds[lon_name] - 360,ds[lon_name]) ds (ds.swap_d…...

MySQL 基础学习(1):数据类型与操作数据库和数据表

MySQL 基础学习&#xff1a;数据类型与操作数据库和数据表 在这篇博客中&#xff0c;我们将深入学习 MySQL 的基础操作&#xff0c;重点关注数据库和数据表的操作&#xff0c;以及 MySQL 中常见的数据类型。希望本文能帮助你更好地理解和掌握 MySQL 的基本用法。 一、操作数据…...

一个简单的自适应html5导航模板

一个简单的 HTML 导航模板示例&#xff0c;它包含基本的导航栏结构&#xff0c;同时使用了 CSS 进行样式美化&#xff0c;让导航栏看起来更美观。另外&#xff0c;还添加了一些 JavaScript 代码&#xff0c;用于在移动端实现导航菜单的展开和收起功能。 PHP <!DOCTYPE htm…...

深入解析“Wholesome”的含义及用法

深入解析“Wholesome”的含义及用法 一、引言 在阅读英文材料时&#xff0c;我们经常会遇到一些词汇&#xff0c;它们的含义既有直接的字面意思&#xff0c;又带有丰富的情感色彩。“Wholesome”就是这样一个词。它表面上看似简单&#xff0c;但在不同语境中却有多重内涵。在…...

供水企业满意度调查报告

民安智库作为一家独立的第三方评估机构&#xff0c;致力于为供水企业提供全面的客户满意度调查服务。本报告将详细介绍民安智库的调查方法、结果和建议&#xff0c;以帮助供水企业更好地理解其服务质量和客户需求。 一、调查方法 民安智库首先对客户进行了分类&#xff0c;包括…...

实现B-树

一、概述 1.历史 B树&#xff08;B-Tree&#xff09;结构是一种高效存储和查询数据的方法&#xff0c;它的历史可以追溯到1970年代早期。B树的发明人Rudolf Bayer和Edward M. McCreight分别发表了一篇论文介绍了B树。这篇论文是1972年发表于《ACM Transactions on Database S…...

无人机微波图像传输数据链技术详解

无人机微波图像传输数据链技术是无人机通信系统中的关键组成部分&#xff0c;它确保了无人机与地面站之间高效、可靠的图像数据传输。以下是对该技术的详细解析&#xff1a; 一、技术原理 无人机微波图像传输数据链主要基于微波通信技术实现。在数据链路中&#xff0c;图像数…...

【Leetcode 热题 100】300. 最长递增子序列

问题背景 给你一个整数数组 n u m s nums nums&#xff0c;找到其中最长严格递增子序列的长度。 子序列 是由数组派生而来的序列&#xff0c;删除&#xff08;或不删除&#xff09;数组中的元素而不改变其余元素的顺序。例如&#xff0c; [ 3 , 6 , 2 , 7 ] [3,6,2,7] [3,6,2…...