超大规模分类(四):Partial FC
人脸识别任务里,通常利用全连接层,来做人脸的分类。会面临三个实际问题:
- 真实的人脸识别数据噪声严重
- 真实的人脸识别数据存在严重的长尾分布问题,一些类别样本多,多数类别样本少
- 人脸类别越来越多,全连接层训练成本越来越高,难度越来越大
于是,有研究人员提出Partial FC,拒绝全量更新负类别中心,而是仅更新少部分负类别中心。该做法优势在于
- 降低噪声数据被采样的概率
- 降低高频负类别中心被选中的概率
- 降低负类别中心的更新频率,降低训练难度
如下图所示:
![![[Pasted image 20250219215540.png]]](https://i-blog.csdnimg.cn/direct/5a1cf3c3680e4880b811c4b3290956ec.png)
问题建模
人脸识别领域,常用的分类损失公式化定义如下:
L = − 1 B ∑ i = 1 B l o g e W y i T ⋅ x i e W y i T ⋅ x i + ∑ j = 1 , j ≠ y i C e W j T ⋅ x i (1) L=-\frac{1}{B}\sum_{i=1}^{B}log\frac{e^{W_{y_i}^T\cdot x_i}}{e^{W_{y_i}^T}\cdot x_i+\sum_{j=1,j\neq y_i}^Ce^{W_j^T\cdot x_i}} \tag{1} L=−B1i=1∑BlogeWyiT⋅xi+∑j=1,j=yiCeWjT⋅xieWyiT⋅xi(1)
,其中, B B B表示batch size, C C C表示类别个数, W j T W_{j}^T WjT表示第 j j j个类别中心的特征, ( x i , y i ) (x_i,y_i) (xi,yi)表示第 i i i个样本的特征为 x i x_i xi,类别为 y i y_i yi。
真实大规模人脸数据实际使用时,有以下问题:![![[Pasted image 20250219222626.png]]](https://i-blog.csdnimg.cn/direct/72a04a75e2ee4ffdb4f6fb3a1ed68c62.png)
- 噪声问题:见上图(a),图片对都是一个人的图片,但是被分到不同的类别,这对模型训练有非常大的干扰。
- 长尾分布:见上图(b),大部分类别(identity)包含的图像数量很少,在WebFace42M中,44.57%的类别包含的图像数量少于10张。这会导致低频类别的类别中心更新缓慢,而高频类别的类别中心更新频繁。
- 训练资源:全连接层一般表示为 W ∈ R D × C W\in \mathbb{R}^{D\times C} W∈RD×C,其中 D D D表示维度, C C C表示类别数。假设 D = 512 D=512 D=512,如果类别数是1000,000(一百万)
- fp16下,全连接层的显存消耗为: 512 × 100 , 000 × 2 1024 × 1024 × 1024 = 0.95 G B \frac{512\times 100,000 \times 2}{1024\times 1024\times 1024}=0.95GB 1024×1024×1024512×100,000×2=0.95GB
- 公式(1)中,需要计算 B B B个 x i x_i xi属于类别中心 W j T W_{j}^T WjT的logit,维度是 R B × D × C \mathbb{R}^{B\times D\times C} RB×D×C,显存消耗为 512 × 100 , 000 × 2 1024 × 1024 × 1024 ⋅ B = 0.95 B G B \frac{512\times 100,000 \times 2}{1024\times 1024\times 1024}\cdot B=0.95B \,GB 1024×1024×1024512×100,000×2⋅B=0.95BGB,batchsize越大,需要的显存越大。
- 在下图,进行了模型并行和partial fc在显存消耗和训练速度上的比较,可以发现:
- partial fc显著降低了对logit的显存消耗
- partial fc略微降低了对存储类别中心的显存消耗
- partial fc未降低对特征抽取网络的显存消耗(将原图像转换为特征的模型的消耗)
- 由于partial fc减少了负类别中心的数量,降低了logit计算的复杂度,随着训练类别越多,加速比越高。
![![[Pasted image 20250219222640.png]]](https://i-blog.csdnimg.cn/direct/5120c47931814a75aff2bfa7564d4914.png)
partial fc
为了缓解上述问题,提出了partial fc,通过稀疏更新全连接层的参数,来支持大规模人脸识别模型的训练。
整体架构如下图所示:![![[Pasted image 20250219222722.png]]](https://i-blog.csdnimg.cn/direct/e1214ff7f3ee4102a614396d3e009a0e.png)
模型通过数据并行训练的,不同GPU包含了不同数据的特征,整体步骤如下:
- 汇总不同GPU里的图像特征和图像标签
- 将汇总的图像特征和图像标签送到每张GPU上
- 将全连接层(即 C C C个类别中心)均分到每张GPU上
- 在单张卡上,保留需要的正类别中心,以及采样固定比例的负类别中心
- 利用样本、正类别中心、负类别中心计算损失函数
代码实现
def forward(self,local_embeddings: torch.Tensor,local_labels: torch.Tensor,):"""Parameters:----------local_embeddings: torch.Tensorfeature embeddings on each GPU(Rank).local_labels: torch.Tensorlabels on each GPU(Rank).Returns:-------loss: torch.Tensorpass"""local_labels.squeeze_()local_labels = local_labels.long()batch_size = local_embeddings.size(0)if self.last_batch_size == 0:self.last_batch_size = batch_sizeassert self.last_batch_size == batch_size, (f"last batch size do not equal current batch size: {self.last_batch_size} vs {batch_size}")_gather_embeddings = [torch.zeros((batch_size, self.embedding_size)).cuda()for _ in range(self.world_size)]_gather_labels = [torch.zeros(batch_size).long().cuda() for _ in range(self.world_size)]_list_embeddings = AllGather(local_embeddings, *_gather_embeddings)distributed.all_gather(_gather_labels, local_labels)# 汇总不同GPU里的图像特征和图像标签embeddings = torch.cat(_list_embeddings)labels = torch.cat(_gather_labels)labels = labels.view(-1, 1)# self.class_start表示该GPU中,分配的类别中心起始id# self.num_local表示该GPU中,分配的类别中心数量# 于是,该GPU的类别中心id范围是[类别中心起始id, 类别中心起始id + 类别中心数量]# 在单张卡上,仅保留需要的正类别中心index_positive = (self.class_start <= labels) & (labels < self.class_start + self.num_local)labels[~index_positive] = -1labels[index_positive] -= self.class_start# 在单张卡上,采样固定比例的负类别中心if self.sample_rate < 1:weight = self.sample(labels, index_positive)else:weight = self.weightwith torch.cuda.amp.autocast(self.fp16):norm_embeddings = normalize(embeddings)norm_weight_activated = normalize(weight)logits = linear(norm_embeddings, norm_weight_activated)if self.fp16:logits = logits.float()logits = logits.clamp(-1, 1)# 基于样本特征、样本标签、正类别中心,采样的负类别中心,计算损失logits = self.margin_softmax(logits, labels)loss = self.dist_cross_entropy(logits, labels)return loss
优势
partial fc的核心思想是”降低训练中负类别中心数量,显式得减少需要更新的参数量“。负类别中心采样比例越低,节约的显存越多。
为了更好理解partial fc对长尾分布、噪声问题的影响,计算分类损失对于样本 x i x_i xi的梯度,如下:
∂ L ∂ x i = − ( ( 1 − p + ) W + − ∑ j ∈ S , j ≠ y i p j − W j − ) (2) \frac{\partial L}{\partial x_i}=-((1-p^+)W^+-\sum_{j\in \mathbb{S}, j\neq y_i}p_j^-W_j^-) \tag{2} ∂xi∂L=−((1−p+)W+−j∈S,j=yi∑pj−Wj−)(2)
其中, p + p^+ p+、 p − p^- p−分别表示通过样本特征 x i x_i xi计算的logit分数, S \mathbb{S} S表示负类别中心, ∣ S ∣ = C × r |\mathbb{S}|=C\times r ∣S∣=C×r,通过采样比例 r r r,调整训练时的负样本数量。
样本特征 x i x_i xi的更新方向和正类别中心和负类别中心都有关系,partial fc随机减少负类别中心数量,减低噪声数据被采样的概率,降低高频负类别中心被选中的概率,有效缓解长尾问题和噪声问题。
注意:采样率为1,等同于选取所有负类别中心,进行模型训练。相当于原始fc分类器
为了进一步验证partial fc的作用原理,做了下述验证下实验
探究采样率与类内、类间相似度关系
![![[Pasted image 20250224175118.png]]](https://i-blog.csdnimg.cn/direct/2541376000dc4efdaa6abf21036b4a60.png)
(a)图中,采样率越低,APCS收敛至更高数值。APCS表示类内距离 A P C S = 1 B ∑ i = 1 B W y i T x i ∣ ∣ W y i ∣ ∣ ⋅ ∣ ∣ x i ∣ ∣ APCS=\frac{1}{B}\sum_{i=1}^B\frac{W_{y_i}^Tx_i}{||W_{y_i}||\cdot ||x_i||} APCS=B1i=1∑B∣∣Wyi∣∣⋅∣∣xi∣∣WyiTxi,说明采样率越低,类内相似度越大,类内越紧密。
(b)图中,采样率越低,MICS分布越往右,整体数值越大。MICS表示最大的类间余弦相似度 MICS i = max j ≠ i W i T W j ∥ W i ∥ ∥ W j ∥ \text{MICS}_i = \max_{j \neq i} \frac{W_i^T W_j}{\|W_i\| \, \|W_j\|} MICSi=j=imax∥Wi∥∥Wj∥WiTWj,说明采样率越低,类间相似度越大,类间拉不开。
探究采样率与评测集合效果关系
随着采样率越来越低,IJB-C、MFR-All评测集上的效果越来越差,如下:
![![[Pasted image 20250226180734.png]]](https://i-blog.csdnimg.cn/direct/63452da8fae1489ca5d26791f6649afe.png)
探究采样率对噪声数据的鲁棒性
为验证在噪声数据上的效果,做了如下实验:
![![[Pasted image 20250226181829.png]]](https://i-blog.csdnimg.cn/direct/a0501aed8a2d4647be2707b3f91dd07e.png)
构造了WebFace12M-Conflict数据集,随机将20万个类的样本放到另外60万个类别数据中。
图(a)纵坐标为AMNCS,指最小的负类中心距离(越大,说明样本和负类离得远)。可以发现,降低采样率(1.0->0.1),在干净数据上,效果接近;在噪声数据上,缓解过拟合问题。
图(b)的MICS,指最大的类间余弦相似度(越大,说明类别越相似,越区分不开类别)。可以发现,降低采样率(1.0->0.1),MICS分布往右,整体数值偏大。由于WebFace12M-Conflict数据集中,20万类的样本随机分布在其他类别中,类间余弦相似度本就很大,图(b)更好的刻画实际噪声分布。
![![[Pasted image 20250226210518.png]]](https://i-blog.csdnimg.cn/direct/706a7c96b908436cb6d91343dd251d90.png)
上图定义了两个概念,分别是conflict-hard和conflict-noise。conflict-hard表示利用真实负样本计算AMNCS;conflict-noise表示利用噪声负样本计算AMNCS。结果表明:
- r=1.0时,针对AMNCS指标,conflict-hard>conflict-noise,表明负样本不采样,会使得模型过分拟合数据集,导致对噪声数据不鲁棒(按理说应该是conflict-noise>conflict-hard)。
- r=0.1时,针对AMNCS指标,conflict-hard<conflict-noise,刻画出真实数据特性。
消融实验
不同数据集、不同采样率下partial fc
![![[Pasted image 20250226212125.png]]](https://i-blog.csdnimg.cn/direct/5cf2334d4bf449548ede42866290f39a.png)
不同网络结构下partial fc
![![[Pasted image 20250226212145.png]]](https://i-blog.csdnimg.cn/direct/e14f72a5958e4205b1675de850e32f7e.png)
对噪声数据鲁棒
采用WebFace12M-Conflict作为训练集合。
![![[Pasted image 20250226212355.png]]](https://i-blog.csdnimg.cn/direct/cf0b8d19dbe9428198916461333838ad.png)
![![[Pasted image 20250226212449.png]]](https://i-blog.csdnimg.cn/direct/24e5ed17deb341028d028001f919021c.png)
对长尾数据鲁棒
![![[Pasted image 20250226212618.png]]](https://i-blog.csdnimg.cn/direct/6a69a7a80869480d8c5746206f475622.png)
收敛速度、训练时间
![![[Pasted image 20250226212802.png]]](https://i-blog.csdnimg.cn/direct/dbfc930f35ac4ed59fde5491813881e7.png)
相关文章:
超大规模分类(四):Partial FC
人脸识别任务里,通常利用全连接层,来做人脸的分类。会面临三个实际问题: 真实的人脸识别数据噪声严重真实的人脸识别数据存在严重的长尾分布问题,一些类别样本多,多数类别样本少人脸类别越来越多,全连接层…...
uniapp 小程序如何实现大模型流式交互?前端SSE技术完整实现解析
文章目录 一、背景概述二、核心流程图解三、代码模块详解1. UTF-8解码器(处理二进制流)2. 请求控制器(核心通信模块)3. 流式请求处理器(分块接收)4. 数据解析器(处理SSE格式)5. 回调…...
因子分析详解:从理论到MATLAB实战
内容摘要: 本文系统解析因子分析的核心原理与MATLAB实战,涵盖数学模型、载荷矩阵估计、因子旋转及得分计算。通过上市公司盈利能力、消费者偏好等案例,演示数据标准化、因子提取与解释的全流程,并提供完整代码实现。深入对比因子分…...
【组态PLC】基于三菱西门子S7-200PLC和组态王液料混合系统组态设计【含PLC组态源码 M016期】
控制要求 总体控制要求:如面板图所示,本装置为三种液体混合模拟装置,由液面传感器SL1、SL2、SL3,液体A、B、C阀门与混合液阀门由电磁阀YV1、YV2、YV3、YV4,搅匀电机M,加热器H,温度传感器T组成。…...
js:根据后端返回的数组取出每一个数组的keyword字段然后拼接成一个逗号分隔的字符串
问: 现在有一个el-select, 后端接口返回数据为keyword:xxx,referenceNum:1,tagId:132sf32fasdfaf组成的数组, 现在select是多选, 但是但我选择多个下拉框选项后,后端需要前端返回的数据tagIds字段需要时一个字符串…...
基于大模型的肺纤维化预测及临床方案研究报告
目录 一、引言 1.1 研究背景与意义 1.2 研究目的与方法 二、大模型技术概述 2.1 大模型的基本原理 2.2 大模型在医疗领域的应用现状 三、肺纤维化相关知识 3.1 肺纤维化的病因与发病机制 3.2 肺纤维化的临床症状与诊断方法 3.3 肺纤维化的治疗现状与挑战 四、大模型…...
7. 【.NET 8 实战--孢子记账--从单体到微服务--转向微服务】--微服务基础工具与技术--Ocelot 网关--路由
路由是API网关的核心功能,对系统性能和可靠性至关重要。路由通过定义规则,将客户端请求准确地转发到相应的后端服务,确保请求能够正确处理,简化了微服务架构中的服务调用逻辑。有效的路由配置能够提高系统的灵活性和可维护性。 一…...
【GESP】C++二级模拟 luogu-b3995, [GESP 二级模拟] 小洛的田字矩阵
GESP二级模拟题,多层循环、分支语句练习,难度★✮☆☆☆。 题目题解详见:https://www.coderli.com/gesp-2-luogu-b3995/ 【GESP】C二级模拟 luogu-b3995, [GESP 二级模拟] 小洛的田字矩阵 | OneCoderGESP二级模拟题,多层循环、分…...
监督学习——基于线性回归的波士顿房价预测:理论、实践与评估
基于线性回归的波士顿房价预测:理论、实践与评估 文章目录 基于线性回归的波士顿房价预测:理论、实践与评估一、引言二、线性回归基础理论2.1 线性回归原理2.2 线性回归在房价预测中的应用逻辑三、波士顿房价数据集介绍3.1 数据集概述3.2 特征说明3.3 目标变量四、波士顿房价…...
Selenium 调用模型接口实现功能测试
要使用 Selenium 调用模型接口实现功能测试,可按以下步骤进行: 1. 环境准备 安装 Selenium:使用 pip install selenium 安装 Selenium 库。安装浏览器驱动:根据使用的浏览器(如 Chrome、Firefox 等)下载对应的驱动,并将其添加到系统的环境变量中。例如,Chrome 浏览器需…...
回调函数的用法
回调函数的基本用法 回调函数是一种被作为参数传递给另一个函数的函数,接收回调函数作为参数的函数在合适的时候会调用这个回调函数。回调函数为代码提供了更高的灵活性和可扩展性,下面为你详细介绍回调函数的基本用法。 基本概念 回调函数的核心在于函…...
springboot实现文件上传到华为云的obs
一、前言 有时在项目中需要使用一些存储系统来存储文件,那么当项目要接入obs作为存储系统时,就会利用obs来进行文件的上传下载,具体实现如下。 二、如何通过obs实现文件的上传下载? 1.添加相关的obs的maven依赖。 <dependency…...
南京布局产业园剖析:成都树莓集团的战略逻辑
在数字产业飞速发展的当下,成都树莓集团在南京布局产业园,这一举措蕴含着深刻的战略考量,是基于对市场环境、产业趋势以及自身发展需求的综合研判。 一、政策利好与发展机遇 南京作为长三角地区的重要城市,在数字经济发展方面享有…...
C++ QT 6.6.1 QCustomPlot的导入及使用注意事项和示例 | 关于高版本QT使用QCustomPlot报错问题解决的办法
C QT 6.6.1 QCustomPlot的导入及使用注意事项和示例 | 关于高版本QT使用QCustomPlot报错问题解决的办法 记录一下 qmake .pro文件的配置 QT core gui printsupportgreaterThan(QT_MAJOR_VERSION, 4): QT widgetsCONFIG c17# You can make your code fail to compil…...
【算法】哈希表详解
【算法】哈希表详解 1. 哈希表的基本概念2. 哈希表的优缺点3. 哈希表的实现方法4. 哈希表的应用场景5. 哈希表的性能优化6. 哈希表 vs 其他数据结构7. 总结 哈希表(Hash Table) 是一种高效的数据结构,用于存储键值对(Key-Value Pa…...
【红队利器】单文件一键结束火绒6.0
关于我们 4SecNet 团队专注于网络安全攻防研究,目前团队成员分布在国内多家顶级安全厂商的核心部门,包括安全研究领域、攻防实验室等,汇聚了行业内的顶尖技术力量。团队在病毒木马逆向分析、APT 追踪、破解技术、漏洞分析、红队工具开发等多个…...
Docker小游戏 | 使用Docker部署star-battle太空飞船射击小游戏
Docker小游戏 | 使用Docker部署star-battle太空飞船射击小游戏 前言项目介绍项目简介项目预览二、系统要求环境要求环境检查Docker版本检查检查操作系统版本三、部署star-battle网页小游戏下载镜像创建容器检查容器状态检查服务端口安全设置四、访问star-battle网页小游戏五、总…...
【EB-06】SystemCreator dbc转arxml
SystemCreator dbc转arxml 1. SystemCreator 意义2. SystemCreator使用方法2.1 实现步骤2.2 参考官方文档方法1. SystemCreator 意义 EB Tresos 对dbc直接导入的支持不是很完善,dbc也不是AUTOSAR标准的数据库文件,EB建议所有通信矩阵通过ARXML交互比较合理(AUTOSAR定义的)…...
(0)阿里云大模型ACP-考试回忆
这两天通过了阿里云大模型ACP考试,由于之前在网上没有找到真题,导致第一次考试没有过,后面又重新学习了一遍文档才顺利通过考试,这两次考试内容感觉考试题目90%内容是覆盖的,后面准备分享一下每一章的考题,…...
按键精灵鹰眼中控:ios多设备管理工具
在当今数字化时代,高效管理多设备已成为许多企业和个人的迫切需求。无论是游戏多开、自动化测试,还是电商运营,如何同时操作多台设备并确保精准执行,一直是一个难题。现在,按键精灵的鹰眼群控功能为您提供了完美的解决…...
(十)学生端搭建
本次旨在将之前的已完成的部分功能进行拼装到学生端,同时完善学生端的构建。本次工作主要包括: 1.学生端整体界面布局 2.模拟考场与部分个人画像流程的串联 3.整体学生端逻辑 一、学生端 在主界面可以选择自己的用户角色 选择学生则进入学生登录界面…...
日语学习-日语知识点小记-构建基础-JLPT-N4阶段(33):にする
日语学习-日语知识点小记-构建基础-JLPT-N4阶段(33):にする 1、前言(1)情况说明(2)工程师的信仰2、知识点(1) にする1,接续:名词+にする2,接续:疑问词+にする3,(A)は(B)にする。(2)復習:(1)复习句子(2)ために & ように(3)そう(4)にする3、…...
Spring Boot 实现流式响应(兼容 2.7.x)
在实际开发中,我们可能会遇到一些流式数据处理的场景,比如接收来自上游接口的 Server-Sent Events(SSE) 或 流式 JSON 内容,并将其原样中转给前端页面或客户端。这种情况下,传统的 RestTemplate 缓存机制会…...
练习(含atoi的模拟实现,自定义类型等练习)
一、结构体大小的计算及位段 (结构体大小计算及位段 详解请看:自定义类型:结构体进阶-CSDN博客) 1.在32位系统环境,编译选项为4字节对齐,那么sizeof(A)和sizeof(B)是多少? #pragma pack(4)st…...
关于iview组件中使用 table , 绑定序号分页后序号从1开始的解决方案
问题描述:iview使用table 中type: "index",分页之后 ,索引还是从1开始,试过绑定后台返回数据的id, 这种方法可行,就是后台返回数据的每个页面id都不完全是按照从1开始的升序,因此百度了下,找到了…...
《用户共鸣指数(E)驱动品牌大模型种草:如何抢占大模型搜索结果情感高地》
在注意力分散、内容高度同质化的时代,情感连接已成为品牌破圈的关键通道。我们在服务大量品牌客户的过程中发现,消费者对内容的“有感”程度,正日益成为影响品牌传播效率与转化率的核心变量。在生成式AI驱动的内容生成与推荐环境中࿰…...
Java - Mysql数据类型对应
Mysql数据类型java数据类型备注整型INT/INTEGERint / java.lang.Integer–BIGINTlong/java.lang.Long–––浮点型FLOATfloat/java.lang.FloatDOUBLEdouble/java.lang.Double–DECIMAL/NUMERICjava.math.BigDecimal字符串型CHARjava.lang.String固定长度字符串VARCHARjava.lang…...
苍穹外卖--缓存菜品
1.问题说明 用户端小程序展示的菜品数据都是通过查询数据库获得,如果用户端访问量比较大,数据库访问压力随之增大 2.实现思路 通过Redis来缓存菜品数据,减少数据库查询操作。 缓存逻辑分析: ①每个分类下的菜品保持一份缓存数据…...
C# 类和继承(抽象类)
抽象类 抽象类是指设计为被继承的类。抽象类只能被用作其他类的基类。 不能创建抽象类的实例。抽象类使用abstract修饰符声明。 抽象类可以包含抽象成员或普通的非抽象成员。抽象类的成员可以是抽象成员和普通带 实现的成员的任意组合。抽象类自己可以派生自另一个抽象类。例…...
C++.OpenGL (10/64)基础光照(Basic Lighting)
基础光照(Basic Lighting) 冯氏光照模型(Phong Lighting Model) #mermaid-svg-GLdskXwWINxNGHso {font-family:"trebuchet ms",verdana,arial,sans-serif;font-size:16px;fill:#333;}#mermaid-svg-GLdskXwWINxNGHso .error-icon{fill:#552222;}#mermaid-svg-GLd…...
