⌈ 传知代码 ⌋ DETR[端到端目标检测]
💛前情提要💛
本文是传知代码平台中的相关前沿知识与技术的分享~
接下来我们即将进入一个全新的空间,对技术有一个全新的视角~
本文所涉及所有资源均在传知代码平台可获取
以下的内容一定会让你对AI 赋能时代有一个颠覆性的认识哦!!!
以下内容干货满满,跟上步伐吧~
📌导航小助手📌
- 💡本章重点
- 🍞一. 概述
- 🍞二. 模型主体框架
- 🍞三.演示效果
- 🍞四.核心逻辑
- 🫓总结
💡本章重点
- DETR[端到端目标检测]
🍞一. 概述
在目标检测需要许多手工设计的组件,例如非极大值抑制(NMS),基于人工经验生成的先验框(Anchor)等。DETR这篇文章通过将目标检测作为一个直接的集合预测问题,减少了人工设计组件的知识,简化了目标检测的流程。给定一组固定的可学习的目标查询,DETR推理目标和全局图像的上下文关系,由于DETR没有先验框的约束,因此对于较大的物体预测性能会更好。
🍞二. 模型主体框架

如图所示为DETR的主体框架,由于直接采用transformer结构,模型的计算量较大,因此DETR首先采用CNN卷积神经网络进行抽取特征,此时生成的特征图一般而言降采样32倍。之后将提取的特征图送Transformer的encoder结构中进行自注意力的交互,获取特征图中每个像素和其他像素之间关系。decoder首先预设了N个查询,该N个查询首先进行自注意力机制除去模型中的冗余框,之后与来自Encoder的特征进行交互形成数量为N查询,该查询通过线性层生成模型预测的类别和相应的边界框输出,最终预测得到结果。
在实验的时候N的数据要大于一张图片上所含有所有物体的数量,在计算损失函数的时候,DETR首先采用匈牙利算法去寻找到正确的匹配方式。之后再去计算bbox和分类的损失值。由于 L1
损失函数对于不同大小的边界框产生的误差不相同,因此我们采用了GIoU 损失函数去弥补这些误差。下图所示是DETR更详细的示意图:

主干网络
针对于一张通道数大小为3的图片,首先经过CNN的骨干网络,得到一个通道数为2048(该数据是我们人工设置的)
Transformer编码器

Transformer解码器
与标准的Transformer架构中的decoder不同,DETR没有采用掩码机制,因此N个预测的边界框可以同时输出。由于解码器仍然具有置换不变性,因此我们采用可学习的位置编码作为解码器的输入嵌入,并把它称为object query。通过多个层结构,该object query最终转变为输出的边界框,通过FFN结构,生成N个坐标点和分类的对象。

上图所示是模型Transformer的主要结构,来自CNN主干网络的图像特征被送到transformer编码器中,在每个多头自注意力机制中与空间位置编码相加作为多头自注意力机制的键和查询,(生成q,k,v需要矩阵相乘,并不是一个直接的结果)。作为在解码器和编码器进行注意力机制计算之前,首先object query需要进行一个自注意力机制,该步骤是为了去除模型中的冗余框。
🍞三.演示效果
DETR 进行目标检测

DETR 交叉注意力机制可视化

query表示当前物体的标号,下方对应的是相应的名称
DETR自注意力机制可视化

