〖open-mmlab: MMDetection〗解析文件:mmdet/models/detectors/two_stage.py
目录
- MMDetection中的两阶段检测器:深入解析`two_stage.py`源码
- 两阶段检测器概述
- `two_stage.py`的关键组件
- 类定义和初始化
- 构造函数
- Neck头配置
- RPN头配置
- RoI头配置
- `_load_from_state_dict`
- 方法概述
- 参数解释
- 代码解析
- 特征提取
- 方法签名
- 文档字符串(Docstring)
- 方法体
- 返回值
- 前向传播
- 方法签名
- 文档字符串(Docstring)
- 方法体
- 返回值
- 损失计算
- 方法签名
- 文档字符串(Docstring)
- 方法体
- 返回值
- 预测
- 方法签名
- 文档字符串(Docstring)
- 方法体
- 返回值
- 结论
MMDetection中的两阶段检测器:深入解析two_stage.py源码
在目标检测领域,两阶段检测器因其在准确性和速度之间取得的平衡而成为基石方法之一。MMDetection是一个基于PyTorch的开源目标检测工具箱,它为实现此类检测器提供了强大的框架。在这篇博客文章中,我们将深入解析two_stage.py源码,这是MMDetection两阶段检测架构中的核心部分。
两阶段检测器概述
两阶段检测器的操作分为两个主要阶段:
- 区域提议网络(Region Proposal Network, RPN):第一阶段识别潜在的目标位置,即区域提议。
- 感兴趣区域(Region of Interest, RoI)头:第二阶段对这些提议进行细化,以得到精确的目标检测结果。
two_stage.py的关键组件
TwoStageDetector类是MMDetection中两阶段检测器的基础构建模块。让我们分解其核心组件:
类定义和初始化
@MODELS.register_module()
class TwoStageDetector(BaseDetector):"""两阶段检测器的基类。"""
- 类通过
@MODELS.register_module()装饰器注册在MMDetection的模型注册表中,使其易于配置和实例化。
构造函数
def __init__(self, backbone, neck=None, rpn_head=None, roi_head=None, train_cfg=None, test_cfg=None, data_preprocessor=None, init_cfg=None):super().__init__(data_preprocessor=data_preprocessor, init_cfg=init_cfg)self.backbone = MODELS.build(backbone)...
- 构造函数使用各种组件(如骨干网络、颈部网络、RPN头和RoI头)初始化检测器。它还处理训练和测试的配置。
Neck头配置
if neck is not None:self.neck = MODELS.build(neck)
RPN头配置
if rpn_head is not None:rpn_train_cfg = train_cfg.rpn if train_cfg is not None else Nonerpn_head_ = rpn_head.copy()rpn_head_.update(train_cfg=rpn_train_cfg, test_cfg=test_cfg.rpn)rpn_head_num_classes = rpn_head_.get('num_classes', None)if rpn_head_num_classes is None:rpn_head_.update(num_classes=1)else:if rpn_head_num_classes != 1:warnings.warn('The `num_classes` should be 1 in RPN, but get 'f'{rpn_head_num_classes}, please set ''rpn_head.num_classes = 1 in your config file.')rpn_head_.update(num_classes=1)self.rpn_head = MODELS.build(rpn_head_)
- RPN头使用训练和测试配置进行配置。确保
num_classes设置为1对于RPN至关重要,因为它只预测目标存在,而不是类别标签。
这段代码是两阶段检测器中初始化和配置区域提议网络(Region Proposal Network, RPN)的逻辑部分。让我们逐行分析:
-
检查RPN头是否提供:
if rpn_head is not None:这行代码检查是否提供了
rpn_head配置。如果提供了,那么进入代码块进行进一步的配置。 -
获取训练配置:
rpn_train_cfg = train_cfg.rpn if train_cfg is not None else None这行代码尝试从
train_cfg(训练配置)中获取RPN部分的配置。如果train_cfg存在,则rpn_train_cfg被设置为train_cfg中的rpn部分,否则设置为None。 -
复制RPN头配置:
rpn_head_ = rpn_head.copy()这行代码创建了
rpn_head配置的一个副本,以避免直接修改原始配置。 -
更新RPN头配置:
rpn_head_.update(train_cfg=rpn_train_cfg, test_cfg=test_cfg.rpn)这行代码将训练和测试的配置更新到RPN头的配置中。这样做是为了确保RPN在训练和测试时使用正确的参数。
-
获取RPN头的类别数:
rpn_head_num_classes = rpn_head_.get('num_classes', None)这行代码尝试从RPN头配置中获取
num_classes参数。如果不存在,则默认为None。 -
设置RPN头的类别数:
if rpn_head_num_classes is None:rpn_head_.update(num_classes=1) else:if rpn_head_num_classes != 1:warnings.warn('The `num_classes` should be 1 in RPN, but get 'f'{rpn_head_num_classes}, please set ''rpn_head.num_classes = 1 in your config file.')rpn_head_.update(num_classes=1)这部分代码首先检查
num_classes是否为None。如果是,那么它将num_classes设置为1。如果不是None,但值不是1,那么它会发出一个警告,提示用户RPN中的num_classes应该是1,因为RPN只负责检测物体的存在与否,而不是分类物体。然后,它将num_classes强制设置为1。 -
构建RPN头:
self.rpn_head = MODELS.build(rpn_head_)这行代码使用更新后的RPN头配置来构建RPN模型。
MODELS.build是一个工厂方法,根据提供的配置创建并返回RPN模型的实例。
总的来说,这段代码确保了RPN头被正确地配置和构建,特别是关于num_classes参数,它对于RPN的功能至关重要。
RoI头配置
if roi_head is not None:roi_head.update(train_cfg=rcnn_train_cfg)roi_head.update(test_cfg=test_cfg.rcnn)self.roi_head = MODELS.build(roi_head)
- 与RPN头类似,RoI头也配置了相应的训练和测试配置。
这段代码是两阶段检测器中初始化和配置感兴趣区域(Region of Interest, RoI)头的逻辑部分。让我们逐行分析:
-
检查RoI头是否提供:
if roi_head is not None:这行代码检查是否提供了
roi_head配置。如果提供了,那么进入代码块进行进一步的配置。 -
获取训练和测试配置:
rcnn_train_cfg = train_cfg.rcnn if train_cfg is not None else None这行代码尝试从
train_cfg(训练配置)中获取RoI部分的配置。如果train_cfg存在,则rcnn_train_cfg被设置为train_cfg中的rcnn部分,否则设置为None。 -
更新RoI头的训练配置:
roi_head.update(train_cfg=rcnn_train_cfg)这行代码将训练的配置更新到RoI头的配置中。这样做是为了确保RoI头在训练时使用正确的参数。
-
更新RoI头的测试配置:
roi_head.update(test_cfg=test_cfg.rcnn)这行代码将测试的配置更新到RoI头的配置中。这样做是为了确保RoI头在测试时使用正确的参数。
-
构建RoI头:
self.roi_head = MODELS.build(roi_head)这行代码使用更新后的RoI头配置来构建RoI模型。
MODELS.build是一个工厂方法,根据提供的配置创建并返回RoI模型的实例。
_load_from_state_dict
def _load_from_state_dict(self, state_dict: dict, prefix: str,local_metadata: dict, strict: bool,missing_keys: Union[List[str], str],unexpected_keys: Union[List[str], str],error_msgs: Union[List[str], str]) -> None:"""Exchange bbox_head key to rpn_head key when loading single-stageweights into two-stage model."""bbox_head_prefix = prefix + '.bbox_head' if prefix else 'bbox_head'bbox_head_keys = [k for k in state_dict.keys() if k.startswith(bbox_head_prefix)]rpn_head_prefix = prefix + '.rpn_head' if prefix else 'rpn_head'rpn_head_keys = [k for k in state_dict.keys() if k.startswith(rpn_head_prefix)]if len(bbox_head_keys) != 0 and len(rpn_head_keys) == 0:for bbox_head_key in bbox_head_keys:rpn_head_key = rpn_head_prefix + \bbox_head_key[len(bbox_head_prefix):]state_dict[rpn_head_key] = state_dict.pop(bbox_head_key)super()._load_from_state_dict(state_dict, prefix, local_metadata,strict, missing_keys, unexpected_keys,error_msgs)
在深度学习模型的训练和部署过程中,加载预训练权重是一个常见的操作。在两阶段检测器中,由于其结构与单阶段检测器不同,因此在加载权重时需要特别注意权重的匹配和转换。_load_from_state_dict方法正是为了解决这个问题而设计的。下面,我们将详细解析这个方法的工作原理,并探讨其在两阶段检测器中的重要性。
方法概述
_load_from_state_dict方法是在加载预训练权重时调用的,它的作用是将单阶段检测器的权重转换为两阶段检测器可以使用的格式。这是通过交换bbox_head和rpn_head的键来实现的。
参数解释
state_dict: 包含模型权重的字典。prefix: 权重键的前缀,用于区分不同部分的权重。local_metadata: 模型的元数据,通常包含模型结构信息。strict: 是否严格匹配权重,如果为True,权重不匹配会抛出错误。missing_keys: 缺失的权重键列表。unexpected_keys: 多余的权重键列表。error_msgs: 加载权重时的错误信息列表。
代码解析
-
定义
bbox_head和rpn_head的键前缀:bbox_head_prefix = prefix + '.bbox_head' if prefix else 'bbox_head' rpn_head_prefix = prefix + '.rpn_head' if prefix else 'rpn_head'这两行代码定义了
bbox_head和rpn_head的键前缀。如果提供了prefix,则将prefix加到bbox_head和rpn_head前面,否则使用默认的键名。 -
获取
bbox_head和rpn_head的键:bbox_head_keys = [k for k in state_dict.keys() if k.startswith(bbox_head_prefix)] rpn_head_keys = [k for k in state_dict.keys() if k.startswith(rpn_head_prefix)]这两行代码通过列表推导式获取所有以
bbox_head_prefix和rpn_head_prefix开头的键,这些键分别对应单阶段检测器的边界框头和两阶段检测器的RPN头的权重。 -
权重转换:
if len(bbox_head_keys) != 0 and len(rpn_head_keys) == 0:for bbox_head_key in bbox_head_keys:rpn_head_key = rpn_head_prefix + bbox_head_key[len(bbox_head_prefix):]state_dict[rpn_head_key] = state_dict.pop(bbox_head_key)这段代码检查是否存在
bbox_head的权重而没有rpn_head的权重。如果是这种情况,它会遍历所有的bbox_head权重键,将它们转换为rpn_head的权重键,并在state_dict中进行更新。这是通过删除原bbox_head的权重键并添加新的rpn_head的权重键来实现的。 -
调用父类的加载方法:
super()._load_from_state_dict(state_dict, prefix, local_metadata,strict, missing_keys, unexpected_keys,error_msgs)这行代码调用父类的
_load_from_state_dict方法,完成权重的加载。这一步是必要的,因为它会处理权重的最终匹配和加载过程。
特征提取
def extract_feat(self, batch_inputs: Tensor) -> Tuple[Tensor]:"""Extract features.Args:batch_inputs (Tensor): Image tensor with shape (N, C, H ,W).Returns:tuple[Tensor]: Multi-level features that may havedifferent resolutions."""x = self.backbone(batch_inputs)if self.with_neck:x = self.neck(x)return x
extract_feat方法使用骨干网络和可选的颈部模块从输入图像中提取特征。
这段代码定义了一个名为 extract_feat 的方法,它是两阶段检测器中用于提取特征的关键步骤。下面,我们将详细解析这个方法的每个部分。
方法签名
def extract_feat(self, batch_inputs: Tensor) -> Tuple[Tensor]:
self: 指向类的实例,允许访问类的属性和方法。batch_inputs: 输入的图像张量,其形状为(N, C, H, W),其中N是批量大小,C是通道数,H和W分别是图像的高度和宽度。-> Tuple[Tensor]: 方法的返回类型注解,表示该方法将返回一个包含张量的元组,这些张量是不同分辨率的特征。
文档字符串(Docstring)
"""
Extract features.Args:batch_inputs (Tensor): Image tensor with shape (N, C, H ,W).Returns:tuple[Tensor]: Multi-level features that may havedifferent resolutions.
"""
- 这部分是对方法的简要说明,说明了该方法的功能是提取特征。
Args: 描述了方法的输入参数,即一批图像。Returns: 描述了方法的返回值,即具有不同分辨率的多级特征。
方法体
x = self.backbone(batch_inputs)
- 这行代码调用了检测器的
backbone网络,将输入的图像张量batch_inputs传递给它。 backbone通常是卷积神经网络(CNN)的一部分,负责从输入图像中提取特征。- 执行后,
x将包含从输入图像中提取的特征。
if self.with_neck:x = self.neck(x)
- 这行代码检查检测器是否具有
neck组件(通常称为“颈部”或“连接”网络)。 self.with_neck是一个布尔值,指示是否构建了颈部网络。- 如果存在颈部网络(
self.with_neck为True),则将backbone提取的特征x传递给neck网络进一步处理。 neck网络通常用于进一步提取或融合特征,以提高检测器的性能。
返回值
return x
- 方法返回
x,它包含了从输入图像中提取的特征。 - 这些特征可能包含多个尺度或分辨率,这对于两阶段检测器在后续步骤中生成区域提议和进行目标识别非常有用。
前向传播
def _forward(self, batch_inputs: Tensor,batch_data_samples: SampleList) -> tuple:"""Network forward process. Usually includes backbone, neck and headforward without any post-processing.Args:batch_inputs (Tensor): Inputs with shape (N, C, H, W).batch_data_samples (list[:obj:`DetDataSample`]): Each item containsthe meta information of each image and correspondingannotations.Returns:tuple: A tuple of features from ``rpn_head`` and ``roi_head``forward."""results = ()x = self.extract_feat(batch_inputs)if self.with_rpn:rpn_results_list = self.rpn_head.predict(x, batch_data_samples, rescale=False)else:assert batch_data_samples[0].get('proposals', None) is not Nonerpn_results_list = [data_sample.proposals for data_sample in batch_data_samples]roi_outs = self.roi_head.forward(x, rpn_results_list,batch_data_samples)results = results + (roi_outs, )return results
_forward方法协调网络的前向传播,处理RPN和RoI头阶段。
这段代码定义了一个名为_forward的方法,它是两阶段检测器中用于执行网络前向传播的关键步骤。下面,我们将详细解析这个方法的每个部分。
方法签名
def _forward(self, batch_inputs: Tensor,batch_data_samples: SampleList) -> tuple:
self: 指向类的实例,允许访问类的属性和方法。batch_inputs: 输入的图像张量,其形状为(N, C, H, W),其中N是批量大小,C是通道数,H和W分别是图像的高度和宽度。batch_data_samples: 包含每个图像的元信息和对应注释的DetDataSample对象列表。-> tuple: 方法的返回类型注解,表示该方法将返回一个元组。
文档字符串(Docstring)
"""
Network forward process. Usually includes backbone, neck and head
forward without any post-processing.Args:batch_inputs (Tensor): Inputs with shape (N, C, H, W).batch_data_samples (list[:obj:`DetDataSample`]): Each item containsthe meta information of each image and correspondingannotations.Returns:tuple: A tuple of features from ``rpn_head`` and ``roi_head``forward.
"""
- 这部分是对方法的简要说明,说明了该方法的功能是执行网络的前向传播过程,通常包括骨干网络、颈部网络和头部网络的前向传播,但不包括任何后处理。
方法体
results = ()
- 初始化一个空的元组
results,用于存储前向传播的结果。
x = self.extract_feat(batch_inputs)
- 调用
extract_feat方法提取输入图像的特征。这些特征将被用于后续的区域提议网络(RPN)和感兴趣区域(RoI)头。
if self.with_rpn:rpn_results_list = self.rpn_head.predict(x, batch_data_samples, rescale=False)
else:assert batch_data_samples[0].get('proposals', None) is not Nonerpn_results_list = [data_sample.proposals for data_sample in batch_data_samples]
- 检查检测器是否具有 RPN 头(
self.with_rpn)。 - 如果有 RPN 头,调用 RPN 头的
predict方法来生成区域提议。这些提议是候选的目标位置。 - 如果没有 RPN 头,假设输入数据中已经包含了预先定义的提议(
proposals),并从每个数据样本中提取这些提议。
roi_outs = self.roi_head.forward(x, rpn_results_list,batch_data_samples)
- 调用 RoI 头的
forward方法,传入从骨干网络提取的特征x、RPN 生成的区域提议rpn_results_list和包含图像元信息的数据样本batch_data_samples。 - RoI 头负责从提议的区域中提取更精细的特征,并进行目标识别。
results = results + (roi_outs, )
- 将 RoI 头的输出
roi_outs添加到results元组中。
返回值
return results
- 返回
results元组,它包含了 RPN 头和 RoI 头的前向传播结果。
在当前代码片段中,并没有直接将 RPN 的结果和 RoI 头的结果合并到同一个元组中。只有 RoI 头的结果被添加到了 results 元组中。如果需要同时包含 RPN 和 RoI 头的结果,代码可能需要稍作修改,例如:
results = (rpn_results_list, roi_outs)
或者,如果 RPN 结果也需要在后续处理中使用,可以这样修改:
results = results + (rpn_results_list, roi_outs)
这样,results 元组就会同时包含 RPN 和 RoI 头的结果。
损失计算
def loss(self, batch_inputs: Tensor,batch_data_samples: SampleList) -> dict:"""Calculate losses from a batch of inputs and data samples.Args:batch_inputs (Tensor): Input images of shape (N, C, H, W).These should usually be mean centered and std scaled.batch_data_samples (List[:obj:`DetDataSample`]): The batchdata samples. It usually includes information suchas `gt_instance` or `gt_panoptic_seg` or `gt_sem_seg`.Returns:dict: A dictionary of loss components"""x = self.extract_feat(batch_inputs)losses = dict()# RPN forward and lossif self.with_rpn:proposal_cfg = self.train_cfg.get('rpn_proposal',self.test_cfg.rpn)rpn_data_samples = copy.deepcopy(batch_data_samples)# set cat_id of gt_labels to 0 in RPNfor data_sample in rpn_data_samples:data_sample.gt_instances.labels = \torch.zeros_like(data_sample.gt_instances.labels)rpn_losses, rpn_results_list = self.rpn_head.loss_and_predict(x, rpn_data_samples, proposal_cfg=proposal_cfg)# avoid get same name with roi_head losskeys = rpn_losses.keys()for key in list(keys):if 'loss' in key and 'rpn' not in key:rpn_losses[f'rpn_{key}'] = rpn_losses.pop(key)losses.update(rpn_losses)else:assert batch_data_samples[0].get('proposals', None) is not None# use pre-defined proposals in InstanceData for the second stage# to extract ROI features.rpn_results_list = [data_sample.proposals for data_sample in batch_data_samples]roi_losses = self.roi_head.loss(x, rpn_results_list,batch_data_samples)losses.update(roi_losses)return losses
loss方法计算训练损失,考虑了RPN和RoI头的损失。
这段代码定义了一个名为 loss 的方法,用于计算两阶段目标检测器在一批输入图像和数据样本上的损失。这个方法是训练过程中的核心部分,因为它决定了如何通过反向传播更新模型的权重。下面,我们将详细解析这个方法的每个部分。
方法签名
def loss(self, batch_inputs: Tensor, batch_data_samples: SampleList) -> dict:
self: 指向类的实例,允许访问类的属性和方法。batch_inputs: 输入的图像张量,其形状为(N, C, H, W),其中N是批量大小,C是通道数,H和W分别是图像的高度和宽度。batch_data_samples: 包含每个图像的元信息和对应注释的DetDataSample对象列表。-> dict: 方法的返回类型注解,表示该方法将返回一个包含损失组件的字典。
文档字符串(Docstring)
"""
Calculate losses from a batch of inputs and data samples.Args:batch_inputs (Tensor): Input images of shape (N, C, H, W).These should usually be mean centered and std scaled.batch_data_samples (List[:obj:`DetDataSample`]): The batchdata samples. It usually includes information suchas `gt_instance` or `gt_panoptic_seg` or `gt_sem_seg`.Returns:dict: A dictionary of loss components
"""
- 这部分是对方法的简要说明,说明了该方法的功能是计算损失,并描述了输入参数和返回值。
方法体
x = self.extract_feat(batch_inputs)
- 调用
extract_feat方法提取输入图像的特征。这些特征将被用于后续的 RPN 和 RoI 头的损失计算。
losses = dict()
- 初始化一个空字典
losses,用于存储和返回损失组件。
if self.with_rpn:proposal_cfg = self.train_cfg.get('rpn_proposal', self.test_cfg.rpn)rpn_data_samples = copy.deepcopy(batch_data_samples)for data_sample in rpn_data_samples:data_sample.gt_instances.labels = torch.zeros_like(data_sample.gt_instances.labels)rpn_losses, rpn_results_list = self.rpn_head.loss_and_predict(x, rpn_data_samples, proposal_cfg=proposal_cfg)keys = rpn_losses.keys()for key in list(keys):if 'loss' in key and 'rpn' not in key:rpn_losses[f'rpn_{key}'] = rpn_losses.pop(key)losses.update(rpn_losses)
else:assert batch_data_samples[0].get('proposals', None) is not Nonerpn_results_list = [data_sample.proposals for data_sample in batch_data_samples]
- 检查是否配置了 RPN 头(
self.with_rpn)。 - 如果有 RPN 头,首先获取 RPN 的配置,然后创建数据样本的深拷贝,并重置所有数据样本中的
gt_instances.labels为零(这是因为 RPN 阶段不涉及类别标签的预测)。 - 调用 RPN 头的
loss_and_predict方法计算损失并获取区域提议。 - 为了避免与 RoI 头的损失名称冲突,重命名 RPN 头的损失名称,添加前缀
rpn_。 - 如果没有 RPN 头,直接从数据样本中获取预定义的提议。
roi_losses = self.roi_head.loss(x, rpn_results_list, batch_data_samples)
losses.update(roi_losses)
- 调用 RoI 头的
loss方法计算损失,传入特征x、RPN 的结果rpn_results_list和数据样本batch_data_samples。 - 更新
losses字典,将 RoI 头的损失添加到其中。
返回值
return losses
- 返回
losses字典,它包含了 RPN 和 RoI 头的所有损失组件。
预测
def predict(self,batch_inputs: Tensor,batch_data_samples: SampleList,rescale: bool = True) -> SampleList:"""Predict results from a batch of inputs and data samples with post-processing.Args:batch_inputs (Tensor): Inputs with shape (N, C, H, W).batch_data_samples (List[:obj:`DetDataSample`]): The DataSamples. It usually includes information such as`gt_instance`, `gt_panoptic_seg` and `gt_sem_seg`.rescale (bool): Whether to rescale the results.Defaults to True.Returns:list[:obj:`DetDataSample`]: Return the detection results of theinput images. The returns value is DetDataSample,which usually contain 'pred_instances'. And the``pred_instances`` usually contains following keys.- scores (Tensor): Classification scores, has a shape(num_instance, )- labels (Tensor): Labels of bboxes, has a shape(num_instances, ).- bboxes (Tensor): Has a shape (num_instances, 4),the last dimension 4 arrange as (x1, y1, x2, y2).- masks (Tensor): Has a shape (num_instances, H, W)."""assert self.with_bbox, 'Bbox head must be implemented.'x = self.extract_feat(batch_inputs)# If there are no pre-defined proposals, use RPN to get proposalsif batch_data_samples[0].get('proposals', None) is None:rpn_results_list = self.rpn_head.predict(x, batch_data_samples, rescale=False)else:rpn_results_list = [data_sample.proposals for data_sample in batch_data_samples]results_list = self.roi_head.predict(x, rpn_results_list, batch_data_samples, rescale=rescale)batch_data_samples = self.add_pred_to_datasample(batch_data_samples, results_list)return batch_data_samples
predict方法生成最终的检测结果,应用后处理步骤,如非极大值抑制。
这段代码定义了一个名为 predict 的方法,用于在两阶段目标检测器中对一批输入图像和数据样本进行预测,并执行后处理。以下是该方法的详细解析:
方法签名
def predict(self, batch_inputs: Tensor, batch_data_samples: SampleList, rescale: bool = True) -> SampleList:
self: 指向类的实例,允许访问类的属性和方法。batch_inputs: 输入的图像张量,其形状为(N, C, H, W),其中N是批量大小,C是通道数,H和W分别是图像的高度和宽度。batch_data_samples: 包含每个图像的元信息和对应注释的DetDataSample对象列表。rescale: 一个布尔值,指示是否需要对预测结果进行尺度调整(例如,将边界框坐标从特征图尺度转换回原始图像尺度)。默认值为True。-> SampleList: 方法的返回类型注解,表示该方法将返回一个SampleList对象,它包含了预测结果。
文档字符串(Docstring)
"""
Predict results from a batch of inputs and data samples with post-
processing.Args:batch_inputs (Tensor): Inputs with shape (N, C, H, W).batch_data_samples (List[:obj:`DetDataSample`]): The DataSamples. It usually includes information such as`gt_instance`, `gt_panoptic_seg` and `gt_sem_seg`.rescale (bool): Whether to rescale the results.Defaults to True.Returns:list[:obj:`DetDataSample`]: Return the detection results of theinput images. The returns value is DetDataSample,which usually contain 'pred_instances'. And the``pred_instances`` usually contains following keys.- scores (Tensor): Classification scores, has a shape(num_instance, )- labels (Tensor): Labels of bboxes, has a shape(num_instances, ).- bboxes (Tensor): Has a shape (num_instances, 4),the last dimension 4 arrange as (x1, y1, x2, y2).- masks (Tensor): Has a shape (num_instances, H, W).
"""
- 这部分是对方法的简要说明,说明了该方法的功能是进行预测并执行后处理,并描述了输入参数和返回值。
方法体
assert self.with_bbox, 'Bbox head must be implemented.'
- 这行代码是一个断言,确保检测器实现了边界框头(
bbox_head)。如果没有实现,将抛出异常。
x = self.extract_feat(batch_inputs)
- 调用
extract_feat方法提取输入图像的特征。这些特征将被用于后续的 RPN 和 RoI 头的预测。
if batch_data_samples[0].get('proposals', None) is None:rpn_results_list = self.rpn_head.predict(x, batch_data_samples, rescale=False)
else:rpn_results_list = [data_sample.proposals for data_sample in batch_data_samples]
- 检查输入数据样本中是否已经包含了预定义的提议(
proposals)。如果没有,使用 RPN 头的predict方法生成区域提议。如果有,直接使用这些预定义的提议。
results_list = self.roi_head.predict(x, rpn_results_list, batch_data_samples, rescale=rescale)
- 调用 RoI 头的
predict方法,传入特征x、RPN 的结果rpn_results_list、数据样本batch_data_samples和rescale参数。这一步将生成最终的预测结果,包括类别、置信度和边界框。
batch_data_samples = self.add_pred_to_datasample(batch_data_samples, results_list)
- 调用
add_pred_to_datasample方法,将预测结果results_list添加到数据样本batch_data_samples中。这通常涉及到更新数据样本中的pred_instances属性,它包含了预测的类别、置信度、边界框等信息。
返回值
return batch_data_samples
- 返回更新后的
batch_data_samples,它现在包含了每个图像的预测结果。
结论
two_stage.py文件封装了MMDetection中两阶段检测的本质。它提供了一种结构化的方法来构建具有模块化设计、灵活性和易于定制的检测器。理解这段代码对于任何希望使用MMDetection实现或修改两阶段检测器的人来说都是至关重要的。
想要更深入地探索或亲自动手使用MMDetection,可以参考官方文档和GitHub仓库。编程愉快!
本文旨在提供对MMDetection中TwoStageDetector类的全面理解,重点关注其架构和功能。对于进一步的探索或特定用例,建议探索源代码和配置文件。
相关文章:
〖open-mmlab: MMDetection〗解析文件:mmdet/models/detectors/two_stage.py
目录 MMDetection中的两阶段检测器:深入解析two_stage.py源码两阶段检测器概述two_stage.py的关键组件类定义和初始化构造函数Neck头配置RPN头配置RoI头配置_load_from_state_dict方法概述参数解释代码解析 特征提取方法签名文档字符串(Docstring&#x…...
【最新华为OD机试E卷-支持在线评测】机器人活动区域(100分)多语言题解-(Python/C/JavaScript/Java/Cpp)
🍭 大家好这里是春秋招笔试突围 ,一枚热爱算法的程序员 ✨ 本系列打算持续跟新华为OD-E/D卷的三语言AC题解 💻 ACM金牌🏅️团队| 多次AK大厂笔试 | 编程一对一辅导 👏 感谢大家的订阅➕ 和 喜欢💗 🍿 最新华为OD机试D卷目录,全、新、准,题目覆盖率达 95% 以上,…...
C语言:刷题日志(1)
一.阶乘计算升级版 本题要求实现一个打印非负整数阶乘的函数。 其中n是用户传入的参数,其值不超过1000。如果n是非负整数,则该函数必须在一行中打印出n!的值,否则打印“Invalid input”。 首先,知道阶乘是所有小于及等于该数的…...
ios私钥证书(p12)导入失败,Windows OpenSSl 1.1.1 下载
ios私钥证书(p12)导入失败 如果你用的OpenSSL版本是v3那么恭喜你V3必然报这个错,解决办法将OpenSSL 3降低成 v1。 Windows OpenSSl 1.1.1 下载 阿里云网盘下载地址:OpenSSL V1...
嵌入式面试经典30问:二
1. 嵌入式系统中,如何选择合适的微控制器或微处理器? 在嵌入式系统中选择合适的微控制器(MCU)或微处理器(MPU)时,需要考虑多个因素以确保所选组件能够满足项目的具体需求。以下是一些关键步骤和…...
目标检测-YOLOv1
YOLOv1介绍 YOLOv1(You Only Look Once version 1)是一种用于目标检测的深度学习算法,由Joseph Redmon等人于2016年提出。它基于单个卷积神经网络,将目标检测任务转化为一个回归问题,通过在图像上划分网格并预测每个网…...
python基础语法八-异常
书接上回: python基础语法一-基本数据类型 python基础语法二-多维数据类型 python基础语法三-类 python基础语法四-数据可视化 python基础语法五-函数 python基础语法六-正则匹配 python基础语法七-openpyxl操作excel 1. 异常简介 (1)异常:遇到…...
【堆的应用--C语言版】
前面一节我们都已将堆的结构(顺序存储)已经实现,对树的相关概念以及知识做了一定的了解。其中我们在实现删除操作和插入操作的时候,我们还同时实现了建大堆(小堆)的向上(下)调整算法…...
【微信小程序】搭建项目步骤 + 引入Tdesign UI
目录 创建1个空文件夹,选择下图基础模板 开启/支持sass 创建公共style文件并引入 引入Tdesign UI: 1. 初始化: 2. 安装后,开发工具进行构建: 3. 修改 app.json 4. 使用 5. 自定义主题色 创建1个空文件夹,选择下…...
android系统源码12 修改默认桌面壁纸--SRO方式
1、aosp12修改默认桌面壁纸 代码路径 :frameworks\base\core\res\res\drawable-nodpi 替换成自己的图片即可,不过需要覆盖所有目录下的图片。 由于是静态修改,则需要make一下,重新编译。 2、方法二Overlay方式 由于上述方法有…...
Echarts可视化
echarts是一个基于javascripts的开源可视化图表库 画图步骤: 1.引入echarts.js文件 <script src" https://cdn.jsdelivr.net/npm/echarts5.5.1/dist/echarts.min.js"></script> 也可将文件下载到本地通过src引入。 2. 准备一个呈现图表的…...
验证linux gpu是否可用
通过torch验证 import torchprint(torch.__version__) # 查看torch当前版本号 print(torch.version.cuda) # 编译当前版本的torch使用的cuda版本号 print(torch.cuda.is_available()) # 查看当前cuda是否可用于当前版本的Torch,如果输出True,则表示可…...
JavaScript( 简介)
目录 含义 实例 js代码位置 1 外部引入js文件 2 在 HTML 中,JavaScript 代码必须位于 标签之间。 小结 含义 js是一门脚本语言,能够改变HTML内容 实例 getElementById() 是多个 JavaScript HTML 方法之一。 本例使用该方法来“查找” id"d…...
Linux中的编译器gcc/g++
目录 一、gcc与g的区别 1.gcc编译器使用 2.g编译器使用 二、gcc/g编译器编译源文件过程 1.预处理 2.编译 3.汇编 4.链接 三、静态库和动态库 1.库中的头文件作用 2.静态库 3.动态库 四、gcc编译器的一些选项命令 一、gcc与g的区别 gcc用于编译C语言代码ÿ…...
RK3568安装部署Docker容器
设置华为镜像源 sudo sed -i s/huaweicloud.com/ustc.edu.cn/g /etc/apt/sources.list更新索引 rootok3568:/home/forlinx# sudo apt-get update Hit:1 http://ports.ubuntu.com/ubuntu-ports focal InRelease Hit:2 http://ports.ubuntu.com/ubuntu-ports focal-updates InR…...
Ubuntu 常用指令和作用解析
Ubuntu 常用指令和作用解析 Ubuntu 是一种常见的 Linux 发行版,它利用了 Unix 的力量和开源软件的精神。掌握常用指令可以提高我们在使用 Ubuntu 时的效率。本文将介绍一些常见的指令及其用途。 目录 更新与安装软件文件与目录操作系统信息与资源监控用户与权限管…...
2024国赛数学建模C题完整论文:农作物的种植策略
农作物种植策略优化的数学建模研究(完整论文,持续更新,大家持续关注,更新见文末名片 ) 摘要 在本文中,建立了基于整数规划、动态规划、马尔科夫决策过程、不确定性建模、多目标优化、相关性分析、蒙特卡洛…...
【语音告警】博灵智能语音报警灯JavaScript循环播报场景实例-语音报警灯|声光报警器|网络信号灯
功能说明 本文将以JavaScript代码为实例,讲解如何通过JavaScript代码调用博灵语音通知终端 A4实现声光语音告警。主要博灵语音通知终端如何实现无线循环播报或者周期播报的功能。 本代码实现HTTP接口的声光语音播报,并指定循环次数、播报内容。由于通知…...
指针与函数(三)
三 .指向函数的指针 函数和数组一样,经系统编译后,其目标代码在内存中连续存放,其名字本身就是一个地址,是函数的入口地址。C语言中,指针可以指向变量,也可以指向函数。 指问函数的指针的定义格式为 类型名(*指针变量名)参数表 其中参数表为函数指针所…...
锐捷网络2025届校园招聘正式启动,【NTA6dni】!
锐捷网络2025届校园招聘正式启动,内推码[NTA6dni]。 原文链接点这 投递链接点这 祝大家面试顺利,offer多多~ 有问题大家可以评论,互相交流~...
业务系统对接大模型的基础方案:架构设计与关键步骤
业务系统对接大模型:架构设计与关键步骤 在当今数字化转型的浪潮中,大语言模型(LLM)已成为企业提升业务效率和创新能力的关键技术之一。将大模型集成到业务系统中,不仅可以优化用户体验,还能为业务决策提供…...
C++初阶-list的底层
目录 1.std::list实现的所有代码 2.list的简单介绍 2.1实现list的类 2.2_list_iterator的实现 2.2.1_list_iterator实现的原因和好处 2.2.2_list_iterator实现 2.3_list_node的实现 2.3.1. 避免递归的模板依赖 2.3.2. 内存布局一致性 2.3.3. 类型安全的替代方案 2.3.…...
质量体系的重要
质量体系是为确保产品、服务或过程质量满足规定要求,由相互关联的要素构成的有机整体。其核心内容可归纳为以下五个方面: 🏛️ 一、组织架构与职责 质量体系明确组织内各部门、岗位的职责与权限,形成层级清晰的管理网络…...
Springcloud:Eureka 高可用集群搭建实战(服务注册与发现的底层原理与避坑指南)
引言:为什么 Eureka 依然是存量系统的核心? 尽管 Nacos 等新注册中心崛起,但金融、电力等保守行业仍有大量系统运行在 Eureka 上。理解其高可用设计与自我保护机制,是保障分布式系统稳定的必修课。本文将手把手带你搭建生产级 Eur…...
MySQL 8.0 OCP 英文题库解析(十三)
Oracle 为庆祝 MySQL 30 周年,截止到 2025.07.31 之前。所有人均可以免费考取原价245美元的MySQL OCP 认证。 从今天开始,将英文题库免费公布出来,并进行解析,帮助大家在一个月之内轻松通过OCP认证。 本期公布试题111~120 试题1…...
大学生职业发展与就业创业指导教学评价
这里是引用 作为软工2203/2204班的学生,我们非常感谢您在《大学生职业发展与就业创业指导》课程中的悉心教导。这门课程对我们即将面临实习和就业的工科学生来说至关重要,而您认真负责的教学态度,让课程的每一部分都充满了实用价值。 尤其让我…...
Yolov8 目标检测蒸馏学习记录
yolov8系列模型蒸馏基本流程,代码下载:这里本人提交了一个demo:djdll/Yolov8_Distillation: Yolov8轻量化_蒸馏代码实现 在轻量化模型设计中,**知识蒸馏(Knowledge Distillation)**被广泛应用,作为提升模型…...
处理vxe-table 表尾数据是单独一个接口,表格tableData数据更新后,需要点击两下,表尾才是正确的
修改bug思路: 分别把 tabledata 和 表尾相关数据 console.log() 发现 更新数据先后顺序不对 settimeout延迟查询表格接口 ——测试可行 升级↑:async await 等接口返回后再开始下一个接口查询 ________________________________________________________…...
在Mathematica中实现Newton-Raphson迭代的收敛时间算法(一般三次多项式)
考察一般的三次多项式,以r为参数: p[z_, r_] : z^3 (r - 1) z - r; roots[r_] : z /. Solve[p[z, r] 0, z]; 此多项式的根为: 尽管看起来这个多项式是特殊的,其实一般的三次多项式都是可以通过线性变换化为这个形式…...
探索Selenium:自动化测试的神奇钥匙
目录 一、Selenium 是什么1.1 定义与概念1.2 发展历程1.3 功能概述 二、Selenium 工作原理剖析2.1 架构组成2.2 工作流程2.3 通信机制 三、Selenium 的优势3.1 跨浏览器与平台支持3.2 丰富的语言支持3.3 强大的社区支持 四、Selenium 的应用场景4.1 Web 应用自动化测试4.2 数据…...
