当前位置: 首页 > news >正文

浅谈PyTorch中的DP和DDP

目录

  • 1. 引言
  • 2. PyTorch 数据并行(Data Parallel, DP)
    • 2.1 DP 的优缺点
    • 2.2 DP 实现代码示例
  • 3. PyTorch 分布式数据并行(Distributed Data Parallel, DDP)
    • 3.1 DDP 的优缺点
    • 3.2 分布式基本概念
    • 3.3 DDP 的应用流程
    • 3.5 DDP 实现代码示例
  • 4. DP和DDP的对比

1. 引言

在现代深度学习中,随着模型规模的不断增大以及数据量的快速增长,模型训练所需的计算资源也变得愈加庞大。尤其是在大型深度学习模型的训练过程中,单张 GPU 显存往往难以满足需求,因此,如何高效利用多 GPU 进行并行训练,成为了加速模型训练的关键手段。PyTorch 作为目前最受欢迎的深度学习框架之一,提供了多种并行训练的方式,其中最常用的是 数据并行(Data Parallel, DP)分布式数据并行(Distributed Data Parallel, DDP)

⚠️ 无论是DP还是DDP都只支持数据并行。

2. PyTorch 数据并行(Data Parallel, DP)

数据并行(Data Parallel, DP) 是 PyTorch 中一种简单的并行训练方式,它的主要思想是将数据拆分为多个子集,然后将这些子集分别分配给不同的 GPU 进行计算。DP 的工作原理如下:

  1. 在前向传播时,首先将模型的参数复制到每个 GPU 上。
  2. 每个 GPU 独立计算一部分数据的前向传播和损失值,并将计算结果返回到主 GPU。
  3. 主 GPU 汇总每个 GPU 计算的损失,并计算出梯度。
  4. 通过反向传播,将计算得到的梯度更新主 GPU 的模型参数,然后再将更新后的参数广播到其他 GPU 上。

2.1 DP 的优缺点

优点

  • 实现简单,使用 PyTorch 提供的 torch.nn.DataParallel 接口即可轻松实现。
  • 对于小规模的模型和数据集,DP 能够在单机多卡的场景下提供良好的加速效果。

缺点

  • DP 在每个 batch 中需要在 GPU 之间传递模型参数和数据,参数更新时也需要将梯度传递回主 GPU,这会造成大量的通信开销。
  • 由于梯度的计算和模型参数的更新都是在主 GPU 上完成的,主 GPU 的负载会显著增加,导致 GPU 资源无法得到充分利用。

2.2 DP 实现代码示例

使用 torch.nn.DataParallel 实现数据并行非常简单。我们只需要将模型封装到 DataParallel 中,然后传入多个 GPU 即可。下面我们通过代码示例展示如何使用 DP 进行并行训练。

import torch
import torch.nn as nn
import torchvisionBATCH_SIZE = 256
EPOCHS = 5
NUM_CLASSES = 10
INPUT_SHAPE = (3, 224, 224)  # ResNet-18 的输入尺寸# 1. 创建模型
net = torchvision.models.resnet18(pretrained=False, num_classes=NUM_CLASSES)
net = nn.DataParallel(net)
net = net.cuda()# 2. 生成随机数据
total_steps = 100  # 假设每个 epoch 有 100 个步骤
inputs = torch.randn(BATCH_SIZE, *INPUT_SHAPE).cuda()
targets = torch.randint(0, NUM_CLASSES, (BATCH_SIZE,)).cuda()# 3. 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(net.parameters(), lr=0.02, momentum=0.9, weight_decay=0.0001, nesterov=True
)# 4. 开始训练
net.train()
for ep in range(1, EPOCHS + 1):train_loss = correct = total = 0for idx in range(total_steps):outputs = net(inputs)loss = criterion(outputs, targets)optimizer.zero_grad()loss.backward()optimizer.step()train_loss += loss.item()total += targets.size(0)correct += torch.eq(outputs.argmax(dim=1), targets).sum().item()if (idx + 1) % 25 == 0 or (idx + 1) == total_steps:print(f"Epoch [{ep}/{EPOCHS}], Step [{idx + 1}/{total_steps}], Loss: {train_loss / (idx + 1):.3f}, Acc: {correct / total:.3%}")

