当前位置: 首页 > news >正文

LeetCode - Google 大模型校招10题 第1天 Attention 汇总 (3题)

欢迎关注我的CSDN:https://spike.blog.csdn.net/
本文地址:https://spike.blog.csdn.net/article/details/145368666


GQA
GroupQueryAttention(分组查询注意力机制) 和 KVCache(键值缓存) 是大语言模型中的常见架构,GroupQueryAttention 是注意力机制的变体,通过将查询(Query)分组,每组与相同的键(Key)值(Value)交互,优化计算效率和性能,保持模型对于输入信息有效关注,减少计算资源的消耗,适用于处理大规模数据和复杂任务的场景。KVCache 是缓存机制,用于存储和快速检索键值对(KV),当模型处理新的输入(Q)时,直接从缓存中读取KV数据,无需重新计算,显著提高模型的推理速度和效率。GQA 与 KVCache 在提升模型性能和优化资源利用方面,都发挥着重要作用,结合使用可以进一步增强模型在实际应用中的表现。

从 MHA 到 GQA,再到 GQA+KVCache,简单实现,参考:

  • GQA:从头实现 LLaMA3 网络与推理流程
  • KVCache:GPT(Decoder Only) 类模型的 KV Cache 公式与原理

Scaled Dot-Product Attention (缩放点积注意力机制),也称单头自注意力机制,公式:
A t t e n t i o n ( Q , K , V ) = s o f t m a x ( Q K ⊤ d k ) V Attention(Q,K,V)=softmax(\frac{QK^{\top}}{\sqrt{d_{k}}})V Attention(Q,K,V)=softmax(dk QK)V

1. MultiHeadAttention

MultiHeadAttention (多头注意力机制),合计 43 行:

  1. __init__ 初始化 (10行):
    • 输入:heads(头数)、d_model(维度)、dropout (用于 scores)
    • 计算 d_k 每个 Head 的维度,即 d m o d e l = h e a d s × d k d_{model} = heads \times d_{k} dmodel=heads×dk
    • 线性层是 QKVO,Dropout 层
  2. attention 注意力 (10行):
    • q q q 的维度 [bs,h,s,d],与 k ⊤ k^{\top} k[bs,h,d,s],mm 之后 scores 是 [bs,h,s,s]
    • mask 的维度是 [bs,s,s],使用 unsqueeze(1),转换成 [bs,1,s,s]
    • QKV 的计算,额外支持 Dropout
  3. forward 推理 (12行):
    • QKV Linear 转换成 [bs,s,h,dk],再转换 [bs,h,s,dk]
    • 计算 attn 的 [bs,h,s,dk]
    • 转换 [bs,s,h,dk],再 contiguous(),再 合并 h × d k = d h \times d_{k} = d h×dk=d
    • 再过 O
  4. 测试 (11行):
    • torch.randn 构建数据
    • Mask 的 torch.tril(torch.ones(bs, s, s))

即:

import math
import torch
import torch.nn.functional as F
from torch import nn
class MultiHeadAttention(nn.Module):"""多头自注意力机制 MultiHeadAttention"""def __init__(self, heads, d_model, dropout=0.1):  # 10行super().__init__()self.d_model = d_modelself.d_k = d_model // headsself.h = headsself.q_linear = nn.Linear(d_model, d_model)self.k_linear = nn.Linear(d_model, d_model)self.v_linear = nn.Linear(d_model, d_model)self.out = nn.Linear(d_model, d_model)self.dropout = nn.Dropout(dropout)@staticmethoddef attention(q, k, v, d_k, mask=None, dropout=None):  # 10行scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(d_k)# 掩盖掉那些为了填补长度增加的单元,使其通过 softmax 计算后为 0if mask is not None:mask = mask.unsqueeze(1)scores = scores.masked_fill(mask == 0, -1e9)scores = F.softmax(scores, dim=-1)if dropout is not None:scores = dropout(scores)output = torch.matmul(scores, v)return outputdef forward(self, q, k, v, mask=None):  # 12行bs = q.size(0)# 进行线性操作划分为成 h 个头k = self.k_linear(k).view(bs, -1, self.h, self.d_k)q = self.q_linear(q).view(bs, -1, self.h, self.d_k)v = self.v_linear(v).view(bs, -1, self.h, self.d_k)# 矩阵转置k = k.transpose(1, 2)  # [bs,h,s,d] = [2, 8, 10, 64]q = q.transpose(1, 2)v = v.transpose(1, 2)# 计算 attentionattn = self.attention(q, k, v, self.d_k, mask, self.dropout)print(f"[Info] attn: {attn.shape}")# 连接多个头并输入到最后的线性层concat = attn.transpose(1, 2).contiguous().view(bs, -1, self.d_model)output = self.out(concat)return output
def main():# 设置超参数bs, s, h, d = 2, 10, 8, 512dropout_rate = 0.1# 创建 MultiHeadAttention 实例attention = MultiHeadAttention(h, d, dropout_rate)# 创建随机输入张量q = torch.randn(bs, s, d)k = torch.randn(bs, s, d)v = torch.randn(bs, s, d)# 可选:创建掩码,因果掩码,上三角矩阵mask = torch.tril(torch.ones(bs, s, s))# 测试无掩码的情况output_no_mask = attention(q, k, v)print("Output shape without mask:", output_no_mask.shape)# 测试有掩码的情况output_with_mask = attention(q, k, v, mask)print("Output shape with mask:", output_with_mask.shape)# 检查输出是否符合预期assert output_no_mask.shape == (bs, s, d), "Output shape is incorrect without mask"assert output_with_mask.shape == (bs, s, d), "Output shape is incorrect with mask"print("Test passed!")
if __name__ == '__main__':main()

2. GroupQueryAttention

GroupQueryAttention (分组查询注意力机制),相比于 MHA,参考 torch.nn.functional.scaled_dot_product_attention

  1. __init__ :增加参数 kv_heads,即 KV Head 数量,KV 的 Linear 层输出维度(kv_heads * self.d_k)也需要修改。
  2. forward:使用 repeat_interleave 扩充 KV 维度,其他相同,增加 3 行。

即:

import math
import torch
import torch.nn.functional as F
from torch import nn
class GroupQueryAttention(nn.Module):"""分组查询注意力机制(Group Query Attention)"""def __init__(self, heads, d_model, kv_heads, dropout=0.1):super().__init__()self.d_model = d_modelself.d_k = d_model // headsself.h = headsself.kv_heads = kv_headsself.q_linear = nn.Linear(d_model, d_model)self.k_linear = nn.Linear(d_model, kv_heads * self.d_k)self.v_linear = nn.Linear(d_model, kv_heads * self.d_k)self.out = nn.Linear(d_model, d_model)self.dropout = nn.Dropout(dropout)@staticmethoddef attention(q, k, v, d_k, mask=None, dropout=None):# [2, 8, 10, 64] x [2, 8, 64, 10] = [2, 8, 10, 10]scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(d_k)# 掩盖掉那些为了填补长度增加的单元,使其通过 softmax 计算后为 0if mask is not None:mask = mask.unsqueeze(1)scores = scores.masked_fill(mask == 0, -1e9)scores = F.softmax(scores, dim=-1)if dropout is not None:scores = dropout(scores)output = torch.matmul(scores, v)return outputdef forward(self, q, k, v, mask=None):bs = q.size(0)# 进行线性操作q = self.q_linear(q).view(bs, -1, self.h, self.d_k)  # [2, 10, 8, 64]k = self.k_linear(k).view(bs, -1, self.kv_heads, self.d_k)  # [2, 10, 4, 64]v = self.v_linear(v).view(bs, -1, self.kv_heads, self.d_k)# 复制键值头以匹配查询头的数量group = self.h // self.kv_headsk = k.repeat_interleave(group, dim=2)  # [2, 10, 4, 64] -> [2, 10, 8, 64]v = v.repeat_interleave(group, dim=2)# 矩阵转置, 将 head 在前k = k.transpose(1, 2)  # [2, 8, 10, 64]q = q.transpose(1, 2)v = v.transpose(1, 2)# 计算 attentionattn = self.attention(q, k, v, self.d_k, mask, self.dropout)# 连接多个头并输入到最后的线性层concat = attn.transpose(1, 2).contiguous().view(bs, -1, self.d_model)output = self.out(concat)return output
def main():# 设置超参数, GQA 8//4=2组bs, s, h, d, kv_heads = 2, 10, 8, 512, 4dropout_rate = 0.1# 创建 MultiHeadAttention 实例attention = GroupQueryAttention(h, d, kv_heads, dropout_rate)# 创建随机输入张量q = torch.randn(bs, s, d)k = torch.randn(bs, s, d)v = torch.randn(bs, s, d)# 可选:创建掩码,因果掩码,上三角矩阵mask = torch.tril(torch.ones(bs, s, s))# 测试无掩码的情况output_no_mask = attention(q, k, v)print("Output shape without mask:", output_no_mask.shape)# 测试有掩码的情况output_with_mask = attention(q, k, v, mask)print("Output shape with mask:", output_with_mask.shape)# 检查输出是否符合预期assert output_no_mask.shape == (bs, s, d), "Output shape is incorrect without mask"assert output_with_mask.shape == (bs, s, d), "Output shape is incorrect with mask"print("Test passed!")
if __name__ == '__main__':main()

