当前位置: 首页 > news >正文

简单谈谈 EMP-SSL:自监督对比学习的一种极简主义风

论文链接:https://arxiv.org/pdf/2304.03977.pdf

代码:https://github.com/tsb0601/EMP-SSL

其他学习链接:突破自监督学习效率极限!马毅、LeCun联合发布EMP-SSL:无需花哨trick,30个epoch即可实现SOTA


主要思想

如图,一张图片裁剪成不同的 patch,对不同的 patch 做数据增强,分别输入 encoder,得到多个 embedding,对它们求均值,得到 \bar z 作为这张图片的 embedding。最后,拉近每个 patch 的 embedding 和图片的 embedding(\bar z)之间的余弦距离;再用 Total Coding Rate(TCR) 防止坍塌(即 encoder 对所有输入都输出相同的 embedding)

图片

图片

Total Coding Rate(TCR)

公式如下:

图片

其中,det 表示求矩阵的行列式,d 是 feature vector 的 dimension,b 是 batch size

查了查该公式的含义:expand all features of Z as large as possible,即尽可能拉远矩阵中特征之间的距离。

源自 PPT 第 24 页:

https://s3.amazonaws.com/sf-web-assets-prod/wp-content/uploads/2021/06/15175515/Deep_Networks_from_First_Principles.pdf

至于为什么最大化该公式的值就可以拉远矩阵中特征之间的距离,这背后的数学原理真难啃啊 /(ㄒoㄒ)/~~


核心代码解读

数据处理

https://github.com/tsb0601/EMP-SSL/blob/main/dataset/aug.py#L116C1-L138C27

class ContrastiveLearningViewGenerator(object):def __init__(self, num_patch = 4):self.num_patch = num_patchdef __call__(self, x):normalize = transforms.Normalize([0.5,0.5,0.5], [0.5,0.5,0.5])aug_transform = transforms.Compose([transforms.RandomResizedCrop(32,scale=(0.25, 0.25), ratio=(1,1)),transforms.RandomHorizontalFlip(p=0.5),transforms.RandomApply([transforms.ColorJitter(0.4, 0.4, 0.4, 0.2)], p=0.8),transforms.RandomGrayscale(p=0.2),GBlur(p=0.1),transforms.RandomApply([Solarization()], p=0.1),transforms.ToTensor(),  normalize])augmented_x = [aug_transform(x) for i in range(self.num_patch)]return augmented_x

由此看出返回的 数据 为:长度为 num_patches 个 tensor 的列表。其中,每个 tensor 的 shape 为 (B, C, H, W)。

主函数

https://github.com/tsb0601/EMP-SSL/blob/main/main.py#L148C9-L162C63

for step, (data, label) in tqdm(enumerate(dataloader)):net.zero_grad()opt.zero_grad()data = torch.cat(data, dim=0) data = data.cuda()z_proj = net(data)z_list = z_proj.chunk(num_patches, dim=0)z_avg = chunk_avg(z_proj, num_patches)# Contractive Lossloss_contract, _ = contractive_loss(z_list, z_avg)loss_TCR = cal_TCR(z_proj, criterion, num_patches)

这里要稍微注意一下几个变量的 shape:

  • data 被 cat 完后:(num_patches * B,C,H,W)
  • z_proj:(num_patches * B,C)
  • z_list:(num_patches,B,C)
  • z_avg:(B,C)

其中,chunk_avg 就是对来自同一张图片的不同 patch 的 embedding 求均值(\bar z):

https://github.com/tsb0601/EMP-SSL/blob/main/main.py#L67

def chunk_avg(x,n_chunks=2,normalize=False):x_list = x.chunk(n_chunks,dim=0)x = torch.stack(x_list,dim=0)if not normalize:return x.mean(0)else:return F.normalize(x.mean(0),dim=1)

loss

contractive_loss 就是计算每个 patch 的 embedding 和均值(\bar z)的余弦距离:

https://github.com/tsb0601/EMP-SSL/blob/main/main.py#L76

class Similarity_Loss(nn.Module):def __init__(self, ):super().__init__()passdef forward(self, z_list, z_avg):z_sim = 0num_patch = len(z_list)z_list = torch.stack(list(z_list), dim=0)z_avg = z_list.mean(dim=0)z_sim = 0for i in range(num_patch):z_sim += F.cosine_similarity(z_list[i], z_avg, dim=1).mean()z_sim = z_sim/num_patchz_sim_out = z_sim.clone().detach()return -z_sim, z_sim_out

