当前位置: 首页 > news >正文

使用 PyTorch 和 Horovod 来编写一个简单的分布式训练 demo

使用 PyTorch 和 Horovod 来编写一个简单的分布式训练 demo,可以帮助你理解如何在多GPU或多节点环境中高效地训练深度学习模型。Horovod 是 Uber 开发的一个用于分布式训练的框架,它支持 TensorFlow、Keras、PyTorch 等多个机器学习库。下面是一个基于 PyTorch 的简单例子,演示了如何用 Horovod 进行分布式训练。

安装依赖

首先确保你已经安装了 PyTorch 和 Horovod。你可以通过 pip 或者 conda 来安装这些包。对于 Horovod,推荐使用 MPI(Message Passing Interface)进行通信,因此你也需要安装 MPI 和相应的开发工具。

pip install torch torchvision horovod

或者如果你使用的是 Anaconda:

conda install pytorch torchvision -c pytorch
horovodrun --check
# 如果没有安装 horovod, 可以使用以下命令安装:
pip install horovod[pytorch]

编写 PyTorch + Horovod 代码

创建一个新的 Python 文件 train.py,然后添加如下代码:

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
import horovod.torch as hvd
from torchvision import datasets, transforms# 初始化 Horovod
hvd.init()# 设置随机种子,确保结果可复现
torch.manual_seed(42)# 如果有 GPU 可用,则使用 GPU
if torch.cuda.is_available():torch.cuda.set_device(hvd.local_rank())# 加载数据集 (以 MNIST 数据集为例)
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
train_dataset = datasets.MNIST('.', train=True, download=True, transform=transform)
val_dataset = datasets.MNIST('.', train=False, transform=transform)# 分布式采样器
train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset, num_replicas=hvd.size(), rank=hvd.rank())
val_sampler = torch.utils.data.distributed.DistributedSampler(val_dataset, num_replicas=hvd.size(), rank=hvd.rank())# 创建数据加载器
train_loader = DataLoader(train_dataset, batch_size=64, sampler=train_sampler)
val_loader = DataLoader(val_dataset, batch_size=1000, sampler=val_sampler)# 定义模型
class Net(nn.Module):def __init__(self):super(Net, self).__init__()self.conv1 = nn.Conv2d(1, 32, 3, 1)self.conv2 = nn.Conv2d(32, 64, 3, 1)self.dropout1 = nn.Dropout2d(0.25)self.dropout2 = nn.Dropout2d(0.5)self.fc1 = nn.Linear(9216, 128)self.fc2 = nn.Linear(128, 10)def forward(self, x):x = self.conv1(x)x = nn.functional.relu(x)x = self.conv2(x)x = nn.functional.relu(x)x = nn.functional.max_pool2d(x, 2)x = self.dropout1(x)x = torch.flatten(x, 1)x = self.fc1(x)x = nn.functional.relu(x)x = self.dropout2(x)x = self.fc2(x)output = nn.functional.log_softmax(x, dim=1)return outputmodel = Net()# 如果有 GPU 可用,则将模型转移到 GPU 上
if torch.cuda.is_available():model.cuda()# 定义损失函数和优化器,并应用 Horovod 的 DistributedOptimizer 包装
optimizer = optim.Adam(model.parameters(), lr=0.001 * hvd.size())
optimizer = hvd.DistributedOptimizer(optimizer,named_parameters=model.named_parameters(),op=hvd.Average)# 损失函数
criterion = nn.CrossEntropyLoss()# 训练模型
def train(epoch):model.train()for batch_idx, (data, target) in enumerate(train_loader):if torch.cuda.is_available():data, target = data.cuda(), target.cuda()optimizer.zero_grad()output = model(data)loss = criterion(output, target)loss.backward()optimizer.step()if batch_idx % 100 == 0:print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(epoch, batch_idx * len(data), len(train_sampler),100. * batch_idx / len(train_loader), loss.item()))# 验证模型
def validate():model.eval()validation_loss = 0correct = 0with torch.no_grad():for data, target in val_loader:if torch.cuda.is_available():data, target = data.cuda(), target.cuda()output = model(data)validation_loss += criterion(output, target).item() # sum up batch losspred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probabilitycorrect += pred.eq(target.view_as(pred)).sum().item()validation_loss /= len(val_loader.dataset)accuracy = 100. * correct / len(val_loader.dataset)print('\nValidation set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(validation_loss, correct, len(val_loader.dataset), accuracy))# 调用训练和验证函数
for epoch in range(1, 4):train_sampler.set_epoch(epoch)train(epoch)validate()# 广播模型状态从rank 0到其他进程
hvd.broadcast_parameters(model.state_dict(), root_rank=0)
hvd.broadcast_optimizer_state(optimizer, root_rank=0)

