当前位置: 首页 > news >正文

目标检测损失函数 yolos、DETR为例

yolos和DETR,除了yolos没有卷积层以外,几乎所有操作都一样。
HF官方文档

因为目标检测模型,实际会输出几百几千个“框”,所以损失函数计算比较复杂。损失函数为偶匹配损失 bipartite matching loss,参考此blog

target为class_label和box组成的字典。假设对于一张图片,我们有5个target框。
num_detection_tokens为模型对一张图最多可以产生的box的数量
简单阐述loss计算流程

  1. vit 模型,输入经过预处理的图片,输出最后隐含层状态, 大小为 [batchsize,seq_len,hidden_size]

  2. 取最后num_detection_tokens个token的隐藏状态,变为
    [batchsize,num_detection_tokens,hidden_size]

  3. 由于输出了num_detection_tokens个box,而target为5个box,所以需要进行一对一的匹配,

  4. 匹配过程:

    1. 先计算3个cost矩阵,shape均为【num_detection_tokens,num_target_box】,矩阵元素代表loss,矩阵代表对所有pred和target之间两两计算一次loss。
    2. 3个cost矩阵分别代表标签loss(交叉熵损失)、坐标loss(表示一个框的4个值的L1损失)、GIoU loss(框与框之间计算GIoU)
    3. 三个cost矩阵加权得到总体cost矩阵,大小为【num_detection_tokens,num_target_box】
    4. 对此矩阵进行linear_sum_assignment操作,得到一个匹配,此匹配下cost最小(即cost矩阵中找到不同行且不同列的5个元素,这5个元素之和最小)。匹配表示为长度为min(num_detection_tokens,num_target_box)的索引对。本例长度为5。
  5. 根据此匹配,pred和target之间计算一次loss(本例中一共计算5次loss并求和),最重loss就是上面说的3种loss的加权和

  6. 其实还有两种loss:

    1. “cardinality” loss,表示输出的num_detection_tokens个class_label中,class_label不为“无目标”的个数,与num_target_box的个数,的L1 loss. 说白了就是,除了5个框有实际的class以外,其他框应尽可能分类为“无目标”,避免检测出来目标过多。但之一loss不产生梯度,仅仅用于评估。
    2. mask loss:功能暂时不清楚

官方匹配函数,匈牙利算法

# Copied from transformers.models.detr.modeling_detr.DetrHungarianMatcher with Detr->Yolos
class YolosHungarianMatcher(nn.Module):"""This class computes an assignment between the targets and the predictions of the network.For efficiency reasons, the targets don't include the no_object. Because of this, in general, there are morepredictions than targets. In this case, we do a 1-to-1 matching of the best predictions, while the others areun-matched (and thus treated as non-objects).Args:class_cost:The relative weight of the classification error in the matching cost.bbox_cost:The relative weight of the L1 error of the bounding box coordinates in the matching cost.giou_cost:The relative weight of the giou loss of the bounding box in the matching cost."""def __init__(self, class_cost: float = 1, bbox_cost: float = 1, giou_cost: float = 1):super().__init__()requires_backends(self, ["scipy"])self.class_cost = class_costself.bbox_cost = bbox_costself.giou_cost = giou_costif class_cost == 0 and bbox_cost == 0 and giou_cost == 0:raise ValueError("All costs of the Matcher can't be 0")@torch.no_grad()def forward(self, outputs, targets):"""Args:outputs (`dict`):A dictionary that contains at least these entries:* "logits": Tensor of dim [batch_size, num_queries, num_classes] with the classification logits* "pred_boxes": Tensor of dim [batch_size, num_queries, 4] with the predicted box coordinates.targets (`List[dict]`):A list of targets (len(targets) = batch_size), where each target is a dict containing:* "class_labels": Tensor of dim [num_target_boxes] (where num_target_boxes is the number ofground-truthobjects in the target) containing the class labels* "boxes": Tensor of dim [num_target_boxes, 4] containing the target box coordinates.Returns:`List[Tuple]`: A list of size `batch_size`, containing tuples of (index_i, index_j) where:- index_i is the indices of the selected predictions (in order)- index_j is the indices of the corresponding selected targets (in order)For each batch element, it holds: len(index_i) = len(index_j) = min(num_queries, num_target_boxes)"""batch_size, num_queries = outputs["logits"].shape[:2]# We flatten to compute the cost matrices in a batchout_prob = outputs["logits"].flatten(0, 1).softmax(-1)  # [batch_size * num_queries, num_classes]out_bbox = outputs["pred_boxes"].flatten(0, 1)  # [batch_size * num_queries, 4]# Also concat the target labels and boxestarget_ids = torch.cat([v["class_labels"] for v in targets])target_bbox = torch.cat([v["boxes"] for v in targets])# Compute the classification cost. Contrary to the loss, we don't use the NLL,# but approximate it in 1 - proba[target class].# The 1 is a constant that doesn't change the matching, it can be ommitted.class_cost = -out_prob[:, target_ids]# Compute the L1 cost between boxesbbox_cost = torch.cdist(out_bbox, target_bbox, p=1)# Compute the giou cost between boxesgiou_cost = -generalized_box_iou(center_to_corners_format(out_bbox), center_to_corners_format(target_bbox))# Final cost matrixcost_matrix = self.bbox_cost * bbox_cost + self.class_cost * class_cost + self.giou_cost * giou_costcost_matrix = cost_matrix.view(batch_size, num_queries, -1).cpu()sizes = [len(v["boxes"]) for v in targets]indices = [linear_sum_assignment(c[i]) for i, c in enumerate(cost_matrix.split(sizes, -1))]return [(torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64)) for i, j in indices]

