【CV学习笔记】tensorrtx-yolov5 逐行代码解析
1、前言
TensorRTx(下文简称为trtx)是一个十分流行的利用API来搭建网络结构实现trt加速的开源库,作者提到为什么不用ONNX parser的方式来进行trt加速,而用最底层的API来搭建trt加速的方式有如下原因:
- Flexible 很容易修改模型的任意一层,删除、增加、替换等操作。
- Debuggable 可以容易获得模型中间某一层的结果
- Chance to learn 可以对模型结构有进一步的了解
尽管onnx2trt的方式目前已经在绝大部分情况下都不会出现问题,但在trtx下,我们能够掌握更底层的原理和代码,有利于我们对模型的部署以及优化。下文将会以yolov5s在trtx框架下的例子,来逐行解析是trtx是如何工作的。
TensorRTx项目链接:https://github.com/wang-xinyu/tensorrtx。
2、步骤解析
在trtx中,对一个模型加速的过程可以分为两个步骤
- 提取pytorch模型参数 wts
- 利用trt底层API搭建网络结构,并将wts中的参数填充到网络中
2.1、get_wts.py
首先需要将pytorch中的模型参数提取出来,pytorch中的模型参数是以caffe中blob的格式存在的,每个操作都有对应的名字、数据长度、数据.
for k, v in model.state_dict().items():# k-> blob的名字vr = v.reshape(-1).cpu().numpy() # vr -> 数据长度f.write('{} {} '.format(k, len(vr)))for vv in vr:f.write(' ')f.write(struct.pack('>f', float(vv)).hex()) # 将数据转化到16进制f.write('\n')
通过上get_wts.py,就可以得到包含yolov5s.pth的模型参数,打开yolov5s.wts如下图所示:

其中第一行的351为总的blob数量,第二行的model.0.conv.weight为第一个blob的名字,3456表示为该blob的数据长度,3a198000 3ca58000…为实际参数。
得到了上述的参数之后,就可以以trtx的方式进行加速了。
2.2、构造engine
在利用wts转engine的之前,需要十分清楚模型的网络结构,不太清楚的同学可以参考太阳花的小绿豆关于yolov5的网络结构图。了解完yolov5的网络结构后,就可以着手利用trt的api来搭建网络模型了。搭建模型的代码在 model.cpp中的build_det_engine函数,本文将其中的代码过程直接画到yolov5的网络结构图中了,可以直接对照代码和图来进行查看。

