【YOLO系列】YOLOv5 NMS源码理解、更换为DIoU-NMS
代码来源:GitHub - ultralytics/yolov5: YOLOv5 🚀 in PyTorch > ONNX > CoreML > TFLite
使用的代码是YOLOv5 6.1版本
参考笔记:YOLOv5改进系列(八) 更换NMS非极大抑制DIoU-NMS、CIoU-NMS、EIoU-NMS、GIoU-NMS 、SIoU-NMS、Soft-NMS_diou nms-CSDN博客
yolov5 极大值抑制 nms 代码详解 - 金色旭光 - 博客园
https://zhuanlan.zhihu.com/p/511151467
目录
1.NMS源码理解
2.更换DIou-NMS
1.NMS源码理解
YOLOv5中NMS的实现代码在utils/general.py的non_max_suppression
#对推理结果执行NMS
def non_max_suppression(prediction,#模型的预测结果,shape=[batch_size,预测框数量,5+类别数量=中心x+中心y+w+h+conf+类别数量]conf_thres=0.25,#置信度阈值,用于NMS,置信度低于此阈值的预测框会被去除iou_thres=0.45,#IoU阈值,用于NMS,去除冗余的预测框classes=None,#只对某些类别作NMS,None则表示所有类别都作NMSagnostic=False,#是否作类别无关的NMS,即所有预测框不分类别一起作NMS处理,通常不开启,都是各类别各自作NMSmulti_label=False,labels=(),max_det=300#每张图片作NMS之后剩余的最多预测框数):'''函数返回值:返回值output是一个列表,存放每张图片的检测结果eg:output[0]即第一张图片的检测结果,outout[0] shape=[预测框数量,6=xyxy+conf+cls]'''#类别数量ncnc = prediction.shape[2] - 5#符合置信度阈值的预测框bool数组,xc shape=[batch_size,预测框数量]xc = prediction[..., 4] > conf_thres#检查置信度、IoU阈值的有效性assert 0 <= conf_thres <= 1, f'Invalid Confidence threshold {conf_thres}, valid values are between 0.0 and 1.0'assert 0 <= iou_thres <= 1, f'Invalid IoU {iou_thres}, valid values are between 0.0 and 1.0'#设置参数min_wh, max_wh = 2, 4096 #框的最小和最大宽高(像素)max_nms = 30000 #每张图片作NMS之前的最多预测框数time_limit = 10.0 #处理图片超过此时间则退出multi_label &= nc > 1 #没啥用t = time.time() #记录开始时间output = [torch.zeros((0, 6), device=prediction.device)] * prediction.shape[0] #初始化返回值output#遍历每张图像的预测结果for xi, x in enumerate(prediction):'''xi:当前图片在batch中的idx:存放当前图片的预测框信息,shape=[预测框数量,5+类别数量]'''#仅保留大于置信度阈值的预测框,x shape=[预测框数量,5+类别数量]x = x[xc[xi]]#如果存在真实标签,则将其合并到预测结果中(这段代码不知道有什么用)if labels and len(labels[xi]):l = labels[xi] #真实标签v = torch.zeros((len(l), nc + 5), device=x.device) # 初始化与真实标签相同形状的张量v[:, :4] = l[:, 1:5] # 提取真实框的坐标v[:, 4] = 1.0 # 置信度设为1.0v[range(len(l)), l[:, 0].long() + 5] = 1.0 # 设置类别x = torch.cat((x, v), 0) # 合并预测框和真实框#如果预测框数量为0,则处理下一张图片if not x.shape[0]:continue#重置类别概率=conf置信度*原始类别概率x[:, 5:] *= x[:, 4:5]#将坐标值从(中心x, 中心y, w, h)转换为(x1, y1, x2, y2),box shape=[预测框数量,4=xyxy]box = xywh2xyxy(x[:, :4])#通常multi_label为False,执行else部分if multi_label:i, j = (x[:, 5:] > conf_thres).nonzero(as_tuple=False).T # 确定哪些框符合多标签条件x = torch.cat((box[i], x[i, j + 5, None], j[:, None].float()), 1) # 合并框信息else:#将最大类别概率作为检测框的置信度存放于conf中,并将类别索引存放于j中conf, j = x[:, 5:].max(1, keepdim=True)x = torch.cat((box, conf, j.float()), 1)[conf.view(-1) > conf_thres]#合并xyxy+置信度+类别索引'''conf: shape=[预测框数,1=置信度]j: shape=[预测框数,1=类别索引]x: shape=[预测框数,6=xyxy+置信度+类别索引]'''#利用class进行过滤,筛选出指定的class,nms仅仅对指定的class进行nms;#若classes为None,则所有类别都需要作nmsif classes is not None:x = x[(x[:, 5:6] == torch.tensor(classes, device=x.device)).any(1)]#仅保留指定类别的预测框#预测框数量nn = x.shape[0]#如果没有预测框,则处理下一张图片if not n:continueelif n > max_nms: #如果作NMS之前预测框的数量大于max_nms,则按置信度排序并保留前max_nms个框x = x[x[:, 4].argsort(descending=True)[:max_nms]]#Batches NMS#这行代码是在多类别中应用NMS#多类别NMS的处理策略是为了让每个类都能独立执行NMS,所以给所有预测框的坐标值添加一个偏移量#偏移量仅取决于了类别的Id(也就是x[:, 5:6]),并且足够大,使得不同类的预测框不会重叠c = x[:, 5:6] * (0 if agnostic else max_wh)#创建类别偏移c,即c=原类别索引*max_wh#给每个预测框的坐标值加上类别偏移c,boxes shape=[预测框数量,4]boxes = x[:, :4] + c#获取所有预测框的置信度,scores shape=[预测框数量,]scores = x[:, 4]#执行NMS,i存放NMS之后的预测框id,shape=[NMS后的预测框数,]i = torchvision.ops.nms(boxes, scores, iou_thres)#每张图片NMS之后最多剩余max_det个预测框if i.shape[0] > max_det:i = i[:max_det]#将该图片的检测结果存储到输出output中output[xi] = x[i]#如果处理此图片超出时间限制if (time.time() - t) > time_limit:#提示超时print(f'WARNING: NMS time limit {time_limit}s exceeded')break #超时退出#返回值output是一个列表,存放每张图片的检测结果#eg:output[0]即第一张图片的检测结果,outout[0] shape=[预测框数量,6=xyxy+conf+cls]return output #返回每张图片的检测结果
真正作NMS过滤的代码是如下几行代码:
#Batches NMS
#这行代码是在多类别中应用NMS
#多类别NMS的处理策略是为了让每个类都能独立执行NMS,所以给所有预测框的坐标值添加一个偏移量
#偏移量仅取决于了类别的Id(也就是x[:, 5:6]),并且足够大,使得不同类的预测框不会重叠
c = x[:, 5:6] * (0 if agnostic else max_wh)#创建类别偏移c,即c=类别索引*max_whboxes = x[:, :4] + c#给每个预测框的坐标值加上类别偏移c,boxes shape=[预测框数量,4]
scores = x[:, 4]#获取所有预测框的置信度,scores shape=[预测框数量,]#执行NMS,i存放NMS之后的预测框id,shape=[NMS后的预测框数,]
i = torchvision.ops.nms(boxes, scores, iou_thres)
代码重点是在 '+c’这里,c是偏移量
(1)agnostic参数为True,表示所有类别一起作NMS处理,偏移量c为0;
(2)agnostic参数为False,表示按照不同类别分别作NMS处理,c=类别索引*max_wh,对不同类别的预测框做一个偏移操作,防止不同类别的预测框互相影响
注意:源码中是传入参数boxes、scores、iou_thres调用torchvision.ops.nms实现NMS处理,下面是NMS的代码实现。看了下面的NMS代码可以发现上面说agnostic为False时表示按照不同类别分别作NMS处理,但源码这里应该不是特别严格按不同类别作NMS(因为连类别的索引都没有用到),添加偏移量c只是算是一种trick把(我个人的理解,如有错误请指出)
代码流程:
- 将所有预测框按置信度从高到低排序,确保置信度高的预测框排在前面。order存放排序后的预测框索引
- 从置信度最高的框开始(即order[0]),计算它和剩下所有预测框的IoU。剩下的预测框中IoU低于设定的IoU阈值则保留下来,高于IoU阈值的预测框则去除(即在order中删除当前预测框和IoU大于阈值的预测框索引)
- 重复步骤2,直到遍历完order中的预测框,得到最终筛选出来的预测框
import torch
def NMS(boxes,scores, iou_thres):'''boxes:shape=[预测框数量,4=xyxy],存放预测框坐标值scores:shape=[预测框数量,],存放预测框的置信度iou_thres: IoU阈值'''x1 = boxes[:,0]y1 = boxes[:,1]x2 = boxes[:,2]y2 = boxes[:,3]#计算所有预测框的面积areas = (x2-x1)*(y2-y1)#将预测框按置信度从高到低排序,order存放预测框的索引值_,order = scores.sort(0,descending=True)#keep保存NMS之后剩余的预测框索引keep = []while order.numel() > 0:#循环条件'''注意:当order=tensor([2,0,1,3])时,用order[0]可以正常取出第1个值2当order=tensor([3])时,用order[0]取出第1个值3会报错,需要用order.item()取出'''i = order[0] if order.numel()>1 else order.item()#取出置信度最大的预测框索引keep.append(i)#将预测框索引加入keep中#如果只剩余1个预测框,则NMS执行结束if order.numel() == 1:break#计算当前预测框与剩下所有预测框的IoUxx1 = x1[order[1:]].clamp(min=x1[i])yy1 = y1[order[1:]].clamp(min=y1[i])xx2 = x2[order[1:]].clamp(max=x2[i])yy2 = y2[order[1:]].clamp(max=y2[i])w = (xx2-xx1).clamp(min=0)h = (yy2-yy1).clamp(min=0)inter = w*hovr = inter/(areas[i] + areas[order[1:]] - inter)#当前预测框与剩下所有预测框的IoU值#筛选出IOU小于阈值的预测框索引, 过滤掉所有IOU大于阈值的预测框ids = (ovr<=iou_thres).nonzero().squeeze()#重置order数组,丢弃和当前bbox的IOU大于阈值的预测框order = order[ids+1]#这里看代码会有点懵,可以debug一下#torch.LongTensor(keep)将keep列表转换为tensor,shape:[NMS后预测框数量,]return torch.LongTensor(keep)#实例
box = torch.tensor([[2, 3.1, 7, 5], [3, 4, 8, 4.8], [4, 4, 5.6, 7], [0.1, 0, 8, 1]])
score = torch.tensor([0.5, 0.3, 0.2, 0.4])
output =NMS(boxes=box, scores=score, iou_thres=0.3)
print(output)
2.更换DIou-NMS
YOLOv5源码中使用的是IoU-NMS,这里可以作一下改进,将其替换为DIoU-NMS,因为DIoU考虑到的要素比IoU更多,应用于NMS中,可以使得NMS后得到的结果更加合理
第1步:编写DIoU_NMS函数
def DIoU_NMS(boxes,scores, iou_thres):'''boxes:shape=[预测框数量,4=xyxy],存放预测框坐标值scores:shape=[预测框数量,],存放预测框的置信度iou_thres: DIoU阈值'''#将预测框按置信度从高到低排序,order存放预测框的索引值_,order = scores.sort(0,descending=True)#keep保存NMS之后剩余的预测框索引keep = []while order.numel() > 0:#循环条件'''注意:当order=tensor([2,0,1,3])时,用order[0]可以正常取出第1个值2当order=tensor([3])时,用order[0]取出第1个值3会报错,需要用order.item()取出'''i = order[0] if order.numel()>1 else order.item()#取出置信度最大的预测框索引keep.append(i)#将预测框索引加入keep中#如果只剩余1个预测框,则NMS执行结束if order.numel() == 1:break#计算当前预测框与剩下所有预测框的DIoU#boxes[i,:]为当前预测框的坐标值,shape=[4,]#boxes[order[1:],:]为其他预测框的坐标值,shape=[n,4]ovr = bbox_iou(boxes[i, :], boxes[order[1:], :], DIoU=True)#筛选出DIoU小于阈值的预测框索引, 过滤掉所有DIoU大于阈值的预测框ids = (ovr<=iou_thres).nonzero().squeeze()#重置order数组,丢弃和当前bbox的DIoU大于阈值的预测框order = order[ids+1]#这里看代码会有点懵,可以debug一下#torch.LongTensor(keep)将keep列表转换为tensor,shape:[NMS后预测框数量,]return torch.LongTensor(keep)
这里计算DIoU的函数bbox_iou是直接引用了YOLOv5中的代码,该函数的实现在utils/metrics.py中,此函数集成了IoU、GIoU、DIoU、CIoU的计算,其他XIoU_NMS的实现方法类似。PS:GIoU、DIoU、CIoU用于损失函数的情况比较多

最后将DIoU_NMS函数复制到utils/general.py
第2步:将IoU-NMS更换为DIoU-NMS
将utils/general.py下non_max_suppression函数的
i = torchvision.ops.nms(boxes, scores, iou_thres)
替换为
i = DIoU_NMS(boxes, scores, iou_thres)
这样就将IoU-NMS更换为DIoU-NMS了,但是我用几张图片作测试,发现大多数时候使用IoU-NMS和DIoU-NMS的处理结果是完全一致的。如下:

处理结果
所以这种改进可能实际意义不大
更换其他XIoU-NMS的方法是一样的,这里不再赘述
相关文章:
【YOLO系列】YOLOv5 NMS源码理解、更换为DIoU-NMS
代码来源:GitHub - ultralytics/yolov5: YOLOv5 🚀 in PyTorch > ONNX > CoreML > TFLite 使用的代码是YOLOv5 6.1版本 参考笔记:YOLOv5改进系列(八) 更换NMS非极大抑制DIoU-NMS、CIoU-NMS、EIoU-NMS、GIoU-NMS 、SIoU-NMS、Soft-…...
Android RenderEffect对Bitmap高斯模糊(毛玻璃),Kotlin(1)
Android RenderEffect对Bitmap高斯模糊(毛玻璃),Kotlin(1) import android.graphics.Bitmap import android.graphics.BitmapFactory import android.graphics.HardwareRenderer import android.graphics.PixelFormat import android.graphic…...
【linux学习指南】线程同步与互斥
文章目录 📝线程互斥🌠 库函数strncpy🌉进程线程间的互斥相关背景概念🌉互斥量mutex 🌠线程同步🌉条件变量🌉同步概念与竞态条件🌉 条件变量函数 🚩总结 📝线…...
JavaScript函数与方法详解
目录 一、函数的定义 1. 函数声明 2. 函数表达式 3. 箭头函数 二、函数的调用 1. 调用方式 2. 参数数量的灵活性 三、arguments 对象 1. 基本概念 2. 属性 3. 应用场景 4. 转换为真数组 5. 总结 四、Rest参数 1. 基本概念 2. 特点 3. 应用场景 4. 总结 五、变…...
【论文笔记】ZeroGS:扩展Spann3R+GS+pose估计
spann3r是利用dust3r做了增量式的点云重建,这里zeroGS在前者的基础上,进行了增量式的GS重建以及进行了pose的联合优化,这是一篇dust3r与GS结合的具有启发意义的工作。 abstract NeRF和3DGS是重建和渲染逼真图像的流行技术。然而,…...
AtCoder - arc058_d Iroha Loves Strings解答与注意事项
链接:Iroha Loves Strings - AtCoder arc058_d - Virtual Judge 利用bitset这一数据结构,定义bitset类型的变量dp[i]表示第i到n个字符串能拼成的字符串长度都有哪些,比如00100101,表示能拼成的长度有0,2,5,࿰…...
企业使用统一终端管理(UEM)工具提高端点安全性
什么是统一终端管理(UEM) 统一终端管理(UEM)是一种从单个控制台管理和保护企业中所有端点的方法,包括智能手机、平板电脑、笔记本电脑、台式机和 IoT设备。UEM 解决方案为 IT 管理员提供了一个集中式平台,用于跨所有作系统和设备类型部署、配置、管理和…...
Leetcode 算法题 9 回文数
起因, 目的: 数学法。 % 求余数, 拆开组合,组合拆开。 这个题,翻来覆去,拆开组合, 组合拆开。构建的过程。 题目来源,9 回文数: https://leetcode.cn/problems/palindrome-number…...
设计模式Python版 命令模式(上)
文章目录 前言一、命令模式二、命令模式示例 前言 GOF设计模式分三大类: 创建型模式:关注对象的创建过程,包括单例模式、简单工厂模式、工厂方法模式、抽象工厂模式、原型模式和建造者模式。结构型模式:关注类和对象之间的组合&…...
C语言之循环结构:直到型循环
C语言 循环结构 直到型循环的实现 特点:先执行,后判断,不管条件是否满足,至少执行一次。典型代表:do…while,goto(已淘汰,不推荐使用) do…while 语法: d…...
细说STM32F407单片机RTC的备份寄存器原理及使用方法
目录 一、备份寄存器的功能 二、示例功能 三、项目设置 1、晶振、DEBUG、CodeGenerator、USART6 2、RTC 3、NVIC 4、GPIO 及KEYLED 四、软件设计 1、main.h 2、main.c 3、rtc.c 4、keyled.c、keyled.h 五、运行调试 本实例旨在介绍备份寄存器的作用。本实例继续使…...
MATLAB计算反映热需求和能源消耗的度数日指标(HDD+CDD)(全代码)
目录 度数日(Degree Days, DD)概述计算公式MATLAB计算代码调用函数1:计算单站点的 CDD参考度数日(Degree Days, DD)概述 度数日(Degree Days, DD)是用于衡量建筑、城市和地区的热需求和能源消耗模式的指标。它分为两部分: 加热度日(Heating Degree Days, HDD):当室…...
J6 X8B/X3C切换HDR各帧图像
1、OV手册上的切换命令 寄存器为Ox5074 各帧切换: 2、地平线control tool实现切换命令 默认HDR模式出图: HCG出图: LCG出图 SPD出图 VS出图...
09-轮转数组
给定一个整数数组 nums,将数组中的元素向右轮转 k 个位置,其中 k 是非负数。 方法一:使用额外数组 function rotate(nums: number[], k: number): void {const n nums.length;k k % n; // 处理 k 大于数组长度的情况const newNums new A…...
用vue3写一个好看的wiki前端页面
以下是一个使用 Vue 3 Element Plus 实现的 Wiki 风格前端页面示例,包含现代设计、响应式布局和常用功能: <template><div class"wiki-container"><!-- 头部导航 --><el-header class"wiki-header"><d…...
瑞芯微烧写工具
文章目录 前言一、安装驱动二、安装烧写工具1.直接解压压缩包2. 如何使用 三、MASKROM 裸机必备四、LOADER 烧写,前提是搞过第三步没问题五、Update.img包的烧录六、linux下烧写总结 前言 提示:这里可以添加本文要记录的大概内容: 项目需要…...
说下JVM中一次完整的GC流程?
大家好,我是锋哥。今天分享关于【说下JVM中一次完整的GC流程?】面试题。希望对大家有帮助; 说下JVM中一次完整的GC流程? 1000道 互联网大厂Java工程师 精选面试题-Java资源分享网 JVM中的一次完整的垃圾回收(GC)流程可以概括为…...
Open FPV VTX开源之OSD使用分类
Open FPV VTX开源之OSD使用分类 1. 源由2. 硬件2.1 【天空端】SigmaStar2.2 【天空端】Raspberry Pi2.3 【地面端】 3. 软件3.1 天空端软件3.2 地面端软件 4. 分类4.1 嵌入式OSD分类A1-嵌入式OSD:SigmaStar Android分类A2-嵌入式OSD:SigmaStar Hi3536分…...
智慧农业-虫害及生长预测
有害生物防控系统是一个综合性的管理体系,旨在预防和控制对人类生活、生产甚至生存产生危害的生物。这些生物可能包括昆虫、动物、植物、微生物乃至病毒等。 一、系统构成 1、监测预警系统:利用智能传感器、无人机、遥感技术等手段,实时监测…...
Python 识别图片和扫描PDF中的文字
目录 工具与设置 Python 识别图片中的文字 Python 识别图片中的文字及其坐标位置 Python 识别扫描PDF中的文字 注意事项 在处理扫描的PDF和图片时,文字信息往往无法直接编辑、搜索或复制,这给信息提取和分析带来了诸多不便。手动录入信息不仅耗时费…...
Matlab | matlab常用命令总结
常用命令 一、 基础操作与环境二、 矩阵与数组操作(核心)三、 绘图与可视化四、 编程与控制流五、 符号计算 (Symbolic Math Toolbox)六、 文件与数据 I/O七、 常用函数类别重要提示这是一份 MATLAB 常用命令和功能的总结,涵盖了基础操作、矩阵运算、绘图、编程和文件处理等…...
pikachu靶场通关笔记22-1 SQL注入05-1-insert注入(报错法)
目录 一、SQL注入 二、insert注入 三、报错型注入 四、updatexml函数 五、源码审计 六、insert渗透实战 1、渗透准备 2、获取数据库名database 3、获取表名table 4、获取列名column 5、获取字段 本系列为通过《pikachu靶场通关笔记》的SQL注入关卡(共10关࿰…...
分布式增量爬虫实现方案
之前我们在讨论的是分布式爬虫如何实现增量爬取。增量爬虫的目标是只爬取新产生或发生变化的页面,避免重复抓取,以节省资源和时间。 在分布式环境下,增量爬虫的实现需要考虑多个爬虫节点之间的协调和去重。 另一种思路:将增量判…...
Python Ovito统计金刚石结构数量
大家好,我是小马老师。 本文介绍python ovito方法统计金刚石结构的方法。 Ovito Identify diamond structure命令可以识别和统计金刚石结构,但是无法直接输出结构的变化情况。 本文使用python调用ovito包的方法,可以持续统计各步的金刚石结构,具体代码如下: from ovito…...
AI语音助手的Python实现
引言 语音助手(如小爱同学、Siri)通过语音识别、自然语言处理(NLP)和语音合成技术,为用户提供直观、高效的交互体验。随着人工智能的普及,Python开发者可以利用开源库和AI模型,快速构建自定义语音助手。本文由浅入深,详细介绍如何使用Python开发AI语音助手,涵盖基础功…...
C++实现分布式网络通信框架RPC(2)——rpc发布端
有了上篇文章的项目的基本知识的了解,现在我们就开始构建项目。 目录 一、构建工程目录 二、本地服务发布成RPC服务 2.1理解RPC发布 2.2实现 三、Mprpc框架的基础类设计 3.1框架的初始化类 MprpcApplication 代码实现 3.2读取配置文件类 MprpcConfig 代码实现…...
sshd代码修改banner
sshd服务连接之后会收到字符串: SSH-2.0-OpenSSH_9.5 容易被hacker识别此服务为sshd服务。 是否可以通过修改此banner达到让人无法识别此服务的目的呢? 不能。因为这是写的SSH的协议中的。 也就是协议规定了banner必须这么写。 SSH- 开头,…...
pgsql:还原数据库后出现重复序列导致“more than one owned sequence found“报错问题的解决
问题: pgsql数据库通过备份数据库文件进行还原时,如果表中有自增序列,还原后可能会出现重复的序列,此时若向表中插入新行时会出现“more than one owned sequence found”的报错提示。 点击菜单“其它”-》“序列”,…...
Java数组Arrays操作全攻略
Arrays类的概述 Java中的Arrays类位于java.util包中,提供了一系列静态方法用于操作数组(如排序、搜索、填充、比较等)。这些方法适用于基本类型数组和对象数组。 常用成员方法及代码示例 排序(sort) 对数组进行升序…...
算法250609 高精度
加法 #include<stdio.h> #include<iostream> #include<string.h> #include<math.h> #include<algorithm> using namespace std; char input1[205]; char input2[205]; int main(){while(scanf("%s%s",input1,input2)!EOF){int a[205]…...
