当前位置: 首页 > news >正文

Pytorch单机多卡分布式训练

Pytorch单机多卡分布式训练

数据并行:

DP和DDP

这两个都是pytorch下实现多GPU训练的库,DP是pytorch以前实现的库,现在官方更推荐使用DDP,即使是单机训练也比DP快。

  1. DataParallel(DP)

    • 只支持单进程多线程,单一机器上进行训练。
    • 模型训练开始的时候,先把模型复制到四个GPU上面,然后把数据分配给四个GPU进行前向传播,前向传播之后再汇总到卡0上面,然后在卡0上进行反向传播,参数更新,再将更新好的模型复制到其他几张卡上。

    在这里插入图片描述

  2. DistributedDataParallel(DDP)

    • 支持多线程多进程,单一或者多个机器上进行训练。通常DDP比DP要快。

    • 先把模型载入到四张卡上,每个GPU上都分配一些小批量的数据,再进行前向传播,反向传播,计算完梯度之后再把所有卡上的梯度汇聚到卡0上面,卡0算完梯度的平均值之后广播给所有的卡,所有的卡更新自己的模型,这样传输的数据量会少很多。

      在这里插入图片描述

DDP代码写法

  1. 初始化

    import torch.distributed as dist
    import torch.utils.data.distributed# 进行初始化,backend表示通信方式,可选择的有nccl(英伟达的GPU2GPU的通信库,适用于具有英伟达GPU的分布式训练)、gloo(基于tcp/ip的后端,可在不同机器之间进行通信,通常适用于不具备英伟达GPU的环境)、mpi(适用于支持mpi集群的环境)
    # init_method: 告知每个进程如何发现彼此,默认使用env://
    dist.init_process_group(backend='nccl', init_method="env://")
    
  2. 设置device

    device = torch.device(f'cuda:{args.local_rank}')	# 设置device,local_rank表示当前机器的进程号,该方式为每个显卡一个进程
    torch.cuda.set_device(device)	# 设定device
    
  3. 创建dataloader之前要加一个sampler

    trans = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (1.0,))])
    data_set = torchvision.datasets.MNIST("./", train=True, transform=trans, target_transform=None, download=True)
    train_sampler = torch.utils.data.distributed.DistributedSampler(data_set)	# 加一个sampler
    data_loader_train = torch.utils.data.DataLoader(dataset=data_set, batch_size=256, sampler=train_sampler)
    
  4. torch.nn.parallel.DistributedDataParallel包裹模型(先to(device)再包裹模型)

    net = torchvision.models.resnet101(num_classes=10)
    net.conv1 = torch.nn.Conv2d(1, 64, (7, 7), (2, 2), (3, 3), bias=False)
    net = net.to(device)
    net = torch.nn.parallel.DistributedDataParallel(net, device_ids=[device], output_device=[device])	# 包裹模型
    
  5. 真正训练之前要set_epoch(),否则将不会shuffer数据

    for epoch in range(10):train_sampler.set_epoch(epoch)		# set_epochfor step, data in enumerate(data_loader_train):images, labels = dataimages, labels = images.to(device), labels.to(device)opt.zero_grad()outputs = net(images)loss = criterion(outputs, labels)loss.backward()opt.step()if step % 10 == 0:print("loss: {}".format(loss.item()))
    
  6. 模型保存

    if args.local_rank == 0:		# local_rank为0表示master进程torch.save(net, "my_net.pth")
    
  7. 运行

    if __name__ == "__main__":parser = argparse.ArgumentParser()# local_rank参数是必须的,运行的时候不必自己指定,DDP会自行提供parser.add_argument("--local_rank", type=int, default=0)args = parser.parse_args()main(args)
    
  8. 运行命令

    python -m torch.distributed.launch --nproc_per_node=2 多卡训练.py	# --nproc_per_node=2表示当前机器上有两个GPU可以使用
    

完整代码

