当前位置: 首页 > news >正文

〖open-mmlab: MMDetection〗解析文件:mmdet/models/detectors/two_stage.py

目录

  • MMDetection中的两阶段检测器:深入解析`two_stage.py`源码
    • 两阶段检测器概述
    • `two_stage.py`的关键组件
      • 类定义和初始化
      • 构造函数
      • Neck头配置
      • RPN头配置
      • RoI头配置
      • `_load_from_state_dict`
        • 方法概述
        • 参数解释
        • 代码解析
      • 特征提取
        • 方法签名
        • 文档字符串(Docstring)
        • 方法体
        • 返回值
      • 前向传播
        • 方法签名
        • 文档字符串(Docstring)
        • 方法体
        • 返回值
      • 损失计算
        • 方法签名
        • 文档字符串(Docstring)
        • 方法体
        • 返回值
      • 预测
        • 方法签名
        • 文档字符串(Docstring)
        • 方法体
        • 返回值
    • 结论

MMDetection中的两阶段检测器:深入解析two_stage.py源码

在目标检测领域,两阶段检测器因其在准确性和速度之间取得的平衡而成为基石方法之一。MMDetection是一个基于PyTorch的开源目标检测工具箱,它为实现此类检测器提供了强大的框架。在这篇博客文章中,我们将深入解析two_stage.py源码,这是MMDetection两阶段检测架构中的核心部分。

两阶段检测器概述

两阶段检测器的操作分为两个主要阶段:

  1. 区域提议网络(Region Proposal Network, RPN):第一阶段识别潜在的目标位置,即区域提议。
  2. 感兴趣区域(Region of Interest, RoI)头:第二阶段对这些提议进行细化,以得到精确的目标检测结果。

two_stage.py的关键组件

TwoStageDetector类是MMDetection中两阶段检测器的基础构建模块。让我们分解其核心组件:

类定义和初始化

@MODELS.register_module()
class TwoStageDetector(BaseDetector):"""两阶段检测器的基类。"""
  • 类通过@MODELS.register_module()装饰器注册在MMDetection的模型注册表中,使其易于配置和实例化。

构造函数

def __init__(self, backbone, neck=None, rpn_head=None, roi_head=None, train_cfg=None, test_cfg=None, data_preprocessor=None, init_cfg=None):super().__init__(data_preprocessor=data_preprocessor, init_cfg=init_cfg)self.backbone = MODELS.build(backbone)...
  • 构造函数使用各种组件(如骨干网络、颈部网络、RPN头和RoI头)初始化检测器。它还处理训练和测试的配置。

Neck头配置

if neck is not None:self.neck = MODELS.build(neck)

RPN头配置

if rpn_head is not None:rpn_train_cfg = train_cfg.rpn if train_cfg is not None else Nonerpn_head_ = rpn_head.copy()rpn_head_.update(train_cfg=rpn_train_cfg, test_cfg=test_cfg.rpn)rpn_head_num_classes = rpn_head_.get('num_classes', None)if rpn_head_num_classes is None:rpn_head_.update(num_classes=1)else:if rpn_head_num_classes != 1:warnings.warn('The `num_classes` should be 1 in RPN, but get 'f'{rpn_head_num_classes}, please set ''rpn_head.num_classes = 1 in your config file.')rpn_head_.update(num_classes=1)self.rpn_head = MODELS.build(rpn_head_)
  • RPN头使用训练和测试配置进行配置。确保num_classes设置为1对于RPN至关重要,因为它只预测目标存在,而不是类别标签。
    这段代码是两阶段检测器中初始化和配置区域提议网络(Region Proposal Network, RPN)的逻辑部分。让我们逐行分析:
  1. 检查RPN头是否提供:

    if rpn_head is not None:
    

    这行代码检查是否提供了rpn_head配置。如果提供了,那么进入代码块进行进一步的配置。

  2. 获取训练配置:

    rpn_train_cfg = train_cfg.rpn if train_cfg is not None else None
    

    这行代码尝试从train_cfg(训练配置)中获取RPN部分的配置。如果train_cfg存在,则rpn_train_cfg被设置为train_cfg中的rpn部分,否则设置为None

  3. 复制RPN头配置:

    rpn_head_ = rpn_head.copy()
    

    这行代码创建了rpn_head配置的一个副本,以避免直接修改原始配置。

  4. 更新RPN头配置:

    rpn_head_.update(train_cfg=rpn_train_cfg, test_cfg=test_cfg.rpn)
    

    这行代码将训练和测试的配置更新到RPN头的配置中。这样做是为了确保RPN在训练和测试时使用正确的参数。

  5. 获取RPN头的类别数:

    rpn_head_num_classes = rpn_head_.get('num_classes', None)
    

    这行代码尝试从RPN头配置中获取num_classes参数。如果不存在,则默认为None

  6. 设置RPN头的类别数:

    if rpn_head_num_classes is None:rpn_head_.update(num_classes=1)
    else:if rpn_head_num_classes != 1:warnings.warn('The `num_classes` should be 1 in RPN, but get 'f'{rpn_head_num_classes}, please set ''rpn_head.num_classes = 1 in your config file.')rpn_head_.update(num_classes=1)
    

    这部分代码首先检查num_classes是否为None。如果是,那么它将num_classes设置为1。如果不是None,但值不是1,那么它会发出一个警告,提示用户RPN中的num_classes应该是1,因为RPN只负责检测物体的存在与否,而不是分类物体。然后,它将num_classes强制设置为1。

  7. 构建RPN头:

    self.rpn_head = MODELS.build(rpn_head_)
    

    这行代码使用更新后的RPN头配置来构建RPN模型。MODELS.build是一个工厂方法,根据提供的配置创建并返回RPN模型的实例。

