当前位置: 首页 > news >正文

使用PyTorch导出JIT模型:C++ API与libtorch实战

PyTorch导出JIT模型并用C++ API libtorch调用

本文将介绍如何将一个 PyTorch 模型导出为 JIT 模型并用 PyTorch 的 C++API libtorch运行这个模型。

Step1:导出模型

首先我们进行第一步,用 Python API 来导出模型,由于本文的重点是在后面的部署阶段,因此,模型的训练就不进行了,直接对 torchvision 中自带的 ResNet50 进行导出。在实际应用中,大家可以对自己训练好的模型进行导出。

# export_jit_model.py
import torch
import torchvision.models as modelsmodel = models.resnet50(pretrained=True)
model.eval()example_input = torch.rand(1, 3, 224, 224)jit_model = torch.jit.trace(model, example_input)
torch.jit.save(jit_model, 'resnet50_jit.pth')

导出 JIT 模型的方式有两种:trace 和 script。

我们采用
torch.jit.trace
的方式来导出 JIT 模型,这种方式会根据一个输入将模型跑一遍,然后记录下执行过程。这种方式的问题在于对于有分支判断的模型不能很好的应对,因为一个输入不能覆盖到所有的分支。但是在我们 ResNet50 模型中不会遇到分支判断,因此这里是合适的。关于两种导出 JIT 模型的方式各自优劣不是本文的中断,以后会再写一篇来分析。

在我们的工程目录
demo
下运行上面的
export_jit_model.py
,会得到一个 JIT 模型件:
resnet50_jit.pth

Step 2:安装libtorch

接下来我们要安装 PyTorch 的 C++ API:libtorch。这一步很简单,直接下载官方预编译的文件并解压即可:

wget https://download.pytorch.org/libtorch/nightly/cpu/libtorch-shared-with-deps-latest.zip
unzip libtorch-shared-with-deps-latest.zip

也解压在我们的工程目录
demo
下即可。

Step 3:安装OpenCV

用 Python 或 C++ 做图像任务,OpenCV 是经常用到的。如果还没有安装的读者可以参考如下在工程目录
demo
下进行安装,构建的过程可能会比较久。已经安装的读者可跳过此步骤,一会儿在
CMakeLists.txt
文件中正确地指定本机的 OpenCV 地址即可。

git clone --branch 3.4 --depth 1 https://github.com/opencv/opencv.git
mkdir demo/build && cd demo/build
cmake ..
make -j 6

Step 4:准备测试图像并用Python测试

我们先准备一张小猫的图像,并用 PyTorch ResNet50 模型正常跑一下,一会儿与我们 C++ 模型运行的结果对比来验证 C++ 模型是否被正确的部署。

kitten.jpg

写一个脚本用 PyTorch 运行一下模型:

# pytorch_test.pyimport torchvision.models as models
from torchvision.transforms import transforms
import torch
from PIL import Image# normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
all_transforms = transforms.Compose([transforms.Resize(224),transforms.ToTensor()])# normalize])model = models.resnet50(pretrained=True)
model.eval()img = Image.open('kitten.jpg').convert('RGB')
img_tensor = all_transforms(img).unsqueeze(dim=0)
pred = model(img_tensor).squeeze(dim=0)
print(torch.argmax(pred).item())

输出结果是:282。通过查看
ImageNet 1K 类别名与索引的对应关系
,可以看到,结果为 tiger cat,模型预测正确。一会儿我们看一下部署后的 C++ 模型是否能正确输出结果 282。

Step 5:准备cpp源文件

我们下面准备一会要执行的 cpp 源文件,第一次使用 libtorch 的读者可以先借鉴下面的文件。

这里有几个点要说一下,不注意可能会犯错:

  1. cv::imread()
    默认读取为三通道BGR,需要进行B/R通道交换,这里采用
    cv::cvtColor()
    实现。
  2. 图像尺寸需要调整到

224

×

224

