当前位置: 首页 > article >正文

推测解码算法在 MTT GPU 的应用实践

前言​

目前主流的大模型自回归解码每一步都只生成一个token, 尽管kv cache等技术可以提升解码的效率,但是单个样本的解码速度依然受限于访存瓶颈,即模型需要频繁从内存中读取和写入数据,此时GPU的利用率有限。为了解决这种问题,VLLM框架中提出的continues batching的推理方式则是充分利用批量推理来缓解或避免访存瓶颈,极大的提升了推理系统的吞吐量。不同于VLLM等框架在系统层面的加速优化,本文所想要介绍的推测解码(speculative decoding)技术则是聚焦于算法层面的加速优化,其核心想法是借助于更小的模型来并行或者串行生成多个token。

OpenAI 于2024年11月5日提出了 "Predicted Outputs" 特性,在某些情况下,LLM 的大部分输出是事先已知的。例如,如果用户要求模型对某些文本或代码进行仅有少量修改的重写,那么可以通过使用预测输出显著降低延迟,将已有内容作为预测传入。这种技术与prompt-lookup Decoding很相似,即通过匹配prompt中相似的token序列来生成候选token的方法,这种算法的优势是不需要额外的模型来验证候选的token是否是可接受的,但劣势是只适用于特定的任务,即输出大概率在输入中的,如果是翻译任务则大概率会适得其反。因此,"Predicted Outputs" 可以看成是一般性推测解码的一种特殊场景的引用。

对于一般性的推测解码技术而言,目前方法通常都分为两个步骤:一是使用Draft model生成若干个token序列;二是将候选token序列输入到LLM(Target model)中进行验证。因此大部分工作都集中在如何设计和训练一个准确性高同时参数量小的Draft model,以及如何在验证阶段更快的验证那些合理的token序列。一般来说Draft model 要比 Target model 参数量要小很多,每一次迭代至少会生成1个token,最多会生成K+1个Token。speculative decoding取得加速收益的两个关键因素:一是自然语言层面,存在一些比较容易的token可以用更小的代价来生成;二是硬件层面,Batch的情况下硬件不会陷入计算瓶颈。

本文首先会介绍推测解码及其比较经典的EAGLE算法 [1,2],并测试官方开源的权重在A100和S4000上的推理加速结果。接着,我们基于S4000,完成在7B和14B模型在中文数据集上的训练和推理,并报告其推理加速结果。最后是本文的总结。

推测解码与EAGLE算法​

推测解码的验证策略​

贪婪解码(greedy decoding)​

当直接使用贪婪解码来生成,即取概率最大的token时, 只需要直接匹配top-1,遇到不匹配的直接丢弃。如下图,来自Draft model的y2,y3,y4,y5作为input_id输入到大模型中进行验证,最终根据大模型的输出来匹配得到y2,y3,y4这一个序列。

gd

随机解码(Nucleus decoding)​

当使用随机解码时,由于选择下一个token是偏随机性的,此时验证的策略则会更复杂一些。下面介绍deepmind的一篇论文[3]中提出的验证算法,他们严格证明了对于任意分布p(x)p(x)和q(x)q(x),通过从p(x)p(x)和q(x)q(x)进行投机采样所得到的标记的分布与仅从p(x)p(x)进行采样所得到的标记的分布是相同的。

算法: 生成并评分初步标记序列(draft tokens)输入: 
- 一个较小的自回归模型(draft model)
- 一个目标大模型(target model)
- 期望生成的标记序列长度 K输出:
- 最终生成的标记序列步骤:
1. 生成 draft tokensa. 使用 draft model 生成一个长度为 K 的初步标记序列(draft tokens)b. 记录这 K 个草稿标记对应的概率值 p
2. 评分 draft tokensa. 使用 target model 对这 K 个草稿标记进行评分,获得概率值 qb. 评分时间与评分单个标记的时间相当
3. 判断是否接受a. 对于每个草稿标记,计算 min(1, q/p) 作为接受该标记的概率b. 生成 K 个 [0,1] 范围内的均匀随机数c. 如果随机数小于等于 min(1, q/p),则接受该草稿标记,否则拒绝
4. 处理接受/拒绝结果a. 如果所有 K 个草稿标记都被接受,从 target model 直接采样第 K+1 个标记b. 如果第 t 个草稿标记被拒绝,则:- 从 q(x) - p(x) >= 0 的修正概率分布中采样一个新标记- 将之前接受的草稿标记和新采样的标记连接作为最终结果
5. 返回结果

