当前位置: 首页 > news >正文

【center-loss 中心损失函数】 原理及程序解释(更新中)

文章目录

  • 前言
  • 问题引出
    • open-set问题
    • 抛出
  • 解决方法
    • softmax函数、softmax-loss函数
    • 解决
    • 代码(center_loss.py)
      • 原理
      • 程序解释
  • 如何梯度更新
    • 首先了解一下基本的梯度下降算法
    • 然后


前言

学习一下: 中心损失函数,用于用于深度人脸识别的特征判别方法
论文:https://ydwen.github.io/papers/WenECCV16.pdf
github代码:https://github.com/KaiyangZhou/pytorch-center-loss

参考:史上最全MNIST系列(三)——Centerloss在MNIST上的Pytorch实现(可视化)


问题引出

open-set问题

open-set 问题是一种模式识别中的问题,它指的是当训练集和测试集的类别不完全相同的情况。例如,如果训练集只包含 0 到 9 的数字,而测试集包含了 A 到 Z 的字母,那么就是一个 open-set 问题。这种情况下,分类器不仅要正确识别已知的类别,还要能够拒绝未知的类别,即将它们标记为 unknown 或 outlier。(后来我想这就是一个聚类问题)

open-set 问题与 closed-set 问题相对应,closed-set 问题是指训练集和测试集的类别完全相同的情况。例如,如果训练集和测试集都只包含 0 到 9 的数字,那么就是一个 closed-set 问题。这种情况下,分类器只需要正确识别已知的类别即可。

抛出

当我们要预测的人脸不在训练集中出现过时,我们需要让它不识别出,而不要因为与训练过的人脸相像而误判。

对于常见的图像分类问题,我们常常用softmax loss来求损失。以MNIST数据集为例,如果你的损失采用softmax loss,那么最后各个类别学出来的特征分布大概如下图:(在倒数第二层全连接层输出了一个2维的特征向量)
在这里插入图片描述
左图为训练集,右图为测试集,发现结果分类还算不错,但是每一类的界限太过模糊,若从中再加一列,则有可能出现误判。

解决方法

softmax函数、softmax-loss函数

已知softmax的函数为:
在这里插入图片描述
在深度学习的分类问题上,意为对应类的概率(符合每个类的概率相加和为1且0<每个类的概率<1)。(输出层后的结果)
其中 zi​ 是第 i 个输出节点的值,K是输出节点的个数,即分类的类别数。softmax函数的作用是将输出节点的值归一化为范围在 [0, 1] 且和为 1 的概率值,表示属于每个类别的可能性。

softmax的损失函数:(具体详见:一文详解Softmax函数)
在softmax的函数的基础上,我们要求正确对应类的概率最大,
在这里插入图片描述
即,损失函数为:(越小越好)
在这里插入图片描述
其中z = wTx+b,m是小批量的样本的个数,n是类别个数,w是全连接层的权重,b为偏置项,(一般来说w,b是要学习出来的),xi是第i个深层特征,属于第yi类
在这里插入图片描述

解决

在分类的基础上,我们还要求,每个类,往自己的中心的特征靠拢,使类内间距减少,这样才能更加显出差别。
在这里插入图片描述
于是,论文作者提出如下新的损失函数:
在这里插入图片描述
其中 xi​ 是第 i 个样本的特征向量,cyi​​ 是第 yi​ 个类别的中心向量,m 是样本的个数。
小的注意点:这里用的是样本数,也就是说,是在小批量分类任务完成后再进行的聚类任务
在这里插入图片描述
我们不难发现,作者这里用的是欧式距离的平方,即在多维空间上的两点之间真实距离的平方。
实际上,我们很容易发现,这实际上是一个聚类问题,常见的聚类问题上用的是误差平方和(SSE)损失函数:
在这里插入图片描述

论文作者仅从此推出的欧式距离的平方,仅从形式上十分相像,当然可以作为此聚类问题的解决。
在这里插入图片描述
(在二维上,可以简单看成是直角三角形斜边的平方=两直角边平方和)
为何要加以平方:欧式距离的平方相比欧式距离有一些优点,例如:
1、计算更快,省去了开方的运算。
2、更加敏感,能够放大距离的差异,使得离群点更容易被发现。
3、更加方便,能够与其他平方项相结合,如方差、协方差等。