上方显示的点可以人工手动调整
🍞四.核心逻辑
DETR模型的基本框架
class DETR(nn.Module):""" This is the DETR module that performs object detection """def __init__(self, backbone, transformer, num_classes, num_queries, aux_loss=False):""" Initializes the model.Parameters:backbone: torch module of the backbone to be used. See backbone.pytransformer: torch module of the transformer architecture. See transformer.pynum_classes: number of object classesnum_queries: number of object queries, ie detection slot. This is the maximal number of objectsDETR can detect in a single image. For COCO, we recommend 100 queries.aux_loss: True if auxiliary decoding losses (loss at each decoder layer) are to be used."""super().__init__()self.num_queries = num_queriesself.transformer = transformerhidden_dim = transformer.d_modelself.class_embed = nn.Linear(hidden_dim, num_classes + 1)self.bbox_embed = MLP(hidden_dim, hidden_dim, 4, 3)self.query_embed = nn.Embedding(num_queries, hidden_dim)self.input_proj = nn.Conv2d(backbone.num_channels, hidden_dim, kernel_size=1)self.backbone = backboneself.aux_loss = aux_lossdef forward(self, samples: NestedTensor):""" The forward expects a NestedTensor, which consists of:- samples.tensor: batched images, of shape [batch_size x 3 x H x W]- samples.mask: a binary mask of shape [batch_size x H x W], containing 1 on padded pixelsIt returns a dict with the following elements:- "pred_logits": the classification logits (including no-object) for all queries.Shape= [batch_size x num_queries x (num_classes + 1)]- "pred_boxes": The normalized boxes coordinates for all queries, represented as(center_x, center_y, height, width). These values are normalized in [0, 1],relative to the size of each individual image (disregarding possible padding).See PostProcess for information on how to retrieve the unnormalized bounding box.- "aux_outputs": Optional, only returned when auxilary losses are activated. It is a list ofdictionnaries containing the two above keys for each decoder layer."""if isinstance(samples, (list, torch.Tensor)):samples = nested_tensor_from_tensor_list(samples)# backbone 网络进行了两个操作,分别是获取特征图和位置编码features, pos = self.backbone(samples)src, mask = features[-1].decompose()assert mask is not None# input_proj: src: [2,2048,28,38]->[2,256,28,38] 改变特征图的通道维数# mask: [2,28,38] mask的通道维数为1 pos: [2,256,28,38] query表示查询,也就是图片里面可能有多少物体的个数hs = self.transformer(self.input_proj(src), mask, self.query_embed.weight, pos[-1])[0]outputs_class = self.class_embed(hs)outputs_coord = self.bbox_embed(hs).sigmoid()# 都只使用最后一层decoder输出的结果out = {'pred_logits': outputs_class[-1], 'pred_boxes': outputs_coord[-1]}if self.aux_loss:out['aux_outputs'] = self._set_aux_loss(outputs_class, outputs_coord)return out@torch.jit.unuseddef _set_aux_loss(self, outputs_class, outputs_coord):# this is a workaround to make torchscript happy, as torchscript# doesn't support dictionary with non-homogeneous values, such# as a dict having both a Tensor and a list.return [{'pred_logits': a, 'pred_boxes': b}for a, b in zip(outputs_class[:-1], outputs_coord[:-1])]
Transformer模块
class Transformer(nn.Module):def __init__(self, d_model=512, nhead=8, num_encoder_layers=6,num_decoder_layers=6, dim_feedforward=2048, dropout=0.1,activation="relu", normalize_before=False,return_intermediate_dec=False):super().__init__()encoder_layer = TransformerEncoderLayer(d_model, nhead, dim_feedforward,dropout, activation, normalize_before)encoder_norm = nn.LayerNorm(d_model) if normalize_before else Noneself.encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm)decoder_layer = TransformerDecoderLayer(d_model, nhead, dim_feedforward,dropout, activation, normalize_before)decoder_norm = nn.LayerNorm(d_model)self.decoder = TransformerDecoder(decoder_layer, num_decoder_layers, decoder_norm,return_intermediate=return_intermediate_dec)self._reset_parameters()self.d_model = d_modelself.nhead = nheaddef _reset_parameters(self):for p in self.parameters():if p.dim() > 1:nn.init.xavier_uniform_(p)def forward(self, src, mask, query_embed, pos_embed):# flatten NxCxHxW to HWxNxC [2,256,28,38]bs, c, h, w = src.shape# src: [2,256,28,38]->[2,256,28*38]->[1064,2,256]# pos_embed: [2,256,28,38]->[2,256,28*38]->[1064,2,256]src = src.flatten(2).permute(2, 0, 1)pos_embed = pos_embed.flatten(2).permute(2, 0, 1)# query_embed:[100,256]->[100,1,256]->[100,2,256]query_embed = query_embed.unsqueeze(1).repeat(1, bs, 1)# mask: [2,28,38]->[2,1064]mask = mask.flatten(1)# 其实也是一个位置编码,表示目标的信息,一开始被初始化为0 [100,2,256]tgt = torch.zeros_like(query_embed)# memory的shape和src的一样是[1064,2,256]memory = self.encoder(src, src_key_padding_mask=mask, pos=pos_embed)hs = self.decoder(tgt, memory, memory_key_padding_mask=mask,pos=pos_embed, query_pos=query_embed)# hs 不止输出最后一层的结构,而是输出解码器所有层结构的输出情况# hs: [6,100,2,256]->[6,2,100,256] [depth,batch_size,num_query,channel]# 一般只使用最后一层特征所以未hs[-1]return hs.transpose(1, 2), memory.permute(1, 2, 0).view(bs, c, h, w)
🫓总结
综上,我们基本了解了“一项全新的技术啦” 🍭 ~~
恭喜你的内功又双叒叕得到了提高!!!
感谢你们的阅读😆
后续还会继续更新💓,欢迎持续关注📌哟~
💫如果有错误❌,欢迎指正呀💫
✨如果觉得收获满满,可以点点赞👍支持一下哟~✨
【传知科技 – 了解更多新知识】
相关文章:
⌈ 传知代码 ⌋ DETR[端到端目标检测]
💛前情提要💛 本文是传知代码平台中的相关前沿知识与技术的分享~ 接下来我们即将进入一个全新的空间,对技术有一个全新的视角~ 本文所涉及所有资源均在传知代码平台可获取 以下的内容一定会让你对AI 赋能时代有一个颠覆性的认识哦&#x…...
Oracle之触发器
简介 触发器在数据库里以独立的对象存储,他与存储过程不同的是,存储过程通过其他程序来启动运行或直接启动运行而触发器是由一个事件来启动运行,即触发器是当某个事件发生时自动式运行。并企,触发器不能接收参数。所以运行触发器…...
从零搭建微前端架构:解耦大型项目的终极方案
随着前端应用的复杂度不断提升,单体前端应用(Monolithic Frontend)的维护和扩展难度也日益增加。微前端(Micro-Frontend)作为一种新兴架构理念,旨在将大型前端项目拆分为多个独立、可独立部署的微应用,从而提升项目的可维护性和灵活性。这篇文章将带你从零开始搭建一个微…...
24/8/17算法笔记 MPC算法
MPC算法,在行动前推演一下 MPC(Model Predictive Control,模型预测控制)是一种先进的控制策略,它利用未来预测模型来优化当前的控制动作。MPC的核心思想是,在每一个控制步骤中,都基于当前系统状…...
GROUP_CONCAT 用法详解(Mysql)
GROUP_CONCAT GROUP_CONCAT 是 MySQL 中的一个聚合函数,用于将分组后的多行数据连接成一个单一的字符串。 通常用于将某个列的多个值合并到一个字符串中,以便更方便地显示或处理数据。 GROUP_CONCAT([DISTINCT] column_name[ORDER BY column_name [ASC…...
Golang httputil 包深度解析:HTTP请求与响应的操控艺术
标题:Golang httputil 包深度解析:HTTP请求与响应的操控艺术 引言 在Go语言的丰富标准库中,net/http/httputil包是一个强大的工具集,它提供了操作HTTP请求和响应的高级功能。从创建自定义的HTTP代理到调试HTTP流量,h…...
SQLALchemy 分页
SQLALchemy 分页 1. 使用SQLAlchemy的`slice`和`offset`/`limit`SQLAlchemy 1.4及更新版本SQLAlchemy 1.3及更早版本使用第三方库注意事项在Web开发中,分页是处理大量数据时一个非常重要的功能。SQLAlchemy是一个流行的Python SQL工具包和对象关系映射(ORM)库,它允许开发者…...
快速上手体验MyPerf4J监控springboot应用(docker版快速开始-本地版)
使用MyPerf4J监控springboot应用 快速启动influxdb时序数据库日志收集器telegrafgrafana可视化界面安装最终效果 项目地址 项目简介: 一个针对高并发、低延迟应用设计的高性能 Java 性能监控和统计工具。 价值 快速定位性能瓶颈快速定位故障原因 快速启动 监控本地应用 idea配…...
C语言 之 strlen、strcpy、strcat、strcmp字符串函数的使用和模拟实现
文章目录 strlen的使用和模拟实现函数的原型strlen模拟实现:方法1方法2方法3 strcpy的使用和模拟实现函数的原型strcpy的模拟实现: strcat的使用和模拟实现函数的原型strcat的模拟实现: strcmp的使用和模拟实现函数的原型strcmp的模拟实现 本…...
CAPL使用结构体的方式组装一条DoIP车辆识别请求报文(payload type 0x0002)
DoIP车辆识别请求(payload type 0x0002)报文的格式为: /******************************************************** +--------+--------+--------+--------+ |version | inVer | type | +--------+--------+--------+--------+ | length …...
数据接入教学
数据接入教学 1、开通外部网络策略2、检查本地防火墙策略3、测试网络连通性4、工具抓包命令5、本地测试发送与监听 1、开通外部网络策略 保证外部网络联通、保证内部防火墙开通策略(可以关闭进行测试) 2、检查本地防火墙策略 关闭进行测试 停止firewa…...
炒作将引发人工智能寒冬
我们似乎经常看到人工智能的进步被吹捧为机器真正变得智能的一大飞跃。我将在这里挑选其中的一个例子,并确切解释为什么这种态度会为人工智能的未来埋下隐患。 这很酷,这是一个非常困难且非常具体的问题,这个团队花了3 年时间才解决。他们一定…...
clamp靶机复现
靶机设置 设置靶机为NAT模式 靶机IP发现 nmap 192.168.112.0/24 靶机IP为192.168.112.143 目录扫描 dirsearch 192.168.112.143 访问浏览器 提示让我们扫描更多的目录 换个更大的字典,扫出来一个 /nt4stopc/ 目录 目录拼接 拼接 /nt4stopc/ 发现页面中有很多…...
mfc100u.dll丢失问题分析,详细讲解mfc100u.dll丢失解决方法
面对mfc100u.dll文件丢失带来的挑战时,许多用户都可能感到有些无助,尤其是当这一问题影响到他们日常使用的软件时。但实际上,存在几种有效方法可以帮助您快速恢复该关键的系统文件。为了方便不同水平的用户,本文将详细解析各种处理…...
【C++】什么是内存管理?
如果有不懂的地方,可以看我以往文章哦! 个人主页:CSDN_小八哥向前冲 所属专栏:C入门 目录 C/C内存分布 C内存管理方式 new/delete操作内置类型 new/delete操作自定义类型 operator new与operator delete函数 new和delete实现…...
产业经济大脑建设方案(五)
为了提升产业经济的智能化水平,我们提出建设一个综合产业经济大脑系统,该系统通过整合大数据分析、人工智能和云计算技术,构建全方位的数据采集、处理和决策支持平台。该平台能够实时监测产业链各环节的数据,运用智能算法进行深度…...
如何在 Odoo 16 中覆盖创建、写入和取消链接方法
Odoo 是一款强大的开源业务应用程序套件,可为各种业务运营提供广泛的功能。其主要功能之一是能够自定义和扩展其功能以满足特定的业务需求。在本博客中,我们将探讨如何覆盖Odoo 16中的创建、写入和取消链接方法,从而使您无需修改核心代码…...
pip离线安装accelerate
一、离线下载到当前文件夹 pip download accelerate -d ./anzhuangbao# 制定版本使用以下命令pip download accelerate0.32.0 -d ./anzhuangbao二、离线安装 cd anzhuangbaipip install --no-index --find-links. accelerate三、验证是否安装 pip show accelerateAccelerate: …...
VUE3请求意外报跨越错误或者500错误问题
1.有可能是请求传参和传参类型写错了 首先要确保该请求接口是支持跨域的(不支持叫后端改) access-control-allow-headers:Content-Type, Accept, Access-Control-Allow-Origin, api_key, Authorization access-control-allow-methods:GET, POST, OPTIO…...
vue 关于两个if条件中的promise
一、案例效果 期望if判断条件里的两个promise 都同时执行完成 二、 初始代码案例 const formatDetail async (fnArgsJsonParams: MapLogicType) > {if (fnArgsJsonParams?.targetFeatureName) {const resDetailData await formatFeatureInfo(fnArgsJsonParams.targetF…...
观成科技:隐蔽隧道工具Ligolo-ng加密流量分析
1.工具介绍 Ligolo-ng是一款由go编写的高效隧道工具,该工具基于TUN接口实现其功能,利用反向TCP/TLS连接建立一条隐蔽的通信信道,支持使用Let’s Encrypt自动生成证书。Ligolo-ng的通信隐蔽性体现在其支持多种连接方式,适应复杂网…...
前端倒计时误差!
提示:记录工作中遇到的需求及解决办法 文章目录 前言一、误差从何而来?二、五大解决方案1. 动态校准法(基础版)2. Web Worker 计时3. 服务器时间同步4. Performance API 高精度计时5. 页面可见性API优化三、生产环境最佳实践四、终极解决方案架构前言 前几天听说公司某个项…...
Go 语言接口详解
Go 语言接口详解 核心概念 接口定义 在 Go 语言中,接口是一种抽象类型,它定义了一组方法的集合: // 定义接口 type Shape interface {Area() float64Perimeter() float64 } 接口实现 Go 接口的实现是隐式的: // 矩形结构体…...
【Go】3、Go语言进阶与依赖管理
前言 本系列文章参考自稀土掘金上的 【字节内部课】公开课,做自我学习总结整理。 Go语言并发编程 Go语言原生支持并发编程,它的核心机制是 Goroutine 协程、Channel 通道,并基于CSP(Communicating Sequential Processes࿰…...
智能AI电话机器人系统的识别能力现状与发展水平
一、引言 随着人工智能技术的飞速发展,AI电话机器人系统已经从简单的自动应答工具演变为具备复杂交互能力的智能助手。这类系统结合了语音识别、自然语言处理、情感计算和机器学习等多项前沿技术,在客户服务、营销推广、信息查询等领域发挥着越来越重要…...
QT3D学习笔记——圆台、圆锥
类名作用Qt3DWindow3D渲染窗口容器QEntity场景中的实体(对象或容器)QCamera控制观察视角QPointLight点光源QConeMesh圆锥几何网格QTransform控制实体的位置/旋转/缩放QPhongMaterialPhong光照材质(定义颜色、反光等)QFirstPersonC…...
[大语言模型]在个人电脑上部署ollama 并进行管理,最后配置AI程序开发助手.
ollama官网: 下载 https://ollama.com/ 安装 查看可以使用的模型 https://ollama.com/search 例如 https://ollama.com/library/deepseek-r1/tags # deepseek-r1:7bollama pull deepseek-r1:7b改token数量为409622 16384 ollama命令说明 ollama serve #:…...
day36-多路IO复用
一、基本概念 (服务器多客户端模型) 定义:单线程或单进程同时监测若干个文件描述符是否可以执行IO操作的能力 作用:应用程序通常需要处理来自多条事件流中的事件,比如我现在用的电脑,需要同时处理键盘鼠标…...
第7篇:中间件全链路监控与 SQL 性能分析实践
7.1 章节导读 在构建数据库中间件的过程中,可观测性 和 性能分析 是保障系统稳定性与可维护性的核心能力。 特别是在复杂分布式场景中,必须做到: 🔍 追踪每一条 SQL 的生命周期(从入口到数据库执行)&#…...
【p2p、分布式,区块链笔记 MESH】Bluetooth蓝牙通信 BLE Mesh协议的拓扑结构 定向转发机制
目录 节点的功能承载层(GATT/Adv)局限性: 拓扑关系定向转发机制定向转发意义 CG 节点的功能 节点的功能由节点支持的特性和功能决定。所有节点都能够发送和接收网格消息。节点还可以选择支持一个或多个附加功能,如 Configuration …...
