当前位置: 首页 > news >正文

TensorRT加速推理入门-1:Pytorch转ONNX

这篇文章,用于记录将TransReID的pytorch模型转换为onnx的学习过程,期间参考和学习了许多大佬编写的博客,在参考文章这一章节中都已列出,非常感谢。

1. 在pytorch下使用ONNX主要步骤

1.1. 环境准备

安装onnxruntime包
安装教程可参考:
onnx模型预测环境安装笔记
onnxruntime配置
CPU版本:
直接pip安装

pip install onnxruntime

GPU版本:
先查看自己CUDA版本然后在下面的链接去找对应的onnxruntime的版本
CUDA版本的查询,可参考这个
onnxruntime版本查询
查询到对应版本,直接pip安装即可,例如

pip install onnxruntime-gpu==1.13.1

安装onnxsim包

pip install onnx-simplifier

1.2. 搭建 PyTorch 模型(TransReID)

def get_net(model_path,opt_=False):if opt_:cfg.merge_from_file("/home/TransReID-main/configs/OCC_Duke/vit_transreid_stride.yml")#cfg.freeze()train_loader, train_loader_normal, val_loader, num_query, num_classes, camera_num, view_num = make_dataloader(cfg)net = make_model(cfg, num_class=num_classes, camera_num=camera_num, view_num = view_num)else:cfg.merge_from_file("/home/TransReID-main/configs/OCC_Duke/vit_transreid_stride.yml")train_loader, train_loader_normal, val_loader, num_query, num_classes, camera_num, view_num = make_dataloader(cfg)net = make_model(cfg, num_class=num_classes, camera_num=camera_num, view_num = view_num)#state_dict = torch.load(model_path, map_location=torch.device('cpu'))['state_dict']state_dict = torch.load(model_path, map_location=torch.device('cpu'))model_state_dict=net.state_dict()for key in list(state_dict.keys()):if key[7:] in model_state_dict.keys():model_state_dict[key[7:]]=state_dict[key]net.load_state_dict(model_state_dict)return net

1.3. pytorch模型转换为 ONNX 模型

这个提供了静态转换(静态转换支持静态输入)和动态转换(动态转换支持动态输入)两个函数,可根据需要选择。

def convert_onnx_dynamic(model,save_path,simp=False):x = torch.randn(4, 3, 256,128)input_name = 'input'output_name = 'class'torch.onnx.export(model,x,save_path,input_names = [input_name],output_names = [output_name],dynamic_axes= {input_name: {0: 'B'},output_name: {0: 'B'}})if simp:onnx_model = onnx.load(save_path) model_simp, check = simplify(onnx_model,input_shapes={'input':(4,3,256,128)},dynamic_input_shape=True)assert check, "Simplified ONNX model could not be validated"onnx.save(model_simp, save_path)print('simplify onnx done')def convert_onnx(model,save_path,batch=1,simp=False):input_names = ['input']output_names=['class']x = torch.randn(batch, 3, 256, 128)for para in model.parameters():para.requires_grad = False# model_script = torch.jit.script(model)# model_trace = torch.jit.trace(model, x)torch.onnx.export(model, x, save_path,input_names =input_names,output_names=output_names, opset_version=12)if simp:onnx_model = onnx.load(save_path) model_simp, check = simplify(onnx_model)assert check, "Simplified ONNX model could not be validated"onnx.save(model_simp, save_path)print('simplify onnx done')

pytorch 转 onnx 仅仅需要一个函数 torch.onnx.export,来看看该函数的参数和用法。

