【复现DeepSeek-R1之Open R1实战】系列5:SFT源码逐行深度解析
目录
- 3 SFT源码分析
- 3.1 accelerate
- 3.1.1 关键特性
- 3.1.2 使用场景
- 3.1.3 简单示例
- 3.2 代码主入口
- 3.3 设置随机种子
- 3.4 设置Log
- 3.5 加载数据集
- 3.6 加载Tokenizer
- 3.7 模型参数配置初始化
- 3.8 初始化SFT Trainer
- 3.9 开始训练
- 3.9.1 主函数
- 3.9.2 核心循环
- 3.9.3 单步训练
- 3.9.4 原始Loss计算方法
- 3.9.5 标签平滑
- 3.9.6 SFT的Loss计算方法
- 3.9.7 计算令牌准确性
- 3.10 保存模型
- 3.11 评估
- 3.12 推送到Hub
【复现DeepSeek-R1之Open R1实战】系列3:SFT和GRPO源码逐行深度解析(上)
【复现DeepSeek-R1之Open R1实战】系列5:SFT和GRPO源码逐行深度解析(中)
省流:本文重点是【3.9 开始训练】小节。
3 SFT源码分析
HuggingFace已经将很多重要的函数都封装好了,我们只需要掉包就能简单实现SFT了。
前面几篇博文我们详细介绍了如何一步步搭建环境了,感兴趣的话可以翻阅一下,此处不展开细说了:
- 【复现DeepSeek-R1之Open R1实战】系列1:跑通SFT(一步步操作,手把手教学)
- 【复现DeepSeek-R1之Open R1实战】系列2:没有卡也能训模型!Colab跑OpenR1(附源码)
3.1 accelerate
我们使用了accelerate库来训练模型:
# Train via command line
accelerate launch --config_file=recipes/accelerate_configs/zero3.yaml src/open_r1/sft.py \--model_name_or_path Qwen/Qwen2.5-1.5B-Instruct \--dataset_name HuggingFaceH4/Bespoke-Stratos-17k \--learning_rate 2.0e-5 \--num_train_epochs 1 \--packing \--max_seq_length 4096 \--per_device_train_batch_size 2 \--gradient_accumulation_steps 8 \--gradient_checkpointing \--bf16 \--output_dir data/Qwen2.5-1.5B-Open-R1-Distill# Train via YAML config
accelerate launch --config_file recipes/accelerate_configs/zero3.yaml src/open_r1/sft.py \--config recipes/Qwen2.5-1.5B-Instruct/sft/config_demo.yaml
Accelerate 是 Hugging Face 开发的一个库,旨在简化深度学习模型的训练过程,特别是在分布式环境或使用不同硬件(如多个GPU、TPU等)时。它提供了一个统一且灵活的接口,使得用户能够轻松地配置和运行训练脚本,而无需深入理解复杂的分布式计算概念。以下是 Accelerate 的一些关键特性和优势:
3.1.1 关键特性
-
简化分布式训练:无论是单机多卡、多机多卡还是TPU训练,
Accelerate都能通过简单的配置文件或者命令行参数进行设置,大大降低了分布式训练的复杂性。 -
灵活性与可扩展性:支持多种深度学习框架,但主要与PyTorch集成得最为紧密。它允许用户在不修改核心训练代码的情况下调整训练策略,包括混合精度训练、梯度累积、梯度检查点等高级功能。
-
易于使用的API:
Accelerate提供了一个高层次的API,使得启动训练任务变得非常简单。例如,你可以使用Accelerator()对象来包裹你的训练循环,它会自动处理设备分配、数据加载器的优化等细节。 -
配置管理:通过一个简单的YAML格式配置文件,用户可以指定训练所需的各种参数,比如使用的设备类型(CPU/GPU/TPU)、是否启用混合精度训练等,这极大地提高了实验的可重复性。
-
兼容性:与Hugging Face Transformers库高度集成,可以直接用于Transformer模型的训练。当然,它也适用于其他类型的神经网络模型。
3.1.2 使用场景
- 当你需要在不同的硬件环境中快速部署训练任务时。
- 在探索不同的训练策略(如改变批大小、学习率等)时,
Accelerate能让你以最小的代码改动实现这些变化。 - 如果你正在寻找一种方法来简化分布式训练的配置和执行流程,
Accelerate是一个很好的选择。
3.1.3 简单示例
下面是一个如何使用Accelerate进行简单训练任务的例子:
from accelerate import Accelerator
accelerator = Accelerator()model, optimizer, train_dataloader, scheduler = accelerator.prepare(model, optimizer, train_dataloader, scheduler
)for epoch in range(num_epochs):for batch in train_dataloader:outputs = model(batch)loss = loss_function(outputs, labels)accelerator.backward(loss)optimizer.step()scheduler.step()optimizer.zero_grad()
在这个例子中,Accelerator对象帮助我们自动化了许多底层细节,如将模型和数据迁移到正确的设备上,以及处理分布式训练中的通信问题。这样,开发者就可以专注于模型设计和训练策略本身。
3.2 代码主入口
if __name__ == "__main__":parser = TrlParser((ScriptArguments, SFTConfig, ModelConfig))script_args, training_args, model_args = parser.parse_args_and_config()main(script_args, training_args, model_args)
首先调用了TrlPaser库,将输入的参数归类分成script_args, training_args, model_args这三类,每一类都是封装好的函数,这样便于拓展和迁移使用。
- script_args 主要是一些关于数据集的参数,例如 dataset_name(数据名称/路径)、dataset_config(数据集的配置)、dataset_train_split(训练集)、dataset_test_split(测试集)等等。
- training_args 继承自SFTConfig类,主要是一些关于训练的参数,例如 max_seq_length(tokenized序列的最大长度)、learning_rate等等。
- model_args 主要是一些关于模型的参数,例如 model_name_or_path(模型名称/路径)、torch_dtype(数据类型:bfloat16、float16、float32和auto)。
3.3 设置随机种子
设置随机种子,默认是42。主要是为了确保实验的可重复性,在训练模型时,涉及许多随机过程,例如初始化权重、数据集的shuffle等。通过固定随机种子,可以使得这些随机过程在每次运行时都产生相同的结果,从而保证实验结果的一致性和可重复性。
另外,当模型出现问题或需要调整参数时,固定的随机种子可以帮助开发者更容易地进行调试。因为相同的输入会得到相同的输出,这有助于定位问题。
在进行模型选择或调参时,使用相同的随机种子可以让不同的实验之间只存在因模型架构或参数设置不同而产生的差异,而非由于随机因素导致的变化,这样可以更准确地评估模型性能。
# Set seed for reproducibilityset_seed(training_args.seed)
3.4 设置Log
主要是打印一些关键信息,例如 系统时间、训练和模型参数配置等等。
logging.basicConfig(format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",datefmt="%Y-%m-%d %H:%M:%S",handlers=[logging.StreamHandler(sys.stdout)],)log_level = training_args.get_process_log_level()logger.setLevel(log_level)datasets.utils.logging.set_verbosity(log_level)transformers.utils.logging.set_verbosity(log_level)transformers.utils.logging.enable_default_handler()transformers.utils.logging.enable_explicit_format()# Log on each process a small summarylogger.warning(f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"+ f" distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}")logger.info(f"Model parameters {model_args}")logger.info(f"Script parameters {script_args}")logger.info(f"Training parameters {training_args}")if "wandb" in training_args.report_to:init_wandb_training(training_args)
此外,还会从output的文件夹中获取最新的checkpoint,打印checkpoint信息。
# Check for last checkpointlast_checkpoint = Noneif os.path.isdir(training_args.output_dir):last_checkpoint = get_last_checkpoint(training_args.output_dir)if last_checkpoint is not None and training_args.resume_from_checkpoint is None:logger.info(f"Checkpoint detected, resuming training at {last_checkpoint=}.")
3.5 加载数据集
通过load_dataset加载来自Hugging Face Hub的数据集或本地数据集,我们可以在Hugging Face Hub上找到数据集列表,或者使用[huggingface_hub.list_datasets]进行查找。
这个函数在后台执行以下操作:
- 加载数据集构建器:
- 确定数据集中最常见的数据格式并选择其关联的构建器(例如JSON、CSV、Parquet、Webdataset、ImageFolder等)。
- 根据文件名和目录名或YAML配置确定哪些文件属于哪个分割(例如训练/测试)。
- 也可以手动指定data_files以及要使用的数据集构建器(例如"parquet")。
- 运行数据集构建器:
- 在一般情况下:
- 如果数据文件尚未在本地可用或缓存,则从数据集中下载这些文件。
- 将数据集处理并缓存为类型化的Arrow表以用于缓存。Arrow表是任意长度的、类型化的表格,可以存储嵌套对象,并映射到numpy/pandas/python的通用类型。它们可以直接从磁盘访问、加载到RAM中甚至通过网络流式传输。
- 在流式处理的情况下:
- 不下载或缓存任何内容。相反,数据集将被惰性加载并在迭代时动态流式传输。
- 在一般情况下:
- 返回由split参数(默认:所有)请求的分割构建的数据集。
dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config)
3.6 加载Tokenizer
关于Tokenizer的详细介绍可以看上一篇博文。
执行完这段,就会从预训练的大模型文件夹中自动加载Tokenizer。
tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code, use_fast=True)tokenizer.pad_token = tokenizer.eos_token
3.7 模型参数配置初始化
主要是完成模型加载时的一些参数配置,例如数据类型、量化配置等等。
logger.info("*** Initializing model kwargs ***")torch_dtype = (model_args.torch_dtype if model_args.torch_dtype in ["auto", None] else getattr(torch, model_args.torch_dtype))quantization_config = get_quantization_config(model_args)model_kwargs = dict(revision=model_args.model_revision,trust_remote_code=model_args.trust_remote_code,attn_implementation=model_args.attn_implementation,torch_dtype=torch_dtype,use_cache=False if training_args.gradient_checkpointing else True,device_map=get_kbit_device_map() if quantization_config is not None else None,quantization_config=quantization_config,)training_args.model_init_kwargs = model_kwargs
3.8 初始化SFT Trainer
SFT Trainer继承自transformers库的Trainer类,
trainer = SFTTrainer(model=model_args.model_name_or_path,args=training_args,train_dataset=dataset[script_args.dataset_train_split],eval_dataset=dataset[script_args.dataset_test_split] if training_args.eval_strategy != "no" else None,processing_class=tokenizer,peft_config=get_peft_config(model_args),callbacks=get_callbacks(training_args, model_args),)
3.9 开始训练
logger.info("*** Train ***")checkpoint = Noneif training_args.resume_from_checkpoint is not None:checkpoint = training_args.resume_from_checkpointelif last_checkpoint is not None:checkpoint = last_checkpointtrain_result = trainer.train(resume_from_checkpoint=checkpoint)metrics = train_result.metricsmetrics["train_samples"] = len(dataset[script_args.dataset_train_split])trainer.log_metrics("train", metrics)trainer.save_metrics("train", metrics)trainer.save_state()
3.9.1 主函数
train()函数会先加载模型以及完成一些初始化工作,然后通过 find_executable_batch_size 装饰器函数以某种方式调用目标函数 _inner_training_loop:要么是直接使用给定的批处理大小,要么是经过调整找到的最佳批处理大小。find_executable_batch_size 函数的目的是帮助自动找到适合执行的batch size,特别是对于那些可能因为内存不足(out-of-memory)或CUDNN相关异常而失败的操作。
inner_training_loop = find_executable_batch_size(self._inner_training_loop, self._train_batch_size, args.auto_find_batch_size)
3.9.2 核心循环
最关键的是成员函数_inner_training_loop,该方法涵盖了从初始化到训练结束的整个过程。
- 初始化与状态设置
- 记录训练参数如批处理大小、总训练批处理大小、梯度累积步数、优化步骤总数及可训练参数数量。
- 初始化训练状态变量。
if self.args.per_device_train_batch_size != self._train_batch_size:logger.info(f" Training with DataParallel so batch size has been adjusted to: {self._train_batch_size:,}")
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_train_batch_size:,}")
logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
logger.info(f" Total optimization steps = {max_steps:,}")
logger.info(f" Number of trainable parameters = {get_model_param_count(model, trainable_only=True):,}")self.state.epoch = 0
start_time = time.time()
epochs_trained = 0
steps_trained_in_current_epoch = 0
steps_trained_progress_bar = None
- 检查点恢复
- 如果提供了检查点路径并且存在相应的状态文件,则从检查点恢复训练状态。
if resume_from_checkpoint is not None and os.path.isfile(os.path.join(resume_from_checkpoint, TRAINER_STATE_NAME)):self.state = TrainerState.load_from_json(os.path.join(resume_from_checkpoint, TRAINER_STATE_NAME))epochs_trained = int(self.state.global_step // num_update_steps_per_epoch)steps_trained_in_current_epoch = self.state.global_step % (num_update_steps_per_epoch)if not args.ignore_data_skip:steps_trained_in_current_epoch *= args.gradient_accumulation_stepslogger.info(" Continuing training from checkpoint, will skip to saved global_step")logger.info(f" Continuing training from epoch {epochs_trained}")logger.info(f" Continuing training from global step {self.state.global_step}")
- 更新引用
- 更新回调处理器中的模型、优化器、学习率调度器和数据加载器的引用。
self.callback_handler.model = self.model
self.callback_handler.optimizer = self.optimizer
self.callback_handler.lr_scheduler = self.lr_scheduler
self.callback_handler.train_dataloader = train_dataloader
- 状态更新
- 设置
self.state.max_steps和self.state.num_train_epochs,并确保进程零的状态正确性。
self.state.max_steps = max_steps
self.state.num_train_epochs = num_train_epochs
self.state.is_local_process_zero = self.is_local_process_zero()
self.state.is_world_process_zero = self.is_world_process_zero()
- 初始化损失变量
- 初始化
tr_loss和_total_loss_scalar,并将模型梯度置零。
tr_loss = torch.tensor(0.0).to(args.device)
self._total_loss_scalar = 0.0
self._globalstep_last_logged = self.state.global_step
model.zero_grad()
- 回调处理
- 调用
on_train_begin回调,并在训练开始时进行一次评估(如果配置了)。
self.control = self.callback_handler.on_train_begin(args, self.state, self.control)
if args.eval_on_start:self._evaluate(trial, ignore_keys_for_eval, skip_scheduler=True)
- 主训练循环
- 遍历每个epoch,并在每个epoch开始时调用
on_epoch_begin回调。 - 根据是否需要同步梯度设置加速器的状态,并执行单步训练 (
training_step)。 - 对于同步梯度步骤:
- 进行梯度裁剪。
- 执行优化器步骤,并根据情况更新学习率调度器。
- 将模型梯度置零,并更新全局步数和当前epoch。
- 调用
on_step_end回调并可能进行日志记录、保存和评估。
- 对于非同步梯度步骤,调用
on_substep_end回调。
for epoch in range(epochs_trained, num_train_epochs):epoch_dataloader = train_dataloaderif hasattr(epoch_dataloader, "set_epoch"):epoch_dataloader.set_epoch(epoch)steps_in_epoch = len(epoch_dataloader) if len_dataloader is not None else args.max_steps * args.gradient_accumulation_stepsself.control = self.callback_handler.on_epoch_begin(args, self.state, self.control)# 处理从检查点恢复的情况if epoch == epochs_trained and resume_from_checkpoint is not None and steps_trained_in_current_epoch == 0:self._load_rng_state(resume_from_checkpoint)rng_to_sync = Falsesteps_skipped = 0if steps_trained_in_current_epoch > 0:epoch_dataloader = skip_first_batches(epoch_dataloader, steps_trained_in_current_epoch)steps_skipped = steps_trained_in_current_epochsteps_trained_in_current_epoch = 0rng_to_sync = Truestep = -1epoch_iterator = iter(epoch_dataloader)for _ in range(total_updates):update_step += 1num_batches = args.gradient_accumulation_steps if update_step != (total_updates - 1) else remainderbatch_samples, num_items_in_batch = self.get_batch_samples(epoch_iterator, num_batches)for i, inputs in enumerate(batch_samples):step += 1do_sync_step = (step + 1) % args.gradient_accumulation_steps == 0 or (step + 1) == steps_in_epochif not do_sync_step:self.accelerator.gradient_state._set_sync_gradients(False)else:self.accelerator.gradient_state._set_sync_gradients(True)tr_loss_step = self.training_step(model, inputs, num_items_in_batch)if args.logging_nan_inf_filter and not is_torch_xla_available() and (torch.isnan(tr_loss_step) or torch.isinf(tr_loss_step)):tr_loss = tr_loss + tr_loss / (1 + self.state.global_step - self._globalstep_last_logged)else:tr_loss = tr_loss + tr_loss_stepif do_sync_step:if args.max_grad_norm is not None and args.max_grad_norm > 0:if is_sagemaker_mp_enabled() and args.fp16:_grad_norm = self.optimizer.clip_master_grads(args.max_grad_norm)elif self.use_apex:_grad_norm = nn.utils.clip_grad_norm_(amp.master_params(self.optimizer), args.max_grad_norm)else:_grad_norm = self.accelerator.clip_grad_norm_(model.parameters(), args.max_grad_norm)grad_norm = _grad_norm.item() if hasattr(_grad_norm, "item") else _grad_normself.control = self.callback_handler.on_pre_optimizer_step(args, self.state, self.control)self.optimizer.step()self.control = self.callback_handler.on_optimizer_step(args, self.state, self.control)optimizer_was_run = not self.accelerator.optimizer_step_was_skippedif optimizer_was_run and not isinstance(self.lr_scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):self.lr_scheduler.step()model.zero_grad()self.state.global_step += 1self.state.epoch = epoch + (step + 1 + steps_skipped) / steps_in_epochself.control = self.callback_handler.on_step_end(args, self.state, self.control)self._maybe_log_save_evaluate(tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval, start_time)else:self.control = self.callback_handler.on_substep_end(args, self.state, self.control)if self.control.should_epoch_stop or self.control.should_training_stop:if is_torch_xla_available():xm.mark_step()breakif self.control.should_epoch_stop or self.control.should_training_stop:if is_torch_xla_available():xm.mark_step()break
- Epoch结束处理
- 调用
on_epoch_end回调并可能进行日志记录、保存和评估。 - 如果启用了TPU调试选项,则打印调试指标报告。
self.control = self.callback_handler.on_epoch_end(args, self.state, self.control)
self._maybe_log_save_evaluate(tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval, start_time)if DebugOption.TPU_METRICS_DEBUG in self.args.debug:if is_torch_xla_available():xm.master_print(met.metrics_report())else:logger.warning("You enabled PyTorch/XLA debug metrics but you don't have a TPU configured.")
if self.control.should_training_stop:break
- 训练结束处理
- 输出一条信息提示训练完成。
- 如果配置了在训练结束时加载最佳模型,则加载最佳模型检查点。
- 计算总损失并将结果添加到
self._total_loss_scalar中。 - 计算训练速度指标 (
speed_metrics) 并记录它们。 - 停止内存跟踪器并更新指标。
- 记录最终的训练指标。
- 根据保存限制删除旧的检查点。
- 调用
on_train_end回调并完成当前推送操作。 - 清理嵌入层的前向后钩子(如果使用了NEFTune噪声)。
- 返回包含全局步数、训练损失和指标的
TrainOutput对象。
logger.info("\n\nTraining completed. Do not forget to share your model on huggingface.co/models =)\n\n")
if args.load_best_model_at_end and self.state.best_model_checkpoint is not None:if is_torch_xla_available():xm.rendezvous("load_best_model_at_end")elif args.parallel_mode == ParallelMode.DISTRIBUTED:dist.barrier()elif is_sagemaker_mp_enabled():smp.barrier()self._load_best_model()self._total_loss_scalar += tr_loss.item()
effective_global_step = max(self.state.global_step, 0.001)
train_loss = self._total_loss_scalar / effective_global_stepmetrics = speed_metrics("train", start_time, num_samples=num_train_samples, num_steps=self.state.max_steps, num_tokens=num_train_tokens)
self.store_flos()
metrics["total_flos"] = self.state.total_flos
metrics["train_loss"] = train_lossself.log(metrics)run_dir = self._get_output_dir(trial)
checkpoints_sorted = self._sorted_checkpoints(use_mtime=False, output_dir=run_dir)if self.args.should_save and self.state.best_model_checkpoint is not None and self.args.save_total_limit == 1:for checkpoint in checkpoints_sorted:if not os.path.samefile(checkpoint, self.state.best_model_checkpoint):logger.info(f"Deleting older checkpoint [{checkpoint}] due to args.save_total_limit")shutil.rmtree(checkpoint, ignore_errors=True)self.control = self.callback_handler.on_train_end(args, self.state, self.control)
self._finish_current_push()if self.neftune_noise_alpha is not None:self._deactivate_neftune(self.model)return TrainOutput(self.state.global_step, train_loss, metrics)
3.9.3 单步训练
在 _inner_training_loop 方法中,单步训练是在 training_step 方法内完成的。
with context():tr_loss_step = self.training_step(model, inputs, num_items_in_batch)
training_step 方法是训练过程中对每个批次数据执行单步训练的核心函数。它负责前向传播、计算损失、后向传播等操作,并返回当前批次的训练损失,主要包括以下几个步骤:
- 前向传播:通过 self.compute_loss 方法计算损失,该方法通常包含模型的前向传播和损失函数的计算。
- 后向传播:根据是否使用Apex混合精度训练,选择不同的方式进行后向传播。
- 多GPU处理:如果使用多个GPU进行分布式训练,需要对损失值进行平均。
- 梯度累积:根据配置的梯度累积步数,对损失值进行缩放。
- 内存管理:根据配置,定期清空不同硬件类型的缓存以释放内存。
3.9.4 原始Loss计算方法
损失的计算主要是在单步训练中的compute_loss函数中完成,它处理了标签平滑、自定义损失函数以及多设备(如多GPU)的损失平均等问题。
def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
- model (
nn.Module): 要训练的模型。 - inputs (
Dict[str, Union[torch.Tensor, Any]]): 包含输入和目标的字典,通常包括输入ID、注意力掩码、标签等。 - return_outputs (
bool): 是否返回模型输出,默认为False。 - num_items_in_batch (
int, optional): 批次中的样本数量(可选参数)。
-
处理标签
if (self.label_smoother is not None or self.compute_loss_func is not None) and "labels" in inputs:labels = inputs.pop("labels") else:labels = None- 如果启用了标签平滑器 (
label_smoother) 或者存在自定义损失函数 (compute_loss_func) 并且输入中包含labels,则从中提取标签并从inputs字典中移除。
- 如果启用了标签平滑器 (
-
准备损失计算的关键字参数
if self.model_accepts_loss_kwargs:loss_kwargs = {}if num_items_in_batch is not None:loss_kwargs["num_items_in_batch"] = num_items_in_batchinputs = {**inputs, **loss_kwargs}- 如果模型接受额外的损失关键字参数,则将
num_items_in_batch添加到inputs中。
- 如果模型接受额外的损失关键字参数,则将
-
前向传播
outputs = model(**inputs)- 将输入数据传递给模型进行前向传播,并获取模型输出。
-
保存过去的状态(如果适用)
if self.args.past_index >= 0:self._past = outputs[self.args.past_index]- 如果配置了过去索引 (
past_index),则保存模型输出中的相应部分(例如,对于某些生成任务)。
- 如果配置了过去索引 (
-
计算损失
- 根据是否有标签、是否使用自定义损失函数或标签平滑器,选择不同的方式计算损失。
情况一:有标签且使用自定义损失函数或标签平滑器
if labels is not None:unwrapped_model = self.accelerator.unwrap_model(model)if _is_peft_model(unwrapped_model):model_name = unwrapped_model.base_model.model._get_name()else:model_name = unwrapped_model._get_name()if self.compute_loss_func is not None:loss = self.compute_loss_func(outputs, labels, num_items_in_batch=num_items_in_batch)elif model_name in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES.values():loss = self.label_smoother(outputs, labels, shift_labels=True)else:loss = self.label_smoother(outputs, labels)
- 如果存在标签:
- 解包加速器中的模型。
- 判断模型是否为PEFT模型(Parameter-Efficient Fine-Tuning),并获取模型名称。
- 如果存在自定义损失函数 (
compute_loss_func),则调用该函数计算损失。 - 如果模型属于因果语言模型(Causal Language Model),则使用标签平滑器 (
label_smoother) 并设置shift_labels=True。 - 否则,直接使用标签平滑器计算损失。
情况二:无标签或模型未返回损失
else:if isinstance(outputs, dict) and "loss" not in outputs:raise ValueError("The model did not return a loss from the inputs, only the following keys: "f"{','.join(outputs.keys())}. For reference, the inputs it received are {','.join(inputs.keys())}.")# We don't use .loss here since the model may return tuples instead of ModelOutput.loss = outputs["loss"] if isinstance(outputs, dict) else outputs[0]
- 如果没有标签或者模型输出中不包含
loss键,则抛出异常提示用户模型未返回损失。 - 否则,从模型输出中提取损失值。
-
多设备损失平均
if self.args.average_tokens_across_devices and self.model_accepts_loss_kwargs:loss *= self.accelerator.num_processes- 如果配置了跨设备平均令牌数 (
average_tokens_across_devices),则根据设备数量调整损失值。
- 如果配置了跨设备平均令牌数 (
-
返回结果
return (loss, outputs) if return_outputs else loss- 如果
return_outputs参数为True,则返回一个元组(loss, outputs);否则仅返回损失值。
- 如果
3.9.5 标签平滑
LabelSmoother 是一个用于在预计算的模型输出上添加标签平滑(label smoothing)的类,标签平滑是一种正则化技术,旨在防止模型对训练数据中的特定标签过度自信,从而提高泛化能力。
@dataclass
class LabelSmoother:"""Adds label-smoothing on a pre-computed output from a Transformers model.Args:epsilon (`float`, *optional*, defaults to 0.1):The label smoothing factor.ignore_index (`int`, *optional*, defaults to -100):The index in the labels to ignore when computing the loss."""epsilon: float = 0.1ignore_index: int = -100
- epsilon (
float): 标签平滑因子,默认值为 0.1。 - ignore_index (
int): 在计算损失时忽略的标签索引,默认值为 -100。
-
提取
logitslogits = model_output["logits"] if isinstance(model_output, dict) else model_output[0]- 从模型输出中提取
logits,如果输出是字典形式,则通过键"logits"获取;否则直接取第一个元素。
- 从模型输出中提取
-
偏移处理(如果需要)
if shift_labels:logits = logits[..., :-1, :].contiguous()labels = labels[..., 1:].contiguous()- 如果需要偏移标签(如因果语言模型),则对
logits和labels进行偏移处理,使它们对齐。
- 如果需要偏移标签(如因果语言模型),则对
-
计算负对数概率
log_probs = -nn.functional.log_softmax(logits, dim=-1)- 使用
log_softmax函数计算负对数概率(即负对数似然)。
- 使用
-
调整标签维度
if labels.dim() == log_probs.dim() - 1:labels = labels.unsqueeze(-1)- 如果
labels的维度比log_probs少一维,则增加一个维度以匹配log_probs的形状。
- 如果
-
创建填充掩码
padding_mask = labels.eq(self.ignore_index) labels = torch.clamp(labels, min=0)- 创建一个填充掩码
padding_mask,标记哪些位置是填充(使用ignore_index)。 - 使用
clamp函数将标签限制在非负值范围内,避免在后续操作中出现错误。
- 创建一个填充掩码
-
计算负对数似然损失(NLL Loss)
nll_loss = log_probs.gather(dim=-1, index=labels)- 使用
gather函数从log_probs中提取对应于真实标签的负对数概率。
- 使用
-
计算平滑损失
smoothed_loss = log_probs.sum(dim=-1, keepdim=True, dtype=torch.float32)- 计算所有类别的负对数概率之和,并保持维度不变。
-
应用填充掩码
nll_loss.masked_fill_(padding_mask, 0.0) smoothed_loss.masked_fill_(padding_mask, 0.0)- 使用填充掩码将填充位置的损失置为零。
-
计算有效元素数量
num_active_elements = padding_mask.numel() - padding_mask.long().sum()- 计算非填充位置的有效元素数量。
-
归一化损失
nll_loss = nll_loss.sum() / num_active_elements smoothed_loss = smoothed_loss.sum() / (num_active_elements * log_probs.shape[-1])- 对负对数似然损失和平滑损失进行归一化处理。
-
组合最终损失
return (1 - self.epsilon) * nll_loss + self.epsilon * smoothed_loss- 组合负对数似然损失和平滑损失,得到最终的标签平滑损失。
3.9.6 SFT的Loss计算方法
SFT Trainer重写了compute_loss方法,不仅计算训练损失,还额外计算了令牌(token)准确性,这对于评估模型在生成任务中的表现特别有用。
-
调用父类的
compute_loss方法(loss, outputs) = super().compute_loss(model, inputs, return_outputs=True, num_items_in_batch=num_items_in_batch )- 调用父类的
compute_loss方法来计算损失值和模型输出。这里使用return_outputs=True确保返回模型输出以便后续计算令牌准确性。
- 调用父类的
-
计算令牌准确性(如果适用)
if "labels" in inputs and not self.args.use_liger:shift_logits = outputs.logits[..., :-1, :].contiguous()shift_labels = inputs["labels"][..., 1:].contiguous()- 如果输入中包含标签且未使用 Liger(一种特定的优化器或模型配置),则从模型输出中提取
logits和labels。 - 对于因果语言模型(Causal Language Model),通常需要对
logits和labels进行偏移处理:shift_logits: 将logits的最后一个维度去掉一个位置,使其与labels对齐。shift_labels: 将labels的第一个位置去掉,使其与logits对齐。
- 如果输入中包含标签且未使用 Liger(一种特定的优化器或模型配置),则从模型输出中提取
-
多GPU环境下收集 logits 和 labels
shift_logits = self.accelerator.gather_for_metrics(shift_logits) shift_labels = self.accelerator.gather_for_metrics(shift_labels)- 使用加速器的
gather_for_metrics方法将所有GPU上的logits和labels收集到主进程中。这一步确保了在分布式训练环境中能够正确地计算全局指标。
- 使用加速器的
-
计算令牌准确性
if self.accelerator.is_main_process:accuracy = compute_token_accuracy(shift_logits, shift_labels)self._metrics["mean_token_accuracy"].append(accuracy)- 在主进程中(即
is_main_process为True),调用compute_token_accuracy函数计算令牌准确性,并将其添加到_metrics字典中。
- 在主进程中(即
-
返回结果
return (loss, outputs) if return_outputs else loss- 如果
return_outputs参数为True,则返回一个元组(loss, outputs);否则仅返回损失值。
- 如果
3.9.7 计算令牌准确性
该函数用于计算令牌(token)的准确性,即模型预测的正确率。它通过比较模型输出的预测值和真实标签来计算准确率,并忽略填充(padding)部分的令牌。
def compute_token_accuracy(logits: torch.Tensor, labels: torch.Tensor, ignore_index: int = -100) -> float:
- logits (
torch.Tensor): 模型输出的对数概率张量,形状通常为(batch_size, sequence_length, vocab_size)。 - labels (
torch.Tensor): 真实标签张量,形状通常为(batch_size, sequence_length)。 - ignore_index (
int): 忽略的索引,默认值为-100,表示填充部分的标签。
-
获取预测值
predictions = logits.argmax(dim=-1)- 使用
argmax函数从logits中获取每个位置的最大概率对应的索引,作为模型的预测值。dim=-1表示在最后一个维度(词汇表维度)上进行操作,结果是一个形状为(batch_size, sequence_length)的张量。
- 使用
-
创建非填充掩码
mask = labels != ignore_index- 创建一个布尔掩码
mask,标记哪些位置不是填充部分(即不等于ignore_index)。这个掩码用于后续计算时忽略填充部分的令牌。
- 创建一个布尔掩码
-
计算正确的预测
correct_predictions = (predictions == labels) & mask- 计算预测值与真实标签相等的位置,并结合掩码
mask过滤掉填充部分的令牌。结果是一个布尔张量,其中True表示正确预测且非填充位置。
- 计算预测值与真实标签相等的位置,并结合掩码
-
统计有效令牌数量
total_tokens = mask.sum() correct_tokens = correct_predictions.sum()- 使用
sum函数统计掩码中True的数量,得到有效令牌的总数total_tokens。 - 同样地,使用
sum函数统计correct_predictions中True的数量,得到正确预测的令牌数correct_tokens。
- 使用
-
计算准确性
accuracy = correct_tokens.item() / total_tokens.item() if total_tokens > 0 else 0.0- 计算准确性:将正确预测的令牌数除以总的有效令牌数。如果有效令牌数为零,则返回
0.0以避免除零错误。
- 计算准确性:将正确预测的令牌数除以总的有效令牌数。如果有效令牌数为零,则返回
-
返回准确性
return accuracy- 返回计算出的令牌准确性。
3.10 保存模型
logger.info("*** Save model ***")trainer.save_model(training_args.output_dir)logger.info(f"Model saved to {training_args.output_dir}")# Save everything else on main processkwargs = {"dataset_name": script_args.dataset_name,"tags": ["open-r1"],}if trainer.accelerator.is_main_process:trainer.create_model_card(**kwargs)# Restore k,v cache for fast inferencetrainer.model.config.use_cache = Truetrainer.model.config.save_pretrained(training_args.output_dir)
3.11 评估
直接调用trainer的evaluate()函数完成评测。
if training_args.do_eval:logger.info("*** Evaluate ***")metrics = trainer.evaluate()metrics["eval_samples"] = len(dataset[script_args.dataset_test_split])trainer.log_metrics("eval", metrics)trainer.save_metrics("eval", metrics)
3.12 推送到Hub
将训练结果推送到HuggingFace Hub上。
if training_args.push_to_hub:logger.info("Pushing to hub...")trainer.push_to_hub(**kwargs)
相关文章:
【复现DeepSeek-R1之Open R1实战】系列5:SFT源码逐行深度解析
目录 3 SFT源码分析3.1 accelerate3.1.1 关键特性3.1.2 使用场景3.1.3 简单示例 3.2 代码主入口3.3 设置随机种子3.4 设置Log3.5 加载数据集3.6 加载Tokenizer3.7 模型参数配置初始化3.8 初始化SFT Trainer3.9 开始训练3.9.1 主函数3.9.2 核心循环3.9.3 单步训练3.9.4 原始Loss…...
WPF8-常用控件
目录 写在前面:1. 按钮控件1.1. Button 按钮1.2. RepeatButton:长按按钮1.3. RadioButton:单选按钮 2. 数据显示控件2.1. TextBlock:只读文本控件2.2. Lable:标签 显示文本控件2.3. ListBox:显示可选择项的列表2.4. DataGrid&…...
单元测试整理
在国外软件开发中,单元测试必不可少,但是国内并不太重视这一块,一个好的单元测试可以提前发现很多问题,也减去和测试battle的时间 Spring单元测试 JUnit4 RunWith 指明单元测试框架 e.g. RunWith(SpringJUnit4ClassRunner.cla…...
代码随想录刷题day24|(字符串篇)151.反转字符串中的单词
一、题目思路 1.快慢指针移除字符串首尾以及单词中的多余空格 类似前面数组篇--移除元素代码随想录刷题day02|(数组篇)27.移除元素、26.删除有序数组中的重复项_代码随想录网站-CSDN博客 快指针fast遍历整个字符串,慢指针slow指向新字符串…...
六、敏捷开发工具:项目管理工具
一、敏捷开发工具 在敏捷开发过程中,项目管理工具是支持团队高效协作、任务跟踪和项目进度控制的关键因素。随着敏捷方法的普及,市场上出现了多种工具来帮助团队进行需求管理、任务分配、进度跟踪以及反馈收集等任务。本文将对常用的敏捷开发项目管理工具(如Jira、Trello、…...
VMware按照的MacOS升级后无法联网
背景 3年前公司使用Flutter开发了一款app,现在app有微小改动需要重新发布到AppStore 问题 问题是原来的Vmware搭建的开发环境发布App失败了 提示:App需要使用xcode15IOS 17 SDK重新构建,这样的话MacOS至少需要升级到13.5 Xcode - 支持 - Ap…...
I2C、SPI、UART
I2C:串口通信,同步,半双工,双线(数据线SDA时钟线SCL),最大距离1米到几米 SPI(串行外设接口):串口通信,同步,全双工,四线&…...
3.2 Hugging Face Transformers库深度解析:大模型开发的一站式解决方案
Hugging Face Transformers库深度解析:大模型开发的一站式解决方案 一、Transformers库定位:NLP领域的"模型工厂" 1.1 核心定义与技术定位 Hugging Face Transformers 是一个开源的Python库,专为自然语言处理(NLP)、计算机视觉(CV)和语音任务设计。它提供:…...
DeepSeek V3和R1
DeepSeek V3 和 R1 是深度求索(DeepSeek)推出的两款大模型,基于混合专家架构(MoE),但在设计目标、训练方法和应用场景上存在显著差异。以下是两者的详细对比与补充内容: DeepSeek V3和R1 一、模…...
【操作系统】深入理解Linux物理内存
物理内存的组织结构 我们平时所称的内存也叫随机访问存储器也叫 RAM 。RAM 分为两类: 一类是静态 RAM( SRAM ),这类 SRAM 用于 CPU 高速缓存 L1Cache,L2Cache,L3Cache。其特点是访问速度快,访…...
6.【线性代数】—— 列空间和零空间
六 列空间和零空间 1. 列空间 C(A)2. 零空间 N(A)2.1 定义2.2 为什么零空间是一个子空间?2.3 Axb的解空间,是一个子空间吗? 1. 列空间 C(A) [ c o l 11 c o l 21 c o l 31 c o l 12 c o l 22 c o l 32 c o l 13 c o l 23 c o l 33 ] ⏟ A [ a…...
记一次一波三折的众测SRC经历
视频教程和更多福利在我主页简介或专栏里 (不懂都可以来问我 专栏找我哦) 目录: 前言 波折一:RCE漏洞利用失败 波折二:SQL时间盲注 波折三:寻找管理后台 总结 前言 先谈个人SRC心得体会吧,我虽…...
Java中的Thread.sleep(0)你了解多少
在Java中,Thread.sleep(long millis)方法用于使当前线程暂停执行指定的时间(以毫秒为单位)。它通常用于控制线程的执行节奏、避免过度占用CPU资源或实现任务的延迟。然而,Thread.sleep(0)作为Thread.sleep方法的一种特殊用法&…...
POI优化Excel录入
57000单词原始录入时间258S 核心代码: List<Word> wordBookList ExcelUtil.getReader(file.getInputStream()).readAll(Word.class);if (!CollectionUtil.isEmpty(wordBookList)) {for (Word word : wordBookList) {//逐条向数据库中插入单词wordMapper.insert(word);}…...
HarmonyOS进程通信及原理
大家好,我是学徒小z,最近在研究鸿蒙中一些偏底层原理的内容,今天分析进程通信给大家,请用餐😊 文章目录 进程间通信1. 通过公共事件(ohos.commonEventManager)公共事件的底层原理 2. IPC Kit能…...
DeepSeek核心算法解析:如何打造比肩ChatGPT的国产大模型
注:此文章内容均节选自充电了么创始人,CEO兼CTO陈敬雷老师的新书《自然语言处理原理与实战》(人工智能科学与技术丛书)【陈敬雷编著】【清华大学出版社】 文章目录 DeepSeek大模型技术系列一DeepSeek核心算法解析:如何…...
【算法】双指针(上)
目录 双指针 左右指针(对撞指针) 快慢指针 移动零 双指针解题 复写零 暴力解题 双指针解题(快慢指针) 快乐数 双指针解题(快慢指针) 盛最多水的容器 暴力解题(会超时) 双指针解题(左右指针) 有效三角形的个数 暴力解题 双指针解题(左右指针) 双指针 常见的双指…...
深度学习模型常用激活函数集合
激活函数是深度学习模型中的关键组成部分,用于引入非线性特性,使神经网络能够学习复杂的模式和映射关系;神经网络本质上是一个复合函数。如果没有激活函数,无论网络有多少层,其输出都只是输入的线性组合。激活函数通过…...
WebAssembly 3.0发布:浏览器端高性能计算迎来新突破!
“WebAssembly 3.0来了,浏览器端的高性能计算将彻底改变!”2025年,WebAssembly(Wasm)迎来了重大更新——WebAssembly 3.0正式发布。这次更新不仅支持多线程和SIMD指令集,还优化了内存管理,让浏览…...
ERP对制造业务有何价值?
ERP 的定义 在定义 ERP 之前,我们先从其首字母缩写说起,ERP 代表企业资源规划。我们可以将 ERP 定义为一种企业软件,它帮助组织管理日常业务。从根本上讲,ERP 将客户管理、人力资源、商业智能、财务管理、库存以及供应链功能整合…...
MySQL5.7 创建用户并授予超管权限脚本
记录MySQL5.7 创建新用户并授予超管权限脚本 用户与密码可任意设置 创建用户并设置密码 CREATE USER zhangsan % identified by 123456oo;修改用户密码 UPDATE USER set authentication_stringpassword("Abc123!") where user"zhangsan ";授予用户超管权…...
芝加哥学派(Chicago School):金融与经济学的创新力量(中英双语)
芝加哥学派:金融与经济学的创新力量 在经济学和金融学的历史上,有一个学派的影响力不容忽视,那就是芝加哥学派(Chicago School)。芝加哥学派不仅在学术界广受推崇,也深刻影响了全球的经济政策和金融市场。…...
Pytorch实现论文之一种基于扰动卷积层和梯度归一化的生成对抗网络
简介 简介:提出了一种针对鉴别器的梯度惩罚方法和在鉴别器中采用扰动卷积,拟解决锐梯度空间引起的训练不稳定性问题和判别器的记忆问题。 论文题目:A Perturbed Convolutional Layer and Gradient Normalization based Generative Adversarial Network(一种基于扰动卷积层…...
哈希表(C语言版)
文章目录 哈希表原理实现(无自动扩容功能)代码运行结果 分析应用 哈希表 如何统计一段文本中,小写字母出现的次数? 显然,我们可以用数组 int table[26] 来存储每个小写字母出现的次数,而且这样处理,效率奇高。假如我们想知道字…...
3.5 使用Tokenizer编解码文本:从原理到企业级实践
使用Tokenizer编解码文本:从原理到企业级实践 一、Tokenizer核心原理:文本到数字的魔法转换 1.1 分词算法三大流派 # 不同分词算法对比 tokenization_methods = {"WordPiece": "BERT/ELECTRA", "BPE": "GPT/RoBERTa",...
多表关联查询的优化
文章目录 前言1. 数据库设计优化:深入实践**1.1 规范化与反规范化的决策树****1.2 索引设计的实战技巧** **2. SQL 优化:进阶技巧****2.1 JOIN 顺序与执行计划****2.2 分页查询的深度优化** **3. MyBatis Plus 高级用法****3.1 动态 SQL 规避 N1 查询***…...
亚马逊企业购大客户业务拓展经理张越:跨境电商已然成为全球零售电商领域中熠熠生辉的强劲增长点
2024年12月26日-27日,由中国产业海外发展协会上合-海湾双链专委会指导、极新主办的「重度垂直2024极新AIGC峰会」先后在深圳、香港两地顺利开幕。本届峰会以AI的垂直应用与出海为核心主题,旨在深入探讨AI技术在全球范围内的融合应用与发展趋势࿰…...
VirtualBox 中使用 桥接网卡 并设置 MAC 地址
在 VirtualBox 中使用 桥接网卡 并设置 MAC 地址,可以按照以下步骤操作: 步骤 1:设置桥接网卡 打开 VirtualBox,选择你的虚拟机,点击 “设置” (Settings)。进入 “网络” (Network) 选项卡。在 “适配器 1” (Adapt…...
idea无法联网,离线安装插件
插件地址:https://plugins.jetbrains.com/ JetBrains Marketplace 如果无法进入,可以试试 配置hosts 3.163.125.103 plugins.jetbrains.com ip 变了,可以查询个最新的: https://tool.chinaz.com/speedtest/plugins.jetbrai…...
网络安全中的机器学习
当涉及到网络安全时,技术一直是保护系统免受攻击和数据泄露的关键。在这篇论文中,我将介绍一些当前在网络安全领域使用的关键技术,包括加密,身份验证和防火墙。 首先,加密是网络安全中最常见的技术之一。加密是指使用算…...
