当前位置: 首页 > news >正文

目标检测损失函数 yolos、DETR为例

yolos和DETR,除了yolos没有卷积层以外,几乎所有操作都一样。
HF官方文档

因为目标检测模型,实际会输出几百几千个“框”,所以损失函数计算比较复杂。损失函数为偶匹配损失 bipartite matching loss,参考此blog

target为class_label和box组成的字典。假设对于一张图片,我们有5个target框。
num_detection_tokens为模型对一张图最多可以产生的box的数量
简单阐述loss计算流程

  1. vit 模型,输入经过预处理的图片,输出最后隐含层状态, 大小为 [batchsize,seq_len,hidden_size]

  2. 取最后num_detection_tokens个token的隐藏状态,变为
    [batchsize,num_detection_tokens,hidden_size]

  3. 由于输出了num_detection_tokens个box,而target为5个box,所以需要进行一对一的匹配,

  4. 匹配过程:

    1. 先计算3个cost矩阵,shape均为【num_detection_tokens,num_target_box】,矩阵元素代表loss,矩阵代表对所有pred和target之间两两计算一次loss。
    2. 3个cost矩阵分别代表标签loss(交叉熵损失)、坐标loss(表示一个框的4个值的L1损失)、GIoU loss(框与框之间计算GIoU)
    3. 三个cost矩阵加权得到总体cost矩阵,大小为【num_detection_tokens,num_target_box】
    4. 对此矩阵进行linear_sum_assignment操作,得到一个匹配,此匹配下cost最小(即cost矩阵中找到不同行且不同列的5个元素,这5个元素之和最小)。匹配表示为长度为min(num_detection_tokens,num_target_box)的索引对。本例长度为5。
  5. 根据此匹配,pred和target之间计算一次loss(本例中一共计算5次loss并求和),最重loss就是上面说的3种loss的加权和

  6. 其实还有两种loss:

    1. “cardinality” loss,表示输出的num_detection_tokens个class_label中,class_label不为“无目标”的个数,与num_target_box的个数,的L1 loss. 说白了就是,除了5个框有实际的class以外,其他框应尽可能分类为“无目标”,避免检测出来目标过多。但之一loss不产生梯度,仅仅用于评估。
    2. mask loss:功能暂时不清楚

官方匹配函数,匈牙利算法

# Copied from transformers.models.detr.modeling_detr.DetrHungarianMatcher with Detr->Yolos
class YolosHungarianMatcher(nn.Module):"""This class computes an assignment between the targets and the predictions of the network.For efficiency reasons, the targets don't include the no_object. Because of this, in general, there are morepredictions than targets. In this case, we do a 1-to-1 matching of the best predictions, while the others areun-matched (and thus treated as non-objects).Args:class_cost:The relative weight of the classification error in the matching cost.bbox_cost:The relative weight of the L1 error of the bounding box coordinates in the matching cost.giou_cost:The relative weight of the giou loss of the bounding box in the matching cost."""def __init__(self, class_cost: float = 1, bbox_cost: float = 1, giou_cost: float = 1):super().__init__()requires_backends(self, ["scipy"])self.class_cost = class_costself.bbox_cost = bbox_costself.giou_cost = giou_costif class_cost == 0 and bbox_cost == 0 and giou_cost == 0:raise ValueError("All costs of the Matcher can't be 0")@torch.no_grad()def forward(self, outputs, targets):"""Args:outputs (`dict`):A dictionary that contains at least these entries:* "logits": Tensor of dim [batch_size, num_queries, num_classes] with the classification logits* "pred_boxes": Tensor of dim [batch_size, num_queries, 4] with the predicted box coordinates.targets (`List[dict]`):A list of targets (len(targets) = batch_size), where each target is a dict containing:* "class_labels": Tensor of dim [num_target_boxes] (where num_target_boxes is the number ofground-truthobjects in the target) containing the class labels* "boxes": Tensor of dim [num_target_boxes, 4] containing the target box coordinates.Returns:`List[Tuple]`: A list of size `batch_size`, containing tuples of (index_i, index_j) where:- index_i is the indices of the selected predictions (in order)- index_j is the indices of the corresponding selected targets (in order)For each batch element, it holds: len(index_i) = len(index_j) = min(num_queries, num_target_boxes)"""batch_size, num_queries = outputs["logits"].shape[:2]# We flatten to compute the cost matrices in a batchout_prob = outputs["logits"].flatten(0, 1).softmax(-1)  # [batch_size * num_queries, num_classes]out_bbox = outputs["pred_boxes"].flatten(0, 1)  # [batch_size * num_queries, 4]# Also concat the target labels and boxestarget_ids = torch.cat([v["class_labels"] for v in targets])target_bbox = torch.cat([v["boxes"] for v in targets])# Compute the classification cost. Contrary to the loss, we don't use the NLL,# but approximate it in 1 - proba[target class].# The 1 is a constant that doesn't change the matching, it can be ommitted.class_cost = -out_prob[:, target_ids]# Compute the L1 cost between boxesbbox_cost = torch.cdist(out_bbox, target_bbox, p=1)# Compute the giou cost between boxesgiou_cost = -generalized_box_iou(center_to_corners_format(out_bbox), center_to_corners_format(target_bbox))# Final cost matrixcost_matrix = self.bbox_cost * bbox_cost + self.class_cost * class_cost + self.giou_cost * giou_costcost_matrix = cost_matrix.view(batch_size, num_queries, -1).cpu()sizes = [len(v["boxes"]) for v in targets]indices = [linear_sum_assignment(c[i]) for i, c in enumerate(cost_matrix.split(sizes, -1))]return [(torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64)) for i, j in indices]

