当前位置: 首页 > news >正文

使用 PyTorch 和 Horovod 来编写一个简单的分布式训练 demo

使用 PyTorch 和 Horovod 来编写一个简单的分布式训练 demo,可以帮助你理解如何在多GPU或多节点环境中高效地训练深度学习模型。Horovod 是 Uber 开发的一个用于分布式训练的框架,它支持 TensorFlow、Keras、PyTorch 等多个机器学习库。下面是一个基于 PyTorch 的简单例子,演示了如何用 Horovod 进行分布式训练。

安装依赖

首先确保你已经安装了 PyTorch 和 Horovod。你可以通过 pip 或者 conda 来安装这些包。对于 Horovod,推荐使用 MPI(Message Passing Interface)进行通信,因此你也需要安装 MPI 和相应的开发工具。

pip install torch torchvision horovod

或者如果你使用的是 Anaconda:

conda install pytorch torchvision -c pytorch
horovodrun --check
# 如果没有安装 horovod, 可以使用以下命令安装:
pip install horovod[pytorch]

编写 PyTorch + Horovod 代码

创建一个新的 Python 文件 train.py,然后添加如下代码:

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
import horovod.torch as hvd
from torchvision import datasets, transforms# 初始化 Horovod
hvd.init()# 设置随机种子,确保结果可复现
torch.manual_seed(42)# 如果有 GPU 可用,则使用 GPU
if torch.cuda.is_available():torch.cuda.set_device(hvd.local_rank())# 加载数据集 (以 MNIST 数据集为例)
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
train_dataset = datasets.MNIST('.', train=True, download=True, transform=transform)
val_dataset = datasets.MNIST('.', train=False, transform=transform)# 分布式采样器
train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset, num_replicas=hvd.size(), rank=hvd.rank())
val_sampler = torch.utils.data.distributed.DistributedSampler(val_dataset, num_replicas=hvd.size(), rank=hvd.rank())# 创建数据加载器
train_loader = DataLoader(train_dataset, batch_size=64, sampler=train_sampler)
val_loader = DataLoader(val_dataset, batch_size=1000, sampler=val_sampler)# 定义模型
class Net(nn.Module):def __init__(self):super(Net, self).__init__()self.conv1 = nn.Conv2d(1, 32, 3, 1)self.conv2 = nn.Conv2d(32, 64, 3, 1)self.dropout1 = nn.Dropout2d(0.25)self.dropout2 = nn.Dropout2d(0.5)self.fc1 = nn.Linear(9216, 128)self.fc2 = nn.Linear(128, 10)def forward(self, x):x = self.conv1(x)x = nn.functional.relu(x)x = self.conv2(x)x = nn.functional.relu(x)x = nn.functional.max_pool2d(x, 2)x = self.dropout1(x)x = torch.flatten(x, 1)x = self.fc1(x)x = nn.functional.relu(x)x = self.dropout2(x)x = self.fc2(x)output = nn.functional.log_softmax(x, dim=1)return outputmodel = Net()# 如果有 GPU 可用,则将模型转移到 GPU 上
if torch.cuda.is_available():model.cuda()# 定义损失函数和优化器,并应用 Horovod 的 DistributedOptimizer 包装
optimizer = optim.Adam(model.parameters(), lr=0.001 * hvd.size())
optimizer = hvd.DistributedOptimizer(optimizer,named_parameters=model.named_parameters(),op=hvd.Average)# 损失函数
criterion = nn.CrossEntropyLoss()# 训练模型
def train(epoch):model.train()for batch_idx, (data, target) in enumerate(train_loader):if torch.cuda.is_available():data, target = data.cuda(), target.cuda()optimizer.zero_grad()output = model(data)loss = criterion(output, target)loss.backward()optimizer.step()if batch_idx % 100 == 0:print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(epoch, batch_idx * len(data), len(train_sampler),100. * batch_idx / len(train_loader), loss.item()))# 验证模型
def validate():model.eval()validation_loss = 0correct = 0with torch.no_grad():for data, target in val_loader:if torch.cuda.is_available():data, target = data.cuda(), target.cuda()output = model(data)validation_loss += criterion(output, target).item() # sum up batch losspred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probabilitycorrect += pred.eq(target.view_as(pred)).sum().item()validation_loss /= len(val_loader.dataset)accuracy = 100. * correct / len(val_loader.dataset)print('\nValidation set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(validation_loss, correct, len(val_loader.dataset), accuracy))# 调用训练和验证函数
for epoch in range(1, 4):train_sampler.set_epoch(epoch)train(epoch)validate()# 广播模型状态从rank 0到其他进程
hvd.broadcast_parameters(model.state_dict(), root_rank=0)
hvd.broadcast_optimizer_state(optimizer, root_rank=0)

