简单谈谈 EMP-SSL:自监督对比学习的一种极简主义风
论文链接:https://arxiv.org/pdf/2304.03977.pdf
代码:https://github.com/tsb0601/EMP-SSL
其他学习链接:突破自监督学习效率极限!马毅、LeCun联合发布EMP-SSL:无需花哨trick,30个epoch即可实现SOTA
主要思想
如图,一张图片裁剪成不同的 patch,对不同的 patch 做数据增强,分别输入 encoder,得到多个 embedding,对它们求均值,得到 作为这张图片的 embedding。最后,拉近每个 patch 的 embedding 和图片的 embedding(
)之间的余弦距离;再用 Total Coding Rate(TCR) 防止坍塌(即 encoder 对所有输入都输出相同的 embedding)


Total Coding Rate(TCR)
公式如下:

其中,det 表示求矩阵的行列式,d 是 feature vector 的 dimension,b 是 batch size
查了查该公式的含义:expand all features of Z as large as possible,即尽可能拉远矩阵中特征之间的距离。
源自 PPT 第 24 页:
https://s3.amazonaws.com/sf-web-assets-prod/wp-content/uploads/2021/06/15175515/Deep_Networks_from_First_Principles.pdf
至于为什么最大化该公式的值就可以拉远矩阵中特征之间的距离,这背后的数学原理真难啃啊 /(ㄒoㄒ)/~~
核心代码解读
数据处理
https://github.com/tsb0601/EMP-SSL/blob/main/dataset/aug.py#L116C1-L138C27
class ContrastiveLearningViewGenerator(object):def __init__(self, num_patch = 4):self.num_patch = num_patchdef __call__(self, x):normalize = transforms.Normalize([0.5,0.5,0.5], [0.5,0.5,0.5])aug_transform = transforms.Compose([transforms.RandomResizedCrop(32,scale=(0.25, 0.25), ratio=(1,1)),transforms.RandomHorizontalFlip(p=0.5),transforms.RandomApply([transforms.ColorJitter(0.4, 0.4, 0.4, 0.2)], p=0.8),transforms.RandomGrayscale(p=0.2),GBlur(p=0.1),transforms.RandomApply([Solarization()], p=0.1),transforms.ToTensor(), normalize])augmented_x = [aug_transform(x) for i in range(self.num_patch)]return augmented_x
由此看出返回的 数据 为:长度为 num_patches 个 tensor 的列表。其中,每个 tensor 的 shape 为 (B, C, H, W)。
主函数
https://github.com/tsb0601/EMP-SSL/blob/main/main.py#L148C9-L162C63
for step, (data, label) in tqdm(enumerate(dataloader)):net.zero_grad()opt.zero_grad()data = torch.cat(data, dim=0) data = data.cuda()z_proj = net(data)z_list = z_proj.chunk(num_patches, dim=0)z_avg = chunk_avg(z_proj, num_patches)# Contractive Lossloss_contract, _ = contractive_loss(z_list, z_avg)loss_TCR = cal_TCR(z_proj, criterion, num_patches)
这里要稍微注意一下几个变量的 shape:
- data 被 cat 完后:(num_patches * B,C,H,W)
- z_proj:(num_patches * B,C)
- z_list:(num_patches,B,C)
- z_avg:(B,C)
其中,chunk_avg 就是对来自同一张图片的不同 patch 的 embedding 求均值():
https://github.com/tsb0601/EMP-SSL/blob/main/main.py#L67
def chunk_avg(x,n_chunks=2,normalize=False):x_list = x.chunk(n_chunks,dim=0)x = torch.stack(x_list,dim=0)if not normalize:return x.mean(0)else:return F.normalize(x.mean(0),dim=1)
loss
contractive_loss 就是计算每个 patch 的 embedding 和均值()的余弦距离:
https://github.com/tsb0601/EMP-SSL/blob/main/main.py#L76
class Similarity_Loss(nn.Module):def __init__(self, ):super().__init__()passdef forward(self, z_list, z_avg):z_sim = 0num_patch = len(z_list)z_list = torch.stack(list(z_list), dim=0)z_avg = z_list.mean(dim=0)z_sim = 0for i in range(num_patch):z_sim += F.cosine_similarity(z_list[i], z_avg, dim=1).mean()z_sim = z_sim/num_patchz_sim_out = z_sim.clone().detach()return -z_sim, z_sim_out
TCR loss:最大化矩阵之间特征的距离,即拉远负样本(不是来自同一个样本的 patches)之间的距离
https://github.com/tsb0601/EMP-SSL/blob/main/main.py#L96
def cal_TCR(z, criterion, num_patches):z_list = z.chunk(num_patches,dim=0)loss = 0for i in range(num_patches):loss += criterion(z_list[i])loss = loss/num_patchesreturn loss
需要注意:函数输入的 z 是 z_proj,形状为(num_patches * B,C)。
所以,函数内部 z_list 的形状为(num_patches,B,C),即将数据分为了 num_patches 个组,每个组包含了来自不同图片里 patch 的 embedding。再分别对每个组求 TCR loss,最大化组内(不同图片的 patch)特征的距离。
所以,公式中的 指的是一组来自不同图片里 patch 的 embedding,形状为(B,C)。
每个组内求 TCR loss 的代码按照公式计算,如下:

https://github.com/tsb0601/EMP-SSL/blob/main/loss.py#L76
class TotalCodingRate(nn.Module):def __init__(self, eps=0.01):super(TotalCodingRate, self).__init__()self.eps = epsdef compute_discrimn_loss(self, W):"""Discriminative Loss."""p, m = W.shape #[d, B]I = torch.eye(p,device=W.device)scalar = p / (m * self.eps)logdet = torch.logdet(I + scalar * W.matmul(W.T))return logdet / 2.def forward(self,X):return - self.compute_discrimn_loss(X.T)
相关文章:
简单谈谈 EMP-SSL:自监督对比学习的一种极简主义风
论文链接:https://arxiv.org/pdf/2304.03977.pdf 代码:https://github.com/tsb0601/EMP-SSL 其他学习链接:突破自监督学习效率极限!马毅、LeCun联合发布EMP-SSL:无需花哨trick,30个epoch即可实现SOTA 主要…...
nginx的负载均衡
nginx的负载均衡 文章目录 nginx的负载均衡1.以多台虚拟机作服务器1.1 在不同的虚拟机上安装httpd服务1.2 在不同虚拟机所构建的服务端的默认路径下创建不同标识的文件1.3 使用windows本机的浏览器分别访问3台服务器的地址 2.在新的一台虚拟机上配置nginx实现反向代理以及负载均…...
linux系统服务学习(四)Linux系统下数据同步服务RSYNC
文章目录 Linux系统下数据同步服务RSYNC一、RSYNC概述1、什么是rsyncrsync的好姐妹数据同步过程 2、rsync特点3、rsync与scp的区别 二、RSYNC的使用1、基本语法2、本地文件同步3、远程文件同步思考:4、rsync作为系统服务Linux系统服务的思路: 三、任务解…...
走进 Linux
一、开关机 开机: 开机会启动许多程序。他们在windows叫做“服务”(service),在Linux就叫做“守护进程”(daemon)开机成功后,它会显示一个文本登录界面, 这个界面就是我们经常看到的登录界面,在这个登录界…...
Docker高级——Docker Swarm集群和部署应用
创建 Swarm 集群 初始化管理节点 [rootk8s-master ~]# docker swarm init --advertise-addr 192.168.192.133 Swarm initialized: current node (vy95txqo3pglh478e4qew1h28) is now a manager.To add a worker to this swarm, run the following command:docker swarm join …...
【SA8295P 源码分析】74 - QNX secpol 安全策略文件配置详解 及 secpol.bin 编译过程分析
【SA8295P 源码分析】74 - QNX secpol 安全策略文件配置详解 及 secpol.bin 编译过程分析 一、secpol 的编译流程:编译生成 secpol.bin 打包在 ifs2_la.img 中二、QNX 开启 secpol 功能三、为新进程 创建 新的secpol 安全策略:以 vmm_service 为例四、secpol 配置示例,以 I2…...
Docker入门使用
用一个hello world的小例子来入门docker 在 Docker 容器中部署 Python Flask 的简单 Hello World 项目,需要遵循以下流程: 编写应用程序 首先,在本地计算机上编写一个简单的 PythonFlask 应用程序,例如: # hello.…...
在SAP上使用 LiquidUI Android 扫描条形码/QR 码
LiquidUI Android 可使用安卓移动设备的内置摄像头扫描条形码和二维码,为输入框填充数值。因此,无需附加任何第三方设备进行扫描。 LiquidUI Android 还提供了扫描功能,如 Accessible-Enter(俗称自动输入)和 Accessib…...
Maven - 全面解析 Maven BOM (Bill of Materials):打造高效依赖管理与模块化开发
文章目录 Whats BOMWhy Bom常見的官方BOMSpring Maven BOM dependencySpringBoot SpringCloud Maven BOM dependencyJBOSS Maven BOM dependencyRESTEasy Maven BOM dependencyJersey Maven BOM dependency How Bom定义BOM其他工程使用的方法 BOM VS POM What’s BOM BOM&…...
Lua脚本对比redis事务区别是什么
redis官方对于lua脚本的解释:Redis使用同一个Lua解释器来执行所有命令,同时,Redis保证以一种原子性的方式来执行脚本:当lua脚本在执行的时候,不会有其他脚本和命令同时执行,这种语义类似于 MULTI/EXEC。从别…...
ES安装问题汇总
max file descriptors [4096] for elasticsearch process is too low, increase to at least [65535] 问题描述 ES启动报错。其原因是ES需要的的最小max file descriptors为65535,我们设置的是4096,需要增大max file descriptors的值。 解决方案 调大…...
煜邦转债,华设转债,兴瑞转债,神通转债上市价格预测
煜邦转债 基本信息 转债名称:煜邦转债,评级:A,发行规模:4.10806亿元。 正股名称:煜邦电力,今日收盘价:8.82元,转股价格:10.12元。 当前转股价值 转债面值 / …...
R语言生存分析算法的简单组合
library(survival) library(randomForestSRC)# 生成模拟数据 set.seed(123) n <- 200 time <- rexp(n, rate 0.1) status <- rbinom(n, size 1, prob 0.7) var1 <- rnorm(n) var2 <- rnorm(n) var3 <- rnorm(n) data1 <- data.frame(time time, statu…...
Qt应用开发(基础篇)——滚屏区域基类 QAbstractScrollArea
一、前言 QAbstractScrollArea滚屏区域抽象类继承于QFrame,QFrame继承于QWidget,是QListview(列表浏览器)、QTableview(表格浏览器)、QTextEdit(文本编辑器)、QTextBrowser(文本浏览器)等所有需要滚屏区域部件的抽象基类。 框架类QFrame介绍 QAbstractSc…...
HTTPS安全通信
HTTPS,TLS/SSL Hyper Text Transfer Protocol over Secure Socket Layer,安全的超文本传输协议,网景公式设计了SSL(Secure Sockets Layer)协议用于对Http协议传输的数据进行加密,保证会话过程中的安全性。 使用TCP端口默认为443 TLS:(Transport Layer Security,传输层…...
C语言暑假刷题冲刺篇——day1
目录 一、选择题 二、编程题 🎈个人主页:库库的里昂 🎐CSDN新晋作者 🎉欢迎 👍点赞✍评论⭐收藏✨收录专栏:C语言每日一练 ✨其他专栏:代码小游戏C语言初阶🤝希望作者的文章能对你…...
trollcave靶场
配置 第一步:启动靶机时按下 shift 键, 进入以下界面 第二步:选择第二个选项,然后按下 e 键,进入编辑界面 将这里的ro修改为rw single init/bin/bash,然后按ctrlx,进入一个相当于控制台的界面…...
反馈式编译
一、 反馈式编译 简介 PGO,即Profile-Guided Optimizations,反馈式优化。PGO是编译器的又一优化技术,PGO与其它的一些优化技术/选项有一个明显的区别是:PGO优化是分三步完成的,是一个动态的优化过程。 反馈优化过…...
sql-libs靶场-----0x00、环境准备
文章目录 一、PhPstudy下载、安装二、Sqli-libs下载、搭建三、启用Sqli-libs phpstudy地址:https://www.xp.cn/ sqli-libs地址:https://github.com/Audi-1/sqli-labs 一、PhPstudy下载、安装 1、下载–解压–安装,安装完成如下图 2、更换php…...
一百四十九、Kettle——Linux上安装的kettle8.2创建共享资源库时遇到的问题(持续更新中)
一、目的 在kettle8.2在Linux上安装好可以启动界面、并且可以连接MySQL、Hive、ClickHouse等数据库后开始创建共享资源库,但是遇到了一些问题 二、Linux系统以及kettle版本 (一)Linux:CentOS 7 英文的图形化界面模式 &#…...
如何用Hearthstone-Script解放炉石传说玩家双手?开源自动化工具全解析
如何用Hearthstone-Script解放炉石传说玩家双手?开源自动化工具全解析 【免费下载链接】Hearthstone-Script Hearthstone script(炉石传说脚本) 项目地址: https://gitcode.com/gh_mirrors/he/Hearthstone-Script 你是否也曾为炉石传说…...
百度网盘直链解析开源工具完全指南:从入门到精通
百度网盘直链解析开源工具完全指南:从入门到精通 【免费下载链接】baidu-wangpan-parse 获取百度网盘分享文件的下载地址 项目地址: https://gitcode.com/gh_mirrors/ba/baidu-wangpan-parse 你是否曾经历过这样的困扰:明明网络带宽充足ÿ…...
UABEA:Unity游戏资源编辑与分析的终极解决方案
UABEA:Unity游戏资源编辑与分析的终极解决方案 【免费下载链接】UABEA c# uabe for newer versions of unity 项目地址: https://gitcode.com/gh_mirrors/ua/UABEA 在Unity游戏开发和模组制作领域,处理Asset Bundle资源文件是每个开发者都会面临的…...
LumiPixel Canvas Quest光影艺术展:极致光影效果人像作品集
LumiPixel Canvas Quest光影艺术展:极致光影效果人像作品集 1. 光影艺术的数字革命 摄影圈最近有个热议话题:当AI开始玩光影,专业摄影师该紧张了吗?这场由LumiPixel Canvas Quest带来的光影艺术展,或许能给我们一些启…...
EcomGPT-7B电商大模型网络安全应用:智能识别钓鱼商品与欺诈文案
EcomGPT-7B电商大模型网络安全应用:智能识别钓鱼商品与欺诈文案 最近和几个做电商平台的朋友聊天,他们都在头疼同一个问题:平台上的商品和文案越来越“花”,有些商家为了引流,标题和描述写得天花乱坠,甚至…...
C++的std--chrono时间库与steady_clock在性能测量中的正确使用
在C高性能程序开发中,精确测量代码执行时间是优化和调试的关键环节。std::chrono时间库作为现代C的标准工具,提供了高精度、类型安全的计时能力,其中steady_clock因其单调递增的特性成为性能测量的首选。本文将深入解析其正确使用方式&#x…...
RMBG-2.0开源模型教程:微调BiRefNet适配特定行业(如医疗影像标记)
RMBG-2.0开源模型教程:微调BiRefNet适配特定行业(如医疗影像标记) 1. 项目概述与核心价值 RMBG-2.0(BiRefNet)是一个基于先进架构开发的图像背景扣除模型,能够精确识别并移除图像背景,保留高质…...
Qwen3-ASR性能优化:基于CNN的语音特征提取技术
Qwen3-ASR性能优化:基于CNN的语音特征提取技术 语音识别技术发展到今天,已经不再是实验室里的新奇玩具,而是我们日常生活中随处可见的实用工具。从手机语音助手到会议记录软件,从智能家居控制到车载语音交互,语音识别…...
AIGlasses OS Pro保姆级教程:从环境配置到四大模式实战体验
AIGlasses OS Pro保姆级教程:从环境配置到四大模式实战体验 1. 系统概述与核心价值 AIGlasses OS Pro是一款专为智能眼镜设计的本地化视觉辅助系统,它巧妙融合了YOLO11目标检测与MediaPipe骨骼识别两大引擎。与市面上依赖云服务的方案不同,…...
Kandinsky-5.0-I2V-Lite-5s惊艳效果展示:古风人物图→衣袖飘动+发带飞扬动态视频
Kandinsky-5.0-I2V-Lite-5s惊艳效果展示:古风人物图→衣袖飘动发带飞扬动态视频 1. 模型效果震撼开场 想象一下,你有一张精美的古风人物插画,画中女子衣袂飘飘、发带轻扬。现在,只需一个简单的操作,就能让这幅静态画…...
