简单谈谈 EMP-SSL:自监督对比学习的一种极简主义风
论文链接:https://arxiv.org/pdf/2304.03977.pdf
代码:https://github.com/tsb0601/EMP-SSL
其他学习链接:突破自监督学习效率极限!马毅、LeCun联合发布EMP-SSL:无需花哨trick,30个epoch即可实现SOTA
主要思想
如图,一张图片裁剪成不同的 patch,对不同的 patch 做数据增强,分别输入 encoder,得到多个 embedding,对它们求均值,得到 作为这张图片的 embedding。最后,拉近每个 patch 的 embedding 和图片的 embedding(
)之间的余弦距离;再用 Total Coding Rate(TCR) 防止坍塌(即 encoder 对所有输入都输出相同的 embedding)


Total Coding Rate(TCR)
公式如下:

其中,det 表示求矩阵的行列式,d 是 feature vector 的 dimension,b 是 batch size
查了查该公式的含义:expand all features of Z as large as possible,即尽可能拉远矩阵中特征之间的距离。
源自 PPT 第 24 页:
https://s3.amazonaws.com/sf-web-assets-prod/wp-content/uploads/2021/06/15175515/Deep_Networks_from_First_Principles.pdf
至于为什么最大化该公式的值就可以拉远矩阵中特征之间的距离,这背后的数学原理真难啃啊 /(ㄒoㄒ)/~~
核心代码解读
数据处理
https://github.com/tsb0601/EMP-SSL/blob/main/dataset/aug.py#L116C1-L138C27
class ContrastiveLearningViewGenerator(object):def __init__(self, num_patch = 4):self.num_patch = num_patchdef __call__(self, x):normalize = transforms.Normalize([0.5,0.5,0.5], [0.5,0.5,0.5])aug_transform = transforms.Compose([transforms.RandomResizedCrop(32,scale=(0.25, 0.25), ratio=(1,1)),transforms.RandomHorizontalFlip(p=0.5),transforms.RandomApply([transforms.ColorJitter(0.4, 0.4, 0.4, 0.2)], p=0.8),transforms.RandomGrayscale(p=0.2),GBlur(p=0.1),transforms.RandomApply([Solarization()], p=0.1),transforms.ToTensor(), normalize])augmented_x = [aug_transform(x) for i in range(self.num_patch)]return augmented_x
由此看出返回的 数据 为:长度为 num_patches 个 tensor 的列表。其中,每个 tensor 的 shape 为 (B, C, H, W)。
主函数
https://github.com/tsb0601/EMP-SSL/blob/main/main.py#L148C9-L162C63
for step, (data, label) in tqdm(enumerate(dataloader)):net.zero_grad()opt.zero_grad()data = torch.cat(data, dim=0) data = data.cuda()z_proj = net(data)z_list = z_proj.chunk(num_patches, dim=0)z_avg = chunk_avg(z_proj, num_patches)# Contractive Lossloss_contract, _ = contractive_loss(z_list, z_avg)loss_TCR = cal_TCR(z_proj, criterion, num_patches)
这里要稍微注意一下几个变量的 shape:
- data 被 cat 完后:(num_patches * B,C,H,W)
- z_proj:(num_patches * B,C)
- z_list:(num_patches,B,C)
- z_avg:(B,C)
其中,chunk_avg 就是对来自同一张图片的不同 patch 的 embedding 求均值():
https://github.com/tsb0601/EMP-SSL/blob/main/main.py#L67
def chunk_avg(x,n_chunks=2,normalize=False):x_list = x.chunk(n_chunks,dim=0)x = torch.stack(x_list,dim=0)if not normalize:return x.mean(0)else:return F.normalize(x.mean(0),dim=1)
loss
contractive_loss 就是计算每个 patch 的 embedding 和均值()的余弦距离:
https://github.com/tsb0601/EMP-SSL/blob/main/main.py#L76
class Similarity_Loss(nn.Module):def __init__(self, ):super().__init__()passdef forward(self, z_list, z_avg):z_sim = 0num_patch = len(z_list)z_list = torch.stack(list(z_list), dim=0)z_avg = z_list.mean(dim=0)z_sim = 0for i in range(num_patch):z_sim += F.cosine_similarity(z_list[i], z_avg, dim=1).mean()z_sim = z_sim/num_patchz_sim_out = z_sim.clone().detach()return -z_sim, z_sim_out
TCR loss:最大化矩阵之间特征的距离,即拉远负样本(不是来自同一个样本的 patches)之间的距离
https://github.com/tsb0601/EMP-SSL/blob/main/main.py#L96
def cal_TCR(z, criterion, num_patches):z_list = z.chunk(num_patches,dim=0)loss = 0for i in range(num_patches):loss += criterion(z_list[i])loss = loss/num_patchesreturn loss
需要注意:函数输入的 z 是 z_proj,形状为(num_patches * B,C)。
所以,函数内部 z_list 的形状为(num_patches,B,C),即将数据分为了 num_patches 个组,每个组包含了来自不同图片里 patch 的 embedding。再分别对每个组求 TCR loss,最大化组内(不同图片的 patch)特征的距离。
所以,公式中的 指的是一组来自不同图片里 patch 的 embedding,形状为(B,C)。
每个组内求 TCR loss 的代码按照公式计算,如下:

