跟代码执行流程,读Megatron源码(二)训练入口pretrain_gpt.py
Megatron-LM默认支持GPT、T5、BERT等多个常见模型的预训练,当下大模型流行,故以pretrain_gpt.py为例做源码的走读。
一. 启动pretrain_gpt.py
pretrain_gpt.py为GPT类模型的训练入口,它通过命令行形式被调用,其精确执行路径位于Megatron-LM框架的examples/gpt3目录下。具体而言,启动过程依赖于train_gpt3_175b_distributed.sh这一脚本,该脚本专为部署GPT-3模型在分布式环境下训练而设计(当然,也可以参照编写自定义的启动脚本)。
在train_gpt3_175b_distributed.sh脚本内部,核心操作是通过torchrun命令实现的,该命令是PyTorch分布式训练的一部分,用于在多个计算节点上高效并行地执行pretrain_gpt.py。此过程确保了模型训练任务能够充分利用集群资源,加速训练过程,代码如下图:
二. torchrun简介
trochrun是PyTorch官方推荐用于替代torch.distributed.launch的分布式数据并行训练模块。它旨在提供一种更灵活、更健壮的方式来启动和管理分布式训练任务。
trochrun启动并行训练任务的原理如下:
1. 初始化分布式环境
trochrun首先负责初始化分布式训练所需的环境。这包括设置通信后端(如NCCL、GLOO等)、分配工作进程的RANK和WORLD_SIZE(即参与训练的总进程数),以及处理其他与分布式训练相关的配置。
2. 分配工作进程
trochrun会根据指定的参数(如--nnodes、--nproc-per-node等)来分配工作进程。这些进程可以是同一台机器上的多个 GPU,也可以是跨多台机器的GPU。每个进程都会加载相同的训练脚本(如pretrain_gpt.py),但会处理不同的数据子集,以实现并行训练。
3. 同步与通信
在训练过程中,torchrun 管理下的各个工作进程需要频繁地进行同步和通信。这包括梯度同步(在反向传播后同步各GPU上的梯度)、参数更新(使用同步后的梯度更新模型参数)等。PyTorch提供了丰富的API(如torch.distributed.all_reduce、torch.distributed.barrier等)来支持这些操作。
4. 优雅处理故障
trochrun相比torch.distributed.launch的一大改进是它能够更优雅地处理工作进程的故障。例如,如果某个工作进程因为某种原因崩溃了,torchrun可以尝试重新启动该进程,以确保训练任务的连续性。此外,torchrun还支持弹性训练(elastic training),即允许在训练过程中动态地增加或减少工作进程的数量。
5. 简化配置与启动
trochrun通过提供命令行接口和配置文件选项来简化分布式训练的配置和启动过程。用户只需指定少量的参数(如节点数、每节点进程数等),即可启动复杂的分布式训练任务。此外,torchrun还支持从环境变量中读取配置信息,这使得在不同环境中部署训练任务变得更加灵活。
6. 自动化资源分配
在某些情况下,torchrun还可以与资源管理器(如Kubernetes、Slurm等)集成,以自动化地分配和管理训练所需的计算资源。这包括GPU、CPU、内存和存储等资源。通过集成资源管理器,torchrun可以进一步提高分布式训练的可扩展性和灵活性。
总之,torchrun通过以上机制共同作用,使得使用PyTorch进行分布式训练变得更加高效、可靠和易于管理。
三. 主要函数
pretrain_gpt.py脚本封装了多个核心功能组件,具体包括model_provider(),forward_step(),train_valid_test_datasets_provider(),以及pretrain()等主要函数。
其中,model_provider()负责提供预训练所需的模型实例对象;forward_step()定义了模型前向传播的具体步骤,包括输入处理、模型计算等;train_valid_test_datasets_provider()则负责准备训练、验证及测试所需的数据集,确保数据的有效供给。
值得注意的是,前三个函数model_provider(),forward_step(),train_valid_test_datasets_provider()更是作为pretrain()函数的入参,共同构成了GPT模型训练入口。
这种设计确保了预训练过程的模块化、灵活性与可扩展性。下面会从model_provider()开始逐行解析源码。
四. 源码分析
1. model_provider
def model_provider(pre_process=True, post_process=True) -> Union[GPTModel, megatron.legacy.model.GPTModel]:"""Builds the model.If you set the use_legacy_models to True, it will return the legacy GPT model and if not the mcore GPT model.Args:pre_process (bool, optional): Set to true if you need to compute embedings. Defaults to True.post_process (bool, optional): Set to true if you need to want to compute output logits/loss. Defaults to True.Returns:Union[GPTModel, megatron.legacy.model.GPTModel]: The returned model"""args = get_args()use_te = args.transformer_impl == "transformer_engine"print_rank_0('building GPT model ...')# Experimental loading arguments from yamlif args.yaml_cfg is not None:config = core_transformer_config_from_yaml(args, "language_model")else:config = core_transformer_config_from_args(args)if args.use_legacy_models:model = megatron.legacy.model.GPTModel(config,num_tokentypes=0,parallel_output=True,pre_process=pre_process,post_process=post_process,)else: # using core modelsif args.spec is not None:transformer_layer_spec = import_module(args.spec)else:if use_te:transformer_layer_spec = get_gpt_layer_with_transformer_engine_spec(args.num_experts, args.moe_grouped_gemm, args.qk_layernorm)else:transformer_layer_spec = get_gpt_layer_local_spec(args.num_experts, args.moe_grouped_gemm, args.qk_layernorm)model = GPTModel(config=config,transformer_layer_spec=transformer_layer_spec,vocab_size=args.padded_vocab_size,max_sequence_length=args.max_position_embeddings,pre_process=pre_process,post_process=post_process,fp16_lm_cross_entropy=args.fp16_lm_cross_entropy,parallel_output=True,share_embeddings_and_output_weights=not args.untie_embeddings_and_output_weights,position_embedding_type=args.position_embedding_type,rotary_percent=args.rotary_percent,)return model
model_provider函数是用于构建GPT(生成预训练Transformer)模型实例的函数,它会以类似函数指针的形式,作为pretrain()的入参传递到后续的训练代码中,供训练过程调用。
该函数主要代码流程包含以下几个步骤:
a. 获取参数和配置:
通过 get_args() 函数获取命令行参数和配置文件中的参数。根据 args.transformer_impl 的值确定是否使用 Transformer Engine (use_te)。
b. 配置模型:
如果指定了 YAML 配置文件 (args.yaml_cfg),则从 YAML 文件中加载模型结构。否则,根据命令行参数 (args) 加载。
c. 选择模型类型:
如果args.use_legacy_models为True,则使用megatron.legacy.model.GPTModel构建模型。这通常用于向后兼容或测试旧版本的模型。
如果不使用旧版模型,则直接运行构建GPTModel,如图红框部分。
其中,GPTModel的参数包括配置 (config)、词汇表大小 (vocab_size)、最大序列长度 (max_sequence_length)、是否进行前处理和后处理、是否使用 FP16 进行语言模型交叉熵计算、是否并行输出等。这些参数基本上都来源于配置文件,关于配置文件的内容和解析将于下文详述。
注:此处的model_provider是作为函数指针传递到pretrain()中,函数指针只有在调用时才会真正执行,故,GPTModel的具体实现待到执行时再具体分析。
2. forward_step
def forward_step(data_iterator, model: GPTModel):"""Forward training step.Args:data_iterator : Input data iteratormodel (GPTModel): The GPT Model"""args = get_args()timers = get_timers()# Get the batch.timers('batch-generator', log_level=2).start()global stimerwith stimer(bdata=True):tokens, labels, loss_mask, attention_mask, position_ids = get_batch(data_iterator)timers('batch-generator').stop()with stimer:output_tensor = model(tokens, position_ids, attention_mask,labels=labels)return output_tensor, partial(loss_func, loss_mask)
forward_step顾名思义,这个函数是GPT模型训练过程中前向处理函数,负责处理一批输入数据并通过模型进行前向传播。
该函数的核心实现,仍然是对model(forward_step函数的入参)的forward的调用,只是在调用之前封装了计时器计时逻辑以及批次数据获取的逻辑(这部分逻辑会根据不同的业务场景变化而变化,故不能直接封装到model的forward函数中,而是应该在pretrain脚本中实现),具体代码流程如下:
a. 获取参数和计时器:
通过get_args()和get_timers()函数分别获取训练参数和计时器对象,用于控制训练过程和记录时间消耗。
b. 获取批次数据:
使用timers对象记录获取批次数据的时间(可选,通过log_level=2控制)。
调用get_batch函数从data_iterator中获取一批数据,包括tokens(输入文本对应的token IDs)、labels(训练标签,通常用于计算损失,对于语言模型任务,labels通常是tokens的右移一位版本)、loss_mask(损失掩码,用于忽略某些位置的损失计算,如填充位置)、attention_mask(注意力掩码,用于指示哪些位置需要参与注意力计算)和position_ids(位置ID,用于模型中的位置编码)。
c. 模型前向传播:
使用stimer(可能是一个自定义的计时器)记录模型前向传播的时间。
将获取到的数据(tokens, position_ids, attention_mask, labels)传递给模型model进行前向传播。这里labels是可选的,用于计算损失,但在前向传播阶段不一定需要。
模型输出output_tensor,通常包含模型的预测结果(如logits)。
d. 返回输出和损失函数:
返回output_tensor和partial函数,该函数需要loss_mask作为参数来计算损失。这种方式允许延迟损失的计算,直到所有相关的数据都已准备好。
3. train_valid_test_datasets_provider
def train_valid_test_datasets_provider(train_val_test_num_samples):"""Build the train test and validation datasets.Args:train_val_test_num_samples : A list containing the number of samples in train test and validation."""args = get_args()config = core_gpt_dataset_config_from_args(args)if args.mock_data:dataset_type = MockGPTDatasetelse:dataset_type = GPTDatasetprint_rank_0("> building train, validation, and test datasets for GPT ...")train_ds, valid_ds, test_ds = BlendedMegatronDatasetBuilder(dataset_type,train_val_test_num_samples,is_dataset_built_on_rank,config).build()print_rank_0("> finished creating GPT datasets ...")return train_ds, valid_ds, test_ds
该函数接收一个参数train_val_test_num_samples,这是一个列表,包含了训练集、验证集和测试集的样本数量。函数的目的是根据提供的参数和配置,构建GPT模型的训练、验证和测试数据集。主要代码流程如下:
a. 获取参数和配置:
使用get_args()函数获取训练过程中的全局参数,并通过core_gpt_dataset_config_from_args根据这些参数生成数据集配置对象config。
b. 确定数据集类型:
根据args.mock_data的值决定使用哪种数据集类型。如果mock_data为True,则使用MockGPTDataset,这是一种模拟数据集,可能用于测试或快速原型开发。如果mock_data为False,则使用GPTDataset,这是实际的数据集类型,包含真实的训练数据。
c. 构建数据集:
使用BlendedMegatronDatasetBuilder类来构建数据集,传递给BlendedMegatronDatasetBuilder的参数包括数据集类型dataset_type、训练/验证/测试集的样本数量train_val_test_num_samples、is_dataset_built_on_rank(用于检查当前处理单元是否负责构建数据集),以及配置对象config。
调用build()方法实际构建数据集,该方法返回三个数据集对象:训练集train_ds、验证集valid_ds和测试集test_ds。
其中BlendedMegatronDatasetBuilder来源于包“megatron.core.datasets.blended_megatron_dataset_builder”,由于数据集构建逻辑比较简单,故,在此不做详述,有兴趣的同学可以自行查看。
d. 返回值
函数返回三个数据集对象:训练集train_ds、验证集valid_ds和测试集test_ds,这些对象可以用于后续的训练、验证和测试过程。
4. pretrain
pretrain函数是megatron/pretrain_gpt.py文件中的一个执行入口,通常会将该函数写于文件的末尾。该函数被第一章中的启动脚本调用,进而开启训练流程。
pretrain函数的入参如下:
train_valid_test_datasets_provider:这是第3小节分析的函数指针,负责提供训练、验证和测试数据集。
model_provider:这是第1小节分析的函数指针,负责提供GPT模型的实例。它可能根据传入的配置或参数来初始化模型。
ModelType.encoder_or_decoder:这个参数指定了模型的类型,这里是编码器或解码器(对于GPT模型,它实际上是一个解码器)。
forward_step:这是第2小节分析的函数指针,定义了模型训练过程中的一个前向传播步骤,包括数据的前向传递和损失的计算。
args_defaults:这是一个字典,包含了预训练过程中一些默认参数的键值对。在这个例子中,它指定了默认的tokenizer_type为GPT2BPETokenizer,这意味着在文本预处理时将使用基于BPE(Byte Pair Encoding)的GPT-2分词器。
至此,pretrain_gpt.py的源码基本解析完毕,下一篇文章将以pretrain函数为入口,跟随代码运行流程,深入其内部实现,详细解析。
相关文章:

