当前位置: 首页 > news >正文

记录pytorch实现自定义算子并转onnx文件输出

概览:记录了如何自定义一个算子,实现pytorch注册,通过C++编译为库文件供python端调用,并转为onnx文件输出

整体大概流程:

  • 定义算子实现为torch的C++版本文件
  • 注册算子
  • 编译算子生成库文件
  • 调用自定义算子

一、编译环境准备

1,在pytorch官网下载如下C++的libTorch package,下载完成后解压文件,是一个libtorch文件夹。

2,提前准备好python,以及pytorch

3,本示例使用了opencv库,所以需要提前安装好opencv。

二、自定义算子的实现

1,实现自定义算子函数

在解压后的libtorch文件夹统计目录,实现自定义算子,用opencv库实现的图像投射函数:warp_perspective。warp_perspective函数后面几行就是实现自定义算子的注册

warpPerspective.cpp文件:

#include "torch/script.h"
#include "opencv2/opencv.hpp"torch::Tensor warp_perspective(torch::Tensor image, torch::Tensor warp) {// BEGIN image_matcv::Mat image_mat(/*rows=*/image.size(0),/*cols=*/image.size(1),/*type=*/CV_32FC1,/*data=*/image.data_ptr<float>());// END image_mat// BEGIN warp_matcv::Mat warp_mat(/*rows=*/warp.size(0),/*cols=*/warp.size(1),/*type=*/CV_32FC1,/*data=*/warp.data_ptr<float>());// END warp_mat// BEGIN output_matcv::Mat output_mat;cv::warpPerspective(image_mat, output_mat, warp_mat, /*dsize=*/{ image.size(0),image.size(1) });// END output_mat// BEGIN output_tensortorch::Tensor output = torch::from_blob(output_mat.ptr<float>(), /*sizes=*/{ image.size(0),image.size(1) });return output.clone();// END output_tensor
}
//static auto registry = torch::RegisterOperators("my_ops::warp_perspective", &warp_perspective);  // torch.__version__: 1.5.0torch.__version__ >= 1.6.0  torch/include/torch/library.h
TORCH_LIBRARY(my_ops, m) {m.def("warp_perspective", warp_perspective);
}

2,同级目录创建CMakeList.txt文件

里面需要修改你自己的python下torch的路径,以及你对应安装python版pytorch是cpu还是gpu的。

