当前位置: 首页 > article >正文

大模型中的KV Cache

1. KV Cache的定义与核心原理

KV Cache(Key-Value Cache)是一种在Transformer架构的大模型推理阶段使用的优化技术,通过缓存自注意力机制中的键(Key)和值(Value)矩阵,避免重复计算,从而显著提升推理效率。

原理:

  • 自注意力机制:在Transformer中,注意力计算基于公式:
    Attention ( Q , K , V ) = softmax ( Q K ⊤ d k ) V = ∑ i = 1 n w i v i (加权求和形式) \begin{split} \text{Attention}(Q, K, V) &= \text{softmax}\left( \frac{QK^\top}{\sqrt{d_k}} \right) V \\ &= \sum_{i=1}^n w_i v_i \quad \text{(加权求和形式)} \end{split} Attention(Q,K,V)=softmax(dk QK)V=i=1nwivi(加权求和形式)
    其中,Q(Query)、K(Key)、V(Value)由输入序列线性变换得到。

  • 缓存机制:在生成式任务(如文本生成)中,模型以自回归方式逐个生成token。首次推理时,计算所有输入token的K和V并缓存;后续生成时,仅需为新token计算Q,并从缓存中读取历史K和V进行注意力计算。

  • 复杂度优化:传统方法的计算复杂度为O(n²),而KV Cache将后续生成的复杂度降为O(n),避免重复计算历史token的K和V。

2. KV Cache的核心作用

  • 加速推理:通过复用缓存的K和V,减少矩阵计算量,提升生成速度。例如,某聊天机器人应用响应时间从0.5秒缩短至0.2秒。

  • 降低资源消耗:显存占用减少约30%-50%(例如移动端模型从1GB降至0.6GB),支持在资源受限设备上部署大模型。

  • 支持长文本生成:缓存机制使推理耗时不再随文本长度线性增长,可稳定处理长序列(如1024 token以上)。

  • 保持模型性能:仅优化计算流程,不影响输出质量。

3. 技术实现与优化策略

实现方式:
  • 数据结构

    1. KV Cache以张量形式存储,Key Cache和Value Cache的形状分别为(batch_size, num_heads, seq_len, k_dim)(batch_size, num_heads, seq_len, v_dim)
  • 两阶段推理:

    1. 初始化阶段:计算初始输入的所有K和V,存入缓存。
    2. 迭代阶段:仅计算新token的Q,结合缓存中的K和V生成输出,并更新缓存。
      • 代码示例(Hugging Face Transformers):设置model.generate(use_cache=True)即可启用KV Cache。

优化策略:

  • 稀疏化(Sparse):仅缓存部分重要K和V,减少显存占用。

  • 量化(Quantization):将K和V矩阵从FP32转为INT8/INT4,降低存储需求。

共享机制(MQA/GQA):

  • Multi-Query Attention (MQA):所有注意力头共享同一组K和V,显存占用降低至1/头数。

  • Grouped-Query Attention (GQA):将头分组,组内共享K和V,平衡性能和显存。

4. 挑战与局限性

  • 显存压力:随着序列长度增加,缓存占用显存线性增长(如1024 token占用约1GB显存),可能引发OOM(内存溢出)。

  • 冷启动问题:首次推理仍需完整计算K和V,无法完全避免初始延迟。

5、python实现