目标检测还有很多细节问题,以后更新

相关文章:

目标检测损失函数 yolos、DETR为例

yolos和DETR,除了yolos没有卷积层以外,几乎所有操作都一样。 HF官方文档 因为目标检测模型,实际会输出几百几千个“框”,所以损失函数计算比较复杂。损失函数为偶匹配损失 bipartite matching loss,参考此blog targe…...

linux系统编程2--网络编程socket

在linux系统编程中网络编程是使用socket(套接字),socket这个词可以表示很多概念:在TCP/IP协议中,“IP地址TCP或UDP端口号”唯一标识网络通讯中的一个进程,“IP地址端口号”就称为socket。在TCP协议中&#…...

FPGA纯Verilog实现任意尺寸图像缩放,串口指令控制切换,贴近真实项目,提供工程源码和技术支持

目录1、前言2、目前主流的FPGA图像缩放方案3、本方案的优越性4、详细设计方案5、vivado工程详解6、上板调试验证并演示7、福利:工程源码获取1、前言 代码使用纯verilog实现,没有任何ip,可在Xilinx、Intel、国产FPGA间任意移植; 图…...

华为OD机试题 - 最长合法表达式(JavaScript)| 代码+思路+重要知识点

最近更新的博客 华为OD机试题 - 字符串加密(JavaScript) 华为OD机试题 - 字母消消乐(JavaScript) 华为OD机试题 - 字母计数(JavaScript) 华为OD机试题 - 整数分解(JavaScript) 华为OD机试题 - 单词反转(JavaScript) 使用说明 参加华为od机试,一定要注意不要完全背…...

L1-005 考试座位号

L1-005 考试座位号 每个 PAT 考生在参加考试时都会被分配两个座位号,一个是试机座位,一个是考试座位。正常情况下,考生在入场时先得到试机座位号码,入座进入试机状态后,系统会显示该考生的考试座位号码,考试…...

Obsidian + remotely save + 坚果云:实现电脑端和手机端的同步

写在前面:近年来某象笔记广告有增无减,不堪其扰,便转投其它笔记,Obsidian、OneNote、Notion、flomo都略有使用,本人更偏好obsidian操作简单,然其官方同步资费甚高,囊中羞涩,所幸可通…...

对比学习MoCo损失函数infoNCE理解(附代码)

MoCo loss计算采用的损失函数是InfoNCE: ​​ 下面是MoCo的伪代码,MoCo这个loss的实现就是基于cross entropy loss。 将k作为q的正样本,因为k与q是来自同一张图像的不同视图;将queue作为q的负样本,因为queue中含有大量…...

logd守护进程

