当前位置: 首页 > news >正文

计算机视觉目标检测-DETR网络

目录

  • 摘要
  • abstract
  • DETR目标检测网络详解
    • 二分图匹配和损失函数
  • DETR总结
  • 总结

摘要

DETR(DEtection TRansformer)是由Facebook AI提出的一种基于Transformer架构的端到端目标检测方法。它通过将目标检测建模为集合预测问题,摒弃了锚框设计和非极大值抑制(NMS)等复杂后处理步骤。DETR使用卷积神经网络提取图像特征,并将其通过位置编码转换为输入序列,送入Transformer的Encoder-Decoder结构。Decoder通过固定数量的目标查询(Object Queries),预测类别和边界框位置。DETR创新性地引入匈牙利算法进行二分图匹配,确保预测与真实值的唯一对应关系,且采用交叉熵损失和L1-GIoU损失进行优化。在COCO数据集上的实验表明,DETR在大目标检测中表现优异,并能灵活迁移到其他任务,如全景分割。

abstract

DETR (DEtection TRansformer) is an end-to-end target detection method based on Transformer architecture proposed by Facebook AI. By modeling object detection as a set prediction problem, it eliminates complex post-processing steps such as anchor frame design and non-maximum suppression (NMS). DETR uses convolutional neural networks to extract image features and convert them via positional encoding into input sequences that feed into Transformer’s Encoder-Decoder structure. Decoder predicts categories and bounding box positions with a fixed number of Object Queries. DETR innovates by introducing the Hungarian algorithm for bipartite graph matching to ensure a unique relationship between the prediction and the true value, and optimizes with cross-entropy losses and L1-GIoU losses. Experiments on the COCO dataset show that DETR performs well in large target detection and can be flexibly migrated to other tasks, such as panoramic segmentation.

下图是目标检测中检测器模型的发展:
在这里插入图片描述

DETR目标检测网络详解

DETR(DEtection TRansformer)是由Facebook AI在2020年提出的一种基于Transformer架构的端到端目标检测方法。与传统的目标检测方法(如Faster R-CNN、YOLO等)不同,DETR直接将目标检测建模为一个集合预测问题,摆脱了锚框设计和复杂的后处理(如NMS)。结果在 COCO 数据集上效果与 Faster RCNN 相当,在大目标上效果比 Faster RCNN 好,且可以很容易地将 DETR 迁移到其他任务例如全景分割。
在这里插入图片描述
简单来说,就是通过CNN提取图像特征(通常 Backbone 的输出通道为 2048,图像高和宽都变为了 1/32),并经过input embedding+positional encoding操作转换为图像序列(如下图所说,就是类似[N, HW, C]的序列)作为transformer encoder的输入,得到了编码后的图像序列,在图像序列的帮助下,将object queries(下图中说的是固定数量的可学习的位置embeddings)转换/预测为固定数量的类别+bbox预测。相当于Transformer本质上起了一个序列转换的作用。
在这里插入图片描述
下图为DETR的详细结构:
在这里插入图片描述
DETR中的encoder-decoder与transformer中的encoder-decoder对比:

  1. spatial positional encoding:新提出的二维空间位置编码方法,该位置编码分别被加入到了encoder的self attention的QK和decoder的cross attention的K,同时object queries也被加入到了decoder的两个attention(第一个加到了QK中,第二个加入了Q)中。而原版的Transformer将位置编码加到了input和output embedding中。
  2. DETR在计算attention的时候没有使用masked attention,因为将特征图展开成一维以后,所有像素都可能是互相关联的,因此没必要规定mask。
  3. object queries的转换过程:object queries是预定义的目标查询的个数,代码中默认为100。它的意义是:根据Encoder编码的特征,Decoder将100个查询转化成100个目标,即最终预测这100个目标的类别和bbox位置。最终预测得到的shape应该为[N, 100, C],N为Batch Num,100个目标,C为预测的100个目标的类别数+1(背景类)以及bbox位置(4个值)。
  4. 得到预测结果以后,将object predictions和ground truth box之间通过匈牙利算法进行二分匹配:假如有K个目标,那么100个object predictions中就会有K个能够匹配到这K个ground truth,其他的都会和“no object”匹配成功,使其在理论上每个object query都有唯一匹配的目标,不会存在重叠,所以DETR不需要nms进行后处理。
  5. 分类loss采用的是交叉熵损失,针对所有predictions;bbox loss采用了L1 loss和giou loss,针对匹配成功的predictions。