//yolov5_det.cpp
viod serialize_engine(...){if (is_p6) {...} else {// 以yolov5s为例engine = build_det_engine(max_batchsize, builder, config, DataType::kFLOAT, gd, gw, wts_name);}// 序列化IHostMemory* serialized_engine = engine->serialize();std::ofstream p(engine_name, std::ios::binary);// 写到文件中p.write(reinterpret_cast<const char*>(serialized_engine->data()), serialized_engine->size());}
model.cpp
// 解析get_wts.py
static std::map<std::string, Weights> loadWeights(const std::string file) {int32_t count; // wts文件第一行,共有351个blobinput >> count;//每一行是一个blob,模型名称 + 数据长度 + 参数while (count--) {// 一个blob的参数Weights wt{ DataType::kFLOAT, nullptr, 0 };uint32_t size; //blob 数据长度std::string name; // blob 数据名字for (uint32_t x = 0, y = size; x < y; ++x) {input >> std::hex >> val[x]; // 将数据转化成十进制,并放到val中}// 每个blob名字对应一个wtweightMap[name] = wt;}
}ICudaEngine* build_det_engine(){// 初始化网络结构INetworkDefinition* network = builder->createNetworkV2(0U);// 定义模型输入ITensor* data = network->addInput(kInputTensorName, dt, Dims3{ 3, kInputH, kInputW });// 加载pytorch模型中的参数std::map<std::string, Weights> weightMap = loadWeights(wts_name);// 逐步添加网络结构,已将代码与网络结构一一对应 ,具体过程见上图// 增加yolo后处理decode模块,使用了pluginauto yolo = addYoLoLayer(network, weightMap, "model.24", std::vector<IConvolutionLayer*>{det0, det1, det2});network->markOutput(*yolo->getOutput(0)); //将plugin的输出设置为模型的最后输出(decode)#if defined(USE_FP16)// FP16config->setFlag(BuilderFlag::kFP16);#elif defined(USE_INT8)// INT8 量化std::cout << "Your platform support int8: " << (builder->platformHasFastInt8() ? "true" : "false") << std::endl;assert(builder->platformHasFastInt8());config->setFlag(BuilderFlag::kINT8);Int8EntropyCalibrator2* calibrator = new Int8EntropyCalibrator2(1, kInputW, kInputH, "./coco_calib/", "int8calib.table", kInputTensorName);config->setInt8Calibrator(calibrator);#endif// 根据网络结构来生成engineICudaEngine* engine = builder->buildEngineWithConfig(*network, *config);return engine;
}
3、plugin
本人对plugin也在学习当中,下面是我在学习trtx-yolo5代码中对plugin浅显的认知。原作者在模型后面增加了一个模型解码的plugin,用于获得每个特征层上的bbox,调用代码在model.cpp中
auto yolo = addYoLoLayer(network, weightMap, "model.24", std::vector<IConvolutionLayer*>{det0, det1, det2});static IPluginV2Layer* addYoLoLayer(...){// 注册一个名为 "YoloLayer_TRT"的插件,如果找不到插件,就会报错auto creator = getPluginRegistry()->getPluginCreator("YoloLayer_TRT", "1");// plugin的数据PluginField plugin_fields[2];int netinfo[5] = {kNumClass, kInputW, kInputH, kMaxNumOutputBbox, (int)is_segmentation}; //维度数据plugin_fields[0].data = netinfo; plugin_fields[0].length = 5; plugin_fields[0].name = "netinfo";plugin_fields[0].type = PluginFieldType::kFLOAT32;// 所有plugin的参数PluginFieldCollection plugin_data;plugin_data.nbFields = 2;plugin_data.fields = plugin_fields;// 创建plugin的对象 IPluginV2 *plugin_obj = creator->createPlugin("yololayer", &plugin_data);
}
实现代码在yololayer.h/cu中
class API YoloLayerPlugin : public IPluginV2IOExt {// 设置插件名称,在注册插件时会寻找对应的插件const char* getPluginType() const TRT_NOEXCEPT override{return "YoloLayer_TRT";}//插件构造函数YoloLayerPlugin(int classCount, int netWidth, int netHeight, int maxOut, bool is_segmentation, const std::vector<YoloKernel>& vYoloKernel){/*classCount:类别数量netWidth:输入宽netHeight:输入高maxOut:最大检测数量is_segmentation:是否含有实例分割vYoloKernel:anchors参数*/}}// 插件运行时调用的代码
void YoloLayerPlugin::forwardGpu(...){// 输出结果 1+ 是在第一个位置记录解码的数量int outputElem = 1 + mMaxOutObject * sizeof(Detection) / sizeof(float);// 将存放结果的内存置为0for (int idx = 0; idx < batchSize; ++idx) {CUDA_CHECK(cudaMemsetAsync(output + idx * outputElem, 0, sizeof(float), stream));// 遍历三种不同尺度的anchorfor (unsigned int i = 0; i < mYoloKernel.size(); ++i) {// 调用核函数进行解码CalDetection << < (numElem + mThreadCount - 1) / mThreadCount, mThreadCount, 0, stream >> >(...)}}__global__ void CalDetection(...){// input:模型输出结果// output:decode存放地址// 当前线程的的全局索引IDint idx = threadIdx.x + blockDim.x * blockIdx.x;// yoloWidth * yoloHeightint total_grid = yoloWidth * yoloHeight; // 在当前特征层上要处理的总框数int bnIdx = idx / total_grid; // 第n个batch // x,y,w,h,score + 80int info_len_i = 5 + classes;// 如果带有实例分割分析,需要再加上32个分割系数if (is_segmentation) info_len_i += 32;// 第n个batch的推理结果开始地址const float* curInput = input + bnIdx * (info_len_i * total_grid * kNumAnchor);// 遍历三种不同尺寸的anchorfor (int k = 0; k < kNumAnchor; ++k) {//每个框的置信度float box_prob = Logist(curInput[idx + k * info_len_i * total_grid + 4 * total_grid]);if (box_prob < kIgnoreThresh) continue;for (int i = 5; i < 5 + classes; ++i) {// 每个类别的概率float p = Logist(curInput[idx + k * info_len_i * total_grid + i * total_grid]);// 提取最大概率以及类别IDif (p > max_cls_prob) {max_cls_prob = p;class_id = i - 5;}}// float *res_count = output + bnIdx * outputElem;// 统计decode框的数量 int count = (int)atomicAdd(res_count, 1);// 下面是按照论文的公式将预测的宽和高恢复到原图大小...}
}
4、总结
通过本次对trtx开源代码的深入学习,知道了如何利用trt的api对模型进行加速,同时还了解到plugin的实现,后续还会继续学习trtx里面的知识点。
相关文章:
【CV学习笔记】tensorrtx-yolov5 逐行代码解析
1、前言 TensorRTx(下文简称为trtx)是一个十分流行的利用API来搭建网络结构实现trt加速的开源库,作者提到为什么不用ONNX parser的方式来进行trt加速,而用最底层的API来搭建trt加速的方式有如下原因: Flexible 很容易修改模型的任意一层,删…...
微信管理系统可以解决什么问题?
微信作为一款社交通讯软件,已经成为人们日常生活中不可缺少的工具。不仅个人,很多企业都用微信来联系客户、维护客户和营销,这自然而然就会有很多微信账号、手机也多,那管理起来就会带来很多的不便,而微信管理系统正好…...
mysql事务测试
mysql的事务处理主要有两种方法1、用begin,rollback,commit来实现 begin; -- 开始一个事务 rollback; -- 事务回滚 commit; -- 事务提交 2、直接用set来改变mysql的自动提交模式 mysql默认是自动提交的,也就是你提交一个sql,它就直接执行!我…...
Spring面试题14:Spring中什么是Spring Beans? 包含哪些?Spring容器提供几种方式配置元数据?Spring中怎样定义类的作用域?
该文章专注于面试,面试只要回答关键点即可,不需要对框架有非常深入的回答,如果你想应付面试,是足够了,抓住关键点 面试官:Spring中什么是Spring Beans? 包含哪些? 在Spring中,Spring Beans是指由Spring容器管理的对象。Spring Beans包含以下内容: 类定义(Class De…...
Tomcat部署、优化、以及操作练习
一.Tomcat的基本介绍 1.1.Tomcat是什么? Tomcat服务器是一个免费的开放源代码的Web应用服务器,属于轻量级应用服务器,在中小型系统和并发访问用户不是很多的场合下被普遍使用,是开发和调试JSP程序的首选。一般来说,T…...
服务器假死日志按时间统计排查
文章目录 场景解决方案排查过程根据cost时间来筛选 场景 服务器假死,进程还在,但是已经接不到请求了。因此有客户报事,发现服务假死了。 解决方案 这种假死问题一般不太好排查,常规来说有几种可能。 1、慢sql导致卡死。 2、大数…...
CSS——grid网格布局的基本使用
网格布局在实现页面自适应,大屏可视化中常常使用,在这篇博客里,记录一下网格布局的基本使用。 参考文档:网格布局_菜鸟教程 文章目录 1. 体会grid的自适应性2. grid-template-arr配置网格行列3. 网格单位fr与repeat()简写属性值4…...
【python】使用Nuitka打包python项目-demo示例
文章目录 写在前面参考准备工作Quick Start参数说明使用打包程序输出目录结构日志2023.09.20 写在前面 本文的demo示例的代码/数据可从笔者的GitCode获取: HelloWorld 参考 Nuitka官网: https://github.com/Nuitka/NuitkaNuitka使用: https://daobook.github.io/nuitka-doc/…...
Java多线程篇(5)——cas和atomic原子类
文章目录 CASAtomic 原子类一般原子类针对aba问题 —— AtomicStampedReference针对大量自旋问题 —— LongAdder CAS 原理大致如下: 在java的 Unsafe 类里封装了一些 cas 的api。以 compareAndSetInt 为例,来看看其底层实现。 可以发现,最…...
数据结构---栈和队列
栈(Stack) 栈:一种特殊的线性表,其只允许在固定的一端进行插入和删除元素操作。进行数据插入和删除操作的一端称为栈 顶,另一端称为栈底。栈中的数据元素遵守后进先出LIFO(Last In First Out)的原则。 压栈࿱…...
2023-9-23 合并果子
题目链接:合并果子 #include <iostream> #include <algorithm> #include <queue>using namespace std;int main() {int n;cin >> n;priority_queue<int, vector<int>, greater<int>> heap;for(int i 0; i < n; i){in…...
基于QT和UDP实现一个实时RTP数据包的接收,并将数据包转化成文件
简单介绍:代码写的比较详细,需要留意的地方看结尾介绍 头文件 #ifndef RTPRECEIVER_H #define RTPRECEIVER_H#include <QDialog> #include <QUdpSocket> #include <QFile> #include <QTextStream> #include <httpclient.h&g…...
云原生安全性:保护现代应用免受威胁
文章目录 引言云原生安全性的挑战云原生安全性的关键实践1. 安全的镜像构建2. 网络策略3. 漏洞扫描和漏洞管理4. 认证和授权5. 日志和监控 云原生安全工具结论 🎉欢迎来到云计算技术应用专栏~云原生安全性:保护现代应用免受威胁 ☆* o(≧▽≦)o *☆嗨~我…...
R语言绘图-3-Circular-barplot图
0. 参考: https://r-graph-gallery.com/web-circular-barplot-with-R-and-ggplot2.html 1. 说明: 利用 ggplot 绘制 环状的条形图 (circular barplot),并且每个条带按照数值大小进行排列。 2 绘图代码: 注意:绘图代码中的字体…...
解决Keil5下载没有对应芯片Flash的问题
问题描述 例如芯片是STM32F103ZET6,但是选项中并没有对应型号的芯片导致下载失败。 解决方法 1、寻找芯片安装包的具体位置,芯片安装包路径在软件安装过程中会有(如图1所示)。如果没有记录可以双击一下芯片安装包会直接提示。…...
深拷贝与浅拷贝(对象的引用)
可以用赋值 1.对象的引用 代码: <!-- 1.对象的引用 --><script>const info{name:"lucy",age:20}const objinfo;info.name"sam"console.log(obj.name) //sam</script>图解: 等于号的赋值,对象info…...
重新认识架构—不只是软件设计
前言 什么是架构? 通常情况下,人们对架构的认知仅限于在软件工程中的定义:架构主要指软件系统的结构设计,比如常见的SOLID准则、DDD架构。一个良好的软件架构可以帮助团队更有效地进行软件开发,降低维护成本࿰…...
我的创业笔记:困境与思索
现在是2023年9月22日傍晚,我一个人走在广州的珠江边,静静地思索着当前个人创业面临的困境,不由自主地想将这些想法记录下来。 故事需要从两个月前说起。2023年7月31号,我从金山办公离职后,就满心欢喜地开启了自己的个…...
minio文件上传
1.代码 大佬仓库:https://gitee.com/Gary2016/minio-upload?_fromgitee_search 关于这个代码的讲解:来自b站 2.准备minio 参考:[1]、[2] 2.1 下载 官网:https://min.io/download#/windows 2.2 启动 ①准备一个data文件夹…...
IDEA .iml文件及.idea文件夹详解
.iml文件 idea 对module 配置信息之意, infomation of module。每个模块都有一个iml文件。 IDEA中的.iml文件是项目标识文件,缺少了这个文件,IDEA就无法识别项目。跟Eclipse的.project文件性质是一样的。并且这些文件不同的设备上的内容也会…...
多场景 OkHttpClient 管理器 - Android 网络通信解决方案
下面是一个完整的 Android 实现,展示如何创建和管理多个 OkHttpClient 实例,分别用于长连接、普通 HTTP 请求和文件下载场景。 <?xml version"1.0" encoding"utf-8"?> <LinearLayout xmlns:android"http://schemas…...
PPT|230页| 制造集团企业供应链端到端的数字化解决方案:从需求到结算的全链路业务闭环构建
制造业采购供应链管理是企业运营的核心环节,供应链协同管理在供应链上下游企业之间建立紧密的合作关系,通过信息共享、资源整合、业务协同等方式,实现供应链的全面管理和优化,提高供应链的效率和透明度,降低供应链的成…...
【解密LSTM、GRU如何解决传统RNN梯度消失问题】
解密LSTM与GRU:如何让RNN变得更聪明? 在深度学习的世界里,循环神经网络(RNN)以其卓越的序列数据处理能力广泛应用于自然语言处理、时间序列预测等领域。然而,传统RNN存在的一个严重问题——梯度消失&#…...
转转集团旗下首家二手多品类循环仓店“超级转转”开业
6月9日,国内领先的循环经济企业转转集团旗下首家二手多品类循环仓店“超级转转”正式开业。 转转集团创始人兼CEO黄炜、转转循环时尚发起人朱珠、转转集团COO兼红布林CEO胡伟琨、王府井集团副总裁祝捷等出席了开业剪彩仪式。 据「TMT星球」了解,“超级…...
从零开始打造 OpenSTLinux 6.6 Yocto 系统(基于STM32CubeMX)(九)
设备树移植 和uboot设备树修改的内容同步到kernel将设备树stm32mp157d-stm32mp157daa1-mx.dts复制到内核源码目录下 源码修改及编译 修改arch/arm/boot/dts/st/Makefile,新增设备树编译 stm32mp157f-ev1-m4-examples.dtb \stm32mp157d-stm32mp157daa1-mx.dtb修改…...
Linux云原生安全:零信任架构与机密计算
Linux云原生安全:零信任架构与机密计算 构建坚不可摧的云原生防御体系 引言:云原生安全的范式革命 随着云原生技术的普及,安全边界正在从传统的网络边界向工作负载内部转移。Gartner预测,到2025年,零信任架构将成为超…...
12.找到字符串中所有字母异位词
🧠 题目解析 题目描述: 给定两个字符串 s 和 p,找出 s 中所有 p 的字母异位词的起始索引。 返回的答案以数组形式表示。 字母异位词定义: 若两个字符串包含的字符种类和出现次数完全相同,顺序无所谓,则互为…...
vue3+vite项目中使用.env文件环境变量方法
vue3vite项目中使用.env文件环境变量方法 .env文件作用命名规则常用的配置项示例使用方法注意事项在vite.config.js文件中读取环境变量方法 .env文件作用 .env 文件用于定义环境变量,这些变量可以在项目中通过 import.meta.env 进行访问。Vite 会自动加载这些环境变…...
selenium学习实战【Python爬虫】
selenium学习实战【Python爬虫】 文章目录 selenium学习实战【Python爬虫】一、声明二、学习目标三、安装依赖3.1 安装selenium库3.2 安装浏览器驱动3.2.1 查看Edge版本3.2.2 驱动安装 四、代码讲解4.1 配置浏览器4.2 加载更多4.3 寻找内容4.4 完整代码 五、报告文件爬取5.1 提…...
Python Ovito统计金刚石结构数量
大家好,我是小马老师。 本文介绍python ovito方法统计金刚石结构的方法。 Ovito Identify diamond structure命令可以识别和统计金刚石结构,但是无法直接输出结构的变化情况。 本文使用python调用ovito包的方法,可以持续统计各步的金刚石结构,具体代码如下: from ovito…...
