当前位置: 首页 > news >正文

以Bert训练为例,测试torch不同的运行方式,并用torch.profile+HolisticTraceAnalysis分析性能瓶颈

以Bert训练为例,测试torch不同的运行方式,并用torch.profile+HolisticTraceAnalysis分析性能瓶颈

  • 1.参考链接:
  • 2.性能对比
  • 3.相关依赖或命令
  • 4.测试代码
  • 5.HolisticTraceAnalysis代码
  • 6.可视化
    • A.优化前
    • B.优化后

以Bert训练为例,测试torch不同的运行方式,并用torch.profile+HolisticTraceAnalysis分析性能瓶颈

1.参考链接:

  • Accelerating PyTorch with CUDA Graphs
  • BERT
  • torch-compiler

2.性能对比

序号运行方式build耗时(s)warmup耗时(s)运行耗时(w)备注
1普通模式0.70max:0.0791 min:0.0358 std:0.0126 mean:0.0586CPU Bound
2torch.cuda.CUDAGraph()0.01max:0.0109 min:0.0090 std:0.0006 mean:0.0094Kernel Bound
3torch.compile(“cudagraphs”)0.712610.7256max:3.9467 min:0.0197 std:1.1683 mean:0.4590
4torch.compile(“inductor”)0.000545.1444max:5.9465 min:0.0389 std:1.7684 mean:0.6415

3.相关依赖或命令

# 安装pytorch
pip install torch==2.3.1 -i https://pypi.tuna.tsinghua.edu.cn/simple# 安装HTA
git clone https://github.com/facebookresearch/HolisticTraceAnalysis.git
cd HolisticTraceAnalysis
git submodule update --init
pip install -r requirements.txt
pip install -e .# 运行jupyter
pip install jupyter
jupyter notebook --allow-root --no-browser --ip=192.168.1.100 --port 8080

4.测试代码

