EAGLE代码研读+模型复现
要对代码下手了,加油(ง •_•)ง
作者在他们自己的设备上展现了推理的评估结果,受第三方评估认证,EAGLE为目前最快的投机方法(虽然加速度是评估投机解码方法的主要指标,但其他点也值得关注。比如PLD和Lookahead无需额外参数,更容易和许多模型进行集成),所有用来评估的方法都和Spec-Bench对齐。
设备:一台NVIDIA GeForce RTX 3090 GPU(24GB) ,带12个CPU核
测试环境:Pytorch 2.0.1,CUDA 11.8
环境设置:Vicuna-7B-v1.3,贪心解码,FP16精度,批量大小为1
EAGLE-2会利用草稿模型打出的信心分数去近似接受率,动态调整草稿的树形架构,进一步提升性能。EAGLE-2在两块RTX 3060 GPU上的推理速度,比原始的投机解码在一块A100 GPU上的推理速度要快。

设备:一台NVIDIA A100 GPU(80GB) ,带64个CPU核(比V100提升20倍的AI计算性能,实验室没有-_-,相较于RTX 4090、A40,更适合大模型AI训练)(看人家怎么说的,在8张RTX 3090 GPU上也能训练,1-2天完事儿,“so even the GPU poor can afford it”)(下面3张图看着乐呵一下,俺只有3090(≧∀≦)ゞ)
测试环境:Pytorch 2.0.1,CUDA 11.4
实验设置:贪心解码,FP16精度,批量大小为1



EAGLE已被集成到许多主流的LLM服务框架中了,比如Intel Extension for Transformers、vLLM、SGLang等等。
相应GitHub库的更新如下:2023.12.8发布EAGLE v1.0,2024.1.17起支持Mixtral-8x7B-Instruct、2024.2.25经第三方评估认证为最快的投机方法、2024.6.27发布EAGLE-2、2024.8.8起支持QWen-2(阿里巴巴集团QWen团队开发的第二代LLM系列,旨在提升自然语言处理和生成任务的性能)。不久的将来将发布EAGLE-3(2025.3.19更新:太牛了,我代码还没复现完它就出了)。
我选用的基础模型主要是Llama-2-chat-7B,该模型是由Meta(原Facebook)基于Transformer架构开发的开源LLM,有70亿参数,属于Llama-2系列中的较小版本,专为对话任务微调,适合交互式应用。
配置并安装
git clone https://github.com/SafeAILab/EAGLE.git
cd EAGLE
pip install -r requirements.txt
在安装requirements.txt中的包时,要使用合适的python版本,比如我试过python3.11报错说没有“python3.11/site-packages/torch/include/torch/目录”,而3.9可以丝滑安装。
而且需要提前设置一下pip镜像,参考博客,不然慢不慢另说,还可能找不到某些版本的包而发生报错。
为什么要cd到这个目录呢?其实后面可以发现,许多文件里到处都是相对路径(🤯)也不是说一定要在这个目录吧,只是要改成对应的正确路径。
下载EAGLE权重
作者提供了各式目标模型对应的EAGLE参数的Hugging Face网址。对于在上面托管的模型,点开hugging face界面的“Use this model”-》Transformers库,一般可以看到两种使用transformer库的加载方法(下图microsoft/DialoGPT-small对应的界面这样)

一种使用高级封装pipeline,自动处理tokenization(文本切分)、模型推理、解码这些步骤,简单不灵活;一种手动加载模型和分词器,可自由调整参数,可扩展性强,代码复杂。
1. 使用pipeline作为高级封装
# Use a pipeline as a high-level helper
from transformers import pipelinepipe = pipeline("text-generation", model="yuhuili/EAGLE-Vicuna-7B-v1.3")
pipeline是transformer库提供的一个高级封装,可自动处理tokenization(文本切分)、模型推理、解码这些步骤;"text-generation"任务代表使用自回归文本生成模型,比如GPT类模型(Vicuna也在其中);model="yuhuili/EAGLE-Vicuna-7B-v1.3"告诉pipeline下载该模型。加载后使用方法就是
result=pipe("Hello, how are you?")
print(result)
pipe()直接输入文本,返回生成的文本结果,通常是一个包含文本输出的列表。简单易用,适用于对性能要求不高时的快速部署。但缺乏灵活性,无法自定义tokenizer或model的参数,例如温度、最大长度等,且pipeline默认会尝试自动优化加载方式,可能消耗额外显存。
2. 手动加载模型和分词器
# Load model directly
from transformers import AutoTokenizer, AutoModelForCausalLMtokenizer = AutoTokenizer.from_pretrained("yuhuili/EAGLE-Vicuna-7B-v1.3")
model = AutoModelForCausalLM.from_pretrained("yuhuili/EAGLE-Vicuna-7B-v1.3")
tokenizer分词器负责将输入的文本转换为token,用于模型计算,并将模型的输出转换回人类可读的文本。从Hugging Face服务器下载该模型的权重和架构时,AutoModelForCausalLM适用于因果语言模型,如GPT、Vicuna这类基于自回归生成的Transformer模型。使用时需要手动处理输入输出
input_text = "Hello, how are you?"
input_ids = tokenizer(input_text, return_tensors="pt").input_ids # 将文本转换为 token ID(张量格式)
output = model.generate(input_ids, max_length=50) # 让模型生成文本,指定最大生成长度为50
result = tokenizer.decode(output[0], skip_special_tokens=True) # 将生成的 token ID序列转换成字符串
print(result)
优点是高度可控,可以自由调整generate()里的参数,如温度、top_k、top_p等,提高生成质量,可以优化显存占用(如启用torch_dtype=torch.float16或device_map="auto"),适合大规模任务,可扩展性强,可与LoRA、DeepSpeed、FSDP等优化技术结合,可扩展性强。但代码复杂度高,默认不会自动优化显存占用,加载大模型时可能会超出GPU负载。
但是,要真直接运行这代码,常会报错(≧∀≦)ゞ,比如EAGLE系列模型的就说没分词器啦~聊天小模型microsoft/DialoGPT-small就说输入内容格式不对啦~因为这代码是网站自动生成的,适用于通常情况。就以yuhuili/EAGLE-Vicuna-7B-v1.3的情况来说吧,点开“Files and versions”,下面确实没有分词器(可以拿上图比对一下)

