超大规模分类(四):Partial FC
人脸识别任务里,通常利用全连接层,来做人脸的分类。会面临三个实际问题:
- 真实的人脸识别数据噪声严重
- 真实的人脸识别数据存在严重的长尾分布问题,一些类别样本多,多数类别样本少
- 人脸类别越来越多,全连接层训练成本越来越高,难度越来越大
于是,有研究人员提出Partial FC,拒绝全量更新负类别中心,而是仅更新少部分负类别中心。该做法优势在于
- 降低噪声数据被采样的概率
- 降低高频负类别中心被选中的概率
- 降低负类别中心的更新频率,降低训练难度
如下图所示:
![![[Pasted image 20250219215540.png]]](https://i-blog.csdnimg.cn/direct/5a1cf3c3680e4880b811c4b3290956ec.png)
问题建模
人脸识别领域,常用的分类损失公式化定义如下:
L = − 1 B ∑ i = 1 B l o g e W y i T ⋅ x i e W y i T ⋅ x i + ∑ j = 1 , j ≠ y i C e W j T ⋅ x i (1) L=-\frac{1}{B}\sum_{i=1}^{B}log\frac{e^{W_{y_i}^T\cdot x_i}}{e^{W_{y_i}^T}\cdot x_i+\sum_{j=1,j\neq y_i}^Ce^{W_j^T\cdot x_i}} \tag{1} L=−B1i=1∑BlogeWyiT⋅xi+∑j=1,j=yiCeWjT⋅xieWyiT⋅xi(1)
,其中, B B B表示batch size, C C C表示类别个数, W j T W_{j}^T WjT表示第 j j j个类别中心的特征, ( x i , y i ) (x_i,y_i) (xi,yi)表示第 i i i个样本的特征为 x i x_i xi,类别为 y i y_i yi。
真实大规模人脸数据实际使用时,有以下问题:![![[Pasted image 20250219222626.png]]](https://i-blog.csdnimg.cn/direct/72a04a75e2ee4ffdb4f6fb3a1ed68c62.png)
- 噪声问题:见上图(a),图片对都是一个人的图片,但是被分到不同的类别,这对模型训练有非常大的干扰。
- 长尾分布:见上图(b),大部分类别(identity)包含的图像数量很少,在WebFace42M中,44.57%的类别包含的图像数量少于10张。这会导致低频类别的类别中心更新缓慢,而高频类别的类别中心更新频繁。
- 训练资源:全连接层一般表示为 W ∈ R D × C W\in \mathbb{R}^{D\times C} W∈RD×C,其中 D D D表示维度, C C C表示类别数。假设 D = 512 D=512 D=512,如果类别数是1000,000(一百万)
- fp16下,全连接层的显存消耗为: 512 × 100 , 000 × 2 1024 × 1024 × 1024 = 0.95 G B \frac{512\times 100,000 \times 2}{1024\times 1024\times 1024}=0.95GB 1024×1024×1024512×100,000×2=0.95GB
- 公式(1)中,需要计算 B B B个 x i x_i xi属于类别中心 W j T W_{j}^T WjT的logit,维度是 R B × D × C \mathbb{R}^{B\times D\times C} RB×D×C,显存消耗为 512 × 100 , 000 × 2 1024 × 1024 × 1024 ⋅ B = 0.95 B G B \frac{512\times 100,000 \times 2}{1024\times 1024\times 1024}\cdot B=0.95B \,GB 1024×1024×1024512×100,000×2⋅B=0.95BGB,batchsize越大,需要的显存越大。
- 在下图,进行了模型并行和partial fc在显存消耗和训练速度上的比较,可以发现:
- partial fc显著降低了对logit的显存消耗
- partial fc略微降低了对存储类别中心的显存消耗
- partial fc未降低对特征抽取网络的显存消耗(将原图像转换为特征的模型的消耗)
- 由于partial fc减少了负类别中心的数量,降低了logit计算的复杂度,随着训练类别越多,加速比越高。
![![[Pasted image 20250219222640.png]]](https://i-blog.csdnimg.cn/direct/5120c47931814a75aff2bfa7564d4914.png)
partial fc
为了缓解上述问题,提出了partial fc,通过稀疏更新全连接层的参数,来支持大规模人脸识别模型的训练。
整体架构如下图所示:![![[Pasted image 20250219222722.png]]](https://i-blog.csdnimg.cn/direct/e1214ff7f3ee4102a614396d3e009a0e.png)
模型通过数据并行训练的,不同GPU包含了不同数据的特征,整体步骤如下:
- 汇总不同GPU里的图像特征和图像标签
- 将汇总的图像特征和图像标签送到每张GPU上
- 将全连接层(即 C C C个类别中心)均分到每张GPU上
- 在单张卡上,保留需要的正类别中心,以及采样固定比例的负类别中心
- 利用样本、正类别中心、负类别中心计算损失函数
代码实现
def forward(self,local_embeddings: torch.Tensor,local_labels: torch.Tensor,):"""Parameters:----------local_embeddings: torch.Tensorfeature embeddings on each GPU(Rank).local_labels: torch.Tensorlabels on each GPU(Rank).Returns:-------loss: torch.Tensorpass"""local_labels.squeeze_()local_labels = local_labels.long()batch_size = local_embeddings.size(0)if self.last_batch_size == 0:self.last_batch_size = batch_sizeassert self.last_batch_size == batch_size, (f"last batch size do not equal current batch size: {self.last_batch_size} vs {batch_size}")_gather_embeddings = [torch.zeros((batch_size, self.embedding_size)).cuda()for _ in range(self.world_size)]_gather_labels = [torch.zeros(batch_size).long().cuda() for _ in range(self.world_size)]_list_embeddings = AllGather(local_embeddings, *_gather_embeddings)distributed.all_gather(_gather_labels, local_labels)# 汇总不同GPU里的图像特征和图像标签embeddings = torch.cat(_list_embeddings)labels = torch.cat(_gather_labels)labels = labels.view(-1, 1)# self.class_start表示该GPU中,分配的类别中心起始id# self.num_local表示该GPU中,分配的类别中心数量# 于是,该GPU的类别中心id范围是[类别中心起始id, 类别中心起始id + 类别中心数量]# 在单张卡上,仅保留需要的正类别中心index_positive = (self.class_start <= labels) & (labels < self.class_start + self.num_local)labels[~index_positive] = -1labels[index_positive] -= self.class_start# 在单张卡上,采样固定比例的负类别中心if self.sample_rate < 1:weight = self.sample(labels, index_positive)else:weight = self.weightwith torch.cuda.amp.autocast(self.fp16):norm_embeddings = normalize(embeddings)norm_weight_activated = normalize(weight)logits = linear(norm_embeddings, norm_weight_activated)if self.fp16:logits = logits.float()logits = logits.clamp(-1, 1)# 基于样本特征、样本标签、正类别中心,采样的负类别中心,计算损失logits = self.margin_softmax(logits, labels)loss = self.dist_cross_entropy(logits, labels)return loss
优势
partial fc的核心思想是”降低训练中负类别中心数量,显式得减少需要更新的参数量“。负类别中心采样比例越低,节约的显存越多。
为了更好理解partial fc对长尾分布、噪声问题的影响,计算分类损失对于样本 x i x_i xi的梯度,如下:
∂ L ∂ x i = − ( ( 1 − p + ) W + − ∑ j ∈ S , j ≠ y i p j − W j − ) (2) \frac{\partial L}{\partial x_i}=-((1-p^+)W^+-\sum_{j\in \mathbb{S}, j\neq y_i}p_j^-W_j^-) \tag{2} ∂xi∂L=−((1−p+)W+−j∈S,j=yi∑pj−Wj−)(2)
其中, p + p^+ p+、 p − p^- p−分别表示通过样本特征 x i x_i xi计算的logit分数, S \mathbb{S} S表示负类别中心, ∣ S ∣ = C × r |\mathbb{S}|=C\times r ∣S∣=C×r,通过采样比例 r r r,调整训练时的负样本数量。
样本特征 x i x_i xi的更新方向和正类别中心和负类别中心都有关系,partial fc随机减少负类别中心数量,减低噪声数据被采样的概率,降低高频负类别中心被选中的概率,有效缓解长尾问题和噪声问题。
注意:采样率为1,等同于选取所有负类别中心,进行模型训练。相当于原始fc分类器
为了进一步验证partial fc的作用原理,做了下述验证下实验
探究采样率与类内、类间相似度关系
![![[Pasted image 20250224175118.png]]](https://i-blog.csdnimg.cn/direct/2541376000dc4efdaa6abf21036b4a60.png)
(a)图中,采样率越低,APCS收敛至更高数值。APCS表示类内距离 A P C S = 1 B ∑ i = 1 B W y i T x i ∣ ∣ W y i ∣ ∣ ⋅ ∣ ∣ x i ∣ ∣ APCS=\frac{1}{B}\sum_{i=1}^B\frac{W_{y_i}^Tx_i}{||W_{y_i}||\cdot ||x_i||} APCS=B1i=1∑B∣∣Wyi∣∣⋅∣∣xi∣∣WyiTxi,说明采样率越低,类内相似度越大,类内越紧密。
(b)图中,采样率越低,MICS分布越往右,整体数值越大。MICS表示最大的类间余弦相似度 MICS i = max j ≠ i W i T W j ∥ W i ∥ ∥ W j ∥ \text{MICS}_i = \max_{j \neq i} \frac{W_i^T W_j}{\|W_i\| \, \|W_j\|} MICSi=j=imax∥Wi∥∥Wj∥WiTWj,说明采样率越低,类间相似度越大,类间拉不开。
探究采样率与评测集合效果关系
随着采样率越来越低,IJB-C、MFR-All评测集上的效果越来越差,如下:
![![[Pasted image 20250226180734.png]]](https://i-blog.csdnimg.cn/direct/63452da8fae1489ca5d26791f6649afe.png)
探究采样率对噪声数据的鲁棒性
为验证在噪声数据上的效果,做了如下实验:
![![[Pasted image 20250226181829.png]]](https://i-blog.csdnimg.cn/direct/a0501aed8a2d4647be2707b3f91dd07e.png)
构造了WebFace12M-Conflict数据集,随机将20万个类的样本放到另外60万个类别数据中。
图(a)纵坐标为AMNCS,指最小的负类中心距离(越大,说明样本和负类离得远)。可以发现,降低采样率(1.0->0.1),在干净数据上,效果接近;在噪声数据上,缓解过拟合问题。
图(b)的MICS,指最大的类间余弦相似度(越大,说明类别越相似,越区分不开类别)。可以发现,降低采样率(1.0->0.1),MICS分布往右,整体数值偏大。由于WebFace12M-Conflict数据集中,20万类的样本随机分布在其他类别中,类间余弦相似度本就很大,图(b)更好的刻画实际噪声分布。
![![[Pasted image 20250226210518.png]]](https://i-blog.csdnimg.cn/direct/706a7c96b908436cb6d91343dd251d90.png)
上图定义了两个概念,分别是conflict-hard和conflict-noise。conflict-hard表示利用真实负样本计算AMNCS;conflict-noise表示利用噪声负样本计算AMNCS。结果表明:
- r=1.0时,针对AMNCS指标,conflict-hard>conflict-noise,表明负样本不采样,会使得模型过分拟合数据集,导致对噪声数据不鲁棒(按理说应该是conflict-noise>conflict-hard)。
- r=0.1时,针对AMNCS指标,conflict-hard<conflict-noise,刻画出真实数据特性。
消融实验
不同数据集、不同采样率下partial fc
![![[Pasted image 20250226212125.png]]](https://i-blog.csdnimg.cn/direct/5cf2334d4bf449548ede42866290f39a.png)
不同网络结构下partial fc
![![[Pasted image 20250226212145.png]]](https://i-blog.csdnimg.cn/direct/e14f72a5958e4205b1675de850e32f7e.png)
对噪声数据鲁棒
采用WebFace12M-Conflict作为训练集合。
![![[Pasted image 20250226212355.png]]](https://i-blog.csdnimg.cn/direct/cf0b8d19dbe9428198916461333838ad.png)
![![[Pasted image 20250226212449.png]]](https://i-blog.csdnimg.cn/direct/24e5ed17deb341028d028001f919021c.png)
对长尾数据鲁棒
![![[Pasted image 20250226212618.png]]](https://i-blog.csdnimg.cn/direct/6a69a7a80869480d8c5746206f475622.png)
收敛速度、训练时间
![![[Pasted image 20250226212802.png]]](https://i-blog.csdnimg.cn/direct/dbfc930f35ac4ed59fde5491813881e7.png)
相关文章:
超大规模分类(四):Partial FC
人脸识别任务里,通常利用全连接层,来做人脸的分类。会面临三个实际问题: 真实的人脸识别数据噪声严重真实的人脸识别数据存在严重的长尾分布问题,一些类别样本多,多数类别样本少人脸类别越来越多,全连接层…...
uniapp 小程序如何实现大模型流式交互?前端SSE技术完整实现解析
文章目录 一、背景概述二、核心流程图解三、代码模块详解1. UTF-8解码器(处理二进制流)2. 请求控制器(核心通信模块)3. 流式请求处理器(分块接收)4. 数据解析器(处理SSE格式)5. 回调…...
因子分析详解:从理论到MATLAB实战
内容摘要: 本文系统解析因子分析的核心原理与MATLAB实战,涵盖数学模型、载荷矩阵估计、因子旋转及得分计算。通过上市公司盈利能力、消费者偏好等案例,演示数据标准化、因子提取与解释的全流程,并提供完整代码实现。深入对比因子分…...
【组态PLC】基于三菱西门子S7-200PLC和组态王液料混合系统组态设计【含PLC组态源码 M016期】
控制要求 总体控制要求:如面板图所示,本装置为三种液体混合模拟装置,由液面传感器SL1、SL2、SL3,液体A、B、C阀门与混合液阀门由电磁阀YV1、YV2、YV3、YV4,搅匀电机M,加热器H,温度传感器T组成。…...
js:根据后端返回的数组取出每一个数组的keyword字段然后拼接成一个逗号分隔的字符串
问: 现在有一个el-select, 后端接口返回数据为keyword:xxx,referenceNum:1,tagId:132sf32fasdfaf组成的数组, 现在select是多选, 但是但我选择多个下拉框选项后,后端需要前端返回的数据tagIds字段需要时一个字符串…...
基于大模型的肺纤维化预测及临床方案研究报告
目录 一、引言 1.1 研究背景与意义 1.2 研究目的与方法 二、大模型技术概述 2.1 大模型的基本原理 2.2 大模型在医疗领域的应用现状 三、肺纤维化相关知识 3.1 肺纤维化的病因与发病机制 3.2 肺纤维化的临床症状与诊断方法 3.3 肺纤维化的治疗现状与挑战 四、大模型…...
7. 【.NET 8 实战--孢子记账--从单体到微服务--转向微服务】--微服务基础工具与技术--Ocelot 网关--路由
路由是API网关的核心功能,对系统性能和可靠性至关重要。路由通过定义规则,将客户端请求准确地转发到相应的后端服务,确保请求能够正确处理,简化了微服务架构中的服务调用逻辑。有效的路由配置能够提高系统的灵活性和可维护性。 一…...
【GESP】C++二级模拟 luogu-b3995, [GESP 二级模拟] 小洛的田字矩阵
GESP二级模拟题,多层循环、分支语句练习,难度★✮☆☆☆。 题目题解详见:https://www.coderli.com/gesp-2-luogu-b3995/ 【GESP】C二级模拟 luogu-b3995, [GESP 二级模拟] 小洛的田字矩阵 | OneCoderGESP二级模拟题,多层循环、分…...
监督学习——基于线性回归的波士顿房价预测:理论、实践与评估
基于线性回归的波士顿房价预测:理论、实践与评估 文章目录 基于线性回归的波士顿房价预测:理论、实践与评估一、引言二、线性回归基础理论2.1 线性回归原理2.2 线性回归在房价预测中的应用逻辑三、波士顿房价数据集介绍3.1 数据集概述3.2 特征说明3.3 目标变量四、波士顿房价…...
Selenium 调用模型接口实现功能测试
要使用 Selenium 调用模型接口实现功能测试,可按以下步骤进行: 1. 环境准备 安装 Selenium:使用 pip install selenium 安装 Selenium 库。安装浏览器驱动:根据使用的浏览器(如 Chrome、Firefox 等)下载对应的驱动,并将其添加到系统的环境变量中。例如,Chrome 浏览器需…...
回调函数的用法
回调函数的基本用法 回调函数是一种被作为参数传递给另一个函数的函数,接收回调函数作为参数的函数在合适的时候会调用这个回调函数。回调函数为代码提供了更高的灵活性和可扩展性,下面为你详细介绍回调函数的基本用法。 基本概念 回调函数的核心在于函…...
springboot实现文件上传到华为云的obs
一、前言 有时在项目中需要使用一些存储系统来存储文件,那么当项目要接入obs作为存储系统时,就会利用obs来进行文件的上传下载,具体实现如下。 二、如何通过obs实现文件的上传下载? 1.添加相关的obs的maven依赖。 <dependency…...
南京布局产业园剖析:成都树莓集团的战略逻辑
在数字产业飞速发展的当下,成都树莓集团在南京布局产业园,这一举措蕴含着深刻的战略考量,是基于对市场环境、产业趋势以及自身发展需求的综合研判。 一、政策利好与发展机遇 南京作为长三角地区的重要城市,在数字经济发展方面享有…...
C++ QT 6.6.1 QCustomPlot的导入及使用注意事项和示例 | 关于高版本QT使用QCustomPlot报错问题解决的办法
C QT 6.6.1 QCustomPlot的导入及使用注意事项和示例 | 关于高版本QT使用QCustomPlot报错问题解决的办法 记录一下 qmake .pro文件的配置 QT core gui printsupportgreaterThan(QT_MAJOR_VERSION, 4): QT widgetsCONFIG c17# You can make your code fail to compil…...
【算法】哈希表详解
【算法】哈希表详解 1. 哈希表的基本概念2. 哈希表的优缺点3. 哈希表的实现方法4. 哈希表的应用场景5. 哈希表的性能优化6. 哈希表 vs 其他数据结构7. 总结 哈希表(Hash Table) 是一种高效的数据结构,用于存储键值对(Key-Value Pa…...
【红队利器】单文件一键结束火绒6.0
关于我们 4SecNet 团队专注于网络安全攻防研究,目前团队成员分布在国内多家顶级安全厂商的核心部门,包括安全研究领域、攻防实验室等,汇聚了行业内的顶尖技术力量。团队在病毒木马逆向分析、APT 追踪、破解技术、漏洞分析、红队工具开发等多个…...
Docker小游戏 | 使用Docker部署star-battle太空飞船射击小游戏
Docker小游戏 | 使用Docker部署star-battle太空飞船射击小游戏 前言项目介绍项目简介项目预览二、系统要求环境要求环境检查Docker版本检查检查操作系统版本三、部署star-battle网页小游戏下载镜像创建容器检查容器状态检查服务端口安全设置四、访问star-battle网页小游戏五、总…...
【EB-06】SystemCreator dbc转arxml
SystemCreator dbc转arxml 1. SystemCreator 意义2. SystemCreator使用方法2.1 实现步骤2.2 参考官方文档方法1. SystemCreator 意义 EB Tresos 对dbc直接导入的支持不是很完善,dbc也不是AUTOSAR标准的数据库文件,EB建议所有通信矩阵通过ARXML交互比较合理(AUTOSAR定义的)…...
(0)阿里云大模型ACP-考试回忆
这两天通过了阿里云大模型ACP考试,由于之前在网上没有找到真题,导致第一次考试没有过,后面又重新学习了一遍文档才顺利通过考试,这两次考试内容感觉考试题目90%内容是覆盖的,后面准备分享一下每一章的考题,…...
按键精灵鹰眼中控:ios多设备管理工具
在当今数字化时代,高效管理多设备已成为许多企业和个人的迫切需求。无论是游戏多开、自动化测试,还是电商运营,如何同时操作多台设备并确保精准执行,一直是一个难题。现在,按键精灵的鹰眼群控功能为您提供了完美的解决…...
【JavaEE】-- HTTP
1. HTTP是什么? HTTP(全称为"超文本传输协议")是一种应用非常广泛的应用层协议,HTTP是基于TCP协议的一种应用层协议。 应用层协议:是计算机网络协议栈中最高层的协议,它定义了运行在不同主机上…...
【Redis技术进阶之路】「原理分析系列开篇」分析客户端和服务端网络诵信交互实现(服务端执行命令请求的过程 - 初始化服务器)
服务端执行命令请求的过程 【专栏简介】【技术大纲】【专栏目标】【目标人群】1. Redis爱好者与社区成员2. 后端开发和系统架构师3. 计算机专业的本科生及研究生 初始化服务器1. 初始化服务器状态结构初始化RedisServer变量 2. 加载相关系统配置和用户配置参数定制化配置参数案…...
Python ROS2【机器人中间件框架】 简介
销量过万TEEIS德国护膝夏天用薄款 优惠券冠生园 百花蜂蜜428g 挤压瓶纯蜂蜜巨奇严选 鞋子除臭剂360ml 多芬身体磨砂膏280g健70%-75%酒精消毒棉片湿巾1418cm 80片/袋3袋大包清洁食品用消毒 优惠券AIMORNY52朵红玫瑰永生香皂花同城配送非鲜花七夕情人节生日礼物送女友 热卖妙洁棉…...
Java毕业设计:WML信息查询与后端信息发布系统开发
JAVAWML信息查询与后端信息发布系统实现 一、系统概述 本系统基于Java和WML(无线标记语言)技术开发,实现了移动设备上的信息查询与后端信息发布功能。系统采用B/S架构,服务器端使用Java Servlet处理请求,数据库采用MySQL存储信息࿰…...
Linux系统部署KES
1、安装准备 1.版本说明V008R006C009B0014 V008:是version产品的大版本。 R006:是release产品特性版本。 C009:是通用版 B0014:是build开发过程中的构建版本2.硬件要求 #安全版和企业版 内存:1GB 以上 硬盘…...
Windows电脑能装鸿蒙吗_Windows电脑体验鸿蒙电脑操作系统教程
鸿蒙电脑版操作系统来了,很多小伙伴想体验鸿蒙电脑版操作系统,可惜,鸿蒙系统并不支持你正在使用的传统的电脑来安装。不过可以通过可以使用华为官方提供的虚拟机,来体验大家心心念念的鸿蒙系统啦!注意:虚拟…...
CppCon 2015 学习:Time Programming Fundamentals
Civil Time 公历时间 特点: 共 6 个字段: Year(年)Month(月)Day(日)Hour(小时)Minute(分钟)Second(秒) 表示…...
ThreadLocal 源码
ThreadLocal 源码 此类提供线程局部变量。这些变量不同于它们的普通对应物,因为每个访问一个线程局部变量的线程(通过其 get 或 set 方法)都有自己独立初始化的变量副本。ThreadLocal 实例通常是类中的私有静态字段,这些类希望将…...
【51单片机】4. 模块化编程与LCD1602Debug
1. 什么是模块化编程 传统编程会将所有函数放在main.c中,如果使用的模块多,一个文件内会有很多代码,不利于组织和管理 模块化编程则是将各个模块的代码放在不同的.c文件里,在.h文件里提供外部可调用函数声明,其他.c文…...
跨平台商品数据接口的标准化与规范化发展路径:淘宝京东拼多多的最新实践
在电商行业蓬勃发展的当下,多平台运营已成为众多商家的必然选择。然而,不同电商平台在商品数据接口方面存在差异,导致商家在跨平台运营时面临诸多挑战,如数据对接困难、运营效率低下、用户体验不一致等。跨平台商品数据接口的标准…...
