当前位置: 首页 > news >正文

pytorch单精度、半精度、混合精度、单卡、多卡(DP / DDP)、FSDP、DeepSpeed模型训练

pytorch单精度、半精度、混合精度、单卡、多卡(DP / DDP)、FSDP、DeepSpeed(环境没搞起来)模型训练代码,并对比不同方法的训练速度以及GPU内存的使用

代码:pytorch_model_train


FairScale(你真的需要FSDP、DeepSpeed吗?)

在了解各种训练方式之前,先来看一下 FairScale 给出的一个模型训练方式选择的流程,选择适合自己的方式,就是最好的。

在这里插入图片描述


训练环境设置

  • 模型:预训练的Resnet50
  • 数据集:Cifar10
  • 硬件资源:一台4卡Tesla P40
  • 训练设置:5 epoch、128 batch size
  • 观察指标:显存占用、GPU使用率、训练时长、模型训练结果

备注:

  1. 由于P40硬件限制,不支持半精度fp16的训练,在fp16条件下训练的速度会受到影
  2. ResNet50模型较小,batch_size=1时单卡仅占用 0.34G显存,绝大部分显存都被输入数据,以及中间激活占用

测试基准(batch_size=1)

  • 单卡显存占用:0.34 G
  • 单卡GPU使用率峰值:60%

单卡单精度训练

  • 代码文件:pytorch_SingleGPU.py
  • 单卡显存占用:11.24 G
  • 单卡GPU使用率峰值:100%
  • 训练时长(5 epoch):1979 s
  • 训练结果:准确率85%左右

在这里插入图片描述


单卡半精度训练

  • 代码文件:pytorch_half_precision.py
  • 单卡显存占用:5.79 G
  • 单卡GPU使用率峰值:100%
  • 训练时长(5 epoch):1946 s
  • 训练结果:准确率75%左右

在这里插入图片描述

备注: 单卡半精度训练的准确率只有75%,单精度的准确率在85%左右


单卡混合精度训练

AUTOMATIC MIXED PRECISION PACKAGE - TORCH.AMP

CUDA AUTOMATIC MIXED PRECISION EXAMPLES

PyTorch 源码解读之 torch.cuda.amp: 自动混合精度详解

如何使用 PyTorch 进行半精度、混(合)精度训练

如何使用 PyTorch 进行半精度训练

pytorch模型训练之fp16、apm、多GPU模型、梯度检查点(gradient checkpointing)显存优化等

Working with Multiple GPUs

  • 代码文件:pytorch_auto_mixed_precision.py
  • 单卡显存占用:6.02 G
  • 单卡GPU使用率峰值:100%
  • 训练时长(5 epoch):1546 s
  • 训练结果:准确率85%左右

在这里插入图片描述

  • 混合精度训练过程

在这里插入图片描述

  • 混合精度训练基本流程
  1. 维护一个 FP32 数值精度模型的副本
  2. 在每个iteration
    • 拷贝并且转换成 FP16 模型
    • 前向传播(FP16 的模型参数)
    • loss 乘 scale factor s
    • 反向传播(FP16 的模型参数和参数梯度)
    • 参数梯度乘 1/s
    • 利用 FP16 的梯度更新 FP32 的模型参数
  • autocast结合GradScaler用法
# Creates model and optimizer in default precision
model = Net().cuda()
optimizer = optim.SGD(model.parameters(), ...)# Creates a GradScaler once at the beginning of training.
scaler = GradScaler()for epoch in epochs:for input, target in data:optimizer.zero_grad()# Runs the forward pass with autocasting.with autocast(device_type='cuda', dtype=torch.float16):output = model(input)loss = loss_fn(output, target)# Scales loss.  Calls backward() on scaled loss to create scaled gradients.# Backward passes under autocast are not recommended.# Backward ops run in the same dtype autocast chose for corresponding forward ops.scaler.scale(loss).backward()# scaler.step() first unscales the gradients of the optimizer's assigned params.# If these gradients do not contain infs or NaNs, optimizer.step() is then called,# otherwise, optimizer.step() is skipped.scaler.step(optimizer)# Updates the scale for next iteration.scaler.update()
  • 基于GradScaler进行梯度裁剪
scaler.scale(loss).backward()
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm)
scaler.step(optimizer)
scaler.update()
  • autocast用法
