当前位置: 首页 > news >正文

Grouped-Query Attention(GQA)详解: Pytorch实现

Grouped-Query Attention(GQA)详解


Grouped-Query Attention(GQA)Multi-Query Attention(MQA) 的改进版,它通过在 多个查询头(Query Heads)之间共享 Key 和 Value,在 Multi-Head Attention(MHA)MQA 之间找到了一种折中方案。GQA 旨在在 推理速度模型质量 之间取得更好的平衡,减少 MQA 带来的模型质量下降问题,同时仍然保留比 MHA 更快的推理速度。

在这里插入图片描述
Source: https://arxiv.org/pdf/2305.13245


1. 为什么需要 Grouped-Query Attention?

在理解 GQA 之前,我们先回顾 MHA 和 MQA 的核心区别。

(1) Multi-Head Attention(MHA)

  • 每个 Query 头都有独立的 Key 和 Value
  • 优势
    • 允许不同的 Query 头关注不同的 Key-Value 信息,提高模型的表达能力。
    • 更适合复杂任务,如长序列建模和复杂推理任务。
  • 劣势
    • 推理速度慢,因为在每一步都要存储和读取 所有 Query 头的 Key 和 Value,导致 KV 缓存(KV Cache)非常大,占用大量显存和内存带宽。

(2) Multi-Query Attention(MQA)

  • 所有 Query 头共享相同的 Key 和 Value
  • 优势
    • 推理速度快,因为只需要存储和读取一个 Key-Value 组,而不是多个。
    • 显存占用低,适用于 大规模语言模型推理(如 ChatGPT)
  • 劣势
    • 不同 Query 头会关注相同的信息,导致模型表达能力下降,尤其是在长序列建模任务上(如机器翻译、摘要生成)。
    • 可能导致训练不稳定,特别是长序列输入时,训练容易出现 Loss spikes(损失值剧烈波动)

(3) GQA 的改进点

Grouped-Query Attention(GQA) 介于 MHA 和 MQA 之间:

  • GQA 不是让所有 Query 头共享同一个 Key-Value,而是分组共享
  • 假设一个模型有 8 个 Query 头
    • MHA:8 个 Query 头,每个头有自己的 Key 和 Value。
    • MQA:8 个 Query 头,所有头共享 1 组 Key 和 Value。
    • GQA(例如 GQA-4):8 个 Query 头被分成 4 组,每组共享一组 Key 和 Value。

因此,GQA 允许:

  • 部分 Query 头共享 Key-Value,但仍然保持了一定的多样性。
  • 推理速度比 MHA 快,但比 MQA 慢
  • 模型质量比 MQA 高,但比 MHA 略低

2. GQA 的数学表达

假设:

  • h 是 Query 头的总数(如 8)。
  • G 是 GQA 分组的数量(如 G=4)。
  • k, v 分别是 Key 和 Value 的维度。

对于 MHA:
Q h = X P Q , h , K h = M P K , h , V h = M P V , h Q_h = X P_{Q,h}, \quad K_h = M P_{K,h}, \quad V_h = M P_{V,h} Qh=XPQ,h,Kh=MPK,h,Vh=MPV,h
logits h = Q h K h T , weights h = softmax ( logits h ) \text{logits}_h = Q_h K_h^T, \quad \text{weights}_h = \text{softmax}(\text{logits}_h) logitsh=QhKhT,weightsh=softmax(logitsh)
O h = weights h V h , Y = ∑ h O h P O , h O_h = \text{weights}_h V_h, \quad Y = \sum_{h} O_h P_{O,h} Oh=weightshVh,Y=hOhPO,h

对于 MQA:
Q h = X P Q , h , K = M P K , V = M P V Q_h = X P_{Q,h}, \quad K = M P_K, \quad V = M P_V Qh=XPQ,h,K=MPK,V=MPV
logits h = Q h K T , weights h = softmax ( logits h ) \text{logits}_h = Q_h K^T, \quad \text{weights}_h = \text{softmax}(\text{logits}_h) logitsh=QhKT,weightsh=softmax(logitsh)
O h = weights h V , Y = ∑ h O h P O , h O_h = \text{weights}_h V, \quad Y = \sum_{h} O_h P_{O,h} Oh=weightshV,Y=hOhPO,h

对于 GQA(分组共享 K/V)
Q h = X P Q , h , K g = M P K , g , V g = M P V , g , g = ⌊ h / G ⌋ Q_h = X P_{Q,h}, \quad K_g = M P_{K,g}, \quad V_g = M P_{V,g}, \quad g = \lfloor h/G \rfloor Qh=XPQ,h,Kg=MPK,g,Vg=MPV,g,g=h/G
logits h = Q h K g T , weights h = softmax ( logits h ) \text{logits}_h = Q_h K_g^T, \quad \text{weights}_h = \text{softmax}(\text{logits}_h) logitsh=QhKgT,weightsh=softmax(logitsh)
O h = weights h V g , Y = ∑ h O h P O , h O_h = \text{weights}_h V_g, \quad Y = \sum_{h} O_h P_{O,h} Oh=weightshVg,Y=hOhPO,h

