当前位置: 首页 > news >正文

《多GPU大模型训练与微调手册》

在这里插入图片描述

全参数微调

Lora微调

PTuning微调

多GPU微调预备知识

1. 参数数据类型 torch.dtype

在这里插入图片描述

1.1 半精度 half-precision
  • torch.float16:fp16 就是 float16,1个 sign(符号位),5个 exponent bits(指数位),10个 mantissa bits(小数位)

  • torch.bfloat16:bf 16 就是 brain float16,1个 :符号位,8个exponent bits(指数位),7个mantissa bits(小数位)

  • 区别:bf16 牺牲了精度(小数位),实现了比 fp16 更大的范围(多了三个指数位)。

1.2 全精度 single-precision
  • torch.float32:fp 32 就是 float32,1个 sign(符号位),8个 exponent bits(指数位),23个 mantissa bits(小数位)

2. 显卡环境

2.1 参数量与显存换算

例如,实验室是单机多卡:8卡A6000(40G)服务器 320G显存

① CUDA_VISIBLE_DEVICES 控制显卡可见性

通过CUDA_VISIBLE_DEVICES环境变量 控制哪些GPU可以被torch调用:

  • 代码控制
# 必须置于 import torch 之前,准确地说在 torch.cuda 的调用之前
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0,1,2,3,4,5,6,7'
import torch
torch.cuda.device_count()
# 8
  • 命令行控制
CUDA_VISIBLE_DEVICES=0,1 python train.py
② 推理换算
  • 模型加载
    (1)目前模型的参数绝大多数都是float32类型, 每个参数占用 4 个字节。所以一个粗略的计算方法就是,每10亿个参数(1 billion=10亿),占用4G显存 (实际应该是10^9 * 4 / 1024 / 1024 / 1024 = 3.725G,为了方便可以记为4G),即 1B Params= 4G VRAM。比如LLaMA的参数量为7000559616个Params,那么全精度加载这个模型参数需要的显存为:7000559616 * 4 /1024/1024/1024 = 26.08G
    (2)显存不够,可以用半精度fp16/bf16来加载,这样每个参数只占2个字节,所需显存就降为一半,只需要13.04G。
    (3)如果显存还不够,可以采用int8的精度,显存再降一半,仅需6.5G,但是模型效果会更差一些。
    (4)如果显存还是不够,int4精度显存再降一半,仅需3.26G。int4就是最低精度了,再往下模型推理效果就很难保证了。
    在这里插入图片描述

  • 模型推理:注意上面只是加载模型到显存,模型运算时的一些临时变量也需要申请空间,比如你beam search的时候。所以真正做推理的时候记得留一些Buffer,不然就容易OOM。如果显存还不够,就只能采用Memery Offload的技术,把部分显存的内容给挪到内存,但是这样会显著降低推理速度。

③ 训练换算

模型训练的时候显存使用包括如下几部分:

  1. 模型权重,计算方法和推理一样。
  2. 优化器:(1)如果你采用AdamW,每个参数需要占用8个字节,因为需要维护两个状态。也就说优化器使用显存是全精度(float32)模型权重的2倍。(2)如果采用bitsandbytes优化的AdamW,每个参数需要占用2个字节,也就是全精度(float32)模型权重的一半。(3)如果采用SGD,则优化器占用显存和全精度模型权重一样。
  3. 梯度:梯度占用显存和全精度(float32)模型权重一样。
  4. 计算图内部变量:有时候也叫Forward Activations。

如果模型想要训练,只看前3部分,需要的显存是至少推理的3-4倍。7B的全精度模型加载需要78G ~ 104G。 然后计算图内部变量这一部分只能在运行时候观测了,可以两个不同的batch的占用显存的差值大概估算出来。

优化的思路也就有了,目前市面上主流的一些计算加速的框架如DeepSpeed, Megatron等都在降低显存方面做了很多优化工作,比如量化,模型切分,混合精度计算,Memory Offload等等。

2.2 分布式架构

在这里插入图片描述
3种并行方式

  • 数据并行Data Paralleism:模型复制到不同GPU上,将数据切分后,分配到不同的GPU上。
  • 模型并行Model Paralleism:将模型切分后,分配到不同的GPU上。分为张量并行和流水线并行。
    • 张量并行Tensor Paralleism:对模型参数 tensor 切分,分配到不同的GPU进行计算,在参数更新的时候再进行同步。在这里插入图片描述
    • 流水线并行Pipeline Paralleism:对模型按层layer切分,分配到不同的GPU上进行计算。
      在这里插入图片描述
  • 混合并行Hybrid Paralleism:同时进行数据并行、张量并行、流水线并行。
    在这里插入图片描述