匈牙利算法是用于解决二分图匹配的问题,即将Ground Truth的K个bbox和预测出的100个bbox作为二分图的两个集合,匈牙利算法的目标就是找到最大匹配,即在二分图中最多能找到多少条没有公共端点的边。匈牙利算法的输入就是每条边的cost 矩阵
在这里插入图片描述

二分图匹配和损失函数

思考
DETR 预测了一组固定大小的 N = 100 个边界框,这比图像中感兴趣的对象的实际数量大得多。怎么样来计算损失呢?或者说预测出来的框我们怎么知道对应哪一个 ground-truth 的框呢?

为了解决这个问题,第一步是将 ground-truth 也扩展成 N = 100 个检测框。使用了一个额外的特殊类标签 ϕ \phiϕ 来表示在未检测到任何对象,或者认为是背景类别。这样预测和真实都是两个100 个元素的集合了。这时候采用匈牙利算法进行二分图匹配,即对预测集合和真实集合的元素进行一一对应,使得匹配损失最小。
σ ^ = arg ⁡ min ⁡ G ∈ G N ∑ i N L m a t c h ( y i , y ^ σ ( i ) ) \hat{\sigma}=\arg\min_{\mathrm{G\in G_N}}\sum_{\mathrm{i}}^{\mathrm{N}}\mathcal{L}_{\mathrm{match}}\left(\mathrm{y_i},\hat{\mathrm{y}}_{\mathrm{\sigma(i)}}\right) σ^=argGGNminiNLmatch(yi,y^σ(i))
L m a t c h ( y i , y ^ σ ( i ) ) = − 1 { c i ≠ ∅ } p ^ σ ( i ) ( c i ) + 1 { c i ≠ ∅ } L b o x ( b i , b ^ σ ( i ) ) \mathcal{L}_{\mathrm{match}}\left(\mathrm{y_i},\hat{\mathrm{y}}_{\mathrm{\sigma(i)}}\right)=-1_{\{\mathrm{c_i}\neq\varnothing\}}\hat{\mathrm{p}}_{\mathrm{\sigma(i)}}\left(\mathrm{c_i}\right)+1_{\{\mathrm{c_i}\neq\varnothing\}}\mathcal{L}_{\mathrm{box}}\left(\mathrm{b_i},\hat{\mathrm{b}}_{\mathrm{\sigma(i)}}\right) Lmatch(yi,y^σ(i))=1{ci=}p^σ(i)(ci)+1{ci=}Lbox(bi,b^σ(i))
对于那些不是背景的,获得其对应的预测是目标类别的概率,然后用框损失减去预测类别概率。这也就是说不仅框要近,类别也要基本一致,是最好的。经过匈牙利算法之后,我们就得到了 ground truth 和预测目标框之间的一一对应关系。然后就可以计算损失函数了。

下面是利用pytorch实现DETR的代码:
位置编码部分:

class PositionalEncoding(nn.Module):def __init__(self, d_model, max_len=5000):super().__init__()pe = torch.zeros(max_len, d_model)position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-torch.log(torch.tensor(10000.0)) / d_model))pe[:, 0::2] = torch.sin(position * div_term)pe[:, 1::2] = torch.cos(position * div_term)pe = pe.unsqueeze(0).transpose(0, 1)self.register_buffer('pe', pe)def forward(self, x):return x + self.pe[:x.size(0), :]

用于为序列数据(如Transformer中的输入)添加位置信息。位置编码帮助模型保留序列中元素的位置信息,这是因为Transformer模型本身不具备位置信息感知能力。
使用正弦和余弦函数优点
优点:
正弦和余弦具有周期性和平滑性;
不同维度具有不同频率,编码了多尺度的位置信息。
作用:保留序列的位置信息,使模型能够感知数据的顺序。

编码可视化结果:

import matplotlib.pyplot as pltimport torch
import torch.nn as nn# 位置编码
class PositionalEncoding(nn.Module):def __init__(self, d_model, max_len=5000):super().__init__()pe = torch.zeros(max_len, d_model)position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-torch.log(torch.tensor(10000.0)) / d_model))pe[:, 0::2] = torch.sin(position * div_term)pe[:, 1::2] = torch.cos(position * div_term)pe = pe.unsqueeze(0).transpose(0, 1)self.register_buffer('pe', pe)def forward(self, x):return x + self.pe[:x.size(0), :]pe = PositionalEncoding(d_model=16, max_len=100)
x = torch.zeros(100, 1, 16)
encoded = pe(x).squeeze(1).detach().numpy()plt.figure(figsize=(10, 5))
plt.imshow(encoded, aspect='auto', cmap='viridis')
plt.colorbar(label='Encoding Value')
plt.xlabel('Dimension')
plt.ylabel('Position')
plt.title('Positional Encoding Visualization')
plt.show()

在这里插入图片描述
上图反应以下几点变化
不同维度的变化

  1. 低频维度(如 d=0,1):颜色变化缓慢,代表位置之间编码的相似性较高,捕捉全局信息。
  2. 高频维度(如 d=14,15):颜色变化迅速,代表位置之间编码差异较大,捕捉局部信息。

同一位置的编码:
值的分布(正弦和余弦的相互作用)保证了每个位置在多维空间中具有唯一性。

时间步的相对差异:
相邻位置(如第1和第2位置)在高维上的值差异较大,这为模型提供了感知时间步变化的能力。

encoder-decoder:

class Transformer(nn.Module):def __init__(self, d_model=256, nhead=8, num_encoder_layers=6, num_decoder_layers=6):super().__init__()self.encoder_layer = nn.TransformerEncoderLayer(d_model=d_model, nhead=nhead)self.decoder_layer = nn.TransformerDecoderLayer(d_model=d_model, nhead=nhead)self.encoder = nn.TransformerEncoder(self.encoder_layer, num_layers=num_encoder_layers)self.decoder = nn.TransformerDecoder(self.decoder_layer, num_layers=num_decoder_layers)def forward(self, src, tgt, src_mask=None, tgt_mask=None):memory = self.encoder(src, mask=src_mask)output = self.decoder(tgt, memory, tgt_mask=tgt_mask)return output

DETR模型:

# DETR模型
class DETR(nn.Module):def __init__(self, num_classes, num_queries, backbone='resnet50'):super().__init__()self.num_queries = num_queries# Backboneself.backbone = models.resnet50(pretrained=True)self.conv = nn.Conv2d(2048, 256, kernel_size=1)# Transformerself.transformer = Transformer(d_model=256)self.query_embed = nn.Embedding(num_queries, 256)self.positional_encoding = PositionalEncoding(256)# Prediction headsself.class_embed = nn.Linear(256, num_classes + 1)  # +1 for no-object classself.bbox_embed = nn.Linear(256, 4)def forward(self, images):# Feature extractionfeatures = self.backbone(images)features = self.conv(features)h, w = features.shape[-2:]# Flatten and add positional encodingsrc = features.flatten(2).permute(2, 0, 1)  # (HW, N, C)src = self.positional_encoding(src)# Query embeddingquery_embed = self.query_embed.weight.unsqueeze(1).repeat(1, images.size(0), 1)  # (num_queries, N, C)# Transformerhs = self.transformer(src, query_embed)# Predictionoutputs_class = self.class_embed(hs)outputs_coord = self.bbox_embed(hs).sigmoid()  # Normalized to [0, 1]return {'pred_logits': outputs_class, 'pred_boxes': outputs_coord}

DETR总结

DETR通过Transformer实现端到端的目标检测,无需(如NMS)复杂的后处理。相比传统检测器,DETR具有简洁的架构和强大的全局建模能力,但训练时对数据和计算资源的需求较高。

总结

DETR简化了目标检测的流程,摒弃了传统检测器中繁琐的锚框设计和后处理步骤,架构更简洁,且依托于Transformer的全局建模能力,在捕捉长距离特征关系方面表现出色。相比传统方法,DETR在目标数量固定的场景下,能够更高效地处理目标检测任务。其优点包括易迁移、多任务适用性和端到端优化能力,但其劣势在于训练时间较长、计算资源消耗较大,尤其是在小目标检测和训练数据量不足的情况下效果略显不足。

