当前位置: 首页 > news >正文

用huggingface.Accelerate进行分布式训练

诸神缄默不语-个人CSDN博文目录

本文属于huggingface.transformers全部文档学习笔记博文的一部分。
全文链接:huggingface transformers包 文档学习笔记(持续更新ing…)

本部分网址:https://huggingface.co/docs/transformers/main/en/accelerate
本文介绍如何使用huggingface.accelerate(官方文档:https://huggingface.co/docs/accelerate/index)进行分布式训练。

此外还参考了accelerate的安装文档:https://huggingface.co/docs/accelerate/basic_tutorials/install

一个本文代码可用的Python环境:Python 3.9.7, PyTorch 2.0.1, transformers 4.31.0, accelerate 0.22.0

parallelism能让我们实现在硬件条件受限时训练更大的模型,训练速度能加快几个数量级。

文章目录

  • 1. 安装与配置
  • 2. 在代码中使用

1. 安装与配置

安装:pip install accelerate

配置:accelerate config
然后它会给出一些问题,通过上下键更换选项,用Enter确定
在这里插入图片描述

选错了也没啥关系,反正能改

accelerate env命令可以查看配置环境。

2. 在代码中使用

用accelerate之前的脚本(具体讲解可见我之前写的博文:用huggingface.transformers.AutoModelForSequenceClassification在文本分类任务上微调预训练模型 用的是原生PyTorch那一版,因为Trainer会自动使用分布式训练。metric部分改成新版,并用全部数据来训练):

from tqdm.auto import tqdmimport torch
from torch.utils.data import DataLoader
from torch.optim import AdamWimport datasets,evaluate
from transformers import AutoTokenizer,AutoModelForSequenceClassification,get_schedulerdataset=datasets.load_from_disk("download/yelp_full_review_disk")tokenizer=AutoTokenizer.from_pretrained("/data/pretrained_models/bert-base-cased")def tokenize_function(examples):return tokenizer(examples["text"], padding="max_length",truncation=True,max_length=512)tokenized_datasets=dataset.map(tokenize_function, batched=True)#Postprocess dataset
tokenized_datasets=tokenized_datasets.remove_columns(["text"])
#删除模型不用的text列tokenized_datasets=tokenized_datasets.rename_column("label", "labels")
#改名label列为labels,因为AutoModelForSequenceClassification的入参键名为label
#我不知道为什么dataset直接叫label就可以啦……tokenized_datasets.set_format("torch")  #将值转换为torch.Tensor对象small_train_dataset=tokenized_datasets["train"].shuffle(seed=42)
small_eval_dataset=tokenized_datasets["test"].shuffle(seed=42)train_dataloader=DataLoader(small_train_dataset,shuffle=True,batch_size=32)
eval_dataloader=DataLoader(small_eval_dataset,batch_size=64)model=AutoModelForSequenceClassification.from_pretrained("/data/pretrained_models/bert-base-cased",num_labels=5)optimizer=AdamW(model.parameters(),lr=5e-5)num_epochs=3
num_training_steps=num_epochs*len(train_dataloader)
lr_scheduler=get_scheduler(name="linear",optimizer=optimizer,num_warmup_steps=0,num_training_steps=num_training_steps)device=torch.device("cuda:1") if torch.cuda.is_available() else torch.device("cpu")
model.to(device)progress_bar = tqdm(range(num_training_steps))model.train()
for epoch in range(num_epochs):for batch in train_dataloader:batch={k:v.to(device) for k,v in batch.items()}outputs=model(**batch)loss=outputs.lossloss.backward()optimizer.step()lr_scheduler.step()optimizer.zero_grad()progress_bar.update(1)metric=evaluate.load("accuracy")
model.eval()
for batch in eval_dataloader:batch={k:v.to(device) for k,v in batch.items()}with torch.no_grad():outputs=model(**batch)logits=outputs.logitspredictions=torch.argmax(logits, dim=-1)metric.add_batch(predictions=predictions, references=batch["labels"])print(metric.compute())

懒得跑完了,总之预计要跑11个小时来着,非常慢。

添加如下代码:

from accelerate import Acceleratoraccelerator = Accelerator()#去掉将模型和数据集放到指定卡上的代码#在建立好数据集、模型和优化器之后:
train_dataloader, eval_dataloader, model, optimizer = accelerator.prepare(train_dataloader, eval_dataloader, model, optimizer
)#训练阶段将loss.backward()替换成
accelerator.backward(loss)

添加后的代码(我用全部数据集出来预计训练时间是4小时(3张卡),但我懒得跑这么久了,我就还是用1000条跑跑,把整个流程跑完意思一下):
accelerate launch Python脚本路径运行
验证部分的情况见代码后面

from tqdm.auto import tqdmimport torch
from torch.utils.data import DataLoader
from torch.optim import AdamWimport datasets
from transformers import AutoTokenizer,AutoModelForSequenceClassification,get_schedulerfrom accelerate import Acceleratoraccelerator = Accelerator()dataset=datasets.load_from_disk("download/yelp_full_review_disk")tokenizer=AutoTokenizer.from_pretrained("/data/pretrained_models/bert-base-cased")def tokenize_function(examples):return tokenizer(examples["text"], padding="max_length",truncation=True,max_length=512)tokenized_datasets=dataset.map(tokenize_function, batched=True)#Postprocess dataset
tokenized_datasets=tokenized_datasets.remove_columns(["text"])
#删除模型不用的text列tokenized_datasets=tokenized_datasets.rename_column("label", "labels")
#改名label列为labels,因为AutoModelForSequenceClassification的入参键名为label
#我不知道为什么dataset直接叫label就可以啦……tokenized_datasets.set_format("torch")  #将值转换为torch.Tensor对象small_train_dataset=tokenized_datasets["train"].shuffle(seed=42).select(range(1000))
small_eval_dataset=tokenized_datasets["test"].shuffle(seed=42).select(range(1000))train_dataloader=DataLoader(small_train_dataset,shuffle=True,batch_size=32)
eval_dataloader=DataLoader(small_eval_dataset,batch_size=64)model=AutoModelForSequenceClassification.from_pretrained("/data/pretrained_models/bert-base-cased",num_labels=5)optimizer=AdamW(model.parameters(),lr=5e-5)train_dataloader, eval_dataloader, model, optimizer = accelerator.prepare(train_dataloader, eval_dataloader, model, optimizer
)num_epochs=3
num_training_steps=num_epochs*len(train_dataloader)
lr_scheduler=get_scheduler(name="linear",optimizer=optimizer,num_warmup_steps=0,num_training_steps=num_training_steps)progress_bar = tqdm(range(num_training_steps))model.train()
for epoch in range(num_epochs):for batch in train_dataloader:outputs=model(**batch)loss=outputs.lossaccelerator.backward(loss)optimizer.step()lr_scheduler.step()optimizer.zero_grad()progress_bar.update(1)

验证部分是这样的,直接用原来的验证部分就也能跑,但是因为脚本会被运行2遍,所以验证部分也会运行2遍。
所以我原则上建议用accelerate的话就光训练,验证的部分还是单卡实现。
如果还是想在训练过程中看一下验证效果,可以正常验证;也可以将验证部分限定在if accelerator.is_main_process:里,这样就只有主进程(通常是第一个GPU)会执行验证代码,而其他GPU不会,这样就只会打印一次指标了。

相关文章:

用huggingface.Accelerate进行分布式训练

诸神缄默不语-个人CSDN博文目录 本文属于huggingface.transformers全部文档学习笔记博文的一部分。 全文链接:huggingface transformers包 文档学习笔记(持续更新ing…) 本部分网址:https://huggingface.co/docs/transformers/m…...

unity 物体至视图中心以及新对象创建位置

如果游戏对象不在视野中心或在视野之外, 一种方法是双击Hierarchy中的对象名称 另一种是选中后按F 新建物体时对象的位置不是在坐标原点,而是在当前屏幕的中心...

船舶稳定性和静水力计算——绘图体平面图,静水力,GZ计算(Matlab代码实现)

💥💥💞💞欢迎来到本博客❤️❤️💥💥 🏆博主优势:🌞🌞🌞博客内容尽量做到思维缜密,逻辑清晰,为了方便读者。 ⛳️座右铭&a…...

Python 网页爬虫的原理是怎样的?

网页爬虫是一种自动化工具,用于从互联网上获取和提取信息。它们被广泛用于搜索引擎、数据挖掘、市场研究等领域。 网页爬虫的工作原理可以分为以下几个步骤:URL调度、页面下载、页面解析和数据提取。 URL调度: 网页爬虫首先需要一个初始的U…...

python技术面试题合集(二)

python技术面试题 1、简述django FBV和CBV FBV是基于函数编程,CBV是基于类编程,本质上也是FBV编程,在Djanog中使用CBV,则需要继承View类,在路由中指定as_view函数,返回的还是一个函数 在DRF中的使用的就是…...

【linux命令讲解大全】089.使用tree命令快速查看目录结构的方法

文章目录 tree补充说明语法选项列表选项文件选项排序选项图形选项XML / HTML / JSON 选项杂项选项 参数实例 从零学 python tree 树状图列出目录的内容 补充说明 tree 命令以树状图列出目录的内容。 语法 tree [选项] [参数]选项 列表选项 -a:显示所有文件和…...

【C++】—— 单例模式详解

前言: 本期,我将要讲解的是有关C中常见的设计模式之单例模式的相关知识!! 目录 (一)设计模式的六⼤原则 (二)设计模式的分类 (三)单例模式 1、定义 2、…...

TheRouter 框架原理

TheRouter 框架入口方法 通过InnerTheRouterContentProvider 注册在AndroidManifest.xml中&#xff0c;在应用启动时初始化 <application><providerandroid:name"com.therouter.InnerTheRouterContentProvider"android:authorities"${applicationId}.…...

系列十二、Java操作RocketMQ之带标签Tag的消息

一、带标签的Tag消息 1.1、概述 RocketMQ提供消息过滤的功能&#xff0c;通过Tag或者Key进行区分。我们往一个主题里面发送消息的时候&#xff0c;根据业务逻辑可能需要区分&#xff0c;比如带有tagA标签的消息被消费者A消费&#xff0c;带有tagB标签的消息被消费者B消费&…...

Java面向对象学习笔记-1

前言 “Java 学习笔记” 是为初学者和希望加深对Java编程语言的理解的人们编写的。Java是一门广泛应用于软件开发领域的强大编程语言&#xff0c;它的语法和概念对于初学者来说可能有些复杂。这份学习笔记的目的是帮助读者逐步学习Java的基本概念&#xff0c;并提供了一系列示…...

el-table根据data动态生成列和行

css //el-table-column加上fixed后会导致悬浮样式丢失&#xff0c;用下面方法可以避免 .el-table__body .el-table__row.hover-row td{background-color: #083a78 !important; } .el-table tbody tr:hover>td {background: #171F34 !important; }html <el-table ref&quo…...

【c++】如何有效地利用命名空间?

​ &#x1f331;博客主页&#xff1a;青竹雾色间 &#x1f618;博客制作不易欢迎各位&#x1f44d;点赞⭐收藏➕关注 ​✨人生如寄&#xff0c;多忧何为 ✨ 目录 前言什么是命名空间&#xff1f;命名空间的语法命名空间的使用避免命名冲突命名空间的嵌套总结 前言 当谈到C编…...

Go语言传参

为了让新手尽快熟悉go的使用,特记录此文,不必谢我,转载请注明! Go 语言中参数传递的各种效果,主要内容包括: 传值效果指针传递结构体传递map 传递channel 传递切片传递错误传递传递效果示例传递方式选择原文连接:https://mp.weixin.qq.com/s?__biz=MzA5Mzk4Njk1OA==&…...

SAP PI 配置SSL链接接口报错问题处理Peer certificate rejected by ChainVerifier

出现这种情况一般无非是没有正确导入证书或者证书过期的情况 第一种&#xff0c;如果没有导入证书的话&#xff0c;需要在NWA中的证书与验证-》CAs中导入管理员提供的证书&#xff0c;这里需要注意的是&#xff0c;需要导入完整的证书链。 第二种如果是证书过期的&#xff0c…...

【MyBatisⅡ】动态 SQL

目录 &#x1f392;1 if 标签 &#x1fad6;2 trim 标签 &#x1f460;3 where 标签 &#x1f9ba;4 set 标签 &#x1f3a8;5 foreach 标签 动态 sql 是Mybatis的强⼤特性之⼀&#xff0c;能够完成不同条件下不同的 sql 拼接。 在 xml 里面写判断条件。 动态SQL 在数据库里…...

音视频入门基础理论知识

文章目录 前言一、视频1、视频的概念2、常见的视频格式3、视频帧4、帧率5、色彩空间6、采用 YUV 的优势7、RGB 和 YUV 的换算 二、音频1、音频的概念2、采样率和采样位数①、采样率②、采样位数 3、音频编码4、声道数5、码率6、音频格式 三、编码1、为什么要编码2、视频编码①、…...

Pytorch中如何加载数据、Tensorboard、Transforms的使用

一、Pytorch中如何加载数据 在Pytorch中涉及到如何读取数据&#xff0c;主要是两个类一个类是Dataset、Dataloader Dataset 提供一种方式获取数据&#xff0c;及其对应的label。主要包含以下两个功能&#xff1a; 如何获取每一个数据以及label 告诉我们总共有多少的数据 Datal…...

python如何使用打开文件对话框选择文件?

python如何使用打开文件对话框选择文件&#xff1f; ━━━━━━━━━━━━━━━━━━━━━━ 在Python中&#xff0c;可以使用Tkinter库中的filedialog子模块来打开一个文件对话框以供用户选择文件。以下是一个简单的例子&#xff0c;演示如何使用tkinter.filedialog打…...

虚拟化和容器

文章目录 1 介绍1.1 简介1.2 虚拟化工作原理1.3 两大核心组件&#xff1a;QEMU、KVMQEMUKVM 1.4 发展历史1.5 虚拟化类型1.6 云计算与虚拟化1.7 HypervisorHypervisor分为两大类 1.8 虚拟化 VS 容器 2 虚拟化应用dockerdocker 与虚拟机的区别 K8Swine 参考 1 介绍 1.1 简介 虚…...

LeetCode-78-子集

题目描述&#xff1a; 给你一个整数数组 nums &#xff0c;数组中的元素 互不相同。返回该数组所有可能的子集&#xff08;幂集&#xff09;。 解集 不能 包含重复的子集。你可以按 任意顺序 返回解集。 题目链接&#xff1a;LeetCode-78-子集 解题思路&#xff1a;递归回溯 题…...

Java 语言特性(面试系列1)

一、面向对象编程 1. 封装&#xff08;Encapsulation&#xff09; 定义&#xff1a;将数据&#xff08;属性&#xff09;和操作数据的方法绑定在一起&#xff0c;通过访问控制符&#xff08;private、protected、public&#xff09;隐藏内部实现细节。示例&#xff1a; public …...

iPhone密码忘记了办?iPhoneUnlocker,iPhone解锁工具Aiseesoft iPhone Unlocker 高级注册版​分享

平时用 iPhone 的时候&#xff0c;难免会碰到解锁的麻烦事。比如密码忘了、人脸识别 / 指纹识别突然不灵&#xff0c;或者买了二手 iPhone 却被原来的 iCloud 账号锁住&#xff0c;这时候就需要靠谱的解锁工具来帮忙了。Aiseesoft iPhone Unlocker 就是专门解决这些问题的软件&…...

华为OD机试-食堂供餐-二分法

import java.util.Arrays; import java.util.Scanner;public class DemoTest3 {public static void main(String[] args) {Scanner in new Scanner(System.in);// 注意 hasNext 和 hasNextLine 的区别while (in.hasNextLine()) { // 注意 while 处理多个 caseint a in.nextIn…...

React---day11

14.4 react-redux第三方库 提供connect、thunk之类的函数 以获取一个banner数据为例子 store&#xff1a; 我们在使用异步的时候理应是要使用中间件的&#xff0c;但是configureStore 已经自动集成了 redux-thunk&#xff0c;注意action里面要返回函数 import { configureS…...

Qemu arm操作系统开发环境

使用qemu虚拟arm硬件比较合适。 步骤如下&#xff1a; 安装qemu apt install qemu-system安装aarch64-none-elf-gcc 需要手动下载&#xff0c;下载地址&#xff1a;https://developer.arm.com/-/media/Files/downloads/gnu/13.2.rel1/binrel/arm-gnu-toolchain-13.2.rel1-x…...

水泥厂自动化升级利器:Devicenet转Modbus rtu协议转换网关

在水泥厂的生产流程中&#xff0c;工业自动化网关起着至关重要的作用&#xff0c;尤其是JH-DVN-RTU疆鸿智能Devicenet转Modbus rtu协议转换网关&#xff0c;为水泥厂实现高效生产与精准控制提供了有力支持。 水泥厂设备众多&#xff0c;其中不少设备采用Devicenet协议。Devicen…...

基于鸿蒙(HarmonyOS5)的打车小程序

1. 开发环境准备 安装DevEco Studio (鸿蒙官方IDE)配置HarmonyOS SDK申请开发者账号和必要的API密钥 2. 项目结构设计 ├── entry │ ├── src │ │ ├── main │ │ │ ├── ets │ │ │ │ ├── pages │ │ │ │ │ ├── H…...

在golang中如何将已安装的依赖降级处理,比如:将 go-ansible/v2@v2.2.0 更换为 go-ansible/@v1.1.7

在 Go 项目中降级 go-ansible 从 v2.2.0 到 v1.1.7 具体步骤&#xff1a; 第一步&#xff1a; 修改 go.mod 文件 // 原 v2 版本声明 require github.com/apenella/go-ansible/v2 v2.2.0 替换为&#xff1a; // 改为 v…...

《信号与系统》第 6 章 信号与系统的时域和频域特性

目录 6.0 引言 6.1 傅里叶变换的模和相位表示 6.2 线性时不变系统频率响应的模和相位表示 6.2.1 线性与非线性相位 6.2.2 群时延 6.2.3 对数模和相位图 6.3 理想频率选择性滤波器的时域特性 6.4 非理想滤波器的时域和频域特性讨论 6.5 一阶与二阶连续时间系统 6.5.1 …...

Linux安全加固:从攻防视角构建系统免疫

Linux安全加固:从攻防视角构建系统免疫 构建坚不可摧的数字堡垒 引言:攻防对抗的新纪元 在日益复杂的网络威胁环境中,Linux系统安全已从被动防御转向主动免疫。2023年全球网络安全报告显示,高级持续性威胁(APT)攻击同比增长65%,平均入侵停留时间缩短至48小时。本章将从…...