【深度学习】【OnnxRuntime】【Python】模型转化、环境搭建以及模型部署的详细教程
【深度学习】【OnnxRuntime】【Python】模型转化、环境搭建以及模型部署的详细教程
提示:博主取舍了很多大佬的博文并亲测有效,分享笔记邀大家共同学习讨论
文章目录
- 【深度学习】【OnnxRuntime】【Python】模型转化、环境搭建以及模型部署的详细教程
- 前言
- 模型转换--pytorch转onnx
- Windows平台搭建依赖环境
- onnxruntime调用onnx模型
- ONNXRuntime推理核心流程
- ONNXRuntime推理代码
- 总结
前言
ONNXRuntime是微软推出的一款高性能的机器学习推理引擎框架,用户可以非常便利的用其运行一个onnx模型,专注于加速机器学习模型的预测阶段。ONNXRuntime设计目的是为了提供一个高效的执行环境,使机器学习模型能够在各种硬件上快速执行,支持多种运行后端包括CPU,GPU,TensorRT,DML等,使得开发者可以灵活选择最适合其应用场景的硬件平台。
ONNXRuntime是对ONNX模型最原生的支持。
读者可以通过学习【onnx部署】部署系列学习文章目录的onnxruntime系统学习–Python篇 的内容,系统的学习OnnxRuntime部署不同任务的onnx模型。
模型转换–pytorch转onnx
Pytorch模型转onnx并推理的步骤如下:
- 将PyTorch预训练模型文件( .pth 或 .pt 格式)转换成ONNX格式的文件(.onnx格式),这一转换过程在PyTorch环境中进行。
- 将转换得到的 .onnx 文件随后作为输入,调用ONNXRuntime的C++ API来执行模型的推理。
博主使用AlexNet图像分类(五种花分类)进行演示,需要安装pytorch环境,对于该算法的基础知识,可以参考博主【AlexNet模型算法Pytorch版本详解】博文
conda create --name AlexNet python==3.10
conda activate AlexNet
# 根据自己主机配置环境
pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
# 假设模型转化出错则降级为指定1.16.1版本
pip install onnx==1.16.1
然后把训练模型好的AlexNet.pth模型转成AlexNet.onnx模型,pyorch2onnx.py转换代码如下:
import torch
from model import AlexNet
model = AlexNet(num_classes=5)
weights_path = "./AlexNet.pth"
# 加载模型权重
model.load_state_dict(torch.load(weights_path))
# 模型推理模式
model.eval()
model.cpu()
# 虚拟输入数据
dummy_input1 = torch.randn(1, 3, 224, 224)
# 模型转化函数
torch.onnx.export(model, (dummy_input1), "AlexNet.onnx", verbose=True, opset_version=11)
【AlexNet.pth百度云链接,提取码:ktq5 】直接下载使用即可。
Windows平台搭建依赖环境
需要在anaconda虚拟环境安装onnxruntime,需要注意onnxruntime-gpu, cuda, cudnn三者的版本要对应,具体参照官方说明。
博主是win11+cuda12.1+cudnn8.8.1,对应onnxruntime-gpu==1.18.0
import torch
# 查询cuda版本
print(torch.version.cuda)
# 查询cudnn版本
print(torch.backends.cudnn.version())
# 激活环境
activate AlexNet
# 安装onnx
pip install -i https://pypi.tuna.tsinghua.edu.cn/simple onnx
# 安装GPU版
pip install -i https://pypi.tuna.tsinghua.edu.cn/simple onnxruntime-gpu==1.18.0
# 或者可以安装CPU版本:没有版本对应要求
pip install -i https://pypi.tuna.tsinghua.edu.cn/simple onnxruntime
# 安装opencv
pip install -i https://pypi.tuna.tsinghua.edu.cn/simple opencv-python
onnxruntime调用onnx模型
ONNXRuntime推理核心流程
设置会话选项
通常包括配置优化器级别、线程数和设备(GPU/CPU)使用等。
sess_options = ort.SessionOptions()
sess_options.log_severity_level = 3
sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_BASIC
sess_options.intra_op_num_threads = 4
providers = ['CUDAExecutionProvider', 'CPUExecutionProvider']
会换选项 | 日志严重性级别 | 优化器级别 | 线程数 | 设备使用 |
---|---|---|---|---|
选项 | log_severity_level | graph_optimization_level | graph_optimization_level | CUDAExecutionProvider;CPUExecutionProvider |
作用 | 决定了哪些级别的日志信息将被记录下来,运行时提供了几个预定义的宏来表示不同的日志级别。 | 在模型加载到ONNXRuntime之前对其进行图优化的过程,提高执行效率 | 设置每个运算符内部执行时的最大线程数 | CUDA/CPU设备选择。 |
参数 | 整形,1:Info, 2:Warning. 3:Error, 4:Fatal,默认是2。 | ORT_ENABLE_BASIC:基本的图优化; ORT_DISABLE_ALL:禁用所有优化;ORT_ENABLE_EXTENDED:启用扩展优化;ORT_ENABLE_ALL:启用所有优化。 | 整型 | 列表中的顺序决定了执行提供者的优先级。 |
加载模型并创建会话
加载预训练的ONNX模型文件,使用运行时环境、会话选项和模型创建一个Session对象。
session = ort.InferenceSession(onnxpath, sess_options=sess_options, providers=providers)
ort.InferenceSession参数 | path_or_bytes | sess_options | providers |
---|---|---|---|
内容 | 模型的位置或者模型的二进制数据 | 会话选项 | 设备选择 |
获取模型输入输出信息
从Session对象中获取模型输入和输出的详细信息,包括数量、名称、类型和形状。
input_nodes_num = len(session.get_inputs())
output_nodes_num = len(session.get_outputs())
input_name = session.get_inputs()[i].name
output_name = session.get_outputs()[i].name
input_shape = session.get_inputs()[i].shape
output_shape = session.get_outputs()[i].shape
预处理输入数据
对输入数据进行颜色空间转换,尺寸缩放、标准化以及形状维度扩展操作。
rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
blob = cv2.resize(rgb, (input_w, input_h))
blob = blob.astype(np.float32)
blob /= 255.0
blob -= np.array([0.485, 0.456, 0.406])
blob /= np.array([0.229, 0.224, 0.225])
timg = cv2.dnn.blobFromImage(blob)
这部分不是OnnxRuntime核心部分,根据任务需求不同,代码略微不同。
执行推理
调用Session.run方法,传入输入张量、输出张量名和其他必要的参数,执行推理。
ort_outputs = session.run(output_names=input_node_names, input_feed={output_node_names[0]: timg})
Session.run参数 | output_names | input_feed |
---|---|---|
含义 | 输出节点名称的列表。 | 输入节点名称和输入数据的键值对字典,可能有多个输入。 |
后处理推理结果
推理完成后,从输出张量中获取结果数据,根据需要对结果进行后处理,以获得最终的预测结果。
prob = ort_outputs[0]
max_index = np.argmax(prob)
这部分不是OnnxRuntime核心部分,根据任务需求不同,代码基本不同。
ONNXRuntime推理代码
需要配置flower_classes.txt文件存储五种花的分类标签,并将其放置到工程目录下(推荐)。
daisy
dandelion
roses
sunflowers
tulips
这里需要将AlexNet.onnx放置到工程目录下(推荐),并且将以下推理代码拷贝到新建的py文件中,并执行查看结果。
import onnxruntime as ort
import cv2
import numpy as np# 加载标签文件获得分类标签
def read_class_names(file_path="./flower_classes.txt"):class_names = []try:with open(file_path, 'r') as fp:for line in fp:name = line.strip()if name:class_names.append(name)except IOError:print("could not open file...")import syssys.exit(-1)return class_names# 主函数
def main():# 预测的目标标签数labels = read_class_names()# 测试图片image_path = "./sunflowers.jpg"image = cv2.imread(image_path)cv2.imshow("输入图", image)cv2.waitKey(0)# 设置会话选项sess_options = ort.SessionOptions()# 0=VERBOSE, 1=INFO, 2=WARN, 3=ERROR, 4=FATALsess_options.log_severity_level = 3# 优化器级别:基本的图优化级别sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_BASIC# 线程数:4sess_options.intra_op_num_threads = 4# 设备使用优先使用GPU而是才是CPU,列表中的顺序决定了执行提供者的优先级providers = ['CUDAExecutionProvider', 'CPUExecutionProvider']# onnx训练模型文件onnxpath = "./AlexNet.onnx"# 加载模型并创建会话session = ort.InferenceSession(onnxpath, sess_options=sess_options, providers=providers)input_nodes_num = len(session.get_inputs()) # 输入节点输output_nodes_num = len(session.get_outputs()) # 输出节点数input_node_names = [] # 输入节点名称output_node_names = [] # 输出节点名称# 获取模型输入信息for i in range(input_nodes_num):# 获得输入节点的名称并存储input_name = session.get_inputs()[i].nameinput_node_names.append(input_name)# 显示输入图像的形状input_shape = session.get_inputs()[i].shapech, input_h, input_w = input_shape[1], input_shape[2], input_shape[3]print(f"input format: {ch}x{input_h}x{input_w}")# 获取模型输出信息for i in range(output_nodes_num):# 获得输出节点的名称并存储output_name = session.get_outputs()[i].nameoutput_node_names.append(output_name)# 显示输出结果的形状output_shape = session.get_outputs()[i].shapenum, nc = output_shape[0], output_shape[1]print(f"output format: {num}x{nc}")input_shape = session.get_inputs()[0].shapeinput_h, input_w = input_shape[2], input_shape[3]print(f"input format: {input_shape[1]}x{input_h}x{input_w}")# 预处理输入数据# 默认是BGR需要转化成RGBrgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)# 对图像尺寸进行缩放blob = cv2.resize(rgb, (input_w, input_h))blob = blob.astype(np.float32)# 对图像进行标准化处理blob /= 255.0 # 归一化blob -= np.array([0.485, 0.456, 0.406]) # 减去均值blob /= np.array([0.229, 0.224, 0.225]) # 除以方差#CHW-->NCHW 维度扩展timg = cv2.dnn.blobFromImage(blob)# ---blobFromImage 可以用以下替换---# blob = blob.transpose(2, 0, 1)# blob = np.expand_dims(blob, axis=0)# -------------------------------# 模型推理try:ort_outputs = session.run(output_names=output_node_names, input_feed={input_node_names[0]: timg})except Exception as e:print(e)ort_outputs = None# 后处理推理结果prob = ort_outputs[0]max_index = np.argmax(prob) # 获得最大值的索引print(f"label id: {max_index}")# 在测试图像上加上预测的分类标签label_text = labels[max_index]cv2.putText(image, label_text, (50, 50), cv2.FONT_HERSHEY_SIMPLEX, 1.0, (0, 0, 255), 2, 8)cv2.imshow("输入图像", image)cv2.waitKey(0)if __name__ == '__main__':main()
图片正确预测为向日葵:
总结
尽可能简单、详细的介绍了pytorch模型到onnx模型的转化,python下onnxruntime环境的搭建以及ONNX模型的OnnxRuntime部署。
相关文章:

【深度学习】【OnnxRuntime】【Python】模型转化、环境搭建以及模型部署的详细教程
【深度学习】【OnnxRuntime】【Python】模型转化、环境搭建以及模型部署的详细教程 提示:博主取舍了很多大佬的博文并亲测有效,分享笔记邀大家共同学习讨论 文章目录 【深度学习】【OnnxRuntime】【Python】模型转化、环境搭建以及模型部署的详细教程前言模型转换--pytorch转on…...

React学习笔记(1.0)
在使用vite创建react时,有一个语言选项,就是typescript-SWC,这里介绍一下SWC。 SWC:可扩展的Rust的平台,用于下一代快速开发工具,SWC比Babel快20倍。 简单来说,就是用于格式转换的,…...

Axure RP实战:打造高效图形旋转验证码
Axure RP实战:打造高效图形旋转验证码 在数字产品设计的海洋中,验证码环节往往是用户交互体验的细微之处,却承载着验证用户身份的重要任务。 传统的文本验证码虽然简单直接,但随着用户需求的提高和设计趋势的发展,它…...

101012分页属性
4k页面 P(有效位):1有效,0无效 R/W(读写位):1可读可写,0可读 U/S(权限位):1(User),0(System) A(物理页访问位ÿ…...

从0-1 用AI做一个赚钱的小红书账号(不是广告不是广告)
大家好,我是胡广!是不是被标题吸引过来的呢?是不是觉得自己天赋异禀,肯定是那万中无一的赚钱天才。哈哈哈,我告诉你,你我皆是牛马,不要老想着突然就成功了,一夜暴富了,瞬…...
【Kubernetes】常见面试题汇总(十七)
目录 51.简述 Kubernetes 网络策略? 52.简述 Kubernetes 网络策略原理? 53.简述 Kubernetes 中 flannel 的作用? 54.简述 Kubernetes Calico 网络组件实现原理? 51.简述 Kubernetes 网络策略? - 为实现细粒度的容器…...
Vue 3 中动态赋值 ref 的应用
引言 Vue 3 引入了许多新特性,其中之一便是 Composition API。Composition API 提供了一种新的编程范式,使开发者能够更灵活地组织和复用逻辑。其中 ref 是一个核心概念,它允许我们在组件内部声明响应式的状态。本文将探讨如何在 Vue 3 中使…...
Spring Boot-应用启动问题
在使用 Spring Boot 进行开发时,应用启动问题是开发人员经常遇到的挑战之一。通过有效排查和解决这些问题,可以提高应用的稳定性和可靠性。 1. Spring Boot 启动问题的常见表现 Spring Boot 应用启动失败通常表现为以下几种情况: 应用启动…...
深入解析:如何通过网络命名空间跟踪单个进程的网络活动(C/C++代码实现)
在 Linux 系统中,网络命名空间(Network Namespaces)是一种强大的功能,它允许系统管理员和开发者隔离网络资源,使得每个命名空间都拥有独立的网络协议栈。这种隔离机制不仅用于容器技术如 Docker,也是网络安…...
C++ 科目二 [const_cast]
基础数据类型 const_cast 仅仅是深层拷贝改变,而不是改动之前的值 如果需要使用改动后的值,需要通过指针或者引用来间接使用 const int n 5; const string s "MyString";// cosnt_cast 针对指针,引用,this指针 // co…...

【电脑组装】✈️从配置拼装到安装系统组装自己的台式电脑
目录 🍸前言 🍻一、台式电脑基本组成 🍺二、组装 🍹三、安装系统 👋四、系统设置 👀五、章末 🍸前言 小伙伴们大家好,上篇文章分享了在平时开发的时候遇到的一种项目整合情况&…...

Hadoop生态圈拓展内容(一)
1. Hadoop的主要部分及其作用 HDFS(Hadoop分布式文件系统) HDFS是一个高容错、高可靠性、高可扩展性、高吞吐率的分布式文件存储系统,负责海量数据的存储。 YARN(资源管理调度系统) YARN是Hadoop的资源管理调度系统…...

使用随机森林模型在digits数据集上执行分类任务
程序功能 使用随机森林模型对digits数据集进行手写数字分类任务。具体步骤如下: 加载数据:从digits数据集中获取手写数字图片的特征和对应的标签。 划分数据:将数据集分为训练集和测试集,测试集占30%。 训练模型:使用…...

后端开发刷题 | 打家劫舍
描述 你是一个经验丰富的小偷,准备偷沿街的一排房间,每个房间都存有一定的现金,为了防止被发现,你不能偷相邻的两家,即,如果偷了第一家,就不能再偷第二家;如果偷了第二家࿰…...

欧美游戏市场的差异
欧洲和美国的游戏市场虽然高度发达且利润丰厚,但表现出由文化偏好、消费者行为、监管环境和平台受欢迎程度塑造的独特特征。这些差异对于寻求为每个地区量身定制策略的游戏开发商和发行商来说非常重要。 文化偏好和游戏类型 美国:美国游戏市场倾向于青…...

DeDeCMS靶场漏洞复现
打开靶场地址 姿势一:通过文件管理器上传webshell 1.登录后台 dedecms默认的后台登录地址为/dede 2.在附加管理里的文件式管理器中有文件上传 3.上传木马文件 4.访问木马文件 并连接 姿势二:修改模板文件获取webshell 1.点击模板里面的默认模板管理 …...

Transformer模型详细步骤
Transformer模型是nlp任务中不能绕开的学习任务,我将从数据开始,每一步骤都列举出来,然后对应重点的代码进行讲解 ------------------------------------------------------------------------------------------------------------- Trans…...

LC并联电路在正弦稳态下的传递函数推导(LC并联谐振选频电路)
LC并联电路在正弦稳态下的传递函数推导(LC并联谐振选频电路) 本文通过 1.解微分方程、2.阻抗模型两种方法推导 LC 并联选频电路在正弦稳态条件下的传递函数,并通过仿真验证不同频率时 vo(t) 与 vi(t) 的幅值相角的关系。 电路介绍 已知条件…...
【前后端】大文件切片上传
Ruoyi框架上传文件_若依微服务框架 文件上传-CSDN博客 原理介绍 大文件上传时,如果直接上传整个文件,可能会因为文件过大导致上传失败、服务器超时或内存溢出等问题。因此,通常采用文件切片(Chunking)的方式来解决这些…...
图像处理 -- ISP功能之局部对比度增强 LCE
局部对比度增强(LCE) 局部对比度增强(Local Contrast Enhancement, LCE)是一种图像处理技术,旨在通过调整图像的局部区域对比度,增强图像细节和视觉效果。LCE 的实现方式多种多样,以下是几种常…...

网络六边形受到攻击
大家读完觉得有帮助记得关注和点赞!!! 抽象 现代智能交通系统 (ITS) 的一个关键要求是能够以安全、可靠和匿名的方式从互联车辆和移动设备收集地理参考数据。Nexagon 协议建立在 IETF 定位器/ID 分离协议 (…...
反向工程与模型迁移:打造未来商品详情API的可持续创新体系
在电商行业蓬勃发展的当下,商品详情API作为连接电商平台与开发者、商家及用户的关键纽带,其重要性日益凸显。传统商品详情API主要聚焦于商品基本信息(如名称、价格、库存等)的获取与展示,已难以满足市场对个性化、智能…...
线程同步:确保多线程程序的安全与高效!
全文目录: 开篇语前序前言第一部分:线程同步的概念与问题1.1 线程同步的概念1.2 线程同步的问题1.3 线程同步的解决方案 第二部分:synchronized关键字的使用2.1 使用 synchronized修饰方法2.2 使用 synchronized修饰代码块 第三部分ÿ…...

Android15默认授权浮窗权限
我们经常有那种需求,客户需要定制的apk集成在ROM中,并且默认授予其【显示在其他应用的上层】权限,也就是我们常说的浮窗权限,那么我们就可以通过以下方法在wms、ams等系统服务的systemReady()方法中调用即可实现预置应用默认授权浮…...

OPenCV CUDA模块图像处理-----对图像执行 均值漂移滤波(Mean Shift Filtering)函数meanShiftFiltering()
操作系统:ubuntu22.04 OpenCV版本:OpenCV4.9 IDE:Visual Studio Code 编程语言:C11 算法描述 在 GPU 上对图像执行 均值漂移滤波(Mean Shift Filtering),用于图像分割或平滑处理。 该函数将输入图像中的…...
AspectJ 在 Android 中的完整使用指南
一、环境配置(Gradle 7.0 适配) 1. 项目级 build.gradle // 注意:沪江插件已停更,推荐官方兼容方案 buildscript {dependencies {classpath org.aspectj:aspectjtools:1.9.9.1 // AspectJ 工具} } 2. 模块级 build.gradle plu…...
Java线上CPU飙高问题排查全指南
一、引言 在Java应用的线上运行环境中,CPU飙高是一个常见且棘手的性能问题。当系统出现CPU飙高时,通常会导致应用响应缓慢,甚至服务不可用,严重影响用户体验和业务运行。因此,掌握一套科学有效的CPU飙高问题排查方法&…...
Go 语言并发编程基础:无缓冲与有缓冲通道
在上一章节中,我们了解了 Channel 的基本用法。本章将重点分析 Go 中通道的两种类型 —— 无缓冲通道与有缓冲通道,它们在并发编程中各具特点和应用场景。 一、通道的基本分类 类型定义形式特点无缓冲通道make(chan T)发送和接收都必须准备好࿰…...
Kafka主题运维全指南:从基础配置到故障处理
#作者:张桐瑞 文章目录 主题日常管理1. 修改主题分区。2. 修改主题级别参数。3. 变更副本数。4. 修改主题限速。5.主题分区迁移。6. 常见主题错误处理常见错误1:主题删除失败。常见错误2:__consumer_offsets占用太多的磁盘。 主题日常管理 …...
SQL Server 触发器调用存储过程实现发送 HTTP 请求
文章目录 需求分析解决第 1 步:前置条件,启用 OLE 自动化方式 1:使用 SQL 实现启用 OLE 自动化方式 2:Sql Server 2005启动OLE自动化方式 3:Sql Server 2008启动OLE自动化第 2 步:创建存储过程第 3 步:创建触发器扩展 - 如何调试?第 1 步:登录 SQL Server 2008第 2 步…...