当前位置: 首页 > news >正文

使用 JAX 进行 LLM 分布式监督微调

LLM distributed supervised fine-tuning with JAX — ROCm Blogs (amd.com)

24年1月25日,Douglas Jia 发布在AMD ROCm 博客上的文章。

在这篇文章中,我们回顾了使用 JAX 对基于双向编码器表示(BERT)的大型语言模型(LLM)进行文本分类任务微调的过程。我们探讨了在多个 AMD GPU 上并行化这一微调过程的技术,然后评估模型在测试数据集上的性能。为此,我们使用了一个基于 BERT的 cased transformer 模型和 General Language Understanding Evaluation(GLUE)基准数据集在多个 AMD GPU 上进行实验。

我们重点关注 JAX 中两个单程序多数据(SPMD)并行化方法。这两个方法是:
- 使用 pmap 函数在单个领先轴上进行简单的数据分发。
- 使用 jit、`Mesh` 和 mesh_utils 函数在设备之间分片数据,提供更大的并行化控制。

我们主要强调第一个方法,并在文章的最后部分提供了第二个方法的详细说明。
在撰写本文时,我们参考了这个教程,我们强烈推荐阅读。

什么是监督微调?

在人工智能(AI)时代,基于Transformer架构的模型(如 BERT、GPT-3 及其后续版本)为实现各种自然语言处理(NLP)任务(如文本分类、文本生成和情感分析)的尖端性能提供了坚实的基础。然而,当这些大型预训练模型单独应用于这些特定任务时,常常表现出一定的局限性。监督微调(SFT)为解决这些局限性提供了方案。

与在大规模、多样化数据集上进行广泛无监督训练的预训练模型不同,SFT采用了一种专注且资源高效的方法。通常,这需要一个相对紧凑、高质量的数据集,该数据集精确地针对特定任务量身定制。SFT可以在不需要长时间训练的情况下,将模型性能提升到最先进的水平,因为它能够利用预训练模型所获得的广泛知识。

SFT过程包括微调模型的现有权重或添加额外参数,以确保与指定任务的复杂性保持一致。通常,这种适应会结合任务特定的层,例如为分类添加一个 softmax 层,从而增强模型解决监督任务的能力。

什么是 JAX?

JAX 是一个高性能的 Python 数值计算库。与传统的机器学习框架(如 TensorFlow 和 PyTorch)相比,JAX 的速度和效率都非常出色。JAX 利用即时编译(JIT),无缝的自动微分,以及高效向量化和并行化代码的能力,使其能简单地适配 AI 加速器(如 GPU 和 TPU)。

为什么使用 AMD GPU?

AMD GPU 因其强大的开源支持而脱颖而出,工具如 ROCm 和 HIP 使其易于适配 AI 工作流程。AMD 具有竞争力的性价比,非常适合寻求成本效益的 AI 和深度学习任务解决方案的用户。随着 AMD 在市场上的影响力不断增长,越来越多的机器学习库和框架正在添加对 AMD GPU 的支持。

硬件要求和运行环境

为了利用完成此任务所需的计算能力,我们使用AMD加速器云平台 (AAC)。AAC 是一个按需提供云计算资源和API的付费平台。具体来说,我们使用一个JAX Docker容器,其在AAC上拥有8个GPU,以充分利用先进的GPU并行计算能力。

本文是硬件无关的,这意味着要成功运行提供的代码示例,不需要访问AAC。只要您有加速器设备(如GPU或TPU),您应该能够以最小的代码修改来运行这些代码示例。如果您使用的是AMD GPU,请确保正确安装了ROCm及其兼容版本的JAX和Jaxlib。参考以下教程进行安装:

  • ROCm 安装

  • JAX and Jaxlib 安装: 您也可以直接通过链接拉取一个JAX Docker镜像。

代码示例:对Transformer模型进行SFT

为了演示,我们使用一个通用语言理解评估(GLUE)基准数据集Quora Question Pairs(QQP)微调一个基于transformer的LLM(如:bert-base-cased)。该数据集包含超过40万对问题,每对问题都有一个二进制注释,指示这两个问题是否是相互的复述。输入变量是两个问题的句子,而输出变量是一个二进制指标,表示这两个问题是否具有相同的含义。

安装

首先,安装所需的软件包 (%%capture 是一个 _cell magic_,它将抑制单元格的输出)。

%%capture
!pip install datasets
!pip install git+https://github.com/huggingface/transformers.git
!pip install flax
!pip install git+https://github.com/deepmind/optax.git
!pip install evaluate
!pip install ipywidgets
!pip install black isort # 单元格中的格式化器;可选项

导入剩余的软件包和功能。