总的来说,这段代码确保了RPN头被正确地配置和构建,特别是关于num_classes参数,它对于RPN的功能至关重要。


RoI头配置

if roi_head is not None:roi_head.update(train_cfg=rcnn_train_cfg)roi_head.update(test_cfg=test_cfg.rcnn)self.roi_head = MODELS.build(roi_head)
  • 与RPN头类似,RoI头也配置了相应的训练和测试配置。
    这段代码是两阶段检测器中初始化和配置感兴趣区域(Region of Interest, RoI)头的逻辑部分。让我们逐行分析:
  1. 检查RoI头是否提供:

    if roi_head is not None:
    

    这行代码检查是否提供了roi_head配置。如果提供了,那么进入代码块进行进一步的配置。

  2. 获取训练和测试配置:

    rcnn_train_cfg = train_cfg.rcnn if train_cfg is not None else None
    

    这行代码尝试从train_cfg(训练配置)中获取RoI部分的配置。如果train_cfg存在,则rcnn_train_cfg被设置为train_cfg中的rcnn部分,否则设置为None

  3. 更新RoI头的训练配置:

    roi_head.update(train_cfg=rcnn_train_cfg)
    

    这行代码将训练的配置更新到RoI头的配置中。这样做是为了确保RoI头在训练时使用正确的参数。

  4. 更新RoI头的测试配置:

    roi_head.update(test_cfg=test_cfg.rcnn)
    

    这行代码将测试的配置更新到RoI头的配置中。这样做是为了确保RoI头在测试时使用正确的参数。

  5. 构建RoI头:

    self.roi_head = MODELS.build(roi_head)
    

    这行代码使用更新后的RoI头配置来构建RoI模型。MODELS.build是一个工厂方法,根据提供的配置创建并返回RoI模型的实例。


_load_from_state_dict

def _load_from_state_dict(self, state_dict: dict, prefix: str,local_metadata: dict, strict: bool,missing_keys: Union[List[str], str],unexpected_keys: Union[List[str], str],error_msgs: Union[List[str], str]) -> None:"""Exchange bbox_head key to rpn_head key when loading single-stageweights into two-stage model."""bbox_head_prefix = prefix + '.bbox_head' if prefix else 'bbox_head'bbox_head_keys = [k for k in state_dict.keys() if k.startswith(bbox_head_prefix)]rpn_head_prefix = prefix + '.rpn_head' if prefix else 'rpn_head'rpn_head_keys = [k for k in state_dict.keys() if k.startswith(rpn_head_prefix)]if len(bbox_head_keys) != 0 and len(rpn_head_keys) == 0:for bbox_head_key in bbox_head_keys:rpn_head_key = rpn_head_prefix + \bbox_head_key[len(bbox_head_prefix):]state_dict[rpn_head_key] = state_dict.pop(bbox_head_key)super()._load_from_state_dict(state_dict, prefix, local_metadata,strict, missing_keys, unexpected_keys,error_msgs)

在深度学习模型的训练和部署过程中,加载预训练权重是一个常见的操作。在两阶段检测器中,由于其结构与单阶段检测器不同,因此在加载权重时需要特别注意权重的匹配和转换。_load_from_state_dict方法正是为了解决这个问题而设计的。下面,我们将详细解析这个方法的工作原理,并探讨其在两阶段检测器中的重要性。

方法概述

_load_from_state_dict方法是在加载预训练权重时调用的,它的作用是将单阶段检测器的权重转换为两阶段检测器可以使用的格式。这是通过交换bbox_headrpn_head的键来实现的。

参数解释
  • state_dict: 包含模型权重的字典。
  • prefix: 权重键的前缀,用于区分不同部分的权重。
  • local_metadata: 模型的元数据,通常包含模型结构信息。
  • strict: 是否严格匹配权重,如果为True,权重不匹配会抛出错误。
  • missing_keys: 缺失的权重键列表。
  • unexpected_keys: 多余的权重键列表。
  • error_msgs: 加载权重时的错误信息列表。
