当前位置: 首页 > news >正文

yolov5导出onnx模型问题

为了适配C++工程代码,我在导出onnx模型时,会把models/yolo.py里面的forward函数改成下面这样,

    #转模型def forward(self, x):z = []  # inference outputfor i in range(self.nl):x[i] = self.m[i](x[i])  # convbs, _, ny, nx = x[i].shape  # x(bs,255,20,20) to x(bs,3,20,20,85)x[i] = x[i].view(bs, self.na, self.no, ny, nx).permute(0, 1, 3, 4, 2).contiguous()if not self.training:  # inferenceif self.onnx_dynamic or self.grid[i].shape[2:4] != x[i].shape[2:4]:self.grid[i], self.anchor_grid[i] = self._make_grid(nx, ny, i)y = x[i].sigmoid()if self.inplace:y[..., 0:2] = (y[..., 0:2] * 2 - 0.5 + self.grid[i]) * self.stride[i]  # xyy[..., 2:4] = (y[..., 2:4] * 2) ** 2 * self.anchor_grid[i]  # wh z.append(y.view(bs, -1, self.no))                    else:  # for YOLOv5 on AWS Inferentia https://github.com/ultralytics/yolov5/pull/2953xy = (y[..., 0:2] * 2 - 0.5 + self.grid[i]) * self.stride[i]  # xywh = (y[..., 2:4] * 2) ** 2 * self.anchor_grid[i]  # whanchor, conf, prob = torch.split(y, [4, 1, self.nc], dim=4)# add a idx (label ids before prob)# oriidxs = torch.argmax(prob, dim=-1).unsqueeze(axis=-1).type(x[i].dtype)# new#idxs = torch.max(prob, dim=-1)[1].data.unsqueeze(axis=-1).type(x[i].dtype)y = torch.cat((xy, wh, conf, idxs, prob), -1)z.append(y.view(bs, -1, self.no + 1))return x if self.training else (torch.cat(z, 1))

也就是把后面类别得分中最大的那个计算出来赋值给idxs,

原来的yolov5输出是x y w h box_score  label1_confidence label2_confidence ....  labeln_confidence.

我改完之后,输出变成x y w h box_score idxs  label1_confidence label2_confidence ....  labeln_confidence.

然后之前我都是在转onnx之前手动的去改代码,然后转完模型再改回来因为train和detect也要用到这个yolo.py中的forward函数,但是后来某项目中,要实现一个自动训练、自动检测、自动转模型,这就不能我手动改了,所以我第一个方法是我复制一份yolo.py复制成yolo_onnx.py,然后export.py中from models.yolo_onnx import Detect,这种方法不可行,因为其他还有还有很多地方也是用的from models.yolo import Detect,最后用的方法如下:

首先在yolo.py中的Detect类中增加一个成员export

class Detect(nn.Module):stride = None  # strides computed during buildonnx_dynamic = False  # ONNX export parameterexport = False  #增加的成员......

然后我在export.py的run函数中给这个值赋值为true