224\times 224

2

2

4

×

2

2

4

,通过
cv::resize()
实现。
3. opencv读取的图像矩阵存储形式:H x W x C, 但是pytorch中 Tensor的存储为:N x C x H x W, 因此需要进行变换,就是
np.transpose()
操作,这里使用
tensor.permut()
实现,效果是一样的。
4. 数据归一化,采用
tensor.div(255)
实现。

// test_model.cpp
#include <vector>#include <torch/torch.h>
#include <torch/script.h>#include <opencv2/core.hpp>
#include <opencv2/imgproc/imgproc.hpp>
#include <opencv2/highgui/highgui.hpp>int main(int argc, char* argv[]) {// 加载JIT模型auto module = torch::jit::load(argv[1]);// 加载图像auto image = cv::imread(argv[2], cv::ImreadModes::IMREAD_COLOR);cv::Mat image_transfomed;cv::resize(image, image_transfomed, cv::Size(224, 224));cv::cvtColor(image_transfomed, image_transfomed, cv::COLOR_BGR2RGB);// 图像转换为Tensortorch::Tensor tensor_image = torch::from_blob(image_transfomed.data, {image_transfomed.rows, image_transfomed.cols, 3},torch::kByte);tensor_image = tensor_image.permute({2, 0, 1});// tensor_image = tensor_image.toType(torch::kFloat);tensor_image = tensor_image.div(255.);// tensor_image = tensor_image.sub(0.5);// tensor_image = tensor_image.div(0.5);tensor_image = tensor_image.unsqueeze(0);// 运行模型torch::Tensor output = module.forward({tensor_image}).toTensor();// 结果处理int result = output.argmax().item<int>();std::cout << "The classifiction index is: " << result << std::endl;return 0;
}

Step 6:构建运行验证

我们先来写一下
CMakeLists.txt

cmake_minimum_required(VERSION 3.0 FATAL_ERROR)
project(resnet50)find_package(Torch REQUIRED PATHS ./libtorch)
find_package(OpenCV REQUIRED)add_executable(resnet50  test_model.cpp)
target_link_libraries(resnet50 "${TORCH_LIBRARIES}" "${OpenCV_LIBS}")set_property(TARGET resnet50  PROPERTY CXX_STANDARD 14)
set(CMAKE_CXX_STANDARD_REQUIRED TRUE)set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${TORCH_CXX_FLAGS}")

现在我们的工程目录
demo
下有以下文件:

CMakeLists.txt  export_jit_model.py  kitten.jpg  libtorch  pytorch_test.py  resnet50_jit.pth  test_model.cpp

然后开始用 CMake 构建工程:

mkdir build && cd build
OpenCV_DIR=[YOUR_PATH_TO_OPENCV]/opencv/build cmake ..
make

整个过程没有报错的话我们就已经构建完成了,会得到一个可执行文件
resnet50
在工程目录
demo
下。

接下来我们执行,并验证运行结果是否与 PyTorch 的结果一致:

./build/resnet50 resnet50_jit.pth kitten.jpg

输出:

The classifiction index is: 282

运行成功并且结果正确。

Ref:

https://www.jianshu.com/p/7cddc09ca7a4

https://blog.csdn.net/cxx654/article/details/115916275

https://zhuanlan.zhihu.com/p/370455320

相关文章:

使用PyTorch导出JIT模型:C++ API与libtorch实战

PyTorch导出JIT模型并用C API libtorch调用 本文将介绍如何将一个 PyTorch 模型导出为 JIT 模型并用 PyTorch 的 CAPI libtorch运行这个模型。 Step1&#xff1a;导出模型 首先我们进行第一步&#xff0c;用 Python API 来导出模型&#xff0c;由于本文的重点是在后面的部署…...

Python——异常捕获,传递及其抛出操作

