当前位置: 首页 > news >正文

deepspeed+transformers模型微调

一、目录

  1. 代码讲解

二、实现。

1、代码讲解,trainer 实现。
transformers通过trainer 集成deepspeed功能,所以中需要进行文件配置,即可实现deepspeed的训练。
微调代码: 参数定义—>数据处理---->模型创建/评估方式---->trainer 框架训练
注意: V100 显卡,不包括float16 精度训练。

import deepspeed
deepspeed.ops.op_builder.CPUAdamBuilder().load()
import nltk
import torch
import evaluate
import datasets
import numpy as np
from nltk.tokenize import sent_tokenize
from torch.nn.utils.rnn import pad_sequence
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from transformers import Seq2SeqTrainer, Seq2SeqTrainingArguments
nltk.download("punkt")
import gc
import torch######################################定义参数####################################
dataset_name = "samsum" # 数据集名称
#model_name="google/flan-t5-xxl" # 模型名称
model_name="google/flan-t5-xl" # 模型名称
max_input_length = 256
max_gen_length = 128
output_dir = "checkpoints"
num_train_epochs = 5
learning_rate = 5e-5
deepspeed_config = "ds_config.json" #          deepspeed配置文件
per_device_train_batch_size=5 # batch size设置为1,因为太大导致OOM
per_device_eval_batch_size=5
gradient_accumulation_steps=10 # 由于单卡的batch size为1,为了扩展batch size,使用梯度累加#################################加载数据集,与数据预处理#########################################
tokenizer = AutoTokenizer.from_pretrained(model_name)
dataset = datasets.load_dataset(dataset_name)
print(dataset["train"][0])def preprocess(examples):dialogues = ["summarize:" + dia for dia in examples["dialogue"]]# summaries = [summ for summ in examples["summary"]]model_inputs = tokenizer(dialogues, max_length=max_input_length, truncation=True)labels = tokenizer(text_target=examples["summary"], max_length=max_gen_length, truncation=True)model_inputs["labels"] = labels["input_ids"]return model_inputstokenized_dataset = dataset.map(preprocess, batched=True, remove_columns=["dialogue", "summary", "id"])
# print(tokenized_dataset["train"]["input_ids"][0]) # 打印结果    对map后的数据进行查看。def collate_fn(features):batch_input_ids = [torch.LongTensor(feature["input_ids"]) for feature in features]batch_attention_mask = [torch.LongTensor(feature["attention_mask"]) for feature in features]batch_labels = [torch.LongTensor(feature["labels"]) for feature in features]batch_input_ids = pad_sequence(batch_input_ids, batch_first=True, padding_value=tokenizer.pad_token_id)batch_attention_mask = pad_sequence(batch_attention_mask, batch_first=True, padding_value=0)batch_labels = pad_sequence(batch_labels, batch_first=True, padding_value=-100)return {"input_ids": batch_input_ids,"attention_mask": batch_attention_mask,"labels": batch_labels}##############################加载模型,采用seq2seqLM模型,并进行测试##############################
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)#用于测试的代码
#dataloader = DataLoader(tokenized_dataset["test"], shuffle=False, batch_size=4, collate_fn=collate_fn)
# batch = next(iter(dataloader))
# #print(batch)
# # 用于测试的代码
# dataloader = DataLoader(tokenized_dataset["test"], shuffle=False, batch_size=4, collate_fn=collate_fn)
# batch = next(iter(dataloader))
# output = model(**batch)
#print(output)
#############################################模型训练,并采用trainer 架构####################################
print("==========train....================")
metric = evaluate.load("rouge")
def compute_metrics(eval_preds):preds, labels = eval_predsif isinstance(preds, tuple):preds = preds[0]decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)labels = np.where(labels != -100, labels, tokenizer.pad_token_id)decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)decoded_preds = ["\n".join(sent_tokenize(pred.strip())) for pred in decoded_preds]decoded_labels = ["\n".join(sent_tokenize(label.strip())) for label in decoded_labels]result = metric.compute(predictions=decoded_preds, references=decoded_labels, use_stemmer=True)result = {k: round(v * 100, 4) for k, v in result.items()}prediction_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) for pred in preds]result["gen_len"] = np.mean(prediction_lens)return resulttraining_args = Seq2SeqTrainingArguments(output_dir=output_dir,per_device_train_batch_size=per_device_train_batch_size,per_device_eval_batch_size=per_device_eval_batch_size,gradient_accumulation_steps=gradient_accumulation_steps,eval_accumulation_steps=1, # 防止评估时导致OOMpredict_with_generate=True,learning_rate=learning_rate,num_train_epochs=num_train_epochs,# logging & evaluation strategieslogging_dir="logs",logging_strategy="steps",logging_steps=50, # 每50个step打印一次logevaluation_strategy="steps",eval_steps=500, # 每500个step进行一次评估save_steps=500,save_total_limit=2,load_best_model_at_end=True,deepspeed=deepspeed_config,                        # deepspeed配置文件的位置report_to="all"
)trainer = Seq2SeqTrainer(model=model,args=training_args,train_dataset=tokenized_dataset["train"],eval_dataset=tokenized_dataset["validation"],data_collator=collate_fn,compute_metrics=compute_metrics,
)
trainer.train()
gc.collect()
torch.cuda.empty_cache()
# 打印验证集上的结果
# print(trainer.evaluate(tokenized_dataset["validation"]))
# # 打印测试集上的结果
# print(trainer.evaluate(tokenized_dataset["test"]))
# 保存最优模型
trainer.save_model("best.pt")
#export NCCL_IB_DISABLE=1; export NCCL_P2P_DISABLE=1; NCCL_DEBUG=INFO deepspeed --include=localhost:0,1 test1.py