跟代码执行流程,读Megatron源码(二)训练入口pretrain_gpt.py
Megatron-LM默认支持GPT、T5、BERT等多个常见模型的预训练,当下大模型流行,故以pretrain_gpt.py为例做源码的走读。 一. 启动pretrain_gpt.py pretrain_gpt.py为GPT类模型的训练入口,它通过命令行形式被调用,其精确执行路径位于M…...
MATLAB练习题——矩阵(2)
逻辑运算 a [5 0.2 0 -8 -0.7 ],在进行逻辑运算时,a 相当于什么样的逻辑量。 相当于 a[1 1 0 1 1] 角度运算 在 sin(x)运算中,x 是角度还是弧度? 在 sin(x)运算中,x 是弧度,MATLAB 规定所有…...
arm、AArch64、x86、amd64、x86_64 的区别
arm vs AArch64 vs amd64 vs x86_64 vs x86 的区别 当涉及到 CPU 的时候,有许多术语:AArch64、x86_64、amd64、arm 等等。了解它们是什么以及它们之间的区别。 当你查看数据表或软件下载页面时是否被 ARM、AArch64、x86_64、i386 等术语混淆?…...

【SpringBoot】 jasypt配置文件密码加解密
目前我们对yml配置文件中的密码都是明文显示,显然这不安全,有的程序员离职了以后可能会做一些非法骚操作,所以我们最好要做一个加密,只能让领导架构师或者技术经理知道这个密码。所以这节课就需要来实现一下。 我们可以使用jasypt…...