https://github.com/tsb0601/EMP-SSL/blob/main/loss.py#L76
class TotalCodingRate(nn.Module):def __init__(self, eps=0.01):super(TotalCodingRate, self).__init__()self.eps = epsdef compute_discrimn_loss(self, W):"""Discriminative Loss."""p, m = W.shape #[d, B]I = torch.eye(p,device=W.device)scalar = p / (m * self.eps)logdet = torch.logdet(I + scalar * W.matmul(W.T))return logdet / 2.def forward(self,X):return - self.compute_discrimn_loss(X.T)
相关文章:
简单谈谈 EMP-SSL:自监督对比学习的一种极简主义风
论文链接:https://arxiv.org/pdf/2304.03977.pdf 代码:https://github.com/tsb0601/EMP-SSL 其他学习链接:突破自监督学习效率极限!马毅、LeCun联合发布EMP-SSL:无需花哨trick,30个epoch即可实现SOTA 主要…...
nginx的负载均衡
nginx的负载均衡 文章目录 nginx的负载均衡1.以多台虚拟机作服务器1.1 在不同的虚拟机上安装httpd服务1.2 在不同虚拟机所构建的服务端的默认路径下创建不同标识的文件1.3 使用windows本机的浏览器分别访问3台服务器的地址 2.在新的一台虚拟机上配置nginx实现反向代理以及负载均…...
linux系统服务学习(四)Linux系统下数据同步服务RSYNC
文章目录 Linux系统下数据同步服务RSYNC一、RSYNC概述1、什么是rsyncrsync的好姐妹数据同步过程 2、rsync特点3、rsync与scp的区别 二、RSYNC的使用1、基本语法2、本地文件同步3、远程文件同步思考:4、rsync作为系统服务Linux系统服务的思路: 三、任务解…...
走进 Linux
一、开关机 开机: 开机会启动许多程序。他们在windows叫做“服务”(service),在Linux就叫做“守护进程”(daemon)开机成功后,它会显示一个文本登录界面, 这个界面就是我们经常看到的登录界面,在这个登录界…...
Docker高级——Docker Swarm集群和部署应用
创建 Swarm 集群 初始化管理节点 [rootk8s-master ~]# docker swarm init --advertise-addr 192.168.192.133 Swarm initialized: current node (vy95txqo3pglh478e4qew1h28) is now a manager.To add a worker to this swarm, run the following command:docker swarm join …...
【SA8295P 源码分析】74 - QNX secpol 安全策略文件配置详解 及 secpol.bin 编译过程分析
【SA8295P 源码分析】74 - QNX secpol 安全策略文件配置详解 及 secpol.bin 编译过程分析 一、secpol 的编译流程:编译生成 secpol.bin 打包在 ifs2_la.img 中二、QNX 开启 secpol 功能三、为新进程 创建 新的secpol 安全策略:以 vmm_service 为例四、secpol 配置示例,以 I2…...
Docker入门使用
用一个hello world的小例子来入门docker 在 Docker 容器中部署 Python Flask 的简单 Hello World 项目,需要遵循以下流程: 编写应用程序 首先,在本地计算机上编写一个简单的 PythonFlask 应用程序,例如: # hello.…...
在SAP上使用 LiquidUI Android 扫描条形码/QR 码
LiquidUI Android 可使用安卓移动设备的内置摄像头扫描条形码和二维码,为输入框填充数值。因此,无需附加任何第三方设备进行扫描。 LiquidUI Android 还提供了扫描功能,如 Accessible-Enter(俗称自动输入)和 Accessib…...
Maven - 全面解析 Maven BOM (Bill of Materials):打造高效依赖管理与模块化开发
文章目录 Whats BOMWhy Bom常見的官方BOMSpring Maven BOM dependencySpringBoot SpringCloud Maven BOM dependencyJBOSS Maven BOM dependencyRESTEasy Maven BOM dependencyJersey Maven BOM dependency How Bom定义BOM其他工程使用的方法 BOM VS POM What’s BOM BOM&…...
Lua脚本对比redis事务区别是什么
redis官方对于lua脚本的解释:Redis使用同一个Lua解释器来执行所有命令,同时,Redis保证以一种原子性的方式来执行脚本:当lua脚本在执行的时候,不会有其他脚本和命令同时执行,这种语义类似于 MULTI/EXEC。从别…...
ES安装问题汇总
max file descriptors [4096] for elasticsearch process is too low, increase to at least [65535] 问题描述 ES启动报错。其原因是ES需要的的最小max file descriptors为65535,我们设置的是4096,需要增大max file descriptors的值。 解决方案 调大…...
煜邦转债,华设转债,兴瑞转债,神通转债上市价格预测
煜邦转债 基本信息 转债名称:煜邦转债,评级:A,发行规模:4.10806亿元。 正股名称:煜邦电力,今日收盘价:8.82元,转股价格:10.12元。 当前转股价值 转债面值 / …...
R语言生存分析算法的简单组合
library(survival) library(randomForestSRC)# 生成模拟数据 set.seed(123) n <- 200 time <- rexp(n, rate 0.1) status <- rbinom(n, size 1, prob 0.7) var1 <- rnorm(n) var2 <- rnorm(n) var3 <- rnorm(n) data1 <- data.frame(time time, statu…...
Qt应用开发(基础篇)——滚屏区域基类 QAbstractScrollArea
一、前言 QAbstractScrollArea滚屏区域抽象类继承于QFrame,QFrame继承于QWidget,是QListview(列表浏览器)、QTableview(表格浏览器)、QTextEdit(文本编辑器)、QTextBrowser(文本浏览器)等所有需要滚屏区域部件的抽象基类。 框架类QFrame介绍 QAbstractSc…...
HTTPS安全通信
HTTPS,TLS/SSL Hyper Text Transfer Protocol over Secure Socket Layer,安全的超文本传输协议,网景公式设计了SSL(Secure Sockets Layer)协议用于对Http协议传输的数据进行加密,保证会话过程中的安全性。 使用TCP端口默认为443 TLS:(Transport Layer Security,传输层…...
C语言暑假刷题冲刺篇——day1
目录 一、选择题 二、编程题 🎈个人主页:库库的里昂 🎐CSDN新晋作者 🎉欢迎 👍点赞✍评论⭐收藏✨收录专栏:C语言每日一练 ✨其他专栏:代码小游戏C语言初阶🤝希望作者的文章能对你…...
trollcave靶场
配置 第一步:启动靶机时按下 shift 键, 进入以下界面 第二步:选择第二个选项,然后按下 e 键,进入编辑界面 将这里的ro修改为rw single init/bin/bash,然后按ctrlx,进入一个相当于控制台的界面…...
反馈式编译
一、 反馈式编译 简介 PGO,即Profile-Guided Optimizations,反馈式优化。PGO是编译器的又一优化技术,PGO与其它的一些优化技术/选项有一个明显的区别是:PGO优化是分三步完成的,是一个动态的优化过程。 反馈优化过…...
sql-libs靶场-----0x00、环境准备
文章目录 一、PhPstudy下载、安装二、Sqli-libs下载、搭建三、启用Sqli-libs phpstudy地址:https://www.xp.cn/ sqli-libs地址:https://github.com/Audi-1/sqli-labs 一、PhPstudy下载、安装 1、下载–解压–安装,安装完成如下图 2、更换php…...
一百四十九、Kettle——Linux上安装的kettle8.2创建共享资源库时遇到的问题(持续更新中)
一、目的 在kettle8.2在Linux上安装好可以启动界面、并且可以连接MySQL、Hive、ClickHouse等数据库后开始创建共享资源库,但是遇到了一些问题 二、Linux系统以及kettle版本 (一)Linux:CentOS 7 英文的图形化界面模式 &#…...
web vue 项目 Docker化部署
Web 项目 Docker 化部署详细教程 目录 Web 项目 Docker 化部署概述Dockerfile 详解 构建阶段生产阶段 构建和运行 Docker 镜像 1. Web 项目 Docker 化部署概述 Docker 化部署的主要步骤分为以下几个阶段: 构建阶段(Build Stage):…...
阿里云ACP云计算备考笔记 (5)——弹性伸缩
目录 第一章 概述 第二章 弹性伸缩简介 1、弹性伸缩 2、垂直伸缩 3、优势 4、应用场景 ① 无规律的业务量波动 ② 有规律的业务量波动 ③ 无明显业务量波动 ④ 混合型业务 ⑤ 消息通知 ⑥ 生命周期挂钩 ⑦ 自定义方式 ⑧ 滚的升级 5、使用限制 第三章 主要定义 …...
从零实现富文本编辑器#5-编辑器选区模型的状态结构表达
先前我们总结了浏览器选区模型的交互策略,并且实现了基本的选区操作,还调研了自绘选区的实现。那么相对的,我们还需要设计编辑器的选区表达,也可以称为模型选区。编辑器中应用变更时的操作范围,就是以模型选区为基准来…...
【大模型RAG】Docker 一键部署 Milvus 完整攻略
本文概要 Milvus 2.5 Stand-alone 版可通过 Docker 在几分钟内完成安装;只需暴露 19530(gRPC)与 9091(HTTP/WebUI)两个端口,即可让本地电脑通过 PyMilvus 或浏览器访问远程 Linux 服务器上的 Milvus。下面…...
土地利用/土地覆盖遥感解译与基于CLUE模型未来变化情景预测;从基础到高级,涵盖ArcGIS数据处理、ENVI遥感解译与CLUE模型情景模拟等
🔍 土地利用/土地覆盖数据是生态、环境和气象等诸多领域模型的关键输入参数。通过遥感影像解译技术,可以精准获取历史或当前任何一个区域的土地利用/土地覆盖情况。这些数据不仅能够用于评估区域生态环境的变化趋势,还能有效评价重大生态工程…...
【C++从零实现Json-Rpc框架】第六弹 —— 服务端模块划分
一、项目背景回顾 前五弹完成了Json-Rpc协议解析、请求处理、客户端调用等基础模块搭建。 本弹重点聚焦于服务端的模块划分与架构设计,提升代码结构的可维护性与扩展性。 二、服务端模块设计目标 高内聚低耦合:各模块职责清晰,便于独立开发…...
Android第十三次面试总结(四大 组件基础)
Activity生命周期和四大启动模式详解 一、Activity 生命周期 Activity 的生命周期由一系列回调方法组成,用于管理其创建、可见性、焦点和销毁过程。以下是核心方法及其调用时机: onCreate() 调用时机:Activity 首次创建时调用。…...
html css js网页制作成品——HTML+CSS榴莲商城网页设计(4页)附源码
目录 一、👨🎓网站题目 二、✍️网站描述 三、📚网站介绍 四、🌐网站效果 五、🪓 代码实现 🧱HTML 六、🥇 如何让学习不再盲目 七、🎁更多干货 一、👨…...
AGain DB和倍数增益的关系
我在设置一款索尼CMOS芯片时,Again增益0db变化为6DB,画面的变化只有2倍DN的增益,比如10变为20。 这与dB和线性增益的关系以及传感器处理流程有关。以下是具体原因分析: 1. dB与线性增益的换算关系 6dB对应的理论线性增益应为&…...
IP如何挑?2025年海外专线IP如何购买?
你花了时间和预算买了IP,结果IP质量不佳,项目效率低下不说,还可能带来莫名的网络问题,是不是太闹心了?尤其是在面对海外专线IP时,到底怎么才能买到适合自己的呢?所以,挑IP绝对是个技…...