配置文件:ds_config.json

{"fp16": {"enabled": "auto"},"optimizer": {"type": "AdamW","params": {"lr": "auto","betas": "auto","eps": "auto","weight_decay": "auto"}},"scheduler": {"type": "WarmupLR","params": {"warmup_min_lr": "auto","warmup_max_lr": "auto","warmup_num_steps": "auto"}},"zero_optimization": {"stage": 3,"offload_optimizer": {"device": "cpu","pin_memory": true},"offload_param": {"device": "cpu","pin_memory": true},"overlap_comm": true,"contiguous_gradients": true,"sub_group_size": 1e9,"reduce_bucket_size": "auto","stage3_prefetch_bucket_size": "auto","stage3_param_persistence_threshold": "auto","stage3_max_live_parameters": 1e9,"stage3_max_reuse_distance": 1e9,"stage3_gather_16bit_weights_on_model_save": false},"gradient_accumulation_steps": "auto","gradient_clipping": "auto","steps_per_print": 2000,"train_batch_size": "auto","train_micro_batch_size_per_gpu": "auto","wall_clock_breakdown": false
}

启动: 单机多卡

export NCCL_IB_DISABLE=1; export NCCL_P2P_DISABLE=1; NCCL_DEBUG=INFO deepspeed --include=localhost:0,1 test1.py>output.log 2>&1 &

二、代码讲解,peft微调+trainer 实现。