import os
from itertools import chain
from typing import Callableimport evaluate
import flax
import jax
import jax.numpy as jnp
import optax
import pandas as pd
from datasets import load_dataset
from flax import traverse_util
from flax.training import train_state
from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key
from ipywidgets import IntProgress as IProgress
from tqdm.notebook import tqdm
from transformers import (AutoConfig,AutoTokenizer,FlaxAutoModelForSequenceClassification,
)os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"

JAX 预先分配75%的GPU内存以减少首次运行JAX操作时的开销和碎片,但可能会触发内存不足(OOM)错误。为了避免OOM问题,可通过将 XLA_PYTHON_CLIENT_PREALLOCATE 标志设置为 false 来抑制默认行为。

检查是否可以通过JAX检测到GPU设备。如果不能,可能需要重新安装ROCm、JAX和Jaxlib。如果JAX安装正确,你可以看到所有请求的GPU设备,在我们的例子中是8个GPU。

jax.local_devices()
[gpu(id=0),gpu(id=1),gpu(id=2),gpu(id=3),gpu(id=4),gpu(id=5),gpu(id=6),gpu(id=7)]

获取微调数据集和预训练模型检查点

指定你的微调过程的设置:数据集、预训练模型以及每个设备每批次要处理的数据量。

task = "qqp"
model_checkpoint = "bert-base-cased"
per_device_batch_size = 64

加载数据集和评估指标模块。

raw_dataset = load_dataset("glue", task)
metric = evaluate.load("glue", task)

接下来的几段代码展示了如何使用模型特定的分词器对文本数据进行分词,并加载分词后的训练和验证数据。使用与预训练模型相同的分词器确保在微调过程中相同的词会被转换为相同的嵌入向量。

重要的是,我们在原始训练数据中对训练和评估数据集进行了10%的抽样。尽管如此,QQP数据集仍然提供了足够的数据来实现令人满意的性能,并且可以在每个epoch后观察到指标的改进。这种抽样方法还加快了我们的训练过程,便于说明。

使用数据预处理函数和map包装器的批处理和并行处理功能处理训练和评估数据集。你可以在以下输出中查看分词后的数据集。

tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
def preprocess_function(examples):texts = (examples["question1"], examples["question2"])processed = tokenizer(*texts, padding="max_length", max_length=128, truncation=True)processed["labels"] = examples["label"]return processed
# 关于如何处理和操作 huggingface 数据集的详细信息:
# https://huggingface.co/docs/datasets/process
data = raw_dataset["train"].shuffle(seed=0)
train_data = data.select(list(range(int(data.shape[0] * 0.1))))
eval_data = data.select(list(range(int(data.shape[0] * 0.1), int(data.shape[0] * 0.2))))
print(f"原始训练数据集的形状为: {data.shape}")
print(f"当前训练数据集的形状为: {train_data.shape}")
print(f"当前验证数据集的形状为: {eval_data.shape}")
原始训练数据集的形状为: (363846, 4)
当前训练数据集的形状为: (36384, 4)
当前验证数据集的形状为: (36385, 4)
train_dataset = train_data.map(preprocess_function, batched=True, remove_columns=train_data.column_names
)
eval_dataset = eval_data.map(preprocess_function, batched=True, remove_columns=eval_data.column_names
)
# 你可以在以下单元格的输出中查看已分词的数据集
pd.DataFrame(train_dataset[:3])

从Hugging Face下载预训练模型配置和检查点。注意,你会看到一个警告信息,指出某些模型权重未使用。这是预期的,因为BERT模型检查点是一个PreTraining模型类,而你正在初始化一个
SequenceClassification模型。警告信息指出:你可能需要在下游任务上训练该模型,以便能够将其用于预测和推理。 这就是我们接下来要关注的内容。

num_labels = 2
seed = 0
config = AutoConfig.from_pretrained(model_checkpoint, num_labels=num_labels)
model = FlaxAutoModelForSequenceClassification.from_pretrained(model_checkpoint, config=config, seed=seed
)
某些在bert-base-cased模型检查点中的权重在初始化FlaxBertForSequenceClassification时未被使用: {('cls', 'predictions', 'bias'), ('cls', 'predictions', 'transform', 'dense', 'kernel'), ('cls', 'predictions', 'transform', 'LayerNorm', 'bias'), ('cls', 'predictions', 'transform', 'LayerNorm', 'scale'), ('cls', 'predictions', 'transform', 'dense', 'bias')}
- 如果您正在从另一个任务或架构的模型检查点初始化FlaxBertForSequenceClassification(例如,从BertForPreTraining模型初始化BertForSequenceClassification模型),这是预期的。
- 如果您正在从您期望完全相同的模型检查点初始化FlaxBertForSequenceClassification(从BertForSequenceClassification模型初始化BertForSequenceClassification模型),这不是预期的。
某些在bert-base-cased模型检查点中的权重未被初始化到FlaxBertForSequenceClassification并被重新初始化: {('classifier', 'kernel'), ('classifier', 'bias'), ('bert', 'pooler', 'dense', 'kernel'), ('bert', 'pooler', 'dense', 'bias')}
您可能需要在下游任务中训练此模型,以便能够使用它进行预测和推理。

