当前位置: 首页 > news >正文

使用Pytorch导出自定义ONNX算子

在实际部署模型时有时可能会遇到想用的算子无法导出onnx,但实际部署的框架是支持该算子的。此时可以通过自定义onnx算子的方式导出onnx模型(注:自定义onnx算子导出onnx模型后是无法使用onnxruntime推理的)。下面给出个具体应用中的示例:需要导出pytorch的affine_grid算子,但在pytorch的2.0.1版本中又无法正常导出该算子,故可通过如下自定义算子代码导出。

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Function
from torch.onnx import OperatorExportTypesclass CustomAffineGrid(Function):@staticmethoddef forward(ctx, theta: torch.Tensor, size: torch.Tensor):grid = F.affine_grid(theta=theta, size=size.cpu().tolist())return grid@staticmethoddef symbolic(g: torch.Graph, theta: torch.Tensor, size: torch.Tensor):return g.op("AffineGrid", theta, size)class MyModel(nn.Module):def __init__(self) -> None:super().__init__()def forward(self, x: torch.Tensor, theta: torch.Tensor, size: torch.Tensor):grid = CustomAffineGrid.apply(theta, size)x = F.grid_sample(x, grid=grid, mode="bilinear", padding_mode="zeros")return xdef main():with torch.inference_mode():custum_model = MyModel()x = torch.randn(1, 3, 224, 224)theta = torch.randn(1, 2, 3)size = torch.as_tensor([1, 3, 512, 512])torch.onnx.export(model=custum_model,args=(x, theta, size),f="custom.onnx",input_names=["input0_x", "input1_theta", "input2_size"],output_names=["output"],dynamic_axes={"input0_x": {2: "h0", 3: "w0"},"output": {2: "h1", 3: "w1"}},opset_version=16,operator_export_type=OperatorExportTypes.ONNX_FALLTHROUGH)if __name__ == '__main__':main()

在上面代码中,通过继承torch.autograd.Function父类的方式实现导出自定义算子,继承该父类后需要用户自己实现forward以及symbolic两个静态方法,其中forward方法是在pytorch正常推理时调用的函数,而symbolic方法是在导出onnx时调用的函数。对于forward方法需要按照正常的pytorch语法来实现,其中第一个参数必须是ctx但对于当前导出onnx场景可以不用管它,后面的参数是实际自己传入的参数。对于symbolic方法的第一个必须是g,后面的参数任为实际自己传入的参数,然后通过g.op方法指定具体导出自定义算子的名称,以及输入的参数(注:上面示例中传入的都是Tensor所以可以直接传入,对与非Tensor的参数可见下面一个示例)。最后在使用时直接调用自己实现类的apply方法即可。使用netron打开自己导出的onnx文件,可以看到如下所示网络结构。
在这里插入图片描述

有时按照使用的推理框架导出自定义算子时还需要设置一些参数(非Tensor)那么可以参考如下示例,例如要导出int型的参数k那么可以通过传入k_i来指定,要导出float型的参数scale那么可以通过传入scale_f来指定,要导出string型的参数clockwise那么可以通过传入clockwise_s来指定:

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Function
from torch.onnx import OperatorExportTypesclass CustomRot90AndScale(Function):@staticmethoddef forward(ctx, x: torch.Tensor):x = torch.rot90(x, k=1, dims=(3, 2))  # clockwise 90x *= 1.2return x@staticmethoddef symbolic(g: torch.Graph, x: torch.Tensor):return g.op("Rot90AndScale", x, k_i=1, scale_f=1.2, clockwise_s="yes")class MyModel(nn.Module):def __init__(self) -> None:super().__init__()def forward(self, x: torch.Tensor):return CustomRot90AndScale.apply(x)def main():with torch.inference_mode():custum_model = MyModel()x = torch.randn(1, 3, 224, 224)torch.onnx.export(model=custum_model,args=(x,),f="custom_rot90.onnx",input_names=["input"],output_names=["output"],dynamic_axes={"input": {2: "h0", 3: "w0"},"output": {2: "w0", 3: "h0"}},opset_version=16,operator_export_type=OperatorExportTypes.ONNX_FALLTHROUGH)if __name__ == '__main__':main()

