llama3源码解读之推理-infer
文章目录
- 前言
- 一、整体源码解读
- 1、完整main源码
- 2、tokenizer加载
- 3、llama3模型加载
- 4、llama3测试数据文本加载
- 5、llama3模型推理模块
- 1、模型推理模块的数据处理
- 2、模型推理模块的model.generate预测
- 3、模型推理模块的预测结果处理
- 6、多轮对话
- 二、llama3推理数据处理
- 1、完整数据处理源码
- 2、使用prompt方式询问数据加载
- 3、推理处理数据
- 三、llama3推理generate调用_sample方法
- 1、GenerationMode.SAMPLE方法源码
- 2、huggingface的_sample源码
- 3、_sample的初始化与准备
- 4、_sample的初始化(attention / hidden states / scores)
- 5、encoder-decoder模式模型处理
- 6、保持相应变量内容
- 7、进入while循环模块
- 8、结果返回处理
- 四、_sample的while循环模块内容
- 1、模型输入加工
- 1、首次迭代数据加工
- 2、再次迭代数据加工
- 2、模型推理
- 1、首次迭代数据加工
- 2、再次迭代数据加工
- 3、预测结果取值
- 4、预测logits惩罚处理
- 5、预测softmax处理
- 6、预测token选择处理
- 7、停止条件更新处理
- 8、再次循环迭代input_ids与attention_mask处理(生成式预测)
- 8、停止条件再次更新与处理
- 五、模型推理(self)
- 1、模型预测输入参数
- 2、进入包装选择调用方法
- 3、进入forward函数--class LlamaForCausalLM(LlamaPreTrainedModel)
- 4、llama3推理forward的前提说明
- 5、llama3推理forward的设置参数
- 6、llama3推理forward的推理(十分重要)
- 7、llama3推理forward的的lm_head转换
- 8、llama3推理forward的结果包装(CausalLMOutputWithPast)
- 六、llama3模型forward的self.model方法
- 1、模型输入内容
- 1、self.model模型初次输入内容
- 2、self.model模型再次输入内容
- 2、进入包装选择调用方法
- 3、进入forward函数--class LlamaModel(LlamaPreTrainedModel)
- 4、llama3模型forward的源码
- 5、llama3模型forward的输入准备与embedding
- 1、初次输入llama3模型内容
- 2、再此输入llama3模型内容
- 3、说明
- 6、llama3模型forward的cache
- 1、DynamicCache.from_legacy_cache(past_key_values)函数
- 2、cache = cls()类方法
- 该类位置
- 该类方法
- 2、cache.update的更新
- 7、llama3模型decoder_layer方法
- 1、decoder_layer调用源码
- 2、decoder_layer源码方法与调用self.self_attn方法
- 3、self.self_attn源码方法
- 8、hidden_states的LlamaRMSNorm
- 9、next_cache保存与输出结果
- 七、llama3模型atten方法
- 1、模型位置
- 2、llama3推理模型初始化
- 3、llama3推理的forward方法
- 4、llama3模型结构图
前言
本项目是解读开源github的代码,该项目基于Meta最新发布的新一代开源大模型Llama-3开发,是Chinese-LLaMA-Alpaca开源大模型相关系列项目(一期、二期)的第三期。而本项目开源了中文Llama-3基座模型和中文Llama-3-Instruct指令精调大模型。这些模型在原版Llama-3的基础上使用了大规模中文数据进行增量预训练,并且使用精选指令数据进行精调,进一步提升了中文基础语义和指令理解能力,相比二代相关模型获得了显著性能提升。因此,我是基于该项目解读训练与推理相关原理与内容,并以代码形式带领读者一步一步解读,理解其大语言模型运行机理。而该博客首先给出llama3推理源码相关内容解读,我将按照源码流程给出解读。
一、整体源码解读
1、完整main源码
我先给出完整的源码,后面推理使用哪些部分代码,我在深度解读。而一些较为简单内容我不在解读了。
if __name__ == '__main__':load_type = torch.float16# Move the model to the MPS device if availableif torch.backends.mps.is_available():device = torch.device("mps")else:if torch.cuda.is_available():device = torch.device(0)else:device = torch.device('cpu')print(f"Using device: {device}")if args.tokenizer_path is None:args.tokenizer_path = args.base_modeltokenizer = AutoTokenizer.from_pretrained(args.tokenizer_path)terminators = [tokenizer.eos_token_id,tokenizer.convert_tokens_to_ids("<|eot_id|>")]if args.use_vllm:model = LLM(model=args.base_model,tokenizer=args.tokenizer_path,tensor_parallel_size=len(args.gpus.split(',')),dtype=load_type)generation_config["stop_token_ids"] = terminatorsgeneration_config["stop"] = ["<|eot_id|>", "<|end_of_text|>"]else:if args.load_in_4bit or args.load_in_8bit:quantization_config = BitsAndBytesConfig(load_in_4bit=args.load_in_4bit,load_in_8bit=args.load_in_8bit,bnb_4bit_compute_dtype=load_type,bnb_4bit_use_double_quant=True,bnb_4bit_quant_type="nf4")model = AutoModelForCausalLM.from_pretrained(args.base_model,torch_dtype=load_type,low_cpu_mem_usage=True,device_map='auto',quantization_config=quantization_config if (args.load_in_4bit or args.load_in_8bit) else None,attn_implementation="flash_attention_2" if args.use_flash_attention_2 else "sdpa")if device==torch.device('cpu'):model.float()model.eval()# test dataif args.data_file is None:examples = sample_dataelse:with open(args.data_file, 'r') as f:examples = [line.strip() for line in f.readlines()]print("first 10 examples:")for example in examples[:10]:print(example)with torch.no_grad():if args.interactive:print("Start inference with instruction mode.")print('='*85)print("+ 该模式下仅支持单轮问答,无多轮对话能力。\n""+ 如要进行多轮对话,请使用llama.cpp")print('-'*85)print("+ This mode only supports single-turn QA.\n""+ If you want to experience multi-turn dialogue, please use llama.cpp")print('='*85)while True:raw_input_text = input("Input:")if len(raw_input_text.strip())==0:breakif args.with_prompt:input_text = generate_prompt(instruction=raw_input_text)else:input_text = raw_input_textif args.use_vllm:output = model.generate([input_text], SamplingParams(**generation_config), use_tqdm=False)response = output[0].outputs[0].textelse:inputs = tokenizer(input_text,return_tensors="pt") #add_special_tokens=False ?generation_output = model.generate(input_ids = inputs["input_ids"].to(device),attention_mask = inputs['attention_mask'].to(device),eos_token_id=terminators,pad_token_id=tokenizer.eos_token_id,generation_config = generation_config)s = generation_output[0]output = tokenizer.decode(s, skip_special_tokens=True)if args.with_prompt:response = output.split("assistant\n\n")[-1].strip()else:response = outputprint("Response: ",response)print("\n")else:print("Start inference.")results = []if args.use_vllm:if args.with_prompt is True:inputs = [generate_prompt(example) for example in examples]else:inputs = examplesoutputs = model.generate(inputs, SamplingParams(**generation_config))for index, (example, output) in enumerate(zip(examples, outputs)):response = output.outputs[0].textprint(f"======={index}=======")print(f"Input: {example}\n")print(f"Output: {response}\n")results.append({"Input":example,"Output":response})else:for index, example in enumerate(examples):if args.with_prompt:input_text = generate_prompt(instruction=example)else:input_text = exampleinputs = tokenizer(input_text,return_tensors="pt") #add_special_tokens=False ?generation_output = model.generate(input_ids = inputs["input_ids"].to(device),attention_mask = inputs['attention_mask'].to(device),eos_token_id=terminators,pad_token_id=tokenizer.eos_token_id,generation_config = generation_config)s = generation_output[0]output = tokenizer.decode(s,skip_special_tokens=True)if args.with_prompt:response = output.split("assistant\n\n")[1].strip()else:response = outputprint(f"======={index}=======")print(f"Input: {example}\n")print(f"Output: {response}\n")results.append({"Input":input_text,"Output":response})dirname = os.path.dirname(args.predictions_file)os.makedirs(dirname,exist_ok=True)with open(args.predictions_file,'w') as f:json.dump(results,f,ensure_ascii=False,indent=2)if args.use_vllm:with open(dirname+'/generation_config.json','w') as f:json.dump(generation_config,f,ensure_ascii=False,indent=2)else:generation_config.save_pretrained('./')
2、tokenizer加载
有关tokenzier相关加载可参考博客这里。这里,我直接给出其源码,如下:
tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_path)
terminators = [tokenizer.eos_token_id,tokenizer.convert_tokens_to_ids("<|eot_id|>")]
tokenizer.eos_token_id=128009,而terminators=[128009,128009]。
3、llama3模型加载
huggingface模型加载可参考博客这里。这里,llama3的模型加载不在介绍,如下源码:
model = AutoModelForCausalLM.from_pretrained(args.base_model, # 权重路径文件夹torch_dtype=load_type,low_cpu_mem_usage=True,device_map='auto',quantization_config=quantization_config if (args.load_in_4bit or args.load_in_8bit) else None,attn_implementation="flash_attention_2" if args.use_flash_attention_2 else "sdpa"
)
if device==torch.device('cpu'):model.float()
model.eval()
注意:model.eval()为固定权重方式,这是pytorch评估类似。
4、llama3测试数据文本加载
# test dataif args.data_file is None:examples = sample_data # ["为什么要减少污染,保护环境?","你有什么建议?"]else:with open(args.data_file, 'r') as f:examples 相关文章:
llama3源码解读之推理-infer
文章目录 前言一、整体源码解读1、完整main源码2、tokenizer加载3、llama3模型加载4、llama3测试数据文本加载5、llama3模型推理模块1、模型推理模块的数据处理2、模型推理模块的model.generate预测3、模型推理模块的预测结果处理6、多轮对话二、llama3推理数据处理1、完整数据…...
【教程】Linux安装Redis步骤记录
下载地址 Index of /releases/ Downloads - Redis 安装redis-7.4.0.tar.gz 1.下载安装包 wget https://download.redis.io/releases/redis-7.4.0.tar.gz 2.解压 tar -zxvf redis-7.4.0.tar.gz 3.进入目录 cd redis-7.4.0/ 4.编译 make 5.安装 make install PREFIX/u…...
全球汽车线控制动系统市场规模预测:未来六年CAGR为17.3%
引言: 随着汽车行业的持续发展和对安全性能需求的增加,汽车线控制动系统作为提升车辆安全性和操控性的关键组件,正逐渐受到市场的广泛关注。本文旨在通过深度分析汽车线控制动系统行业的各个维度,揭示行业发展趋势和潜在机会。 【…...
Ubuntu运行深度学习代码,代码随机epoch中断没有任何报错
深度学习运行代码直接中断 文章目录 深度学习运行代码直接中断问题描述设备信息问题补充解决思路问题发现及正确解决思路新问题出现最终问题:ubuntu系统,4090显卡安装英伟达驱动535.x外的驱动会导致开机无法进入桌面问题记录 问题描述 运行深度学习代码…...
只有4%知道的Linux,看了你也能上手Ubuntu桌面系统,Ubuntu简易设置,源更新,root密码,远程服务...
创作不易 只因热爱!! 热衷分享,一起成长! “你的鼓励就是我努力付出的动力” 最近常提的一句话,那就是“但行好事,莫问前程"! 与辉同行的董工说:守正出奇。坚持分享,坚持付出,坚持奉献,…...
Tomcat部署——个人笔记
Tomcat部署——个人笔记 文章目录 [toc]简介安装配置文件WEB项目的标准结构WEB项目部署IDEA中开发并部署运行WEB项目 本学习笔记参考尚硅谷等教程。 简介 Apache Tomcat 官网 Tomcat是Apache 软件基金会(Apache Software Foundation)的Jakarta 项目中…...
常见且重要的用户体验原则
以下是一些常见且重要的用户体验原则: 1. 以用户为中心 - 深入了解用户的需求、期望、目标和行为习惯。通过用户研究、调查、访谈等方法获取真实的用户反馈,以此来设计产品或服务。 - 例如,在设计一款老年手机时,充分考虑老年…...
web基础及nginx搭建
第四周 上午 静态资源 根据开发者保存在项目资源目录中的路径访问静态资源 html 图片 js css 音乐 视频 f12 ,开发者工具,网络 1 、 web 基本概念 web 服务器( web server ):也称 HTTP 服务器( HTTP …...
C++ 布隆过滤器
1. 布隆过滤器提出 我们在使用新闻客户端看新闻时,它会给我们不停地推荐新的内容,它每次推荐时要去重,去掉 那些已经看过的内容。问题来了,新闻客户端推荐系统如何实现推送去重的? 用服务器记录了用 户看过的所有历史…...
使用HTML创建用户注册表单
在当今数字化时代,网页表单对于收集用户信息和促进网站交互至关重要。无论您设计简单的注册表单还是复杂的调查表,了解HTML的基础知识可以帮助您构建有效的用户界面。在本教程中,我们将详细介绍如何使用HTML创建基本的用户注册表单。 第一步…...
Python零基础入门教程
Python零基础详细入门教程可以从以下几个方面进行学习和掌握: 一、Python基础认知 1. Python简介 由来与发展:Python是一种广泛使用的高级编程语言,由Guido van Rossum(吉多范罗苏姆)于1991年首次发布。Python以其简…...
成为git砖家(10): 根据文件内容生成SHA-1
文章目录 1. .git/objects 目录2. git cat-file 命令3. 根据文件内容生成 sha-14. 结语5. References 1. .git/objects 目录 git 是一个根据文件内容进行检索的系统。 当创建 hello.py, 填入 print("hello, world")的内容, 并执行 git add hello.py gi…...
园区导航小程序:一站式解决园区导航问题,释放存储,优化访客体验
随着园区的规模不断扩大,功能区划分日益复杂,导致访客和新员工在没有有效导航的情况下容易迷路。传统APP导航虽能解决部分问题,但其下载安装繁琐、占用手机内存大、且非高频使用导致的闲置,让许多用户望而却步。园区导航小程序的出…...
对于n进制转十进制的解法及代码(干货!)
对于p进制转十进制,我们有:(x)pa[0]*p^0a[1]*p^1a[2]*p^2...a[n]*p^n 举个例子:(11001)21*10*20*41*81*1625 (9FA)1610*16^015*16^19*16^22554 据此,我们可以编出c代码来解决问题 …...
当代互联网打工人的生存现状,看完泪流满面!
欢迎私信小编,了解更多产品信息呦~...
花几千上万学习Java,真没必要!(三十八)
测试代码1: package iotest.com; import java.nio.charset.StandardCharsets; import java.io.UnsupportedEncodingException; public class StringByteConversion { public static void main(String[] args) throws UnsupportedEncodingException { // 原始字…...
Zilliz 2025届校园招聘正式启动,寻找向量数据库内核开发工程师
为了解决非结构化数据处理问题,我们构建了向量数据库-Milvus! Milvus 数据库不仅是顶级开源基金会 LF AI&Data 的毕业项目,还曾登上数据库顶会SIGMOD、VLDB,在全球首届向量检索比赛中夺冠。目前,Milvus 项目已获得超过 2.8w s…...
TwinCAT3 新建项目教程
文章目录 打开TwinCAT 新建项目(通过TcXaeShell) 新建项目(通过VS 2019)...
大模型算法面试题(十九)
本系列收纳各种大模型面试题及答案。 1、SFT(有监督微调)、RM(奖励模型)、PPO(强化学习)的数据集格式? SFT(有监督微调)、RM(奖励模型)、PPO&…...
应用地址信息获取新技巧:Xinstall来助力
在移动互联网时代,应用获取用户地址信息的需求越来越普遍。无论是为了提供个性化服务,还是进行精准营销,地址信息都扮演着至关重要的角色。然而,如何合规、准确地获取这一信息,却是许多开发者面临的挑战。今天…...
Opencv中的addweighted函数
一.addweighted函数作用 addweighted()是OpenCV库中用于图像处理的函数,主要功能是将两个输入图像(尺寸和类型相同)按照指定的权重进行加权叠加(图像融合),并添加一个标量值&#x…...
蓝桥杯 2024 15届国赛 A组 儿童节快乐
P10576 [蓝桥杯 2024 国 A] 儿童节快乐 题目描述 五彩斑斓的气球在蓝天下悠然飘荡,轻快的音乐在耳边持续回荡,小朋友们手牵着手一同畅快欢笑。在这样一片安乐祥和的氛围下,六一来了。 今天是六一儿童节,小蓝老师为了让大家在节…...
IoT/HCIP实验-3/LiteOS操作系统内核实验(任务、内存、信号量、CMSIS..)
文章目录 概述HelloWorld 工程C/C配置编译器主配置Makefile脚本烧录器主配置运行结果程序调用栈 任务管理实验实验结果osal 系统适配层osal_task_create 其他实验实验源码内存管理实验互斥锁实验信号量实验 CMISIS接口实验还是得JlINKCMSIS 简介LiteOS->CMSIS任务间消息交互…...
SpringTask-03.入门案例
一.入门案例 启动类: package com.sky;import lombok.extern.slf4j.Slf4j; import org.springframework.boot.SpringApplication; import org.springframework.boot.autoconfigure.SpringBootApplication; import org.springframework.cache.annotation.EnableCach…...
是否存在路径(FIFOBB算法)
题目描述 一个具有 n 个顶点e条边的无向图,该图顶点的编号依次为0到n-1且不存在顶点与自身相连的边。请使用FIFOBB算法编写程序,确定是否存在从顶点 source到顶点 destination的路径。 输入 第一行两个整数,分别表示n 和 e 的值(1…...
C语言中提供的第三方库之哈希表实现
一. 简介 前面一篇文章简单学习了C语言中第三方库(uthash库)提供对哈希表的操作,文章如下: C语言中提供的第三方库uthash常用接口-CSDN博客 本文简单学习一下第三方库 uthash库对哈希表的操作。 二. uthash库哈希表操作示例 u…...
windows系统MySQL安装文档
概览:本文讨论了MySQL的安装、使用过程中涉及的解压、配置、初始化、注册服务、启动、修改密码、登录、退出以及卸载等相关内容,为学习者提供全面的操作指导。关键要点包括: 解压 :下载完成后解压压缩包,得到MySQL 8.…...
在 Spring Boot 中使用 JSP
jsp? 好多年没用了。重新整一下 还费了点时间,记录一下。 项目结构: pom: <?xml version"1.0" encoding"UTF-8"?> <project xmlns"http://maven.apache.org/POM/4.0.0" xmlns:xsi"http://ww…...
uni-app学习笔记三十五--扩展组件的安装和使用
由于内置组件不能满足日常开发需要,uniapp官方也提供了众多的扩展组件供我们使用。由于不是内置组件,需要安装才能使用。 一、安装扩展插件 安装方法: 1.访问uniapp官方文档组件部分:组件使用的入门教程 | uni-app官网 点击左侧…...
Java后端检查空条件查询
通过抛出运行异常:throw new RuntimeException("请输入查询条件!");BranchWarehouseServiceImpl.java // 查询试剂交易(入库/出库)记录Overridepublic List<BranchWarehouseTransactions> queryForReagent(Branch…...
