使用 PyTorch 和 Horovod 来编写一个简单的分布式训练 demo
使用 PyTorch 和 Horovod 来编写一个简单的分布式训练 demo,可以帮助你理解如何在多GPU或多节点环境中高效地训练深度学习模型。Horovod 是 Uber 开发的一个用于分布式训练的框架,它支持 TensorFlow、Keras、PyTorch 等多个机器学习库。下面是一个基于 PyTorch 的简单例子,演示了如何用 Horovod 进行分布式训练。
安装依赖
首先确保你已经安装了 PyTorch 和 Horovod。你可以通过 pip 或者 conda 来安装这些包。对于 Horovod,推荐使用 MPI(Message Passing Interface)进行通信,因此你也需要安装 MPI 和相应的开发工具。
pip install torch torchvision horovod
或者如果你使用的是 Anaconda:
conda install pytorch torchvision -c pytorch
horovodrun --check
# 如果没有安装 horovod, 可以使用以下命令安装:
pip install horovod[pytorch]
编写 PyTorch + Horovod 代码
创建一个新的 Python 文件 train.py
,然后添加如下代码:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
import horovod.torch as hvd
from torchvision import datasets, transforms# 初始化 Horovod
hvd.init()# 设置随机种子,确保结果可复现
torch.manual_seed(42)# 如果有 GPU 可用,则使用 GPU
if torch.cuda.is_available():torch.cuda.set_device(hvd.local_rank())# 加载数据集 (以 MNIST 数据集为例)
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
train_dataset = datasets.MNIST('.', train=True, download=True, transform=transform)
val_dataset = datasets.MNIST('.', train=False, transform=transform)# 分布式采样器
train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset, num_replicas=hvd.size(), rank=hvd.rank())
val_sampler = torch.utils.data.distributed.DistributedSampler(val_dataset, num_replicas=hvd.size(), rank=hvd.rank())# 创建数据加载器
train_loader = DataLoader(train_dataset, batch_size=64, sampler=train_sampler)
val_loader = DataLoader(val_dataset, batch_size=1000, sampler=val_sampler)# 定义模型
class Net(nn.Module):def __init__(self):super(Net, self).__init__()self.conv1 = nn.Conv2d(1, 32, 3, 1)self.conv2 = nn.Conv2d(32, 64, 3, 1)self.dropout1 = nn.Dropout2d(0.25)self.dropout2 = nn.Dropout2d(0.5)self.fc1 = nn.Linear(9216, 128)self.fc2 = nn.Linear(128, 10)def forward(self, x):x = self.conv1(x)x = nn.functional.relu(x)x = self.conv2(x)x = nn.functional.relu(x)x = nn.functional.max_pool2d(x, 2)x = self.dropout1(x)x = torch.flatten(x, 1)x = self.fc1(x)x = nn.functional.relu(x)x = self.dropout2(x)x = self.fc2(x)output = nn.functional.log_softmax(x, dim=1)return outputmodel = Net()# 如果有 GPU 可用,则将模型转移到 GPU 上
if torch.cuda.is_available():model.cuda()# 定义损失函数和优化器,并应用 Horovod 的 DistributedOptimizer 包装
optimizer = optim.Adam(model.parameters(), lr=0.001 * hvd.size())
optimizer = hvd.DistributedOptimizer(optimizer,named_parameters=model.named_parameters(),op=hvd.Average)# 损失函数
criterion = nn.CrossEntropyLoss()# 训练模型
def train(epoch):model.train()for batch_idx, (data, target) in enumerate(train_loader):if torch.cuda.is_available():data, target = data.cuda(), target.cuda()optimizer.zero_grad()output = model(data)loss = criterion(output, target)loss.backward()optimizer.step()if batch_idx % 100 == 0:print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(epoch, batch_idx * len(data), len(train_sampler),100. * batch_idx / len(train_loader), loss.item()))# 验证模型
def validate():model.eval()validation_loss = 0correct = 0with torch.no_grad():for data, target in val_loader:if torch.cuda.is_available():data, target = data.cuda(), target.cuda()output = model(data)validation_loss += criterion(output, target).item() # sum up batch losspred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probabilitycorrect += pred.eq(target.view_as(pred)).sum().item()validation_loss /= len(val_loader.dataset)accuracy = 100. * correct / len(val_loader.dataset)print('\nValidation set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(validation_loss, correct, len(val_loader.dataset), accuracy))# 调用训练和验证函数
for epoch in range(1, 4):train_sampler.set_epoch(epoch)train(epoch)validate()# 广播模型状态从rank 0到其他进程
hvd.broadcast_parameters(model.state_dict(), root_rank=0)
hvd.broadcast_optimizer_state(optimizer, root_rank=0)
运行代码
要运行这段代码,你需要使用 horovodrun
命令来启动多个进程。例如,在单个节点上的 4 个 GPU 上运行该脚本,可以这样做:
horovodrun -np 4 -H localhost:4 python train.py
这将在本地主机上启动四个进程,每个进程都会占用一个 GPU。如果你是在多台机器上运行,你需要指定每台机器的地址和可用的 GPU 数量。
请注意,这个例子是一个非常基础的实现,实际应用中可能还需要考虑更多的细节,比如更复杂的模型结构、数据预处理、超参数调整等。
相关文章:

使用 PyTorch 和 Horovod 来编写一个简单的分布式训练 demo
使用 PyTorch 和 Horovod 来编写一个简单的分布式训练 demo,可以帮助你理解如何在多GPU或多节点环境中高效地训练深度学习模型。Horovod 是 Uber 开发的一个用于分布式训练的框架,它支持 TensorFlow、Keras、PyTorch 等多个机器学习库。下面是一个基于 P…...

SQL复杂查询功能介绍及示例
文章目录 1. 多表连接(JOIN)功能介绍应用场景示例查询及初始表格customers 表(未查询前)orders 表(未查询前)INNER JOIN 示例LEFT JOIN 示例 2. 子查询(Subquery)功能介绍应用场景示…...

shell基础用法
shell基础知识 shell中的多行注释 :<<EOF read echo $REPLY # read不指定变量,则默认写入$REPLY EOF # :<<EOF ...EOF 多行注释,EOF可以替换为!# 等文件目录和执行目录 echo $0$0 # ./demo.sh echo $0的realpath$(realpath…...

C#设计模式--策略模式(Strategy Pattern)
策略模式是一种行为设计模式,它使你能在运行时改变对象的行为。在策略模式定义了一系列算法或策略,并将每个算法封装在独立的类中,使得它们可以互相替换。通过使用策略模式,可以在运行时根据需要选择不同的算法,而不需…...