EAGLE算法​

动机​
  • 相比于其他工作如MEDUSA直接预测token,预测“特征”比预测token更简单,特征指的是 LLM 倒数第二层的feature

  • 保留特征层可以更好的克服采样过程中的不确定性。如下图,在输出 I 之后,会按概率采样输出 am 或是 always。在进一步寻找 always 的后续输出时,如果能保留 I 的特征层输出,就能保留住采样过程中丢掉的关于 am 的信息 

    alt text

整体流程​
  • 产生候选token阶段

    在EAGLE算法中,由一个参数量较小的draft model来完成后续token的生成,是比较标准的transformers layer的结构,它的作用是对于输入的最后一层隐藏层的特征和token的embedding特征预测下一个token的隐藏层特征,然后通过原始LM head来预测token在draft model进行预测时。EAGLE draft model是以自回归的形式来迭代预测的,假设它预测了kk个step,并且我们每个step只保留概率最高的mm个token,那么我们就可以得到kmkm个token序列,这kmkm个token序列接下来会一次性送进大模型中进行验证。

  • 验证阶段

    对于验证阶段得到的kmkm条路径序列,如果每条路径都要过一次大模型进行验证,这样代价是很大的。得益于tree attention这种对mask的巧妙设计,我们可以将这kmkm条路径组成一棵树的形式并修改对应节点的attention mask,这样就可以大模型前向计算一次就可以完成所有路径的验证[4]。

  • 验证阶段的优化

    根据上面的介绍,如果我们朴素的将kmkm条路径全部进行验证,也可能会有不少冗余的验证计算。因此,在EAGLE-1[1]中,作者经验性的将验证树进行裁剪,只保留m条固定的路径。而在EAGLE-2[2]中,作者利用了Draft model给出的token置信度得分来动态的对草稿树进行裁剪,来选择最有可能的验证序列,从而尽可能实现接受token数量的最大化。

    alt text

草稿模型(draft model)的训练​

在训练draft model时,有两个损失函数一起联合训练,一个是回归损失,即用L1 loss来计算Draft model预测出来的i+1i+1时刻的特征和真实特征之间的差异;另一个是用于增强的分类损失,用交叉熵损失计算i+1i+1时刻最终的预测token与真实token的差异。

Lreg=SmoothL1(fi+1,Draft_Model(T2:i+1,F1:i))Lreg​=SmoothL1(fi+1​,Draft_Model(T2:i+1,F1:i​​))