最后,作者沿用softmax损失函数与中心损失函数相加的方法(在损失函数上很常见),
在这里插入图片描述
这里的1/2很容易想到是作为梯度下降时与后面的平方用来抵消的项,这里λ 可以看作调节两者损失函数的比例

来作为总体的损失函数,作为此问题的解决。

代码(center_loss.py)

原理

首先确定三个事实:
1、在之前的学习中通过实践得知(最小二乘法那块),在拟合的最后结果上,在利用SSE损失函数与MSE损失函数作为损失值参与到程序中时,拟合出的结果并无差别,只有在结果出来后,作为评判模型的好坏时,才有数值上的差别。(两者区别在于是否求平均)。
2、1/2的乘或不乘,只对运算的过程的简便程度有影响,与结果无影响。
3、图像是二维的。

于是我们将作者的的中心损失函数稍加变形。就得到了如下形式(事实上,github上给出的代码就是这么写的)
在这里插入图片描述

程序解释

参考:Center loss-pytorch代码详解
这里只做补充:

import torch
import torch.nn as nnclass CenterLoss(nn.Module):"""Center loss.# 参考Reference:Wen et al. A Discriminative Feature Learning Approach for Deep Face Recognition. ECCV 2016.# 参数Args:num_classes (int): number of classes. # 类别数feat_dim (int): feature dimension.  # 特征维度"""# 初始化        默认参数:类别数为10 特征维度为2 使用GPUdef __init__(self, num_classes=10, feat_dim=2, use_gpu=True):super(CenterLoss, self).__init__() # 继承父类的所有属性self.num_classes = num_classes self.feat_dim = feat_dimself.use_gpu = use_gpuif self.use_gpu: # 如果使用GPU#nn.Parameter()将一个不可训练的tensor转换成可以训练的类型parameter,并将这个parameter绑定到一个module里面,参与到模型的训练和优化中。#nn.Parameter的对象的requires_grad属性的默认值是True,即是可被训练的,这与torth.Tensor对象的默认值相反。self.centers = nn.Parameter(torch.randn(self.num_classes, self.feat_dim).cuda()) # 初始化中心矩阵 .cuda()表示将数据放到GPU上else:self.centers = nn.Parameter(torch.randn(self.num_classes, self.feat_dim))def forward(self, x, labels): # 前向传播"""Args:x: feature matrix with shape (batch_size, feat_dim). # 特征矩阵labels: ground truth labels with shape (batch_size). # 真实标签"""batch_size = x.size(0) #batch_size x的形式为tensor张量# .pow对x中的每一个元素求平方 dim=1表示按行求和 keepdim=True表示保持原来的维度 expand是扩展维度distmat = torch.pow(x, 2).sum(dim=1, keepdim=True).expand(batch_size, self.num_classes) + \torch.pow(self.centers, 2).sum(dim=1, keepdim=True).expand(self.num_classes, batch_size).t() #.t()表示转置 均转成batch_size x num_classes的形式# .addmm_()表示进行1*distmat + (-2)*x@self.centers.t()的运算    @表示矩阵乘法distmat.addmm_(1, -2, x, self.centers.t())classes = torch.arange(self.num_classes).long()# 生成一个从0到num_classes-1的整数序列 long表示数据类型if self.use_gpu: classes = classes.cuda() #这里 .unsqueeze(0) 例[0,4,2] -> [[0,4,2]]  .unsqueeze(1) 例[0,4,2] -> [[0],[4],[2]] labels = labels.unsqueeze(1).expand(batch_size, self.num_classes)# .unsqueeze(1)表示在第1维增加一个维度  .expand()表示扩展维度mask = labels.eq(classes.expand(batch_size, self.num_classes)) #eq是比较两个tensor是否相等,相等返回1,不相等返回0# *表示对应元素相乘dist = distmat * mask.float() # mask.float()将mask转换为float类型loss = dist.clamp(min=1e-12, max=1e+12).sum() / batch_size # clamp表示将dist中的元素限制在1e-12和1e+12之间return loss

如何梯度更新

首先了解一下基本的梯度下降算法

【点云、图像】学习中 常见的数学知识及其中的关系与python实战
里的小标题:基于迭代的梯度下降算法,了解到超参数(人为设定的参数):
学习率,w,b等。此算法为拟合出一条线。

伪代码:
1、未达到设定迭代次数
2、迭代次数epoch+1
3、计算损失值
4、计算梯度
5、更新w、b
6、达到迭代次数,结束

然后

