浅谈PyTorch中的DP和DDP
目录
- 1. 引言
- 2. PyTorch 数据并行(Data Parallel, DP)
- 2.1 DP 的优缺点
- 2.2 DP 实现代码示例
- 3. PyTorch 分布式数据并行(Distributed Data Parallel, DDP)
- 3.1 DDP 的优缺点
- 3.2 分布式基本概念
- 3.3 DDP 的应用流程
- 3.5 DDP 实现代码示例
- 4. DP和DDP的对比
1. 引言
在现代深度学习中,随着模型规模的不断增大以及数据量的快速增长,模型训练所需的计算资源也变得愈加庞大。尤其是在大型深度学习模型的训练过程中,单张 GPU 显存往往难以满足需求,因此,如何高效利用多 GPU 进行并行训练,成为了加速模型训练的关键手段。PyTorch 作为目前最受欢迎的深度学习框架之一,提供了多种并行训练的方式,其中最常用的是 数据并行(Data Parallel, DP) 和 分布式数据并行(Distributed Data Parallel, DDP)。
⚠️ 无论是DP还是DDP都只支持数据并行。
2. PyTorch 数据并行(Data Parallel, DP)
数据并行(Data Parallel, DP) 是 PyTorch 中一种简单的并行训练方式,它的主要思想是将数据拆分为多个子集,然后将这些子集分别分配给不同的 GPU 进行计算。DP 的工作原理如下:
- 在前向传播时,首先将模型的参数复制到每个 GPU 上。
- 每个 GPU 独立计算一部分数据的前向传播和损失值,并将计算结果返回到主 GPU。
- 主 GPU 汇总每个 GPU 计算的损失,并计算出梯度。
- 通过反向传播,将计算得到的梯度更新主 GPU 的模型参数,然后再将更新后的参数广播到其他 GPU 上。
2.1 DP 的优缺点
优点:
- 实现简单,使用 PyTorch 提供的
torch.nn.DataParallel
接口即可轻松实现。 - 对于小规模的模型和数据集,DP 能够在单机多卡的场景下提供良好的加速效果。
缺点:
- DP 在每个 batch 中需要在 GPU 之间传递模型参数和数据,参数更新时也需要将梯度传递回主 GPU,这会造成大量的通信开销。
- 由于梯度的计算和模型参数的更新都是在主 GPU 上完成的,主 GPU 的负载会显著增加,导致 GPU 资源无法得到充分利用。
2.2 DP 实现代码示例
使用 torch.nn.DataParallel
实现数据并行非常简单。我们只需要将模型封装到 DataParallel
中,然后传入多个 GPU 即可。下面我们通过代码示例展示如何使用 DP 进行并行训练。
import torch
import torch.nn as nn
import torchvisionBATCH_SIZE = 256
EPOCHS = 5
NUM_CLASSES = 10
INPUT_SHAPE = (3, 224, 224) # ResNet-18 的输入尺寸# 1. 创建模型
net = torchvision.models.resnet18(pretrained=False, num_classes=NUM_CLASSES)
net = nn.DataParallel(net)
net = net.cuda()# 2. 生成随机数据
total_steps = 100 # 假设每个 epoch 有 100 个步骤
inputs = torch.randn(BATCH_SIZE, *INPUT_SHAPE).cuda()
targets = torch.randint(0, NUM_CLASSES, (BATCH_SIZE,)).cuda()# 3. 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(net.parameters(), lr=0.02, momentum=0.9, weight_decay=0.0001, nesterov=True
)# 4. 开始训练
net.train()
for ep in range(1, EPOCHS + 1):train_loss = correct = total = 0for idx in range(total_steps):outputs = net(inputs)loss = criterion(outputs, targets)optimizer.zero_grad()loss.backward()optimizer.step()train_loss += loss.item()total += targets.size(0)correct += torch.eq(outputs.argmax(dim=1), targets).sum().item()if (idx + 1) % 25 == 0 or (idx + 1) == total_steps:print(f"Epoch [{ep}/{EPOCHS}], Step [{idx + 1}/{total_steps}], Loss: {train_loss / (idx + 1):.3f}, Acc: {correct / total:.3%}")
在这个代码示例中,我们使用了随机生成的输入和标签数据,以简化代码并专注于并行训练的实现。通过将模型封装在 DataParallel
中,我们可以在多个 GPU 上进行并行计算。然而,由于 DP 存在较大的通信开销以及主 GPU 的计算瓶颈,因此在更大规模的训练中,我们更推荐使用分布式数据并行(DDP)来加速训练。
3. PyTorch 分布式数据并行(Distributed Data Parallel, DDP)
分布式数据并行(Distributed Data Parallel, DDP) 是 PyTorch 中推荐使用的多 GPU 并行训练方式,特别适合大规模训练任务。与 DP 不同,DDP 是一种多进程并行方式,避免了 Python 全局解释器锁(GIL)的限制,可以在单机或多机多卡环境中实现更高效的并行计算。DDP的工作原理如下:
- 在每个 GPU 上运行一个独立的进程,每个进程都有自己的一份模型副本和数据。
- 各个进程独立执行前向传播、计算损失和反向传播,得到各自的梯度。
- 在反向传播阶段,各个 GPU 的进程通过通信将梯度汇总,平均后更新每个进程中的模型参数。
- 每个进程的模型参数在整个训练过程中保持一致,避免了 DP 中由于参数广播导致的通信开销。
3.1 DDP 的优缺点
优点:
- 由于各个 GPU 上的进程独立计算梯度,更新模型参数时只需要同步梯度而非整个模型,通信开销较小,性能大幅提升。
- DDP 可以在多机多卡环境下使用,支持大规模的分布式训练,适合深度学习模型的高效扩展。
缺点:
- 代码实现相对 DP 较为复杂,需要手动管理进程的初始化和同步。
3.2 分布式基本概念
在使用 DDP 进行分布式训练时,我们需要理解以下几个基本概念:
- node(节点):物理节点,一台机器即为一个节点。
- nnodes(节点数量):表示参与训练的物理节点数量。
- node rank(节点序号):节点的编号,用于区分不同的物理节点。
- nproc per node(每节点的进程数量):表示每个物理节点上启动的进程数量,通常等于 GPU 的数量。
- world size(全局进程数量):表示全局并行的进程总数,等于
nnodes * nproc_per_node
。 - rank(进程序号):表示每个进程的唯一编号,用于进程间通信,
rank=0
的进程为主进程。 - local rank(本地进程序号):在某个节点上的进程的序号,
local_rank=0
表示该节点的主进程。
3.3 DDP 的应用流程
使用 DDP 进行分布式训练的步骤如下:
- 初始化分布式训练环境:通过
torch.distributed.init_process_group
初始化进程组,指定通信后端和相关配置。 - 创建分布式模型:将模型封装到
torch.nn.parallel.DistributedDataParallel
中,进行并行训练。 - 生成或加载数据:在每个进程中加载数据,并确保数据在不同进程间的分布,如使用
DistributedSampler
。 - 执行训练脚本:在每个节点的每个进程上启动训练脚本,进行模型训练。
3.5 DDP 实现代码示例
import os
import torch
import torch.distributed as dist
import torch.nn as nn
import torchvision
from torch.nn.parallel import DistributedDataParallel as DDPBATCH_SIZE = 256
EPOCHS = 5
NUM_CLASSES = 10
INPUT_SHAPE = (3, 224, 224) # ResNet-18 的输入尺寸if __name__ == "__main__":# 1. 设置分布式变量,初始化进程组rank = int(os.environ["RANK"])local_rank = int(os.environ["LOCAL_RANK"])torch.cuda.set_device(local_rank)dist.init_process_group(backend="nccl")device = torch.device("cuda", local_rank)print(f"[init] == local rank: {local_rank}, global rank: {rank} ==")# 2. 创建模型net = torchvision.models.resnet18(pretrained=False, num_classes=NUM_CLASSES)net = net.to(device)net = DDP(net, device_ids=[local_rank], output_device=local_rank)# 3. 生成随机数据total_steps = 100 # 假设每个 epoch 有 100 个步骤inputs = torch.randn(BATCH_SIZE, *INPUT_SHAPE).to(device)targets = torch.randint(0, NUM_CLASSES, (BATCH_SIZE,)).to(device)# 4. 定义损失函数和优化器criterion = nn.CrossEntropyLoss()optimizer = torch.optim.SGD(net.parameters(), lr=0.02, momentum=0.9, weight_decay=0.0001, nesterov=True)# 5. 开始训练net.train()for ep in range(1, EPOCHS + 1):train_loss = correct = total = 0for idx in range(total_steps):outputs = net(inputs)loss = criterion(outputs, targets)optimizer.zero_grad()loss.backward()optimizer.step()train_loss += loss.item()total += targets.size(0)correct += torch.eq(outputs.argmax(dim=1), targets).sum().item()if rank == 0 and ((idx + 1) % 25 == 0 or (idx + 1) == total_steps):print(" == step: [{:3}/{}] [{}/{}] | loss: {:.3f} | acc: {:6.3f}%".format(idx + 1,total_steps,ep,EPOCHS,train_loss / (idx + 1),100.0 * correct / total,))if rank == 0:print("\n ======= Training Finished ======= \n")
在以上代码中,我们使用了随机生成的输入和标签数据,以简化代码并专注于 DDP 的实现细节。通过在每个进程中初始化分布式环境,并将模型封装在 DistributedDataParallel
中,我们可以在多个 GPU 上高效地进行并行训练。
需要注意的是,DDP 的实现需要在每个进程中正确设置设备和初始化过程,这样才能确保模型和数据在对应的 GPU 上进行计算。
4. DP和DDP的对比
DP 是单进程多线程的分布式方法,主要用于单机多卡的场景。它的工作方式是在每个批处理期间,将模型参数分发到所有 GPU,各 GPU 计算各自的梯度后将结果汇总到 GPU0,再由 GPU0 完成参数更新,然后将更新后的模型参数广播回其他 GPU。由于 DP 只广播模型的参数,速度较慢,尤其是在多个 GPU 协同工作时,GPU 利用率低,通常效率不如 DDP。
相比之下,DDP 使用多进程架构,既支持单机多卡,也支持多机多卡,并避免了 GIL(全局解释器锁)带来的性能损失。每个进程独立计算梯度,计算完成后各进程汇总并平均梯度,更新参数时各进程均独立完成。这种方式减少了通信开销,只在初始化时广播一次模型参数,并且在每次更新后只传递梯度。由于各进程独立更新参数,且更新过程中模型参数保持一致,DDP 在效率和速度上大大优于 DP。
数据并行(DP) | 分布式数据并行(DDP) | |
---|---|---|
实现复杂度 | 使用 nn.DataParallel ,实现简单,代码改动较少。 | 需要设置分布式环境,使用 torch.distributed ,代码实现相对复杂,需要手动管理进程和同步。 |
通信开销 | 通信开销较大,参数和梯度需要在主 GPU 和其他 GPU 之间频繁传递。 | 通信开销较小,只在反向传播时同步梯度,各 GPU 之间直接通信,无需通过主 GPU。 |
扩展性 | 扩展性有限,适用于单机多卡,不支持多机训练。 | 扩展性强,支持单机多卡和多机多卡,适合大规模分布式训练。 |
性能 | 主 GPU 负载重,可能成为瓶颈,GPU 资源利用率较低。 | 各 GPU 负载均衡,资源利用率高,训练速度更快。 |
适用场景 | 适合小规模模型和数据集的单机多卡训练。 | 适合大规模模型和数据集的单机或多机多卡训练。 |
梯度同步方式 | 梯度在主 GPU 上汇总和更新,需要从其他 GPU 收集梯度。 | 梯度在各 GPU 间直接同步,通常使用 All-Reduce 操作,效率更高。 |
模型参数广播 | 每次前向传播都需要将模型参数从主 GPU 复制到其他 GPU。 | 初始化时各进程各自持有一份模型副本,参数更新后自动同步,无需频繁复制。 |
对 Python GIL 的影响 | 受限于 Python 全局解释器锁(GIL),因为是单进程多线程,无法充分利用多核 CPU。 | 采用多进程方式,不受 GIL 影响,能够充分利用多核 CPU 和多 GPU 进行并行计算。 |
容错性 | 主 GPU 故障会导致整个训练中断,容错性较差。 | 各进程相对独立,某个进程出错不会影响其他进程,容错性较好。 |
调试难度 | 由于是单进程,调试相对容易。 | 多进程调试较为复杂,需要注意进程间的通信和同步问题。 |
代码修改量 | 只需在模型外层加上 nn.DataParallel 封装,代码改动少。 | 需要在代码中添加进程初始化、模型封装、设备设置等步骤,修改量较大。 |
数据加载方式 | 使用常规的数据加载方式,无需特殊处理。 | 需要使用 DistributedSampler 等工具,确保各进程加载不同的数据子集,避免数据重复。 |
资源占用 | 主 GPU 内存和计算资源占用较高,其他 GPU 资源可能未被充分利用。 | 各 GPU 资源均衡占用,能够最大化利用多 GPU 的计算能力。 |
训练结果一致性 | 由于参数更新在主 GPU 上进行,可能存在精度损失或不一致的情况。 | 各进程的模型参数同步更新,训练结果一致性更好。 |
相关文章:
浅谈PyTorch中的DP和DDP
目录 1. 引言2. PyTorch 数据并行(Data Parallel, DP)2.1 DP 的优缺点2.2 DP 实现代码示例 3. PyTorch 分布式数据并行(Distributed Data Parallel, DDP)3.1 DDP 的优缺点3.2 分布式基本概念3.3 DDP 的应用流程3.5 DDP 实现代码示…...

