通过解读yolov5_gpu_optimization学习如何使用onnx_surgon
onnx实战一: 解析yolov5 gpu的onnx优化案例:
这是一个英伟达的仓库, 这个仓库的做法就是通过用gs对onnx进行修改减少算子然后最后使用TensorRT插件实现算子, 左边是优化过的, 右边是原版的。 通过这个案例理解原版的onnx的导出流程然后我们看英伟达是怎么拿gs来优化这个onnx
原版的export_onnx函数
先看torch.onnx.export
函数的参数解释:
-
model: 要导出的PyTorch模型, 在工程中这里输入的是训练好的pt文件
-
im: 这里对应torch.onnx.export的args, 这个是用作模型输入的示例张量。这帮助ONNX确定输入的形状和类型。
-
f: 输出ONNX模型的文件名或文件对象, 用来指定导出模型的路径和文件名。
-
verbose (默认为
False
): 如果设置为True
,则会打印出模型导出时的详细日志。 -
opset_version: 导出的ONNX模型的操作集版本。不同的版本可能支持不同的操作。
-
training:
torch.onnx.TrainingMode.TRAINING
: 表示模型处于训练模式。torch.onnx.TrainingMode.EVAL
: 表示模型处于评估模式。
-
do_constant_folding (默认为
True
): 当设置为True
,导出过程中会尝试简化模型,将常量子图折叠为一个常量节点。 -
input_names: 为模型的输入提供名称, 参数规定是数组
-
output_names: 为模型的输出提供名称, 参数规定是数组
-
dynamic_axes: 为模型的输入/输出定义动态轴。对于那些维度在推理时可能会发生变化的情况(例如,批处理大小),此参数允许指定哪些轴是动态的。这里images是输入, 本来是1x3x640x640, 这里通过指定把0, 2, 3维度变成了动态轴的输入, 第二个维度是3这个还是固定的。如果使用动态, 可以输入任意数量和任意大小的图片而不是规定的单张640x640
'images'
: 对应的张量名称。0: 'batch'
: 表示第0个维度(即批处理维度)是动态的,并命名为’batch’。2: 'height'
: 表示第2个维度(即图像的高度)是动态的。3: 'width'
: 表示第3个维度(即图像的宽度)是动态的。
'output'
: 对应的张量名称。0: 'batch'
: 表示第0个维度(即批处理维度)是动态的。1: 'anchors'
: 表示第1个维度是动态的。
- dynamic (没有在给定的函数调用中明确给出,但可以从上下文推断):
True
: 如果你想让某些轴动态,你可以设置此参数为True
。False
: 表示不使用动态轴。
导出了onnx之后开始做onnxsim
-
model_onnx, check = onnxsim.simplify(...):
使用onnxsim的simplify方法简化模型。它返回简化后的onnx模型和一个布尔值check,表示简化是否成功。 -
在对动态输入的onnx导出的时候,
dynamic_input_shape=dynamic
是不够的,还要把输入给他,让onnxsim更加谨慎的优化onnx, 确保满足我们给他的输出,所以这里多了一个input_shapes={'images': list(im.shape)} if dynamic else None
def export_onnx(model, im, file, opset, train, dynamic, simplify, prefix=colorstr('ONNX:')):# YOLOv5 ONNX exporttry:check_requirements(('onnx',))import onnxLOGGER.info(f'\n{prefix} starting export with onnx {onnx.__version__}...')f = file.with_suffix('.onnx')torch.onnx.export(model,im,f,verbose=False,opset_version=opset,training=torch.onnx.TrainingMode.TRAINING if train else torch.onnx.TrainingMode.EVAL,do_constant_folding=not train,input_names=['images'],output_names=['output'],dynamic_axes={'images': {0: 'batch',2: 'height',3: 'width'}, # shape(1,3,640,640)'output': {0: 'batch',1: 'anchors'} # shape(1,25200,85)} if dynamic else None)# Checksmodel_onnx = onnx.load(f) # load onnx modelonnx.checker.check_model(model_onnx) # check onnx model# Metadatad = {'stride': int(max(model.stride)), 'names': model.names}for k, v in d.items():meta = model_onnx.metadata_props.add()meta.key, meta.value = k, str(v)onnx.save(model_onnx, f)# Simplifyif simplify:try:check_requirements(('onnx-simplifier',))import onnxsimLOGGER.info(f'{prefix} simplifying with onnx-simplifier {onnxsim.__version__}...')model_onnx, check = onnxsim.simplify(model_onnx,dynamic_input_shape=dynamic,input_shapes={'images': list(im.shape)} if dynamic else None)assert check, 'assert check failed'onnx.save(model_onnx, f)except Exception as e:LOGGER.info(f'{prefix} simplifier failure: {e}')LOGGER.info(f'{prefix} export success, saved as {f} ({file_size(f):.1f} MB)')return fexcept Exception as e:LOGGER.info(f'{prefix} export failure: {e}')
更改过的export_onnx函数
- 首先是把onnx的输出由一个改成了3个, 然后指定动态输出, 因为有多个输出,全部都把他们的batch, width, height指定为动态的,满足不同的输入输出。 不过这边的问题是看起来是只改了最后的输出,但是前面在yolo.py的地方已经把sigmoid后面的计算都干掉了, 因为后面的计算映射了一堆的算子导致了在计算图太冗余
这一坨全部不要了就保留sigmoid就可以了,然后就是直接硬编码t就是int32
diff --git a/models/yolo.py b/models/yolo.py
index 02660e6..c810745 100644
--- a/models/yolo.py
+++ b/models/yolo.py
@@ -55,29 +55,15 @@ class Detect(nn.Module):z = [] # inference outputfor i in range(self.nl):x[i] = self.m[i](x[i]) # conv
- bs, _, ny, nx = x[i].shape # x(bs,255,20,20) to x(bs,3,20,20,85)
- x[i] = x[i].view(bs, self.na, self.no, ny, nx).permute(0, 1, 3, 4, 2).contiguous()
-
- if not self.training: # inference
- if self.onnx_dynamic or self.grid[i].shape[2:4] != x[i].shape[2:4]:
- self.grid[i], self.anchor_grid[i] = self._make_grid(nx, ny, i)
-
- y = x[i].sigmoid()
- if self.inplace:
- y[..., 0:2] = (y[..., 0:2] * 2 + self.grid[i]) * self.stride[i] # xy
- y[..., 2:4] = (y[..., 2:4] * 2) ** 2 * self.anchor_grid[i] # wh
- else: # for YOLOv5 on AWS Inferentia https://github.com/ultralytics/yolov5/pull/2953
- xy, wh, conf = y.split((2, 2, self.nc + 1), 4) # y.tensor_split((2, 4, 5), 4) # torch 1.8.0
- xy = (xy * 2 + self.grid[i]) * self.stride[i] # xy
- wh = (wh * 2) ** 2 * self.anchor_grid[i] # wh
- y = torch.cat((xy, wh, conf), 4)
- z.append(y.view(bs, -1, self.no))
-
- return x if self.training else (torch.cat(z, 1),) if self.export else (torch.cat(z, 1), x)
+ y = x[i].sigmoid()
+ z.append(y)
+ return zdef _make_grid(self, nx=20, ny=20, i=0):d = self.anchors[i].device
- t = self.anchors[i].dtype
+ # t = self.anchors[i].dtype
+ # TODO(tylerz) hard-code data type to int
+ t = torch.int32shape = 1, self.na, ny, nx, 2 # grid shapey, x = torch.arange(ny, device=d, dtype=t), torch.arange(nx, device=d, dtype=t)if check_version(torch.__version__, '1.10.0'): # torch>=1.10.0 meshgrid workaround for torch>=0.7 compatibility
--
2.36.0
-
onnxsim这里跟之前是一样的, 也是直接onnxsim, 如果动态的要给输出给onnxsim然后让它更加的谨慎,满足需求
-
这后面的重点是增加了用onnx-surgon来更改onnx, 先把整个onnx导入进来,然后使用然后用Variable做模型的输出, 这里做四个模型输出, 分别是
DecodeNumDetection
,DecodeDetectionBoxes
,DecodeDetectionScores
,DecodeDetectionClasses
-
然后设置一个attrs, gs设置的attrs使用字典的格式弄的。这里设置max_stride, num_classes, anchors, prenms_score_threshold四个属性,这些属性的操作会在TensorRT中实现的
-
decode_plugin是中间的节点,这个节点上面是inputs, 下面是四个不同的decodes, 这里就是把这个nodes做出来了
-
然后在整体的网络上添加这个节点,然后再把输出改成这个节点的输出保持一致,在计算图中把其他的节点claenup()清洁掉, 最后导出
def export_onnx(model, im, file, opset, train, dynamic, simplify, prefix=colorstr('ONNX:')):# YOLOv5 ONNX export# try:check_requirements(('onnx',))import onnxLOGGER.info(f'\n{prefix} starting export with onnx {onnx.__version__}...')f = file.with_suffix('.onnx')print(train)torch.onnx.export(model,im,f,verbose=False,opset_version=opset,training=torch.onnx.TrainingMode.TRAINING if train else torch.onnx.TrainingMode.EVAL,do_constant_folding=not train,input_names=['images'],output_names=['p3', 'p4', 'p5'],dynamic_axes={'images': {0: 'batch',2: 'height',3: 'width'}, # shape(1,3,640,640)'p3': {0: 'batch',2: 'height',3: 'width'}, # shape(1,25200,4)'p4': {0: 'batch',2: 'height',3: 'width'},'p5': {0: 'batch',2: 'height',3: 'width'}} if dynamic else None)# Checksmodel_onnx = onnx.load(f) # load onnx modelonnx.checker.check_model(model_onnx) # check onnx model# Simplifyif simplify:# try:check_requirements(('onnx-simplifier',))import onnxsimLOGGER.info(f'{prefix} simplifying with onnx-simplifier {onnxsim.__version__}...')model_onnx, check = onnxsim.simplify(model_onnx,dynamic_input_shape=dynamic,input_shapes={'images': list(im.shape)} if dynamic else None)assert check, 'assert check failed'onnx.save(model_onnx, f)# except Exception as e:# LOGGER.info(f'{prefix} simplifier failure: {e}')# add yolov5_decoding:import onnx_graphsurgeon as onnx_gsimport numpy as npyolo_graph = onnx_gs.import_onnx(model_onnx)p3 = yolo_graph.outputs[0]p4 = yolo_graph.outputs[1]p5 = yolo_graph.outputs[2]decode_out_0 = onnx_gs.Variable("DecodeNumDetection",dtype=np.int32)decode_out_1 = onnx_gs.Variable("DecodeDetectionBoxes",dtype=np.float32)decode_out_2 = onnx_gs.Variable("DecodeDetectionScores",dtype=np.float32)decode_out_3 = onnx_gs.Variable("DecodeDetectionClasses",dtype=np.int32)decode_attrs = dict()decode_attrs["max_stride"] = int(max(model.stride))decode_attrs["num_classes"] = model.model[-1].ncdecode_attrs["anchors"] = [float(v) for v in [10,13, 16,30, 33,23, 30,61, 62,45, 59,119, 116,90, 156,198, 373,326]]decode_attrs["prenms_score_threshold"] = 0.25decode_plugin = onnx_gs.Node(op="YoloLayer_TRT",name="YoloLayer",inputs=[p3, p4, p5],outputs=[decode_out_0, decode_out_1, decode_out_2, decode_out_3],attrs=decode_attrs)yolo_graph.nodes.append(decode_plugin)yolo_graph.outputs = decode_plugin.outputsyolo_graph.cleanup().toposort()model_onnx = onnx_gs.export_onnx(yolo_graph)d = {'stride': int(max(model.stride)), 'names': model.names}for k, v in d.items():meta = model_onnx.metadata_props.add()meta.key, meta.value = k, str(v)onnx.save(model_onnx, f)LOGGER.info(f'{prefix} export success, saved as {f} ({file_size(f):.1f} MB)')return f# except Exception as e:# LOGGER.info(f'{prefix} export failure: {e}')
相关文章:

通过解读yolov5_gpu_optimization学习如何使用onnx_surgon
onnx实战一: 解析yolov5 gpu的onnx优化案例: 这是一个英伟达的仓库, 这个仓库的做法就是通过用gs对onnx进行修改减少算子然后最后使用TensorRT插件实现算子, 左边是优化过的, 右边是原版的。 通过这个案例理解原版的onnx的导出流程然后我们看英伟达是怎么拿gs来优化…...

图像复原与重建,解决噪声的几种空间域复原方法(数字图像处理概念 P4)
文章目录 图像复原模型噪声模型只存在噪声的空间域复原 图像复原模型 噪声模型 只存在噪声的空间域复原...

Android 启动优化案例:WebView非预期初始化排查
去年年底做启动优化时,有个比较好玩的 case 给大家分享下,希望大家能从我的分享里 get 到我在做一些问题排查修复时是怎么看上去又low又土又高效的。 1. 现象 在我们使用 Perfetto 进行app 启动过程性能观测时,在 UI 线程发现了一段 几十毫…...

20230919后台面经整理
1.你认为什么是操作系统,操作系统有哪些功能 os是:管理资源、向用户提供服务、硬件机器的扩展 1.进程线程管理:状态、控制、通信等 2.存储管理:分配回收、地址转换 3.文件管理:目录、操作、磁盘、存取 4.设备管理&…...