在这个代码示例中,我们使用了随机生成的输入和标签数据,以简化代码并专注于并行训练的实现。通过将模型封装在 DataParallel 中,我们可以在多个 GPU 上进行并行计算。然而,由于 DP 存在较大的通信开销以及主 GPU 的计算瓶颈,因此在更大规模的训练中,我们更推荐使用分布式数据并行(DDP)来加速训练。

3. PyTorch 分布式数据并行(Distributed Data Parallel, DDP)

分布式数据并行(Distributed Data Parallel, DDP) 是 PyTorch 中推荐使用的多 GPU 并行训练方式,特别适合大规模训练任务。与 DP 不同,DDP 是一种多进程并行方式,避免了 Python 全局解释器锁(GIL)的限制,可以在单机或多机多卡环境中实现更高效的并行计算。DDP的工作原理如下:

  1. 在每个 GPU 上运行一个独立的进程,每个进程都有自己的一份模型副本和数据。
  2. 各个进程独立执行前向传播、计算损失和反向传播,得到各自的梯度。
  3. 在反向传播阶段,各个 GPU 的进程通过通信将梯度汇总,平均后更新每个进程中的模型参数。
  4. 每个进程的模型参数在整个训练过程中保持一致,避免了 DP 中由于参数广播导致的通信开销。

3.1 DDP 的优缺点

优点

  • 由于各个 GPU 上的进程独立计算梯度,更新模型参数时只需要同步梯度而非整个模型,通信开销较小,性能大幅提升。
  • DDP 可以在多机多卡环境下使用,支持大规模的分布式训练,适合深度学习模型的高效扩展。

缺点

  • 代码实现相对 DP 较为复杂,需要手动管理进程的初始化和同步。

3.2 分布式基本概念

在使用 DDP 进行分布式训练时,我们需要理解以下几个基本概念:

  1. node(节点):物理节点,一台机器即为一个节点。
  2. nnodes(节点数量):表示参与训练的物理节点数量。
  3. node rank(节点序号):节点的编号,用于区分不同的物理节点。
  4. nproc per node(每节点的进程数量):表示每个物理节点上启动的进程数量,通常等于 GPU 的数量。
  5. world size(全局进程数量):表示全局并行的进程总数,等于 nnodes * nproc_per_node
  6. rank(进程序号):表示每个进程的唯一编号,用于进程间通信,rank=0 的进程为主进程。
  7. local rank(本地进程序号):在某个节点上的进程的序号,local_rank=0 表示该节点的主进程。

3.3 DDP 的应用流程

使用 DDP 进行分布式训练的步骤如下:

  1. 初始化分布式训练环境:通过 torch.distributed.init_process_group 初始化进程组,指定通信后端和相关配置。
  2. 创建分布式模型:将模型封装到 torch.nn.parallel.DistributedDataParallel 中,进行并行训练。
  3. 生成或加载数据:在每个进程中加载数据,并确保数据在不同进程间的分布,如使用 DistributedSampler
  4. 执行训练脚本:在每个节点的每个进程上启动训练脚本,进行模型训练。

3.5 DDP 实现代码示例