运行代码

要运行这段代码,你需要使用 horovodrun 命令来启动多个进程。例如,在单个节点上的 4 个 GPU 上运行该脚本,可以这样做:

horovodrun -np 4 -H localhost:4 python train.py

这将在本地主机上启动四个进程,每个进程都会占用一个 GPU。如果你是在多台机器上运行,你需要指定每台机器的地址和可用的 GPU 数量。

请注意,这个例子是一个非常基础的实现,实际应用中可能还需要考虑更多的细节,比如更复杂的模型结构、数据预处理、超参数调整等。

相关文章:

使用 PyTorch 和 Horovod 来编写一个简单的分布式训练 demo

使用 PyTorch 和 Horovod 来编写一个简单的分布式训练 demo,可以帮助你理解如何在多GPU或多节点环境中高效地训练深度学习模型。Horovod 是 Uber 开发的一个用于分布式训练的框架,它支持 TensorFlow、Keras、PyTorch 等多个机器学习库。下面是一个基于 P…...

SQL复杂查询功能介绍及示例

文章目录 1. 多表连接(JOIN)功能介绍应用场景示例查询及初始表格customers 表(未查询前)orders 表(未查询前)INNER JOIN 示例LEFT JOIN 示例 2. 子查询(Subquery)功能介绍应用场景示…...

shell基础用法

shell基础知识 shell中的多行注释 :<<EOF read echo $REPLY # read不指定变量&#xff0c;则默认写入$REPLY EOF # :<<EOF ...EOF 多行注释&#xff0c;EOF可以替换为&#xff01;# 等文件目录和执行目录 echo $0$0 # ./demo.sh echo $0的realpath$(realpath…...

C#设计模式--策略模式(Strategy Pattern)

策略模式是一种行为设计模式&#xff0c;它使你能在运行时改变对象的行为。在策略模式定义了一系列算法或策略&#xff0c;并将每个算法封装在独立的类中&#xff0c;使得它们可以互相替换。通过使用策略模式&#xff0c;可以在运行时根据需要选择不同的算法&#xff0c;而不需…...

【opencv入门教程】15. 访问像素的十四种方式

文章选自&#xff1a; 一、像素访问 一张图片由许多个点组成&#xff0c;每个点就是一个像素&#xff0c;每个像素包含不同的值&#xff0c;对图像像素操作是图像处理过程中常使用的 二、访问像素 void Samples::AccessPixels1(Mat &image, int div 64) {int nl imag…...

【MySQL调优】如何进行MySQL调优?从参数、数据建模、索引、SQL语句等方向,三万字详细解读MySQL的性能优化方案(2024版)

导航&#xff1a; 本文一些内容需要聚簇索引、非聚簇索引、B树、覆盖索引、索引下推等前置概念&#xff0c;虽然本文有简单回顾&#xff0c;但详细可以参考下文的【MySQL高级篇】 【Java笔记踩坑汇总】Java基础JavaWebSSMSpringBootSpringCloud瑞吉外卖/谷粒商城/学成在线设计模…...

根据html的段落长度设置QtextBrowser的显示内容,最少显示一个段落

要根据 HTML 段落的长度设置 QTextBrowser 的显示内容&#xff0c;并确保至少显示一个段落&#xff0c;可以通过以下步骤来实现&#xff1a; 加载 HTML 内容&#xff1a;首先&#xff0c;你需要加载 HTML 内容到 QTextBrowser 中。可以通过 setHtml() 方法来设置 HTML。 计算段…...

基于Huffman编码的GPS定位数据无损压缩算法

目录 一、引言 二、霍夫曼编码 三、经典Huffman编码 四、适应性Huffman编码 五、GPS定位数据压缩 提示&#xff1a;文末附定位数据压缩工具和源码 一、引言 车载监控系统中&#xff0c;车载终端需要获取GPS信号&#xff08;经度、纬 度、速度、方向等&#xff09;实时上传…...

php:完整部署Grid++Report到php项目,并实现模板打印