import os
import warnings
warnings.filterwarnings("ignore")
import copy
import sys
import torch
from tqdm import tqdm
from torch.profiler import profile
import time
from typing import Final, Any, Callable
import random
import numpy as np
import os
import requests
import importlib.util
import sys
import jsondef download_module(url, destination_path):response = requests.get(url)response.raise_for_status()with open(destination_path, 'wb') as f:f.write(response.content)def module_from_path(module_name, file_path):spec = importlib.util.spec_from_file_location(module_name, file_path)module = importlib.util.module_from_spec(spec)sys.modules[module_name] = modulespec.loader.exec_module(module)return moduledef load_or_download_module(module_url, module_name, cache_dir=".cache"):if not os.path.exists(cache_dir):os.makedirs(cache_dir)destination_path = os.path.join(cache_dir, module_name + ".py")if not os.path.isfile(destination_path):download_module(module_url, destination_path)module = module_from_path(module_name, destination_path)return moduleimport sys
sys.path.append(".cache/")module_url = "https://raw.githubusercontent.com/NVIDIA/DeepLearningExamples/master/PyTorch/LanguageModeling/BERT/file_utils.py"
module_name = "file_utils"
load_or_download_module(module_url, module_name)module_url = "https://raw.githubusercontent.com/NVIDIA/DeepLearningExamples/master/PyTorch/LanguageModeling/BERT/modeling.py"
module_name = "modeling"
modeling = load_or_download_module(module_url, module_name)def fix_gelu_bug(fn):def wrapper(tensor, *args, **kwargs):return fn(tensor)return wrapper
torch.nn.functional.gelu=fix_gelu_bug(torch.nn.functional.gelu)class SyncFreeStats :def __init__(self) :self.host_stats = {}self.device_stats = {}self.device_funcs = {}def add_stat(self, name, dtype=torch.int32, device_tensor=None, device_func=None) :if device_tensor is not None :assert dtype == device_tensor.dtype, "Error: dtype do not match: {} {}".format(dtype, device_tensor.dtype)self.host_stats[name] = torch.zeros(1, dtype=dtype).pin_memory()self.device_stats[name] = device_tensorself.device_funcs[name] = device_funcdef copy_from_device(self) :for name in self.host_stats.keys() :# Apply device function to device statif self.device_stats[name] is not None and self.device_funcs[name] is not None:self.host_stats[name].copy_(self.device_funcs[name](self.device_stats[name]), non_blocking=True)elif self.device_stats[name] is not None :self.host_stats[name].copy_(self.device_stats[name], non_blocking=True)elif self.device_funcs[name] is not None :self.host_stats[name].copy_(self.device_funcs[name](), non_blocking=True)def host_stat(self, name) :assert name in self.host_statsreturn self.host_stats[name]def host_stat_value(self, name) :assert name in self.host_statsreturn self.host_stats[name].item()def update_host_stat(self, name, tensor) :self.host_stats[name] = tensordef device_stat(self, name) :assert self.device_stats[name] is not Nonereturn self.device_stats[name]def update_device_stat(self, name, tensor) :self.device_stats[name] = tensorclass BertPretrainingCriterion(torch.nn.Module):sequence_output_is_dense: Final[bool]def __init__(self, vocab_size, sequence_output_is_dense=False):super(BertPretrainingCriterion, self).__init__()self.loss_fn = torch.nn.CrossEntropyLoss(ignore_index=-1)self.vocab_size = vocab_sizeself.sequence_output_is_dense = sequence_output_is_densedef forward(self, prediction_scores, seq_relationship_score, masked_lm_labels, next_sentence_labels):if self.sequence_output_is_dense:# prediction_scores are already densemasked_lm_labels_flat = masked_lm_labels.view(-1)mlm_labels = masked_lm_labels_flat[masked_lm_labels_flat != -1]masked_lm_loss = self.loss_fn(prediction_scores.view(-1, self.vocab_size), mlm_labels.view(-1))else:masked_lm_loss = self.loss_fn(prediction_scores.view(-1, self.vocab_size), masked_lm_labels.view(-1))next_sentence_loss = self.loss_fn(seq_relationship_score.view(-1, 2), next_sentence_labels.view(-1))total_loss = masked_lm_loss + next_sentence_lossreturn total_lossdef setup_model_optimizer_data(device="cuda"):train_batch_size=1max_seq_length=128config=modeling.BertConfig(21128)sequence_output_is_dense=Falsemodel = modeling.BertForPreTraining(config, sequence_output_is_dense=sequence_output_is_dense)model=model.half()model.train().to(device)optimizer = torch.optim.SGD(model.parameters(), lr=0.1)criterion = BertPretrainingCriterion(config.vocab_size, sequence_output_is_dense=sequence_output_is_dense).to(device)batch = {'input_ids': torch.ones(train_batch_size, max_seq_length, dtype=torch.int64, device=device),'token_type_ids': torch.ones(train_batch_size, max_seq_length, dtype=torch.int64, device=device),'attention_mask': torch.ones(train_batch_size, max_seq_length, dtype=torch.int64, device=device),'labels': torch.ones(train_batch_size, max_seq_length, dtype=torch.int64, device=device),'next_sentence_labels': torch.ones(train_batch_size, dtype=torch.int64, device=device),}stats = SyncFreeStats()stats.add_stat('average_loss', dtype=torch.float32, device_tensor=torch.zeros(1, dtype=torch.float32, device=device))return model,optimizer,criterion,batch,statsdef train_step(model,optimizer,criterion,batch,stats):optimizer.zero_grad(set_to_none=True)prediction_scores,seq_relationship_score=model(input_ids=batch['input_ids'],token_type_ids=batch['token_type_ids'],attention_mask=batch['attention_mask'],masked_lm_labels=batch['labels'])loss = criterion(prediction_scores, seq_relationship_score, batch['labels'], batch['next_sentence_labels'])stats.device_stat('average_loss').add_(loss.detach())loss.backward()optimizer.step()  def reset_seed():random.seed(0)np.random.seed(0)torch.manual_seed(0)torch.cuda.manual_seed(0)def stat(data):return f"max:{np.max(data):.4f} min:{np.min(data):.4f} std:{np.std(data):.4f} mean:{np.mean(data):.4f}"def prof_bert_native():reset_seed()activities=[torch.profiler.ProfilerActivity.CPU]activities.append(torch.profiler.ProfilerActivity.CUDA)model,optimizer,criterion,batch,stats=setup_model_optimizer_data()t0=time.time()train_step(model,optimizer,criterion,batch,stats)     torch.cuda.synchronize()t1=time.time()print(f"warmup:{t1-t0:.2f}")latency=[] with profile(activities=activities,record_shapes=True,with_stack=True,with_modules=True,schedule=torch.profiler.schedule(wait=1,warmup=1,active=3,repeat=0),with_flops=True,profile_memory=True) as prof:for i in range(10):t0=time.time()train_step(model,optimizer,criterion,batch,stats)     torch.cuda.synchronize()t1=time.time()latency.append(t1-t0)prof.step()stats.copy_from_device()      print(f"native average_loss:{stats.host_stat_value('average_loss'):.4f} {stat(latency)}")prof.export_chrome_trace("prof_bert_native.json")def prof_bert_cudagraph():reset_seed()activities=[torch.profiler.ProfilerActivity.CPU]activities.append(torch.profiler.ProfilerActivity.CUDA)model,optimizer,criterion,batch,stats=setup_model_optimizer_data()# Warmup Steps - includes jitting fusionsside_stream = torch.cuda.Stream()side_stream.wait_stream(torch.cuda.current_stream())with torch.cuda.stream(side_stream):for _ in range(11):train_step(model,optimizer,criterion,batch,stats)torch.cuda.current_stream().wait_stream(side_stream)# Capture Graphfull_cudagraph = torch.cuda.CUDAGraph()with torch.cuda.graph(full_cudagraph):train_step(model,optimizer,criterion,batch,stats)print("build done")t0=time.time()full_cudagraph.replay()torch.cuda.synchronize()t1=time.time()print(f"warmup:{t1-t0:.2f}")latency=[]with profile(activities=activities,record_shapes=True,with_stack=True,with_modules=True,schedule=torch.profiler.schedule(wait=1,warmup=1,active=3,repeat=0),with_flops=True,profile_memory=True) as prof:for i in range(10):t0=time.time()full_cudagraph.replay()torch.cuda.synchronize()t1=time.time()latency.append(t1-t0)prof.step()stats.copy_from_device()           print(f"cudagraph average_loss:{stats.host_stat_value('average_loss'):.4f} {stat(latency)}")prof.export_chrome_trace("prof_bert_cudagraph.json")def prof_bert_torchcompiler(backend):reset_seed()activities=[torch.profiler.ProfilerActivity.CPU]activities.append(torch.profiler.ProfilerActivity.CUDA)model,optimizer,criterion,batch,stats=setup_model_optimizer_data()latency=[]   t0=time.time()new_fn = torch.compile(train_step, backend=backend)t1=time.time()print(f"torchcompiler_{backend} build:{t1-t0:.4f}s")new_fn(model,optimizer,criterion,batch,stats)     torch.cuda.synchronize()t2=time.time()print(f"torchcompiler_{backend} warmup:{t2-t1:.4f}s")with profile(activities=activities,record_shapes=True,with_stack=True,with_modules=True,schedule=torch.profiler.schedule(wait=1,warmup=1,active=3,repeat=0),with_flops=True,profile_memory=True) as prof:for i in range(10):t0=time.time()new_fn(model,optimizer,criterion,batch,stats)     torch.cuda.synchronize()t1=time.time()latency.append(t1-t0)prof.step()stats.copy_from_device()print(f"torchcompiler_{backend} average_loss:{stats.host_stat_value('average_loss'):.4f} {stat(latency)}")prof.export_chrome_trace(f"prof_bert_torchcompiler_{backend}.json")os.environ['LOCAL_RANK']="0"
os.environ['RANK']="0"
os.environ['WORLD_SIZE']="1"
os.environ['MASTER_ADDR']="localhost"
os.environ['MASTER_PORT']="6006"import torch.distributed as dist
dist.init_process_group(backend='nccl')
rank=torch.distributed.get_rank()prof_bert_native()
prof_bert_cudagraph()
prof_bert_torchcompiler("cudagraphs")
prof_bert_torchcompiler("inductor")