代码解析
  1. 定义bbox_headrpn_head的键前缀:

    bbox_head_prefix = prefix + '.bbox_head' if prefix else 'bbox_head'
    rpn_head_prefix = prefix + '.rpn_head' if prefix else 'rpn_head'
    

    这两行代码定义了bbox_headrpn_head的键前缀。如果提供了prefix,则将prefix加到bbox_headrpn_head前面,否则使用默认的键名。

  2. 获取bbox_headrpn_head的键:

    bbox_head_keys = [k for k in state_dict.keys() if k.startswith(bbox_head_prefix)]
    rpn_head_keys = [k for k in state_dict.keys() if k.startswith(rpn_head_prefix)]
    

    这两行代码通过列表推导式获取所有以bbox_head_prefixrpn_head_prefix开头的键,这些键分别对应单阶段检测器的边界框头和两阶段检测器的RPN头的权重。

  3. 权重转换:

    if len(bbox_head_keys) != 0 and len(rpn_head_keys) == 0:for bbox_head_key in bbox_head_keys:rpn_head_key = rpn_head_prefix + bbox_head_key[len(bbox_head_prefix):]state_dict[rpn_head_key] = state_dict.pop(bbox_head_key)
    

    这段代码检查是否存在bbox_head的权重而没有rpn_head的权重。如果是这种情况,它会遍历所有的bbox_head权重键,将它们转换为rpn_head的权重键,并在state_dict中进行更新。这是通过删除原bbox_head的权重键并添加新的rpn_head的权重键来实现的。

  4. 调用父类的加载方法:

    super()._load_from_state_dict(state_dict, prefix, local_metadata,strict, missing_keys, unexpected_keys,error_msgs)
    

    这行代码调用父类的_load_from_state_dict方法,完成权重的加载。这一步是必要的,因为它会处理权重的最终匹配和加载过程。


特征提取

def extract_feat(self, batch_inputs: Tensor) -> Tuple[Tensor]:"""Extract features.Args:batch_inputs (Tensor): Image tensor with shape (N, C, H ,W).Returns:tuple[Tensor]: Multi-level features that may havedifferent resolutions."""x = self.backbone(batch_inputs)if self.with_neck:x = self.neck(x)return x
  • extract_feat方法使用骨干网络和可选的颈部模块从输入图像中提取特征。

这段代码定义了一个名为 extract_feat 的方法,它是两阶段检测器中用于提取特征的关键步骤。下面,我们将详细解析这个方法的每个部分。

方法签名
def extract_feat(self, batch_inputs: Tensor) -> Tuple[Tensor]:
  • self: 指向类的实例,允许访问类的属性和方法。
  • batch_inputs: 输入的图像张量,其形状为 (N, C, H, W),其中 N 是批量大小,C 是通道数,HW 分别是图像的高度和宽度。
  • -> Tuple[Tensor]: 方法的返回类型注解,表示该方法将返回一个包含张量的元组,这些张量是不同分辨率的特征。
文档字符串(Docstring)
"""
Extract features.Args:batch_inputs (Tensor): Image tensor with shape (N, C, H ,W).Returns:tuple[Tensor]: Multi-level features that may havedifferent resolutions.
"""
  • 这部分是对方法的简要说明,说明了该方法的功能是提取特征。
  • Args: 描述了方法的输入参数,即一批图像。
  • Returns: 描述了方法的返回值,即具有不同分辨率的多级特征。
方法体
x = self.backbone(batch_inputs)
  • 这行代码调用了检测器的 backbone 网络,将输入的图像张量 batch_inputs 传递给它。
  • backbone 通常是卷积神经网络(CNN)的一部分,负责从输入图像中提取特征。
  • 执行后,x 将包含从输入图像中提取的特征。
if self.with_neck:x = self.neck(x)
  • 这行代码检查检测器是否具有 neck 组件(通常称为“颈部”或“连接”网络)。
  • self.with_neck 是一个布尔值,指示是否构建了颈部网络。
  • 如果存在颈部网络(self.with_neckTrue),则将 backbone 提取的特征 x 传递给 neck 网络进一步处理。
  • neck 网络通常用于进一步提取或融合特征,以提高检测器的性能。
返回值
return x
  • 方法返回 x,它包含了从输入图像中提取的特征。
  • 这些特征可能包含多个尺度或分辨率,这对于两阶段检测器在后续步骤中生成区域提议和进行目标识别非常有用。

前向传播


def _forward(self, batch_inputs: Tensor,batch_data_samples: SampleList) -> tuple:"""Network forward process. Usually includes backbone, neck and headforward without any post-processing.Args:batch_inputs (Tensor): Inputs with shape (N, C, H, W).batch_data_samples (list[:obj:`DetDataSample`]): Each item containsthe meta information of each image and correspondingannotations.Returns:tuple: A tuple of features from ``rpn_head`` and ``roi_head``forward."""results = ()x = self.extract_feat(batch_inputs)if self.with_rpn:rpn_results_list = self.rpn_head.predict(x, batch_data_samples, rescale=False)else:assert batch_data_samples[0].get('proposals', None) is not Nonerpn_results_list = [data_sample.proposals for data_sample in batch_data_samples]roi_outs = self.roi_head.forward(x, rpn_results_list,batch_data_samples)results = results + (roi_outs, )return results
  • _forward方法协调网络的前向传播,处理RPN和RoI头阶段。
    这段代码定义了一个名为 _forward 的方法,它是两阶段检测器中用于执行网络前向传播的关键步骤。下面,我们将详细解析这个方法的每个部分。
