当前位置: 首页 > article >正文

目标检测 TaskAlignedAssigner 原理

文章目录

    • TaskAlignedAssigner 原理和代码
    • 使用示例

TaskAlignedAssigner 原理和代码

原理主要是结合预测的分类分数和边界框与真实标注的信息,找出与真实目标最匹配的锚点,为这些锚点分配对应的目标标签、边界框和分数。

TaskAlignedAssigner 是目标检测中用于对齐分类和定位任务的样本分配器。其核心思想是通过综合分类得分和预测框与真实框的 IoU,动态选择最合适的锚点作为正样本。具体步骤如下:

  1. 计算 IoU:预测框与真实框之间的 IoU。

  2. 提取分类得分:根据真实框的类别,提取对应类别的分类得分。

  3. 任务对齐指标:计算每个锚点的任务对齐指标,公式为:
    在这里插入图片描述

其中, α \alpha α β \beta β 是超参数,用于平衡分类和定位的重要性。

  1. 中心约束:过滤掉锚点中心不在真实框内的候选。

  2. 动态 Top-k 选择:对每个真实框,选择任务对齐指标最高的前 k k k 个锚点作为正样本。

  3. 冲突处理:若一个锚点被多个真实框选中,保留指标最高的分配。

import torch
import torch.nn as nndef pairwise_iou(boxes1, boxes2):"""计算两组框之间的 IoU。Args:boxes1 (Tensor): (N, 4) 格式为 xyxyboxes2 (Tensor): (M, 4) 格式为 xyxyReturns:iou (Tensor): (N, M)"""lt = torch.max(boxes1[:, None, :2], boxes2[:, :2])rb = torch.min(boxes1[:, None, 2:], boxes2[:, 2:])wh = (rb - lt).clamp(min=0)inter = wh[:, :, 0] * wh[:, :, 1]area1 = (boxes1[:, 2] - boxes1[:, 0]) * (boxes1[:, 3] - boxes1[:, 1])area2 = (boxes2[:, 2] - boxes2[:, 0]) * (boxes2[:, 3] - boxes2[:, 1])union = area1[:, None] + area2 - interreturn inter / (union + 1e-9)
class TaskAlignedAssigner(nn.Module):def __init__(self, topk=13, alpha=1.0, beta=6.0):super().__init__()# 初始化超参数self.topk = topk    # 每个真实框(GT)最多选择的正样本锚点数量,默认13self.alpha = alpha  # 分类得分的指数权重,用于平衡分类任务的重要性,默认1.0self.beta = beta    # IoU的指数权重,用于平衡定位任务的重要性,默认6.0@torch.no_grad()def forward(self, cls_scores, bbox_preds, gt_bboxes, gt_labels):"""输入:cls_scores  : (B, num_anchors, num_classes) 模型输出的分类得分(未归一化)bbox_preds : (B, num_anchors, 4)           模型输出的预测框坐标(xyxy格式)gt_bboxes  : (B, num_gts, 4)               真实框坐标(xyxy格式)gt_labels  : (B, num_gts)                  真实框的类别标签(0~num_classes-1)输出:pos_anchor_indices : (batch_idx, anchor_idx) 正样本锚点的批次索引和锚点索引pos_gt_indices     : (gt_idx,)               对应的真实框索引pos_labels         : (label,)               正样本的类别标签"""# 获取输入张量的维度信息batch_size, num_anchors, _ = cls_scores.shape  # B=批次大小, num_anchors=锚点总数device = cls_scores.device  # 设备信息(CPU/GPU)# 初始化分配结果张量(记录每个锚点分配的GT索引和类别)assigned_gt_inds = torch.zeros((batch_size, num_anchors), dtype=torch.long, device=device)  # Shape: (B, num_anchors),0表示未分配,非零值表示分配的GT索引+1(避免与0冲突)assigned_labels = torch.zeros((batch_size, num_anchors), dtype=torch.long, device=device)  # Shape: (B, num_anchors),记录分配的类别标签# 逐样本处理(每个批次独立处理)for b in range(batch_size):# 提取当前样本的预测框、真实框和类别标签bbox_pred = bbox_preds[b]  # (num_anchors,4) 当前批次的所有锚点预测框gt_bbox = gt_bboxes[b]     # (num_gts,4)     当前批次的所有真实框gt_label = gt_labels[b]    # (num_gts)       当前批次的所有真实框类别num_gts = gt_bbox.size(0)  # 当前样本的真实框数量# 若当前样本没有真实框,跳过处理if num_gts == 0:continue#################################################################### Step 1: 计算预测框与真实框的 IoU#################################################################### 输入: bbox_pred (num_anchors,4), gt_bbox (num_gts,4)# 输出: iou (num_anchors, num_gts),每个元素为锚点与GT的IoUiou = pairwise_iou(bbox_pred, gt_bbox)  #################################################################### Step 2: 提取对应真实框类别的分类得分#################################################################### cls_scores[b]的shape为 (num_anchors, num_classes)# gt_label的shape为 (num_gts),每个元素是对应GT的类别索引# 通过高级索引提取每个锚点在对应GT类别上的得分# 结果scores的shape为 (num_anchors, num_gts)scores = cls_scores[b][:, gt_label]  #################################################################### Step 3: 计算任务对齐指标(分类得分^α * IoU^β)###################################################################alignment_metrics = scores.pow(self.alpha) * iou.pow(self.beta)# alignment_metrics的shape: (num_anchors, num_gts)#################################################################### Step 4: 中心点约束(过滤锚点中心不在GT内部的候选)#################################################################### 计算锚点中心坐标cx = (bbox_pred[:, 0] + bbox_pred[:, 2]) / 2  # (num_anchors,)cy = (bbox_pred[:, 1] + bbox_pred[:, 3]) / 2  # (num_anchors,)# 判断锚点中心是否在GT框内(利用广播机制)# cx[:, None]的shape: (num_anchors,1)# gt_bbox[None, :, 0]的shape: (1, num_gts)# 比较操作后,in_gt的shape为 (num_anchors, num_gts)in_gt = (cx[:, None] >= gt_bbox[None, :, 0]) & \(cx[:, None] <= gt_bbox[None, :, 2]) & \(cy[:, None] >= gt_bbox[None, :, 1]) & \(cy[:, None] <= gt_bbox[None, :, 3])# 将不在GT内的锚点指标置零alignment_metrics *= in_gt.float()  # (num_anchors, num_gts)#################################################################### Step 5: 动态选择每个GT的Top-k锚点###################################################################candidate_metrics = []       # 保存所有候选锚点的指标值(多个张量)candidate_gt_indices = []    # 保存候选锚点对应的GT索引(平铺列表)candidate_anchor_indices = [] # 保存候选锚点的索引(多个张量)# 遍历每个真实框(GT)for gt_idx in range(num_gts):# 提取当前GT对应的所有锚点指标metrics = alignment_metrics[:, gt_idx]  # (num_anchors,)# 筛选有效锚点(指标>0表示中心在GT内且指标非零)valid = metrics > 0  # (num_anchors,)if not valid.any():  # 若没有有效锚点,跳过该GTcontinue# 确定实际选择的Top-k数量(不超过有效锚点数和预设topk)k = min(self.topk, valid.sum().item())# 选择当前GT的Top-k锚点(指标值和索引)topk_metrics, topk_anchors = metrics.topk(k)# 保存结果candidate_metrics.append(topk_metrics)  # 添加一个形状为(k,)的张量candidate_gt_indices.extend([gt_idx] * k)  # 扩展k个gt_idx元素candidate_anchor_indices.append(topk_anchors)  # 添加形状为(k,)的锚点索引# 若当前样本无候选锚点,跳过后续处理if not candidate_metrics:continue#################################################################### Step 6: 合并候选并排序#################################################################### 合并所有候选指标candidate_metrics = torch.cat(candidate_metrics)  # (total_candidates,)# 转换GT索引为张量(total_candidates,)candidate_gt_indices = torch.tensor(candidate_gt_indices, dtype=torch.long, device=device)# 合并所有候选锚点索引candidate_anchor_indices = torch.cat(candidate_anchor_indices)  # (total_candidates,)# 按指标降序排序(从高到低)sorted_idx = candidate_metrics.argsort(descending=True)  # (total_candidates,)candidate_gt_indices = candidate_gt_indices[sorted_idx]  # 按排序调整GT索引candidate_anchor_indices = candidate_anchor_indices[sorted_idx]  # 按排序调整锚点索引#################################################################### Step 7: 分配正样本(解决冲突,高优先级指标优先)###################################################################assigned_mask = torch.zeros(num_anchors, dtype=torch.bool, device=device)  # 标记锚点是否已被分配# 按排序后的顺序遍历候选锚点for anchor_idx, gt_idx in zip(candidate_anchor_indices, candidate_gt_indices):if not assigned_mask[anchor_idx]:# 记录分配的GT索引(+1避免与0冲突)assigned_gt_inds[b, anchor_idx] = gt_idx + 1  # 记录分配的类别标签assigned_labels[b, anchor_idx] = gt_label[gt_idx]  # 标记该锚点已分配assigned_mask[anchor_idx] = True  #################################################################### Step 8: 提取最终正样本信息#################################################################### 生成正样本的掩码(assigned_gt_inds > 0表示已分配)pos_mask = assigned_gt_inds > 0  # (B, num_anchors)# 获取正样本的批次索引和锚点索引(非零元素的坐标)pos_anchor_indices = pos_mask.nonzero(as_tuple=True)  # 格式为 (batch_indices, anchor_indices),例如:(tensor([0,0,1]), tensor([5,8,3]))# 获取对应的GT索引(需减去1恢复原始索引)pos_gt_indices = assigned_gt_inds[pos_mask] - 1  # (num_pos_samples,)# 获取正样本的类别标签pos_labels = assigned_labels[pos_mask]  # (num_pos_samples,)return pos_anchor_indices, pos_gt_indices, pos_labels

