使用 PyTorch 和 Horovod 来编写一个简单的分布式训练 demo
使用 PyTorch 和 Horovod 来编写一个简单的分布式训练 demo,可以帮助你理解如何在多GPU或多节点环境中高效地训练深度学习模型。Horovod 是 Uber 开发的一个用于分布式训练的框架,它支持 TensorFlow、Keras、PyTorch 等多个机器学习库。下面是一个基于 PyTorch 的简单例子,演示了如何用 Horovod 进行分布式训练。
安装依赖
首先确保你已经安装了 PyTorch 和 Horovod。你可以通过 pip 或者 conda 来安装这些包。对于 Horovod,推荐使用 MPI(Message Passing Interface)进行通信,因此你也需要安装 MPI 和相应的开发工具。
pip install torch torchvision horovod
或者如果你使用的是 Anaconda:
conda install pytorch torchvision -c pytorch
horovodrun --check
# 如果没有安装 horovod, 可以使用以下命令安装:
pip install horovod[pytorch]
编写 PyTorch + Horovod 代码
创建一个新的 Python 文件 train.py,然后添加如下代码:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
import horovod.torch as hvd
from torchvision import datasets, transforms# 初始化 Horovod
hvd.init()# 设置随机种子,确保结果可复现
torch.manual_seed(42)# 如果有 GPU 可用,则使用 GPU
if torch.cuda.is_available():torch.cuda.set_device(hvd.local_rank())# 加载数据集 (以 MNIST 数据集为例)
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
train_dataset = datasets.MNIST('.', train=True, download=True, transform=transform)
val_dataset = datasets.MNIST('.', train=False, transform=transform)# 分布式采样器
train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset, num_replicas=hvd.size(), rank=hvd.rank())
val_sampler = torch.utils.data.distributed.DistributedSampler(val_dataset, num_replicas=hvd.size(), rank=hvd.rank())# 创建数据加载器
train_loader = DataLoader(train_dataset, batch_size=64, sampler=train_sampler)
val_loader = DataLoader(val_dataset, batch_size=1000, sampler=val_sampler)# 定义模型
class Net(nn.Module):def __init__(self):super(Net, self).__init__()self.conv1 = nn.Conv2d(1, 32, 3, 1)self.conv2 = nn.Conv2d(32, 64, 3, 1)self.dropout1 = nn.Dropout2d(0.25)self.dropout2 = nn.Dropout2d(0.5)self.fc1 = nn.Linear(9216, 128)self.fc2 = nn.Linear(128, 10)def forward(self, x):x = self.conv1(x)x = nn.functional.relu(x)x = self.conv2(x)x = nn.functional.relu(x)x = nn.functional.max_pool2d(x, 2)x = self.dropout1(x)x = torch.flatten(x, 1)x = self.fc1(x)x = nn.functional.relu(x)x = self.dropout2(x)x = self.fc2(x)output = nn.functional.log_softmax(x, dim=1)return outputmodel = Net()# 如果有 GPU 可用,则将模型转移到 GPU 上
if torch.cuda.is_available():model.cuda()# 定义损失函数和优化器,并应用 Horovod 的 DistributedOptimizer 包装
optimizer = optim.Adam(model.parameters(), lr=0.001 * hvd.size())
optimizer = hvd.DistributedOptimizer(optimizer,named_parameters=model.named_parameters(),op=hvd.Average)# 损失函数
criterion = nn.CrossEntropyLoss()# 训练模型
def train(epoch):model.train()for batch_idx, (data, target) in enumerate(train_loader):if torch.cuda.is_available():data, target = data.cuda(), target.cuda()optimizer.zero_grad()output = model(data)loss = criterion(output, target)loss.backward()optimizer.step()if batch_idx % 100 == 0:print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(epoch, batch_idx * len(data), len(train_sampler),100. * batch_idx / len(train_loader), loss.item()))# 验证模型
def validate():model.eval()validation_loss = 0correct = 0with torch.no_grad():for data, target in val_loader:if torch.cuda.is_available():data, target = data.cuda(), target.cuda()output = model(data)validation_loss += criterion(output, target).item() # sum up batch losspred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probabilitycorrect += pred.eq(target.view_as(pred)).sum().item()validation_loss /= len(val_loader.dataset)accuracy = 100. * correct / len(val_loader.dataset)print('\nValidation set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(validation_loss, correct, len(val_loader.dataset), accuracy))# 调用训练和验证函数
for epoch in range(1, 4):train_sampler.set_epoch(epoch)train(epoch)validate()# 广播模型状态从rank 0到其他进程
hvd.broadcast_parameters(model.state_dict(), root_rank=0)
hvd.broadcast_optimizer_state(optimizer, root_rank=0)
运行代码
要运行这段代码,你需要使用 horovodrun 命令来启动多个进程。例如,在单个节点上的 4 个 GPU 上运行该脚本,可以这样做:
horovodrun -np 4 -H localhost:4 python train.py
这将在本地主机上启动四个进程,每个进程都会占用一个 GPU。如果你是在多台机器上运行,你需要指定每台机器的地址和可用的 GPU 数量。
请注意,这个例子是一个非常基础的实现,实际应用中可能还需要考虑更多的细节,比如更复杂的模型结构、数据预处理、超参数调整等。
相关文章:
使用 PyTorch 和 Horovod 来编写一个简单的分布式训练 demo
使用 PyTorch 和 Horovod 来编写一个简单的分布式训练 demo,可以帮助你理解如何在多GPU或多节点环境中高效地训练深度学习模型。Horovod 是 Uber 开发的一个用于分布式训练的框架,它支持 TensorFlow、Keras、PyTorch 等多个机器学习库。下面是一个基于 P…...
SQL复杂查询功能介绍及示例
文章目录 1. 多表连接(JOIN)功能介绍应用场景示例查询及初始表格customers 表(未查询前)orders 表(未查询前)INNER JOIN 示例LEFT JOIN 示例 2. 子查询(Subquery)功能介绍应用场景示…...
shell基础用法
shell基础知识 shell中的多行注释 :<<EOF read echo $REPLY # read不指定变量,则默认写入$REPLY EOF # :<<EOF ...EOF 多行注释,EOF可以替换为!# 等文件目录和执行目录 echo $0$0 # ./demo.sh echo $0的realpath$(realpath…...
C#设计模式--策略模式(Strategy Pattern)
策略模式是一种行为设计模式,它使你能在运行时改变对象的行为。在策略模式定义了一系列算法或策略,并将每个算法封装在独立的类中,使得它们可以互相替换。通过使用策略模式,可以在运行时根据需要选择不同的算法,而不需…...
【opencv入门教程】15. 访问像素的十四种方式
文章选自: 一、像素访问 一张图片由许多个点组成,每个点就是一个像素,每个像素包含不同的值,对图像像素操作是图像处理过程中常使用的 二、访问像素 void Samples::AccessPixels1(Mat &image, int div 64) {int nl imag…...
【MySQL调优】如何进行MySQL调优?从参数、数据建模、索引、SQL语句等方向,三万字详细解读MySQL的性能优化方案(2024版)
导航: 本文一些内容需要聚簇索引、非聚簇索引、B树、覆盖索引、索引下推等前置概念,虽然本文有简单回顾,但详细可以参考下文的【MySQL高级篇】 【Java笔记踩坑汇总】Java基础JavaWebSSMSpringBootSpringCloud瑞吉外卖/谷粒商城/学成在线设计模…...
根据html的段落长度设置QtextBrowser的显示内容,最少显示一个段落
要根据 HTML 段落的长度设置 QTextBrowser 的显示内容,并确保至少显示一个段落,可以通过以下步骤来实现: 加载 HTML 内容:首先,你需要加载 HTML 内容到 QTextBrowser 中。可以通过 setHtml() 方法来设置 HTML。 计算段…...
基于Huffman编码的GPS定位数据无损压缩算法
目录 一、引言 二、霍夫曼编码 三、经典Huffman编码 四、适应性Huffman编码 五、GPS定位数据压缩 提示:文末附定位数据压缩工具和源码 一、引言 车载监控系统中,车载终端需要获取GPS信号(经度、纬 度、速度、方向等)实时上传…...
php:完整部署Grid++Report到php项目,并实现模板打印
一、下载Grid++Report软件 路径:开发者安装包下载 - 锐浪报表工具 二、 安装软件 1、对下载的压缩包运行内部的exe文件 2、选择语言 3、 完成安装引导 下一步即可 4、接收许可协议 点击“我接受” 5、选择安装路径 “浏览”选择安装路径,点击"安装" 6、完成…...
C标签和 EL表达式的在前端界面的应用
目录 前言 常用的c标签有: for循环 1 表示 普通的for循环的 2 常在集合中使用 表示 选择关系 1 简单的表示如果 2 表示如果。。否则。。 EL表达式 格式 : ${属性名/对象/ 集合} 前言 本篇博客介绍 c标签和el表达式的使用 使用C标签 要引入 …...
Linux絮絮叨(四) 系统目录结构
Linux 系统的目录结构(Filesystem Hierarchy Standard, FHS)定义了 Linux 系统中文件系统的标准布局,以下是一些常见目录的功能: 根目录 / 描述:所有文件和目录的起始点,Linux 文件系统的根。内容…...
Java基于SpringBoot的网上订餐系统,附源码
博主介绍:✌Java老徐、7年大厂程序员经历。全网粉丝12w、csdn博客专家、掘金/华为云/阿里云/InfoQ等平台优质作者、专注于Java技术领域和毕业项目实战✌ 🍅文末获取源码联系🍅 👇🏻 精彩专栏推荐订阅👇&…...
《Java核心技术I》死锁
死锁 账户1:200元账户2: 300元线程1:从账号1转300到账户2线程2:从账户2转400到账户1 如上,线程1和线程2显然都被阻塞,两个账户的余额都不足以转账,两个线程都无法执行下去。 有可能会因为每一个线程要等…...
【Windows11系统局域网共享文件数据】
【Windows11系统局域网共享文件数据】 1. 引言1. 规划网络2. 获取必要的硬件3. 设置网络4. 配置网络设备5. 测试网络连接6. 安全性和维护7. 扩展和优化 2. 准备工作2.1: 启用网络发现和文件共享2.2: 设置共享文件夹 3. 访问共享文件夹4. 小贴士5. 总结 1. 引言 随着家庭和小型办…...
MCU、ARM体系结构,单片机基础,单片机操作
计算机基础 计算机的组成 输入设备、输出设备、存储器、运算器、控制器 输入设备:将其他信号转换为计算机可以识别的信号(电信号)。输出设备:将电信号(0、1)转为人或其他设备能理解的…...
在办公室环境中用HMD替代传统显示器的优势
VR头戴式显示器(HMD)是进入虚拟现实环境的一把钥匙,拥有HMD的您将能够在虚拟现实世界中尽情探索未知领域,正如如今的互联网一样,虚拟现实环境能够为您提供现实中无法实现的或不可能实现的事。随着技术的不断进步&#…...
ssm 多数据源 注解版本
application.xml 配置如下 <!-- 使用 DruidDataSource 数据源 --><bean id"primaryDataSource" class"com.alibaba.druid.pool.DruidDataSource" init-method"init" destroy-method"close"></bean> <!-- 使用 数…...
selenium常见接口函数使用
博客主页:花果山~程序猿-CSDN博客 文章分栏:测试_花果山~程序猿的博客-CSDN博客 关注我一起学习,一起进步,一起探索编程的无限可能吧!让我们一起努力,一起成长! 目录 1. 查找 查找方式 css_s…...
STM32F103单片机使用STM32CubeMX新建IAR工程步骤
打开STM32CubeMX软件,选择File 选择新建工程 在打开的窗口输入单片机型号 在右下角选择单片机型号,然后点右上角 start project,开始新建工程。 接下来设置调试接口,在左边System Core中选择 SYS,然后在右右边debu…...
刷题重开:找出字符串中第一个匹配项的下标——解题思路记录
问题描述: 给你两个字符串 haystack 和 needle ,请你在 haystack 字符串中找出 needle 字符串的第一个匹配项的下标(下标从 0 开始)。如果 needle 不是 haystack 的一部分,则返回 -1 。 示例 1: 输入&…...
家政维修平台实战20:权限设计
目录 1 获取工人信息2 搭建工人入口3 权限判断总结 目前我们已经搭建好了基础的用户体系,主要是分成几个表,用户表我们是记录用户的基础信息,包括手机、昵称、头像。而工人和员工各有各的表。那么就有一个问题,不同的角色…...
oracle与MySQL数据库之间数据同步的技术要点
Oracle与MySQL数据库之间的数据同步是一个涉及多个技术要点的复杂任务。由于Oracle和MySQL的架构差异,它们的数据同步要求既要保持数据的准确性和一致性,又要处理好性能问题。以下是一些主要的技术要点: 数据结构差异 数据类型差异ÿ…...
使用van-uploader 的UI组件,结合vue2如何实现图片上传组件的封装
以下是基于 vant-ui(适配 Vue2 版本 )实现截图中照片上传预览、删除功能,并封装成可复用组件的完整代码,包含样式和逻辑实现,可直接在 Vue2 项目中使用: 1. 封装的图片上传组件 ImageUploader.vue <te…...
【算法训练营Day07】字符串part1
文章目录 反转字符串反转字符串II替换数字 反转字符串 题目链接:344. 反转字符串 双指针法,两个指针的元素直接调转即可 class Solution {public void reverseString(char[] s) {int head 0;int end s.length - 1;while(head < end) {char temp …...
HTML前端开发:JavaScript 常用事件详解
作为前端开发的核心,JavaScript 事件是用户与网页交互的基础。以下是常见事件的详细说明和用法示例: 1. onclick - 点击事件 当元素被单击时触发(左键点击) button.onclick function() {alert("按钮被点击了!&…...
3403. 从盒子中找出字典序最大的字符串 I
3403. 从盒子中找出字典序最大的字符串 I 题目链接:3403. 从盒子中找出字典序最大的字符串 I 代码如下: class Solution { public:string answerString(string word, int numFriends) {if (numFriends 1) {return word;}string res;for (int i 0;i &…...
基于 TAPD 进行项目管理
起因 自己写了个小工具,仓库用的Github。之前在用markdown进行需求管理,现在随着功能的增加,感觉有点难以管理了,所以用TAPD这个工具进行需求、Bug管理。 操作流程 注册 TAPD,需要提供一个企业名新建一个项目&#…...
CSS | transition 和 transform的用处和区别
省流总结: transform用于变换/变形,transition是动画控制器 transform 用来对元素进行变形,常见的操作如下,它是立即生效的样式变形属性。 旋转 rotate(角度deg)、平移 translateX(像素px)、缩放 scale(倍数)、倾斜 skewX(角度…...
uniapp 集成腾讯云 IM 富媒体消息(地理位置/文件)
UniApp 集成腾讯云 IM 富媒体消息全攻略(地理位置/文件) 一、功能实现原理 腾讯云 IM 通过 消息扩展机制 支持富媒体类型,核心实现方式: 标准消息类型:直接使用 SDK 内置类型(文件、图片等)自…...
在鸿蒙HarmonyOS 5中使用DevEco Studio实现指南针功能
指南针功能是许多位置服务应用的基础功能之一。下面我将详细介绍如何在HarmonyOS 5中使用DevEco Studio实现指南针功能。 1. 开发环境准备 确保已安装DevEco Studio 3.1或更高版本确保项目使用的是HarmonyOS 5.0 SDK在项目的module.json5中配置必要的权限 2. 权限配置 在mo…...