import os
import torch
import random
import datasets
import numpy as np
from typing import Dict
from transformers import (AutoModelForCausalLM,AutoTokenizer,DataCollatorForSeq2Seq,TrainingArguments,Trainer
)
from peft import (LoraConfig,TaskType,get_peft_model,get_peft_model_state_dict,
)def set_random_seed(seed):if seed is not None and seed > 0:random.seed(seed)np.random.seed(seed)torch.manual_seed(seed)torch.random.manual_seed(seed)torch.cuda.manual_seed(seed)torch.cuda.manual_seed_all(seed)torch.backends.cudnn.deterministic = Trueset_random_seed(1234)# 1. 设置参数
# LoRA参数
LORA_R = 8
LORA_ALPHA = 32
LORA_DROPOUT = 0.1
# 训练参数
EPOCHS=3
LEARNING_RATE=5e-5
OUTPUT_DIR="./checkpoints"
BATCH_SIZE=4 # 2
GRADIENT_ACCUMULATION_STEPS=3
# 其他参数
MODEL_PATH = "bigscience/bloomz-7b1-mt"
DATA_PATH = "./data/belle_open_source_1M.train.json"
MAX_LENGTH = 512
PATTERN = "{}\n{}"
DS_CONFIG = "ds_zero2_config.json"
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH) # 加载tokenizer
# 加载数据
dataset = datasets.load_dataset("json", data_files=DATA_PATH)
# print(dataset["train"][0])# 2. tokenize
def tokenize(text: str, add_eos_token=True):result = tokenizer(text,truncation=True,max_length=MAX_LENGTH,padding=False,return_tensors=None)# 判断是否要添加eos_tokenif (result["input_ids"][-1] != tokenizer.eos_token_idand len(result["input_ids"]) < MAX_LENGTHand add_eos_token):result["input_ids"].append(tokenizer.eos_token_id)result["attention_mask"].append(1)result["labels"] = result["input_ids"].copy()return resultdef preprocess(example: Dict, train_on_inputs: bool = False):prompt = example["input"]response = example["target"]text = PATTERN.format(prompt, response)tokenized_inp = tokenize(text)# 若train_on_inputs为False,则将label中与input相关的token替换为-100if not train_on_inputs:tokenized_prompt = tokenize(prompt,add_eos_token=False)prompt_tokens_len = len(tokenized_prompt["input_ids"])tokenized_inp["labels"] = [-100]*prompt_tokens_len + tokenized_inp["labels"][prompt_tokens_len:]return tokenized_inptrain_data = dataset["train"].shuffle().map(preprocess, remove_columns=["id", "input", "target"])
print(train_data[0])# pad_to_multiple_of=8表示padding的长度是8的倍数
collate_fn = DataCollatorForSeq2Seq(tokenizer, pad_to_multiple_of=8, return_tensors="pt", padding=True)# 2. 加载模型
evice_map = {"": int(os.environ.get("LOCAL_RANK") or 0)}
# device_map指定模型加载的GPU;troch_dtype=torch.float16表示半精度加载模型
model = AutoModelForCausalLM.from_pretrained(MODEL_PATH, torch_dtype=torch.float16, device_map=device_map)# 3. LoRA相关
lora_config = LoraConfig(task_type=TaskType.CAUSAL_LM,inference_mode=False,r=LORA_R,                         # LoRA中低秩近似的秩lora_alpha=LORA_ALPHA,            # 见上文中的低秩矩阵缩放超参数lora_dropout=LORA_DROPOUT,       # LoRA层的dropout
)
# 转换模型
model = get_peft_model(model, lora_config)
model.config.use_cache = False
old_state_dict = model.state_dict
model.state_dict = (lambda self, *_, **__: get_peft_model_state_dict(self, old_state_dict())
).__get__(model, type(model))
# 打印模型中的可训练参数
model.print_trainable_parameters()# 4. 训练参数
args = TrainingArguments(output_dir=OUTPUT_DIR, # checkpoint的存储目录per_device_train_batch_size=BATCH_SIZE, # 单设备上的batch sizegradient_accumulation_steps=GRADIENT_ACCUMULATION_STEPS, # 梯度累加的step数warmup_steps=100,num_train_epochs=EPOCHS,learning_rate=LEARNING_RATE,fp16=True, # 使用混合精度训练logging_steps=50,evaluation_strategy="no", # 不进行评估save_strategy="steps",save_steps=2000, # 保存checkpoint的step数save_total_limit=5, # 最多保存5个checkpointdeepspeed=DS_CONFIG                               #deepspeed 配置
)# 5. 模型训练
trainer = Trainer(model=model,train_dataset=train_data,eval_dataset=None,args=args,data_collator=collate_fn
)
trainer.train()
model.save_pretrained("best_model")
{"train_micro_batch_size_per_gpu": "auto","gradient_accumulation_steps": "auto","steps_per_print": 50,"gradient_clipping": 1.0,"zero_optimization": {"stage": 2,"offload_optimizer": {"device": "cpu"},"contiguous_gradients": true,"overlap_comm": true},"zero_allow_untested_optimizer": true,"fp16": {"enabled": true,"loss_scale": 0,"loss_scale_window": 1000,"hysteresis": 2,"min_loss_scale": 1},"optimizer": {"type": "Adam","params": {"lr": "auto","betas": "auto","eps": "auto","weight_decay": "auto"}},"activation_checkpointing": {"partition_activations": true,"contiguous_memory_optimization": true},"wall_clock_breakdown": false
}