看下这里的更新方法
在这里插入图片描述
其中,卷积层中初始化的参数 θc,超参数 :λ、α 和学习率 μ ,迭代次数 t 。(λ、α、 μ均可调)
伪代码:
1、当未收敛时:
2、迭代次数t+1
3、计算损失函数 L=Ls+Lc
4、计算反向传播误差(即梯度)
5、更新w
6、更新cj
7、更新θc
8、结束

其中
在这里插入图片描述
其中,如果满足条件,则 δ(条件) = 1,如果不满足,则 δ(条件) = 0。

首先:第一个公式不用多说,为中心损失函数对特征xi求偏导,即求梯度。
其次:第二个公式:加入了判断,当条件满足时(即yi=j,即就是在同一类时)Δcj=cj-xi 的小批量的所有和/m+1。相当于累加了误差求平均,以此来作为梯度。
当条件不满足时,即Δcj =0,此时这个分母的+1的作用就体现出来了,分母不能为0嘛。

补充:
第5步后面一个等号成立是因为Lc中没有W项,所以Lc对W的求导为0
第7步也是一样,求梯度,只是写成了求导的链式法则。

相关文章:

【center-loss 中心损失函数】 原理及程序解释(更新中)

文章目录 前言问题引出open-set问题抛出 解决方法softmax函数、softmax-loss函数解决代码&#xff08;center_loss.py&#xff09;原理程序解释 如何梯度更新首先了解一下基本的梯度下降算法然后 前言 学习一下&#xff1a; 中心损失函数&#xff0c;用于用于深度人脸识别的特…...

什么是 HTTPS 证书?作用是什么?

HTTPS 证书&#xff0c;即超文本传输安全协议证书&#xff08;Hypertext Transfer Protocol Secure&#xff09;&#xff0c;是网站安全的关键组成部分。它通过 SSL/TLS 加密协议&#xff0c;确保用户与网站之间的数据传输是加密和安全的。 什么是 HTTPS 证书&#xff1f; HT…...

为什么软考报名人数越来越多?

2020年软考报名人数404666人&#xff0c;广东省报考人数超过14万人。 ●2021年软考通信考试报名人数突破100万人&#xff0c;估计软考有90多万。 ●2022年软考通信考试共129万人&#xff0c;估计软考占了120多万人。 ●2023年软考具体报名人数没有公布&#xff0c;但工业和信…...

【投稿优惠|快速见刊】2024年图像,机器学习和人工智能国际会议(ICIMLAI 2024)

【投稿优惠|快速见刊】2024年图像&#xff0c;机器学习和人工智能国际会议&#xff08;ICIMLAI 2024&#xff09; 重要信息 会议官网&#xff1a;http://www.icimlai.com会议地址&#xff1a;深圳召开日期&#xff1a;2024.03.30截稿日期&#xff1a;2024.03.20 &#xff08;先…...

html2canvas 将DOM节点转成图片

官网地址&#xff1a;html2canvas - Screenshots with JavaScript 将js文件保存到本地 可以新建一个txt文件&#xff0c;然后丢进去修改后缀名称即可。 在项目中引入js文件&#xff1a; import html2canvas from "../html2canvas.min.js" 这是我准备画的DOM节点。…...

【多线程】常见锁策略详解(面试常考题型)

目录 &#x1f334; 乐观锁 vs 悲观锁&#x1f38d;重量级锁 vs 轻量级锁&#x1f340;自旋锁&#xff08;Spin Lock&#xff09;&#x1f38b;公平锁 vs ⾮公平锁&#x1f333;可重⼊锁 vs 不可重⼊锁&#x1f384;读写锁⭕相关面试题 常⻅的锁策略 注意: 接下来讲解的锁策略不…...

Python列表操作函数

在Python中&#xff0c;列表&#xff08;list&#xff09;是一种可变的数据类型&#xff0c;它包含一系列有序的元素。Python提供了一系列内置的函数和方法来操作列表。以下是一些常用的Python列表操作函数和方法&#xff1a; 列表方法 append(x) 将元素x添加到列表的末尾。 …...

Qt注册类对象单例与单类型区别

1.实现类型SingletonTypeExample #ifndef SINGLETONTYPEEXAMPLE_H #define SINGLETONTYPEEXAMPLE_H#include <QObject>class SingletonTypeExample : public QObject {Q_OBJECT public://只能显示构造类对象explicit SingletonTypeExample(QObject *parent nullptr);//…...