import os
import argparse
import torch
import torchvision
import torch.distributed as dist
import torch.utils.data.distributedfrom torchvision import transforms
from torch.multiprocessing import Processdef main(args):# nccl: 后端基于NVIDIA的GPU-to-GPU通信库,适用于具有NVIDIA GPU的分布式训练# gloo: 后端是一个基于TCP/IP的后端,可在不同机器之间进行通信,通常适用于不具备NVIDIA GPU的环境。# mpi: 后端使用MPI实现,适用于具备MPI支持的集群环境。# init_method: 告知每个进程如何发现彼此,如何使用通信后端初始化和验证进程组。 默认情况下,如果未指定 init_method,PyTorch 将使用环境变量初始化方法 (env://)。dist.init_process_group(backend='nccl', init_method="env://") # nccl比较推荐device = torch.device(f'cuda:{args.local_rank}')torch.cuda.set_device(device)trans = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (1.0,))])data_set = torchvision.datasets.MNIST("./", train=True, transform=trans, target_transform=None, download=True)train_sampler = torch.utils.data.distributed.DistributedSampler(data_set)data_loader_train = torch.utils.data.DataLoader(dataset=data_set, batch_size=256, sampler=train_sampler)net = torchvision.models.resnet101(num_classes=10)net.conv1 = torch.nn.Conv2d(1, 64, (7, 7), (2, 2), (3, 3), bias=False)net = net.to(device)net = torch.nn.parallel.DistributedDataParallel(net, device_ids=[device], output_device=[device])criterion = torch.nn.CrossEntropyLoss()opt = torch.optim.Adam(params=net.parameters(), lr=0.001)for epoch in range(10):train_sampler.set_epoch(epoch)for step, data in enumerate(data_loader_train):images, labels = dataimages, labels = images.to(device), labels.to(device)opt.zero_grad()outputs = net(images)loss = criterion(outputs, labels)loss.backward()opt.step()if step % 10 == 0:print("loss: {}".format(loss.item()))if args.local_rank == 0:torch.save(net, "my_net.pth")if __name__ == "__main__":parser = argparse.ArgumentParser()# must parse the command-line argument: ``--local_rank=LOCAL_PROCESS_RANK``, which will be provided by DDPparser.add_argument("--local_rank", type=int, default=0)args = parser.parse_args()main(args)

参考:

https://zhuanlan.zhihu.com/p/594046884
https://zhuanlan.zhihu.com/p/358974461

相关文章:

Pytorch单机多卡分布式训练

Pytorch单机多卡分布式训练 数据并行: DP和DDP 这两个都是pytorch下实现多GPU训练的库,DP是pytorch以前实现的库,现在官方更推荐使用DDP,即使是单机训练也比DP快。 DataParallel(DP) 只支持单进程多线程…...

asp.net coremvc+efcore增删改查

下面是一个使用 EF Core 在 ASP.NET Core MVC 中完成增删改查的示例&#xff1a; 创建一个新的 ASP.NET Core MVC 项目。 安装 EF Core 相关的 NuGet 包。在项目文件 (.csproj) 中添加以下依赖项&#xff1a; <ItemGroup><PackageReference Include"Microsoft…...

Java基础面试,什么是面向对象,谈谈你对面向对象的理解

前言 马上就要找工作了&#xff0c;从今天开始一天准备1~2道面试题&#xff0c;来打基础&#xff0c;就从Java基础开始吧。 什么是面向对象&#xff0c;谈谈你对面向对象的理解&#xff1f; 谈到面向对象&#xff0c;那就不得不谈到面向过程。面向过程更加注重的是完成一个任…...

Ubuntu系统初始设置

更换国内源 安装截图工具 安装中文输入法 安装QQ 参考&#xff1a; 安装双系统win10Ubuntu20.04LTS&#xff08;详细到我自己都害怕&#xff09; 引导方式磁盘分区方法UEFIGPTLegancyMBR 安装网络助手 sudo apt install net-tools 安装VS Code 使用从官网下载.deb安装包…...

焕新古文化传承之路,AI为古彝文识别赋能

目录 1 古彝文与古典保护 2 古文识别的挑战 2.1 西文与汉文OCR 2.2 古彝文识别难点 3 合合信息&#xff1a;古彝文保护新思路 3.1 图像矫正 3.2 图像增强 3.3 语义理解 3.4 工程技巧 4 总结 1 古彝文与古典保护 彝文指的是云南、贵州、四川等地的彝族人使用的文字&am…...

毛玻璃动画交互效果

效果展示 页面结构组成 从上述的效果展示页面结构来看&#xff0c;页面布局都是比较简单的&#xff0c;只是元素的动画交互比较麻烦。 第一个动画交互是两个圆相互交错来回运动。第二个动画交互是三角绕着圆进行 360 度旋转。 CSS 知识点 animationanimation-delay绝对定位…...

Audio2Face的工作原理

预加载一个3D数字人物模型(Digital Mark),该模型可以通过音频驱动进行面部动画。 用户上传音频文件作为输入。 将音频输入馈送到预训练的深度神经网络中。 Audio2Face加载预制的3d人头mesh 3D数字人物面部模型由大量顶点组成,每个顶点都有xyz坐标。 深度神经网络输入音频特征,…...