相关文章:

deepspeed+transformers模型微调

一、目录 代码讲解 二、实现。 1、代码讲解&#xff0c;trainer 实现。 transformers通过trainer 集成deepspeed功能&#xff0c;所以中需要进行文件配置&#xff0c;即可实现deepspeed的训练。 微调代码&#xff1a; 参数定义—>数据处理---->模型创建/评估方式----&…...

无人机摄影测量数据处理、三维建模及在土方量计算中的应用

专题一、无人机摄影测量技术应用现状及其发展 1、无人机摄影测量技术概述 2、摄影测量系统的发展 3、无人机摄影测量技术应用分析 专题二、基本原理和关键技术讲解 1、摄影测量基础知识 1&#xff09;航空摄影 2&#xff09;航摄像片的方位元素 3&#xff09;共…...

《ESP8266通信指南》15-MQTT连接、订阅MQTT主题并打印消息(基于Lua|适合新手|非常简单)

往期 《ESP8266通信指南》14-连接WIFI&#xff08;基于Lua&#xff09;-CSDN博客 《ESP8266通信指南》13-Lua 简单入门&#xff08;打印数据&#xff09;-CSDN博客 《ESP8266通信指南》12-Lua 固件烧录-CSDN博客 《ESP8266通信指南》11-Lua开发环境配置-CSDN博客 《ESP826…...

LeetCode:两数之和

文章收录于LeetCode专栏 LeetCode地址 两数之和 给定一个整数数组nums和一个整数目标值target&#xff0c;请你在该数组中找出和为目标值的那两个整数&#xff0c;并返回它们的数组下标。   你可以假设每种输入只会对应一个答案。但是数组中同一个元素在答案里不能重复出现。…...

CSDN我的创作纪念日128天||不忘初心|努力上进|勇往直前

机缘 Hello&#xff0c;大家好&#xff0c;我是景天&#xff0c;其实很早之前我就加入到了CSND的大军&#xff0c;彼时我还是个刚毕业的小白白&#xff0c;时常过来CSND汲取养料&#xff0c;就这样&#xff0c;慢慢的来提升自己&#xff0c;强大自己。工作锻炼了我&#xff0c…...

MySQL数据库中的浮点类型和高精度类型有什么区别?为什么不推荐使用浮点类型?

在软件开发中&#xff0c;作为后端&#xff0c;无可避免的需要熟练使用 MySQL 数据库进行数据存储和读取。对于信息系统而言&#xff0c;数据库的的地位不言而喻。那作为软件开发工程师&#xff0c;在使用 MySQL 过程中&#xff0c;又有哪些需要注意的呢&#xff1f;我们从实际…...

C++ 抽象与封装

一 抽象 抽象实例&#xff1a;时钟 数据抽象&#xff1a; 具有表面当前时间的时、分、秒 行为抽象&#xff1a; 具有设置时间和显示时间两个最基本的功能。 抽象实例&#xff1a;人 数据抽象&#xff1a;姓名、年龄、性别等。 行为抽象&#xff1a; 生物属性&#xff1a;吃…...

antV X6的简要使用教程

&#x1f9d1;‍&#x1f393; 个人主页&#xff1a;《爱蹦跶的大A阿》 &#x1f525;当前正在更新专栏&#xff1a;《VUE》 、《JavaScript保姆级教程》、《krpano》、《krpano中文文档》 ​ ​ ✨ 前言 在我们的日常开发工作中&#xff0c;我们经常需要构建复杂的交互式图…...

【LLM 论文】Step-Back Prompting:先解决更高层次的问题来提高 LLM 推理能力

论文&#xff1a;Take a Step Back: Evoking Reasoning via Abstraction in Large Language Models ⭐⭐⭐⭐ Google DeepMind, ICLR 2024, arXiv:2310.06117 论文速读 该论文受到的启发是&#xff1a;人类再解决一个包含很多细节的具体问题时&#xff0c;先站在更高的层次上解…...

Java——接口的补充

目录 一&#xff1a;接口的注意事项 1. 接口中不能有方法块&#xff1b; 2. 接口没有构造方法&#xff1a; 3.接口是可以多继承的&#xff1b; 4. 多个接口抽象方法重复 5. 类的父类方法与接口方法重复 二&#xff1a;类与接口 1. 继承与实现 2. 多个父接口的抽象…...