目标检测还有很多细节问题,以后更新

相关文章:

目标检测损失函数 yolos、DETR为例

yolos和DETR,除了yolos没有卷积层以外,几乎所有操作都一样。 HF官方文档 因为目标检测模型,实际会输出几百几千个“框”,所以损失函数计算比较复杂。损失函数为偶匹配损失 bipartite matching loss,参考此blog targe…...

linux系统编程2--网络编程socket

在linux系统编程中网络编程是使用socket(套接字),socket这个词可以表示很多概念:在TCP/IP协议中,“IP地址TCP或UDP端口号”唯一标识网络通讯中的一个进程,“IP地址端口号”就称为socket。在TCP协议中&#…...

FPGA纯Verilog实现任意尺寸图像缩放,串口指令控制切换,贴近真实项目,提供工程源码和技术支持

目录1、前言2、目前主流的FPGA图像缩放方案3、本方案的优越性4、详细设计方案5、vivado工程详解6、上板调试验证并演示7、福利:工程源码获取1、前言 代码使用纯verilog实现,没有任何ip,可在Xilinx、Intel、国产FPGA间任意移植; 图…...

华为OD机试题 - 最长合法表达式(JavaScript)| 代码+思路+重要知识点

最近更新的博客 华为OD机试题 - 字符串加密(JavaScript) 华为OD机试题 - 字母消消乐(JavaScript) 华为OD机试题 - 字母计数(JavaScript) 华为OD机试题 - 整数分解(JavaScript) 华为OD机试题 - 单词反转(JavaScript) 使用说明 参加华为od机试,一定要注意不要完全背…...

L1-005 考试座位号

L1-005 考试座位号 每个 PAT 考生在参加考试时都会被分配两个座位号,一个是试机座位,一个是考试座位。正常情况下,考生在入场时先得到试机座位号码,入座进入试机状态后,系统会显示该考生的考试座位号码,考试…...

Obsidian + remotely save + 坚果云:实现电脑端和手机端的同步

写在前面:近年来某象笔记广告有增无减,不堪其扰,便转投其它笔记,Obsidian、OneNote、Notion、flomo都略有使用,本人更偏好obsidian操作简单,然其官方同步资费甚高,囊中羞涩,所幸可通…...

对比学习MoCo损失函数infoNCE理解(附代码)

MoCo loss计算采用的损失函数是InfoNCE: ​​ 下面是MoCo的伪代码,MoCo这个loss的实现就是基于cross entropy loss。 将k作为q的正样本,因为k与q是来自同一张图像的不同视图;将queue作为q的负样本,因为queue中含有大量…...

logd守护进程

