llama3源码解读之推理-infer
文章目录
- 前言
- 一、整体源码解读
- 1、完整main源码
- 2、tokenizer加载
- 3、llama3模型加载
- 4、llama3测试数据文本加载
- 5、llama3模型推理模块
- 1、模型推理模块的数据处理
- 2、模型推理模块的model.generate预测
- 3、模型推理模块的预测结果处理
- 6、多轮对话
- 二、llama3推理数据处理
- 1、完整数据处理源码
- 2、使用prompt方式询问数据加载
- 3、推理处理数据
- 三、llama3推理generate调用_sample方法
- 1、GenerationMode.SAMPLE方法源码
- 2、huggingface的_sample源码
- 3、_sample的初始化与准备
- 4、_sample的初始化(attention / hidden states / scores)
- 5、encoder-decoder模式模型处理
- 6、保持相应变量内容
- 7、进入while循环模块
- 8、结果返回处理
- 四、_sample的while循环模块内容
- 1、模型输入加工
- 1、首次迭代数据加工
- 2、再次迭代数据加工
- 2、模型推理
- 1、首次迭代数据加工
- 2、再次迭代数据加工
- 3、预测结果取值
- 4、预测logits惩罚处理
- 5、预测softmax处理
- 6、预测token选择处理
- 7、停止条件更新处理
- 8、再次循环迭代input_ids与attention_mask处理(生成式预测)
- 8、停止条件再次更新与处理
- 五、模型推理(self)
- 1、模型预测输入参数
- 2、进入包装选择调用方法
- 3、进入forward函数--class LlamaForCausalLM(LlamaPreTrainedModel)
- 4、llama3推理forward的前提说明
- 5、llama3推理forward的设置参数
- 6、llama3推理forward的推理(十分重要)
- 7、llama3推理forward的的lm_head转换
- 8、llama3推理forward的结果包装(CausalLMOutputWithPast)
- 六、llama3模型forward的self.model方法
- 1、模型输入内容
- 1、self.model模型初次输入内容
- 2、self.model模型再次输入内容
- 2、进入包装选择调用方法
- 3、进入forward函数--class LlamaModel(LlamaPreTrainedModel)
- 4、llama3模型forward的源码
- 5、llama3模型forward的输入准备与embedding
- 1、初次输入llama3模型内容
- 2、再此输入llama3模型内容
- 3、说明
- 6、llama3模型forward的cache
- 1、DynamicCache.from_legacy_cache(past_key_values)函数
- 2、cache = cls()类方法
- 该类位置
- 该类方法
- 2、cache.update的更新
- 7、llama3模型decoder_layer方法
- 1、decoder_layer调用源码
- 2、decoder_layer源码方法与调用self.self_attn方法
- 3、self.self_attn源码方法
- 8、hidden_states的LlamaRMSNorm
- 9、next_cache保存与输出结果
- 七、llama3模型atten方法
- 1、模型位置
- 2、llama3推理模型初始化
- 3、llama3推理的forward方法
- 4、llama3模型结构图
前言
本项目是解读开源github的代码,该项目基于Meta最新发布的新一代开源大模型Llama-3开发,是Chinese-LLaMA-Alpaca开源大模型相关系列项目(一期、二期)的第三期。而本项目开源了中文Llama-3基座模型和中文Llama-3-Instruct指令精调大模型。这些模型在原版Llama-3的基础上使用了大规模中文数据进行增量预训练,并且使用精选指令数据进行精调,进一步提升了中文基础语义和指令理解能力,相比二代相关模型获得了显著性能提升。因此,我是基于该项目解读训练与推理相关原理与内容,并以代码形式带领读者一步一步解读,理解其大语言模型运行机理。而该博客首先给出llama3推理源码相关内容解读,我将按照源码流程给出解读。
一、整体源码解读
1、完整main源码
我先给出完整的源码,后面推理使用哪些部分代码,我在深度解读。而一些较为简单内容我不在解读了。
if __name__ == '__main__':load_type = torch.float16# Move the model to the MPS device if availableif torch.backends.mps.is_available():device = torch.device("mps")else:if torch.cuda.is_available():device = torch.device(0)else:device = torch.device('cpu')print(f"Using device: {device}")if args.tokenizer_path is None:args.tokenizer_path = args.base_modeltokenizer = AutoTokenizer.from_pretrained(args.tokenizer_path)terminators = [tokenizer.eos_token_id,tokenizer.convert_tokens_to_ids("<|eot_id|>")]if args.use_vllm:model = LLM(model=args.base_model,tokenizer=args.tokenizer_path,tensor_parallel_size=len(args.gpus.split(',')),dtype=load_type)generation_config["stop_token_ids"] = terminatorsgeneration_config["stop"] = ["<|eot_id|>", "<|end_of_text|>"]else:if args.load_in_4bit or args.load_in_8bit:quantization_config = BitsAndBytesConfig(load_in_4bit=args.load_in_4bit,load_in_8bit=args.load_in_8bit,bnb_4bit_compute_dtype=load_type,bnb_4bit_use_double_quant=True,bnb_4bit_quant_type="nf4")model = AutoModelForCausalLM.from_pretrained(args.base_model,torch_dtype=load_type,low_cpu_mem_usage=True,device_map='auto',quantization_config=quantization_config if (args.load_in_4bit or args.load_in_8bit) else None,attn_implementation="flash_attention_2" if args.use_flash_attention_2 else "sdpa")if device==torch.device('cpu'):model.float()model.eval()# test dataif args.data_file is None:examples = sample_dataelse:with open(args.data_file, 'r') as f:examples = [line.strip() for line in f.readlines()]print("first 10 examples:")for example in examples[:10]:print(example)with torch.no_grad():if args.interactive:print("Start inference with instruction mode.")print('='*85)print("+ 该模式下仅支持单轮问答,无多轮对话能力。\n""+ 如要进行多轮对话,请使用llama.cpp")print('-'*85)print("+ This mode only supports single-turn QA.\n""+ If you want to experience multi-turn dialogue, please use llama.cpp")print('='*85)while True:raw_input_text = input("Input:")if len(raw_input_text.strip())==0:breakif args.with_prompt:input_text = generate_prompt(instruction=raw_input_text)else:input_text = raw_input_textif args.use_vllm:output = model.generate([input_text], SamplingParams(**generation_config), use_tqdm=False)response = output[0].outputs[0].textelse:inputs = tokenizer(input_text,return_tensors="pt") #add_special_tokens=False ?generation_output = model.generate(input_ids = inputs["input_ids"].to(device),attention_mask = inputs['attention_mask'].to(device),eos_token_id=terminators,pad_token_id=tokenizer.eos_token_id,generation_config = generation_config)s = generation_output[0]output = tokenizer.decode(s, skip_special_tokens=True)if args.with_prompt:response = output.split("assistant\n\n")[-1].strip()else:response = outputprint("Response: ",response)print("\n")else:print("Start inference.")results = []if args.use_vllm:if args.with_prompt is True:inputs = [generate_prompt(example) for example in examples]else:inputs = examplesoutputs = model.generate(inputs, SamplingParams(**generation_config))for index, (example, output) in enumerate(zip(examples, outputs)):response = output.outputs[0].textprint(f"======={index}=======")print(f"Input: {example}\n")print(f"Output: {response}\n")results.append({"Input":example,"Output":response})else:for index, example in enumerate(examples):if args.with_prompt:input_text = generate_prompt(instruction=example)else:input_text = exampleinputs = tokenizer(input_text,return_tensors="pt") #add_special_tokens=False ?generation_output = model.generate(input_ids = inputs["input_ids"].to(device),attention_mask = inputs['attention_mask'].to(device),eos_token_id=terminators,pad_token_id=tokenizer.eos_token_id,generation_config = generation_config)s = generation_output[0]output = tokenizer.decode(s,skip_special_tokens=True)if args.with_prompt:response = output.split("assistant\n\n")[1].strip()else:response = outputprint(f"======={index}=======")print(f"Input: {example}\n")print(f"Output: {response}\n")results.append({"Input":input_text,"Output":response})dirname = os.path.dirname(args.predictions_file)os.makedirs(dirname,exist_ok=True)with open(args.predictions_file,'w') as f:json.dump(results,f,ensure_ascii=False,indent=2)if args.use_vllm:with open(dirname+'/generation_config.json','w') as f:json.dump(generation_config,f,ensure_ascii=False,indent=2)else:generation_config.save_pretrained('./')
2、tokenizer加载
有关tokenzier相关加载可参考博客这里。这里,我直接给出其源码,如下:
tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_path)
terminators = [tokenizer.eos_token_id,tokenizer.convert_tokens_to_ids("<|eot_id|>")]
tokenizer.eos_token_id=128009,而terminators=[128009,128009]。
3、llama3模型加载
huggingface模型加载可参考博客这里。这里,llama3的模型加载不在介绍,如下源码:
model = AutoModelForCausalLM.from_pretrained(args.base_model, # 权重路径文件夹torch_dtype=load_type,low_cpu_mem_usage=True,device_map='auto',quantization_config=quantization_config if (args.load_in_4bit or args.load_in_8bit) else None,attn_implementation="flash_attention_2" if args.use_flash_attention_2 else "sdpa"
)
if device==torch.device('cpu'):model.float()
model.eval()
注意:model.eval()为固定权重方式,这是pytorch评估类似。
4、llama3测试数据文本加载
# test dataif args.data_file is None:examples = sample_data # ["为什么要减少污染,保护环境?","你有什么建议?"]else:with open(args.data_file, 'r') as f:examples
相关文章:
llama3源码解读之推理-infer
文章目录 前言一、整体源码解读1、完整main源码2、tokenizer加载3、llama3模型加载4、llama3测试数据文本加载5、llama3模型推理模块1、模型推理模块的数据处理2、模型推理模块的model.generate预测3、模型推理模块的预测结果处理6、多轮对话二、llama3推理数据处理1、完整数据…...

