当前位置: 首页 > news >正文

Detr源码解读(mmdetection)

Detr源码解读(mmdetection)

1、原理简要介绍

在这里插入图片描述
整体流程: 在给定一张输入图像后,1)特征向量提取: 首先经过ResNet提取图像的最后一层特征图F。注意此处仅仅用了一层特征图,是因为后续计算复杂度原因,另外,由于仅用最后一层特征图,故对小目标检测不友好,这也是后续deformable detr改进的原因。 2)添加位置编码信息: 经F拉平成一维张量并添加上位置编码信息得到I。3)Transformer中encoder部分4)Transformer中decoder部分,学习位置嵌入object queries。5)FFN部分:6)后续匈牙利匹配+损失计算。

2、mmdetection中源码介绍

2.1. 整体逻辑

Detr的内部逻辑如下:在mmdet/models/detector/single_stage.py。即首先提取图像特征向量,之后经过DetrHead来计算最终的损失。img[b,3,224,224] x[b,2048,7,7]

def forward_train(self,img,img_metas,gt_bboxes,gt_labels,gt_bboxes_ignore=None):super(SingleStageDetector, self).forward_train(img, img_metas)# img[b,3,224,224] x[b,2048,7,7]x = self.extract_feat(img) # 提取图像特征向量  # 经过DetrHead得到loss                   losses = self.bbox_head.forward_train(x, img_metas, gt_bboxes,gt_labels, gt_bboxes_ignore)return losses

forward_train
跟其他的检测头差不多,先是调用自己,也就是自身的 forward 函数,得到输出的 class label 和 reg coordinate,再调用自身的 loss 函数,不过这里是重载了一下,将 img_meta 传输进了 forward 函数的参数。执行完outs = self(x, img_metas)跳转到forward的num_levels = len(feats)

def forward_train(self,x,img_metas,gt_bboxes,gt_labels=None,gt_bboxes_ignore=None,proposal_cfg=None,**kwargs):"""Forward function for training mode.Args:x (list[Tensor]): Features from backbone.img_metas (list[dict]): Meta information of each image每个图像的元信息, e.g.,image size, scaling factor, etc.gt_bboxes (Tensor): Ground truth bboxes of the image,图像的地面真相框shape (num_gts, 4).gt_labels (Tensor): Ground truth labels of each box,shape (num_gts,).gt_bboxes_ignore (Tensor): Ground truth bboxes to be ignored,要忽略的基本事实框,shape (num_ignored_gts, 4).proposal_cfg (mmcv.Config): Test / postprocessing configuration,测试/后处理配置if None, test_cfg would be used.Returns:dict[str, Tensor]: A dictionary of loss components.损失成分词典。"""assert proposal_cfg is None, '"proposal_cfg" must be None'outs = self(x, img_metas) #x[b,2048,7,7]if gt_labels is None:loss_inputs = outs + (gt_bboxes, img_metas)else:loss_inputs = outs + (gt_bboxes, gt_labels, img_metas)losses = self.loss(*loss_inputs, gt_bboxes_ignore=gt_bboxes_ignore)return losses

执行完outs = self(x, img_metas)跳转到forward的num_levels = len(feats)。feats[b,2048,7,7]

    def forward(self, feats, img_metas):#这里默认为1,因为DETR默认用最后一层特征图num_levels = len(feats)img_metas_list = [img_metas for _ in range(num_levels)]return multi_apply(self.forward_single, feats, img_metas_list)

执行完return multi_apply(self.forward_single, feats, img_metas_list)跳转到forward_single函数

2.2. 图像特征向量提取

mmdet中提取图像特征向量的config配置文件如下,可以发现用ResNet50并只提取了最后一层特征层,即out_indices=(3,)。骨干网络会输出特征图的1/32,输入为【2,3,224,224】。通过backbone后得到图像大小为【2, 2048, 7, 7】和mask大小为【2,7,7】

