当前位置: 首页 > news >正文

yolov5导出onnx模型问题

为了适配C++工程代码,我在导出onnx模型时,会把models/yolo.py里面的forward函数改成下面这样,

    #转模型def forward(self, x):z = []  # inference outputfor i in range(self.nl):x[i] = self.m[i](x[i])  # convbs, _, ny, nx = x[i].shape  # x(bs,255,20,20) to x(bs,3,20,20,85)x[i] = x[i].view(bs, self.na, self.no, ny, nx).permute(0, 1, 3, 4, 2).contiguous()if not self.training:  # inferenceif self.onnx_dynamic or self.grid[i].shape[2:4] != x[i].shape[2:4]:self.grid[i], self.anchor_grid[i] = self._make_grid(nx, ny, i)y = x[i].sigmoid()if self.inplace:y[..., 0:2] = (y[..., 0:2] * 2 - 0.5 + self.grid[i]) * self.stride[i]  # xyy[..., 2:4] = (y[..., 2:4] * 2) ** 2 * self.anchor_grid[i]  # wh z.append(y.view(bs, -1, self.no))                    else:  # for YOLOv5 on AWS Inferentia https://github.com/ultralytics/yolov5/pull/2953xy = (y[..., 0:2] * 2 - 0.5 + self.grid[i]) * self.stride[i]  # xywh = (y[..., 2:4] * 2) ** 2 * self.anchor_grid[i]  # whanchor, conf, prob = torch.split(y, [4, 1, self.nc], dim=4)# add a idx (label ids before prob)# oriidxs = torch.argmax(prob, dim=-1).unsqueeze(axis=-1).type(x[i].dtype)# new#idxs = torch.max(prob, dim=-1)[1].data.unsqueeze(axis=-1).type(x[i].dtype)y = torch.cat((xy, wh, conf, idxs, prob), -1)z.append(y.view(bs, -1, self.no + 1))return x if self.training else (torch.cat(z, 1))

也就是把后面类别得分中最大的那个计算出来赋值给idxs,

原来的yolov5输出是x y w h box_score  label1_confidence label2_confidence ....  labeln_confidence.

我改完之后,输出变成x y w h box_score idxs  label1_confidence label2_confidence ....  labeln_confidence.

然后之前我都是在转onnx之前手动的去改代码,然后转完模型再改回来因为train和detect也要用到这个yolo.py中的forward函数,但是后来某项目中,要实现一个自动训练、自动检测、自动转模型,这就不能我手动改了,所以我第一个方法是我复制一份yolo.py复制成yolo_onnx.py,然后export.py中from models.yolo_onnx import Detect,这种方法不可行,因为其他还有还有很多地方也是用的from models.yolo import Detect,最后用的方法如下:

首先在yolo.py中的Detect类中增加一个成员export

class Detect(nn.Module):stride = None  # strides computed during buildonnx_dynamic = False  # ONNX export parameterexport = False  #增加的成员......

然后我在export.py的run函数中给这个值赋值为true


