当前位置: 首页 > news >正文

【人脸识别】CurricularFace:自适应课程学习人脸识别损失函数

论文题目:《CurricularFace: Adaptive Curriculum Learning Loss for Deep Face Recognition》
论文地址:https://arxiv.org/pdf/2004.00288v1.pdf
代码地址:https://github.com/HuangYG123/CurricularFace

建议先了解下这篇文章:MV-softmax

1.背景

       人脸识别中常用损失函数主要包括两类,基于间隔和难样本挖掘,这两种方法损失函数的训练策略都存在缺陷。前一种方法是对所有样本都采用一个固定的间隔值,没有充分利用每个样本自身的难易信息,这可能导致在使用大边际时出现收敛问题;后一种方法则在整个网络训练周期都强调难样本,可能出现网络无法收敛问题。在本论文中,提出了一种新的自适应课程学习损失函数,称为CurricularFace,它能够很好地解决上述两类损失函数存在的问题。
       下图是CurricularFace跟ArcFace和 MV-Arc-Softmax两种方法的对比,可以看到CurricularFace的优势还是很明显的,通过自适应的方式实现,在早期突出易样本的作用(红色虚线),而在晚期突出难样本的作用(红色实线)

在这里插入图片描述

注:Curriculum Learning即课程学习,它是由Montreal大学的Bengio教授团队在2009年的ICML上提出的,其主要思想是模仿人类学习的特点,按照从简单到困难的程度来学习课程,这样容易使模型找到更好的局部最优,同时加快训练速度。

– MV-Sotamax存在的问题:从training起始阶段就开始强调semi-hard/hard-sample,可能会导致模型的收敛问题!

easy sample first, hard sample later!

2.方法

       论文中提出的一种新的自适应课程学习损失CurricularFace,是将课程学习的思想嵌入到损失函数中,以实现一种新的深度人脸识别训练策略。该策略主要针对早期训练阶段的易样本和后期训练阶段的难样本,使其在不同的训练阶段,通过一个课程表自适应地调整简单和困难样本的相对重要性。也就是说,在每个阶段,不同的样本根据其相应的困难程度被赋予不同的重要性。
       由于人类学习的本质是先易后难,CurricularFace是以一种适应性的方式将课程学习的理念融入到人脸识别中,这与传统的认知有两处明显不同:
       1)首先,课程设计的自适应性。在传统的课程学习中,样本是按照相应的难易程度排序的,这些难易程度往往是由先验知识定义的,然后固定下来建立课程。而在CurricularFace中,做法是由每个Batch随机抽取样本,通过在线挖掘难样本自适应地建立课程
       2)其次,难样本的重要性是自适应的。一方面,易样本和难样本的相对重要性是动态的,可以在不同的训练阶段进行调整。另一方面,当前Batch中每一个难样本的重要性取决于其自身的难易程度。
       具体来看,文中选择Batch中的被误分类样本作为难样本,通过调整样本与假类别中心向量之间的余弦相似度的调制系数来加权。为了在整个训练过程中实现自适应课程学习的目标,论文设计了一种新的系数函数,该函数包括以下两个因子:
       1)自适应估计参数t,该参数利用样本和其真类别间的Positive余弦相似度的移动平均值来实现自适应,以消除人工调整的负担。
       2)余弦角度参数,该参数定义难样本实现自适应分配的的难易性。

       上面介绍完了CurricularFace的基本原理,我们来看下其损失函数是如何定义的,如下:
在这里插入图片描述
其中,T(cos(θ_y)) = cos(θ_y + m), I (t, cos(θ_j))表示样本的权重函数,N(t, cos(θ_j))定义如下:

在这里插入图片描述

Adaptive Estimation of t.
       在不同的训练阶段决定一个恰当的t的值是十分重要的。理想情况下,t的值能够指示模型的训练阶段。我们通过经验发现正cosine相似度的平均值是一个好的指示器。可是min-batch的基于统计的方法往往面临一个问题:当许多极端数据被采样到一个mini-batch时,统计可能是一个很大的噪声,估计值可能很不稳定。Exponential Moving Average (EMA)方法是一个常用的解决该问题的方法,假设r(k)是第k个batch的正cosine相似度的平均值,r^(0) = 0,即:

在这里插入图片描述
则有(t^(k)随着k的增加,会呈现出单调递增的趋势):
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
Note : (a, b), a表示在训练过程中[某个时刻] curricular_loss和arcface-loss的比值;b表示max {cos(θ_j), j ≠ yi}

3.训练

3.1.训练步骤

在这里插入图片描述

3.2.训练曲线

在这里插入图片描述
在这里插入图片描述
1.x-axis : iterations, y-axis : 难样本的调整系数
2. t:adaptive parameter; M : MV-Arc-Softmax; M(ours) : gradient modulation coefficients
3.在训练早期,t --> 0,模型可以利用easy-sample加速收敛;在训练中后期t不断增大使得I(t, cos(θ_j)) > 1,这样模型可以更多地关注hard-smaples.

4.实验

       从Figure 4中可以看到,在整个训练阶段,CurricularFace对于难样本的决策边界从训练早期到后期自适应性的变化。
在这里插入图片描述
       最终,与其它方法相比,CurricularFace下的人脸识别效果得到明显改善(如Table4与Table6)
在这里插入图片描述
在这里插入图片描述

5.结论

       论文提出的自适应课程学习损失CurricularFace,将自适应课程学习的思想嵌入到人脸识别中。该方法易于实现,收敛性强,能够明显的提升人脸识别的准确率,而且它解决的是经常在训练过程中出现的问题(如:大边际和难样本),因而具备很高的实用价值。

pytorch代码:

class CurricularFace(nn.Module):"""Implementation for "CurricularFace: Adaptive Curriculum Learning Loss for Deep Face Recognition"."""def __init__(self, in_features, out_features, device_id=None, m = 0.5, s = 64., fp16 = False):super(CurricularFace, self).__init__()self.device_id = device_idself.fp16 = fp16self.m = mself.s = sself.cos_m = math.cos(m)self.sin_m = math.sin(m)self.threshold = math.cos(math.pi - m)self.mm = math.sin(math.pi - m) * mself.kernel = Parameter(torch.FloatTensor(out_features, in_features))self.register_buffer('t', torch.zeros(1))nn.init.xavier_uniform_(self.kernel)     #self.kernel = Parameter(torch.Tensor(in_features, out_features))#self.register_buffer('t', torch.zeros(1))#nn.init.normal_(self.kernel, std=0.01)def forward(self, feats, labels):#kernel_norm = F.normalize(self.kernel, dim=0)#feats = F.normalize(feats)#cos_theta = torch.mm(feats, kernel_norm)sub_weights = torch.chunk(self.kernel, len(self.device_id), dim=0)temp_x = feats.cuda(self.device_id[0])weight = sub_weights[0].cuda(self.device_id[0])cos_theta = F.linear(F.normalize(temp_x), F.normalize(weight))for i in range(1, len(self.device_id)):temp_x = x.cuda(self.device_id[i])weight = sub_weights[i].cuda(self.device_id[i])cos_theta = torch.cat((cos_theta, F.linear(F.normalize(temp_x), F.normalize(weight)).cuda(self.device_id[0])), dim=1)cos_theta = cos_theta.clamp(-1.0, 1.0)  # for numerical stabilitywith torch.no_grad():origin_cos = cos_theta.clone()target_logit = cos_theta[torch.arange(0, temp_x.size(0)), labels].view(-1, 1)sin_theta = torch.sqrt(1.0 - torch.pow(target_logit, 2))cos_theta_m = target_logit * self.cos_m - sin_theta * self.sin_m #cos(target+margin)mask = cos_theta > cos_theta_mif self.fp16:cos_theta_m = cos_theta_m.half()final_target_logit = torch.where(target_logit > self.threshold, cos_theta_m, target_logit - self.mm)hard_example = cos_theta[mask]with torch.no_grad():self.t = target_logit.mean() * 0.01 + (1 - 0.01) * self.tif self.fp16:self.t = self.t.half()cos_theta[mask] = hard_example * (self.t + hard_example)if self.device_id != None:cos_theta = cos_theta.cuda(self.device_id[0])cos_theta.scatter_(1, labels.view(-1, 1).long(), final_target_logit)output = cos_theta * self.sreturn output

