当前位置: 首页 > news >正文

PyTorch多GPU训练时同步梯度是mean还是sum?

PyTorch 通过两种方式可以进行多GPU训练: DataParallel, DistributedDataParallel. 当使用DataParallel的时候, 梯度的计算结果和在单卡上跑是一样的, 对每个数据计算出来的梯度进行累加. 当使用DistributedDataParallel的时候, 每个卡单独计算梯度, 然后多卡的梯度再进行平均.
下面是实验验证:

DataParallel

import torch
import os
import torch.nn as nndef main():model = nn.Linear(2, 3).cuda()model = torch.nn.DataParallel(model, device_ids=[0, 1])input = torch.rand(2, 2)labels = torch.tensor([[1, 0, 0], [0, 1, 0]]).cuda()(model(input) * labels).sum().backward()print('input', input)print([p.grad for p in model.parameters()])if __name__=="__main__":main()

执行CUDA_VISIBLE_DEVICES=0,1 python t.py可以看到输出, 代码中对两个样本分别求梯度, 梯度等于样本的值, DataParallel把两个样本的梯度累加起来在不同GPU中同步.

input tensor([[0.4362, 0.4574],[0.2052, 0.2362]])
[tensor([[0.4363, 0.4573],[0.2052, 0.2362],[0.0000, 0.0000]], device='cuda:0'), tensor([1., 1., 0.], device='cuda:0')]

DistributedDataParallel

import torch
import os
import torch.distributed as dist
import torch.multiprocessing as mp
import torch.nn as nn
import torch.optim as optim
from torch.nn.parallel import DistributedDataParallel as DDPdef example(rank, world_size):# create default process groupdist.init_process_group("gloo", rank=rank, world_size=world_size)# create local modelmodel = nn.Linear(2, 3).to(rank)print('model param', 'rank', rank, [p for p in model.parameters()])# construct DDP modelddp_model = DDP(model, device_ids=[rank])print('ddp model param', 'rank', rank, [p for p in ddp_model.parameters()])# forward passinput = torch.randn(1, 2).to(rank)outputs = ddp_model(input)labels = torch.randn(1, 3).to(rank) * 0labels[0, rank] = 1# backward pass(outputs * labels).sum().backward()print('rank', rank, 'grad', [p.grad for p in ddp_model.parameters()])print('rank', rank, 'input', input, 'outputs', outputs)print('rank', rank, 'labels', labels)# update parametersoptimizer.step()def main():world_size = 2mp.spawn(example,args=(world_size,),nprocs=world_size,join=True)if __name__=="__main__":# Environment variables which need to be# set when using c10d's default "env"# initialization mode.os.environ["MASTER_ADDR"] = "localhost"os.environ["MASTER_PORT"] = "29504"main()

执行CUDA_VISIBLE_DEVICES=0,1 python t1.py可以看到输出, 代码中对两个样本分别求梯度, 梯度等于样本的值, 最终的梯度是各个GPU的梯度的平均.

model param rank 0 [Parameter containing:
tensor([[-0.4819,  0.0253],[ 0.0858,  0.2256],[ 0.5614,  0.2702]], device='cuda:0', requires_grad=True), Parameter containing:
tensor([-0.0090,  0.4461, -0.3493], device='cuda:0', requires_grad=True)]
model param rank 1 [Parameter containing:
tensor([[-0.3737,  0.3062],[ 0.6450,  0.2930],[-0.2422,  0.2089]], device='cuda:1', requires_grad=True), Parameter containing:
tensor([-0.5868,  0.2106, -0.4461], device='cuda:1', requires_grad=True)]
ddp model param rank 1 [Parameter containing:
tensor([[-0.4819,  0.0253],[ 0.0858,  0.2256],[ 0.5614,  0.2702]], device='cuda:1', requires_grad=True), Parameter containing:
tensor([-0.0090,  0.4461, -0.3493], device='cuda:1', requires_grad=True)]
ddp model param rank 0 [Parameter containing:
tensor([[-0.4819,  0.0253],[ 0.0858,  0.2256],[ 0.5614,  0.2702]], device='cuda:0', requires_grad=True), Parameter containing:
tensor([-0.0090,  0.4461, -0.3493], device='cuda:0', requires_grad=True)]
rank 1 grad [tensor([[ 0.2605,  0.1631],[-0.0934, -0.5308],[ 0.0000,  0.0000]], device='cuda:1'), tensor([0.5000, 0.5000, 0.0000], device='cuda:1')]
rank 0 grad [tensor([[ 0.2605,  0.1631],[-0.0934, -0.5308],[ 0.0000,  0.0000]], device='cuda:0'), tensor([0.5000, 0.5000, 0.0000], device='cuda:0')]
rank 1 input tensor([[-0.1868, -1.0617]], device='cuda:1') outputs tensor([[ 0.0542,  0.1906, -0.7411]], device='cuda:1',grad_fn=<AddmmBackward0>)
rank 0 input tensor([[0.5209, 0.3261]], device='cuda:0') outputs tensor([[-0.2518,  0.5644,  0.0314]], device='cuda:0',grad_fn=<AddmmBackward0>)
rank 1 labels tensor([[-0., 1., -0.]], device='cuda:1')
rank 0 labels tensor([[1., 0., -0.]], device='cuda:0')