方法签名
def _forward(self, batch_inputs: Tensor,batch_data_samples: SampleList) -> tuple:
  • self: 指向类的实例,允许访问类的属性和方法。
  • batch_inputs: 输入的图像张量,其形状为 (N, C, H, W),其中 N 是批量大小,C 是通道数,HW 分别是图像的高度和宽度。
  • batch_data_samples: 包含每个图像的元信息和对应注释的 DetDataSample 对象列表。
  • -> tuple: 方法的返回类型注解,表示该方法将返回一个元组。
文档字符串(Docstring)
"""
Network forward process. Usually includes backbone, neck and head
forward without any post-processing.Args:batch_inputs (Tensor): Inputs with shape (N, C, H, W).batch_data_samples (list[:obj:`DetDataSample`]): Each item containsthe meta information of each image and correspondingannotations.Returns:tuple: A tuple of features from ``rpn_head`` and ``roi_head``forward.
"""
  • 这部分是对方法的简要说明,说明了该方法的功能是执行网络的前向传播过程,通常包括骨干网络、颈部网络和头部网络的前向传播,但不包括任何后处理。
方法体
results = ()
  • 初始化一个空的元组 results,用于存储前向传播的结果。
x = self.extract_feat(batch_inputs)
  • 调用 extract_feat 方法提取输入图像的特征。这些特征将被用于后续的区域提议网络(RPN)和感兴趣区域(RoI)头。
if self.with_rpn:rpn_results_list = self.rpn_head.predict(x, batch_data_samples, rescale=False)
else:assert batch_data_samples[0].get('proposals', None) is not Nonerpn_results_list = [data_sample.proposals for data_sample in batch_data_samples]
  • 检查检测器是否具有 RPN 头(self.with_rpn)。
  • 如果有 RPN 头,调用 RPN 头的 predict 方法来生成区域提议。这些提议是候选的目标位置。
  • 如果没有 RPN 头,假设输入数据中已经包含了预先定义的提议(proposals),并从每个数据样本中提取这些提议。
roi_outs = self.roi_head.forward(x, rpn_results_list,batch_data_samples)
  • 调用 RoI 头的 forward 方法,传入从骨干网络提取的特征 x、RPN 生成的区域提议 rpn_results_list 和包含图像元信息的数据样本 batch_data_samples
  • RoI 头负责从提议的区域中提取更精细的特征,并进行目标识别。
results = results + (roi_outs, )
  • 将 RoI 头的输出 roi_outs 添加到 results 元组中。
返回值
return results
  • 返回 results 元组,它包含了 RPN 头和 RoI 头的前向传播结果。

在当前代码片段中,并没有直接将 RPN 的结果和 RoI 头的结果合并到同一个元组中。只有 RoI 头的结果被添加到了 results 元组中。如果需要同时包含 RPN 和 RoI 头的结果,代码可能需要稍作修改,例如:

results = (rpn_results_list, roi_outs)

或者,如果 RPN 结果也需要在后续处理中使用,可以这样修改:

results = results + (rpn_results_list, roi_outs)

这样,results 元组就会同时包含 RPN 和 RoI 头的结果。


损失计算

def loss(self, batch_inputs: Tensor,batch_data_samples: SampleList) -> dict:"""Calculate losses from a batch of inputs and data samples.Args:batch_inputs (Tensor): Input images of shape (N, C, H, W).These should usually be mean centered and std scaled.batch_data_samples (List[:obj:`DetDataSample`]): The batchdata samples. It usually includes information suchas `gt_instance` or `gt_panoptic_seg` or `gt_sem_seg`.Returns:dict: A dictionary of loss components"""x = self.extract_feat(batch_inputs)losses = dict()# RPN forward and lossif self.with_rpn:proposal_cfg = self.train_cfg.get('rpn_proposal',self.test_cfg.rpn)rpn_data_samples = copy.deepcopy(batch_data_samples)# set cat_id of gt_labels to 0 in RPNfor data_sample in rpn_data_samples:data_sample.gt_instances.labels = \torch.zeros_like(data_sample.gt_instances.labels)rpn_losses, rpn_results_list = self.rpn_head.loss_and_predict(x, rpn_data_samples, proposal_cfg=proposal_cfg)# avoid get same name with roi_head losskeys = rpn_losses.keys()for key in list(keys):if 'loss' in key and 'rpn' not in key:rpn_losses[f'rpn_{key}'] = rpn_losses.pop(key)losses.update(rpn_losses)else:assert batch_data_samples[0].get('proposals', None) is not None# use pre-defined proposals in InstanceData for the second stage# to extract ROI features.rpn_results_list = [data_sample.proposals for data_sample in batch_data_samples]roi_losses = self.roi_head.loss(x, rpn_results_list,batch_data_samples)losses.update(roi_losses)return losses
  • loss方法计算训练损失,考虑了RPN和RoI头的损失。

这段代码定义了一个名为 loss 的方法,用于计算两阶段目标检测器在一批输入图像和数据样本上的损失。这个方法是训练过程中的核心部分,因为它决定了如何通过反向传播更新模型的权重。下面,我们将详细解析这个方法的每个部分。