运行代码

要运行这段代码,你需要使用 horovodrun 命令来启动多个进程。例如,在单个节点上的 4 个 GPU 上运行该脚本,可以这样做:

horovodrun -np 4 -H localhost:4 python train.py

这将在本地主机上启动四个进程,每个进程都会占用一个 GPU。如果你是在多台机器上运行,你需要指定每台机器的地址和可用的 GPU 数量。

请注意,这个例子是一个非常基础的实现,实际应用中可能还需要考虑更多的细节,比如更复杂的模型结构、数据预处理、超参数调整等。

相关文章:

使用 PyTorch 和 Horovod 来编写一个简单的分布式训练 demo

使用 PyTorch 和 Horovod 来编写一个简单的分布式训练 demo,可以帮助你理解如何在多GPU或多节点环境中高效地训练深度学习模型。Horovod 是 Uber 开发的一个用于分布式训练的框架,它支持 TensorFlow、Keras、PyTorch 等多个机器学习库。下面是一个基于 P…...

SQL复杂查询功能介绍及示例

文章目录 1. 多表连接(JOIN)功能介绍应用场景示例查询及初始表格customers 表(未查询前)orders 表(未查询前)INNER JOIN 示例LEFT JOIN 示例 2. 子查询(Subquery)功能介绍应用场景示…...

shell基础用法

shell基础知识 shell中的多行注释 :<<EOF read echo $REPLY # read不指定变量&#xff0c;则默认写入$REPLY EOF # :<<EOF ...EOF 多行注释&#xff0c;EOF可以替换为&#xff01;# 等文件目录和执行目录 echo $0$0 # ./demo.sh echo $0的realpath$(realpath…...

C#设计模式--策略模式(Strategy Pattern)

策略模式是一种行为设计模式&#xff0c;它使你能在运行时改变对象的行为。在策略模式定义了一系列算法或策略&#xff0c;并将每个算法封装在独立的类中&#xff0c;使得它们可以互相替换。通过使用策略模式&#xff0c;可以在运行时根据需要选择不同的算法&#xff0c;而不需…...

【opencv入门教程】15. 访问像素的十四种方式

文章选自&#xff1a; 一、像素访问 一张图片由许多个点组成&#xff0c;每个点就是一个像素&#xff0c;每个像素包含不同的值&#xff0c;对图像像素操作是图像处理过程中常使用的 二、访问像素 void Samples::AccessPixels1(Mat &image, int div 64) {int nl imag…...

【MySQL调优】如何进行MySQL调优?从参数、数据建模、索引、SQL语句等方向,三万字详细解读MySQL的性能优化方案(2024版)

导航&#xff1a; 本文一些内容需要聚簇索引、非聚簇索引、B树、覆盖索引、索引下推等前置概念&#xff0c;虽然本文有简单回顾&#xff0c;但详细可以参考下文的【MySQL高级篇】 【Java笔记踩坑汇总】Java基础JavaWebSSMSpringBootSpringCloud瑞吉外卖/谷粒商城/学成在线设计模…...

根据html的段落长度设置QtextBrowser的显示内容,最少显示一个段落

要根据 HTML 段落的长度设置 QTextBrowser 的显示内容&#xff0c;并确保至少显示一个段落&#xff0c;可以通过以下步骤来实现&#xff1a; 加载 HTML 内容&#xff1a;首先&#xff0c;你需要加载 HTML 内容到 QTextBrowser 中。可以通过 setHtml() 方法来设置 HTML。 计算段…...

基于Huffman编码的GPS定位数据无损压缩算法

目录 一、引言 二、霍夫曼编码 三、经典Huffman编码 四、适应性Huffman编码 五、GPS定位数据压缩 提示&#xff1a;文末附定位数据压缩工具和源码 一、引言 车载监控系统中&#xff0c;车载终端需要获取GPS信号&#xff08;经度、纬 度、速度、方向等&#xff09;实时上传…...

php:完整部署Grid++Report到php项目,并实现模板打印

