【深度学习】【OnnxRuntime】【Python】模型转化、环境搭建以及模型部署的详细教程
【深度学习】【OnnxRuntime】【Python】模型转化、环境搭建以及模型部署的详细教程
提示:博主取舍了很多大佬的博文并亲测有效,分享笔记邀大家共同学习讨论
文章目录
- 【深度学习】【OnnxRuntime】【Python】模型转化、环境搭建以及模型部署的详细教程
- 前言
- 模型转换--pytorch转onnx
- Windows平台搭建依赖环境
- onnxruntime调用onnx模型
- ONNXRuntime推理核心流程
- ONNXRuntime推理代码
- 总结
前言
ONNXRuntime是微软推出的一款高性能的机器学习推理引擎框架,用户可以非常便利的用其运行一个onnx模型,专注于加速机器学习模型的预测阶段。ONNXRuntime设计目的是为了提供一个高效的执行环境,使机器学习模型能够在各种硬件上快速执行,支持多种运行后端包括CPU,GPU,TensorRT,DML等,使得开发者可以灵活选择最适合其应用场景的硬件平台。
ONNXRuntime是对ONNX模型最原生的支持。
读者可以通过学习【onnx部署】部署系列学习文章目录的onnxruntime系统学习–Python篇 的内容,系统的学习OnnxRuntime部署不同任务的onnx模型。
模型转换–pytorch转onnx
Pytorch模型转onnx并推理的步骤如下:
- 将PyTorch预训练模型文件( .pth 或 .pt 格式)转换成ONNX格式的文件(.onnx格式),这一转换过程在PyTorch环境中进行。
- 将转换得到的 .onnx 文件随后作为输入,调用ONNXRuntime的C++ API来执行模型的推理。
博主使用AlexNet图像分类(五种花分类)进行演示,需要安装pytorch环境,对于该算法的基础知识,可以参考博主【AlexNet模型算法Pytorch版本详解】博文
conda create --name AlexNet python==3.10
conda activate AlexNet
# 根据自己主机配置环境
pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
# 假设模型转化出错则降级为指定1.16.1版本
pip install onnx==1.16.1
然后把训练模型好的AlexNet.pth模型转成AlexNet.onnx模型,pyorch2onnx.py转换代码如下:
import torch
from model import AlexNet
model = AlexNet(num_classes=5)
weights_path = "./AlexNet.pth"
# 加载模型权重
model.load_state_dict(torch.load(weights_path))
# 模型推理模式
model.eval()
model.cpu()
# 虚拟输入数据
dummy_input1 = torch.randn(1, 3, 224, 224)
# 模型转化函数
torch.onnx.export(model, (dummy_input1), "AlexNet.onnx", verbose=True, opset_version=11)
【AlexNet.pth百度云链接,提取码:ktq5 】直接下载使用即可。
Windows平台搭建依赖环境
需要在anaconda虚拟环境安装onnxruntime,需要注意onnxruntime-gpu, cuda, cudnn三者的版本要对应,具体参照官方说明。
博主是win11+cuda12.1+cudnn8.8.1,对应onnxruntime-gpu==1.18.0
import torch
# 查询cuda版本
print(torch.version.cuda)
# 查询cudnn版本
print(torch.backends.cudnn.version())
# 激活环境
activate AlexNet
# 安装onnx
pip install -i https://pypi.tuna.tsinghua.edu.cn/simple onnx
# 安装GPU版
pip install -i https://pypi.tuna.tsinghua.edu.cn/simple onnxruntime-gpu==1.18.0
# 或者可以安装CPU版本:没有版本对应要求
pip install -i https://pypi.tuna.tsinghua.edu.cn/simple onnxruntime
# 安装opencv
pip install -i https://pypi.tuna.tsinghua.edu.cn/simple opencv-python
onnxruntime调用onnx模型
ONNXRuntime推理核心流程
设置会话选项
通常包括配置优化器级别、线程数和设备(GPU/CPU)使用等。
sess_options = ort.SessionOptions()
sess_options.log_severity_level = 3
sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_BASIC
sess_options.intra_op_num_threads = 4
providers = ['CUDAExecutionProvider', 'CPUExecutionProvider']
会换选项 | 日志严重性级别 | 优化器级别 | 线程数 | 设备使用 |
---|---|---|---|---|
选项 | log_severity_level | graph_optimization_level | graph_optimization_level | CUDAExecutionProvider;CPUExecutionProvider |
作用 | 决定了哪些级别的日志信息将被记录下来,运行时提供了几个预定义的宏来表示不同的日志级别。 | 在模型加载到ONNXRuntime之前对其进行图优化的过程,提高执行效率 | 设置每个运算符内部执行时的最大线程数 | CUDA/CPU设备选择。 |
参数 | 整形,1:Info, 2:Warning. 3:Error, 4:Fatal,默认是2。 | ORT_ENABLE_BASIC:基本的图优化; ORT_DISABLE_ALL:禁用所有优化;ORT_ENABLE_EXTENDED:启用扩展优化;ORT_ENABLE_ALL:启用所有优化。 | 整型 | 列表中的顺序决定了执行提供者的优先级。 |
加载模型并创建会话
加载预训练的ONNX模型文件,使用运行时环境、会话选项和模型创建一个Session对象。
session = ort.InferenceSession(onnxpath, sess_options=sess_options, providers=providers)
ort.InferenceSession参数 | path_or_bytes | sess_options | providers |
---|---|---|---|
内容 | 模型的位置或者模型的二进制数据 | 会话选项 | 设备选择 |
获取模型输入输出信息
从Session对象中获取模型输入和输出的详细信息,包括数量、名称、类型和形状。
input_nodes_num = len(session.get_inputs())
output_nodes_num = len(session.get_outputs())
input_name = session.get_inputs()[i].name
output_name = session.get_outputs()[i].name
input_shape = session.get_inputs()[i].shape
output_shape = session.get_outputs()[i].shape
预处理输入数据
对输入数据进行颜色空间转换,尺寸缩放、标准化以及形状维度扩展操作。
rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
blob = cv2.resize(rgb, (input_w, input_h))
blob = blob.astype(np.float32)
blob /= 255.0
blob -= np.array([0.485, 0.456, 0.406])
blob /= np.array([0.229, 0.224, 0.225])
timg = cv2.dnn.blobFromImage(blob)
这部分不是OnnxRuntime核心部分,根据任务需求不同,代码略微不同。
执行推理
调用Session.run方法,传入输入张量、输出张量名和其他必要的参数,执行推理。
ort_outputs = session.run(output_names=input_node_names, input_feed={output_node_names[0]: timg})
Session.run参数 | output_names | input_feed |
---|---|---|
含义 | 输出节点名称的列表。 | 输入节点名称和输入数据的键值对字典,可能有多个输入。 |
后处理推理结果
推理完成后,从输出张量中获取结果数据,根据需要对结果进行后处理,以获得最终的预测结果。
prob = ort_outputs[0]
max_index = np.argmax(prob)
这部分不是OnnxRuntime核心部分,根据任务需求不同,代码基本不同。
ONNXRuntime推理代码
需要配置flower_classes.txt文件存储五种花的分类标签,并将其放置到工程目录下(推荐)。
daisy
dandelion
roses
sunflowers
tulips
这里需要将AlexNet.onnx放置到工程目录下(推荐),并且将以下推理代码拷贝到新建的py文件中,并执行查看结果。
import onnxruntime as ort
import cv2
import numpy as np# 加载标签文件获得分类标签
def read_class_names(file_path="./flower_classes.txt"):class_names = []try:with open(file_path, 'r') as fp:for line in fp:name = line.strip()if name:class_names.append(name)except IOError:print("could not open file...")import syssys.exit(-1)return class_names# 主函数
def main():# 预测的目标标签数labels = read_class_names()# 测试图片image_path = "./sunflowers.jpg"image = cv2.imread(image_path)cv2.imshow("输入图", image)cv2.waitKey(0)# 设置会话选项sess_options = ort.SessionOptions()# 0=VERBOSE, 1=INFO, 2=WARN, 3=ERROR, 4=FATALsess_options.log_severity_level = 3# 优化器级别:基本的图优化级别sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_BASIC# 线程数:4sess_options.intra_op_num_threads = 4# 设备使用优先使用GPU而是才是CPU,列表中的顺序决定了执行提供者的优先级providers = ['CUDAExecutionProvider', 'CPUExecutionProvider']# onnx训练模型文件onnxpath = "./AlexNet.onnx"# 加载模型并创建会话session = ort.InferenceSession(onnxpath, sess_options=sess_options, providers=providers)input_nodes_num = len(session.get_inputs()) # 输入节点输output_nodes_num = len(session.get_outputs()) # 输出节点数input_node_names = [] # 输入节点名称output_node_names = [] # 输出节点名称# 获取模型输入信息for i in range(input_nodes_num):# 获得输入节点的名称并存储input_name = session.get_inputs()[i].nameinput_node_names.append(input_name)# 显示输入图像的形状input_shape = session.get_inputs()[i].shapech, input_h, input_w = input_shape[1], input_shape[2], input_shape[3]print(f"input format: {ch}x{input_h}x{input_w}")# 获取模型输出信息for i in range(output_nodes_num):# 获得输出节点的名称并存储output_name = session.get_outputs()[i].nameoutput_node_names.append(output_name)# 显示输出结果的形状output_shape = session.get_outputs()[i].shapenum, nc = output_shape[0], output_shape[1]print(f"output format: {num}x{nc}")input_shape = session.get_inputs()[0].shapeinput_h, input_w = input_shape[2], input_shape[3]print(f"input format: {input_shape[1]}x{input_h}x{input_w}")# 预处理输入数据# 默认是BGR需要转化成RGBrgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)# 对图像尺寸进行缩放blob = cv2.resize(rgb, (input_w, input_h))blob = blob.astype(np.float32)# 对图像进行标准化处理blob /= 255.0 # 归一化blob -= np.array([0.485, 0.456, 0.406]) # 减去均值blob /= np.array([0.229, 0.224, 0.225]) # 除以方差#CHW-->NCHW 维度扩展timg = cv2.dnn.blobFromImage(blob)# ---blobFromImage 可以用以下替换---# blob = blob.transpose(2, 0, 1)# blob = np.expand_dims(blob, axis=0)# -------------------------------# 模型推理try:ort_outputs = session.run(output_names=output_node_names, input_feed={input_node_names[0]: timg})except Exception as e:print(e)ort_outputs = None# 后处理推理结果prob = ort_outputs[0]max_index = np.argmax(prob) # 获得最大值的索引print(f"label id: {max_index}")# 在测试图像上加上预测的分类标签label_text = labels[max_index]cv2.putText(image, label_text, (50, 50), cv2.FONT_HERSHEY_SIMPLEX, 1.0, (0, 0, 255), 2, 8)cv2.imshow("输入图像", image)cv2.waitKey(0)if __name__ == '__main__':main()
图片正确预测为向日葵:
总结
尽可能简单、详细的介绍了pytorch模型到onnx模型的转化,python下onnxruntime环境的搭建以及ONNX模型的OnnxRuntime部署。
相关文章:

【深度学习】【OnnxRuntime】【Python】模型转化、环境搭建以及模型部署的详细教程
【深度学习】【OnnxRuntime】【Python】模型转化、环境搭建以及模型部署的详细教程 提示:博主取舍了很多大佬的博文并亲测有效,分享笔记邀大家共同学习讨论 文章目录 【深度学习】【OnnxRuntime】【Python】模型转化、环境搭建以及模型部署的详细教程前言模型转换--pytorch转on…...

React学习笔记(1.0)
在使用vite创建react时,有一个语言选项,就是typescript-SWC,这里介绍一下SWC。 SWC:可扩展的Rust的平台,用于下一代快速开发工具,SWC比Babel快20倍。 简单来说,就是用于格式转换的,…...

Axure RP实战:打造高效图形旋转验证码
Axure RP实战:打造高效图形旋转验证码 在数字产品设计的海洋中,验证码环节往往是用户交互体验的细微之处,却承载着验证用户身份的重要任务。 传统的文本验证码虽然简单直接,但随着用户需求的提高和设计趋势的发展,它…...

101012分页属性
4k页面 P(有效位):1有效,0无效 R/W(读写位):1可读可写,0可读 U/S(权限位):1(User),0(System) A(物理页访问位ÿ…...

从0-1 用AI做一个赚钱的小红书账号(不是广告不是广告)
大家好,我是胡广!是不是被标题吸引过来的呢?是不是觉得自己天赋异禀,肯定是那万中无一的赚钱天才。哈哈哈,我告诉你,你我皆是牛马,不要老想着突然就成功了,一夜暴富了,瞬…...
【Kubernetes】常见面试题汇总(十七)
目录 51.简述 Kubernetes 网络策略? 52.简述 Kubernetes 网络策略原理? 53.简述 Kubernetes 中 flannel 的作用? 54.简述 Kubernetes Calico 网络组件实现原理? 51.简述 Kubernetes 网络策略? - 为实现细粒度的容器…...
Vue 3 中动态赋值 ref 的应用
引言 Vue 3 引入了许多新特性,其中之一便是 Composition API。Composition API 提供了一种新的编程范式,使开发者能够更灵活地组织和复用逻辑。其中 ref 是一个核心概念,它允许我们在组件内部声明响应式的状态。本文将探讨如何在 Vue 3 中使…...
Spring Boot-应用启动问题
在使用 Spring Boot 进行开发时,应用启动问题是开发人员经常遇到的挑战之一。通过有效排查和解决这些问题,可以提高应用的稳定性和可靠性。 1. Spring Boot 启动问题的常见表现 Spring Boot 应用启动失败通常表现为以下几种情况: 应用启动…...
深入解析:如何通过网络命名空间跟踪单个进程的网络活动(C/C++代码实现)
在 Linux 系统中,网络命名空间(Network Namespaces)是一种强大的功能,它允许系统管理员和开发者隔离网络资源,使得每个命名空间都拥有独立的网络协议栈。这种隔离机制不仅用于容器技术如 Docker,也是网络安…...
C++ 科目二 [const_cast]
基础数据类型 const_cast 仅仅是深层拷贝改变,而不是改动之前的值 如果需要使用改动后的值,需要通过指针或者引用来间接使用 const int n 5; const string s "MyString";// cosnt_cast 针对指针,引用,this指针 // co…...

【电脑组装】✈️从配置拼装到安装系统组装自己的台式电脑
目录 🍸前言 🍻一、台式电脑基本组成 🍺二、组装 🍹三、安装系统 👋四、系统设置 👀五、章末 🍸前言 小伙伴们大家好,上篇文章分享了在平时开发的时候遇到的一种项目整合情况&…...

Hadoop生态圈拓展内容(一)
1. Hadoop的主要部分及其作用 HDFS(Hadoop分布式文件系统) HDFS是一个高容错、高可靠性、高可扩展性、高吞吐率的分布式文件存储系统,负责海量数据的存储。 YARN(资源管理调度系统) YARN是Hadoop的资源管理调度系统…...

使用随机森林模型在digits数据集上执行分类任务
程序功能 使用随机森林模型对digits数据集进行手写数字分类任务。具体步骤如下: 加载数据:从digits数据集中获取手写数字图片的特征和对应的标签。 划分数据:将数据集分为训练集和测试集,测试集占30%。 训练模型:使用…...

后端开发刷题 | 打家劫舍
描述 你是一个经验丰富的小偷,准备偷沿街的一排房间,每个房间都存有一定的现金,为了防止被发现,你不能偷相邻的两家,即,如果偷了第一家,就不能再偷第二家;如果偷了第二家࿰…...

欧美游戏市场的差异
欧洲和美国的游戏市场虽然高度发达且利润丰厚,但表现出由文化偏好、消费者行为、监管环境和平台受欢迎程度塑造的独特特征。这些差异对于寻求为每个地区量身定制策略的游戏开发商和发行商来说非常重要。 文化偏好和游戏类型 美国:美国游戏市场倾向于青…...

DeDeCMS靶场漏洞复现
打开靶场地址 姿势一:通过文件管理器上传webshell 1.登录后台 dedecms默认的后台登录地址为/dede 2.在附加管理里的文件式管理器中有文件上传 3.上传木马文件 4.访问木马文件 并连接 姿势二:修改模板文件获取webshell 1.点击模板里面的默认模板管理 …...

Transformer模型详细步骤
Transformer模型是nlp任务中不能绕开的学习任务,我将从数据开始,每一步骤都列举出来,然后对应重点的代码进行讲解 ------------------------------------------------------------------------------------------------------------- Trans…...

LC并联电路在正弦稳态下的传递函数推导(LC并联谐振选频电路)
LC并联电路在正弦稳态下的传递函数推导(LC并联谐振选频电路) 本文通过 1.解微分方程、2.阻抗模型两种方法推导 LC 并联选频电路在正弦稳态条件下的传递函数,并通过仿真验证不同频率时 vo(t) 与 vi(t) 的幅值相角的关系。 电路介绍 已知条件…...
【前后端】大文件切片上传
Ruoyi框架上传文件_若依微服务框架 文件上传-CSDN博客 原理介绍 大文件上传时,如果直接上传整个文件,可能会因为文件过大导致上传失败、服务器超时或内存溢出等问题。因此,通常采用文件切片(Chunking)的方式来解决这些…...
图像处理 -- ISP功能之局部对比度增强 LCE
局部对比度增强(LCE) 局部对比度增强(Local Contrast Enhancement, LCE)是一种图像处理技术,旨在通过调整图像的局部区域对比度,增强图像细节和视觉效果。LCE 的实现方式多种多样,以下是几种常…...

(LeetCode 每日一题) 3442. 奇偶频次间的最大差值 I (哈希、字符串)
题目:3442. 奇偶频次间的最大差值 I 思路 :哈希,时间复杂度0(n)。 用哈希表来记录每个字符串中字符的分布情况,哈希表这里用数组即可实现。 C版本: class Solution { public:int maxDifference(string s) {int a[26]…...

初学 pytest 记录
安装 pip install pytest用例可以是函数也可以是类中的方法 def test_func():print()class TestAdd: # def __init__(self): 在 pytest 中不可以使用__init__方法 # self.cc 12345 pytest.mark.api def test_str(self):res add(1, 2)assert res 12def test_int(self):r…...

R语言速释制剂QBD解决方案之三
本文是《Quality by Design for ANDAs: An Example for Immediate-Release Dosage Forms》第一个处方的R语言解决方案。 第一个处方研究评估原料药粒径分布、MCC/Lactose比例、崩解剂用量对制剂CQAs的影响。 第二处方研究用于理解颗粒外加硬脂酸镁和滑石粉对片剂质量和可生产…...

力扣热题100 k个一组反转链表题解
题目: 代码: func reverseKGroup(head *ListNode, k int) *ListNode {cur : headfor i : 0; i < k; i {if cur nil {return head}cur cur.Next}newHead : reverse(head, cur)head.Next reverseKGroup(cur, k)return newHead }func reverse(start, end *ListNode) *ListN…...
MySQL 索引底层结构揭秘:B-Tree 与 B+Tree 的区别与应用
文章目录 一、背景知识:什么是 B-Tree 和 BTree? B-Tree(平衡多路查找树) BTree(B-Tree 的变种) 二、结构对比:一张图看懂 三、为什么 MySQL InnoDB 选择 BTree? 1. 范围查询更快 2…...
人工智能--安全大模型训练计划:基于Fine-tuning + LLM Agent
安全大模型训练计划:基于Fine-tuning LLM Agent 1. 构建高质量安全数据集 目标:为安全大模型创建高质量、去偏、符合伦理的训练数据集,涵盖安全相关任务(如有害内容检测、隐私保护、道德推理等)。 1.1 数据收集 描…...

实战三:开发网页端界面完成黑白视频转为彩色视频
一、需求描述 设计一个简单的视频上色应用,用户可以通过网页界面上传黑白视频,系统会自动将其转换为彩色视频。整个过程对用户来说非常简单直观,不需要了解技术细节。 效果图 二、实现思路 总体思路: 用户通过Gradio界面上…...

云安全与网络安全:核心区别与协同作用解析
在数字化转型的浪潮中,云安全与网络安全作为信息安全的两大支柱,常被混淆但本质不同。本文将从概念、责任分工、技术手段、威胁类型等维度深入解析两者的差异,并探讨它们的协同作用。 一、核心区别 定义与范围 网络安全:聚焦于保…...
【实施指南】Android客户端HTTPS双向认证实施指南
🔐 一、所需准备材料 证书文件(6类核心文件) 类型 格式 作用 Android端要求 CA根证书 .crt/.pem 验证服务器/客户端证书合法性 需预置到Android信任库 服务器证书 .crt 服务器身份证明 客户端需持有以验证服务器 客户端证书 .crt 客户端身份…...
使用 uv 工具快速部署并管理 vLLM 推理环境
uv:现代 Python 项目管理的高效助手 uv:Rust 驱动的 Python 包管理新时代 在部署大语言模型(LLM)推理服务时,vLLM 是一个备受关注的方案,具备高吞吐、低延迟和对 OpenAI API 的良好兼容性。为了提高部署效…...