一、下载Grid++Report软件 路径:开发者安装包下载 - 锐浪报表工具 二、 安装软件 1、对下载的压缩包运行内部的exe文件 2、选择语言 3、 完成安装引导 下一步即可 4、接收许可协议 点击“我接受” 5、选择安装路径 “浏览”选择安装路径,点击"安装" 6、完成…...

C标签和 EL表达式的在前端界面的应用

目录 前言 常用的c标签有&#xff1a; for循环 1 表示 普通的for循环的 2 常在集合中使用 表示 选择关系 1 简单的表示如果 2 表示如果。。否则。。 EL表达式 格式 &#xff1a; ${属性名/对象/ 集合} 前言 本篇博客介绍 c标签和el表达式的使用 使用C标签 要引入 …...

Linux絮絮叨(四) 系统目录结构

Linux 系统的目录结构&#xff08;Filesystem Hierarchy Standard, FHS&#xff09;定义了 Linux 系统中文件系统的标准布局&#xff0c;以下是一些常见目录的功能&#xff1a; 根目录 / 描述&#xff1a;所有文件和目录的起始点&#xff0c;Linux 文件系统的根。内容&#xf…...

Java基于SpringBoot的网上订餐系统,附源码

博主介绍&#xff1a;✌Java老徐、7年大厂程序员经历。全网粉丝12w、csdn博客专家、掘金/华为云/阿里云/InfoQ等平台优质作者、专注于Java技术领域和毕业项目实战✌ &#x1f345;文末获取源码联系&#x1f345; &#x1f447;&#x1f3fb; 精彩专栏推荐订阅&#x1f447;&…...

《Java核心技术I》死锁

死锁 账户1&#xff1a;200元账户2: 300元线程1&#xff1a;从账号1转300到账户2线程2&#xff1a;从账户2转400到账户1 如上&#xff0c;线程1和线程2显然都被阻塞&#xff0c;两个账户的余额都不足以转账&#xff0c;两个线程都无法执行下去。 有可能会因为每一个线程要等…...

【Windows11系统局域网共享文件数据】

【Windows11系统局域网共享文件数据】 1. 引言1. 规划网络2. 获取必要的硬件3. 设置网络4. 配置网络设备5. 测试网络连接6. 安全性和维护7. 扩展和优化 2. 准备工作2.1: 启用网络发现和文件共享2.2: 设置共享文件夹 3. 访问共享文件夹4. 小贴士5. 总结 1. 引言 随着家庭和小型办…...

MCU、ARM体系结构,单片机基础,单片机操作

计算机基础 计算机的组成 输入设备、输出设备、存储器、运算器、控制器 输入设备&#xff1a;将其他信号转换为计算机可以识别的信号&#xff08;电信号&#xff09;。输出设备&#xff1a;将电信号&#xff08;&#xff10;、&#xff11;&#xff09;转为人或其他设备能理解的…...

在办公室环境中用HMD替代传统显示器的优势

VR头戴式显示器&#xff08;HMD&#xff09;是进入虚拟现实环境的一把钥匙&#xff0c;拥有HMD的您将能够在虚拟现实世界中尽情探索未知领域&#xff0c;正如如今的互联网一样&#xff0c;虚拟现实环境能够为您提供现实中无法实现的或不可能实现的事。随着技术的不断进步&#…...

ssm 多数据源 注解版本

application.xml 配置如下 <!-- 使用 DruidDataSource 数据源 --><bean id"primaryDataSource" class"com.alibaba.druid.pool.DruidDataSource" init-method"init" destroy-method"close"></bean> <!-- 使用 数…...

selenium常见接口函数使用

博客主页&#xff1a;花果山~程序猿-CSDN博客 文章分栏&#xff1a;测试_花果山~程序猿的博客-CSDN博客 关注我一起学习&#xff0c;一起进步&#xff0c;一起探索编程的无限可能吧&#xff01;让我们一起努力&#xff0c;一起成长&#xff01; 目录 1. 查找 查找方式 css_s…...

STM32F103单片机使用STM32CubeMX新建IAR工程步骤

打开STM32CubeMX软件&#xff0c;选择File 选择新建工程 在打开的窗口输入单片机型号 在右下角选择单片机型号&#xff0c;然后点右上角 start project&#xff0c;开始新建工程。 接下来设置调试接口&#xff0c;在左边System Core中选择 SYS&#xff0c;然后在右右边debu…...

刷题重开:找出字符串中第一个匹配项的下标——解题思路记录

