当前位置: 首页 > news >正文

使用 JAX 进行 LLM 分布式监督微调

LLM distributed supervised fine-tuning with JAX — ROCm Blogs (amd.com)

24年1月25日,Douglas Jia 发布在AMD ROCm 博客上的文章。

在这篇文章中,我们回顾了使用 JAX 对基于双向编码器表示(BERT)的大型语言模型(LLM)进行文本分类任务微调的过程。我们探讨了在多个 AMD GPU 上并行化这一微调过程的技术,然后评估模型在测试数据集上的性能。为此,我们使用了一个基于 BERT的 cased transformer 模型和 General Language Understanding Evaluation(GLUE)基准数据集在多个 AMD GPU 上进行实验。

我们重点关注 JAX 中两个单程序多数据(SPMD)并行化方法。这两个方法是:
- 使用 pmap 函数在单个领先轴上进行简单的数据分发。
- 使用 jit、`Mesh` 和 mesh_utils 函数在设备之间分片数据,提供更大的并行化控制。

我们主要强调第一个方法,并在文章的最后部分提供了第二个方法的详细说明。
在撰写本文时,我们参考了这个教程,我们强烈推荐阅读。

什么是监督微调?

在人工智能(AI)时代,基于Transformer架构的模型(如 BERT、GPT-3 及其后续版本)为实现各种自然语言处理(NLP)任务(如文本分类、文本生成和情感分析)的尖端性能提供了坚实的基础。然而,当这些大型预训练模型单独应用于这些特定任务时,常常表现出一定的局限性。监督微调(SFT)为解决这些局限性提供了方案。

与在大规模、多样化数据集上进行广泛无监督训练的预训练模型不同,SFT采用了一种专注且资源高效的方法。通常,这需要一个相对紧凑、高质量的数据集,该数据集精确地针对特定任务量身定制。SFT可以在不需要长时间训练的情况下,将模型性能提升到最先进的水平,因为它能够利用预训练模型所获得的广泛知识。

SFT过程包括微调模型的现有权重或添加额外参数,以确保与指定任务的复杂性保持一致。通常,这种适应会结合任务特定的层,例如为分类添加一个 softmax 层,从而增强模型解决监督任务的能力。

什么是 JAX?

JAX 是一个高性能的 Python 数值计算库。与传统的机器学习框架(如 TensorFlow 和 PyTorch)相比,JAX 的速度和效率都非常出色。JAX 利用即时编译(JIT),无缝的自动微分,以及高效向量化和并行化代码的能力,使其能简单地适配 AI 加速器(如 GPU 和 TPU)。

为什么使用 AMD GPU?

AMD GPU 因其强大的开源支持而脱颖而出,工具如 ROCm 和 HIP 使其易于适配 AI 工作流程。AMD 具有竞争力的性价比,非常适合寻求成本效益的 AI 和深度学习任务解决方案的用户。随着 AMD 在市场上的影响力不断增长,越来越多的机器学习库和框架正在添加对 AMD GPU 的支持。

硬件要求和运行环境

为了利用完成此任务所需的计算能力,我们使用AMD加速器云平台 (AAC)。AAC 是一个按需提供云计算资源和API的付费平台。具体来说,我们使用一个JAX Docker容器,其在AAC上拥有8个GPU,以充分利用先进的GPU并行计算能力。

本文是硬件无关的,这意味着要成功运行提供的代码示例,不需要访问AAC。只要您有加速器设备(如GPU或TPU),您应该能够以最小的代码修改来运行这些代码示例。如果您使用的是AMD GPU,请确保正确安装了ROCm及其兼容版本的JAX和Jaxlib。参考以下教程进行安装:

  • ROCm 安装

  • JAX and Jaxlib 安装: 您也可以直接通过链接拉取一个JAX Docker镜像。

代码示例:对Transformer模型进行SFT

为了演示,我们使用一个通用语言理解评估(GLUE)基准数据集Quora Question Pairs(QQP)微调一个基于transformer的LLM(如:bert-base-cased)。该数据集包含超过40万对问题,每对问题都有一个二进制注释,指示这两个问题是否是相互的复述。输入变量是两个问题的句子,而输出变量是一个二进制指标,表示这两个问题是否具有相同的含义。

安装

首先,安装所需的软件包 (%%capture 是一个 _cell magic_,它将抑制单元格的输出)。

%%capture
!pip install datasets
!pip install git+https://github.com/huggingface/transformers.git
!pip install flax
!pip install git+https://github.com/deepmind/optax.git
!pip install evaluate
!pip install ipywidgets
!pip install black isort # 单元格中的格式化器;可选项

导入剩余的软件包和功能。