使用示例

if __name__ == "__main__":# 参数设置batch_size = 2num_anchors = 100   # 锚点数量num_classes = 20     # 类别数num_gts = 3         # 每个样本的 GT 数量# 模拟数据cls_scores = torch.rand(batch_size, num_anchors, num_classes)  # 随机分类得分bbox_preds = torch.rand(batch_size, num_anchors, 4) * 100      # 随机预测框(xyxy)gt_bboxes = torch.rand(batch_size, num_gts, 4) * 100           # 随机真实框(xyxy)gt_labels = torch.randint(0, num_classes, (batch_size, num_gts)) # 随机 GT 类别标签# 初始化分配器assigner = TaskAlignedAssigner(topk=5, alpha=1.0, beta=6.0)# 分配正样本pos_anchors, pos_gts, pos_labels = assigner(cls_scores, bbox_preds, gt_bboxes, gt_labels)# 输出结果print("正样本锚点索引:", pos_anchors)  # 格式为 (batch_idx, anchor_idx) 的元组print("对应真实框索引:", pos_gts)     # 每个正样本锚点对应的 GT 索引print("正样本标签:", pos_labels)      # 每个正样本锚点的类别标签

相关文章:

目标检测 TaskAlignedAssigner 原理

文章目录 TaskAlignedAssigner 原理和代码使用示例 TaskAlignedAssigner 原理和代码 原理主要是结合预测的分类分数和边界框与真实标注的信息&#xff0c;找出与真实目标最匹配的锚点&#xff0c;为这些锚点分配对应的目标标签、边界框和分数。 TaskAlignedAssigner 是目标检…...