# Creates some tensors in default dtype (here assumed to be float32)
a_float32 = torch.rand((8, 8), device="cuda")
b_float32 = torch.rand((8, 8), device="cuda")
c_float32 = torch.rand((8, 8), device="cuda")
d_float32 = torch.rand((8, 8), device="cuda")with torch.autocast(device_type="cuda"):# torch.mm is on autocast's list of ops that should run in float16.# Inputs are float32, but the op runs in float16 and produces float16 output.# No manual casts are required.e_float16 = torch.mm(a_float32, b_float32)# Also handles mixed input typesf_float16 = torch.mm(d_float32, e_float16)# After exiting autocast, calls f_float16.float() to use with d_float32
g_float32 = torch.mm(d_float32, f_float16.float())
  • autocast嵌套使用
# Creates some tensors in default dtype (here assumed to be float32)
a_float32 = torch.rand((8, 8), device="cuda")
b_float32 = torch.rand((8, 8), device="cuda")
c_float32 = torch.rand((8, 8), device="cuda")
d_float32 = torch.rand((8, 8), device="cuda")with torch.autocast(device_type="cuda"):e_float16 = torch.mm(a_float32, b_float32)with torch.autocast(device_type="cuda", enabled=False):# Calls e_float16.float() to ensure float32 execution# (necessary because e_float16 was created in an autocasted region)f_float32 = torch.mm(c_float32, e_float16.float())# No manual casts are required when re-entering the autocast-enabled region.# torch.mm again runs in float16 and produces float16 output, regardless of input types.g_float16 = torch.mm(d_float32, f_float32)

4卡 DP(Data Parallel)

  • 代码文件:pytorch_DP.py
  • 单卡显存占用:3.08 G
  • 单卡GPU使用率峰值:99%
  • 训练时长(5 epoch):742 s
  • 训练结果:准确率85%左右

在这里插入图片描述


4卡 DDP(Distributed Data Parallel)

pytorch-multi-gpu-training
/ddp_train.py

DISTRIBUTED COMMUNICATION PACKAGE - TORCH.DISTRIBUTED

  • 代码文件:pytorch_DDP.py
  • 单卡显存占用:3.12 G
  • 单卡GPU使用率峰值:99%
  • 训练时长(5 epoch):560 s
  • 训练结果:准确率85%左右

在这里插入图片描述

  • 代码启动命令(单机 4 GPU)
python -m torch.distributed.launch --nproc_per_node=4 --nnodes=1 pytorch_DDP.py    

基于accelerate的 DDP

huggingface/accelerate

Hugging Face开源库accelerate详解

  • 代码文件:accelerate_DDP.py
  • 单卡显存占用:3.15 G
  • 单卡GPU使用率峰值:99%
  • 训练时长(5 epoch):569 s
  • 训练结果:准确率85%左右

在这里插入图片描述

  • accelerate配置文件default_DDP.yml
compute_environment: LOCAL_MACHINE
distributed_type: MULTI_GPU
downcast_bf16: 'no'
gpu_ids: all
machine_rank: 0
main_training_function: main
mixed_precision: 'no'
num_machines: 1
num_processes: 4
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false
  • 代码启动命令(单机 4 GPU)
accelerate launch --config_file ./config/default_DDP.yml accelerate_DDP.py    

Pytorch + FSDP(Fully Sharded Data Parallel)

Pytorch FULLY SHARDED DATA PARALLEL (FSDP) 初识

2023 年了,大模型训练还要不要用 PyTorch 的 FSDP ?

GETTING STARTED WITH FULLY SHARDED DATA PARALLEL(FSDP)

  • batch_size == 1

    • 单卡显存占用:0.19 G,相比基准测试的 0.34G 有减少,但是没有达到4倍
    • 单卡GPU使用率峰值:60%
  • batch_size == 128

    • 单卡显存占用:2.88 G
    • 单卡GPU使用率峰值:99%
  • 代码文件:pytorch_FSDP.py

  • 训练时长(5 epoch):581 s

  • 训练结果:准确率85%左右

备注: pytorch里面的FSDP的batchsize是指单张卡上的batch大小

在这里插入图片描述

  • 代码启动命令(单机 4 GPU)
python -m torch.distributed.launch --nproc_per_node=4 --nnodes=1 pytorch_FSDP.py    
  • FSDP包装后的模型

代码中指定对Resnet50中的Linear和Conv2d层应用FSDP。

在这里插入图片描述


