使用libtorch加载YOLOv8生成的torchscript文件进行目标检测
在网上下载了60多幅包含西瓜和冬瓜的图像组成melon数据集,使用 LabelMe 工具进行标注,然后使用 labelme2yolov8 脚本将json文件转换成YOLOv8支持的.txt文件,并自动生成YOLOv8支持的目录结构,包括melon.yaml文件,其内容如下:
path: ../datasets/melon # dataset root dir
train: images/train # train images (relative to 'path')
val: images/val # val images (relative to 'path')
test: # test images (optional)# Classes
names:0: watermelon1: wintermelon
使用以下python脚本进行训练生成torchscript文件:
import argparse
import colorama
from ultralytics import YOLOdef parse_args():parser = argparse.ArgumentParser(description="YOLOv8 object detect")parser.add_argument("--yaml", required=True, type=str, help="yaml file")parser.add_argument("--epochs", required=True, type=int, help="number of training")args = parser.parse_args()return argsdef train(yaml, epochs):model = YOLO("yolov8n.pt") # load a pretrained modelresults = model.train(data=yaml, epochs=epochs, imgsz=640) # train the modelmetrics = model.val() # It'll automatically evaluate the data you trained, no arguments needed, dataset and settings rememberedmodel.export(format="onnx") #, dynamic=True) # export the model, cannot specify dynamic=True, opencv does not support# model.export(format="onnx", opset=12, simplify=True, dynamic=False, imgsz=640)model.export(format="torchscript") # libtorchif __name__ == "__main__":colorama.init()args = parse_args()train(args.yaml, args.epochs)print(colorama.Fore.GREEN + "====== execution completed ======")
以下是使用libtorch接口加载torchscript文件进行目标检测的实现代码:
namespace {constexpr bool cuda_enabled{ false };
constexpr int image_size[2]{ 640, 640 }; // {height,width}, input shape (1, 3, 640, 640) BCHW and output shape(s) (1, 6, 8400)
constexpr float model_score_threshold{ 0.45 }; // confidence threshold
constexpr float model_nms_threshold{ 0.50 }; // iou threshold#ifdef _MSC_VER
constexpr char* onnx_file{ "../../../data/best.onnx" };
constexpr char* torchscript_file{ "../../../data/best.torchscript" };
constexpr char* images_dir{ "../../../data/images/predict" };
constexpr char* result_dir{ "../../../data/result" };
constexpr char* classes_file{ "../../../data/images/labels.txt" };
#else
constexpr char* onnx_file{ "data/best.onnx" };
constexpr char* torchscript_file{ "data/best.torchscript" };
constexpr char* images_dir{ "data/images/predict" };
constexpr char* result_dir{ "data/result" };
constexpr char* classes_file{ "data/images/labels.txt" };
#endifstd::vector<std::string> parse_classes_file(const char* name)
{std::vector<std::string> classes;std::ifstream file(name);if (!file.is_open()) {std::cerr << "Error: fail to open classes file: " << name << std::endl;return classes;}std::string line;while (std::getline(file, line)) {auto pos = line.find_first_of(" ");classes.emplace_back(line.substr(0, pos));}file.close();return classes;
}auto get_dir_images(const char* name)
{std::map<std::string, std::string> images; // image name, image path + image namefor (auto const& dir_entry : std::filesystem::directory_iterator(name)) {if (dir_entry.is_regular_file())images[dir_entry.path().filename().string()] = dir_entry.path().string();}return images;
}void draw_boxes(const std::vector<std::string>& classes, const std::vector<int>& ids, const std::vector<float>& confidences,const std::vector<cv::Rect>& boxes, const std::string& name, cv::Mat& frame)
{if (ids.size() != confidences.size() || ids.size() != boxes.size() || confidences.size() != boxes.size()) {std::cerr << "Error: their lengths are inconsistent: " << ids.size() << ", " << confidences.size() << ", " << boxes.size() << std::endl;return;}std::cout << "image name: " << name << ", number of detections: " << ids.size() << std::endl;std::random_device rd;std::mt19937 gen(rd());std::uniform_int_distribution<int> dis(100, 255);for (auto i = 0; i < ids.size(); ++i) {auto color = cv::Scalar(dis(gen), dis(gen), dis(gen));cv::rectangle(frame, boxes[i], color, 2);std::string class_string = classes[ids[i]] + ' ' + std::to_string(confidences[i]).substr(0, 4);cv::Size text_size = cv::getTextSize(class_string, cv::FONT_HERSHEY_DUPLEX, 1, 2, 0);cv::Rect text_box(boxes[i].x, boxes[i].y - 40, text_size.width + 10, text_size.height + 20);cv::rectangle(frame, text_box, color, cv::FILLED);cv::putText(frame, class_string, cv::Point(boxes[i].x + 5, boxes[i].y - 10), cv::FONT_HERSHEY_DUPLEX, 1, cv::Scalar(0, 0, 0), 2, 0);}cv::imshow("Inference", frame);cv::waitKey(-1);std::string path(result_dir);path += "/" + name;cv::imwrite(path, frame);
}float letter_box(const cv::Mat& src, cv::Mat& dst, const std::vector<int>& imgsz)
{if (src.cols == imgsz[1] && src.rows == imgsz[0]) {if (src.data == dst.data) {return 1.;} else {dst = src.clone();return 1.;}}auto resize_scale = std::min(imgsz[0] * 1. / src.rows, imgsz[1] * 1. / src.cols);int new_shape_w = std::round(src.cols * resize_scale);int new_shape_h = std::round(src.rows * resize_scale);float padw = (imgsz[1] - new_shape_w) / 2.;float padh = (imgsz[0] - new_shape_h) / 2.;int top = std::round(padh - 0.1);int bottom = std::round(padh + 0.1);int left = std::round(padw - 0.1);int right = std::round(padw + 0.1);cv::resize(src, dst, cv::Size(new_shape_w, new_shape_h), 0, 0, cv::INTER_AREA);cv::copyMakeBorder(dst, dst, top, bottom, left, right, cv::BORDER_CONSTANT, cv::Scalar(114.));return resize_scale;
}torch::Tensor xywh2xyxy(const torch::Tensor& x)
{auto y = torch::empty_like(x);auto dw = x.index({ "...", 2 }).div(2);auto dh = x.index({ "...", 3 }).div(2);y.index_put_({ "...", 0 }, x.index({ "...", 0 }) - dw);y.index_put_({ "...", 1 }, x.index({ "...", 1 }) - dh);y.index_put_({ "...", 2 }, x.index({ "...", 0 }) + dw);y.index_put_({ "...", 3 }, x.index({ "...", 1 }) + dh);return y;
}// reference: https://github.com/pytorch/vision/blob/main/torchvision/csrc/ops/cpu/nms_kernel.cpp
torch::Tensor nms(const torch::Tensor& bboxes, const torch::Tensor& scores, float iou_threshold)
{if (bboxes.numel() == 0)return torch::empty({ 0 }, bboxes.options().dtype(torch::kLong));auto x1_t = bboxes.select(1, 0).contiguous();auto y1_t = bboxes.select(1, 1).contiguous();auto x2_t = bboxes.select(1, 2).contiguous();auto y2_t = bboxes.select(1, 3).contiguous();torch::Tensor areas_t = (x2_t - x1_t) * (y2_t - y1_t);auto order_t = std::get<1>(scores.sort(/*stable=*/true, /*dim=*/0, /* descending=*/true));auto ndets = bboxes.size(0);torch::Tensor suppressed_t = torch::zeros({ ndets }, bboxes.options().dtype(torch::kByte));torch::Tensor keep_t = torch::zeros({ ndets }, bboxes.options().dtype(torch::kLong));auto suppressed = suppressed_t.data_ptr<uint8_t>();auto keep = keep_t.data_ptr<int64_t>();auto order = order_t.data_ptr<int64_t>();auto x1 = x1_t.data_ptr<float>();auto y1 = y1_t.data_ptr<float>();auto x2 = x2_t.data_ptr<float>();auto y2 = y2_t.data_ptr<float>();auto areas = areas_t.data_ptr<float>();int64_t num_to_keep = 0;for (int64_t _i = 0; _i < ndets; _i++) {auto i = order[_i];if (suppressed[i] == 1)continue;keep[num_to_keep++] = i;auto ix1 = x1[i];auto iy1 = y1[i];auto ix2 = x2[i];auto iy2 = y2[i];auto iarea = areas[i];for (int64_t _j = _i + 1; _j < ndets; _j++) {auto j = order[_j];if (suppressed[j] == 1)continue;auto xx1 = std::max(ix1, x1[j]);auto yy1 = std::max(iy1, y1[j]);auto xx2 = std::min(ix2, x2[j]);auto yy2 = std::min(iy2, y2[j]);auto w = std::max(static_cast<float>(0), xx2 - xx1);auto h = std::max(static_cast<float>(0), yy2 - yy1);auto inter = w * h;auto ovr = inter / (iarea + areas[j] - inter);if (ovr > iou_threshold)suppressed[j] = 1;}}return keep_t.narrow(0, 0, num_to_keep);
}torch::Tensor non_max_suppression(torch::Tensor& prediction, float conf_thres = 0.25, float iou_thres = 0.45, int max_det = 300)
{using torch::indexing::Slice;using torch::indexing::None;auto bs = prediction.size(0);auto nc = prediction.size(1) - 4;auto nm = prediction.size(1) - nc - 4;auto mi = 4 + nc;auto xc = prediction.index({ Slice(), Slice(4, mi) }).amax(1) > conf_thres;prediction = prediction.transpose(-1, -2);prediction.index_put_({ "...", Slice({None, 4}) }, xywh2xyxy(prediction.index({ "...", Slice(None, 4) })));std::vector<torch::Tensor> output;for (int i = 0; i < bs; i++) {output.push_back(torch::zeros({ 0, 6 + nm }, prediction.device()));}for (int xi = 0; xi < prediction.size(0); xi++) {auto x = prediction[xi];x = x.index({ xc[xi] });auto x_split = x.split({ 4, nc, nm }, 1);auto box = x_split[0], cls = x_split[1], mask = x_split[2];auto [conf, j] = cls.max(1, true);x = torch::cat({ box, conf, j.toType(torch::kFloat), mask }, 1);x = x.index({ conf.view(-1) > conf_thres });int n = x.size(0);if (!n) { continue; }// NMSauto c = x.index({ Slice(), Slice{5, 6} }) * 7680;auto boxes = x.index({ Slice(), Slice(None, 4) }) + c;auto scores = x.index({ Slice(), 4 });auto i = nms(boxes, scores, iou_thres);i = i.index({ Slice(None, max_det) });output[xi] = x.index({ i });}return torch::stack(output);
}} // namespaceint test_yolov8_detect_libtorch()
{// reference: ultralytics/examples/YOLOv8-LibTorch-CPP-Inferenceif (auto flag = torch::cuda::is_available(); flag == true)std::cout << "cuda is available" << std::endl;elsestd::cout << "cuda is not available" << std::endl;torch::Device device(torch::cuda::is_available() ? torch::kCUDA : torch::kCPU);auto classes = parse_classes_file(classes_file);if (classes.size() == 0) {std::cerr << "Error: fail to parse classes file: " << classes_file << std::endl;return -1;}std::cout << "classes: ";for (const auto& val : classes) {std::cout << val << " ";}std::cout << std::endl;try {// load modeltorch::jit::script::Module model;if (torch::cuda::is_available() == true)model = torch::jit::load(torchscript_file, torch::kCUDA);elsemodel = torch::jit::load(torchscript_file, torch::kCPU);model.eval();// note: cpu is normal; gpu is abnormal: the model may not be fully placed on the gpu // model = torch::jit::load(file); model.to(torch::kCUDA) ==> model = torch::jit::load(file, torch::kCUDA)// model.to(device, torch::kFloat32);for (const auto& [key, val] : get_dir_images(images_dir)) {// load image and preprocesscv::Mat frame = cv::imread(val, cv::IMREAD_COLOR);if (frame.empty()) {std::cerr << "Warning: unable to load image: " << val << std::endl;continue;}cv::Mat bgr;letter_box(frame, bgr, {image_size[0], image_size[1]});torch::Tensor tensor = torch::from_blob(bgr.data, { bgr.rows, bgr.cols, 3 }, torch::kByte).to(device);tensor = tensor.toType(torch::kFloat32).div(255);tensor = tensor.permute({ 2, 0, 1 });tensor = tensor.unsqueeze(0);std::vector<torch::jit::IValue> inputs{ tensor };// inferencetorch::Tensor output = model.forward(inputs).toTensor().cpu();// NMSauto keep = non_max_suppression(output, 0.1f, 0.1f, 300)[0];std::vector<int> ids;std::vector<float> confidences;std::vector<cv::Rect> boxes;for (auto i = 0; i < keep.size(0); ++i) {int x1 = keep[i][0].item().toFloat();int y1 = keep[i][1].item().toFloat();int x2 = keep[i][2].item().toFloat();int y2 = keep[i][3].item().toFloat();boxes.emplace_back(cv::Rect(x1, y1, x2 - x1, y2 - y1));confidences.emplace_back(keep[i][4].item().toFloat());ids.emplace_back(keep[i][5].item().toInt());}draw_boxes(classes, ids, confidences, boxes, key, bgr);}} catch (const c10::Error& e) {std::cerr << "Error: " << e.msg() << std::endl;}return 0;
}
labels.txt文件内容如下:仅2类
watermelon 0
wintermelon 1
说明:
1.这里使用的libtorch版本为2.2.2;
2.通过函数torch::cuda::is_available()判断执行cpu还是gpu
3.通过非cmake构建项目时,调用torch::cuda::is_available()时即使在gpu下也会返回false,解决方法:项目属性:链接器 --> 命令行:其他选项中添加如下语句:
/INCLUDE:?warp_size@cuda@at@@YAHXZ
4.gpu下,语句model.to(torch::kCUDA)有问题,好像并不能将模型全部放置在gpu上,应调整为如下语句:
model = torch::jit::load(torchscript_file, torch::kCUDA);
执行结果如下图所示:同样的预测图像集,结果不如使用 opencv dnn 方法好,它们的前处理和后处理方式不同
其中一幅图像的检测结果如下图所示:
GitHub:https://github.com/fengbingchun/NN_Test
相关文章:

使用libtorch加载YOLOv8生成的torchscript文件进行目标检测
在网上下载了60多幅包含西瓜和冬瓜的图像组成melon数据集,使用 LabelMe 工具进行标注,然后使用 labelme2yolov8 脚本将json文件转换成YOLOv8支持的.txt文件,并自动生成YOLOv8支持的目录结构,包括melon.yaml文件,其内容…...

Oracle 并行和 session 数量的
这也就是为什么我们指定parallel为4,而实际并行度为8的原因。 insert create index,发现并行数都是加倍的 Indexes seem always created with parallel degree 1 during import as seen from a sqlfile. The sql file shows content like: CREATE INDE…...

Android 版本与 API level 以及 NDK 版本对应
采用 Android studio 开发 Android app 的时候,需要选择支持的最低 API Level 和使用的 NDK 版本,对应开发 app 的最低 SDK 版本: 在 app 的 build.gradle 文件里,对应于代码如下: 目前各版本的占有率情况如下…...
护网经验面试题目原版
文章目录 一、护网项目经验1.项目经验**Hvv的分组和流程**有没有遇到过有意思的逻辑漏洞?有没有自己开发过武器/工具?有做过代码审计吗?有0day吗有cve/cnvd吗?有src排名吗?有没有写过技战法有钓鱼经历吗?具…...

ipa 覆盖算法测试
相关文章 ipa 功能包测试 ipa 分区算法 ipa 分区算法总结,部分算法图解 ipa 覆盖算法分析(一) ipa 覆盖算法分析(二) 测试 网上找的地图: fig.1 测试地图 opencv fig.2 opencv 显示的覆盖路径 rviz fi…...

linuxwindows硬件信息midecod和wmic命令
1、命令dmidecode -t实例 1.1命令格式 dmidecode -t [类型代码或名称 ] 指令 1.2获取系统信息 [rootlala docker]# dmidecode -t 1 1.3获取主板信息: [rootshanghai docker]# dmidecode -t 2 1.4获取CPU ID dmidecode -t 4 | grep ID 1.5获取系统序列号 …...
03. SpringBoot 整合 Redis
文章目录 Jedis导入依赖测试连接Jedis 实现事务 SpringBoot 整合 RedisRedisTemplateSpringBoot 整合 Redis 测试RedisTemplate 序列化RedisUtils Jedis Jedis 是 Redis 官方推荐的 Java 连接工具。 导入依赖 </dependencies><dependency><groupId>redis.c…...

01-Linux【准备篇】
一、学Linux的作用? 1.Linux下开发(部署)软件项目 2.Linux运维 二、Linux的强与弱 1.薄弱 个人桌面领域的应用 此领域是传统Linux应用薄弱的环节,近些年随着Ubuntu、fedora等优秀桌面环境的兴起,Linux在个人桌面领域的占有率在慢慢提高…...

在IDEA中配置servlet(maven配置完成的基础下)
在IDEA中配置servlet(maven配置完成的基础下) 1.先新建一个项目 2.选择尾巴是webapp的,名称自定义 3.点击高级设置,修改组id 点击创建,等待jar包下载完成。在pom.xml中配置以下 <dependency><groupId>ja…...

pyqt6水平布局
效果预览 main_window.ui <?xml version"1.0" encoding"UTF-8"?> <ui version"4.0"><class>MainWindow</class><widget class"QMainWindow" name"MainWindow"><property name"geo…...

CLIP论文学习
学习来自B站bryanyzhu...

手把手教大家,怎么查看抖音小店的类目保证金?
大家好,我是喷火龙。 抖音小店的类目保证金也介绍过很多次了,不同的类目有不同的保证金,要想准确的知道自己想做的类目要交多少保证金的话,还是去官网查询比较可靠。 今天,就教大家怎么去查询自己想做的类目要交多少…...

5.24作业
...

Linux之LLVM、Clang、Clang++区别及用法实例(六十五)
简介: CSDN博客专家,专注Android/Linux系统,分享多mic语音方案、音视频、编解码等技术,与大家一起成长! 优质专栏:Audio工程师进阶系列【原创干货持续更新中……】🚀 优质专栏:多媒…...
CentOS7 安装 Mysql 5.7:密码查看与修改、更改端口、开机启动
文章目录 下载 MySQL yum包安装MySQL源安装MySQL服务端,需要等待一些时间启动MySQL修改密码方式一:临时密码获取临时密码,MySQL5.7为root用户随机生成了一个密码通过临时密码登录MySQL,进行修改密码操作 方式二:skip-grant-tables…...

专业渗透测试 Phpsploit-Framework(PSF)框架软件小白入门教程(十三)
本系列课程,将重点讲解Phpsploit-Framework框架软件的基础使用! 本文章仅提供学习,切勿将其用于不法手段! 接上一篇文章内容,讲述如何进行Phpsploit-Framework软件的基础使用和二次开发。 我们,继续讲一…...
linux替换文件中的字符串
linux替换文件中的字符串 方法一:使用sed命令进行替换 sed -i s/原字符串/新字符串/g 文件名 ex: sed -i s/2024-04-25%/2024-04-26%/g sql10.sql ex:,"analyzer":"ik_analyzer" 替换为空 sed -i s/,"analyzer":"ik_analyz…...
【前端每日基础】day22——js控制结构
循环语句用于重复执行代码块。 for 循环 常用于需要精确控制循环次数的情况。 for (let i 0; i < 5; i) {console.log("Iteration:", i); }while 循环 当条件为真时重复执行代码块,适用于循环次数不确定但条件明确的情况。 let i 0;while (i <…...

npm详解
引言 在JavaScript和Node.js开发领域,npm(Node Package Manager)是一个不可或缺的工具。它不仅是一个包管理器,也是一个强大的生态系统,允许开发者共享和重用代码。本文将详细介绍npm的基本概念、主要功能以及如何有效…...

ChatGPT-4o 实战 如何快速分析混淆加密和webpack打包的源码
ChatGPT-4o 几个特点 一个对话拥有长时间的记忆,可以连续上传文件,让其分析,最大一个代码文件只能3M,超出3M的文件,可以通过split-file可以进行拆分 其次ChatGPT-4o可以生成文件的下载链接,这有利于大文件的…...

【入坑系列】TiDB 强制索引在不同库下不生效问题
文章目录 背景SQL 优化情况线上SQL运行情况分析怀疑1:执行计划绑定问题?尝试:SHOW WARNINGS 查看警告探索 TiDB 的 USE_INDEX 写法Hint 不生效问题排查解决参考背景 项目中使用 TiDB 数据库,并对 SQL 进行优化了,添加了强制索引。 UAT 环境已经生效,但 PROD 环境强制索…...
java 实现excel文件转pdf | 无水印 | 无限制
文章目录 目录 文章目录 前言 1.项目远程仓库配置 2.pom文件引入相关依赖 3.代码破解 二、Excel转PDF 1.代码实现 2.Aspose.License.xml 授权文件 总结 前言 java处理excel转pdf一直没找到什么好用的免费jar包工具,自己手写的难度,恐怕高级程序员花费一年的事件,也…...
Objective-C常用命名规范总结
【OC】常用命名规范总结 文章目录 【OC】常用命名规范总结1.类名(Class Name)2.协议名(Protocol Name)3.方法名(Method Name)4.属性名(Property Name)5.局部变量/实例变量(Local / Instance Variables&…...

Ascend NPU上适配Step-Audio模型
1 概述 1.1 简述 Step-Audio 是业界首个集语音理解与生成控制一体化的产品级开源实时语音对话系统,支持多语言对话(如 中文,英文,日语),语音情感(如 开心,悲伤)&#x…...
大语言模型(LLM)中的KV缓存压缩与动态稀疏注意力机制设计
随着大语言模型(LLM)参数规模的增长,推理阶段的内存占用和计算复杂度成为核心挑战。传统注意力机制的计算复杂度随序列长度呈二次方增长,而KV缓存的内存消耗可能高达数十GB(例如Llama2-7B处理100K token时需50GB内存&a…...
代理篇12|深入理解 Vite中的Proxy接口代理配置
在前端开发中,常常会遇到 跨域请求接口 的情况。为了解决这个问题,Vite 和 Webpack 都提供了 proxy 代理功能,用于将本地开发请求转发到后端服务器。 什么是代理(proxy)? 代理是在开发过程中,前端项目通过开发服务器,将指定的请求“转发”到真实的后端服务器,从而绕…...
力扣-35.搜索插入位置
题目描述 给定一个排序数组和一个目标值,在数组中找到目标值,并返回其索引。如果目标值不存在于数组中,返回它将会被按顺序插入的位置。 请必须使用时间复杂度为 O(log n) 的算法。 class Solution {public int searchInsert(int[] nums, …...

springboot整合VUE之在线教育管理系统简介
可以学习到的技能 学会常用技术栈的使用 独立开发项目 学会前端的开发流程 学会后端的开发流程 学会数据库的设计 学会前后端接口调用方式 学会多模块之间的关联 学会数据的处理 适用人群 在校学生,小白用户,想学习知识的 有点基础,想要通过项…...
音视频——I2S 协议详解
I2S 协议详解 I2S (Inter-IC Sound) 协议是一种串行总线协议,专门用于在数字音频设备之间传输数字音频数据。它由飞利浦(Philips)公司开发,以其简单、高效和广泛的兼容性而闻名。 1. 信号线 I2S 协议通常使用三根或四根信号线&a…...
虚拟电厂发展三大趋势:市场化、技术主导、车网互联
市场化:从政策驱动到多元盈利 政策全面赋能 2025年4月,国家发改委、能源局发布《关于加快推进虚拟电厂发展的指导意见》,首次明确虚拟电厂为“独立市场主体”,提出硬性目标:2027年全国调节能力≥2000万千瓦࿰…...