如何精确统计Pytorch模型推理时间
文章目录
- 0 背景
- 1 精确统计方法
- 2 手动synchronize和Event适用场景
0 背景
在分析模型性能时需要精确地统计出模型的推理时间,但仅仅通过在模型推理前后打时间戳然后相减得到的时间其实是Host侧向Device侧下发指令的时间。如下图所示,Host侧下发指令与Device侧计算实际上是异步进行的。

1 精确统计方法
比较常用的精确统计方法有两种,一种是手动调用同步函数等待Device侧计算完成。另一种是通过Event方法在Device侧记录时间戳。

下面示例代码中分别给出了直接在模型推理前后打时间戳相减,使用同步函数以及Event方法统计模型推理时间(每种方法都重复50次,忽略前5次推理,取后45次的平均值)。
import timeimport torch
import torch.nn as nnclass CustomModel(nn.Module):def __init__(self):super().__init__()self.part0 = nn.Sequential(nn.Conv2d(in_channels=3, out_channels=512, kernel_size=3, stride=2, padding=1),nn.GELU(),nn.Conv2d(in_channels=512, out_channels=1024, kernel_size=3, stride=2, padding=1),nn.GELU())self.part1 = nn.Sequential(nn.AdaptiveAvgPool2d(output_size=(1, 1)),nn.Flatten(),nn.Linear(in_features=1024, out_features=2048),nn.GELU(),nn.Linear(in_features=2048, out_features=512),nn.GELU(),nn.Linear(in_features=512, out_features=1))def forward(self, x):x = self.part0(x)x = self.part1(x)return xdef cal_time1(model, x):with torch.inference_mode():time_list = []for _ in range(50):ts = time.perf_counter()ret = model(x)td = time.perf_counter()time_list.append(td - ts)print(f"avg time: {sum(time_list[5:]) / len(time_list[5:]):.5f}")def cal_time2(model, x):device = x.devicewith torch.inference_mode():time_list = []for _ in range(50):torch.cuda.synchronize(device)ts = time.perf_counter()ret = model(x)torch.cuda.synchronize(device)td = time.perf_counter()time_list.append(td - ts)print(f"syn avg time: {sum(time_list[5:]) / len(time_list[5:]):.5f}")def cal_time3(model, x):with torch.inference_mode():start_event = torch.cuda.Event(enable_timing=True)end_event = torch.cuda.Event(enable_timing=True)time_list = []for _ in range(50):start_event.record()ret = model(x)end_event.record()end_event.synchronize()time_list.append(start_event.elapsed_time(end_event) / 1000)print(f"event avg time: {sum(time_list[5:]) / len(time_list[5:]):.5f}")def main():device = torch.device("cuda:0")model = CustomModel().eval().to(device)x = torch.randn(size=(32, 3, 224, 224), device=device)cal_time1(model, x)cal_time2(model, x)cal_time3(model, x)if __name__ == '__main__':main()
终端输出:
avg time: 0.00023
syn avg time: 0.04709
event avg time: 0.04710
通过终端输出可以看到,如果直接在模型推理前后打时间戳相减得到的时间非常短(因为并没有等待Device侧计算完成)。而使用同步函数或者Event方法统计的时间明显要长很多。
2 手动synchronize和Event适用场景
通过上面的代码示例可以看到,通过同步函数统计的时间和Event方法统计的时间基本一致(差异1ms内)。那两者有什么区别呢?如果只是简单统计一个模型的推理时间确实看不出什么差异。但如果要统计一个完整AI应用通路(其中可能包含多个模型以及各种CPU计算)中不同模型的耗时,而又不想影响到整个通路的性能,那么建议使用Event方法。因为使用同步函数可能会让Host长期处于等待状态,等待过程中也无法干其他的事情,从而导致计算资源的浪费。可以看看下面这个示例,整个通路由Model1推理+一段纯CPU计算+Model2推理串行构成,假设想统计一下model1、model2推理分别用了多长时间:
import timeimport torch
import torch.nn as nn
import numpy as npclass CustomModel1(nn.Module):def __init__(self):super().__init__()self.part0 = nn.Sequential(nn.Conv2d(in_channels=3, out_channels=512, kernel_size=3, stride=2, padding=1),nn.GELU(),nn.Conv2d(in_channels=512, out_channels=1024, kernel_size=3, stride=2, padding=1),nn.GELU())def forward(self, x):x = self.part0(x)return xclass CustomModel2(nn.Module):def __init__(self):super().__init__()self.part1 = nn.Sequential(nn.AdaptiveAvgPool2d(output_size=(1, 1)),nn.Flatten(),nn.Linear(in_features=1024, out_features=2048),nn.GELU(),nn.Linear(in_features=2048, out_features=512),nn.GELU(),nn.Linear(in_features=512, out_features=1))def forward(self, x):x = self.part1(x)return xdef do_pure_cpu_task():x = np.random.randn(1, 3, 512, 512)x = x.astype(np.float32)x = x * 1024 ** 0.5def cal_time2(model1, model2, x):device = x.devicewith torch.inference_mode():time_total_list = []time_model1_list = []time_model2_list = []for _ in range(50):torch.cuda.synchronize(device)ts1 = time.perf_counter()ret = model1(x)torch.cuda.synchronize(device)td1 = time.perf_counter()do_pure_cpu_task()torch.cuda.synchronize(device)ts2 = time.perf_counter()ret = model2(ret)torch.cuda.synchronize(device)td2 = time.perf_counter()time_model1_list.append(td1 - ts1)time_model2_list.append(td2 - ts2)time_total_list.append(td2 - ts1)avg_model1 = sum(time_model1_list[5:]) / len(time_model1_list[5:])avg_model2 = sum(time_model2_list[5:]) / len(time_model2_list[5:])avg_total = sum(time_total_list[5:]) / len(time_total_list[5:])print(f"syn avg model1 time: {avg_model1:.5f}, model2 time: {avg_model2:.5f}, total time: {avg_total:.5f}")def cal_time3(model1, model2, x):with torch.inference_mode():model1_start_event = torch.cuda.Event(enable_timing=True)model1_end_event = torch.cuda.Event(enable_timing=True)model2_start_event = torch.cuda.Event(enable_timing=True)model2_end_event = torch.cuda.Event(enable_timing=True)time_total_list = []time_model1_list = []time_model2_list = []for _ in range(50):model1_start_event.record()ret = model1(x)model1_end_event.record()do_pure_cpu_task()model2_start_event.record()ret = model2(ret)model2_end_event.record()model2_end_event.synchronize()time_model1_list.append(model1_start_event.elapsed_time(model1_end_event) / 1000)time_model2_list.append(model2_start_event.elapsed_time(model2_end_event) / 1000)time_total_list.append(model1_start_event.elapsed_time(model2_end_event) / 1000)avg_model1 = sum(time_model1_list[5:]) / len(time_model1_list[5:])avg_model2 = sum(time_model2_list[5:]) / len(time_model2_list[5:])avg_total = sum(time_total_list[5:]) / len(time_total_list[5:])print(f"event avg model1 time: {avg_model1:.5f}, model2 time: {avg_model2:.5f}, total time: {avg_total:.5f}")def main():device = torch.device("cuda:0")model1 = CustomModel1().eval().to(device)model2 = CustomModel2().eval().to(device)x = torch.randn(size=(32, 3, 224, 224), device=device)cal_time2(model1, model2, x)cal_time3(model1, model2, x)if __name__ == '__main__':main()
终端输出:
syn avg model1 time: 0.04725, model2 time: 0.00125, total time: 0.05707
event avg model1 time: 0.04697, model2 time: 0.00099, total time: 0.04797
通过终端打印的结果可以看到无论是使用同步函数还是Event方法统计的model1、model2的推理时间基本是一致的。但对于整个通路而言使用同步函数时总时间明显变长了。下图大致解释了为什么使用同步函数时导致整个通路变长的原因,主要是在model1发送完指令后使用同步函数时会一直等待Device侧计算结束,期间啥也不能干。而使用Event方法时在model1发送完指令后不会阻塞Host,可以立马去进行后面的CPU计算任务。