基于accelerate的 FSDP(Fully Sharded Data Parallel)

  • batch_size == 1

    • 单卡显存占用:0.38 G,相比基准测试的 0.34G 并没有减少
    • 单卡GPU使用率峰值:60%
  • batch_size == 128

    • 单卡显存占用:2.90 G
    • 单卡GPU使用率峰值:99%
  • 代码文件:accelerate_FSDP.py

  • 训练时长(5 epoch):576 s,对于这个小模型速度和DDP相当

  • 训练结果:准确率85%左右

在这里插入图片描述

  • accelerate配置文件default_FSDP.yml
compute_environment: LOCAL_MACHINE
distributed_type: FSDP
downcast_bf16: 'no'
fsdp_config:fsdp_auto_wrap_policy: SIZE_BASED_WRAPfsdp_backward_prefetch_policy: BACKWARD_PREfsdp_forward_prefetch: truefsdp_min_num_params: 1000000fsdp_offload_params: falsefsdp_sharding_strategy: 1fsdp_state_dict_type: SHARDED_STATE_DICTfsdp_sync_module_states: truefsdp_use_orig_params: true
machine_rank: 0
main_training_function: main
mixed_precision: 'no'
num_machines: 1
num_processes: 4
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false
  • 代码启动命令(单机 4 GPU)
accelerate launch --config_file ./config/default_FSDP.yml accelerate_FSDP.py    

Pytorch + DeepSpeed(环境没搞起来,哈哈哈)

[BUG] error: unrecognized arguments: --deepspeed ./ds_config.json #3961

fused_adam.so: cannot open shared object file: No such file or directory #119

DeepSpeedExamples/training/cifar/

Getting Started

  • 代码文件:pytorch_DeepSpeed.py

  • 单卡显存占用:

  • 单卡GPU使用率峰值:

  • 训练时长(5 epoch):

  • 训练结果:

  • 代码启动命令(单机 4 GPU)

deepspeed pytorch_DeepSpeed.py --deepspeed_config ./config/zero_stage2_config.json    

基于accelerate的 DeepSpeed(环境没搞起来,哈哈哈)

DeepSpeed介绍

深度解析:如何使用DeepSpeed加速PyTorch模型训练

DeepSpeed

  • 代码文件:accelerate_DeepSpeed.py
  • 单卡显存占用:
  • 单卡GPU使用率峰值:
  • 训练时长(5 epoch):
  • 训练结果:

相关文章:

pytorch单精度、半精度、混合精度、单卡、多卡(DP / DDP)、FSDP、DeepSpeed模型训练