定义微调模型的状态

以下代码块展示了如何设置训练参数,比如训练周期数和初始学习率。学习率调度是为了使学习率在训练过程中线性衰减,以确保学习的效率和稳定性。

num_train_epochs = 6
learning_rate = 2e-5
total_batch_size = per_device_batch_size * jax.local_device_count()
print("The overall batch size (both for training and eval) is", total_batch_size)
The overall batch size (both for training and eval) is 512
num_train_steps = len(train_dataset) // total_batch_size * num_train_epochslearning_rate_function = optax.linear_schedule(init_value=learning_rate, end_value=0, transition_steps=num_train_steps
)

接下来,需要建立训练状态,包括优化器和损失函数的职责,并监督模型参数在训练过程中的更新。

使用状态对象,初始化和更新模型。当调用模型时,将状态作为输入,模型会返回通过新数据批次更新后的状态,同时保留模型实例。

Flax 提供了一个用户友好的类(`flax.training.train_state.TrainState`),它将模型参数、损失函数和优化器封装在一起。当提供数据时,它可以使用 apply_gradients 函数更新模型参数。

下面的代码块展示了如何定义和建立训练状态、优化器和损失函数。

class TrainState(train_state.TrainState):logits_function: Callable = flax.struct.field(pytree_node=False)loss_function: Callable = flax.struct.field(pytree_node=False)
# 创建一个 decay_mask_fn 函数,以确保对任何偏置项或 LayerNorm 权重不应用权重衰减,因为这可能不会提高模型性能甚至会有害。def decay_mask_fn(params):flat_params = traverse_util.flatten_dict(params)flat_mask = {path: (path[-1] != "bias" and path[-2:] != ("LayerNorm", "scale"))for path in flat_params}return traverse_util.unflatten_dict(flat_mask)
# 标准的带权重衰减的 Adam 优化器
def adamw(weight_decay):return optax.adamw(learning_rate=learning_rate_function,b1=0.9,b2=0.999,eps=1e-6,weight_decay=weight_decay,mask=decay_mask_fn,)
def loss_function(logits, labels):xentropy = optax.softmax_cross_entropy(logits, onehot(labels, num_classes=num_labels))return jnp.mean(xentropy)def eval_function(logits):return logits.argmax(-1)
# 实例化 TrainState
state = TrainState.create(apply_fn=model.__call__,params=model.params,tx=adamw(weight_decay=0.01),logits_function=eval_function,loss_function=loss_function,
)

定义如何训练、评估模型并启用并行化

train_step 和 eval_step 参数定义了如何训练和评估模型。训练步骤遵循标准的训练过程:

  1. 使用当前的权重计算损失。

  2. 计算损失函数相对于权重的梯度。

  3. 使用梯度和学习率更新权重。

  4. 使用梯度和学习率更新权重。

需要强调的是,`lax.pmean` 函数计算跨所有 8 个 GPU 设备的数据批次梯度的均值。这个关键步骤保证了所有 GPU 设备上的模型参数同步。

def train_step(state, batch, dropout_rng):targets = batch.pop("labels")dropout_rng, new_dropout_rng = jax.random.split(dropout_rng)def loss_function(params):logits = state.apply_fn(**batch, params=params, dropout_rng=dropout_rng, train=True)[0]loss = state.loss_function(logits, targets)return lossgrad_function = jax.value_and_grad(loss_function)loss, grad = grad_function(state.params)grad = jax.lax.pmean(grad, "batch")new_state = state.apply_gradients(grads=grad)metrics = jax.lax.pmean({"loss": loss, "learning_rate": learning_rate_function(state.step)},axis_name="batch",)return new_state, metrics, new_dropout_rng
def eval_step(state, batch):logits = state.apply_fn(**batch, params=state.params, train=False)[0]return state.logits_function(logits)

接下来,应用 jax.pmap 函数到定义的 train_step 和 eval_step 函数。将 pmap() 应用于函数时,该函数会使用 XLA 编译(类似于 jit()),然后在 XLA 设备上并行运行,例如多 GPU 设备或多 TPU 核。简单来说,这一步将训练和评估函数发送到所有 GPU 设备。你还需要通过 flax.jax_utils.replicate 将训练状态发送到所有 GPU 设备,这些步骤确保你通过分布式训练在所有 GPU 设备上更新模型状态。

parallel_train_step = jax.pmap(train_step, axis_name="batch", donate_argnums=(0,))
parallel_eval_step = jax.pmap(eval_step, axis_name="batch")
state = flax.jax_utils.replicate(state)