【面试题】2023前端面试真题之JS篇

前端面试题库 &#xff08;面试必备&#xff09; 推荐&#xff1a;★★★★★ 地址&#xff1a;前端面试题库 表妹一键制作自己的五星红旗国庆头像&#xff0c;超好看 世界上只有一种真正的英雄主义&#xff0c;那就是看清生活的真相之后&#xff0c;依然热爱生活。…...

Mysql 分布式序列算法

接上文 Mysql分库分表 1.分布式序列简介 在分布式系统下&#xff0c;怎么保证ID的生成满足以上需求&#xff1f; ShardingJDBC支持以上两种算法自动生成ID。这里&#xff0c;使用ShardingJDBC让主键ID以雪花算法进行生成&#xff0c;首先配置数据库&#xff0c;因为默认的注…...

Windows/Linux双系统卸载Ubuntu

参考&#xff1a;双系统下完全卸载ubuntu...

asp.net core mvc 视图组件viewComponents

ASP.NET Core MVC 视图组件&#xff08;View Components&#xff09;是一种可重用的 UI 组件&#xff0c;用于在视图中呈现某些特定的功能块&#xff0c;例如导航菜单、侧边栏、用户信息等。视图组件提供了一种将视图逻辑与控制器解耦的方式&#xff0c;使视图能够更加灵活、可…...

如何保持终身学习

文章目录 2.1. 了解你的大脑2.2 学习是对神经元网络的塑造2.3 大脑的一生 3.学习的心里基础3.1 固定思维与成长思维3.2 我们为什么要学习 4. 学习路径4.1 构建知识模块4.2 大脑是如何使用注意力的4.3 提高专注力4.4 放松一下&#xff0c;学的更好4.5 巩固你的学习痕迹4.6 被动学…...

【RV1103】RTL8723bs (SD卡形状模块)驱动开发

文章目录 前言硬件分析Luckfox Pico的SD卡接口硬件原理图LicheePi zero WiFiBT模块总结 正文Kernel WiFi驱动支持Kernel 设备树支持修改一&#xff1a;修改二&#xff1a; SDK全局配置支持 wifi全局编译脚本支持编译逻辑拷贝rtl8723bs的固件到文件系统的固定目录里面去 上电后手…...

LeetCode 周赛上分之旅 #49 再探内向基环树

⭐️ 本文已收录到 AndroidFamily&#xff0c;技术和职场问题&#xff0c;请关注公众号 [彭旭锐] 和 BaguTree Pro 知识星球提问。 学习数据结构与算法的关键在于掌握问题背后的算法思维框架&#xff0c;你的思考越抽象&#xff0c;它能覆盖的问题域就越广&#xff0c;理解难度…...

kubernetes-v1.23.3 部署 kafka_2.12-2.3.0

文章目录 [toc]构建 debian 基础镜像部署 zookeeper配置 namespace配置 gfs 的 endpoints配置 pv 和 pvc配置 configmap配置 service配置 statefulset 部署 kafka配置 configmap配置 service配置 statefulset 这里采用的部署方式如下&#xff1a; 使用自定义的 debian 镜像作为…...

位置编码器

目录 1、位置编码器的作用 2、代码演示 &#xff08;1&#xff09;、使用unsqueeze扩展维度 &#xff08;2&#xff09;、使用squeeze降维 &#xff08;3&#xff09;、显示张量维度 &#xff08;4&#xff09;、随机失活张量中的数值 3、定义位置编码器类&#xff0c;我…...

Lua多脚本执行

--全局变量 a 1 b "123"for i 1,2 doc "Holens" endprint(c) print("*************************************1")--本地变量&#xff08;局部变量&#xff09; for i 1,2 dolocal d "Holens2"print(d) end print(d)function F1( ..…...

Spirng Cloud Alibaba Nacos注册中心的使用 (环境隔离、服务分级存储模型、权重配置、临时实例与持久实例)

文章目录 一、环境隔离1. Namespace&#xff08;命名空间&#xff09;&#xff1a;2. Group&#xff08;分组&#xff09;&#xff1a;3. Services&#xff08;服务&#xff09;&#xff1a;4. DataId&#xff08;数据ID&#xff09;&#xff1a;5. 实战演示&#xff1a;5.1 默…...

26663-2011 大型液压安全联轴器 课堂随笔

声明 本文是学习GB-T 26663-2011 大型液压安全联轴器. 而整理的学习笔记,分享出来希望更多人受益,如果存在侵权请及时联系我们 1 范围 本标准规定了大型液压安全联轴器的分类、技术要求、试验方法及检验规则等。 本标准适用于联接两同轴线的传动轴系&#xff0c;可起到限制…...