其中:

  • 在 GQA 中,每个 Query 头属于一个组 ( g g g ),每个组 共享 Key 和 Value
  • 当 ( G = 1 G = 1 G=1 ) 时,GQA 退化为 MQA。
  • 当 ( G = h G = h G=h ) 时,GQA 退化为 MHA。

3. 代码解析

GQA 代码与 MQA 类似,只是 Key 和 Value 现在是 按组分配的

def GroupedQueryAttention(X, M, mask, P_q, P_k, P_v, P_o, num_groups):"""Grouped-Query Attention 实现Args:X: 输入查询 [b, n, d]M: 输入键值存储 [b, m, d]mask: 注意力掩码 [b, h, n, m]P_q: 查询投影矩阵 [h, d, k]P_k: 共享键投影矩阵 [num_groups, d, k]P_v: 共享值投影矩阵 [num_groups, d, v]P_o: 输出投影矩阵 [h, d, v]Returns:Y: 输出张量 [b, n, d]"""# 计算 QueryQ = tf.einsum("bnd, hdk->bhnk", X, P_q)# 计算 Key 和 Value,每个组共享K = tf.einsum("bmd, gdk->bmgk", M, P_k)  # g = num_groupsV = tf.einsum("bmd, gdv->bmgv", M, P_v)# 计算注意力 logitslogits = tf.einsum("bhnk, bmgk->bhng", Q, K)# 计算 softmax 权重weights = tf.nn.softmax(logits + mask)# 计算最终的加权 ValueO = tf.einsum("bhng, bmgv->bhnv", weights, V)# 计算最终输出Y = tf.einsum("bhnv, hdv->bnd", O, P_o)return Y

4. GQA 的性能分析

论文中的实验表明:

  • 质量上,GQA 的 BLEU 得分几乎接近 MHA,明显优于 MQA。
  • 推理速度上,GQA 仅比 MQA 略慢,但比 MHA 快得多。
  • 适用于大模型推理,如 T5、GPT-4、Gemini,减少 KV 访问,提高吞吐量。

实验表明,GQA-8(8 组)质量和速度最优的选择,可以接近 MHA 的质量,同时拥有 MQA 级别的推理速度。


5. 总结

GQA 结合了 MHA 的高质量和 MQA 的高效推理,具有:

  • 更低的 KV 存储需求,推理更快。
  • 更高的模型表达能力,减少 MQA 的信息冗余问题。
  • 适用于大规模语言模型(如 LLaMA、PaLM、GPT-4)推理优化

GQA 目前已被 Google 等研究团队广泛应用于大模型推理优化,是 MQA 的重要改进方案。


Grouped-Query Attention(GQA)PyTorch 实现

以下是 Grouped-Query Attention(GQA)PyTorch 实现,它不使用 einsum,而是采用 矩阵乘法(@)、bmm() 方式进行计算,保证代码可以直接运行。