TCR loss:最大化矩阵之间特征的距离,即拉远负样本(不是来自同一个样本的 patches)之间的距离

https://github.com/tsb0601/EMP-SSL/blob/main/main.py#L96

def cal_TCR(z, criterion, num_patches):z_list = z.chunk(num_patches,dim=0)loss = 0for i in range(num_patches):loss += criterion(z_list[i])loss = loss/num_patchesreturn loss

需要注意:函数输入的 z 是 z_proj,形状为(num_patches * B,C)。

所以,函数内部 z_list 的形状为(num_patches,B,C),即将数据分为了 num_patches 个组,每个组包含了来自不同图片里 patch 的 embedding。再分别对每个组求 TCR loss,最大化组内(不同图片的 patch)特征的距离。

所以,公式中的 Z 指的是一组来自不同图片里 patch 的 embedding,形状为(B,C)。

每个组内求 TCR loss 的代码按照公式计算,如下: 

图片

https://github.com/tsb0601/EMP-SSL/blob/main/loss.py#L76

class TotalCodingRate(nn.Module):def __init__(self, eps=0.01):super(TotalCodingRate, self).__init__()self.eps = epsdef compute_discrimn_loss(self, W):"""Discriminative Loss."""p, m = W.shape  #[d, B]I = torch.eye(p,device=W.device)scalar = p / (m * self.eps)logdet = torch.logdet(I + scalar * W.matmul(W.T))return logdet / 2.def forward(self,X):return - self.compute_discrimn_loss(X.T)

相关文章:

简单谈谈 EMP-SSL:自监督对比学习的一种极简主义风

论文链接:https://arxiv.org/pdf/2304.03977.pdf 代码:https://github.com/tsb0601/EMP-SSL 其他学习链接:突破自监督学习效率极限!马毅、LeCun联合发布EMP-SSL:无需花哨trick,30个epoch即可实现SOTA 主要…...

nginx的负载均衡

nginx的负载均衡 文章目录 nginx的负载均衡1.以多台虚拟机作服务器1.1 在不同的虚拟机上安装httpd服务1.2 在不同虚拟机所构建的服务端的默认路径下创建不同标识的文件1.3 使用windows本机的浏览器分别访问3台服务器的地址 2.在新的一台虚拟机上配置nginx实现反向代理以及负载均…...

linux系统服务学习(四)Linux系统下数据同步服务RSYNC

文章目录 Linux系统下数据同步服务RSYNC一、RSYNC概述1、什么是rsyncrsync的好姐妹数据同步过程 2、rsync特点3、rsync与scp的区别 二、RSYNC的使用1、基本语法2、本地文件同步3、远程文件同步思考:4、rsync作为系统服务Linux系统服务的思路: 三、任务解…...

走进 Linux

一、开关机 开机: 开机会启动许多程序。他们在windows叫做“服务”(service),在Linux就叫做“守护进程”(daemon)开机成功后,它会显示一个文本登录界面, 这个界面就是我们经常看到的登录界面,在这个登录界…...

Docker高级——Docker Swarm集群和部署应用

创建 Swarm 集群 初始化管理节点 [rootk8s-master ~]# docker swarm init --advertise-addr 192.168.192.133 Swarm initialized: current node (vy95txqo3pglh478e4qew1h28) is now a manager.To add a worker to this swarm, run the following command:docker swarm join …...

【SA8295P 源码分析】74 - QNX secpol 安全策略文件配置详解 及 secpol.bin 编译过程分析

【SA8295P 源码分析】74 - QNX secpol 安全策略文件配置详解 及 secpol.bin 编译过程分析 一、secpol 的编译流程:编译生成 secpol.bin 打包在 ifs2_la.img 中二、QNX 开启 secpol 功能三、为新进程 创建 新的secpol 安全策略:以 vmm_service 为例四、secpol 配置示例,以 I2…...

Docker入门使用

用一个hello world的小例子来入门docker 在 Docker 容器中部署 Python Flask 的简单 Hello World 项目,需要遵循以下流程: 编写应用程序 首先,在本地计算机上编写一个简单的 PythonFlask 应用程序,例如: # hello.…...

在SAP上使用 LiquidUI Android 扫描条形码/QR 码