word转pdf的java实现(documents4j)

一、多余的话 java实现word转pdf可用的jar包不多&#xff0c;很多都是收费的。最近发现com.documents4j挺好用的&#xff0c;它支持在本机转换&#xff0c;也支持远程服务转换。但它依赖于微软的office。电脑需要安装office才能转换。鉴于没在linux中使用office&#xff0c;本…...

基于K8S构建Jenkins持续集成平台

文章目录 安装和配置NFSNFS简介NFS安装 在Kubernetes安装Jenkins-Master创建NFS client provisioner安装Jenkins-Master Jenkins与Kubernetes整合实现Jenkins与Kubernetes整合构建Jenkins-Slave自定义镜像 JenkinsKubernetesDocker完成微服务持续集成拉取代码&#xff0c;构建镜…...

PHPStudy 访问网页 403 Forbidden禁止访问

涉及靶场 upload-labd sqli-labs pikachu dvwa 以及所有部署在phpstudy中的靶场 注意&#xff1a;一定要安装解压软件 很多同学解压靶场代码以后访问报错的原因是&#xff1a;电脑上没有解压软件。 这个时候压缩包看起来就是黄色公文包的样子&#xff0c;右键只有“全部提取…...

热爱电子值得做的电子制作实验

加我zkhengyang&#xff0c;进嵌入式音频系统研究开发交流答疑群(课题组) AM/FM收音机散件制作&#xff0c;磁带随声听散件&#xff0c;黑白电视机散件制作&#xff0c;功放散件制作&#xff0c;闪光灯散件制作&#xff0c;声控灯散件&#xff0c;等等&#xff0c;可提高动手能…...

.class文件启动过程以及文件内容结构讲解

当你直接启动一个.class文件时&#xff0c;实际上是在操作系统中调用Java虚拟机&#xff08;JVM&#xff09;&#xff0c;并将该.class文件传递给JVM以执行。现在让我们来解释一下.class文件的启动过程以及文件内容结构&#xff1a; 启动过程&#xff1a;操作系统通过指定的命…...

解锁楼宇自动化新维度西门子Insight+BACnet IP I/O控制器

数字城市的楼宇自动化已不再是一个遥不可及的概念&#xff0c;而是成为了现代建筑的标配。特别是在大型商业综合体、高端写字楼和公共设施中&#xff0c;高效的楼宇管理系统是确保环境舒适度与能源效率的关键。当提及楼宇自动化领域的佼佼者&#xff0c;西门子Insight楼宇自动化…...

2024.05.10作业

TCP服务器 头文件 #ifndef WIDGET_H #define WIDGET_H#include <QWidget> #include <QTcpServer> #include <QTcpSocket> #include <QList> #include <QMessageBox> #include <QDebug>QT_BEGIN_NAMESPACE namespace Ui { class Widget; …...

基于POSIX标准库的读者-写者问题的简单实现

文章目录 实验要求分析保证读写、写写互斥保证多个读者同时进行读操作读者优先实例代码分析写者优先读写公平法示例代码分析实验要求 创建一个控制台进程,此进程包含n个线程。用这n个线程来表示n个读者或写者。每个线程按相应测试数据文件的要求进行读写操作。用信号量机制分别…...

重生我是嵌入式大能之串口调试UART

什么是串口 串口是一种在数据通讯中广泛使用的通讯接口&#xff0c;通常我们叫做UART (通用异步收发传输器Universal Asynchronous Receiver/Transmitter)&#xff0c;其具有数据传输速度稳定、可靠性高、适用范围广等优点。在嵌入式系统中&#xff0c;串口常用于与外部设备进…...

【智能优化算法】蜜獾优化算法(Honey Badger Algorithm,HBA)

蜜獾优化算法(Honey Badger Algorithm,HBA)是期刊“MATHEMATICS AND COMPUTERS IN SIMULATION”&#xff08;IF 3.6&#xff09;的2022年智能优化算法 01.引言 蜜獾优化算法(Honey Badger Algorithm,HBA)受蜜獾智能觅食行为的启发&#xff0c;从数学上发展出一种求解优化问题的…...

