当前位置: 首页 > news >正文

将Pytorch搭建的ViT模型转为onnx模型

本文尝试将pytorch搭建的ViT模型转为onnx模型。

首先将博主上一篇文章中搭建的模型ViT Vision Transformer超详细解析,网络构建,可视化,数据预处理,全流程实例教程-CSDN博客转存为.pth

torch.save(model, 'my_vit_model.pth')

然后新建一个py文件,要新建py文件的原因是,博主上一篇文章的main.py文件引用了很多torch相关的库,如果还是在main.py文件中运行转onnx的代码,回报错circle import 重复循环引用的错误,所以姑且将.pth作为一个中转。

新建一个py文件,写入

import importlib
torch = importlib.import_module('torch')model = torch.load("my_vit_model.pth")model.cpu()
# 创建一个随机的输入张量
dummy_input = torch.randn(1, 3, 16, 16)
torch.onnx.export(model, dummy_input, 'model.onnx', opset_version=18)

引入importlib,通过它来引用torch也是为了解决循环引用的问题。

这时运行这段代码,会报错onnx 不支持aten::unflatten运算。这里有两种解决方法,一种是将自己pytorch模型中的unflatten运算全部换成onnx支持的reshape函数(参见文章:https://www.cnblogs.com/antelx/p/17564039.html)

还有一种方法是,修改onnx库中的 symbolic_opset18.py 文件(/home/.local/lib/python3.8/site-packages/torch/onnx),改为如下形式

"""This file exports ONNX ops for opset 18.Note [ONNX Operators that are added/updated in opset 18]~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
https://github.com/onnx/onnx/blob/main/docs/Changelog.md#version-18-of-the-default-onnx-operator-set
New operators:CenterCropPadCol2ImMishOptionalGetElementOptionalHasElementPadResizeScatterElementsScatterND
"""import functools
from typing import Sequenceimport torch
import torch._C._onnx as _C_onnx
from torch.onnx import (_constants,_type_utils,errors,symbolic_helper,symbolic_opset11 as opset11,symbolic_opset9 as opset9,utils,
)
from torch.onnx._internal import _beartype, jit_utils, registrationfrom torch import _C
from torch.onnx import symbolic_helper
from torch.onnx._internal import _beartype, registration# EDITING THIS FILE? READ THIS FIRST!
# see Note [Edit Symbolic Files] in symbolic_helper.py__all__ = ["col2im"]_onnx_symbolic = functools.partial(registration.onnx_symbolic, opset=18)@_onnx_symbolic("aten::col2im")
@symbolic_helper.parse_args("v", "v", "v", "is", "is", "is")
@_beartype.beartype
def col2im(g,input: _C.Value,output_size: _C.Value,kernel_size: _C.Value,dilation: Sequence[int],padding: Sequence[int],stride: Sequence[int],
):# convert [i0, i1, ..., in] into [i0, i0, i1, i1, ..., in, in]adjusted_padding = []for pad in padding:for _ in range(2):adjusted_padding.append(pad)num_dimensional_axis = symbolic_helper._get_tensor_sizes(output_size)[0]if not adjusted_padding:adjusted_padding = [0, 0] * num_dimensional_axisif not dilation:dilation = [1] * num_dimensional_axisif not stride:stride = [1] * num_dimensional_axisreturn g.op("Col2Im",input,output_size,kernel_size,dilations_i=dilation,pads_i=adjusted_padding,strides_i=stride,)@_onnx_symbolic("aten::unflatten")
def unflatten(g:jit_utils.GraphContext, input, dim, unflattened_size):input_dim = symbolic_helper._get_tensor_rank(input)if input_dim is None:return symbolic_helper._unimplemented("dim","ONNX and PyTorch use different strategies to split the input. ""Input rank must be known at export time.",)# dim could be negativeinput_dim = g.op("Constant", value_t=torch.tensor([input_dim], dtype=torch.int64))dim = g.op("Add", input_dim, dim)dim = g.op("Mod", dim, input_dim)input_size = g.op("Shape", input)head_start_idx = g.op("Constant", value_t=torch.tensor([0], dtype=torch.int64))head_end_idx = g.op("Reshape", dim, g.op("Constant", value_t=torch.tensor([1], dtype=torch.int64)))head_part_rank = g.op("Slice", input_size, head_start_idx, head_end_idx)dim_plus_one = g.op("Add", dim, g.op("Constant", value_t=torch.tensor([1], dtype=torch.int64)))tail_start_idx = g.op("Reshape",dim_plus_one,g.op("Constant", value_t=torch.tensor([1], dtype=torch.int64)),)tail_end_idx = g.op("Constant", value_t=torch.tensor([_constants.INT64_MAX], dtype=torch.int64))tail_part_rank = g.op("Slice", input_size, tail_start_idx, tail_end_idx)final_shape = g.op("Concat", head_part_rank, unflattened_size, tail_part_rank, axis_i=0)return symbolic_helper._reshape_helper(g, input, final_shape)

这里这样做是相当于自己在onnx库中注册aten::unflatten运算。

再新建一个py文件,写入

import onnxruntime as rt
import numpy as np# 加载模型
sess = rt.InferenceSession("model.onnx")# 获取输入和输出名称
input_name = sess.get_inputs()[0].name
output_name = sess.get_outputs()[0].name# 创建输入数据
input_data = np.random.rand(1, 3, 16, 16).astype(np.float32)# 运行模型
pred_onnx = sess.run([output_name], {input_name: input_data})# 打印预测结果
print(pred_onnx)

就可以运行onnx模型了。

相关文章:

将Pytorch搭建的ViT模型转为onnx模型

本文尝试将pytorch搭建的ViT模型转为onnx模型。 首先将博主上一篇文章中搭建的模型ViT Vision Transformer超详细解析,网络构建,可视化,数据预处理,全流程实例教程-CSDN博客转存为.pth torch.save(model, my_vit_model.pth) 然…...

图神经网络(GNN)性能优化方案汇总,附37个配套算法模型和代码

图神经网络的表达能力对其性能和应用范围有着重要的影响,是GNN研究的核心问题和发展方向。增强表达能力是扩展GNN应用范围、提高性能的关键所在。 目前GNN的表达能力受特征表示和拓扑结构这两个因素的影响,其中GNN在学习和保持图拓扑方面的缺陷是限制表…...

国科大移动互联网考试资料(2023+2020+2018真题+答案)

老师王文杰。真题附加2022部分...

ModStart系统安全规范建议

1 不要使用弱密码 很多人为了系统管理方便(或者是懒),经常会设置类似 123456、admin 这样的管理密码,这样的密码很容易被暴力软件扫描出来。 2 不要使用默认配置 默认的软件系统设置、默认的系统端口、默认的网站设置在发生漏洞…...

【漏洞复现】Django_debug page_XSS漏洞(CVE-2017-12794)

感谢互联网提供分享知识与智慧,在法治的社会里,请遵守有关法律法规 文章目录 1.1、漏洞描述1.2、漏洞等级1.3、影响版本1.4、漏洞复现1、基础环境2、漏洞分析3、漏洞验证 说明内容漏洞编号CVE-2017-12794漏洞名称Django_debug page_XSS漏洞漏洞评级影响范…...

Redis性能调优:深度剖析与示例解析

标题:Redis性能调优:深度剖析与示例解析 引言 Redis是一款强大的开源内存数据库,广泛应用于高性能系统。然而,为了充分发挥Redis的性能,需要进行合理的性能调优。本博客将深入介绍Redis性能调优的策略和示例&#xf…...

oracle查询前几条数据的方法

在Oralce中实现select top N&#xff1a;由于Oracle不支持select top 语句&#xff0c;所以在oracle中经常是用order by 跟rownum的组合来实现select top n的查询。 方法1&#xff1a; SELECT * FROM (SELECT * FROM EMP ORDER BY SAL DESC) WHERE ROWNUM < 5 --抽取处记录…...

c#弹性和瞬态故障处理库Polly

1. 重试&#xff08;Retry&#xff09; Policy .Handle<Exception>() //指定需要重试的异常类型 .Retry(2,(ex,count,context)> { //指定发生异常重试的次数Console.WriteLine($ "重试次数{count},异常{ex.Message}" ); }) …...

20231107-前端学习炫酷菜单效果和折叠侧边栏

炫酷菜单效果 代码 <!DOCTYPE html> <html lang"en"><head><meta charset"UTF-8"><meta name"viewport" content"widthdevice-width, initial-scale1.0"><title>炫酷菜单效果</title><…...

基于CLIP的图像分类、语义分割和目标检测

OpenAI CLIP模型是一个创造性的突破&#xff1b; 它以与文本相同的方式处理图像。 令人惊讶的是&#xff0c;如果进行大规模训练&#xff0c;效果非常好。 在线工具推荐&#xff1a; Three.js AI纹理开发包 - YOLO合成数据生成器 - GLTF/GLB在线编辑 - 3D模型格式在线转换 - 3D…...

python爬虫(数据获取——selenium)

环境测试 from selenium import webdriverchromedriver_path r"C:\Program Files\Google\Chrome\Application\chromedriver.exe" driver webdriver.Chrome()url "https://www.xinpianchang.com/discover/article?fromnavigator" driver.get(url)drive…...

[wp]NewStarCTF 2023 WEEK5|WEB

前言:比赛是结束了&#xff0c;但我的学习还未结束&#xff0c;看看自己能复习几道题吧&#xff0c;第四周实在太难 Final 考点&#xff1a; ThinkPHP 5.0.23 RCE一句话木马上传SUID提权&#xff08;find&#xff09; 解题: 首先页面就给了ThinkPHP V5&#xff0c; 那无非考…...

未将对象引用设置到对象实例

环境 vs 2017 qt 5.13.0 qt-vs-addin 2.10 qt 项目打开的vs 2010 的项目 配置完成之后可以编译执行&#xff0c;但是新建qt 类提示 未将对象引用设置到对象实例 问题 插件的版本太高了使用低版本的&#xff0c;到qt 官网下载Index of /official_releases/vsaddin 下载q…...

网络的地址簿:Linux DNS服务的全面指南

1 dns 1.1 dns&#xff08;域名解析服务&#xff09;介绍 当访问 www.baidu.com 首先查询/etc/hosts&#xff0c;如果没有再去查询/etc/resolv.conf&#xff0c;还是没有就去查询域名服务器 关于客户端: /etc/resolv.conf ##dns指向文件 nameserver 172.25.254.20测试&…...

输电线路AR可视化巡检降低作业风险

随着现代工业的快速发展&#xff0c;各行业的一线技术工人要处理的问题越来越复杂&#xff0c;一些工作中棘手的问题迫切需要远端专家的协同处理。但远端专家赶来现场往往面临着专家差旅成本高、设备停机损失大、专业支持滞后、突发故障无法立即解决等痛点。传统的远程协助似乎…...

18. 四数之和

18. 四数之和 原题链接&#xff1a;完成情况&#xff1a;解题思路&#xff1a;参考代码&#xff1a;错误经验吸取 原题链接&#xff1a; 18. 四数之和 https://leetcode.cn/problems/4sum/description/ 完成情况&#xff1a; 解题思路&#xff1a; /** * //HashMap只能记录…...

排序:堆排序(未完待续)

文章目录 排序一、 排序的概念1.排序&#xff1a;2.稳定性&#xff1a;3.内部排序&#xff1a;4.外部排序&#xff1a; 二、插入排序1.直接插入排序 二、插入排序堆排序 排序 一、 排序的概念 1.排序&#xff1a; 一组数据按递增/递减排序 2.稳定性&#xff1a; 待排序的序列…...

小米智能电视投屏方法

小米智能电视也提供了投屏功能。 使用遥控器&#xff0c;在应用中找到它&#xff0c;点击进入。 小米电视支持windows笔记本&#xff0c;macbook笔记本&#xff0c;iphone手机&#xff0c;安卓手机投屏。 windows笔记本投屏 在投屏应用中找到windows投屏&#xff0c;选中开…...

保外就医罪犯收到指定医院《罪犯病情诊断书》及检测、检查报告等其他医疗文书后,应当在规定时限内提交( ),或者受委托司法所审查。

需要查看详细试题题库及其参考答案的&#xff0c;请到&#xff08;题-海-舟&#xff09;里进行搜索查看。可搜试题题干或者搜索关键词&#xff0c;搜题的时候&#xff0c;先进行题目识别&#xff0c;能大大提高学习效率&#xff0c;感谢使用&#xff01; 保外就医罪犯收到指定…...

pytorh模型训练、测试

目录 1 导入数据集 2 使用tensorboard展示经过各个层的图片数据 3 完整的模型训练测试流程 使用Gpu训练的两种方式 使用tensorboard显示模型 模型训练测试 L1Loss函数 保存未训练模型或者已经训练完的模型 4 加载训练好的模型进行测试 1 导入数据集 import torch from torch.u…...

接口测试中缓存处理策略

在接口测试中&#xff0c;缓存处理策略是一个关键环节&#xff0c;直接影响测试结果的准确性和可靠性。合理的缓存处理策略能够确保测试环境的一致性&#xff0c;避免因缓存数据导致的测试偏差。以下是接口测试中常见的缓存处理策略及其详细说明&#xff1a; 一、缓存处理的核…...

MPNet:旋转机械轻量化故障诊断模型详解python代码复现

目录 一、问题背景与挑战 二、MPNet核心架构 2.1 多分支特征融合模块(MBFM) 2.2 残差注意力金字塔模块(RAPM) 2.2.1 空间金字塔注意力(SPA) 2.2.2 金字塔残差块(PRBlock) 2.3 分类器设计 三、关键技术突破 3.1 多尺度特征融合 3.2 轻量化设计策略 3.3 抗噪声…...

AI Agent与Agentic AI:原理、应用、挑战与未来展望

文章目录 一、引言二、AI Agent与Agentic AI的兴起2.1 技术契机与生态成熟2.2 Agent的定义与特征2.3 Agent的发展历程 三、AI Agent的核心技术栈解密3.1 感知模块代码示例&#xff1a;使用Python和OpenCV进行图像识别 3.2 认知与决策模块代码示例&#xff1a;使用OpenAI GPT-3进…...

Java 8 Stream API 入门到实践详解

一、告别 for 循环&#xff01; 传统痛点&#xff1a; Java 8 之前&#xff0c;集合操作离不开冗长的 for 循环和匿名类。例如&#xff0c;过滤列表中的偶数&#xff1a; List<Integer> list Arrays.asList(1, 2, 3, 4, 5); List<Integer> evens new ArrayList…...

JVM垃圾回收机制全解析

Java虚拟机&#xff08;JVM&#xff09;中的垃圾收集器&#xff08;Garbage Collector&#xff0c;简称GC&#xff09;是用于自动管理内存的机制。它负责识别和清除不再被程序使用的对象&#xff0c;从而释放内存空间&#xff0c;避免内存泄漏和内存溢出等问题。垃圾收集器在Ja…...

《基于Apache Flink的流处理》笔记

思维导图 1-3 章 4-7章 8-11 章 参考资料 源码&#xff1a; https://github.com/streaming-with-flink 博客 https://flink.apache.org/bloghttps://www.ververica.com/blog 聚会及会议 https://flink-forward.orghttps://www.meetup.com/topics/apache-flink https://n…...

全面解析各类VPN技术:GRE、IPsec、L2TP、SSL与MPLS VPN对比

目录 引言 VPN技术概述 GRE VPN 3.1 GRE封装结构 3.2 GRE的应用场景 GRE over IPsec 4.1 GRE over IPsec封装结构 4.2 为什么使用GRE over IPsec&#xff1f; IPsec VPN 5.1 IPsec传输模式&#xff08;Transport Mode&#xff09; 5.2 IPsec隧道模式&#xff08;Tunne…...

RNN避坑指南:从数学推导到LSTM/GRU工业级部署实战流程

本文较长&#xff0c;建议点赞收藏&#xff0c;以免遗失。更多AI大模型应用开发学习视频及资料&#xff0c;尽在聚客AI学院。 本文全面剖析RNN核心原理&#xff0c;深入讲解梯度消失/爆炸问题&#xff0c;并通过LSTM/GRU结构实现解决方案&#xff0c;提供时间序列预测和文本生成…...

Rapidio门铃消息FIFO溢出机制

关于RapidIO门铃消息FIFO的溢出机制及其与中断抖动的关系&#xff0c;以下是深入解析&#xff1a; 门铃FIFO溢出的本质 在RapidIO系统中&#xff0c;门铃消息FIFO是硬件控制器内部的缓冲区&#xff0c;用于临时存储接收到的门铃消息&#xff08;Doorbell Message&#xff09;。…...

rnn判断string中第一次出现a的下标

# coding:utf8 import torch import torch.nn as nn import numpy as np import random import json""" 基于pytorch的网络编写 实现一个RNN网络完成多分类任务 判断字符 a 第一次出现在字符串中的位置 """class TorchModel(nn.Module):def __in…...