cmake_minimum_required(VERSION 3.10 FATAL_ERROR)
project(warp_perspective)set(CMAKE_VERBOSE_MAKEFILE ON)
# >>> build type 
set(CMAKE_BUILD_TYPE "Release")				# 指定生成的版本
set(CMAKE_CXX_FLAGS_DEBUG "$ENV{CXXFLAGS} -O0 -Wall -g2 -ggdb")
set(CMAKE_CXX_FLAGS_RELEASE "$ENV{CXXFLAGS} -O3 -Wall")set(TORCH_ROOT "/home/xxx/anaconda3/lib/python3.10/site-packages/torch")   
include_directories(${TORCH_ROOT}/include)
link_directories(${TORCH_ROOT}/lib/)# Opencv
find_package(OpenCV REQUIRED)# Define our library target
add_library(warp_perspective SHARED warpPerspective.cpp)# Enable C++14
target_compile_features(warp_perspective PRIVATE cxx_std_17)# libtorch库文件
target_link_libraries(warp_perspective # CPUc10 torch_cpu# GPU# c10_cuda # torch_cuda)# opencv库文件
target_link_libraries(warp_perspective${OpenCV_LIBS}
)add_definitions(-D _GLIBCXX_USE_CXX11_ABI=0)

3,编译生成库文件

同级目录创建build文件夹,进入build文件夹利用CMakeList.txt进行编译,生成libwarp_perspective.so库文件

mkdir build
cd build
cmake ..
make

4,python版pytorch进行自定义算子的测试

注意我的以上代码都是放在了/data/xxx/mylib路径下,所以torch.ops.load_library("/data/xxx/mylib/build/libwarp_perspective.so")就找到库文件的位置。

这里我随便找了一张图片,和直接用python版的opencv做投射变换的结果作为golden对比。如下分别是原图,golden, 自定义pytorch算子的输出。自定义算子的输出不太对,但是图像轮廓和投射效果是对的,后面有时间我再检查一下是什么原因。

测试代码: 

import torch
import cv2
import numpy as nptorch.ops.load_library("/data/xxx/mylib/build/libwarp_perspective.so")im=cv2.imread("/data/xxx/mylib/cat.jpg",0)pst1 = np.float32([[56,65], [368,52], [28,387], [389,390]])
pst2 = np.float32([[100,145], [300,100], [80,290], [310,300]])
#2.2获取透视变换矩阵
T = cv2.getPerspectiveTransform(pst1, pst2)in_data =torch.from_numpy(np.float32(im))
in2_data = torch.Tensor(T)out1=torch.ops.my_ops.warp_perspective(in_data,in2_data)
dst0=np.uint8(out1.numpy())
cv2.imwrite("/data/xxx/mylib/cat_warp.jpg",dst0)dst = cv2.warpPerspective(im, np.float32(T), (im.shape[1], im.shape[0]))
cv2.imwrite("/data/xxx/mylib/cat_warp_gold.jpg",dst)

三、自定义算子导出为onnx文件

将注册的pytorch的自定义算子导出为onnx文件查看,效果图如下:

导出代码文件如下

import torch
import numpy as nptorch.ops.load_library("/data/xxx/mylib/build/libwarp_perspective.so")
class MyNet(torch.nn.Module):def __init__(self, name):super(MyNet, self).__init__()self.model_name = namedef forward(self, in_data, warp_data):return torch.ops.my_ops.warp_perspective(in_data, warp_data)def my_custom(g, in_data, warp_data):return g.op("cus_ops::warp_perspective", in_data, warp_data)
torch.onnx.register_custom_op_symbolic("my_ops::warp_perspective", my_custom, 9)if __name__ == "__main__":net = MyNet("my_ops")in_data = torch.randn((32, 32))warp_data = torch.rand((3, 3))out = net(in_data, warp_data)print("out: ", out)# export onnxtorch.onnx.export(net,(in_data, warp_data),"./my_ops_export_model2.onnx",input_names=["img_data", "warp_mat"],output_names=["out_img"],custom_opsets={"cus_ops": 11},)

相关文章:

记录pytorch实现自定义算子并转onnx文件输出

概览&#xff1a;记录了如何自定义一个算子&#xff0c;实现pytorch注册&#xff0c;通过C编译为库文件供python端调用&#xff0c;并转为onnx文件输出 整体大概流程&#xff1a; 定义算子实现为torch的C版本文件注册算子编译算子生成库文件调用自定义算子 一、编译环境准备…...

ARPG----C++学习记录04 Section8 角色类,移动

角色类输入 新建一个角色C&#xff0c;继承建立蓝图,和Pawn一样&#xff0c;绑定输入移动和相机. 在构造函数中添加这段代码也能实现。打开UsePawnControlRotation就可以让人物不跟随鼠标旋转 得到旋转后的向前向量 使用旋转矩阵 想要前进方向和旋转的方向对应。获取当前控制…...

拆解软件定义汽车:OS突围

软件作为智能汽车的核心组成部分&#xff0c;由于自身较为独立和复杂的IT学科体系&#xff0c;其技术链路、产业分工、价值分配、商业模式相对硬件产品&#xff08;如域控、激光雷达、摄像头等硬件&#xff09;而言&#xff0c;在汽车产业内探讨和传播相对较少。 11月3日&…...

并发线程使用介绍(二)

2.2.6 线程的强占 Thread的非静态方法join方法 需要在某一个线程下去调用这个方法 如果在main线程中调用了t1.join()&#xff0c;那么main线程会进入到等待状态&#xff0c;需要等待t1线程全部执行完毕&#xff0c;在恢复到就绪状态等待 CPU调度。 如果在main线程中调用了t1.j…...

【Proteus仿真】【51单片机】多路温度控制系统

文章目录 一、功能简介二、软件设计三、实验现象联系作者 一、功能简介 本项目使用Proteus8仿真51单片机控制器&#xff0c;使用按键、LED、蜂鸣器、LCD1602、DS18B20温度传感器、HC05蓝牙模块等。 主要功能&#xff1a; 系统运行后&#xff0c;默认LCD1602显示前4路采集的温…...

一些可以参考的文档集合15

之前的文章集合: 一些可以参考文章集合1_xuejianxinokok的博客-CSDN博客 一些可以参考文章集合2_xuejianxinokok的博客-CSDN博客 一些可以参考的文档集合3_xuejianxinokok的博客-CSDN博客 一些可以参考的文档集合4_xuejianxinokok的博客-CSDN博客 一些可以参考的文档集合5…...

k8s的service自动发现服务:实战版

Service服务发现的必要性: 对于kubernetes整个集群来说&#xff0c;Pod的地址也可变的&#xff0c;也就是说如果一个Pod因为某些原因退出了&#xff0c;而由于其设置了副本数replicas大于1&#xff0c;那么该Pod就会在集群的任意节点重新启动&#xff0c;这个重新启动的Pod的I…...

项目笔记记录

一、node下载版本报错&#xff1a;npm install --legacy-peer-deps 二、Scheduled: 任务自动化调度 Scheduled 标记要调度的方法的注解&#xff0c;必须指定 cron&#xff0c;fixedDelay或fixedRate属性之一 fixedDelay&#xff1a;固定延迟 延迟执行任务&#xff0c;任务在…...

【leetcode】1137. 第 N 个泰波那契数

题目 泰波那契序列 Tn 定义如下&#xff1a; T0 0, T1 1, T2 1, 且在 n > 0 的条件下 Tn3 Tn Tn1 Tn2 给你整数 n&#xff0c;请返回第 n 个泰波那契数 Tn 的值。 示例 1&#xff1a; 输入&#xff1a;n 4 输出&#xff1a;4 解释&#xff1a; T_3 0 1 1 2 …...

【解决】conda-script.py: error: argument COMMAND: invalid choice: ‘activate‘

运行conda activate base报错&#xff1a; 试了网上找到的解决方法都不行&#xff1a; 最后切换了一下terminal&#xff1a; 从powershell改回cmd&#xff08;不知道为什么一开始手贱换成powershell&#xff09; 就可以了...

Linux 性能调优之硬件资源监控

写在前面 考试整理相关笔记博文内容涉及 Linux 硬件资源监控常见的命名介绍&#xff0c;涉及硬件基本信息查看查看硬件错误信息查看虚拟环境和云环境资源理解不足小伙伴帮忙指正 对每个人而言&#xff0c;真正的职责只有一个&#xff1a;找到自我。然后在心中坚守其一生&#x…...

Windows系统隐藏窗口启动控制台程序

背景 上线项目有时候需要一些控制台应用作为辅助服务来协助UI应用满足实际需求&#xff0c;这时候如果一运行UI就冒出一系列的黑框&#xff0c;这将会导致客户被下的不起&#xff0c;生怕中了什么不知名病毒 方案 可以使用vbs来启动&#xff0c;这个是window系统自带的&#…...

FreeSWITCH fail2ban.lua

--[[ 部署:在vars.xml里面增加配置项目&#xff1a;<X-PRE-PROCESS cmd"set" data"api_on_startupluarun fail2ban.lua"/>或者在 lua.conf.xml 里面增加下面这个配置项目&#xff1a;<param name"startup-script" value"fail2ban.…...

Qt HTTP下载数据

添加头文件&#xff1a; #include <QNetworkAccessManager> #include <QNetworkReply> #include <QUrl> #include <QDesktopServices> 创建对象&#xff1a; QNetworkAccessManager networkManager;//网络管理QNetworkReply *reply; …...

8. 深度学习——NLP

机器学习面试题汇总与解析——NLP 本章讲解知识点 什么是 NLP循环神经网络(RNN)RNN 变体Attention 机制RNN 反向传播推导LSTM 与 GRUTransformerBertGPT分词算法分类CBOW 模型与 Skip-Gram 模型本专栏适合于Python已经入门的学生或人士,有一定的编程基础。本专栏适合于算法…...

部署 KVM 虚拟化平台

虚拟化技术的演变过程分为软件模拟、虚拟化层翻译、容器虚拟化三个阶段 1 软件模拟的技术方式 软件模拟是通过软件完全模拟CPU、网卡、芯片组、磁盘等计算机硬件&#xff0c;因为是软件模拟&#xff0c;所以理论上可以模拟任何硬件&#xff0c;甚至不存在的硬件。但是由于是软…...

Juniper PPPOE双线路冗余RPM配置

------------------ 浮动静态路由 set routing-options static route 0.0.0.0/0 next-hop pp0.0 qualified-next-hop pp0.1 preference 10 ----------------- RPM测试的内容,包括从哪个接口发起测试,测试ping等等 #指定探针类型用ICMP请求 #探测的目标地址 #探测间隔 #探测阈…...

原生JS实现视频截图

视频截图效果预览 利用Canvas进行截图 要用原生js实现视频截图&#xff0c;可以利用canvas的绘图功能 ctx.drawImage&#xff0c;只需要获取到视频标签&#xff0c;就可以通过drawImage把视频当前帧图像绘制在canvas画布上。 const video document.querySelector(video) con…...

前端Rust二进制/wasm全平台构建流程简述

前言 开门见山&#xff0c;现代前端 Rust 构建基本分三大类&#xff0c;即 构建 .wasm 、构建 .node 二进制 、构建 swc 插件。 入门详见 《 前端Rust开发WebAssembly与Swc插件快速入门 》 。 对于单独开发某一类的流程&#xff0c;在上述参考文章中已有介绍&#xff0c;但对于…...

加解密算法相关技术详解

文章目录 简介工作机制加解密对称密钥算法非对称密钥算法 数字信封数字签名数字证书技术对比 推荐阅读 简介 随着网络技术的飞速发展&#xff0c;网络安全问题日益重要&#xff0c;加解密技术是网络安全技术中的核心技术&#xff0c;是最常用的安全保密手段。 加密&#xff1…...

conda相比python好处

Conda 作为 Python 的环境和包管理工具&#xff0c;相比原生 Python 生态&#xff08;如 pip 虚拟环境&#xff09;有许多独特优势&#xff0c;尤其在多项目管理、依赖处理和跨平台兼容性等方面表现更优。以下是 Conda 的核心好处&#xff1a; 一、一站式环境管理&#xff1a…...

【杂谈】-递归进化:人工智能的自我改进与监管挑战

递归进化&#xff1a;人工智能的自我改进与监管挑战 文章目录 递归进化&#xff1a;人工智能的自我改进与监管挑战1、自我改进型人工智能的崛起2、人工智能如何挑战人类监管&#xff1f;3、确保人工智能受控的策略4、人类在人工智能发展中的角色5、平衡自主性与控制力6、总结与…...

通过Wrangler CLI在worker中创建数据库和表

官方使用文档&#xff1a;Getting started Cloudflare D1 docs 创建数据库 在命令行中执行完成之后&#xff0c;会在本地和远程创建数据库&#xff1a; npx wranglerlatest d1 create prod-d1-tutorial 在cf中就可以看到数据库&#xff1a; 现在&#xff0c;您的Cloudfla…...

大语言模型如何处理长文本?常用文本分割技术详解

为什么需要文本分割? 引言:为什么需要文本分割?一、基础文本分割方法1. 按段落分割(Paragraph Splitting)2. 按句子分割(Sentence Splitting)二、高级文本分割策略3. 重叠分割(Sliding Window)4. 递归分割(Recursive Splitting)三、生产级工具推荐5. 使用LangChain的…...

Neo4j 集群管理:原理、技术与最佳实践深度解析

Neo4j 的集群技术是其企业级高可用性、可扩展性和容错能力的核心。通过深入分析官方文档,本文将系统阐述其集群管理的核心原理、关键技术、实用技巧和行业最佳实践。 Neo4j 的 Causal Clustering 架构提供了一个强大而灵活的基石,用于构建高可用、可扩展且一致的图数据库服务…...

ios苹果系统,js 滑动屏幕、锚定无效

现象&#xff1a;window.addEventListener监听touch无效&#xff0c;划不动屏幕&#xff0c;但是代码逻辑都有执行到。 scrollIntoView也无效。 原因&#xff1a;这是因为 iOS 的触摸事件处理机制和 touch-action: none 的设置有关。ios有太多得交互动作&#xff0c;从而会影响…...

JavaScript基础-API 和 Web API

在学习JavaScript的过程中&#xff0c;理解API&#xff08;应用程序接口&#xff09;和Web API的概念及其应用是非常重要的。这些工具极大地扩展了JavaScript的功能&#xff0c;使得开发者能够创建出功能丰富、交互性强的Web应用程序。本文将深入探讨JavaScript中的API与Web AP…...

Java求职者面试指南:计算机基础与源码原理深度解析

Java求职者面试指南&#xff1a;计算机基础与源码原理深度解析 第一轮提问&#xff1a;基础概念问题 1. 请解释什么是进程和线程的区别&#xff1f; 面试官&#xff1a;进程是程序的一次执行过程&#xff0c;是系统进行资源分配和调度的基本单位&#xff1b;而线程是进程中的…...

基于Java+VUE+MariaDB实现(Web)仿小米商城

仿小米商城 环境安装 nodejs maven JDK11 运行 mvn clean install -DskipTestscd adminmvn spring-boot:runcd ../webmvn spring-boot:runcd ../xiaomi-store-admin-vuenpm installnpm run servecd ../xiaomi-store-vuenpm installnpm run serve 注意&#xff1a;运行前…...

LLaMA-Factory 微调 Qwen2-VL 进行人脸情感识别(二)

在上一篇文章中,我们详细介绍了如何使用LLaMA-Factory框架对Qwen2-VL大模型进行微调,以实现人脸情感识别的功能。本篇文章将聚焦于微调完成后,如何调用这个模型进行人脸情感识别的具体代码实现,包括详细的步骤和注释。 模型调用步骤 环境准备:确保安装了必要的Python库。…...