使用 PyTorch 和 Horovod 来编写一个简单的分布式训练 demo
使用 PyTorch 和 Horovod 来编写一个简单的分布式训练 demo,可以帮助你理解如何在多GPU或多节点环境中高效地训练深度学习模型。Horovod 是 Uber 开发的一个用于分布式训练的框架,它支持 TensorFlow、Keras、PyTorch 等多个机器学习库。下面是一个基于 PyTorch 的简单例子,演示了如何用 Horovod 进行分布式训练。
安装依赖
首先确保你已经安装了 PyTorch 和 Horovod。你可以通过 pip 或者 conda 来安装这些包。对于 Horovod,推荐使用 MPI(Message Passing Interface)进行通信,因此你也需要安装 MPI 和相应的开发工具。
pip install torch torchvision horovod
或者如果你使用的是 Anaconda:
conda install pytorch torchvision -c pytorch
horovodrun --check
# 如果没有安装 horovod, 可以使用以下命令安装:
pip install horovod[pytorch]
编写 PyTorch + Horovod 代码
创建一个新的 Python 文件 train.py,然后添加如下代码:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
import horovod.torch as hvd
from torchvision import datasets, transforms# 初始化 Horovod
hvd.init()# 设置随机种子,确保结果可复现
torch.manual_seed(42)# 如果有 GPU 可用,则使用 GPU
if torch.cuda.is_available():torch.cuda.set_device(hvd.local_rank())# 加载数据集 (以 MNIST 数据集为例)
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
train_dataset = datasets.MNIST('.', train=True, download=True, transform=transform)
val_dataset = datasets.MNIST('.', train=False, transform=transform)# 分布式采样器
train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset, num_replicas=hvd.size(), rank=hvd.rank())
val_sampler = torch.utils.data.distributed.DistributedSampler(val_dataset, num_replicas=hvd.size(), rank=hvd.rank())# 创建数据加载器
train_loader = DataLoader(train_dataset, batch_size=64, sampler=train_sampler)
val_loader = DataLoader(val_dataset, batch_size=1000, sampler=val_sampler)# 定义模型
class Net(nn.Module):def __init__(self):super(Net, self).__init__()self.conv1 = nn.Conv2d(1, 32, 3, 1)self.conv2 = nn.Conv2d(32, 64, 3, 1)self.dropout1 = nn.Dropout2d(0.25)self.dropout2 = nn.Dropout2d(0.5)self.fc1 = nn.Linear(9216, 128)self.fc2 = nn.Linear(128, 10)def forward(self, x):x = self.conv1(x)x = nn.functional.relu(x)x = self.conv2(x)x = nn.functional.relu(x)x = nn.functional.max_pool2d(x, 2)x = self.dropout1(x)x = torch.flatten(x, 1)x = self.fc1(x)x = nn.functional.relu(x)x = self.dropout2(x)x = self.fc2(x)output = nn.functional.log_softmax(x, dim=1)return outputmodel = Net()# 如果有 GPU 可用,则将模型转移到 GPU 上
if torch.cuda.is_available():model.cuda()# 定义损失函数和优化器,并应用 Horovod 的 DistributedOptimizer 包装
optimizer = optim.Adam(model.parameters(), lr=0.001 * hvd.size())
optimizer = hvd.DistributedOptimizer(optimizer,named_parameters=model.named_parameters(),op=hvd.Average)# 损失函数
criterion = nn.CrossEntropyLoss()# 训练模型
def train(epoch):model.train()for batch_idx, (data, target) in enumerate(train_loader):if torch.cuda.is_available():data, target = data.cuda(), target.cuda()optimizer.zero_grad()output = model(data)loss = criterion(output, target)loss.backward()optimizer.step()if batch_idx % 100 == 0:print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(epoch, batch_idx * len(data), len(train_sampler),100. * batch_idx / len(train_loader), loss.item()))# 验证模型
def validate():model.eval()validation_loss = 0correct = 0with torch.no_grad():for data, target in val_loader:if torch.cuda.is_available():data, target = data.cuda(), target.cuda()output = model(data)validation_loss += criterion(output, target).item() # sum up batch losspred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probabilitycorrect += pred.eq(target.view_as(pred)).sum().item()validation_loss /= len(val_loader.dataset)accuracy = 100. * correct / len(val_loader.dataset)print('\nValidation set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(validation_loss, correct, len(val_loader.dataset), accuracy))# 调用训练和验证函数
for epoch in range(1, 4):train_sampler.set_epoch(epoch)train(epoch)validate()# 广播模型状态从rank 0到其他进程
hvd.broadcast_parameters(model.state_dict(), root_rank=0)
hvd.broadcast_optimizer_state(optimizer, root_rank=0)
运行代码
要运行这段代码,你需要使用 horovodrun 命令来启动多个进程。例如,在单个节点上的 4 个 GPU 上运行该脚本,可以这样做:
horovodrun -np 4 -H localhost:4 python train.py
这将在本地主机上启动四个进程,每个进程都会占用一个 GPU。如果你是在多台机器上运行,你需要指定每台机器的地址和可用的 GPU 数量。
请注意,这个例子是一个非常基础的实现,实际应用中可能还需要考虑更多的细节,比如更复杂的模型结构、数据预处理、超参数调整等。
相关文章:
使用 PyTorch 和 Horovod 来编写一个简单的分布式训练 demo
使用 PyTorch 和 Horovod 来编写一个简单的分布式训练 demo,可以帮助你理解如何在多GPU或多节点环境中高效地训练深度学习模型。Horovod 是 Uber 开发的一个用于分布式训练的框架,它支持 TensorFlow、Keras、PyTorch 等多个机器学习库。下面是一个基于 P…...
SQL复杂查询功能介绍及示例
文章目录 1. 多表连接(JOIN)功能介绍应用场景示例查询及初始表格customers 表(未查询前)orders 表(未查询前)INNER JOIN 示例LEFT JOIN 示例 2. 子查询(Subquery)功能介绍应用场景示…...
shell基础用法
shell基础知识 shell中的多行注释 :<<EOF read echo $REPLY # read不指定变量,则默认写入$REPLY EOF # :<<EOF ...EOF 多行注释,EOF可以替换为!# 等文件目录和执行目录 echo $0$0 # ./demo.sh echo $0的realpath$(realpath…...
C#设计模式--策略模式(Strategy Pattern)
策略模式是一种行为设计模式,它使你能在运行时改变对象的行为。在策略模式定义了一系列算法或策略,并将每个算法封装在独立的类中,使得它们可以互相替换。通过使用策略模式,可以在运行时根据需要选择不同的算法,而不需…...
【opencv入门教程】15. 访问像素的十四种方式
文章选自: 一、像素访问 一张图片由许多个点组成,每个点就是一个像素,每个像素包含不同的值,对图像像素操作是图像处理过程中常使用的 二、访问像素 void Samples::AccessPixels1(Mat &image, int div 64) {int nl imag…...
【MySQL调优】如何进行MySQL调优?从参数、数据建模、索引、SQL语句等方向,三万字详细解读MySQL的性能优化方案(2024版)
导航: 本文一些内容需要聚簇索引、非聚簇索引、B树、覆盖索引、索引下推等前置概念,虽然本文有简单回顾,但详细可以参考下文的【MySQL高级篇】 【Java笔记踩坑汇总】Java基础JavaWebSSMSpringBootSpringCloud瑞吉外卖/谷粒商城/学成在线设计模…...
根据html的段落长度设置QtextBrowser的显示内容,最少显示一个段落
要根据 HTML 段落的长度设置 QTextBrowser 的显示内容,并确保至少显示一个段落,可以通过以下步骤来实现: 加载 HTML 内容:首先,你需要加载 HTML 内容到 QTextBrowser 中。可以通过 setHtml() 方法来设置 HTML。 计算段…...
基于Huffman编码的GPS定位数据无损压缩算法
目录 一、引言 二、霍夫曼编码 三、经典Huffman编码 四、适应性Huffman编码 五、GPS定位数据压缩 提示:文末附定位数据压缩工具和源码 一、引言 车载监控系统中,车载终端需要获取GPS信号(经度、纬 度、速度、方向等)实时上传…...
php:完整部署Grid++Report到php项目,并实现模板打印
一、下载Grid++Report软件 路径:开发者安装包下载 - 锐浪报表工具 二、 安装软件 1、对下载的压缩包运行内部的exe文件 2、选择语言 3、 完成安装引导 下一步即可 4、接收许可协议 点击“我接受” 5、选择安装路径 “浏览”选择安装路径,点击"安装" 6、完成…...
C标签和 EL表达式的在前端界面的应用
目录 前言 常用的c标签有: for循环 1 表示 普通的for循环的 2 常在集合中使用 表示 选择关系 1 简单的表示如果 2 表示如果。。否则。。 EL表达式 格式 : ${属性名/对象/ 集合} 前言 本篇博客介绍 c标签和el表达式的使用 使用C标签 要引入 …...
Linux絮絮叨(四) 系统目录结构
Linux 系统的目录结构(Filesystem Hierarchy Standard, FHS)定义了 Linux 系统中文件系统的标准布局,以下是一些常见目录的功能: 根目录 / 描述:所有文件和目录的起始点,Linux 文件系统的根。内容…...
Java基于SpringBoot的网上订餐系统,附源码
博主介绍:✌Java老徐、7年大厂程序员经历。全网粉丝12w、csdn博客专家、掘金/华为云/阿里云/InfoQ等平台优质作者、专注于Java技术领域和毕业项目实战✌ 🍅文末获取源码联系🍅 👇🏻 精彩专栏推荐订阅👇&…...
《Java核心技术I》死锁
死锁 账户1:200元账户2: 300元线程1:从账号1转300到账户2线程2:从账户2转400到账户1 如上,线程1和线程2显然都被阻塞,两个账户的余额都不足以转账,两个线程都无法执行下去。 有可能会因为每一个线程要等…...
【Windows11系统局域网共享文件数据】
【Windows11系统局域网共享文件数据】 1. 引言1. 规划网络2. 获取必要的硬件3. 设置网络4. 配置网络设备5. 测试网络连接6. 安全性和维护7. 扩展和优化 2. 准备工作2.1: 启用网络发现和文件共享2.2: 设置共享文件夹 3. 访问共享文件夹4. 小贴士5. 总结 1. 引言 随着家庭和小型办…...
MCU、ARM体系结构,单片机基础,单片机操作
计算机基础 计算机的组成 输入设备、输出设备、存储器、运算器、控制器 输入设备:将其他信号转换为计算机可以识别的信号(电信号)。输出设备:将电信号(0、1)转为人或其他设备能理解的…...
在办公室环境中用HMD替代传统显示器的优势
VR头戴式显示器(HMD)是进入虚拟现实环境的一把钥匙,拥有HMD的您将能够在虚拟现实世界中尽情探索未知领域,正如如今的互联网一样,虚拟现实环境能够为您提供现实中无法实现的或不可能实现的事。随着技术的不断进步&#…...
ssm 多数据源 注解版本
application.xml 配置如下 <!-- 使用 DruidDataSource 数据源 --><bean id"primaryDataSource" class"com.alibaba.druid.pool.DruidDataSource" init-method"init" destroy-method"close"></bean> <!-- 使用 数…...
selenium常见接口函数使用
博客主页:花果山~程序猿-CSDN博客 文章分栏:测试_花果山~程序猿的博客-CSDN博客 关注我一起学习,一起进步,一起探索编程的无限可能吧!让我们一起努力,一起成长! 目录 1. 查找 查找方式 css_s…...
STM32F103单片机使用STM32CubeMX新建IAR工程步骤
打开STM32CubeMX软件,选择File 选择新建工程 在打开的窗口输入单片机型号 在右下角选择单片机型号,然后点右上角 start project,开始新建工程。 接下来设置调试接口,在左边System Core中选择 SYS,然后在右右边debu…...
刷题重开:找出字符串中第一个匹配项的下标——解题思路记录
问题描述: 给你两个字符串 haystack 和 needle ,请你在 haystack 字符串中找出 needle 字符串的第一个匹配项的下标(下标从 0 开始)。如果 needle 不是 haystack 的一部分,则返回 -1 。 示例 1: 输入&…...
利用ngx_stream_return_module构建简易 TCP/UDP 响应网关
一、模块概述 ngx_stream_return_module 提供了一个极简的指令: return <value>;在收到客户端连接后,立即将 <value> 写回并关闭连接。<value> 支持内嵌文本和内置变量(如 $time_iso8601、$remote_addr 等)&a…...
(十)学生端搭建
本次旨在将之前的已完成的部分功能进行拼装到学生端,同时完善学生端的构建。本次工作主要包括: 1.学生端整体界面布局 2.模拟考场与部分个人画像流程的串联 3.整体学生端逻辑 一、学生端 在主界面可以选择自己的用户角色 选择学生则进入学生登录界面…...
【论文阅读28】-CNN-BiLSTM-Attention-(2024)
本文把滑坡位移序列拆开、筛优质因子,再用 CNN-BiLSTM-Attention 来动态预测每个子序列,最后重构出总位移,预测效果超越传统模型。 文章目录 1 引言2 方法2.1 位移时间序列加性模型2.2 变分模态分解 (VMD) 具体步骤2.3.1 样本熵(S…...
ios苹果系统,js 滑动屏幕、锚定无效
现象:window.addEventListener监听touch无效,划不动屏幕,但是代码逻辑都有执行到。 scrollIntoView也无效。 原因:这是因为 iOS 的触摸事件处理机制和 touch-action: none 的设置有关。ios有太多得交互动作,从而会影响…...
代理篇12|深入理解 Vite中的Proxy接口代理配置
在前端开发中,常常会遇到 跨域请求接口 的情况。为了解决这个问题,Vite 和 Webpack 都提供了 proxy 代理功能,用于将本地开发请求转发到后端服务器。 什么是代理(proxy)? 代理是在开发过程中,前端项目通过开发服务器,将指定的请求“转发”到真实的后端服务器,从而绕…...
【Go语言基础【13】】函数、闭包、方法
文章目录 零、概述一、函数基础1、函数基础概念2、参数传递机制3、返回值特性3.1. 多返回值3.2. 命名返回值3.3. 错误处理 二、函数类型与高阶函数1. 函数类型定义2. 高阶函数(函数作为参数、返回值) 三、匿名函数与闭包1. 匿名函数(Lambda函…...
CSS | transition 和 transform的用处和区别
省流总结: transform用于变换/变形,transition是动画控制器 transform 用来对元素进行变形,常见的操作如下,它是立即生效的样式变形属性。 旋转 rotate(角度deg)、平移 translateX(像素px)、缩放 scale(倍数)、倾斜 skewX(角度…...
探索Selenium:自动化测试的神奇钥匙
目录 一、Selenium 是什么1.1 定义与概念1.2 发展历程1.3 功能概述 二、Selenium 工作原理剖析2.1 架构组成2.2 工作流程2.3 通信机制 三、Selenium 的优势3.1 跨浏览器与平台支持3.2 丰富的语言支持3.3 强大的社区支持 四、Selenium 的应用场景4.1 Web 应用自动化测试4.2 数据…...
从“安全密码”到测试体系:Gitee Test 赋能关键领域软件质量保障
关键领域软件测试的"安全密码":Gitee Test如何破解行业痛点 在数字化浪潮席卷全球的今天,软件系统已成为国家关键领域的"神经中枢"。从国防军工到能源电力,从金融交易到交通管控,这些关乎国计民生的关键领域…...
协议转换利器,profinet转ethercat网关的两大派系,各有千秋
随着工业以太网的发展,其高效、便捷、协议开放、易于冗余等诸多优点,被越来越多的工业现场所采用。西门子SIMATIC S7-1200/1500系列PLC集成有Profinet接口,具有实时性、开放性,使用TCP/IP和IT标准,符合基于工业以太网的…...