定义数据加载函数,这些函数返回数据批次生成器。在最终的训练和评估循环中,每一步都会输入一个新的数据批次。

def glue_train_data_loader(rng, dataset, batch_size):steps_per_epoch = len(dataset) // batch_sizeperms = jax.random.permutation(rng, len(dataset))perms = perms[: steps_per_epoch * batch_size]  # 跳过不完整的批次。perms = perms.reshape((steps_per_epoch, batch_size))for perm in perms:batch = dataset[perm]batch = {k: jnp.array(v) for k, v in batch.items()}batch = shard(batch)yield batch
def glue_eval_data_loader(dataset, batch_size):for i in range(len(dataset) // batch_size):batch = dataset[i * batch_size : (i + 1) * batch_size]batch = {k: jnp.array(v) for k, v in batch.items()}batch = shard(batch)yield batch

基于整数种子生成伪随机数生成器(PRNG)密钥,然后将其拆分为 8 个新的密钥,以确保每个 GPU 设备都得到不同的密钥。然后运行训练步骤,以根据预定义的训练参数(如训练轮次和总批次大小)更新 state。在每个轮次结束时,运行评估步骤,以查看评估数据集上的准确率和 F1 指标。由于使用的训练数据集比基准中的原始训练数据集要小,可以看到在前几轮训练中,评估指标(训练损失和评估准确率)稳定提升。

rng = jax.random.PRNGKey(seed)
dropout_rngs = jax.random.split(rng, jax.local_device_count())
for i, epoch in enumerate(tqdm(range(1, num_train_epochs + 1), desc=f"Epoch ...", position=0, leave=True)
):rng, input_rng = jax.random.split(rng)# trainwith tqdm(total=len(train_dataset) // total_batch_size, desc="Training...", leave=True) as progress_bar_train:for batch in glue_train_data_loader(input_rng, train_dataset, total_batch_size):state, train_metrics, dropout_rngs = parallel_train_step(state, batch, dropout_rngs)progress_bar_train.update(1)# 评估with tqdm(total=len(eval_dataset) // total_batch_size, desc="Evaluating...", leave=False) as progress_bar_eval:for batch in glue_eval_data_loader(eval_dataset, total_batch_size):labels = batch.pop("labels")predictions = parallel_eval_step(state, batch)metric.add_batch(predictions=list(chain(*predictions)), references=list(chain(*labels)))progress_bar_eval.update(1)eval_metric = metric.compute()loss = round(flax.jax_utils.unreplicate(train_metrics)["loss"].item(), 3)eval_score1 = round(list(eval_metric.values())[0], 3)metric_name1 = list(eval_metric.keys())[0]eval_score2 = round(list(eval_metric.values())[1], 3)metric_name2 = list(eval_metric.keys())[1]print(f"{i+1}/{num_train_epochs} | Train loss: {loss} | Eval {metric_name1}: {eval_score1}, {metric_name2}: {eval_score2}")
Epoch ...:   0%|          | 0/6 [00:00<?, ?it/s]
Training...:   0%|          | 0/71 [00:00<?, ?it/s]
Evaluating...:   0%|          | 0/71 [00:00<?, ?it/s]
1/6 | Train loss: 0.475 | Eval accuracy: 0.799, f1: 0.762
Training...:   0%|          | 0/71 [00:00<?, ?it/s]
Evaluating...:   0%|          | 0/71 [00:00<?, ?it/s]
2/6 | Train loss: 0.369 | Eval accuracy: 0.834, f1: 0.789
Training...:   0%|          | 0/71 [00:00<?, ?it/s]
Evaluating...:   0%|          | 0/71 [00:00<?, ?it/s]
3/6 | Train loss: 0.299 | Eval accuracy: 0.846, f1: 0.797
Training...:   0%|          | 0/71 [00:00<?, ?it/s]
Evaluating...:   0%|          | 0/71 [00:00<?, ?it/s]
4/6 | Train loss: 0.239 | Eval accuracy: 0.846, f1: 0.806
Training...:   0%|          | 0/71 [00:00<?, ?it/s]
Evaluating...:   0%|          | 0/71 [00:00<?, ?it/s]
5/6 | Train loss: 0.252 | Eval accuracy: 0.849, f1: 0.802
Training...:   0%|          | 0/71 [00:00<?, ?it/s]
Evaluating...:   0%|          | 0/71 [00:00<?, ?it/s]
6/6 | Train loss: 0.212 | Eval accuracy: 0.849, f1: 0.805

使用JAX设备网格来实现并行化

from jax.experimental import mesh_utils
from jax.sharding import Mesh, NamedSharding
from jax.sharding import PartitionSpec as P
config = AutoConfig.from_pretrained(model_checkpoint, num_labels=num_labels)
model = FlaxAutoModelForSequenceClassification.from_pretrained(model_checkpoint, config=config, seed=seed
)
state = TrainState.create(apply_fn=model.__call__,params=model.params,tx=adamw(weight_decay=0.01),logits_function=eval_function,loss_function=loss_function,
)
一些来自 bert-base-cased 模型检查点的权重在初始化 FlaxBertForSequenceClassification 时未被使用: {('cls', 'predictions', 'bias'), ('cls', 'predictions', 'transform', 'dense', 'kernel'), ('cls', 'predictions', 'transform', 'LayerNorm', 'bias'), ('cls', 'predictions', 'transform', 'LayerNorm', 'scale'), ('cls', 'predictions', 'transform', 'dense', 'bias')}
- 当你用模型训练其他任务或用另一种架构初始化 FlaxBertForSequenceClassification 时,这是预期中的情况(例如从 BertForPreTraining 模型初始化 BertForSequenceClassification 模型)。
- 当你期望从与 FlaxBertForSequenceClassification 模型完全相同的检查点初始化时(从 BertForSequenceClassification 模型初始化 BertForSequenceClassification 模型),这不是预期情况。
FlaxBertForSequenceClassification 中一些权重没有从 bert-base-cased 模型检查点初始化,是新初始化的: {('classifier', 'kernel'), ('classifier', 'bias'), ('bert', 'pooler', 'dense', 'kernel'), ('bert', 'pooler', 'dense', 'bias')}
应该将这个模型训练到下游任务上以便用于预测和推断。
@jax.jit
def train_step(state, batch, dropout_rng):targets = batch.pop("labels")dropout_rng, new_dropout_rng = jax.random.split(dropout_rng)def loss_function(params):logits = state.apply_fn(**batch, params=params, dropout_rng=dropout_rng, train=True)[0]loss = state.loss_function(logits, targets)return lossgrad_function = jax.value_and_grad(loss_function)loss, grad = grad_function(state.params)new_state = state.apply_gradients(grads=grad)metrics = {"loss": loss, "learning_rate": learning_rate_function(state.step)}return new_state, metrics, new_dropout_rng
@jax.jit
def eval_step(state, batch):logits = state.apply_fn(**batch, params=state.params, train=False)[0]return state.logits_function(logits)
num_devices = len(jax.local_devices())
devices = mesh_utils.create_device_mesh((num_devices,))# 数据将沿批处理轴进行分割
data_mesh = Mesh(devices, axis_names=("batch",))  # naming axes of the mesh
data_sharding = NamedSharding(data_mesh,P("batch",),
)  # 命名网格的轴def glue_train_data_loader(rng, dataset, batch_size):steps_per_epoch = len(dataset) // batch_sizeperms = jax.random.permutation(rng, len(dataset))perms = perms[: steps_per_epoch * batch_size]  # 略过不完整的批处理。perms = perms.reshape((steps_per_epoch, batch_size))for perm in perms:batch = dataset[perm]batch = {k: jax.device_put(jnp.array(v), data_sharding) for k, v in batch.items()}yield batchdef glue_eval_data_loader(dataset, batch_size):for i in range(len(dataset) // batch_size):batch = dataset[i * batch_size : (i + 1) * batch_size]batch = {k: jax.device_put(jnp.array(v), data_sharding) for k, v in batch.items()}yield batch
# 在所有设备上复制模型和优化器变量
def get_replicated_train_state(devices, state):# 所有变量将在所有设备上复制var_mesh = Mesh(devices, axis_names=("_"))# 在 NamedSharding 中,未提到的轴将被复制(此处为所有轴)var_replication = NamedSharding(var_mesh, P())# 应用分布设置到模型变量state = jax.device_put(state, var_replication)return statestate = get_replicated_train_state(devices, state)
rng = jax.random.PRNGKey(seed)
dropout_rng = jax.random.PRNGKey(seed)
for i, epoch in enumerate(tqdm(range(1, num_train_epochs + 1), desc=f"Epoch ...", position=0, leave=True)
):rng, input_rng = jax.random.split(rng)# 训练with tqdm(total=len(train_dataset) // total_batch_size, desc="Training...", leave=True) as progress_bar_train:for batch in glue_train_data_loader(input_rng, train_dataset, total_batch_size):state, train_metrics, dropout_rng = train_step(state, batch, dropout_rng)progress_bar_train.update(1)# 评估with tqdm(total=len(eval_dataset) // total_batch_size, desc="Evaluating...", leave=False) as progress_bar_eval:for batch in glue_eval_data_loader(eval_dataset, total_batch_size):labels = batch.pop("labels")predictions = eval_step(state, batch)metric.add_batch(predictions=list(predictions), references=list(labels))progress_bar_eval.update(1)eval_metric = metric.compute()loss = round(train_metrics["loss"].item(), 3)eval_score1 = round(list(eval_metric.values())[0], 3)metric_name1 = list(eval_metric.keys())[0]eval_score2 = round(list(eval_metric.values())[1], 3)metric_name2 = list(eval_metric.keys())[1]print(f"{i+1}/{num_train_epochs} | Train loss: {loss} | Eval {metric_name1}: {eval_score1}, {metric_name2}: {eval_score2}")
Epoch ...:   0%|          | 0/6 [00:00<?, ?it/s]
Training...:   0%|          | 0/71 [00:00<?, ?it/s]
Evaluating...:   0%|          | 0/71 [00:00<?, ?it/s]
1/6 | Train loss: 0.469 | Eval accuracy: 0.796, f1: 0.759
Training...:   0%|          | 0/71 [00:00<?, ?it/s]
Evaluating...:   0%|          | 0/71 [00:00<?, ?it/s]
2/6 | Train loss: 0.376 | Eval accuracy: 0.833, f1: 0.788
Training...:   0%|          | 0/71 [00:00<?, ?it/s]
Evaluating...:   0%|          | 0/71 [00:00<?, ?it/s]
3/6 | Train loss: 0.296 | Eval accuracy: 0.844, f1: 0.795
Training...:   0%|          | 0/71 [00:00<?, ?it/s]
Evaluating...:   0%|          | 0/71 [00:00<?, ?it/s]
4/6 | Train loss: 0.267 | Eval accuracy: 0.846, f1: 0.805
Training...:   0%|          | 0/71 [00:00<?, ?it/s]
Evaluating...:   0%|          | 0/71 [00:00<?, ?it/s]
5/6 | Train loss: 0.263 | Eval accuracy: 0.848, f1: 0.804
Training...:   0%|          | 0/71 [00:00<?, ?it/s]
Evaluating...:   0%|          | 0/71 [00:00<?, ?it/s]
6/6 | Train loss: 0.222 | Eval accuracy: 0.849, f1: 0.805

相关文章:

使用 JAX 进行 LLM 分布式监督微调

LLM distributed supervised fine-tuning with JAX — ROCm Blogs (amd.com) 24年1月25日&#xff0c;Douglas Jia 发布在AMD ROCm 博客上的文章。 在这篇文章中&#xff0c;我们回顾了使用 JAX 对基于双向编码器表示&#xff08;BERT&#xff09;的大型语言模型&#xff08;LL…...

【简单版】通过 Window.performance 实现前端页面(性能)监控

1 背景 前端监控系统告警xx接口fetchError 问题&#xff1a;前端监控系统没有更多的错误信息&#xff0c;查询该fetch请求对应的接口日志返回200状态码、无请求异常记录&#xff0c;且后台能查到通过该fetch请求成功发送的数据。那是前端页面的错误还是前端监控系统的问题&…...

微信小程序考试系统(lw+演示+源码+运行)

摘要 随着信息技术在管理上越来越深入而广泛的应用&#xff0c;管理信息系统的实施在技术上已逐步成熟。本文介绍了微信小程序考试系统的开发全过程。通过分析微信小程序考试系统管理的不足&#xff0c;创建了一个计算机管理微信小程序考试系统的方案。文章介绍了微信小程序考…...

手机摄影入门

感觉会摄影的人是能够从生活中发现美的人。 我不太会拍照&#xff0c;觉得拍好的照片比较浪费时间&#xff0c;而且缺乏审美也缺乏技巧&#xff0c;所以拍照的时候总是拍不好。但有时候还是需要拍一些好看的照片的。 心态和审美可能需要比较长时间提升&#xff0c;但一些基础…...

微信小程序手机号授权获取(aes加密手机号)

<view class="container"> <view class=topTabSwiper> <view class=tab {{currentData == 0 ? "tabBorer" : ""}} data-current = "0" bindtap=checkCurrent>一键授权<span class="tab_bor"><…...

asyn queueRequest使用实例

使用queueRequest读写端口驱动的示例&#xff0c;驱动驱动程序使用一个基于asyn实现了asynCommon和asynOctet的驱动程序-CSDN博客中编写的驱动程序&#xff0c;本程序的C代码如下&#xff1a; #include <stdlib.h> #include <stdio.h> #include <string.h>#…...

关于jmeter设置为中文问题之后无法保存设置的若干问题

1、jemeter如何设置中文模式 Options--->Choose Language--->Chinese(Simplifies), 如此设置后就可显示中文模式(缺点&#xff1a;下次打开还是英文)&#xff1b;如下图所示&#xff1a; 操作完成之后&#xff1a; 但是下次重启之后依旧是英文&#xff1b; 2、在jmeter.…...

基于FPGA的信号发生器verilog实现,可以输出方波,脉冲波,m序列以及正弦波,可调整输出信号频率

目录 1.算法运行效果图预览 2.算法运行软件版本 3.部分核心程序 4.算法理论概述 5.算法完整程序工程 1.算法运行效果图预览 (完整程序运行后无水印) 输出方波 输出脉冲波 输出m随机序列 输出正弦波 2.算法运行软件版本 vivado2019.2 3.部分核心程序 &#xff08;完整…...

背景全文及翻译

背景 Oracle数据向MySQL同步&#xff0c;没有最新数据&#xff0c;于是在plsql手敲SQL筛选最新数据时&#xff0c;执行报错。 问题描述 通过日期字段筛选最近的数据&#xff0c;我用了类似这样的语句&#xff1a; SELECT * FROM orders WHERE order_date > 2022/01/01;我…...

JAVA地狱级笑话

为什么Java开发者总是不怕黑暗&#xff1f; 因为他们总是有null指针来照亮路。 Java程序员最讨厌的音乐是什么&#xff1f; Garbage Collection旋律&#xff0c;节奏总是让他们烦躁。 为什么Java中的HashMap很擅长社交&#xff1f; 因为它总是能快速找到key对应的朋友。 Java开…...

宝塔PHP8.1安装fileinfo拓展失败解决办法

在宝塔面板中安装PHP8.1后&#xff0c;安装fileinfo扩展一直安装不上&#xff0c;查看日志有报错&#xff0c;于是手动来安装也报错。 宝塔报错&#xff1a; 手动命令行编译安装同&#xff0c;也有报错 cd /www/server/php/81/src/ext/fileinfo/ make distclean ./configure …...

Python 魔术方法

在Python中&#xff0c;魔术方法&#xff08;Magic Methods&#xff09;或称为双下划线方法&#xff08;Dunder Methods&#xff09;&#xff0c;是一类具有特殊用途的方法&#xff0c;其名称前后都带有两个下划线&#xff08;如 __init__、__str__ 等&#xff09;。这些方法定…...

03 go语言(golang) - fmt包基本类型

fmt包 在Go语言中&#xff0c;fmt 包是一个非常重要且广泛使用的标准库包&#xff0c;它提供了格式化I/O&#xff08;输入/输出&#xff09;功能&#xff0c;类似于C语言中的 printf 和 scanf。通过这个包&#xff0c;你可以读取输入并将数据格式化输出到标准输出或其他写入器…...

Docker本地镜像发布到阿里云镜像服务的简易指南

1 阿里云容器镜像服务 阿里云容器镜像服务&#xff08;Alibaba Cloud Container Registry&#xff0c;简称ACR&#xff09;是一个为容器镜像、Helm Chart等云原生资产提供安全托管及高效分发的平台。它支持多架构容器镜像&#xff0c;包括Linux、Windows、ARM等&#xff0c;以…...

大数据学习---快速了解clickhouse数据库

ClickHouse数据库介绍 ClickHouse是一款由Yandex开发的列式数据库管理系统&#xff08;DBMS&#xff09;&#xff0c;适用于在线分析处理&#xff08;OLAP&#xff09;场景。它具有高性能、可扩展性、实时更新等特点&#xff0c;适用于处理大规模数据。 特点 列式存储&#x…...

哪些方法可以缓解面试紧张?

面试紧张是许多人在面对重要职业机会时的一种常见情绪。虽然一定程度的紧张可能激发人的潜能&#xff0c;但过度的紧张则可能影响到面试表现。为了缓解面试紧张&#xff0c;以下是一些有效的方法&#xff1a; 1.充分准备&#xff1a; 深入了解公司背景、职位要求以及公司文化…...

即时通讯未读消息计数

单聊未读消息计数 未读消息的计数&#xff0c;分为两个部分&#xff1a;增加和减少 其中&#xff0c;未读消息计数的增加&#xff0c;是由数据库&#xff08;redis&#xff09;在写入消息的同时&#xff0c;增加对应接收方的未读消息计数 在线 用户在线时&#xff0c;客户端…...

在Openshift(K8S)上通过EMQX Operator部署Emqx集群

EMQX Operator 简介 EMQX Broker/Enterprise 是一个云原生的 MQTT 消息中间件。 我们提供了 EMQX Kubernetes Operator 来帮助您在 Kubernetes 的环境上快速创建和管理 EMQX Broker/Enterprise 集群。 它可以大大简化部署和管理 EMQX 集群的流程&#xff0c;对于管理和配置的知…...

Python酷玩之旅_数据分析入门(matplotlib)

导览 前言matplotlib入门1. 简介1.1 Pairwise data1.2 Statistical distributions1.3 Gridded data1.4 Irregularly gridded data1.5 3D and volumetric data 2. 实践2.1 安装2.2 示例 结语系列回顾 前言 翻看日历&#xff0c;今年的日子已划到了2024年10月19日&#xff0c;今天…...

uiautomatorviewer安卓9以上正常使用及问题处理

一、安卓9以上使用uiautomatorviewer问题现象 打开Unexpected error while obtaining UI hierarchy 问题详情 Unexpected error while obtaining UI hierarchy java.lang.reflect.InvocationTargetException 二、问题处理 需要的是替换对应D:\software\android-sdk-windows…...

Go语言gRPC快速入门

文章目录 前言gRPC是什么Go语言的gRPC技术栈准备工作接口定义代码生成服务端代码编写客户端代码编写效果演示完整代码链接最后 前言 你好&#xff0c;我是醉墨居士&#xff0c;这篇博客想帮助初学者能够快速入门gRPC&#xff0c;希望能够为你节省宝贵的时间&#xff0c;让时间…...

Golang | Leetcode Golang题解之第479题最大回文数乘积

题目&#xff1a; 题解&#xff1a; func largestPalindrome(n int) int {if n 1 {return 9}upper : int(math.Pow10(n)) - 1for left : upper; ; left-- { // 枚举回文数的左半部分p : leftfor x : left; x > 0; x / 10 {p p*10 x%10 // 翻转左半部分到其自身末尾&…...

UDP协议讲解

预备知识&#xff1a; 端口号port&#xff1a; 我们在正常网络通信时&#xff0c;实际上是进程在互相通信。 我们所有的网络通信的行为&#xff0c;本质上都是进程间通信。 对双方而言&#xff0c;1.先保证数据能到达自己的机器 ip解决 2.找到指定的进程 端口号 ip地址用来…...

交叉注意力融合时域、频域特征的FFT + CNN -BiLSTM-CrossAttention轴承故障识别模型

往期精彩内容&#xff1a; Python-凯斯西储大学&#xff08;CWRU&#xff09;轴承数据解读与分类处理 Pytorch-LSTM轴承故障一维信号分类(一)-CSDN博客 Pytorch-CNN轴承故障一维信号分类(二)-CSDN博客 Pytorch-Transformer轴承故障一维信号分类(三)-CSDN博客 三十多个开源…...

CSDN Markdown 编辑器语法大全

Markdown 是一种轻量级标记语言&#xff0c;它以简洁、易读易写的特点&#xff0c;被广泛应用于技术文档、博客文章、笔记等领域。CSDN 的 Markdown 编辑器为用户提供了丰富的功能&#xff0c;让用户能够轻松地创建格式规范、内容丰富的文档。以下是一份详细的 CSDN Markdown 编…...

TCP/IP 协议【四次挥手】简要说明

四次挥手是为了确保数据的完整性和可靠性&#xff0c;解决的主要问题是双方在断开连接时&#xff0c;可能还有未完成传输的数据或者未被接收的数据。 具体来说&#xff0c;四次挥手解决的问题是&#xff1a; 第一次挥手&#xff08;发送方向接收方发送FIN包&#xff09;&#…...

第11篇:网络安全协议

目录 引言 11.1 安全套接字层&#xff08;SSL&#xff09;和传输层安全&#xff08;TLS&#xff09;协议 11.1.1 SSL/TLS 的工作原理 11.1.2 SSL/TLS 的应用场景 11.2 虚拟专用网&#xff08;VPN&#xff09;和 IP 安全协议&#xff08;IPSec&#xff09; 11.2.1 VPN 的工…...

ES-入门-javaApi-文档-新增-删除

新增指定索引的文档数据的代码如下&#xff1a; package com.atgulgu.es.test;import com.fasterxml.jackson.databind.ObjectMapper; import org.apache.http.HttpHost; import org.elasticsearch.action.index.IndexRequest; import org.elasticsearch.action.index.IndexRe…...

【视频生成大模型】 视频生成大模型 THUDM/CogVideoX-2b

【视频生成大模型】 视频生成大模型 THUDM/CogVideoX-2b CogVideoX-2b 模型介绍发布时间模型测试生成的demo视频生成视频限制 运行环境安装运行模型下载开源协议参考 CogVideoX-2b 模型介绍 CogVideoX是 清影 同源的开源版本视频生成模型。 基础信息&#xff1a; 发布时间 2…...

【MR开发】在Pico设备上接入MRTK3(三)——在Unity中运行MRTK示例

在前面的文档中&#xff0c;介绍了如何在Unity工程中配置号MRTK和Pico SDK 【MR开发】在Pico设备上接入MRTK3&#xff08;一&#xff09;在Unity中导入MRTK3依赖【MR开发】在Pico设备上接入MRTK3&#xff08;二&#xff09;在Unity中配置Pico SDK 本文将介绍如何运行一个简单…...