方法签名
def loss(self, batch_inputs: Tensor, batch_data_samples: SampleList) -> dict:
  • self: 指向类的实例,允许访问类的属性和方法。
  • batch_inputs: 输入的图像张量,其形状为 (N, C, H, W),其中 N 是批量大小,C 是通道数,HW 分别是图像的高度和宽度。
  • batch_data_samples: 包含每个图像的元信息和对应注释的 DetDataSample 对象列表。
  • -> dict: 方法的返回类型注解,表示该方法将返回一个包含损失组件的字典。
文档字符串(Docstring)
"""
Calculate losses from a batch of inputs and data samples.Args:batch_inputs (Tensor): Input images of shape (N, C, H, W).These should usually be mean centered and std scaled.batch_data_samples (List[:obj:`DetDataSample`]): The batchdata samples. It usually includes information suchas `gt_instance` or `gt_panoptic_seg` or `gt_sem_seg`.Returns:dict: A dictionary of loss components
"""
  • 这部分是对方法的简要说明,说明了该方法的功能是计算损失,并描述了输入参数和返回值。
方法体
x = self.extract_feat(batch_inputs)
  • 调用 extract_feat 方法提取输入图像的特征。这些特征将被用于后续的 RPN 和 RoI 头的损失计算。
losses = dict()
  • 初始化一个空字典 losses,用于存储和返回损失组件。
if self.with_rpn:proposal_cfg = self.train_cfg.get('rpn_proposal', self.test_cfg.rpn)rpn_data_samples = copy.deepcopy(batch_data_samples)for data_sample in rpn_data_samples:data_sample.gt_instances.labels = torch.zeros_like(data_sample.gt_instances.labels)rpn_losses, rpn_results_list = self.rpn_head.loss_and_predict(x, rpn_data_samples, proposal_cfg=proposal_cfg)keys = rpn_losses.keys()for key in list(keys):if 'loss' in key and 'rpn' not in key:rpn_losses[f'rpn_{key}'] = rpn_losses.pop(key)losses.update(rpn_losses)
else:assert batch_data_samples[0].get('proposals', None) is not Nonerpn_results_list = [data_sample.proposals for data_sample in batch_data_samples]
  • 检查是否配置了 RPN 头(self.with_rpn)。
  • 如果有 RPN 头,首先获取 RPN 的配置,然后创建数据样本的深拷贝,并重置所有数据样本中的 gt_instances.labels 为零(这是因为 RPN 阶段不涉及类别标签的预测)。
  • 调用 RPN 头的 loss_and_predict 方法计算损失并获取区域提议。
  • 为了避免与 RoI 头的损失名称冲突,重命名 RPN 头的损失名称,添加前缀 rpn_
  • 如果没有 RPN 头,直接从数据样本中获取预定义的提议。
roi_losses = self.roi_head.loss(x, rpn_results_list, batch_data_samples)
losses.update(roi_losses)
  • 调用 RoI 头的 loss 方法计算损失,传入特征 x、RPN 的结果 rpn_results_list 和数据样本 batch_data_samples
  • 更新 losses 字典,将 RoI 头的损失添加到其中。
返回值
return losses
  • 返回 losses 字典,它包含了 RPN 和 RoI 头的所有损失组件。

预测

def predict(self,batch_inputs: Tensor,batch_data_samples: SampleList,rescale: bool = True) -> SampleList:"""Predict results from a batch of inputs and data samples with post-processing.Args:batch_inputs (Tensor): Inputs with shape (N, C, H, W).batch_data_samples (List[:obj:`DetDataSample`]): The DataSamples. It usually includes information such as`gt_instance`, `gt_panoptic_seg` and `gt_sem_seg`.rescale (bool): Whether to rescale the results.Defaults to True.Returns:list[:obj:`DetDataSample`]: Return the detection results of theinput images. The returns value is DetDataSample,which usually contain 'pred_instances'. And the``pred_instances`` usually contains following keys.- scores (Tensor): Classification scores, has a shape(num_instance, )- labels (Tensor): Labels of bboxes, has a shape(num_instances, ).- bboxes (Tensor): Has a shape (num_instances, 4),the last dimension 4 arrange as (x1, y1, x2, y2).- masks (Tensor): Has a shape (num_instances, H, W)."""assert self.with_bbox, 'Bbox head must be implemented.'x = self.extract_feat(batch_inputs)# If there are no pre-defined proposals, use RPN to get proposalsif batch_data_samples[0].get('proposals', None) is None:rpn_results_list = self.rpn_head.predict(x, batch_data_samples, rescale=False)else:rpn_results_list = [data_sample.proposals for data_sample in batch_data_samples]results_list = self.roi_head.predict(x, rpn_results_list, batch_data_samples, rescale=rescale)batch_data_samples = self.add_pred_to_datasample(batch_data_samples, results_list)return batch_data_samples
  • predict方法生成最终的检测结果,应用后处理步骤,如非极大值抑制。

这段代码定义了一个名为 predict 的方法,用于在两阶段目标检测器中对一批输入图像和数据样本进行预测,并执行后处理。以下是该方法的详细解析:

方法签名
def predict(self, batch_inputs: Tensor, batch_data_samples: SampleList, rescale: bool = True) -> SampleList:
  • self: 指向类的实例,允许访问类的属性和方法。
  • batch_inputs: 输入的图像张量,其形状为 (N, C, H, W),其中 N 是批量大小,C 是通道数,HW 分别是图像的高度和宽度。
  • batch_data_samples: 包含每个图像的元信息和对应注释的 DetDataSample 对象列表。
  • rescale: 一个布尔值,指示是否需要对预测结果进行尺度调整(例如,将边界框坐标从特征图尺度转换回原始图像尺度)。默认值为 True
  • -> SampleList: 方法的返回类型注解,表示该方法将返回一个 SampleList 对象,它包含了预测结果。
文档字符串(Docstring)
"""
Predict results from a batch of inputs and data samples with post-
processing.Args:batch_inputs (Tensor): Inputs with shape (N, C, H, W).batch_data_samples (List[:obj:`DetDataSample`]): The DataSamples. It usually includes information such as`gt_instance`, `gt_panoptic_seg` and `gt_sem_seg`.rescale (bool): Whether to rescale the results.Defaults to True.Returns:list[:obj:`DetDataSample`]: Return the detection results of theinput images. The returns value is DetDataSample,which usually contain 'pred_instances'. And the``pred_instances`` usually contains following keys.- scores (Tensor): Classification scores, has a shape(num_instance, )- labels (Tensor): Labels of bboxes, has a shape(num_instances, ).- bboxes (Tensor): Has a shape (num_instances, 4),the last dimension 4 arrange as (x1, y1, x2, y2).- masks (Tensor): Has a shape (num_instances, H, W).
"""
  • 这部分是对方法的简要说明,说明了该方法的功能是进行预测并执行后处理,并描述了输入参数和返回值。
方法体
assert self.with_bbox, 'Bbox head must be implemented.'
  • 这行代码是一个断言,确保检测器实现了边界框头(bbox_head)。如果没有实现,将抛出异常。
x = self.extract_feat(batch_inputs)
  • 调用 extract_feat 方法提取输入图像的特征。这些特征将被用于后续的 RPN 和 RoI 头的预测。
if batch_data_samples[0].get('proposals', None) is None:rpn_results_list = self.rpn_head.predict(x, batch_data_samples, rescale=False)
else:rpn_results_list = [data_sample.proposals for data_sample in batch_data_samples]
  • 检查输入数据样本中是否已经包含了预定义的提议(proposals)。如果没有,使用 RPN 头的 predict 方法生成区域提议。如果有,直接使用这些预定义的提议。
results_list = self.roi_head.predict(x, rpn_results_list, batch_data_samples, rescale=rescale)
  • 调用 RoI 头的 predict 方法,传入特征 x、RPN 的结果 rpn_results_list、数据样本 batch_data_samplesrescale 参数。这一步将生成最终的预测结果,包括类别、置信度和边界框。
batch_data_samples = self.add_pred_to_datasample(batch_data_samples, results_list)
  • 调用 add_pred_to_datasample 方法,将预测结果 results_list 添加到数据样本 batch_data_samples 中。这通常涉及到更新数据样本中的 pred_instances 属性,它包含了预测的类别、置信度、边界框等信息。
返回值
return batch_data_samples
  • 返回更新后的 batch_data_samples,它现在包含了每个图像的预测结果。

结论

two_stage.py文件封装了MMDetection中两阶段检测的本质。它提供了一种结构化的方法来构建具有模块化设计、灵活性和易于定制的检测器。理解这段代码对于任何希望使用MMDetection实现或修改两阶段检测器的人来说都是至关重要的。

想要更深入地探索或亲自动手使用MMDetection,可以参考官方文档和GitHub仓库。编程愉快!


本文旨在提供对MMDetection中TwoStageDetector类的全面理解,重点关注其架构和功能。对于进一步的探索或特定用例,建议探索源代码和配置文件。

相关文章:

〖open-mmlab: MMDetection〗解析文件:mmdet/models/detectors/two_stage.py

