LeetCode - Google 大模型校招10题 第1天 Attention 汇总 (3题)
欢迎关注我的CSDN:https://spike.blog.csdn.net/
本文地址:https://spike.blog.csdn.net/article/details/145368666

GroupQueryAttention(分组查询注意力机制) 和 KVCache(键值缓存) 是大语言模型中的常见架构,GroupQueryAttention 是注意力机制的变体,通过将查询(Query)分组,每组与相同的键(Key)值(Value)交互,优化计算效率和性能,保持模型对于输入信息有效关注,减少计算资源的消耗,适用于处理大规模数据和复杂任务的场景。KVCache 是缓存机制,用于存储和快速检索键值对(KV),当模型处理新的输入(Q)时,直接从缓存中读取KV数据,无需重新计算,显著提高模型的推理速度和效率。GQA 与 KVCache 在提升模型性能和优化资源利用方面,都发挥着重要作用,结合使用可以进一步增强模型在实际应用中的表现。
从 MHA 到 GQA,再到 GQA+KVCache,简单实现,参考:
- GQA:从头实现 LLaMA3 网络与推理流程
- KVCache:GPT(Decoder Only) 类模型的 KV Cache 公式与原理
Scaled Dot-Product Attention (缩放点积注意力机制),也称单头自注意力机制,公式:
A t t e n t i o n ( Q , K , V ) = s o f t m a x ( Q K ⊤ d k ) V Attention(Q,K,V)=softmax(\frac{QK^{\top}}{\sqrt{d_{k}}})V Attention(Q,K,V)=softmax(dkQK⊤)V
1. MultiHeadAttention
MultiHeadAttention (多头注意力机制),合计 43 行:
__init__初始化 (10行):- 输入:
heads(头数)、d_model(维度)、dropout(用于scores) - 计算
d_k每个 Head 的维度,即 d m o d e l = h e a d s × d k d_{model} = heads \times d_{k} dmodel=heads×dk - 线性层是 QKVO,Dropout 层
- 输入:
attention注意力 (10行):- q q q 的维度
[bs,h,s,d],与 k ⊤ k^{\top} k⊤ 的[bs,h,d,s],mm 之后 scores 是[bs,h,s,s] - mask 的维度是
[bs,s,s],使用unsqueeze(1),转换成[bs,1,s,s] - QKV 的计算,额外支持 Dropout
- q q q 的维度
forward推理 (12行):- QKV Linear 转换成
[bs,s,h,dk],再转换[bs,h,s,dk] - 计算 attn 的
[bs,h,s,dk] - 转换
[bs,s,h,dk],再contiguous(),再 合并 h × d k = d h \times d_{k} = d h×dk=d - 再过 O
- QKV Linear 转换成
- 测试 (11行):
torch.randn构建数据- Mask 的
torch.tril(torch.ones(bs, s, s))
即:
import math
import torch
import torch.nn.functional as F
from torch import nn
class MultiHeadAttention(nn.Module):"""多头自注意力机制 MultiHeadAttention"""def __init__(self, heads, d_model, dropout=0.1): # 10行super().__init__()self.d_model = d_modelself.d_k = d_model // headsself.h = headsself.q_linear = nn.Linear(d_model, d_model)self.k_linear = nn.Linear(d_model, d_model)self.v_linear = nn.Linear(d_model, d_model)self.out = nn.Linear(d_model, d_model)self.dropout = nn.Dropout(dropout)@staticmethoddef attention(q, k, v, d_k, mask=None, dropout=None): # 10行scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(d_k)# 掩盖掉那些为了填补长度增加的单元,使其通过 softmax 计算后为 0if mask is not None:mask = mask.unsqueeze(1)scores = scores.masked_fill(mask == 0, -1e9)scores = F.softmax(scores, dim=-1)if dropout is not None:scores = dropout(scores)output = torch.matmul(scores, v)return outputdef forward(self, q, k, v, mask=None): # 12行bs = q.size(0)# 进行线性操作划分为成 h 个头k = self.k_linear(k).view(bs, -1, self.h, self.d_k)q = self.q_linear(q).view(bs, -1, self.h, self.d_k)v = self.v_linear(v).view(bs, -1, self.h, self.d_k)# 矩阵转置k = k.transpose(1, 2) # [bs,h,s,d] = [2, 8, 10, 64]q = q.transpose(1, 2)v = v.transpose(1, 2)# 计算 attentionattn = self.attention(q, k, v, self.d_k, mask, self.dropout)print(f"[Info] attn: {attn.shape}")# 连接多个头并输入到最后的线性层concat = attn.transpose(1, 2).contiguous().view(bs, -1, self.d_model)output = self.out(concat)return output
def main():# 设置超参数bs, s, h, d = 2, 10, 8, 512dropout_rate = 0.1# 创建 MultiHeadAttention 实例attention = MultiHeadAttention(h, d, dropout_rate)# 创建随机输入张量q = torch.randn(bs, s, d)k = torch.randn(bs, s, d)v = torch.randn(bs, s, d)# 可选:创建掩码,因果掩码,上三角矩阵mask = torch.tril(torch.ones(bs, s, s))# 测试无掩码的情况output_no_mask = attention(q, k, v)print("Output shape without mask:", output_no_mask.shape)# 测试有掩码的情况output_with_mask = attention(q, k, v, mask)print("Output shape with mask:", output_with_mask.shape)# 检查输出是否符合预期assert output_no_mask.shape == (bs, s, d), "Output shape is incorrect without mask"assert output_with_mask.shape == (bs, s, d), "Output shape is incorrect with mask"print("Test passed!")
if __name__ == '__main__':main()
2. GroupQueryAttention
GroupQueryAttention (分组查询注意力机制),相比于 MHA,参考 torch.nn.functional.scaled_dot_product_attention
__init__:增加参数kv_heads,即 KV Head 数量,KV 的 Linear 层输出维度(kv_heads * self.d_k)也需要修改。forward:使用repeat_interleave扩充 KV 维度,其他相同,增加 3 行。
即:
import math
import torch
import torch.nn.functional as F
from torch import nn
class GroupQueryAttention(nn.Module):"""分组查询注意力机制(Group Query Attention)"""def __init__(self, heads, d_model, kv_heads, dropout=0.1):super().__init__()self.d_model = d_modelself.d_k = d_model // headsself.h = headsself.kv_heads = kv_headsself.q_linear = nn.Linear(d_model, d_model)self.k_linear = nn.Linear(d_model, kv_heads * self.d_k)self.v_linear = nn.Linear(d_model, kv_heads * self.d_k)self.out = nn.Linear(d_model, d_model)self.dropout = nn.Dropout(dropout)@staticmethoddef attention(q, k, v, d_k, mask=None, dropout=None):# [2, 8, 10, 64] x [2, 8, 64, 10] = [2, 8, 10, 10]scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(d_k)# 掩盖掉那些为了填补长度增加的单元,使其通过 softmax 计算后为 0if mask is not None:mask = mask.unsqueeze(1)scores = scores.masked_fill(mask == 0, -1e9)scores = F.softmax(scores, dim=-1)if dropout is not None:scores = dropout(scores)output = torch.matmul(scores, v)return outputdef forward(self, q, k, v, mask=None):bs = q.size(0)# 进行线性操作q = self.q_linear(q).view(bs, -1, self.h, self.d_k) # [2, 10, 8, 64]k = self.k_linear(k).view(bs, -1, self.kv_heads, self.d_k) # [2, 10, 4, 64]v = self.v_linear(v).view(bs, -1, self.kv_heads, self.d_k)# 复制键值头以匹配查询头的数量group = self.h // self.kv_headsk = k.repeat_interleave(group, dim=2) # [2, 10, 4, 64] -> [2, 10, 8, 64]v = v.repeat_interleave(group, dim=2)# 矩阵转置, 将 head 在前k = k.transpose(1, 2) # [2, 8, 10, 64]q = q.transpose(1, 2)v = v.transpose(1, 2)# 计算 attentionattn = self.attention(q, k, v, self.d_k, mask, self.dropout)# 连接多个头并输入到最后的线性层concat = attn.transpose(1, 2).contiguous().view(bs, -1, self.d_model)output = self.out(concat)return output
def main():# 设置超参数, GQA 8//4=2组bs, s, h, d, kv_heads = 2, 10, 8, 512, 4dropout_rate = 0.1# 创建 MultiHeadAttention 实例attention = GroupQueryAttention(h, d, kv_heads, dropout_rate)# 创建随机输入张量q = torch.randn(bs, s, d)k = torch.randn(bs, s, d)v = torch.randn(bs, s, d)# 可选:创建掩码,因果掩码,上三角矩阵mask = torch.tril(torch.ones(bs, s, s))# 测试无掩码的情况output_no_mask = attention(q, k, v)print("Output shape without mask:", output_no_mask.shape)# 测试有掩码的情况output_with_mask = attention(q, k, v, mask)print("Output shape with mask:", output_with_mask.shape)# 检查输出是否符合预期assert output_no_mask.shape == (bs, s, d), "Output shape is incorrect without mask"assert output_with_mask.shape == (bs, s, d), "Output shape is incorrect with mask"print("Test passed!")
if __name__ == '__main__':main()
3. GQA + KVCache
GroupQueryAttention + KVCache,相比于 GQA,增加 KVCache:
forward:增加参数kv_cache,合并[cached_k, new_k],同时返回new_kv_cache,用于迭代,增加 5 行。- 设置
cur_qkv与cur_mask,迭代序列s维度,合计 8 行。
即:
import math
import torch
import torch.nn.functional as F
from torch import nn
class GroupQueryAttention(nn.Module):"""分组查询注意力机制(Group Query Attention)"""def __init__(self, heads, d_model, kv_heads, dropout=0.1):super().__init__()self.d_model = d_modelself.d_k = d_model // headsself.h = headsself.kv_heads = kv_headsself.q_linear = nn.Linear(d_model, d_model)self.k_linear = nn.Linear(d_model, kv_heads * self.d_k)self.v_linear = nn.Linear(d_model, kv_heads * self.d_k)self.out = nn.Linear(d_model, d_model)self.dropout = nn.Dropout(dropout)@staticmethoddef attention(q, k, v, d_k, mask=None, dropout=None):# [2, 8, 1, 64] x [2, 8, 64, 10] = [2, 8, 1, 10]scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(d_k)# 掩盖掉那些为了填补长度增加的单元,使其通过 softmax 计算后为 0if mask is not None:mask = mask.unsqueeze(1)scores = scores.masked_fill(mask == 0, -1e9)scores = F.softmax(scores, dim=-1)if dropout is not None:scores = dropout(scores)output = torch.matmul(scores, v)return outputdef forward(self, q, k, v, mask=None, kv_cache=None):bs = q.size(0)# 进行线性操作q = self.q_linear(q).view(bs, -1, self.h, self.d_k) # [2, 1, 8, 64]new_k = self.k_linear(k).view(bs, -1, self.kv_heads, self.d_k) # [2, 1, 4, 64]new_v = self.v_linear(v).view(bs, -1, self.kv_heads, self.d_k) # [2, 1, 4, 64]# 处理 KV Cacheif kv_cache is not None:cached_k, cached_v = kv_cachenew_k = torch.cat([cached_k, new_k], dim=1)new_v = torch.cat([cached_v, new_v], dim=1)# 复制键值头以匹配查询头的数量group = self.h // self.kv_headsk = new_k.repeat_interleave(group, dim=2) # [2, 10, 4, 64] -> [2, 10, 8, 64]v = new_v.repeat_interleave(group, dim=2)# 矩阵转置, 将 head 在前# KV Cache 最后1轮: q—>[2, 8, 1, 64] k->[2, 8, 10, 64] v->[2, 8, 10, 64]k = k.transpose(1, 2) # [2, 8, 10, 64]q = q.transpose(1, 2)v = v.transpose(1, 2)# 计算 attentionattn = self.attention(q, k, v, self.d_k, mask, self.dropout) # [2, 8, 1, 64]print(f"[Info] attn: {attn.shape}")# 连接多个头并输入到最后的线性层concat = attn.transpose(1, 2).contiguous().view(bs, -1, self.d_model)output = self.out(concat)# 更新 KV Cachenew_kv_cache = (new_k, new_v) # 当前的 KV 缓存return output, new_kv_cache
def main():# 设置超参数bs, s, h, d, kv_heads = 2, 10, 8, 512, 4dropout_rate = 0.1# 创建 GroupQueryAttention 实例attention = GroupQueryAttention(h, d, kv_heads, dropout_rate)# 创建随机输入张量q = torch.randn(bs, s, d)k = torch.randn(bs, s, d)v = torch.randn(bs, s, d)# 可选:创建掩码,因果掩码,上三角矩阵mask = torch.tril(torch.ones(bs, s, s))# 模拟逐步生成序列,测试 KV Cacheprint("Testing KV Cache...")kv_cache, output = None, Nonefor i in range(s):cur_q = q[:, i:i+1, :]cur_k = k[:, i:i+1, :]cur_v = v[:, i:i+1, :]cur_mask = mask[:, i:i+1, :i+1] # q是 i:i+1,k是 :i+1output, kv_cache = attention(cur_q, cur_k, cur_v, cur_mask, kv_cache)print(f"Output shape at step {i}:", output.shape)# 检查输出是否符合预期assert output.shape == (bs, 1, d), "Output shape is incorrect when using KV Cache"print("Test passed!")
if __name__ == "__main__":main()
相关文章:
LeetCode - Google 大模型校招10题 第1天 Attention 汇总 (3题)
欢迎关注我的CSDN:https://spike.blog.csdn.net/ 本文地址:https://spike.blog.csdn.net/article/details/145368666 GroupQueryAttention(分组查询注意力机制) 和 KVCache(键值缓存) 是大语言模型中的常见架构,GroupQueryAttention 是注意力…...
Vue3 provide/inject用法总结
1. 基本概念 provide/inject 是 Vue3 中实现跨层级组件通信的方案,类似于 React 的 Context。它允许父组件向其所有子孙组件注入依赖,无论层级有多深。 1.1 基本语法 // 提供方(父组件) const value ref(hello) provide(key, …...
Linux——网络基础(1)
文章目录 目录 文章目录 前言 一、文件传输协议 应用层 传输层 网络层 数据链路层 数据接收与解封装 主机与网卡 数据传输过程示意 二、IP和MAC地址 定义与性质 地址格式 分配方式 作用范围 可见性与可获取性 生活例子 定义 用途 特点 联系 四、TCP和UDP协…...
【记录】日常|从零散记录到博客之星Top300的成长之路
文章目录 shandianchengzi 2024 年度盘点概述写作风格简介2024年的创作内容总结 shandianchengzi 2024 年度盘点 概述 2024年及2025年至今我创作了786即84篇文章,加上这篇就是85篇。 很荣幸这次居然能够入选博客之星Top300,这个排名在我之前的所有年份…...
【二分查找】力扣373. 查找和最小的 K 对数字
给定两个以 非递减顺序排列 的整数数组 nums1 和 nums2 , 以及一个整数 k 。 定义一对值 (u,v),其中第一个元素来自 nums1,第二个元素来自 nums2 。 请找到和最小的 k 个数对 (u1,v1), (u2,v2) … (uk,vk) 。 示例 1: 输入: nums1 [1,7,11], nums2 …...
池化层Pooling Layer
1. 定义 池化是对特征图进行的一种压缩操作,通过在一个小的局部区域内进行汇总统计,用一个值来代表这个区域的特征信息,常用于卷积神经网络(CNN)中。 2. 作用 提取代表性信息的同时降低特征维度,具有平移…...
力扣算法题——11.盛最多水的容器
目录 💕1.题目 💕2.解析思路 本题思路总览 借助双指针探索规律 从规律到代码实现的转化 双指针的具体实现 代码整体流程 💕3.代码实现 💕4.完结 二十七步也能走完逆流河吗 💕1.题目 💕2.解析思路…...
自由学习记录(32)
文件里找到切换颜色空间 fgui中的 颜色空间是一种总体使用前的设定 颜色空间,和半透明混合产生的效果有差异,这种问题一般可以产生联系 动效就是在fgui里可以编辑好,然后在unity中也准备了对应的调用手段,可以详细的使用每一个具…...
VScode+Latex (Recipe terminated with fatal error: spawn xelatex ENOENT)
使用VSCode编辑出现Recipe terminated with fatal error: spawn xelatex ENOENT问题咋办? 很好解决,大概率的原因是因为latex没有添加到系统环境变量中,所有设置的编译工具没有办法找到才出现的这种情况。 解决方法: winR 然后输…...
「蓝桥杯题解」蜗牛(Java)
题目链接 这道题我感觉状态定义不太好想,需要一定的经验 import java.util.*; /*** 蜗牛* 状态定义:* dp[i][0]:到达(x[i],0)最小时间* dp[i][1]:到达 xi 上方的传送门最小时间*/public class Main {static Scanner in new Scanner(System.in);static f…...
PHP EOF (Heredoc) 详解
PHP EOF (Heredoc) 详解 PHP 中的 EOF(End Of File)是一种非常有用的语法特性,允许开发者创建多行字符串。它特别适合于创建格式化文本,如配置文件、HTML 模板等。本文将详细讲解 PHP EOF 的用法、优势以及注意事项。 什么是 EOF? EOF 是一种特殊的字符串定义方式,它允…...
pyautogui操控Acrobat DC pro万能PDF转Word,不丢任何PDF格式样式
为了将PDF转换脚本改为多进程异步处理,我们需要确保每个进程独立操作不同的Acrobat窗口。以下是实现步骤: 实现代码 import os import pyautogui import time import subprocess import pygetwindow as gw from multiprocessing import Pooldef conver…...
Day32:字符串的复制
在 Python 中,字符串的复制是指创建一个新的字符串,它的内容与原字符串相同。字符串是不可变的对象,这意味着你不能直接修改字符串的内容,但是可以通过复制来创建新的字符串进行操作。字符串的复制在一些情况下非常有用࿰…...
基于Mybatis继承AbstractRoutingDataSource使用自定义注解实现动态数据源
一:实现 方式一:继承AbstractRoutingDataSource使用自定义注解实现 环境:springboot3 MyBatis3 mysql-connector8 DataSourceKeyEnum枚举类 有几个数据源就配置几个枚举类,和数据源数量一一对应 class DataSourceKeyEnum{D…...
ZooKeeper 数据模型
ZooKeeper 数据模型 ZooKeeper 拥有层次化的命名空间,类似分布式文件系统,但每个节点不仅能有子节点,还可关联数据。节点路径为规范的绝对路径,用斜杠分隔,无相对引用。路径命名有如下约束: 路径名不能包…...
【VUE】Vue2中Vue.extend方法
在 Vue.js 2.x 版本中,Vue.extend() 方法被用于创建一个新的 Vue 子类,可以在该子类上扩展一些属性、指令和组件选项等,然后进行实例化。 比如,可以在创建一些类似 loading 式的函数式插件时,使用: 在 Vue…...
MaskGAE论文阅读
What’s Behind the Mask: Understanding Masked Graph Modeling for Graph Autoencoders 碎碎念:一篇论文看四天,效率也没谁了(捂脸) 看一点忘一点,虽然在本子上有记录,但还是忘,下次看一点在博客上记一点启发 本来很…...
Mybatis-plus 更新 Null 的策略踩坑记
一个bug 在一个管理页面,有一个非必填字段被设置成空了并提交更新,再次打开的时候,发现字段还在,并没有被更新成功。 使用的数据库映射框架是 Mybatis-plus ,对于Mybatis 在更新字段的时候会对空进行校验,…...
Oracle迁移DM数据库
Oracle迁移DM数据库 本文记录使用达梦官方数据迁移工具DTS,将Oracle数据库的数据迁移至达梦数据库。 1 数据准备 2 DTS工具操作步骤 2.1 创建工程 打开DTS迁移工具,点击新建工程,填写好工程信息,如图: 2.2 新建迁…...
HTML特殊符号的使用示例
目录 一、基本特殊符号的使用 1、空格符号: 2、小于号 和 大于号: 3、引号: 二、版权、注册商标符号的使用 1、版权符号:© 2、注册商标符号: 三、数学符号的使用 四、箭头符号的使用 五、货币符号的使用…...
OriginPro 2023保姆级教程:三步搞定柱状图+点线图组合,让你的科研图表颜值飙升
OriginPro 2023科研图表优化实战:从基础绘图到期刊级组合图表 科研图表是学术论文的"门面",一张精心设计的图表往往能让审稿人和读者眼前一亮。OriginPro作为科研绘图领域的标杆工具,其2023版本在图表组合和视觉优化方面带来了诸多…...
终极指南:如何使用dnstwist与模糊哈希精准识别钓鱼网站攻击
终极指南:如何使用dnstwist与模糊哈希精准识别钓鱼网站攻击 【免费下载链接】dnstwist Domain name permutation engine for detecting homograph phishing attacks, typo squatting, and brand impersonation 项目地址: https://gitcode.com/gh_mirrors/dn/dnstw…...
OpenClaw 实用指南-节假日系统巡检全自动化(下)
前言 在上一篇文章中,我们已详细讲解了节假日系统巡检全自动化的前三个核心部分,分别是:Part1:AI节假日智能判断、Part2:目标服务器稳定连接、Part3:借助“小龙虾”工具批量部署软件,并利用部署…...
【DIY小记】解决MacOS上Edge浏览器bilibili全屏卡顿的问题
近日笔者发现自己Macbook-Pro播放B站视频,全屏的时候必然卡顿,退出全屏就没事。笔者电脑的参数是: 芯片:M3系统:Tahoe 26.4浏览器:Edge 到网上一查发现《Edge浏览器在MacOS 26(Tahoe)系统上看B站卡顿》一…...
科技企业如何利用智能手段提升研发效率?
观点作者:科易网-国家科技成果转化(厦门)示范基地 现状概述:传统研发模式的瓶颈与挑战 在全球科技创新加速迭代的背景下,科技企业面临的核心挑战之一是如何提升研发效率。传统研发模式往往存在以下痛点: 信…...
【硬件小达人-基础篇(1)】-电阻那些事儿
文章目录什么是电阻电阻的功率一定要降额使用电阻的额定电压和精度额定电压精度PCB设计中,电阻的作用1.限流电阻保护敏感元件常用经验2.分压电阻电压反馈ADC采集电路一些经验3.分流电阻4.上拉电阻/下拉电阻什么是上下拉作用一、 防止引脚悬空,消除外部干…...
memtest_vulkan显存检测终极指南:从问题识别到健康管理
memtest_vulkan显存检测终极指南:从问题识别到健康管理 【免费下载链接】memtest_vulkan Vulkan compute tool for testing video memory stability 项目地址: https://gitcode.com/gh_mirrors/me/memtest_vulkan 一、显存故障问题识别:图形渲染异…...
nVisual设备板卡关联
在线模型库导入:ODF-12x2 这个型号的设备打开模型库点左侧模型搜索需要添加板卡设备型号,点击建模双击板卡搜索板卡名称点击绿色按钮添加添加完成点应用到实例...
Linunx常用命令
一. 通用1.1系统单元启动# 创建系统用户,不允许登录,不创建 home 目录 sudo useradd -r -s /sbin/nologin xxl-job#将 /data/middleware/xxl-job 目录的归属权改为 xxl-job 用户: sudo chown -R xxl-job:xxl-job /data/middleware/xxl-job#检…...
N_m3u8DL-RE:跨平台流媒体解决方案的全方位技术指南
N_m3u8DL-RE:跨平台流媒体解决方案的全方位技术指南 【免费下载链接】N_m3u8DL-RE Cross-Platform, modern and powerful stream downloader for MPD/M3U8/ISM. English/简体中文/繁體中文. 项目地址: https://gitcode.com/GitHub_Trending/nm3/N_m3u8DL-RE …...