5.HolisticTraceAnalysis代码

#!/usr/bin/env python
# coding: utf-8
# In[25]:
import warnings
warnings.filterwarnings("ignore")
from hta.trace_analysis import TraceAnalysis
analyzer = TraceAnalysis(trace_dir = "./traces")
# In[26]:
temporal_breakdown_df = analyzer.get_temporal_breakdown()
# kernel_type_metrics_df, kernel_metrics_df = analyzer.get_gpu_kernel_breakdown()
# In[28]:
kernel_type_metrics_df
# In[29]:
kernel_metrics_df
# In[30]:
idle_time_df, interval_stats_df = analyzer.get_idle_time_breakdown(ranks=[0], visualize=True,\visualize_pctg = 1,show_idle_interval_stats=True)
# In[31]:
cuda_launch_kernel_stats = analyzer.get_cuda_kernel_launch_stats()
# In[32]:
memory_bw_series = analyzer.get_memory_bw_time_series()
# In[33]:
memory_bw_series
# In[34]:
ql_series = analyzer.get_queue_length_time_series()
# In[35]:
ql_series
# In[36]:
ql_summary = analyzer.get_queue_length_summary()
# In[37]:
ql_summary
# In[38]:
annotation = "ProfilerStep"
instance_id = (0)
cp_graph, success = analyzer.critical_path_analysis(rank = 0, annotation=annotation, instance_id=instance_id)
cp_graph.summary()
# In[39]:
analyzer.overlay_critical_path_analysis(0, cp_graph, output_dir='traces/overlaid')
# In[40]:
cuda_sequences_df = analyzer.get_frequent_cuda_kernel_sequences(operator_name="cu", output_dir = "/tmp/")
# In[42]:
cuda_sequences_df

