使用PyTorch导出JIT模型:C++ API与libtorch实战
PyTorch导出JIT模型并用C++ API libtorch调用
本文将介绍如何将一个 PyTorch 模型导出为 JIT 模型并用 PyTorch 的 C++API libtorch运行这个模型。
Step1:导出模型
首先我们进行第一步,用 Python API 来导出模型,由于本文的重点是在后面的部署阶段,因此,模型的训练就不进行了,直接对 torchvision 中自带的 ResNet50 进行导出。在实际应用中,大家可以对自己训练好的模型进行导出。
# export_jit_model.py
import torch
import torchvision.models as modelsmodel = models.resnet50(pretrained=True)
model.eval()example_input = torch.rand(1, 3, 224, 224)jit_model = torch.jit.trace(model, example_input)
torch.jit.save(jit_model, 'resnet50_jit.pth')
导出 JIT 模型的方式有两种:trace 和 script。
我们采用
torch.jit.trace
的方式来导出 JIT 模型,这种方式会根据一个输入将模型跑一遍,然后记录下执行过程。这种方式的问题在于对于有分支判断的模型不能很好的应对,因为一个输入不能覆盖到所有的分支。但是在我们 ResNet50 模型中不会遇到分支判断,因此这里是合适的。关于两种导出 JIT 模型的方式各自优劣不是本文的中断,以后会再写一篇来分析。
在我们的工程目录
demo
下运行上面的
export_jit_model.py
,会得到一个 JIT 模型件:
resnet50_jit.pth
。
Step 2:安装libtorch
接下来我们要安装 PyTorch 的 C++ API:libtorch。这一步很简单,直接下载官方预编译的文件并解压即可:
wget https://download.pytorch.org/libtorch/nightly/cpu/libtorch-shared-with-deps-latest.zip
unzip libtorch-shared-with-deps-latest.zip
也解压在我们的工程目录
demo
下即可。
Step 3:安装OpenCV
用 Python 或 C++ 做图像任务,OpenCV 是经常用到的。如果还没有安装的读者可以参考如下在工程目录
demo
下进行安装,构建的过程可能会比较久。已经安装的读者可跳过此步骤,一会儿在
CMakeLists.txt
文件中正确地指定本机的 OpenCV 地址即可。
git clone --branch 3.4 --depth 1 https://github.com/opencv/opencv.git
mkdir demo/build && cd demo/build
cmake ..
make -j 6
Step 4:准备测试图像并用Python测试
我们先准备一张小猫的图像,并用 PyTorch ResNet50 模型正常跑一下,一会儿与我们 C++ 模型运行的结果对比来验证 C++ 模型是否被正确的部署。
kitten.jpg
:

写一个脚本用 PyTorch 运行一下模型:
# pytorch_test.pyimport torchvision.models as models
from torchvision.transforms import transforms
import torch
from PIL import Image# normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
all_transforms = transforms.Compose([transforms.Resize(224),transforms.ToTensor()])# normalize])model = models.resnet50(pretrained=True)
model.eval()img = Image.open('kitten.jpg').convert('RGB')
img_tensor = all_transforms(img).unsqueeze(dim=0)
pred = model(img_tensor).squeeze(dim=0)
print(torch.argmax(pred).item())
输出结果是:282。通过查看
ImageNet 1K 类别名与索引的对应关系
,可以看到,结果为 tiger cat,模型预测正确。一会儿我们看一下部署后的 C++ 模型是否能正确输出结果 282。
Step 5:准备cpp源文件
我们下面准备一会要执行的 cpp 源文件,第一次使用 libtorch 的读者可以先借鉴下面的文件。
这里有几个点要说一下,不注意可能会犯错:
cv::imread()
默认读取为三通道BGR,需要进行B/R通道交换,这里采用
cv::cvtColor()
实现。- 图像尺寸需要调整到
224
×
224
224\times 224
2
2
4
×
2
2
4
,通过
cv::resize()
实现。
3. opencv读取的图像矩阵存储形式:H x W x C, 但是pytorch中 Tensor的存储为:N x C x H x W, 因此需要进行变换,就是
np.transpose()
操作,这里使用
tensor.permut()
实现,效果是一样的。
4. 数据归一化,采用
tensor.div(255)
实现。
// test_model.cpp
#include <vector>#include <torch/torch.h>
#include <torch/script.h>#include <opencv2/core.hpp>
#include <opencv2/imgproc/imgproc.hpp>
#include <opencv2/highgui/highgui.hpp>int main(int argc, char* argv[]) {// 加载JIT模型auto module = torch::jit::load(argv[1]);// 加载图像auto image = cv::imread(argv[2], cv::ImreadModes::IMREAD_COLOR);cv::Mat image_transfomed;cv::resize(image, image_transfomed, cv::Size(224, 224));cv::cvtColor(image_transfomed, image_transfomed, cv::COLOR_BGR2RGB);// 图像转换为Tensortorch::Tensor tensor_image = torch::from_blob(image_transfomed.data, {image_transfomed.rows, image_transfomed.cols, 3},torch::kByte);tensor_image = tensor_image.permute({2, 0, 1});// tensor_image = tensor_image.toType(torch::kFloat);tensor_image = tensor_image.div(255.);// tensor_image = tensor_image.sub(0.5);// tensor_image = tensor_image.div(0.5);tensor_image = tensor_image.unsqueeze(0);// 运行模型torch::Tensor output = module.forward({tensor_image}).toTensor();// 结果处理int result = output.argmax().item<int>();std::cout << "The classifiction index is: " << result << std::endl;return 0;
}
Step 6:构建运行验证
我们先来写一下
CMakeLists.txt
:
cmake_minimum_required(VERSION 3.0 FATAL_ERROR)
project(resnet50)find_package(Torch REQUIRED PATHS ./libtorch)
find_package(OpenCV REQUIRED)add_executable(resnet50 test_model.cpp)
target_link_libraries(resnet50 "${TORCH_LIBRARIES}" "${OpenCV_LIBS}")set_property(TARGET resnet50 PROPERTY CXX_STANDARD 14)
set(CMAKE_CXX_STANDARD_REQUIRED TRUE)set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${TORCH_CXX_FLAGS}")
现在我们的工程目录
demo
下有以下文件:
CMakeLists.txt export_jit_model.py kitten.jpg libtorch pytorch_test.py resnet50_jit.pth test_model.cpp
然后开始用 CMake 构建工程:
mkdir build && cd build
OpenCV_DIR=[YOUR_PATH_TO_OPENCV]/opencv/build cmake ..
make
整个过程没有报错的话我们就已经构建完成了,会得到一个可执行文件
resnet50
在工程目录
demo
下。
接下来我们执行,并验证运行结果是否与 PyTorch 的结果一致:
./build/resnet50 resnet50_jit.pth kitten.jpg
输出:
The classifiction index is: 282
运行成功并且结果正确。
Ref:
https://www.jianshu.com/p/7cddc09ca7a4
https://blog.csdn.net/cxx654/article/details/115916275
https://zhuanlan.zhihu.com/p/370455320
相关文章:
使用PyTorch导出JIT模型:C++ API与libtorch实战
PyTorch导出JIT模型并用C API libtorch调用 本文将介绍如何将一个 PyTorch 模型导出为 JIT 模型并用 PyTorch 的 CAPI libtorch运行这个模型。 Step1:导出模型 首先我们进行第一步,用 Python API 来导出模型,由于本文的重点是在后面的部署…...
Python——异常捕获,传递及其抛出操作
01. 异常的概念 1. 程序在运行时,如果 python解释器遇到一个错误,会停止程序的执行,并且提示一些错误信息,这就是异常。 2. 程序停止执行并且提示错误信息这个动作,我们通常称之为:抛出(raise…...
【Maven】 的继承机制
Maven是一个强大的项目管理工具,主要用于Java项目的构建和管理。它以其项目对象模型(POM)为基础,允许开发者定义项目的依赖、构建过程和插件。Maven的继承机制是其核心特性之一,它允许子项目继承和复用父项目的配置&am…...
微信小程序结合后端php发送模版消息
前端: <view class"container"><button bindtap"requestSubscribeMessage">订阅消息</button> </view> // index.js Page({data: {tmplIds: [UTgCUfsjHVESf5FjOzls0I9i_FVS1N620G2VQCg1LZ0] // 使用你的模板ID},requ…...
sqlalchemy报错sqlalchemy.orm.exc.DetachedInstanceError
解决方案: 在初始化数据库的代码中,将 maker sessionmaker(bindeng)修改为 maker sessionmaker(bindeng, expire_on_commitFalse)为什么要添加 expire_on_commitFalse 参数? expire_on_commit 可以用来更改 SQLAlchemy 的对象刷新机制&…...
华为网络模拟器eNSP安装部署教程
eNSP是图形化网络仿真平台,该平台通过对真实网络设备的仿真模拟,帮助广大ICT从业者和客户快速熟悉华为数通系列产品,了解并掌握相关产品的操作和配置、提升对企业ICT网络的规划、建设、运维能力,从而帮助企业构建更高效࿰…...
【React】详解样式控制:从基础到进阶应用的全面指南
文章目录 一、内联样式1. 什么是内联样式?2. 内联样式的定义3. 基本示例4. 动态内联样式 二、CSS模块1. 什么是CSS模块?2. CSS模块的定义3. 基本示例4. 动态应用样式 三、CSS-in-JS1. 什么是CSS-in-JS?2. styled-components的定义3. 基本示例…...
【ROS2】高级:安全-理解安全密钥库
目标:探索位于 ROS 2 安全密钥库中的文件。 教程级别:高级 时间:15 分钟 内容 背景安全工件位置 公钥材料 私钥材料域治理政策 安全飞地 参加测验! 背景 在继续之前,请确保您已完成设置安全教程。 sros2 包可以用来创…...
C语言 ——— 数组指针的定义 数组指针的使用
目录 前言 数组指针的定义 数组指针的使用 前言 之前有编写过关于 指针数组 的相关知识 C语言 ——— 指针数组 & 指针数组模拟二维整型数组-CSDN博客 指针数组 顾名思义就是 存放指针的数组 那什么是数组指针呢? 数组指针的定义 何为数组指针…...
opencascade AIS_ManipulatorOwner AIS_MediaPlayer源码学习
前言 AIS_ManipulatorOwner是OpenCascade中的一个类,主要用于操纵对象的交互控制。AIS_ManipulatorOwner结合AIS_Manipulator类,允许用户通过可视化工具(如旋转、平移、缩放等)来操纵几何对象。 以下是AIS_ManipulatorOwner的基…...
如何防止用户通过打印功能复制页面文字
简单防白嫖,要让打印出来的页面是空白,通常的做法是在打印时隐藏页面上的所有内容。这可以通过CSS的媒体查询(Media Queries)来实现,特别是针对media print的查询。 在JavaScript中,你通常不会直接控制打印…...
Python3网络爬虫开发实战(3)网页数据的解析提取
文章目录 一、XPath1. 选取节点2. 查找某个特定的节点或者包含某个指定的值的节点3. XPath 运算符4. 节点轴5. 利用 lxml 使用 XPath 二、CSS三、Beautiful Soup1. 信息提取2. 嵌套选择3. 关联选择4. 方法选择器5. css 选择器 四、PyQuery1. 初始化2. css 选择器3. 信息提取4. …...
基于 HTML+ECharts 实现监控平台数据可视化大屏(含源码)
构建监控平台数据可视化大屏:基于 HTML 和 ECharts 的实现 监控平台的数据可视化对于实时掌握系统状态、快速响应问题至关重要。通过直观的数据展示,运维团队可以迅速发现异常,优化资源配置。本文将详细介绍如何利用 HTML 和 ECharts 实现一个…...
立创梁山派--移植开源的SFUD和FATFS实现SPI-FLASH文件系统
本文主要是在sfud的基础上进行fatfs文件系统的移植,并不对sfud的移植再进行过多的讲解了哦,所以如果想了解sfud的移植过程,请参考我的另外一篇文章:传送门 正文开始咯 首先我们需要先准备资料准备好,这里对于fatfs的…...
MySQL之视图和索引实战
1.新建数据库 mysql> create database myudb5_indexstu; Query OK, 1 row affected (0.01 sec) mysql> use myudb5_indexstu; Database changed 2.新建表 1.学生表student,定义主键,姓名不能重名,性别只能输入男或女,所在…...
快速参考:用C# Selenium实现浏览器窗口缩放的步骤
背景介绍 在现代网络环境中,浏览器自动化已成为数据抓取和测试的重要工具。Selenium作为一个强大的浏览器自动化工具,能够与多种编程语言结合使用,其中C#是非常受欢迎的选择之一。在实际应用中,我们常常需要调整浏览器窗口的缩放…...
MyBatis 插件机制、分页插件如何实现的
MyBatis 插件机制允许开发者在 SQL 执行的各个阶段(如预处理、执行、结果处理等)中插入自定义逻辑,从而实现对 MyBatis 行为的扩展和增强。以下是 MyBatis 插件运行原理的详细介绍: 插件接口 MyBatis 插件通过实现 org.apache.i…...
CentOS6.0安装telnet-server启用telnet服务
CentOS6.0安装telnet-server启用telnet服务 一步到位 fp"/etc/yum.repos.d" ; cp -a ${fp} ${fp}.$(date %0y%0m%0d%0H%0M%0S).bkup echo [base] nameCentOS-$releasever - Base baseurlhttp://mirrors.163.com/centos-vault/6.0/os/$basearch/http://mirrors.a…...
H5+CSS+JS工作性价比计算器
工作性价比=平均日新x综合环境系数/35 x(工作时长+通勤时长—0.5 x摸鱼时长) x学历系数 如果代码中的公式不对,请指正 效果图 源代码 <!DOCTYPE html> <html> <head> <style> .calculator { width: 300px; padd…...
Linux:基础命令学习
目录 一、ls命令 实例:-l以长格式显示文件和目录信息 实例:-F根据文件类型在列出的文件名称后加一符号 实例: -R 递归显示目录中的所有文件和子目录。 实例: 组合使用 Home目录和工作目录 二、目录修改和查看命令 三、mkd…...
KubeSphere 容器平台高可用:环境搭建与可视化操作指南
Linux_k8s篇 欢迎来到Linux的世界,看笔记好好学多敲多打,每个人都是大神! 题目:KubeSphere 容器平台高可用:环境搭建与可视化操作指南 版本号: 1.0,0 作者: 老王要学习 日期: 2025.06.05 适用环境: Ubuntu22 文档说…...
日语学习-日语知识点小记-构建基础-JLPT-N4阶段(33):にする
日语学习-日语知识点小记-构建基础-JLPT-N4阶段(33):にする 1、前言(1)情况说明(2)工程师的信仰2、知识点(1) にする1,接续:名词+にする2,接续:疑问词+にする3,(A)は(B)にする。(2)復習:(1)复习句子(2)ために & ように(3)そう(4)にする3、…...
深入理解JavaScript设计模式之单例模式
目录 什么是单例模式为什么需要单例模式常见应用场景包括 单例模式实现透明单例模式实现不透明单例模式用代理实现单例模式javaScript中的单例模式使用命名空间使用闭包封装私有变量 惰性单例通用的惰性单例 结语 什么是单例模式 单例模式(Singleton Pattern&#…...
STM32标准库-DMA直接存储器存取
文章目录 一、DMA1.1简介1.2存储器映像1.3DMA框图1.4DMA基本结构1.5DMA请求1.6数据宽度与对齐1.7数据转运DMA1.8ADC扫描模式DMA 二、数据转运DMA2.1接线图2.2代码2.3相关API 一、DMA 1.1简介 DMA(Direct Memory Access)直接存储器存取 DMA可以提供外设…...
对WWDC 2025 Keynote 内容的预测
借助我们以往对苹果公司发展路径的深入研究经验,以及大语言模型的分析能力,我们系统梳理了多年来苹果 WWDC 主题演讲的规律。在 WWDC 2025 即将揭幕之际,我们让 ChatGPT 对今年的 Keynote 内容进行了一个初步预测,聊作存档。等到明…...
LLM基础1_语言模型如何处理文本
基于GitHub项目:https://github.com/datawhalechina/llms-from-scratch-cn 工具介绍 tiktoken:OpenAI开发的专业"分词器" torch:Facebook开发的强力计算引擎,相当于超级计算器 理解词嵌入:给词语画"…...
大模型多显卡多服务器并行计算方法与实践指南
一、分布式训练概述 大规模语言模型的训练通常需要分布式计算技术,以解决单机资源不足的问题。分布式训练主要分为两种模式: 数据并行:将数据分片到不同设备,每个设备拥有完整的模型副本 模型并行:将模型分割到不同设备,每个设备处理部分模型计算 现代大模型训练通常结合…...
汇编常见指令
汇编常见指令 一、数据传送指令 指令功能示例说明MOV数据传送MOV EAX, 10将立即数 10 送入 EAXMOV [EBX], EAX将 EAX 值存入 EBX 指向的内存LEA加载有效地址LEA EAX, [EBX4]将 EBX4 的地址存入 EAX(不访问内存)XCHG交换数据XCHG EAX, EBX交换 EAX 和 EB…...
听写流程自动化实践,轻量级教育辅助
随着智能教育工具的发展,越来越多的传统学习方式正在被数字化、自动化所优化。听写作为语文、英语等学科中重要的基础训练形式,也迎来了更高效的解决方案。 这是一款轻量但功能强大的听写辅助工具。它是基于本地词库与可选在线语音引擎构建,…...
Go 语言并发编程基础:无缓冲与有缓冲通道
在上一章节中,我们了解了 Channel 的基本用法。本章将重点分析 Go 中通道的两种类型 —— 无缓冲通道与有缓冲通道,它们在并发编程中各具特点和应用场景。 一、通道的基本分类 类型定义形式特点无缓冲通道make(chan T)发送和接收都必须准备好࿰…...