复杂网络的任意子节点的网络最短距离
复杂网络的任意子节点的网络最短距离 题目要求介绍 本文算法测试用的数据集为空手道俱乐部,其中空手道俱乐部的数据集可通过这个链接进行下载•http://vlado.fmf.uni-lj.si/pub/networks/data/Ucinet/UciData.htm#zachary 摘要 本文旨在解决复杂网络中任意子节点…...

(Qt) 文件读写基础
文章目录 🗂️前言📄ref📄访问标记🗃️enum 标记 🗂️Code📄demo📄分点讲解🗃️继承体系🗃️打开/关闭🗃️写🗃️读 🗂️END…...

全产业布局对穿戴甲品牌连锁店的意义
对于美甲行业来说,穿戴甲虽然不是什么新生事物,但也就是近两年才流行开来。面对井喷的市场需求,相应的从业者,不管是品牌连锁店,还是做批发、外贸,美甲周边、亦或是OEM的,大家都忙得不亦乐乎&am…...

git的一些使用技巧(git fetch 和 git pull的区别,git merge 和 git rebase的区别)
最近闲来无聊,虽然会使用git操作,但是 git fetch 和 git pull 的区别,git merge 和 git rebase的区别只是一知半解,稍微研究一下; git fetch 和 git pull 的区别 git fetch git fetch 是将远程仓库中的改动拉到本地…...