相关文章:

【人脸识别】CurricularFace:自适应课程学习人脸识别损失函数

论文题目:《CurricularFace: Adaptive Curriculum Learning Loss for Deep Face Recognition》 论文地址:https://arxiv.org/pdf/2004.00288v1.pdf 代码地址:https://github.com/HuangYG123/CurricularFace 建议先了解下这篇文章&#xff1a…...

springmvc之rest风格(RESTFUL)

目录 一、介绍 1.什么是REST? 2.REST的实质 3.REST风格的优点 4.REST风格的缺点 3.什么是RESTful? 二、代码理解 一、介绍 1.什么是REST? 答:REST(Representational State Transfer) ,表现形式转…...

django项目实战十四(django+bootstrap实现增删改查)进阶混合数据使用modelform上传

目录 一、启用media 1、URL设置 2、settings.py配置 二、url 三、upload.py 新增upload_modelform方法 四、form.py新增UpModelForm 五、创建city表 六、创建city_list.html 接上一篇《django项目实战十三(djangobootstrap实现增删改查)进阶混合数据f…...

2023年CDGA考试模拟题库(1-100)

2023年CDGA考试模拟题库(1-100) 1.以下哪种活动中 ,混淆是不足以保护数据 的?[1分] A.数据共享 B.数据转换 C.数据脱敏 D.以上都正确 答案C 2.关于受控词表描述不正确的是?[1分] A.系统地组织文件档案和内容离不开受控词表 B.受控词表的一个例子是用于出版物分类的都…...

HTML常用基础内容总结

文章目录一、对HTML的感性认知前置知识什么是web前端,什么是web后端前端技术栈、后端技术栈开发与运行的区别浏览器的功能是什么简介写一个简单可运行的的html代码前端开发方式二、VSCode的简单使用三、常用的HTML标签最最基本的HTML结构HTML代码特点注释标签标题标…...

Gorm-学习笔记

1 基本使用 2 创建数据 2.1 如何使用Upsert 使用clause.OnConflict处理数据冲突 2.2 如何使用默认值 通过使用default标签为字段定义默认值 3 查询数据 3.1 First与Find 使用First时,需要注意查询不到数据会返回ErrRecordNotFound。 使用Find查询多条数据&#x…...

【Neo4j】图数据库CypherQueryLanguage随笔

CQL语言随笔 一、Cyther关系描述 如图&#xff1a;唐僧&#xff0c;孙悟空&#xff0c;白骨精三者的关系图&#xff1a; Cypher语言描述他们的关系&#xff1a; (孙悟空)<-[:赶走]-(唐僧)-[:被骗]->(白骨精)-[:被打死]->(孙悟空) 二、CQL语言的使用案例 创建结点…...

STM32Cube串口USART发送接收数据

本文代码使用 HAL 库。 文章目录前言一、USART 同步/异步串行接收/发送器二、USART 原理图三、CubeMX 创建工程四、usart.c 文件解析五&#xff0c;设计实验&#xff1a;在 串口输入字符点亮led实验现象&#xff1a;总结前言 这篇文章介绍 实现 USART 异步模式下 通过 串口助手…...

OpenFeign详解

OpenFeign是什么&#xff1f; OpenFeign&#xff1a; OpenFeign是Spring Cloud 在Feign的基础上支持了SpringMVC的注解&#xff0c;如RequesMapping等等。OpenFeign的FeignClient可以解析SpringMVC的RequestMapping注解下的接口&#xff0c;并通过动态代理的方式产生实现类&am…...

python多线程网络编程

背景 使用过flask框架后&#xff0c;我对request这个全局实例非常感兴趣。它在客户端发起请求后会保存着所有的客户端数据&#xff0c;例如用户上传的表单或者文件等。那么在很多客户端发起请求时&#xff0c;服务器是怎么去区分不同的request对象呢&#xff1f;当查看了大量的…...