咱就下载一下EAGLE weight🙄,直接运行下边这个得了,下载条拉满即可。
git lfs clone https://huggingface.co/yuhuili/EAGLE-llama2-chat-7B
另外作者说了一下,目标模型是QWen时,应采用bf16而非fp16以避免数字溢出,两者都是16位浮点数格式,但指数和位数分布分配不同,BF16为8位指数7位尾数,数值范围接近FP32,FP16为5位指数10位尾数。草稿模型的训练数据集为ShareGPT,数据全英文,如果想要将其用在非英文,比如中文的数据上,需要用相应的数据进行训练。在EAGLE的基础上,EAGLE-2无需额外的训练,直接使用相同的权重。仓库提供的推理代码会自动分配模型权重,在多个GPU上加载模型,使得超出单个GPU内存的模型也能跑起来。(自动分布式吗?还挺牛的!看代码咋实现)
用UI体验
作者提供了网络端口,运行下列命令即可启动(damn!还得事先下载llama2-chat-7B)。等模型被完全加载,终端就会输出网址,点进去就跳到浏览器体验啦。(好好好显存不够,换台新的服务器。每当俺在一台新服务器上注册账号后,需要设置免密登录,安装Miniconda、pip,禁止自动激活base环境、安装扩展,scp文件,有时还得加速这个加速内个🙄)(还得小心代码中给你指定设备的情况)
python -m eagle.application.webui --ea-model-path [path of EAGLE weight]\ --base-model-path [path of the original model]\--model-type ["llama-2-chat","vicuna","mixtral","llama-3-instruct"]\--total-token [int]
total-token这个选项代表草稿token的数量,如果用的是较小的模型或先进的GPUs(没有:p)的话这个值可以大点。反正就根据具体的硬件和模型做调整,以达更好的性能吧。设成-1的话EAGLE-2会自动配置的(😭模型啥都能自动了而本菜鸡怎么活)。
3090GPU的服务器挂掉了,我换了台A40 GPU的服务器,这时仓库提供的requirement.txt中,torch和accelerate版本(2.0.1和0.21.0)会出现“设备映射”的问题,代码里有的地方又用的旧版本,而且轻易更换版本容易冲突,总之各种问题!再换台服务器吧😭
代码运行报错说让你安装什么包、什么版本,就照着报错提示去安装得了。
TODO: 怪我写的太慢,eagle3都出来了我还在这磨蹭!注意到eagle3中添加了一个参数”draft_vocab_size“,这难道可以控制草稿长度?后面看看!然后在eagle3版本的webui.py中有下面这么句代码,光从字面意思上看也很迷惑啊!加个“not”吧。
use_eagle3=args.no_eagle3,
特别解释一下-m选项,用于以模块方式运行Python脚本,告诉python查找eagle.application.webui这个模块并运行它(会去哪找eagle模块呢?当前目录以及PYTHON中),而无需手动cd进入目录再执行python webui.py,更灵活(吗?)
所以直接运行可能会遇到一个问题,如果你不是正处于EAGLE这一级目录(谁知道经历了前面一堆乱七八糟的时候跑到了哪个鬼目录),又在PYTHONPATH找不到eagle.application.webui这个模块的话,就会报错“Error while finding module specification for 'eagle.application.webui' (ModuleNotFoundError: No module named 'eagle')”。要么执行“cd xxx/EAGLE”回到正确目录,要么执行下列指令
export PYTHONPATH=xxx/EAGLE:$PYTHONPATH
(哪个办法好我也是反复横跳,最后结论是觉得前者麻烦,限死了终端执行目录,脚本文件存放位置。我倾向于后者,直接在终端执行的话,就临时添加一下,换个终端便没了,写到~/.bashrc文件里,source一下永久生效吧也行,下次换别的哪个模块的话改一下)
用第一种办法就每次在终端打那么老长的指令,也累,可以直接去改 eagle/application/webui.py下解析参数的代码,设成默认值,指令在“webui”后截断即可
python -m eagle.application.webui --ea-model-path ../EAGLE-llama2-chat-7B/ --base-model-path ../Llama-2-7b-chat-hf/ --model-type llama-2-chat --total-token 2
感谢!下面看看弹出来的界面(两张图才截全,说真的,作者的界面都做得好好看இ௰இ)


