浅谈PyTorch中的DP和DDP
目录
- 1. 引言
- 2. PyTorch 数据并行(Data Parallel, DP)
- 2.1 DP 的优缺点
- 2.2 DP 实现代码示例
- 3. PyTorch 分布式数据并行(Distributed Data Parallel, DDP)
- 3.1 DDP 的优缺点
- 3.2 分布式基本概念
- 3.3 DDP 的应用流程
- 3.5 DDP 实现代码示例
- 4. DP和DDP的对比
1. 引言
在现代深度学习中,随着模型规模的不断增大以及数据量的快速增长,模型训练所需的计算资源也变得愈加庞大。尤其是在大型深度学习模型的训练过程中,单张 GPU 显存往往难以满足需求,因此,如何高效利用多 GPU 进行并行训练,成为了加速模型训练的关键手段。PyTorch 作为目前最受欢迎的深度学习框架之一,提供了多种并行训练的方式,其中最常用的是 数据并行(Data Parallel, DP) 和 分布式数据并行(Distributed Data Parallel, DDP)。
⚠️ 无论是DP还是DDP都只支持数据并行。
2. PyTorch 数据并行(Data Parallel, DP)
数据并行(Data Parallel, DP) 是 PyTorch 中一种简单的并行训练方式,它的主要思想是将数据拆分为多个子集,然后将这些子集分别分配给不同的 GPU 进行计算。DP 的工作原理如下:
- 在前向传播时,首先将模型的参数复制到每个 GPU 上。
- 每个 GPU 独立计算一部分数据的前向传播和损失值,并将计算结果返回到主 GPU。
- 主 GPU 汇总每个 GPU 计算的损失,并计算出梯度。
- 通过反向传播,将计算得到的梯度更新主 GPU 的模型参数,然后再将更新后的参数广播到其他 GPU 上。
2.1 DP 的优缺点
优点:
- 实现简单,使用 PyTorch 提供的
torch.nn.DataParallel接口即可轻松实现。 - 对于小规模的模型和数据集,DP 能够在单机多卡的场景下提供良好的加速效果。
缺点:
- DP 在每个 batch 中需要在 GPU 之间传递模型参数和数据,参数更新时也需要将梯度传递回主 GPU,这会造成大量的通信开销。
- 由于梯度的计算和模型参数的更新都是在主 GPU 上完成的,主 GPU 的负载会显著增加,导致 GPU 资源无法得到充分利用。
2.2 DP 实现代码示例
使用 torch.nn.DataParallel 实现数据并行非常简单。我们只需要将模型封装到 DataParallel 中,然后传入多个 GPU 即可。下面我们通过代码示例展示如何使用 DP 进行并行训练。
import torch
import torch.nn as nn
import torchvisionBATCH_SIZE = 256
EPOCHS = 5
NUM_CLASSES = 10
INPUT_SHAPE = (3, 224, 224) # ResNet-18 的输入尺寸# 1. 创建模型
net = torchvision.models.resnet18(pretrained=False, num_classes=NUM_CLASSES)
net = nn.DataParallel(net)
net = net.cuda()# 2. 生成随机数据
total_steps = 100 # 假设每个 epoch 有 100 个步骤
inputs = torch.randn(BATCH_SIZE, *INPUT_SHAPE).cuda()
targets = torch.randint(0, NUM_CLASSES, (BATCH_SIZE,)).cuda()# 3. 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(net.parameters(), lr=0.02, momentum=0.9, weight_decay=0.0001, nesterov=True
)# 4. 开始训练
net.train()
for ep in range(1, EPOCHS + 1):train_loss = correct = total = 0for idx in range(total_steps):outputs = net(inputs)loss = criterion(outputs, targets)optimizer.zero_grad()loss.backward()optimizer.step()train_loss += loss.item()total += targets.size(0)correct += torch.eq(outputs.argmax(dim=1), targets).sum().item()if (idx + 1) % 25 == 0 or (idx + 1) == total_steps:print(f"Epoch [{ep}/{EPOCHS}], Step [{idx + 1}/{total_steps}], Loss: {train_loss / (idx + 1):.3f}, Acc: {correct / total:.3%}")
在这个代码示例中,我们使用了随机生成的输入和标签数据,以简化代码并专注于并行训练的实现。通过将模型封装在 DataParallel 中,我们可以在多个 GPU 上进行并行计算。然而,由于 DP 存在较大的通信开销以及主 GPU 的计算瓶颈,因此在更大规模的训练中,我们更推荐使用分布式数据并行(DDP)来加速训练。
3. PyTorch 分布式数据并行(Distributed Data Parallel, DDP)
分布式数据并行(Distributed Data Parallel, DDP) 是 PyTorch 中推荐使用的多 GPU 并行训练方式,特别适合大规模训练任务。与 DP 不同,DDP 是一种多进程并行方式,避免了 Python 全局解释器锁(GIL)的限制,可以在单机或多机多卡环境中实现更高效的并行计算。DDP的工作原理如下:
- 在每个 GPU 上运行一个独立的进程,每个进程都有自己的一份模型副本和数据。
- 各个进程独立执行前向传播、计算损失和反向传播,得到各自的梯度。
- 在反向传播阶段,各个 GPU 的进程通过通信将梯度汇总,平均后更新每个进程中的模型参数。
- 每个进程的模型参数在整个训练过程中保持一致,避免了 DP 中由于参数广播导致的通信开销。
3.1 DDP 的优缺点
优点:
- 由于各个 GPU 上的进程独立计算梯度,更新模型参数时只需要同步梯度而非整个模型,通信开销较小,性能大幅提升。
- DDP 可以在多机多卡环境下使用,支持大规模的分布式训练,适合深度学习模型的高效扩展。
缺点:
- 代码实现相对 DP 较为复杂,需要手动管理进程的初始化和同步。
3.2 分布式基本概念
在使用 DDP 进行分布式训练时,我们需要理解以下几个基本概念:
- node(节点):物理节点,一台机器即为一个节点。
- nnodes(节点数量):表示参与训练的物理节点数量。
- node rank(节点序号):节点的编号,用于区分不同的物理节点。
- nproc per node(每节点的进程数量):表示每个物理节点上启动的进程数量,通常等于 GPU 的数量。
- world size(全局进程数量):表示全局并行的进程总数,等于
nnodes * nproc_per_node。 - rank(进程序号):表示每个进程的唯一编号,用于进程间通信,
rank=0的进程为主进程。 - local rank(本地进程序号):在某个节点上的进程的序号,
local_rank=0表示该节点的主进程。
3.3 DDP 的应用流程
使用 DDP 进行分布式训练的步骤如下:
- 初始化分布式训练环境:通过
torch.distributed.init_process_group初始化进程组,指定通信后端和相关配置。 - 创建分布式模型:将模型封装到
torch.nn.parallel.DistributedDataParallel中,进行并行训练。 - 生成或加载数据:在每个进程中加载数据,并确保数据在不同进程间的分布,如使用
DistributedSampler。 - 执行训练脚本:在每个节点的每个进程上启动训练脚本,进行模型训练。
3.5 DDP 实现代码示例
import os
import torch
import torch.distributed as dist
import torch.nn as nn
import torchvision
from torch.nn.parallel import DistributedDataParallel as DDPBATCH_SIZE = 256
EPOCHS = 5
NUM_CLASSES = 10
INPUT_SHAPE = (3, 224, 224) # ResNet-18 的输入尺寸if __name__ == "__main__":# 1. 设置分布式变量,初始化进程组rank = int(os.environ["RANK"])local_rank = int(os.environ["LOCAL_RANK"])torch.cuda.set_device(local_rank)dist.init_process_group(backend="nccl")device = torch.device("cuda", local_rank)print(f"[init] == local rank: {local_rank}, global rank: {rank} ==")# 2. 创建模型net = torchvision.models.resnet18(pretrained=False, num_classes=NUM_CLASSES)net = net.to(device)net = DDP(net, device_ids=[local_rank], output_device=local_rank)# 3. 生成随机数据total_steps = 100 # 假设每个 epoch 有 100 个步骤inputs = torch.randn(BATCH_SIZE, *INPUT_SHAPE).to(device)targets = torch.randint(0, NUM_CLASSES, (BATCH_SIZE,)).to(device)# 4. 定义损失函数和优化器criterion = nn.CrossEntropyLoss()optimizer = torch.optim.SGD(net.parameters(), lr=0.02, momentum=0.9, weight_decay=0.0001, nesterov=True)# 5. 开始训练net.train()for ep in range(1, EPOCHS + 1):train_loss = correct = total = 0for idx in range(total_steps):outputs = net(inputs)loss = criterion(outputs, targets)optimizer.zero_grad()loss.backward()optimizer.step()train_loss += loss.item()total += targets.size(0)correct += torch.eq(outputs.argmax(dim=1), targets).sum().item()if rank == 0 and ((idx + 1) % 25 == 0 or (idx + 1) == total_steps):print(" == step: [{:3}/{}] [{}/{}] | loss: {:.3f} | acc: {:6.3f}%".format(idx + 1,total_steps,ep,EPOCHS,train_loss / (idx + 1),100.0 * correct / total,))if rank == 0:print("\n ======= Training Finished ======= \n")
在以上代码中,我们使用了随机生成的输入和标签数据,以简化代码并专注于 DDP 的实现细节。通过在每个进程中初始化分布式环境,并将模型封装在 DistributedDataParallel 中,我们可以在多个 GPU 上高效地进行并行训练。
需要注意的是,DDP 的实现需要在每个进程中正确设置设备和初始化过程,这样才能确保模型和数据在对应的 GPU 上进行计算。
4. DP和DDP的对比
DP 是单进程多线程的分布式方法,主要用于单机多卡的场景。它的工作方式是在每个批处理期间,将模型参数分发到所有 GPU,各 GPU 计算各自的梯度后将结果汇总到 GPU0,再由 GPU0 完成参数更新,然后将更新后的模型参数广播回其他 GPU。由于 DP 只广播模型的参数,速度较慢,尤其是在多个 GPU 协同工作时,GPU 利用率低,通常效率不如 DDP。
相比之下,DDP 使用多进程架构,既支持单机多卡,也支持多机多卡,并避免了 GIL(全局解释器锁)带来的性能损失。每个进程独立计算梯度,计算完成后各进程汇总并平均梯度,更新参数时各进程均独立完成。这种方式减少了通信开销,只在初始化时广播一次模型参数,并且在每次更新后只传递梯度。由于各进程独立更新参数,且更新过程中模型参数保持一致,DDP 在效率和速度上大大优于 DP。
| 数据并行(DP) | 分布式数据并行(DDP) | |
|---|---|---|
| 实现复杂度 | 使用 nn.DataParallel,实现简单,代码改动较少。 | 需要设置分布式环境,使用 torch.distributed,代码实现相对复杂,需要手动管理进程和同步。 |
| 通信开销 | 通信开销较大,参数和梯度需要在主 GPU 和其他 GPU 之间频繁传递。 | 通信开销较小,只在反向传播时同步梯度,各 GPU 之间直接通信,无需通过主 GPU。 |
| 扩展性 | 扩展性有限,适用于单机多卡,不支持多机训练。 | 扩展性强,支持单机多卡和多机多卡,适合大规模分布式训练。 |
| 性能 | 主 GPU 负载重,可能成为瓶颈,GPU 资源利用率较低。 | 各 GPU 负载均衡,资源利用率高,训练速度更快。 |
| 适用场景 | 适合小规模模型和数据集的单机多卡训练。 | 适合大规模模型和数据集的单机或多机多卡训练。 |
| 梯度同步方式 | 梯度在主 GPU 上汇总和更新,需要从其他 GPU 收集梯度。 | 梯度在各 GPU 间直接同步,通常使用 All-Reduce 操作,效率更高。 |
| 模型参数广播 | 每次前向传播都需要将模型参数从主 GPU 复制到其他 GPU。 | 初始化时各进程各自持有一份模型副本,参数更新后自动同步,无需频繁复制。 |
| 对 Python GIL 的影响 | 受限于 Python 全局解释器锁(GIL),因为是单进程多线程,无法充分利用多核 CPU。 | 采用多进程方式,不受 GIL 影响,能够充分利用多核 CPU 和多 GPU 进行并行计算。 |
| 容错性 | 主 GPU 故障会导致整个训练中断,容错性较差。 | 各进程相对独立,某个进程出错不会影响其他进程,容错性较好。 |
| 调试难度 | 由于是单进程,调试相对容易。 | 多进程调试较为复杂,需要注意进程间的通信和同步问题。 |
| 代码修改量 | 只需在模型外层加上 nn.DataParallel 封装,代码改动少。 | 需要在代码中添加进程初始化、模型封装、设备设置等步骤,修改量较大。 |
| 数据加载方式 | 使用常规的数据加载方式,无需特殊处理。 | 需要使用 DistributedSampler 等工具,确保各进程加载不同的数据子集,避免数据重复。 |
| 资源占用 | 主 GPU 内存和计算资源占用较高,其他 GPU 资源可能未被充分利用。 | 各 GPU 资源均衡占用,能够最大化利用多 GPU 的计算能力。 |
| 训练结果一致性 | 由于参数更新在主 GPU 上进行,可能存在精度损失或不一致的情况。 | 各进程的模型参数同步更新,训练结果一致性更好。 |
相关文章:
浅谈PyTorch中的DP和DDP
目录 1. 引言2. PyTorch 数据并行(Data Parallel, DP)2.1 DP 的优缺点2.2 DP 实现代码示例 3. PyTorch 分布式数据并行(Distributed Data Parallel, DDP)3.1 DDP 的优缺点3.2 分布式基本概念3.3 DDP 的应用流程3.5 DDP 实现代码示…...
在Windows上利用谷歌浏览器进行视频会议和协作
随着远程工作和在线教育的普及,使用谷歌浏览器在Windows上进行视频会议和协作变得越来越常见。本文将为您提供一个详细的教程,教您如何在Windows上利用谷歌浏览器进行视频会议和协作,同时解决一些常见的问题。(本文由https://goog…...
VMware Fusion 13.6.1 发布下载,修复 4 个已知问题
VMware Fusion 13.6.1 发布下载,修复 4 个已知问题 VMware Fusion 13.6.1 for Mac - 领先的免费桌面虚拟化软件 适用于基于 Intel 处理器和搭载 Apple 芯片的 Mac 的桌面虚拟化软件 请访问原文链接:https://sysin.org/blog/vmware-fusion-13/ 查看最新…...
P9751 [CSP-J 2023] 旅游巴士
P 9751 P9751 P9751 部分分思路 题目要求时间必须是 k k k 的非负整数倍,所以想到了升维。这样就变成了一道分层图最短路的题目。用 BFS 算法可以拿到 A i 0 A_i0 Ai0 的 35 35 35 分。 满分思路 其实部分分的思路已经很接近正解了,想要拿到满…...
【Linux】man手册安装使用
目录 man(manual,手册) 手册安装: 章节区分: 指令参数: 使用场景: 手册内容列表: 手册查看快捷键: 实例: 仍致谢:Linux常用命令大全(手册) – 真正好用的Linux命令在线查询网站 提供的命令查询 在开头先提醒一下:在 man 手册中退出的方法很简单…...
mysql学习教程,从入门到精通,SQL处理重复数据(39)
1、SQL处理重复数据 使用GROUP BY和HAVING子句删除重复数据(以SQL Server为例)”的背景和原理的详细解释: 1.1、背景 在数据库管理中,数据重复是一个常见的问题。重复数据可能由于多种原因产生,如数据录入错误、数据…...
mapbox解决wmts请求乱码问题
贴个群号 WebGIS学习交流群461555818,欢迎大家 事故现场 如图所示,wmts请求全是乱码,看起来像是将一个完整的请求拆成一个一个的字母了,而且控制台打印map.getStyle() 查看该source发现不出异常 解决办法 此类问题就是由于更…...
《C++职场中设计模式的学习与应用:开启高效编程之旅》
在 C职场中,设计模式是提升代码质量、增强程序可维护性和可扩展性的强大武器。掌握并正确应用设计模式,不仅能让你在工作中更加得心应手,还能为你的职业发展增添有力的砝码。那么,如何在 C职场中学习和应用设计模式呢?…...
Maya动画--基础约束
005-基础约束02_哔哩哔哩_bilibili 父子约束 移动圆环,球体会跟着移动,并回到初始的相对位置 不同物体间没有层级关系 明确子物体与父物体间的关系 衣服上的纽扣 法线约束 切线约束 碰到中心时会改变方向...
腾讯云License 相关
腾讯云视立方 License 是必须购买的吗? 若您下载的腾讯云视立方功能模块中,包含直播推流(主播开播和主播观众连麦/主播跨房 PK)、短视频(视频录制编辑/视频上传发布)、终端极速高清和腾讯特效功能模块&…...
开放式耳机什么品牌最好?十大超好用开放式耳机排名!
由于长时间使用传统入耳式耳机可能会对耳道健康带来潜在的负面影响,越来越多的用户倾向于选择开放式耳机,这种设计不侵入耳道。它有助于降低耳内湿度、减少细菌滋生,以及缓解耳道因封闭而过热的不适。但是大部分人还是不知道怎么选择开放式耳…...
基于Zynq SDIO WiFi移植二(支持2.4/5G)
1 SDIO设备识别 经过编译,将移植好的uboot、kernel、rootFS、ramdisk等烧录到Flash中,上电启动,在log中,可看到sdio设备 [ 1.747059] mmc1: queuing unknown CIS tuple 0x01 (3 bytes) [ 1.761842] mmc1: queuing unknown…...
Spring Boot敏感数据动态配置:深入实践与安全性提升
在构建Spring Boot应用的过程中,敏感数据的处理与保护是至关重要的。传统上,这些敏感数据(如数据库密码、API密钥、加密密钥等)可能被硬编码在配置文件中,这不仅增加了泄露的风险,也限制了配置的灵活性和可…...
软考数据库部分 ---- (概念数据库模型,三级模式,两级映像,事物管理)
文章目录 一、概念数据库模型二、结构数据库模型三、三级模式四、两级映像五、关系模式基本术语六、关系模式七、关系的数学定义八、数据定义语言九、SQL访问控制十、视图十一、索引十二、关系模式十三、范式十四、数据库设计十五、事物管理(ACID)十六、…...
AI 概念大杂烩
目录 介绍 数据挖掘 / 机器学习 / 深度学习 一、数据挖掘(Data Mining) 1. 定义 2. 目标 3. 常用算法 二、机器学习(Machine Learning) 1. 定义 2. 目标 3. 常用算法 三、深度学习(Deep Learning࿰…...
Composer和PHP有什么关系
Composer是PHP的一个依赖管理工具,以下是对Composer及其与PHP关系的详细解释: Composer简介 核心功能:Composer的核心思想是“依赖管理”,它能够自动下载和安装项目所依赖的库、框架或插件等。这些依赖项可以是PHP本身的库文件&…...
【PGCCC】在 Postgres 上构建图像搜索引擎
我最近看到的最有趣的电子商务功能之一是能够搜索与我手机上的图片相似的产品。例如,我可以拍一双鞋或其他产品的照片,然后搜索产品目录以查找类似商品。使用这样的功能可以是一个相当简单的项目,只要有合适的工具。如果我们可以将问题定义为…...
性能测试之性能问题分析
开始性能测试前需要了解的内容: 1、项目具体需求。 2、指标:响应时间在多少以内,并发数多少,tps多少,总tps多少,稳定性交易总量多少,事务成功率,交易波动范围,稳定运行…...
错过了A股,别再错过AI表情包!N款变现攻略,你选哪个?
本文背景 据 Swyft Media 统计,全世界每天各类聊天 app 发送的表情符号有 60 多亿,我们国家每天表情包发送量大概 6 亿次。 表情包简直就是个大淘金池,最近用 AI 做表情包也挺火。所以今天给大家讲讲一个用 AI 做表情包变现的项目。 以前没…...
SpringBoot驱动的美发沙龙管理系统:优雅地管理您的业务
1系统概述 1.1 研究背景 随着计算机技术的发展以及计算机网络的逐渐普及,互联网成为人们查找信息的重要场所,二十一世纪是信息的时代,所以信息的管理显得特别重要。因此,使用计算机来管理美发门店管理系统的相关信息成为必然。开发…...
【Axure高保真原型】引导弹窗
今天和大家中分享引导弹窗的原型模板,载入页面后,会显示引导弹窗,适用于引导用户使用页面,点击完成后,会显示下一个引导弹窗,直至最后一个引导弹窗完成后进入首页。具体效果可以点击下方视频观看或打开下方…...
工业安全零事故的智能守护者:一体化AI智能安防平台
前言: 通过AI视觉技术,为船厂提供全面的安全监控解决方案,涵盖交通违规检测、起重机轨道安全、非法入侵检测、盗窃防范、安全规范执行监控等多个方面,能够实现对应负责人反馈机制,并最终实现数据的统计报表。提升船厂…...
练习(含atoi的模拟实现,自定义类型等练习)
一、结构体大小的计算及位段 (结构体大小计算及位段 详解请看:自定义类型:结构体进阶-CSDN博客) 1.在32位系统环境,编译选项为4字节对齐,那么sizeof(A)和sizeof(B)是多少? #pragma pack(4)st…...
Java如何权衡是使用无序的数组还是有序的数组
在 Java 中,选择有序数组还是无序数组取决于具体场景的性能需求与操作特点。以下是关键权衡因素及决策指南: ⚖️ 核心权衡维度 维度有序数组无序数组查询性能二分查找 O(log n) ✅线性扫描 O(n) ❌插入/删除需移位维护顺序 O(n) ❌直接操作尾部 O(1) ✅内存开销与无序数组相…...
JUC笔记(上)-复习 涉及死锁 volatile synchronized CAS 原子操作
一、上下文切换 即使单核CPU也可以进行多线程执行代码,CPU会给每个线程分配CPU时间片来实现这个机制。时间片非常短,所以CPU会不断地切换线程执行,从而让我们感觉多个线程是同时执行的。时间片一般是十几毫秒(ms)。通过时间片分配算法执行。…...
selenium学习实战【Python爬虫】
selenium学习实战【Python爬虫】 文章目录 selenium学习实战【Python爬虫】一、声明二、学习目标三、安装依赖3.1 安装selenium库3.2 安装浏览器驱动3.2.1 查看Edge版本3.2.2 驱动安装 四、代码讲解4.1 配置浏览器4.2 加载更多4.3 寻找内容4.4 完整代码 五、报告文件爬取5.1 提…...
初学 pytest 记录
安装 pip install pytest用例可以是函数也可以是类中的方法 def test_func():print()class TestAdd: # def __init__(self): 在 pytest 中不可以使用__init__方法 # self.cc 12345 pytest.mark.api def test_str(self):res add(1, 2)assert res 12def test_int(self):r…...
人工智能(大型语言模型 LLMs)对不同学科的影响以及由此产生的新学习方式
今天是关于AI如何在教学中增强学生的学习体验,我把重要信息标红了。人文学科的价值被低估了 ⬇️ 转型与必要性 人工智能正在深刻地改变教育,这并非炒作,而是已经发生的巨大变革。教育机构和教育者不能忽视它,试图简单地禁止学生使…...
C++ 设计模式 《小明的奶茶加料风波》
👨🎓 模式名称:装饰器模式(Decorator Pattern) 👦 小明最近上线了校园奶茶配送功能,业务火爆,大家都在加料: 有的同学要加波霸 🟤,有的要加椰果…...
NPOI操作EXCEL文件 ——CAD C# 二次开发
缺点:dll.版本容易加载错误。CAD加载插件时,没有加载所有类库。插件运行过程中用到某个类库,会从CAD的安装目录找,找不到就报错了。 【方案2】让CAD在加载过程中把类库加载到内存 【方案3】是发现缺少了哪个库,就用插件程序加载进…...