相关文章:

PyTorch多GPU训练时同步梯度是mean还是sum?

PyTorch 通过两种方式可以进行多GPU训练: DataParallel, DistributedDataParallel. 当使用DataParallel的时候, 梯度的计算结果和在单卡上跑是一样的, 对每个数据计算出来的梯度进行累加. 当使用DistributedDataParallel的时候, 每个卡单独计算梯度, 然后多卡的梯度再进行平均.…...

Spring Framework IoC依赖注入-按Bean类型注入

Spring Framework 作为一个领先的企业级开发框架&#xff0c;以其强大的依赖注入&#xff08;Dependency Injection&#xff0c;DI&#xff09;机制而闻名。DI使得开发者可以更加灵活地管理对象之间的关系&#xff0c;而不必过多关注对象的创建和组装。在Spring Framework中&am…...

IDEA运行thymeleaf的html文件打开端口为63342且连不上数据库

这边贴apple.html代码 <!DOCTYPE html> <html xmlns:th"http://www.thymeleaf.org"> <head><meta charset"UTF-8"><title>User List</title> </head> <body> <h1>User List</h1> <table&…...

sql报错注入和联合注入

1.[NISACTF 2022]join-us 过滤&#xff1a; as IF rand() LEFT by updatesubstring handler union floor benchmark COLUMN UPDATE & sys.schema_auto_increment_columns && 11 database case AND right CAST FLOOR left updatexml DATABASES BENCHMARK BY sleep…...

028 - STM32学习笔记 - ADC结构体学习(二)

028 - STM32学习笔记 - 结构体学习&#xff08;二&#xff09; 上节对ADC基础知识进行了学习&#xff0c;这节在了解一下ADC相关的结构体。 一、ADC初始化结构体 在标准库函数中基本上对于外设都有一个初始化结构体xx_InitTypeDef&#xff08;其中xx为外设名&#xff0c;例如…...

Pytest自动化测试框架:mark用法---测试用例分组执行

pytest中的mark&#xff1a; mark主要用于在测试用例/测试类中给用例打标记(只能使用已注册的标记名)&#xff0c;实现测试分组功能&#xff0c;并能和其它插件配合设置测试方法执行顺序等。 如下图&#xff0c;现在需要只执行红色部分的测试方法&#xff0c;其它方法不执行&am…...

【TCP连接的状态】

linux查看tcp的状态命令&#xff1a; 1&#xff09;、netstat -nat 查看TCP各个状态的数量 2&#xff09;、lsof -i:port 可以检测到打开套接字的状况 3)、 sar -n SOCK 查看tcp创建的连接数 4)、tcpdump -iany tcp port 9000 对tcp端口为9000的进行抓包 查看占用端口…...

Node.js入门指南(一)