import torch
import torch.nn as nn
import torch.nn.functional as Fclass GroupedQueryAttention(nn.Module):def __init__(self, embed_dim, num_heads, num_groups, dropout=0.1):"""Grouped-Query Attention 实现Args:embed_dim: 词嵌入维度 dnum_heads: 查询头的数量 hnum_groups: 组的数量 G (1 表示 MQA, h 表示 MHA)dropout: dropout 率"""super(GroupedQueryAttention, self).__init__()assert num_heads % num_groups == 0, "num_heads 必须是 num_groups 的整数倍"self.embed_dim = embed_dimself.num_heads = num_headsself.num_groups = num_groupsself.head_dim = embed_dim // num_heads  # 每个头的维度 k# 查询(Q)投影矩阵,每个头独立self.q_proj = nn.Linear(embed_dim, embed_dim, bias=False)# 键(K)和值(V)投影矩阵,每组共享self.k_proj = nn.Linear(embed_dim, (embed_dim // num_heads) * num_groups, bias=False)self.v_proj = nn.Linear(embed_dim, (embed_dim // num_heads) * num_groups, bias=False)# 输出投影self.o_proj = nn.Linear(embed_dim, embed_dim, bias=False)# dropoutself.dropout = nn.Dropout(dropout)def forward(self, query, key, value, mask=None):"""前向传播Args:query: 查询张量,形状 [batch, seq_len, embed_dim]key: 键张量,形状 [batch, seq_len_kv, embed_dim]value: 值张量,形状 [batch, seq_len_kv, embed_dim]mask: 掩码张量,形状 [batch, 1, 1, seq_len_kv],默认 NoneReturns:输出张量,形状 [batch, seq_len, embed_dim]"""batch_size, seq_len, _ = query.shape_, seq_len_kv, _ = key.shape# 计算 Query,每个头独立Q = self.q_proj(query)  # [batch, seq_len, embed_dim]Q = Q.view(batch_size, seq_len, self.num_heads, self.head_dim)  # [batch, seq_len, num_heads, head_dim]Q = Q.permute(0, 2, 1, 3)  # [batch, num_heads, seq_len, head_dim]# 计算 Key 和 Value,按组共享K = self.k_proj(key)  # [batch, seq_len_kv, num_groups * head_dim]V = self.v_proj(value)  # [batch, seq_len_kv, num_groups * head_dim]K = K.view(batch_size, seq_len_kv, self.num_groups, self.head_dim)  # [batch, seq_len_kv, num_groups, head_dim]V = V.view(batch_size, seq_len_kv, self.num_groups, self.head_dim)  # [batch, seq_len_kv, num_groups, head_dim]K = K.permute(0, 2, 1, 3)  # [batch, num_groups, seq_len_kv, head_dim]V = V.permute(0, 2, 1, 3)  # [batch, num_groups, seq_len_kv, head_dim]# 计算注意力权重 (Q @ K^T),Query 按照组进行索引匹配group_size = self.num_heads // self.num_groupsQ_grouped = Q.view(batch_size, self.num_groups, group_size, seq_len, self.head_dim)  # [batch, num_groups, group_size, seq_len, head_dim]# 计算点积注意力attn_logits = torch.matmul(Q_grouped, K.transpose(-2, -1))  # [batch, num_groups, group_size, seq_len, seq_len_kv]# 归一化attn_logits /= self.head_dim ** 0.5# 应用掩码if mask is not None:attn_logits = attn_logits.masked_fill(mask == 0, float("-inf"))# 计算 softmax 注意力分布attn_weights = F.softmax(attn_logits, dim=-1)  # [batch, num_groups, group_size, seq_len, seq_len_kv]attn_weights = self.dropout(attn_weights)# 计算注意力加权的 ValueO = torch.matmul(attn_weights, V)  # [batch, num_groups, group_size, seq_len, head_dim]# 重新排列回原始形状O = O.permute(0, 3, 1, 2, 4).contiguous()  # [batch, seq_len, num_groups, group_size, head_dim]O = O.view(batch_size, seq_len, self.embed_dim)  # [batch, seq_len, embed_dim]# 通过最终的线性变换Y = self.o_proj(O)  # [batch, seq_len, embed_dim]return Y

5. 代码解读

  1. 参数解释

    • embed_dim: 输入嵌入维度(即 d)。
    • num_heads: 注意力头的数量(即 h)。
    • num_groups: 组的数量(如果 num_groups=1,则相当于 MQA;如果 num_groups=num_heads,则相当于 MHA)。
    • dropout: Dropout 率。
  2. 计算 Query

    • Query 使用独立的投影矩阵 self.q_proj 计算,每个 Query 头仍然是独立的。
  3. 计算 Key 和 Value

    • Key 和 Value 共享,但按照 num_groups 进行分组,每组有 head_dim 维度。
  4. 计算注意力

    • Q @ K^T 计算注意力分数。
    • softmax 归一化并应用 dropout。
    • attention_weights @ V 计算加权 Value。
  5. 重塑输出

    • 由于每个 Query 头仍然是独立的,计算完后需要重新排列回原始形状。
    • 通过 self.o_proj 进行最终的线性投影。

6. 运行示例

你可以用下面的代码来测试 GQA:

# 初始化模型
embed_dim = 64
num_heads = 8
num_groups = 4
batch_size = 2
seq_len = 10
seq_len_kv = 12gqa = GroupedQueryAttention(embed_dim, num_heads, num_groups)# 生成随机输入
query = torch.randn(batch_size, seq_len, embed_dim)
key = torch.randn(batch_size, seq_len_kv, embed_dim)
value = torch.randn(batch_size, seq_len_kv, embed_dim)# 前向传播
output = gqa(query, key, value)
print("Output shape:", output.shape)  # 预期输出 [batch_size, seq_len, embed_dim]

7. 总结

GQA 的 PyTorch 实现:

  • 完全可运行,不依赖 einsum,使用 matmul 进行计算。
  • 适用于推理优化,减少 KV 存储,提高 LLM 推理效率。
  • 兼容 MHA/MQA,通过 num_groups 控制:
    • num_groups = 1 时,相当于 MQA
    • num_groups = num_heads 时,相当于 MHA
    • num_groups = 4 时,找到 质量与推理速度的最佳平衡

这个实现可以直接用于 大模型推理加速,如 LLaMA、GPT-4、Gemini 等模型的优化!🚀

Grouped-Query Attention(GQA)结合 KV Cache 的推理优化


大语言模型(LLM) 的自回归推理过程中,每生成一个新 token,都需要计算 注意力(attention)。然而,标准 Multi-Head Attention(MHA) 需要存储并加载 所有 Key(K)和 Value(V),这会带来 显存占用过大内存带宽受限 的问题。

Grouped-Query Attention(GQA) 结合 KV Cache(Key-Value 缓存) 可以 减少存储、提高推理速度,特别适用于 GPT-4、Gemini 等大模型