相关文章:

计算机视觉目标检测-DETR网络

目录 摘要abstractDETR目标检测网络详解二分图匹配和损失函数 DETR总结总结 摘要 DETR(DEtection TRansformer)是由Facebook AI提出的一种基于Transformer架构的端到端目标检测方法。它通过将目标检测建模为集合预测问题,摒弃了锚框设计和非…...

《自动驾驶与机器人中的SLAM技术》ch1:自动驾驶

目录 1.1 自动驾驶技术 1.2 自动驾驶中的定位与地图 1.1 自动驾驶技术 1.2 自动驾驶中的定位与地图 L2 在技术实现上会更倾向于实时感知,乃至可以使用感知结果直接构建鸟瞰图(bird eye view, BEV),而 L4 则依赖离线地图。 高精地…...

【UE5 C++课程系列笔记】23——多线程基础——AsyncTask

目录 概念 函数说明 注意事项 (1)线程安全问题 (2)依赖特定线程执行的任务限制 (3)任务执行顺序和时间不确定性 使用示例 概念 AsyncTask 允许开发者将一个函数或者一段代码逻辑提交到特定的线程去执…...

基于Python的音乐播放器 毕业设计-附源码73733

摘 要 本项目基于Python开发了一款简单而功能强大的音乐播放器。通过该音乐播放器,用户可以轻松管理自己的音乐库,播放喜爱的音乐,并享受音乐带来的愉悦体验。 首先,我们使用Python语言结合相关库开发了这款音乐播放器。利用Tkin…...

cursor vip

https://cursor.jeter.eu.org?pf7f4f3fab0af4119bece19ff4a4360c3 可以直接复制命令使用git bash执行即可 命令&#xff1a; bash <(curl -Lk https://gitee.com/kingparks/cursor-vip/releases/download/latest/ic.sh) f7f4f3fab0af4119bece19ff4a4360c3 等待执行完成后…...

Docker部署项目,Mysql数据库总是宕机并且上传数据全部被删除了

刚开始排查原因我以为是一些内存占用问题的原因&#xff0c;后来查看数据库日志发现有多个异常ip尝试连接数据库并且也连接成功了随后数据库就被异常关闭了&#xff0c;然后我就重启容器远程连接数据库发现数据全没了&#xff0c;又在数据库中找到了如下内容&#xff1a; All y…...

C++ 复习总结记录六