展厅中控系统有哪些优势呢
格芬科技的展厅中控系统具有多方面的优势,主要体现在以下几个方面: 一、高度集成与灵活控制 全终端网络可编程:格芬科技的展厅中控系统采用全终端网络可编程技术,能够实现对展厅内各种设备的集中控制和管理,包括电脑…...

FPGA开发在verilog中关于阻塞和非阻塞赋值的区别
一、概念 阻塞赋值:阻塞赋值的赋值号用“”表示,对应的是串行执行。 对应的电路结构往往与触发沿没有关系,只与输入电平的变化有关系。阻塞赋值的操作可以认为是只有一个步骤的操作,即计算赋值号右边的语句并更新赋值号左边的语句…...
动态特征转换的艺术:在Mojo模型中实现自定义变换的策略
动态特征转换的艺术:在Mojo模型中实现自定义变换的策略 在机器学习中,特征转换是数据预处理的关键步骤,它直接影响模型的性能和结果的准确性。Mojo模型,作为一种高效的模型部署形式,允许在不同环境中运行模型并进行预…...

如何让Python爬虫在遇到异常时继续运行
概述 在数据收集和数据挖掘中,爬虫技术是一项关键技能。然而,爬虫在运行过程中不可避免地会遇到各种异常情况,如网络超时、目标网站变化、数据格式不一致等。如果不加以处理,这些异常可能会导致爬虫程序中断,影响数据…...

