当前位置: 首页 > news >正文

如何精确统计Pytorch模型推理时间

文章目录

  • 0 背景
  • 1 精确统计方法
  • 2 手动synchronize和Event适用场景


0 背景

在分析模型性能时需要精确地统计出模型的推理时间,但仅仅通过在模型推理前后打时间戳然后相减得到的时间其实是Host侧向Device侧下发指令的时间。如下图所示,Host侧下发指令与Device侧计算实际上是异步进行的。
在这里插入图片描述


1 精确统计方法

比较常用的精确统计方法有两种,一种是手动调用同步函数等待Device侧计算完成。另一种是通过Event方法在Device侧记录时间戳。
在这里插入图片描述

下面示例代码中分别给出了直接在模型推理前后打时间戳相减,使用同步函数以及Event方法统计模型推理时间(每种方法都重复50次,忽略前5次推理,取后45次的平均值)。

import timeimport torch
import torch.nn as nnclass CustomModel(nn.Module):def __init__(self):super().__init__()self.part0 = nn.Sequential(nn.Conv2d(in_channels=3, out_channels=512, kernel_size=3, stride=2, padding=1),nn.GELU(),nn.Conv2d(in_channels=512, out_channels=1024, kernel_size=3, stride=2, padding=1),nn.GELU())self.part1 = nn.Sequential(nn.AdaptiveAvgPool2d(output_size=(1, 1)),nn.Flatten(),nn.Linear(in_features=1024, out_features=2048),nn.GELU(),nn.Linear(in_features=2048, out_features=512),nn.GELU(),nn.Linear(in_features=512, out_features=1))def forward(self, x):x = self.part0(x)x = self.part1(x)return xdef cal_time1(model, x):with torch.inference_mode():time_list = []for _ in range(50):ts = time.perf_counter()ret = model(x)td = time.perf_counter()time_list.append(td - ts)print(f"avg time: {sum(time_list[5:]) / len(time_list[5:]):.5f}")def cal_time2(model, x):device = x.devicewith torch.inference_mode():time_list = []for _ in range(50):torch.cuda.synchronize(device)ts = time.perf_counter()ret = model(x)torch.cuda.synchronize(device)td = time.perf_counter()time_list.append(td - ts)print(f"syn avg time: {sum(time_list[5:]) / len(time_list[5:]):.5f}")def cal_time3(model, x):with torch.inference_mode():start_event = torch.cuda.Event(enable_timing=True)end_event = torch.cuda.Event(enable_timing=True)time_list = []for _ in range(50):start_event.record()ret = model(x)end_event.record()end_event.synchronize()time_list.append(start_event.elapsed_time(end_event) / 1000)print(f"event avg time: {sum(time_list[5:]) / len(time_list[5:]):.5f}")def main():device = torch.device("cuda:0")model = CustomModel().eval().to(device)x = torch.randn(size=(32, 3, 224, 224), device=device)cal_time1(model, x)cal_time2(model, x)cal_time3(model, x)if __name__ == '__main__':main()

终端输出:

avg time: 0.00023
syn avg time: 0.04709
event avg time: 0.04710

通过终端输出可以看到,如果直接在模型推理前后打时间戳相减得到的时间非常短(因为并没有等待Device侧计算完成)。而使用同步函数或者Event方法统计的时间明显要长很多。


2 手动synchronize和Event适用场景

通过上面的代码示例可以看到,通过同步函数统计的时间和Event方法统计的时间基本一致(差异1ms内)。那两者有什么区别呢?如果只是简单统计一个模型的推理时间确实看不出什么差异。但如果要统计一个完整AI应用通路(其中可能包含多个模型以及各种CPU计算)中不同模型的耗时,而又不想影响到整个通路的性能,那么建议使用Event方法。因为使用同步函数可能会让Host长期处于等待状态,等待过程中也无法干其他的事情,从而导致计算资源的浪费。可以看看下面这个示例,整个通路由Model1推理+一段纯CPU计算+Model2推理串行构成,假设想统计一下model1、model2推理分别用了多长时间:

import timeimport torch
import torch.nn as nn
import numpy as npclass CustomModel1(nn.Module):def __init__(self):super().__init__()self.part0 = nn.Sequential(nn.Conv2d(in_channels=3, out_channels=512, kernel_size=3, stride=2, padding=1),nn.GELU(),nn.Conv2d(in_channels=512, out_channels=1024, kernel_size=3, stride=2, padding=1),nn.GELU())def forward(self, x):x = self.part0(x)return xclass CustomModel2(nn.Module):def __init__(self):super().__init__()self.part1 = nn.Sequential(nn.AdaptiveAvgPool2d(output_size=(1, 1)),nn.Flatten(),nn.Linear(in_features=1024, out_features=2048),nn.GELU(),nn.Linear(in_features=2048, out_features=512),nn.GELU(),nn.Linear(in_features=512, out_features=1))def forward(self, x):x = self.part1(x)return xdef do_pure_cpu_task():x = np.random.randn(1, 3, 512, 512)x = x.astype(np.float32)x = x * 1024 ** 0.5def cal_time2(model1, model2, x):device = x.devicewith torch.inference_mode():time_total_list = []time_model1_list = []time_model2_list = []for _ in range(50):torch.cuda.synchronize(device)ts1 = time.perf_counter()ret = model1(x)torch.cuda.synchronize(device)td1 = time.perf_counter()do_pure_cpu_task()torch.cuda.synchronize(device)ts2 = time.perf_counter()ret = model2(ret)torch.cuda.synchronize(device)td2 = time.perf_counter()time_model1_list.append(td1 - ts1)time_model2_list.append(td2 - ts2)time_total_list.append(td2 - ts1)avg_model1 = sum(time_model1_list[5:]) / len(time_model1_list[5:])avg_model2 = sum(time_model2_list[5:]) / len(time_model2_list[5:])avg_total = sum(time_total_list[5:]) / len(time_total_list[5:])print(f"syn avg model1 time: {avg_model1:.5f}, model2 time: {avg_model2:.5f}, total time: {avg_total:.5f}")def cal_time3(model1, model2, x):with torch.inference_mode():model1_start_event = torch.cuda.Event(enable_timing=True)model1_end_event = torch.cuda.Event(enable_timing=True)model2_start_event = torch.cuda.Event(enable_timing=True)model2_end_event = torch.cuda.Event(enable_timing=True)time_total_list = []time_model1_list = []time_model2_list = []for _ in range(50):model1_start_event.record()ret = model1(x)model1_end_event.record()do_pure_cpu_task()model2_start_event.record()ret = model2(ret)model2_end_event.record()model2_end_event.synchronize()time_model1_list.append(model1_start_event.elapsed_time(model1_end_event) / 1000)time_model2_list.append(model2_start_event.elapsed_time(model2_end_event) / 1000)time_total_list.append(model1_start_event.elapsed_time(model2_end_event) / 1000)avg_model1 = sum(time_model1_list[5:]) / len(time_model1_list[5:])avg_model2 = sum(time_model2_list[5:]) / len(time_model2_list[5:])avg_total = sum(time_total_list[5:]) / len(time_total_list[5:])print(f"event avg model1 time: {avg_model1:.5f}, model2 time: {avg_model2:.5f}, total time: {avg_total:.5f}")def main():device = torch.device("cuda:0")model1 = CustomModel1().eval().to(device)model2 = CustomModel2().eval().to(device)x = torch.randn(size=(32, 3, 224, 224), device=device)cal_time2(model1, model2, x)cal_time3(model1, model2, x)if __name__ == '__main__':main()

终端输出:

syn avg model1 time: 0.04725, model2 time: 0.00125, total time: 0.05707
event avg model1 time: 0.04697, model2 time: 0.00099, total time: 0.04797

通过终端打印的结果可以看到无论是使用同步函数还是Event方法统计的model1、model2的推理时间基本是一致的。但对于整个通路而言使用同步函数时总时间明显变长了。下图大致解释了为什么使用同步函数时导致整个通路变长的原因,主要是在model1发送完指令后使用同步函数时会一直等待Device侧计算结束,期间啥也不能干。而使用Event方法时在model1发送完指令后不会阻塞Host,可以立马去进行后面的CPU计算任务。

在这里插入图片描述

相关文章:

如何精确统计Pytorch模型推理时间

文章目录 0 背景1 精确统计方法2 手动synchronize和Event适用场景 0 背景 在分析模型性能时需要精确地统计出模型的推理时间,但仅仅通过在模型推理前后打时间戳然后相减得到的时间其实是Host侧向Device侧下发指令的时间。如下图所示,Host侧下发指令与De…...

Mybatis-plus-Generator 3.5.5 自定义模板支持 (DTO/VO 等) 配置

随着项目节奏越来越快,为了减少把时间浪费在新建DTO 、VO 等地方,直接直接基于Mybatis-plus 这颗大树稍微扩展一下,在原来生成PO、 DAO、Service、ServiceImpl、Controller 基础新增。为了解决这个问题,网上找了一堆资料&#xff…...

C#环境下MAC地址获取方法解析

在C#中,获取MAC地址并不是直接支持的,因为出于安全和隐私的考虑,操作系统通常会限制对这类硬件信息的直接访问。不过,仍然可以通过一些方法间接地获取到本地网络接口(比如以太网接口)的MAC地址。 以下是几…...