LiquidUI Android 可使用安卓移动设备的内置摄像头扫描条形码和二维码,为输入框填充数值。因此,无需附加任何第三方设备进行扫描。 LiquidUI Android 还提供了扫描功能,如 Accessible-Enter(俗称自动输入)和 Accessib…...

Maven - 全面解析 Maven BOM (Bill of Materials):打造高效依赖管理与模块化开发

文章目录 Whats BOMWhy Bom常見的官方BOMSpring Maven BOM dependencySpringBoot SpringCloud Maven BOM dependencyJBOSS Maven BOM dependencyRESTEasy Maven BOM dependencyJersey Maven BOM dependency How Bom定义BOM其他工程使用的方法 BOM VS POM What’s BOM BOM&…...

Lua脚本对比redis事务区别是什么

redis官方对于lua脚本的解释:Redis使用同一个Lua解释器来执行所有命令,同时,Redis保证以一种原子性的方式来执行脚本:当lua脚本在执行的时候,不会有其他脚本和命令同时执行,这种语义类似于 MULTI/EXEC。从别…...

ES安装问题汇总

max file descriptors [4096] for elasticsearch process is too low, increase to at least [65535] 问题描述 ES启动报错。其原因是ES需要的的最小max file descriptors为65535,我们设置的是4096,需要增大max file descriptors的值。 解决方案 调大…...

煜邦转债,华设转债,兴瑞转债,神通转债上市价格预测

煜邦转债 基本信息 转债名称:煜邦转债,评级:A,发行规模:4.10806亿元。 正股名称:煜邦电力,今日收盘价:8.82元,转股价格:10.12元。 当前转股价值 转债面值 / …...

R语言生存分析算法的简单组合

library(survival) library(randomForestSRC)# 生成模拟数据 set.seed(123) n <- 200 time <- rexp(n, rate 0.1) status <- rbinom(n, size 1, prob 0.7) var1 <- rnorm(n) var2 <- rnorm(n) var3 <- rnorm(n) data1 <- data.frame(time time, statu…...

Qt应用开发(基础篇)——滚屏区域基类 QAbstractScrollArea

一、前言 QAbstractScrollArea滚屏区域抽象类继承于QFrame&#xff0c;QFrame继承于QWidget&#xff0c;是QListview(列表浏览器)、QTableview(表格浏览器)、QTextEdit(文本编辑器)、QTextBrowser(文本浏览器)等所有需要滚屏区域部件的抽象基类。 框架类QFrame介绍 QAbstractSc…...

HTTPS安全通信