Qt popup窗口半透明背景

半透明弹窗需要paintEvent()接口支持 方法一&#xff1a;使用setStyleSheet设置半透明样式&#xff0c;如果是子窗口&#xff0c;则可注释构建函数内属性设置 class TranslucentWidget : public QWidget { public: explicit TranslucentWidget(QWidget *parent nullptr)…...

游戏:元梦之星游戏开发代码(谢苏)

《元梦之星》是一款轻松社交派对游戏,玩家们可以化身星宝,体验纯粹的游玩乐趣,收获简单的快乐。无论i人e人,都能轻松找到属于自己的社交方式。 《元梦之星》的快乐,可以是闯关夺冠时的激动&#xff0c;谁是狼人推理的巧妙&#xff0c;峡谷3V3打赢团战的爽感。也可以是星梦广场开…...

TCP协议原理与Java编程实战:从连接建立到断开的完整解析

1.TCP协议核心&#xff1a;面向连接的可靠通信基石 TCP&#xff08;Transmission Control Protocol&#xff0c;传输控制协议&#xff09;是互联网的“可靠信使”&#xff0c;属于传输层协议&#xff0c;其核心在于面向连接和可靠传输。它通过严谨的握手机制与数据控制逻辑&am…...

Linux的top命令使用

Linux系统中top命令详解及使用技巧 一、基础功能 top命令用于实时监控系统性能和进程活动&#xff0c;可查看以下信息&#xff1a; - CPU使用率 - 内存使用情况 - 进程状态信息 - 系统负载数据 二、使用步骤 1. 打开终端输入命令&#xff1a;top 2. 查看实时更新的数据界面&a…...