6.可视化

A.优化前

在这里插入图片描述
在这里插入图片描述

B.优化后

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

相关文章:

以Bert训练为例,测试torch不同的运行方式,并用torch.profile+HolisticTraceAnalysis分析性能瓶颈

以Bert训练为例,测试torch不同的运行方式,并用torch.profileHolisticTraceAnalysis分析性能瓶颈 1.参考链接:2.性能对比3.相关依赖或命令4.测试代码5.HolisticTraceAnalysis代码6.可视化A.优化前B.优化后 以Bert训练为例,测试torch不同的运行方式,并用torch.profileHolisticTra…...

地球地图:快速进行先进土地监测和气候评估的新工具Earth Map

地球地图:快速进行先进土地监测和气候评估的新工具 这个工具是居于GEE 开发的多功能的一个APP应用,主要进行土地监测和气候评估 Earth Map 什么是地球地图? 地球地图是联合国粮食及农业组织(粮农组织)在粮农组织与谷歌合作框架内开发的一个创新、免费和开放源码的工具。…...

6.22套题

B. Dark 题意:每次能在数列中能使相邻两个数-1,求当数列没有连续非0值的最小贡献 解法:设表示前i个数中前i-1个数是否为0,当前数是j的最小贡献。表示i1以后减掉d的最小贡献。 C. 幸运值 D. 凤凰院真凶...

openEuler搭建hadoop Standalone 模式

Standalone 升级软件安装常用软件关闭防火墙修改主机名和IP地址修改hosts配置文件下载jdk和hadoop并配置环境变量配置ssh免密钥登录修改配置文件初始化集群windows修改hosts文件测试 1、升级软件 yum -y update2、安装常用软件 yum -y install gcc gcc-c autoconf automake…...

nginx更新https/ssl证书的步骤

一、上传nginx证书到服务器 上传步骤略。。。 二、更新证书 (一)确认nginx的安装目录 我这里的环境是/etc/nginx/ (二)确认nginx的证书目录 查看/etc/nginx/nginx.conf,证书目录就在/etc/nginx目录下 将新的证书tes…...

【Android面试八股文】说一说Handler的sendMessage和postDelay的区别?

文章目录 一、`sendMessage` 方法1.1 主要用法1.2 适用场景二、`postDelayed` 方法2.1 主要用法2.2 适用场景三、 区别总结3.1 区别3.2 本质上有差别吗?四、实例对比4.1 使用`sendMessage`4.2 使用`postDelayed`五、结论Handler类在Android中用于消息传递和任务调度。 sendMe…...

Java学习 - Redis主从复制

主从复制是什么 用于建立一个和主数据库完全一样的数据库环境,称为从数据库 主从复制的作用 数据备份读写分离 主从复制使用方式 通过slaveof命令 创建从节点 redis-slave> slaveof 127.0.0.1 6379取消从节点 redis-slave> slaveof no one通过配置 配置…...

图的拓扑排序