3. GQA + KVCache

GroupQueryAttention + KVCache,相比于 GQA,增加 KVCache:

  1. forward :增加参数 kv_cache,合并 [cached_k, new_k],同时返回 new_kv_cache,用于迭代,增加 5 行。
  2. 设置 cur_qkvcur_mask,迭代序列s维度,合计 8 行。

即:

import math
import torch
import torch.nn.functional as F
from torch import nn
class GroupQueryAttention(nn.Module):"""分组查询注意力机制(Group Query Attention)"""def __init__(self, heads, d_model, kv_heads, dropout=0.1):super().__init__()self.d_model = d_modelself.d_k = d_model // headsself.h = headsself.kv_heads = kv_headsself.q_linear = nn.Linear(d_model, d_model)self.k_linear = nn.Linear(d_model, kv_heads * self.d_k)self.v_linear = nn.Linear(d_model, kv_heads * self.d_k)self.out = nn.Linear(d_model, d_model)self.dropout = nn.Dropout(dropout)@staticmethoddef attention(q, k, v, d_k, mask=None, dropout=None):# [2, 8, 1, 64] x [2, 8, 64, 10] = [2, 8, 1, 10]scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(d_k)# 掩盖掉那些为了填补长度增加的单元,使其通过 softmax 计算后为 0if mask is not None:mask = mask.unsqueeze(1)scores = scores.masked_fill(mask == 0, -1e9)scores = F.softmax(scores, dim=-1)if dropout is not None:scores = dropout(scores)output = torch.matmul(scores, v)return outputdef forward(self, q, k, v, mask=None, kv_cache=None):bs = q.size(0)# 进行线性操作q = self.q_linear(q).view(bs, -1, self.h, self.d_k)  # [2, 1, 8, 64]new_k = self.k_linear(k).view(bs, -1, self.kv_heads, self.d_k)  # [2, 1, 4, 64]new_v = self.v_linear(v).view(bs, -1, self.kv_heads, self.d_k)  # [2, 1, 4, 64]# 处理 KV Cacheif kv_cache is not None:cached_k, cached_v = kv_cachenew_k = torch.cat([cached_k, new_k], dim=1)new_v = torch.cat([cached_v, new_v], dim=1)# 复制键值头以匹配查询头的数量group = self.h // self.kv_headsk = new_k.repeat_interleave(group, dim=2)  # [2, 10, 4, 64] -> [2, 10, 8, 64]v = new_v.repeat_interleave(group, dim=2)# 矩阵转置, 将 head 在前# KV Cache 最后1轮: q—>[2, 8, 1, 64] k->[2, 8, 10, 64] v->[2, 8, 10, 64]k = k.transpose(1, 2)  # [2, 8, 10, 64]q = q.transpose(1, 2)v = v.transpose(1, 2)# 计算 attentionattn = self.attention(q, k, v, self.d_k, mask, self.dropout)  # [2, 8, 1, 64]print(f"[Info] attn: {attn.shape}")# 连接多个头并输入到最后的线性层concat = attn.transpose(1, 2).contiguous().view(bs, -1, self.d_model)output = self.out(concat)# 更新 KV Cachenew_kv_cache = (new_k, new_v)  # 当前的 KV 缓存return output, new_kv_cache
def main():# 设置超参数bs, s, h, d, kv_heads = 2, 10, 8, 512, 4dropout_rate = 0.1# 创建 GroupQueryAttention 实例attention = GroupQueryAttention(h, d, kv_heads, dropout_rate)# 创建随机输入张量q = torch.randn(bs, s, d)k = torch.randn(bs, s, d)v = torch.randn(bs, s, d)# 可选:创建掩码,因果掩码,上三角矩阵mask = torch.tril(torch.ones(bs, s, s))# 模拟逐步生成序列,测试 KV Cacheprint("Testing KV Cache...")kv_cache, output = None, Nonefor i in range(s):cur_q = q[:, i:i+1, :]cur_k = k[:, i:i+1, :]cur_v = v[:, i:i+1, :]cur_mask = mask[:, i:i+1, :i+1]   # q是 i:i+1,k是 :i+1output, kv_cache = attention(cur_q, cur_k, cur_v, cur_mask, kv_cache)print(f"Output shape at step {i}:", output.shape)# 检查输出是否符合预期assert output.shape == (bs, 1, d), "Output shape is incorrect when using KV Cache"print("Test passed!")
if __name__ == "__main__":main()

