当前位置: 首页 > news >正文

Pytorch与Onnx的转换与推理

Open Neural Network Exchange(ONNX,开放神经网络交换)格式,是一个用于表示深度学习模型的标准,可使模型在不同框架之间进行转移。

一、pytorch模型保存/加载
有两种方式可用于保存/加载pytorch模型 1)文件中保存模型结构和权重参数 2)文件只保留模型权重.

1、文件中保存模型结构和权重参数
模型保存与调用方式一(只保存权重):

保存:

torch.save(model.state_dict(), mymodel.pth)#只保存模型权重参数,不保存模型结构

调用:

model = My_model(*args, **kwargs)  #这里需要重新创建模型,My_model
model.load_state_dict(torch.load(mymodel.pth))#这里根据模型结构,导入存储的模型参数
model.eval()

模型保存与调用方式二(保存完整模型):

保存:

torch.save(model, mymodel.pth)#保存整个model的状态

调用:

model=torch.load(mymodel.pth)#这里已经不需要重构模型结构了,直接load就可以
model.eval()

.pt表示pytorch的模型,.onnx表示onnx的模型,后缀名为.pt, .pth, .pkl的pytorch模型文件之间其实没有任何区别

二、pytorch模型转ONNX模型
1、文件中保存模型结构和权重参数