01. 异常的概念 1. 程序在运行时&#xff0c;如果 python解释器遇到一个错误&#xff0c;会停止程序的执行&#xff0c;并且提示一些错误信息&#xff0c;这就是异常。 2. 程序停止执行并且提示错误信息这个动作&#xff0c;我们通常称之为&#xff1a;抛出&#xff08;raise…...

【Maven】 的继承机制

Maven是一个强大的项目管理工具&#xff0c;主要用于Java项目的构建和管理。它以其项目对象模型&#xff08;POM&#xff09;为基础&#xff0c;允许开发者定义项目的依赖、构建过程和插件。Maven的继承机制是其核心特性之一&#xff0c;它允许子项目继承和复用父项目的配置&am…...

微信小程序结合后端php发送模版消息

前端&#xff1a; <view class"container"><button bindtap"requestSubscribeMessage">订阅消息</button> </view> // index.js Page({data: {tmplIds: [UTgCUfsjHVESf5FjOzls0I9i_FVS1N620G2VQCg1LZ0] // 使用你的模板ID},requ…...

sqlalchemy报错sqlalchemy.orm.exc.DetachedInstanceError

解决方案&#xff1a; 在初始化数据库的代码中&#xff0c;将 maker sessionmaker(bindeng)修改为 maker sessionmaker(bindeng, expire_on_commitFalse)为什么要添加 expire_on_commitFalse 参数&#xff1f; expire_on_commit 可以用来更改 SQLAlchemy 的对象刷新机制&…...

华为网络模拟器eNSP安装部署教程

eNSP是图形化网络仿真平台&#xff0c;该平台通过对真实网络设备的仿真模拟&#xff0c;帮助广大ICT从业者和客户快速熟悉华为数通系列产品&#xff0c;了解并掌握相关产品的操作和配置、提升对企业ICT网络的规划、建设、运维能力&#xff0c;从而帮助企业构建更高效&#xff0…...

【React】详解样式控制:从基础到进阶应用的全面指南

文章目录 一、内联样式1. 什么是内联样式&#xff1f;2. 内联样式的定义3. 基本示例4. 动态内联样式 二、CSS模块1. 什么是CSS模块&#xff1f;2. CSS模块的定义3. 基本示例4. 动态应用样式 三、CSS-in-JS1. 什么是CSS-in-JS&#xff1f;2. styled-components的定义3. 基本示例…...

【ROS2】高级:安全-理解安全密钥库

目标&#xff1a;探索位于 ROS 2 安全密钥库中的文件。 教程级别&#xff1a;高级 时间&#xff1a;15 分钟 内容 背景安全工件位置 公钥材料 私钥材料域治理政策 安全飞地 参加测验&#xff01; 背景 在继续之前&#xff0c;请确保您已完成设置安全教程。 sros2 包可以用来创…...

C语言 ——— 数组指针的定义 数组指针的使用

目录 前言 数组指针的定义 数组指针的使用 前言 之前有编写过关于 指针数组 的相关知识 C语言 ——— 指针数组 & 指针数组模拟二维整型数组-CSDN博客 指针数组 顾名思义就是 存放指针的数组 那什么是数组指针呢&#xff1f; 数组指针的定义 何为数组指针&#xf…...

opencascade AIS_ManipulatorOwner AIS_MediaPlayer源码学习

前言 AIS_ManipulatorOwner是OpenCascade中的一个类&#xff0c;主要用于操纵对象的交互控制。AIS_ManipulatorOwner结合AIS_Manipulator类&#xff0c;允许用户通过可视化工具&#xff08;如旋转、平移、缩放等&#xff09;来操纵几何对象。 以下是AIS_ManipulatorOwner的基…...

如何防止用户通过打印功能复制页面文字

简单防白嫖&#xff0c;要让打印出来的页面是空白&#xff0c;通常的做法是在打印时隐藏页面上的所有内容。这可以通过CSS的媒体查询&#xff08;Media Queries&#xff09;来实现&#xff0c;特别是针对media print的查询。 在JavaScript中&#xff0c;你通常不会直接控制打印…...