ChatGPT架构师:语言大模型的多模态能力、幻觉与研究经验

来源 | The Robot Brains Podcast OneFlow编译 翻译&#xff5c;宛子琳、杨婷 9月26日&#xff0c;OpenAI宣布ChatGPT新增了图片识别和语音能力&#xff0c;使得ChatGPT不仅可以进行文字交流&#xff0c;还可以给它展示图片并进行互动&#xff0c;这是一次ChatGPT向多模态进化的…...

华为云AI开发平台ModelArts

华为云ModelArts&#xff1a;重塑AI开发流程的“智能引擎”与“创新加速器”&#xff01; 在人工智能浪潮席卷全球的2025年&#xff0c;企业拥抱AI的意愿空前高涨&#xff0c;但技术门槛高、流程复杂、资源投入巨大的现实&#xff0c;却让许多创新构想止步于实验室。数据科学家…...

C++_核心编程_多态案例二-制作饮品

#include <iostream> #include <string> using namespace std;/*制作饮品的大致流程为&#xff1a;煮水 - 冲泡 - 倒入杯中 - 加入辅料 利用多态技术实现本案例&#xff0c;提供抽象制作饮品基类&#xff0c;提供子类制作咖啡和茶叶*//*基类*/ class AbstractDr…...

华硕a豆14 Air香氛版,美学与科技的馨香融合

在快节奏的现代生活中&#xff0c;我们渴望一个能激发创想、愉悦感官的工作与生活伙伴&#xff0c;它不仅是冰冷的科技工具&#xff0c;更能触动我们内心深处的细腻情感。正是在这样的期许下&#xff0c;华硕a豆14 Air香氛版翩然而至&#xff0c;它以一种前所未有的方式&#x…...

springboot整合VUE之在线教育管理系统简介

可以学习到的技能 学会常用技术栈的使用 独立开发项目 学会前端的开发流程 学会后端的开发流程 学会数据库的设计 学会前后端接口调用方式 学会多模块之间的关联 学会数据的处理 适用人群 在校学生&#xff0c;小白用户&#xff0c;想学习知识的 有点基础&#xff0c;想要通过项…...

js 设置3秒后执行

如何在JavaScript中延迟3秒执行操作 在JavaScript中&#xff0c;要设置一个操作在指定延迟后&#xff08;例如3秒&#xff09;执行&#xff0c;可以使用 setTimeout 函数。setTimeout 是JavaScript的核心计时器方法&#xff0c;它接受两个参数&#xff1a; 要执行的函数&…...

LUA+Reids实现库存秒杀预扣减 记录流水 以及自己的思考

目录 lua脚本 记录流水 记录流水的作用 流水什么时候删除 我们在做库存扣减的时候&#xff0c;显示基于Lua脚本和Redis实现的预扣减 这样可以在秒杀扣减的时候保证操作的原子性和高效性 lua脚本 // ... 已有代码 ...Overridepublic InventoryResponse decrease(Inventor…...

Axure零基础跟我学:展开与收回

亲爱的小伙伴,如有帮助请订阅专栏!跟着老师每课一练,系统学习Axure交互设计课程! Axure产品经理精品视频课https://edu.csdn.net/course/detail/40420 课程主题:Axure菜单展开与收回 课程视频:...

uniapp获取当前位置和经纬度信息

1.1. 获取当前位置和经纬度信息&#xff08;需要配置高的SDK&#xff09; 调用uni-app官方API中的uni.chooseLocation()&#xff0c;即打开地图选择位置。 <button click"getAddress">获取定位</button> const getAddress () > {uni.chooseLocatio…...

在MobaXterm 打开图形工具firefox

目录 1.安装 X 服务器软件 2.服务器端配置 3.客户端配置 4.安装并打开 Firefox 1.安装 X 服务器软件 Centos系统 # CentOS/RHEL 7 及之前&#xff08;YUM&#xff09; sudo yum install xorg-x11-server-Xorg xorg-x11-xinit xorg-x11-utils mesa-libEGL mesa-libGL mesa-…...

【QT】qtdesigner中将控件提升为自定义控件后,css设置样式不生效(已解决,图文详情)

目录 0.背景 1.解决思路 2.详细代码 0.背景 实际项目中遇到的问题&#xff0c;描述如下&#xff1a; 我在qtdesigner用界面拖了一个QTableView控件&#xff0c;object name为【tableView_electrode】&#xff0c;然后【提升为】了自定义的类【Steer_Electrode_Table】&…...