import torch
import torch.nn as nn# 超参数
d_model = 4
n_heads = 1
seq_len = 3
batch_size = 3# 初始化参数(兼容多头形式)
Wq = nn.Linear(d_model, d_model, bias=False)
Wk = nn.Linear(d_model, d_model, bias=False)
Wv = nn.Linear(d_model, d_model, bias=False)# 生成模拟输入(整个序列一次性输入)
input_sequence = torch.randn(batch_size, seq_len, d_model)  # [B, L, D]# 初始化 KV 缓存(兼容多头格式)
kv_cache = {"keys": torch.empty(batch_size, 0, n_heads, d_model // n_heads),  # [B, T, H, D/H]"values": torch.empty(batch_size, 0, n_heads, d_model // n_heads) 
}# 因果掩码预先生成(覆盖最大序列长度)
causal_mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1).bool()  # [L, L]'''
本循环是将整句话中的token一个一个输入,并更新KV缓存;
所以无需显示的因果掩码,因为因果掩码只用于计算注意力权重时,而计算注意力权重时,KV缓存中的key和value已经包含了因果掩码的信息。'''for step in range(seq_len):# 1. 获取当前时间步的输入(整个批次)current_token = input_sequence[:, step, :]  # [B, 1, D]# 2. 计算当前时间步的 Q/K/V(保持三维结构)q = Wq(current_token)  # [B, 1, D]k = Wk(current_token)  # [B, 1, D]v = Wv(current_token)  # [B, 1, D]# 3. 调整维度以兼容多头格式(关键修改点)def reshape_for_multihead(x):return x.view(batch_size, 1, n_heads, d_model // n_heads).transpose(1, 2)  # [B, H, 1, D/H]# 4. 更新 KV 缓存(增加时间步维度)kv_cache["keys"] = torch.cat([kv_cache["keys"], reshape_for_multihead(k).transpose(1, 2)  # [B, T+1, H, D/H]], dim=1)kv_cache["values"] = torch.cat([kv_cache["values"],reshape_for_multihead(v).transpose(1, 2)  # [B, T+1, H, D/H]], dim=1)# 5. 多头注意力计算(支持批量处理)q_multi = reshape_for_multihead(q)  # [B, H, 1, D/H]k_multi = kv_cache["keys"].transpose(1, 2)  # [B, H, T+1, D/H]print("q_multi shape:", q_multi.shape)print("k_multi shape:", k_multi.shape)# 6. 计算注意力分数(带因果掩码)attn_scores = torch.matmul(q_multi, k_multi.transpose(-2, -1)) / (d_model ** 0.5)print("attn_scores shape:", attn_scores.shape)# attn_scores = attn_scores.masked_fill(causal_mask[:step+1, :step+1], float('-inf'))# print("attn_scores shape:", attn_scores.shape)# 7. 注意力权重计算attn_weights = torch.softmax(attn_scores, dim=-1)  # [B, H, 1, T+1]# 8. 加权求和output = torch.matmul(attn_weights, kv_cache["values"].transpose(1, 2))  # [B, H, 1, D/H]# 9. 合并多头输出output = output.contiguous().view(batch_size, 1, d_model)  # [B, 1, D]print(f"Step {step} 输出:", output.shape)

相关文章:

大模型中的KV Cache

1. KV Cache的定义与核心原理 KV Cache(Key-Value Cache)是一种在Transformer架构的大模型推理阶段使用的优化技术,通过缓存自注意力机制中的键(Key)和值(Value)矩阵,避免重复计算&…...

FHQ平衡树

FHQ平衡树 大致是这样的题目: 您需要动态地维护一个可重集合 M M M,并且提供以下操作: 向 M M M 中插入一个数 x x x。从 M M M 中删除一个数 x x x(若有多个相同的数,应只删除一个)。查询 M M M 中…...

力扣算法---总结篇

5.13 数组总结 数组是存放在连续内存空间上的相同类型数据的集合。 数组可以方便的通过下标索引的方式获取到下标对应的数据。 正是因为数组在内存空间的地址是连续的,所以我们在删除或者增添元素的时候,就难免要移动其他元素的地址。 数组的元素是不…...

ABAP+旧数据接管的会计年度未确定

导资产主数据时,报错旧数据接管的会计年度未确定 是因为程序里面使用了下列函数AISCO_CALCULATE_FIRST_DAY,输入公司代码,获取会计年度,这个数据是在后台表T093C表中取数的,通过SE16N可以看到后台表数据没有数&#xf…...

Java【10_1】用户注册登录(面向过程与面向对象)

测试题 1、基于文本界面实现登录注册的需求(要求可以满足多个用户的注册和登录) 通过工具去完成 公共类: public class User { private int id;//用户编号 private int username;//用户名 private int password;//密码 private String name;//真…...

养生:打造健康生活的全方位策略

在生活节奏不断加快的当下,养生已成为提升生活质量、维护身心平衡的重要方式。从饮食、运动到睡眠,再到心态调节,各个方面的养生之道共同构建起健康生活的坚实基础。以下为您详细介绍养生的关键要点,助您拥抱健康生活。 饮食养生…...

贪吃蛇游戏排行榜模块开发总结:从数据到视觉的实现

一、项目背景与成果概览 在完成贪吃蛇游戏核心玩法后,本次开发重点聚焦于排行榜系统的实现。该系统具备以下核心特性: 🌐 双数据源支持:本地存储(localStorage)与远程API自由切换 🕒 时间维度统计:日榜/周榜/月榜/全时段数据筛选 🎮 模式区分:闯关模式(关卡进度…...

pytorch 数据预处理和常用工具

文章目录 NumPyNumpy数据结构安装和使用NumPy Matplotlib的安装和导入安装和导入Matplotlib绘制基础图画折线图散点图柱状图图例 数据清洗据清洗的作用Pandas进行数据清洗Pandas数据结构Series 数据结构DataFrame数据结构 Pandas数据清洗常用代码 特征工程主成分分析线性判别分…...

如何界定合法收集数据?

首席数据官高鹏律师团队 在当今数字化时代,数据的价值日益凸显,而合法收集数据成为了企业、机构以及各类组织必须严守的关键准则。作为律师,深入理解并准确界定合法收集数据的范畴,对于保障各方权益、维护法律秩序至关重要。 一…...

企业对数据集成工具的需求及 ETL 工具工作原理详解

当下,数据已然成为企业运营发展过程中的关键生产要素,其重要性不言而喻。 海量的数据分散在企业的各类系统、平台以及不同的业务部门之中,企业要充分挖掘这些数据背后所蕴含的巨大价值,实现数据驱动的精准决策,数据集…...

内核深入学习3——分析ARM32和ARM64体系架构下的Linux内存区域示意图与页表的建立流程

内核深入学习3——ARM32/ARM64在Linux内核中的实现(2) ​ 今天我们来讨论的是一个硬核的内容,也是一个老生常谈的话题——那就是分析ARM32和ARM64体系架构下的Linux内存区域示意图的内容。对于ARM64的部分,我们早就知道一个基本的…...

MapReduce基本介绍

核心思想 分而治之:将大规模的数据处理任务分解成多个可以并行处理的子任务,然后将这些子任务分配到不同的计算节点上进行处理,最后将各个子任务的处理结果合并起来,得到最终的结果。 工作流程 Map 阶段: 输入数据被…...

屏幕与触摸调试

本章配套视频介绍: 《28-屏幕与触摸设置》 【鲁班猫】28-屏幕与触摸设置_哔哩哔哩_bilibili LubanCat-RK3588系列板卡都支持mipi屏以及hdmi显示屏的显示。 19.1. 旋转触摸屏 参考文章 触摸校准 参考文章 旋转触摸方向 配置触摸旋转方向 1 2 # 1.查看触摸输入设备 xinput…...

使用 百度云大模型平台 做 【提示词优化】

1. 百度云大模型平台 百度智能云千帆大模型平台  平台功能:演示了阿里云大模型的百炼平台,该平台提供Prompt工程功能,支持在线创建和优化Prompt模板模板类型:平台提供多种预制模板,同时也支持用户自定义…...

C 语言_常见排序算法全解析

排序算法是计算机科学中的基础内容,本文将介绍 C 语言中几种常见的排序算法,包括实现代码、时间复杂度分析、适用场景和详细解析。 一、冒泡排序(Bubble Sort) 基本思想:重复遍历数组,比较相邻元素,将较大元素交换到右侧。 代码实现: void bubbleSort(int arr[], i…...

IJCAI 2025 | 高德首个原生3D生成基座大模型「G3PT」重塑3D生成的未来

国际人工智能联合会议(IJCAI)是人工智能领域最古老、最具权威性的学术会议之一,自1969年首次举办以来,至今已有近六十年的历史。它见证了人工智能从萌芽到蓬勃发展的全过程,是全球人工智能研究者、学者、工程师和行业专…...

Samtec助力电视广播行业

【摘要前言】 现代广播电视技术最有趣的方面之一就是界限的模糊。过去,音频和视频是通过射频电缆传输的模拟技术采集的,而现在,数字世界已经取代了模拟技术。物理胶片和磁带已让位于数字存储设备和流媒体。 在这个过程中,连接器…...

密码学--仿射密码

一、实验目的 1、通过实现简单的古典密码算法,理解密码学的相关概念 2、理解明文、密文、加密密钥、解密密钥、加密算法、解密算法、流密码与分组密码等。 二、实验内容 1、题目内容描述 ①随机生成加密密钥,并验证密钥的可行性 ②从plain文件读入待…...

生成式图像水印研究综述

生成式图像水印研究综述 一、引言二、生成式图像水印研究背景三、生成式图像水印算法研究进展3.1 基于流模型的方案3.2 基于生成对抗网络的方案3.3 基于扩散模型的方案3.3.1 修改图像数据3.3.2 调整生成模型3.3.3 修改隐变量空间四、算法的性能与评价指标五、常用数据集六、本章…...

TCP协议详细讲解及C++代码实例

目录 一. TCP协议详细讲解及C代码实例1、TCP协议概述2、TCP通信流程1) 三次握手2) 数据传输3) 四次挥手 3、关键点解析1) 套接字创建2) 三次握手实现3) 数据传输4) 四次挥手实现 4、TCP与UDP对比 一. TCP协议详细讲解及…...

深度剖析:Vue2 项目兼容第三方库模块格式的终极解决方案

当我们为 Vue2 项目引入某些现代 JavaScript 库时,常常会遇到这样的报错: error in ./node_modules/some-lib/lib/index.mjs Cant import the named export xxx from non EcmaScript module这类问题的本质是模块格式的世纪之争 —— ES Module&#xff…...

APISQL免费版安装教程(视频)

APISQL 一款通用的API开发管理软件,支持将主流数据库中的表、视图、SQL语句、存储过程等快速封装为标准的 RESTful API,支持多种安全认证方式和可视化管理界面。适用于接口开发、系统集成、数据共享等场景。 支持主流数据库的表、视图、自定义函数、存储…...

SpringBoot整合MQTT实战:基于EMQX实现双向设备通信(附源码)

简言: 在万物互联的时代,MQTT协议凭借其轻量级、高效率的特性,已成为物联网通信的事实标准。本教程将带领您在Ubuntu系统上搭建EMQX 5.9.0消息服务器,并使用Spring Boot快速实现两个客户端的高效通信。通过本指南,您将…...

从零开始掌握FreeRTOS(2)链表之节点的定义

目录 节点 节点定义 节点实现 根节点 根节点定义 精简节点定义 根节点实现 在上篇文章,我们完成了 FreeRTOS 的移植。在创建任务之前,我们需要先了解FreeRTOS的运转机制。 FreeRTOS是一个多任务系统,由操作系统来管理执行每个任务。这些任务全都挂载到一个双向循…...

Java的While循环写的出票简单程序

import java.util.Scanner;public class Hello {public static void main(String[] args) {Scanner in new Scanner(System.in);int balance 0;while(true){System.out.print("请投币: ");int amount in.nextInt();balance balance amount;if(balance >10 )…...

详解Windows(十一)——网络连接设置

Windows网络连接设置完全指南 1. Windows网络连接基础 网络连接类型 有线连接: 通过网线将电脑连接到路由器或调制解调器优点:连接稳定,速度快,延迟低适合:需要高速稳定网络的场景,如游戏、大文件下载、…...

多线程爬虫语言选择与实现

之前文中有人提到:想要一个简单易用、能快速实现多线程爬虫的方案,而且目标是小网站,基本可以确定对反爬虫措施要求不高,这些就比较简单了。 以往我肯定要考虑常见的编程语言中哪些适合爬虫。Python、JavaScript(Node…...

【数据结构】——双向链表

一、链表的分类 我们前面学习了单链表,其是我们链表中的其中一种,我们前面的单链表其实全称是单向无头不循环链表,我们的链表从三个维度进行分类,一共分为八种。 1、单向和双向 可以看到第一个链表,其只能找到其后一个…...

AI助力:零基础开启编程之旅

一、代码调试 三步解决BUG 1. 错误信息翻译 指令模板: 错误诊断模式我遇到【编程语言】报错“粘贴报错信息“ 请: 用小白能懂的话解释问题本质标注可能引发该错误的三个场景给出最可能的修复方案和其他备选方案 2. 上下文分析 进阶指令 结合上下文代…...

mybatis中${}和#{}的区别

先测试&#xff0c;再说结论 userService.selectStudentByClssIds(10000, "wzh or 11");List<StudentEntity> selectStudentByClssIds(Param("stuId") int stuId, Param("field") String field);<select id"selectStudentByClssI…...