Python3网络爬虫开发实战(3)网页数据的解析提取

文章目录 一、XPath1. 选取节点2. 查找某个特定的节点或者包含某个指定的值的节点3. XPath 运算符4. 节点轴5. 利用 lxml 使用 XPath 二、CSS三、Beautiful Soup1. 信息提取2. 嵌套选择3. 关联选择4. 方法选择器5. css 选择器 四、PyQuery1. 初始化2. css 选择器3. 信息提取4. …...

基于 HTML+ECharts 实现监控平台数据可视化大屏(含源码)

构建监控平台数据可视化大屏&#xff1a;基于 HTML 和 ECharts 的实现 监控平台的数据可视化对于实时掌握系统状态、快速响应问题至关重要。通过直观的数据展示&#xff0c;运维团队可以迅速发现异常&#xff0c;优化资源配置。本文将详细介绍如何利用 HTML 和 ECharts 实现一个…...

立创梁山派--移植开源的SFUD和FATFS实现SPI-FLASH文件系统

本文主要是在sfud的基础上进行fatfs文件系统的移植&#xff0c;并不对sfud的移植再进行过多的讲解了哦&#xff0c;所以如果想了解sfud的移植过程&#xff0c;请参考我的另外一篇文章&#xff1a;传送门 正文开始咯 首先我们需要先准备资料准备好&#xff0c;这里对于fatfs的…...

MySQL之视图和索引实战

1.新建数据库 mysql> create database myudb5_indexstu; Query OK, 1 row affected (0.01 sec) mysql> use myudb5_indexstu; Database changed 2.新建表 1.学生表student&#xff0c;定义主键&#xff0c;姓名不能重名&#xff0c;性别只能输入男或女&#xff0c;所在…...

快速参考:用C# Selenium实现浏览器窗口缩放的步骤

背景介绍 在现代网络环境中&#xff0c;浏览器自动化已成为数据抓取和测试的重要工具。Selenium作为一个强大的浏览器自动化工具&#xff0c;能够与多种编程语言结合使用&#xff0c;其中C#是非常受欢迎的选择之一。在实际应用中&#xff0c;我们常常需要调整浏览器窗口的缩放…...

MyBatis 插件机制、分页插件如何实现的

MyBatis 插件机制允许开发者在 SQL 执行的各个阶段&#xff08;如预处理、执行、结果处理等&#xff09;中插入自定义逻辑&#xff0c;从而实现对 MyBatis 行为的扩展和增强。以下是 MyBatis 插件运行原理的详细介绍&#xff1a; 插件接口 MyBatis 插件通过实现 org.apache.i…...

CentOS6.0安装telnet-server启用telnet服务

CentOS6.0安装telnet-server启用telnet服务 一步到位 fp"/etc/yum.repos.d" ; cp -a ${fp} ${fp}.$(date %0y%0m%0d%0H%0M%0S).bkup echo [base] nameCentOS-$releasever - Base baseurlhttp://mirrors.163.com/centos-vault/6.0/os/$basearch/http://mirrors.a…...

H5+CSS+JS工作性价比计算器

工作性价比&#xff1d;平均日新x综合环境系数/35 x(工作时长&#xff0b;通勤时长—0.5 x摸鱼时长) x学历系数 如果代码中的公式不对&#xff0c;请指正 效果图 源代码 <!DOCTYPE html> <html> <head> <style> .calculator { width: 300px; padd…...

Linux:基础命令学习

目录 一、ls命令 实例&#xff1a;-l以长格式显示文件和目录信息 实例&#xff1a;-F根据文件类型在列出的文件名称后加一符号 实例&#xff1a; -R 递归显示目录中的所有文件和子目录。 实例&#xff1a; 组合使用 Home目录和工作目录 二、目录修改和查看命令 三、mkd…...

conda相比python好处