一、下载Grid++Report软件 路径:开发者安装包下载 - 锐浪报表工具 二、 安装软件 1、对下载的压缩包运行内部的exe文件 2、选择语言 3、 完成安装引导 下一步即可 4、接收许可协议 点击“我接受” 5、选择安装路径 “浏览”选择安装路径,点击"安装" 6、完成…...

C标签和 EL表达式的在前端界面的应用

目录 前言 常用的c标签有&#xff1a; for循环 1 表示 普通的for循环的 2 常在集合中使用 表示 选择关系 1 简单的表示如果 2 表示如果。。否则。。 EL表达式 格式 &#xff1a; ${属性名/对象/ 集合} 前言 本篇博客介绍 c标签和el表达式的使用 使用C标签 要引入 …...

Linux絮絮叨(四) 系统目录结构

Linux 系统的目录结构&#xff08;Filesystem Hierarchy Standard, FHS&#xff09;定义了 Linux 系统中文件系统的标准布局&#xff0c;以下是一些常见目录的功能&#xff1a; 根目录 / 描述&#xff1a;所有文件和目录的起始点&#xff0c;Linux 文件系统的根。内容&#xf…...

Java基于SpringBoot的网上订餐系统,附源码

博主介绍&#xff1a;✌Java老徐、7年大厂程序员经历。全网粉丝12w、csdn博客专家、掘金/华为云/阿里云/InfoQ等平台优质作者、专注于Java技术领域和毕业项目实战✌ &#x1f345;文末获取源码联系&#x1f345; &#x1f447;&#x1f3fb; 精彩专栏推荐订阅&#x1f447;&…...

《Java核心技术I》死锁

死锁 账户1&#xff1a;200元账户2: 300元线程1&#xff1a;从账号1转300到账户2线程2&#xff1a;从账户2转400到账户1 如上&#xff0c;线程1和线程2显然都被阻塞&#xff0c;两个账户的余额都不足以转账&#xff0c;两个线程都无法执行下去。 有可能会因为每一个线程要等…...

【Windows11系统局域网共享文件数据】

【Windows11系统局域网共享文件数据】 1. 引言1. 规划网络2. 获取必要的硬件3. 设置网络4. 配置网络设备5. 测试网络连接6. 安全性和维护7. 扩展和优化 2. 准备工作2.1: 启用网络发现和文件共享2.2: 设置共享文件夹 3. 访问共享文件夹4. 小贴士5. 总结 1. 引言 随着家庭和小型办…...

MCU、ARM体系结构,单片机基础,单片机操作

计算机基础 计算机的组成 输入设备、输出设备、存储器、运算器、控制器 输入设备&#xff1a;将其他信号转换为计算机可以识别的信号&#xff08;电信号&#xff09;。输出设备&#xff1a;将电信号&#xff08;&#xff10;、&#xff11;&#xff09;转为人或其他设备能理解的…...

在办公室环境中用HMD替代传统显示器的优势

VR头戴式显示器&#xff08;HMD&#xff09;是进入虚拟现实环境的一把钥匙&#xff0c;拥有HMD的您将能够在虚拟现实世界中尽情探索未知领域&#xff0c;正如如今的互联网一样&#xff0c;虚拟现实环境能够为您提供现实中无法实现的或不可能实现的事。随着技术的不断进步&#…...

ssm 多数据源 注解版本

application.xml 配置如下 <!-- 使用 DruidDataSource 数据源 --><bean id"primaryDataSource" class"com.alibaba.druid.pool.DruidDataSource" init-method"init" destroy-method"close"></bean> <!-- 使用 数…...

selenium常见接口函数使用

博客主页&#xff1a;花果山~程序猿-CSDN博客 文章分栏&#xff1a;测试_花果山~程序猿的博客-CSDN博客 关注我一起学习&#xff0c;一起进步&#xff0c;一起探索编程的无限可能吧&#xff01;让我们一起努力&#xff0c;一起成长&#xff01; 目录 1. 查找 查找方式 css_s…...

STM32F103单片机使用STM32CubeMX新建IAR工程步骤

打开STM32CubeMX软件&#xff0c;选择File 选择新建工程 在打开的窗口输入单片机型号 在右下角选择单片机型号&#xff0c;然后点右上角 start project&#xff0c;开始新建工程。 接下来设置调试接口&#xff0c;在左边System Core中选择 SYS&#xff0c;然后在右右边debu…...

刷题重开:找出字符串中第一个匹配项的下标——解题思路记录