图的拓扑排序(Topological Sorting)是一种线性排序,用于有向无环图(Directed Acyclic Graph,DAG)。拓扑排序将图中的顶点排成一个线性序列,使得对于每一条有向边 (u, v),顶点 u 都排…...

windows USB 设备驱动开发-总章

通用串行总线 (USB) 提供可扩展的即插即用串行接口,确保外围设备的标准、低成本的连接。 USB 设备包括键盘、鼠标、游戏杆、打印机、扫描仪、存储设备、调制解调器、视频会议摄像头等。USB-IF 是一个特别兴趣组 (SIG),负责维护官方 USB 规范、测试规范和…...

springboot解析自定义yml文件

背景 公司产品微服务架构下有十几个模块,几乎大部分模块都要连接redis。每次在客户那里部署应用,都要改十几遍配置,太痛苦了。当然可以用nacos配置中心的功能,配置公共参数。不过我是喜欢在应用级别上解决问题,因为并不…...

【C/C++】静态函数调用类中成员函数方法 -- 最快捷之一

背景 注册回调函数中,回调函数是一个静态函数。需要调用类对象中的一个成员函数进行后续通知逻辑。 方案 定义全局指针,用于指向类对象this指针 static void *s_this_obj;类构造函数中,将全局指针指向所需类的this指针 s_this_obj this…...

佣金的定义和类型

1. 佣金的定义 基本定义:佣金是指在商业交易中,代理人或中介机构为促成交易所获得的报酬。它通常是按交易金额的一定比例计算和支付的。支付方式:佣金可以是固定金额,也可以是交易金额的百分比。 2. 佣金的类型 销售佣金&#…...

python数据分析实训任务二(‘风力风向’)

import numpy as np import matplotlib.pyplot as plt # 数据 labelsnp.array([东风, 东北风, 北风, 西北风, 西风, 西南风, 南风, 东南风]) statsnp.array([2.1, 2, 0, 3, 1.5, 3, 6, 4]) # 将角度转换为弧度 anglesnp.linspace(0, 2*np.pi, len(labels), endpointFalse).toli…...

Java技术栈总结:数据库MySQL篇

一、慢查询 1、常见情形 聚合查询 多表查询 表数据量过大查询 深度分页查询 2、定位慢查询 方案一、开源工具 调试工具:Arthas运维工具:Prometheus、Skywalking 方案二、MySQL自带慢日志 在MySQL配置文件 /etc/my.conf 中配置: # …...

vue-cli 项目打包优化-基础篇

1、项目打包完运行空白 引用资源路径问题,打包完的【index.html】文件引用其他文件的引用地址不对 参考配置:https://cli.vuejs.org/zh/config 修改vue.config.js ,根据与 后端 或 运维 沟通修改 module.export {// 默认 publicPath: //…...

24/06/26(1.1129)动态内存

strtok 字符串分割函数 #include<stdio.h> int main(){ char str[] "this,a sample string."; char* sep ","; char* pch strtok(str, sep); printf("%s\n", pch); while (pch ! NULL){ printf("%s\…...

基于 elementUI / elementUI plus,实现 主要色(主题色)的一件换色(换肤)

一、效果图 二、方法 改变elementUI 的主要色 --el-color-primary 为自己选择的颜色&#xff0c;核心代码如下&#xff1a; // 处理主题样式 export function handleThemeStyle(theme) {document.documentElement.style.setProperty(--el-color-primary, theme) } 三、全部代…...

js 计算某个日期加月份最后月份不会增加或者跳变