相关文章:
如何精确统计Pytorch模型推理时间
文章目录 0 背景1 精确统计方法2 手动synchronize和Event适用场景 0 背景 在分析模型性能时需要精确地统计出模型的推理时间,但仅仅通过在模型推理前后打时间戳然后相减得到的时间其实是Host侧向Device侧下发指令的时间。如下图所示,Host侧下发指令与De…...
Mybatis-plus-Generator 3.5.5 自定义模板支持 (DTO/VO 等) 配置
随着项目节奏越来越快,为了减少把时间浪费在新建DTO 、VO 等地方,直接直接基于Mybatis-plus 这颗大树稍微扩展一下,在原来生成PO、 DAO、Service、ServiceImpl、Controller 基础新增。为了解决这个问题,网上找了一堆资料ÿ…...
C#环境下MAC地址获取方法解析
在C#中,获取MAC地址并不是直接支持的,因为出于安全和隐私的考虑,操作系统通常会限制对这类硬件信息的直接访问。不过,仍然可以通过一些方法间接地获取到本地网络接口(比如以太网接口)的MAC地址。 以下是几…...
(k8s)Kubernetes 从0到1容器编排之旅
一、引言 在当今数字化的浪潮中,Kubernetes 如同一艘强大的航船,引领着容器化应用的部署与管理。它以其卓越的灵活性、可扩展性和可靠性,成为众多企业和开发者的首选。然而,要真正发挥 Kubernetes 的强大威力,仅仅掌握…...
Rust Web开发框架对比:Warp与Actix-web
文章目录 Rust Web开发框架对比:Warp与Actix-web引言框架概述Warp框架简介Actix-web框架简介 设计理念Warp的设计理念Actix-web的设计理念 性能比较可扩展性和生态插件和中间件支持社区和文档 使用示例使用Warp构建简单的HTTP服务使用Actix-web构建简单的HTTP服务 学…...
F12抓包12:Performance(性能)前端性能分析
课程大纲 使用场景: ① 前端界面加载性能测试。 ② 导出性能报告给前端开发。 复习:后端(接口)性能分析 ① 所有请求耗时时间轴:“网络”(Network) - 概览。 ② 单个请求耗时:“网络”(Network…...
数据结构(Day13)
一、学习内容 内存空间划分 1、一个进程启动后,计算机会给该进程分配4G的虚拟内存 2、其中0G-3G是用户空间【程序员写代码操作部分】【应用层】 3、3G-4G是内核空间【与底层驱动有关】 4、所有进程共享3G-4G的内核空间,每个进程独立拥有0G-3G的用户空间 …...
链表的快速排序(C/C++实现)
一、前言 大家在做需要排名的项目的时候,需要把各种数据从高到低排序。如果用的快速排序的话,处理数组是十分简单的。因为数组的存储空间的连续的,可以通过下标就可以简单的实现。但如果是链表的话,内存地址是随机分配的…...
css总结(记录一下...)
文字 语法说明word-wrapword-wrap:normal| break-word normal:使用浏览器默认的换行 break-word:允许在单词内换行 text-overflow clip:修剪文本 ellipsis:显示省略符号来代表被修剪的文本 text-shadow可向文本应用的阴影。能够规定水平阴影、垂直阴影、模糊距离,以…...
SpringBoot 处理 @KafkaListener 消息
消息监听容器 1、KafkaMessageListenerContainer 由spring提供用于监听以及拉取消息,并将这些消息按指定格式转换后交给由KafkaListener注解的方法处理,相当于一个消费者; 看看其整体代码结构: 可以发现其入口方法为doStart(),…...
Spring Boot-API版本控制问题
在现代软件开发中,API(应用程序接口)版本控制是一项至关重要的技术。随着应用的不断迭代,API 的改动不可避免,如何在引入新版本的同时保证向后兼容,避免对现有用户的影响,是每个开发者需要考虑的…...
Git 提取和拉取的区别在哪
1. 提取(Fetch) 操作说明:Fetch 操作会从远程仓库下载最新的提交、分支信息等,但不会将这些更改合并到你当前的分支中。它只是将远程仓库的更新信息存储在本地,并不会自动修改你当前的工作区。 使用场景: …...
【数据结构与算法 | 每日一题 | 力扣篇】力扣2390, 2848
1. 力扣2390:从字符串中删除星号 1.1 题目: 给你一个包含若干星号 * 的字符串 s 。 在一步操作中,你可以: 选中 s 中的一个星号。移除星号 左侧 最近的那个 非星号 字符,并移除该星号自身。 返回移除 所有 星号之…...
破解信息架构实施的密码:常见挑战与最佳解决方案全指南
信息架构的成功实施是企业数字化转型的关键步骤,但在实际操作中,企业往往会遇到各种复杂的挑战。这些挑战包括 技术整合的难度、数据管理的复杂性、合规性要求的变化 以及 资源限制 等。《信息架构:商业智能&分析与元数据管理参考模型》为…...
CodeChef Starters 151 (Div.2) A~D
codechef是真敢给分,上把刚注册,这把就div2了,再加上一周没打过还是有点不适应的,好在最后还是能够顺利上分 今天的封面是P3R的设置菜单 我抠出来做我自己的游戏主页了( A - Convert string 题意 在01串里面可以翻转…...
Redis学习——数据不一致怎么办?更新缓存失败了又怎么办?
文章目录 引言正文读写缓存的数据一致性只读缓存的数据一致性删除和修改数据不一致问题操作执行失败导致数据不一致解决办法 多线程访问导致数据不一致问题总结 总结参考信息 引言 最近面试快手的时候被问到了缓存不一致怎么解决?一开始还是很懵的,因为…...
跨境电商代购新纪元:一键解锁全球好物,系统流程全揭秘
添加图片注释,不超过 140 字(可选) 在全球化日益加深的今天,跨境电商代购成为了连接消费者与世界各地优质商品的桥梁。本文将在CSDN平台上,深入剖析跨境电商代购系统的功能流程,带您一窥其背后的技术奥秘与…...
Mac 上终端使用 MySql 记录
文章目录 下载安装终端进入 MySql常用操作查看数据库选择一个数据库查看当前选择的数据库Navcat 打开提示报错参考文章 下载安装 先下载社区版的 MySql 安装的过程需要设置 root 的密码,这个是要进入数据库所设定的,所以要记住 终端进入 MySql 首先输…...
461. 汉明距离
一:题目: 两个整数之间的 汉明距离 指的是这两个数字对应二进制位不同的位置的数目。 给你两个整数 x 和 y,计算并返回它们之间的汉明距离。 示例 1: 输入:x 1, y 4 输出:2 解释: 1 (0 0…...
开发指南061-nexus权限管理
平台后台服务的核心是组件,管理组件的软件有: Apache的Archiva、JFrog的Artifactory、Sonatype的Nexus。 本平台选择nexus。nexus的权限模型是用户-角色-权限体系:通过组合权限定义角色,通过给用户赋角色来赋权限。有关nexus的权…...
MPNet:旋转机械轻量化故障诊断模型详解python代码复现
目录 一、问题背景与挑战 二、MPNet核心架构 2.1 多分支特征融合模块(MBFM) 2.2 残差注意力金字塔模块(RAPM) 2.2.1 空间金字塔注意力(SPA) 2.2.2 金字塔残差块(PRBlock) 2.3 分类器设计 三、关键技术突破 3.1 多尺度特征融合 3.2 轻量化设计策略 3.3 抗噪声…...
[10-3]软件I2C读写MPU6050 江协科技学习笔记(16个知识点)
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16...
企业如何增强终端安全?
在数字化转型加速的今天,企业的业务运行越来越依赖于终端设备。从员工的笔记本电脑、智能手机,到工厂里的物联网设备、智能传感器,这些终端构成了企业与外部世界连接的 “神经末梢”。然而,随着远程办公的常态化和设备接入的爆炸式…...
【无标题】路径问题的革命性重构:基于二维拓扑收缩色动力学模型的零点隧穿理论
路径问题的革命性重构:基于二维拓扑收缩色动力学模型的零点隧穿理论 一、传统路径模型的根本缺陷 在经典正方形路径问题中(图1): mermaid graph LR A((A)) --- B((B)) B --- C((C)) C --- D((D)) D --- A A -.- C[无直接路径] B -…...
Modbus RTU与Modbus TCP详解指南
目录 1. Modbus协议基础 1.1 什么是Modbus? 1.2 Modbus协议历史 1.3 Modbus协议族 1.4 Modbus通信模型 🎭 主从架构 🔄 请求响应模式 2. Modbus RTU详解 2.1 RTU是什么? 2.2 RTU物理层 🔌 连接方式 ⚡ 通信参数 2.3 RTU数据帧格式 📦 帧结构详解 🔍…...
ZYNQ学习记录FPGA(一)ZYNQ简介
一、知识准备 1.一些术语,缩写和概念: 1)ZYNQ全称:ZYNQ7000 All Pgrammable SoC 2)SoC:system on chips(片上系统),对比集成电路的SoB(system on board) 3)ARM:处理器…...
机器学习的数学基础:线性模型
线性模型 线性模型的基本形式为: f ( x ) ω T x b f\left(\boldsymbol{x}\right)\boldsymbol{\omega}^\text{T}\boldsymbol{x}b f(x)ωTxb 回归问题 利用最小二乘法,得到 ω \boldsymbol{\omega} ω和 b b b的参数估计$ \boldsymbol{\hat{\omega}}…...
怎么开发一个网络协议模块(C语言框架)之(六) ——通用对象池总结(核心)
+---------------------------+ | operEntryTbl[] | ← 操作对象池 (对象数组) +---------------------------+ | 0 | 1 | 2 | ... | N-1 | +---------------------------+↓ 初始化时全部加入 +------------------------+ +-------------------------+ | …...
【若依】框架项目部署笔记
参考【SpringBoot】【Vue】项目部署_no main manifest attribute, in springboot-0.0.1-sn-CSDN博客 多一个redis安装 准备工作: 压缩包下载:http://download.redis.io/releases 1. 上传压缩包,并进入压缩包所在目录,解压到目标…...
CVE-2023-25194源码分析与漏洞复现(Kafka JNDI注入)
漏洞概述 漏洞名称:Apache Kafka Connect JNDI注入导致的远程代码执行漏洞 CVE编号:CVE-2023-25194 CVSS评分:8.8 影响版本:Apache Kafka 2.3.0 - 3.3.2 修复版本:≥ 3.4.0 漏洞类型:反序列化导致的远程代…...
