【YOLO系列】YOLOv5 NMS源码理解、更换为DIoU-NMS
代码来源:GitHub - ultralytics/yolov5: YOLOv5 🚀 in PyTorch > ONNX > CoreML > TFLite
使用的代码是YOLOv5 6.1版本
参考笔记:YOLOv5改进系列(八) 更换NMS非极大抑制DIoU-NMS、CIoU-NMS、EIoU-NMS、GIoU-NMS 、SIoU-NMS、Soft-NMS_diou nms-CSDN博客
yolov5 极大值抑制 nms 代码详解 - 金色旭光 - 博客园
https://zhuanlan.zhihu.com/p/511151467
目录
1.NMS源码理解
2.更换DIou-NMS
1.NMS源码理解
YOLOv5中NMS的实现代码在utils/general.py的non_max_suppression
#对推理结果执行NMS
def non_max_suppression(prediction,#模型的预测结果,shape=[batch_size,预测框数量,5+类别数量=中心x+中心y+w+h+conf+类别数量]conf_thres=0.25,#置信度阈值,用于NMS,置信度低于此阈值的预测框会被去除iou_thres=0.45,#IoU阈值,用于NMS,去除冗余的预测框classes=None,#只对某些类别作NMS,None则表示所有类别都作NMSagnostic=False,#是否作类别无关的NMS,即所有预测框不分类别一起作NMS处理,通常不开启,都是各类别各自作NMSmulti_label=False,labels=(),max_det=300#每张图片作NMS之后剩余的最多预测框数):'''函数返回值:返回值output是一个列表,存放每张图片的检测结果eg:output[0]即第一张图片的检测结果,outout[0] shape=[预测框数量,6=xyxy+conf+cls]'''#类别数量ncnc = prediction.shape[2] - 5#符合置信度阈值的预测框bool数组,xc shape=[batch_size,预测框数量]xc = prediction[..., 4] > conf_thres#检查置信度、IoU阈值的有效性assert 0 <= conf_thres <= 1, f'Invalid Confidence threshold {conf_thres}, valid values are between 0.0 and 1.0'assert 0 <= iou_thres <= 1, f'Invalid IoU {iou_thres}, valid values are between 0.0 and 1.0'#设置参数min_wh, max_wh = 2, 4096 #框的最小和最大宽高(像素)max_nms = 30000 #每张图片作NMS之前的最多预测框数time_limit = 10.0 #处理图片超过此时间则退出multi_label &= nc > 1 #没啥用t = time.time() #记录开始时间output = [torch.zeros((0, 6), device=prediction.device)] * prediction.shape[0] #初始化返回值output#遍历每张图像的预测结果for xi, x in enumerate(prediction):'''xi:当前图片在batch中的idx:存放当前图片的预测框信息,shape=[预测框数量,5+类别数量]'''#仅保留大于置信度阈值的预测框,x shape=[预测框数量,5+类别数量]x = x[xc[xi]]#如果存在真实标签,则将其合并到预测结果中(这段代码不知道有什么用)if labels and len(labels[xi]):l = labels[xi] #真实标签v = torch.zeros((len(l), nc + 5), device=x.device) # 初始化与真实标签相同形状的张量v[:, :4] = l[:, 1:5] # 提取真实框的坐标v[:, 4] = 1.0 # 置信度设为1.0v[range(len(l)), l[:, 0].long() + 5] = 1.0 # 设置类别x = torch.cat((x, v), 0) # 合并预测框和真实框#如果预测框数量为0,则处理下一张图片if not x.shape[0]:continue#重置类别概率=conf置信度*原始类别概率x[:, 5:] *= x[:, 4:5]#将坐标值从(中心x, 中心y, w, h)转换为(x1, y1, x2, y2),box shape=[预测框数量,4=xyxy]box = xywh2xyxy(x[:, :4])#通常multi_label为False,执行else部分if multi_label:i, j = (x[:, 5:] > conf_thres).nonzero(as_tuple=False).T # 确定哪些框符合多标签条件x = torch.cat((box[i], x[i, j + 5, None], j[:, None].float()), 1) # 合并框信息else:#将最大类别概率作为检测框的置信度存放于conf中,并将类别索引存放于j中conf, j = x[:, 5:].max(1, keepdim=True)x = torch.cat((box, conf, j.float()), 1)[conf.view(-1) > conf_thres]#合并xyxy+置信度+类别索引'''conf: shape=[预测框数,1=置信度]j: shape=[预测框数,1=类别索引]x: shape=[预测框数,6=xyxy+置信度+类别索引]'''#利用class进行过滤,筛选出指定的class,nms仅仅对指定的class进行nms;#若classes为None,则所有类别都需要作nmsif classes is not None:x = x[(x[:, 5:6] == torch.tensor(classes, device=x.device)).any(1)]#仅保留指定类别的预测框#预测框数量nn = x.shape[0]#如果没有预测框,则处理下一张图片if not n:continueelif n > max_nms: #如果作NMS之前预测框的数量大于max_nms,则按置信度排序并保留前max_nms个框x = x[x[:, 4].argsort(descending=True)[:max_nms]]#Batches NMS#这行代码是在多类别中应用NMS#多类别NMS的处理策略是为了让每个类都能独立执行NMS,所以给所有预测框的坐标值添加一个偏移量#偏移量仅取决于了类别的Id(也就是x[:, 5:6]),并且足够大,使得不同类的预测框不会重叠c = x[:, 5:6] * (0 if agnostic else max_wh)#创建类别偏移c,即c=原类别索引*max_wh#给每个预测框的坐标值加上类别偏移c,boxes shape=[预测框数量,4]boxes = x[:, :4] + c#获取所有预测框的置信度,scores shape=[预测框数量,]scores = x[:, 4]#执行NMS,i存放NMS之后的预测框id,shape=[NMS后的预测框数,]i = torchvision.ops.nms(boxes, scores, iou_thres)#每张图片NMS之后最多剩余max_det个预测框if i.shape[0] > max_det:i = i[:max_det]#将该图片的检测结果存储到输出output中output[xi] = x[i]#如果处理此图片超出时间限制if (time.time() - t) > time_limit:#提示超时print(f'WARNING: NMS time limit {time_limit}s exceeded')break #超时退出#返回值output是一个列表,存放每张图片的检测结果#eg:output[0]即第一张图片的检测结果,outout[0] shape=[预测框数量,6=xyxy+conf+cls]return output #返回每张图片的检测结果
真正作NMS过滤的代码是如下几行代码:
#Batches NMS
#这行代码是在多类别中应用NMS
#多类别NMS的处理策略是为了让每个类都能独立执行NMS,所以给所有预测框的坐标值添加一个偏移量
#偏移量仅取决于了类别的Id(也就是x[:, 5:6]),并且足够大,使得不同类的预测框不会重叠
c = x[:, 5:6] * (0 if agnostic else max_wh)#创建类别偏移c,即c=类别索引*max_whboxes = x[:, :4] + c#给每个预测框的坐标值加上类别偏移c,boxes shape=[预测框数量,4]
scores = x[:, 4]#获取所有预测框的置信度,scores shape=[预测框数量,]#执行NMS,i存放NMS之后的预测框id,shape=[NMS后的预测框数,]
i = torchvision.ops.nms(boxes, scores, iou_thres)
代码重点是在 '+c’这里,c是偏移量
(1)agnostic参数为True,表示所有类别一起作NMS处理,偏移量c为0;
(2)agnostic参数为False,表示按照不同类别分别作NMS处理,c=类别索引*max_wh,对不同类别的预测框做一个偏移操作,防止不同类别的预测框互相影响
注意:源码中是传入参数boxes、scores、iou_thres调用torchvision.ops.nms实现NMS处理,下面是NMS的代码实现。看了下面的NMS代码可以发现上面说agnostic为False时表示按照不同类别分别作NMS处理,但源码这里应该不是特别严格按不同类别作NMS(因为连类别的索引都没有用到),添加偏移量c只是算是一种trick把(我个人的理解,如有错误请指出)
代码流程:
- 将所有预测框按置信度从高到低排序,确保置信度高的预测框排在前面。order存放排序后的预测框索引
- 从置信度最高的框开始(即order[0]),计算它和剩下所有预测框的IoU。剩下的预测框中IoU低于设定的IoU阈值则保留下来,高于IoU阈值的预测框则去除(即在order中删除当前预测框和IoU大于阈值的预测框索引)
- 重复步骤2,直到遍历完order中的预测框,得到最终筛选出来的预测框
import torch
def NMS(boxes,scores, iou_thres):'''boxes:shape=[预测框数量,4=xyxy],存放预测框坐标值scores:shape=[预测框数量,],存放预测框的置信度iou_thres: IoU阈值'''x1 = boxes[:,0]y1 = boxes[:,1]x2 = boxes[:,2]y2 = boxes[:,3]#计算所有预测框的面积areas = (x2-x1)*(y2-y1)#将预测框按置信度从高到低排序,order存放预测框的索引值_,order = scores.sort(0,descending=True)#keep保存NMS之后剩余的预测框索引keep = []while order.numel() > 0:#循环条件'''注意:当order=tensor([2,0,1,3])时,用order[0]可以正常取出第1个值2当order=tensor([3])时,用order[0]取出第1个值3会报错,需要用order.item()取出'''i = order[0] if order.numel()>1 else order.item()#取出置信度最大的预测框索引keep.append(i)#将预测框索引加入keep中#如果只剩余1个预测框,则NMS执行结束if order.numel() == 1:break#计算当前预测框与剩下所有预测框的IoUxx1 = x1[order[1:]].clamp(min=x1[i])yy1 = y1[order[1:]].clamp(min=y1[i])xx2 = x2[order[1:]].clamp(max=x2[i])yy2 = y2[order[1:]].clamp(max=y2[i])w = (xx2-xx1).clamp(min=0)h = (yy2-yy1).clamp(min=0)inter = w*hovr = inter/(areas[i] + areas[order[1:]] - inter)#当前预测框与剩下所有预测框的IoU值#筛选出IOU小于阈值的预测框索引, 过滤掉所有IOU大于阈值的预测框ids = (ovr<=iou_thres).nonzero().squeeze()#重置order数组,丢弃和当前bbox的IOU大于阈值的预测框order = order[ids+1]#这里看代码会有点懵,可以debug一下#torch.LongTensor(keep)将keep列表转换为tensor,shape:[NMS后预测框数量,]return torch.LongTensor(keep)#实例
box = torch.tensor([[2, 3.1, 7, 5], [3, 4, 8, 4.8], [4, 4, 5.6, 7], [0.1, 0, 8, 1]])
score = torch.tensor([0.5, 0.3, 0.2, 0.4])
output =NMS(boxes=box, scores=score, iou_thres=0.3)
print(output)
2.更换DIou-NMS
YOLOv5源码中使用的是IoU-NMS,这里可以作一下改进,将其替换为DIoU-NMS,因为DIoU考虑到的要素比IoU更多,应用于NMS中,可以使得NMS后得到的结果更加合理
第1步:编写DIoU_NMS函数
def DIoU_NMS(boxes,scores, iou_thres):'''boxes:shape=[预测框数量,4=xyxy],存放预测框坐标值scores:shape=[预测框数量,],存放预测框的置信度iou_thres: DIoU阈值'''#将预测框按置信度从高到低排序,order存放预测框的索引值_,order = scores.sort(0,descending=True)#keep保存NMS之后剩余的预测框索引keep = []while order.numel() > 0:#循环条件'''注意:当order=tensor([2,0,1,3])时,用order[0]可以正常取出第1个值2当order=tensor([3])时,用order[0]取出第1个值3会报错,需要用order.item()取出'''i = order[0] if order.numel()>1 else order.item()#取出置信度最大的预测框索引keep.append(i)#将预测框索引加入keep中#如果只剩余1个预测框,则NMS执行结束if order.numel() == 1:break#计算当前预测框与剩下所有预测框的DIoU#boxes[i,:]为当前预测框的坐标值,shape=[4,]#boxes[order[1:],:]为其他预测框的坐标值,shape=[n,4]ovr = bbox_iou(boxes[i, :], boxes[order[1:], :], DIoU=True)#筛选出DIoU小于阈值的预测框索引, 过滤掉所有DIoU大于阈值的预测框ids = (ovr<=iou_thres).nonzero().squeeze()#重置order数组,丢弃和当前bbox的DIoU大于阈值的预测框order = order[ids+1]#这里看代码会有点懵,可以debug一下#torch.LongTensor(keep)将keep列表转换为tensor,shape:[NMS后预测框数量,]return torch.LongTensor(keep)
这里计算DIoU的函数bbox_iou是直接引用了YOLOv5中的代码,该函数的实现在utils/metrics.py中,此函数集成了IoU、GIoU、DIoU、CIoU的计算,其他XIoU_NMS的实现方法类似。PS:GIoU、DIoU、CIoU用于损失函数的情况比较多

最后将DIoU_NMS函数复制到utils/general.py
第2步:将IoU-NMS更换为DIoU-NMS
将utils/general.py下non_max_suppression函数的
i = torchvision.ops.nms(boxes, scores, iou_thres)
替换为
i = DIoU_NMS(boxes, scores, iou_thres)
这样就将IoU-NMS更换为DIoU-NMS了,但是我用几张图片作测试,发现大多数时候使用IoU-NMS和DIoU-NMS的处理结果是完全一致的。如下:

处理结果
所以这种改进可能实际意义不大
更换其他XIoU-NMS的方法是一样的,这里不再赘述
相关文章:
【YOLO系列】YOLOv5 NMS源码理解、更换为DIoU-NMS
代码来源:GitHub - ultralytics/yolov5: YOLOv5 🚀 in PyTorch > ONNX > CoreML > TFLite 使用的代码是YOLOv5 6.1版本 参考笔记:YOLOv5改进系列(八) 更换NMS非极大抑制DIoU-NMS、CIoU-NMS、EIoU-NMS、GIoU-NMS 、SIoU-NMS、Soft-…...
Android RenderEffect对Bitmap高斯模糊(毛玻璃),Kotlin(1)
Android RenderEffect对Bitmap高斯模糊(毛玻璃),Kotlin(1) import android.graphics.Bitmap import android.graphics.BitmapFactory import android.graphics.HardwareRenderer import android.graphics.PixelFormat import android.graphic…...
【linux学习指南】线程同步与互斥
文章目录 📝线程互斥🌠 库函数strncpy🌉进程线程间的互斥相关背景概念🌉互斥量mutex 🌠线程同步🌉条件变量🌉同步概念与竞态条件🌉 条件变量函数 🚩总结 📝线…...
JavaScript函数与方法详解
目录 一、函数的定义 1. 函数声明 2. 函数表达式 3. 箭头函数 二、函数的调用 1. 调用方式 2. 参数数量的灵活性 三、arguments 对象 1. 基本概念 2. 属性 3. 应用场景 4. 转换为真数组 5. 总结 四、Rest参数 1. 基本概念 2. 特点 3. 应用场景 4. 总结 五、变…...
【论文笔记】ZeroGS:扩展Spann3R+GS+pose估计
spann3r是利用dust3r做了增量式的点云重建,这里zeroGS在前者的基础上,进行了增量式的GS重建以及进行了pose的联合优化,这是一篇dust3r与GS结合的具有启发意义的工作。 abstract NeRF和3DGS是重建和渲染逼真图像的流行技术。然而,…...
AtCoder - arc058_d Iroha Loves Strings解答与注意事项
链接:Iroha Loves Strings - AtCoder arc058_d - Virtual Judge 利用bitset这一数据结构,定义bitset类型的变量dp[i]表示第i到n个字符串能拼成的字符串长度都有哪些,比如00100101,表示能拼成的长度有0,2,5,࿰…...
企业使用统一终端管理(UEM)工具提高端点安全性
什么是统一终端管理(UEM) 统一终端管理(UEM)是一种从单个控制台管理和保护企业中所有端点的方法,包括智能手机、平板电脑、笔记本电脑、台式机和 IoT设备。UEM 解决方案为 IT 管理员提供了一个集中式平台,用于跨所有作系统和设备类型部署、配置、管理和…...
Leetcode 算法题 9 回文数
起因, 目的: 数学法。 % 求余数, 拆开组合,组合拆开。 这个题,翻来覆去,拆开组合, 组合拆开。构建的过程。 题目来源,9 回文数: https://leetcode.cn/problems/palindrome-number…...
设计模式Python版 命令模式(上)
文章目录 前言一、命令模式二、命令模式示例 前言 GOF设计模式分三大类: 创建型模式:关注对象的创建过程,包括单例模式、简单工厂模式、工厂方法模式、抽象工厂模式、原型模式和建造者模式。结构型模式:关注类和对象之间的组合&…...
C语言之循环结构:直到型循环
C语言 循环结构 直到型循环的实现 特点:先执行,后判断,不管条件是否满足,至少执行一次。典型代表:do…while,goto(已淘汰,不推荐使用) do…while 语法: d…...
细说STM32F407单片机RTC的备份寄存器原理及使用方法
目录 一、备份寄存器的功能 二、示例功能 三、项目设置 1、晶振、DEBUG、CodeGenerator、USART6 2、RTC 3、NVIC 4、GPIO 及KEYLED 四、软件设计 1、main.h 2、main.c 3、rtc.c 4、keyled.c、keyled.h 五、运行调试 本实例旨在介绍备份寄存器的作用。本实例继续使…...
MATLAB计算反映热需求和能源消耗的度数日指标(HDD+CDD)(全代码)
目录 度数日(Degree Days, DD)概述计算公式MATLAB计算代码调用函数1:计算单站点的 CDD参考度数日(Degree Days, DD)概述 度数日(Degree Days, DD)是用于衡量建筑、城市和地区的热需求和能源消耗模式的指标。它分为两部分: 加热度日(Heating Degree Days, HDD):当室…...
J6 X8B/X3C切换HDR各帧图像
1、OV手册上的切换命令 寄存器为Ox5074 各帧切换: 2、地平线control tool实现切换命令 默认HDR模式出图: HCG出图: LCG出图 SPD出图 VS出图...
09-轮转数组
给定一个整数数组 nums,将数组中的元素向右轮转 k 个位置,其中 k 是非负数。 方法一:使用额外数组 function rotate(nums: number[], k: number): void {const n nums.length;k k % n; // 处理 k 大于数组长度的情况const newNums new A…...
用vue3写一个好看的wiki前端页面
以下是一个使用 Vue 3 Element Plus 实现的 Wiki 风格前端页面示例,包含现代设计、响应式布局和常用功能: <template><div class"wiki-container"><!-- 头部导航 --><el-header class"wiki-header"><d…...
瑞芯微烧写工具
文章目录 前言一、安装驱动二、安装烧写工具1.直接解压压缩包2. 如何使用 三、MASKROM 裸机必备四、LOADER 烧写,前提是搞过第三步没问题五、Update.img包的烧录六、linux下烧写总结 前言 提示:这里可以添加本文要记录的大概内容: 项目需要…...
说下JVM中一次完整的GC流程?
大家好,我是锋哥。今天分享关于【说下JVM中一次完整的GC流程?】面试题。希望对大家有帮助; 说下JVM中一次完整的GC流程? 1000道 互联网大厂Java工程师 精选面试题-Java资源分享网 JVM中的一次完整的垃圾回收(GC)流程可以概括为…...
Open FPV VTX开源之OSD使用分类
Open FPV VTX开源之OSD使用分类 1. 源由2. 硬件2.1 【天空端】SigmaStar2.2 【天空端】Raspberry Pi2.3 【地面端】 3. 软件3.1 天空端软件3.2 地面端软件 4. 分类4.1 嵌入式OSD分类A1-嵌入式OSD:SigmaStar Android分类A2-嵌入式OSD:SigmaStar Hi3536分…...
智慧农业-虫害及生长预测
有害生物防控系统是一个综合性的管理体系,旨在预防和控制对人类生活、生产甚至生存产生危害的生物。这些生物可能包括昆虫、动物、植物、微生物乃至病毒等。 一、系统构成 1、监测预警系统:利用智能传感器、无人机、遥感技术等手段,实时监测…...
Python 识别图片和扫描PDF中的文字
目录 工具与设置 Python 识别图片中的文字 Python 识别图片中的文字及其坐标位置 Python 识别扫描PDF中的文字 注意事项 在处理扫描的PDF和图片时,文字信息往往无法直接编辑、搜索或复制,这给信息提取和分析带来了诸多不便。手动录入信息不仅耗时费…...
Linux 文件类型,目录与路径,文件与目录管理
文件类型 后面的字符表示文件类型标志 普通文件:-(纯文本文件,二进制文件,数据格式文件) 如文本文件、图片、程序文件等。 目录文件:d(directory) 用来存放其他文件或子目录。 设备…...
树莓派超全系列教程文档--(61)树莓派摄像头高级使用方法
树莓派摄像头高级使用方法 配置通过调谐文件来调整相机行为 使用多个摄像头安装 libcam 和 rpicam-apps依赖关系开发包 文章来源: http://raspberry.dns8844.cn/documentation 原文网址 配置 大多数用例自动工作,无需更改相机配置。但是,一…...
Spark 之 入门讲解详细版(1)
1、简介 1.1 Spark简介 Spark是加州大学伯克利分校AMP实验室(Algorithms, Machines, and People Lab)开发通用内存并行计算框架。Spark在2013年6月进入Apache成为孵化项目,8个月后成为Apache顶级项目,速度之快足见过人之处&…...
【Linux】C语言执行shell指令
在C语言中执行Shell指令 在C语言中,有几种方法可以执行Shell指令: 1. 使用system()函数 这是最简单的方法,包含在stdlib.h头文件中: #include <stdlib.h>int main() {system("ls -l"); // 执行ls -l命令retu…...
IGP(Interior Gateway Protocol,内部网关协议)
IGP(Interior Gateway Protocol,内部网关协议) 是一种用于在一个自治系统(AS)内部传递路由信息的路由协议,主要用于在一个组织或机构的内部网络中决定数据包的最佳路径。与用于自治系统之间通信的 EGP&…...
基于数字孪生的水厂可视化平台建设:架构与实践
分享大纲: 1、数字孪生水厂可视化平台建设背景 2、数字孪生水厂可视化平台建设架构 3、数字孪生水厂可视化平台建设成效 近几年,数字孪生水厂的建设开展的如火如荼。作为提升水厂管理效率、优化资源的调度手段,基于数字孪生的水厂可视化平台的…...
ETLCloud可能遇到的问题有哪些?常见坑位解析
数据集成平台ETLCloud,主要用于支持数据的抽取(Extract)、转换(Transform)和加载(Load)过程。提供了一个简洁直观的界面,以便用户可以在不同的数据源之间轻松地进行数据迁移和转换。…...
04-初识css
一、css样式引入 1.1.内部样式 <div style"width: 100px;"></div>1.2.外部样式 1.2.1.外部样式1 <style>.aa {width: 100px;} </style> <div class"aa"></div>1.2.2.外部样式2 <!-- rel内表面引入的是style样…...
(转)什么是DockerCompose?它有什么作用?
一、什么是DockerCompose? DockerCompose可以基于Compose文件帮我们快速的部署分布式应用,而无需手动一个个创建和运行容器。 Compose文件是一个文本文件,通过指令定义集群中的每个容器如何运行。 DockerCompose就是把DockerFile转换成指令去运行。 …...
深度学习习题2
1.如果增加神经网络的宽度,精确度会增加到一个特定阈值后,便开始降低。造成这一现象的可能原因是什么? A、即使增加卷积核的数量,只有少部分的核会被用作预测 B、当卷积核数量增加时,神经网络的预测能力会降低 C、当卷…...