使用netron打开自己导出的onnx文件,可以看到如下所示信息。
在这里插入图片描述

相关文章:

使用Pytorch导出自定义ONNX算子

在实际部署模型时有时可能会遇到想用的算子无法导出onnx,但实际部署的框架是支持该算子的。此时可以通过自定义onnx算子的方式导出onnx模型(注:自定义onnx算子导出onnx模型后是无法使用onnxruntime推理的)。下面给出个具体应用中的…...

unity-urp:视野雾

问题背景 恐怖游戏在黑夜或者某些场景下,需要用雾或者黑暗遮盖视野,搭建游戏氛围 效果 场景中,雾会遮挡场景和怪物,但是在玩家视野内雾会消散,距离玩家越近雾越薄。 当前是第三人称视角,但是可以轻松的…...

Spring Cloud Gateway介绍及入门配置

Spring Cloud Gateway介绍及入门配置 概述: Gateway是在Spring生态系统之上构建的API网关服务,基于Spring6,Spring Boot 3和Project Reactor等技术。它旨在为微服务架构提供一种简单有效的统一的 API 路由管理方式,并为它们提供…...

Thingsboard本地源码部署教程

本章将介绍ThingsBoard的本地环境搭建,以及源码的编译安装。本机环境:jdk11、maven 3.6.2、node v12.18.2、idea 2023.1、redis 6.2 环境安装 开发环境要求: Jdk 11 版本 ;Postgresql 9 以上;Maven 3.6 以上&#xf…...

【MySQL 系列】MySQL 起步篇

MySQL 是一个开放源代码的、免费的关系型数据库管理系统。在 Web 开发领域,MySQL 是最流行、使用最广泛的关系数据库。MySql 分为社区版和商业版,社区版完全免费,并且几乎能满足全部的使用场景。由于 MySQL 是开源的,我们还可以根…...

C++的成员初始化列表

C的成员构造函数初始化列表:构造函数中初始化类成员的一种方式,当我们编写一个类并向该类添加成员时,通常需要某种方式对这些成员变量进行初始化。 建议应该在所有地方使用成员初始化列表进行初始化 成员初始化的方法 方法一: …...

为什么TikTok视频0播放?账号权重提高要重视

许多TikTok账号运营者都会遇到一个难题,那就是视频要么播放量很低,要么0播放!不管内容做的多好,最好都是竹篮打水一场空!其实你可能忽略了一个问题,那就是账号权重。下面好好跟大家讲讲这个东西&#xff01…...

element---tree树形结构(返回的数据与官方的不一样)

项目中要用到属性结构数据&#xff0c;后端返回的数据不是官方默认的数据结构&#xff1a; <el-tree:data"treeData":filter-node-method"filterNode":props"defaultProps"node-click"handleNodeClick"></el-tree>这是文档…...

Spring Boot工程集成验证码生成与验证功能教程

&#x1f31f; 前言 欢迎来到我的技术小宇宙&#xff01;&#x1f30c; 这里不仅是我记录技术点滴的后花园&#xff0c;也是我分享学习心得和项目经验的乐园。&#x1f4da; 无论你是技术小白还是资深大牛&#xff0c;这里总有一些内容能触动你的好奇心。&#x1f50d; &#x…...

Bert Encoder和Transformer Encoder有什么不同

前言&#xff1a;本篇文章主要从代码实现角度研究 Bert Encoder和Transformer Encoder 有什么不同&#xff1f;应该可以帮助你&#xff1a; 深入了解Bert Encoder 的结构实现深入了解Transformer Encoder的结构实现 本篇文章不涉及对注意力机制实现的代码研究。 注&#xff1a;…...