画一个时钟(html+css+js)
这是一个很简约的时钟。。。。。。。 效果: 代码: <template><div class"demo-box"><div class"clock"><ul class"mark"><liv-for"(rotate, index) in rotatedAngles":key"i…...

红 黑 树
文章目录 一、红黑树的概念二、红黑树的实现1. 红黑树的存储结构2. 红黑树的插入 一、红黑树的概念 在 AVL 树中删除一个结点,旋转可能要持续到根结点,此时效率较低 红黑树也是一种二叉搜索树,通过在每个结点中增加一个位置来存储红色或黑色…...
掷骰子的多线程应用程序1(复现《Qt C++6.0》)
说明:复现的代码来自《Qt C6.0》P496-P500。在复现时完全按照代码,出现了两处报错: (1)ui指针(2)按钮的响应函数。下面程序对以上问题进行了修改。除了图片、清空、关闭功能外,其他…...

【vue2第十八章】VueRouter 路由嵌套 与 keep-alive缓存组件(activated,deactivated)
VueRouter 路由嵌套 在使用vue开发中,可能会碰到使用多层级别的路由。比如: 其中就包含了两个主要页面,首页,详情,但是首页的下面又包含了列表,喜欢,收藏,我的四个子路由。 此时就…...

如何确保亚马逊、速卖通等平台测评补单的环境稳定性和安全性?
做亚马逊、速卖通等平台测评补单时,确保环境的安全性和稳定性是非常重要的。稳定的环境是测评的基础,如果无法解决安全性问题,那么测评就不值得进行。为了确保稳定的环境系统,需要考虑物理环境和IP环境两个方面。 首先࿰…...