目录 Node.js入门 什么是Node.js Node.js的作用 Node.js安装 Node.js编码注意事项 Buffer(缓冲器&#xff09; 定义 使用 fs模块 概念 文件写入 文件读取 文件移动与重命名 文件删除 文件夹操作 查看资源状态 路径问题 path模块 Node.js入门 什么是Node.js …...

使用Grpc实现高性能PHP RPC服务

文档&#xff1a;Quick start | PHP | gRPC 下面将介绍使用 Grpc 和 Protobuf 实现高性能 RPC 服务的具体步骤&#xff1a; 1. 安装 Grpc 和 Protobuf 首先需要安装 Grpc 和 Protobuf。可以从官网下载相应的安装包&#xff08;Supported languages | gRPC&#xff09;或通过…...

二、爬虫-爬取肯德基在北京的店铺地址

1、算法框架解释 针对这个案例&#xff0c;现在对爬虫的基础使用做总结如下&#xff1a; 1、算法框架 (1)设定传入参数 ~url: 当前整个页面的url:当前页面的网址 当前页面某个局部的url:打开检查 ~data:需要爬取数据的关键字&…...

linux驱动开发.之spi测试工具spidev_test源码(一)

同i2c-tools工具类似&#xff0c;spidev_test是用来测试SPI BUS的用户态程序&#xff0c;其源码存在kernel目录下的tools下&#xff0c;具体为tools\spi\spidev_test.c。buildroot同样也提供名为spidev_test的package&#xff0c;可以直接进行编译&#xff0c;方便用户调试spi总…...

基于材料生成算法优化概率神经网络PNN的分类预测 - 附代码

基于材料生成算法优化概率神经网络PNN的分类预测 - 附代码 文章目录 基于材料生成算法优化概率神经网络PNN的分类预测 - 附代码1.PNN网络概述2.变压器故障诊街系统相关背景2.1 模型建立 3.基于材料生成优化的PNN网络5.测试结果6.参考文献7.Matlab代码 摘要&#xff1a;针对PNN神…...

Go——二、变量和数据类型

Go 一、Go语言中的变量和常量1、Go语言中变量的声明2、如何定义变量方式1&#xff1a;方式2&#xff1a;带类型方式3&#xff1a;类型推导方式定义变量方式4&#xff1a;声明多个变量总结 3、如何定义常量4、Const常量结合iota的使用 二、Golang的数据类型1、概述2、整型2.1 类…...

合并区间问题

以数组 intervals 表示若干个区间的集合&#xff0c;其中单个区间为 intervals[i] [starti, endi] 。请你合并所有重叠的区间&#xff0c;并返回 一个不重叠的区间数组&#xff0c;该数组需恰好覆盖输入中的所有区间 。 示例 1&#xff1a; 输入&#xff1a;intervals [[1,…...

2023 年最新 MySQL 数据库 Windows 本地安装、Centos 服务器安装详细教程

MySQL 基本概述 MySQL是一个流行的关系型数据库管理系统&#xff08;RDBMS&#xff09;&#xff0c;广泛应用于各种业务场景。它是由瑞典MySQL AB公司开发&#xff0c;后来被Sun Microsystems收购&#xff0c;最终被甲骨文公司&#xff08;Oracle Corporation&#xff09;收购…...

每天一道算法题(十)——获取和为k的子数组

文章目录 1、问题2、示例3、解决方法&#xff08;1&#xff09;方法1——双指针 总结 1、问题 给你一个整数数组 nums 和一个整数 k &#xff0c;请你统计并返回 该数组中和为 k 的子数组的个数 。 子数组是数组中元素的连续非空序列。 2、示例 示例 1&#xff1a; 输入&#x…...

2023年亚太杯数学建模思路 - 案例:最短时间生产计划安排

文章目录 0 赛题思路1 模型描述2 实例2.1 问题描述2.2 数学模型2.2.1 模型流程2.2.2 符号约定2.2.3 求解模型 2.3 相关代码2.4 模型求解结果 建模资料 0 赛题思路 &#xff08;赛题出来以后第一时间在CSDN分享&#xff09; https://blog.csdn.net/dc_sinor?typeblog 最短时…...

在vscode中使用Latex:TexLive2023

安装TexLive2023及配置vscode可参考https://zhuanlan.zhihu.com/p/166523064 然后编译模板 .tex文件时&#xff0c;出现以下几个错误&#xff1a; 1. ctexbook找不到字体集 d:/texlive/2023/texmf-dist/tex/latex/ctex/ctexbook.cls:1678: Class ctexbook Error: CTeX fo…...

Unity开发之C#基础-File文件读取

前言 今天我们将要讲解到c#中 对于文件的读写是怎样的 那么没接触过特别系统编程小伙伴们应该会有一个疑问 这跟文件有什么关系呢&#xff1f; 我们这样来理解 首先 大家对电脑或多或少都应该有不少的了解吧 那么我们这些软件 都是通过变成一个一个文件保存在电脑中 我们才可以…...

深度学习之二(前馈神经网络--Feedforward Neural Network)

概念 前馈神经网络(Feedforward Neural Network)是一种最基本的神经网络结构,也被称为多层感知器(Multilayer Perceptron,MLP)。它的特点是信息只在网络中单向传播,不会形成环路。每一层神经元的输出都作为下一层神经元的输入,没有反馈回路。 结构: 前馈神经网络通…...

业务系统对接大模型的基础方案:架构设计与关键步骤

业务系统对接大模型&#xff1a;架构设计与关键步骤 在当今数字化转型的浪潮中&#xff0c;大语言模型&#xff08;LLM&#xff09;已成为企业提升业务效率和创新能力的关键技术之一。将大模型集成到业务系统中&#xff0c;不仅可以优化用户体验&#xff0c;还能为业务决策提供…...

<6>-MySQL表的增删查改

目录 一&#xff0c;create&#xff08;创建表&#xff09; 二&#xff0c;retrieve&#xff08;查询表&#xff09; 1&#xff0c;select列 2&#xff0c;where条件 三&#xff0c;update&#xff08;更新表&#xff09; 四&#xff0c;delete&#xff08;删除表&#xf…...

Opencv中的addweighted函数

一.addweighted函数作用 addweighted&#xff08;&#xff09;是OpenCV库中用于图像处理的函数&#xff0c;主要功能是将两个输入图像&#xff08;尺寸和类型相同&#xff09;按照指定的权重进行加权叠加&#xff08;图像融合&#xff09;&#xff0c;并添加一个标量值&#x…...

ESP32读取DHT11温湿度数据

芯片&#xff1a;ESP32 环境&#xff1a;Arduino 一、安装DHT11传感器库 红框的库&#xff0c;别安装错了 二、代码 注意&#xff0c;DATA口要连接在D15上 #include "DHT.h" // 包含DHT库#define DHTPIN 15 // 定义DHT11数据引脚连接到ESP32的GPIO15 #define D…...

Java - Mysql数据类型对应

Mysql数据类型java数据类型备注整型INT/INTEGERint / java.lang.Integer–BIGINTlong/java.lang.Long–––浮点型FLOATfloat/java.lang.FloatDOUBLEdouble/java.lang.Double–DECIMAL/NUMERICjava.math.BigDecimal字符串型CHARjava.lang.String固定长度字符串VARCHARjava.lang…...

微服务商城-商品微服务

数据表 CREATE TABLE product (id bigint(20) UNSIGNED NOT NULL AUTO_INCREMENT COMMENT 商品id,cateid smallint(6) UNSIGNED NOT NULL DEFAULT 0 COMMENT 类别Id,name varchar(100) NOT NULL DEFAULT COMMENT 商品名称,subtitle varchar(200) NOT NULL DEFAULT COMMENT 商…...

Linux-07 ubuntu 的 chrome 启动不了

文章目录 问题原因解决步骤一、卸载旧版chrome二、重新安装chorme三、启动不了&#xff0c;报错如下四、启动不了&#xff0c;解决如下 总结 问题原因 在应用中可以看到chrome&#xff0c;但是打不开(说明&#xff1a;原来的ubuntu系统出问题了&#xff0c;这个是备用的硬盘&a…...

【Java_EE】Spring MVC

目录 Spring Web MVC ​编辑注解 RestController RequestMapping RequestParam RequestParam RequestBody PathVariable RequestPart 参数传递 注意事项 ​编辑参数重命名 RequestParam ​编辑​编辑传递集合 RequestParam 传递JSON数据 ​编辑RequestBody ​…...

Fabric V2.5 通用溯源系统——增加图片上传与下载功能

fabric-trace项目在发布一年后,部署量已突破1000次,为支持更多场景,现新增支持图片信息上链,本文对图片上传、下载功能代码进行梳理,包含智能合约、后端、前端部分。 一、智能合约修改 为了增加图片信息上链溯源,需要对底层数据结构进行修改,在此对智能合约中的农产品数…...

CVE-2020-17519源码分析与漏洞复现(Flink 任意文件读取)

漏洞概览 漏洞名称&#xff1a;Apache Flink REST API 任意文件读取漏洞CVE编号&#xff1a;CVE-2020-17519CVSS评分&#xff1a;7.5影响版本&#xff1a;Apache Flink 1.11.0、1.11.1、1.11.2修复版本&#xff1a;≥ 1.11.3 或 ≥ 1.12.0漏洞类型&#xff1a;路径遍历&#x…...