当前位置: 首页 > news >正文

使用 PyTorch 和 Horovod 来编写一个简单的分布式训练 demo

使用 PyTorch 和 Horovod 来编写一个简单的分布式训练 demo,可以帮助你理解如何在多GPU或多节点环境中高效地训练深度学习模型。Horovod 是 Uber 开发的一个用于分布式训练的框架,它支持 TensorFlow、Keras、PyTorch 等多个机器学习库。下面是一个基于 PyTorch 的简单例子,演示了如何用 Horovod 进行分布式训练。

安装依赖

首先确保你已经安装了 PyTorch 和 Horovod。你可以通过 pip 或者 conda 来安装这些包。对于 Horovod,推荐使用 MPI(Message Passing Interface)进行通信,因此你也需要安装 MPI 和相应的开发工具。

pip install torch torchvision horovod

或者如果你使用的是 Anaconda:

conda install pytorch torchvision -c pytorch
horovodrun --check
# 如果没有安装 horovod, 可以使用以下命令安装:
pip install horovod[pytorch]

编写 PyTorch + Horovod 代码

创建一个新的 Python 文件 train.py,然后添加如下代码:

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
import horovod.torch as hvd
from torchvision import datasets, transforms# 初始化 Horovod
hvd.init()# 设置随机种子,确保结果可复现
torch.manual_seed(42)# 如果有 GPU 可用,则使用 GPU
if torch.cuda.is_available():torch.cuda.set_device(hvd.local_rank())# 加载数据集 (以 MNIST 数据集为例)
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
train_dataset = datasets.MNIST('.', train=True, download=True, transform=transform)
val_dataset = datasets.MNIST('.', train=False, transform=transform)# 分布式采样器
train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset, num_replicas=hvd.size(), rank=hvd.rank())
val_sampler = torch.utils.data.distributed.DistributedSampler(val_dataset, num_replicas=hvd.size(), rank=hvd.rank())# 创建数据加载器
train_loader = DataLoader(train_dataset, batch_size=64, sampler=train_sampler)
val_loader = DataLoader(val_dataset, batch_size=1000, sampler=val_sampler)# 定义模型
class Net(nn.Module):def __init__(self):super(Net, self).__init__()self.conv1 = nn.Conv2d(1, 32, 3, 1)self.conv2 = nn.Conv2d(32, 64, 3, 1)self.dropout1 = nn.Dropout2d(0.25)self.dropout2 = nn.Dropout2d(0.5)self.fc1 = nn.Linear(9216, 128)self.fc2 = nn.Linear(128, 10)def forward(self, x):x = self.conv1(x)x = nn.functional.relu(x)x = self.conv2(x)x = nn.functional.relu(x)x = nn.functional.max_pool2d(x, 2)x = self.dropout1(x)x = torch.flatten(x, 1)x = self.fc1(x)x = nn.functional.relu(x)x = self.dropout2(x)x = self.fc2(x)output = nn.functional.log_softmax(x, dim=1)return outputmodel = Net()# 如果有 GPU 可用,则将模型转移到 GPU 上
if torch.cuda.is_available():model.cuda()# 定义损失函数和优化器,并应用 Horovod 的 DistributedOptimizer 包装
optimizer = optim.Adam(model.parameters(), lr=0.001 * hvd.size())
optimizer = hvd.DistributedOptimizer(optimizer,named_parameters=model.named_parameters(),op=hvd.Average)# 损失函数
criterion = nn.CrossEntropyLoss()# 训练模型
def train(epoch):model.train()for batch_idx, (data, target) in enumerate(train_loader):if torch.cuda.is_available():data, target = data.cuda(), target.cuda()optimizer.zero_grad()output = model(data)loss = criterion(output, target)loss.backward()optimizer.step()if batch_idx % 100 == 0:print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(epoch, batch_idx * len(data), len(train_sampler),100. * batch_idx / len(train_loader), loss.item()))# 验证模型
def validate():model.eval()validation_loss = 0correct = 0with torch.no_grad():for data, target in val_loader:if torch.cuda.is_available():data, target = data.cuda(), target.cuda()output = model(data)validation_loss += criterion(output, target).item() # sum up batch losspred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probabilitycorrect += pred.eq(target.view_as(pred)).sum().item()validation_loss /= len(val_loader.dataset)accuracy = 100. * correct / len(val_loader.dataset)print('\nValidation set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(validation_loss, correct, len(val_loader.dataset), accuracy))# 调用训练和验证函数
for epoch in range(1, 4):train_sampler.set_epoch(epoch)train(epoch)validate()# 广播模型状态从rank 0到其他进程
hvd.broadcast_parameters(model.state_dict(), root_rank=0)
hvd.broadcast_optimizer_state(optimizer, root_rank=0)