pytorch单精度、半精度、混合精度、单卡、多卡(DP / DDP)、FSDP、DeepSpeed(环境没搞起来)模型训练代码,并对比不同方法的训练速度以及GPU内存的使用 代码:pytorch_model_train FairScale(你真…...

基于PHP的纺织用品商城系统

有需要请加文章底部Q哦 可远程调试 基于PHP的纺织用品商城系统 一 介绍 此纺织用品商城系统基于原生PHP开发,数据库mysql,前端bootstrap。用户可注册登录,购物下单,评论等。管理员登录后台可对纺织用品,用户&#xf…...

Go使用命令行输出二维码

引言 二维码(QR code)是一种矩阵条码的标准,广泛应用于商业、移动支付和数据存储等领域。在开发过程中,我们可能需要在命令行中显示二维码,这可以帮助我们快速生成和分享二维码信息。本文将介绍如何使用Go语言生成二维…...

最长连续序列[中等]

优质博文:IT-BLOG-CN 一、题目 给定一个未排序的整数数组nums,找出数字连续的最长序列(不要求序列元素在原数组中连续)的长度。请你设计并实现时间复杂度为O(n)的算法解决此问题。 示例 1: 输入:nums […...

设计模式-状态模式-笔记

状态模式State 在组件构建过程中,某些对象的状态经常面临变化,如何对这些变化进行有效的管理?同时又维持高层模块的稳定?“状态变化”模式为这一问题提供了一种解决方案。 经典模式:State、Memento 动机&#xff08…...

Java中for、foreach、stream区别和性能比较

文章目录 性能比较区别使用方式和行为 性能比较 最终总结:如果数据在1万以内的话,for循环效率高于foreach和stream;如果数据量在10万的时候,stream效率最高,其次是foreach,最后是for。另外需要注意的是如果数据达到10…...

[CSS] 文本折行

文本折行一般分为两种情况: CJK(Chinese/Japanese/Korean) 字符和非 CJK 字符。一般非 CJK 字符折行发生在两个单词的空格中间,见下图: 图中文本 “hello world” 包裹容器的宽度为 2rem,但是 hello 并没有…...

033-从零搭建微服务-日志插件(一)

写在最前 如果这个项目让你有所收获,记得 Star 关注哦,这对我是非常不错的鼓励与支持。 源码地址(后端):mingyue: 🎉 基于 Spring Boot、Spring Cloud & Alibaba 的分布式微服务架构基础服务中心 源…...

短期经济波动:均衡国民收入决定理论(三)

短期经济波动:国民收入决定理论(三) 文章目录 短期经济波动:国民收入决定理论(三)[toc]1 总需求曲线及其变动1.1 总需求曲线含义1.2 总需求曲线推导1.2.1 代数推导1.2.2 几何推导 1.3 AD曲线及其变动1.3.1 扩张性财政政策1.3.2 扩张性货币政策 2 总供给曲…...

电力感知边缘计算网关产品设计方案-网关软件架构

边缘计算网关采用ARM定制硬件平台架构,包含上位机端(内网)和FPGA网关端(外网)两部分,通过芯片间的高速信号总线实现边缘计算网关工业数据采集、数据实时传输、数据存储、网关状态信息收集等功能。 边缘计算网关上位机端(内网)重点完成工业数据采集、业务软件运算、客户…...

最新AI创作系统ChatGPT系统运营源码/支持最新GPT-4-Turbo模型/支持DALL-E3文生图

一、AI创作系统 SparkAi创作系统是基于OpenAI很火的ChatGPT进行开发的Ai智能问答系统和Midjourney绘画系统,支持OpenAI-GPT全模型国内AI全模型。本期针对源码系统整体测试下来非常完美,可以说SparkAi是目前国内一款的ChatGPT对接OpenAI软件系统。那么如…...

Java使用Redis的几种客户端介绍

Redis是一种高性能的内存数据库,可以提供快速的数据读写操作。在Java中使用Redis,需要使用Redis客户端。目前,Java中常用的Redis客户端有以下几种: Jedis Jedis是Java中最流行的Redis客户端之一,它提供了丰富的API和…...

程序员的护城河

程序员的护城河 算法,一定是过硬的算法!!!举个栗子:算法不硬吃大亏写在最后 算法,一定是过硬的算法!!! 其实会什么技术不重要,掌握多少种编程语言也不重要&a…...

常见面试题-MySQL软删除以及索引结构

为什么 mysql 删了行记录,反而磁盘空间没有减少? 答: 在 mysql 中,当使用 delete 删除数据时,mysql 会将删除的数据标记为已删除,但是并不去磁盘上真正进行删除,而是在需要使用这片存储空间时…...

信号的机制——信号处理函数的注册

在 Linux 操作系统中,为了响应各种各样的事件,也是定义了非常多的信号。我们可以通过 kill -l 命令,查看所有的信号。 # kill -l1) SIGHUP 2) SIGINT 3) SIGQUIT 4) SIGILL 5) SIGTRAP6) SIGABRT 7) SIGBUS …...

JS-项目实战-鼠标悬浮变手势(鼠标放单价上生效)

1、鼠标悬浮和离开事件.js //当页面加载完成后执行后面的匿名函数 window.onload function () {//get:获取 Element:元素 By:通过...方式//getElementById()根据id值获取某元素let fruitTbl document.getElementById("fruit_tbl");//table.rows:获取这个表格…...

redis运维(十一) python操作redis

一 python操作redis ① 安装pyredis redis常见错误 说明:由于redis服务器是5.0.8的,为了避免出现问题,默认最高版本的即可 --> 适配 ② 操作流程 核心:获取redis数据库连接对象 ③ Python 字符串前面加u,r,b的含义 原因: 字符串在…...

黑马程序员微服务 第五天课程 分布式搜索引擎2

分布式搜索引擎02 在昨天的学习中,我们已经导入了大量数据到elasticsearch中,实现了elasticsearch的数据存储功能。但elasticsearch最擅长的还是搜索和数据分析。 所以今天,我们研究下elasticsearch的数据搜索功能。我们会分别使用DSL和Res…...

什么是UV贴图?

UV 是与几何图形的顶点信息相对应的二维纹理坐标。UV 至关重要,因为它们提供了表面网格与图像纹理如何应用于该表面之间的联系。它们基本上是控制纹理上哪些像素对应于 3D 网格上的哪个顶点的标记点。它们在雕刻中也很重要。 为什么UV映射很重要? 默认情…...

从哪里下载 Oracle database 11g 软件

登入My Oracle Support,选择Patches & Updates 标签页,点击下方的Latest Patchsets链接: 然后单击Oracle Database,就可以下载11g软件了: 安装单实例数据库需要1和2两个zip文件,安装GI需要第3个zip文…...

利用最小二乘法找圆心和半径

#include <iostream> #include <vector> #include <cmath> #include <Eigen/Dense> // 需安装Eigen库用于矩阵运算 // 定义点结构 struct Point { double x, y; Point(double x_, double y_) : x(x_), y(y_) {} }; // 最小二乘法求圆心和半径 …...

Linux应用开发之网络套接字编程(实例篇)

服务端与客户端单连接 服务端代码 #include <sys/socket.h> #include <sys/types.h> #include <netinet/in.h> #include <stdio.h> #include <stdlib.h> #include <string.h> #include <arpa/inet.h> #include <pthread.h> …...

智慧医疗能源事业线深度画像分析(上)

引言 医疗行业作为现代社会的关键基础设施,其能源消耗与环境影响正日益受到关注。随着全球"双碳"目标的推进和可持续发展理念的深入,智慧医疗能源事业线应运而生,致力于通过创新技术与管理方案,重构医疗领域的能源使用模式。这一事业线融合了能源管理、可持续发…...

R语言AI模型部署方案:精准离线运行详解

R语言AI模型部署方案:精准离线运行详解 一、项目概述 本文将构建一个完整的R语言AI部署解决方案,实现鸢尾花分类模型的训练、保存、离线部署和预测功能。核心特点: 100%离线运行能力自包含环境依赖生产级错误处理跨平台兼容性模型版本管理# 文件结构说明 Iris_AI_Deployme…...

Python:操作 Excel 折叠

💖亲爱的技术爱好者们,热烈欢迎来到 Kant2048 的博客!我是 Thomas Kant,很开心能在CSDN上与你们相遇~💖 本博客的精华专栏: 【自动化测试】 【测试经验】 【人工智能】 【Python】 Python 操作 Excel 系列 读取单元格数据按行写入设置行高和列宽自动调整行高和列宽水平…...

前端倒计时误差!

提示:记录工作中遇到的需求及解决办法 文章目录 前言一、误差从何而来?二、五大解决方案1. 动态校准法(基础版)2. Web Worker 计时3. 服务器时间同步4. Performance API 高精度计时5. 页面可见性API优化三、生产环境最佳实践四、终极解决方案架构前言 前几天听说公司某个项…...

循环冗余码校验CRC码 算法步骤+详细实例计算

通信过程&#xff1a;&#xff08;白话解释&#xff09; 我们将原始待发送的消息称为 M M M&#xff0c;依据发送接收消息双方约定的生成多项式 G ( x ) G(x) G(x)&#xff08;意思就是 G &#xff08; x ) G&#xff08;x) G&#xff08;x) 是已知的&#xff09;&#xff0…...

基于uniapp+WebSocket实现聊天对话、消息监听、消息推送、聊天室等功能,多端兼容

基于 ​UniApp + WebSocket​实现多端兼容的实时通讯系统,涵盖WebSocket连接建立、消息收发机制、多端兼容性配置、消息实时监听等功能,适配​微信小程序、H5、Android、iOS等终端 目录 技术选型分析WebSocket协议优势UniApp跨平台特性WebSocket 基础实现连接管理消息收发连接…...

解决Ubuntu22.04 VMware失败的问题 ubuntu入门之二十八

现象1 打开VMware失败 Ubuntu升级之后打开VMware上报需要安装vmmon和vmnet&#xff0c;点击确认后如下提示 最终上报fail 解决方法 内核升级导致&#xff0c;需要在新内核下重新下载编译安装 查看版本 $ vmware -v VMware Workstation 17.5.1 build-23298084$ lsb_release…...

从深圳崛起的“机器之眼”:赴港乐动机器人的万亿赛道赶考路

进入2025年以来&#xff0c;尽管围绕人形机器人、具身智能等机器人赛道的质疑声不断&#xff0c;但全球市场热度依然高涨&#xff0c;入局者持续增加。 以国内市场为例&#xff0c;天眼查专业版数据显示&#xff0c;截至5月底&#xff0c;我国现存在业、存续状态的机器人相关企…...