(k8s)Kubernetes 从0到1容器编排之旅

一、引言 在当今数字化的浪潮中,Kubernetes 如同一艘强大的航船,引领着容器化应用的部署与管理。它以其卓越的灵活性、可扩展性和可靠性,成为众多企业和开发者的首选。然而,要真正发挥 Kubernetes 的强大威力,仅仅掌握…...

Rust Web开发框架对比:Warp与Actix-web

文章目录 Rust Web开发框架对比:Warp与Actix-web引言框架概述Warp框架简介Actix-web框架简介 设计理念Warp的设计理念Actix-web的设计理念 性能比较可扩展性和生态插件和中间件支持社区和文档 使用示例使用Warp构建简单的HTTP服务使用Actix-web构建简单的HTTP服务 学…...

F12抓包12:Performance(性能)前端性能分析

课程大纲 使用场景: ① 前端界面加载性能测试。 ② 导出性能报告给前端开发。 复习:后端(接口)性能分析 ① 所有请求耗时时间轴:“网络”(Network) - 概览。 ② 单个请求耗时:“网络”(Network&#xf…...

数据结构(Day13)

一、学习内容 内存空间划分 1、一个进程启动后,计算机会给该进程分配4G的虚拟内存 2、其中0G-3G是用户空间【程序员写代码操作部分】【应用层】 3、3G-4G是内核空间【与底层驱动有关】 4、所有进程共享3G-4G的内核空间,每个进程独立拥有0G-3G的用户空间 …...

链表的快速排序(C/C++实现)

一、前言 大家在做需要排名的项目的时候,需要把各种数据从高到低排序。如果用的快速排序的话,处理数组是十分简单的。因为数组的存储空间的连续的,可以通过下标就可以简单的实现。但如果是链表的话,内存地址是随机分配的&#xf…...

css总结(记录一下...)

文字 语法说明word-wrapword-wrap:normal| break-word normal:使用浏览器默认的换行 break-word:允许在单词内换行 text-overflow clip:修剪文本 ellipsis:显示省略符号来代表被修剪的文本 text-shadow可向文本应用的阴影。能够规定水平阴影、垂直阴影、模糊距离,以…...

SpringBoot 处理 @KafkaListener 消息

消息监听容器 1、KafkaMessageListenerContainer 由spring提供用于监听以及拉取消息,并将这些消息按指定格式转换后交给由KafkaListener注解的方法处理,相当于一个消费者; 看看其整体代码结构: 可以发现其入口方法为doStart(),…...

Spring Boot-API版本控制问题

在现代软件开发中,API(应用程序接口)版本控制是一项至关重要的技术。随着应用的不断迭代,API 的改动不可避免,如何在引入新版本的同时保证向后兼容,避免对现有用户的影响,是每个开发者需要考虑的…...

Git 提取和拉取的区别在哪

1. 提取(Fetch) 操作说明:Fetch 操作会从远程仓库下载最新的提交、分支信息等,但不会将这些更改合并到你当前的分支中。它只是将远程仓库的更新信息存储在本地,并不会自动修改你当前的工作区。 使用场景: …...

【数据结构与算法 | 每日一题 | 力扣篇】力扣2390, 2848

1. 力扣2390:从字符串中删除星号 1.1 题目: 给你一个包含若干星号 * 的字符串 s 。 在一步操作中,你可以: 选中 s 中的一个星号。移除星号 左侧 最近的那个 非星号 字符,并移除该星号自身。 返回移除 所有 星号之…...

破解信息架构实施的密码:常见挑战与最佳解决方案全指南

信息架构的成功实施是企业数字化转型的关键步骤,但在实际操作中,企业往往会遇到各种复杂的挑战。这些挑战包括 技术整合的难度、数据管理的复杂性、合规性要求的变化 以及 资源限制 等。《信息架构:商业智能&分析与元数据管理参考模型》为…...

CodeChef Starters 151 (Div.2) A~D

codechef是真敢给分,上把刚注册,这把就div2了,再加上一周没打过还是有点不适应的,好在最后还是能够顺利上分 今天的封面是P3R的设置菜单 我抠出来做我自己的游戏主页了( A - Convert string 题意 在01串里面可以翻转…...

Redis学习——数据不一致怎么办?更新缓存失败了又怎么办?

文章目录 引言正文读写缓存的数据一致性只读缓存的数据一致性删除和修改数据不一致问题操作执行失败导致数据不一致解决办法 多线程访问导致数据不一致问题总结 总结参考信息 引言 最近面试快手的时候被问到了缓存不一致怎么解决?一开始还是很懵的,因为…...

跨境电商代购新纪元:一键解锁全球好物,系统流程全揭秘

添加图片注释,不超过 140 字(可选) 在全球化日益加深的今天,跨境电商代购成为了连接消费者与世界各地优质商品的桥梁。本文将在CSDN平台上,深入剖析跨境电商代购系统的功能流程,带您一窥其背后的技术奥秘与…...

Mac 上终端使用 MySql 记录

文章目录 下载安装终端进入 MySql常用操作查看数据库选择一个数据库查看当前选择的数据库Navcat 打开提示报错参考文章 下载安装 先下载社区版的 MySql 安装的过程需要设置 root 的密码,这个是要进入数据库所设定的,所以要记住 终端进入 MySql 首先输…...

461. 汉明距离

一:题目: 两个整数之间的 汉明距离 指的是这两个数字对应二进制位不同的位置的数目。 给你两个整数 x 和 y,计算并返回它们之间的汉明距离。 示例 1: 输入:x 1, y 4 输出:2 解释: 1 (0 0…...

开发指南061-nexus权限管理

平台后台服务的核心是组件,管理组件的软件有: Apache的Archiva、JFrog的Artifactory、Sonatype的Nexus。 本平台选择nexus。nexus的权限模型是用户-角色-权限体系:通过组合权限定义角色,通过给用户赋角色来赋权限。有关nexus的权…...

JavaSec-RCE

简介 RCE(Remote Code Execution),可以分为:命令注入(Command Injection)、代码注入(Code Injection) 代码注入 1.漏洞场景:Groovy代码注入 Groovy是一种基于JVM的动态语言,语法简洁,支持闭包、动态类型和Java互操作性&#xff0c…...

Cursor实现用excel数据填充word模版的方法

cursor主页:https://www.cursor.com/ 任务目标:把excel格式的数据里的单元格,按照某一个固定模版填充到word中 文章目录 注意事项逐步生成程序1. 确定格式2. 调试程序 注意事项 直接给一个excel文件和最终呈现的word文件的示例,…...

在HarmonyOS ArkTS ArkUI-X 5.0及以上版本中,手势开发全攻略:

在 HarmonyOS 应用开发中,手势交互是连接用户与设备的核心纽带。ArkTS 框架提供了丰富的手势处理能力,既支持点击、长按、拖拽等基础单一手势的精细控制,也能通过多种绑定策略解决父子组件的手势竞争问题。本文将结合官方开发文档&#xff0c…...

java调用dll出现unsatisfiedLinkError以及JNA和JNI的区别

UnsatisfiedLinkError 在对接硬件设备中,我们会遇到使用 java 调用 dll文件 的情况,此时大概率出现UnsatisfiedLinkError链接错误,原因可能有如下几种 类名错误包名错误方法名参数错误使用 JNI 协议调用,结果 dll 未实现 JNI 协…...

linux 错误码总结

1,错误码的概念与作用 在Linux系统中,错误码是系统调用或库函数在执行失败时返回的特定数值,用于指示具体的错误类型。这些错误码通过全局变量errno来存储和传递,errno由操作系统维护,保存最近一次发生的错误信息。值得注意的是,errno的值在每次系统调用或函数调用失败时…...

spring:实例工厂方法获取bean

spring处理使用静态工厂方法获取bean实例,也可以通过实例工厂方法获取bean实例。 实例工厂方法步骤如下: 定义实例工厂类(Java代码),定义实例工厂(xml),定义调用实例工厂&#xff…...

css的定位(position)详解:相对定位 绝对定位 固定定位

在 CSS 中,元素的定位通过 position 属性控制,共有 5 种定位模式:static(静态定位)、relative(相对定位)、absolute(绝对定位)、fixed(固定定位)和…...

【python异步多线程】异步多线程爬虫代码示例

claude生成的python多线程、异步代码示例,模拟20个网页的爬取,每个网页假设要0.5-2秒完成。 代码 Python多线程爬虫教程 核心概念 多线程:允许程序同时执行多个任务,提高IO密集型任务(如网络请求)的效率…...

ios苹果系统,js 滑动屏幕、锚定无效

现象:window.addEventListener监听touch无效,划不动屏幕,但是代码逻辑都有执行到。 scrollIntoView也无效。 原因:这是因为 iOS 的触摸事件处理机制和 touch-action: none 的设置有关。ios有太多得交互动作,从而会影响…...

SQL慢可能是触发了ring buffer

简介 最近在进行 postgresql 性能排查的时候,发现 PG 在某一个时间并行执行的 SQL 变得特别慢。最后通过监控监观察到并行发起得时间 buffers_alloc 就急速上升,且低水位伴随在整个慢 SQL,一直是 buferIO 的等待事件,此时也没有其他会话的争抢。SQL 虽然不是高效 SQL ,但…...