import os
from itertools import chain
from typing import Callableimport evaluate
import flax
import jax
import jax.numpy as jnp
import optax
import pandas as pd
from datasets import load_dataset
from flax import traverse_util
from flax.training import train_state
from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key
from ipywidgets import IntProgress as IProgress
from tqdm.notebook import tqdm
from transformers import (AutoConfig,AutoTokenizer,FlaxAutoModelForSequenceClassification,
)os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"

JAX 预先分配75%的GPU内存以减少首次运行JAX操作时的开销和碎片,但可能会触发内存不足(OOM)错误。为了避免OOM问题,可通过将 XLA_PYTHON_CLIENT_PREALLOCATE 标志设置为 false 来抑制默认行为。

检查是否可以通过JAX检测到GPU设备。如果不能,可能需要重新安装ROCm、JAX和Jaxlib。如果JAX安装正确,你可以看到所有请求的GPU设备,在我们的例子中是8个GPU。

jax.local_devices()
[gpu(id=0),gpu(id=1),gpu(id=2),gpu(id=3),gpu(id=4),gpu(id=5),gpu(id=6),gpu(id=7)]

获取微调数据集和预训练模型检查点

指定你的微调过程的设置:数据集、预训练模型以及每个设备每批次要处理的数据量。

task = "qqp"
model_checkpoint = "bert-base-cased"
per_device_batch_size = 64

加载数据集和评估指标模块。

raw_dataset = load_dataset("glue", task)
metric = evaluate.load("glue", task)

接下来的几段代码展示了如何使用模型特定的分词器对文本数据进行分词,并加载分词后的训练和验证数据。使用与预训练模型相同的分词器确保在微调过程中相同的词会被转换为相同的嵌入向量。

重要的是,我们在原始训练数据中对训练和评估数据集进行了10%的抽样。尽管如此,QQP数据集仍然提供了足够的数据来实现令人满意的性能,并且可以在每个epoch后观察到指标的改进。这种抽样方法还加快了我们的训练过程,便于说明。

使用数据预处理函数和map包装器的批处理和并行处理功能处理训练和评估数据集。你可以在以下输出中查看分词后的数据集。

tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
def preprocess_function(examples):texts = (examples["question1"], examples["question2"])processed = tokenizer(*texts, padding="max_length", max_length=128, truncation=True)processed["labels"] = examples["label"]return processed
# 关于如何处理和操作 huggingface 数据集的详细信息:
# https://huggingface.co/docs/datasets/process
data = raw_dataset["train"].shuffle(seed=0)
train_data = data.select(list(range(int(data.shape[0] * 0.1))))
eval_data = data.select(list(range(int(data.shape[0] * 0.1), int(data.shape[0] * 0.2))))
print(f"原始训练数据集的形状为: {data.shape}")
print(f"当前训练数据集的形状为: {train_data.shape}")
print(f"当前验证数据集的形状为: {eval_data.shape}")
原始训练数据集的形状为: (363846, 4)
当前训练数据集的形状为: (36384, 4)
当前验证数据集的形状为: (36385, 4)
train_dataset = train_data.map(preprocess_function, batched=True, remove_columns=train_data.column_names
)
eval_dataset = eval_data.map(preprocess_function, batched=True, remove_columns=eval_data.column_names
)
# 你可以在以下单元格的输出中查看已分词的数据集
pd.DataFrame(train_dataset[:3])

从Hugging Face下载预训练模型配置和检查点。注意,你会看到一个警告信息,指出某些模型权重未使用。这是预期的,因为BERT模型检查点是一个PreTraining模型类,而你正在初始化一个
SequenceClassification模型。警告信息指出:你可能需要在下游任务上训练该模型,以便能够将其用于预测和推理。 这就是我们接下来要关注的内容。

num_labels = 2
seed = 0
config = AutoConfig.from_pretrained(model_checkpoint, num_labels=num_labels)
model = FlaxAutoModelForSequenceClassification.from_pretrained(model_checkpoint, config=config, seed=seed
)
某些在bert-base-cased模型检查点中的权重在初始化FlaxBertForSequenceClassification时未被使用: {('cls', 'predictions', 'bias'), ('cls', 'predictions', 'transform', 'dense', 'kernel'), ('cls', 'predictions', 'transform', 'LayerNorm', 'bias'), ('cls', 'predictions', 'transform', 'LayerNorm', 'scale'), ('cls', 'predictions', 'transform', 'dense', 'bias')}
- 如果您正在从另一个任务或架构的模型检查点初始化FlaxBertForSequenceClassification(例如,从BertForPreTraining模型初始化BertForSequenceClassification模型),这是预期的。
- 如果您正在从您期望完全相同的模型检查点初始化FlaxBertForSequenceClassification(从BertForSequenceClassification模型初始化BertForSequenceClassification模型),这不是预期的。
某些在bert-base-cased模型检查点中的权重未被初始化到FlaxBertForSequenceClassification并被重新初始化: {('classifier', 'kernel'), ('classifier', 'bias'), ('bert', 'pooler', 'dense', 'kernel'), ('bert', 'pooler', 'dense', 'bias')}
您可能需要在下游任务中训练此模型,以便能够使用它进行预测和推理。