import torch
torch_model = torch.load("/home/pytorch/save.pth") # pytorch模型加载#set the model to inference mode
torch_model.eval()x = torch.randn(1,3,320,640)        # 生成张量(模型输入格式)
export_onnx_file = "/home/pytorch/test.onnx"   # 目的ONNX文件名// 导出export:pt->onnx
torch.onnx.export(torch_model,                    # pytorch模型x,                            # 生成张量(模型输入格式)export_onnx_file,            # 目的ONNX文件名do_constant_folding=True,    # 是否执行常量折叠优化input_names=["input"],        # 输入名(可略)output_names=["output"],    # 输出名(可略)dynamic_axes={"input":{0:"batch_size"},        # 批处理变量(可略)"output":{0:"batch_size"}}) 

注:dynamic_axes字段用于批处理.若不想支持批处理或固定批处理大小,移除dynamic_axes字段即可.

2、文件中只保留模型权重

import torch
torch_model = selfmodel()                      # 由研究员提供python.py文件#set the model to inference mode
torch_model.eval()x = torch.randn(1,3,320,640)        # 生成张量(模型输入格式)
export_onnx_file = "/home/pytorch/test.onnx"   # 目的ONNX文件名// 导出export:pt->onnx
torch.onnx.export(torch_model,                    # pytorch模型x,                            # 生成张量(模型输入格式)export_onnx_file,            # 目的ONNX文件名do_constant_folding=True,    # 是否执行常量折叠优化input_names=["input"],        # 输入名(可略)output_names=["output"],    # 输出名(可略)dynamic_axes={"input":{0:"batch_size"},        # 批处理变量(可略)"output":{0:"batch_size"}}) 

3、onnx文件操作

3.1 安装onnx,onnxruntime:

pip install onnx
pip install onnxruntime(只能用cpu)
pip install onnxruntime-gpu(gpu和cpu都能用)

首先要强调的是,有两个版本的onnxruntime,一个叫onnxruntime,只能使用cpu推理,另一个叫onnxruntime-gpu,既可以使用gpu,也可以使用cpu。
如果自己安装的是onnxruntime,需要卸载后安装gpu版本。

 确认一下是否可以使用gpu
注意:
```python
print(onnxruntime.get_device())
```
上面的代码给出的输出是'GPU'时,并不代表就成功了。

而要用下面的代码来验证:
```python
ort_session = onnxruntime.InferenceSession("path/model/model_name.onnx",
providers=['CUDAExecutionProvider'])
print(ort_session.get_providers())
```
当输出是:['CUDAExecutionProvider', 'CPUExecutionProvider']才表示成功了。

版本查询:NVIDIA - CUDA | onnxruntime

安装固定版本的onnxruntime:

pip install onnxruntime-gpu==1.9.0

卸载pip uninstall


3.2 加载onnx文件

# "加载load"
model=onnx.load('net.onnx')

检查模型格式是否完整及正确

onnx.checker.check_model(model)

3.3 打印onnx模型文件信息

session=onnxruntime.InferenceSession('net.onnx')
inp=session.get_inputs()[0]#conv1=session.get_inputs()['conv1']
#out1=session.get_outputs()[1]
out=session.get_provider_options()
#print(inp,conv1,out1)
print(inp)
#print(out)
"打印图信息:字符串信息"
graph=onnx.helper.printable_graph(model.graph)
print(type(graph))

3.4 获取onnx模型输入输出层

input=model.graph.input
output = model.graph.output
"""输入输出层"""
print(input,output)

3.5 推理过程

import onnx
import onnxruntime
import torchinputs=torch.randn(1,3,640,320)
#上述inputs仅用于测试使用,用于图片推理,应该换成自己的图片,如:
#img_path='1.jpg'#图片尺寸与onnx模型的处理尺寸保持一致
#img=cv2.imread(img_path)
#inputs=preprocess_imgae(img)#标准化等预处理操作,与源项目代码保持一致即可
#print('inputs.size():',inputs.size())model=onnx.load('/home/pytorch_DL/test_320_640.onnx')
onnx.checker.check_model(model)
session =onnxruntime.InferenceSession('/home/pytorch_DL/test_320_640.onnx',
providers['CUDAExecutionProvider','CPUExecutionProvider'])
print('session.get_providers():',session.get_providers())
input_name = session.get_inputs()
output_name=session.get_outputs()[0].name
res=session.run([output_name],{input_name[0].name:inputs.numpy()})

参考:Pytorch与Onnx模型的保存、转换与操作_onnx转pytorch_Yuezero_的博客-CSDN博客

pytorch 模型的保存与加载方法以及使用onnx模型部署推理 | 码农家园

onnxruntime使用gpu推理 - 知乎

相关文章:

Pytorch与Onnx的转换与推理

Open Neural Network Exchange(ONNX,开放神经网络交换)格式,是一个用于表示深度学习模型的标准,可使模型在不同框架之间进行转移。 一、pytorch模型保存/加载 有两种方式可用于保存/加载pytorch模型 1)文件…...

Linux权限详解

文章目录 1. shell命令及运行原理2. Linux权限的概念(1)用户种类(2)切换用户(3)命令提权 3. Linux权限管理(1)文件访问者的分类(人)(2&#xff09…...

基于react18+arco+zustand通用后台管理系统React18Admin

React-Arco-Admin轻量级后台管理系统解决方案 基于vite4构建react18后台项目ReactAdmin。使用了reactarco-designzustandbizcharts等技术架构非凡后台管理框架。支持 dark/light主题、i18n国际化、动态路由鉴权、3种经典布局、tabs路由标签 等功能。 技术框架 编辑器&#xff…...

BAT031:按列表名单将路径a下的文件夹批量剪切到路径b

引言:编写批处理程序,实现按列表名单将路径a下的文件夹批量剪切到路径b。 一、新建Windows批处理文件 参考博客: CSDNhttps://mp.csdn.net/mp_blog/creation/editor/132137544 二、写入批处理代码 1.右键新建的批处理文件,点击…...

随机专享记录第一话 -- RustDesk的自我搭建和使用

1.介绍 RustDesk是继TeamView、向日葵等远程桌面软件后的新起之秀,最主要的是开源的可自己搭建中继服务。相比于公共服务器,连接一次等待的时间要多久,用过TeamView的都知道,而且还是免费的,不像某些远程搞各种个人证书,各种登录设备限制! 先看看软件图,这是待连接界…...

【数据库】拼接字段 使用别名

拼接字段 使用别名 e . g . e.g. e.g. Vendors 表包含供应商名和电话信息,name 和 mobile;需要输出这两个属性的值的组合作为供应商的基本信息组合。 SELECT concat(name, _, mobile) FROM Vendors; -- 语句通过 MySQL 环境下测试,其他 DBMS…...

Golang设计22种模式

什么是设计模式 设计模式是面向对象软件的设计经验,是通常设计问题的解决方案。每一种设计模式系统的命名、解释和评价了面向对象中一个重要的和重复出现的设计。 设计模式的分类 创建模式 - 用来帮助我们创建对象的 工厂模式 (Factory Pattern)抽象工厂模式 (Abstract F…...

MMKV(3)

使用时遇到的问题 在项目的构建配置文件(如 Gradle 或 Maven)中添加相应的依赖项。 MMKV 是一个键值存储库,它存储的是原始的字节数组数据。需要存储和检索复杂的对象或数据结构,需要自行进行序列化和反序列化操作。可以使用任何…...

vivado报错警告之[Vivado 12-1017] Problems encountered:

文章目录 方法一方法二方法三(作者最终解决) 我们对vivado 的程序进行综合(Run Synthesis)时,可能会出现[Vivado 12-1017] Problems encountered: 1. Failed to delete one or more files in run directory的一个警告信息,导致我们…...

基于springboot汽车租赁系统

功能如下图所示 摘要 Spring Boot汽车租赁系统的设计旨在满足不断增长的租车市场需求,并通过简化开发和部署流程来提供方便的租车解决方案。系统采用了现代化的架构,主要基于以下技术栈: Spring Boot:作为后端的核心框架&#xff…...

C++禁用赋值操作符

1.禁用赋值操作符 在C中,void operator(const ClassName&) delete; 是一种特殊的语法,用于明确地禁止赋值操作符(assignment operator)的默认实现或自定义实现。 这通常用于防止类的实例被意外赋值。通过明确地删除赋值操作…...

小程序的数据驱动和vue的双向绑定有何异同?

小程序的数据驱动和Vue的双向绑定有以下异同之处: 异同点: 数据驱动:小程序的数据驱动是指通过编写数据绑定的代码,将数据与视图进行关联,当数据发生变化时,视图会自动更新。而Vue的双向绑定则是一种特殊的…...

Nvm管理NodeJs版本

文章目录 Nvm管理NodeJs版本一、前言1.简介2.环境 二、正文1.卸载NodeJs2.安装Nvm3.配置国内镜像4.Nvm使用5.其它1)报错12)报错2 Nvm管理NodeJs版本 一、前言 1.简介 Node Version Manager(nvm)可通过命令行快速安装和使用不同…...

阿里云国际站服务器开放端口详解!!

在互联网技术发展的今天,服务器扮演着至关重要的角色。作为云服务供给商,阿里云服务器供给了安稳、高效的服务,而敞开端口则是阿里云服务器功能的重要体现。本文将详细解读阿里云服务器敞开端口的意义、实现办法以及其带来的优点。 一、阿里云…...

【自动化测试入门】用Airtest - Selenium对Firefox进行自动化测试(0基础也能学会)

1. 前言 本文将详细介绍如何使用AirtestIDE驱动Firefox测试,以及脱离AirtestIDE怎么驱动Firefox(VScode为例)。看完本文零基础小白也能学会Firefox浏览器自动化测试!!! 2. 如何使用AirtestIDE驱动Firefox…...

Python 爬虫入门:常见工具介绍

接着我的上一篇文章《网页爬虫完全指南》,这篇文章将涵盖几乎所有的 Python 网页爬取工具。我们从最基本的开始讲起,逐步涉及到当前最前沿的技术,并且对它们的利弊进行分析。 当然,我们不能全面地介绍每个工具,但这篇…...

uniGUI文件操作

一.文件上传TUniFileUploadButton TUniFileUploadButton主要属性: Filter: 文件类型过滤,有图片image/* audio/* video/*三种过滤 MaxAllowedSize: 设置文件最大上传尺寸; Message:标题以及消息文本,可翻译成中文…...

Python多进程之分享(multiprocessing包)

threading和multiprocessing (可以阅读Python多线程与同步) multiprocessing包是Python中的多进程管理包。与threading.Thread类似,它可以利用multiprocessing.Process对象来创建一个进程。该进程可以运行在Python程序内部编写的函数。该Process对象与Thread对象的…...

【试题028】C语言关于逻辑与的短路例题

1.题目&#xff1a;设inta1,b;&#xff0c;执行b0&&(a);后&#xff0c;变量a的值是&#xff1f; 2.代码解析&#xff1a; #include <stdio.h> int main() {//设inta1,b;执行b0&&(a);后&#xff0c;变量a的值是?int a 1, b;printf("表达式的值是…...

TSINGSEE烟火识别算法的技术原理是什么?如何应用在视频监控中?

AI烟火识别算法是基于深度学习技术的一种视觉识别算法&#xff0c;主要用于在视频监控场景中自动检测和识别烟雾、火焰的行为。该技术基于深度学习神经网络技术&#xff0c;可以动态识别烟雾和火焰从有到无、从小到大、从大到小、从小烟到浓烟的状态转换过程。 1、技术原理 1…...

挑战杯推荐项目

“人工智能”创意赛 - 智能艺术创作助手&#xff1a;借助大模型技术&#xff0c;开发能根据用户输入的主题、风格等要求&#xff0c;生成绘画、音乐、文学作品等多种形式艺术创作灵感或初稿的应用&#xff0c;帮助艺术家和创意爱好者激发创意、提高创作效率。 ​ - 个性化梦境…...

从WWDC看苹果产品发展的规律

WWDC 是苹果公司一年一度面向全球开发者的盛会&#xff0c;其主题演讲展现了苹果在产品设计、技术路线、用户体验和生态系统构建上的核心理念与演进脉络。我们借助 ChatGPT Deep Research 工具&#xff0c;对过去十年 WWDC 主题演讲内容进行了系统化分析&#xff0c;形成了这份…...

day52 ResNet18 CBAM

在深度学习的旅程中&#xff0c;我们不断探索如何提升模型的性能。今天&#xff0c;我将分享我在 ResNet18 模型中插入 CBAM&#xff08;Convolutional Block Attention Module&#xff09;模块&#xff0c;并采用分阶段微调策略的实践过程。通过这个过程&#xff0c;我不仅提升…...

Go 语言接口详解

Go 语言接口详解 核心概念 接口定义 在 Go 语言中&#xff0c;接口是一种抽象类型&#xff0c;它定义了一组方法的集合&#xff1a; // 定义接口 type Shape interface {Area() float64Perimeter() float64 } 接口实现 Go 接口的实现是隐式的&#xff1a; // 矩形结构体…...

蓝桥杯 2024 15届国赛 A组 儿童节快乐

P10576 [蓝桥杯 2024 国 A] 儿童节快乐 题目描述 五彩斑斓的气球在蓝天下悠然飘荡&#xff0c;轻快的音乐在耳边持续回荡&#xff0c;小朋友们手牵着手一同畅快欢笑。在这样一片安乐祥和的氛围下&#xff0c;六一来了。 今天是六一儿童节&#xff0c;小蓝老师为了让大家在节…...

在 Nginx Stream 层“改写”MQTT ngx_stream_mqtt_filter_module

1、为什么要修改 CONNECT 报文&#xff1f; 多租户隔离&#xff1a;自动为接入设备追加租户前缀&#xff0c;后端按 ClientID 拆分队列。零代码鉴权&#xff1a;将入站用户名替换为 OAuth Access-Token&#xff0c;后端 Broker 统一校验。灰度发布&#xff1a;根据 IP/地理位写…...

NFT模式:数字资产确权与链游经济系统构建

NFT模式&#xff1a;数字资产确权与链游经济系统构建 ——从技术架构到可持续生态的范式革命 一、确权技术革新&#xff1a;构建可信数字资产基石 1. 区块链底层架构的进化 跨链互操作协议&#xff1a;基于LayerZero协议实现以太坊、Solana等公链资产互通&#xff0c;通过零知…...

在QWebEngineView上实现鼠标、触摸等事件捕获的解决方案

这个问题我看其他博主也写了&#xff0c;要么要会员、要么写的乱七八糟。这里我整理一下&#xff0c;把问题说清楚并且给出代码&#xff0c;拿去用就行&#xff0c;照着葫芦画瓢。 问题 在继承QWebEngineView后&#xff0c;重写mousePressEvent或event函数无法捕获鼠标按下事…...

动态 Web 开发技术入门篇

一、HTTP 协议核心 1.1 HTTP 基础 协议全称 &#xff1a;HyperText Transfer Protocol&#xff08;超文本传输协议&#xff09; 默认端口 &#xff1a;HTTP 使用 80 端口&#xff0c;HTTPS 使用 443 端口。 请求方法 &#xff1a; GET &#xff1a;用于获取资源&#xff0c;…...

【从零学习JVM|第三篇】类的生命周期(高频面试题)

前言&#xff1a; 在Java编程中&#xff0c;类的生命周期是指类从被加载到内存中开始&#xff0c;到被卸载出内存为止的整个过程。了解类的生命周期对于理解Java程序的运行机制以及性能优化非常重要。本文会深入探寻类的生命周期&#xff0c;让读者对此有深刻印象。 目录 ​…...