Rocky Linux 运维工具yum

一、yum的简介 ​​yum​是用于在基于RPM包管理系统的包管理工具。用户可以通过 ​yum​来搜索、安装、更新和删除软件包&#xff0c;自动处理依赖关系&#xff0c;方便快捷地管理系统上的软件。 二、yum的参数说明 1、install 用于在系统的上安装一个或多个软件包 2、seach 用…...

linux下的ollama

refs: https://github.com/ollama/ollama/blob/main/docs/linux.md 1)安装 curl -fsSL https://ollama.com/install.sh | sh 2)修改服务配置&#xff0c;打开端口允许所有IP地址 refs(https://github.com/ollama/ollama/blob/main/docs/faq.md#where-are-models-stored) C…...

YOLOv9详细解读,改进提升全面分析(附YOLOv9结构图)

&#x1f951; Welcome to Aedream同学 s blog! &#x1f951; 文章目录 1. 概要1.1 模型结构上的改动:1.2 训练脚本上的改动&#xff1a; 2. 介绍2.1 背景2.2 主要贡献 3. 总体框架3.1 可编程梯度信息&#xff08;PGI&#xff09;3.1.1 辅助可逆分支3.1.2 多级辅助信息 3.2 Ge…...

html基础操练和进阶修炼宝典

文章目录 1.超链接标签2.跳锚点3.图片标签4.表格5.表格的方向属性6.子窗口7.音视频标签8.表单9.文件上传10.input属性 html修炼必经之路—各种类型标签详解加展示&#xff0c;关注点赞加收藏&#xff0c;防止迷路哦 1.超链接标签 <!DOCTYPE html> <html lang"en…...

从Mysql 数据库删除重复记录只保留其中一条(删除id最小的一条)

准备工作&#xff1a;新建表tb_coupon /*Navicat Premium Data TransferSource Server : rootlocalhostSource Server Type : MySQLSource Server Version : 50527Source Host : localhost:3306Source Schema : leyouTarget Server Type : My…...

从http到websocket

阅读本文之前&#xff0c;你最好已经做过一些websocket的简单应用 从http到websocket HTTP101HTTP 轮询、长轮询和流化其他技术1. 服务器发送事件2. SPDY3. web实时通信 互联网简史web和httpWebsocket协议1. 简介2. 初始握手3. 计算响应健值4. 消息格式5. WebSocket关闭握手 实…...

UE5 C++ Widget练习 Button 和 ProgressBar创建血条

一. 1.C创建一个继承Widget类的子类&#xff0c; 命名为MyUserWidget 2.加上Button 和 UserWidget的头文件 #include "CoreMinimal.h" #include "Components/Button.h" #include "Blueprint/UserWidget.h" #include "MyUserWidget.genera…...

抖店无货源违规频发,不能入驻?这个是真的吗?

我是电商珠珠 还没有踏入抖店这个电商行业的新手&#xff0c;单从别人的口中&#xff0c;听说了抖店无货源特别容易违规&#xff0c;还会被扣除全部的保证金&#xff0c;得不偿失之类的话。有的还专门劝诫新手不要做抖店&#xff0c;做了就会亏本之类的话&#xff0c;这搞得人…...

HarmonyOS—开发云数据库

您可以在云侧工程下开发云数据库资源&#xff0c;包括创建对象类型、在对象类型中添加数据条目、部署云数据库。 创建对象类型 对象类型&#xff08;即ObjectType&#xff09;用于定义存储对象的集合&#xff0c;不同的对象类型对应的不同数据结构。每创建一个对象类型&#…...

mysql查询某个数据库的数量有多少GB

要查询MySQL数据库中某个数据库&#xff08;或称为“schema”&#xff09;所占用的磁盘空间大小&#xff08;以GB为单位&#xff09;&#xff0c;你可以使用information_schema数据库中的TABLES和DATA_LENGTH、INDEX_LENGTH字段来获取每个表的数据和索引的大小&#xff0c;然后…...

table展示子级踩坑

##elemenui中table通过row中是否有children进行判断是否展示子集&#xff0c;通过设置tree-prop的属性进行设置&#xff0c;子级的children的名字可以根据自己的子级名字进行替换&#xff0c;当然同样可以对数据处理成含有chilren的子级list。 问题&#xff1a; 1.如果是根据后…...

xss过waf的小姿势