logd守护进程1、adb logcat命令2、logd守护进程启动2.1 logd文件目录2.2 main方法启动3、LogBuffer缓存大小3.1 缓存大小优先级设置3.2 缓存大小相关代码位置android12-release1、adb logcat命令 命令功能adb bugreport > bugreport.txtbugreport 日志adb shell dmesg >…...

【汽车雷达通往自动驾驶的关键技术】

本文编辑:调皮哥的小助理 现代汽车雷达装置比手机还小,能探测前方、后方或侧方的盲点位置是否存在障碍物,但这还不百分之百实现全自动驾驶的。传统的汽车雷达分辨率都不高,只能“看到”一团东西,可以检测到汽车周围存在…...

2023实习面经

实习面经 秋招笔试面试全记录 字节-电商 字节实习一面: 二分类的损失函数是什么,怎么算?多分类的损失函数怎么算?如果文本分类的标签有多个,比如一个文本同时属于多个label那怎么办?如果文本分类里面的…...

linux shell 入门学习笔记2shell脚本

什么是shell脚本 当命令或者程序语句写在文件中,我们执行文件,读取其中的代码,这个程序就称之为shell脚本。 有了shell脚本肯定是要有对应的解释器了,常见的shell脚本解释器有sh、python、perl、tcl、php、ruby等。一般这种使用文…...

Android稳定性系列-01-使用 Address Sanitizer检测原生代码中的内存错误

前言想必大家曾经被各种Native Crash折磨过,本地测试没啥问题,一到线上或者自动化测试就出现各种SIGSEGV、SIGABRT、SIGILL、SIGBUS、SIGFPE异常,而且堆栈还是崩溃到libc.so这种,看起来跟我们的代码没啥关系,关键还不好…...

HyperOpt-quniform 范围问题

在使用 quniform 的时候,可能会出现超出指定范围的值,例如对于 GBDT 设置参数空间为 learning_rate:hp.quniform(learning_rate,0.05,2.05,0.2),但是仍然会报错 ValueError: learning_rate must be greater than 0 but was 0.0,但…...

Pycharm搭建一个Django项目

File->new project 点击create, 等待一下即可 查看安装 Django 版本: 在 Pycharm 底部选择 Terminal 然后在里面输入:python -m django --version 启动项目: 在 Terminal 里面输入: python manage.py runserver 查看文件目…...

浅析前端工程化中的一部曲——模块化

在日益复杂和多元的 Web 业务背景下,前端工程化经常会被提及。工程化的目的是高性能、稳定性、可用性、可维护性、高效协同,只要是以这几个角度为目标所做的操作,都可成为工程化的一部分。工程化是软件工程中的一种思想,当下的工程…...

新版bing(集成ChatGPT)申请通过后在谷歌浏览器(Chrome)上的使用方法

大家好,我是herosunly。985院校硕士毕业,现担任算法研究员一职,热衷于机器学习算法研究与应用。曾获得阿里云天池比赛第一名,科大讯飞比赛第三名,CCF比赛第四名。拥有多项发明专利。对机器学习和深度学习拥有自己独到的见解。曾经辅导过若干个非计算机专业的学生进入到算法…...

Time-distributed 的理解

前言 今天看到论文中用到 Time-distributed CNN,第一次见到 Time-distributed,不理解是什么含义,看到代码实现也很懵。不管什么网络结构,外面都能套一个TimeDistributed。看了几个博客,还是不明白,问了问C…...

matlab 计算矩阵的Moore-Penrose 伪逆

目录 一、Moore-Penrose 伪逆1、主要函数2、输入输出参数二、代码示例使用伪逆求解线性方程组一、Moore-Penrose 伪逆 Moore-Penrose 伪逆是一种矩阵,可在不存在逆矩阵的情况下作为逆矩阵的部分替代。此矩阵常被用于求解没有唯一解或有许多解的线性方程组。    对于任何矩阵…...

简历制作方面的经验与建议

专栏推荐:2023 数字IC设计秋招复盘——数十家公司笔试题、面试实录 专栏首页:2023 数字IC设计秋招复盘——数十家公司笔试题、面试实录 专栏内容: 笔试复盘篇 2023秋招过程中整理的笔试题,来源包括我自己求职笔试以及整理其他同学的笔试。包含华为、中兴、联发科、AMD、大…...