Spring Cloud Gateway 限流实践:基于 Redis 令牌桶算法的网关层流量治理

一、引言 在微服务架构中,API 网关作为流量枢纽,需对进入系统的请求进行精细化限流,以保护下游服务免受流量冲击。Spring Cloud Gateway 结合 Redis 实现的令牌桶算法,为网关层限流提供了高效、分布式的解决方案。本文将深入解析其原理、配置及实践优化。 二、技术栈与原…...

可视化大屏实现全屏或非全屏

通过点击按钮实现全屏和非全屏效果展示 代码如下&#xff1a; <template> //点击icon图片进入全屏或非全屏<img :src"screenStatus ? /src/assets/noFull.png : /src/assets/full.png" alt"" click"enterFullScreen" /> </te…...

java8函数式接口(函数式接口的匿名实现类作为某些方法的入参)

文章目录 前置介绍通过 lambda 表达式&#xff0c;使用匿名类&#xff0c;实现函数式接口函数式接口和回调函数的关系函数式接口的应用 前置介绍 是 Java 8 引入的核心概念之一&#xff0c;指的是 仅包含一个抽象方法的接口。它可以被 FunctionalInterface 注解标记&#xff0…...

linux自有服务

文章目录 [TOC](文章目录)linux自有服务概述systemctl管理服务命令CentOS 7 之前CentOS 7 常用自有服务ntpd或systemd-timesyncd时间同步服务ntp同步服务器原理ntpd时间同步操作systemd-timesyncd同步原理systemd-timesyncd时间同步操作 firewalld防火墙计划任务crontab CentOS…...

UniApp网页版集成海康视频播放器

注意&#xff1a;本人全部集成好后使用最新的海康平台下载插件进行替换后就不能预览视频 使用Uni插件进行集成&#xff1a;海康视频H5播放器组件 - DCloud 插件市场 CSDN资源下载&#xff1a;https://download.csdn.net/download/wangdaoyin2010/90910975 注意&#xff1a;初…...

Filter和Interceptor详解(一文了解执行阶段及其流程)

Filter和Interceptor的区别 Filter&#xff08;过滤器&#xff09;和 Interceptor&#xff08;拦截器&#xff09;都是用于在请求处理前后插入额外逻辑的组件&#xff0c;下面依次介绍&#xff0c;并额外介绍Spring Gateway的过滤器&#xff08;GlobalFilter/GatewayFilter&am…...

鸿蒙仓颉开发语言实战教程:实现商城应用详情页