import os
import torch
import torch.distributed as dist
import torch.nn as nn
import torchvision
from torch.nn.parallel import DistributedDataParallel as DDPBATCH_SIZE = 256
EPOCHS = 5
NUM_CLASSES = 10
INPUT_SHAPE = (3, 224, 224)  # ResNet-18 的输入尺寸if __name__ == "__main__":# 1. 设置分布式变量,初始化进程组rank = int(os.environ["RANK"])local_rank = int(os.environ["LOCAL_RANK"])torch.cuda.set_device(local_rank)dist.init_process_group(backend="nccl")device = torch.device("cuda", local_rank)print(f"[init] == local rank: {local_rank}, global rank: {rank} ==")# 2. 创建模型net = torchvision.models.resnet18(pretrained=False, num_classes=NUM_CLASSES)net = net.to(device)net = DDP(net, device_ids=[local_rank], output_device=local_rank)# 3. 生成随机数据total_steps = 100  # 假设每个 epoch 有 100 个步骤inputs = torch.randn(BATCH_SIZE, *INPUT_SHAPE).to(device)targets = torch.randint(0, NUM_CLASSES, (BATCH_SIZE,)).to(device)# 4. 定义损失函数和优化器criterion = nn.CrossEntropyLoss()optimizer = torch.optim.SGD(net.parameters(), lr=0.02, momentum=0.9, weight_decay=0.0001, nesterov=True)# 5. 开始训练net.train()for ep in range(1, EPOCHS + 1):train_loss = correct = total = 0for idx in range(total_steps):outputs = net(inputs)loss = criterion(outputs, targets)optimizer.zero_grad()loss.backward()optimizer.step()train_loss += loss.item()total += targets.size(0)correct += torch.eq(outputs.argmax(dim=1), targets).sum().item()if rank == 0 and ((idx + 1) % 25 == 0 or (idx + 1) == total_steps):print("   == step: [{:3}/{}] [{}/{}] | loss: {:.3f} | acc: {:6.3f}%".format(idx + 1,total_steps,ep,EPOCHS,train_loss / (idx + 1),100.0 * correct / total,))if rank == 0:print("\n            =======  Training Finished  ======= \n")

在以上代码中,我们使用了随机生成的输入和标签数据,以简化代码并专注于 DDP 的实现细节。通过在每个进程中初始化分布式环境,并将模型封装在 DistributedDataParallel 中,我们可以在多个 GPU 上高效地进行并行训练。

需要注意的是,DDP 的实现需要在每个进程中正确设置设备和初始化过程,这样才能确保模型和数据在对应的 GPU 上进行计算。

4. DP和DDP的对比

DP 是单进程多线程的分布式方法,主要用于单机多卡的场景。它的工作方式是在每个批处理期间,将模型参数分发到所有 GPU,各 GPU 计算各自的梯度后将结果汇总到 GPU0,再由 GPU0 完成参数更新,然后将更新后的模型参数广播回其他 GPU。由于 DP 只广播模型的参数,速度较慢,尤其是在多个 GPU 协同工作时,GPU 利用率低,通常效率不如 DDP。

相比之下,DDP 使用多进程架构,既支持单机多卡,也支持多机多卡,并避免了 GIL(全局解释器锁)带来的性能损失。每个进程独立计算梯度,计算完成后各进程汇总并平均梯度,更新参数时各进程均独立完成。这种方式减少了通信开销,只在初始化时广播一次模型参数,并且在每次更新后只传递梯度。由于各进程独立更新参数,且更新过程中模型参数保持一致,DDP 在效率和速度上大大优于 DP。

数据并行(DP)分布式数据并行(DDP)
实现复杂度使用 nn.DataParallel,实现简单,代码改动较少。需要设置分布式环境,使用 torch.distributed,代码实现相对复杂,需要手动管理进程和同步。
通信开销通信开销较大,参数和梯度需要在主 GPU 和其他 GPU 之间频繁传递。通信开销较小,只在反向传播时同步梯度,各 GPU 之间直接通信,无需通过主 GPU。
扩展性扩展性有限,适用于单机多卡,不支持多机训练。扩展性强,支持单机多卡和多机多卡,适合大规模分布式训练。
性能主 GPU 负载重,可能成为瓶颈,GPU 资源利用率较低。各 GPU 负载均衡,资源利用率高,训练速度更快。
适用场景适合小规模模型和数据集的单机多卡训练。适合大规模模型和数据集的单机或多机多卡训练。
梯度同步方式梯度在主 GPU 上汇总和更新,需要从其他 GPU 收集梯度。梯度在各 GPU 间直接同步,通常使用 All-Reduce 操作,效率更高。
模型参数广播每次前向传播都需要将模型参数从主 GPU 复制到其他 GPU。初始化时各进程各自持有一份模型副本,参数更新后自动同步,无需频繁复制。
对 Python GIL 的影响受限于 Python 全局解释器锁(GIL),因为是单进程多线程,无法充分利用多核 CPU。采用多进程方式,不受 GIL 影响,能够充分利用多核 CPU 和多 GPU 进行并行计算。
容错性主 GPU 故障会导致整个训练中断,容错性较差。各进程相对独立,某个进程出错不会影响其他进程,容错性较好。
调试难度由于是单进程,调试相对容易。多进程调试较为复杂,需要注意进程间的通信和同步问题。
代码修改量只需在模型外层加上 nn.DataParallel 封装,代码改动少。需要在代码中添加进程初始化、模型封装、设备设置等步骤,修改量较大。
数据加载方式使用常规的数据加载方式,无需特殊处理。需要使用 DistributedSampler 等工具,确保各进程加载不同的数据子集,避免数据重复。
资源占用主 GPU 内存和计算资源占用较高,其他 GPU 资源可能未被充分利用。各 GPU 资源均衡占用,能够最大化利用多 GPU 的计算能力。
训练结果一致性由于参数更新在主 GPU 上进行,可能存在精度损失或不一致的情况。各进程的模型参数同步更新,训练结果一致性更好。

