超大规模分类(四):Partial FC
人脸识别任务里,通常利用全连接层,来做人脸的分类。会面临三个实际问题:
- 真实的人脸识别数据噪声严重
- 真实的人脸识别数据存在严重的长尾分布问题,一些类别样本多,多数类别样本少
- 人脸类别越来越多,全连接层训练成本越来越高,难度越来越大
于是,有研究人员提出Partial FC,拒绝全量更新负类别中心,而是仅更新少部分负类别中心。该做法优势在于
- 降低噪声数据被采样的概率
- 降低高频负类别中心被选中的概率
- 降低负类别中心的更新频率,降低训练难度
如下图所示:
问题建模
人脸识别领域,常用的分类损失公式化定义如下:
L = − 1 B ∑ i = 1 B l o g e W y i T ⋅ x i e W y i T ⋅ x i + ∑ j = 1 , j ≠ y i C e W j T ⋅ x i (1) L=-\frac{1}{B}\sum_{i=1}^{B}log\frac{e^{W_{y_i}^T\cdot x_i}}{e^{W_{y_i}^T}\cdot x_i+\sum_{j=1,j\neq y_i}^Ce^{W_j^T\cdot x_i}} \tag{1} L=−B1i=1∑BlogeWyiT⋅xi+∑j=1,j=yiCeWjT⋅xieWyiT⋅xi(1)
,其中, B B B表示batch size, C C C表示类别个数, W j T W_{j}^T WjT表示第 j j j个类别中心的特征, ( x i , y i ) (x_i,y_i) (xi,yi)表示第 i i i个样本的特征为 x i x_i xi,类别为 y i y_i yi。
真实大规模人脸数据实际使用时,有以下问题:
- 噪声问题:见上图(a),图片对都是一个人的图片,但是被分到不同的类别,这对模型训练有非常大的干扰。
- 长尾分布:见上图(b),大部分类别(identity)包含的图像数量很少,在WebFace42M中,44.57%的类别包含的图像数量少于10张。这会导致低频类别的类别中心更新缓慢,而高频类别的类别中心更新频繁。
- 训练资源:全连接层一般表示为 W ∈ R D × C W\in \mathbb{R}^{D\times C} W∈RD×C,其中 D D D表示维度, C C C表示类别数。假设 D = 512 D=512 D=512,如果类别数是1000,000(一百万)
- fp16下,全连接层的显存消耗为: 512 × 100 , 000 × 2 1024 × 1024 × 1024 = 0.95 G B \frac{512\times 100,000 \times 2}{1024\times 1024\times 1024}=0.95GB 1024×1024×1024512×100,000×2=0.95GB
- 公式(1)中,需要计算 B B B个 x i x_i xi属于类别中心 W j T W_{j}^T WjT的logit,维度是 R B × D × C \mathbb{R}^{B\times D\times C} RB×D×C,显存消耗为 512 × 100 , 000 × 2 1024 × 1024 × 1024 ⋅ B = 0.95 B G B \frac{512\times 100,000 \times 2}{1024\times 1024\times 1024}\cdot B=0.95B \,GB 1024×1024×1024512×100,000×2⋅B=0.95BGB,batchsize越大,需要的显存越大。
- 在下图,进行了模型并行和partial fc在显存消耗和训练速度上的比较,可以发现:
- partial fc显著降低了对logit的显存消耗
- partial fc略微降低了对存储类别中心的显存消耗
- partial fc未降低对特征抽取网络的显存消耗(将原图像转换为特征的模型的消耗)
- 由于partial fc减少了负类别中心的数量,降低了logit计算的复杂度,随着训练类别越多,加速比越高。
partial fc
为了缓解上述问题,提出了partial fc,通过稀疏更新全连接层的参数,来支持大规模人脸识别模型的训练。
整体架构如下图所示:
模型通过数据并行训练的,不同GPU包含了不同数据的特征,整体步骤如下:
- 汇总不同GPU里的图像特征和图像标签
- 将汇总的图像特征和图像标签送到每张GPU上
- 将全连接层(即 C C C个类别中心)均分到每张GPU上
- 在单张卡上,保留需要的正类别中心,以及采样固定比例的负类别中心
- 利用样本、正类别中心、负类别中心计算损失函数
代码实现
def forward(self,local_embeddings: torch.Tensor,local_labels: torch.Tensor,):"""Parameters:----------local_embeddings: torch.Tensorfeature embeddings on each GPU(Rank).local_labels: torch.Tensorlabels on each GPU(Rank).Returns:-------loss: torch.Tensorpass"""local_labels.squeeze_()local_labels = local_labels.long()batch_size = local_embeddings.size(0)if self.last_batch_size == 0:self.last_batch_size = batch_sizeassert self.last_batch_size == batch_size, (f"last batch size do not equal current batch size: {self.last_batch_size} vs {batch_size}")_gather_embeddings = [torch.zeros((batch_size, self.embedding_size)).cuda()for _ in range(self.world_size)]_gather_labels = [torch.zeros(batch_size).long().cuda() for _ in range(self.world_size)]_list_embeddings = AllGather(local_embeddings, *_gather_embeddings)distributed.all_gather(_gather_labels, local_labels)# 汇总不同GPU里的图像特征和图像标签embeddings = torch.cat(_list_embeddings)labels = torch.cat(_gather_labels)labels = labels.view(-1, 1)# self.class_start表示该GPU中,分配的类别中心起始id# self.num_local表示该GPU中,分配的类别中心数量# 于是,该GPU的类别中心id范围是[类别中心起始id, 类别中心起始id + 类别中心数量]# 在单张卡上,仅保留需要的正类别中心index_positive = (self.class_start <= labels) & (labels < self.class_start + self.num_local)labels[~index_positive] = -1labels[index_positive] -= self.class_start# 在单张卡上,采样固定比例的负类别中心if self.sample_rate < 1:weight = self.sample(labels, index_positive)else:weight = self.weightwith torch.cuda.amp.autocast(self.fp16):norm_embeddings = normalize(embeddings)norm_weight_activated = normalize(weight)logits = linear(norm_embeddings, norm_weight_activated)if self.fp16:logits = logits.float()logits = logits.clamp(-1, 1)# 基于样本特征、样本标签、正类别中心,采样的负类别中心,计算损失logits = self.margin_softmax(logits, labels)loss = self.dist_cross_entropy(logits, labels)return loss
优势
partial fc的核心思想是”降低训练中负类别中心数量,显式得减少需要更新的参数量“。负类别中心采样比例越低,节约的显存越多。
为了更好理解partial fc对长尾分布、噪声问题的影响,计算分类损失对于样本 x i x_i xi的梯度,如下:
∂ L ∂ x i = − ( ( 1 − p + ) W + − ∑ j ∈ S , j ≠ y i p j − W j − ) (2) \frac{\partial L}{\partial x_i}=-((1-p^+)W^+-\sum_{j\in \mathbb{S}, j\neq y_i}p_j^-W_j^-) \tag{2} ∂xi∂L=−((1−p+)W+−j∈S,j=yi∑pj−Wj−)(2)
其中, p + p^+ p+、 p − p^- p−分别表示通过样本特征 x i x_i xi计算的logit分数, S \mathbb{S} S表示负类别中心, ∣ S ∣ = C × r |\mathbb{S}|=C\times r ∣S∣=C×r,通过采样比例 r r r,调整训练时的负样本数量。
样本特征 x i x_i xi的更新方向和正类别中心和负类别中心都有关系,partial fc随机减少负类别中心数量,减低噪声数据被采样的概率,降低高频负类别中心被选中的概率,有效缓解长尾问题和噪声问题。
注意:采样率为1,等同于选取所有负类别中心,进行模型训练。相当于原始fc分类器
为了进一步验证partial fc的作用原理,做了下述验证下实验
探究采样率与类内、类间相似度关系
(a)图中,采样率越低,APCS收敛至更高数值。APCS表示类内距离 A P C S = 1 B ∑ i = 1 B W y i T x i ∣ ∣ W y i ∣ ∣ ⋅ ∣ ∣ x i ∣ ∣ APCS=\frac{1}{B}\sum_{i=1}^B\frac{W_{y_i}^Tx_i}{||W_{y_i}||\cdot ||x_i||} APCS=B1i=1∑B∣∣Wyi∣∣⋅∣∣xi∣∣WyiTxi,说明采样率越低,类内相似度越大,类内越紧密。
(b)图中,采样率越低,MICS分布越往右,整体数值越大。MICS表示最大的类间余弦相似度 MICS i = max j ≠ i W i T W j ∥ W i ∥ ∥ W j ∥ \text{MICS}_i = \max_{j \neq i} \frac{W_i^T W_j}{\|W_i\| \, \|W_j\|} MICSi=j=imax∥Wi∥∥Wj∥WiTWj,说明采样率越低,类间相似度越大,类间拉不开。
探究采样率与评测集合效果关系
随着采样率越来越低,IJB-C、MFR-All评测集上的效果越来越差,如下:
探究采样率对噪声数据的鲁棒性
为验证在噪声数据上的效果,做了如下实验:
构造了WebFace12M-Conflict数据集,随机将20万个类的样本放到另外60万个类别数据中。
图(a)纵坐标为AMNCS,指最小的负类中心距离(越大,说明样本和负类离得远)。可以发现,降低采样率(1.0->0.1),在干净数据上,效果接近;在噪声数据上,缓解过拟合问题。
图(b)的MICS,指最大的类间余弦相似度(越大,说明类别越相似,越区分不开类别)。可以发现,降低采样率(1.0->0.1),MICS分布往右,整体数值偏大。由于WebFace12M-Conflict数据集中,20万类的样本随机分布在其他类别中,类间余弦相似度本就很大,图(b)更好的刻画实际噪声分布。
上图定义了两个概念,分别是conflict-hard和conflict-noise。conflict-hard表示利用真实负样本计算AMNCS;conflict-noise表示利用噪声负样本计算AMNCS。结果表明:
- r=1.0时,针对AMNCS指标,conflict-hard>conflict-noise,表明负样本不采样,会使得模型过分拟合数据集,导致对噪声数据不鲁棒(按理说应该是conflict-noise>conflict-hard)。
- r=0.1时,针对AMNCS指标,conflict-hard<conflict-noise,刻画出真实数据特性。
消融实验
不同数据集、不同采样率下partial fc
不同网络结构下partial fc
对噪声数据鲁棒
采用WebFace12M-Conflict作为训练集合。
对长尾数据鲁棒
收敛速度、训练时间
相关文章:

超大规模分类(四):Partial FC
人脸识别任务里,通常利用全连接层,来做人脸的分类。会面临三个实际问题: 真实的人脸识别数据噪声严重真实的人脸识别数据存在严重的长尾分布问题,一些类别样本多,多数类别样本少人脸类别越来越多,全连接层…...
uniapp 小程序如何实现大模型流式交互?前端SSE技术完整实现解析
文章目录 一、背景概述二、核心流程图解三、代码模块详解1. UTF-8解码器(处理二进制流)2. 请求控制器(核心通信模块)3. 流式请求处理器(分块接收)4. 数据解析器(处理SSE格式)5. 回调…...
因子分析详解:从理论到MATLAB实战
内容摘要: 本文系统解析因子分析的核心原理与MATLAB实战,涵盖数学模型、载荷矩阵估计、因子旋转及得分计算。通过上市公司盈利能力、消费者偏好等案例,演示数据标准化、因子提取与解释的全流程,并提供完整代码实现。深入对比因子分…...

【组态PLC】基于三菱西门子S7-200PLC和组态王液料混合系统组态设计【含PLC组态源码 M016期】
控制要求 总体控制要求:如面板图所示,本装置为三种液体混合模拟装置,由液面传感器SL1、SL2、SL3,液体A、B、C阀门与混合液阀门由电磁阀YV1、YV2、YV3、YV4,搅匀电机M,加热器H,温度传感器T组成。…...
js:根据后端返回的数组取出每一个数组的keyword字段然后拼接成一个逗号分隔的字符串
问: 现在有一个el-select, 后端接口返回数据为keyword:xxx,referenceNum:1,tagId:132sf32fasdfaf组成的数组, 现在select是多选, 但是但我选择多个下拉框选项后,后端需要前端返回的数据tagIds字段需要时一个字符串…...
基于大模型的肺纤维化预测及临床方案研究报告
目录 一、引言 1.1 研究背景与意义 1.2 研究目的与方法 二、大模型技术概述 2.1 大模型的基本原理 2.2 大模型在医疗领域的应用现状 三、肺纤维化相关知识 3.1 肺纤维化的病因与发病机制 3.2 肺纤维化的临床症状与诊断方法 3.3 肺纤维化的治疗现状与挑战 四、大模型…...
7. 【.NET 8 实战--孢子记账--从单体到微服务--转向微服务】--微服务基础工具与技术--Ocelot 网关--路由
路由是API网关的核心功能,对系统性能和可靠性至关重要。路由通过定义规则,将客户端请求准确地转发到相应的后端服务,确保请求能够正确处理,简化了微服务架构中的服务调用逻辑。有效的路由配置能够提高系统的灵活性和可维护性。 一…...