echarts图表 实现高度按照 内容撑起来或者超出部分滚动展示效果
背景:因为数据不固定 高度写死导致数据显示不全,所以图表高度要根据内容计算 实现代码如下: <divv-if"showCharts"id"business-bars"class"chart":style"{ height: chartHeight px }"></d…...

【论文阅读】检索增强发展历程及相关文章总结
文章目录 前言Knn-LMInsightMethodResultsDomain AdaptionTuning Nearest Neighbor Search Analysis REALMInsightsMethodKnowledge RetrieverKnowledge-Augmented Encoder ExpResultAblation StudyCase Study DPRInsightMethodExperimentsResults RAGInsightRAG-Sequence Mode…...

【漏洞复现系列】二、weblogic-cve_2020_2883(RCE/反序列化)
Key words:T3协议,weblogic Server,反序列化 2.1、漏洞原理 cve_2020_2883 远程代码执行漏洞存在于 WebLogic Server 核心组件中,允许未经身份验证的攻击者通过 T3 协议网络访问并破坏易受攻击的 WebLogic Server,成功的漏洞利…...
算法通关村-----LRU的设计与实现
LRU 缓存 问题描述 请你设计并实现一个满足 LRU (最近最少使用) 缓存 约束的数据结构。 实现 LRUCache 类: LRUCache(int capacity) 以 正整数 作为容量 capacity 初始化 LRU 缓存。int get(int key) 如果关键字 key 存在于缓存中,则返回关键字的值&…...

王江涛十天搞定考研词汇
学习目标: 考研词汇 学习内容: 2023-9-17 第一天考研词汇 学习时间: 开始:2023-9-17 结束:进行中 学习产出: 2023-9-17intellect智力;知识分子intellectual智力的;聪明的intellectualize使...理智化&a…...

