当前位置: 首页 > news >正文

记录pytorch实现自定义算子并转onnx文件输出

概览:记录了如何自定义一个算子,实现pytorch注册,通过C++编译为库文件供python端调用,并转为onnx文件输出

整体大概流程:

  • 定义算子实现为torch的C++版本文件
  • 注册算子
  • 编译算子生成库文件
  • 调用自定义算子

一、编译环境准备

1,在pytorch官网下载如下C++的libTorch package,下载完成后解压文件,是一个libtorch文件夹。

2,提前准备好python,以及pytorch

3,本示例使用了opencv库,所以需要提前安装好opencv。

二、自定义算子的实现

1,实现自定义算子函数

在解压后的libtorch文件夹统计目录,实现自定义算子,用opencv库实现的图像投射函数:warp_perspective。warp_perspective函数后面几行就是实现自定义算子的注册

warpPerspective.cpp文件:

#include "torch/script.h"
#include "opencv2/opencv.hpp"torch::Tensor warp_perspective(torch::Tensor image, torch::Tensor warp) {// BEGIN image_matcv::Mat image_mat(/*rows=*/image.size(0),/*cols=*/image.size(1),/*type=*/CV_32FC1,/*data=*/image.data_ptr<float>());// END image_mat// BEGIN warp_matcv::Mat warp_mat(/*rows=*/warp.size(0),/*cols=*/warp.size(1),/*type=*/CV_32FC1,/*data=*/warp.data_ptr<float>());// END warp_mat// BEGIN output_matcv::Mat output_mat;cv::warpPerspective(image_mat, output_mat, warp_mat, /*dsize=*/{ image.size(0),image.size(1) });// END output_mat// BEGIN output_tensortorch::Tensor output = torch::from_blob(output_mat.ptr<float>(), /*sizes=*/{ image.size(0),image.size(1) });return output.clone();// END output_tensor
}
//static auto registry = torch::RegisterOperators("my_ops::warp_perspective", &warp_perspective);  // torch.__version__: 1.5.0torch.__version__ >= 1.6.0  torch/include/torch/library.h
TORCH_LIBRARY(my_ops, m) {m.def("warp_perspective", warp_perspective);
}

2,同级目录创建CMakeList.txt文件

里面需要修改你自己的python下torch的路径,以及你对应安装python版pytorch是cpu还是gpu的。