【GESP】C++二级模拟 luogu-b3995, [GESP 二级模拟] 小洛的田字矩阵
GESP二级模拟题,多层循环、分支语句练习,难度★✮☆☆☆。 题目题解详见:https://www.coderli.com/gesp-2-luogu-b3995/ 【GESP】C二级模拟 luogu-b3995, [GESP 二级模拟] 小洛的田字矩阵 | OneCoderGESP二级模拟题,多层循环、分…...
监督学习——基于线性回归的波士顿房价预测:理论、实践与评估
基于线性回归的波士顿房价预测:理论、实践与评估 文章目录 基于线性回归的波士顿房价预测:理论、实践与评估一、引言二、线性回归基础理论2.1 线性回归原理2.2 线性回归在房价预测中的应用逻辑三、波士顿房价数据集介绍3.1 数据集概述3.2 特征说明3.3 目标变量四、波士顿房价…...
Selenium 调用模型接口实现功能测试
要使用 Selenium 调用模型接口实现功能测试,可按以下步骤进行: 1. 环境准备 安装 Selenium:使用 pip install selenium 安装 Selenium 库。安装浏览器驱动:根据使用的浏览器(如 Chrome、Firefox 等)下载对应的驱动,并将其添加到系统的环境变量中。例如,Chrome 浏览器需…...
回调函数的用法
回调函数的基本用法 回调函数是一种被作为参数传递给另一个函数的函数,接收回调函数作为参数的函数在合适的时候会调用这个回调函数。回调函数为代码提供了更高的灵活性和可扩展性,下面为你详细介绍回调函数的基本用法。 基本概念 回调函数的核心在于函…...
springboot实现文件上传到华为云的obs
一、前言 有时在项目中需要使用一些存储系统来存储文件,那么当项目要接入obs作为存储系统时,就会利用obs来进行文件的上传下载,具体实现如下。 二、如何通过obs实现文件的上传下载? 1.添加相关的obs的maven依赖。 <dependency…...