torch.onnx.export(model, args, path, export_params, verbose, input_names, output_names, do_constant_folding, dynamic_axes, opset_version)
参数用法
model需要导出的pytorch模型
args模型的任意一组输入(模拟实际输入数据的大小,比如三通道的512*512大小的图片,就可以设置为torch.randn(1, 3 , 512, 512)
path输出的onnx模型的位置,例如yolov5.onnx
export_params输出模型是否可训练。default=True,表示导出trained model, 否则untrained。
verbose是否打印模型转换信息,default=None
opset_versiononnx算子集的版本
input_names模型的输入节点名称(自己定义的),如果不写,默认输出数字类型 的名称
output_name模型的输出节点名称(自己定义的), 如果不写,默认输出数字类型的名称
do_constant_folding是否使用常量折叠,默认即可。default=True。
dynamic_axes设置动态输入输出,用法:“输入输出名:[支持动态的维度”,如"支持动态的维度设置为[0, 2, 3]"则表示第0维,第2维,第3维支持动态输入输出。
模型的输入输出有时是可变的,如rnn,或者输出图像的batch可变,可通过该参数设置。如输入层的shape为(b, 3, h, w), 其中batch、height、width是可变的,但是chancel是固定三通道。格式如下:1)仅list(int)dynamic_axes={‘input’:[0, 2, 3], ‘output’:{0:‘batch’, 1:‘c’}} 2)仅dict<int, string> dynamic_axes={‘input’:{‘input’:{0:‘batch’, 2:‘height’, 3:‘width’}, ‘output’:{0:‘batch’, 1:‘c’}} 3)mixed dynamic_axes={‘input’:{0:‘batch’, 2:‘height’, 3:‘width’}, ‘output’:[0,1]}

注意onnx不支持结构中带有if语句的模型,如:

当我们在网络中嵌入一些if选择性的语句时,不好意思,模型不会考虑这些, 它只会记录下运行时走过的节点,不会根据if的实际情况来选择走哪条路, 所以势必会丢弃一部分节点,而丢弃哪些则是根据我们转模型时的输入来定的,一旦指定了,后面运行onnx模型都会如此。另一个问题就是,我们在代码中有一些循环或者迭代的操作时,要注意,尤其是我们的迭代次数是根据输入不同 会有变化时,也会因为这些操作导致后面的推理出现意外错误,正像前面说的,模型转换不喜欢不确定的东西,它会把这些变量dump成常量,所以会导致推理 错误。

对于实际部署的需求,很多时候pytorch是不满足的,所以需要转成其他模型格式来加快推理。常用的就是onnx,onnx天然支持很多框架模型的转换,如Pytorch,tf,darknet,caffe等。而pytorch也给我们提供了对应的接口,就是torch.onnx.export。下面具体到每一步。
原文来自:Windows下使用ONNX+pytorch记录

首先,环境和依赖:onnx包,cuda和cudnn,我用的版本号分别是1.7.0, 10.1, 7.5.4。
我们需要提供一个pytorch的模型,然后调用torch.onnx.export,同时还需要提供另外一些参数。我们一个个来分析,一是我们要给一个dummy input, 就是随便指定一个和我们实际输入时尺寸相同的一个随机数,是Tensor类型的,然后我们要指定转换的device,即是在gpu还是cpu。 然后我们要给一个input_names和output_names,这是绑定输入和输出,当然输入和输出可能不止一个,那就根据实际的输入和输出个数来给出name列表,
如果我们指定的输入和输出名和实际的网络结构不一致的话,onnx会自动给我们设置一个名字。一般是数字字符串。
输入和输出的绑定之后,我我们们可以看到还有一个参数叫做dynamic_axes,这是做什么的呢?哦,这是指定动态输入的,为了满足我们实际推理过程中,可能每张图片的分辨率不一样,所以允许我们给每个维度设置动态输入,这样是不是灵活多了?然后,设置完这些参数和输入,我们就可以开始转换模型了,如果不报错就是成功了,会在当前目录下生成一个.onnx文件。
原文来自: 一文掌握Pytorch-onnx-tensorrt模型转换

1.4 onnx-simplifier简化onnx模型

model_simp, check = simplify(onnx_model,input_shapes={'input':(4,3,256,128)},dynamic_input_shape=True)

Pytorch转换为ONNX的完整代码pytorch_to_onnx.py