运行代码

要运行这段代码,你需要使用 horovodrun 命令来启动多个进程。例如,在单个节点上的 4 个 GPU 上运行该脚本,可以这样做:

horovodrun -np 4 -H localhost:4 python train.py

这将在本地主机上启动四个进程,每个进程都会占用一个 GPU。如果你是在多台机器上运行,你需要指定每台机器的地址和可用的 GPU 数量。

请注意,这个例子是一个非常基础的实现,实际应用中可能还需要考虑更多的细节,比如更复杂的模型结构、数据预处理、超参数调整等。

相关文章:

使用 PyTorch 和 Horovod 来编写一个简单的分布式训练 demo

使用 PyTorch 和 Horovod 来编写一个简单的分布式训练 demo,可以帮助你理解如何在多GPU或多节点环境中高效地训练深度学习模型。Horovod 是 Uber 开发的一个用于分布式训练的框架,它支持 TensorFlow、Keras、PyTorch 等多个机器学习库。下面是一个基于 P…...

SQL复杂查询功能介绍及示例

文章目录 1. 多表连接(JOIN)功能介绍应用场景示例查询及初始表格customers 表(未查询前)orders 表(未查询前)INNER JOIN 示例LEFT JOIN 示例 2. 子查询(Subquery)功能介绍应用场景示…...

shell基础用法

shell基础知识 shell中的多行注释 :<<EOF read echo $REPLY # read不指定变量&#xff0c;则默认写入$REPLY EOF # :<<EOF ...EOF 多行注释&#xff0c;EOF可以替换为&#xff01;# 等文件目录和执行目录 echo $0$0 # ./demo.sh echo $0的realpath$(realpath…...

C#设计模式--策略模式(Strategy Pattern)

策略模式是一种行为设计模式&#xff0c;它使你能在运行时改变对象的行为。在策略模式定义了一系列算法或策略&#xff0c;并将每个算法封装在独立的类中&#xff0c;使得它们可以互相替换。通过使用策略模式&#xff0c;可以在运行时根据需要选择不同的算法&#xff0c;而不需…...

【opencv入门教程】15. 访问像素的十四种方式

文章选自&#xff1a; 一、像素访问 一张图片由许多个点组成&#xff0c;每个点就是一个像素&#xff0c;每个像素包含不同的值&#xff0c;对图像像素操作是图像处理过程中常使用的 二、访问像素 void Samples::AccessPixels1(Mat &image, int div 64) {int nl imag…...

【MySQL调优】如何进行MySQL调优?从参数、数据建模、索引、SQL语句等方向,三万字详细解读MySQL的性能优化方案(2024版)

导航&#xff1a; 本文一些内容需要聚簇索引、非聚簇索引、B树、覆盖索引、索引下推等前置概念&#xff0c;虽然本文有简单回顾&#xff0c;但详细可以参考下文的【MySQL高级篇】 【Java笔记踩坑汇总】Java基础JavaWebSSMSpringBootSpringCloud瑞吉外卖/谷粒商城/学成在线设计模…...

根据html的段落长度设置QtextBrowser的显示内容,最少显示一个段落

要根据 HTML 段落的长度设置 QTextBrowser 的显示内容&#xff0c;并确保至少显示一个段落&#xff0c;可以通过以下步骤来实现&#xff1a; 加载 HTML 内容&#xff1a;首先&#xff0c;你需要加载 HTML 内容到 QTextBrowser 中。可以通过 setHtml() 方法来设置 HTML。 计算段…...

基于Huffman编码的GPS定位数据无损压缩算法

目录 一、引言 二、霍夫曼编码 三、经典Huffman编码 四、适应性Huffman编码 五、GPS定位数据压缩 提示&#xff1a;文末附定位数据压缩工具和源码 一、引言 车载监控系统中&#xff0c;车载终端需要获取GPS信号&#xff08;经度、纬 度、速度、方向等&#xff09;实时上传…...