南京布局产业园剖析:成都树莓集团的战略逻辑
在数字产业飞速发展的当下,成都树莓集团在南京布局产业园,这一举措蕴含着深刻的战略考量,是基于对市场环境、产业趋势以及自身发展需求的综合研判。 一、政策利好与发展机遇 南京作为长三角地区的重要城市,在数字经济发展方面享有…...

C++ QT 6.6.1 QCustomPlot的导入及使用注意事项和示例 | 关于高版本QT使用QCustomPlot报错问题解决的办法
C QT 6.6.1 QCustomPlot的导入及使用注意事项和示例 | 关于高版本QT使用QCustomPlot报错问题解决的办法 记录一下 qmake .pro文件的配置 QT core gui printsupportgreaterThan(QT_MAJOR_VERSION, 4): QT widgetsCONFIG c17# You can make your code fail to compil…...
【算法】哈希表详解
【算法】哈希表详解 1. 哈希表的基本概念2. 哈希表的优缺点3. 哈希表的实现方法4. 哈希表的应用场景5. 哈希表的性能优化6. 哈希表 vs 其他数据结构7. 总结 哈希表(Hash Table) 是一种高效的数据结构,用于存储键值对(Key-Value Pa…...

【红队利器】单文件一键结束火绒6.0
关于我们 4SecNet 团队专注于网络安全攻防研究,目前团队成员分布在国内多家顶级安全厂商的核心部门,包括安全研究领域、攻防实验室等,汇聚了行业内的顶尖技术力量。团队在病毒木马逆向分析、APT 追踪、破解技术、漏洞分析、红队工具开发等多个…...
Docker小游戏 | 使用Docker部署star-battle太空飞船射击小游戏
Docker小游戏 | 使用Docker部署star-battle太空飞船射击小游戏 前言项目介绍项目简介项目预览二、系统要求环境要求环境检查Docker版本检查检查操作系统版本三、部署star-battle网页小游戏下载镜像创建容器检查容器状态检查服务端口安全设置四、访问star-battle网页小游戏五、总…...

【EB-06】SystemCreator dbc转arxml
SystemCreator dbc转arxml 1. SystemCreator 意义2. SystemCreator使用方法2.1 实现步骤2.2 参考官方文档方法1. SystemCreator 意义 EB Tresos 对dbc直接导入的支持不是很完善,dbc也不是AUTOSAR标准的数据库文件,EB建议所有通信矩阵通过ARXML交互比较合理(AUTOSAR定义的)…...

(0)阿里云大模型ACP-考试回忆
这两天通过了阿里云大模型ACP考试,由于之前在网上没有找到真题,导致第一次考试没有过,后面又重新学习了一遍文档才顺利通过考试,这两次考试内容感觉考试题目90%内容是覆盖的,后面准备分享一下每一章的考题,…...
按键精灵鹰眼中控:ios多设备管理工具
在当今数字化时代,高效管理多设备已成为许多企业和个人的迫切需求。无论是游戏多开、自动化测试,还是电商运营,如何同时操作多台设备并确保精准执行,一直是一个难题。现在,按键精灵的鹰眼群控功能为您提供了完美的解决…...

XCTF-web-easyupload
试了试php,php7,pht,phtml等,都没有用 尝试.user.ini 抓包修改将.user.ini修改为jpg图片 在上传一个123.jpg 用蚁剑连接,得到flag...
Nginx server_name 配置说明
Nginx 是一个高性能的反向代理和负载均衡服务器,其核心配置之一是 server 块中的 server_name 指令。server_name 决定了 Nginx 如何根据客户端请求的 Host 头匹配对应的虚拟主机(Virtual Host)。 1. 简介 Nginx 使用 server_name 指令来确定…...
Angular微前端架构:Module Federation + ngx-build-plus (Webpack)
以下是一个完整的 Angular 微前端示例,其中使用的是 Module Federation 和 npx-build-plus 实现了主应用(Shell)与子应用(Remote)的集成。 🛠️ 项目结构 angular-mf/ ├── shell-app/ # 主应用&…...

Linux 内存管理实战精讲:核心原理与面试常考点全解析
Linux 内存管理实战精讲:核心原理与面试常考点全解析 Linux 内核内存管理是系统设计中最复杂但也最核心的模块之一。它不仅支撑着虚拟内存机制、物理内存分配、进程隔离与资源复用,还直接决定系统运行的性能与稳定性。无论你是嵌入式开发者、内核调试工…...
WEB3全栈开发——面试专业技能点P4数据库
一、mysql2 原生驱动及其连接机制 概念介绍 mysql2 是 Node.js 环境中广泛使用的 MySQL 客户端库,基于 mysql 库改进而来,具有更好的性能、Promise 支持、流式查询、二进制数据处理能力等。 主要特点: 支持 Promise / async-await…...
raid存储技术
1. 存储技术概念 数据存储架构是对数据存储方式、存储设备及相关组件的组织和规划,涵盖存储系统的布局、数据存储策略等,它明确数据如何存储、管理与访问,为数据的安全、高效使用提供支撑。 由计算机中一组存储设备、控制部件和管理信息调度的…...

Axure零基础跟我学:展开与收回
亲爱的小伙伴,如有帮助请订阅专栏!跟着老师每课一练,系统学习Axure交互设计课程! Axure产品经理精品视频课https://edu.csdn.net/course/detail/40420 课程主题:Axure菜单展开与收回 课程视频:...

NineData数据库DevOps功能全面支持百度智能云向量数据库 VectorDB,助力企业 AI 应用高效落地
NineData 的数据库 DevOps 解决方案已完成对百度智能云向量数据库 VectorDB 的全链路适配,成为国内首批提供 VectorDB 原生操作能力的服务商。此次合作聚焦 AI 开发核心场景,通过标准化 SQL 工作台与细粒度权限管控两大能力,助力企业安全高效…...
数据库优化实战指南:提升性能的黄金法则
在现代软件系统中,数据库性能直接影响应用的响应速度和用户体验。面对数据量激增、访问压力增大,数据库性能瓶颈经常成为项目痛点。如何科学有效地优化数据库,提升查询效率和系统稳定性,是每位开发与运维人员必备的技能。 本文结…...
Ubuntu 可执行程序自启动方法
使用 autostart(适用于桌面环境) 适用于 GNOME/KDE 桌面环境(如 Ubuntu 图形界面) 1. 创建 .desktop 文件 sudo vi ~/.config/autostart/my_laser.desktop[Desktop Entry] TypeApplication NameMy Laser Program Execbash -c &…...