import json
import os
import onnx
import torch
import argparse
import torch.nn as nn
from onnxsim import simplify
from collections import OrderedDict
import torch.nn.functional as F# TransReID的模型构建需要的包
from model.make_model import *
from config import cfg
from datasets import make_dataloader os.environ['CUDA_VISIBLE_DEVICES'] = '1'def convert_onnx_dynamic(model,save_path,simp=False):x = torch.randn(4, 3, 256,128)input_name = 'input'output_name = 'class'torch.onnx.export(model,x,save_path,input_names = [input_name],output_names = [output_name],dynamic_axes= {input_name: {0: 'B'},output_name: {0: 'B'}})if simp:onnx_model = onnx.load(save_path) model_simp, check = simplify(onnx_model,input_shapes={'input':(4,3,256,128)},dynamic_input_shape=True)assert check, "Simplified ONNX model could not be validated"onnx.save(model_simp, save_path)print('simplify onnx done')def convert_onnx(model,save_path,batch=1,simp=False):input_names = ['input']output_names=['class']x = torch.randn(batch, 3, 256, 128)for para in model.parameters():para.requires_grad = False# model_script = torch.jit.script(model)# model_trace = torch.jit.trace(model, x)torch.onnx.export(model, x, save_path,input_names =input_names,output_names=output_names, opset_version=12)if simp:onnx_model = onnx.load(save_path) model_simp, check = simplify(onnx_model)assert check, "Simplified ONNX model could not be validated"onnx.save(model_simp, save_path)print('simplify onnx done')def get_net(model_path,opt_=False):if opt_:cfg.merge_from_file("/home/TransReID-main/configs/OCC_Duke/vit_transreid_stride.yml")#cfg.freeze()train_loader, train_loader_normal, val_loader, num_query, num_classes, camera_num, view_num = make_dataloader(cfg)net = make_model(cfg, num_class=num_classes, camera_num=camera_num, view_num = view_num)else:cfg.merge_from_file("/home/TransReID-main/configs/OCC_Duke/vit_transreid_stride.yml")train_loader, train_loader_normal, val_loader, num_query, num_classes, camera_num, view_num = make_dataloader(cfg)net = make_model(cfg, num_class=num_classes, camera_num=camera_num, view_num = view_num)#state_dict = torch.load(model_path, map_location=torch.device('cpu'))['state_dict']state_dict = torch.load(model_path, map_location=torch.device('cpu'))model_state_dict=net.state_dict()for key in list(state_dict.keys()):if key[7:] in model_state_dict.keys():model_state_dict[key[7:]]=state_dict[key]net.load_state_dict(model_state_dict)return netif __name__=="__main__":parser = argparse.ArgumentParser(description='torch to onnx describe.')parser.add_argument("--model_path",type = str,default="/home/TransReID-main/weights/vit_transreid_occ_duke.pth",help="torch weight path, default is MobileViT_Pytorch/weights-file/model_best.pth.tar.")parser.add_argument("--save_path",type=str,default="/home/TransReID-main/weights/vit_transreid_occ_duke_v2.onnx",help="save direction of onnx models,default is ./target/MobileViT.onnx.")parser.add_argument("--batch",type=int,default=1,help="batchsize of onnx models, default is 1.")parser.add_argument("--opt",default=False, action='store_true',help="model optmization , default is False.")parser.add_argument("--dynamic",default=False, action='store_true',help="export  dynamic onnx model , default is False.")args = parser.parse_args()
#   print(args)#net = get_net(args.model_path,opt_=args.opt)net = get_net(args.model_path)if args.dynamic:convert_onnx_dynamic(net,args.save_path,simp=True)else:with torch.no_grad():convert_onnx(net,args.save_path,simp=True,batch=args.batch)

1.5 查看onnx模型

当将pytorch模型保存为 ONNX 之后,可以使用一款名为 Netron 的软件打开 .onnx 文件,查看模型结构。

2. 参考文章

[1] Windows下使用ONNX+pytorch记录
[2] pytorch-onnx-tensorrt全链路简单教程(支持动态输入)
[3] PyTorch语义分割模型转ONNX以及对比转换后的效果(PyTorch2ONNX、Torch2ONNX、pth2onnx、pt2onnx、修改名称、转换、测试、加载ONNX、运行ONNX)
[4] ONNX系列一:ONNX的使用,从转化到推理

相关文章:

TensorRT加速推理入门-1:Pytorch转ONNX

这篇文章&#xff0c;用于记录将TransReID的pytorch模型转换为onnx的学习过程&#xff0c;期间参考和学习了许多大佬编写的博客&#xff0c;在参考文章这一章节中都已列出&#xff0c;非常感谢。 1. 在pytorch下使用ONNX主要步骤 1.1. 环境准备 安装onnxruntime包 安装教程可…...