【教程】Linux安装Redis步骤记录
下载地址 Index of /releases/ Downloads - Redis 安装redis-7.4.0.tar.gz 1.下载安装包 wget https://download.redis.io/releases/redis-7.4.0.tar.gz 2.解压 tar -zxvf redis-7.4.0.tar.gz 3.进入目录 cd redis-7.4.0/ 4.编译 make 5.安装 make install PREFIX/u…...

全球汽车线控制动系统市场规模预测:未来六年CAGR为17.3%
引言: 随着汽车行业的持续发展和对安全性能需求的增加,汽车线控制动系统作为提升车辆安全性和操控性的关键组件,正逐渐受到市场的广泛关注。本文旨在通过深度分析汽车线控制动系统行业的各个维度,揭示行业发展趋势和潜在机会。 【…...

Ubuntu运行深度学习代码,代码随机epoch中断没有任何报错
深度学习运行代码直接中断 文章目录 深度学习运行代码直接中断问题描述设备信息问题补充解决思路问题发现及正确解决思路新问题出现最终问题:ubuntu系统,4090显卡安装英伟达驱动535.x外的驱动会导致开机无法进入桌面问题记录 问题描述 运行深度学习代码…...

只有4%知道的Linux,看了你也能上手Ubuntu桌面系统,Ubuntu简易设置,源更新,root密码,远程服务...
创作不易 只因热爱!! 热衷分享,一起成长! “你的鼓励就是我努力付出的动力” 最近常提的一句话,那就是“但行好事,莫问前程"! 与辉同行的董工说:守正出奇。坚持分享,坚持付出,坚持奉献,…...