昨天有朋友提到鸿蒙既然有了ArkTs开发语言&#xff0c;为什么还需要仓颉开发语言。其实这个不难理解&#xff0c;安卓有Java和Kotlin&#xff0c;iOS先后推出了Objective-C和Swift&#xff0c;鸿蒙有两种开发语言也就不奇怪了。而且仓颉是比ArkTs更加灵活的语言&#xff0c;虽然…...

GitAny - 無需登入的 GitHub 最新倉庫檢索工具

地址&#xff1a;https://github.com/MartinxMax/gitany GitAny - 無需登入的 GitHub 專案搜尋工具 GitAny 是一款基於 Python 的工具&#xff0c;允許你在無需登入的情況下搜尋當天最新的 GitHub 專案。它支援模糊搜尋、條件篩選以及倉庫資料的視覺化分析。 安裝依賴 $ pip…...

在飞牛nas系统上部署gitlab

在飞牛nas系统上部署gitlab需要使用docker进行部署&#xff0c;如下将介绍详细的部署流程。 文章目录 1. docker镜像2. 拉取镜像3. 运行容器4. 运行和访问gitlab5. 一些小配置5.1 url问题5.2 ssh端口5.3 其他配置 1. docker镜像 首先需要找一个gitlab的docker镜像地址&#x…...

深入理解 Redis 哨兵模式

Redis 哨兵模式深度解析&#xff1a;从原理到实践的全流程指南 在分布式系统架构中&#xff0c;Redis 作为高性能的内存数据库&#xff0c;其哨兵模式&#xff08;Sentinel&#xff09;是保障服务高可用性的核心方案。本文将从基础概念、运行机制出发&#xff0c;结合具体配置…...

SQL进阶之旅 Day 4:子查询与临时表优化

文章标题 【SQL进阶之旅 Day 4】子查询与临时表优化 文章内容 开篇&#xff1a;SQL进阶之旅的第4天 在“SQL进阶之旅”系列中&#xff0c;第4天的主题是子查询与临时表优化。这是SQL开发中不可或缺的一部分&#xff0c;尤其在处理复杂查询时&#xff0c;合理使用子查询和临…...

[特殊字符]《Qt实战:基于QCustomPlot的装药燃面动态曲线绘制(附右键菜单/样式美化/完整源码)》

1、将qcustomplot.cpp qcustomplot.h放入工程目录下引入qcustomplot 2、代码 .h #if defined(_MSC_VER) #pragma execution_character_set(...

力扣-最大连续一的个数

1.题目描述 2.题目链接 1004. 最大连续1的个数 III - 力扣&#xff08;LeetCode&#xff09; 3.代码解答 class Solution {public int longestOnes(int[] nums, int k) {int zero0,length0;for(int left0,right0;right<nums.length;right){if(nums[right]0){zero;}while…...

无人机避障——深蓝学院浙大栅格地图以及ESDF地图内容

Occupancy Grid Map & Euclidean Signed Distance Field: 【注意】&#xff1a;目的是为了将有噪声的传感器收集起来&#xff0c;用于实时的建图。 Occupancy Grid Map&#xff1a; 概率栅格&#xff1a; 【注意】&#xff1a;由于传感器带有噪声&#xff0c;在实际中基于…...

Postman基础操作

1.Postman是什么&#xff1f; Postman是接口测试的工具&#xff0c;简单来说它能模拟浏览器对服务器的某个接口发起请求并接收响应数据。 1.1 Postman工作原理 2.Postman发送请求 2.1 发送GET请求 我们知道GET请求是没用请求体的&#xff0c;所以我们需要将请求参数写在Param…...

【MPC控制 - 从ACC到自动驾驶】3 MPC控制器设计原理与参数配置:打造ACC的“最强大脑”

