当前位置: 首页 > news >正文

【论文阅读】DETR 论文逐段精读

【论文阅读】DETR 论文逐段精读

文章目录

  • 【论文阅读】DETR 论文逐段精读
  • 📖DETR 论文精读【论文精读】
    • 🌐前言
    • 📋摘要
    • 📚引言
    • 🧬相关工作
    • 🔍方法
      • 💡目标函数
      • 📜模型结构
      • ⚙️代码
    • 📌实验

参考跟李沐学AI: 精读DETR

📖DETR 论文精读【论文精读】


🌐前言

目标检测领域:从目标检测开始火到 detr 都很少有端到端的方法,大部分方法最后至少需要后处理操作(NMS, non-maximum suppression 非极大值抑制)。有了 NMS,模型调参就会很复杂,而且即使训练好了一个模型,部署起来也非常困难(NMS 不是所有硬件都支持)。

📋摘要

贡献:把目标检测做成一个端到端的框架,把之前特别依赖人的先验知识的部分删掉了(NMS 部分、anchor)。

DETR提出

  • 新的目标函数,通过二分图匹配的方式,强制模型输出一组独一无二的预测(没有那么多冗余框,每个物体理想状态下就会生成一个框)
  • 使用 encoder-decoder 的架构

两个小贡献:

  1. decoder 还有另外一个输入 learned object query,类似 anchor 的意思
    (给定这些object query之后,detr就可以把learned object query和全局图像信息结合一起,通过不同的做注意力操作,从而让模型直接输出最后的一组预测框)
  2. 想法&&实效性:并行比串行更合适

DETR 的好处:

  1. 简单性:想法上简单,不需要一个特殊的 library,只要硬件支持 transformer 或 CNN,就一定支持 detr
  2. 性能:在 coco 数据集上,detr 和一个训练非常好的 faster RCNN 基线网络取得了差不多的效果,模型内存和速度也和 faster RCNN 差不多
  3. 想法好,解决了目标检测领域很多痛点,写作好
  4. 别的任务:全景分割任务上 detr 效果很好,detr 能够非常简单拓展到其他任务上

📚引言


DETR 流程(训练)

  1. CNN 提特征
  2. 特征拉直,送到 encoder-decoder 中,encoder 作用:进一步学习全局信息,为近下来的 decoder,也就是最后出预测框做铺垫。
  3. decoder 生成框的输出,当你有了图像特征之后,还会有一个 object query(限定了你要出多少框),通过 query 和特征在 decoder 里进行自注意力操作,得到输出的框(文中是100,无论是什么图片都会预测100个框)
  4. loss :二分图匹配,计算100个预测框和2个 GT 框的 matching loss,决定100个预测框哪两个是独一无二对应到红黄色的 GT 框,匹配的框去算目标检测的 loss

推理
1、2、3一致,第四步 loss 不需要,直接在最后的输出上用一个阈值卡一个输出的置信度,置信度比较大(>0.7的)保留,置信度小于0.7的当做背景物体。

🧬相关工作

让 DETR 成功主要原因:transformer

🔍方法

分两块:1、基于集合的目标函数怎么做,作者如何通过二分图匹配把预测的框和 GT 框连接在一起,算得目标函数 2、detr 具体模型架构

💡目标函数

DETR模型最后输出是一个固定集合,无论图片是什么,最后都会输出 n 个(本文 n=100)

问题:detr 每次都会出 100 个输出,但是实际上一个图片的 GT 的 bounding box 可能只有几个,如何匹配?如何计算 loss?怎么知道哪个预测框对应 GT 框?
匈牙利算法是解决该问题的一个知名且高效的算法,能够以较低的复杂度得到唯一的最优解。
在 scipy 库中,已经封装好了匈牙利算法,只需要将成本矩阵(cost matrix)输入进去就能够得到最优的排列。在 DETR 的官方代码中,也是调用的这个函数进行匹配(from scipy.optimize import linear_sum_assignment)。
从N个预测框中,选出与M个GT Box最匹配的预测框,也可以转化为二分图匹配问题,这里需要填入矩阵的“成本”,就是每个预测框和GT Box的损失。对于目标检测问题,损失就是分类损失和边框损失组成。