相关文章:

LeetCode - Google 大模型校招10题 第1天 Attention 汇总 (3题)

欢迎关注我的CSDN:https://spike.blog.csdn.net/ 本文地址:https://spike.blog.csdn.net/article/details/145368666 GroupQueryAttention(分组查询注意力机制) 和 KVCache(键值缓存) 是大语言模型中的常见架构,GroupQueryAttention 是注意力…...

Vue3 provide/inject用法总结

1. 基本概念 provide/inject 是 Vue3 中实现跨层级组件通信的方案,类似于 React 的 Context。它允许父组件向其所有子孙组件注入依赖,无论层级有多深。 1.1 基本语法 // 提供方(父组件) const value ref(hello) provide(key, …...

Linux——网络基础(1)

文章目录 目录 文章目录 前言 一、文件传输协议 应用层 传输层 网络层 数据链路层 数据接收与解封装 主机与网卡 数据传输过程示意 二、IP和MAC地址 定义与性质 地址格式 分配方式 作用范围 可见性与可获取性 生活例子 定义 用途 特点 联系 四、TCP和UDP协…...

【记录】日常|从零散记录到博客之星Top300的成长之路

文章目录 shandianchengzi 2024 年度盘点概述写作风格简介2024年的创作内容总结 shandianchengzi 2024 年度盘点 概述 2024年及2025年至今我创作了786即84篇文章,加上这篇就是85篇。 很荣幸这次居然能够入选博客之星Top300,这个排名在我之前的所有年份…...

【二分查找】力扣373. 查找和最小的 K 对数字

给定两个以 非递减顺序排列 的整数数组 nums1 和 nums2 , 以及一个整数 k 。 定义一对值 (u,v),其中第一个元素来自 nums1,第二个元素来自 nums2 。 请找到和最小的 k 个数对 (u1,v1), (u2,v2) … (uk,vk) 。 示例 1: 输入: nums1 [1,7,11], nums2 …...

池化层Pooling Layer

1. 定义 池化是对特征图进行的一种压缩操作,通过在一个小的局部区域内进行汇总统计,用一个值来代表这个区域的特征信息,常用于卷积神经网络(CNN)中。 2. 作用 提取代表性信息的同时降低特征维度,具有平移…...

力扣算法题——11.盛最多水的容器

目录 💕1.题目 💕2.解析思路 本题思路总览 借助双指针探索规律 从规律到代码实现的转化 双指针的具体实现 代码整体流程 💕3.代码实现 💕4.完结 二十七步也能走完逆流河吗 💕1.题目 💕2.解析思路…...