logd守护进程1、adb logcat命令2、logd守护进程启动2.1 logd文件目录2.2 main方法启动3、LogBuffer缓存大小3.1 缓存大小优先级设置3.2 缓存大小相关代码位置android12-release1、adb logcat命令 命令功能adb bugreport > bugreport.txtbugreport 日志adb shell dmesg >…...

【汽车雷达通往自动驾驶的关键技术】

本文编辑:调皮哥的小助理 现代汽车雷达装置比手机还小,能探测前方、后方或侧方的盲点位置是否存在障碍物,但这还不百分之百实现全自动驾驶的。传统的汽车雷达分辨率都不高,只能“看到”一团东西,可以检测到汽车周围存在…...

2023实习面经

实习面经 秋招笔试面试全记录 字节-电商 字节实习一面: 二分类的损失函数是什么,怎么算?多分类的损失函数怎么算?如果文本分类的标签有多个,比如一个文本同时属于多个label那怎么办?如果文本分类里面的…...

linux shell 入门学习笔记2shell脚本

什么是shell脚本 当命令或者程序语句写在文件中,我们执行文件,读取其中的代码,这个程序就称之为shell脚本。 有了shell脚本肯定是要有对应的解释器了,常见的shell脚本解释器有sh、python、perl、tcl、php、ruby等。一般这种使用文…...

Android稳定性系列-01-使用 Address Sanitizer检测原生代码中的内存错误

前言想必大家曾经被各种Native Crash折磨过,本地测试没啥问题,一到线上或者自动化测试就出现各种SIGSEGV、SIGABRT、SIGILL、SIGBUS、SIGFPE异常,而且堆栈还是崩溃到libc.so这种,看起来跟我们的代码没啥关系,关键还不好…...

HyperOpt-quniform 范围问题

在使用 quniform 的时候,可能会出现超出指定范围的值,例如对于 GBDT 设置参数空间为 learning_rate:hp.quniform(learning_rate,0.05,2.05,0.2),但是仍然会报错 ValueError: learning_rate must be greater than 0 but was 0.0,但…...

Pycharm搭建一个Django项目

File->new project 点击create, 等待一下即可 查看安装 Django 版本: 在 Pycharm 底部选择 Terminal 然后在里面输入:python -m django --version 启动项目: 在 Terminal 里面输入: python manage.py runserver 查看文件目…...

浅析前端工程化中的一部曲——模块化

在日益复杂和多元的 Web 业务背景下,前端工程化经常会被提及。工程化的目的是高性能、稳定性、可用性、可维护性、高效协同,只要是以这几个角度为目标所做的操作,都可成为工程化的一部分。工程化是软件工程中的一种思想,当下的工程…...

新版bing(集成ChatGPT)申请通过后在谷歌浏览器(Chrome)上的使用方法

大家好,我是herosunly。985院校硕士毕业,现担任算法研究员一职,热衷于机器学习算法研究与应用。曾获得阿里云天池比赛第一名,科大讯飞比赛第三名,CCF比赛第四名。拥有多项发明专利。对机器学习和深度学习拥有自己独到的见解。曾经辅导过若干个非计算机专业的学生进入到算法…...

Time-distributed 的理解

前言 今天看到论文中用到 Time-distributed CNN,第一次见到 Time-distributed,不理解是什么含义,看到代码实现也很懵。不管什么网络结构,外面都能套一个TimeDistributed。看了几个博客,还是不明白,问了问C…...

matlab 计算矩阵的Moore-Penrose 伪逆

目录 一、Moore-Penrose 伪逆1、主要函数2、输入输出参数二、代码示例使用伪逆求解线性方程组一、Moore-Penrose 伪逆 Moore-Penrose 伪逆是一种矩阵,可在不存在逆矩阵的情况下作为逆矩阵的部分替代。此矩阵常被用于求解没有唯一解或有许多解的线性方程组。    对于任何矩阵…...

简历制作方面的经验与建议

专栏推荐:2023 数字IC设计秋招复盘——数十家公司笔试题、面试实录 专栏首页:2023 数字IC设计秋招复盘——数十家公司笔试题、面试实录 专栏内容: 笔试复盘篇 2023秋招过程中整理的笔试题,来源包括我自己求职笔试以及整理其他同学的笔试。包含华为、中兴、联发科、AMD、大…...