C语言--static、const、volatile关键字

Static static修饰局部变量改变了变量的生命周期,让静态局部变量出了作用域依然存在,到程序结束,生命周期才结束。 static 修饰局部变量 改变局部变量的生命周期,本质上是改变了局部变量的存储位置,让局部变量不再是…...

XML Group端口详解

在XML数据映射过程中,经常需要对数据进行分组聚合操作。例如,当处理包含多个物料明细的XML文件时,可能需要将相同物料号的明细归为一组,或对相同物料号的数量进行求和计算。传统实现方式通常需要编写脚本代码,增加了开…...

2025 后端自学UNIAPP【项目实战:旅游项目】6、我的收藏页面

代码框架视图 1、先添加一个获取收藏景点的列表请求 【在文件my_api.js文件中添加】 // 引入公共的请求封装 import http from ./my_http.js// 登录接口(适配服务端返回 Token) export const login async (code, avatar) > {const res await http…...

HBuilderX安装(uni-app和小程序开发)

下载HBuilderX 访问官方网站:https://www.dcloud.io/hbuilderx.html 根据您的操作系统选择合适版本: Windows版(推荐下载标准版) Windows系统安装步骤 运行安装程序: 双击下载的.exe安装文件 如果出现安全提示&…...

SAP学习笔记 - 开发26 - 前端Fiori开发 OData V2 和 V4 的差异 (Deepseek整理)

上一章用到了V2 的概念,其实 Fiori当中还有 V4,咱们这一章来总结一下 V2 和 V4。 SAP学习笔记 - 开发25 - 前端Fiori开发 Remote OData Service(使用远端Odata服务),代理中间件(ui5-middleware-simpleproxy)-CSDN博客…...

Xen Server服务器释放磁盘空间

disk.sh #!/bin/bashcd /run/sr-mount/e54f0646-ae11-0457-b64f-eba4673b824c # 全部虚拟机物理磁盘文件存储 a$(ls -l | awk {print $NF} | cut -d. -f1) # 使用中的虚拟机物理磁盘文件 b$(xe vm-disk-list --multiple | grep uuid | awk {print $NF})printf "%s\n"…...

Mysql8 忘记密码重置,以及问题解决

1.使用免密登录 找到配置MySQL文件,我的文件路径是/etc/mysql/my.cnf,有的人的是/etc/mysql/mysql.cnf 在里最后加入 skip-grant-tables重启MySQL服务 service mysql restartShutting down MySQL… SUCCESS! Starting MySQL… SUCCESS! 重启成功 2.登…...

Python+ZeroMQ实战:智能车辆状态监控与模拟模式自动切换

目录 关键点 技术实现1 技术实现2 摘要: 本文将介绍如何利用Python和ZeroMQ消息队列构建一个智能车辆状态监控系统。系统能够根据时间策略自动切换驾驶模式(自动驾驶、人工驾驶、远程驾驶、主动安全),并通过实时消息推送更新车…...

论文阅读笔记——Muffin: Testing Deep Learning Libraries via Neural Architecture Fuzzing

Muffin 论文 现有方法 CRADLE 和 LEMON,依赖模型推理阶段输出进行差分测试,但在训练阶段是不可行的,因为训练阶段直到最后才有固定输出,中间过程是不断变化的。API 库覆盖低,因为各个 API 都是在各种具体场景下使用。…...

在 Spring Boot 项目里,MYSQL中json类型字段使用

前言&#xff1a; 因为程序特殊需求导致&#xff0c;需要mysql数据库存储json类型数据&#xff0c;因此记录一下使用流程 1.java实体中新增字段 private List<User> users 2.增加mybatis-plus注解 TableField(typeHandler FastjsonTypeHandler.class) private Lis…...

Python 训练营打卡 Day 47

注意力热力图可视化 在day 46代码的基础上&#xff0c;对比不同卷积层热力图可视化的结果 import torch import torch.nn as nn import torch.optim as optim from torchvision import datasets, transforms from torch.utils.data import DataLoader import matplotlib.pypl…...