《用户共鸣指数(E)驱动品牌大模型种草:如何抢占大模型搜索结果情感高地》

在注意力分散、内容高度同质化的时代&#xff0c;情感连接已成为品牌破圈的关键通道。我们在服务大量品牌客户的过程中发现&#xff0c;消费者对内容的“有感”程度&#xff0c;正日益成为影响品牌传播效率与转化率的核心变量。在生成式AI驱动的内容生成与推荐环境中&#xff0…...

Rust 异步编程

Rust 异步编程 引言 Rust 是一种系统编程语言,以其高性能、安全性以及零成本抽象而著称。在多核处理器成为主流的今天,异步编程成为了一种提高应用性能、优化资源利用的有效手段。本文将深入探讨 Rust 异步编程的核心概念、常用库以及最佳实践。 异步编程基础 什么是异步…...

在鸿蒙HarmonyOS 5中使用DevEco Studio实现录音机应用

1. 项目配置与权限设置 1.1 配置module.json5 {"module": {"requestPermissions": [{"name": "ohos.permission.MICROPHONE","reason": "录音需要麦克风权限"},{"name": "ohos.permission.WRITE…...

技术栈RabbitMq的介绍和使用

目录 1. 什么是消息队列&#xff1f;2. 消息队列的优点3. RabbitMQ 消息队列概述4. RabbitMQ 安装5. Exchange 四种类型5.1 direct 精准匹配5.2 fanout 广播5.3 topic 正则匹配 6. RabbitMQ 队列模式6.1 简单队列模式6.2 工作队列模式6.3 发布/订阅模式6.4 路由模式6.5 主题模式…...

08. C#入门系列【类的基本概念】:开启编程世界的奇妙冒险

C#入门系列【类的基本概念】&#xff1a;开启编程世界的奇妙冒险 嘿&#xff0c;各位编程小白探险家&#xff01;欢迎来到 C# 的奇幻大陆&#xff01;今天咱们要深入探索这片大陆上至关重要的 “建筑”—— 类&#xff01;别害怕&#xff0c;跟着我&#xff0c;保准让你轻松搞…...

Windows安装Miniconda

一、下载 https://www.anaconda.com/download/success 二、安装 三、配置镜像源 Anaconda/Miniconda pip 配置清华镜像源_anaconda配置清华源-CSDN博客 四、常用操作命令 Anaconda/Miniconda 基本操作命令_miniconda创建环境命令-CSDN博客...

解析奥地利 XARION激光超声检测系统:无膜光学麦克风 + 无耦合剂的技术协同优势及多元应用

在工业制造领域&#xff0c;无损检测&#xff08;NDT)的精度与效率直接影响产品质量与生产安全。奥地利 XARION开发的激光超声精密检测系统&#xff0c;以非接触式光学麦克风技术为核心&#xff0c;打破传统检测瓶颈&#xff0c;为半导体、航空航天、汽车制造等行业提供了高灵敏…...

掌握 HTTP 请求:理解 cURL GET 语法

cURL 是一个强大的命令行工具&#xff0c;用于发送 HTTP 请求和与 Web 服务器交互。在 Web 开发和测试中&#xff0c;cURL 经常用于发送 GET 请求来获取服务器资源。本文将详细介绍 cURL GET 请求的语法和使用方法。 一、cURL 基本概念 cURL 是 "Client URL" 的缩写…...

【Veristand】Veristand环境安装教程-Linux RT / Windows

首先声明&#xff0c;此教程是针对Simulink编译模型并导入Veristand中编写的&#xff0c;同时需要注意的是老用户编译可能用的是Veristand Model Framework&#xff0c;那个是历史版本&#xff0c;且NI不会再维护&#xff0c;新版本编译支持为VeriStand Model Generation Suppo…...

从零开始了解数据采集(二十八)——制造业数字孪生

近年来&#xff0c;我国的工业领域正经历一场前所未有的数字化变革&#xff0c;从“双碳目标”到工业互联网平台的推广&#xff0c;国家政策和市场需求共同推动了制造业的升级。在这场变革中&#xff0c;数字孪生技术成为备受关注的关键工具&#xff0c;它不仅让企业“看见”设…...