C语言--static、const、volatile关键字

Static static修饰局部变量改变了变量的生命周期,让静态局部变量出了作用域依然存在,到程序结束,生命周期才结束。 static 修饰局部变量 改变局部变量的生命周期,本质上是改变了局部变量的存储位置,让局部变量不再是…...

PPT|230页| 制造集团企业供应链端到端的数字化解决方案:从需求到结算的全链路业务闭环构建

制造业采购供应链管理是企业运营的核心环节,供应链协同管理在供应链上下游企业之间建立紧密的合作关系,通过信息共享、资源整合、业务协同等方式,实现供应链的全面管理和优化,提高供应链的效率和透明度,降低供应链的成…...

第25节 Node.js 断言测试

Node.js的assert模块主要用于编写程序的单元测试时使用,通过断言可以提早发现和排查出错误。 稳定性: 5 - 锁定 这个模块可用于应用的单元测试,通过 require(assert) 可以使用这个模块。 assert.fail(actual, expected, message, operator) 使用参数…...

根据万维钢·精英日课6的内容,使用AI(2025)可以参考以下方法:

根据万维钢精英日课6的内容,使用AI(2025)可以参考以下方法: 四个洞见 模型已经比人聪明:以ChatGPT o3为代表的AI非常强大,能运用高级理论解释道理、引用最新学术论文,生成对顶尖科学家都有用的…...

Spring数据访问模块设计

前面我们已经完成了IoC和web模块的设计,聪明的码友立马就知道了,该到数据访问模块了,要不就这俩玩个6啊,查库势在必行,至此,它来了。 一、核心设计理念 1、痛点在哪 应用离不开数据(数据库、No…...

Pinocchio 库详解及其在足式机器人上的应用

Pinocchio 库详解及其在足式机器人上的应用 Pinocchio (Pinocchio is not only a nose) 是一个开源的 C 库,专门用于快速计算机器人模型的正向运动学、逆向运动学、雅可比矩阵、动力学和动力学导数。它主要关注效率和准确性,并提供了一个通用的框架&…...

面向无人机海岸带生态系统监测的语义分割基准数据集

描述:海岸带生态系统的监测是维护生态平衡和可持续发展的重要任务。语义分割技术在遥感影像中的应用为海岸带生态系统的精准监测提供了有效手段。然而,目前该领域仍面临一个挑战,即缺乏公开的专门面向海岸带生态系统的语义分割基准数据集。受…...

作为测试我们应该关注redis哪些方面

1、功能测试 数据结构操作:验证字符串、列表、哈希、集合和有序的基本操作是否正确 持久化:测试aof和aof持久化机制,确保数据在开启后正确恢复。 事务:检查事务的原子性和回滚机制。 发布订阅:确保消息正确传递。 2、性…...

脑机新手指南(七):OpenBCI_GUI:从环境搭建到数据可视化(上)

一、OpenBCI_GUI 项目概述 (一)项目背景与目标 OpenBCI 是一个开源的脑电信号采集硬件平台,其配套的 OpenBCI_GUI 则是专为该硬件设计的图形化界面工具。对于研究人员、开发者和学生而言,首次接触 OpenBCI 设备时,往…...

6.9-QT模拟计算器

源码: 头文件: widget.h #ifndef WIDGET_H #define WIDGET_H#include <QWidget> #include <QMouseEvent>QT_BEGIN_NAMESPACE namespace Ui { class Widget; } QT_END_NAMESPACEclass Widget : public QWidget {Q_OBJECTpublic:Widget(QWidget *parent nullptr);…...

跨平台商品数据接口的标准化与规范化发展路径:淘宝京东拼多多的最新实践

在电商行业蓬勃发展的当下&#xff0c;多平台运营已成为众多商家的必然选择。然而&#xff0c;不同电商平台在商品数据接口方面存在差异&#xff0c;导致商家在跨平台运营时面临诸多挑战&#xff0c;如数据对接困难、运营效率低下、用户体验不一致等。跨平台商品数据接口的标准…...