C 复习总结记录六 模板初阶主要内容 1、泛型编程 2、函数模板 3、类模板 4、STL 简介 一 泛型编程 如何实现一个通用的交换函数 void Swap(int& left, int& right) {int temp left;left right;right temp; } void Swap(double& left, double& right…...

spring boot 集成 knife4j

1、knife4j介绍以及环境介绍 knife4j是为Java MVC框架集成Swagger生成Api文档的增强解决方案,前身是swagger-bootstrap-ui,取名knife4j是希望它能像一把匕首一样小巧,轻量,并且功能强悍!其底层是对Springfox的封装&#xff0c;使用方式也和Springfox一致&#xff0c;只是对接口…...

WordPress静态缓存插件WP Super Cache与 WP Fastest Cache

引言 WordPress是一款开源的内容管理系统&#xff08;CMS&#xff09;&#xff0c;最初作为博客平台开发&#xff0c;现已发展成为一个功能强大的建站工具&#xff0c;支持创建各种类型的网站&#xff0c;包括企业网站、在线商店、个人博客等。它具有用户友好的界面、丰富的插…...

Pytest钩子函数,测试框架动态切换测试环境

在软件测试中&#xff0c;测试环境的切换是个令人头疼的问题。不同环境的配置不同&#xff0c;如何高效切换测试环境成为许多测试开发人员关注的重点。你是否希望在运行测试用例时&#xff0c;能够动态选择测试环境&#xff0c;而不是繁琐地手动修改配置&#xff1f; Pytest 测…...

VUE3封装一个Hook

在 Vue 3 中&#xff0c;Composition API 让我们能够封装和复用代码逻辑&#xff0c;尤其是通过 setup 函数进行组件间的复用。为了提高代码的可复用性&#xff0c;我们可以把一些常见的 API 请求和状态管理逻辑封装到一个单独的 hook 中。 以下是一个简单的例子&#xff0c;我…...

【Spring Boot】Spring AOP 快速上手指南:开启面向切面编程新旅程

前言 &#x1f31f;&#x1f31f;本期讲解关于spring aop的入门介绍~~~ &#x1f308;感兴趣的小伙伴看一看小编主页&#xff1a;GGBondlctrl-CSDN博客 &#x1f525; 你的点赞就是小编不断更新的最大动力 &#x1f386;那么废话不…...

HTML基础入门——简单网页页面

目录 一&#xff0c;网上转账电子账单 ​编辑 1&#xff0c;所利用到的标签 2&#xff0c;代码编写 3&#xff0c;运行结果 二&#xff0c;李白诗词 1&#xff0c;所用到的标签 2&#xff0c;照片的编辑 3&#xff0c;代码编写 4&#xff0c;运行结果 一&#xff0c;网…...

INT301 Bio Computation 题型整理

perceptron 设计和计算 1. XOR: 当两个输入值中只有一个为真时&#xff0c;输出为真 2. 3. 5. 6. 7. 2^3 2^n 9. a) 直接test b) 把v≥2 改成 v≥1 10. no, because it cant be separate through only one decision boundary,its not linearlly separable. Backpropagatio…...

机器学习免费使用的数据集及网站链接

机器学习领域存在许多可以免费使用的数据集&#xff0c;这些数据集来自于学习、研究、比赛等目的。 一、综合性数据集平台 1.Kaggle 网址&#xff1a;Kaggle 数据集https://www.kaggle.com/datasets Kaggle是一个数据科学竞赛和社区平台&#xff0c;提供了大量的数据集供用…...

低空经济——飞行汽车运营建模求解问题思路

1. 掌握问题背景和领域知识 目标&#xff1a; 理解飞行汽车及其运营问题的核心要素和应用背景。学习内容&#xff1a; 飞行汽车基础&#xff1a; 了解飞行汽车的技术特点&#xff08;垂直起降、电动推进等&#xff09;。阅读行业报告&#xff0c;如 Uber Elevate 白皮书。共享…...

英伟达Project Digits赋能医疗大模型:创新应用与未来展望

英伟达Project Digits赋能医疗大模型&#xff1a;创新应用与未来展望 一、引言 1.1 研究背景与意义 在当今数字化时代&#xff0c;医疗行业作为关乎国计民生的关键领域&#xff0c;正面临着前所未有的挑战与机遇。一方面&#xff0c;传统医疗模式在应对海量医疗数据的处理、复…...

【Python3】异步操作 redis

aioredis 在高版本已经不支持了&#xff0c; 不要用 代码示例 redis 连接池异步操作redis以及接口 import asyncio from sanic import Sanic from sanic.response import json import redis.asyncio as redis from redis.asyncio import ConnectionPool# 创建 Sanic 应用 app…...

【W800】UART 的使用与问题

1.开发环境 OS: Windows 11开发板&#xff1a;海凌科 HLK-W800-KIT-PROSDK: W80X_SDK_v1.00.10IDE: CSKY Development Kit 2.UART 使用 在 SDK 中创建文件 uart_test.h 和 uart_test.c&#xff0c;然后在 CDK 项目中添加这两个文件&#xff0c;CDK 会自动 include 头文件。 …...

UART串口数据分析

串口基础知识详细介绍&#xff1a; 该链接详细介绍了串并行、单双工、同异步、连接方式 https://blog.csdn.net/weixin_43386810/article/details/127156063 该文章将介绍串口数据的电平变化、波特率计算、脉宽计算以及数据传输量的计算。 捕获工具&#xff1a;逻辑分析仪&…...

OpenClaw权限管理:千问3.5-35B-A3B-FP8操作范围最小化实践

OpenClaw权限管理&#xff1a;千问3.5-35B-A3B-FP8操作范围最小化实践 1. 为什么需要限制OpenClaw的权限 去年夏天&#xff0c;我在本地部署OpenClaw对接千问3.5模型时&#xff0c;曾因为一个简单的文件整理指令差点酿成大祸。当时我让AI帮我整理下载文件夹&#xff0c;结果它…...

如何用readme.so快速制作专业README:揭秘实时预览与Markdown同步技术

如何用readme.so快速制作专业README&#xff1a;揭秘实时预览与Markdown同步技术 【免费下载链接】readme.so An online drag-and-drop editor to easily build READMEs 项目地址: https://gitcode.com/gh_mirrors/re/readme.so readme.so是一款功能强大的在线拖放编辑器…...

SAP MM模块预留功能实战:从创建到发料的完整流程解析

SAP MM模块预留功能实战&#xff1a;从创建到发料的完整流程解析 在制造业和供应链管理领域&#xff0c;物料预留是确保生产计划顺利执行的关键环节。SAP MM模块中的预留功能&#xff0c;就像一位经验丰富的仓库管理员&#xff0c;能够提前为未来需求锁定必要的物料资源。想象一…...

AI Agent在物流与运输中的应用:路径优化与调度自动化

AI Agent在物流与运输中的应用:路径优化与调度自动化 引言 在当今快速发展的商业环境中,物流与运输行业正面临着前所未有的挑战。随着电子商务的爆发式增长,消费者对配送速度、成本和可靠性的要求越来越高。同时,全球化供应链的复杂性、燃油价格的波动以及环保法规的日益…...

从‘轮胎压力传感器’到‘魔数饼干’:手把手拆解SOME/IP协议栈的五个核心通信模型

从轮胎压力到魔数饼干&#xff1a;SOME/IP协议栈五大通信模型实战解码 1. 引言&#xff1a;当汽车电子遇上分布式通信 想象一下&#xff0c;你驾驶的现代汽车正以每小时100公里的速度飞驰&#xff0c;此时轮胎压力监测系统突然检测到右前轮气压异常。这个信号需要以毫秒级速度传…...

STM32F4标准库实战:用DMA+FSMC驱动TFT-LCD,让你的GUI刷新快人一步(附避坑指南)

STM32F4标准库实战&#xff1a;DMAFSMC驱动TFT-LCD的性能飞跃与避坑全攻略 在嵌入式GUI开发中&#xff0c;流畅的界面刷新体验往往决定着产品的第一印象。当你在STM32F4平台上使用LVGL或emWin时&#xff0c;是否遇到过这些场景&#xff1a;手指滑动列表时的明显卡顿、动画渲染…...

OpenClaw如何做好记忆持久化的 · 六、经济学与可扩展性——记忆的代价

六、经济学与可扩展性——记忆的代价⏱ 30 秒速览 | 中度使用&#xff08;日均 50 次对话&#xff09;纯记忆附加成本&#xff1a;~$5/月&#xff08;Claude Sonnet&#xff09;/ ~$1/月&#xff08;GPT-4o-mini&#xff09;。72% 花在记忆注入&#xff0c;24% 花在自动提取&am…...

好写作AI:毕业论文“智造”新引擎,开启学术创作新纪元!

在学术探索的征途中&#xff0c;毕业论文无疑是一座巍峨的山峰&#xff0c;让无数莘莘学子既期待又忐忑。但别怕&#xff0c;时代在进步&#xff0c;科技在发展&#xff0c;我们有了新的“登山装备”——好写作AI。它不仅是你的学术助手&#xff0c;更是毕业论文“智造”的新引…...

5G毫米波手机天线设计实战:TLM算法在CST中的高效整机仿真

1. 5G毫米波天线设计的挑战与TLM算法优势 5G毫米波频段&#xff08;24GHz以上&#xff09;的天线设计就像在针尖上跳舞——既要保证高频信号的传输效率&#xff0c;又要应对手机内部寸土寸金的布局空间。我去年参与的一个项目就遇到过典型问题&#xff1a;当把毫米波天线集成到…...

Spring 事务从入门到精通:一篇搞定事务失效、传播行为、回滚规则(Spring系列10)

一、前言 在日常开发中&#xff0c;事务是保证数据一致性的核心手段。尤其是转账这类业务&#xff0c;必须保证「A减钱」和「B加钱」两个操作同成功、同失败&#xff0c;否则就会出现资金异常。 Spring 提供了一套完整的声明式事务解决方案&#xff0c;基于 AOP 实现&#xff0…...