【opencv入门教程】15. 访问像素的十四种方式
文章选自: 一、像素访问 一张图片由许多个点组成,每个点就是一个像素,每个像素包含不同的值,对图像像素操作是图像处理过程中常使用的 二、访问像素 void Samples::AccessPixels1(Mat &image, int div 64) {int nl imag…...

【MySQL调优】如何进行MySQL调优?从参数、数据建模、索引、SQL语句等方向,三万字详细解读MySQL的性能优化方案(2024版)
导航: 本文一些内容需要聚簇索引、非聚簇索引、B树、覆盖索引、索引下推等前置概念,虽然本文有简单回顾,但详细可以参考下文的【MySQL高级篇】 【Java笔记踩坑汇总】Java基础JavaWebSSMSpringBootSpringCloud瑞吉外卖/谷粒商城/学成在线设计模…...

根据html的段落长度设置QtextBrowser的显示内容,最少显示一个段落
要根据 HTML 段落的长度设置 QTextBrowser 的显示内容,并确保至少显示一个段落,可以通过以下步骤来实现: 加载 HTML 内容:首先,你需要加载 HTML 内容到 QTextBrowser 中。可以通过 setHtml() 方法来设置 HTML。 计算段…...

基于Huffman编码的GPS定位数据无损压缩算法
目录 一、引言 二、霍夫曼编码 三、经典Huffman编码 四、适应性Huffman编码 五、GPS定位数据压缩 提示:文末附定位数据压缩工具和源码 一、引言 车载监控系统中,车载终端需要获取GPS信号(经度、纬 度、速度、方向等)实时上传…...

php:完整部署Grid++Report到php项目,并实现模板打印
一、下载Grid++Report软件 路径:开发者安装包下载 - 锐浪报表工具 二、 安装软件 1、对下载的压缩包运行内部的exe文件 2、选择语言 3、 完成安装引导 下一步即可 4、接收许可协议 点击“我接受” 5、选择安装路径 “浏览”选择安装路径,点击"安装" 6、完成…...

C标签和 EL表达式的在前端界面的应用
目录 前言 常用的c标签有: for循环 1 表示 普通的for循环的 2 常在集合中使用 表示 选择关系 1 简单的表示如果 2 表示如果。。否则。。 EL表达式 格式 : ${属性名/对象/ 集合} 前言 本篇博客介绍 c标签和el表达式的使用 使用C标签 要引入 …...

Linux絮絮叨(四) 系统目录结构
Linux 系统的目录结构(Filesystem Hierarchy Standard, FHS)定义了 Linux 系统中文件系统的标准布局,以下是一些常见目录的功能: 根目录 / 描述:所有文件和目录的起始点,Linux 文件系统的根。内容…...

Java基于SpringBoot的网上订餐系统,附源码
博主介绍:✌Java老徐、7年大厂程序员经历。全网粉丝12w、csdn博客专家、掘金/华为云/阿里云/InfoQ等平台优质作者、专注于Java技术领域和毕业项目实战✌ 🍅文末获取源码联系🍅 👇🏻 精彩专栏推荐订阅👇&…...

《Java核心技术I》死锁
死锁 账户1:200元账户2: 300元线程1:从账号1转300到账户2线程2:从账户2转400到账户1 如上,线程1和线程2显然都被阻塞,两个账户的余额都不足以转账,两个线程都无法执行下去。 有可能会因为每一个线程要等…...

【Windows11系统局域网共享文件数据】
【Windows11系统局域网共享文件数据】 1. 引言1. 规划网络2. 获取必要的硬件3. 设置网络4. 配置网络设备5. 测试网络连接6. 安全性和维护7. 扩展和优化 2. 准备工作2.1: 启用网络发现和文件共享2.2: 设置共享文件夹 3. 访问共享文件夹4. 小贴士5. 总结 1. 引言 随着家庭和小型办…...

MCU、ARM体系结构,单片机基础,单片机操作
计算机基础 计算机的组成 输入设备、输出设备、存储器、运算器、控制器 输入设备:将其他信号转换为计算机可以识别的信号(电信号)。输出设备:将电信号(0、1)转为人或其他设备能理解的…...

在办公室环境中用HMD替代传统显示器的优势
VR头戴式显示器(HMD)是进入虚拟现实环境的一把钥匙,拥有HMD的您将能够在虚拟现实世界中尽情探索未知领域,正如如今的互联网一样,虚拟现实环境能够为您提供现实中无法实现的或不可能实现的事。随着技术的不断进步&#…...

ssm 多数据源 注解版本
application.xml 配置如下 <!-- 使用 DruidDataSource 数据源 --><bean id"primaryDataSource" class"com.alibaba.druid.pool.DruidDataSource" init-method"init" destroy-method"close"></bean> <!-- 使用 数…...

selenium常见接口函数使用
博客主页:花果山~程序猿-CSDN博客 文章分栏:测试_花果山~程序猿的博客-CSDN博客 关注我一起学习,一起进步,一起探索编程的无限可能吧!让我们一起努力,一起成长! 目录 1. 查找 查找方式 css_s…...

STM32F103单片机使用STM32CubeMX新建IAR工程步骤
打开STM32CubeMX软件,选择File 选择新建工程 在打开的窗口输入单片机型号 在右下角选择单片机型号,然后点右上角 start project,开始新建工程。 接下来设置调试接口,在左边System Core中选择 SYS,然后在右右边debu…...

刷题重开:找出字符串中第一个匹配项的下标——解题思路记录
问题描述: 给你两个字符串 haystack 和 needle ,请你在 haystack 字符串中找出 needle 字符串的第一个匹配项的下标(下标从 0 开始)。如果 needle 不是 haystack 的一部分,则返回 -1 。 示例 1: 输入&…...

product/admin/list?page=0size=10field=jancodevalue=4562249292272
文章目录 1、ProductController2、AdminCommonService3、ProductApiService4、ProductCommonService5、ProductSqlService https://api.crossbiog.com/product/admin/list?page0&size10&fieldjancode&value45622492922721、ProductController GetMapping("ad…...

人工智能机器学习无监督学习概念及应用详解
无监督学习:深入解析 引言 在人工智能和机器学习的领域中,无监督学习(Unsupervised Learning)是一种重要的学习范式。与监督学习不同,无监督学习不依赖于标签数据,而是通过模型从无标签的数据中学习数据的…...

APM装机教程(五):测绘无人船
文章目录 前言一、元生惯导RTK使用二、元厚HXF260测深仪使用三、云卓H2pro遥控器四、海康威视摄像头 前言 船体:超维USV-M1000 飞控:pix6c mini 测深仪:元厚HXF160 RTK:元生惯导RTK 遥控器:云卓H12pro 摄像头…...

微信小程序 运行出错 弹出提示框(获取token失败,请重试 或者 请求失败)
原因是:需要登陆微信公众平台在开发管理 中设置 相应的 服务器域名 中的 request合法域名 // index.jsPage({data: {products:[],cardLayout: grid, // 默认卡片布局为网格模式isGrid: true, // 默认为网格布局page: 0, // 当前页码size: 10, // 每页大小hasMore…...

IDEA的service窗口中启动类是灰色且容易消失
大家在学习Spring Cloud的过程中,随着项目的深入,会分出很多个微服务,当我们的服务数量大于等于三个的时候,IDEA会给我们的服务整理起来,类似于这样 但是当我们的微服务数量达到5个以上的时候,再启动服务的时候,服务的启动类就会变成灰色,而且还容易丢失 解决方法 我们按住…...

R中利用ggplot2绘制气泡图
闲来无事,整理了一下自己的绘图笔记,顺便分享到CSDN上。 一、介绍 气泡图(Bubble Plot)是一种常用的数据可视化方法,用于展示三个变量之间的关系。气泡图的特点是通过气泡的大小、颜色和位置来表达数据中的多维信息。…...

CID引流电商
ClickID技术是基于多家媒体平台开发的电商引流服务,通过媒体提供的宏参数,间接解决电商平台订单数据的回传问题,帮助账户收集到极致精准的数据模型,搭建不同媒体往各平台引流的桥梁。简单来说就是通过ClickID数据监测到另外一个平…...

在google cloud虚拟机上配置anaconda虚拟环境简单教程
下载anaconda安装包 wget https://repo.anaconda.com/archive/Anaconda3-2022.10-Linux-x86_64.sh 安装 bash Anaconda3-2022.10-Linux-x86_64.sh 进入base环境 eval "$(/home/xmxhuihui/anaconda3/bin/conda shell.bash hook)" source ~/.bashrc 安装虚拟环境…...

windows下用vs搭配clang一起生成抽象语法树
如果你使用的是 Visual Studio 环境,并且想要通知 Clang 关于 C 语言标准库的位置,你可以通过以下几种方法来实现。Visual Studio 提供了完整的 C/C 标准库,Clang 可以与之协同工作。以下是具体步骤: 1. 使用 clang-cl Visual S…...

输入法:点三下输入一个汉字
作者常用的双拼输入法,需要26键。虽然也有9键的方案,但重码率较高。计算一下,9键点2下,共81种排列组合。而汉字的读音,不计声调,有400多个。相差甚多。 所以,设计了“三拼输入法”,…...