Conda 作为 Python 的环境和包管理工具&#xff0c;相比原生 Python 生态&#xff08;如 pip 虚拟环境&#xff09;有许多独特优势&#xff0c;尤其在多项目管理、依赖处理和跨平台兼容性等方面表现更优。以下是 Conda 的核心好处&#xff1a; 一、一站式环境管理&#xff1a…...

Golang 面试经典题:map 的 key 可以是什么类型?哪些不可以?

Golang 面试经典题&#xff1a;map 的 key 可以是什么类型&#xff1f;哪些不可以&#xff1f; 在 Golang 的面试中&#xff0c;map 类型的使用是一个常见的考点&#xff0c;其中对 key 类型的合法性 是一道常被提及的基础却很容易被忽视的问题。本文将带你深入理解 Golang 中…...

Debian系统简介

目录 Debian系统介绍 Debian版本介绍 Debian软件源介绍 软件包管理工具dpkg dpkg核心指令详解 安装软件包 卸载软件包 查询软件包状态 验证软件包完整性 手动处理依赖关系 dpkg vs apt Debian系统介绍 Debian 和 Ubuntu 都是基于 Debian内核 的 Linux 发行版&#xff…...

Objective-C常用命名规范总结

【OC】常用命名规范总结 文章目录 【OC】常用命名规范总结1.类名&#xff08;Class Name)2.协议名&#xff08;Protocol Name)3.方法名&#xff08;Method Name)4.属性名&#xff08;Property Name&#xff09;5.局部变量/实例变量&#xff08;Local / Instance Variables&…...

1688商品列表API与其他数据源的对接思路

将1688商品列表API与其他数据源对接时&#xff0c;需结合业务场景设计数据流转链路&#xff0c;重点关注数据格式兼容性、接口调用频率控制及数据一致性维护。以下是具体对接思路及关键技术点&#xff1a; 一、核心对接场景与目标 商品数据同步 场景&#xff1a;将1688商品信息…...

今日科技热点速览

&#x1f525; 今日科技热点速览 &#x1f3ae; 任天堂Switch 2 正式发售 任天堂新一代游戏主机 Switch 2 今日正式上线发售&#xff0c;主打更强图形性能与沉浸式体验&#xff0c;支持多模态交互&#xff0c;受到全球玩家热捧 。 &#x1f916; 人工智能持续突破 DeepSeek-R1&…...

Linux --进程控制

本文从以下五个方面来初步认识进程控制&#xff1a; 目录 进程创建 进程终止 进程等待 进程替换 模拟实现一个微型shell 进程创建 在Linux系统中我们可以在一个进程使用系统调用fork()来创建子进程&#xff0c;创建出来的进程就是子进程&#xff0c;原来的进程为父进程。…...

Java线上CPU飙高问题排查全指南

一、引言 在Java应用的线上运行环境中&#xff0c;CPU飙高是一个常见且棘手的性能问题。当系统出现CPU飙高时&#xff0c;通常会导致应用响应缓慢&#xff0c;甚至服务不可用&#xff0c;严重影响用户体验和业务运行。因此&#xff0c;掌握一套科学有效的CPU飙高问题排查方法&…...

R 语言科研绘图第 55 期 --- 网络图-聚类

在发表科研论文的过程中&#xff0c;科研绘图是必不可少的&#xff0c;一张好看的图形会是文章很大的加分项。 为了便于使用&#xff0c;本系列文章介绍的所有绘图都已收录到了 sciRplot 项目中&#xff0c;获取方式&#xff1a; R 语言科研绘图模板 --- sciRplothttps://mp.…...

深入理解Optional:处理空指针异常

1. 使用Optional处理可能为空的集合 在Java开发中&#xff0c;集合判空是一个常见但容易出错的场景。传统方式虽然可行&#xff0c;但存在一些潜在问题&#xff1a; // 传统判空方式 if (!CollectionUtils.isEmpty(userInfoList)) {for (UserInfo userInfo : userInfoList) {…...