问题描述&#xff1a; 给你两个字符串 haystack 和 needle &#xff0c;请你在 haystack 字符串中找出 needle 字符串的第一个匹配项的下标&#xff08;下标从 0 开始&#xff09;。如果 needle 不是 haystack 的一部分&#xff0c;则返回 -1 。 示例 1&#xff1a; 输入&…...

stm32G473的flash模式是单bank还是双bank?

今天突然有人stm32G473的flash模式是单bank还是双bank&#xff1f;由于时间太久&#xff0c;我真忘记了。搜搜发现&#xff0c;还真有人和我一样。见下面的链接&#xff1a;https://shequ.stmicroelectronics.cn/forum.php?modviewthread&tid644563 根据STM32G4系列参考手…...

MySQL 隔离级别:脏读、幻读及不可重复读的原理与示例

一、MySQL 隔离级别 MySQL 提供了四种隔离级别,用于控制事务之间的并发访问以及数据的可见性,不同隔离级别对脏读、幻读、不可重复读这几种并发数据问题有着不同的处理方式,具体如下: 隔离级别脏读不可重复读幻读性能特点及锁机制读未提交(READ UNCOMMITTED)允许出现允许…...

React Native在HarmonyOS 5.0阅读类应用开发中的实践

一、技术选型背景 随着HarmonyOS 5.0对Web兼容层的增强&#xff0c;React Native作为跨平台框架可通过重新编译ArkTS组件实现85%以上的代码复用率。阅读类应用具有UI复杂度低、数据流清晰的特点。 二、核心实现方案 1. 环境配置 &#xff08;1&#xff09;使用React Native…...

Java求职者面试指南:Spring、Spring Boot、MyBatis框架与计算机基础问题解析

Java求职者面试指南&#xff1a;Spring、Spring Boot、MyBatis框架与计算机基础问题解析 一、第一轮提问&#xff08;基础概念问题&#xff09; 1. 请解释Spring框架的核心容器是什么&#xff1f;它在Spring中起到什么作用&#xff1f; Spring框架的核心容器是IoC容器&#…...

虚拟电厂发展三大趋势:市场化、技术主导、车网互联

市场化&#xff1a;从政策驱动到多元盈利 政策全面赋能 2025年4月&#xff0c;国家发改委、能源局发布《关于加快推进虚拟电厂发展的指导意见》&#xff0c;首次明确虚拟电厂为“独立市场主体”&#xff0c;提出硬性目标&#xff1a;2027年全国调节能力≥2000万千瓦&#xff0…...

【从零开始学习JVM | 第四篇】类加载器和双亲委派机制(高频面试题)

前言&#xff1a; 双亲委派机制对于面试这块来说非常重要&#xff0c;在实际开发中也是经常遇见需要打破双亲委派的需求&#xff0c;今天我们一起来探索一下什么是双亲委派机制&#xff0c;在此之前我们先介绍一下类的加载器。 目录 ​编辑 前言&#xff1a; 类加载器 1. …...

淘宝扭蛋机小程序系统开发:打造互动性强的购物平台

淘宝扭蛋机小程序系统的开发&#xff0c;旨在打造一个互动性强的购物平台&#xff0c;让用户在购物的同时&#xff0c;能够享受到更多的乐趣和惊喜。 淘宝扭蛋机小程序系统拥有丰富的互动功能。用户可以通过虚拟摇杆操作扭蛋机&#xff0c;实现旋转、抽拉等动作&#xff0c;增…...

android RelativeLayout布局

<?xml version"1.0" encoding"utf-8"?> <RelativeLayout xmlns:android"http://schemas.android.com/apk/res/android"android:layout_width"match_parent"android:layout_height"match_parent"android:gravity&…...

sshd代码修改banner

sshd服务连接之后会收到字符串&#xff1a; SSH-2.0-OpenSSH_9.5 容易被hacker识别此服务为sshd服务。 是否可以通过修改此banner达到让人无法识别此服务的目的呢&#xff1f; 不能。因为这是写的SSH的协议中的。 也就是协议规定了banner必须这么写。 SSH- 开头&#xff0c…...

2025年低延迟业务DDoS防护全攻略:高可用架构与实战方案

一、延迟敏感行业面临的DDoS攻击新挑战 2025年&#xff0c;金融交易、实时竞技游戏、工业物联网等低延迟业务成为DDoS攻击的首要目标。攻击呈现三大特征&#xff1a; AI驱动的自适应攻击&#xff1a;攻击流量模拟真实用户行为&#xff0c;差异率低至0.5%&#xff0c;传统规则引…...