手把手带你搭建Snort入侵检测系统
在当今数字化社会,网络安全问题日益突出。为了有效防范网络攻击,部署入侵检测系统(IDS)是必要的防护措施。Snort作为一款功能强大的开源IDS工具,被广泛应用于各种网络环境中。本文将手把手教您如何从零开始实现Snort入…...

小程序内嵌uniapp页面跳转回小程序指定页面方式
使用微信小程序提供的Api:wx.miniProgram.navigateTo 在小程序中嵌套uniapp的H5页面,并使用wx.miniProgram.navigateTo进行页面跳转,需要确保满足以下条件: 你的小程序必须是通过uniapp构建的,并且支持小程序嵌套。 你…...

基于 Three.js 的 3D 模型加载优化
作者:来自 vivo 互联网前端团队- Su Ning 作为一个3D的项目,从用户打开页面到最终模型的渲染需要经过多个流程,加载的时间也会比普通的H5项目要更长一些,从而造成大量的用户流失。为了提升首屏加载的转化率,需要尽可能…...

Jlink下载与适配keil ccs theia教程 用jlink代替ti自己的下载仿真器
用jlink代替ti自己的下载仿真器,然后你去买立创的m0g3507才19.9包赚160 安装 J-Link 软件包 J-Link 软件包 v7.88i 或更高版本支持 MSPM0。 从 Segger 网站下载安装程序 按照安装程序说明操作 安装程序将自动请求更新 IAR 或 Keil(如果已安装&#x…...
C# 进制之间的转换(二进制,八进制,十进制,十六进制)
常用的方法是:Convert.ToString(byte value, int toBase), 并且有多个重载方法, value的类型可以为short,int 等,但必须是整数且不能为负数, 一般默认为十进制 toBase: 返回值的基数,必须是 2、…...

Linux 基础开发工具 : Vim编辑器
Vim 是 Linux 和其他类 Unix 系统上广泛使用的文本编辑器之一。它基于更早的 vi 编辑器,但添加了许多增强功能和扩展。Vim 是“Vi IMproved”的缩写,意为“改进的 Vi”,我们常使用Vim编辑器编写c/c代码。 ps:该篇介绍均为最基础介…...

Delphi 11.2 配置Android SDK 环境
打开 Delphi 11 点击 Tools–Options… 然后点击 Deployment–SDK Manager–Add… 这里如果配置64位就选 Android 64-bit,如果配置32位就选 Android 32-bit 点击 Select an SDK version–Add New… 有警告图标的就是有问题的项,需要手动更新一下…...

Spring Boot 学习(10)——固基(Idea 配置 git 访问 gitee)
几转眼就过了两个月,其实也没有闲着,学也学了,只是繁杂事多,学的不如以前多,也没有做过笔记了。 以前做开发因条件受限,没有什么 git ,也没有 gitee。现在出来混要跟上形势才行,学习…...

使用VSCode开发Django指南
使用VSCode开发Django指南 一、概述 Django 是一个高级 Python 框架,专为快速、安全和可扩展的 Web 开发而设计。Django 包含对 URL 路由、页面模板和数据处理的丰富支持。 本文将创建一个简单的 Django 应用,其中包含三个使用通用基本模板的页面。在此…...

大话软工笔记—需求分析概述
需求分析,就是要对需求调研收集到的资料信息逐个地进行拆分、研究,从大量的不确定“需求”中确定出哪些需求最终要转换为确定的“功能需求”。 需求分析的作用非常重要,后续设计的依据主要来自于需求分析的成果,包括: 项目的目的…...
R语言AI模型部署方案:精准离线运行详解
R语言AI模型部署方案:精准离线运行详解 一、项目概述 本文将构建一个完整的R语言AI部署解决方案,实现鸢尾花分类模型的训练、保存、离线部署和预测功能。核心特点: 100%离线运行能力自包含环境依赖生产级错误处理跨平台兼容性模型版本管理# 文件结构说明 Iris_AI_Deployme…...

学习STC51单片机31(芯片为STC89C52RCRC)OLED显示屏1
每日一言 生活的美好,总是藏在那些你咬牙坚持的日子里。 硬件:OLED 以后要用到OLED的时候找到这个文件 OLED的设备地址 SSD1306"SSD" 是品牌缩写,"1306" 是产品编号。 驱动 OLED 屏幕的 IIC 总线数据传输格式 示意图 …...

微服务商城-商品微服务
数据表 CREATE TABLE product (id bigint(20) UNSIGNED NOT NULL AUTO_INCREMENT COMMENT 商品id,cateid smallint(6) UNSIGNED NOT NULL DEFAULT 0 COMMENT 类别Id,name varchar(100) NOT NULL DEFAULT COMMENT 商品名称,subtitle varchar(200) NOT NULL DEFAULT COMMENT 商…...

Spring数据访问模块设计
前面我们已经完成了IoC和web模块的设计,聪明的码友立马就知道了,该到数据访问模块了,要不就这俩玩个6啊,查库势在必行,至此,它来了。 一、核心设计理念 1、痛点在哪 应用离不开数据(数据库、No…...
Python ROS2【机器人中间件框架】 简介
销量过万TEEIS德国护膝夏天用薄款 优惠券冠生园 百花蜂蜜428g 挤压瓶纯蜂蜜巨奇严选 鞋子除臭剂360ml 多芬身体磨砂膏280g健70%-75%酒精消毒棉片湿巾1418cm 80片/袋3袋大包清洁食品用消毒 优惠券AIMORNY52朵红玫瑰永生香皂花同城配送非鲜花七夕情人节生日礼物送女友 热卖妙洁棉…...

NXP S32K146 T-Box 携手 SD NAND(贴片式TF卡):驱动汽车智能革新的黄金组合
在汽车智能化的汹涌浪潮中,车辆不再仅仅是传统的交通工具,而是逐步演变为高度智能的移动终端。这一转变的核心支撑,来自于车内关键技术的深度融合与协同创新。车载远程信息处理盒(T-Box)方案:NXP S32K146 与…...
JavaScript基础-API 和 Web API
在学习JavaScript的过程中,理解API(应用程序接口)和Web API的概念及其应用是非常重要的。这些工具极大地扩展了JavaScript的功能,使得开发者能够创建出功能丰富、交互性强的Web应用程序。本文将深入探讨JavaScript中的API与Web AP…...
【SSH疑难排查】轻松解决新版OpenSSH连接旧服务器的“no matching...“系列算法协商失败问题
【SSH疑难排查】轻松解决新版OpenSSH连接旧服务器的"no matching..."系列算法协商失败问题 摘要: 近期,在使用较新版本的OpenSSH客户端连接老旧SSH服务器时,会遇到 "no matching key exchange method found", "n…...