目录 MMDetection中的两阶段检测器:深入解析two_stage.py源码两阶段检测器概述two_stage.py的关键组件类定义和初始化构造函数Neck头配置RPN头配置RoI头配置_load_from_state_dict方法概述参数解释代码解析 特征提取方法签名文档字符串(Docstring&#x…...

【最新华为OD机试E卷-支持在线评测】机器人活动区域(100分)多语言题解-(Python/C/JavaScript/Java/Cpp)

🍭 大家好这里是春秋招笔试突围 ,一枚热爱算法的程序员 ✨ 本系列打算持续跟新华为OD-E/D卷的三语言AC题解 💻 ACM金牌🏅️团队| 多次AK大厂笔试 | 编程一对一辅导 👏 感谢大家的订阅➕ 和 喜欢💗 🍿 最新华为OD机试D卷目录,全、新、准,题目覆盖率达 95% 以上,…...

C语言:刷题日志(1)

一.阶乘计算升级版 本题要求实现一个打印非负整数阶乘的函数。 其中n是用户传入的参数,其值不超过1000。如果n是非负整数,则该函数必须在一行中打印出n!的值,否则打印“Invalid input”。 首先,知道阶乘是所有小于及等于该数的…...

ios私钥证书(p12)导入失败,Windows OpenSSl 1.1.1 下载

ios私钥证书(p12)导入失败 如果你用的OpenSSL版本是v3那么恭喜你V3必然报这个错,解决办法将OpenSSL 3降低成 v1。 Windows OpenSSl 1.1.1 下载 阿里云网盘下载地址:OpenSSL V1...

嵌入式面试经典30问:二

1. 嵌入式系统中,如何选择合适的微控制器或微处理器? 在嵌入式系统中选择合适的微控制器(MCU)或微处理器(MPU)时,需要考虑多个因素以确保所选组件能够满足项目的具体需求。以下是一些关键步骤和…...

目标检测-YOLOv1

YOLOv1介绍 YOLOv1(You Only Look Once version 1)是一种用于目标检测的深度学习算法,由Joseph Redmon等人于2016年提出。它基于单个卷积神经网络,将目标检测任务转化为一个回归问题,通过在图像上划分网格并预测每个网…...

python基础语法八-异常

书接上回: python基础语法一-基本数据类型 python基础语法二-多维数据类型 python基础语法三-类 python基础语法四-数据可视化 python基础语法五-函数 python基础语法六-正则匹配 python基础语法七-openpyxl操作excel 1. 异常简介 (1)异常:遇到…...

【堆的应用--C语言版】

前面一节我们都已将堆的结构(顺序存储)已经实现,对树的相关概念以及知识做了一定的了解。其中我们在实现删除操作和插入操作的时候,我们还同时实现了建大堆(小堆)的向上(下)调整算法…...

【微信小程序】搭建项目步骤 + 引入Tdesign UI

目录 创建1个空文件夹,选择下图基础模板 开启/支持sass 创建公共style文件并引入 引入Tdesign UI: 1. 初始化: 2. 安装后,开发工具进行构建: 3. 修改 app.json 4. 使用 5. 自定义主题色 创建1个空文件夹,选择下…...

android系统源码12 修改默认桌面壁纸--SRO方式

1、aosp12修改默认桌面壁纸 代码路径 :frameworks\base\core\res\res\drawable-nodpi 替换成自己的图片即可,不过需要覆盖所有目录下的图片。 由于是静态修改,则需要make一下,重新编译。 2、方法二Overlay方式 由于上述方法有…...

Echarts可视化

echarts是一个基于javascripts的开源可视化图表库 画图步骤&#xff1a; 1.引入echarts.js文件 <script src" https://cdn.jsdelivr.net/npm/echarts5.5.1/dist/echarts.min.js"></script> 也可将文件下载到本地通过src引入。 2. 准备一个呈现图表的…...

验证linux gpu是否可用

通过torch验证 import torchprint(torch.__version__) # 查看torch当前版本号 print(torch.version.cuda) # 编译当前版本的torch使用的cuda版本号 print(torch.cuda.is_available()) # 查看当前cuda是否可用于当前版本的Torch&#xff0c;如果输出True&#xff0c;则表示可…...

JavaScript( 简介)

目录 含义 实例 js代码位置 1 外部引入js文件 2 在 HTML 中&#xff0c;JavaScript 代码必须位于 标签之间。 小结 含义 js是一门脚本语言&#xff0c;能够改变HTML内容 实例 getElementById() 是多个 JavaScript HTML 方法之一。 本例使用该方法来“查找” id"d…...

Linux中的编译器gcc/g++

目录 一、gcc与g的区别 1.gcc编译器使用 2.g编译器使用 二、gcc/g编译器编译源文件过程 1.预处理 2.编译 3.汇编 4.链接 三、静态库和动态库 1.库中的头文件作用 2.静态库 3.动态库 四、gcc编译器的一些选项命令 一、gcc与g的区别 gcc用于编译C语言代码&#xff…...

RK3568安装部署Docker容器

设置华为镜像源 sudo sed -i s/huaweicloud.com/ustc.edu.cn/g /etc/apt/sources.list更新索引 rootok3568:/home/forlinx# sudo apt-get update Hit:1 http://ports.ubuntu.com/ubuntu-ports focal InRelease Hit:2 http://ports.ubuntu.com/ubuntu-ports focal-updates InR…...

Ubuntu 常用指令和作用解析

Ubuntu 常用指令和作用解析 Ubuntu 是一种常见的 Linux 发行版&#xff0c;它利用了 Unix 的力量和开源软件的精神。掌握常用指令可以提高我们在使用 Ubuntu 时的效率。本文将介绍一些常见的指令及其用途。 目录 更新与安装软件文件与目录操作系统信息与资源监控用户与权限管…...

2024国赛数学建模C题完整论文:农作物的种植策略

农作物种植策略优化的数学建模研究&#xff08;完整论文&#xff0c;持续更新&#xff0c;大家持续关注&#xff0c;更新见文末名片 &#xff09; 摘要 在本文中&#xff0c;建立了基于整数规划、动态规划、马尔科夫决策过程、不确定性建模、多目标优化、相关性分析、蒙特卡洛…...

【语音告警】博灵智能语音报警灯JavaScript循环播报场景实例-语音报警灯|声光报警器|网络信号灯

功能说明 本文将以JavaScript代码为实例&#xff0c;讲解如何通过JavaScript代码调用博灵语音通知终端 A4实现声光语音告警。主要博灵语音通知终端如何实现无线循环播报或者周期播报的功能。 本代码实现HTTP接口的声光语音播报&#xff0c;并指定循环次数、播报内容。由于通知…...

指针与函数(三)

三 .指向函数的指针 函数和数组一样,经系统编译后,其目标代码在内存中连续存放,其名字本身就是一个地址,是函数的入口地址。C语言中,指针可以指向变量,也可以指向函数。 指问函数的指针的定义格式为 类型名&#xff08;*指针变量名&#xff09;参数表 其中参数表为函数指针所…...

锐捷网络2025届校园招聘正式启动,【NTA6dni】!

锐捷网络2025届校园招聘正式启动&#xff0c;内推码[NTA6dni]。 原文链接点这 投递链接点这 祝大家面试顺利&#xff0c;offer多多~ 有问题大家可以评论&#xff0c;互相交流~...

共享内存喜欢沙县小吃

旭日新摊子好耶&#xff01; 系统从0开始搭建过通信方案&#xff0c;本地通信方案的代码&#xff1a;System V IPC 里面有共享内存、消息队列、信号量 共享内存 原理 两个进程有自己的内存区域划分&#xff0c;共享内存被创建出的时候是归属操作系统的&#xff0c;还是通过…...

五、Build构建配置:jar包换名、自行定义编译规则

&#xff08;1&#xff09;jar包换名&#xff1a;finalName &#xff08;2&#xff09;自行定义编译规则&#xff08;通常不用&#xff09; Maven约定的规则就是java目录下写java代码&#xff0c;resources目录下写配置文件。 遵循规则&#xff0c;Maven会帮忙做编译。 如若…...

Html、Css3动画效果

文章目录 第九章 动画9.1 transform动画9.2 transition过渡动画9.3 定义动画 第九章 动画 9.1 transform动画 transform 2D变形 translate()&#xff1a;平移函数&#xff0c;基于X、Y坐标重新定位元素的位置 scale()&#xff1a;缩放函数&#xff0c;可以使任意元素对象尺…...

【AIStarter:AI绘画、设计、对话】零基础入门:Llama 3.1 + 千问2快速部署

对于希望在本地环境中运行先进语言模型的用户来说&#xff0c;Llama 3.1和千问2是非常不错的选择。本文将详细介绍如何在本地部署这两个模型&#xff0c;让你能够快速开始使用。 前期准备 确保你的计算机具备足够的存储空间和计算能力。安装Python环境以及必要的库&#xff0…...

多机编队—(1)ubuntu 配置Fast_Planner

文章目录 前言一、Could not find package ...二、使用error: no match for ‘operator’...总结 前言 最近想要做有轨迹引导的多机器人编队&#xff0c;打算采用分布式的编队架构&#xff0c;实时的给每个机器人规划出目标位置&#xff0c;然后通过Fast_Planner生成避障路径&…...

【数学建模经验贴】国赛拿到赛题后,该如何选题?

2024“高教社杯”全国大学生数学建模竞赛即将开赛。这可能是很多同学第一次参加国赛&#xff0c;甚至是第一次参加数学建模比赛。 那么赛题的公布也就意味着比赛的开始&#xff0c;也将是我们所要面对的第一个问题——选题。在国赛来临的前夕&#xff0c;C君想和大家聊一聊容易…...

如何快速融入大学课堂

快速融入大学课堂是适应大学生活的重要一步。以下是一些实用的建议&#xff0c;帮助你快速融入大学课堂并取得良好的学习效果。 ### 1. 提前准备 - **课前预习**&#xff1a;在上课前预习课程内容&#xff0c;了解基本概念和知识点&#xff0c;这样在课堂上更容易跟上老师的讲…...

el-table行编辑

需求&#xff1a;单点行编辑并且请求接口更新数据&#xff0c;表格中某几个字段是下拉框取值的&#xff0c;剩下的是文本域&#xff1b;展示的时候 需要区分下拉框编码还是中文 故障模式这个展示的是fault_mode编码,但要显示的文字fault_mode_chn 这点需要注意 <el-tablere…...

OpenSSL Windows编译

目录 1. 源码下载2. vs2022编译 1. 源码下载 源码地址 2. vs2022编译 (1) 将“VS2022安装目录VC\Auxiliary\Build\“设置为PATH环境变量&#xff0c;启动cmd命令行&#xff08;一定要先设置环境变量&#xff09;。 (2)在cmd下进入VS2013安装目录vs2022\VC\Auxiliary\Build&…...

优化LabVIEW中TCP通信速度的方法

在LabVIEW中&#xff0c;TCP通信速度较慢可能由多种因素导致&#xff0c;如数据包处理延迟、阻塞式读取或数据解析效率低等。通过调整读取模式、优化数据处理逻辑、以及使用并行处理结构&#xff0c;可以显著提升TCP通信的速度&#xff0c;使其接近第三方调试工具的表现。LabVI…...