php:完整部署Grid++Report到php项目,并实现模板打印

一、下载Grid++Report软件 路径:开发者安装包下载 - 锐浪报表工具 二、 安装软件 1、对下载的压缩包运行内部的exe文件 2、选择语言 3、 完成安装引导 下一步即可 4、接收许可协议 点击“我接受” 5、选择安装路径 “浏览”选择安装路径,点击"安装" 6、完成…...

C标签和 EL表达式的在前端界面的应用

目录 前言 常用的c标签有&#xff1a; for循环 1 表示 普通的for循环的 2 常在集合中使用 表示 选择关系 1 简单的表示如果 2 表示如果。。否则。。 EL表达式 格式 &#xff1a; ${属性名/对象/ 集合} 前言 本篇博客介绍 c标签和el表达式的使用 使用C标签 要引入 …...

Linux絮絮叨(四) 系统目录结构

Linux 系统的目录结构&#xff08;Filesystem Hierarchy Standard, FHS&#xff09;定义了 Linux 系统中文件系统的标准布局&#xff0c;以下是一些常见目录的功能&#xff1a; 根目录 / 描述&#xff1a;所有文件和目录的起始点&#xff0c;Linux 文件系统的根。内容&#xf…...

Java基于SpringBoot的网上订餐系统,附源码

博主介绍&#xff1a;✌Java老徐、7年大厂程序员经历。全网粉丝12w、csdn博客专家、掘金/华为云/阿里云/InfoQ等平台优质作者、专注于Java技术领域和毕业项目实战✌ &#x1f345;文末获取源码联系&#x1f345; &#x1f447;&#x1f3fb; 精彩专栏推荐订阅&#x1f447;&…...

《Java核心技术I》死锁

死锁 账户1&#xff1a;200元账户2: 300元线程1&#xff1a;从账号1转300到账户2线程2&#xff1a;从账户2转400到账户1 如上&#xff0c;线程1和线程2显然都被阻塞&#xff0c;两个账户的余额都不足以转账&#xff0c;两个线程都无法执行下去。 有可能会因为每一个线程要等…...

【Windows11系统局域网共享文件数据】

【Windows11系统局域网共享文件数据】 1. 引言1. 规划网络2. 获取必要的硬件3. 设置网络4. 配置网络设备5. 测试网络连接6. 安全性和维护7. 扩展和优化 2. 准备工作2.1: 启用网络发现和文件共享2.2: 设置共享文件夹 3. 访问共享文件夹4. 小贴士5. 总结 1. 引言 随着家庭和小型办…...

MCU、ARM体系结构,单片机基础,单片机操作

计算机基础 计算机的组成 输入设备、输出设备、存储器、运算器、控制器 输入设备&#xff1a;将其他信号转换为计算机可以识别的信号&#xff08;电信号&#xff09;。输出设备&#xff1a;将电信号&#xff08;&#xff10;、&#xff11;&#xff09;转为人或其他设备能理解的…...

在办公室环境中用HMD替代传统显示器的优势

VR头戴式显示器&#xff08;HMD&#xff09;是进入虚拟现实环境的一把钥匙&#xff0c;拥有HMD的您将能够在虚拟现实世界中尽情探索未知领域&#xff0c;正如如今的互联网一样&#xff0c;虚拟现实环境能够为您提供现实中无法实现的或不可能实现的事。随着技术的不断进步&#…...

ssm 多数据源 注解版本

application.xml 配置如下 <!-- 使用 DruidDataSource 数据源 --><bean id"primaryDataSource" class"com.alibaba.druid.pool.DruidDataSource" init-method"init" destroy-method"close"></bean> <!-- 使用 数…...

selenium常见接口函数使用

博客主页&#xff1a;花果山~程序猿-CSDN博客 文章分栏&#xff1a;测试_花果山~程序猿的博客-CSDN博客 关注我一起学习&#xff0c;一起进步&#xff0c;一起探索编程的无限可能吧&#xff01;让我们一起努力&#xff0c;一起成长&#xff01; 目录 1. 查找 查找方式 css_s…...

STM32F103单片机使用STM32CubeMX新建IAR工程步骤

