llama3源码解读之推理-infer
文章目录
- 前言
- 一、整体源码解读
- 1、完整main源码
- 2、tokenizer加载
- 3、llama3模型加载
- 4、llama3测试数据文本加载
- 5、llama3模型推理模块
- 1、模型推理模块的数据处理
- 2、模型推理模块的model.generate预测
- 3、模型推理模块的预测结果处理
- 6、多轮对话
- 二、llama3推理数据处理
- 1、完整数据处理源码
- 2、使用prompt方式询问数据加载
- 3、推理处理数据
- 三、llama3推理generate调用_sample方法
- 1、GenerationMode.SAMPLE方法源码
- 2、huggingface的_sample源码
- 3、_sample的初始化与准备
- 4、_sample的初始化(attention / hidden states / scores)
- 5、encoder-decoder模式模型处理
- 6、保持相应变量内容
- 7、进入while循环模块
- 8、结果返回处理
- 四、_sample的while循环模块内容
- 1、模型输入加工
- 1、首次迭代数据加工
- 2、再次迭代数据加工
- 2、模型推理
- 1、首次迭代数据加工
- 2、再次迭代数据加工
- 3、预测结果取值
- 4、预测logits惩罚处理
- 5、预测softmax处理
- 6、预测token选择处理
- 7、停止条件更新处理
- 8、再次循环迭代input_ids与attention_mask处理(生成式预测)
- 8、停止条件再次更新与处理
- 五、模型推理(self)
- 1、模型预测输入参数
- 2、进入包装选择调用方法
- 3、进入forward函数--class LlamaForCausalLM(LlamaPreTrainedModel)
- 4、llama3推理forward的前提说明
- 5、llama3推理forward的设置参数
- 6、llama3推理forward的推理(十分重要)
- 7、llama3推理forward的的lm_head转换
- 8、llama3推理forward的结果包装(CausalLMOutputWithPast)
- 六、llama3模型forward的self.model方法
- 1、模型输入内容
- 1、self.model模型初次输入内容
- 2、self.model模型再次输入内容
- 2、进入包装选择调用方法
- 3、进入forward函数--class LlamaModel(LlamaPreTrainedModel)
- 4、llama3模型forward的源码
- 5、llama3模型forward的输入准备与embedding
- 1、初次输入llama3模型内容
- 2、再此输入llama3模型内容
- 3、说明
- 6、llama3模型forward的cache
- 1、DynamicCache.from_legacy_cache(past_key_values)函数
- 2、cache = cls()类方法
- 该类位置
- 该类方法
- 2、cache.update的更新
- 7、llama3模型decoder_layer方法
- 1、decoder_layer调用源码
- 2、decoder_layer源码方法与调用self.self_attn方法
- 3、self.self_attn源码方法
- 8、hidden_states的LlamaRMSNorm
- 9、next_cache保存与输出结果
- 七、llama3模型atten方法
- 1、模型位置
- 2、llama3推理模型初始化
- 3、llama3推理的forward方法
- 4、llama3模型结构图
前言
本项目是解读开源github的代码,该项目基于Meta最新发布的新一代开源大模型Llama-3开发,是Chinese-LLaMA-Alpaca开源大模型相关系列项目(一期、二期)的第三期。而本项目开源了中文Llama-3基座模型和中文Llama-3-Instruct指令精调大模型。这些模型在原版Llama-3的基础上使用了大规模中文数据进行增量预训练,并且使用精选指令数据进行精调,进一步提升了中文基础语义和指令理解能力,相比二代相关模型获得了显著性能提升。因此,我是基于该项目解读训练与推理相关原理与内容,并以代码形式带领读者一步一步解读,理解其大语言模型运行机理。而该博客首先给出llama3推理源码相关内容解读,我将按照源码流程给出解读。
一、整体源码解读
1、完整main源码
我先给出完整的源码,后面推理使用哪些部分代码,我在深度解读。而一些较为简单内容我不在解读了。
if __name__ == '__main__':load_type = torch.float16# Move the model to the MPS device if availableif torch.backends.mps.is_available():device = torch.device("mps")else:if torch.cuda.is_available():device = torch.device(0)else:device = torch.device('cpu')print(f"Using device: {device}")if args.tokenizer_path is None:args.tokenizer_path = args.base_modeltokenizer = AutoTokenizer.from_pretrained(args.tokenizer_path)terminators = [tokenizer.eos_token_id,tokenizer.convert_tokens_to_ids("<|eot_id|>")]if args.use_vllm:model = LLM(model=args.base_model,tokenizer=args.tokenizer_path,tensor_parallel_size=len(args.gpus.split(',')),dtype=load_type)generation_config["stop_token_ids"] = terminatorsgeneration_config["stop"] = ["<|eot_id|>", "<|end_of_text|>"]else:if args.load_in_4bit or args.load_in_8bit:quantization_config = BitsAndBytesConfig(load_in_4bit=args.load_in_4bit,load_in_8bit=args.load_in_8bit,bnb_4bit_compute_dtype=load_type,bnb_4bit_use_double_quant=True,bnb_4bit_quant_type="nf4")model = AutoModelForCausalLM.from_pretrained(args.base_model,torch_dtype=load_type,low_cpu_mem_usage=True,device_map='auto',quantization_config=quantization_config if (args.load_in_4bit or args.load_in_8bit) else None,attn_implementation="flash_attention_2" if args.use_flash_attention_2 else "sdpa")if device==torch.device('cpu'):model.float()model.eval()# test dataif args.data_file is None:examples = sample_dataelse:with open(args.data_file, 'r') as f:examples = [line.strip() for line in f.readlines()]print("first 10 examples:")for example in examples[:10]:print(example)with torch.no_grad():if args.interactive:print("Start inference with instruction mode.")print('='*85)print("+ 该模式下仅支持单轮问答,无多轮对话能力。\n""+ 如要进行多轮对话,请使用llama.cpp")print('-'*85)print("+ This mode only supports single-turn QA.\n""+ If you want to experience multi-turn dialogue, please use llama.cpp")print('='*85)while True:raw_input_text = input("Input:")if len(raw_input_text.strip())==0:breakif args.with_prompt:input_text = generate_prompt(instruction=raw_input_text)else:input_text = raw_input_textif args.use_vllm:output = model.generate([input_text], SamplingParams(**generation_config), use_tqdm=False)response = output[0].outputs[0].textelse:inputs = tokenizer(input_text,return_tensors="pt") #add_special_tokens=False ?generation_output = model.generate(input_ids = inputs["input_ids"].to(device),attention_mask = inputs['attention_mask'].to(device),eos_token_id=terminators,pad_token_id=tokenizer.eos_token_id,generation_config = generation_config)s = generation_output[0]output = tokenizer.decode(s, skip_special_tokens=True)if args.with_prompt:response = output.split("assistant\n\n")[-1].strip()else:response = outputprint("Response: ",response)print("\n")else:print("Start inference.")results = []if args.use_vllm:if args.with_prompt is True:inputs = [generate_prompt(example) for example in examples]else:inputs = examplesoutputs = model.generate(inputs, SamplingParams(**generation_config))for index, (example, output) in enumerate(zip(examples, outputs)):response = output.outputs[0].textprint(f"======={index}=======")print(f"Input: {example}\n")print(f"Output: {response}\n")results.append({"Input":example,"Output":response})else:for index, example in enumerate(examples):if args.with_prompt:input_text = generate_prompt(instruction=example)else:input_text = exampleinputs = tokenizer(input_text,return_tensors="pt") #add_special_tokens=False ?generation_output = model.generate(input_ids = inputs["input_ids"].to(device),attention_mask = inputs['attention_mask'].to(device),eos_token_id=terminators,pad_token_id=tokenizer.eos_token_id,generation_config = generation_config)s = generation_output[0]output = tokenizer.decode(s,skip_special_tokens=True)if args.with_prompt:response = output.split("assistant\n\n")[1].strip()else:response = outputprint(f"======={index}=======")print(f"Input: {example}\n")print(f"Output: {response}\n")results.append({"Input":input_text,"Output":response})dirname = os.path.dirname(args.predictions_file)os.makedirs(dirname,exist_ok=True)with open(args.predictions_file,'w') as f:json.dump(results,f,ensure_ascii=False,indent=2)if args.use_vllm:with open(dirname+'/generation_config.json','w') as f:json.dump(generation_config,f,ensure_ascii=False,indent=2)else:generation_config.save_pretrained('./')
2、tokenizer加载
有关tokenzier相关加载可参考博客这里。这里,我直接给出其源码,如下:
tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_path)
terminators = [tokenizer.eos_token_id,tokenizer.convert_tokens_to_ids("<|eot_id|>")]
tokenizer.eos_token_id=128009,而terminators=[128009,128009]。
3、llama3模型加载
huggingface模型加载可参考博客这里。这里,llama3的模型加载不在介绍,如下源码:
model = AutoModelForCausalLM.from_pretrained(args.base_model, # 权重路径文件夹torch_dtype=load_type,low_cpu_mem_usage=True,device_map='auto',quantization_config=quantization_config if (args.load_in_4bit or args.load_in_8bit) else None,attn_implementation="flash_attention_2" if args.use_flash_attention_2 else "sdpa"
)
if device==torch.device('cpu'):model.float()
model.eval()
注意:model.eval()为固定权重方式,这是pytorch评估类似。
4、llama3测试数据文本加载
# test dataif args.data_file is None:examples = sample_data # ["为什么要减少污染,保护环境?","你有什么建议?"]else:with open(args.data_file, 'r') as f:examples 相关文章:
llama3源码解读之推理-infer
文章目录 前言一、整体源码解读1、完整main源码2、tokenizer加载3、llama3模型加载4、llama3测试数据文本加载5、llama3模型推理模块1、模型推理模块的数据处理2、模型推理模块的model.generate预测3、模型推理模块的预测结果处理6、多轮对话二、llama3推理数据处理1、完整数据…...
【教程】Linux安装Redis步骤记录
下载地址 Index of /releases/ Downloads - Redis 安装redis-7.4.0.tar.gz 1.下载安装包 wget https://download.redis.io/releases/redis-7.4.0.tar.gz 2.解压 tar -zxvf redis-7.4.0.tar.gz 3.进入目录 cd redis-7.4.0/ 4.编译 make 5.安装 make install PREFIX/u…...
全球汽车线控制动系统市场规模预测:未来六年CAGR为17.3%
引言: 随着汽车行业的持续发展和对安全性能需求的增加,汽车线控制动系统作为提升车辆安全性和操控性的关键组件,正逐渐受到市场的广泛关注。本文旨在通过深度分析汽车线控制动系统行业的各个维度,揭示行业发展趋势和潜在机会。 【…...
Ubuntu运行深度学习代码,代码随机epoch中断没有任何报错
深度学习运行代码直接中断 文章目录 深度学习运行代码直接中断问题描述设备信息问题补充解决思路问题发现及正确解决思路新问题出现最终问题:ubuntu系统,4090显卡安装英伟达驱动535.x外的驱动会导致开机无法进入桌面问题记录 问题描述 运行深度学习代码…...
只有4%知道的Linux,看了你也能上手Ubuntu桌面系统,Ubuntu简易设置,源更新,root密码,远程服务...
创作不易 只因热爱!! 热衷分享,一起成长! “你的鼓励就是我努力付出的动力” 最近常提的一句话,那就是“但行好事,莫问前程"! 与辉同行的董工说:守正出奇。坚持分享,坚持付出,坚持奉献,…...
Tomcat部署——个人笔记
Tomcat部署——个人笔记 文章目录 [toc]简介安装配置文件WEB项目的标准结构WEB项目部署IDEA中开发并部署运行WEB项目 本学习笔记参考尚硅谷等教程。 简介 Apache Tomcat 官网 Tomcat是Apache 软件基金会(Apache Software Foundation)的Jakarta 项目中…...
常见且重要的用户体验原则
以下是一些常见且重要的用户体验原则: 1. 以用户为中心 - 深入了解用户的需求、期望、目标和行为习惯。通过用户研究、调查、访谈等方法获取真实的用户反馈,以此来设计产品或服务。 - 例如,在设计一款老年手机时,充分考虑老年…...
web基础及nginx搭建
第四周 上午 静态资源 根据开发者保存在项目资源目录中的路径访问静态资源 html 图片 js css 音乐 视频 f12 ,开发者工具,网络 1 、 web 基本概念 web 服务器( web server ):也称 HTTP 服务器( HTTP …...
C++ 布隆过滤器
1. 布隆过滤器提出 我们在使用新闻客户端看新闻时,它会给我们不停地推荐新的内容,它每次推荐时要去重,去掉 那些已经看过的内容。问题来了,新闻客户端推荐系统如何实现推送去重的? 用服务器记录了用 户看过的所有历史…...
使用HTML创建用户注册表单
在当今数字化时代,网页表单对于收集用户信息和促进网站交互至关重要。无论您设计简单的注册表单还是复杂的调查表,了解HTML的基础知识可以帮助您构建有效的用户界面。在本教程中,我们将详细介绍如何使用HTML创建基本的用户注册表单。 第一步…...
Python零基础入门教程
Python零基础详细入门教程可以从以下几个方面进行学习和掌握: 一、Python基础认知 1. Python简介 由来与发展:Python是一种广泛使用的高级编程语言,由Guido van Rossum(吉多范罗苏姆)于1991年首次发布。Python以其简…...
成为git砖家(10): 根据文件内容生成SHA-1
文章目录 1. .git/objects 目录2. git cat-file 命令3. 根据文件内容生成 sha-14. 结语5. References 1. .git/objects 目录 git 是一个根据文件内容进行检索的系统。 当创建 hello.py, 填入 print("hello, world")的内容, 并执行 git add hello.py gi…...
园区导航小程序:一站式解决园区导航问题,释放存储,优化访客体验
随着园区的规模不断扩大,功能区划分日益复杂,导致访客和新员工在没有有效导航的情况下容易迷路。传统APP导航虽能解决部分问题,但其下载安装繁琐、占用手机内存大、且非高频使用导致的闲置,让许多用户望而却步。园区导航小程序的出…...
对于n进制转十进制的解法及代码(干货!)
对于p进制转十进制,我们有:(x)pa[0]*p^0a[1]*p^1a[2]*p^2...a[n]*p^n 举个例子:(11001)21*10*20*41*81*1625 (9FA)1610*16^015*16^19*16^22554 据此,我们可以编出c代码来解决问题 …...
当代互联网打工人的生存现状,看完泪流满面!
欢迎私信小编,了解更多产品信息呦~...
花几千上万学习Java,真没必要!(三十八)
测试代码1: package iotest.com; import java.nio.charset.StandardCharsets; import java.io.UnsupportedEncodingException; public class StringByteConversion { public static void main(String[] args) throws UnsupportedEncodingException { // 原始字…...
Zilliz 2025届校园招聘正式启动,寻找向量数据库内核开发工程师
为了解决非结构化数据处理问题,我们构建了向量数据库-Milvus! Milvus 数据库不仅是顶级开源基金会 LF AI&Data 的毕业项目,还曾登上数据库顶会SIGMOD、VLDB,在全球首届向量检索比赛中夺冠。目前,Milvus 项目已获得超过 2.8w s…...
TwinCAT3 新建项目教程
文章目录 打开TwinCAT 新建项目(通过TcXaeShell) 新建项目(通过VS 2019)...
大模型算法面试题(十九)
本系列收纳各种大模型面试题及答案。 1、SFT(有监督微调)、RM(奖励模型)、PPO(强化学习)的数据集格式? SFT(有监督微调)、RM(奖励模型)、PPO&…...
应用地址信息获取新技巧:Xinstall来助力
在移动互联网时代,应用获取用户地址信息的需求越来越普遍。无论是为了提供个性化服务,还是进行精准营销,地址信息都扮演着至关重要的角色。然而,如何合规、准确地获取这一信息,却是许多开发者面临的挑战。今天…...
日语AI面试高效通关秘籍:专业解读与青柚面试智能助攻
在如今就业市场竞争日益激烈的背景下,越来越多的求职者将目光投向了日本及中日双语岗位。但是,一场日语面试往往让许多人感到步履维艰。你是否也曾因为面试官抛出的“刁钻问题”而心生畏惧?面对生疏的日语交流环境,即便提前恶补了…...
STM32+rt-thread判断是否联网
一、根据NETDEV_FLAG_INTERNET_UP位判断 static bool is_conncected(void) {struct netdev *dev RT_NULL;dev netdev_get_first_by_flags(NETDEV_FLAG_INTERNET_UP);if (dev RT_NULL){printf("wait netdev internet up...");return false;}else{printf("loc…...
iPhone密码忘记了办?iPhoneUnlocker,iPhone解锁工具Aiseesoft iPhone Unlocker 高级注册版分享
平时用 iPhone 的时候,难免会碰到解锁的麻烦事。比如密码忘了、人脸识别 / 指纹识别突然不灵,或者买了二手 iPhone 却被原来的 iCloud 账号锁住,这时候就需要靠谱的解锁工具来帮忙了。Aiseesoft iPhone Unlocker 就是专门解决这些问题的软件&…...
【第二十一章 SDIO接口(SDIO)】
第二十一章 SDIO接口 目录 第二十一章 SDIO接口(SDIO) 1 SDIO 主要功能 2 SDIO 总线拓扑 3 SDIO 功能描述 3.1 SDIO 适配器 3.2 SDIOAHB 接口 4 卡功能描述 4.1 卡识别模式 4.2 卡复位 4.3 操作电压范围确认 4.4 卡识别过程 4.5 写数据块 4.6 读数据块 4.7 数据流…...
对WWDC 2025 Keynote 内容的预测
借助我们以往对苹果公司发展路径的深入研究经验,以及大语言模型的分析能力,我们系统梳理了多年来苹果 WWDC 主题演讲的规律。在 WWDC 2025 即将揭幕之际,我们让 ChatGPT 对今年的 Keynote 内容进行了一个初步预测,聊作存档。等到明…...
《通信之道——从微积分到 5G》读书总结
第1章 绪 论 1.1 这是一本什么样的书 通信技术,说到底就是数学。 那些最基础、最本质的部分。 1.2 什么是通信 通信 发送方 接收方 承载信息的信号 解调出其中承载的信息 信息在发送方那里被加工成信号(调制) 把信息从信号中抽取出来&am…...
SpringTask-03.入门案例
一.入门案例 启动类: package com.sky;import lombok.extern.slf4j.Slf4j; import org.springframework.boot.SpringApplication; import org.springframework.boot.autoconfigure.SpringBootApplication; import org.springframework.cache.annotation.EnableCach…...
GruntJS-前端自动化任务运行器从入门到实战
Grunt 完全指南:从入门到实战 一、Grunt 是什么? Grunt是一个基于 Node.js 的前端自动化任务运行器,主要用于自动化执行项目开发中重复性高的任务,例如文件压缩、代码编译、语法检查、单元测试、文件合并等。通过配置简洁的任务…...
怎么让Comfyui导出的图像不包含工作流信息,
为了数据安全,让Comfyui导出的图像不包含工作流信息,导出的图像就不会拖到comfyui中加载出来工作流。 ComfyUI的目录下node.py 直接移除 pnginfo(推荐) 在 save_images 方法中,删除或注释掉所有与 metadata …...
【Linux手册】探秘系统世界:从用户交互到硬件底层的全链路工作之旅
目录 前言 操作系统与驱动程序 是什么,为什么 怎么做 system call 用户操作接口 总结 前言 日常生活中,我们在使用电子设备时,我们所输入执行的每一条指令最终大多都会作用到硬件上,比如下载一款软件最终会下载到硬盘上&am…...