springboot常用扩展点

当涉及到Spring Boot的扩展和自定义时&#xff0c;Spring Boot提供了一些扩展点&#xff0c;使开发人员可以根据自己的需求轻松地扩展和定制Spring Boot的行为。本篇博客将介绍几个常用的Spring Boot扩展点&#xff0c;并提供相应的代码示例。 1. 自定义Starter(面试常问) Sp…...

19道ElasticSearch面试题(很全)

点击下载《19道ElasticSearch面试题&#xff08;很全&#xff09;》 1. elasticsearch的一些调优手段 1、设计阶段调优 &#xff08;1&#xff09;根据业务增量需求&#xff0c;采取基于日期模板创建索引&#xff0c;通过 roll over API 滚动索引&#xff1b; &#xff08;…...

向爬虫而生---Redis 拓宽篇3 <GEO模块>

前言: 继上一章: 向爬虫而生---Redis 拓宽篇2 &#xff1c;Pub/Sub发布订阅&#xff1e;-CSDN博客 这一章的用处其实不是特别大,主要是针对一些地图和距离业务的;就是Redis的GEO模块。 GEO模块是Redis提供的一种高效的地理位置数据管理方案&#xff0c;它允许我们存储和查询…...

Vue项目里实现json对象转formData数据

平常调用后端接口传参都是json对象&#xff0c;当提交表单遇到有附件需要传递时&#xff0c;通常是把附件上传单独做个接口&#xff0c;也有遇到后端让提交接口一并把附件传递到后端&#xff0c;这种情况需要把参数转成formData的数据&#xff0c;需要用到new FormData()。json…...

leetcode刷题记录

栈 2696. 删除子串后的字符串最小长度 哈希表 1. 两数之和 用map来保存每个数和他的索引 383. 赎金信 用map来存储字符的个数 链表 2. 两数相加 指针的移动 动态规划 53. 最大子数组和 2707. 字符串中的额外字符 递归 101. 对称二叉树 数学 1276. 不浪费原料的汉堡…...

SpringMVC通用后台管理系统源码

整体的SSM后台管理框架功能已经初具雏形&#xff0c;前端界面风格采用了结构简单、 性能优良、页面美观大的Layui页面展示框架 数据库支持了SQLserver,只需修改配置文件即可实现数据库之间的转换。 系统工具中加入了定时任务管理和cron生成器&#xff0c;轻松实现系统调度问…...

深度解析Dubbo的基本应用与高级应用:负载均衡、服务超时、集群容错、服务降级、本地存根、本地伪装、参数回调等关键技术详解

负载均衡 官网地址&#xff1a; http://dubbo.apache.org/zh/docs/v2.7/user/examples/loadbalance/ 如果在消费端和服务端都配置了负载均衡策略&#xff0c; 以消费端为准。 这其中比较难理解的就是最少活跃调用数是如何进行统计的&#xff1f; 讲道理&#xff0c; 最少活跃数…...

备战2024美赛数学建模,文末获取历史优秀论文

总说&#xff08;历年美赛优秀论文可获取&#xff09; 数模的题型千变万化&#xff0c;我今天想讲的主要是一些「画图」、「建模」、「写作」和「论文结构」的思路&#xff0c;这些往往是美赛阅卷官最看重的点&#xff0c;突破了这些点&#xff0c;才能真正让你的美赛论文更上…...

Java加密解密大全(MD5、RSA)

