当前位置: 首页 > news >正文

记录pytorch实现自定义算子并转onnx文件输出

概览:记录了如何自定义一个算子,实现pytorch注册,通过C++编译为库文件供python端调用,并转为onnx文件输出

整体大概流程:

  • 定义算子实现为torch的C++版本文件
  • 注册算子
  • 编译算子生成库文件
  • 调用自定义算子

一、编译环境准备

1,在pytorch官网下载如下C++的libTorch package,下载完成后解压文件,是一个libtorch文件夹。

2,提前准备好python,以及pytorch

3,本示例使用了opencv库,所以需要提前安装好opencv。

二、自定义算子的实现

1,实现自定义算子函数

在解压后的libtorch文件夹统计目录,实现自定义算子,用opencv库实现的图像投射函数:warp_perspective。warp_perspective函数后面几行就是实现自定义算子的注册

warpPerspective.cpp文件:

#include "torch/script.h"
#include "opencv2/opencv.hpp"torch::Tensor warp_perspective(torch::Tensor image, torch::Tensor warp) {// BEGIN image_matcv::Mat image_mat(/*rows=*/image.size(0),/*cols=*/image.size(1),/*type=*/CV_32FC1,/*data=*/image.data_ptr<float>());// END image_mat// BEGIN warp_matcv::Mat warp_mat(/*rows=*/warp.size(0),/*cols=*/warp.size(1),/*type=*/CV_32FC1,/*data=*/warp.data_ptr<float>());// END warp_mat// BEGIN output_matcv::Mat output_mat;cv::warpPerspective(image_mat, output_mat, warp_mat, /*dsize=*/{ image.size(0),image.size(1) });// END output_mat// BEGIN output_tensortorch::Tensor output = torch::from_blob(output_mat.ptr<float>(), /*sizes=*/{ image.size(0),image.size(1) });return output.clone();// END output_tensor
}
//static auto registry = torch::RegisterOperators("my_ops::warp_perspective", &warp_perspective);  // torch.__version__: 1.5.0torch.__version__ >= 1.6.0  torch/include/torch/library.h
TORCH_LIBRARY(my_ops, m) {m.def("warp_perspective", warp_perspective);
}

2,同级目录创建CMakeList.txt文件

里面需要修改你自己的python下torch的路径,以及你对应安装python版pytorch是cpu还是gpu的。