下面3个分布式框架都是基于 Pytorch 的并行框架:

  • DP(torch.nn.DataParallel)单机-单进程多线程进行实现的,它使用一个进程来计算模型权重,在每个batch处理期间将数据分发到每个GPU,每个GPU 分发到 batch_size/N 个数据,各个GPU的forward结果汇聚到master GPU上计算loss,计算梯度更新master GPU参数,将参数复制给其他GPU。(数据并行
  • DDP(torch.nn.DistributedDataParallel):可以单机/多机-多进程进行实现的,每个GPU对应的进程都有独立的优化器,执行自己的更新过程。每个进程都执行相同的任务,并且每个进程都与所有其他进程通信。进程(GPU)之间只传递梯度,这样网络通信就不再是瓶颈。(数据并行
  • FSDP(torch.distributed.fsdp.FullyShardedDataParallel):Pytorch最新的数据并行方案,在1.11版本引入的新特性,目的主要是用于训练大模型。我们都知道Pytorch DDP用起来简单方便,但是要求整个模型加载到一个GPU上,维护模型参数、梯度和优化器状态的每个 GPU 副本。FSDP则可以在数据并行的基础上,将模型参数和优化器分片分配到 GPU,这使得大模型的训练权重得以加载。(数据并行+模型并行

这些在前面的博客已经讲过:

  • 分布式并行训练(DP、DDP、DeepSpeed)
  • pytorch单精度、半精度、混合精度、单卡、多卡(DP / DDP)、FSDP、DeepSpeed模型训练
2.3 分布式工具

前面的分布式框架使用起来较为麻烦,因此分布式工具在底层对torch的分布式框架进行封装,实现更加方便的分布式训练和微调:

  • DerepSpeed(微软开发)
  • Accelerate(Huggingface开发)
① DerepSpeed—Zero

DerepSpeed的原理是基于微软的研究:Zero(零冗余优化),研究哪些部分是占用存储空间的,并对这些占用存储的数据进行优化。
在这里插入图片描述

存储空间的消耗 Memory Consumption主要包含两部分:

  • Model States(主):模型参数Parameters梯度Gradients优化器Optimizer_State
  • Residual States(次):前向传播激活值Activations临时缓存区Temporal Buffers内存碎片Unusable Fragmented Memory
    在这里插入图片描述

知道了什么东西会占存储,以及它们占了多大的存储之后,我们就可以来谈如何优化存储了。注意到,在整个训练中,有很多states并不会每时每刻都用到;因此提出了三种Zero优化方法:

  • Zero-DP优化Model States):作者采取三个方法优化内存,Pos、Pg、Pp。大体思路都是一样的,把每个模型的参数、梯度、优化器状态分别平均分给所有的gpu,当时计算需要用到其他gpu的内容时,通过GPU之间的通讯传输,以通讯换内存。其中前两个方法不增加通讯成本,第三个方法会增加GPU之间的通信成本。
    在这里插入图片描述

  • Zero-R优化Residual States):(1)激活函数:在前向传播计算完成激活函数之后,对把激活值丢弃,由于计算图还在,等到反向传播的时候,再次计算激活值,算力换内存。或者采取一个与cpu执行一个换入换出的操作。(2)临时缓冲区:模型训练过程中经常会创建一些大小不等的临时缓冲区,比如对梯度进行All Reduce啥的,解决办法就是预先创建一个固定的缓冲区,训练过程中不再动态创建,如果要创建临时数据,在固定缓冲区创建就好。(3)内存碎片:显存出现碎片的一大原因是时候gradient checkpointing后,不断地创建和销毁那些不保存的激活值,解决方法是预先分配一块连续的显存,将常驻显存的模型状态和checkpointed activation存在里面,剩余显存用于动态创建和销毁discarded activation复用了操作系统对内存的优化,不断内存整理。

  • 混合精度训练:对于模型,我们肯定希望其参数越精准越好,也即我们用fp32(单精度浮点数,存储占4byte)来表示参数W。但是在forward和backward的过程中,fp32的计算开销也是庞大的。那么能否在计算的过程中,引入fp16或bf16(半精度浮点数,存储占2byte),来减轻计算压力呢?于是,混合精度训练(float2hlaf)就产生了,它的步骤如下图:
    在这里插入图片描述

  1. get fp32:存储一份fp32的Model States:parameter,momentum和variance
  2. fp32-to-fp16:在forward开始之前,额外开辟一块存储空间,将fp32 parameter减半到fp16 parameter。
  3. fp16 computing:正常做forward和backward,在此之间产生的activation和gradients,都用fp16进行存储。
  4. update fp32 model states:用fp16 gradients去更新fp32下的model states。
  • Zero-OffloadGPU显存不够,CPU内存来凑。如下图,左边是正常的计算图,右侧是Zero-Offload的计算图。(⭕️表示state,正方形表示计算图,箭头表示数据流向、M表示模型参数,float2half表示32位转16位)其实就是forward和backward在GPU上计算参数更新在CPU上。因为CPU与GPU通信数据开销很大,所以CPU和GPU传播的是gradient16,这样保证传播数据量最小。
    在这里插入图片描述

  • Zero-InfinityGPU内存不够,SSD外存来凑
    在这里插入图片描述

②Accelerate—Huggingface

相关文章:

《多GPU大模型训练与微调手册》

全参数微调 Lora微调 PTuning微调 多GPU微调预备知识 1. 参数数据类型 torch.dtype 1.1 半精度 half-precision torch.float16:fp16 就是 float16,1个 sign(符号位),5个 exponent bits(指数位),10个 ma…...

【C++】const与类(const修饰函数的三种位置)

目录 const基本介绍 正文 前: 中: 后: 拷贝构造使用const 目录 const基本介绍 正文 前: 中: 后: 拷贝构造使用const const基本介绍 const 是 C 中的修饰符,用于声明常量或表示不可修改的对象、函数或成员函数。 我们已经了解了const基本用法,我们先进行…...

深度学习在图像识别中的革命性应用

深度学习在图像识别中的革命性应用标志着计算机视觉领域的重大进步。以下是深度学习在图像识别方面的一些革命性应用: 1. **卷积神经网络(CNN)的崭新时代**: - CNN是深度学习在图像识别中的核心技术,通过卷积层、池化…...

R语言读文件“-“变成“.“

R语言读取文件时发生"-"变成"." 如果使用read.table函数&#xff0c;需要 check.namesFALSE data <- read.table("data.tsv", headerTRUE, row.names1, check.namesFALSE)怎样将"."还原为"-" 方法一&#xff1a;gsub函…...

RabbitMQ 基础操作

概念 从计算机术语层面来说&#xff0c;RabbitMQ 模型更像是一种交换机模型。 Queue 队列 Queue&#xff1a;队列&#xff0c;是RabbitMQ 的内部对象&#xff0c;用于存储消息。 RabbitMQ 中消息只能存储在队列中&#xff0c;这一点和Kafka相反。Kafka将消息存储在topic&am…...

自然语言处理:Transformer与GPT

Transformer和GPT&#xff08;Generative Pre-trained Transformer&#xff09;是深度学习和自然语言处理&#xff08;NLP&#xff09;领域的两个重要概念&#xff0c;它们之间存在密切的关系但也有明显的不同。 1 基本概念 1.1 Transformer基本概念 Transformer是一种深度学…...

Ps:裁剪工具 - 裁剪预设的应用

裁剪工具提供了两种类型的裁剪方式。 一种是仅按宽高比&#xff08;比例&#xff09;进行裁剪&#xff0c;常在对图像进行二次构图时采用。 另一种则按指定的图像尺寸&#xff08;宽度值和高度值&#xff09;及分辨率&#xff08;宽 x 高 x 分辨率&#xff09;进行裁剪。其实质…...

前端工程化-什么是构建工具

了解构建工具之前&#xff0c;我们首先要知道的是浏览器只认识html、css、js&#xff0c;而我们开发时用的vue&#xff0c;react框架都只是为了方便我们开发而使用的工具 使用构建工具的原因 vue或react的企业级项目里都会具备这些功能&#xff1a; 1.使用typescript语言&…...

01-论文阅读-Deep learning for anomaly detection in log data: a survey

01-论文阅读-Deep learning for anomaly detection in log data: a survey 文章目录 01-论文阅读-Deep learning for anomaly detection in log data: a survey摘要I 介绍II 背景A 初步定义B 挑战 III 调查方法A 搜索策略B 审查的功能 IV 调查结果A 文献计量学B 深度学习技术C …...

图像处理02 matlab中NSCT的使用

06 matlab中NSCT的使用 最近在学习NSCT相关内容&#xff0c;奈何网上资源太少&#xff0c;简单看了些论文找了一些帖子才懂了一点点&#xff0c;在此分享给大家&#xff0c;希望有所帮助。 一.NSCT流程 首先我们先梳理一下NSCT变换的流程&#xff0c;只有清楚流程才更好的理清…...

提升办公效率,畅享多功能办公笔记软件Notion for Mac

在现代办公环境中&#xff0c;高效的笔记软件对于提高工作效率至关重要。而Notion for Mac作为一款全能的办公笔记软件&#xff0c;将成为你事业成功的得力助手。 Notion for Mac以其多功能和灵活性而脱颖而出。无论你是需要记录会议笔记、管理项目任务、制定流程指南&#xf…...

Apache Airflow (十三) :Airflow分布式集群搭建及使用-原因及

&#x1f3e1; 个人主页&#xff1a;IT贫道_大数据OLAP体系技术栈,Apache Doris,Clickhouse 技术-CSDN博客 &#x1f6a9; 私聊博主&#xff1a;加入大数据技术讨论群聊&#xff0c;获取更多大数据资料。 &#x1f514; 博主个人B栈地址&#xff1a;豹哥教你大数据的个人空间-豹…...

# 聚类系列(一)——什么是聚类?

目前在做聚类方面的科研工作, 看了很多相关的论文, 也做了一些工作, 于是想出个聚类系列记录一下, 主要包括聚类的概念和相关定义、现有常用聚类算法、聚类相似性度量指标、聚类评价指标、 聚类的应用场景以及共享一些聚类的开源代码 下面正式进入该系列的第一个部分&#xff…...

Android DatePicker(日期选择器)、TimePicker(时间选择器)、CalendarView(日历视图)- 简单应用

示意图&#xff1a; layout布局文件&#xff1a;xml <?xml version"1.0" encoding"utf-8"?> <ScrollView xmlns:android"http://schemas.android.com/apk/res/android"xmlns:app"http://schemas.android.com/apk/res-auto"…...

linux环境搭建mysql5.7总结

以下安装方式&#xff0c;在阿里云与腾讯云服务器上都测试可用。 一、进入到opt目录下&#xff0c;执行&#xff1a; [rootmaster opt]# wget https://dev.mysql.com/get/Downloads/MySQL-5.7/mysql-5.7.26-linux-glibc2.12-x86_64.tar.gz解压&#xff1a; [rootmaster opt]#…...

SQL Server Count()函数

SQL Server Count()函数 SQL Server COUNT() 是一个聚合函数&#xff0c;它返回在集合中找到的项目数。 COUNT() 函数语法&#xff1a; COUNT([ALL | DISTINCT ] expression)ALL 指示COUNT() 函数应用于所有值。ALL是默认值。返回非NULL值的数量&#xff08;包括重复值&…...

架构探索之路-第一站-clickhouse | 京东云技术团队

一、前言 架构, 软件开发中最熟悉不过的名词, 遍布在我们的日常开发工作中, 大到项目整体, 小到功能组件, 想要实现高性能、高扩展、高可用的目标都需要优秀架构理念辅助. 所以本人尝试编写架构系列文章, 去剖析市面上那些经典优秀的开源项目, 学习优秀的架构理念来积累架构设…...

易航网址引导系统 v1.9 源码:去除弹窗功能的易航网址引导页管理系统

易航自主开发了一款极其优雅的易航网址引导页管理系统&#xff0c;后台采用全新的光年 v5 模板开发。该系统完全开源&#xff0c;摒弃了后门风险&#xff0c;可以管理无数个引导页主题。数据管理采用易航原创的JsonDb数据包&#xff0c;无需复杂的安装解压过程即可使用。目前系…...

创新无界:通义灵码在测试过程中展现的独特魅力

通义灵码基于通义大模型&#xff0c;提供代码智能生成、研发智能问答能力。本文就来介绍下通义灵码在测试过程中的应用。 操作手册&#xff1a; 通义灵码, 阿里云提供的一款基于通义大模型的智能编码辅助工具_云效-阿里云帮助中心 1. 什么是通义灵码 是阿里云出品的一款基于通…...

crmchat安装搭建教程文档 bug问题调试

一、安装PHP插件&#xff1a;fileinfo、redis、swoole4。 二、删除PHP对应版本中的 proc_open禁用函数。 一、设置网站运行目录public&#xff0c; 二、设置PHP版本选择纯静态。 三、可选项如有需求则开启SSL,配置SSL证书&#xff0c;开启强制https域名。 四、添加反向代理。 …...

基础测试工具使用经验

背景 vtune&#xff0c;perf, nsight system等基础测试工具&#xff0c;都是用过的&#xff0c;但是没有记录&#xff0c;都逐渐忘了。所以写这篇博客总结记录一下&#xff0c;只要以后发现新的用法&#xff0c;就记得来编辑补充一下 perf 比较基础的用法&#xff1a; 先改这…...

【ROS】Nav2源码之nav2_behavior_tree-行为树节点列表

1、行为树节点分类 在 Nav2(Navigation2)的行为树框架中,行为树节点插件按照功能分为 Action(动作节点)、Condition(条件节点)、Control(控制节点) 和 Decorator(装饰节点) 四类。 1.1 动作节点 Action 执行具体的机器人操作或任务,直接与硬件、传感器或外部系统…...

(转)什么是DockerCompose?它有什么作用?

一、什么是DockerCompose? DockerCompose可以基于Compose文件帮我们快速的部署分布式应用&#xff0c;而无需手动一个个创建和运行容器。 Compose文件是一个文本文件&#xff0c;通过指令定义集群中的每个容器如何运行。 DockerCompose就是把DockerFile转换成指令去运行。 …...

全面解析各类VPN技术:GRE、IPsec、L2TP、SSL与MPLS VPN对比

目录 引言 VPN技术概述 GRE VPN 3.1 GRE封装结构 3.2 GRE的应用场景 GRE over IPsec 4.1 GRE over IPsec封装结构 4.2 为什么使用GRE over IPsec&#xff1f; IPsec VPN 5.1 IPsec传输模式&#xff08;Transport Mode&#xff09; 5.2 IPsec隧道模式&#xff08;Tunne…...

Unsafe Fileupload篇补充-木马的详细教程与木马分享(中国蚁剑方式)

在之前的皮卡丘靶场第九期Unsafe Fileupload篇中我们学习了木马的原理并且学了一个简单的木马文件 本期内容是为了更好的为大家解释木马&#xff08;服务器方面的&#xff09;的原理&#xff0c;连接&#xff0c;以及各种木马及连接工具的分享 文件木马&#xff1a;https://w…...

Aspose.PDF 限制绕过方案:Java 字节码技术实战分享(仅供学习)

Aspose.PDF 限制绕过方案&#xff1a;Java 字节码技术实战分享&#xff08;仅供学习&#xff09; 一、Aspose.PDF 简介二、说明&#xff08;⚠️仅供学习与研究使用&#xff09;三、技术流程总览四、准备工作1. 下载 Jar 包2. Maven 项目依赖配置 五、字节码修改实现代码&#…...

Python基于历史模拟方法实现投资组合风险管理的VaR与ES模型项目实战

说明&#xff1a;这是一个机器学习实战项目&#xff08;附带数据代码文档&#xff09;&#xff0c;如需数据代码文档可以直接到文章最后关注获取。 1.项目背景 在金融市场日益复杂和波动加剧的背景下&#xff0c;风险管理成为金融机构和个人投资者关注的核心议题之一。VaR&…...

RabbitMQ入门4.1.0版本(基于java、SpringBoot操作)

RabbitMQ 一、RabbitMQ概述 RabbitMQ RabbitMQ最初由LShift和CohesiveFT于2007年开发&#xff0c;后来由Pivotal Software Inc.&#xff08;现为VMware子公司&#xff09;接管。RabbitMQ 是一个开源的消息代理和队列服务器&#xff0c;用 Erlang 语言编写。广泛应用于各种分布…...

免费数学几何作图web平台

光锐软件免费数学工具&#xff0c;maths,数学制图&#xff0c;数学作图&#xff0c;几何作图&#xff0c;几何&#xff0c;AR开发,AR教育,增强现实,软件公司,XR,MR,VR,虚拟仿真,虚拟现实,混合现实,教育科技产品,职业模拟培训,高保真VR场景,结构互动课件,元宇宙http://xaglare.c…...

Python Einops库:深度学习中的张量操作革命

Einops&#xff08;爱因斯坦操作库&#xff09;就像给张量操作戴上了一副"语义眼镜"——让你用人类能理解的方式告诉计算机如何操作多维数组。这个基于爱因斯坦求和约定的库&#xff0c;用类似自然语言的表达式替代了晦涩的API调用&#xff0c;彻底改变了深度学习工程…...