【MPC控制 - 从ACC到自动驾驶】MPC控制器设计原理与参数配置&#xff1a;打造ACC的“最强大脑” 在Day 1&#xff0c;我们认识了ACC自适应巡航和MPC这位“深谋远虑的棋手”。Day 2&#xff0c;我们一起给汽车“画像”&#xff0c;建立了它的纵向动力学模型&#xff0c;并把它翻…...

Unity3D仿星露谷物语开发52之菜单页面

1、目标 创建菜单页面&#xff0c;可通过Esc键开启或关闭。 当把鼠标悬停在上面时它会高亮&#xff0c;然后当点击按钮时标签页会被选择。 2、 创建PauseMenuCanvas &#xff08;1&#xff09;创建Canvas 在Hierarchy -> PersistentScene -> UI下创建新的Cavans命名为…...

待定事项之存储数据

#### 部署云服务器 ![alt text](./img/屏幕截图%202025-05-18%20132353.png) ### 部署云服务器完整步骤 1. **连接到云服务器** bash ssh root<服务器IP> 2. **创建项目目录结构** bash mkdir -p /var/www/three/study/待办事项 3. **克隆项目仓库** bash cd /var/www…...

电脑装的数据越多,会不会越重

在这个数字化飞速发展的时代&#xff0c;有一个看似荒诞却又引人深思的问题&#xff1a;电脑装的数据越多&#xff0c;会不会越重&#xff1f; 先来说说大家的普遍认知&#xff0c;我们通常认为数据只是一些虚拟的代码和信息&#xff0c;存放在电脑的硬盘或其他存储设备中&…...

君正Ingenic webRTC P2P库libyangpeerconnection7编程指南

概述 libyangpeerconnection7是一个实现P2P媒体传输/数据通道的一个轻量级的webRTC库&#xff0c;基于metaRTC7.0的传输模块构建&#xff0c;支持H264/H265视频编码&#xff0c;通过 P2P 连接为用户提供高效、低延迟的音视频和数据通信。 君正版libyangpeerconnection7可适用…...

MySQL——复合查询表的内外连

目录 复合查询 回顾基本查询 多表查询 自连接 子查询 where 字句中使用子查询 单行子查询 多行子查询 多列子查询 from 字句中使用子查询 合并查询 实战OJ 查找所有员工入职时候的薪水情况 获取所有非manager的员工emp_no 获取所有员工当前的manager 表的内外…...

小米玄戒O1架构深度解析(一):十核异构设计与缓存层次详解

前言 这两天&#xff0c;小米的全新SOC玄戒O1横空出世&#xff0c;引发了科技数码圈的一次小地震&#xff0c;那么小米的这颗所谓的自研SOC&#xff0c;内部究竟有着什么不为人知的秘密呢&#xff1f;我们一起一探究竟。 目录 前言1 架构总览1.1 基本构成1.2 SLC缺席的原因探…...

Numba模块的用法(高性能计算)

文章目录 介绍核心装饰器与基础用法@jit(nopython=True):最常用的编译装饰器@njit的简写编译时指定类型签名并行加速(parallel=True)@cuda.jit: GPU 编程(CUDA)向量化函数(@vectorize)性能优化技巧调试与常见问题调试模式常见错误适用场景与局限性实例:加速蒙特卡洛模拟…...

Kafka自定义分区策略实战避坑指南

文章目录 概要代码示例小结 概要 kafka生产者发送消息默认根据总分区数和设置的key计算哈希取余数&#xff0c;key不变就默认存放在一个分区&#xff0c;没有key则随机数分区&#xff0c;明显默认的是最不好用的&#xff0c;那kafka也提供了一个轮询分区策略&#xff0c;我自己…...

PyTorch中cdist和sum函数使用示例详解

以下是PyTorch中cdist与sum函数的联合使用详解: 1. cdist函数解析 功能:计算两个张量间的成对距离矩阵 输入格式: X1:形状为(B, P, M)的张量X2:形状为(B, R, M)的张量p:距离类型(默认2表示欧式距离)输出:形状为(B, P, R)的距离矩阵,其中元素 d i j d_{ij} dij​表示…...