Lcls=Cross_Entropy(Softmax(LM_Head(fi+1),Softmax(LM_Head(fi+1′))Lcls​=Cross_Entropy(Softmax(LM_Head(fi+1​),Softmax(LM_Head(fi+1′​))

作者实验分析表明EAGLE对训练数据的敏感程度很低,因此我们可以用一批提前计算好的数据来并行的训练draft model,从而可以极大的节省训练的代价,类比于训练LLM的next token prediction, 我们可以称训练draft model为next token feature prediction.

EAGLE的加速效果​

在本小节中,基于官方训练好的EAGLE-Qwen2-7B-Instruct和EAGLE-Vicuna-13b的权重,我们测试7B和13B模型在A100和S4000上的加速效果。我们选取了英文的alpaca(通用问答)、gsm8k(数学)、humaneval(代码)和sum(文本摘要)四种类型的任务作为测试数据集。其中,一个一般的7B模型对应的EAGLE draft model权重的参数量大约为0.25B,14B模型对应的权重参数量大约为0.38B。(注:官方放出的权重中未在中文训练集上训练,同时本博客的实验均是使用贪婪解码策略下的加速结果。)

7B模型的EAGLE推理加速结果

alpacagsm8khumanevalsum
A1002.92x3.0x3.09x2.91x
S40001.90x1.94x1.98x1.76x

14B模型的EAGLE推理加速结果

alpacagsm8khumanevalsum
A1003.03x3.14x3.50x2.47x
S40002.23x2.30x2.46x2.10x

EAGLE on S4000​

在本节中,我们基于S4000完成在qwen2-7B-instruct和 Qwen2.5-14B-instruct模型在中文数据集上的EAGLE draft model的分布式训练,并测试其在中文测试集上的加速效果。

训练​

首先,我们利用了开源的Magpie-Qwen2-Pro-200K-Chinese和Sharegpt_zh数据集,从中抽取了约70k条数据作为训练集,利用官方仓库[5]中的ge_data_all_qwen2.py文件来提前生成好训练数据。 其次,我们基于kuae1.3环境做分布式训练,我们既可以使用acclerate(accelerate>=0.33已经支持了musa后端)来训练,也可以使用适配好的deepspeed来训练。如果使用deepspeed,我们只需要在官方仓库中的main_deepspeed.py的启动脚本中加入以下几行musa的环境变量,即可启动musa的单机多卡训练。

export MUSA_VISIBLE_DEVICES='0,1,2,3,4,5,6,7'
export DS_ACCELERATOR=musa
export MCCL_PROTOS=2
export MUSA_KERNEL_TIMEOUT=18000000

最终,我们可以得到如下的14b模型的eagle模型训练曲线图,top-3 acc约0.95,top-1 acc约为0.84。

alt text

推理​

我们使用了两类任务来评测EAGLE在S4000上的推理加速结果,一类是alpaca,即从开源的alpaca通用问答数据集收集了50条prompt;另一类是writing,开源的创意写作数据集中抽取了50条prompt,其特点是生成的文字都比较长。

7B模型的EAGLE推理加速结果

alpacawriting
A1002.73x2.55x
S40001.75x1.94x

14B模型的EAGLE推理加速结果

alpacawriting
A1002.92x2.79x
S40002.00x1.90x

最后,我们可以简单分析一下影响EAGLE算法加速效果的两点主要原因。

  • Draft Model预测的准确率。一般而言由于中文的token粒度要大于英文token的粒度,因此在中文上预测下一个token的难度要大一些,准备率可能会低一些。在本实验中我们发现,14B模型的EAGLE的中文平均接受token长度约为3.1,而论文中平均约达到了3.8以上,因此,本实验所训练的EAGLE模型,应该说也还有较大的优化和提升空间。
  • batch情况下的计算耗时。在推理时,我们利用EAGLE模型会得到很多条序列路径,并根据得分选择最终的K条路径,以batch的形式给大模型前向计算做验证,选择最终可接受的token序列。在这个过程中,如果K=1与K=m(m>1)情况下,大模型前向计算一次的耗时是大致相近的,那么此时获得的理论收益是可观的;否则,如果随着K增大,大模型的耗时会按照某种比例增加,那么此时获得的收益也会递减。总的来说,具体的收益取决于不同GPU的算力,带宽及其软件栈的实现和优化方式。

总结​

在这篇博客中,我们介绍了传统的自回归解码和推测解码算法EAGLE,其中推测解码的加速主要来源于一些比较容易的token可以用更小的代价来生成以及Batch的情况下GPU不会陷入计算瓶颈,本质上是一种利用冗余算力换取速度的方法。基于S4000,我们完成了中文上的7B和14B模型的EAGLE模型的训练与推理,并且分别取得了平均约1.80x和1.95x的加速收益。可以看到,基于S4000或者MT-GPU,已经可以很方便的完成大模型的训练和推理实验。

当然,推测解码也不无缺点,它需要额外训练一个草稿模型,不像flash attention拿来即用,同时推测解码这种用算力换时间的方式,可能会影响推理系统的吞吐量。因此,在像VLLM等推理加速框架比较成熟的情况下,将推测解码算法与像VLLM这种加速框架结合来进一步提升系统吞吐量可能是另一个值得研究的问题,此时需要考虑系统吞吐量和系统延迟之间的权衡。不过,随着近期openai O1,以及deepseek-R1等具有超长思维链过程的擅长逻辑推理的大模型的兴起,给大模型的inference带来了新的挑战,也给推测解码这类技术带来新的机遇。

参考文献​

  1. Li, Yuhui, et al. "Eagle: Speculative sampling requires rethinking feature uncertainty." arXiv preprint arXiv:2401.15077 (2024).↩
  2. Li, Yuhui, et al. "Eagle-2: Faster inference of language models with dynamic draft trees." arXiv preprint arXiv:2406.16858 (2024).↩
  3. Chen, Charlie, et al. "Accelerating large language model decoding with speculative sampling." arXiv preprint arXiv:2302.01318 (2023).↩
  4. Cai, Tianle, et al. "Medusa: Simple framework for accelerating llm generation with multiple decoding heads." 2023,(2023).↩
  5. https://github.com/SafeAILab/EAGLE.↩

相关文章:

推测解码算法在 MTT GPU 的应用实践

前言​ 目前主流的大模型自回归解码每一步都只生成一个token, 尽管kv cache等技术可以提升解码的效率,但是单个样本的解码速度依然受限于访存瓶颈,即模型需要频繁从内存中读取和写入数据,此时GPU的利用率有限。为了解决这种问题,…...

Axure酒店管理系统原型

酒店管理系统通常被设计为包含多个模块或界面,以支持酒店运营的不同方面和参与者。其中,管理端和商户端是两个核心组成部分,它们各自承担着不同的职责和功能。 软件版本:Axure RP 9 预览地址:https://556i1e.axshare.…...

写实交互数字人在AI招聘中的应用方案

随着科技的进步,越来越多的行业开始探索如何利用人工智能提升效率和服务质量。其中,写实交互数字人技术以其高度拟真的交互体验和丰富的情感表达能力,在人力资源领域特别是招聘环节中展现出了巨大潜力。本文将探讨写实交互数字人在AI招聘中的…...

C++中IO类(iostream、fstream和sstream)知识详解和应用

一、C I/O 类体系概览 C 的 I/O 功能由一组 流&#xff08;stream&#xff09; 类封装&#xff0c;位于头文件 <iostream>、<fstream>、<sstream> 等。核心类别及其继承关系简图如下&#xff1a; ios_base↑basic_ios<CharT,Traits>↑┌───────…...

Spring Boot中如何对密码等敏感信息进行脱敏处理

以下是常见的脱敏方法及实现步骤&#xff0c;涵盖配置、日志和API响应等多个层面&#xff1a; ​1. 配置文件敏感信息脱敏​ (1) 使用加密库&#xff08;如Jasypt&#xff09; ​步骤​&#xff1a; 添加依赖&#xff1a; <dependency><groupId>com.github.ulise…...

React从基础入门到高级实战:React 基础入门 - JSX与组件基础

JSX 与组件基础 引言 在 React 开发中&#xff0c;JSX 和 组件 是两个最基础且核心的概念。JSX 是一种独特的语法&#xff0c;让你在 JavaScript 中编写类似 HTML 的代码&#xff0c;而组件则是 React 应用的基本构建块&#xff0c;帮助你将复杂的界面拆分为可复用的模块。本…...

房贷利率计算前端小程序

利率计算前端小程序 视图效果展示如下&#xff1a; 在这里插入代码片 <!DOCTYPE html> <html lang"zh-CN"> <head><meta charset"UTF-8"><meta name"viewport" content"widthdevice-width, initial-scale1.0&qu…...

在Visual Studio中进行cuda编程

首先下载与CUDA Toolkit匹配的Visual Studio版本 比如我的CUDA Toolkit版本是12.6&#xff0c;那么我可以使用2022的Visual Studio。 查看Toolkit版本 nvcc -V 配置 ok&#xff0c;让我们开始Visual Studio的nvcc编译器配置 参考例文https://github.com/apachecn/succinc…...

Fastrace:Rust 中分布式追踪的现代化方案

原文链接&#xff1a;Fastrace: A Modern Approach to Distributed Tracing in Rust | FastLabs / Blog 摘要 在微服务架构中&#xff0c;分布式追踪对于理解应用程序的行为至关重要。虽然 tokio-rs/tracing 在 Rust 中被广泛使用&#xff0c;但它存在一些显著的挑战&#xf…...

Linux云计算训练营笔记day13【CentOS 7 find、vim、vimdiff、ping、wget、curl、RPM、YUM】

Linux云计算训练营笔记day13[CentOS 7 find、vim、vimdiff、ping、wget、curl、RPM、YUM]] 目录 Linux云计算训练营笔记day13[CentOS 7 find、vim、vimdiff、ping、wget、curl、RPM、YUM]]1.find练习2.vim高级使用2.1 命令模式:2.2 插入模式:2.3 末行模式: 3. vimdiff4. ping5.…...

黑马Java基础笔记-15

Set 无索引&#xff0c;无序&#xff0c;不可重复 HashSet object类中默认hashCode的方法是根据地址值。 如果集合中存储的是自定义对象&#xff0c;必须要重写hashCode和equals方法。 底层原理 jdk8以前&#xff1a;数组 链表 jdk8及以后&#xff1a;数组 链表 红黑…...

Elasticsearch简单集成java框架方式。

Elasticsearch 在 Java 中最常用的客户端是什么&#xff1f;如何初始化一个 RestHighLevelClient&#xff1f;如何用 Spring Boot 快速集成 Elasticsearch&#xff1f;Spring Data Elasticsearch 如何定义实体类与索引的映射&#xff1f; 最常用的 Java 客户端 目前官方推荐使用…...

【RAG文档切割】从基础拆分到语义分块实战指南

目录 &#x1f31f; 前言&#x1f3d7;️ 技术背景与价值&#x1fa79; 当前技术痛点&#x1f6e0;️ 解决方案概述&#x1f465; 目标读者说明 &#x1f9e0; 一、技术原理剖析&#x1f4ca; 分块流程架构图&#x1f4a1; 核心分块策略&#x1f527; 关键技术模块 &#x1f6e…...

stream数据流

核心知识点&#xff1a;数据流&#xff08;Stream Data Flow&#xff09; 1. 通俗易懂的解释 想象一下你正在用花园里的水管浇花。水管里的水不是一次性全部倒出来的&#xff0c;而是持续不断地从水龙头流出&#xff0c;经过水管&#xff0c;最终从喷头喷洒到花上。在这个过程…...

利用 XML 外部实体注入(XXE)读取文件和探测内部网络

利用 XML 外部实体注入&#xff08;XXE&#xff09;读取文件和探测内部网络 引言 XML 外部实体注入&#xff08;XXE&#xff09;是一种常见的安全漏洞&#xff0c;攻击者可以通过这种漏洞读取服务器上的文件或探测内部网络。本文将通过一个实际的 Python 代码示例&#xff0c…...

软件设计师“排序算法”真题考点分析——求三连

一、考点分值占比与趋势分析 综合知识题分值统计表 年份考题数量总分值分值占比考察重点2018222.67%时间复杂度/稳定性判断2019334.00%算法特性对比分析2020222.67%空间复杂度要求2021111.33%算法稳定性判断2022334.00%综合特性应用2023222.67%时间复杂度计算2024222.67%分治…...

Visual Studio 2019/2022:当前不会命中断点,还没有为该文档加载任何符号。

1、打开调试的模块窗口&#xff0c;该窗口一定要在调试状态下才会显示。 vs2019打开调试的模块窗口 2、Visual Studio 2019提示未使用调试信息生成二进制文件 未使用调试信息生成二进制文件 3、然后到debug目录下看下确实未生成CoreCms.Net.Web.WebApi.pdb文件。 那下面的…...

vue--ofd/pdf预览实现

背景 实现预览ofd/pdf超链接功能 业务实现 pdf的预览 实现方式&#xff1a; 直接使用 <iframe :src"${url}#navpanes0&toolbar0" /> 实现pdf的预览。 navpanes0 隐藏侧边栏toolbar0 隐藏顶部工具栏 使用pdf.js&#xff0c;代码先行&#xff1a; <tem…...

Python 爬虫之requests 模块的应用

requests 是用 python 语言编写的一个开源的HTTP库&#xff0c;可以通过 requests 库编写 python 代码发送网络请求&#xff0c;其简单易用&#xff0c;是编写爬虫程序时必知必会的一个模块。 requests 模块的作用 发送网络请求&#xff0c;获取响应数据。 中文文档&#xf…...

【MySQL】CRUD

CRUD 简介 CRUD是对数据库中的记录进行基本的增删改查操作 Create&#xff08;创建&#xff09;Retrieve&#xff08;读取&#xff09;Update&#xff08;更新&#xff09;Delete&#xff08;删除&#xff09; 一、新增&#xff08;Create&#xff09; 语法&#xff1a; I…...

Spring Boot微服务架构(三):Spring Initializr创建CRM项目

使用Spring Initializr创建CRM项目 一、创建项目前的准备 访问Spring Initializr网站&#xff1a; 打开浏览器访问 https://start.spring.io/或者直接使用IDE&#xff08;如IntelliJ IDEA或Eclipse&#xff09;内置的Spring Initializr功能 项目基本信息配置&#xff1a; Proj…...

【笔记】PyCharm 中创建Poetry解释器

#工作记录 在使用 PyCharm 进行 Python 项目开发时&#xff0c;为项目配置合适的 Python 解释器至关重要。Poetry 作为一款强大的依赖管理和打包工具&#xff0c;能帮助我们更便捷地管理项目的依赖项与虚拟环境。下面将详细记录在 PyCharm 中创建 Poetry 解释器的步骤。 前提条…...

SDL2常用函数SDL事件处理:SDL_Event|SDL_PollEvent

SDL_Event SDL_Event是个联合体&#xff0c;是SDL中所有事件处理的核心。 SDL_Event是SDL中使用的所有事件结构的并集。 只要知道了那个事件类型对应SDL_Event结构的那个成员&#xff0c;使用它是一个简单的事情。 下表罗列了所有SDL_Event的所有成员和对应类型。 Uint32typ…...

RAID技术全解析:从基础到实战应用指南

一、RAID核心概念与级别对比 1. RAID的核心目标 数据冗余&#xff1a;通过镜像或校验机制防止数据丢失。 性能提升&#xff1a;利用条带化技术实现并行读写。 存储扩展&#xff1a;聚合多块磁盘容量&#xff0c;突破单盘限制。 2. 常见RAID级别对比 RAID级别最小磁盘数容…...

word通配符表

目录 一、word查找栏代码&通配符一览表二、word替换栏代码&通配符一览表三、参考文献 一、word查找栏代码&通配符一览表 序号清除使用通配符复选框勾选使用通配符复选框特殊字符代码特殊字符代码or通配符1任意单个字符^?一个任意字符?2任意数字^#任意数字&#…...

python中的numpy(数组)

&#xff08;0&#xff09;numpy介绍 NumPy是Python中用于科学计算的基础库&#xff0c;提供高效的多维数组对象ndarray&#xff0c;支持向量化运算&#xff0c;能大幅提高数值计算效率。它集成了大量数学函数&#xff08;如线性代数、傅里叶变换等&#xff09;&#xff0c;可…...

C++ 正则表达式简介

1. 正则表达式简介 正则表达式&#xff08;Regular Expression&#xff0c;简称Regex&#xff09;是一种用于匹配和处理文本的强大工具。它通过特定的符号组合形成匹配规则&#xff0c;常用于表单验证、文本搜索与替换、数据清洗等场景。 C11标准引入了 <regex> 头文件…...

iOS知识复习

block原理 OC block 是个结构体&#xff0c;内部有个一个结构体成员 专门保存 捕捉对象 Swift闭包 是个函数&#xff0c;捕获了全局上下文的常量或者变量 修改数组存储的内容&#xff0c;不需要加_block,修改数组对象本身时需要 weak原理 Weak 哈希表 &#xff08;散列表&a…...

rce命令执行原理及靶场实战(详细)

2. 原理 在根源上应用系统从设计上要给用户提供一个指定的远程命令操作的接口。漏洞主要出现在常见的路由器、防火墙、入侵检测等设备的web管理界面上。在管理界面提供了一个ping服务。提交后&#xff0c;系统对该IP进行ping&#xff0c;并且返回结果。如果后台服务器并没有对…...

Fuzz 模糊测试篇JS 算法口令隐藏参数盲 Payload未知文件目录

1 、 Fuzz 是一种基于黑盒的自动化软件模糊测试技术 , 简单的说一种懒惰且暴力的技术融合了常见 的以及精心构建的数据文本进行网站、软件安全性测试。 2 、 Fuzz 的核心思想 : 口令 Fuzz( 弱口令 ) 目录 Fuzz( 漏洞点 ) 参数 Fuzz( 利用参数 ) PayloadFuzz(Bypass)…...