PyTorch 分布式训练(Distributed Data Parallel, DDP)简介
PyTorch 分布式训练(Distributed Data Parallel, DDP)
一、DDP 核心概念
torch.nn.parallel.DistributedDataParallel
1. DDP 是什么?
Distributed Data Parallel (DDP) 是 PyTorch 提供的分布式训练接口,DistributedDataParallel相比 DataParallel 具有以下优势:
- 多进程而非多线程:避免 Python GIL 限制
- 更高的效率:每个 GPU 有独立的进程,减少通信开销
- 更好的扩展性:支持多机多卡训练
- 更均衡的负载:无主 GPU 瓶颈问题
2. 核心组件
- 进程组 (Process Group):管理进程间通信
- NCCL 后端:NVIDIA 优化的 GPU 通信库
- Ring-AllReduce:高效的梯度同步算法

二、完整 DDP 训练 Demo
- 官方DDP Dem参考
1. 基础训练脚本 (ddp_demo.py)
import os
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.optim as optim
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import DataLoader, DistributedSampler
from torchvision import datasets, transforms
from torch.cuda.amp import GradScalerdef setup(rank, world_size):"""初始化分布式环境"""os.environ['MASTER_ADDR'] = 'localhost'os.environ['MASTER_PORT'] = '12355'dist.init_process_group("nccl", rank=rank, world_size=world_size)def cleanup():"""清理分布式环境"""dist.destroy_process_group()class SimpleModel(nn.Module):"""简单的CNN模型"""def __init__(self):super().__init__()self.conv1 = nn.Conv2d(1, 32, 3, 1)self.conv2 = nn.Conv2d(32, 64, 3, 1)self.fc = nn.Linear(9216, 10)def forward(self, x):x = torch.relu(self.conv1(x))x = torch.max_pool2d(x, 2)x = torch.relu(self.conv2(x))x = torch.max_pool2d(x, 2)x = torch.flatten(x, 1)return self.fc(x)def prepare_dataloader(rank, world_size, batch_size=32):"""准备分布式数据加载器"""transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.1307,), (0.3081,))])dataset = datasets.MNIST('../data', train=True, download=True, transform=transform)sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank)loader = DataLoader(dataset, batch_size=batch_size, sampler=sampler)return loaderdef train(rank, world_size, epochs=2):"""训练函数"""setup(rank, world_size)# 设置当前设备torch.cuda.set_device(rank)# 初始化模型、优化器等model = SimpleModel().to(rank)ddp_model = DDP(model, device_ids=[rank])optimizer = optim.Adam(ddp_model.parameters())scaler = GradScaler() # 混合精度训练criterion = nn.CrossEntropyLoss()train_loader = prepare_dataloader(rank, world_size)for epoch in range(epochs):ddp_model.train()train_loader.sampler.set_epoch(epoch) # 确保每个epoch有不同的shufflefor batch_idx, (data, target) in enumerate(train_loader):data, target = data.to(rank), target.to(rank)optimizer.zero_grad()# 混合精度训练with torch.autocast(device_type='cuda', dtype=torch.float16):output = ddp_model(data)loss = criterion(output, target)scaler.scale(loss).backward()scaler.step(optimizer)scaler.update()if batch_idx % 100 == 0:print(f"Rank {rank}, Epoch {epoch}, Batch {batch_idx}, Loss {loss.item():.4f}")cleanup()if __name__ == "__main__":# 单机多卡启动时,torchrun会自动设置这些环境变量rank = int(os.environ['LOCAL_RANK'])world_size = int(os.environ['WORLD_SIZE'])train(rank, world_size)
2. 启动训练
使用 torchrun 启动分布式训练(推荐 PyTorch 1.9+):
# 单机4卡训练
torchrun --nproc_per_node=4 --nnodes=1 --node_rank=0 --master_addr=localhost --master_port=12355 ddp_demo.py
3. 关键组件解析
3.1 分布式数据采样 (DistributedSampler)
sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank)
- 确保每个 GPU 处理不同的数据子集
- 自动处理数据分片和 epoch 间的 shuffle
3.2 模型包装 (DDP)
ddp_model = DDP(model, device_ids=[rank])
- 自动处理梯度同步
- 透明地包装模型,使用方式与普通模型一致
3.3 混合精度训练 (AMP)
scaler = GradScaler()
with torch.autocast(device_type='cuda', dtype=torch.float16):# 前向计算
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
- 减少显存占用,加速训练
- 自动管理 float16/float32 转换
三、DDP 最佳实践
-
数据加载
- 必须使用
DistributedSampler - 每个 epoch 前调用
sampler.set_epoch(epoch)保证 shuffle 正确性
- 必须使用
-
模型保存
if rank == 0: # 只在主进程保存torch.save(model.state_dict(), "model.pth") -
多机训练
# 机器1 (主节点) torchrun --nproc_per_node=4 --nnodes=2 --node_rank=0 --master_addr=IP1 --master_port=12355 ddp_demo.py# 机器2 torchrun --nproc_per_node=4 --nnodes=2 --node_rank=1 --master_addr=IP1 --master_port=12355 ddp_demo.py -
性能调优
- 调整
batch_size使各 GPU 负载均衡 - 使用
pin_memory=True加速数据加载 - 考虑梯度累积减少通信频率
- 调整
四、常见问题解决
-
CUDA 内存不足
- 减少
batch_size - 使用梯度累积
for i, (data, target) in enumerate(train_loader):if i % 2 == 0:optimizer.zero_grad()# 前向和反向...if i % 2 == 1:optimizer.step() - 减少
-
进程同步失败
- 检查所有节点的
MASTER_ADDR和MASTER_PORT一致 - 确保防火墙开放相应端口
- 检查所有节点的
-
精度问题
- 混合精度训练时出现 NaN:调整
GradScaler参数
scaler = GradScaler(init_scale=1024, growth_factor=2.0) - 混合精度训练时出现 NaN:调整
相关文章:
PyTorch 分布式训练(Distributed Data Parallel, DDP)简介
PyTorch 分布式训练(Distributed Data Parallel, DDP) 一、DDP 核心概念 torch.nn.parallel.DistributedDataParallel 1. DDP 是什么? Distributed Data Parallel (DDP) 是 PyTorch 提供的分布式训练接口,DistributedDataPara…...
prism WPF 消息的订阅或发布
prism WPF 消息的订阅或发布 EventMessage using Prism.Events; using System; using System.Collections.Generic; using System.Linq; using System.Text; using System.Threading.Tasks;namespace Cjh.PrismWpf {/// <summary>/// 事件消息/// </summary>publ…...
【Unity】记录TMPro使用过程踩的一些坑
1、打包到webgl无法输入中文,编辑器模式可以,但是webgl不行,试过网上的脚本,还是不行 解决方法:暂时没找到 2、针对字体asset是中文时,overflow的效果模式处理奇怪,它会出现除了overflow模式以…...
计算机视觉初步(环境搭建)
1.anaconda 建议安装在D盘,官网正常安装即可,一般可以安装windows版本 安装成功后,可以在电脑应用里找到: 2.创建虚拟环境 打开anaconda prompt, 可以用conda env list 查看现有的环境,一般打开默认bas…...
【go】异常处理panic和recover
panic 和 recover 当然能触发程序宕机退出的,也可以是我们自己,比如经过检查判断,当前环境无法达到我们程序进行的预期条件时(比如一个服务指定监听端口被其他程序占用),可以手动触发 panic,让…...
Sentinel[超详细讲解]-3
主要讲解🚀 - 基于QPS/并发数的流量控制 1、流控规则 流量控制(Flow Control)用于限制某个资源的访问频率,防止系统被瞬时的流量高峰冲垮。流量控制规则可以针对不同的资源进行配置,例如接口、方法、类等。 流量规则的…...
【云原生】Kubernetes CEL 速查表
以下是一份 Kubernetes CEL 速查表(Cheat Sheet),涵盖了常见的语法、宏、标准函数和一些在 Kubernetes 中常见的使用示例。可在编写或调试 CEL 表达式时用作快速参考。 1. 基础概念 概念说明语言特点无副作用、逐渐类型化(Gradua…...
基于聚类与引力斥力优化的选址算法
在众多实际场景中,诸如消防设施选址、基站布局规划以及充电桩站点部署等,都面临着如何利用最少的资源,实现对所有目标对象全面覆盖的难题。为有效解决这类问题,本文提出一种全新的组合算法模型 —— 基于聚类与引力斥力优化的选址…...
深入剖析雪花算法:分布式ID生成的核心方案
深入剖析雪花算法:分布式ID生成的核心方案 深入剖析雪花算法:分布式ID生成的核心方案一、雪花算法(Snowflake)概述二、雪花算法核心组成1. 64位二进制结构2. 时间戳起始点 三、工作原理与代码实现1. 生成逻辑2. Java代码示例3. 代…...
RK3568 pinctrl内容讲解
文章目录 一、pinctrl的概念`pinctrl` 的作用设备树中的 `pinctrl` 节点典型的 `pinctrl` 节点结构例子`pinctrl` 的重要性总结二、RK3568的pinctrl讲解1. `pinctrl` 节点2. `gpio0` 至 `gpio4` 子节点每个 `gpioX` 子节点的结构和作用3. `gpio1` 到 `gpio4` 子节点总结1. `aco…...
主流Web3公链的核心区别对比
以下是当前主流Web3公链的核心区别对比表,涵盖技术架构、性能、生态等关键维度: 特性以太坊 (Ethereum)SolanaBNB ChainPolygonAvalanche共识机制PoS(信标链分片)PoH(历史证明) PoSPoSA(权益证…...
Mac 电脑移动硬盘无法识别的解决方法
在使用 Mac 电脑的过程中,不少用户都遇到过移动硬盘没有正常推出,导致无法识别的问题。这不仅影响了数据的传输,还可能让人担心硬盘内数据的安全。今天,我们就来详细探讨一下针对这一问题的解决方法。 当发现移动硬盘无法识别时&…...
LeetCode Hot100 刷题笔记(4)—— 二叉树、图论
目录 一、二叉树 1. 二叉树的深度遍历(DFS:前序、中序、后序遍历) 2. 二叉树的最大深度 3. 翻转二叉树 4. 对称二叉树 5. 二叉树的直径 6. 二叉树的层序遍历 7. 将有序数组转换为二叉搜索树 8. 验证二叉搜索树 9. 二叉搜索树中第 K 小的元素 …...
安全框架SpringSecurity入门
安全框架 Spring Security 入门 Spring Security 是一个强大的安全框架,广泛用于保护基于 Spring 的应用程序。它提供了全面的安全服务,包括认证、授权、攻击防护等。下面我将为你详细介绍 Spring Security 的主要知识点,帮助你更好地理解和…...
c# 虚函数、接口、抽象区别和应用场景
文章目录 定义和语法实现要求继承和使用场景总结访问修饰符设计目的性能扩展性在 C# 里,虚函数、接口和抽象函数都能助力实现多态性,不过它们的定义、使用场景和特点存在差异,下面为你详细剖析: 定义和语法 虚函数:虚函数在基类里定义,使用 virtual 关键字,且有默认的实…...
MySQL Online DDL 技术深度解析
在MySQL数据库管理体系中,数据定义语言(DDL)和数据操作语言(DML)构成了数据库交互的基础。 DDL用于定义数据库对象,如数据库、表、列、索引等,相关命令包括CREATE、ALTER、DROP;DML则…...
【计算机视觉】YOLO语义分割
一、语义分割简介 1. 定义 语义分割(Semantic Segmentation)是计算机视觉中的一项任务,其目标是对图像中的每一个像素赋予一个类别标签。与目标检测只给出目标的边界框不同,语义分割能够在像素级别上区分不同类别,从…...
【SpringBoot + MyBatis + MySQL + Thymeleaf 的使用】
目录: 一:创建项目二:修改目录三:添加配置四:创建数据表五:创建实体类六:创建数据接口七:编写xml文件八:单元测试九:编写服务层十:编写控制层十一…...
git 按行切割 csv文件
# 进入Git Bash环境 # 基础用法(不保留标题行): split -l 1000 input.csv output_part_# 增强版(保留标题行): header$(head -n1 input.csv) # 提取标题 tail -n 2 input.csv | split -l 5000000 - --filt…...
在ensp进行OSPF+RIP+静态网络架构配置
一、实验目的 1.Ospf与RIP的双向引入路由消息 2.Ospf引入静态路由信息 二、实验要求 需求: 路由器可以互相ping通 实验设备: 路由器router7台 使用ensp搭建实验坏境,结构如图所示 三、实验内容 1.配置R1、R2、R3路由器使用Ospf动态路由…...
Qt实现HTTP GET/POST/PUT/DELETE请求
引言 在现代应用程序开发中,HTTP请求是与服务器交互的核心方式。Qt作为跨平台的C框架,提供了强大的网络模块(QNetworkAccessManager),支持GET、POST、PUT、DELETE等HTTP方法。本文将手把手教你如何用Qt实现这些请求&a…...
从零开始开发HarmonyOS应用并上架
开发环境搭建(1-2天) 硬件准备 操作系统:Windows 10 64位 或 macOS 10.13 内存:8GB以上(推荐16GB) 硬盘:至少10GB可用空间 软件安装 下载 DevEco Studio 3.1(官网:…...
Redis安全与配置问题——AOF文件损坏问题及解决方案
Java 中的 Redis AOF 文件损坏问题全面解析 一、AOF 文件损坏的本质与危害 1.1 AOF 持久化原理 Redis 的 AOF(Append-Only File) 通过记录所有写操作命令实现持久化。文件格式如下: *2\r\n$6\r\nSELECT\r\n$1\r\n0\r\n *3\r\n$3\r\nSET\r\…...
Java 线程池与 Kotlin 协程 高阶学习
以下是Java 线程池与 Kotlin 协程 高阶学习的对比指南,结合具体代码示例,展示两者在异步任务处理中的差异和 Kotlin 的简化优势: 分析: 首先,我们需要回忆Java中线程池的常见用法,比如通过ExecutorService创…...
3.第二阶段x64游戏实战-分析人物移动实现人物加速
免责声明:内容仅供学习参考,请合法利用知识,禁止进行违法犯罪活动! 本次游戏没法给 内容参考于:微尘网络安全 上一个内容:2.第二阶段x64游戏实战-x64dbg的使用 想找人物的速度,就需要使用Ch…...
leetcode 746. Min Cost Climbing Stairs
这道题用动态规划解决。这道题乍一看,含义有点模糊。有两个点要搞清楚:1)给定len个台阶的梯子,其实是要爬完(越过)整个梯子才算到达顶部,相当于顶部是第len1层台阶。台阶序号从0开始编号的话&am…...
网络信息安全应急演练方案
信息安全应急演练方案 总则 (一)编制目的 旨在建立并完善应对病毒入侵、Webshell 攻击以及未授权访问等信息安全突发事件的应急机制,提升组织对这类事件的快速响应、协同处理和恢复能力,最大程度降低事件对业务运营、数据安全和…...
H.264编码解析与C++实现详解
一、H.264编码核心概念 1.1 分层编码结构 H.264采用分层设计,包含视频编码层(VCL)和网络抽象层(NAL)。VCL处理核心编码任务,NAL负责封装网络传输数据。 1.2 NALU单元结构 // NAL单元头部结构示例 struc…...
Python入门(5):异常处理
目录 1 异常处理基础概念 1.1 什么是异常? 1.2 异常与错误的区别 2 异常处理基础 2.1 常见内置异常类型 2.2 try-except 基本结构 2.3 捕获多个异常 2.4 抛出异常 2.4.1 使用raise语句 2.4.2 自定义异常类 3 高级异常处理技巧 3.1 不要过度捕…...
Scala(三)
本节课学习了函数式编程,了解到它与Java、C函数式编程的区别;学习了函数的基础,了解到它的基本语法、函数和方法的定义、函数高级。。。学习到函数至简原则,高阶函数,匿名函数等。 函数的定义 函数基本语法 例子&…...