问题描述&#xff1a; 给你两个字符串 haystack 和 needle &#xff0c;请你在 haystack 字符串中找出 needle 字符串的第一个匹配项的下标&#xff08;下标从 0 开始&#xff09;。如果 needle 不是 haystack 的一部分&#xff0c;则返回 -1 。 示例 1&#xff1a; 输入&…...

product/admin/list?page=0size=10field=jancodevalue=4562249292272

文章目录 1、ProductController2、AdminCommonService3、ProductApiService4、ProductCommonService5、ProductSqlService https://api.crossbiog.com/product/admin/list?page0&size10&fieldjancode&value45622492922721、ProductController GetMapping("ad…...

人工智能机器学习无监督学习概念及应用详解

无监督学习&#xff1a;深入解析 引言 在人工智能和机器学习的领域中&#xff0c;无监督学习&#xff08;Unsupervised Learning&#xff09;是一种重要的学习范式。与监督学习不同&#xff0c;无监督学习不依赖于标签数据&#xff0c;而是通过模型从无标签的数据中学习数据的…...

APM装机教程(五):测绘无人船

文章目录 前言一、元生惯导RTK使用二、元厚HXF260测深仪使用三、云卓H2pro遥控器四、海康威视摄像头 前言 船体&#xff1a;超维USV-M1000 飞控&#xff1a;pix6c mini 测深仪&#xff1a;元厚HXF160 RTK&#xff1a;元生惯导RTK 遥控器&#xff1a;云卓H12pro 摄像头&#xf…...

微信小程序 运行出错 弹出提示框(获取token失败,请重试 或者 请求失败)

原因是&#xff1a;需要登陆微信公众平台在开发管理 中设置 相应的 服务器域名 中的 request合法域名 // index.jsPage({data: {products:[],cardLayout: grid, // 默认卡片布局为网格模式isGrid: true, // 默认为网格布局page: 0, // 当前页码size: 10, // 每页大小hasMore…...

IDEA的service窗口中启动类是灰色且容易消失

大家在学习Spring Cloud的过程中,随着项目的深入,会分出很多个微服务,当我们的服务数量大于等于三个的时候,IDEA会给我们的服务整理起来,类似于这样 但是当我们的微服务数量达到5个以上的时候,再启动服务的时候,服务的启动类就会变成灰色,而且还容易丢失 解决方法 我们按住…...

R中利用ggplot2绘制气泡图

闲来无事&#xff0c;整理了一下自己的绘图笔记&#xff0c;顺便分享到CSDN上。 一、介绍 气泡图&#xff08;Bubble Plot&#xff09;是一种常用的数据可视化方法&#xff0c;用于展示三个变量之间的关系。气泡图的特点是通过气泡的大小、颜色和位置来表达数据中的多维信息。…...

CID引流电商

ClickID技术是基于多家媒体平台开发的电商引流服务&#xff0c;通过媒体提供的宏参数&#xff0c;间接解决电商平台订单数据的回传问题&#xff0c;帮助账户收集到极致精准的数据模型&#xff0c;搭建不同媒体往各平台引流的桥梁。简单来说就是通过ClickID数据监测到另外一个平…...

在google cloud虚拟机上配置anaconda虚拟环境简单教程

下载anaconda安装包 wget https://repo.anaconda.com/archive/Anaconda3-2022.10-Linux-x86_64.sh 安装 bash Anaconda3-2022.10-Linux-x86_64.sh 进入base环境 eval "$(/home/xmxhuihui/anaconda3/bin/conda shell.bash hook)" source ~/.bashrc 安装虚拟环境…...

windows下用vs搭配clang一起生成抽象语法树

如果你使用的是 Visual Studio 环境&#xff0c;并且想要通知 Clang 关于 C 语言标准库的位置&#xff0c;你可以通过以下几种方法来实现。Visual Studio 提供了完整的 C/C 标准库&#xff0c;Clang 可以与之协同工作。以下是具体步骤&#xff1a; 1. 使用 clang-cl Visual S…...

输入法:点三下输入一个汉字

作者常用的双拼输入法&#xff0c;需要26键。虽然也有9键的方案&#xff0c;但重码率较高。计算一下&#xff0c;9键点2下&#xff0c;共81种排列组合。而汉字的读音&#xff0c;不计声调&#xff0c;有400多个。相差甚多。 所以&#xff0c;设计了“三拼输入法”&#xff0c;…...