/** * * param {*} dateString 原来日期 2023-12-31 * param {*} months 加月份 2 * returns 2024-02-29 */ export function getDateByMonth(dateString, months0) { console.log(1); let oldMonths dateString.substring(0,7); let day dateString.substring(8); let …...

Git简介与详细教程

一、简介 什么是Git&#xff1f; Git是一款分布式版本控制系统&#xff0c;由Linux之父Linus Torvalds于2005年开发。它旨在快速、高效地处理从小型到大型项目的所有内容。Git与传统的版本控制系统相比&#xff0c;具备显著的优势&#xff0c;主要体现在其分布式架构、强大的…...

创建OpenWRT虚拟机

环境&#xff1a;Ubuntu 2204&#xff0c;VM VirtualBox 7.0.18 安装必备软件包&#xff1a; sudo apt update sudo apt install subversion automake make cmake uuid-dev gcc vim build-essential clang flex bison g gawk gcc-multilib g-multilib gettext git libncurses…...

QMC5883L的驱动

简介 本篇文章的代码已经上传到了github上面&#xff0c;开源代码 作为一个电子罗盘模块&#xff0c;我们可以通过I2C从中获取偏航角yaw&#xff0c;相对于六轴陀螺仪的yaw&#xff0c;qmc5883l几乎不会零飘并且成本较低。 参考资料 QMC5883L磁场传感器驱动 QMC5883L磁力计…...

线程与协程

1. 线程与协程 1.1. “函数调用级别”的切换、上下文切换 1. 函数调用级别的切换 “函数调用级别的切换”是指&#xff1a;像函数调用/返回一样轻量地完成任务切换。 举例说明&#xff1a; 当你在程序中写一个函数调用&#xff1a; funcA() 然后 funcA 执行完后返回&…...

在 Nginx Stream 层“改写”MQTT ngx_stream_mqtt_filter_module

1、为什么要修改 CONNECT 报文&#xff1f; 多租户隔离&#xff1a;自动为接入设备追加租户前缀&#xff0c;后端按 ClientID 拆分队列。零代码鉴权&#xff1a;将入站用户名替换为 OAuth Access-Token&#xff0c;后端 Broker 统一校验。灰度发布&#xff1a;根据 IP/地理位写…...

测试markdown--肇兴

day1&#xff1a; 1、去程&#xff1a;7:04 --11:32高铁 高铁右转上售票大厅2楼&#xff0c;穿过候车厅下一楼&#xff0c;上大巴车 &#xffe5;10/人 **2、到达&#xff1a;**12点多到达寨子&#xff0c;买门票&#xff0c;美团/抖音&#xff1a;&#xffe5;78人 3、中饭&a…...

跨链模式:多链互操作架构与性能扩展方案

跨链模式&#xff1a;多链互操作架构与性能扩展方案 ——构建下一代区块链互联网的技术基石 一、跨链架构的核心范式演进 1. 分层协议栈&#xff1a;模块化解耦设计 现代跨链系统采用分层协议栈实现灵活扩展&#xff08;H2Cross架构&#xff09;&#xff1a; 适配层&#xf…...

CocosCreator 之 JavaScript/TypeScript和Java的相互交互

引擎版本&#xff1a; 3.8.1 语言&#xff1a; JavaScript/TypeScript、C、Java 环境&#xff1a;Window 参考&#xff1a;Java原生反射机制 您好&#xff0c;我是鹤九日&#xff01; 回顾 在上篇文章中&#xff1a;CocosCreator Android项目接入UnityAds 广告SDK。 我们简单讲…...

Python爬虫(一):爬虫伪装

一、网站防爬机制概述 在当今互联网环境中&#xff0c;具有一定规模或盈利性质的网站几乎都实施了各种防爬措施。这些措施主要分为两大类&#xff1a; 身份验证机制&#xff1a;直接将未经授权的爬虫阻挡在外反爬技术体系&#xff1a;通过各种技术手段增加爬虫获取数据的难度…...

NFT模式:数字资产确权与链游经济系统构建

NFT模式&#xff1a;数字资产确权与链游经济系统构建 ——从技术架构到可持续生态的范式革命 一、确权技术革新&#xff1a;构建可信数字资产基石 1. 区块链底层架构的进化 跨链互操作协议&#xff1a;基于LayerZero协议实现以太坊、Solana等公链资产互通&#xff0c;通过零知…...

MySQL中【正则表达式】用法

MySQL 中正则表达式通过 REGEXP 或 RLIKE 操作符实现&#xff08;两者等价&#xff09;&#xff0c;用于在 WHERE 子句中进行复杂的字符串模式匹配。以下是核心用法和示例&#xff1a; 一、基础语法 SELECT column_name FROM table_name WHERE column_name REGEXP pattern; …...

Device Mapper 机制

Device Mapper 机制详解 Device Mapper&#xff08;简称 DM&#xff09;是 Linux 内核中的一套通用块设备映射框架&#xff0c;为 LVM、加密磁盘、RAID 等提供底层支持。本文将详细介绍 Device Mapper 的原理、实现、内核配置、常用工具、操作测试流程&#xff0c;并配以详细的…...