1. 为什么推理时需要 KV Cache?

Transformer 自回归推理 中:

  • 训练时,模型可以并行计算整个序列(一次性输入所有 token)。
  • 推理时,只能逐步生成新 token,每次只能访问过去的 Key-Value 并计算新的 Query。

标准 MHA 推理(带 KV Cache)

在推理时:

  • 之前生成的 tokens 的 Key 和 Value 可以缓存,不需要重新计算。
  • 新的 Query 需要与 缓存中的 Key/Value 计算注意力

对于 标准 MHA

  • 每个头都有独立的 Key/Value,所以 缓存大小为
    KV Cache Size = O ( b × h × seq_len × d k ) \text{KV Cache Size} = \mathcal{O}(b \times h \times \text{seq\_len} \times d_k) KV Cache Size=O(b×h×seq_len×dk)
    这对于 大模型推理来说,KV 缓存占用显存过大,特别是 h=32 或更大时。

2. GQA 如何优化推理中的 KV Cache?

Grouped-Query Attention(GQA) 中:

  • 每个 Query 组共享同一个 Key 和 Value
  • 减少了 KV 缓存大小,让推理更高效。

对于 GQA(num_groups = G)

  • 只需要 G 组 Key-Value,而不是 h 组
  • 缓存大小降低 (h/G) 倍
    KV Cache Size = O ( b × G × seq_len × d k ) \text{KV Cache Size} = \mathcal{O}(b \times G \times \text{seq\_len} \times d_k) KV Cache Size=O(b×G×seq_len×dk)
  • 例如:
    • MHA(h=32) → 需要存储 32 组 K/V
    • GQA(G=8) → 只需要存储 8 组 K/V,减少 4 倍显存占用。

这样,GQA 在推理时可以大幅减少 KV Cache 访问和存储,提高解码速度!


3. PyTorch 实现:GQA 推理(结合 KV Cache)

下面是完整的 PyTorch 实现,支持 KV Cache,并可用于 增量推理

import torch
import torch.nn as nn
import torch.nn.functional as Fclass GroupedQueryAttention(nn.Module):def __init__(self, embed_dim, num_heads, num_groups, dropout=0.1):"""Grouped-Query Attention 结合 KV CacheArgs:embed_dim: 词嵌入维度 dnum_heads: 查询头的数量 hnum_groups: 组的数量 G (1 表示 MQA, h 表示 MHA)dropout: dropout 率"""super(GroupedQueryAttention, self).__init__()assert num_heads % num_groups == 0, "num_heads 必须是 num_groups 的整数倍"self.embed_dim = embed_dimself.num_heads = num_headsself.num_groups = num_groupsself.head_dim = embed_dim // num_heads  # 每个头的维度 k# 查询(Q)投影矩阵,每个头独立self.q_proj = nn.Linear(embed_dim, embed_dim, bias=False)# 键(K)和值(V)投影矩阵,每组共享self.k_proj = nn.Linear(embed_dim, (embed_dim // num_heads) * num_groups, bias=False)self.v_proj = nn.Linear(embed_dim, (embed_dim // num_heads) * num_groups, bias=False)# 输出投影self.o_proj = nn.Linear(embed_dim, embed_dim, bias=False)# dropoutself.dropout = nn.Dropout(dropout)def forward(self, query, key, value, kv_cache=None, mask=None):"""推理时结合 KV CacheArgs:query: 查询张量 [batch, 1, embed_dim] (推理时单个 token)key: 当前 token 的键 [batch, 1, embed_dim]value: 当前 token 的值 [batch, 1, embed_dim]kv_cache: 之前的 Key-Value 缓存 (字典: {'key': K, 'value': V})mask: 注意力掩码 [batch, 1, 1, seq_len_kv]Returns:输出张量 [batch, 1, embed_dim]更新后的 KV Cache"""batch_size, _, _ = query.shape# 计算 Query,每个头独立Q = self.q_proj(query)  # [batch, 1, embed_dim]Q = Q.view(batch_size, 1, self.num_heads, self.head_dim)  # [batch, 1, num_heads, head_dim]Q = Q.permute(0, 2, 1, 3)  # [batch, num_heads, 1, head_dim]# 计算当前步的 Key 和 Value,按组共享K_new = self.k_proj(key).view(batch_size, 1, self.num_groups, self.head_dim)  # [batch, 1, num_groups, head_dim]V_new = self.v_proj(value).view(batch_size, 1, self.num_groups, self.head_dim)  # [batch, 1, num_groups, head_dim]K_new = K_new.permute(0, 2, 1, 3)  # [batch, num_groups, 1, head_dim]V_new = V_new.permute(0, 2, 1, 3)  # [batch, num_groups, 1, head_dim]# 更新 KV Cacheif kv_cache is None:K = K_newV = V_newelse:K = torch.cat([kv_cache['key'], K_new], dim=2)  # [batch, num_groups, seq_len_kv, head_dim]V = torch.cat([kv_cache['value'], V_new], dim=2)# 计算注意力 logitsgroup_size = self.num_heads // self.num_groupsQ_grouped = Q.view(batch_size, self.num_groups, group_size, 1, self.head_dim)  # [batch, num_groups, group_size, 1, head_dim]attn_logits = torch.matmul(Q_grouped, K.transpose(-2, -1))  # [batch, num_groups, group_size, 1, seq_len_kv]attn_logits /= self.head_dim ** 0.5# 应用掩码if mask is not None:attn_logits = attn_logits.masked_fill(mask == 0, float("-inf"))# 计算 softmax 注意力分布attn_weights = F.softmax(attn_logits, dim=-1)  # [batch, num_groups, group_size, 1, seq_len_kv]attn_weights = self.dropout(attn_weights)# 计算注意力加权的 ValueO = torch.matmul(attn_weights, V)  # [batch, num_groups, group_size, 1, head_dim]O = O.permute(0, 3, 1, 2, 4).contiguous()  # [batch, 1, num_groups, group_size, head_dim]O = O.view(batch_size, 1, self.embed_dim)  # [batch, 1, embed_dim]# 通过最终的线性变换Y = self.o_proj(O)  # [batch, 1, embed_dim]return Y, {'key': K, 'value': V}