目录 一、MD5加密二、RSA加解密(公加私解&#xff0c;私加公解)三、RSA私钥加密四、RSA私钥加密PKCS1Padding模式 一、MD5加密 密文形式&#xff1a;5eb63bbbe01eeed093cb22bb8f5acdc3 import java.math.BigInteger; import java.security.MessageDigest; import java.security…...

C语言程序设计考试掌握这些题妥妥拿绩点(写给即将C语言考试的小猿猴们)

目录 开篇说两句1. 水仙花数题目描述分析代码示例 2. 斐波那契数列题目描述分析代码示例 3. 猴子吃桃问题题目描述分析代码示例 4. 物体自由落地题目描述分析代码示例 5. 矩阵对角线元素之和题目描述分析代码示例 6. 求素数题目描述分析代码示例 7. 最大公约数和最小公倍数题目…...

编译ZLMediaKit(win10+msvc2019_x64)

前言 因工作需要&#xff0c;需要ZLMediaKit&#xff0c;为方便抓包分析&#xff0c;最好在windows系统上测试&#xff0c;但使用自己编译的第三方库一直出问题&#xff0c;无法编译通过。本文档记录下win10上的编译过程&#xff0c;供有需要的小伙伴使用 一、需要安装的软件…...

JS-基础语法(一)

JavaScript简单介绍 变量 常量 数据类型 类型转换 案例 1.JavaScript简单介绍 JavaScript 是什么&#xff1f; 是一种运行在客户端&#xff08;浏览器&#xff09;的编程语言&#xff0c;可以实现人机交互效果。 JS的作用 JavaScript的组成 JSECMAScript( 基础语法 )…...

18款Visual Studio实用插件(更新)

前言 俗话说的好工欲善其事必先利其器&#xff0c;安装一些Visual Studio实用插件对自己日常的开发和工作效率能够大大的提升&#xff0c;避免996从选一款好的IDE实用插件开始。以下是我认为比较实用的Visual Studio插件希望对大家有用&#xff0c;大家有更好的插件推荐可在文…...

三、java线性表(顺序表、链表、栈、队列)

java线性表 三、线性表1.1 顺序表1.2 链表1.2.1 单向链表&#xff08;Singly Linked List&#xff09;1.2.2 双向链表&#xff08;Doubly Linked List&#xff09; 1.3 LinkedList VS ArrayList1.3.7 使用 LinkedList 的场景 1.4 栈1.5 队列 三、线性表 线性表是一种经典的数据…...

PiflowX-MysqlCdc组件

MysqlCdc组件 组件说明 MySQL CDC连接器允许从MySQL数据库读取快照数据和增量数据。 计算引擎 flink 组件分组 cdc 端口 Inport&#xff1a;默认端口 outport&#xff1a;默认端口 组件属性 名称展示名称默认值允许值是否必填描述例子hostnameHostname“”无是MySQL…...

2023春季李宏毅机器学习笔记 03 :机器如何生成文句

资料 课程主页&#xff1a;https://speech.ee.ntu.edu.tw/~hylee/ml/2023-spring.phpGithub&#xff1a;https://github.com/Fafa-DL/Lhy_Machine_LearningB站课程&#xff1a;https://space.bilibili.com/253734135/channel/collectiondetail?sid2014800 一、大语言模型的两种…...

dplayer播放hls格式视频并自动开始播放

监控视频流为hls格式&#xff0c;需要打开或刷新页面自动开始播放&#xff0c;需要安装dplayer和hls.js插件&#xff0c;插件直接npm装就行&#xff0c;上代码 import DPlayer from dplayer import Hls from hls.js //jquery是用来注册点击事件&#xff0c;实现自动开始播放 i…...

使用Vivado Design Suite平台板、将IP目录与平台板流一起使用

使用Vivado Design Suite平台板流 Vivado设计套件允许您使用AMD目标设计平台板&#xff08;TDP&#xff09;创建项目&#xff0c;或者已经添加到板库的用户指定板。当您选择特定板&#xff0c;Vivado设计工具显示有关板的信息&#xff0c;并启用其他设计器作为IP定制的一部分以…...

PACS医学影像报告管理系统源码带CT三维后处理技术

PACS从各种医学影像检查设备中获取、存储、处理影像数据&#xff0c;传输到体检信息系统中&#xff0c;生成图文并茂的体检报告&#xff0c;满足体检中心高水准、高效率影像处理的需要。 自主知识产权&#xff1a;拥有完整知识产权&#xff0c;能够同其他模块无缝对接 国际标准…...

从myplaces.shp到专题地图:手把手教你用QGIS C++ API实现点要素分级渲染

从myplaces.shp到专题地图&#xff1a;QGIS C API实现点要素分级渲染实战指南 当我们需要在桌面GIS应用中直观展示气象站降雨量、城市人口密度或商业网点销售额等连续型空间数据时&#xff0c;分级色彩渲染是最有效的可视化手段之一。本文将深入探讨如何利用QGIS强大的C API&am…...

计算机科学第三难题:“树映射”问题在文件、写作、建筑、生物分类中无处不在!

计算机科学第三难题&#xff1a;将通用图映射到层次结构&#xff0c;“树映射”问题无处不在 根据一个归属于 菲尔卡尔顿 的 经典笑话&#xff0c;计算机科学只有两个难题&#xff1a;命名和缓存失效。这两个问题之所以难&#xff0c;是因为没有算法可以解决它们&#xff1a;好…...

5个场景深度解析:如何用bili2text将B站视频变成你的私人知识库

5个场景深度解析&#xff1a;如何用bili2text将B站视频变成你的私人知识库 【免费下载链接】bili2text Bilibili视频转文字&#xff0c;一步到位&#xff0c;输入链接即可使用 项目地址: https://gitcode.com/gh_mirrors/bi/bili2text 凌晨两点&#xff0c;小林还在为明…...

小红书无水印下载工具XHS-Downloader:3种使用模式全解析

小红书无水印下载工具XHS-Downloader&#xff1a;3种使用模式全解析 【免费下载链接】XHS-Downloader 小红书&#xff08;XiaoHongShu、RedNote&#xff09;链接提取/作品采集工具&#xff1a;提取账号发布、收藏、点赞、专辑作品链接&#xff1b;提取搜索结果作品、用户链接&a…...

使用mcp-maker快速构建AI工具调用服务器:从协议原理到工程实践

1. 项目概述与核心价值最近在折腾AI应用开发&#xff0c;特别是想给大语言模型&#xff08;LLM&#xff09;装上更强大的“手脚”&#xff0c;让它能直接操作我电脑上的各种软件和工具。这听起来很酷&#xff0c;对吧&#xff1f;但实际操作起来&#xff0c;你会发现一个核心痛…...

英雄联盟智能助手Seraphine:告别手动查询,实现高效游戏决策自动化

英雄联盟智能助手Seraphine&#xff1a;告别手动查询&#xff0c;实现高效游戏决策自动化 【免费下载链接】Seraphine 英雄联盟战绩查询工具 项目地址: https://gitcode.com/gh_mirrors/se/Seraphine 在英雄联盟排位赛中&#xff0c;你是否曾因错过接受对局而懊恼不已&a…...

构建个人知识库:从碎片化代码到结构化知识体系

1. 项目概述&#xff1a;从“ClawCode”看个人知识库的构建与价值最近在和一些开发者朋友交流时&#xff0c;发现一个普遍现象&#xff1a;大家电脑里都散落着无数代码片段、配置脚本、临时笔记和项目心得。这些“数字碎片”价值巨大&#xff0c;但往往因为缺乏有效的组织&…...

保姆级教程:INCA 7.2.3 从新建工程到观测标定的完整流程(附A2L文件处理技巧)

INCA 7.2.3 全流程实战指南&#xff1a;从工程搭建到参数标定的深度解析 在汽车电子开发领域&#xff0c;标定工具链的掌握程度直接影响开发效率。作为行业标准的INCA软件&#xff0c;其7.2.3版本在工程管理、实时观测和参数标定方面提供了更完善的解决方案。本文将采用"操…...

多语种出海必备,ElevenLabs菲律宾文语音质量实测对比:Wavenet vs. Instant Voice vs. Custom Model(附MOS评分表)

更多请点击&#xff1a; https://intelliparadigm.com 第一章&#xff1a;多语种出海语音技术演进与菲律宾语本地化挑战 随着全球数字服务加速出海&#xff0c;语音交互系统正从单语种向多语种、低资源语言深度拓展。菲律宾语&#xff08;Filipino/Tagalog&#xff09;作为东…...

基于Vanilla JS与IndexedDB构建本地化Markdown笔记工具

1. 项目概述&#xff1a;从零开始构建一个轻量级笔记工具最近在整理个人知识库时&#xff0c;发现市面上的笔记软件要么功能过于臃肿&#xff0c;要么云端同步存在隐私顾虑&#xff0c;要么就是定制化程度不够。作为一个有十多年开发经验的从业者&#xff0c;我决定自己动手&am…...