今天看大佬的视频学到了几个操作 首先是拆分发可以用self将被过滤的函数进行拆分 如下图我用self将alert拆分成两段依然成功执行 然后学习另一种姿势 <svg id"YWxlcnQoIlhTUyIp"><img src1 οnerrοr"window[eval](atob(document.getElementsByTagNa…...

遍历 Map 类型集合的方法汇总

1 方法一 先用方法 keySet() 获取集合中的所有键。再通过 gey(key) 方法用对应键获取值 import java.util.HashMap; import java.util.Set;public class Test {public static void main(String[] args) {HashMap hashMap new HashMap();hashMap.put("语文",99);has…...

蓝牙 BLE 扫描面试题大全(2):进阶面试题与实战演练

前文覆盖了 BLE 扫描的基础概念与经典问题蓝牙 BLE 扫描面试题大全(1)&#xff1a;从基础到实战的深度解析-CSDN博客&#xff0c;但实际面试中&#xff0c;企业更关注候选人对复杂场景的应对能力&#xff08;如多设备并发扫描、低功耗与高发现率的平衡&#xff09;和前沿技术的…...

屋顶变身“发电站” ,中天合创屋面分布式光伏发电项目顺利并网!

5月28日&#xff0c;中天合创屋面分布式光伏发电项目顺利并网发电&#xff0c;该项目位于内蒙古自治区鄂尔多斯市乌审旗&#xff0c;项目利用中天合创聚乙烯、聚丙烯仓库屋面作为场地建设光伏电站&#xff0c;总装机容量为9.96MWp。 项目投运后&#xff0c;每年可节约标煤3670…...

Keil 中设置 STM32 Flash 和 RAM 地址详解

文章目录 Keil 中设置 STM32 Flash 和 RAM 地址详解一、Flash 和 RAM 配置界面(Target 选项卡)1. IROM1(用于配置 Flash)2. IRAM1(用于配置 RAM)二、链接器设置界面(Linker 选项卡)1. 勾选“Use Memory Layout from Target Dialog”2. 查看链接器参数(如果没有勾选上面…...

OpenPrompt 和直接对提示词的嵌入向量进行训练有什么区别

OpenPrompt 和直接对提示词的嵌入向量进行训练有什么区别 直接训练提示词嵌入向量的核心区别 您提到的代码: prompt_embedding = initial_embedding.clone().requires_grad_(True) optimizer = torch.optim.Adam([prompt_embedding...

Java 二维码

Java 二维码 **技术&#xff1a;**谷歌 ZXing 实现 首先添加依赖 <!-- 二维码依赖 --><dependency><groupId>com.google.zxing</groupId><artifactId>core</artifactId><version>3.5.1</version></dependency><de…...

Java编程之桥接模式

定义 桥接模式&#xff08;Bridge Pattern&#xff09;属于结构型设计模式&#xff0c;它的核心意图是将抽象部分与实现部分分离&#xff0c;使它们可以独立地变化。这种模式通过组合关系来替代继承关系&#xff0c;从而降低了抽象和实现这两个可变维度之间的耦合度。 用例子…...

[大语言模型]在个人电脑上部署ollama 并进行管理,最后配置AI程序开发助手.

ollama官网: 下载 https://ollama.com/ 安装 查看可以使用的模型 https://ollama.com/search 例如 https://ollama.com/library/deepseek-r1/tags # deepseek-r1:7bollama pull deepseek-r1:7b改token数量为409622 16384 ollama命令说明 ollama serve #&#xff1a…...

苹果AI眼镜:从“工具”到“社交姿态”的范式革命——重新定义AI交互入口的未来机会

在2025年的AI硬件浪潮中,苹果AI眼镜(Apple Glasses)正在引发一场关于“人机交互形态”的深度思考。它并非简单地替代AirPods或Apple Watch,而是开辟了一个全新的、日常可接受的AI入口。其核心价值不在于功能的堆叠,而在于如何通过形态设计打破社交壁垒,成为用户“全天佩戴…...

Ubuntu Cursor升级成v1.0

0. 当前版本低 使用当前 Cursor v0.50时 GitHub Copilot Chat 打不开&#xff0c;快捷键也不好用&#xff0c;当看到 Cursor 升级后&#xff0c;还是蛮高兴的 1. 下载 Cursor 下载地址&#xff1a;https://www.cursor.com/cn/downloads 点击下载 Linux (x64) &#xff0c;…...