外汇天眼:频繁交钱却无法出金,只因误入假冒HFM惨成冤大头!

在外汇市场上这么久了&#xff0c;天眼君总结出了一个不争的事实&#xff0c;但凡是不给出金或者以各种理由拒绝出金的平台一定有问题&#xff01;想必不管是在外汇天眼还是其他地方&#xff0c;大家总是能看到一些外汇交易者投诉自己向平台申请出金需要缴纳各种费用&#xff0…...

Linux-信号3_sigaction、volatile与SIGCHLD

文章目录 前言一、sigaction__sighandler_t sa_handler;__sigset_t sa_mask; 二、volatile关键字三、SIGCHLD方法一方法二 前言 本章内容主要对之前的内容做一些补充。 一、sigaction #include <signal.h> int sigaction(int signum, const struct sigaction *act,struc…...

STM32 | STM32时钟分析、GPIO分析、寄存器地址查找、LED灯开发(第二天)

STM32 第二天 一、 STM32时钟分析 寄存器&#xff1a;寄存器的功能是存储二进制代码&#xff0c;它是由具有存储功能的触发器组合起来构成的。一个触发器可以存储1位二进制代码&#xff0c;故存放n位二进制代码的寄存器&#xff0c;需用n个触发器来构成 在计算机领域&#x…...

Python常用语法汇总(一):字符串、列表、字典操作

1. 字符串处理 print(message.title()) #首字母大写print(message.uper()) #全部大写print(message.lower()) #全部小写full_name "lin" "hai" #合并字符串print("Hello, " full_name.title() "!")print("John Q. %s10s&qu…...

Token的奥秘--一起学习吧之token

Token&#xff0c;在计算机科学中&#xff0c;是一个用于表示数据或一段数据的单位。它通常用于加密、身份验证、令牌化等场景&#xff0c;以确保数据的安全性和完整性。在编程语言中&#xff0c;Token通常是指代一段代码或数据的最小单元&#xff0c;例如一个变量、一个操作符…...

FlinkCDC快速搭建实现数据监控

引入依赖 <project xmlns"http://maven.apache.org/POM/4.0.0" xmlns:xsi"http://www.w3.org/2001/XMLSchema-instance"xsi:schemaLocation"http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd"><modelV…...

应急布控球远程视频监控方案:视频监控平台EasyCVR+4G/5G应急布控球

随着科技的不断发展&#xff0c;应急布控球远程视频监控方案在公共安全、交通管理、城市管理等领域的应用越来越广泛。这种方案通过在现场部署应急布控球&#xff0c;实现对特定区域的实时监控&#xff0c;有助于及时发现问题、快速响应&#xff0c;提高管理效率。 智慧安防视…...

3.6 C语言和汇编语言混合编程 “每日读书”

在一些嵌入式场合&#xff0c;我们经常看到C程序和汇编程序相互调用&#xff0c;混合编程&#xff0c;如在ARM启动代码中&#xff0c;系统上电首先运行的是汇编代码&#xff0c;等初始化好内存堆栈环境之后&#xff0c;才会跳到C程序中执行&#xff0c;对嵌入式软件进行优化时&…...

利用“定时执行专家”循环执行BAT、VBS、Python脚本——含参数指定功能

目录 一、软件概述 二、VBS脚本执行设置 三、触发器设置 四、功能亮点 五、总结 在自动化办公和日常计算机任务管理中&#xff0c;定时执行脚本是一项非常重要的功能。今天&#xff0c;我将为大家带来一款名为“定时执行专家”的软件的评测&#xff0c;特别是其定时执行VB…...

【算法集训】基础算法:模拟

一、基本理解 顾名思义&#xff0c;就是题目要求做什么&#xff0c;代码中就跟着做就可以。 二、题目练习 1252. 奇数值单元格的数目 根据题目要求列出如下代码。需要注意填充列和行的时候注意下标。 int oddCells(int m, int n, int** indices, int indicesSize, int* in…...

