当前位置: 首页 > news >正文

用huggingface.Accelerate进行分布式训练

诸神缄默不语-个人CSDN博文目录

本文属于huggingface.transformers全部文档学习笔记博文的一部分。
全文链接:huggingface transformers包 文档学习笔记(持续更新ing…)

本部分网址:https://huggingface.co/docs/transformers/main/en/accelerate
本文介绍如何使用huggingface.accelerate(官方文档:https://huggingface.co/docs/accelerate/index)进行分布式训练。

此外还参考了accelerate的安装文档:https://huggingface.co/docs/accelerate/basic_tutorials/install

一个本文代码可用的Python环境:Python 3.9.7, PyTorch 2.0.1, transformers 4.31.0, accelerate 0.22.0

parallelism能让我们实现在硬件条件受限时训练更大的模型,训练速度能加快几个数量级。

文章目录

  • 1. 安装与配置
  • 2. 在代码中使用

1. 安装与配置

安装:pip install accelerate

配置:accelerate config
然后它会给出一些问题,通过上下键更换选项,用Enter确定
在这里插入图片描述

选错了也没啥关系,反正能改

accelerate env命令可以查看配置环境。

2. 在代码中使用

用accelerate之前的脚本(具体讲解可见我之前写的博文:用huggingface.transformers.AutoModelForSequenceClassification在文本分类任务上微调预训练模型 用的是原生PyTorch那一版,因为Trainer会自动使用分布式训练。metric部分改成新版,并用全部数据来训练):

from tqdm.auto import tqdmimport torch
from torch.utils.data import DataLoader
from torch.optim import AdamWimport datasets,evaluate
from transformers import AutoTokenizer,AutoModelForSequenceClassification,get_schedulerdataset=datasets.load_from_disk("download/yelp_full_review_disk")tokenizer=AutoTokenizer.from_pretrained("/data/pretrained_models/bert-base-cased")def tokenize_function(examples):return tokenizer(examples["text"], padding="max_length",truncation=True,max_length=512)tokenized_datasets=dataset.map(tokenize_function, batched=True)#Postprocess dataset
tokenized_datasets=tokenized_datasets.remove_columns(["text"])
#删除模型不用的text列tokenized_datasets=tokenized_datasets.rename_column("label", "labels")
#改名label列为labels,因为AutoModelForSequenceClassification的入参键名为label
#我不知道为什么dataset直接叫label就可以啦……tokenized_datasets.set_format("torch")  #将值转换为torch.Tensor对象small_train_dataset=tokenized_datasets["train"].shuffle(seed=42)
small_eval_dataset=tokenized_datasets["test"].shuffle(seed=42)train_dataloader=DataLoader(small_train_dataset,shuffle=True,batch_size=32)
eval_dataloader=DataLoader(small_eval_dataset,batch_size=64)model=AutoModelForSequenceClassification.from_pretrained("/data/pretrained_models/bert-base-cased",num_labels=5)optimizer=AdamW(model.parameters(),lr=5e-5)num_epochs=3
num_training_steps=num_epochs*len(train_dataloader)
lr_scheduler=get_scheduler(name="linear",optimizer=optimizer,num_warmup_steps=0,num_training_steps=num_training_steps)device=torch.device("cuda:1") if torch.cuda.is_available() else torch.device("cpu")
model.to(device)progress_bar = tqdm(range(num_training_steps))model.train()
for epoch in range(num_epochs):for batch in train_dataloader:batch={k:v.to(device) for k,v in batch.items()}outputs=model(**batch)loss=outputs.lossloss.backward()optimizer.step()lr_scheduler.step()optimizer.zero_grad()progress_bar.update(1)metric=evaluate.load("accuracy")
model.eval()
for batch in eval_dataloader:batch={k:v.to(device) for k,v in batch.items()}with torch.no_grad():outputs=model(**batch)logits=outputs.logitspredictions=torch.argmax(logits, dim=-1)metric.add_batch(predictions=predictions, references=batch["labels"])print(metric.compute())

懒得跑完了,总之预计要跑11个小时来着,非常慢。

添加如下代码:

from accelerate import Acceleratoraccelerator = Accelerator()#去掉将模型和数据集放到指定卡上的代码#在建立好数据集、模型和优化器之后:
train_dataloader, eval_dataloader, model, optimizer = accelerator.prepare(train_dataloader, eval_dataloader, model, optimizer
)#训练阶段将loss.backward()替换成
accelerator.backward(loss)

添加后的代码(我用全部数据集出来预计训练时间是4小时(3张卡),但我懒得跑这么久了,我就还是用1000条跑跑,把整个流程跑完意思一下):
accelerate launch Python脚本路径运行
验证部分的情况见代码后面

from tqdm.auto import tqdmimport torch
from torch.utils.data import DataLoader
from torch.optim import AdamWimport datasets
from transformers import AutoTokenizer,AutoModelForSequenceClassification,get_schedulerfrom accelerate import Acceleratoraccelerator = Accelerator()dataset=datasets.load_from_disk("download/yelp_full_review_disk")tokenizer=AutoTokenizer.from_pretrained("/data/pretrained_models/bert-base-cased")def tokenize_function(examples):return tokenizer(examples["text"], padding="max_length",truncation=True,max_length=512)tokenized_datasets=dataset.map(tokenize_function, batched=True)#Postprocess dataset
tokenized_datasets=tokenized_datasets.remove_columns(["text"])
#删除模型不用的text列tokenized_datasets=tokenized_datasets.rename_column("label", "labels")
#改名label列为labels,因为AutoModelForSequenceClassification的入参键名为label
#我不知道为什么dataset直接叫label就可以啦……tokenized_datasets.set_format("torch")  #将值转换为torch.Tensor对象small_train_dataset=tokenized_datasets["train"].shuffle(seed=42).select(range(1000))
small_eval_dataset=tokenized_datasets["test"].shuffle(seed=42).select(range(1000))train_dataloader=DataLoader(small_train_dataset,shuffle=True,batch_size=32)
eval_dataloader=DataLoader(small_eval_dataset,batch_size=64)model=AutoModelForSequenceClassification.from_pretrained("/data/pretrained_models/bert-base-cased",num_labels=5)optimizer=AdamW(model.parameters(),lr=5e-5)train_dataloader, eval_dataloader, model, optimizer = accelerator.prepare(train_dataloader, eval_dataloader, model, optimizer
)num_epochs=3
num_training_steps=num_epochs*len(train_dataloader)
lr_scheduler=get_scheduler(name="linear",optimizer=optimizer,num_warmup_steps=0,num_training_steps=num_training_steps)progress_bar = tqdm(range(num_training_steps))model.train()
for epoch in range(num_epochs):for batch in train_dataloader:outputs=model(**batch)loss=outputs.lossaccelerator.backward(loss)optimizer.step()lr_scheduler.step()optimizer.zero_grad()progress_bar.update(1)

验证部分是这样的,直接用原来的验证部分就也能跑,但是因为脚本会被运行2遍,所以验证部分也会运行2遍。
所以我原则上建议用accelerate的话就光训练,验证的部分还是单卡实现。
如果还是想在训练过程中看一下验证效果,可以正常验证;也可以将验证部分限定在if accelerator.is_main_process:里,这样就只有主进程(通常是第一个GPU)会执行验证代码,而其他GPU不会,这样就只会打印一次指标了。

相关文章:

用huggingface.Accelerate进行分布式训练

诸神缄默不语-个人CSDN博文目录 本文属于huggingface.transformers全部文档学习笔记博文的一部分。 全文链接:huggingface transformers包 文档学习笔记(持续更新ing…) 本部分网址:https://huggingface.co/docs/transformers/m…...

unity 物体至视图中心以及新对象创建位置

如果游戏对象不在视野中心或在视野之外, 一种方法是双击Hierarchy中的对象名称 另一种是选中后按F 新建物体时对象的位置不是在坐标原点,而是在当前屏幕的中心...

船舶稳定性和静水力计算——绘图体平面图,静水力,GZ计算(Matlab代码实现)

💥💥💞💞欢迎来到本博客❤️❤️💥💥 🏆博主优势:🌞🌞🌞博客内容尽量做到思维缜密,逻辑清晰,为了方便读者。 ⛳️座右铭&a…...

Python 网页爬虫的原理是怎样的?

网页爬虫是一种自动化工具,用于从互联网上获取和提取信息。它们被广泛用于搜索引擎、数据挖掘、市场研究等领域。 网页爬虫的工作原理可以分为以下几个步骤:URL调度、页面下载、页面解析和数据提取。 URL调度: 网页爬虫首先需要一个初始的U…...

python技术面试题合集(二)

python技术面试题 1、简述django FBV和CBV FBV是基于函数编程,CBV是基于类编程,本质上也是FBV编程,在Djanog中使用CBV,则需要继承View类,在路由中指定as_view函数,返回的还是一个函数 在DRF中的使用的就是…...

【linux命令讲解大全】089.使用tree命令快速查看目录结构的方法

文章目录 tree补充说明语法选项列表选项文件选项排序选项图形选项XML / HTML / JSON 选项杂项选项 参数实例 从零学 python tree 树状图列出目录的内容 补充说明 tree 命令以树状图列出目录的内容。 语法 tree [选项] [参数]选项 列表选项 -a:显示所有文件和…...

【C++】—— 单例模式详解

前言: 本期,我将要讲解的是有关C中常见的设计模式之单例模式的相关知识!! 目录 (一)设计模式的六⼤原则 (二)设计模式的分类 (三)单例模式 1、定义 2、…...

TheRouter 框架原理

TheRouter 框架入口方法 通过InnerTheRouterContentProvider 注册在AndroidManifest.xml中&#xff0c;在应用启动时初始化 <application><providerandroid:name"com.therouter.InnerTheRouterContentProvider"android:authorities"${applicationId}.…...

系列十二、Java操作RocketMQ之带标签Tag的消息

一、带标签的Tag消息 1.1、概述 RocketMQ提供消息过滤的功能&#xff0c;通过Tag或者Key进行区分。我们往一个主题里面发送消息的时候&#xff0c;根据业务逻辑可能需要区分&#xff0c;比如带有tagA标签的消息被消费者A消费&#xff0c;带有tagB标签的消息被消费者B消费&…...

Java面向对象学习笔记-1

前言 “Java 学习笔记” 是为初学者和希望加深对Java编程语言的理解的人们编写的。Java是一门广泛应用于软件开发领域的强大编程语言&#xff0c;它的语法和概念对于初学者来说可能有些复杂。这份学习笔记的目的是帮助读者逐步学习Java的基本概念&#xff0c;并提供了一系列示…...

el-table根据data动态生成列和行

css //el-table-column加上fixed后会导致悬浮样式丢失&#xff0c;用下面方法可以避免 .el-table__body .el-table__row.hover-row td{background-color: #083a78 !important; } .el-table tbody tr:hover>td {background: #171F34 !important; }html <el-table ref&quo…...

【c++】如何有效地利用命名空间?

​ &#x1f331;博客主页&#xff1a;青竹雾色间 &#x1f618;博客制作不易欢迎各位&#x1f44d;点赞⭐收藏➕关注 ​✨人生如寄&#xff0c;多忧何为 ✨ 目录 前言什么是命名空间&#xff1f;命名空间的语法命名空间的使用避免命名冲突命名空间的嵌套总结 前言 当谈到C编…...

Go语言传参

为了让新手尽快熟悉go的使用,特记录此文,不必谢我,转载请注明! Go 语言中参数传递的各种效果,主要内容包括: 传值效果指针传递结构体传递map 传递channel 传递切片传递错误传递传递效果示例传递方式选择原文连接:https://mp.weixin.qq.com/s?__biz=MzA5Mzk4Njk1OA==&…...

SAP PI 配置SSL链接接口报错问题处理Peer certificate rejected by ChainVerifier

出现这种情况一般无非是没有正确导入证书或者证书过期的情况 第一种&#xff0c;如果没有导入证书的话&#xff0c;需要在NWA中的证书与验证-》CAs中导入管理员提供的证书&#xff0c;这里需要注意的是&#xff0c;需要导入完整的证书链。 第二种如果是证书过期的&#xff0c…...

【MyBatisⅡ】动态 SQL

目录 &#x1f392;1 if 标签 &#x1fad6;2 trim 标签 &#x1f460;3 where 标签 &#x1f9ba;4 set 标签 &#x1f3a8;5 foreach 标签 动态 sql 是Mybatis的强⼤特性之⼀&#xff0c;能够完成不同条件下不同的 sql 拼接。 在 xml 里面写判断条件。 动态SQL 在数据库里…...

音视频入门基础理论知识

文章目录 前言一、视频1、视频的概念2、常见的视频格式3、视频帧4、帧率5、色彩空间6、采用 YUV 的优势7、RGB 和 YUV 的换算 二、音频1、音频的概念2、采样率和采样位数①、采样率②、采样位数 3、音频编码4、声道数5、码率6、音频格式 三、编码1、为什么要编码2、视频编码①、…...

Pytorch中如何加载数据、Tensorboard、Transforms的使用

一、Pytorch中如何加载数据 在Pytorch中涉及到如何读取数据&#xff0c;主要是两个类一个类是Dataset、Dataloader Dataset 提供一种方式获取数据&#xff0c;及其对应的label。主要包含以下两个功能&#xff1a; 如何获取每一个数据以及label 告诉我们总共有多少的数据 Datal…...

python如何使用打开文件对话框选择文件?

python如何使用打开文件对话框选择文件&#xff1f; ━━━━━━━━━━━━━━━━━━━━━━ 在Python中&#xff0c;可以使用Tkinter库中的filedialog子模块来打开一个文件对话框以供用户选择文件。以下是一个简单的例子&#xff0c;演示如何使用tkinter.filedialog打…...

虚拟化和容器

文章目录 1 介绍1.1 简介1.2 虚拟化工作原理1.3 两大核心组件&#xff1a;QEMU、KVMQEMUKVM 1.4 发展历史1.5 虚拟化类型1.6 云计算与虚拟化1.7 HypervisorHypervisor分为两大类 1.8 虚拟化 VS 容器 2 虚拟化应用dockerdocker 与虚拟机的区别 K8Swine 参考 1 介绍 1.1 简介 虚…...

LeetCode-78-子集

题目描述&#xff1a; 给你一个整数数组 nums &#xff0c;数组中的元素 互不相同。返回该数组所有可能的子集&#xff08;幂集&#xff09;。 解集 不能 包含重复的子集。你可以按 任意顺序 返回解集。 题目链接&#xff1a;LeetCode-78-子集 解题思路&#xff1a;递归回溯 题…...

UE5 学习系列(二)用户操作界面及介绍

这篇博客是 UE5 学习系列博客的第二篇&#xff0c;在第一篇的基础上展开这篇内容。博客参考的 B 站视频资料和第一篇的链接如下&#xff1a; 【Note】&#xff1a;如果你已经完成安装等操作&#xff0c;可以只执行第一篇博客中 2. 新建一个空白游戏项目 章节操作&#xff0c;重…...

【JavaEE】-- HTTP

1. HTTP是什么&#xff1f; HTTP&#xff08;全称为"超文本传输协议"&#xff09;是一种应用非常广泛的应用层协议&#xff0c;HTTP是基于TCP协议的一种应用层协议。 应用层协议&#xff1a;是计算机网络协议栈中最高层的协议&#xff0c;它定义了运行在不同主机上…...

苍穹外卖--缓存菜品

1.问题说明 用户端小程序展示的菜品数据都是通过查询数据库获得&#xff0c;如果用户端访问量比较大&#xff0c;数据库访问压力随之增大 2.实现思路 通过Redis来缓存菜品数据&#xff0c;减少数据库查询操作。 缓存逻辑分析&#xff1a; ①每个分类下的菜品保持一份缓存数据…...

Neo4j 集群管理:原理、技术与最佳实践深度解析

Neo4j 的集群技术是其企业级高可用性、可扩展性和容错能力的核心。通过深入分析官方文档,本文将系统阐述其集群管理的核心原理、关键技术、实用技巧和行业最佳实践。 Neo4j 的 Causal Clustering 架构提供了一个强大而灵活的基石,用于构建高可用、可扩展且一致的图数据库服务…...

Linux云原生安全:零信任架构与机密计算

Linux云原生安全&#xff1a;零信任架构与机密计算 构建坚不可摧的云原生防御体系 引言&#xff1a;云原生安全的范式革命 随着云原生技术的普及&#xff0c;安全边界正在从传统的网络边界向工作负载内部转移。Gartner预测&#xff0c;到2025年&#xff0c;零信任架构将成为超…...

WordPress插件:AI多语言写作与智能配图、免费AI模型、SEO文章生成

厌倦手动写WordPress文章&#xff1f;AI自动生成&#xff0c;效率提升10倍&#xff01; 支持多语言、自动配图、定时发布&#xff0c;让内容创作更轻松&#xff01; AI内容生成 → 不想每天写文章&#xff1f;AI一键生成高质量内容&#xff01;多语言支持 → 跨境电商必备&am…...

智能仓储的未来:自动化、AI与数据分析如何重塑物流中心

当仓库学会“思考”&#xff0c;物流的终极形态正在诞生 想象这样的场景&#xff1a; 凌晨3点&#xff0c;某物流中心灯火通明却空无一人。AGV机器人集群根据实时订单动态规划路径&#xff1b;AI视觉系统在0.1秒内扫描包裹信息&#xff1b;数字孪生平台正模拟次日峰值流量压力…...

蓝桥杯3498 01串的熵

问题描述 对于一个长度为 23333333的 01 串, 如果其信息熵为 11625907.5798&#xff0c; 且 0 出现次数比 1 少, 那么这个 01 串中 0 出现了多少次? #include<iostream> #include<cmath> using namespace std;int n 23333333;int main() {//枚举 0 出现的次数//因…...

GC1808高性能24位立体声音频ADC芯片解析

1. 芯片概述 GC1808是一款24位立体声音频模数转换器&#xff08;ADC&#xff09;&#xff0c;支持8kHz~96kHz采样率&#xff0c;集成Δ-Σ调制器、数字抗混叠滤波器和高通滤波器&#xff0c;适用于高保真音频采集场景。 2. 核心特性 高精度&#xff1a;24位分辨率&#xff0c…...

Android第十三次面试总结(四大 组件基础)

Activity生命周期和四大启动模式详解 一、Activity 生命周期 Activity 的生命周期由一系列回调方法组成&#xff0c;用于管理其创建、可见性、焦点和销毁过程。以下是核心方法及其调用时机&#xff1a; ​onCreate()​​ ​调用时机​&#xff1a;Activity 首次创建时调用。​…...