定义微调模型的状态

以下代码块展示了如何设置训练参数,比如训练周期数和初始学习率。学习率调度是为了使学习率在训练过程中线性衰减,以确保学习的效率和稳定性。

num_train_epochs = 6
learning_rate = 2e-5
total_batch_size = per_device_batch_size * jax.local_device_count()
print("The overall batch size (both for training and eval) is", total_batch_size)
The overall batch size (both for training and eval) is 512
num_train_steps = len(train_dataset) // total_batch_size * num_train_epochslearning_rate_function = optax.linear_schedule(init_value=learning_rate, end_value=0, transition_steps=num_train_steps
)

接下来,需要建立训练状态,包括优化器和损失函数的职责,并监督模型参数在训练过程中的更新。

使用状态对象,初始化和更新模型。当调用模型时,将状态作为输入,模型会返回通过新数据批次更新后的状态,同时保留模型实例。

Flax 提供了一个用户友好的类(`flax.training.train_state.TrainState`),它将模型参数、损失函数和优化器封装在一起。当提供数据时,它可以使用 apply_gradients 函数更新模型参数。

下面的代码块展示了如何定义和建立训练状态、优化器和损失函数。

class TrainState(train_state.TrainState):logits_function: Callable = flax.struct.field(pytree_node=False)loss_function: Callable = flax.struct.field(pytree_node=False)
# 创建一个 decay_mask_fn 函数,以确保对任何偏置项或 LayerNorm 权重不应用权重衰减,因为这可能不会提高模型性能甚至会有害。def decay_mask_fn(params):flat_params = traverse_util.flatten_dict(params)flat_mask = {path: (path[-1] != "bias" and path[-2:] != ("LayerNorm", "scale"))for path in flat_params}return traverse_util.unflatten_dict(flat_mask)
# 标准的带权重衰减的 Adam 优化器
def adamw(weight_decay):return optax.adamw(learning_rate=learning_rate_function,b1=0.9,b2=0.999,eps=1e-6,weight_decay=weight_decay,mask=decay_mask_fn,)
def loss_function(logits, labels):xentropy = optax.softmax_cross_entropy(logits, onehot(labels, num_classes=num_labels))return jnp.mean(xentropy)def eval_function(logits):return logits.argmax(-1)
# 实例化 TrainState
state = TrainState.create(apply_fn=model.__call__,params=model.params,tx=adamw(weight_decay=0.01),logits_function=eval_function,loss_function=loss_function,
)

定义如何训练、评估模型并启用并行化

train_step 和 eval_step 参数定义了如何训练和评估模型。训练步骤遵循标准的训练过程:

  1. 使用当前的权重计算损失。

  2. 计算损失函数相对于权重的梯度。

  3. 使用梯度和学习率更新权重。

  4. 使用梯度和学习率更新权重。

需要强调的是,`lax.pmean` 函数计算跨所有 8 个 GPU 设备的数据批次梯度的均值。这个关键步骤保证了所有 GPU 设备上的模型参数同步。

def train_step(state, batch, dropout_rng):targets = batch.pop("labels")dropout_rng, new_dropout_rng = jax.random.split(dropout_rng)def loss_function(params):logits = state.apply_fn(**batch, params=params, dropout_rng=dropout_rng, train=True)[0]loss = state.loss_function(logits, targets)return lossgrad_function = jax.value_and_grad(loss_function)loss, grad = grad_function(state.params)grad = jax.lax.pmean(grad, "batch")new_state = state.apply_gradients(grads=grad)metrics = jax.lax.pmean({"loss": loss, "learning_rate": learning_rate_function(state.step)},axis_name="batch",)return new_state, metrics, new_dropout_rng
def eval_step(state, batch):logits = state.apply_fn(**batch, params=state.params, train=False)[0]return state.logits_function(logits)

接下来,应用 jax.pmap 函数到定义的 train_step 和 eval_step 函数。将 pmap() 应用于函数时,该函数会使用 XLA 编译(类似于 jit()),然后在 XLA 设备上并行运行,例如多 GPU 设备或多 TPU 核。简单来说,这一步将训练和评估函数发送到所有 GPU 设备。你还需要通过 flax.jax_utils.replicate 将训练状态发送到所有 GPU 设备,这些步骤确保你通过分布式训练在所有 GPU 设备上更新模型状态。

parallel_train_step = jax.pmap(train_step, axis_name="batch", donate_argnums=(0,))
parallel_eval_step = jax.pmap(eval_step, axis_name="batch")
state = flax.jax_utils.replicate(state)