@torch.no_grad()
def run(data=ROOT / 'data/coco128.yaml',  # 'dataset.yaml path'weights=ROOT / 'yolov5s.pt',  # weights pathimgsz=(640, 640),  # image (height, width)batch_size=1,  # batch sizedevice='cpu',  # cuda device, i.e. 0 or 0,1,2,3 or cpuinclude=('torchscript', 'onnx'),  # include formatshalf=False,  # FP16 half-precision exportinplace=False,  # set YOLOv5 Detect() inplace=Truetrain=False,  # model.train() modeoptimize=False,  # TorchScript: optimize for mobileint8=False,  # CoreML/TF INT8 quantizationdynamic=False,  # ONNX/TF: dynamic axessimplify=False,  # ONNX: simplify modelopset=12,  # ONNX: opset versionverbose=False,  # TensorRT: verbose logworkspace=4,  # TensorRT: workspace size (GB)nms=False,  # TF: add NMS to modelagnostic_nms=False,  # TF: add agnostic NMS to modeltopk_per_class=100,  # TF.js NMS: topk per class to keeptopk_all=100,  # TF.js NMS: topk for all classes to keepiou_thres=0.45,  # TF.js NMS: IoU thresholdconf_thres=0.25  # TF.js NMS: confidence threshold):t = time.time()include = [x.lower() for x in include]  # to lowercaseformats = tuple(export_formats()['Argument'][1:])  # --include argumentsflags = [x in include for x in formats]assert sum(flags) == len(include), f'ERROR: Invalid --include {include}, valid --include arguments are {formats}'jit, onnx, xml, engine, coreml, saved_model, pb, tflite, edgetpu, tfjs = flags  # export booleansfile = Path(url2file(weights) if str(weights).startswith(('http:/', 'https:/')) else weights)  # PyTorch weights# Load PyTorch modeldevice = select_device(device)assert not (device.type == 'cpu' and half), '--half only compatible with GPU export, i.e. use --device 0'model = attempt_load(weights, map_location=device, inplace=True, fuse=True)  # load FP32 modelnc, names = model.nc, model.names  # number of classes, class namesmodel.model[-1].export = True# Checksimgsz *= 2 if len(imgsz) == 1 else 1  # expandopset = 12 if ('openvino' in include) else opset  # OpenVINO requires opset <= 12assert nc == len(names), f'Model class count {nc} != len(names) {len(names)}'# Inputgs = int(max(model.stride))  # grid size (max stride)imgsz = [check_img_size(x, gs) for x in imgsz]  # verify img_size are gs-multiplesim = torch.zeros(batch_size, 3, *imgsz).to(device)  # image size(1,3,320,192) BCHW iDetection# Update modelif half:im, model = im.half(), model.half()  # to FP16model.train() if train else model.eval()  # training mode = no Detect() layer grid constructionfor k, m in model.named_modules():if isinstance(m, Conv):  # assign export-friendly activationsif isinstance(m.act, nn.SiLU):m.act = SiLU()elif isinstance(m, Detect):m.inplace = inplacem.onnx_dynamic = dynamicif hasattr(m, 'forward_export'):m.forward = m.forward_export  # assign custom forward (optional)for _ in range(2):y = model(im)  # dry runsshape = tuple(y[0].shape)  # model output shapeLOGGER.info(f"\n{colorstr('PyTorch:')} starting from {file} with output shape {shape} ({file_size(file):.1f} MB)")# Exportsf = [''] * 10  # exported filenameswarnings.filterwarnings(action='ignore', category=torch.jit.TracerWarning)  # suppress TracerWarningif jit:f[0] = export_torchscript(model, im, file, optimize)if engine:  # TensorRT required before ONNXf[1] = export_engine(model, im, file, train, half, simplify, workspace, verbose)if onnx or xml:  # OpenVINO requires ONNXf[2] = export_onnx(model, im, file, opset, train, dynamic, simplify)if xml:  # OpenVINOf[3] = export_openvino(model, im, file)if coreml:_, f[4] = export_coreml(model, im, file)# TensorFlow Exportsif any((saved_model, pb, tflite, edgetpu, tfjs)):if int8 or edgetpu:  # TFLite --int8 bug https://github.com/ultralytics/yolov5/issues/5707check_requirements(('flatbuffers==1.12',))  # required before `import tensorflow`assert not (tflite and tfjs), 'TFLite and TF.js models must be exported separately, please pass only one type.'model, f[5] = export_saved_model(model, im, file, dynamic, tf_nms=nms or agnostic_nms or tfjs,agnostic_nms=agnostic_nms or tfjs, topk_per_class=topk_per_class,topk_all=topk_all, conf_thres=conf_thres, iou_thres=iou_thres)  # keras modelif pb or tfjs:  # pb prerequisite to tfjsf[6] = export_pb(model, im, file)if tflite or edgetpu:f[7] = export_tflite(model, im, file, int8=int8 or edgetpu, data=data, ncalib=100)if edgetpu:f[8] = export_edgetpu(model, im, file)if tfjs:f[9] = export_tfjs(model, im, file)# Finishf = [str(x) for x in f if x]  # filter out '' and Noneif any(f):LOGGER.info(f'\nExport complete ({time.time() - t:.2f}s)'f"\nResults saved to {colorstr('bold', file.parent.resolve())}"f"\nDetect:          python detect.py --weights {f[-1]}"f"\nPyTorch Hub:     model = torch.hub.load('ultralytics/yolov5', 'custom', '{f[-1]}')"f"\nValidate:        python val.py --weights {f[-1]}"f"\nVisualize:       https://netron.app")return f  # return list of exported files/dirs

然后修改yolo.py中的forward函数,增加分支判断

def forward(self, x):if self.export:print("self.export===============",self.export)z = []  # inference outputfor i in range(self.nl):x[i] = self.m[i](x[i])  # convbs, _, ny, nx = x[i].shape  # x(bs,255,20,20) to x(bs,3,20,20,85)x[i] = x[i].view(bs, self.na, self.no, ny, nx).permute(0, 1, 3, 4, 2).contiguous()if not self.training:  # inferenceif self.onnx_dynamic or self.grid[i].shape[2:4] != x[i].shape[2:4]:self.grid[i], self.anchor_grid[i] = self._make_grid(nx, ny, i)y = x[i].sigmoid()if self.inplace:y[..., 0:2] = (y[..., 0:2] * 2 - 0.5 + self.grid[i]) * self.stride[i]  # xyy[..., 2:4] = (y[..., 2:4] * 2) ** 2 * self.anchor_grid[i]  # wh z.append(y.view(bs, -1, self.no))                    else:  # for YOLOv5 on AWS Inferentia https://github.com/ultralytics/yolov5/pull/2953xy = (y[..., 0:2] * 2 - 0.5 + self.grid[i]) * self.stride[i]  # xywh = (y[..., 2:4] * 2) ** 2 * self.anchor_grid[i]  # whanchor, conf, prob = torch.split(y, [4, 1, self.nc], dim=4)# add a idx (label ids before prob)# oriidxs = torch.argmax(prob, dim=-1).unsqueeze(axis=-1).type(x[i].dtype)# new#idxs = torch.max(prob, dim=-1)[1].data.unsqueeze(axis=-1).type(x[i].dtype)y = torch.cat((xy, wh, conf, idxs, prob), -1)z.append(y.view(bs, -1, self.no + 1))return x if self.training else (torch.cat(z, 1))else:print("self.export===============",self.export)z = []  # inference outputfor i in range(self.nl):x[i] = self.m[i](x[i])  # convbs, _, ny, nx = x[i].shape  # x(bs,255,20,20) to x(bs,3,20,20,85)x[i] = x[i].view(bs, self.na, self.no, ny, nx).permute(0, 1, 3, 4, 2).contiguous()if not self.training:  # inferenceif self.onnx_dynamic or self.grid[i].shape[2:4] != x[i].shape[2:4]:self.grid[i], self.anchor_grid[i] = self._make_grid(nx, ny, i)y = x[i].sigmoid()if self.inplace:y[..., 0:2] = (y[..., 0:2] * 2 - 0.5 + self.grid[i]) * self.stride[i]  # xyy[..., 2:4] = (y[..., 2:4] * 2) ** 2 * self.anchor_grid[i]  # whelse:  # for YOLOv5 on AWS Inferentia https://github.com/ultralytics/yolov5/pull/2953xy = (y[..., 0:2] * 2 - 0.5 + self.grid[i]) * self.stride[i]  # xywh = (y[..., 2:4] * 2) ** 2 * self.anchor_grid[i]  # why = torch.cat((xy, wh, y[..., 4:]), -1)z.append(y.view(bs, -1, self.no))return x if self.training else (torch.cat(z, 1), x)

这样就可以实现train和export分别跑不同的代码了。

相关文章:

yolov5导出onnx模型问题

为了适配C工程代码&#xff0c;我在导出onnx模型时&#xff0c;会把models/yolo.py里面的forward函数改成下面这样&#xff0c; #转模型def forward(self, x):z [] # inference outputfor i in range(self.nl):x[i] self.m[i](x[i]) # convbs, _, ny, nx x[i].shape # x(…...

JS第一课简单看看这是啥东西

1.什么是JavaScript JS是一门编程语言&#xff0c;是一种运行在客户端(浏览器)的编程语言&#xff0c;主要是让前端的画面动起来&#xff0c;注意HTML和CSS不是编程语言&#xff0c;他俩是一种标记语言。JS只要有浏览器就能运行不用跟Python或者Java一样上来装一个jdk或者Pyth…...

2023年常用网络安全政策标准整合

文章目录 前言一、政策篇(一)等级保护(二)关键信息基础设施保护(三)数据安全(四)数据出境安全评估(五)网络信息安全(六)应急响应(七)网络安全专用产品检测认证制度(八)个人信息保护(九)商用密码二、标准篇前言 2023年,国家网络安全政策和标准密集发布,逐渐…...

Redis -- 背景知识

“知识就是力量” -- 弗朗西斯培根 目录 特性 为啥Redis快? 应用场景 Redis不能做什么&#xff1f; Redis是在内存中存储数据的一个中间件&#xff0c;用作为数据库&#xff0c;也可以用作为缓存&#xff0c;在分布式中有很高的威望。 特性 In-memory data structures&…...

如何在Shopee平台上进行手机类目选品?

在Shopee平台上进行手机类目的选品是一个关键而复杂的任务。卖家需要经过一系列的策略和步骤&#xff0c;以确保选品的成功和销售业绩的提升。下面将介绍一些有效的策略&#xff0c;帮助卖家在Shopee平台上进行手机类目选品。 先给大家推荐一款shopee知虾数据运营工具知虾免费…...

班级管理神器,教师在线发布系统

现如今&#xff0c;班级管理也需要与时俱进。传统的管理方式不仅效率低下&#xff0c;而且容易出错。为了更好地管理班级&#xff0c;教师需要一个强大的工具来帮助他们发布信息和管理学生。 发布系统是一款专门为教师设计的数字化管理工具。通过系统&#xff0c;老师们就可以…...

【Spring Boot 3】异步线程任务

【Spring Boot 3】异步线程任务 背景介绍开发环境开发步骤及源码工程目录结构总结背景 软件开发是一门实践性科学,对大多数人来说,学习一种新技术不是一开始就去深究其原理,而是先从做出一个可工作的DEMO入手。但在我个人学习和工作经历中,每次学习新技术总是要花费或多或…...

JAVA斗地主逻辑-控制台版

未排序版&#xff1a; 准备牌->洗牌 -> 发牌 -> 看牌: App程序入口&#xff1a; package doudihzu01;public class App {public static void main(String[] args) {/*作为斗地主程序入口这里不写代码逻辑*///无参创建对象&#xff0c;作为程序启动new PokerGame();…...

Harmony的自定义组件和Page的数据同步

在开发过程中会经常使用自定义组件,就会遇到一个问题,在页面中引入组件后,如何把改变的值传递到自定义组件中呢,这就用到了装饰器,在这是单向传递的,用的装饰器是@State和@Prop @State在page页面中监听数据的变化 @Prop在自定义组件中监听page页面传递过来的变化值,并赋…...

【Vue3+Vite】路由机制router 快速学习 第四期

文章目录 路由简介路由是什么路由的作用 一、路由入门案例1. 创建项目 导入路由依赖2. 准备页面和组件3. 准备路由配置4. main.js引入router配置 二、路由重定向三、编程式路由(useRouter)四、路由传参(useRoute)五、路由守卫总结 路由简介 路由是什么 路由就是根据不同的 URL…...

python脚本实现浏览器驱动chromedriver的版本自动升级

chromedriver的版本号与chrome浏览器版本不匹配时在运行程序时就会报错 用下面的脚本可以自动安装chromedriver的最新版本到指定路径 from webdriver_manager.utils import get_browser_version_from_os from webdriver_manager.chrome import ChromeDriverManager import re…...

npm使用国内淘宝镜像

一、命令配置 1、设置淘宝镜像源 npm config set registry https://registry.npmmirror.com2、查看镜像使用状态 npm config get registry如果返回https://registry.npmmirror.com/,说明配置的是淘宝镜像。 如果返回https://registry.npmjs.org/,说明配置的是官网镜像。 二…...

# Redis 分布式锁如何自动续期

Redis 分布式锁如何自动续期 何为分布式 分布式&#xff0c;从狭义上理解&#xff0c;也与集群差不多&#xff0c;但是它的组织比较松散&#xff0c;不像集群&#xff0c;有一定组织性&#xff0c;一台服务器宕了&#xff0c;其他的服务器可以顶上来。分布式的每一个节点&…...

数据结构 归并排序详解

1.基本思想 归并排序&#xff08;MERGE-SORT&#xff09;是建立在归并操作上的一种有效的排序算法,该算法是采用分治法&#xff08;Divide andConquer&#xff09;的一个非常典型的应用。 将已有序的子序列合并&#xff0c;得到完全有序的序列&#xff0c;即先使每个子序列有序…...

服务器C盘突然满了,是什么问题

随着时代的发展、互联网的普及&#xff0c;加上近几年云计算服务的诞生以及大规模普及&#xff0c;对于服务器的使用目前是非常普遍的&#xff0c;用户运维的主要对象一般也主要是服务器方面。在日常使用服务器的过程中&#xff0c;我们也会遇到各式各样的问题。最近就有遇到用…...

【深度学习】ND4J-科学计算库

目录 简介 基础用法 基础信息 数组创建 打印数组 变更维度&堆叠 加减乘除 累加/最大/最小 转换操作 矩陈乘法 索引/迭代 深拷贝/引用传递/视图 引用传递 视图 深拷贝 其它 简介 ND4J主要是JVM的科学计算库&#xff0c;内置了很多计算方法&#xff0c;目的…...

2024-01-29 ubuntu 用脚本设置安装交叉编译工具链路径方法,设置PATH环境变量

一、设置PATH环境变量的方法,建议用~/.bash_profile的方法&#xff0c;不然在ssh登录的时候可能没有设置PATH. 二、下面的完整的脚本&#xff0c;里面的echo "export PATH$build_toolchain_path:\$PATH" >> $HOME/.bashrc 就是把交叉编译路径写写到.bashrc设置…...

今年春节很多年轻人选择不买战袍,减少年货置办,「极简过年」,如何看待此现象?

​近年来&#xff0c;春节期间出现了一种新的现象&#xff0c;越来越多的年轻人选择不买战袍&#xff0c;减少年货置办&#xff0c;采用“极简过年”的方式度过春节。对于这一现象&#xff0c;不同人有不同的看法。 首先&#xff0c;这种极简过年的方式符合当前社会的一些价值观…...

C语言·贪吃蛇游戏(下)

上节我们将要完成贪吃蛇游戏所需的前置知识都学完了&#xff0c;那么这节我们就开始动手写代码了 1. 程序规划 首先我们应该规划好我们的代码文件&#xff0c;设置3个文件&#xff1a;snack.h 用来声明游戏中实现各种功能的函数&#xff0c;snack.c 用来实现函数&#xff0c;t…...

Flask 入门2:路由

1. 前言 在上一节中&#xff0c;我们使用到了静态路由&#xff0c;即一个路由规则对应一个 URL。而在实际应用中&#xff0c;更多使用的则是动态路由&#xff0c;它的 URL是可变的。 2. 定义一个很常见的路由地址 app.route(/user/<username>) def user(username):ret…...

iOS 26 携众系统重磅更新,但“苹果智能”仍与国行无缘

美国西海岸的夏天&#xff0c;再次被苹果点燃。一年一度的全球开发者大会 WWDC25 如期而至&#xff0c;这不仅是开发者的盛宴&#xff0c;更是全球数亿苹果用户翘首以盼的科技春晚。今年&#xff0c;苹果依旧为我们带来了全家桶式的系统更新&#xff0c;包括 iOS 26、iPadOS 26…...

利用ngx_stream_return_module构建简易 TCP/UDP 响应网关

一、模块概述 ngx_stream_return_module 提供了一个极简的指令&#xff1a; return <value>;在收到客户端连接后&#xff0c;立即将 <value> 写回并关闭连接。<value> 支持内嵌文本和内置变量&#xff08;如 $time_iso8601、$remote_addr 等&#xff09;&a…...

进程地址空间(比特课总结)

一、进程地址空间 1. 环境变量 1 &#xff09;⽤户级环境变量与系统级环境变量 全局属性&#xff1a;环境变量具有全局属性&#xff0c;会被⼦进程继承。例如当bash启动⼦进程时&#xff0c;环 境变量会⾃动传递给⼦进程。 本地变量限制&#xff1a;本地变量只在当前进程(ba…...

Leetcode 3577. Count the Number of Computer Unlocking Permutations

Leetcode 3577. Count the Number of Computer Unlocking Permutations 1. 解题思路2. 代码实现 题目链接&#xff1a;3577. Count the Number of Computer Unlocking Permutations 1. 解题思路 这一题其实就是一个脑筋急转弯&#xff0c;要想要能够将所有的电脑解锁&#x…...

质量体系的重要

质量体系是为确保产品、服务或过程质量满足规定要求&#xff0c;由相互关联的要素构成的有机整体。其核心内容可归纳为以下五个方面&#xff1a; &#x1f3db;️ 一、组织架构与职责 质量体系明确组织内各部门、岗位的职责与权限&#xff0c;形成层级清晰的管理网络&#xf…...

解决本地部署 SmolVLM2 大语言模型运行 flash-attn 报错

出现的问题 安装 flash-attn 会一直卡在 build 那一步或者运行报错 解决办法 是因为你安装的 flash-attn 版本没有对应上&#xff0c;所以报错&#xff0c;到 https://github.com/Dao-AILab/flash-attention/releases 下载对应版本&#xff0c;cu、torch、cp 的版本一定要对…...

Android 之 kotlin 语言学习笔记三(Kotlin-Java 互操作)

参考官方文档&#xff1a;https://developer.android.google.cn/kotlin/interop?hlzh-cn 一、Java&#xff08;供 Kotlin 使用&#xff09; 1、不得使用硬关键字 不要使用 Kotlin 的任何硬关键字作为方法的名称 或字段。允许使用 Kotlin 的软关键字、修饰符关键字和特殊标识…...

3-11单元格区域边界定位(End属性)学习笔记

返回一个Range 对象&#xff0c;只读。该对象代表包含源区域的区域上端下端左端右端的最后一个单元格。等同于按键 End 向上键(End(xlUp))、End向下键(End(xlDown))、End向左键(End(xlToLeft)End向右键(End(xlToRight)) 注意&#xff1a;它移动的位置必须是相连的有内容的单元格…...

零基础在实践中学习网络安全-皮卡丘靶场(第九期-Unsafe Fileupload模块)(yakit方式)

本期内容并不是很难&#xff0c;相信大家会学的很愉快&#xff0c;当然对于有后端基础的朋友来说&#xff0c;本期内容更加容易了解&#xff0c;当然没有基础的也别担心&#xff0c;本期内容会详细解释有关内容 本期用到的软件&#xff1a;yakit&#xff08;因为经过之前好多期…...

网站指纹识别

网站指纹识别 网站的最基本组成&#xff1a;服务器&#xff08;操作系统&#xff09;、中间件&#xff08;web容器&#xff09;、脚本语言、数据厍 为什么要了解这些&#xff1f;举个例子&#xff1a;发现了一个文件读取漏洞&#xff0c;我们需要读/etc/passwd&#xff0c;如…...