所以整个步骤就是:

  • 遍历所有的预测框和 GT Box,计算其 loss。
  • 将 loss 构建为 cost matrix,然后用 scipy 的 linear_sum_assignment(匈牙利算法)求出最优解,即找到每个 GT Box 最匹配的那个预测框。
  • 计算最优的预测框和 GT Box 的损失。(分类+回归)

但是在 DETR 中,损失函数有两点小改动:

  • 去掉分类损失中的 log
  • 回归损失为 L1 loss+GIOU

📜模型结构

下面参考官网的一个 demo,以输入尺寸3×800×1066为例进行前向过程:

  • CNN 提取特征([800,1066,3]→[25,34,256]
    backbone 为 ResNet-50,最后一个 stage 输出特征图为 25×34×2048(32 倍下采样),然后用 1×1 的卷积将通道数降为 256;
  • Transformer encoder 计算自注意力([25,34,256]→[850,256]
    将上一步的特征拉直为 850×256,并加上同样维度的位置编码(Transformer 本身没有位置信息),然后输入的 Transformer encoder 进行自注意力计算,最终输出维度还是 850×256;
  • Transformer decoder 解码,生成预测框
    decoder 输入除了 encoder 部分最终输出的图像特征,还有前面提到的 learned object query,其维度为 100×256。在解码时,learned object query 和全局图像特征不停地做 across attention,最终输出 100×256 的自注意力结果。
    这里的 object query 即相当于之前的 anchor/proposal,是一个硬性条件,告诉模型最后只得到 100 个输出。然后用这 100 个输出接 FFN 得到分类损失和回归损失。
  • 使用检测头输出预测框
    检测头就是目标检测中常用的全连接层(FFN),输出 100 个预测框( h x c e n t e r , y c e n t e r , w , h h x_{center}, y_{center}, w, h hxcenter,ycenter,w,h )和对应的类别。
  • 使用二分图匹配方式输出最终的预测框,然后计算预测框和真实框的损失,梯度回传,更新网络。

除此之外还有部分细节:

  • Transformer-encode/decoder 都有 6层
  • 除第一层外,每层 Transformer encoder 里都会先计算 object query 的 self-attention,主要是为了移除冗余框。这些 query 交互之后,大概就知道每个 query 会出哪种框,互相之间不会再重复(见实验)。
  • decoder 加了 auxiliary loss,即每层 decoder 输出的 100×256 维的结果,都加了 FFN 得到输出,然后去计算 loss,这样模型收敛更快。(每层 FFN 共享参数)

⚙️代码

import torch
from torch import nn
from torchvision.models import resnet50class DETR(nn.Module):def __init__(self, num_classes, hidden_dim, nheads,num_encoder_layers, num_decoder_layers):super().__init__()# We take only convolutional layers from ResNet-50 modelself.backbone = nn.Sequential(*list(resnet50(pretrained=True).children())[:-2])self.conv = nn.Conv2d(2048, hidden_dim, 1) # 1×1卷积层将2048维特征降到256维self.transformer = nn.Transformer(hidden_dim, nheads, num_encoder_layers, num_decoder_layers)self.linear_class = nn.Linear(hidden_dim, num_classes + 1) # 类别FFNself.linear_bbox = nn.Linear(hidden_dim, 4)                # 回归FFNself.query_pos = nn.Parameter(torch.rand(100, hidden_dim)) # object query# 下面两个是位置编码self.row_embed = nn.Parameter(torch.rand(50, hidden_dim // 2))self.col_embed = nn.Parameter(torch.rand(50, hidden_dim // 2))def forward(self, inputs):x = self.backbone(inputs)h = self.conv(x)H, W = h.shape[-2:]pos = torch.cat([self.col_embed[:W].unsqueeze(0).repeat(H, 1, 1),self.row_embed[:H].unsqueeze(1).repeat(1, W, 1),], dim=-1).flatten(0, 1).unsqueeze(1) # 位置编码h = self.transformer(pos + h.flatten(2).permute(2, 0, 1),self.query_pos.unsqueeze(1))return self.linear_class(h), self.linear_bbox(h).sigmoid()detr = DETR(num_classes=91, hidden_dim=256, nheads=8, num_encoder_layers=6, num_decoder_layers=6)
detr.eval()
inputs = torch.randn(1, 3, 800, 1200)
logits, bboxes = detr(inputs)

📌实验

  • 最上面一部分是 Detectron 2 实现的 Faster RCNN ,但是本文中作者使用了很多 trick
  • 中间部分是作者使用了 GIoU loss、更强的数据增强策略、更长的训练时间来把上面三个模型重新训练了一次,这样更显公平。重新训练的模型以+表示,参数量等这些是一样的,但是普偏提了两个点
  • 下面部分是 DETR 模型,可以看到参数量、GFLOPS 更小,但是推理更慢。模型比 Faster RCNN 精度高一点,主要是大物体检测提升 6 个点 AP,小物体相比降低了 4个点左右

相关文章:

【论文阅读】DETR 论文逐段精读

【论文阅读】DETR 论文逐段精读 文章目录 【论文阅读】DETR 论文逐段精读📖DETR 论文精读【论文精读】🌐前言📋摘要📚引言🧬相关工作🔍方法💡目标函数📜模型结构⚙️代码 &#x1f4…...

负载均衡:实现高效稳定的网络服务

随着互联网技术的快速发展,网络应用服务的规模和复杂性日益增加。为了满足日益增长的用户需求,确保服务的高可用性和稳定性,负载均衡技术应运而生。本文将详细介绍负载均衡的概念、原理、分类以及应用场景,帮助读者深入了解这一关…...

2024最新软件测试【测试理论+ 抓包与网络协议】面试题(内附答案)

一、测试理论 3.1 你们原来项目的测试流程是怎么样的? 我们的测试流程主要有三个阶段:需求了解分析、测试准备、测试执行。 1、需求了解分析阶段 我们的 SE 会把需求文档给我们自己先去了解一到两天这样,之后我们会有一个需求澄清会议, …...

极简7照训练法,奇趣相机引领儿童AI摄影潮流

近日,奇趣未来推出一款专注于儿童AI摄影市场的微信小程序——奇趣相机,搭载了专为中国儿童精心研发的AIGC大模型,精准捕捉并贴合亚洲儿童人脸特征,让每一个孩子的笑容都能被完美定格。它不仅涵盖了从3岁至12岁各个年龄段的儿童摄影…...

Flink应用

1.免密登录 2.flink StandAlone模式 3.Flink Yarn 模式 (on per 模式,on session 模式) Flink概述 按照Apache官方的介绍,Flink是一个对有界和无界数据流进行状态计算的分布式处理引擎和框架。通俗地讲,Flink就是一个流计算框架,主要用来处…...

C# 委托与事件 终章

C# 委托与事件 浅尝 C# 委托与事件 深入 委托 委托有什么用&#xff1f; 将函数作为函数的参数传递声明事件并用来注册 强类型委托 Action<T1> Func<T1, TResult>事件 希望一个类的某些成员在发生变化时能被外界观测到 CollctionChangedTextChanged 标准.Ne…...

MySQL-linux安装-万能RPM法

一、MySQL的Linux版安装 1、 CentOS7下检查MySQL依赖 1. 检查/tmp临时目录权限&#xff08;必不可少&#xff09; 由于mysql安装过程中&#xff0c;会通过mysql用户在/tmp目录下新建tmp_db文件&#xff0c;所以请给/tmp较大的权限。执行 &#xff1a; chmod -R 777 /tmp2. …...

elment UI el-date-picker 月份组件选定后提交后台页面显示正常,提交后台字段变成时区格式

需求&#xff1a;要实现一个日期的月份选择<el-date-picker :typeformData.dateType :value-formatdateFormat v-modelformData.leaveFactoryDateplaceholder选择月份></el-date-picker>错误示例&#xff1a;将日期显示类型(type)dateType或将日期绑定值的格式(val…...

基于 NGINX 的 ngx_http_geoip2 模块 来禁止国外 IP 访问网站

基于 NGINX 的 ngx_http_geoip2 模块 来禁止国外 IP 访问网站 一、安装 geoip2 扩展依赖 [rootfxkj ~]# yum install libmaxminddb-devel -y二、下载 ngx_http_geoip2_module 模块 [rootfxkj tmp]# git clone https://github.com/leev/ngx_http_geoip2_module.git三、解压模…...

C++经典面试题目(二十)

1、请解释运算符重载的限制。 运算符重载必须至少有一个操作数是用户自定义类型。不能改变运算符的优先级和结合性。不能创建新的运算符。不能重载以下运算符&#xff1a;::, .*, .*, ?:, sizeof, typeid。 2、什么是友元函数&#xff1f;它有什么作用&#xff1f; 友元函数…...

vue3+uniapp 动态渲染组件,兼容h5、app端

1.setup写在js中&#xff0c;使用ref绑定数据&#xff0c;事件和数据都需要return出去。调用数据{数据名}.value。 如果你想要通过接口动态获取组件路径&#xff0c;并据此动态渲染组件&#xff0c;你可以使用异步组件和defineAsyncComponent函数。在Vue 3中&#xff0c;你可以…...

CSS层叠样式表学习(2)

&#xff08;大家好&#xff0c;今天我们将继续来学习CSS&#xff08;2&#xff09;的相关知识&#xff0c;大家可以在评论区进行互动答疑哦~加油&#xff01;&#x1f495;&#xff09; 目录 二、CSS基础选择器 2.1 CSS选择器的作用 2.2 选择器分类 2.3 标签选择器 2.…...

【MySQL】DML的表操作详解:添加数据&修改数据&删除数据(可cv例题语句)

前言 大家好吖&#xff0c;欢迎来到 YY 滴MySQL系列 &#xff0c;热烈欢迎&#xff01; 本章主要内容面向接触过C Linux的老铁 主要内容含&#xff1a; 欢迎订阅 YY滴C专栏&#xff01;更多干货持续更新&#xff01;以下是传送门&#xff01; YY的《C》专栏YY的《C11》专栏YY的…...

Docker命令及部署Java项目

文章目录 简介Docker镜像镜像列表查找镜像拉取镜像删除镜像镜像标签 Docker容器容器启动容器查看容器停止和重启后台模式和进入强制停止容器清理停止的容器容器错误日志容器别名及操作 Docker部署Java项目 简介 Docker是一种容器化技术&#xff0c;可以帮助开发者轻松打包应用…...

深度学习入门:从理论到实践的全面指南

深度学习入门&#xff1a;从理论到实践的全面指南 引言第一部分&#xff1a;深度学习基础第二部分&#xff1a;数学基础第三部分&#xff1a;编程和工具第四部分&#xff1a;构建你的第一个模型第五部分&#xff1a;深入学习结语 引言 大家好&#xff0c;这里是程序猿代码之路。…...

后端前行Vue之路(二):模版语法之插值与指令

1.概述 Vue.js的模板语法是一种将Vue实例的数据绑定到HTML文档的方法。Vue的模板语法是一种基于HTML的扩展&#xff0c;允许开发者将Vue实例中的数据绑定到HTML元素&#xff0c;以及在HTML中使用一些简单的逻辑和指令。Vue.js 基于 HTML 的模板语法允许开发者声明式地将 DOM 绑…...

Kotlin 中的类和构造方法

Kotlin 中的类与接口和 Java 中的类与接口还是有区别的。例如&#xff0c;Koltin 中的接口可以包含属性声明&#xff0c;与 Java 不同的是。Kotlin 的声明默认是 final 和 public 的。此外&#xff0c;嵌套的类默认并不是内部类&#xff1a;它们并没有包含对其它外部类的隐式引…...

【2024最新】vue3的基本使用(超详细)

一、Vue 3 概述 1. 为什么要学习Vue 3 Vue 3是Vue.js的最新主要版本&#xff0c;它带来了许多改进和新特性&#xff0c;包括但不限于&#xff1a; 性能提升&#xff1a;Vue 3提供了更快的渲染速度和更低的内存使用率。Composition API&#xff1a;引入了一个新的API&#xf…...

【xinference】(8):在autodl上,使用xinference部署qwen1.5大模型,速度特别快,同时还支持函数调用,测试成功!

1&#xff0c;关于xinference Xorbits Inference (Xinference) 是一个开源平台&#xff0c;用于简化各种 AI 模型的运行和集成。借助 Xinference&#xff0c;您可以使用任何开源 LLM、嵌入模型和多模态模型在云端或本地环境中运行推理&#xff0c;并创建强大的 AI 应用。 Xor…...

YARN集群 和 MapReduce 原理及应用

YARN集群模式 本文内容需要基于 Hadoop 集群搭建完成的基础上来实现 如果没有搭建&#xff0c;请先按上一篇: <Linux 系统 CentOS7 上搭建 Hadoop HDFS集群详细步骤> 搭建&#xff1a;https://mp.weixin.qq.com/s/zPYsUexHKsdFax2XeyRdnA 配置hadoop安装目录下的 etc…...

业务系统对接大模型的基础方案:架构设计与关键步骤

业务系统对接大模型&#xff1a;架构设计与关键步骤 在当今数字化转型的浪潮中&#xff0c;大语言模型&#xff08;LLM&#xff09;已成为企业提升业务效率和创新能力的关键技术之一。将大模型集成到业务系统中&#xff0c;不仅可以优化用户体验&#xff0c;还能为业务决策提供…...

模型参数、模型存储精度、参数与显存

模型参数量衡量单位 M&#xff1a;百万&#xff08;Million&#xff09; B&#xff1a;十亿&#xff08;Billion&#xff09; 1 B 1000 M 1B 1000M 1B1000M 参数存储精度 模型参数是固定的&#xff0c;但是一个参数所表示多少字节不一定&#xff0c;需要看这个参数以什么…...

前端倒计时误差!

提示:记录工作中遇到的需求及解决办法 文章目录 前言一、误差从何而来?二、五大解决方案1. 动态校准法(基础版)2. Web Worker 计时3. 服务器时间同步4. Performance API 高精度计时5. 页面可见性API优化三、生产环境最佳实践四、终极解决方案架构前言 前几天听说公司某个项…...

IGP(Interior Gateway Protocol,内部网关协议)

IGP&#xff08;Interior Gateway Protocol&#xff0c;内部网关协议&#xff09; 是一种用于在一个自治系统&#xff08;AS&#xff09;内部传递路由信息的路由协议&#xff0c;主要用于在一个组织或机构的内部网络中决定数据包的最佳路径。与用于自治系统之间通信的 EGP&…...

高频面试之3Zookeeper

高频面试之3Zookeeper 文章目录 高频面试之3Zookeeper3.1 常用命令3.2 选举机制3.3 Zookeeper符合法则中哪两个&#xff1f;3.4 Zookeeper脑裂3.5 Zookeeper用来干嘛了 3.1 常用命令 ls、get、create、delete、deleteall3.2 选举机制 半数机制&#xff08;过半机制&#xff0…...

基于Uniapp开发HarmonyOS 5.0旅游应用技术实践

一、技术选型背景 1.跨平台优势 Uniapp采用Vue.js框架&#xff0c;支持"一次开发&#xff0c;多端部署"&#xff0c;可同步生成HarmonyOS、iOS、Android等多平台应用。 2.鸿蒙特性融合 HarmonyOS 5.0的分布式能力与原子化服务&#xff0c;为旅游应用带来&#xf…...

QT: `long long` 类型转换为 `QString` 2025.6.5

在 Qt 中&#xff0c;将 long long 类型转换为 QString 可以通过以下两种常用方法实现&#xff1a; 方法 1&#xff1a;使用 QString::number() 直接调用 QString 的静态方法 number()&#xff0c;将数值转换为字符串&#xff1a; long long value 1234567890123456789LL; …...

HDFS分布式存储 zookeeper

hadoop介绍 狭义上hadoop是指apache的一款开源软件 用java语言实现开源框架&#xff0c;允许使用简单的变成模型跨计算机对大型集群进行分布式处理&#xff08;1.海量的数据存储 2.海量数据的计算&#xff09;Hadoop核心组件 hdfs&#xff08;分布式文件存储系统&#xff09;&a…...

短视频矩阵系统文案创作功能开发实践,定制化开发

在短视频行业迅猛发展的当下&#xff0c;企业和个人创作者为了扩大影响力、提升传播效果&#xff0c;纷纷采用短视频矩阵运营策略&#xff0c;同时管理多个平台、多个账号的内容发布。然而&#xff0c;频繁的文案创作需求让运营者疲于应对&#xff0c;如何高效产出高质量文案成…...

NXP S32K146 T-Box 携手 SD NAND(贴片式TF卡):驱动汽车智能革新的黄金组合

在汽车智能化的汹涌浪潮中&#xff0c;车辆不再仅仅是传统的交通工具&#xff0c;而是逐步演变为高度智能的移动终端。这一转变的核心支撑&#xff0c;来自于车内关键技术的深度融合与协同创新。车载远程信息处理盒&#xff08;T-Box&#xff09;方案&#xff1a;NXP S32K146 与…...