我问这个负责任有道德的模型,为啥它有的字体显示橙色,它说它只会生成文本,没能力调颜色或用视觉特效,可能是设备、浏览器、平台或什么别的的锅。作者用gradio包做的网页,只需看上面的勾选框和解释即可知,橙色高亮部分是EAGLE-2正确猜测的token,上图便是将total-token设成2的结果。
这个纷繁复杂的交互式网页,是作者通过“ea_generate”流式返回模型每次前向的结果,以快速响应用户请求的。其中每次前向的结果output_ids的形状为(batch_size, sequence_length),都存到text里,其中第一个token由原始模型生成,存到naive_text里,剩下的token便是EAGLE-2的功劳。所以高亮哪些token呢?找text中naive_text里没有的。代码简陋一点看就是
for output_ids in model.ea_generate(input_ids, ...):...# decode_ids是截至本轮所有新生的tokendecode_ids = output_ids[0, input_len:].tolist() # 去掉输入部分decode_ids = truncate_list(decode_ids, model.tokenizer.eos_token_id) # 截断终止符后面的部分...text = model.tokenizer.decode(decode_ids, skip_special_tokens=True,spaces_between_special_tokens=False,clean_up_tokenization_spaces=True, )naive_text.append(model.tokenizer.decode(output_ids[0, cu_len], # 本轮首个tokenskip_special_tokens=True,spaces_between_special_tokens=False,clean_up_tokenization_spaces=True, ))cu_len = output_ids.shape[1]# 将text中naive_token里没有的内容高亮colored_text = highlight_text(text, naive_text, "orange")...
用代码体验
下面用“eagenerate”,一次性返回完整的token序列(体验感不如上面的好,有时会怀疑它卡了或者又要报错),就像用Hugging Face的“generate”那样,如下
from eagle.model.ea_model import EaModel
# EaModel,来自eagle.model.ea_model模块,一个用于NLP任务的模型类,支持从预训练模型加载权重
from fastchat.model import get_conversation_template
# get_conversation_template,来自fastchat.model,用于获取对话模板(如vicuna)
import torchdef warmup(model):# 按照目标模型类型创建对话模板conv = get_conversation_template(args.model_type)if args.model_type == "llama-2-chat":# Llama 2 Chat版本需要一个系统提示词,确保其回答安全无偏见符合道德,其他模型可能就无需这种额外约束了sys_p = "You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.\n\nIf a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information."conv.system_message = sys_pelif args.model_type == "mixtral":conv = get_conversation_template("llama-2-chat")conv.system_message = ''conv.sep2 = "</s>" # 特定结束符your_message="who are you?"# 将用户输入“hello”作为第一个角色(通常是用户)的话加入对话conv.append_message(conv.roles[0], your_message)# 给第二个角色(通常是AI模型)留一个空的响应位置,等待模型生成conv.append_message(conv.roles[1], None)# get_prompt()负责将对话格式化成适合EaModel处理的输入文本prompt = conv.get_prompt()if args.model_type == "llama-2-chat":prompt += " "# 分词器将prompt转换成token idinput_ids=model.tokenizer([prompt]).input_ids# 再转换成PyTorch张量,并转移到GPU提高推理效率input_ids = torch.as_tensor(input_ids).cuda()# 进行文本生成output_ids = model.eagenerate(input_ids,temperature=0.5,max_new_tokens=512) # eagenerate一次性返回完整的token序列output=model.tokenizer.decode(output_ids[0])print(output)# 使用命令行参数,添加参数解析(包已内置于python中)
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--ea_model_path",type=str,default="EAGLE-llama2-chat-7B",help="The path of EAGLE weight. This can be a local folder or a Hugging Face repo ID(<组织名或用户名>/<模型名>)."
)
parser.add_argument("--base_model_path",type=str,default="Llama-2-7b-chat-hf",help="path of the original model. a local folder or a Hugging Face repo ID"
)
parser.add_argument("--load_in_8bit",action="store_true", # 如果提供该参数,则值为True,否则默认为Falsehelp="use 8-bit quantization"
)
parser.add_argument("--load_in_4bit",action="store_true",help="use 4-bit quantization"
)
parser.add_argument("--model_type",type=str,default="llama-2-chat",choices=["llama-2-chat","vicuna","mixtral","llama-3-instruct"]
)
parser.add_argument("--total_token",type=int,default=-1,help=" the number of draft tokens"
)
parser.add_argument("--max_new_token",type=int,default=512,help="the maximum number of new generated tokens",
)
args = parser.parse_args()model = EaModel.from_pretrained(base_model_path=args.base_model_path,ea_model_path=args.ea_model_path,total_token=args.total_token,torch_dtype=torch.float16,low_cpu_mem_usage=True,load_in_4bit=args.load_in_4bit,load_in_8bit=args.load_in_8bit,device_map="auto",
)# 让模型进入推理模式,防止dropout等影响推理
model.eval()warmup(model)
eagle.application.webui里有解析参数的代码,但eagle.model.ea_model里没有,得在作者提供的调用代码里再额外填上,另外Vicuna、LLaMA2-Chat、LLaMA3-Instruct都是聊天模型,咱需要使用对应正确的聊天模板,不然会从模型里产生异常输出,影响EAGLE性能。我仿照webui.py的代码改了改。(不要两种体验同时用,一旦内存不够了它就给你报奇奇怪怪的错误)(一个来回的对话,而且比较久才回你)

训练
生成训练数据
先说一下文件调用结构,在原项目中,allocation.py将待处理的数据平均分配,各GPU执行各自任务,并行调用ge_data_all_llama2chat.py,使用llama2-chat-7B进行数据预处理、加载数据集、推理并保存结果。数据集的存储格式如下
new_examples = {"conversation": [], # 存储对话文本"input_ids": [], # 输入token ID"loss_mask": [] # 标记哪些token需要计算loss
}
预测结果被存储成.ckpt(checkpoint)文件,存储内容如下:
input_ids: Tensor, shape=torch.Size([151]), dtype=torch.int64
hidden_state: Tensor, shape=torch.Size([151, 4096]), dtype=torch.float16
loss_mask: Tensor, shape=torch.Size([151]), dtype=torch.int64
其中最重要的信息就是
hidden_state_big = outs_big.hidden_states[-1] # 提取最后一层隐藏状态
作者没有提供原数据集,咱去hugging face上随便找了个——shareGPT/computer_en_26k.jsonl。检查了一下,总共20692条数据,单拎第一条出来看看
{"conversation_id": "vtu3ZfW", "category": "Program and Code", "conversation": [{"human": "Give Landing page HTML and CSS with bootstrap framework that takes email address and phone number in the lead generation form", "assistant": "Here is an example of a basic landing page HTML structure using the Bootstrap framework:\n...server."}]
}
得据此修改ge_data_all_llama2chat.py中用到的键名,另外这个数据集没有清理的好,“category”的值存在字符串和列表混用的现象,得自己检查一遍,检查代码如下:
import pandas as pd
original_file = '/home/xxx/EAGLE_test/original_data/computer_en_26k.jsonl'
cleaned_file = '/home/xxx/EAGLE_test/original_data/computer_en_26k_cleaned.jsonl'try:df = pd.read_json(original_file, lines=True)print(df.head())
except Exception as e:print(f"Error loading JSON: {e}")# 检查 'category' 列是否存在类型不一致
print(df['category'].apply(type).value_counts())# 遍历 'category' 列并查找出现 list 的位置
for index, value in enumerate(df['category']):if isinstance(value, list):print(f"First occurrence of list in 'category' at row {index}")print(value)def clean_category(value):if isinstance(value, list):return ' '.join(value)elif isinstance(value, str):return valueelse:return str(value)df['category'] = df['category'].apply(clean_category)
df.to_json(cleaned_file, orient='records', lines=True)
用原代码中读取json文件的方式读取jsonl文件也行得通。额外说一下,jsonl文件的每行都是一个JSON(JavaScript Object Notation)格式数据。JSON是一种轻量级的数据交换格式,使用键值对存储数据,类似于Python的dict或JavaScipt的对象(Object)。.json文件通常是配置文件、模型索引、数据存储或API响应。
我摘取了前68条数据,在一台服务器的4张GPU上进行处理,运行指令(参数同样可以在文件中设置默认值)
python -m eagle.ge_data.allocation --outdir [path of data]
得到的训练数据如下