HTTPS,TLS/SSL Hyper Text Transfer Protocol over Secure Socket Layer,安全的超文本传输协议,网景公式设计了SSL(Secure Sockets Layer)协议用于对Http协议传输的数据进行加密,保证会话过程中的安全性。 使用TCP端口默认为443 TLS:(Transport Layer Security,传输层…...

C语言暑假刷题冲刺篇——day1

目录 一、选择题 二、编程题 &#x1f388;个人主页&#xff1a;库库的里昂 &#x1f390;CSDN新晋作者 &#x1f389;欢迎 &#x1f44d;点赞✍评论⭐收藏✨收录专栏&#xff1a;C语言每日一练 ✨其他专栏&#xff1a;代码小游戏C语言初阶&#x1f91d;希望作者的文章能对你…...

trollcave靶场

配置 第一步&#xff1a;启动靶机时按下 shift 键&#xff0c; 进入以下界面 第二步&#xff1a;选择第二个选项&#xff0c;然后按下 e 键&#xff0c;进入编辑界面 将这里的ro修改为rw single init/bin/bash&#xff0c;然后按ctrlx&#xff0c;进入一个相当于控制台的界面…...

反馈式编译

一、 反馈式编译 简介 PGO&#xff0c;即Profile-Guided Optimizations&#xff0c;反馈式优化。PGO是编译器的又一优化技术&#xff0c;PGO与其它的一些优化技术/选项有一个明显的区别是&#xff1a;PGO优化是分三步完成的&#xff0c;是一个动态的优化过程。 反馈优化过…...

sql-libs靶场-----0x00、环境准备

文章目录 一、PhPstudy下载、安装二、Sqli-libs下载、搭建三、启用Sqli-libs phpstudy地址&#xff1a;https://www.xp.cn/ sqli-libs地址&#xff1a;https://github.com/Audi-1/sqli-labs 一、PhPstudy下载、安装 1、下载–解压–安装&#xff0c;安装完成如下图 2、更换php…...

一百四十九、Kettle——Linux上安装的kettle8.2创建共享资源库时遇到的问题(持续更新中)

一、目的 在kettle8.2在Linux上安装好可以启动界面、并且可以连接MySQL、Hive、ClickHouse等数据库后开始创建共享资源库&#xff0c;但是遇到了一些问题 二、Linux系统以及kettle版本 &#xff08;一&#xff09;Linux&#xff1a;CentOS 7 英文的图形化界面模式 &#…...

网络编程(Modbus进阶)

思维导图 Modbus RTU&#xff08;先学一点理论&#xff09; 概念 Modbus RTU 是工业自动化领域 最广泛应用的串行通信协议&#xff0c;由 Modicon 公司&#xff08;现施耐德电气&#xff09;于 1979 年推出。它以 高效率、强健性、易实现的特点成为工业控制系统的通信标准。 包…...

React hook之useRef

React useRef 详解 useRef 是 React 提供的一个 Hook&#xff0c;用于在函数组件中创建可变的引用对象。它在 React 开发中有多种重要用途&#xff0c;下面我将全面详细地介绍它的特性和用法。 基本概念 1. 创建 ref const refContainer useRef(initialValue);initialValu…...

工业安全零事故的智能守护者:一体化AI智能安防平台

前言&#xff1a; 通过AI视觉技术&#xff0c;为船厂提供全面的安全监控解决方案&#xff0c;涵盖交通违规检测、起重机轨道安全、非法入侵检测、盗窃防范、安全规范执行监控等多个方面&#xff0c;能够实现对应负责人反馈机制&#xff0c;并最终实现数据的统计报表。提升船厂…...

DockerHub与私有镜像仓库在容器化中的应用与管理

哈喽&#xff0c;大家好&#xff0c;我是左手python&#xff01; Docker Hub的应用与管理 Docker Hub的基本概念与使用方法 Docker Hub是Docker官方提供的一个公共镜像仓库&#xff0c;用户可以在其中找到各种操作系统、软件和应用的镜像。开发者可以通过Docker Hub轻松获取所…...

视频字幕质量评估的大规模细粒度基准

大家读完觉得有帮助记得关注和点赞&#xff01;&#xff01;&#xff01; 摘要 视频字幕在文本到视频生成任务中起着至关重要的作用&#xff0c;因为它们的质量直接影响所生成视频的语义连贯性和视觉保真度。尽管大型视觉-语言模型&#xff08;VLMs&#xff09;在字幕生成方面…...

汇编常见指令

汇编常见指令 一、数据传送指令 指令功能示例说明MOV数据传送MOV EAX, 10将立即数 10 送入 EAXMOV [EBX], EAX将 EAX 值存入 EBX 指向的内存LEA加载有效地址LEA EAX, [EBX4]将 EBX4 的地址存入 EAX&#xff08;不访问内存&#xff09;XCHG交换数据XCHG EAX, EBX交换 EAX 和 EB…...

Spring AI与Spring Modulith核心技术解析

Spring AI核心架构解析 Spring AI&#xff08;https://spring.io/projects/spring-ai&#xff09;作为Spring生态中的AI集成框架&#xff0c;其核心设计理念是通过模块化架构降低AI应用的开发复杂度。与Python生态中的LangChain/LlamaIndex等工具类似&#xff0c;但特别为多语…...

C++:多态机制详解

目录 一. 多态的概念 1.静态多态&#xff08;编译时多态&#xff09; 二.动态多态的定义及实现 1.多态的构成条件 2.虚函数 3.虚函数的重写/覆盖 4.虚函数重写的一些其他问题 1&#xff09;.协变 2&#xff09;.析构函数的重写 5.override 和 final关键字 1&#…...

永磁同步电机无速度算法--基于卡尔曼滤波器的滑模观测器

一、原理介绍 传统滑模观测器采用如下结构&#xff1a; 传统SMO中LPF会带来相位延迟和幅值衰减&#xff0c;并且需要额外的相位补偿。 采用扩展卡尔曼滤波器代替常用低通滤波器(LPF)&#xff0c;可以去除高次谐波&#xff0c;并且不用相位补偿就可以获得一个误差较小的转子位…...

Oracle11g安装包

Oracle 11g安装包 适用于windows系统&#xff0c;64位 下载路径 oracle 11g 安装包...