@torch.no_grad()
def run(data=ROOT / 'data/coco128.yaml',  # 'dataset.yaml path'weights=ROOT / 'yolov5s.pt',  # weights pathimgsz=(640, 640),  # image (height, width)batch_size=1,  # batch sizedevice='cpu',  # cuda device, i.e. 0 or 0,1,2,3 or cpuinclude=('torchscript', 'onnx'),  # include formatshalf=False,  # FP16 half-precision exportinplace=False,  # set YOLOv5 Detect() inplace=Truetrain=False,  # model.train() modeoptimize=False,  # TorchScript: optimize for mobileint8=False,  # CoreML/TF INT8 quantizationdynamic=False,  # ONNX/TF: dynamic axessimplify=False,  # ONNX: simplify modelopset=12,  # ONNX: opset versionverbose=False,  # TensorRT: verbose logworkspace=4,  # TensorRT: workspace size (GB)nms=False,  # TF: add NMS to modelagnostic_nms=False,  # TF: add agnostic NMS to modeltopk_per_class=100,  # TF.js NMS: topk per class to keeptopk_all=100,  # TF.js NMS: topk for all classes to keepiou_thres=0.45,  # TF.js NMS: IoU thresholdconf_thres=0.25  # TF.js NMS: confidence threshold):t = time.time()include = [x.lower() for x in include]  # to lowercaseformats = tuple(export_formats()['Argument'][1:])  # --include argumentsflags = [x in include for x in formats]assert sum(flags) == len(include), f'ERROR: Invalid --include {include}, valid --include arguments are {formats}'jit, onnx, xml, engine, coreml, saved_model, pb, tflite, edgetpu, tfjs = flags  # export booleansfile = Path(url2file(weights) if str(weights).startswith(('http:/', 'https:/')) else weights)  # PyTorch weights# Load PyTorch modeldevice = select_device(device)assert not (device.type == 'cpu' and half), '--half only compatible with GPU export, i.e. use --device 0'model = attempt_load(weights, map_location=device, inplace=True, fuse=True)  # load FP32 modelnc, names = model.nc, model.names  # number of classes, class namesmodel.model[-1].export = True# Checksimgsz *= 2 if len(imgsz) == 1 else 1  # expandopset = 12 if ('openvino' in include) else opset  # OpenVINO requires opset <= 12assert nc == len(names), f'Model class count {nc} != len(names) {len(names)}'# Inputgs = int(max(model.stride))  # grid size (max stride)imgsz = [check_img_size(x, gs) for x in imgsz]  # verify img_size are gs-multiplesim = torch.zeros(batch_size, 3, *imgsz).to(device)  # image size(1,3,320,192) BCHW iDetection# Update modelif half:im, model = im.half(), model.half()  # to FP16model.train() if train else model.eval()  # training mode = no Detect() layer grid constructionfor k, m in model.named_modules():if isinstance(m, Conv):  # assign export-friendly activationsif isinstance(m.act, nn.SiLU):m.act = SiLU()elif isinstance(m, Detect):m.inplace = inplacem.onnx_dynamic = dynamicif hasattr(m, 'forward_export'):m.forward = m.forward_export  # assign custom forward (optional)for _ in range(2):y = model(im)  # dry runsshape = tuple(y[0].shape)  # model output shapeLOGGER.info(f"\n{colorstr('PyTorch:')} starting from {file} with output shape {shape} ({file_size(file):.1f} MB)")# Exportsf = [''] * 10  # exported filenameswarnings.filterwarnings(action='ignore', category=torch.jit.TracerWarning)  # suppress TracerWarningif jit:f[0] = export_torchscript(model, im, file, optimize)if engine:  # TensorRT required before ONNXf[1] = export_engine(model, im, file, train, half, simplify, workspace, verbose)if onnx or xml:  # OpenVINO requires ONNXf[2] = export_onnx(model, im, file, opset, train, dynamic, simplify)if xml:  # OpenVINOf[3] = export_openvino(model, im, file)if coreml:_, f[4] = export_coreml(model, im, file)# TensorFlow Exportsif any((saved_model, pb, tflite, edgetpu, tfjs)):if int8 or edgetpu:  # TFLite --int8 bug https://github.com/ultralytics/yolov5/issues/5707check_requirements(('flatbuffers==1.12',))  # required before `import tensorflow`assert not (tflite and tfjs), 'TFLite and TF.js models must be exported separately, please pass only one type.'model, f[5] = export_saved_model(model, im, file, dynamic, tf_nms=nms or agnostic_nms or tfjs,agnostic_nms=agnostic_nms or tfjs, topk_per_class=topk_per_class,topk_all=topk_all, conf_thres=conf_thres, iou_thres=iou_thres)  # keras modelif pb or tfjs:  # pb prerequisite to tfjsf[6] = export_pb(model, im, file)if tflite or edgetpu:f[7] = export_tflite(model, im, file, int8=int8 or edgetpu, data=data, ncalib=100)if edgetpu:f[8] = export_edgetpu(model, im, file)if tfjs:f[9] = export_tfjs(model, im, file)# Finishf = [str(x) for x in f if x]  # filter out '' and Noneif any(f):LOGGER.info(f'\nExport complete ({time.time() - t:.2f}s)'f"\nResults saved to {colorstr('bold', file.parent.resolve())}"f"\nDetect:          python detect.py --weights {f[-1]}"f"\nPyTorch Hub:     model = torch.hub.load('ultralytics/yolov5', 'custom', '{f[-1]}')"f"\nValidate:        python val.py --weights {f[-1]}"f"\nVisualize:       https://netron.app")return f  # return list of exported files/dirs

然后修改yolo.py中的forward函数,增加分支判断

def forward(self, x):if self.export:print("self.export===============",self.export)z = []  # inference outputfor i in range(self.nl):x[i] = self.m[i](x[i])  # convbs, _, ny, nx = x[i].shape  # x(bs,255,20,20) to x(bs,3,20,20,85)x[i] = x[i].view(bs, self.na, self.no, ny, nx).permute(0, 1, 3, 4, 2).contiguous()if not self.training:  # inferenceif self.onnx_dynamic or self.grid[i].shape[2:4] != x[i].shape[2:4]:self.grid[i], self.anchor_grid[i] = self._make_grid(nx, ny, i)y = x[i].sigmoid()if self.inplace:y[..., 0:2] = (y[..., 0:2] * 2 - 0.5 + self.grid[i]) * self.stride[i]  # xyy[..., 2:4] = (y[..., 2:4] * 2) ** 2 * self.anchor_grid[i]  # wh z.append(y.view(bs, -1, self.no))                    else:  # for YOLOv5 on AWS Inferentia https://github.com/ultralytics/yolov5/pull/2953xy = (y[..., 0:2] * 2 - 0.5 + self.grid[i]) * self.stride[i]  # xywh = (y[..., 2:4] * 2) ** 2 * self.anchor_grid[i]  # whanchor, conf, prob = torch.split(y, [4, 1, self.nc], dim=4)# add a idx (label ids before prob)# oriidxs = torch.argmax(prob, dim=-1).unsqueeze(axis=-1).type(x[i].dtype)# new#idxs = torch.max(prob, dim=-1)[1].data.unsqueeze(axis=-1).type(x[i].dtype)y = torch.cat((xy, wh, conf, idxs, prob), -1)z.append(y.view(bs, -1, self.no + 1))return x if self.training else (torch.cat(z, 1))else:print("self.export===============",self.export)z = []  # inference outputfor i in range(self.nl):x[i] = self.m[i](x[i])  # convbs, _, ny, nx = x[i].shape  # x(bs,255,20,20) to x(bs,3,20,20,85)x[i] = x[i].view(bs, self.na, self.no, ny, nx).permute(0, 1, 3, 4, 2).contiguous()if not self.training:  # inferenceif self.onnx_dynamic or self.grid[i].shape[2:4] != x[i].shape[2:4]:self.grid[i], self.anchor_grid[i] = self._make_grid(nx, ny, i)y = x[i].sigmoid()if self.inplace:y[..., 0:2] = (y[..., 0:2] * 2 - 0.5 + self.grid[i]) * self.stride[i]  # xyy[..., 2:4] = (y[..., 2:4] * 2) ** 2 * self.anchor_grid[i]  # whelse:  # for YOLOv5 on AWS Inferentia https://github.com/ultralytics/yolov5/pull/2953xy = (y[..., 0:2] * 2 - 0.5 + self.grid[i]) * self.stride[i]  # xywh = (y[..., 2:4] * 2) ** 2 * self.anchor_grid[i]  # why = torch.cat((xy, wh, y[..., 4:]), -1)z.append(y.view(bs, -1, self.no))return x if self.training else (torch.cat(z, 1), x)

这样就可以实现train和export分别跑不同的代码了。

相关文章:

yolov5导出onnx模型问题

为了适配C工程代码&#xff0c;我在导出onnx模型时&#xff0c;会把models/yolo.py里面的forward函数改成下面这样&#xff0c; #转模型def forward(self, x):z [] # inference outputfor i in range(self.nl):x[i] self.m[i](x[i]) # convbs, _, ny, nx x[i].shape # x(…...

JS第一课简单看看这是啥东西

1.什么是JavaScript JS是一门编程语言&#xff0c;是一种运行在客户端(浏览器)的编程语言&#xff0c;主要是让前端的画面动起来&#xff0c;注意HTML和CSS不是编程语言&#xff0c;他俩是一种标记语言。JS只要有浏览器就能运行不用跟Python或者Java一样上来装一个jdk或者Pyth…...

2023年常用网络安全政策标准整合

文章目录 前言一、政策篇(一)等级保护(二)关键信息基础设施保护(三)数据安全(四)数据出境安全评估(五)网络信息安全(六)应急响应(七)网络安全专用产品检测认证制度(八)个人信息保护(九)商用密码二、标准篇前言 2023年,国家网络安全政策和标准密集发布,逐渐…...

Redis -- 背景知识

“知识就是力量” -- 弗朗西斯培根 目录 特性 为啥Redis快? 应用场景 Redis不能做什么&#xff1f; Redis是在内存中存储数据的一个中间件&#xff0c;用作为数据库&#xff0c;也可以用作为缓存&#xff0c;在分布式中有很高的威望。 特性 In-memory data structures&…...

如何在Shopee平台上进行手机类目选品?

在Shopee平台上进行手机类目的选品是一个关键而复杂的任务。卖家需要经过一系列的策略和步骤&#xff0c;以确保选品的成功和销售业绩的提升。下面将介绍一些有效的策略&#xff0c;帮助卖家在Shopee平台上进行手机类目选品。 先给大家推荐一款shopee知虾数据运营工具知虾免费…...

班级管理神器,教师在线发布系统

现如今&#xff0c;班级管理也需要与时俱进。传统的管理方式不仅效率低下&#xff0c;而且容易出错。为了更好地管理班级&#xff0c;教师需要一个强大的工具来帮助他们发布信息和管理学生。 发布系统是一款专门为教师设计的数字化管理工具。通过系统&#xff0c;老师们就可以…...

【Spring Boot 3】异步线程任务

【Spring Boot 3】异步线程任务 背景介绍开发环境开发步骤及源码工程目录结构总结背景 软件开发是一门实践性科学,对大多数人来说,学习一种新技术不是一开始就去深究其原理,而是先从做出一个可工作的DEMO入手。但在我个人学习和工作经历中,每次学习新技术总是要花费或多或…...

JAVA斗地主逻辑-控制台版

未排序版&#xff1a; 准备牌->洗牌 -> 发牌 -> 看牌: App程序入口&#xff1a; package doudihzu01;public class App {public static void main(String[] args) {/*作为斗地主程序入口这里不写代码逻辑*///无参创建对象&#xff0c;作为程序启动new PokerGame();…...

Harmony的自定义组件和Page的数据同步

在开发过程中会经常使用自定义组件,就会遇到一个问题,在页面中引入组件后,如何把改变的值传递到自定义组件中呢,这就用到了装饰器,在这是单向传递的,用的装饰器是@State和@Prop @State在page页面中监听数据的变化 @Prop在自定义组件中监听page页面传递过来的变化值,并赋…...

【Vue3+Vite】路由机制router 快速学习 第四期

文章目录 路由简介路由是什么路由的作用 一、路由入门案例1. 创建项目 导入路由依赖2. 准备页面和组件3. 准备路由配置4. main.js引入router配置 二、路由重定向三、编程式路由(useRouter)四、路由传参(useRoute)五、路由守卫总结 路由简介 路由是什么 路由就是根据不同的 URL…...

python脚本实现浏览器驱动chromedriver的版本自动升级

chromedriver的版本号与chrome浏览器版本不匹配时在运行程序时就会报错 用下面的脚本可以自动安装chromedriver的最新版本到指定路径 from webdriver_manager.utils import get_browser_version_from_os from webdriver_manager.chrome import ChromeDriverManager import re…...

npm使用国内淘宝镜像

一、命令配置 1、设置淘宝镜像源 npm config set registry https://registry.npmmirror.com2、查看镜像使用状态 npm config get registry如果返回https://registry.npmmirror.com/,说明配置的是淘宝镜像。 如果返回https://registry.npmjs.org/,说明配置的是官网镜像。 二…...

# Redis 分布式锁如何自动续期

Redis 分布式锁如何自动续期 何为分布式 分布式&#xff0c;从狭义上理解&#xff0c;也与集群差不多&#xff0c;但是它的组织比较松散&#xff0c;不像集群&#xff0c;有一定组织性&#xff0c;一台服务器宕了&#xff0c;其他的服务器可以顶上来。分布式的每一个节点&…...

数据结构 归并排序详解

1.基本思想 归并排序&#xff08;MERGE-SORT&#xff09;是建立在归并操作上的一种有效的排序算法,该算法是采用分治法&#xff08;Divide andConquer&#xff09;的一个非常典型的应用。 将已有序的子序列合并&#xff0c;得到完全有序的序列&#xff0c;即先使每个子序列有序…...

服务器C盘突然满了,是什么问题

随着时代的发展、互联网的普及&#xff0c;加上近几年云计算服务的诞生以及大规模普及&#xff0c;对于服务器的使用目前是非常普遍的&#xff0c;用户运维的主要对象一般也主要是服务器方面。在日常使用服务器的过程中&#xff0c;我们也会遇到各式各样的问题。最近就有遇到用…...

【深度学习】ND4J-科学计算库

目录 简介 基础用法 基础信息 数组创建 打印数组 变更维度&堆叠 加减乘除 累加/最大/最小 转换操作 矩陈乘法 索引/迭代 深拷贝/引用传递/视图 引用传递 视图 深拷贝 其它 简介 ND4J主要是JVM的科学计算库&#xff0c;内置了很多计算方法&#xff0c;目的…...

2024-01-29 ubuntu 用脚本设置安装交叉编译工具链路径方法,设置PATH环境变量

一、设置PATH环境变量的方法,建议用~/.bash_profile的方法&#xff0c;不然在ssh登录的时候可能没有设置PATH. 二、下面的完整的脚本&#xff0c;里面的echo "export PATH$build_toolchain_path:\$PATH" >> $HOME/.bashrc 就是把交叉编译路径写写到.bashrc设置…...

今年春节很多年轻人选择不买战袍,减少年货置办,「极简过年」,如何看待此现象?

​近年来&#xff0c;春节期间出现了一种新的现象&#xff0c;越来越多的年轻人选择不买战袍&#xff0c;减少年货置办&#xff0c;采用“极简过年”的方式度过春节。对于这一现象&#xff0c;不同人有不同的看法。 首先&#xff0c;这种极简过年的方式符合当前社会的一些价值观…...

C语言·贪吃蛇游戏(下)

上节我们将要完成贪吃蛇游戏所需的前置知识都学完了&#xff0c;那么这节我们就开始动手写代码了 1. 程序规划 首先我们应该规划好我们的代码文件&#xff0c;设置3个文件&#xff1a;snack.h 用来声明游戏中实现各种功能的函数&#xff0c;snack.c 用来实现函数&#xff0c;t…...

Flask 入门2:路由

1. 前言 在上一节中&#xff0c;我们使用到了静态路由&#xff0c;即一个路由规则对应一个 URL。而在实际应用中&#xff0c;更多使用的则是动态路由&#xff0c;它的 URL是可变的。 2. 定义一个很常见的路由地址 app.route(/user/<username>) def user(username):ret…...

SkyWalking 10.2.0 SWCK 配置过程

SkyWalking 10.2.0 & SWCK 配置过程 skywalking oap-server & ui 使用Docker安装在K8S集群以外&#xff0c;K8S集群中的微服务使用initContainer按命名空间将skywalking-java-agent注入到业务容器中。 SWCK有整套的解决方案&#xff0c;全安装在K8S群集中。 具体可参…...

JavaScript 中的 ES|QL:利用 Apache Arrow 工具

作者&#xff1a;来自 Elastic Jeffrey Rengifo 学习如何将 ES|QL 与 JavaScript 的 Apache Arrow 客户端工具一起使用。 想获得 Elastic 认证吗&#xff1f;了解下一期 Elasticsearch Engineer 培训的时间吧&#xff01; Elasticsearch 拥有众多新功能&#xff0c;助你为自己…...

Keil 中设置 STM32 Flash 和 RAM 地址详解

文章目录 Keil 中设置 STM32 Flash 和 RAM 地址详解一、Flash 和 RAM 配置界面(Target 选项卡)1. IROM1(用于配置 Flash)2. IRAM1(用于配置 RAM)二、链接器设置界面(Linker 选项卡)1. 勾选“Use Memory Layout from Target Dialog”2. 查看链接器参数(如果没有勾选上面…...

苍穹外卖--缓存菜品

1.问题说明 用户端小程序展示的菜品数据都是通过查询数据库获得&#xff0c;如果用户端访问量比较大&#xff0c;数据库访问压力随之增大 2.实现思路 通过Redis来缓存菜品数据&#xff0c;减少数据库查询操作。 缓存逻辑分析&#xff1a; ①每个分类下的菜品保持一份缓存数据…...

Qt Http Server模块功能及架构

Qt Http Server 是 Qt 6.0 中引入的一个新模块&#xff0c;它提供了一个轻量级的 HTTP 服务器实现&#xff0c;主要用于构建基于 HTTP 的应用程序和服务。 功能介绍&#xff1a; 主要功能 HTTP服务器功能&#xff1a; 支持 HTTP/1.1 协议 简单的请求/响应处理模型 支持 GET…...

深入解析C++中的extern关键字:跨文件共享变量与函数的终极指南

&#x1f680; C extern 关键字深度解析&#xff1a;跨文件编程的终极指南 &#x1f4c5; 更新时间&#xff1a;2025年6月5日 &#x1f3f7;️ 标签&#xff1a;C | extern关键字 | 多文件编程 | 链接与声明 | 现代C 文章目录 前言&#x1f525;一、extern 是什么&#xff1f;&…...

什么?连接服务器也能可视化显示界面?:基于X11 Forwarding + CentOS + MobaXterm实战指南

文章目录 什么是X11?环境准备实战步骤1️⃣ 服务器端配置(CentOS)2️⃣ 客户端配置(MobaXterm)3️⃣ 验证X11 Forwarding4️⃣ 运行自定义GUI程序(Python示例)5️⃣ 成功效果![在这里插入图片描述](https://i-blog.csdnimg.cn/direct/55aefaea8a9f477e86d065227851fe3d.pn…...

Java多线程实现之Thread类深度解析

Java多线程实现之Thread类深度解析 一、多线程基础概念1.1 什么是线程1.2 多线程的优势1.3 Java多线程模型 二、Thread类的基本结构与构造函数2.1 Thread类的继承关系2.2 构造函数 三、创建和启动线程3.1 继承Thread类创建线程3.2 实现Runnable接口创建线程 四、Thread类的核心…...

Spring Cloud Gateway 中自定义验证码接口返回 404 的排查与解决

Spring Cloud Gateway 中自定义验证码接口返回 404 的排查与解决 问题背景 在一个基于 Spring Cloud Gateway WebFlux 构建的微服务项目中&#xff0c;新增了一个本地验证码接口 /code&#xff0c;使用函数式路由&#xff08;RouterFunction&#xff09;和 Hutool 的 Circle…...

git: early EOF

macOS报错&#xff1a; Initialized empty Git repository in /usr/local/Homebrew/Library/Taps/homebrew/homebrew-core/.git/ remote: Enumerating objects: 2691797, done. remote: Counting objects: 100% (1760/1760), done. remote: Compressing objects: 100% (636/636…...