算法(二)——数组章节和链表章节
数组章节 (1)二分查找 给定一个 n 个元素有序的(升序)整型数组 nums 和一个目标值 target ,写一个函数搜索 nums 中的 target,如果目标值存在返回下标,否则返回 -1。 class Solution {public i…...

Android:ListView在Fragment中的使用
一、前言: 因为工作一直在用mvvm框架,因此这篇文章是基于mvvm框架写的。在Fragment复制之前一定要谨记项目可以跑起来。确保能跑起来之后直接复制就行。 二、代码展示: 页面布局 ?xml version"1.0" encoding"utf-8"…...

BIGEMAP在土地规划中的应用
工具 Bigemap gis office地图软件 BIGEMAP GIS Office-全能版 Bigemap APP_卫星地图APP_高清卫星地图APP 1.使用软件一般都用于套坐标,比如我们常见的(kml shp CAD等土建规划图纸)以及一些项目厂区红线,方便于项目选址和居民建…...

软件测试常见术语和名词解释
1. Unit testing (单元测试):指一段代码的基本测试,其实际大小是未定的,通常是一个函数或子程序,一般由开发者执行。 2. Integration testing (集成测试):被测试系统的所有组件都集成在一起,找出被测试系统…...

prometheus+process_exporter进程监控
一、需要监控进程的服务器上配置 1、进入到临时工作目录,传入process_exporter包 [root Nginx1 ~]# cd work/ [root Nginx1 work]# rz 2、解压,并移动至/usr/local/目录下 [root Nginx1 work]# tar xzf process-exporter-0.7.5.linux-amd64.tar.gz [root…...

四川玖璨电子商务有限公司专注抖音电商运营
四川玖璨电商是一个靠谱的抖音培训公司,在电商行业内有着广泛的知名度和良好的口碑。该公司通过多年的发展,形成了独特的运营理念和有效的运营策略,为商家提供了一站式的抖音电商运营服务。 首先,四川玖璨电子商务有限公司注重与…...

【Axure高保真原型】引导弹窗
今天和大家中分享引导弹窗的原型模板,载入页面后,会显示引导弹窗,适用于引导用户使用页面,点击完成后,会显示下一个引导弹窗,直至最后一个引导弹窗完成后进入首页。具体效果可以点击下方视频观看或打开下方…...

【Redis技术进阶之路】「原理分析系列开篇」分析客户端和服务端网络诵信交互实现(服务端执行命令请求的过程 - 初始化服务器)
服务端执行命令请求的过程 【专栏简介】【技术大纲】【专栏目标】【目标人群】1. Redis爱好者与社区成员2. 后端开发和系统架构师3. 计算机专业的本科生及研究生 初始化服务器1. 初始化服务器状态结构初始化RedisServer变量 2. 加载相关系统配置和用户配置参数定制化配置参数案…...
Spring Boot+Neo4j知识图谱实战:3步搭建智能关系网络!
一、引言 在数据驱动的背景下,知识图谱凭借其高效的信息组织能力,正逐步成为各行业应用的关键技术。本文聚焦 Spring Boot与Neo4j图数据库的技术结合,探讨知识图谱开发的实现细节,帮助读者掌握该技术栈在实际项目中的落地方法。 …...

涂鸦T5AI手搓语音、emoji、otto机器人从入门到实战
“🤖手搓TuyaAI语音指令 😍秒变表情包大师,让萌系Otto机器人🔥玩出智能新花样!开整!” 🤖 Otto机器人 → 直接点明主体 手搓TuyaAI语音 → 强调 自主编程/自定义 语音控制(TuyaAI…...

Java面试专项一-准备篇
一、企业简历筛选规则 一般企业的简历筛选流程:首先由HR先筛选一部分简历后,在将简历给到对应的项目负责人后再进行下一步的操作。 HR如何筛选简历 例如:Boss直聘(招聘方平台) 直接按照条件进行筛选 例如:…...

mysql已经安装,但是通过rpm -q 没有找mysql相关的已安装包
文章目录 现象:mysql已经安装,但是通过rpm -q 没有找mysql相关的已安装包遇到 rpm 命令找不到已经安装的 MySQL 包时,可能是因为以下几个原因:1.MySQL 不是通过 RPM 包安装的2.RPM 数据库损坏3.使用了不同的包名或路径4.使用其他包…...

OPenCV CUDA模块图像处理-----对图像执行 均值漂移滤波(Mean Shift Filtering)函数meanShiftFiltering()
操作系统:ubuntu22.04 OpenCV版本:OpenCV4.9 IDE:Visual Studio Code 编程语言:C11 算法描述 在 GPU 上对图像执行 均值漂移滤波(Mean Shift Filtering),用于图像分割或平滑处理。 该函数将输入图像中的…...

云原生玩法三问:构建自定义开发环境
云原生玩法三问:构建自定义开发环境 引言 临时运维一个古董项目,无文档,无环境,无交接人,俗称三无。 运行设备的环境老,本地环境版本高,ssh不过去。正好最近对 腾讯出品的云原生 cnb 感兴趣&…...

AirSim/Cosys-AirSim 游戏开发(四)外部固定位置监控相机
这个博客介绍了如何通过 settings.json 文件添加一个无人机外的 固定位置监控相机,因为在使用过程中发现 Airsim 对外部监控相机的描述模糊,而 Cosys-Airsim 在官方文档中没有提供外部监控相机设置,最后在源码示例中找到了,所以感…...

Rust 开发环境搭建
环境搭建 1、开发工具RustRover 或者vs code 2、Cygwin64 安装 https://cygwin.com/install.html 在工具终端执行: rustup toolchain install stable-x86_64-pc-windows-gnu rustup default stable-x86_64-pc-windows-gnu 2、Hello World fn main() { println…...