4个进程并行时的输出、文件命名方式都没有太大更改,原代码还是蛮清晰的,但有些逻辑错误,到后面就会发现他妈的loss_mask全零🤯!主要是修改ge_data_all_llama2chat.py中的映射代码
def preprocess_function(examples):new_examples = {"conversation":[], # 存储对话文本"input_ids": [], # token ID"loss_mask": [] # 标记哪些token需要计算loss}# 获取LLaMA-2对话模板conv = get_conversation_template("llama-2-chat")# 设定AI助手的形为准则sys_p="You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. " \"Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. " \"Please ensure that your responses are socially unbiased and positive in nature.\n\n" \"If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. " \"If you don't know the answer to a question, please don't share false information."conv.system_message=sys_p# 遍历数据集中的所有对话num_of_examples = len(examples['conversation_id'])for i in range(num_of_examples):source= examples['conversation'][i]conv.messages = []# 处理human和assistant的每轮对话for j, sentence in enumerate(source):if "human" in sentence and "assistant" in sentence:conv.append_message(conv.roles[0], sentence["human"])conv.append_message(conv.roles[1], " " + sentence["assistant"])else:print(f"Warning: Invalid or incomplete dialogue at index {j}: {sentence}")# 获取最终格式化后的对话文本conversation=conv.get_prompt()# 将对话转成token IDinput_ids = tokenizer(conversation,return_tensors="pt", # 结果返回PyTorch张量max_length=2048,truncation=True,).input_ids[0]# 创建一个形同input_ids的张量loss_mask,初始值全为1,默认所有token都会计算lossloss_mask=torch.ones_like(input_ids)# AI回复前的分隔符:“[/INST] ”sep = conv.sep + conv.roles[1] + " "# 将对话拆分为不同轮次,conv.sep2:“ </s><s>”turns = conversation.split(conv.sep2)# 忽略开始符<s>的loss计算cur_len = 1loss_mask[: cur_len] = 0# 处理每轮对话for j, turn in enumerate(turns):# print(f"本轮内容:{turn}") # 两轮对话三个turnif turn == "":breakturn_len = len(tokenizer(turn).input_ids) # 当前轮对话的token长度parts = turn.split(sep) # 拆分成用户输入和AI回复if len(parts) != 2:breakparts[0] += sepinst_len = len(tokenizer(parts[0]).input_ids) - 2 # 去掉开始符和结尾空格# 忽略(第一轮对话还额外有提示词部分)用户输入部分的loss计算loss_mask[cur_len : cur_len + inst_len] = 0cur_len += (turn_len + 2)if j != 0 and not tokenizer.legacy:# print("是新版,</s><s>只占2个token")cur_len -= 1# print(tokenizer.decode(input_ids[cur_len-2:cur_len]))# 忽略padding位置的loss计算(实际上似乎没有填充)loss_mask[cur_len:] = 0# 把格式化后的对话、token ID、loss_mask存入new_examplesnew_examples["conversation"].append(conversation)new_examples["input_ids"].append(input_ids[None,:])new_examples["loss_mask"].append(loss_mask[None,:])# print(f"loss_mask是什么类型啊现在?{type(new_examples['loss_mask'])}") # listreturn new_examples
而且用这么点数据训练肯定是不够的,后面会出现在训练集上的准确率嘎嘎提高,而在测试集上表现平平的情况,也就是过拟合了。干脆将jsonl文件中所有数据一并处理了,总20692条样本
wc -l path/to/file
好的,显存不够了😭原本我想换一台更加空闲的服务器,但已经换了几次服务器了,深知要想转到一台全新服务器有多麻烦,然后就想用docker把实验依赖的所有环境之类的一起打包,获取镜像、实例化容器,在新服务器上拉取巴拉巴拉……但是家人们,我这条菜狗又困在网络这一关了😭docker学习进程再次搁置。
更换策略,启用4-bit NF4量化,节省显存,如下更改代码
'''启用4-bit NF4量化,最节省显存'''
quantization_config = BitsAndBytesConfig(load_in_4bit=True,bnb_4bit_compute_dtype=torch.bfloat16, # 若GPU不支持bfloat16,改用torch.float16bnb_4bit_use_double_quant=True,bnb_4bit_quant_type="nf4", # Normalized Float 4
)
bigmodel = AutoModelForCausalLM.from_pretrained(bigname,quantization_config=quantization_config,device_map="auto"
)# "auto"会自动将模型的不同层分配到可用的GPU上,以实现模型并行
# bigmodel = AutoModelForCausalLM.from_pretrained(bigname, device_map="auto", torch_dtype=torch.float16)
额外说一下,在AutoModelForCausalLM.from_pretrained中,device_map参数用于控制模型在多个GPU上的分配方式。“auto”和“balanced”是两种不同的分配策略,核心区别在于显存分配的智能程度和均匀性。
device_map="auto"时按层顺序分配,模型从第一层开始依次分配到当前显存最充足的GPU,直到占满该GPU的显存,再切换到下一个GPU,实现简单分配快,但可能导致现存利用率不均衡,某些GPU显存剩余较多,对显存碎片化较敏感。适用于模型层之间显存占用差异较小(如大部分Transformer模型)和GPU显存容量相同(下图我用的服务器的显存情况可能就不太符合?)的场景。(“gpustat -i 1”每隔1s刷新一次GPU使用信息)

device_map="balanced"时则是显存均匀分配,尽可能让每个GPU的显存占用接近均等,计算所有层的显存需求后,动态规划最优分配方案。其现存利用率高,可减少OOM风险,但分配计算开销稍大(首次加载略慢),适合模型层显存需求差异大(如MoE)和GPU显存容量不同的场景,对模型结构的适应性要求较高了。
建议就是优先尝试“balanced”,尤其是面对OOM时,若仍失败,就结合量化(4-bit)或手动分配,超大模型还可考虑offload_to_cpu(如accelerate库的device_map="auto"就支持CPU卸载)
(每个进程处理5173条数据)(如果还是OOM运行失败,但输出文件下的东西看着有模有样的,不要被骗啦!像我的,OOM时输出文件大小71G)(文件夹大小查看指令如下)
du -sh /path/to/directory
allocation分配不同GPU并行,每个GPU上运行ge_data_all_llama2chat,让多个CPU进行数据处理,ctrl+z暂停指令(c才是终止)时,显存不一定释放,可用如下指令筛选一下自己的进程
ps aux | grep "$USER" # 筛选属于当前用户的所有进程
或直接批量终止
ps aux | grep "my_ge_data_all_llama2chat.py" | grep -v "grep" | awk '{print $2}' | xargs kill -9
记得在代码中常加入下列指令:
torch.cuda.empty_cache()
为什么需要这个PyTorch提供的函数?
《PyTorch的显存管理机制》:PyTorch会使用CUDA内存缓存机制来加速张量分配和计算,当释放张量时(如del tensor或变量超出作用域)时,PyTorch不会立即将显存归还给系统,而是保留在内部的缓存池中,当后续需要重新分配新张量时,优先从缓存中复用显存,而非重新向CUDA申请,以便快速重用,减少显存分配的开销。具体地,PyTorch使用分块内存池(Block-based Memory Pool)管理显存,显存被划分为不同大小的块,分配张量时PyTorch会寻找大小最匹配的块,减少碎片化,释放张量时,块会被标记为空闲,供后续使用。但这种缓存机制就可能会导致显存占用看起来比实际需求更高。
《显存碎片化问题》若频繁分配和释放不同大小的张量,显存可能变得碎片化,即剩余显存被分割为小块,无法满足大块请求,此时即使显存总量足够,PyTorch也可能因找不到足够大的连续显存块而报CUDA out of memory错误。PyTorch的显存管理机制正是通过块复用、合并相邻空闲块、按大小分类管理等方式减少碎片化。
empty_cache()的作用就在于强制释放PyTorch缓存的未使用显存,使其归还给系统,减少显存碎片化,提高显存利用率。
另外进一步优化显存管理的方法还有复用张量(避免在循环中反复创建临时张量),还有按照报错“torch.cuda.OutOfMemoryError: CUDA out of memory. Tried to allocate 20.00 MiB (GPU 0; 23.68 GiB total capacity; 1.06 GiB already allocated; 12.06 MiB free; 1.13 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation. See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF”提示的那样,如下调整分配策略,减少每次分配的内存块最大大小,减少碎片化:
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:32"
可以通过如下代码监控显存使用
'''检查显存占用情况'''
def print_memory():allocated = torch.cuda.memory_allocated() / 1024**3reserved = torch.cuda.memory_reserved() / 1024**3print(f"Allocated: {allocated:.2f}GB | Reserved: {reserved:.2f}GB")
终于不报错了,可能在1h的推理等待后,成功输出文件的大小在138G,友友们可以帮我检查一下!
训练自回归头
接下来用前面获取的checkpoint训练自回归头。训练代码和设置可参考eagle/train/
期间有几点说明,.safetensors是一种专门用于存储神经网络权重的二进制文件格式(不能直接用文本编辑器打开,需用safetensors库解析),由Hugging Face团队开发,作为PyTorch的.pt或.bin文件更安全高效的替代方案。具体地,
1.更安全。其仅支持张量数据的存储,即使部分文件损坏,仍可加载其他部分,存储结构固定,而.bin(通常是torch.save生成的PyTorch权重文件)可以包含任意Python对象,若加载不受信任的.bin文件,可能导致代码执行漏洞(Pickle反序列化攻击),若文件损坏,可能整个模型都无法加载,若PyTorch版本不同,可能导致序列化格式变化,影响兼容性。
2. 更快。其支持按需加载(lazy loading),避免不必要数据传输,而.bin必须一次性全部加载
3. 更节约内存。其通过零拷贝(zero-copy)方式映射到内存,而.bin需要将数据从磁盘拷贝到内存,增加了额外开销
存储权重数据的二进制.safetensors文件通常有索引文件.index.json,内容如下:(本节在预训练模型上较关心的正是它的语言建模头lm_head)
{"metadata": {"total_size": 13476835328},"weight_map": {"lm_head.weight": "model-00002-of-00002.safetensors","model.embed_tokens.weight": "model-00001-of-00002.safetensors","model.layers.0.input_layernorm.weight": "model-00001-of-00002.safetensors","model.layers.0.mlp.down_proj.weight": "model-00001-of-00002.safetensors",...}
}
这些文件在原先下载的预训练模型下都配好了👌

将上图高亮的.safetensors文件中部分张量的键名和形状输出,有如下图

再简要介绍一下W&B(这种工具本菜狗才发现இ௰இ)。它是一个机器学习实验管理平台,提供了强大的工具来追踪、可视化和优化模型训练过程,允许用户记录和管理实验的超参数、损失函数、评估指标等,支持实现可视化训练曲线,并为团队提供协作和版本控制功能。通过与常见机器学习框架(如TensorFlow、PyTorch)无缝集成,W&B可帮助开发者更高效地管理实验、提高模型性能,并加速团队的研究与开发进程。!!!想用它得用点魔法访问网站!!!
1. 在Python环境中安装W&B
pip install wandb
2. 在W&B官网注册一个账号,并登录
wandb login <API key>
# 然后终端会输出:wandb: Appending key for api.wandb.ai to your netrc file: /xx/xx/.netrc
3. 在代码中导入wandb库并初始化一个新的实验
import wandb
wandb.init(project="your_project_name")
4. 记录实验参数(config)
wandb.config.batch_size = 32
wandb.config.epochs = 10
wandb.config.learning_rate = 0.001
5. 记录训练过程中的数据(log追踪)
# 假设在训练过程中记录每个epoch的损失和准确率
for epoch in range(10):loss = train_one_epoch()accuracy = evaluate_model()wandb.log({"epoch": epoch, "loss": loss, "accuracy": accuracy})
6. 若想在训练结束后保存模型,也可将其与W&B关联,进行版本控制
wandb.save('path_to_your_model.h5')
7. 结束实验
wandb.finish()
每次运行完wandb.init()后,W&B会自动创建一个新页面,记录该实验的所有信息 。
关于BF16:Bfloat16是一种16-bit浮点数格式,主要由Google提出,并在其TPU和一些NVIDIA A100等加速器中得到了广泛使用。与FP16相比,其在数值范围和精度表示上做了不同的权衡,FP16使用5位指数位,10位尾数位;BF16使用8位指数位,7位尾数位,其指数范围和FP32相当,使其能表示更广泛的数值范围,于神经网络训练中的梯度计算和大数值表示非常有用。
面对OOM问题时还可以用一个PyTorch的梯度检查点(gradient checkpointing)技术,以显著减少显存占用(↓50%~70%),但会略增加计算时间(↑20%~30%)。
问题背景:训练深度神经网络时,前向传播计算的中间结果需要保存,以便在反向传播时计算梯度,这些中间结果就会占用大量GPU显存,尤其是大模型(如Llama-2-7B)。
解决方法:只保留部分关键中间结果,其余的在反向传播时重新计算,用计算时间换显存。
from torch.utils.checkpoint import checkpoint # 导入检查点功能def _forward(self, x):return self.layer2(self.layer1(x)) # 模型的计算逻辑(如Transformer层的堆叠)def forward(self, x):return checkpoint(self._forward, x) # 使用检查点包装前向计算'''关键点解释:
1. checkpoint(func, *args)func:要优化的前向计算函数(如self._forward)*args:传给func的输入(如x)作用:PyTorch不会保存func的中间激活值,而是在反向传播时重新计算它们
2. self._forward(x)是你模型的实际前向计算逻辑(如nn.Module的forward方法)
'''''' 进阶操作如下 '''# 只对显存占用高的部分使用检查点
def forward(self, x):x = checkpoint(self.layer1, x)x = self.layer2(x) # layer2正常计算return x# 调整检查点频率
def forward(self, x):if self.training: # 仅在训练时使用检查点return checkpoint(self._forward, x)else: # 推理时不使用,避免额外计算return self._forward(x)
事先说一下我们将要处理的数据,原始数据集shareGPT/computer_en_26k.jsonl有20692条对话样本,把里面的“category”类处理一下,清洗后的数据集文件cleaned.jsonl文件中样本数是一样的(.jsonl文件每行是一个JSON样本,可用wc -l file_path进行查看),处理成20692个ckpt文件,从中抽取0.95用作训练的话,训练样本数就是19657。然后accelerator检测到服务器上有3块GPU,咱为了避免OOM设置了较小的批处理大小,为2,故每个训练epoch的进度条显示有“/3277”,测试时显示“/173”。
为什么说OOM时老说要减少批处理大小?(原先通过LLM推理获取数据集时批处理大小就是1了)加载模型本身需要固定显存,例如LLaMA-2-7B的4-bit量化后,7B参数,每个0.5个字节,再加些七七八八的,实际推理时的经验值为6-8GB;然后每增加一个样本,前向传播的中间结果就会线性增加显存占用,故总显存≈模型参数+批处理大小×单样本激活值显存。
好不容易准备好数据集,运行下列代码进行训练(我和网络和OOM问题不共戴天!)
accelerate launch -m --mixed_precision=bf16 eagle.train.main --tmpdir [path of data]\
--cpdir [path of checkpoints] --configpath [path of config file]
跑了超4h,约10min/训练epoch(明白为啥原代码只20个epoch了),可以将代码放到后台开始运行,即使电脑关闭也不影响服务器继续当牛做马。首先创建一个新会话,运行下列指令后便会进入一个新终端,此时再打开另一个终端窗口,输入tmux ls便能看到包含它在内的会话列表。
tmux new -s <会话名>
在tmux会话中输入要在后台运行的指令。如果想滚动查看窗口内容,可以配置一下鼠标支持,即在~/.tmux.conf中添加如下配置后重载(tmux source-file ~/.tmux.conf)
set -g mouse on # 允许鼠标滚动和选择面板
创建会话后键入Ctrl+B, 松手,再按D就能从前台分离会话,通过如下指令又能重新连接
tmux attach -t <会话名>
创建会话后键入Ctrl+B,松手,再按C就能新建又一个窗口,执行exit就能退出当前窗口,当最后一个窗口也被关闭时,整个会话会被自动删除。
利用tmux就能方向地关闭电脑了吗?No,还记得前面提到的,为了使用W&B,做了到本地代理的端口映射吗?这个Clash还不能关!或者说我设置了透明代理,代理IP就是本机的IPv4地址,那本机就不能关!我们的确将程序放在服务器上由tmux托管,一般本机关闭并不影响程序本体的继续运行,但这里一旦关机,目标代理IP无法连接,网络请求就会失败。W&B还挺好用的,又自动给你画图,又记录日志的,我们的服务器又不能科学上网,自己的这破电脑就一直开着呗。
区分一下几个指标
交叉熵损失:(plogp是 目标概率分布 与 模型预测的对数概率 的乘积,loss_mask用于忽略无效位置的损失)用于衡量模型预测的token分布和真实token分布的差异,适用于分类任务
out_head = head(predict)
out_logp = nn.LogSoftmax(dim=2)(out_head)
plogp = target_p * out_logp
cross_entropy_loss = -torch.sum(torch.sum(loss_mask * plogp, 2)) / (loss_mask.sum() + 1e-5)
平滑L1损失:计算模型预测的隐藏状态predict与目标隐藏状态target之间的差异
# 平滑L1损失(又称Huber损失)(介于绝对误差损失和均方误差之间的损失函数)
criterion = nn.SmoothL1Loss(reduction="none")
...
smooth_L1_loss = criterion(predict, target)
smooth_L1_loss = torch.sum(torch.mean(loss_mask * smooth_L1_loss, 2)) / (loss_mask.sum() + 1e-5)
Top-k准确率:top-1即标准准确率,指示预测的最高概率token是否正确,top-2/3表示前2/3个token是否包含正确答案,用于评估语言模型的预测能力
def top_kaccuracy(output, target, topk=(1, 2, 3)):
'''
预测头输出output的形状为(bs, num_classes)
目标输出target的形状为(bs,)
'''with torch.no_grad():maxk = max(topk) # 3_, pred = output.topk(maxk, 1, True, True) # 取预测分数前3名,分数忽略,索引保留pred = pred.t() # pred形状转置为(maxk, bs)correct = pred.eq(target.view(1, -1).expend_as(pred)
关于deepspeed:它是由Microsoft开发的深度学习优化库,旨在通过高效的内存管理、分布式训练和混合精度训练等技术,显著提高大规模模型的训练效率。利用零冗余优化器(ZeRO)减少显存占用,支持训练数十亿参数的超大模型(如GPT-3),还支持数据并行、模型并行和多机多卡训练,能在有限的硬件资源上加速训练过程,并广泛用于NLP、CV等领域的大规模深度学习任务。
很抱歉写到这里我打算放弃了,有点戛然而止的感觉,而且前面写的也乱七八糟。主要是在完成主线任务的时候本人太菜了会遇到很多支线问题,而且原论文提供的代码里也存在一些很明显的错误,会让人不禁怀疑这是否是一篇high level的论文,然后实验室挂了两台有3090的服务器成了“最后一根稻草”,我想过换另一台显存很局限的3090服务器,或一台有A40的服务器、或拿师兄买的4090服务器,但停下来审视一番后还是算了,其实也算走完了训练这一步,最后也是在整理和修改原作者的代码。我想再去探索一下别的科研方向,很感谢eagle的作者🙇,研究了代码之后我对论文的理解会更加深刻,也会感慨现在的论文能被接收很不容易,同时在和代码搏斗的这个月里俺也学到了很多东西,这就够了。最后最后,谢谢友友们能看到这里🌼
相关文章:
EAGLE代码研读+模型复现
要对代码下手了,加油(ง •_•)ง 作者在他们自己的设备上展现了推理的评估结果,受第三方评估认证,EAGLE为目前最快的投机方法(虽然加速度是评估投机解码方法的主要指标,但其他点也值得关注。比如PLD和Lookahead无需额…...
2024期刊综述论文 Knowledge Graphs and Semantic Web Tools in Cyber Threat Intelligence
发表在期刊Journal of Cybersecurity and Privacy上,专门讲知识图谱技术和语义Web工具在网络威胁情报领域的作用,还把本体和知识图谱放在相同的地位上讨论。 此处可以明确一点:本体和知识图谱都可以用于网络威胁情报的应用,当然也…...
vue3+vite 多个环境配置
同一套代码 再也不用在不同的环境里来回切换请求地址了 然后踩了一个坑 就是env的文件路径是在当前项目下 不是在views内 因为公司项目需求只有dev和pro两个环境 虽然我新增了3个 但是只在这两个里面配置了 .env是可以配置一些公共配置的 目前需求来说不需要 所以我也懒得配了。…...
秒杀系统解决两个核心问题的思路方法总结:1.库存超卖问题;2.用户重复抢购问题。
秒杀系统解决两个核心问题 秒杀系统解决两个核心问题:一、解决库存超卖的核心逻辑:解释:原子性保证: 二、如何避免重复抢购:使用 Redis 做唯一标识判断优点: 三、流程完整梳理:四、通过数据库建…...
linux socket编程之udp(实现客户端和服务端消息的发送和接收)
目录 一.创建socket套接字(服务器端) 二.bind将prot与端口号进行绑定(服务器端) 2.1填充sockaddr_in结构 2.2bind绑定端口 三.直接通信(服务器端) 3.1接收客户端发送的消息 3.2给客户端发送消息 四.客户端通信 4.1创建socket套接字 4.2客户端bind问题 4.3直接通信即可…...
SAP HANA使用命令行快速导出导入
楔子 今天折腾了接近一下午,就为了使用SAP HANA自带的命令行工具来导出数据备份。 SAP HANA(后续简称Hana)是内存数据库,性能这一方面上还真没怕过谁。 由于SAP HANA提供了Hana Studio这个桌面工具来方便运维和DBA使用…...
goland做验证码识别时报“undefined: gosseract.NewClient”
gosseract 应该有 和 c 相关的配置库因此需要安装 cgo 并且启用 CGO_ENABLED 在cmd下面输入这个 go env -w CGO_ENABLED1 接着输入 go env 验证是否设置成功 解决了这个问题后 “undefined: gosseract.NewClient” 又出现了 # runtime/cgo …...
计算机网络 实验四 静态路由的配置与应用
一、实验目的 掌握路由器基础工作原理及静态路由协议机制熟练使用华为ENSP网络模拟器进行拓扑设计与设备配置建立系统化的网络故障排除思维通过实践验证静态路由在中小型网络中的部署优势 二、实验环境 硬件配置:标准PC终端软件工具:华为企业网络模拟…...
Vue自定义指令-防抖节流
Vue2版本 // 防抖 // <el-button v-debounce"[reset,click,300]" ></el-button> // <el-button v-debounce"[reset]" ></el-button> Vue.directive(debounce, { inserted: function (el, binding) { let [fn, event "cl…...
[每周一更]-(第140期):sync.Pool 使用详解:性能优化的利器
文章目录 一、什么是 sync.Pool?二、sync.Pool 的基本作用三、sync.Pool 的主要方法四、sync.Pool 的内部工作原理五、sync.Pool 适用场景六、使用示例示例 1:基本使用输出示例:示例 2:并发使用 七、一个基于 sync.Pool 的 **Benc…...
3.QT-信号和槽|自定义槽函数|自定义信号}自定义的语法}带参数的信号和槽(C++)
信号和槽 Linux信号 Signal 系统内部的通知机制. 进程间通信的方式. 信号源:谁发的信号.信号的类型:哪种类别的信号信号的处理方式:注册信号处理函数,在信号被触发的时候自动调用执行. Qt中的信号和Linux中的信号,虽…...
健康养生之道
在快节奏的现代生活中,健康养生不再是中老年人的专属话题,越来越多的人开始意识到,合理的养生方式是保持良好身体状态和生活质量的关键。 饮食养生是健康的基石。遵循 “食物多样、谷类为主” 的原则,保证每天摄入足够的蔬菜、…...
Spark-SQL核心编程3
数据加载与保存 通用方式: SparkSQL 提供了通用的保存数据和数据加载的方式。这里的通用指的是使用相同的API,根据不同的参数读取和保存不同格式的数据,SparkSQL 默认读取和保存的文件格式为parquet 数据加载方法: spark.read.lo…...
TVM计算图分割--Collage
1 背景 为满足高效部署的需要,整合大量优化的tensor代数库和运行时做为后端成为必要之举。现在的深度学习后端可以分为两类:1)算子库(operator kernel libraries),为每个DL算子单独提供高效地低阶kernel实现。这些库一般也支持算…...
elementUI中MessageBox.confirm()默认不聚焦问题处理
在项目中使用elementUI的MessageBox.confirm()出现了默认不聚焦的问题,默认确认按钮是浅色的,需要点击一下才会变成正常。面对这种问题,创建新组件,实现聚焦。替换默认的MessageBox.confirm() 解决 创建components/MessageBoxCo…...
【刷题Day20】TCP和UDP(浅)
TCP 和 UDP 有什么区别? TCP提供了可靠、面向连接的传输,适用于需要数据完整性和顺序的场景。 UDP提供了更轻量、面向报文的传输,适用于实时性要求高的场景。 特性TCPUDP连接方式面向连接无连接可靠性提供可靠性,保证数据按顺序…...
sql server 预估索引大小
使用deepseek工具预估如下: 问题: 如果建立一个数据类型是datetime的索引,需要多大的空间? 回答: 如果建立一个数据类型是 datetime 的索引,索引的大小取决于以下因素: 索引键的大小&#…...
利用 i2c 快速从 Interface 生成 Class
利用 i2c 快速从 Interface 生成 Class(支持 TS & ArkTS) 在日常 TypeScript 或 ArkTS 开发中,需要根据 interface 定义手动实现对应的 class,这既重复又容易出错。分享一个命令行工具 —— interface2class,简称…...
MCGS昆仑通太屏笔记
4.3寸:4013ef/e1 7寸:7032kw 特点: 如果是使用组态屏进行调试使用,选择com1如果是实际项目使用,选择com2 操作步骤: 先创建设备窗口,再创建用户界面 在设备窗口界面,依次设置如下…...
服务治理-搭建Nacos注册中心
运行nacos.sql文件。 将准备好的nacos目录和nacos.tar包上传。 192.168.59.101是我的虚拟机ip,8848是我们设置的访问端口号。...
网络--socket编程(2)
Socket 编程 TCP TCP 网络程序 和刚才 UDP 类似 . 实现一个简单的英译汉的功能 TCP socket API 详解 下面介绍程序中用到的 socket API, 这些函数都在 sys/socket.h 中。 socket(): • socket() 打开一个网络通讯端口 , 如果成功的话 , 就像 open() 一样返回一个…...
【FreeRTOS进阶】优先级翻转现象详解及解决方案
【FreeRTOS进阶】优先级翻转现象详解及解决方案 接下来我们聊聊优先级翻转这个经典问题。这个问题在实时系统中经常出现,尤其是在任务较多的场景下,而且问题定位起来比较麻烦。 什么是优先级翻转? 优先级翻转的核心定义很简单:…...
结合建筑业务讲述TOGAF标准处理哪种架构
TOGAF标准处理哪种架构 内容介绍业务架构业务策略,治理,组织和关键业务流程数据架构组织的逻辑和物理数据资产以及数据管理资源的结构应用架构待部署的各个应用程序,它们之间的交互以及与组织核心业务流程的关系的蓝图技术架构支持业务&#…...
C++入门小馆: 深入string类(一)
嘿,各位技术潮人!好久不见甚是想念。生活就像一场奇妙冒险,而编程就是那把超酷的万能钥匙。此刻,阳光洒在键盘上,灵感在指尖跳跃,让我们抛开一切束缚,给平淡日子加点料,注入满满的pa…...
NHANES指标推荐:WWI
文章题目:Weight-adjusted waist circumference index with hepatic steatosis and fibrosis in adult females: a cross-sectional, nationally representative study (NHANES 2017-2020) DOI:10.1186/s12876-025-03706-4 中文标题:体重调整…...
2025.04.18|【Map】地图绘图技巧全解
Add circles Add circles on a Leaflet map Change tile Several background tiles are offered by leaflet. Learn how to load them, and check the possibilities. 文章目录 Add circlesChange tile 2025.04.18【Map】| 地图绘图技巧全解1. 准备工作2. 地理区域着色图&…...
PR第一课
目录 1.新建 2.PR内部设置 3.导入素材 4.关于素材窗口 5.关于编辑窗口 6.序列的创建 7.视频、图片、音乐 7.1 带有透明通道的素材 8.导出作品 8.1 打开方法 8.2 导出时,需要修改的参数 1.新建 2.PR内部设置 随意点开 编辑->首选项 中的任意内容&a…...
C# 预定义类型全解析
在 C# 编程中,预定义类型是基础且重要的概念。下面我们来详细了解 C# 的预定义类型。 预定义类型概述 C# 提供了 16 种预定义类型,包含 13 种简单类型和 3 种非简单类型。所有预定义类型的名称都由全小写字母组成。 预定义简单类型 预定义简单类型表…...
@EnableAsync+@Async源码学习笔记之六
接上文,我们本文分析 AsyncExecutionAspectSupport 的源码: package org.springframework.aop.interceptor;import java.lang.reflect.Method; import java.util.Map; import java.util.concurrent.Callable; import java.util.concurrent.CompletableFu…...
Java CMS和G1垃圾回收器
举个真带劲的例子:把JVM内存比作你家的祖传旱厕 想象你有个祖传旱厕,分三个坑: 新坑区(年轻代):刚拉的屎热乎着(新对象)陈年坑(老年代):风干的屎…...