BFS-走迷宫

题目描述 给定一个 NM 的网格迷宫 G。G 的每个格子要么是道路,要么是障碍物(道路用 1 表示,障碍物用 0 表示)。 已知迷宫的入口位置为 (x1,y1),出口位置为 (x2...

【蓝牙mesh】Lower协议层介绍

【蓝牙mesh】Lower协议层介绍 Lower层简介 Lower协议层用于处理网络层以下的功能&#xff0c;包括节点的广播、重传、路由和网络拓扑等&#xff0c;是实现蓝牙mesh网络的关键协议之一。其中Lower协议层中最主要的一部分工作就是mesh数据的分片和组包。 Lower层是将Upper层发过…...

Java-重排序,happens-before 和 as-if-serial 语义

目录1. 如何解决重排序带来的问题2. happens-before1. 如何解决重排序带来的问题 对于编译器&#xff0c;JMM 的编译器重排序规则会禁止特定类型的编译器重排序。对于处理器重排序&#xff0c;JMM 的处理器重排序规则会要求编译器在生成指令序列时&#xff0c;插入特定类型的内…...

Nginx安装及介绍

前言&#xff1a;传统结构上(如下图所示)我们只会部署一台服务器用来跑服务&#xff0c;在并发量小&#xff0c;用户访问少的情况下基本够用但随着用户访问的越来越多&#xff0c;并发量慢慢增多了&#xff0c;这时候一台服务器已经不能满足我们了&#xff0c;需要我们增加服务…...

【华为OD机试模拟题】用 C++ 实现 - 寻找路径 or 数组二叉树(2023.Q1)

最近更新的博客 【华为OD机试模拟题】用 C++ 实现 - 获得完美走位(2023.Q1) 文章目录 最近更新的博客使用说明寻找路径 or 数组二叉树题目输入输出描述示例一输入输出示例二输入输出Code使用说明 参加华为od机试,一定要注意不要完全背诵代码,需要理解之后模仿写出,通过…...

LINUX学习记录

回顾系列&#xff1a;两天的时间&#xff08;2023.2.24-2023.2.25&#xff09;重新学了遍Linux基础课&#xff0c;收获非常多&#xff0c;以前只会一些简单的Linux命令&#xff0c;对shell&#xff0c;git&#xff0c;管道&#xff0c;复杂Linux命令都不熟悉&#xff0c;学完之…...

华为OD机试用Python实现 -【狼羊过河 or 羊、狼、农夫过河】(2023-Q1 新题)

华为OD机试题 华为OD机试300题大纲狼羊过河 or 羊、狼、农夫过河题目描述输入描述输出描述说明示例一输入输出说明Python 代码实现代码实现思路华为OD机试300题大纲 参加华为od机试,一定要注意不要完全背诵代码,需要理解之后模仿写出,通过率才会高。 华为 OD 清单查看地址…...

【SAP Abap】X-DOC:SAP ABAP 语法更新之Open SQL

SAP ABAP 语法更新之Open SQL1、前言2、演示1、前言 自从 SAP 推出 SAP ON HANA&#xff0c;与之相随的 AS ABAP NW 7.40 版本以后&#xff0c;ABAP 语法也有了较多的更新&#xff0c;本篇对 Open Sql的语法更新部分做一个DEMO演示。 NW 7.40 以前 OpenSQL 的限制&#xff1a…...

leetcode 困难 —— 数组中的逆序对(分治法)

题目&#xff1a; 在数组中的两个数字&#xff0c;如果前面一个数字大于后面的数字&#xff0c;则这两个数字组成一个逆序对。输入一个数组&#xff0c;求出这个数组中的逆序对的总数。 题解&#xff1a; ① 我最开始想的蠢方法&#xff08;会超时&#xff0c;可跳过&#xff…...

02.24:图片的风格转换

Github网址&#xff1a;https://github.com/lengstrom/fast-style-transfer 在anaconda prompt中切换环境命令&#xff1a;activate 环境名 列出所有环境名&#xff1a;conda info --envs 安装环境&#xff1a;conda create -n 环境名 pythonx.x.x 删除某环境&#xff1a;co…...

Ubuntu系统下交叉编译openssl

一、参考资料 OpenSSL&&libcurl库的交叉编译 - hesetone - 博客园 二、准备工作 1. 编译环境 宿主机&#xff1a;Ubuntu 20.04.6 LTSHost&#xff1a;ARM32位交叉编译器&#xff1a;arm-linux-gnueabihf-gcc-11.1.0 2. 设置交叉编译工具链 在交叉编译之前&#x…...

基于距离变化能量开销动态调整的WSN低功耗拓扑控制开销算法matlab仿真

目录 1.程序功能描述 2.测试软件版本以及运行结果展示 3.核心程序 4.算法仿真参数 5.算法理论概述 6.参考文献 7.完整程序 1.程序功能描述 通过动态调整节点通信的能量开销&#xff0c;平衡网络负载&#xff0c;延长WSN生命周期。具体通过建立基于距离的能量消耗模型&am…...

前端导出带有合并单元格的列表

// 导出async function exportExcel(fileName "共识调整.xlsx") {// 所有数据const exportData await getAllMainData();// 表头内容let fitstTitleList [];const secondTitleList [];allColumns.value.forEach(column > {if (!column.children) {fitstTitleL…...

React Native在HarmonyOS 5.0阅读类应用开发中的实践

一、技术选型背景 随着HarmonyOS 5.0对Web兼容层的增强&#xff0c;React Native作为跨平台框架可通过重新编译ArkTS组件实现85%以上的代码复用率。阅读类应用具有UI复杂度低、数据流清晰的特点。 二、核心实现方案 1. 环境配置 &#xff08;1&#xff09;使用React Native…...

华为OD机试-食堂供餐-二分法

import java.util.Arrays; import java.util.Scanner;public class DemoTest3 {public static void main(String[] args) {Scanner in new Scanner(System.in);// 注意 hasNext 和 hasNextLine 的区别while (in.hasNextLine()) { // 注意 while 处理多个 caseint a in.nextIn…...

苍穹外卖--缓存菜品

1.问题说明 用户端小程序展示的菜品数据都是通过查询数据库获得&#xff0c;如果用户端访问量比较大&#xff0c;数据库访问压力随之增大 2.实现思路 通过Redis来缓存菜品数据&#xff0c;减少数据库查询操作。 缓存逻辑分析&#xff1a; ①每个分类下的菜品保持一份缓存数据…...

在Ubuntu中设置开机自动运行(sudo)指令的指南

在Ubuntu系统中&#xff0c;有时需要在系统启动时自动执行某些命令&#xff0c;特别是需要 sudo权限的指令。为了实现这一功能&#xff0c;可以使用多种方法&#xff0c;包括编写Systemd服务、配置 rc.local文件或使用 cron任务计划。本文将详细介绍这些方法&#xff0c;并提供…...

Springcloud:Eureka 高可用集群搭建实战(服务注册与发现的底层原理与避坑指南)

引言&#xff1a;为什么 Eureka 依然是存量系统的核心&#xff1f; 尽管 Nacos 等新注册中心崛起&#xff0c;但金融、电力等保守行业仍有大量系统运行在 Eureka 上。理解其高可用设计与自我保护机制&#xff0c;是保障分布式系统稳定的必修课。本文将手把手带你搭建生产级 Eur…...

tree 树组件大数据卡顿问题优化

问题背景 项目中有用到树组件用来做文件目录&#xff0c;但是由于这个树组件的节点越来越多&#xff0c;导致页面在滚动这个树组件的时候浏览器就很容易卡死。这种问题基本上都是因为dom节点太多&#xff0c;导致的浏览器卡顿&#xff0c;这里很明显就需要用到虚拟列表的技术&…...

初学 pytest 记录

安装 pip install pytest用例可以是函数也可以是类中的方法 def test_func():print()class TestAdd: # def __init__(self): 在 pytest 中不可以使用__init__方法 # self.cc 12345 pytest.mark.api def test_str(self):res add(1, 2)assert res 12def test_int(self):r…...