4. 结论

GQA 结合 KV Cache

  • 减少存储,比 MHA 降低 ( h/G ) 倍 KV Cache 占用
  • 加速推理,减少 Key-Value 访问,适用于 大模型优化(GPT-4、Gemini)
  • PyTorch 实现可直接运行,适用于 增量推理(Streaming Inference)

GQA+KV Cache 是当前 LLM 高效推理的重要优化方向!🚀

Grouped-Query Attention(GQA)中 matmul(Q_grouped, K.transpose(-2, -1)) 的计算解析


GQA 计算注意力 logits 的过程中,我们使用了:

attn_logits = torch.matmul(Q_grouped, K.transpose(-2, -1))  

这个操作的核心是计算 Query 和 Key 之间的点积注意力分数,即:
logits = Q ⋅ K T \text{logits} = Q \cdot K^T logits=QKT
但在 GQA 中,由于 Query 头是按组共享 Key 的,因此计算方式比标准 MHA 更复杂。


1. 形状分析

首先,我们看看 Q_groupedK 的形状:

  • Q_grouped(Grouped Query)

    Q_grouped = Q.view(batch_size, num_groups, group_size, 1, head_dim)  
    

    形状变为:
    ( b a t c h , num_groups , group_size , 1 , head_dim ) (batch, \text{num\_groups}, \text{group\_size}, 1, \text{head\_dim}) (batch,num_groups,group_size,1,head_dim)
    其中:

    • num_groups:查询被分成的组数。
    • group_size:每个组的 Query 头数(num_heads / num_groups)。
    • 1:表示当前推理的单个 token(因为推理是自回归的,每次只计算一个新 token)。
    • head_dim:每个头的维度。
  • K(Key 缓存)

    K = K.transpose(-2, -1)  # 转置 K,使其可以与 Q 进行点积
    

    形状为:
    ( b a t c h , num_groups , seq_len_kv , head_dim ) (batch, \text{num\_groups}, \text{seq\_len\_kv}, \text{head\_dim}) (batch,num_groups,seq_len_kv,head_dim)
    其中:

    • seq_len_kv:当前 Key-Value 缓存中的 token 数量。
    • head_dim:每个 Key 头的维度。

2. matmul(Q_grouped, K.transpose(-2, -1)) 计算过程

现在,我们来看点积计算:

attn_logits = torch.matmul(Q_grouped, K.transpose(-2, -1))  

这个操作等价于:
logits = Q × K T \text{logits} = Q \times K^T logits=Q×KT

矩阵计算规则

假设:

  • Q_grouped 形状为 (batch, num_groups, group_size, 1, head_dim)
  • K^T 形状为 (batch, num_groups, head_dim, seq_len_kv)

由于 矩阵乘法的规则
( A ∈ R m × k ) × ( B ∈ R k × n ) = C ∈ R m × n (A \in \mathbb{R}^{m \times k}) \times (B \in \mathbb{R}^{k \times n}) = C \in \mathbb{R}^{m \times n} (ARm×k)×(BRk×n)=CRm×n
所以计算后:
logits ∈ R batch , num_groups , group_size , 1 , seq_len_kv \text{logits} \in \mathbb{R}^{\text{batch}, \text{num\_groups}, \text{group\_size}, 1, \text{seq\_len\_kv}} logitsRbatch,num_groups,group_size,1,seq_len_kv

即:

  • batch:批大小,不变。
  • num_groups:每个组独立计算注意力分数。
  • group_size:组内的 Query 头。
  • 1:当前 Query 的 token 数(因为推理时每次处理一个 token)。
  • seq_len_kv:Key 缓存的长度(即 Query 需要关注的所有历史 tokens)。

3. 举例计算