Tomcat部署——个人笔记
Tomcat部署——个人笔记 文章目录 [toc]简介安装配置文件WEB项目的标准结构WEB项目部署IDEA中开发并部署运行WEB项目 本学习笔记参考尚硅谷等教程。 简介 Apache Tomcat 官网 Tomcat是Apache 软件基金会(Apache Software Foundation)的Jakarta 项目中…...
常见且重要的用户体验原则
以下是一些常见且重要的用户体验原则: 1. 以用户为中心 - 深入了解用户的需求、期望、目标和行为习惯。通过用户研究、调查、访谈等方法获取真实的用户反馈,以此来设计产品或服务。 - 例如,在设计一款老年手机时,充分考虑老年…...
web基础及nginx搭建
第四周 上午 静态资源 根据开发者保存在项目资源目录中的路径访问静态资源 html 图片 js css 音乐 视频 f12 ,开发者工具,网络 1 、 web 基本概念 web 服务器( web server ):也称 HTTP 服务器( HTTP …...

C++ 布隆过滤器
1. 布隆过滤器提出 我们在使用新闻客户端看新闻时,它会给我们不停地推荐新的内容,它每次推荐时要去重,去掉 那些已经看过的内容。问题来了,新闻客户端推荐系统如何实现推送去重的? 用服务器记录了用 户看过的所有历史…...
使用HTML创建用户注册表单
在当今数字化时代,网页表单对于收集用户信息和促进网站交互至关重要。无论您设计简单的注册表单还是复杂的调查表,了解HTML的基础知识可以帮助您构建有效的用户界面。在本教程中,我们将详细介绍如何使用HTML创建基本的用户注册表单。 第一步…...
Python零基础入门教程
Python零基础详细入门教程可以从以下几个方面进行学习和掌握: 一、Python基础认知 1. Python简介 由来与发展:Python是一种广泛使用的高级编程语言,由Guido van Rossum(吉多范罗苏姆)于1991年首次发布。Python以其简…...

成为git砖家(10): 根据文件内容生成SHA-1
文章目录 1. .git/objects 目录2. git cat-file 命令3. 根据文件内容生成 sha-14. 结语5. References 1. .git/objects 目录 git 是一个根据文件内容进行检索的系统。 当创建 hello.py, 填入 print("hello, world")的内容, 并执行 git add hello.py gi…...

园区导航小程序:一站式解决园区导航问题,释放存储,优化访客体验
随着园区的规模不断扩大,功能区划分日益复杂,导致访客和新员工在没有有效导航的情况下容易迷路。传统APP导航虽能解决部分问题,但其下载安装繁琐、占用手机内存大、且非高频使用导致的闲置,让许多用户望而却步。园区导航小程序的出…...
对于n进制转十进制的解法及代码(干货!)
对于p进制转十进制,我们有:(x)pa[0]*p^0a[1]*p^1a[2]*p^2...a[n]*p^n 举个例子:(11001)21*10*20*41*81*1625 (9FA)1610*16^015*16^19*16^22554 据此,我们可以编出c代码来解决问题 …...

当代互联网打工人的生存现状,看完泪流满面!
欢迎私信小编,了解更多产品信息呦~...

花几千上万学习Java,真没必要!(三十八)
测试代码1: package iotest.com; import java.nio.charset.StandardCharsets; import java.io.UnsupportedEncodingException; public class StringByteConversion { public static void main(String[] args) throws UnsupportedEncodingException { // 原始字…...

Zilliz 2025届校园招聘正式启动,寻找向量数据库内核开发工程师
为了解决非结构化数据处理问题,我们构建了向量数据库-Milvus! Milvus 数据库不仅是顶级开源基金会 LF AI&Data 的毕业项目,还曾登上数据库顶会SIGMOD、VLDB,在全球首届向量检索比赛中夺冠。目前,Milvus 项目已获得超过 2.8w s…...

TwinCAT3 新建项目教程
文章目录 打开TwinCAT 新建项目(通过TcXaeShell) 新建项目(通过VS 2019)...

大模型算法面试题(十九)
本系列收纳各种大模型面试题及答案。 1、SFT(有监督微调)、RM(奖励模型)、PPO(强化学习)的数据集格式? SFT(有监督微调)、RM(奖励模型)、PPO&…...

应用地址信息获取新技巧:Xinstall来助力
在移动互联网时代,应用获取用户地址信息的需求越来越普遍。无论是为了提供个性化服务,还是进行精准营销,地址信息都扮演着至关重要的角色。然而,如何合规、准确地获取这一信息,却是许多开发者面临的挑战。今天…...

使用docker在3台服务器上搭建基于redis 6.x的一主两从三台均是哨兵模式
一、环境及版本说明 如果服务器已经安装了docker,则忽略此步骤,如果没有安装,则可以按照一下方式安装: 1. 在线安装(有互联网环境): 请看我这篇文章 传送阵>> 点我查看 2. 离线安装(内网环境):请看我这篇文章 传送阵>> 点我查看 说明:假设每台服务器已…...

遍历 Map 类型集合的方法汇总
1 方法一 先用方法 keySet() 获取集合中的所有键。再通过 gey(key) 方法用对应键获取值 import java.util.HashMap; import java.util.Set;public class Test {public static void main(String[] args) {HashMap hashMap new HashMap();hashMap.put("语文",99);has…...
Java如何权衡是使用无序的数组还是有序的数组
在 Java 中,选择有序数组还是无序数组取决于具体场景的性能需求与操作特点。以下是关键权衡因素及决策指南: ⚖️ 核心权衡维度 维度有序数组无序数组查询性能二分查找 O(log n) ✅线性扫描 O(n) ❌插入/删除需移位维护顺序 O(n) ❌直接操作尾部 O(1) ✅内存开销与无序数组相…...
React Native在HarmonyOS 5.0阅读类应用开发中的实践
一、技术选型背景 随着HarmonyOS 5.0对Web兼容层的增强,React Native作为跨平台框架可通过重新编译ArkTS组件实现85%以上的代码复用率。阅读类应用具有UI复杂度低、数据流清晰的特点。 二、核心实现方案 1. 环境配置 (1)使用React Native…...
Auto-Coder使用GPT-4o完成:在用TabPFN这个模型构建一个预测未来3天涨跌的分类任务
通过akshare库,获取股票数据,并生成TabPFN这个模型 可以识别、处理的格式,写一个完整的预处理示例,并构建一个预测未来 3 天股价涨跌的分类任务 用TabPFN这个模型构建一个预测未来 3 天股价涨跌的分类任务,进行预测并输…...

pikachu靶场通关笔记22-1 SQL注入05-1-insert注入(报错法)
目录 一、SQL注入 二、insert注入 三、报错型注入 四、updatexml函数 五、源码审计 六、insert渗透实战 1、渗透准备 2、获取数据库名database 3、获取表名table 4、获取列名column 5、获取字段 本系列为通过《pikachu靶场通关笔记》的SQL注入关卡(共10关࿰…...

蓝桥杯3498 01串的熵
问题描述 对于一个长度为 23333333的 01 串, 如果其信息熵为 11625907.5798, 且 0 出现次数比 1 少, 那么这个 01 串中 0 出现了多少次? #include<iostream> #include<cmath> using namespace std;int n 23333333;int main() {//枚举 0 出现的次数//因…...
MySQL账号权限管理指南:安全创建账户与精细授权技巧
在MySQL数据库管理中,合理创建用户账号并分配精确权限是保障数据安全的核心环节。直接使用root账号进行所有操作不仅危险且难以审计操作行为。今天我们来全面解析MySQL账号创建与权限分配的专业方法。 一、为何需要创建独立账号? 最小权限原则…...
Redis的发布订阅模式与专业的 MQ(如 Kafka, RabbitMQ)相比,优缺点是什么?适用于哪些场景?
Redis 的发布订阅(Pub/Sub)模式与专业的 MQ(Message Queue)如 Kafka、RabbitMQ 进行比较,核心的权衡点在于:简单与速度 vs. 可靠与功能。 下面我们详细展开对比。 Redis Pub/Sub 的核心特点 它是一个发后…...
【LeetCode】3309. 连接二进制表示可形成的最大数值(递归|回溯|位运算)
LeetCode 3309. 连接二进制表示可形成的最大数值(中等) 题目描述解题思路Java代码 题目描述 题目链接:LeetCode 3309. 连接二进制表示可形成的最大数值(中等) 给你一个长度为 3 的整数数组 nums。 现以某种顺序 连接…...