cmake_minimum_required(VERSION 3.10 FATAL_ERROR)
project(warp_perspective)set(CMAKE_VERBOSE_MAKEFILE ON)
# >>> build type 
set(CMAKE_BUILD_TYPE "Release")				# 指定生成的版本
set(CMAKE_CXX_FLAGS_DEBUG "$ENV{CXXFLAGS} -O0 -Wall -g2 -ggdb")
set(CMAKE_CXX_FLAGS_RELEASE "$ENV{CXXFLAGS} -O3 -Wall")set(TORCH_ROOT "/home/xxx/anaconda3/lib/python3.10/site-packages/torch")   
include_directories(${TORCH_ROOT}/include)
link_directories(${TORCH_ROOT}/lib/)# Opencv
find_package(OpenCV REQUIRED)# Define our library target
add_library(warp_perspective SHARED warpPerspective.cpp)# Enable C++14
target_compile_features(warp_perspective PRIVATE cxx_std_17)# libtorch库文件
target_link_libraries(warp_perspective # CPUc10 torch_cpu# GPU# c10_cuda # torch_cuda)# opencv库文件
target_link_libraries(warp_perspective${OpenCV_LIBS}
)add_definitions(-D _GLIBCXX_USE_CXX11_ABI=0)

3,编译生成库文件

同级目录创建build文件夹,进入build文件夹利用CMakeList.txt进行编译,生成libwarp_perspective.so库文件

mkdir build
cd build
cmake ..
make

4,python版pytorch进行自定义算子的测试

注意我的以上代码都是放在了/data/xxx/mylib路径下,所以torch.ops.load_library("/data/xxx/mylib/build/libwarp_perspective.so")就找到库文件的位置。

这里我随便找了一张图片,和直接用python版的opencv做投射变换的结果作为golden对比。如下分别是原图,golden, 自定义pytorch算子的输出。自定义算子的输出不太对,但是图像轮廓和投射效果是对的,后面有时间我再检查一下是什么原因。

测试代码: 

import torch
import cv2
import numpy as nptorch.ops.load_library("/data/xxx/mylib/build/libwarp_perspective.so")im=cv2.imread("/data/xxx/mylib/cat.jpg",0)pst1 = np.float32([[56,65], [368,52], [28,387], [389,390]])
pst2 = np.float32([[100,145], [300,100], [80,290], [310,300]])
#2.2获取透视变换矩阵
T = cv2.getPerspectiveTransform(pst1, pst2)in_data =torch.from_numpy(np.float32(im))
in2_data = torch.Tensor(T)out1=torch.ops.my_ops.warp_perspective(in_data,in2_data)
dst0=np.uint8(out1.numpy())
cv2.imwrite("/data/xxx/mylib/cat_warp.jpg",dst0)dst = cv2.warpPerspective(im, np.float32(T), (im.shape[1], im.shape[0]))
cv2.imwrite("/data/xxx/mylib/cat_warp_gold.jpg",dst)

三、自定义算子导出为onnx文件

将注册的pytorch的自定义算子导出为onnx文件查看,效果图如下:

导出代码文件如下

import torch
import numpy as nptorch.ops.load_library("/data/xxx/mylib/build/libwarp_perspective.so")
class MyNet(torch.nn.Module):def __init__(self, name):super(MyNet, self).__init__()self.model_name = namedef forward(self, in_data, warp_data):return torch.ops.my_ops.warp_perspective(in_data, warp_data)def my_custom(g, in_data, warp_data):return g.op("cus_ops::warp_perspective", in_data, warp_data)
torch.onnx.register_custom_op_symbolic("my_ops::warp_perspective", my_custom, 9)if __name__ == "__main__":net = MyNet("my_ops")in_data = torch.randn((32, 32))warp_data = torch.rand((3, 3))out = net(in_data, warp_data)print("out: ", out)# export onnxtorch.onnx.export(net,(in_data, warp_data),"./my_ops_export_model2.onnx",input_names=["img_data", "warp_mat"],output_names=["out_img"],custom_opsets={"cus_ops": 11},)

相关文章:

记录pytorch实现自定义算子并转onnx文件输出

概览&#xff1a;记录了如何自定义一个算子&#xff0c;实现pytorch注册&#xff0c;通过C编译为库文件供python端调用&#xff0c;并转为onnx文件输出 整体大概流程&#xff1a; 定义算子实现为torch的C版本文件注册算子编译算子生成库文件调用自定义算子 一、编译环境准备…...

ARPG----C++学习记录04 Section8 角色类,移动

角色类输入 新建一个角色C&#xff0c;继承建立蓝图,和Pawn一样&#xff0c;绑定输入移动和相机. 在构造函数中添加这段代码也能实现。打开UsePawnControlRotation就可以让人物不跟随鼠标旋转 得到旋转后的向前向量 使用旋转矩阵 想要前进方向和旋转的方向对应。获取当前控制…...

拆解软件定义汽车:OS突围

软件作为智能汽车的核心组成部分&#xff0c;由于自身较为独立和复杂的IT学科体系&#xff0c;其技术链路、产业分工、价值分配、商业模式相对硬件产品&#xff08;如域控、激光雷达、摄像头等硬件&#xff09;而言&#xff0c;在汽车产业内探讨和传播相对较少。 11月3日&…...

并发线程使用介绍(二)

2.2.6 线程的强占 Thread的非静态方法join方法 需要在某一个线程下去调用这个方法 如果在main线程中调用了t1.join()&#xff0c;那么main线程会进入到等待状态&#xff0c;需要等待t1线程全部执行完毕&#xff0c;在恢复到就绪状态等待 CPU调度。 如果在main线程中调用了t1.j…...

【Proteus仿真】【51单片机】多路温度控制系统

文章目录 一、功能简介二、软件设计三、实验现象联系作者 一、功能简介 本项目使用Proteus8仿真51单片机控制器&#xff0c;使用按键、LED、蜂鸣器、LCD1602、DS18B20温度传感器、HC05蓝牙模块等。 主要功能&#xff1a; 系统运行后&#xff0c;默认LCD1602显示前4路采集的温…...

一些可以参考的文档集合15

之前的文章集合: 一些可以参考文章集合1_xuejianxinokok的博客-CSDN博客 一些可以参考文章集合2_xuejianxinokok的博客-CSDN博客 一些可以参考的文档集合3_xuejianxinokok的博客-CSDN博客 一些可以参考的文档集合4_xuejianxinokok的博客-CSDN博客 一些可以参考的文档集合5…...

k8s的service自动发现服务:实战版

Service服务发现的必要性: 对于kubernetes整个集群来说&#xff0c;Pod的地址也可变的&#xff0c;也就是说如果一个Pod因为某些原因退出了&#xff0c;而由于其设置了副本数replicas大于1&#xff0c;那么该Pod就会在集群的任意节点重新启动&#xff0c;这个重新启动的Pod的I…...

项目笔记记录

一、node下载版本报错&#xff1a;npm install --legacy-peer-deps 二、Scheduled: 任务自动化调度 Scheduled 标记要调度的方法的注解&#xff0c;必须指定 cron&#xff0c;fixedDelay或fixedRate属性之一 fixedDelay&#xff1a;固定延迟 延迟执行任务&#xff0c;任务在…...

【leetcode】1137. 第 N 个泰波那契数

题目 泰波那契序列 Tn 定义如下&#xff1a; T0 0, T1 1, T2 1, 且在 n > 0 的条件下 Tn3 Tn Tn1 Tn2 给你整数 n&#xff0c;请返回第 n 个泰波那契数 Tn 的值。 示例 1&#xff1a; 输入&#xff1a;n 4 输出&#xff1a;4 解释&#xff1a; T_3 0 1 1 2 …...

【解决】conda-script.py: error: argument COMMAND: invalid choice: ‘activate‘

运行conda activate base报错&#xff1a; 试了网上找到的解决方法都不行&#xff1a; 最后切换了一下terminal&#xff1a; 从powershell改回cmd&#xff08;不知道为什么一开始手贱换成powershell&#xff09; 就可以了...

Linux 性能调优之硬件资源监控

写在前面 考试整理相关笔记博文内容涉及 Linux 硬件资源监控常见的命名介绍&#xff0c;涉及硬件基本信息查看查看硬件错误信息查看虚拟环境和云环境资源理解不足小伙伴帮忙指正 对每个人而言&#xff0c;真正的职责只有一个&#xff1a;找到自我。然后在心中坚守其一生&#x…...

Windows系统隐藏窗口启动控制台程序

背景 上线项目有时候需要一些控制台应用作为辅助服务来协助UI应用满足实际需求&#xff0c;这时候如果一运行UI就冒出一系列的黑框&#xff0c;这将会导致客户被下的不起&#xff0c;生怕中了什么不知名病毒 方案 可以使用vbs来启动&#xff0c;这个是window系统自带的&#…...

FreeSWITCH fail2ban.lua

--[[ 部署:在vars.xml里面增加配置项目&#xff1a;<X-PRE-PROCESS cmd"set" data"api_on_startupluarun fail2ban.lua"/>或者在 lua.conf.xml 里面增加下面这个配置项目&#xff1a;<param name"startup-script" value"fail2ban.…...

Qt HTTP下载数据

添加头文件&#xff1a; #include <QNetworkAccessManager> #include <QNetworkReply> #include <QUrl> #include <QDesktopServices> 创建对象&#xff1a; QNetworkAccessManager networkManager;//网络管理QNetworkReply *reply; …...

8. 深度学习——NLP

机器学习面试题汇总与解析——NLP 本章讲解知识点 什么是 NLP循环神经网络(RNN)RNN 变体Attention 机制RNN 反向传播推导LSTM 与 GRUTransformerBertGPT分词算法分类CBOW 模型与 Skip-Gram 模型本专栏适合于Python已经入门的学生或人士,有一定的编程基础。本专栏适合于算法…...

部署 KVM 虚拟化平台

虚拟化技术的演变过程分为软件模拟、虚拟化层翻译、容器虚拟化三个阶段 1 软件模拟的技术方式 软件模拟是通过软件完全模拟CPU、网卡、芯片组、磁盘等计算机硬件&#xff0c;因为是软件模拟&#xff0c;所以理论上可以模拟任何硬件&#xff0c;甚至不存在的硬件。但是由于是软…...

Juniper PPPOE双线路冗余RPM配置

------------------ 浮动静态路由 set routing-options static route 0.0.0.0/0 next-hop pp0.0 qualified-next-hop pp0.1 preference 10 ----------------- RPM测试的内容,包括从哪个接口发起测试,测试ping等等 #指定探针类型用ICMP请求 #探测的目标地址 #探测间隔 #探测阈…...

原生JS实现视频截图

视频截图效果预览 利用Canvas进行截图 要用原生js实现视频截图&#xff0c;可以利用canvas的绘图功能 ctx.drawImage&#xff0c;只需要获取到视频标签&#xff0c;就可以通过drawImage把视频当前帧图像绘制在canvas画布上。 const video document.querySelector(video) con…...

前端Rust二进制/wasm全平台构建流程简述

前言 开门见山&#xff0c;现代前端 Rust 构建基本分三大类&#xff0c;即 构建 .wasm 、构建 .node 二进制 、构建 swc 插件。 入门详见 《 前端Rust开发WebAssembly与Swc插件快速入门 》 。 对于单独开发某一类的流程&#xff0c;在上述参考文章中已有介绍&#xff0c;但对于…...

加解密算法相关技术详解

文章目录 简介工作机制加解密对称密钥算法非对称密钥算法 数字信封数字签名数字证书技术对比 推荐阅读 简介 随着网络技术的飞速发展&#xff0c;网络安全问题日益重要&#xff0c;加解密技术是网络安全技术中的核心技术&#xff0c;是最常用的安全保密手段。 加密&#xff1…...

测试微信模版消息推送

进入“开发接口管理”--“公众平台测试账号”&#xff0c;无需申请公众账号、可在测试账号中体验并测试微信公众平台所有高级接口。 获取access_token: 自定义模版消息&#xff1a; 关注测试号&#xff1a;扫二维码关注测试号。 发送模版消息&#xff1a; import requests da…...

TDengine 快速体验(Docker 镜像方式)

简介 TDengine 可以通过安装包、Docker 镜像 及云服务快速体验 TDengine 的功能&#xff0c;本节首先介绍如何通过 Docker 快速体验 TDengine&#xff0c;然后介绍如何在 Docker 环境下体验 TDengine 的写入和查询功能。如果你不熟悉 Docker&#xff0c;请使用 安装包的方式快…...

Xshell远程连接Kali(默认 | 私钥)Note版

前言:xshell远程连接&#xff0c;私钥连接和常规默认连接 任务一 开启ssh服务 service ssh status //查看ssh服务状态 service ssh start //开启ssh服务 update-rc.d ssh enable //开启自启动ssh服务 任务二 修改配置文件 vi /etc/ssh/ssh_config //第一…...

逻辑回归:给不确定性划界的分类大师

想象你是一名医生。面对患者的检查报告&#xff08;肿瘤大小、血液指标&#xff09;&#xff0c;你需要做出一个**决定性判断**&#xff1a;恶性还是良性&#xff1f;这种“非黑即白”的抉择&#xff0c;正是**逻辑回归&#xff08;Logistic Regression&#xff09;** 的战场&a…...

JavaScript 中的 ES|QL:利用 Apache Arrow 工具

作者&#xff1a;来自 Elastic Jeffrey Rengifo 学习如何将 ES|QL 与 JavaScript 的 Apache Arrow 客户端工具一起使用。 想获得 Elastic 认证吗&#xff1f;了解下一期 Elasticsearch Engineer 培训的时间吧&#xff01; Elasticsearch 拥有众多新功能&#xff0c;助你为自己…...

Opencv中的addweighted函数

一.addweighted函数作用 addweighted&#xff08;&#xff09;是OpenCV库中用于图像处理的函数&#xff0c;主要功能是将两个输入图像&#xff08;尺寸和类型相同&#xff09;按照指定的权重进行加权叠加&#xff08;图像融合&#xff09;&#xff0c;并添加一个标量值&#x…...

c++ 面试题(1)-----深度优先搜索(DFS)实现

操作系统&#xff1a;ubuntu22.04 IDE:Visual Studio Code 编程语言&#xff1a;C11 题目描述 地上有一个 m 行 n 列的方格&#xff0c;从坐标 [0,0] 起始。一个机器人可以从某一格移动到上下左右四个格子&#xff0c;但不能进入行坐标和列坐标的数位之和大于 k 的格子。 例…...

Spring AI 入门:Java 开发者的生成式 AI 实践之路

一、Spring AI 简介 在人工智能技术快速迭代的今天&#xff0c;Spring AI 作为 Spring 生态系统的新生力量&#xff0c;正在成为 Java 开发者拥抱生成式 AI 的最佳选择。该框架通过模块化设计实现了与主流 AI 服务&#xff08;如 OpenAI、Anthropic&#xff09;的无缝对接&…...

爬虫基础学习day2

# 爬虫设计领域 工商&#xff1a;企查查、天眼查短视频&#xff1a;抖音、快手、西瓜 ---> 飞瓜电商&#xff1a;京东、淘宝、聚美优品、亚马逊 ---> 分析店铺经营决策标题、排名航空&#xff1a;抓取所有航空公司价格 ---> 去哪儿自媒体&#xff1a;采集自媒体数据进…...

自然语言处理——循环神经网络

自然语言处理——循环神经网络 循环神经网络应用到基于机器学习的自然语言处理任务序列到类别同步的序列到序列模式异步的序列到序列模式 参数学习和长程依赖问题基于门控的循环神经网络门控循环单元&#xff08;GRU&#xff09;长短期记忆神经网络&#xff08;LSTM&#xff09…...