假设输入数据

  • Query Q_grouped

    • 形状:(batch=1, num_groups=2, group_size=2, 1, head_dim=3)
    • 假设值:
      Q_grouped = torch.tensor([[[  # Group 1[[1, 2, 3]],   # Query Head 1[[4, 5, 6]]    # Query Head 2],[  # Group 2[[7, 8, 9]],   # Query Head 3[[10, 11, 12]] # Query Head 4]]
      ], dtype=torch.float32)
      
  • Key K

    • 形状:(batch=1, num_groups=2, seq_len_kv=2, head_dim=3)
    • 假设值:
      K = torch.tensor([[[  # Group 1[0, 1, 0],  # Key 1[1, 0, 1]   # Key 2],[  # Group 2[1, 1, 1],  # Key 1[2, 2, 2]   # Key 2]]
      ], dtype=torch.float32)
      

计算步骤

  1. Key 转置K.transpose(-2, -1)

    K_T = K.transpose(-2, -1)
    

    变为:

    K_T = torch.tensor([[[  # Group 1[0, 1],  # Key Head 1[1, 0],  [0, 1]   ],[  # Group 2[1, 2],  # Key Head 2[1, 2],[1, 2]]]
    ], dtype=torch.float32)
    
  2. 矩阵乘法

    attn_logits = torch.matmul(Q_grouped, K_T)
    

    计算方式如下:

Group 1
Query Head 1 ([1, 2, 3]) 与 Key 矩阵点积:
[ 1 , 2 , 3 ] ⋅ [ 0 1 1 0 0 1 ] = [ 2 , 4 ] [1, 2, 3] \cdot \begin{bmatrix} 0 & 1 \\ 1 & 0 \\ 0 & 1 \end{bmatrix} = [2, 4] [1,2,3] 010101 =[2,4]
Query Head 2 ([4, 5, 6]):

[ 4 , 5 , 6 ] ⋅ [ 0 1 1 0 0 1 ] = [ 5 , 9 ] [4, 5, 6] \cdot \begin{bmatrix} 0 & 1 \\ 1 & 0 \\ 0 & 1 \end{bmatrix} = [5, 9] [4,5,6] 010101 =[5,9]

Group 2

Query Head 3 ([7, 8, 9]):
[ 7 , 8 , 9 ] ⋅ [ 1 2 1 2 1 2 ] = [ 24 , 48 ] [7, 8, 9] \cdot \begin{bmatrix} 1 & 2 \\ 1 & 2 \\ 1 & 2 \end{bmatrix} = [24, 48] [7,8,9] 111222 =[24,48]
Query Head 4 ([10, 11, 12]):
[ 10 , 11 , 12 ] ⋅ [ 1 2 1 2 1 2 ] = [ 33 , 66 ] [10, 11, 12] \cdot \begin{bmatrix} 1 & 2 \\ 1 & 2 \\ 1 & 2 \end{bmatrix} = [33, 66] [10,11,12] 111222 =[33,66]


最终结果

计算出的 attn_logits

attn_logits = torch.tensor([[[[[2, 4]],  # Query Head 1[[5, 9]]   # Query Head 2],[[[24, 48]], # Query Head 3[[33, 66]]  # Query Head 4]]
], dtype=torch.float32)
  • 形状:(batch=1, num_groups=2, group_size=2, 1, seq_len_kv=2)

4. 结论

  • GQA 中,Query 按组匹配共享 Key,减少计算复杂度。
  • KV 缓存中仅存储 num_groups 组 Key,而非 num_heads 组 Key,节省显存。
  • 矩阵计算遵循 Query-Key 点积规则,matmul(Q_grouped, K.transpose(-2, -1)) 计算注意力分数

后记

2025年2月23日10点08分于上海,在GPT4o大模型辅助下完成。

相关文章:

Grouped-Query Attention(GQA)详解: Pytorch实现

Grouped-Query Attention(GQA)详解 Grouped-Query Attention(GQA) 是 Multi-Query Attention(MQA) 的改进版,它通过在 多个查询头(Query Heads)之间共享 Key 和 Value&am…...

DeepSeek AI人工智能该如何学习?

人工智能(Artificial Intelligence, AI)是当今科技领域的热门话题,它涵盖了机器学习、深度学习、自然语言处理、计算机视觉等多个子领域。 作为中国科技发展的核心方向之一,AI在国家战略规划中占据了重要地位,特别是在…...

【数据库】【MySQL】索引

MySQL中索引的概念 索引(MySQL中也叫做"键(key)")是一种数据结构,用于存储引擎快速定找到记录。 简单来说,它类似于书籍的目录,通过索引可以快速找到对应的数据行,而无需…...

SprinBoot整合HTTP API:从零开始的实战指南

在现代 Web 开发中,HTTP API 是前后端交互的核心。Spring Boot 作为 Java 生态中备受欢迎的框架,提供了简洁而强大的方式来构建和整合 HTTP API。本文将带你从零开始,通过具体代码示例,展示如何在 Spring Boot 中整合 HTTP API,实现高效、稳定的前后端通信。 一、为什么选…...

可狱可囚的爬虫系列课程 13:Requests使用代理IP

一、什么是代理 IP 代理 IP(Proxy IP)是一个充当“中间人”的服务器IP地址,用于代替用户设备(如电脑、手机等)直接与目标网站或服务通信。用户通过代理IP访问互联网时,目标网站看到的是代理服务器的IP地址&…...

DBeaver下载安装及数据库连接(MySQL)

1. DBeaver下载 官网下载地址:Download | DBeaver Community 2. 安装 1. 双击下载的安装包,选择简体中文。 2. 点击下一步。 3. 点击我接受。 4. 如下勾选为所有用户安装,点击下一步。 5. 需重复做1~3 的步骤。 6. 选择组件,默认即可&…...

国产开源PDF解析工具MinerU

前言 PDF的数据解析是一件较困难的事情,几乎所有商家都把PDF转WORD功能做成付费产品。 PDF是基于PostScript子集渲染的,PostScript是一门图灵完备的语言。而WORD需要的渲染,本质上是PDF能力的子集。大模型领域,我们的目标文件格…...

消息中间件的开源实现

根据你的需求,以下是一些可以实现类似阿里巴巴 MetaQ 功能的消息中间件和相关项目,这些项目可以帮助你实现消息的动态配置和管理: 1. RocketMQ RocketMQ 是一个分布式消息中间件,支持高吞吐量、低延迟的消息传递,适合…...

AcWing 299 裁剪序列

这道题算是我做过所有的单调队列优化 d p dp dp 题目中最难想的一道题,所以写篇题解再捋捋思路。 暴力 首先很容易想到设 d p i dp_i dpi​ 表示将前 i i i 个数划分成若干序列,【每个序列的最大值之和】的最小值。 那么就会有: d p i …...

P2889 [USACO07NOV] Milking Time S

题目大意 有 N N N 个小时可以挤奶。其中有 m m m 个时间段可以给 Bessis 奶牛挤奶。第 i i i 个时间段为 s i s_i si​ ~ t i t_i ti​,可以获得 E f f i Eff_i Effi​ 滴奶。每次挤完奶后,人都要休息 R R R 小时。最后问,一共能挤出…...

基于Spring Boot的健康医院门诊在线挂号系统设与实现(LW+源码+讲解)

专注于大学生项目实战开发,讲解,毕业答疑辅导,欢迎高校老师/同行前辈交流合作✌。 技术范围:SpringBoot、Vue、SSM、HLMT、小程序、Jsp、PHP、Nodejs、Python、爬虫、数据可视化、安卓app、大数据、物联网、机器学习等设计与开发。 主要内容:…...

PyTorch-基础(CUDA、Dataset、transforms、卷积神经网络、VGG16)

PyTorch-基础 环境准备 CUDA Toolkit安装(核显跳过此步骤) CUDA Toolkit是NVIDIA的开发工具,里面提供了各种工具、如编译器、调试器和库 首先通过NVIDIA控制面板查看本机显卡驱动对应的CUDA版本,如何去下载对应版本的Toolkit工…...

复现论文:DPStyler: Dynamic PromptStyler for Source-Free Domain Generalization

论文:[2403.16697] DPStyler: Dynamic PromptStyler for Source-Free Domain Generalization github: TYLfromSEU/DPStyler: DPStyler: Dynamic PromptStyler for Source-Free Domain Generalization 论文: 这篇论文还是在PromptStyler:Prompt-driven Style Gener…...

6.将cr打包成网络服务|使用postman进行测试|编写oj_server的服务路由功能(C++)

将cr打包成网络服务 compile_server.cc #include "compile_run.hpp" #include "../comm/httplib.h"using namespace ns_compile_and_run; using namespace httplib;//编译服务随时可能被多个人请求,必须保证传递上来的code,形成源…...

基于SpringBoot + Vue的共享汽车(单车)管理系统设计与实现+毕业论文+开题报告+指导搭建视频

本系统包含管理员、用户两个角色。 管理员角色:个人中心管理、用户管理、投放地区管理、汽车信息管理、汽车投放管理、汽车入库管理、使用订单管理、汽车归还管理。 用户角色:注册登录、汽车使用下单、汽车归还。 本共享汽车管理系统有管理员和用户。管…...

Day54(补)【AI思考】-SOA,Web服务以及无状态分步解析与示例说明

文章目录 **SOA,Web服务以及无状态**分步解析与示例说明**分步解析与示例说明****1. 核心概念解析****2. 为什么说SOA与Web服务是“正交的”?****3. 架构风格 vs. 实现技术****4. 接口(Interface)的核心作用****5. Web服务的“被认…...

回溯算法之组合和排列问题

文章目录 1.什么是回溯算法2.回溯算法解题步骤3.回溯算法解决组合问题4.回溯算法解决排列问题 1.什么是回溯算法 回溯算法是一种通过尝试所有可能的解决方案来解决问题的算法策略,它通常用于求解组合优化、排列组合、路径搜索等类型的问题,是一种暴力求解的算法。 2…...

gihub上适合练手的Python项目

GitHub 上有许多适合练手的 Python 项目,涵盖了从初学者到中级开发者的不同难度级别。以下是一些推荐的项目类型和具体示例,帮助你提升 Python 编程技能: 1. 基础项目 适合初学者,帮助掌握 Python 基础语法和常用库。 示例项目&…...

解锁CSnakes:.NET与Python的融合魔法

一、引言 在软件开发的广袤领域中,我们常常面临各种复杂的业务需求和技术挑战。不同的编程语言犹如各具特色的工具,它们在不同的场景下展现出独特的优势。例如,C# 以其强大的类型系统和丰富的类库,在企业级应用开发中占据重要地位…...

Python常见面试题的详解16

1. 如何强行关闭客户端和服务器之间的连接? 在网络编程中,有时需要强行中断客户端和服务器之间的连接。对于基于 TCP 协议的连接,由于其面向连接的特性,需要采取特定的步骤来确保连接被正确关闭;而 UDP 是无连接协议&a…...

label-studio的使用教程(导入本地路径)

文章目录 1. 准备环境2. 脚本启动2.1 Windows2.2 Linux 3. 安装label-studio机器学习后端3.1 pip安装(推荐)3.2 GitHub仓库安装 4. 后端配置4.1 yolo环境4.2 引入后端模型4.3 修改脚本4.4 启动后端 5. 标注工程5.1 创建工程5.2 配置图片路径5.3 配置工程类型标签5.4 配置模型5.…...

【人工智能】神经网络的优化器optimizer(二):Adagrad自适应学习率优化器

一.自适应梯度算法Adagrad概述 Adagrad(Adaptive Gradient Algorithm)是一种自适应学习率的优化算法,由Duchi等人在2011年提出。其核心思想是针对不同参数自动调整学习率,适合处理稀疏数据和不同参数梯度差异较大的场景。Adagrad通…...

c#开发AI模型对话

AI模型 前面已经介绍了一般AI模型本地部署,直接调用现成的模型数据。这里主要讲述讲接口集成到我们自己的程序中使用方式。 微软提供了ML.NET来开发和使用AI模型,但是目前国内可能使用不多,至少实践例子很少看见。开发训练模型就不介绍了&am…...

嵌入式学习笔记DAY33(网络编程——TCP)

一、网络架构 C/S (client/server 客户端/服务器):由客户端和服务器端两个部分组成。客户端通常是用户使用的应用程序,负责提供用户界面和交互逻辑 ,接收用户输入,向服务器发送请求,并展示服务…...

R语言速释制剂QBD解决方案之三

本文是《Quality by Design for ANDAs: An Example for Immediate-Release Dosage Forms》第一个处方的R语言解决方案。 第一个处方研究评估原料药粒径分布、MCC/Lactose比例、崩解剂用量对制剂CQAs的影响。 第二处方研究用于理解颗粒外加硬脂酸镁和滑石粉对片剂质量和可生产…...

AI+无人机如何守护濒危物种?YOLOv8实现95%精准识别

【导读】 野生动物监测在理解和保护生态系统中发挥着至关重要的作用。然而,传统的野生动物观察方法往往耗时耗力、成本高昂且范围有限。无人机的出现为野生动物监测提供了有前景的替代方案,能够实现大范围覆盖并远程采集数据。尽管具备这些优势&#xf…...

Chromium 136 编译指南 Windows篇:depot_tools 配置与源码获取(二)

引言 工欲善其事,必先利其器。在完成了 Visual Studio 2022 和 Windows SDK 的安装后,我们即将接触到 Chromium 开发生态中最核心的工具——depot_tools。这个由 Google 精心打造的工具集,就像是连接开发者与 Chromium 庞大代码库的智能桥梁…...

SpringAI实战:ChatModel智能对话全解

一、引言:Spring AI 与 Chat Model 的核心价值 🚀 在 Java 生态中集成大模型能力,Spring AI 提供了高效的解决方案 🤖。其中 Chat Model 作为核心交互组件,通过标准化接口简化了与大语言模型(LLM&#xff0…...

【Linux】Linux安装并配置RabbitMQ

目录 1. 安装 Erlang 2. 安装 RabbitMQ 2.1.添加 RabbitMQ 仓库 2.2.安装 RabbitMQ 3.配置 3.1.启动和管理服务 4. 访问管理界面 5.安装问题 6.修改密码 7.修改端口 7.1.找到文件 7.2.修改文件 1. 安装 Erlang 由于 RabbitMQ 是用 Erlang 编写的,需要先安…...

CTF show 数学不及格

拿到题目先查一下壳,看一下信息 发现是一个ELF文件,64位的 ​ 用IDA Pro 64 打开这个文件 ​ 然后点击F5进行伪代码转换 可以看到有五个if判断,第一个argc ! 5这个判断并没有起太大作用,主要是下面四个if判断 ​ 根据题目…...