cmake_minimum_required(VERSION 3.10 FATAL_ERROR)
project(warp_perspective)set(CMAKE_VERBOSE_MAKEFILE ON)
# >>> build type 
set(CMAKE_BUILD_TYPE "Release")				# 指定生成的版本
set(CMAKE_CXX_FLAGS_DEBUG "$ENV{CXXFLAGS} -O0 -Wall -g2 -ggdb")
set(CMAKE_CXX_FLAGS_RELEASE "$ENV{CXXFLAGS} -O3 -Wall")set(TORCH_ROOT "/home/xxx/anaconda3/lib/python3.10/site-packages/torch")   
include_directories(${TORCH_ROOT}/include)
link_directories(${TORCH_ROOT}/lib/)# Opencv
find_package(OpenCV REQUIRED)# Define our library target
add_library(warp_perspective SHARED warpPerspective.cpp)# Enable C++14
target_compile_features(warp_perspective PRIVATE cxx_std_17)# libtorch库文件
target_link_libraries(warp_perspective # CPUc10 torch_cpu# GPU# c10_cuda # torch_cuda)# opencv库文件
target_link_libraries(warp_perspective${OpenCV_LIBS}
)add_definitions(-D _GLIBCXX_USE_CXX11_ABI=0)

3,编译生成库文件

同级目录创建build文件夹,进入build文件夹利用CMakeList.txt进行编译,生成libwarp_perspective.so库文件

mkdir build
cd build
cmake ..
make

4,python版pytorch进行自定义算子的测试

注意我的以上代码都是放在了/data/xxx/mylib路径下,所以torch.ops.load_library("/data/xxx/mylib/build/libwarp_perspective.so")就找到库文件的位置。

这里我随便找了一张图片,和直接用python版的opencv做投射变换的结果作为golden对比。如下分别是原图,golden, 自定义pytorch算子的输出。自定义算子的输出不太对,但是图像轮廓和投射效果是对的,后面有时间我再检查一下是什么原因。

测试代码: 

import torch
import cv2
import numpy as nptorch.ops.load_library("/data/xxx/mylib/build/libwarp_perspective.so")im=cv2.imread("/data/xxx/mylib/cat.jpg",0)pst1 = np.float32([[56,65], [368,52], [28,387], [389,390]])
pst2 = np.float32([[100,145], [300,100], [80,290], [310,300]])
#2.2获取透视变换矩阵
T = cv2.getPerspectiveTransform(pst1, pst2)in_data =torch.from_numpy(np.float32(im))
in2_data = torch.Tensor(T)out1=torch.ops.my_ops.warp_perspective(in_data,in2_data)
dst0=np.uint8(out1.numpy())
cv2.imwrite("/data/xxx/mylib/cat_warp.jpg",dst0)dst = cv2.warpPerspective(im, np.float32(T), (im.shape[1], im.shape[0]))
cv2.imwrite("/data/xxx/mylib/cat_warp_gold.jpg",dst)

三、自定义算子导出为onnx文件

将注册的pytorch的自定义算子导出为onnx文件查看,效果图如下:

导出代码文件如下

import torch
import numpy as nptorch.ops.load_library("/data/xxx/mylib/build/libwarp_perspective.so")
class MyNet(torch.nn.Module):def __init__(self, name):super(MyNet, self).__init__()self.model_name = namedef forward(self, in_data, warp_data):return torch.ops.my_ops.warp_perspective(in_data, warp_data)def my_custom(g, in_data, warp_data):return g.op("cus_ops::warp_perspective", in_data, warp_data)
torch.onnx.register_custom_op_symbolic("my_ops::warp_perspective", my_custom, 9)if __name__ == "__main__":net = MyNet("my_ops")in_data = torch.randn((32, 32))warp_data = torch.rand((3, 3))out = net(in_data, warp_data)print("out: ", out)# export onnxtorch.onnx.export(net,(in_data, warp_data),"./my_ops_export_model2.onnx",input_names=["img_data", "warp_mat"],output_names=["out_img"],custom_opsets={"cus_ops": 11},)

相关文章:

记录pytorch实现自定义算子并转onnx文件输出

概览&#xff1a;记录了如何自定义一个算子&#xff0c;实现pytorch注册&#xff0c;通过C编译为库文件供python端调用&#xff0c;并转为onnx文件输出 整体大概流程&#xff1a; 定义算子实现为torch的C版本文件注册算子编译算子生成库文件调用自定义算子 一、编译环境准备…...

ARPG----C++学习记录04 Section8 角色类,移动

角色类输入 新建一个角色C&#xff0c;继承建立蓝图,和Pawn一样&#xff0c;绑定输入移动和相机. 在构造函数中添加这段代码也能实现。打开UsePawnControlRotation就可以让人物不跟随鼠标旋转 得到旋转后的向前向量 使用旋转矩阵 想要前进方向和旋转的方向对应。获取当前控制…...

拆解软件定义汽车:OS突围

软件作为智能汽车的核心组成部分&#xff0c;由于自身较为独立和复杂的IT学科体系&#xff0c;其技术链路、产业分工、价值分配、商业模式相对硬件产品&#xff08;如域控、激光雷达、摄像头等硬件&#xff09;而言&#xff0c;在汽车产业内探讨和传播相对较少。 11月3日&…...

并发线程使用介绍(二)

2.2.6 线程的强占 Thread的非静态方法join方法 需要在某一个线程下去调用这个方法 如果在main线程中调用了t1.join()&#xff0c;那么main线程会进入到等待状态&#xff0c;需要等待t1线程全部执行完毕&#xff0c;在恢复到就绪状态等待 CPU调度。 如果在main线程中调用了t1.j…...

【Proteus仿真】【51单片机】多路温度控制系统

文章目录 一、功能简介二、软件设计三、实验现象联系作者 一、功能简介 本项目使用Proteus8仿真51单片机控制器&#xff0c;使用按键、LED、蜂鸣器、LCD1602、DS18B20温度传感器、HC05蓝牙模块等。 主要功能&#xff1a; 系统运行后&#xff0c;默认LCD1602显示前4路采集的温…...

一些可以参考的文档集合15

之前的文章集合: 一些可以参考文章集合1_xuejianxinokok的博客-CSDN博客 一些可以参考文章集合2_xuejianxinokok的博客-CSDN博客 一些可以参考的文档集合3_xuejianxinokok的博客-CSDN博客 一些可以参考的文档集合4_xuejianxinokok的博客-CSDN博客 一些可以参考的文档集合5…...

k8s的service自动发现服务:实战版

Service服务发现的必要性: 对于kubernetes整个集群来说&#xff0c;Pod的地址也可变的&#xff0c;也就是说如果一个Pod因为某些原因退出了&#xff0c;而由于其设置了副本数replicas大于1&#xff0c;那么该Pod就会在集群的任意节点重新启动&#xff0c;这个重新启动的Pod的I…...

项目笔记记录

一、node下载版本报错&#xff1a;npm install --legacy-peer-deps 二、Scheduled: 任务自动化调度 Scheduled 标记要调度的方法的注解&#xff0c;必须指定 cron&#xff0c;fixedDelay或fixedRate属性之一 fixedDelay&#xff1a;固定延迟 延迟执行任务&#xff0c;任务在…...

【leetcode】1137. 第 N 个泰波那契数

题目 泰波那契序列 Tn 定义如下&#xff1a; T0 0, T1 1, T2 1, 且在 n > 0 的条件下 Tn3 Tn Tn1 Tn2 给你整数 n&#xff0c;请返回第 n 个泰波那契数 Tn 的值。 示例 1&#xff1a; 输入&#xff1a;n 4 输出&#xff1a;4 解释&#xff1a; T_3 0 1 1 2 …...

【解决】conda-script.py: error: argument COMMAND: invalid choice: ‘activate‘

运行conda activate base报错&#xff1a; 试了网上找到的解决方法都不行&#xff1a; 最后切换了一下terminal&#xff1a; 从powershell改回cmd&#xff08;不知道为什么一开始手贱换成powershell&#xff09; 就可以了...

Linux 性能调优之硬件资源监控

写在前面 考试整理相关笔记博文内容涉及 Linux 硬件资源监控常见的命名介绍&#xff0c;涉及硬件基本信息查看查看硬件错误信息查看虚拟环境和云环境资源理解不足小伙伴帮忙指正 对每个人而言&#xff0c;真正的职责只有一个&#xff1a;找到自我。然后在心中坚守其一生&#x…...

Windows系统隐藏窗口启动控制台程序

背景 上线项目有时候需要一些控制台应用作为辅助服务来协助UI应用满足实际需求&#xff0c;这时候如果一运行UI就冒出一系列的黑框&#xff0c;这将会导致客户被下的不起&#xff0c;生怕中了什么不知名病毒 方案 可以使用vbs来启动&#xff0c;这个是window系统自带的&#…...

FreeSWITCH fail2ban.lua

--[[ 部署:在vars.xml里面增加配置项目&#xff1a;<X-PRE-PROCESS cmd"set" data"api_on_startupluarun fail2ban.lua"/>或者在 lua.conf.xml 里面增加下面这个配置项目&#xff1a;<param name"startup-script" value"fail2ban.…...

Qt HTTP下载数据

添加头文件&#xff1a; #include <QNetworkAccessManager> #include <QNetworkReply> #include <QUrl> #include <QDesktopServices> 创建对象&#xff1a; QNetworkAccessManager networkManager;//网络管理QNetworkReply *reply; …...

8. 深度学习——NLP

机器学习面试题汇总与解析——NLP 本章讲解知识点 什么是 NLP循环神经网络(RNN)RNN 变体Attention 机制RNN 反向传播推导LSTM 与 GRUTransformerBertGPT分词算法分类CBOW 模型与 Skip-Gram 模型本专栏适合于Python已经入门的学生或人士,有一定的编程基础。本专栏适合于算法…...

部署 KVM 虚拟化平台

虚拟化技术的演变过程分为软件模拟、虚拟化层翻译、容器虚拟化三个阶段 1 软件模拟的技术方式 软件模拟是通过软件完全模拟CPU、网卡、芯片组、磁盘等计算机硬件&#xff0c;因为是软件模拟&#xff0c;所以理论上可以模拟任何硬件&#xff0c;甚至不存在的硬件。但是由于是软…...

Juniper PPPOE双线路冗余RPM配置

------------------ 浮动静态路由 set routing-options static route 0.0.0.0/0 next-hop pp0.0 qualified-next-hop pp0.1 preference 10 ----------------- RPM测试的内容,包括从哪个接口发起测试,测试ping等等 #指定探针类型用ICMP请求 #探测的目标地址 #探测间隔 #探测阈…...

原生JS实现视频截图

视频截图效果预览 利用Canvas进行截图 要用原生js实现视频截图&#xff0c;可以利用canvas的绘图功能 ctx.drawImage&#xff0c;只需要获取到视频标签&#xff0c;就可以通过drawImage把视频当前帧图像绘制在canvas画布上。 const video document.querySelector(video) con…...

前端Rust二进制/wasm全平台构建流程简述

前言 开门见山&#xff0c;现代前端 Rust 构建基本分三大类&#xff0c;即 构建 .wasm 、构建 .node 二进制 、构建 swc 插件。 入门详见 《 前端Rust开发WebAssembly与Swc插件快速入门 》 。 对于单独开发某一类的流程&#xff0c;在上述参考文章中已有介绍&#xff0c;但对于…...

加解密算法相关技术详解

文章目录 简介工作机制加解密对称密钥算法非对称密钥算法 数字信封数字签名数字证书技术对比 推荐阅读 简介 随着网络技术的飞速发展&#xff0c;网络安全问题日益重要&#xff0c;加解密技术是网络安全技术中的核心技术&#xff0c;是最常用的安全保密手段。 加密&#xff1…...

7 低配置设备鸿蒙运行流畅度提升技巧 | 鸿蒙开发筑基实战

7 低配置设备鸿蒙运行流畅度提升技巧 | 鸿蒙开发筑基实战 作者&#xff1a;杨建宾&#xff08;华夏之光永存&#xff09; 摘要 本文面向鸿蒙开发者&#xff0c;特别是在低配设备、低内存机型上遇到卡顿、掉帧、加载慢的工程师。提供一套通用、可落地、不求炫技的流畅度提升方…...

效率倍增器:OpenClaw+千问3.5-27B自动化邮件处理

效率倍增器&#xff1a;OpenClaw千问3.5-27B自动化邮件处理 1. 为什么需要自动化邮件处理 每天早晨打开邮箱&#xff0c;看到堆积如山的未读邮件时&#xff0c;那种窒息感我至今难忘。作为技术团队的接口人&#xff0c;我的邮箱常年保持着2000未读邮件的状态——重要需求埋没…...

ArduinoEigen:嵌入式平台轻量级Eigen线性代数库移植

1. ArduinoEigen&#xff1a;面向嵌入式平台的轻量化Eigen线性代数库移植1.1 项目定位与工程价值ArduinoEigen 是一个专为资源受限嵌入式平台定制的 Eigen 线性代数库移植版本&#xff0c;其核心目标并非简单地将桌面级 C 数值计算库“搬上”MCU&#xff0c;而是通过深度裁剪、…...

企业级“衣依”服装销售平台管理系统源码|SpringBoot+Vue+MyBatis架构+MySQL数据库【完整版】

&#x1f4a1;实话实说&#xff1a;有自己的项目库存&#xff0c;不需要找别人拿货再加价&#xff0c;所以能给到超低价格。摘要 随着电子商务的快速发展&#xff0c;服装行业对高效、智能化的销售管理平台需求日益增长。传统的线下销售模式在库存管理、订单处理及客户服务等方…...

Gemma-3-12b-it Streamlit应用实战:顶部像素控制面板CSS3定制详解

Gemma-3-12b-it Streamlit应用实战&#xff1a;顶部像素控制面板CSS3定制详解 1. 引言&#xff1a;从传统侧边栏到像素控制面板 如果你用过Streamlit&#xff0c;肯定对那个默认的侧边栏不陌生。它很方便&#xff0c;但有时候也挺碍事——特别是当你想要一个全屏、沉浸式的对…...

mbedBug:面向mbed OS的轻量级嵌入式调试纳米框架

1. mbedBug&#xff1a;面向mbed OS的轻量级嵌入式调试纳米框架1.1 设计定位与工程价值mbedBug并非通用型调试器或完整测试框架&#xff0c;而是一个专为资源受限嵌入式环境裁剪的调试纳米框架&#xff08;Debug Nanoframework&#xff09;。其核心设计哲学是“最小侵入、最大可…...

手机号码定位查询工具:3分钟快速部署,轻松查询号码归属地

手机号码定位查询工具&#xff1a;3分钟快速部署&#xff0c;轻松查询号码归属地 【免费下载链接】location-to-phone-number This a project to search a location of a specified phone number, and locate the map to the phone number location. 项目地址: https://gitco…...

Python AOT编译成本如何从$280K/年压至$49K/年?2026前最后窗口期的6个不可逆决策点

第一章&#xff1a;Python AOT编译成本断崖式下降的战略本质Python 长期以来被诟病于运行时开销高、启动慢、内存占用大&#xff0c;其核心瓶颈在于 CPython 解释器的字节码解释执行机制。而近年来&#xff0c;以 Nuitka、Cython&#xff08;搭配 --aot 模式&#xff09;、以及…...

从NTU-RGB+D到实际应用:如何用这个数据集训练一个摔倒检测模型?

基于NTU-RGBD数据集的摔倒检测模型实战指南 在智能监护和安防领域&#xff0c;摔倒检测一直是个极具社会价值的课题。想象一下&#xff0c;当独居老人不慎跌倒时&#xff0c;系统能在第一时间发出警报&#xff1b;或是在建筑工地&#xff0c;实时监测工人安全状态——这些场景背…...

极空间玩出花!用 File Browser 搭建专属私有云,文件管理超丝滑

前言 玩 NAS 的朋友应该都懂&#xff0c;极空间的硬件确实够稳&#xff0c;但原生的文件管理功能总差那么点意思 —— 权限管控不精细、跨设备操作不够顺手&#xff0c;想把它打造成真正的私人网盘总差点火候。 直到我试了 File Browser&#xff0c;这款轻量又强大的开源 Web…...