打开STM32CubeMX软件&#xff0c;选择File 选择新建工程 在打开的窗口输入单片机型号 在右下角选择单片机型号&#xff0c;然后点右上角 start project&#xff0c;开始新建工程。 接下来设置调试接口&#xff0c;在左边System Core中选择 SYS&#xff0c;然后在右右边debu…...

刷题重开:找出字符串中第一个匹配项的下标——解题思路记录

问题描述&#xff1a; 给你两个字符串 haystack 和 needle &#xff0c;请你在 haystack 字符串中找出 needle 字符串的第一个匹配项的下标&#xff08;下标从 0 开始&#xff09;。如果 needle 不是 haystack 的一部分&#xff0c;则返回 -1 。 示例 1&#xff1a; 输入&…...

TDengine 快速体验(Docker 镜像方式)

简介 TDengine 可以通过安装包、Docker 镜像 及云服务快速体验 TDengine 的功能&#xff0c;本节首先介绍如何通过 Docker 快速体验 TDengine&#xff0c;然后介绍如何在 Docker 环境下体验 TDengine 的写入和查询功能。如果你不熟悉 Docker&#xff0c;请使用 安装包的方式快…...

CTF show Web 红包题第六弹

提示 1.不是SQL注入 2.需要找关键源码 思路 进入页面发现是一个登录框&#xff0c;很难让人不联想到SQL注入&#xff0c;但提示都说了不是SQL注入&#xff0c;所以就不往这方面想了 ​ 先查看一下网页源码&#xff0c;发现一段JavaScript代码&#xff0c;有一个关键类ctfs…...

css实现圆环展示百分比,根据值动态展示所占比例

代码如下 <view class""><view class"circle-chart"><view v-if"!!num" class"pie-item" :style"{background: conic-gradient(var(--one-color) 0%,#E9E6F1 ${num}%),}"></view><view v-else …...

[10-3]软件I2C读写MPU6050 江协科技学习笔记(16个知识点)

1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16...

AspectJ 在 Android 中的完整使用指南

一、环境配置&#xff08;Gradle 7.0 适配&#xff09; 1. 项目级 build.gradle // 注意&#xff1a;沪江插件已停更&#xff0c;推荐官方兼容方案 buildscript {dependencies {classpath org.aspectj:aspectjtools:1.9.9.1 // AspectJ 工具} } 2. 模块级 build.gradle plu…...

Maven 概述、安装、配置、仓库、私服详解

目录 1、Maven 概述 1.1 Maven 的定义 1.2 Maven 解决的问题 1.3 Maven 的核心特性与优势 2、Maven 安装 2.1 下载 Maven 2.2 安装配置 Maven 2.3 测试安装 2.4 修改 Maven 本地仓库的默认路径 3、Maven 配置 3.1 配置本地仓库 3.2 配置 JDK 3.3 IDEA 配置本地 Ma…...

动态 Web 开发技术入门篇

一、HTTP 协议核心 1.1 HTTP 基础 协议全称 &#xff1a;HyperText Transfer Protocol&#xff08;超文本传输协议&#xff09; 默认端口 &#xff1a;HTTP 使用 80 端口&#xff0c;HTTPS 使用 443 端口。 请求方法 &#xff1a; GET &#xff1a;用于获取资源&#xff0c;…...

宇树科技,改名了!

提到国内具身智能和机器人领域的代表企业&#xff0c;那宇树科技&#xff08;Unitree&#xff09;必须名列其榜。 最近&#xff0c;宇树科技的一项新变动消息在业界引发了不少关注和讨论&#xff0c;即&#xff1a; 宇树向其合作伙伴发布了一封公司名称变更函称&#xff0c;因…...

Caliper 配置文件解析:fisco-bcos.json

config.yaml 文件 config.yaml 是 Caliper 的主配置文件,通常包含以下内容: test:name: fisco-bcos-test # 测试名称description: Performance test of FISCO-BCOS # 测试描述workers:type: local # 工作进程类型number: 5 # 工作进程数量monitor:type: - docker- pro…...

Unity UGUI Button事件流程

场景结构 测试代码 public class TestBtn : MonoBehaviour {void Start(){var btn GetComponent<Button>();btn.onClick.AddListener(OnClick);}private void OnClick(){Debug.Log("666");}}当添加事件时 // 实例化一个ButtonClickedEvent的事件 [Formerl…...