定义数据加载函数,这些函数返回数据批次生成器。在最终的训练和评估循环中,每一步都会输入一个新的数据批次。

def glue_train_data_loader(rng, dataset, batch_size):steps_per_epoch = len(dataset) // batch_sizeperms = jax.random.permutation(rng, len(dataset))perms = perms[: steps_per_epoch * batch_size]  # 跳过不完整的批次。perms = perms.reshape((steps_per_epoch, batch_size))for perm in perms:batch = dataset[perm]batch = {k: jnp.array(v) for k, v in batch.items()}batch = shard(batch)yield batch
def glue_eval_data_loader(dataset, batch_size):for i in range(len(dataset) // batch_size):batch = dataset[i * batch_size : (i + 1) * batch_size]batch = {k: jnp.array(v) for k, v in batch.items()}batch = shard(batch)yield batch

基于整数种子生成伪随机数生成器(PRNG)密钥,然后将其拆分为 8 个新的密钥,以确保每个 GPU 设备都得到不同的密钥。然后运行训练步骤,以根据预定义的训练参数(如训练轮次和总批次大小)更新 state。在每个轮次结束时,运行评估步骤,以查看评估数据集上的准确率和 F1 指标。由于使用的训练数据集比基准中的原始训练数据集要小,可以看到在前几轮训练中,评估指标(训练损失和评估准确率)稳定提升。

rng = jax.random.PRNGKey(seed)
dropout_rngs = jax.random.split(rng, jax.local_device_count())
for i, epoch in enumerate(tqdm(range(1, num_train_epochs + 1), desc=f"Epoch ...", position=0, leave=True)
):rng, input_rng = jax.random.split(rng)# trainwith tqdm(total=len(train_dataset) // total_batch_size, desc="Training...", leave=True) as progress_bar_train:for batch in glue_train_data_loader(input_rng, train_dataset, total_batch_size):state, train_metrics, dropout_rngs = parallel_train_step(state, batch, dropout_rngs)progress_bar_train.update(1)# 评估with tqdm(total=len(eval_dataset) // total_batch_size, desc="Evaluating...", leave=False) as progress_bar_eval:for batch in glue_eval_data_loader(eval_dataset, total_batch_size):labels = batch.pop("labels")predictions = parallel_eval_step(state, batch)metric.add_batch(predictions=list(chain(*predictions)), references=list(chain(*labels)))progress_bar_eval.update(1)eval_metric = metric.compute()loss = round(flax.jax_utils.unreplicate(train_metrics)["loss"].item(), 3)eval_score1 = round(list(eval_metric.values())[0], 3)metric_name1 = list(eval_metric.keys())[0]eval_score2 = round(list(eval_metric.values())[1], 3)metric_name2 = list(eval_metric.keys())[1]print(f"{i+1}/{num_train_epochs} | Train loss: {loss} | Eval {metric_name1}: {eval_score1}, {metric_name2}: {eval_score2}")
Epoch ...:   0%|          | 0/6 [00:00<?, ?it/s]
Training...:   0%|          | 0/71 [00:00<?, ?it/s]
Evaluating...:   0%|          | 0/71 [00:00<?, ?it/s]
1/6 | Train loss: 0.475 | Eval accuracy: 0.799, f1: 0.762
Training...:   0%|          | 0/71 [00:00<?, ?it/s]
Evaluating...:   0%|          | 0/71 [00:00<?, ?it/s]
2/6 | Train loss: 0.369 | Eval accuracy: 0.834, f1: 0.789
Training...:   0%|          | 0/71 [00:00<?, ?it/s]
Evaluating...:   0%|          | 0/71 [00:00<?, ?it/s]
3/6 | Train loss: 0.299 | Eval accuracy: 0.846, f1: 0.797
Training...:   0%|          | 0/71 [00:00<?, ?it/s]
Evaluating...:   0%|          | 0/71 [00:00<?, ?it/s]
4/6 | Train loss: 0.239 | Eval accuracy: 0.846, f1: 0.806
Training...:   0%|          | 0/71 [00:00<?, ?it/s]
Evaluating...:   0%|          | 0/71 [00:00<?, ?it/s]
5/6 | Train loss: 0.252 | Eval accuracy: 0.849, f1: 0.802
Training...:   0%|          | 0/71 [00:00<?, ?it/s]
Evaluating...:   0%|          | 0/71 [00:00<?, ?it/s]
6/6 | Train loss: 0.212 | Eval accuracy: 0.849, f1: 0.805

使用JAX设备网格来实现并行化

from jax.experimental import mesh_utils
from jax.sharding import Mesh, NamedSharding
from jax.sharding import PartitionSpec as P
config = AutoConfig.from_pretrained(model_checkpoint, num_labels=num_labels)
model = FlaxAutoModelForSequenceClassification.from_pretrained(model_checkpoint, config=config, seed=seed
)
state = TrainState.create(apply_fn=model.__call__,params=model.params,tx=adamw(weight_decay=0.01),logits_function=eval_function,loss_function=loss_function,
)
一些来自 bert-base-cased 模型检查点的权重在初始化 FlaxBertForSequenceClassification 时未被使用: {('cls', 'predictions', 'bias'), ('cls', 'predictions', 'transform', 'dense', 'kernel'), ('cls', 'predictions', 'transform', 'LayerNorm', 'bias'), ('cls', 'predictions', 'transform', 'LayerNorm', 'scale'), ('cls', 'predictions', 'transform', 'dense', 'bias')}
- 当你用模型训练其他任务或用另一种架构初始化 FlaxBertForSequenceClassification 时,这是预期中的情况(例如从 BertForPreTraining 模型初始化 BertForSequenceClassification 模型)。
- 当你期望从与 FlaxBertForSequenceClassification 模型完全相同的检查点初始化时(从 BertForSequenceClassification 模型初始化 BertForSequenceClassification 模型),这不是预期情况。
FlaxBertForSequenceClassification 中一些权重没有从 bert-base-cased 模型检查点初始化,是新初始化的: {('classifier', 'kernel'), ('classifier', 'bias'), ('bert', 'pooler', 'dense', 'kernel'), ('bert', 'pooler', 'dense', 'bias')}
应该将这个模型训练到下游任务上以便用于预测和推断。
@jax.jit
def train_step(state, batch, dropout_rng):targets = batch.pop("labels")dropout_rng, new_dropout_rng = jax.random.split(dropout_rng)def loss_function(params):logits = state.apply_fn(**batch, params=params, dropout_rng=dropout_rng, train=True)[0]loss = state.loss_function(logits, targets)return lossgrad_function = jax.value_and_grad(loss_function)loss, grad = grad_function(state.params)new_state = state.apply_gradients(grads=grad)metrics = {"loss": loss, "learning_rate": learning_rate_function(state.step)}return new_state, metrics, new_dropout_rng
@jax.jit
def eval_step(state, batch):logits = state.apply_fn(**batch, params=state.params, train=False)[0]return state.logits_function(logits)
num_devices = len(jax.local_devices())
devices = mesh_utils.create_device_mesh((num_devices,))# 数据将沿批处理轴进行分割
data_mesh = Mesh(devices, axis_names=("batch",))  # naming axes of the mesh
data_sharding = NamedSharding(data_mesh,P("batch",),
)  # 命名网格的轴def glue_train_data_loader(rng, dataset, batch_size):steps_per_epoch = len(dataset) // batch_sizeperms = jax.random.permutation(rng, len(dataset))perms = perms[: steps_per_epoch * batch_size]  # 略过不完整的批处理。perms = perms.reshape((steps_per_epoch, batch_size))for perm in perms:batch = dataset[perm]batch = {k: jax.device_put(jnp.array(v), data_sharding) for k, v in batch.items()}yield batchdef glue_eval_data_loader(dataset, batch_size):for i in range(len(dataset) // batch_size):batch = dataset[i * batch_size : (i + 1) * batch_size]batch = {k: jax.device_put(jnp.array(v), data_sharding) for k, v in batch.items()}yield batch
# 在所有设备上复制模型和优化器变量
def get_replicated_train_state(devices, state):# 所有变量将在所有设备上复制var_mesh = Mesh(devices, axis_names=("_"))# 在 NamedSharding 中,未提到的轴将被复制(此处为所有轴)var_replication = NamedSharding(var_mesh, P())# 应用分布设置到模型变量state = jax.device_put(state, var_replication)return statestate = get_replicated_train_state(devices, state)
rng = jax.random.PRNGKey(seed)
dropout_rng = jax.random.PRNGKey(seed)
for i, epoch in enumerate(tqdm(range(1, num_train_epochs + 1), desc=f"Epoch ...", position=0, leave=True)
):rng, input_rng = jax.random.split(rng)# 训练with tqdm(total=len(train_dataset) // total_batch_size, desc="Training...", leave=True) as progress_bar_train:for batch in glue_train_data_loader(input_rng, train_dataset, total_batch_size):state, train_metrics, dropout_rng = train_step(state, batch, dropout_rng)progress_bar_train.update(1)# 评估with tqdm(total=len(eval_dataset) // total_batch_size, desc="Evaluating...", leave=False) as progress_bar_eval:for batch in glue_eval_data_loader(eval_dataset, total_batch_size):labels = batch.pop("labels")predictions = eval_step(state, batch)metric.add_batch(predictions=list(predictions), references=list(labels))progress_bar_eval.update(1)eval_metric = metric.compute()loss = round(train_metrics["loss"].item(), 3)eval_score1 = round(list(eval_metric.values())[0], 3)metric_name1 = list(eval_metric.keys())[0]eval_score2 = round(list(eval_metric.values())[1], 3)metric_name2 = list(eval_metric.keys())[1]print(f"{i+1}/{num_train_epochs} | Train loss: {loss} | Eval {metric_name1}: {eval_score1}, {metric_name2}: {eval_score2}")
Epoch ...:   0%|          | 0/6 [00:00<?, ?it/s]
Training...:   0%|          | 0/71 [00:00<?, ?it/s]
Evaluating...:   0%|          | 0/71 [00:00<?, ?it/s]
1/6 | Train loss: 0.469 | Eval accuracy: 0.796, f1: 0.759
Training...:   0%|          | 0/71 [00:00<?, ?it/s]
Evaluating...:   0%|          | 0/71 [00:00<?, ?it/s]
2/6 | Train loss: 0.376 | Eval accuracy: 0.833, f1: 0.788
Training...:   0%|          | 0/71 [00:00<?, ?it/s]
Evaluating...:   0%|          | 0/71 [00:00<?, ?it/s]
3/6 | Train loss: 0.296 | Eval accuracy: 0.844, f1: 0.795
Training...:   0%|          | 0/71 [00:00<?, ?it/s]
Evaluating...:   0%|          | 0/71 [00:00<?, ?it/s]
4/6 | Train loss: 0.267 | Eval accuracy: 0.846, f1: 0.805
Training...:   0%|          | 0/71 [00:00<?, ?it/s]
Evaluating...:   0%|          | 0/71 [00:00<?, ?it/s]
5/6 | Train loss: 0.263 | Eval accuracy: 0.848, f1: 0.804
Training...:   0%|          | 0/71 [00:00<?, ?it/s]
Evaluating...:   0%|          | 0/71 [00:00<?, ?it/s]
6/6 | Train loss: 0.222 | Eval accuracy: 0.849, f1: 0.805

相关文章:

使用 JAX 进行 LLM 分布式监督微调

LLM distributed supervised fine-tuning with JAX — ROCm Blogs (amd.com) 24年1月25日&#xff0c;Douglas Jia 发布在AMD ROCm 博客上的文章。 在这篇文章中&#xff0c;我们回顾了使用 JAX 对基于双向编码器表示&#xff08;BERT&#xff09;的大型语言模型&#xff08;LL…...

【简单版】通过 Window.performance 实现前端页面(性能)监控

1 背景 前端监控系统告警xx接口fetchError 问题&#xff1a;前端监控系统没有更多的错误信息&#xff0c;查询该fetch请求对应的接口日志返回200状态码、无请求异常记录&#xff0c;且后台能查到通过该fetch请求成功发送的数据。那是前端页面的错误还是前端监控系统的问题&…...

微信小程序考试系统(lw+演示+源码+运行)

摘要 随着信息技术在管理上越来越深入而广泛的应用&#xff0c;管理信息系统的实施在技术上已逐步成熟。本文介绍了微信小程序考试系统的开发全过程。通过分析微信小程序考试系统管理的不足&#xff0c;创建了一个计算机管理微信小程序考试系统的方案。文章介绍了微信小程序考…...

手机摄影入门

感觉会摄影的人是能够从生活中发现美的人。 我不太会拍照&#xff0c;觉得拍好的照片比较浪费时间&#xff0c;而且缺乏审美也缺乏技巧&#xff0c;所以拍照的时候总是拍不好。但有时候还是需要拍一些好看的照片的。 心态和审美可能需要比较长时间提升&#xff0c;但一些基础…...

微信小程序手机号授权获取(aes加密手机号)

<view class="container"> <view class=topTabSwiper> <view class=tab {{currentData == 0 ? "tabBorer" : ""}} data-current = "0" bindtap=checkCurrent>一键授权<span class="tab_bor"><…...

asyn queueRequest使用实例

使用queueRequest读写端口驱动的示例&#xff0c;驱动驱动程序使用一个基于asyn实现了asynCommon和asynOctet的驱动程序-CSDN博客中编写的驱动程序&#xff0c;本程序的C代码如下&#xff1a; #include <stdlib.h> #include <stdio.h> #include <string.h>#…...

关于jmeter设置为中文问题之后无法保存设置的若干问题

1、jemeter如何设置中文模式 Options--->Choose Language--->Chinese(Simplifies), 如此设置后就可显示中文模式(缺点&#xff1a;下次打开还是英文)&#xff1b;如下图所示&#xff1a; 操作完成之后&#xff1a; 但是下次重启之后依旧是英文&#xff1b; 2、在jmeter.…...

基于FPGA的信号发生器verilog实现,可以输出方波,脉冲波,m序列以及正弦波,可调整输出信号频率

目录 1.算法运行效果图预览 2.算法运行软件版本 3.部分核心程序 4.算法理论概述 5.算法完整程序工程 1.算法运行效果图预览 (完整程序运行后无水印) 输出方波 输出脉冲波 输出m随机序列 输出正弦波 2.算法运行软件版本 vivado2019.2 3.部分核心程序 &#xff08;完整…...

背景全文及翻译

背景 Oracle数据向MySQL同步&#xff0c;没有最新数据&#xff0c;于是在plsql手敲SQL筛选最新数据时&#xff0c;执行报错。 问题描述 通过日期字段筛选最近的数据&#xff0c;我用了类似这样的语句&#xff1a; SELECT * FROM orders WHERE order_date > 2022/01/01;我…...

JAVA地狱级笑话

为什么Java开发者总是不怕黑暗&#xff1f; 因为他们总是有null指针来照亮路。 Java程序员最讨厌的音乐是什么&#xff1f; Garbage Collection旋律&#xff0c;节奏总是让他们烦躁。 为什么Java中的HashMap很擅长社交&#xff1f; 因为它总是能快速找到key对应的朋友。 Java开…...

宝塔PHP8.1安装fileinfo拓展失败解决办法

在宝塔面板中安装PHP8.1后&#xff0c;安装fileinfo扩展一直安装不上&#xff0c;查看日志有报错&#xff0c;于是手动来安装也报错。 宝塔报错&#xff1a; 手动命令行编译安装同&#xff0c;也有报错 cd /www/server/php/81/src/ext/fileinfo/ make distclean ./configure …...

Python 魔术方法

在Python中&#xff0c;魔术方法&#xff08;Magic Methods&#xff09;或称为双下划线方法&#xff08;Dunder Methods&#xff09;&#xff0c;是一类具有特殊用途的方法&#xff0c;其名称前后都带有两个下划线&#xff08;如 __init__、__str__ 等&#xff09;。这些方法定…...

03 go语言(golang) - fmt包基本类型

fmt包 在Go语言中&#xff0c;fmt 包是一个非常重要且广泛使用的标准库包&#xff0c;它提供了格式化I/O&#xff08;输入/输出&#xff09;功能&#xff0c;类似于C语言中的 printf 和 scanf。通过这个包&#xff0c;你可以读取输入并将数据格式化输出到标准输出或其他写入器…...

Docker本地镜像发布到阿里云镜像服务的简易指南

1 阿里云容器镜像服务 阿里云容器镜像服务&#xff08;Alibaba Cloud Container Registry&#xff0c;简称ACR&#xff09;是一个为容器镜像、Helm Chart等云原生资产提供安全托管及高效分发的平台。它支持多架构容器镜像&#xff0c;包括Linux、Windows、ARM等&#xff0c;以…...

大数据学习---快速了解clickhouse数据库

ClickHouse数据库介绍 ClickHouse是一款由Yandex开发的列式数据库管理系统&#xff08;DBMS&#xff09;&#xff0c;适用于在线分析处理&#xff08;OLAP&#xff09;场景。它具有高性能、可扩展性、实时更新等特点&#xff0c;适用于处理大规模数据。 特点 列式存储&#x…...

哪些方法可以缓解面试紧张?

面试紧张是许多人在面对重要职业机会时的一种常见情绪。虽然一定程度的紧张可能激发人的潜能&#xff0c;但过度的紧张则可能影响到面试表现。为了缓解面试紧张&#xff0c;以下是一些有效的方法&#xff1a; 1.充分准备&#xff1a; 深入了解公司背景、职位要求以及公司文化…...

即时通讯未读消息计数

单聊未读消息计数 未读消息的计数&#xff0c;分为两个部分&#xff1a;增加和减少 其中&#xff0c;未读消息计数的增加&#xff0c;是由数据库&#xff08;redis&#xff09;在写入消息的同时&#xff0c;增加对应接收方的未读消息计数 在线 用户在线时&#xff0c;客户端…...

在Openshift(K8S)上通过EMQX Operator部署Emqx集群

EMQX Operator 简介 EMQX Broker/Enterprise 是一个云原生的 MQTT 消息中间件。 我们提供了 EMQX Kubernetes Operator 来帮助您在 Kubernetes 的环境上快速创建和管理 EMQX Broker/Enterprise 集群。 它可以大大简化部署和管理 EMQX 集群的流程&#xff0c;对于管理和配置的知…...

Python酷玩之旅_数据分析入门(matplotlib)

导览 前言matplotlib入门1. 简介1.1 Pairwise data1.2 Statistical distributions1.3 Gridded data1.4 Irregularly gridded data1.5 3D and volumetric data 2. 实践2.1 安装2.2 示例 结语系列回顾 前言 翻看日历&#xff0c;今年的日子已划到了2024年10月19日&#xff0c;今天…...

uiautomatorviewer安卓9以上正常使用及问题处理

一、安卓9以上使用uiautomatorviewer问题现象 打开Unexpected error while obtaining UI hierarchy 问题详情 Unexpected error while obtaining UI hierarchy java.lang.reflect.InvocationTargetException 二、问题处理 需要的是替换对应D:\software\android-sdk-windows…...

19c补丁后oracle属主变化,导致不能识别磁盘组

补丁后服务器重启&#xff0c;数据库再次无法启动 ORA01017: invalid username/password; logon denied Oracle 19c 在打上 19.23 或以上补丁版本后&#xff0c;存在与用户组权限相关的问题。具体表现为&#xff0c;Oracle 实例的运行用户&#xff08;oracle&#xff09;和集…...

中南大学无人机智能体的全面评估!BEDI:用于评估无人机上具身智能体的综合性基准测试

作者&#xff1a;Mingning Guo, Mengwei Wu, Jiarun He, Shaoxian Li, Haifeng Li, Chao Tao单位&#xff1a;中南大学地球科学与信息物理学院论文标题&#xff1a;BEDI: A Comprehensive Benchmark for Evaluating Embodied Agents on UAVs论文链接&#xff1a;https://arxiv.…...

ESP32读取DHT11温湿度数据

芯片&#xff1a;ESP32 环境&#xff1a;Arduino 一、安装DHT11传感器库 红框的库&#xff0c;别安装错了 二、代码 注意&#xff0c;DATA口要连接在D15上 #include "DHT.h" // 包含DHT库#define DHTPIN 15 // 定义DHT11数据引脚连接到ESP32的GPIO15 #define D…...

土地利用/土地覆盖遥感解译与基于CLUE模型未来变化情景预测;从基础到高级,涵盖ArcGIS数据处理、ENVI遥感解译与CLUE模型情景模拟等

&#x1f50d; 土地利用/土地覆盖数据是生态、环境和气象等诸多领域模型的关键输入参数。通过遥感影像解译技术&#xff0c;可以精准获取历史或当前任何一个区域的土地利用/土地覆盖情况。这些数据不仅能够用于评估区域生态环境的变化趋势&#xff0c;还能有效评价重大生态工程…...

BLEU评分:机器翻译质量评估的黄金标准

BLEU评分&#xff1a;机器翻译质量评估的黄金标准 1. 引言 在自然语言处理(NLP)领域&#xff0c;衡量一个机器翻译模型的性能至关重要。BLEU (Bilingual Evaluation Understudy) 作为一种自动化评估指标&#xff0c;自2002年由IBM的Kishore Papineni等人提出以来&#xff0c;…...

HybridVLA——让单一LLM同时具备扩散和自回归动作预测能力:训练时既扩散也回归,但推理时则扩散

前言 如上一篇文章《dexcap升级版之DexWild》中的前言部分所说&#xff0c;在叠衣服的过程中&#xff0c;我会带着团队对比各种模型、方法、策略&#xff0c;毕竟针对各个场景始终寻找更优的解决方案&#xff0c;是我个人和我司「七月在线」的职责之一 且个人认为&#xff0c…...

MFE(微前端) Module Federation:Webpack.config.js文件中每个属性的含义解释

以Module Federation 插件详为例&#xff0c;Webpack.config.js它可能的配置和含义如下&#xff1a; 前言 Module Federation 的Webpack.config.js核心配置包括&#xff1a; name filename&#xff08;定义应用标识&#xff09; remotes&#xff08;引用远程模块&#xff0…...

Visual Studio Code 扩展

Visual Studio Code 扩展 change-case 大小写转换EmmyLua for VSCode 调试插件Bookmarks 书签 change-case 大小写转换 https://marketplace.visualstudio.com/items?itemNamewmaurer.change-case 选中单词后&#xff0c;命令 changeCase.commands 可预览转换效果 EmmyLua…...

基于单片机的宠物屋智能系统设计与实现(论文+源码)

本设计基于单片机的宠物屋智能系统核心是实现对宠物生活环境及状态的智能管理。系统以单片机为中枢&#xff0c;连接红外测温传感器&#xff0c;可实时精准捕捉宠物体温变化&#xff0c;以便及时发现健康异常&#xff1b;水位检测传感器时刻监测饮用水余量&#xff0c;防止宠物…...

数据库正常,但后端收不到数据原因及解决

从代码和日志来看&#xff0c;后端SQL查询确实返回了数据&#xff0c;但最终user对象却为null。这表明查询结果没有正确映射到User对象上。 在前后端分离&#xff0c;并且ai辅助开发的时候&#xff0c;很容易出现前后端变量名不一致情况&#xff0c;还不报错&#xff0c;只是单…...