浅谈PyTorch中的DP和DDP
目录
- 1. 引言
- 2. PyTorch 数据并行(Data Parallel, DP)
- 2.1 DP 的优缺点
- 2.2 DP 实现代码示例
- 3. PyTorch 分布式数据并行(Distributed Data Parallel, DDP)
- 3.1 DDP 的优缺点
- 3.2 分布式基本概念
- 3.3 DDP 的应用流程
- 3.5 DDP 实现代码示例
- 4. DP和DDP的对比
1. 引言
在现代深度学习中,随着模型规模的不断增大以及数据量的快速增长,模型训练所需的计算资源也变得愈加庞大。尤其是在大型深度学习模型的训练过程中,单张 GPU 显存往往难以满足需求,因此,如何高效利用多 GPU 进行并行训练,成为了加速模型训练的关键手段。PyTorch 作为目前最受欢迎的深度学习框架之一,提供了多种并行训练的方式,其中最常用的是 数据并行(Data Parallel, DP) 和 分布式数据并行(Distributed Data Parallel, DDP)。
⚠️ 无论是DP还是DDP都只支持数据并行。
2. PyTorch 数据并行(Data Parallel, DP)
数据并行(Data Parallel, DP) 是 PyTorch 中一种简单的并行训练方式,它的主要思想是将数据拆分为多个子集,然后将这些子集分别分配给不同的 GPU 进行计算。DP 的工作原理如下:
- 在前向传播时,首先将模型的参数复制到每个 GPU 上。
- 每个 GPU 独立计算一部分数据的前向传播和损失值,并将计算结果返回到主 GPU。
- 主 GPU 汇总每个 GPU 计算的损失,并计算出梯度。
- 通过反向传播,将计算得到的梯度更新主 GPU 的模型参数,然后再将更新后的参数广播到其他 GPU 上。
2.1 DP 的优缺点
优点:
- 实现简单,使用 PyTorch 提供的
torch.nn.DataParallel接口即可轻松实现。 - 对于小规模的模型和数据集,DP 能够在单机多卡的场景下提供良好的加速效果。
缺点:
- DP 在每个 batch 中需要在 GPU 之间传递模型参数和数据,参数更新时也需要将梯度传递回主 GPU,这会造成大量的通信开销。
- 由于梯度的计算和模型参数的更新都是在主 GPU 上完成的,主 GPU 的负载会显著增加,导致 GPU 资源无法得到充分利用。
2.2 DP 实现代码示例
使用 torch.nn.DataParallel 实现数据并行非常简单。我们只需要将模型封装到 DataParallel 中,然后传入多个 GPU 即可。下面我们通过代码示例展示如何使用 DP 进行并行训练。
import torch
import torch.nn as nn
import torchvisionBATCH_SIZE = 256
EPOCHS = 5
NUM_CLASSES = 10
INPUT_SHAPE = (3, 224, 224) # ResNet-18 的输入尺寸# 1. 创建模型
net = torchvision.models.resnet18(pretrained=False, num_classes=NUM_CLASSES)
net = nn.DataParallel(net)
net = net.cuda()# 2. 生成随机数据
total_steps = 100 # 假设每个 epoch 有 100 个步骤
inputs = torch.randn(BATCH_SIZE, *INPUT_SHAPE).cuda()
targets = torch.randint(0, NUM_CLASSES, (BATCH_SIZE,)).cuda()# 3. 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(net.parameters(), lr=0.02, momentum=0.9, weight_decay=0.0001, nesterov=True
)# 4. 开始训练
net.train()
for ep in range(1, EPOCHS + 1):train_loss = correct = total = 0for idx in range(total_steps):outputs = net(inputs)loss = criterion(outputs, targets)optimizer.zero_grad()loss.backward()optimizer.step()train_loss += loss.item()total += targets.size(0)correct += torch.eq(outputs.argmax(dim=1), targets).sum().item()if (idx + 1) % 25 == 0 or (idx + 1) == total_steps:print(f"Epoch [{ep}/{EPOCHS}], Step [{idx + 1}/{total_steps}], Loss: {train_loss / (idx + 1):.3f}, Acc: {correct / total:.3%}")
在这个代码示例中,我们使用了随机生成的输入和标签数据,以简化代码并专注于并行训练的实现。通过将模型封装在 DataParallel 中,我们可以在多个 GPU 上进行并行计算。然而,由于 DP 存在较大的通信开销以及主 GPU 的计算瓶颈,因此在更大规模的训练中,我们更推荐使用分布式数据并行(DDP)来加速训练。
3. PyTorch 分布式数据并行(Distributed Data Parallel, DDP)
分布式数据并行(Distributed Data Parallel, DDP) 是 PyTorch 中推荐使用的多 GPU 并行训练方式,特别适合大规模训练任务。与 DP 不同,DDP 是一种多进程并行方式,避免了 Python 全局解释器锁(GIL)的限制,可以在单机或多机多卡环境中实现更高效的并行计算。DDP的工作原理如下:
- 在每个 GPU 上运行一个独立的进程,每个进程都有自己的一份模型副本和数据。
- 各个进程独立执行前向传播、计算损失和反向传播,得到各自的梯度。
- 在反向传播阶段,各个 GPU 的进程通过通信将梯度汇总,平均后更新每个进程中的模型参数。
- 每个进程的模型参数在整个训练过程中保持一致,避免了 DP 中由于参数广播导致的通信开销。
3.1 DDP 的优缺点
优点:
- 由于各个 GPU 上的进程独立计算梯度,更新模型参数时只需要同步梯度而非整个模型,通信开销较小,性能大幅提升。
- DDP 可以在多机多卡环境下使用,支持大规模的分布式训练,适合深度学习模型的高效扩展。
缺点:
- 代码实现相对 DP 较为复杂,需要手动管理进程的初始化和同步。
3.2 分布式基本概念
在使用 DDP 进行分布式训练时,我们需要理解以下几个基本概念:
- node(节点):物理节点,一台机器即为一个节点。
- nnodes(节点数量):表示参与训练的物理节点数量。
- node rank(节点序号):节点的编号,用于区分不同的物理节点。
- nproc per node(每节点的进程数量):表示每个物理节点上启动的进程数量,通常等于 GPU 的数量。
- world size(全局进程数量):表示全局并行的进程总数,等于
nnodes * nproc_per_node。 - rank(进程序号):表示每个进程的唯一编号,用于进程间通信,
rank=0的进程为主进程。 - local rank(本地进程序号):在某个节点上的进程的序号,
local_rank=0表示该节点的主进程。
3.3 DDP 的应用流程
使用 DDP 进行分布式训练的步骤如下:
- 初始化分布式训练环境:通过
torch.distributed.init_process_group初始化进程组,指定通信后端和相关配置。 - 创建分布式模型:将模型封装到
torch.nn.parallel.DistributedDataParallel中,进行并行训练。 - 生成或加载数据:在每个进程中加载数据,并确保数据在不同进程间的分布,如使用
DistributedSampler。 - 执行训练脚本:在每个节点的每个进程上启动训练脚本,进行模型训练。
3.5 DDP 实现代码示例
import os
import torch
import torch.distributed as dist
import torch.nn as nn
import torchvision
from torch.nn.parallel import DistributedDataParallel as DDPBATCH_SIZE = 256
EPOCHS = 5
NUM_CLASSES = 10
INPUT_SHAPE = (3, 224, 224) # ResNet-18 的输入尺寸if __name__ == "__main__":# 1. 设置分布式变量,初始化进程组rank = int(os.environ["RANK"])local_rank = int(os.environ["LOCAL_RANK"])torch.cuda.set_device(local_rank)dist.init_process_group(backend="nccl")device = torch.device("cuda", local_rank)print(f"[init] == local rank: {local_rank}, global rank: {rank} ==")# 2. 创建模型net = torchvision.models.resnet18(pretrained=False, num_classes=NUM_CLASSES)net = net.to(device)net = DDP(net, device_ids=[local_rank], output_device=local_rank)# 3. 生成随机数据total_steps = 100 # 假设每个 epoch 有 100 个步骤inputs = torch.randn(BATCH_SIZE, *INPUT_SHAPE).to(device)targets = torch.randint(0, NUM_CLASSES, (BATCH_SIZE,)).to(device)# 4. 定义损失函数和优化器criterion = nn.CrossEntropyLoss()optimizer = torch.optim.SGD(net.parameters(), lr=0.02, momentum=0.9, weight_decay=0.0001, nesterov=True)# 5. 开始训练net.train()for ep in range(1, EPOCHS + 1):train_loss = correct = total = 0for idx in range(total_steps):outputs = net(inputs)loss = criterion(outputs, targets)optimizer.zero_grad()loss.backward()optimizer.step()train_loss += loss.item()total += targets.size(0)correct += torch.eq(outputs.argmax(dim=1), targets).sum().item()if rank == 0 and ((idx + 1) % 25 == 0 or (idx + 1) == total_steps):print(" == step: [{:3}/{}] [{}/{}] | loss: {:.3f} | acc: {:6.3f}%".format(idx + 1,total_steps,ep,EPOCHS,train_loss / (idx + 1),100.0 * correct / total,))if rank == 0:print("\n ======= Training Finished ======= \n")
在以上代码中,我们使用了随机生成的输入和标签数据,以简化代码并专注于 DDP 的实现细节。通过在每个进程中初始化分布式环境,并将模型封装在 DistributedDataParallel 中,我们可以在多个 GPU 上高效地进行并行训练。
需要注意的是,DDP 的实现需要在每个进程中正确设置设备和初始化过程,这样才能确保模型和数据在对应的 GPU 上进行计算。
4. DP和DDP的对比
DP 是单进程多线程的分布式方法,主要用于单机多卡的场景。它的工作方式是在每个批处理期间,将模型参数分发到所有 GPU,各 GPU 计算各自的梯度后将结果汇总到 GPU0,再由 GPU0 完成参数更新,然后将更新后的模型参数广播回其他 GPU。由于 DP 只广播模型的参数,速度较慢,尤其是在多个 GPU 协同工作时,GPU 利用率低,通常效率不如 DDP。
相比之下,DDP 使用多进程架构,既支持单机多卡,也支持多机多卡,并避免了 GIL(全局解释器锁)带来的性能损失。每个进程独立计算梯度,计算完成后各进程汇总并平均梯度,更新参数时各进程均独立完成。这种方式减少了通信开销,只在初始化时广播一次模型参数,并且在每次更新后只传递梯度。由于各进程独立更新参数,且更新过程中模型参数保持一致,DDP 在效率和速度上大大优于 DP。
| 数据并行(DP) | 分布式数据并行(DDP) | |
|---|---|---|
| 实现复杂度 | 使用 nn.DataParallel,实现简单,代码改动较少。 | 需要设置分布式环境,使用 torch.distributed,代码实现相对复杂,需要手动管理进程和同步。 |
| 通信开销 | 通信开销较大,参数和梯度需要在主 GPU 和其他 GPU 之间频繁传递。 | 通信开销较小,只在反向传播时同步梯度,各 GPU 之间直接通信,无需通过主 GPU。 |
| 扩展性 | 扩展性有限,适用于单机多卡,不支持多机训练。 | 扩展性强,支持单机多卡和多机多卡,适合大规模分布式训练。 |
| 性能 | 主 GPU 负载重,可能成为瓶颈,GPU 资源利用率较低。 | 各 GPU 负载均衡,资源利用率高,训练速度更快。 |
| 适用场景 | 适合小规模模型和数据集的单机多卡训练。 | 适合大规模模型和数据集的单机或多机多卡训练。 |
| 梯度同步方式 | 梯度在主 GPU 上汇总和更新,需要从其他 GPU 收集梯度。 | 梯度在各 GPU 间直接同步,通常使用 All-Reduce 操作,效率更高。 |
| 模型参数广播 | 每次前向传播都需要将模型参数从主 GPU 复制到其他 GPU。 | 初始化时各进程各自持有一份模型副本,参数更新后自动同步,无需频繁复制。 |
| 对 Python GIL 的影响 | 受限于 Python 全局解释器锁(GIL),因为是单进程多线程,无法充分利用多核 CPU。 | 采用多进程方式,不受 GIL 影响,能够充分利用多核 CPU 和多 GPU 进行并行计算。 |
| 容错性 | 主 GPU 故障会导致整个训练中断,容错性较差。 | 各进程相对独立,某个进程出错不会影响其他进程,容错性较好。 |
| 调试难度 | 由于是单进程,调试相对容易。 | 多进程调试较为复杂,需要注意进程间的通信和同步问题。 |
| 代码修改量 | 只需在模型外层加上 nn.DataParallel 封装,代码改动少。 | 需要在代码中添加进程初始化、模型封装、设备设置等步骤,修改量较大。 |
| 数据加载方式 | 使用常规的数据加载方式,无需特殊处理。 | 需要使用 DistributedSampler 等工具,确保各进程加载不同的数据子集,避免数据重复。 |
| 资源占用 | 主 GPU 内存和计算资源占用较高,其他 GPU 资源可能未被充分利用。 | 各 GPU 资源均衡占用,能够最大化利用多 GPU 的计算能力。 |
| 训练结果一致性 | 由于参数更新在主 GPU 上进行,可能存在精度损失或不一致的情况。 | 各进程的模型参数同步更新,训练结果一致性更好。 |
相关文章:
浅谈PyTorch中的DP和DDP
目录 1. 引言2. PyTorch 数据并行(Data Parallel, DP)2.1 DP 的优缺点2.2 DP 实现代码示例 3. PyTorch 分布式数据并行(Distributed Data Parallel, DDP)3.1 DDP 的优缺点3.2 分布式基本概念3.3 DDP 的应用流程3.5 DDP 实现代码示…...
在Windows上利用谷歌浏览器进行视频会议和协作
随着远程工作和在线教育的普及,使用谷歌浏览器在Windows上进行视频会议和协作变得越来越常见。本文将为您提供一个详细的教程,教您如何在Windows上利用谷歌浏览器进行视频会议和协作,同时解决一些常见的问题。(本文由https://goog…...
VMware Fusion 13.6.1 发布下载,修复 4 个已知问题
VMware Fusion 13.6.1 发布下载,修复 4 个已知问题 VMware Fusion 13.6.1 for Mac - 领先的免费桌面虚拟化软件 适用于基于 Intel 处理器和搭载 Apple 芯片的 Mac 的桌面虚拟化软件 请访问原文链接:https://sysin.org/blog/vmware-fusion-13/ 查看最新…...
P9751 [CSP-J 2023] 旅游巴士
P 9751 P9751 P9751 部分分思路 题目要求时间必须是 k k k 的非负整数倍,所以想到了升维。这样就变成了一道分层图最短路的题目。用 BFS 算法可以拿到 A i 0 A_i0 Ai0 的 35 35 35 分。 满分思路 其实部分分的思路已经很接近正解了,想要拿到满…...
【Linux】man手册安装使用
目录 man(manual,手册) 手册安装: 章节区分: 指令参数: 使用场景: 手册内容列表: 手册查看快捷键: 实例: 仍致谢:Linux常用命令大全(手册) – 真正好用的Linux命令在线查询网站 提供的命令查询 在开头先提醒一下:在 man 手册中退出的方法很简单…...
mysql学习教程,从入门到精通,SQL处理重复数据(39)
1、SQL处理重复数据 使用GROUP BY和HAVING子句删除重复数据(以SQL Server为例)”的背景和原理的详细解释: 1.1、背景 在数据库管理中,数据重复是一个常见的问题。重复数据可能由于多种原因产生,如数据录入错误、数据…...
mapbox解决wmts请求乱码问题
贴个群号 WebGIS学习交流群461555818,欢迎大家 事故现场 如图所示,wmts请求全是乱码,看起来像是将一个完整的请求拆成一个一个的字母了,而且控制台打印map.getStyle() 查看该source发现不出异常 解决办法 此类问题就是由于更…...
《C++职场中设计模式的学习与应用:开启高效编程之旅》
在 C职场中,设计模式是提升代码质量、增强程序可维护性和可扩展性的强大武器。掌握并正确应用设计模式,不仅能让你在工作中更加得心应手,还能为你的职业发展增添有力的砝码。那么,如何在 C职场中学习和应用设计模式呢?…...
Maya动画--基础约束
005-基础约束02_哔哩哔哩_bilibili 父子约束 移动圆环,球体会跟着移动,并回到初始的相对位置 不同物体间没有层级关系 明确子物体与父物体间的关系 衣服上的纽扣 法线约束 切线约束 碰到中心时会改变方向...
腾讯云License 相关
腾讯云视立方 License 是必须购买的吗? 若您下载的腾讯云视立方功能模块中,包含直播推流(主播开播和主播观众连麦/主播跨房 PK)、短视频(视频录制编辑/视频上传发布)、终端极速高清和腾讯特效功能模块&…...
开放式耳机什么品牌最好?十大超好用开放式耳机排名!
由于长时间使用传统入耳式耳机可能会对耳道健康带来潜在的负面影响,越来越多的用户倾向于选择开放式耳机,这种设计不侵入耳道。它有助于降低耳内湿度、减少细菌滋生,以及缓解耳道因封闭而过热的不适。但是大部分人还是不知道怎么选择开放式耳…...
基于Zynq SDIO WiFi移植二(支持2.4/5G)
1 SDIO设备识别 经过编译,将移植好的uboot、kernel、rootFS、ramdisk等烧录到Flash中,上电启动,在log中,可看到sdio设备 [ 1.747059] mmc1: queuing unknown CIS tuple 0x01 (3 bytes) [ 1.761842] mmc1: queuing unknown…...
Spring Boot敏感数据动态配置:深入实践与安全性提升
在构建Spring Boot应用的过程中,敏感数据的处理与保护是至关重要的。传统上,这些敏感数据(如数据库密码、API密钥、加密密钥等)可能被硬编码在配置文件中,这不仅增加了泄露的风险,也限制了配置的灵活性和可…...
软考数据库部分 ---- (概念数据库模型,三级模式,两级映像,事物管理)
文章目录 一、概念数据库模型二、结构数据库模型三、三级模式四、两级映像五、关系模式基本术语六、关系模式七、关系的数学定义八、数据定义语言九、SQL访问控制十、视图十一、索引十二、关系模式十三、范式十四、数据库设计十五、事物管理(ACID)十六、…...
AI 概念大杂烩
目录 介绍 数据挖掘 / 机器学习 / 深度学习 一、数据挖掘(Data Mining) 1. 定义 2. 目标 3. 常用算法 二、机器学习(Machine Learning) 1. 定义 2. 目标 3. 常用算法 三、深度学习(Deep Learning࿰…...
Composer和PHP有什么关系
Composer是PHP的一个依赖管理工具,以下是对Composer及其与PHP关系的详细解释: Composer简介 核心功能:Composer的核心思想是“依赖管理”,它能够自动下载和安装项目所依赖的库、框架或插件等。这些依赖项可以是PHP本身的库文件&…...
【PGCCC】在 Postgres 上构建图像搜索引擎
我最近看到的最有趣的电子商务功能之一是能够搜索与我手机上的图片相似的产品。例如,我可以拍一双鞋或其他产品的照片,然后搜索产品目录以查找类似商品。使用这样的功能可以是一个相当简单的项目,只要有合适的工具。如果我们可以将问题定义为…...
性能测试之性能问题分析
开始性能测试前需要了解的内容: 1、项目具体需求。 2、指标:响应时间在多少以内,并发数多少,tps多少,总tps多少,稳定性交易总量多少,事务成功率,交易波动范围,稳定运行…...
错过了A股,别再错过AI表情包!N款变现攻略,你选哪个?
本文背景 据 Swyft Media 统计,全世界每天各类聊天 app 发送的表情符号有 60 多亿,我们国家每天表情包发送量大概 6 亿次。 表情包简直就是个大淘金池,最近用 AI 做表情包也挺火。所以今天给大家讲讲一个用 AI 做表情包变现的项目。 以前没…...
SpringBoot驱动的美发沙龙管理系统:优雅地管理您的业务
1系统概述 1.1 研究背景 随着计算机技术的发展以及计算机网络的逐渐普及,互联网成为人们查找信息的重要场所,二十一世纪是信息的时代,所以信息的管理显得特别重要。因此,使用计算机来管理美发门店管理系统的相关信息成为必然。开发…...
高等数学(下)题型笔记(八)空间解析几何与向量代数
目录 0 前言 1 向量的点乘 1.1 基本公式 1.2 例题 2 向量的叉乘 2.1 基础知识 2.2 例题 3 空间平面方程 3.1 基础知识 3.2 例题 4 空间直线方程 4.1 基础知识 4.2 例题 5 旋转曲面及其方程 5.1 基础知识 5.2 例题 6 空间曲面的法线与切平面 6.1 基础知识 6.2…...
鸿蒙DevEco Studio HarmonyOS 5跑酷小游戏实现指南
1. 项目概述 本跑酷小游戏基于鸿蒙HarmonyOS 5开发,使用DevEco Studio作为开发工具,采用Java语言实现,包含角色控制、障碍物生成和分数计算系统。 2. 项目结构 /src/main/java/com/example/runner/├── MainAbilitySlice.java // 主界…...
Java 二维码
Java 二维码 **技术:**谷歌 ZXing 实现 首先添加依赖 <!-- 二维码依赖 --><dependency><groupId>com.google.zxing</groupId><artifactId>core</artifactId><version>3.5.1</version></dependency><de…...
Yolov8 目标检测蒸馏学习记录
yolov8系列模型蒸馏基本流程,代码下载:这里本人提交了一个demo:djdll/Yolov8_Distillation: Yolov8轻量化_蒸馏代码实现 在轻量化模型设计中,**知识蒸馏(Knowledge Distillation)**被广泛应用,作为提升模型…...
return this;返回的是谁
一个审批系统的示例来演示责任链模式的实现。假设公司需要处理不同金额的采购申请,不同级别的经理有不同的审批权限: // 抽象处理者:审批者 abstract class Approver {protected Approver successor; // 下一个处理者// 设置下一个处理者pub…...
C#学习第29天:表达式树(Expression Trees)
目录 什么是表达式树? 核心概念 1.表达式树的构建 2. 表达式树与Lambda表达式 3.解析和访问表达式树 4.动态条件查询 表达式树的优势 1.动态构建查询 2.LINQ 提供程序支持: 3.性能优化 4.元数据处理 5.代码转换和重写 适用场景 代码复杂性…...
[大语言模型]在个人电脑上部署ollama 并进行管理,最后配置AI程序开发助手.
ollama官网: 下载 https://ollama.com/ 安装 查看可以使用的模型 https://ollama.com/search 例如 https://ollama.com/library/deepseek-r1/tags # deepseek-r1:7bollama pull deepseek-r1:7b改token数量为409622 16384 ollama命令说明 ollama serve #:…...
springboot 日志类切面,接口成功记录日志,失败不记录
springboot 日志类切面,接口成功记录日志,失败不记录 自定义一个注解方法 import java.lang.annotation.ElementType; import java.lang.annotation.Retention; import java.lang.annotation.RetentionPolicy; import java.lang.annotation.Target;/***…...
Vue3 PC端 UI组件库我更推荐Naive UI
一、Vue3生态现状与UI库选择的重要性 随着Vue3的稳定发布和Composition API的广泛采用,前端开发者面临着UI组件库的重新选择。一个好的UI库不仅能提升开发效率,还能确保项目的长期可维护性。本文将对比三大主流Vue3 UI库(Naive UI、Element …...
关于 ffmpeg设置摄像头报错“Could not set video options” 的解决方法
若该文为原创文章,转载请注明原文出处 本文章博客地址:https://hpzwl.blog.csdn.net/article/details/148515355 长沙红胖子Qt(长沙创微智科)博文大全:开发技术集合(包含Qt实用技术、树莓派、三维、OpenCV…...