backbone=dict(type='ResNet',depth=50,num_stages=4,out_indices=(3, ),     # detr仅要resnet50的最后一层特征图,并不需要FPNfrozen_stages=1,norm_cfg=dict(type='BN', requires_grad=False),norm_eval=True,style='pytorch',init_cfg=dict(type='Pretrained', checkpoint='torchvision://resnet50'))

2.3. 给图像特征向量添加位置编码信息(forward_single函数,里面是 head 前向的逻辑)。

本部分代码来自mmdet/models/dense_heads/detr_head.py的 forward_single函数中。
  mmdet中生成位置编码信息借助的是mask矩阵(所谓的mask就是为了统一批次大小而对图像进行了pad,被填充的部分在后续计算多头注意力时应该舍弃)故需要一个mask矩阵遮挡住,具体形状为[batch, h,w]这里先贴下生成mask的过程:

batch_size = x.size(0)   
input_img_h, input_img_w = img_metas[0]['batch_input_shape']# 一个批次图像大小
# 先将 mask 设置为全 1
masks = x.new_ones((batch_size, input_img_h, input_img_w))  # [b,224,224]
# 对每一张图来说,在原来图片有像素的地方把 mask 置 0
# 因此 mask 中 padding 的地方才是 1
for img_id in range(batch_size):img_h, img_w, _ = img_metas[img_id]['img_shape']    # 创建了一个mask,非0代表无效区域, 0 代表有效区域masks[img_id, :img_h, :img_w] = 0                   # 将pad部分置为1,非pad部分置为0.

输入图像的经过resnet50下采样后hw已经变了,所以还需进一步将mask下采样成和图像特征向量一样的shape。代码如下:

# 将每一层的特征图先投影到指定的特征维度,2048通道太多了转成256通道
x = self.input_proj(x) #Conv2d(self.in_channels, self.embed_dims, kernel_size=1)#[b,256,7,7]
# interpolate masks to have the same spatial shape with x
masks = F.interpolate(                                                         #masks[b,7,7]masks.unsqueeze(1), size=x.shape[-2:]).to(torch.bool).squeeze(1) # masks和x的shape一样:[b,2,2]

后续便可以生成位置编码部分(mmdet/models/utils/position_encoding.py),代码里采用了sine位置编码,该函数给masks的每个像素位置生成了一个256维的唯一的位置向量。shape:[B, 256, 7, 7]

# position encoding
pos_embed = self.positional_encoding(masks)

2.4 送入Transformer

4.1. 整体逻辑

在得到图像特征向量x=[b,256,7,7]、masks[b,7,7]矩阵以及位置编码pos_embed[b,256,7,7]后,便可送入Transformer。进入transformer的之前四个变量维度分别为, x->[2, 256, 7, 7],mask->[2, 7, 7],query_embed->[100, 256],pos_embed->[2, 256, 7, 7]

# outs_dec: [nb_nb_decdec, bs, num_query, embed_dim]
outs_dec, _ = self.transformer(x, masks, self.query_embedding.weight,pos_embed)

在进入transformer之前,定义了一个query_embed(就是后边的object query),其第一个维度为num_queries(原文解释为一张图片里的最大检测数量),第二个维度为hidden_dim,就是256。

self.query_embedding = nn.Embedding(self.num_query, self.embed_dims)

关键是理清encoder和decoder的QKV分别指啥, 本部分代码来自mmdet\models\utils\transformer.py的 Transformer函数中。看代码:

       bs, c, h, w = x.shape# use `view` instead of `flatten` for dynamically exporting to ONNXx = x.view(bs, c, -1).permute(2, 0, 1)  # [bs, c, h, w] -> [h*w, bs, c]  [49,2,256]pos_embed = pos_embed.view(bs, c, -1).permute(2, 0, 1)                 # [49,2,256]query_embed = query_embed.unsqueeze(1).repeat(                         #[100,b,256]1, bs, 1)  # [num_query, dim] -> [num_query, bs, dim]mask = mask.view(bs, -1)  # [bs, h, w] -> [bs, h*w]   [2,49]"""经过变换后的四个变量维度分别为, img->[49, 2, 256],mask->[2, 49],query_embed->[100, 2, 256],pos_embed->[49, 2, 256]"""memory = self.encoder(query=x,                            # [49,b,256]key=None,value=None,query_pos=pos_embed,                 # [49,b,256]query_key_padding_mask=mask)  # [b,49]target = torch.zeros_like(query_embed) # decoder初始化全0# out_dec: [num_layers, num_query, bs, dim]out_dec = self.decoder(query=target,              # 全0的target, 后续在MultiHeadAttn中执行了key=memory,              # query = query + query_pos又加回去了。value=memory,key_pos=pos_embed,query_pos=query_embed, # [num_query, bs, dim]key_padding_mask=mask)# outs_dec: [nb_nb_decdec, bs, num_query, embed_dim] [6,2,100,256]out_dec = out_dec.transpose(1, 2)memory = memory.permute(1, 2, 0).reshape(bs, c, h, w)return out_dec, memory

其中encoder中q就是x,kv分别为None,query_pos代表位置编码,而query_key_padding_mask就是mask。decoder的q是全0的target,后续decoder会迭代更新q,而kv则 是memory,即encoder的输出;key_pos依旧是k的位置信息;query_embed即论文中Object query,可学习位置信息;key_padding_mask依然是mask。

4.2. encoder部分

先看下encoder初始化部分,内部循环调用了6次BaseTransformerLayer,因此只需讲解一层EncoderLayer即可。将img,mask,pos_embed送入transformer encoder中,进行注意力操作。得到[49, 2, 256]的输出

encoder=dict(type='DetrTransformerEncoder',num_layers=6,                        # 经过6层Layertransformerlayers=dict(              # 每层layer内部使用多头注意力type='BaseTransformerLayer',attn_cfgs=[dict(type='MultiheadAttention',embed_dims=256,           num_heads=8,dropout=0.1)],feedforward_channels=2048,        # FFN中间层的维度   ffn_dropout=0.1,operation_order=('self_attn', 'norm', 'ffn', 'norm'))), # 定义运算流程

先跳转到mmdet\models\utils\transformer.py的DetrTransformerEncoder函数。再来看下BaseTransformerLayer的forward部分。该部分可以损失detr的核心部分了,因为本质上mmdet内部只是封装了pytorch现有的nn.MultiHeadAtten函数。所以,需要理解nn.MultiHeadAttn中两种mask参数的含义,限于篇幅原因,这里可参考nn.Transformer来理解这两个mask。 不过简单理解就是:attn_mask在detr中没用到,仅用key_padding_mask。attn_mask是为了遮挡未来文本信息用的,而图像可以看到全部的信息,因此不需要用attn_mask。
在这里插入图片描述

def forward(self,query,key=None,value=None,query_pos=None,key_pos=None,attn_masks=None,query_key_padding_mask=None,key_padding_mask=None,**kwargs):#Forward function for `TransformerDecoderLayer`.norm_index = 0attn_index = 0ffn_index = 0identity = queryif attn_masks is None:attn_masks = [None for _ in range(self.num_attn)]elif isinstance(attn_masks, torch.Tensor):attn_masks = [copy.deepcopy(attn_masks) for _ in range(self.num_attn)]warnings.warn(f'Use same attn_mask in all attentions in 'f'{self.__class__.__name__} ')else:assert len(attn_masks) == self.num_attn, f'The length of ' \f'attn_masks {len(attn_masks)} must be equal ' \f'to the number of attention in ' \f'operation_order {self.num_attn}'for layer in self.operation_order:                  # 遍历config文件的顺序if layer == 'self_attn':temp_key = temp_value = query query = self.attentions[attn_index](        # 内部调用nn.MultiHeadAttnquery,temp_key,temp_value,identity if self.pre_norm else None,query_pos=query_pos,                    # 若有位置编码信息则和query相加 key_pos=query_pos,                       # 若有位置编码信息则和key相加 attn_mask=attn_masks[attn_index],key_padding_mask=query_key_padding_mask,**kwargs)attn_index += 1identity = queryelif layer == 'norm':query = self.norms[norm_index](query)      # 层归一化norm_index += 1elif layer == 'cross_attn':                    # decoder用到query = self.attentions[attn_index](     query,key,value,identity if self.pre_norm else None,query_pos=query_pos,                   # 若有位置编码信息则和query相加 key_pos=key_pos,                        # 若有位置编码信息则和key相加 attn_mask=attn_masks[attn_index],key_padding_mask=key_padding_mask,**kwargs)attn_index += 1identity = queryelif layer == 'ffn':                         # 残差连接加全连接层query = self.ffns[ffn_index](query, identity if self.pre_norm else None)ffn_index += 1return query

decoder部分和encoder流程类似,只是多了交叉注意力。decoder部分将[49,2,256]的输出和query_embed[100,2,256]输入到transformer decoder中,得到[6, 2, 100, 256]的输出。这里是合并了6个不同层级解码层的输出,其实只需要最后一层即可。
decoder这里其实是将query_embed和feature做了注意力机制,q为query_embed[100, 2, 256],k为memory也就是feature[49, 2, 256],v也是memory[49, 2, 256]。

**

总结

**
decoder的输出经过Prediction feed-forward networks (FFNs)生成最终的预测。即[6,2,100,256]经过线性层生成[6,2,100,92]的类别预测,经过线性层生成[6, 2, 100, 4]的框坐标预测。
由于后续在detr上改进的论文对匈牙利算法以及loss计算改动不大,因此这部分代码就不讲解了。

相关文章:

Detr源码解读(mmdetection)

Detr源码解读(mmdetection) 1、原理简要介绍 整体流程: 在给定一张输入图像后,1)特征向量提取: 首先经过ResNet提取图像的最后一层特征图F。注意此处仅仅用了一层特征图,是因为后续计算复杂度原因,另外&am…...

一个.Net Core开发的,撑起月6亿PV开源监控解决方案

更多开源项目请查看:一个专注推荐.Net开源项目的榜单 项目发布后,对于我们程序员来说,项目还不是真正的结束,保证项目的稳定运行也是非常重要的,而对于服务器的监控,就是保证稳定运行的手段之一。对数据库、…...

C语言数据结构初阶(2)----顺序表

目录 1. 顺序表的概念及结构 2. 动态顺序表的接口实现 2.1 SLInit(SL* ps) 的实现 2.2 SLDestory(SL* ps) 的实现 2.3 SLPrint(SL* ps) 的实现 2.4 SLCheckCapacity(SL* ps) 的实现 2.5 SLPushBack(SL* ps, SLDataType x) 的实现 2.6 SLPopBack(SL* ps) 的实现 2.7 SLP…...

K8S常用命令速查手册

K8S常用命令速查手册一. K8S日常维护常用命令1.1 查看kubectl版本1.2 启动kubelet1.3 master节点执行查看所有的work-node节点列表1.4 查看所有的pod1.5 检查kubelet运行状态排查问题1.6 诊断某pod故障1.7 诊断kubelet故障方式一1.8 诊断kubelet故障方式二二. 端口策略相关2.1 …...

Linux系统下命令行安装MySQL5.6+详细步骤

1、因为想在腾讯云的服务器上创建自己的数据库,所以我在这里是通过使用Xshell 7来连接腾讯云的远程服务器; 2、Xshell 7与服务器连接好之后,就可以开始进行数据库的安装了(如果服务器曾经安装过数据库,得将之前安装的…...

13.STM32超声波模块讲解与实战

目录 1.超声波模块讲解 2.超声波时序图 3.超声波测距步骤 4.项目实战 1.超声波模块讲解 超声波传感器模块上面通常有两个超声波元器件,一个用于发射,一个用于接收。电路板上有4个引脚:VCC GND Trig(触发)&#xff…...

逆向之Windows PE结构

写在前面 对于Windows PE文件结构,个人认为还是非常有必要掌握和了解的,不管是在做逆向分析、免杀、病毒分析,脱壳加壳都是有着非常重要的技能。但是PE文件的学习又是一个非常枯燥过程,希望本文可以帮你有一个了解。 PE文件结构…...

ACL是什么

目录 一、ACL是什么 二、ACL的使用:setacl与getacl 1)针对特定使用者的方式: 1. 创建acl_test1后设置其权限 2. 读取acl_test1的权限 2)针对特定群组的方式: 3)针对有效权限 mask 的设置方式&#xf…...

