使用PyTorch导出JIT模型:C++ API与libtorch实战
PyTorch导出JIT模型并用C++ API libtorch调用
本文将介绍如何将一个 PyTorch 模型导出为 JIT 模型并用 PyTorch 的 C++API libtorch运行这个模型。
Step1:导出模型
首先我们进行第一步,用 Python API 来导出模型,由于本文的重点是在后面的部署阶段,因此,模型的训练就不进行了,直接对 torchvision 中自带的 ResNet50 进行导出。在实际应用中,大家可以对自己训练好的模型进行导出。
# export_jit_model.py
import torch
import torchvision.models as modelsmodel = models.resnet50(pretrained=True)
model.eval()example_input = torch.rand(1, 3, 224, 224)jit_model = torch.jit.trace(model, example_input)
torch.jit.save(jit_model, 'resnet50_jit.pth')
导出 JIT 模型的方式有两种:trace 和 script。
我们采用
torch.jit.trace
的方式来导出 JIT 模型,这种方式会根据一个输入将模型跑一遍,然后记录下执行过程。这种方式的问题在于对于有分支判断的模型不能很好的应对,因为一个输入不能覆盖到所有的分支。但是在我们 ResNet50 模型中不会遇到分支判断,因此这里是合适的。关于两种导出 JIT 模型的方式各自优劣不是本文的中断,以后会再写一篇来分析。
在我们的工程目录
demo
下运行上面的
export_jit_model.py
,会得到一个 JIT 模型件:
resnet50_jit.pth
。
Step 2:安装libtorch
接下来我们要安装 PyTorch 的 C++ API:libtorch。这一步很简单,直接下载官方预编译的文件并解压即可:
wget https://download.pytorch.org/libtorch/nightly/cpu/libtorch-shared-with-deps-latest.zip
unzip libtorch-shared-with-deps-latest.zip
也解压在我们的工程目录
demo
下即可。
Step 3:安装OpenCV
用 Python 或 C++ 做图像任务,OpenCV 是经常用到的。如果还没有安装的读者可以参考如下在工程目录
demo
下进行安装,构建的过程可能会比较久。已经安装的读者可跳过此步骤,一会儿在
CMakeLists.txt
文件中正确地指定本机的 OpenCV 地址即可。
git clone --branch 3.4 --depth 1 https://github.com/opencv/opencv.git
mkdir demo/build && cd demo/build
cmake ..
make -j 6
Step 4:准备测试图像并用Python测试
我们先准备一张小猫的图像,并用 PyTorch ResNet50 模型正常跑一下,一会儿与我们 C++ 模型运行的结果对比来验证 C++ 模型是否被正确的部署。
kitten.jpg
:
写一个脚本用 PyTorch 运行一下模型:
# pytorch_test.pyimport torchvision.models as models
from torchvision.transforms import transforms
import torch
from PIL import Image# normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
all_transforms = transforms.Compose([transforms.Resize(224),transforms.ToTensor()])# normalize])model = models.resnet50(pretrained=True)
model.eval()img = Image.open('kitten.jpg').convert('RGB')
img_tensor = all_transforms(img).unsqueeze(dim=0)
pred = model(img_tensor).squeeze(dim=0)
print(torch.argmax(pred).item())
输出结果是:282。通过查看
ImageNet 1K 类别名与索引的对应关系
,可以看到,结果为 tiger cat,模型预测正确。一会儿我们看一下部署后的 C++ 模型是否能正确输出结果 282。
Step 5:准备cpp源文件
我们下面准备一会要执行的 cpp 源文件,第一次使用 libtorch 的读者可以先借鉴下面的文件。
这里有几个点要说一下,不注意可能会犯错:
cv::imread()
默认读取为三通道BGR,需要进行B/R通道交换,这里采用
cv::cvtColor()
实现。- 图像尺寸需要调整到
224
×
224
224\times 224
2
2
4
×
2
2
4
,通过
cv::resize()
实现。
3. opencv读取的图像矩阵存储形式:H x W x C, 但是pytorch中 Tensor的存储为:N x C x H x W, 因此需要进行变换,就是
np.transpose()
操作,这里使用
tensor.permut()
实现,效果是一样的。
4. 数据归一化,采用
tensor.div(255)
实现。
// test_model.cpp
#include <vector>#include <torch/torch.h>
#include <torch/script.h>#include <opencv2/core.hpp>
#include <opencv2/imgproc/imgproc.hpp>
#include <opencv2/highgui/highgui.hpp>int main(int argc, char* argv[]) {// 加载JIT模型auto module = torch::jit::load(argv[1]);// 加载图像auto image = cv::imread(argv[2], cv::ImreadModes::IMREAD_COLOR);cv::Mat image_transfomed;cv::resize(image, image_transfomed, cv::Size(224, 224));cv::cvtColor(image_transfomed, image_transfomed, cv::COLOR_BGR2RGB);// 图像转换为Tensortorch::Tensor tensor_image = torch::from_blob(image_transfomed.data, {image_transfomed.rows, image_transfomed.cols, 3},torch::kByte);tensor_image = tensor_image.permute({2, 0, 1});// tensor_image = tensor_image.toType(torch::kFloat);tensor_image = tensor_image.div(255.);// tensor_image = tensor_image.sub(0.5);// tensor_image = tensor_image.div(0.5);tensor_image = tensor_image.unsqueeze(0);// 运行模型torch::Tensor output = module.forward({tensor_image}).toTensor();// 结果处理int result = output.argmax().item<int>();std::cout << "The classifiction index is: " << result << std::endl;return 0;
}
Step 6:构建运行验证
我们先来写一下
CMakeLists.txt
:
cmake_minimum_required(VERSION 3.0 FATAL_ERROR)
project(resnet50)find_package(Torch REQUIRED PATHS ./libtorch)
find_package(OpenCV REQUIRED)add_executable(resnet50 test_model.cpp)
target_link_libraries(resnet50 "${TORCH_LIBRARIES}" "${OpenCV_LIBS}")set_property(TARGET resnet50 PROPERTY CXX_STANDARD 14)
set(CMAKE_CXX_STANDARD_REQUIRED TRUE)set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${TORCH_CXX_FLAGS}")
现在我们的工程目录
demo
下有以下文件:
CMakeLists.txt export_jit_model.py kitten.jpg libtorch pytorch_test.py resnet50_jit.pth test_model.cpp
然后开始用 CMake 构建工程:
mkdir build && cd build
OpenCV_DIR=[YOUR_PATH_TO_OPENCV]/opencv/build cmake ..
make
整个过程没有报错的话我们就已经构建完成了,会得到一个可执行文件
resnet50
在工程目录
demo
下。
接下来我们执行,并验证运行结果是否与 PyTorch 的结果一致:
./build/resnet50 resnet50_jit.pth kitten.jpg
输出:
The classifiction index is: 282
运行成功并且结果正确。
Ref:
https://www.jianshu.com/p/7cddc09ca7a4
https://blog.csdn.net/cxx654/article/details/115916275
https://zhuanlan.zhihu.com/p/370455320
相关文章:

使用PyTorch导出JIT模型:C++ API与libtorch实战
PyTorch导出JIT模型并用C API libtorch调用 本文将介绍如何将一个 PyTorch 模型导出为 JIT 模型并用 PyTorch 的 CAPI libtorch运行这个模型。 Step1:导出模型 首先我们进行第一步,用 Python API 来导出模型,由于本文的重点是在后面的部署…...

Python——异常捕获,传递及其抛出操作
01. 异常的概念 1. 程序在运行时,如果 python解释器遇到一个错误,会停止程序的执行,并且提示一些错误信息,这就是异常。 2. 程序停止执行并且提示错误信息这个动作,我们通常称之为:抛出(raise…...
【Maven】 的继承机制
Maven是一个强大的项目管理工具,主要用于Java项目的构建和管理。它以其项目对象模型(POM)为基础,允许开发者定义项目的依赖、构建过程和插件。Maven的继承机制是其核心特性之一,它允许子项目继承和复用父项目的配置&am…...
微信小程序结合后端php发送模版消息
前端: <view class"container"><button bindtap"requestSubscribeMessage">订阅消息</button> </view> // index.js Page({data: {tmplIds: [UTgCUfsjHVESf5FjOzls0I9i_FVS1N620G2VQCg1LZ0] // 使用你的模板ID},requ…...
sqlalchemy报错sqlalchemy.orm.exc.DetachedInstanceError
解决方案: 在初始化数据库的代码中,将 maker sessionmaker(bindeng)修改为 maker sessionmaker(bindeng, expire_on_commitFalse)为什么要添加 expire_on_commitFalse 参数? expire_on_commit 可以用来更改 SQLAlchemy 的对象刷新机制&…...

华为网络模拟器eNSP安装部署教程
eNSP是图形化网络仿真平台,该平台通过对真实网络设备的仿真模拟,帮助广大ICT从业者和客户快速熟悉华为数通系列产品,了解并掌握相关产品的操作和配置、提升对企业ICT网络的规划、建设、运维能力,从而帮助企业构建更高效࿰…...

【React】详解样式控制:从基础到进阶应用的全面指南
文章目录 一、内联样式1. 什么是内联样式?2. 内联样式的定义3. 基本示例4. 动态内联样式 二、CSS模块1. 什么是CSS模块?2. CSS模块的定义3. 基本示例4. 动态应用样式 三、CSS-in-JS1. 什么是CSS-in-JS?2. styled-components的定义3. 基本示例…...

【ROS2】高级:安全-理解安全密钥库
目标:探索位于 ROS 2 安全密钥库中的文件。 教程级别:高级 时间:15 分钟 内容 背景安全工件位置 公钥材料 私钥材料域治理政策 安全飞地 参加测验! 背景 在继续之前,请确保您已完成设置安全教程。 sros2 包可以用来创…...

C语言 ——— 数组指针的定义 数组指针的使用
目录 前言 数组指针的定义 数组指针的使用 前言 之前有编写过关于 指针数组 的相关知识 C语言 ——— 指针数组 & 指针数组模拟二维整型数组-CSDN博客 指针数组 顾名思义就是 存放指针的数组 那什么是数组指针呢? 数组指针的定义 何为数组指针…...

opencascade AIS_ManipulatorOwner AIS_MediaPlayer源码学习
前言 AIS_ManipulatorOwner是OpenCascade中的一个类,主要用于操纵对象的交互控制。AIS_ManipulatorOwner结合AIS_Manipulator类,允许用户通过可视化工具(如旋转、平移、缩放等)来操纵几何对象。 以下是AIS_ManipulatorOwner的基…...
如何防止用户通过打印功能复制页面文字
简单防白嫖,要让打印出来的页面是空白,通常的做法是在打印时隐藏页面上的所有内容。这可以通过CSS的媒体查询(Media Queries)来实现,特别是针对media print的查询。 在JavaScript中,你通常不会直接控制打印…...

Python3网络爬虫开发实战(3)网页数据的解析提取
文章目录 一、XPath1. 选取节点2. 查找某个特定的节点或者包含某个指定的值的节点3. XPath 运算符4. 节点轴5. 利用 lxml 使用 XPath 二、CSS三、Beautiful Soup1. 信息提取2. 嵌套选择3. 关联选择4. 方法选择器5. css 选择器 四、PyQuery1. 初始化2. css 选择器3. 信息提取4. …...

基于 HTML+ECharts 实现监控平台数据可视化大屏(含源码)
构建监控平台数据可视化大屏:基于 HTML 和 ECharts 的实现 监控平台的数据可视化对于实时掌握系统状态、快速响应问题至关重要。通过直观的数据展示,运维团队可以迅速发现异常,优化资源配置。本文将详细介绍如何利用 HTML 和 ECharts 实现一个…...

立创梁山派--移植开源的SFUD和FATFS实现SPI-FLASH文件系统
本文主要是在sfud的基础上进行fatfs文件系统的移植,并不对sfud的移植再进行过多的讲解了哦,所以如果想了解sfud的移植过程,请参考我的另外一篇文章:传送门 正文开始咯 首先我们需要先准备资料准备好,这里对于fatfs的…...
MySQL之视图和索引实战
1.新建数据库 mysql> create database myudb5_indexstu; Query OK, 1 row affected (0.01 sec) mysql> use myudb5_indexstu; Database changed 2.新建表 1.学生表student,定义主键,姓名不能重名,性别只能输入男或女,所在…...

快速参考:用C# Selenium实现浏览器窗口缩放的步骤
背景介绍 在现代网络环境中,浏览器自动化已成为数据抓取和测试的重要工具。Selenium作为一个强大的浏览器自动化工具,能够与多种编程语言结合使用,其中C#是非常受欢迎的选择之一。在实际应用中,我们常常需要调整浏览器窗口的缩放…...
MyBatis 插件机制、分页插件如何实现的
MyBatis 插件机制允许开发者在 SQL 执行的各个阶段(如预处理、执行、结果处理等)中插入自定义逻辑,从而实现对 MyBatis 行为的扩展和增强。以下是 MyBatis 插件运行原理的详细介绍: 插件接口 MyBatis 插件通过实现 org.apache.i…...

CentOS6.0安装telnet-server启用telnet服务
CentOS6.0安装telnet-server启用telnet服务 一步到位 fp"/etc/yum.repos.d" ; cp -a ${fp} ${fp}.$(date %0y%0m%0d%0H%0M%0S).bkup echo [base] nameCentOS-$releasever - Base baseurlhttp://mirrors.163.com/centos-vault/6.0/os/$basearch/http://mirrors.a…...

H5+CSS+JS工作性价比计算器
工作性价比=平均日新x综合环境系数/35 x(工作时长+通勤时长—0.5 x摸鱼时长) x学历系数 如果代码中的公式不对,请指正 效果图 源代码 <!DOCTYPE html> <html> <head> <style> .calculator { width: 300px; padd…...

Linux:基础命令学习
目录 一、ls命令 实例:-l以长格式显示文件和目录信息 实例:-F根据文件类型在列出的文件名称后加一符号 实例: -R 递归显示目录中的所有文件和子目录。 实例: 组合使用 Home目录和工作目录 二、目录修改和查看命令 三、mkd…...
KubeSphere 容器平台高可用:环境搭建与可视化操作指南
Linux_k8s篇 欢迎来到Linux的世界,看笔记好好学多敲多打,每个人都是大神! 题目:KubeSphere 容器平台高可用:环境搭建与可视化操作指南 版本号: 1.0,0 作者: 老王要学习 日期: 2025.06.05 适用环境: Ubuntu22 文档说…...

SpringBoot-17-MyBatis动态SQL标签之常用标签
文章目录 1 代码1.1 实体User.java1.2 接口UserMapper.java1.3 映射UserMapper.xml1.3.1 标签if1.3.2 标签if和where1.3.3 标签choose和when和otherwise1.4 UserController.java2 常用动态SQL标签2.1 标签set2.1.1 UserMapper.java2.1.2 UserMapper.xml2.1.3 UserController.ja…...
云原生核心技术 (7/12): K8s 核心概念白话解读(上):Pod 和 Deployment 究竟是什么?
大家好,欢迎来到《云原生核心技术》系列的第七篇! 在上一篇,我们成功地使用 Minikube 或 kind 在自己的电脑上搭建起了一个迷你但功能完备的 Kubernetes 集群。现在,我们就像一个拥有了一块崭新数字土地的农场主,是时…...

Zustand 状态管理库:极简而强大的解决方案
Zustand 是一个轻量级、快速和可扩展的状态管理库,特别适合 React 应用。它以简洁的 API 和高效的性能解决了 Redux 等状态管理方案中的繁琐问题。 核心优势对比 基本使用指南 1. 创建 Store // store.js import create from zustandconst useStore create((set)…...
【Java学习笔记】Arrays类
Arrays 类 1. 导入包:import java.util.Arrays 2. 常用方法一览表 方法描述Arrays.toString()返回数组的字符串形式Arrays.sort()排序(自然排序和定制排序)Arrays.binarySearch()通过二分搜索法进行查找(前提:数组是…...

理解 MCP 工作流:使用 Ollama 和 LangChain 构建本地 MCP 客户端
🌟 什么是 MCP? 模型控制协议 (MCP) 是一种创新的协议,旨在无缝连接 AI 模型与应用程序。 MCP 是一个开源协议,它标准化了我们的 LLM 应用程序连接所需工具和数据源并与之协作的方式。 可以把它想象成你的 AI 模型 和想要使用它…...

让AI看见世界:MCP协议与服务器的工作原理
让AI看见世界:MCP协议与服务器的工作原理 MCP(Model Context Protocol)是一种创新的通信协议,旨在让大型语言模型能够安全、高效地与外部资源进行交互。在AI技术快速发展的今天,MCP正成为连接AI与现实世界的重要桥梁。…...

IoT/HCIP实验-3/LiteOS操作系统内核实验(任务、内存、信号量、CMSIS..)
文章目录 概述HelloWorld 工程C/C配置编译器主配置Makefile脚本烧录器主配置运行结果程序调用栈 任务管理实验实验结果osal 系统适配层osal_task_create 其他实验实验源码内存管理实验互斥锁实验信号量实验 CMISIS接口实验还是得JlINKCMSIS 简介LiteOS->CMSIS任务间消息交互…...

中医有效性探讨
文章目录 西医是如何发展到以生物化学为药理基础的现代医学?传统医学奠基期(远古 - 17 世纪)近代医学转型期(17 世纪 - 19 世纪末)现代医学成熟期(20世纪至今) 中医的源远流长和一脉相承远古至…...

push [特殊字符] present
push 🆚 present 前言present和dismiss特点代码演示 push和pop特点代码演示 前言 在 iOS 开发中,push 和 present 是两种不同的视图控制器切换方式,它们有着显著的区别。 present和dismiss 特点 在当前控制器上方新建视图层级需要手动调用…...