在Windows上利用谷歌浏览器进行视频会议和协作
随着远程工作和在线教育的普及,使用谷歌浏览器在Windows上进行视频会议和协作变得越来越常见。本文将为您提供一个详细的教程,教您如何在Windows上利用谷歌浏览器进行视频会议和协作,同时解决一些常见的问题。(本文由https://goog…...

VMware Fusion 13.6.1 发布下载,修复 4 个已知问题
VMware Fusion 13.6.1 发布下载,修复 4 个已知问题 VMware Fusion 13.6.1 for Mac - 领先的免费桌面虚拟化软件 适用于基于 Intel 处理器和搭载 Apple 芯片的 Mac 的桌面虚拟化软件 请访问原文链接:https://sysin.org/blog/vmware-fusion-13/ 查看最新…...
P9751 [CSP-J 2023] 旅游巴士
P 9751 P9751 P9751 部分分思路 题目要求时间必须是 k k k 的非负整数倍,所以想到了升维。这样就变成了一道分层图最短路的题目。用 BFS 算法可以拿到 A i 0 A_i0 Ai0 的 35 35 35 分。 满分思路 其实部分分的思路已经很接近正解了,想要拿到满…...

【Linux】man手册安装使用
目录 man(manual,手册) 手册安装: 章节区分: 指令参数: 使用场景: 手册内容列表: 手册查看快捷键: 实例: 仍致谢:Linux常用命令大全(手册) – 真正好用的Linux命令在线查询网站 提供的命令查询 在开头先提醒一下:在 man 手册中退出的方法很简单…...
mysql学习教程,从入门到精通,SQL处理重复数据(39)
1、SQL处理重复数据 使用GROUP BY和HAVING子句删除重复数据(以SQL Server为例)”的背景和原理的详细解释: 1.1、背景 在数据库管理中,数据重复是一个常见的问题。重复数据可能由于多种原因产生,如数据录入错误、数据…...

mapbox解决wmts请求乱码问题
贴个群号 WebGIS学习交流群461555818,欢迎大家 事故现场 如图所示,wmts请求全是乱码,看起来像是将一个完整的请求拆成一个一个的字母了,而且控制台打印map.getStyle() 查看该source发现不出异常 解决办法 此类问题就是由于更…...
《C++职场中设计模式的学习与应用:开启高效编程之旅》
在 C职场中,设计模式是提升代码质量、增强程序可维护性和可扩展性的强大武器。掌握并正确应用设计模式,不仅能让你在工作中更加得心应手,还能为你的职业发展增添有力的砝码。那么,如何在 C职场中学习和应用设计模式呢?…...

Maya动画--基础约束
005-基础约束02_哔哩哔哩_bilibili 父子约束 移动圆环,球体会跟着移动,并回到初始的相对位置 不同物体间没有层级关系 明确子物体与父物体间的关系 衣服上的纽扣 法线约束 切线约束 碰到中心时会改变方向...

腾讯云License 相关
腾讯云视立方 License 是必须购买的吗? 若您下载的腾讯云视立方功能模块中,包含直播推流(主播开播和主播观众连麦/主播跨房 PK)、短视频(视频录制编辑/视频上传发布)、终端极速高清和腾讯特效功能模块&…...

开放式耳机什么品牌最好?十大超好用开放式耳机排名!
由于长时间使用传统入耳式耳机可能会对耳道健康带来潜在的负面影响,越来越多的用户倾向于选择开放式耳机,这种设计不侵入耳道。它有助于降低耳内湿度、减少细菌滋生,以及缓解耳道因封闭而过热的不适。但是大部分人还是不知道怎么选择开放式耳…...

基于Zynq SDIO WiFi移植二(支持2.4/5G)
1 SDIO设备识别 经过编译,将移植好的uboot、kernel、rootFS、ramdisk等烧录到Flash中,上电启动,在log中,可看到sdio设备 [ 1.747059] mmc1: queuing unknown CIS tuple 0x01 (3 bytes) [ 1.761842] mmc1: queuing unknown…...
Spring Boot敏感数据动态配置:深入实践与安全性提升
在构建Spring Boot应用的过程中,敏感数据的处理与保护是至关重要的。传统上,这些敏感数据(如数据库密码、API密钥、加密密钥等)可能被硬编码在配置文件中,这不仅增加了泄露的风险,也限制了配置的灵活性和可…...

软考数据库部分 ---- (概念数据库模型,三级模式,两级映像,事物管理)
文章目录 一、概念数据库模型二、结构数据库模型三、三级模式四、两级映像五、关系模式基本术语六、关系模式七、关系的数学定义八、数据定义语言九、SQL访问控制十、视图十一、索引十二、关系模式十三、范式十四、数据库设计十五、事物管理(ACID)十六、…...
AI 概念大杂烩
目录 介绍 数据挖掘 / 机器学习 / 深度学习 一、数据挖掘(Data Mining) 1. 定义 2. 目标 3. 常用算法 二、机器学习(Machine Learning) 1. 定义 2. 目标 3. 常用算法 三、深度学习(Deep Learning࿰…...
Composer和PHP有什么关系
Composer是PHP的一个依赖管理工具,以下是对Composer及其与PHP关系的详细解释: Composer简介 核心功能:Composer的核心思想是“依赖管理”,它能够自动下载和安装项目所依赖的库、框架或插件等。这些依赖项可以是PHP本身的库文件&…...

【PGCCC】在 Postgres 上构建图像搜索引擎
我最近看到的最有趣的电子商务功能之一是能够搜索与我手机上的图片相似的产品。例如,我可以拍一双鞋或其他产品的照片,然后搜索产品目录以查找类似商品。使用这样的功能可以是一个相当简单的项目,只要有合适的工具。如果我们可以将问题定义为…...

性能测试之性能问题分析
开始性能测试前需要了解的内容: 1、项目具体需求。 2、指标:响应时间在多少以内,并发数多少,tps多少,总tps多少,稳定性交易总量多少,事务成功率,交易波动范围,稳定运行…...

错过了A股,别再错过AI表情包!N款变现攻略,你选哪个?
本文背景 据 Swyft Media 统计,全世界每天各类聊天 app 发送的表情符号有 60 多亿,我们国家每天表情包发送量大概 6 亿次。 表情包简直就是个大淘金池,最近用 AI 做表情包也挺火。所以今天给大家讲讲一个用 AI 做表情包变现的项目。 以前没…...
SpringBoot驱动的美发沙龙管理系统:优雅地管理您的业务
1系统概述 1.1 研究背景 随着计算机技术的发展以及计算机网络的逐渐普及,互联网成为人们查找信息的重要场所,二十一世纪是信息的时代,所以信息的管理显得特别重要。因此,使用计算机来管理美发门店管理系统的相关信息成为必然。开发…...

7.4.分块查找
一.分块查找的算法思想: 1.实例: 以上述图片的顺序表为例, 该顺序表的数据元素从整体来看是乱序的,但如果把这些数据元素分成一块一块的小区间, 第一个区间[0,1]索引上的数据元素都是小于等于10的, 第二…...

使用VSCode开发Django指南
使用VSCode开发Django指南 一、概述 Django 是一个高级 Python 框架,专为快速、安全和可扩展的 Web 开发而设计。Django 包含对 URL 路由、页面模板和数据处理的丰富支持。 本文将创建一个简单的 Django 应用,其中包含三个使用通用基本模板的页面。在此…...
java_网络服务相关_gateway_nacos_feign区别联系
1. spring-cloud-starter-gateway 作用:作为微服务架构的网关,统一入口,处理所有外部请求。 核心能力: 路由转发(基于路径、服务名等)过滤器(鉴权、限流、日志、Header 处理)支持负…...
DeepSeek 赋能智慧能源:微电网优化调度的智能革新路径
目录 一、智慧能源微电网优化调度概述1.1 智慧能源微电网概念1.2 优化调度的重要性1.3 目前面临的挑战 二、DeepSeek 技术探秘2.1 DeepSeek 技术原理2.2 DeepSeek 独特优势2.3 DeepSeek 在 AI 领域地位 三、DeepSeek 在微电网优化调度中的应用剖析3.1 数据处理与分析3.2 预测与…...
三维GIS开发cesium智慧地铁教程(5)Cesium相机控制
一、环境搭建 <script src"../cesium1.99/Build/Cesium/Cesium.js"></script> <link rel"stylesheet" href"../cesium1.99/Build/Cesium/Widgets/widgets.css"> 关键配置点: 路径验证:确保相对路径.…...
【SpringBoot】100、SpringBoot中使用自定义注解+AOP实现参数自动解密
在实际项目中,用户注册、登录、修改密码等操作,都涉及到参数传输安全问题。所以我们需要在前端对账户、密码等敏感信息加密传输,在后端接收到数据后能自动解密。 1、引入依赖 <dependency><groupId>org.springframework.boot</groupId><artifactId...

Keil 中设置 STM32 Flash 和 RAM 地址详解
文章目录 Keil 中设置 STM32 Flash 和 RAM 地址详解一、Flash 和 RAM 配置界面(Target 选项卡)1. IROM1(用于配置 Flash)2. IRAM1(用于配置 RAM)二、链接器设置界面(Linker 选项卡)1. 勾选“Use Memory Layout from Target Dialog”2. 查看链接器参数(如果没有勾选上面…...
GitHub 趋势日报 (2025年06月08日)
📊 由 TrendForge 系统生成 | 🌐 https://trendforge.devlive.org/ 🌐 本日报中的项目描述已自动翻译为中文 📈 今日获星趋势图 今日获星趋势图 884 cognee 566 dify 414 HumanSystemOptimization 414 omni-tools 321 note-gen …...
06 Deep learning神经网络编程基础 激活函数 --吴恩达
深度学习激活函数详解 一、核心作用 引入非线性:使神经网络可学习复杂模式控制输出范围:如Sigmoid将输出限制在(0,1)梯度传递:影响反向传播的稳定性二、常见类型及数学表达 Sigmoid σ ( x ) = 1 1 +...

如何理解 IP 数据报中的 TTL?
目录 前言理解 前言 面试灵魂一问:说说对 IP 数据报中 TTL 的理解?我们都知道,IP 数据报由首部和数据两部分组成,首部又分为两部分:固定部分和可变部分,共占 20 字节,而即将讨论的 TTL 就位于首…...