操作系统核心知识点整理--内存篇

操作系统核心知识点整理--内存篇按段对内存进行管理内存分区内存分页为什么需要多级页表TLB解决了多级页表什么样的缺陷?TLB缓存命中率高的原理是什么?段页结合: 为什么需要虚拟内存?虚拟地址到物理地址的转换过程段页式管理下程序如何载入内存?页面置…...

从零开始学习iftop流量监控(找出服务器耗费流量最多的ip和端口)

一、iftop是什么iftop是类似于top的实时流量监控工具。作用&#xff1a;监控网卡的实时流量&#xff08;可以指定网段&#xff09;、反向解析IP、显示端口信息等官网&#xff1a;http://www.ex-parrot.com/~pdw/iftop/二、界面说明>代表发送数据&#xff0c;< 代表接收数…...

第一篇博客------自我介绍篇

目录&#x1f506;自我介绍&#x1f506;学习目标&#x1f506;如何学习单片机Part 1 基础理论知识学习Part 2 单片机实践Part 3 单片机硬件设计&#x1f506;希望进入的公司&#x1f506;结束语&#x1f506;自我介绍 Hello!!!我是一名即已经步入大二的计算机小白。 --------…...

No suitable device found for this connection (device lo not available(网络突然出问题)

当执行 ifup ens33 出现错误&#xff1a;[rootlocalhost ~]# ifup ens33Error: Connection activation failed: No suitable device found for this connection (device lo not available because device is strictly unmanaged).1解决办法&#xff1a;[rootlocalhost ~]# chkc…...

【算法设计技巧】分治算法

分治算法 用于设计算法的另一种常用技巧为分治算法(divide and conquer)。分治算法由两部分组成&#xff1a; 分(divide)&#xff1a;递归解决较小的问题(当然&#xff0c;基准情况除外)治(conquer)&#xff1a;然后&#xff0c;从子问题的解构建原问题的解。 传统上&#x…...

已解决kettle新建作业,点击保存抛出异常Invalid state, the Connection object is closed.

已解决kettle新建作业&#xff0c;点击保存进资源数据库抛出异常Invalid state, the Connection object is closed.的解决方法&#xff0c;亲测有效&#xff01;&#xff01;&#xff01; 文章目录报错问题报错翻译报错原因解决方法联系博主免费帮忙解决报错报错问题 一个小伙伴…...

【设计模式】 工厂模式介绍及C代码实现

【设计模式】 工厂模式介绍及C代码实现 背景 在软件系统中&#xff0c;经常面临着创建对象的工作&#xff1b;由于需求的变化&#xff0c;需要创建的对象的具体类型经常变化。 如何应对这种变化&#xff1f;如何绕过常规的对象创建方法(new)&#xff0c;提供一种“封装机制”来…...

深入浅出PaddlePaddle函数——paddle.arange

分类目录&#xff1a;《深入浅出PaddlePaddle函数》总目录 相关文章&#xff1a; 深入浅出TensorFlow2函数——tf.range 深入浅出Pytorch函数——torch.arange 深入浅出PaddlePaddle函数——paddle.arange 语法 paddle.arange(start0, endNone, step1, dtypeNone, nameNone…...

X86 ATT常用寄存器及其操作指令

X86 AT&T常用寄存器及其操作指令 常用寄存器 esp寄存器&#xff1a;当我们需要访问堆栈帧中的变量时&#xff0c;可以使用esp寄存器来获取堆栈帧的基址&#xff0c;以便能够正确地访问堆栈帧中的变量。ebp寄存器&#xff1a;当我们需要调用一个函数时&#xff0c;可以使用…...

Kotlin 高端玩法之DSL

如何在 kotlin 优雅的封装匿名内部类&#xff08;DSL、高阶函数&#xff09;匿名内部类在 Java 中是经常用到的一个特性&#xff0c;例如在 Android 开发中的各种 Listener&#xff0c;使用时也很简单&#xff0c;比如&#xff1a;//lambda button.setOnClickListener(v -> …...

理光M2701复印机载体初始化方法

理光M2701基本参数&#xff1a; 产品类型&#xff1a;数码复合机 颜色类型&#xff1a;黑白 复印速度&#xff1a;单面&#xff1a;27cpm 双面&#xff1a;16cpm 涵盖功能&#xff1a;复印、打印、扫描 网络功能&#xff1a;支持无线、有线网络打印 接口类型&#xff1a;USB2.0…...

2.25Maven的安装与配置

一.Mavenmaven是一个Java世界中,非常知名的"工程管理工具"/构建工具"核心功能:1.管理依赖在进行一个A 操作之前,要先进行一个B操作.依赖有的时候是很复杂的,而且是嵌套的2.构建/编译(也是在调用jdk)3. 打包把java代码给构建成jar或者warjar就是一个特殊的压缩包…...

linux之kylin系统nginx的安装

一、nginx的作用 1.可做高性能的web服务器 直接处理静态资源&#xff08;HTML/CSS/图片等&#xff09;&#xff0c;响应速度远超传统服务器类似apache支持高并发连接 2.反向代理服务器 隐藏后端服务器IP地址&#xff0c;提高安全性 3.负载均衡服务器 支持多种策略分发流量…...

SkyWalking 10.2.0 SWCK 配置过程

SkyWalking 10.2.0 & SWCK 配置过程 skywalking oap-server & ui 使用Docker安装在K8S集群以外&#xff0c;K8S集群中的微服务使用initContainer按命名空间将skywalking-java-agent注入到业务容器中。 SWCK有整套的解决方案&#xff0c;全安装在K8S群集中。 具体可参…...

<6>-MySQL表的增删查改

目录 一&#xff0c;create&#xff08;创建表&#xff09; 二&#xff0c;retrieve&#xff08;查询表&#xff09; 1&#xff0c;select列 2&#xff0c;where条件 三&#xff0c;update&#xff08;更新表&#xff09; 四&#xff0c;delete&#xff08;删除表&#xf…...

linux arm系统烧录

1、打开瑞芯微程序 2、按住linux arm 的 recover按键 插入电源 3、当瑞芯微检测到有设备 4、松开recover按键 5、选择升级固件 6、点击固件选择本地刷机的linux arm 镜像 7、点击升级 &#xff08;忘了有没有这步了 估计有&#xff09; 刷机程序 和 镜像 就不提供了。要刷的时…...

页面渲染流程与性能优化

页面渲染流程与性能优化详解&#xff08;完整版&#xff09; 一、现代浏览器渲染流程&#xff08;详细说明&#xff09; 1. 构建DOM树 浏览器接收到HTML文档后&#xff0c;会逐步解析并构建DOM&#xff08;Document Object Model&#xff09;树。具体过程如下&#xff1a; (…...

Cinnamon修改面板小工具图标

Cinnamon开始菜单-CSDN博客 设置模块都是做好的&#xff0c;比GNOME简单得多&#xff01; 在 applet.js 里增加 const Settings imports.ui.settings;this.settings new Settings.AppletSettings(this, HTYMenusonichy, instance_id); this.settings.bind(menu-icon, menu…...

WordPress插件:AI多语言写作与智能配图、免费AI模型、SEO文章生成

厌倦手动写WordPress文章&#xff1f;AI自动生成&#xff0c;效率提升10倍&#xff01; 支持多语言、自动配图、定时发布&#xff0c;让内容创作更轻松&#xff01; AI内容生成 → 不想每天写文章&#xff1f;AI一键生成高质量内容&#xff01;多语言支持 → 跨境电商必备&am…...

UR 协作机器人「三剑客」:精密轻量担当(UR7e)、全能协作主力(UR12e)、重型任务专家(UR15)

UR协作机器人正以其卓越性能在现代制造业自动化中扮演重要角色。UR7e、UR12e和UR15通过创新技术和精准设计满足了不同行业的多样化需求。其中&#xff0c;UR15以其速度、精度及人工智能准备能力成为自动化领域的重要突破。UR7e和UR12e则在负载规格和市场定位上不断优化&#xf…...

图表类系列各种样式PPT模版分享

图标图表系列PPT模版&#xff0c;柱状图PPT模版&#xff0c;线状图PPT模版&#xff0c;折线图PPT模版&#xff0c;饼状图PPT模版&#xff0c;雷达图PPT模版&#xff0c;树状图PPT模版 图表类系列各种样式PPT模版分享&#xff1a;图表系列PPT模板https://pan.quark.cn/s/20d40aa…...

稳定币的深度剖析与展望

一、引言 在当今数字化浪潮席卷全球的时代&#xff0c;加密货币作为一种新兴的金融现象&#xff0c;正以前所未有的速度改变着我们对传统货币和金融体系的认知。然而&#xff0c;加密货币市场的高度波动性却成为了其广泛应用和普及的一大障碍。在这样的背景下&#xff0c;稳定…...