自由学习记录(32)

文件里找到切换颜色空间 fgui中的 颜色空间是一种总体使用前的设定 颜色空间,和半透明混合产生的效果有差异,这种问题一般可以产生联系 动效就是在fgui里可以编辑好,然后在unity中也准备了对应的调用手段,可以详细的使用每一个具…...

VScode+Latex (Recipe terminated with fatal error: spawn xelatex ENOENT)

使用VSCode编辑出现Recipe terminated with fatal error: spawn xelatex ENOENT问题咋办? 很好解决,大概率的原因是因为latex没有添加到系统环境变量中,所有设置的编译工具没有办法找到才出现的这种情况。 解决方法: winR 然后输…...

「蓝桥杯题解」蜗牛(Java)

题目链接 这道题我感觉状态定义不太好想,需要一定的经验 import java.util.*; /*** 蜗牛* 状态定义:* dp[i][0]:到达(x[i],0)最小时间* dp[i][1]:到达 xi 上方的传送门最小时间*/public class Main {static Scanner in new Scanner(System.in);static f…...

PHP EOF (Heredoc) 详解

PHP EOF (Heredoc) 详解 PHP 中的 EOF(End Of File)是一种非常有用的语法特性,允许开发者创建多行字符串。它特别适合于创建格式化文本,如配置文件、HTML 模板等。本文将详细讲解 PHP EOF 的用法、优势以及注意事项。 什么是 EOF? EOF 是一种特殊的字符串定义方式,它允…...

pyautogui操控Acrobat DC pro万能PDF转Word,不丢任何PDF格式样式

为了将PDF转换脚本改为多进程异步处理,我们需要确保每个进程独立操作不同的Acrobat窗口。以下是实现步骤: 实现代码 import os import pyautogui import time import subprocess import pygetwindow as gw from multiprocessing import Pooldef conver…...

Day32:字符串的复制

在 Python 中,字符串的复制是指创建一个新的字符串,它的内容与原字符串相同。字符串是不可变的对象,这意味着你不能直接修改字符串的内容,但是可以通过复制来创建新的字符串进行操作。字符串的复制在一些情况下非常有用&#xff0…...

基于Mybatis继承AbstractRoutingDataSource使用自定义注解实现动态数据源

一:实现 方式一:继承AbstractRoutingDataSource使用自定义注解实现 环境:springboot3 MyBatis3 mysql-connector8 DataSourceKeyEnum枚举类 有几个数据源就配置几个枚举类,和数据源数量一一对应 class DataSourceKeyEnum{D…...

ZooKeeper 数据模型

ZooKeeper 数据模型 ZooKeeper 拥有层次化的命名空间,类似分布式文件系统,但每个节点不仅能有子节点,还可关联数据。节点路径为规范的绝对路径,用斜杠分隔,无相对引用。路径命名有如下约束: 路径名不能包…...

【VUE】Vue2中Vue.extend方法

在 Vue.js 2.x 版本中,Vue.extend() 方法被用于创建一个新的 Vue 子类,可以在该子类上扩展一些属性、指令和组件选项等,然后进行实例化。 比如,可以在创建一些类似 loading 式的函数式插件时,使用: 在 Vue…...

MaskGAE论文阅读

What’s Behind the Mask: Understanding Masked Graph Modeling for Graph Autoencoders 碎碎念:一篇论文看四天,效率也没谁了(捂脸) 看一点忘一点,虽然在本子上有记录,但还是忘,下次看一点在博客上记一点启发 本来很…...

Mybatis-plus 更新 Null 的策略踩坑记

一个bug 在一个管理页面,有一个非必填字段被设置成空了并提交更新,再次打开的时候,发现字段还在,并没有被更新成功。 使用的数据库映射框架是 Mybatis-plus ,对于Mybatis 在更新字段的时候会对空进行校验,…...

Oracle迁移DM数据库

Oracle迁移DM数据库 本文记录使用达梦官方数据迁移工具DTS,将Oracle数据库的数据迁移至达梦数据库。 1 数据准备 2 DTS工具操作步骤 2.1 创建工程 打开DTS迁移工具,点击新建工程,填写好工程信息,如图: 2.2 新建迁…...

HTML特殊符号的使用示例

目录 一、基本特殊符号的使用 1、空格符号: 2、小于号 和 大于号: 3、引号: 二、版权、注册商标符号的使用 1、版权符号:© 2、注册商标符号: 三、数学符号的使用 四、箭头符号的使用 五、货币符号的使用…...

Mybatis逆向工程,动态创建实体类、条件扩展类、Mapper接口、Mapper.xml映射文件

今天呢,博主的学习进度也是步入了Java Mybatis 框架,目前正在逐步杨帆旗航。 那么接下来就给大家出一期有关 Mybatis 逆向工程的教学,希望能对大家有所帮助,也特别欢迎大家指点不足之处,小生很乐意接受正确的建议&…...

线程与协程

1. 线程与协程 1.1. “函数调用级别”的切换、上下文切换 1. 函数调用级别的切换 “函数调用级别的切换”是指:像函数调用/返回一样轻量地完成任务切换。 举例说明: 当你在程序中写一个函数调用: funcA() 然后 funcA 执行完后返回&…...

1688商品列表API与其他数据源的对接思路

将1688商品列表API与其他数据源对接时,需结合业务场景设计数据流转链路,重点关注数据格式兼容性、接口调用频率控制及数据一致性维护。以下是具体对接思路及关键技术点: 一、核心对接场景与目标 商品数据同步 场景:将1688商品信息…...

系统设计 --- MongoDB亿级数据查询优化策略

系统设计 --- MongoDB亿级数据查询分表策略 背景Solution --- 分表 背景 使用audit log实现Audi Trail功能 Audit Trail范围: 六个月数据量: 每秒5-7条audi log,共计7千万 – 1亿条数据需要实现全文检索按照时间倒序因为license问题,不能使用ELK只能使用…...

屋顶变身“发电站” ,中天合创屋面分布式光伏发电项目顺利并网!

5月28日,中天合创屋面分布式光伏发电项目顺利并网发电,该项目位于内蒙古自治区鄂尔多斯市乌审旗,项目利用中天合创聚乙烯、聚丙烯仓库屋面作为场地建设光伏电站,总装机容量为9.96MWp。 项目投运后,每年可节约标煤3670…...

Robots.txt 文件

什么是robots.txt? robots.txt 是一个位于网站根目录下的文本文件(如:https://example.com/robots.txt),它用于指导网络爬虫(如搜索引擎的蜘蛛程序)如何抓取该网站的内容。这个文件遵循 Robots…...

css的定位(position)详解:相对定位 绝对定位 固定定位

在 CSS 中,元素的定位通过 position 属性控制,共有 5 种定位模式:static(静态定位)、relative(相对定位)、absolute(绝对定位)、fixed(固定定位)和…...

【C++从零实现Json-Rpc框架】第六弹 —— 服务端模块划分

一、项目背景回顾 前五弹完成了Json-Rpc协议解析、请求处理、客户端调用等基础模块搭建。 本弹重点聚焦于服务端的模块划分与架构设计,提升代码结构的可维护性与扩展性。 二、服务端模块设计目标 高内聚低耦合:各模块职责清晰,便于独立开发…...

Mobile ALOHA全身模仿学习

一、题目 Mobile ALOHA:通过低成本全身远程操作学习双手移动操作 传统模仿学习(Imitation Learning)缺点:聚焦与桌面操作,缺乏通用任务所需的移动性和灵活性 本论文优点:(1)在ALOHA…...

纯 Java 项目(非 SpringBoot)集成 Mybatis-Plus 和 Mybatis-Plus-Join

纯 Java 项目(非 SpringBoot)集成 Mybatis-Plus 和 Mybatis-Plus-Join 1、依赖1.1、依赖版本1.2、pom.xml 2、代码2.1、SqlSession 构造器2.2、MybatisPlus代码生成器2.3、获取 config.yml 配置2.3.1、config.yml2.3.2、项目配置类 2.4、ftl 模板2.4.1、…...