相关文章:

浅谈PyTorch中的DP和DDP

目录 1. 引言2. PyTorch 数据并行(Data Parallel, DP)2.1 DP 的优缺点2.2 DP 实现代码示例 3. PyTorch 分布式数据并行(Distributed Data Parallel, DDP)3.1 DDP 的优缺点3.2 分布式基本概念3.3 DDP 的应用流程3.5 DDP 实现代码示…...

在Windows上利用谷歌浏览器进行视频会议和协作

随着远程工作和在线教育的普及,使用谷歌浏览器在Windows上进行视频会议和协作变得越来越常见。本文将为您提供一个详细的教程,教您如何在Windows上利用谷歌浏览器进行视频会议和协作,同时解决一些常见的问题。(本文由https://goog…...

VMware Fusion 13.6.1 发布下载,修复 4 个已知问题

VMware Fusion 13.6.1 发布下载,修复 4 个已知问题 VMware Fusion 13.6.1 for Mac - 领先的免费桌面虚拟化软件 适用于基于 Intel 处理器和搭载 Apple 芯片的 Mac 的桌面虚拟化软件 请访问原文链接:https://sysin.org/blog/vmware-fusion-13/ 查看最新…...

P9751 [CSP-J 2023] 旅游巴士

P 9751 P9751 P9751 部分分思路 题目要求时间必须是 k k k 的非负整数倍,所以想到了升维。这样就变成了一道分层图最短路的题目。用 BFS 算法可以拿到 A i 0 A_i0 Ai​0 的 35 35 35 分。 满分思路 其实部分分的思路已经很接近正解了,想要拿到满…...

【Linux】man手册安装使用

目录 man(manual,手册) 手册安装: 章节区分: 指令参数: 使用场景: 手册内容列表: 手册查看快捷键: 实例: 仍致谢:Linux常用命令大全(手册) – 真正好用的Linux命令在线查询网站 提供的命令查询 在开头先提醒一下:在 man 手册中退出的方法很简单…...

mysql学习教程,从入门到精通,SQL处理重复数据(39)

1、SQL处理重复数据 使用GROUP BY和HAVING子句删除重复数据(以SQL Server为例)”的背景和原理的详细解释: 1.1、背景 在数据库管理中,数据重复是一个常见的问题。重复数据可能由于多种原因产生,如数据录入错误、数据…...

mapbox解决wmts请求乱码问题

贴个群号 WebGIS学习交流群461555818,欢迎大家 事故现场 如图所示,wmts请求全是乱码,看起来像是将一个完整的请求拆成一个一个的字母了,而且控制台打印map.getStyle() 查看该source发现不出异常 解决办法 此类问题就是由于更…...

《C++职场中设计模式的学习与应用:开启高效编程之旅》

在 C职场中,设计模式是提升代码质量、增强程序可维护性和可扩展性的强大武器。掌握并正确应用设计模式,不仅能让你在工作中更加得心应手,还能为你的职业发展增添有力的砝码。那么,如何在 C职场中学习和应用设计模式呢?…...

Maya动画--基础约束

005-基础约束02_哔哩哔哩_bilibili 父子约束 移动圆环,球体会跟着移动,并回到初始的相对位置 不同物体间没有层级关系 明确子物体与父物体间的关系 衣服上的纽扣 法线约束 切线约束 碰到中心时会改变方向...

腾讯云License 相关

腾讯云视立方 License 是必须购买的吗? 若您下载的腾讯云视立方功能模块中,包含直播推流(主播开播和主播观众连麦/主播跨房 PK)、短视频(视频录制编辑/视频上传发布)、终端极速高清和腾讯特效功能模块&…...

开放式耳机什么品牌最好?十大超好用开放式耳机排名!

由于长时间使用传统入耳式耳机可能会对耳道健康带来潜在的负面影响,越来越多的用户倾向于选择开放式耳机,这种设计不侵入耳道。它有助于降低耳内湿度、减少细菌滋生,以及缓解耳道因封闭而过热的不适。但是大部分人还是不知道怎么选择开放式耳…...

基于Zynq SDIO WiFi移植二(支持2.4/5G)

1 SDIO设备识别 经过编译,将移植好的uboot、kernel、rootFS、ramdisk等烧录到Flash中,上电启动,在log中,可看到sdio设备 [ 1.747059] mmc1: queuing unknown CIS tuple 0x01 (3 bytes) [ 1.761842] mmc1: queuing unknown…...

Spring Boot敏感数据动态配置:深入实践与安全性提升

在构建Spring Boot应用的过程中,敏感数据的处理与保护是至关重要的。传统上,这些敏感数据(如数据库密码、API密钥、加密密钥等)可能被硬编码在配置文件中,这不仅增加了泄露的风险,也限制了配置的灵活性和可…...

软考数据库部分 ---- (概念数据库模型,三级模式,两级映像,事物管理)

文章目录 一、概念数据库模型二、结构数据库模型三、三级模式四、两级映像五、关系模式基本术语六、关系模式七、关系的数学定义八、数据定义语言九、SQL访问控制十、视图十一、索引十二、关系模式十三、范式十四、数据库设计十五、事物管理(ACID)十六、…...

AI 概念大杂烩

目录 介绍 数据挖掘 / 机器学习 / 深度学习 一、数据挖掘(Data Mining) 1. 定义 2. 目标 3. 常用算法 二、机器学习(Machine Learning) 1. 定义 2. 目标 3. 常用算法 三、深度学习(Deep Learning&#xff0…...

Composer和PHP有什么关系

Composer是PHP的一个依赖管理工具,以下是对Composer及其与PHP关系的详细解释: Composer简介 核心功能:Composer的核心思想是“依赖管理”,它能够自动下载和安装项目所依赖的库、框架或插件等。这些依赖项可以是PHP本身的库文件&…...

【PGCCC】在 Postgres 上构建图像搜索引擎

我最近看到的最有趣的电子商务功能之一是能够搜索与我手机上的图片相似的产品。例如,我可以拍一双鞋或其他产品的照片,然后搜索产品目录以查找类似商品。使用这样的功能可以是一个相当简单的项目,只要有合适的工具。如果我们可以将问题定义为…...

性能测试之性能问题分析

开始性能测试前需要了解的内容: 1、项目具体需求。 2、指标:响应时间在多少以内,并发数多少,tps多少,总tps多少,稳定性交易总量多少,事务成功率,交易波动范围,稳定运行…...

错过了A股,别再错过AI表情包!N款变现攻略,你选哪个?

本文背景 据 Swyft Media 统计,全世界每天各类聊天 app 发送的表情符号有 60 多亿,我们国家每天表情包发送量大概 6 亿次。 表情包简直就是个大淘金池,最近用 AI 做表情包也挺火。所以今天给大家讲讲一个用 AI 做表情包变现的项目。 以前没…...

SpringBoot驱动的美发沙龙管理系统:优雅地管理您的业务

1系统概述 1.1 研究背景 随着计算机技术的发展以及计算机网络的逐渐普及,互联网成为人们查找信息的重要场所,二十一世纪是信息的时代,所以信息的管理显得特别重要。因此,使用计算机来管理美发门店管理系统的相关信息成为必然。开发…...

智慧医疗能源事业线深度画像分析(上)

引言 医疗行业作为现代社会的关键基础设施,其能源消耗与环境影响正日益受到关注。随着全球"双碳"目标的推进和可持续发展理念的深入,智慧医疗能源事业线应运而生,致力于通过创新技术与管理方案,重构医疗领域的能源使用模式。这一事业线融合了能源管理、可持续发…...

Cilium动手实验室: 精通之旅---20.Isovalent Enterprise for Cilium: Zero Trust Visibility

Cilium动手实验室: 精通之旅---20.Isovalent Enterprise for Cilium: Zero Trust Visibility 1. 实验室环境1.1 实验室环境1.2 小测试 2. The Endor System2.1 部署应用2.2 检查现有策略 3. Cilium 策略实体3.1 创建 allow-all 网络策略3.2 在 Hubble CLI 中验证网络策略源3.3 …...

深入理解JavaScript设计模式之单例模式

目录 什么是单例模式为什么需要单例模式常见应用场景包括 单例模式实现透明单例模式实现不透明单例模式用代理实现单例模式javaScript中的单例模式使用命名空间使用闭包封装私有变量 惰性单例通用的惰性单例 结语 什么是单例模式 单例模式(Singleton Pattern&#…...

linux 错误码总结

1,错误码的概念与作用 在Linux系统中,错误码是系统调用或库函数在执行失败时返回的特定数值,用于指示具体的错误类型。这些错误码通过全局变量errno来存储和传递,errno由操作系统维护,保存最近一次发生的错误信息。值得注意的是,errno的值在每次系统调用或函数调用失败时…...

【JavaWeb】Docker项目部署

引言 之前学习了Linux操作系统的常见命令,在Linux上安装软件,以及如何在Linux上部署一个单体项目,大多数同学都会有相同的感受,那就是麻烦。 核心体现在三点: 命令太多了,记不住 软件安装包名字复杂&…...

Yolov8 目标检测蒸馏学习记录

yolov8系列模型蒸馏基本流程,代码下载:这里本人提交了一个demo:djdll/Yolov8_Distillation: Yolov8轻量化_蒸馏代码实现 在轻量化模型设计中,**知识蒸馏(Knowledge Distillation)**被广泛应用,作为提升模型…...

代码规范和架构【立芯理论一】(2025.06.08)

1、代码规范的目标 代码简洁精炼、美观,可持续性好高效率高复用,可移植性好高内聚,低耦合没有冗余规范性,代码有规可循,可以看出自己当时的思考过程特殊排版,特殊语法,特殊指令,必须…...

Qemu arm操作系统开发环境

使用qemu虚拟arm硬件比较合适。 步骤如下: 安装qemu apt install qemu-system安装aarch64-none-elf-gcc 需要手动下载,下载地址:https://developer.arm.com/-/media/Files/downloads/gnu/13.2.rel1/binrel/arm-gnu-toolchain-13.2.rel1-x…...

pikachu靶场通关笔记19 SQL注入02-字符型注入(GET)

目录 一、SQL注入 二、字符型SQL注入 三、字符型注入与数字型注入 四、源码分析 五、渗透实战 1、渗透准备 2、SQL注入探测 (1)输入单引号 (2)万能注入语句 3、获取回显列orderby 4、获取数据库名database 5、获取表名…...

【FTP】ftp文件传输会丢包吗?批量几百个文件传输,有一些文件没有传输完整,如何解决?

FTP(File Transfer Protocol)本身是一个基于 TCP 的协议,理论上不会丢包。但 FTP 文件传输过程中仍可能出现文件不完整、丢失或损坏的情况,主要原因包括: ✅ 一、FTP传输可能“丢包”或文件不完整的原因 原因描述网络…...