19c补丁后oracle属主变化,导致不能识别磁盘组

补丁后服务器重启&#xff0c;数据库再次无法启动 ORA01017: invalid username/password; logon denied Oracle 19c 在打上 19.23 或以上补丁版本后&#xff0c;存在与用户组权限相关的问题。具体表现为&#xff0c;Oracle 实例的运行用户&#xff08;oracle&#xff09;和集…...

日语AI面试高效通关秘籍:专业解读与青柚面试智能助攻

在如今就业市场竞争日益激烈的背景下&#xff0c;越来越多的求职者将目光投向了日本及中日双语岗位。但是&#xff0c;一场日语面试往往让许多人感到步履维艰。你是否也曾因为面试官抛出的“刁钻问题”而心生畏惧&#xff1f;面对生疏的日语交流环境&#xff0c;即便提前恶补了…...

设计模式和设计原则回顾

设计模式和设计原则回顾 23种设计模式是设计原则的完美体现,设计原则设计原则是设计模式的理论基石, 设计模式 在经典的设计模式分类中(如《设计模式:可复用面向对象软件的基础》一书中),总共有23种设计模式,分为三大类: 一、创建型模式(5种) 1. 单例模式(Sing…...

Linux简单的操作

ls ls 查看当前目录 ll 查看详细内容 ls -a 查看所有的内容 ls --help 查看方法文档 pwd pwd 查看当前路径 cd cd 转路径 cd .. 转上一级路径 cd 名 转换路径 …...

Cilium动手实验室: 精通之旅---20.Isovalent Enterprise for Cilium: Zero Trust Visibility

Cilium动手实验室: 精通之旅---20.Isovalent Enterprise for Cilium: Zero Trust Visibility 1. 实验室环境1.1 实验室环境1.2 小测试 2. The Endor System2.1 部署应用2.2 检查现有策略 3. Cilium 策略实体3.1 创建 allow-all 网络策略3.2 在 Hubble CLI 中验证网络策略源3.3 …...

C++ 基础特性深度解析

目录 引言 一、命名空间&#xff08;namespace&#xff09; C 中的命名空间​ 与 C 语言的对比​ 二、缺省参数​ C 中的缺省参数​ 与 C 语言的对比​ 三、引用&#xff08;reference&#xff09;​ C 中的引用​ 与 C 语言的对比​ 四、inline&#xff08;内联函数…...

Python如何给视频添加音频和字幕

在Python中&#xff0c;给视频添加音频和字幕可以使用电影文件处理库MoviePy和字幕处理库Subtitles。下面将详细介绍如何使用这些库来实现视频的音频和字幕添加&#xff0c;包括必要的代码示例和详细解释。 环境准备 在开始之前&#xff0c;需要安装以下Python库&#xff1a;…...

【Java_EE】Spring MVC

目录 Spring Web MVC ​编辑注解 RestController RequestMapping RequestParam RequestParam RequestBody PathVariable RequestPart 参数传递 注意事项 ​编辑参数重命名 RequestParam ​编辑​编辑传递集合 RequestParam 传递JSON数据 ​编辑RequestBody ​…...

拉力测试cuda pytorch 把 4070显卡拉满

import torch import timedef stress_test_gpu(matrix_size16384, duration300):"""对GPU进行压力测试&#xff0c;通过持续的矩阵乘法来最大化GPU利用率参数:matrix_size: 矩阵维度大小&#xff0c;增大可提高计算复杂度duration: 测试持续时间&#xff08;秒&…...

Android Bitmap治理全解析:从加载优化到泄漏防控的全生命周期管理

引言 Bitmap&#xff08;位图&#xff09;是Android应用内存占用的“头号杀手”。一张1080P&#xff08;1920x1080&#xff09;的图片以ARGB_8888格式加载时&#xff0c;内存占用高达8MB&#xff08;192010804字节&#xff09;。据统计&#xff0c;超过60%的应用OOM崩溃与Bitm…...