MoCo 算法阅读记录
论文地址:🐰
何凯明大神之作,通过无监督对比学习预训练Image Encoder的表征能力。后也被许多VLP算法作为ITC的底层算法来使用。
一方面由于源代码本身并不复杂,但是要求多GPU分布式训练,以及需要下载ImageNet这个大规模的数据集;另一方面 本次只是测试和阅读算法原理的实现,并不完整使用。因此,重写了一个低配版(流程不变,超参数没有严格要求设置,单GPU跑,数据集自己配置,几十张图片, no Shuffling BN)。
queue 即文中所构建的字典,这个好比如 C++ 中 的queue 容器,因为它是一种先进先出的数据结构。
目录
一、数据预处理
二、前向传播
网络结构
算法流程
一、数据预处理
对同一张图片进行数据增强操作,得到 query 和 key。
增强操作包括
transforms.RandomResizedCrop(224, scale=(0.2, 1.0)),transforms.RandomGrayscale(p=0.2),transforms.ColorJitter(0.4, 0.4, 0.4, 0.4),transforms.RandomHorizontalFlip(),normalize,
所以,dataloader中的每个输入样本是一个样本对儿。
通过下列方法实现
class TwoCropsTransform:"""Take two random crops of one image as the query and key."""def __init__(self, base_transform):self.base_transform = base_transformdef __call__(self, x):q = self.base_transform(x)k = self.base_transform(x)return [q, k]
二、前向传播
网络结构
代码中 encoder q 和 encoder k的网络结构用的都是ReNet 。ResNet最终的输出层包含了
(avgpool): AdaptiveAvgPool2d(output_size=(1, 1))(fc): Linear(in_features=2048, out_features=128, bias=True)
所以,输出的特征向量维度为 (N,C)。N为文中的Mini batch大小,代码中的超参数为batch size。C应该没有什么具体的含义,只是经验的设置为这一长度了(没找出来C的大小关乎什么)。
其输出还经过了L2归一化。
算法流程
1、 q 送入 encoder q 得到输出,并经过L2归一化, (N,C)
2、 momentum 更新 key encoder。
3、 Shuffling BN(当然我重写的代码并没有实现这个,因为它需要多GPU,但这并不妨碍认识它的作用)
文中所述
大致意思由于ResNet使用了BN操作,因此由于Batch 数据之间的交互,使得模型利用它欺骗预设任务从而简单的找到一个低损失的解决方案,然而这个解决方案效果并不好,使得模型学习不到好的表征能力。
其提出的Shuffling BN
首先,把所有进程的Tensor的收集起来(如果分布式训练,一般每个GPU包含一个进程,所以收集的数据总量大小为 num GPUs * batch size),参考这里🤖
x_gather = concat_all_gather(x)
接下来制作打乱的索引,整个过程如下所示
def _batch_shuffle_ddp(self, x):"""Batch shuffle, for making use of BatchNorm.*** Only support DistributedDataParallel (DDP) model. ***"""# gather from all gpusbatch_size_this = x.shape[0]x_gather = concat_all_gather(x) # 将所有进程的数据收集起来batch_size_all = x_gather.shape[0]num_gpus = batch_size_all // batch_size_this# random shuffle indexidx_shuffle = torch.randperm(batch_size_all).cuda() # torch.randperm 将[0,n)数随机排列# broadcast to all gpustorch.distributed.broadcast(idx_shuffle, src=0) # 将这个信息广播到所有其他进程# index for restoringidx_unshuffle = torch.argsort(idx_shuffle) # 按照值大小顺序返回下标# shuffled index for this gpugpu_idx = torch.distributed.get_rank() # 返回当前的进程idx_this = idx_shuffle.view(num_gpus, -1)[gpu_idx] # idx_shuffle view 后 (num_gpus, batch size) 但是batch size中的索引是打乱顺序的return x_gather[idx_this], idx_unshuffle
最终返回 随机打乱顺序后挑选的当前进程的 batch size 大小的数据,也就是说进行 BN归一化后的数据已经不在 同一个原来的批 中了。
4、k 送入 encoder k 中,在经过L2 归一化, 和q一样。 (N,C)
5、Shuffling BN 对齐 q 和 k
如下面举例
# idx_shuffle tensor([10, 16, 13, 2, 4, 0, 6, 21, 22, 31, 29, 3, 19, 17, 14, 30, 28, 12,24, 26, 8, 25, 11, 18, 5, 7, 27, 1, 15, 23, 20, 9])# idx_unshuffle tensor([ 5, 27, 3, 11, 4, 24, 6, 25, 20, 31, 0, 22, 17, 2, 14, 28, 1, 13,23, 12, 30, 7, 8, 29, 18, 21, 19, 26, 16, 10, 15, 9])# q 的 idx_this tensor([10, 16, 13, 2, 4, 0, 6, 21])# k 的 idx_this tensor([ 5, 27, 3, 11, 4, 24, 6, 25])
这里主要关注的点是 这步是为了使 k对齐打乱顺序的q。q之前是打乱了顺序从而改变了每个batch的内容,相当于从所有的batch中随机挑选了 batch size的q,从而保证去除BN的影响。
而 k 不需要 再打乱了, 只需要从原有的batch size 数据分布中挑选出与q对应的数据即可。所以才在 shuffle BN q的过程中记录了indx unshuffle。
这里的对应关系举例,比如 index shuffle 中的 0 现在位于原来没打乱状态的索引 5处, 类似的 1 -->27, 2-->3, 以此类推。
注:不要被上面单进程的(即idx this)不对齐所迷惑,上面的只是分进程处理的,分布式训练最终会把所有进程的数据拼接起来一起处理,所以所有进程的数据对齐就行。
6、计算损失,即文中公式1
其中 用到的计算方法举例如下,分别用爱因斯坦求和公式实现,参考这里🤖
a = torch.tensor([[1, 2, 3], [1, 1, 1], [2, 2, 2]]) b = torch.tensor([[2, 2, 2], [2, 2, 2], [1, 1, 1]]) print(a) print(b) c = torch.einsum("nc, nc->n", [a, b]) # (3) d = c.unsqueeze(-1) # (3,1) print(c)#=== 输出 tensor([[1, 2, 3],[1, 1, 1],[2, 2, 3]]) tensor([[2, 2, 2],[2, 2, 2],[1, 1, 1]]) tensor([12, 6, 7]) tensor([[12],[ 6],[ 7]])
a = torch.tensor([[1, 2, 3], [1, 1, 1], [2, 2, 3]]) # (3,3) a1 = torch.tensor([[1, 2], [1, 1], [2, 2]]) # (3,2) c = torch.einsum("nc,ck->nk", [a, a1]) print(a) print(a1) print(c)# ===输出 tensor([[1, 2, 3],[1, 1, 1],[2, 2, 3]]) tensor([[1, 2],[1, 1],[2, 2]]) tensor([[ 9, 10],[ 4, 5],[10, 12]])
这里的self.queue 即文中的字典 queue,初始化为
self.register_buffer("queue", torch.randn(dim, K)) self.queue = nn.functional.normalize(self.queue, dim=0)
K为字典的长度,默认设置65536。这里为什么设置为这个可能是由于ImageNet数据集比较大,所以设置的字典比较长,具体的长度设置好像没有做固定的要求,
来源于github官网。但代码中有要求,K必须是batch size 的倍数,这个为了确保字典的更新,方便执行入栈和弹出操作。这个字典像是C++的 queue容器的FIFO数据结构,即先进先出
self.K % batch_size == 0
l_pos = torch.einsum("nc,nc->n", [q, k]).unsqueeze(-1) # (8,1) 对应元素相乘并第一维加和# negative logits: NxKl_neg = torch.einsum("nc,ck->nk", [q, self.queue.clone().detach()]) # (8,65536) 矩阵相乘# logits: Nx(1+K)logits = torch.cat([l_pos, l_neg], dim=1) # (8,65537)# apply temperaturelogits /= self.Tlabels = torch.zeros(logits.shape[0], dtype=torch.long).cuda() # (8,)loss = criterion(output, target)
这里看标签都是0,即第一个也就是0维数据为正样本。因为在拼接cat的时候正样本是在前面的。
7、更新字典
按mini batch 更新。具体地,如果 训练次数*mini batch size 小于字典长度,则字典queue每次都会填充新的key。若训练次数*mini batch size 大于 字典长度,则之前的被替换掉。
ptr = (ptr + batch_size) % self.K # move pointer 8
相关文章:

MoCo 算法阅读记录
论文地址:🐰 何凯明大神之作,通过无监督对比学习预训练Image Encoder的表征能力。后也被许多VLP算法作为ITC的底层算法来使用。 一方面由于源代码本身并不复杂,但是要求多GPU分布式训练,以及需要下载ImageNet这个大规模…...

华为OD机试 - 数组连续和 - 滑动窗口(Java 2024 C卷 100分)
华为OD机试 2024C卷题库疯狂收录中,刷题点这里 专栏导读 本专栏收录于《华为OD机试(JAVA)真题(A卷B卷C卷)》。 刷的越多,抽中的概率越大,每一题都有详细的答题思路、详细的代码注释、样例测试…...

微店micro获得微店micro商品详情,API接口封装系列
微店商品详情API接口封装系列主要涉及注册账号、获取API密钥、选择API接口、发送请求以及处理响应等步骤。以下是详细的流程: 请求示例,API接口接入Anzexi58 一、注册账号并获取API密钥 首先,你需要在微店开放平台注册一个账号。注册成功后…...

C语言中的数据结构--链表的应用1(2)
前言 上一节我们学习了链表的概念以及链表的实现,那么本节我们就来了解一下链表具体有什么用,可以解决哪些实质性的问题,我们借用习题来加强对链表的理解,那么废话不多说,我们正式进入今天的学习 单链表相关经典算法O…...

.Net6 使用Autofac进行依赖注入
一、背景 刚接触.net 6,记录一下在.net6上是怎么使用Autofac进行动态的依赖注入的 二、注入方式 1、新建一个webapi项目,框架选择net 6 2、引用Nuget包---Autofac.Extensions.Dependency 3、在Program.cs上添加如下代码 //依赖注入 builder.Host.Us…...

第十二届蓝桥杯省赛真题(C/C++大学B组)
目录 #A 空间 #B 卡片 #C 直线 #D 货物摆放 #E 路径 #F 时间显示 #G 砝码称重 #H 杨辉三角形 #I 双向排序 #J 括号序列 #A 空间 #include <bits/stdc.h> using namespace std;int main() {cout<<256 * 1024 * 1024 / 4<<endl;return 0; } #B 卡片…...

DC40V降压恒压芯片H4120 40V转5V 3A 40V降压12V 车充降压恒压控制器
同步整流恒压芯片在现代电子设备中发挥着重要作用,为各种设备提供了稳定、高效的电源管理解决方案。 同步整流恒压芯片是一种电源管理芯片,它能够在不同电压输入条件下保持输出电压恒定。这种芯片广泛应用于各种电子设备中,如通讯设备、液晶…...

2、Qt UI控件 -- qucsdk项目使用
前言:上一篇文章讲了qucsdk的环境部署,可以在QDesigner和Qt Creator中看到qucsdk控件,这一篇来讲下在项目中使用qucsdk库中的控件。 一、准备材料 要想使用第三方库,需要三个先决条件, 1、控件的头文件 2、动/静态链…...

MATLAB算法实战应用案例精讲-【人工智能】AIGC概念三部曲(三)
目录 前言 算法原理 大模型 什么是AIGC? AIGC和Chat GPT的关系 常见的AIGC应用...

外汇110:外汇交易不同货币类别及交易注意事项!
外汇市场是一个庞大而复杂的市场,其中有各种各样的货币品种。对于外汇投资者来说,了解外汇品种的特性和走势是比较重要的。1. 货币种类 外汇市场中的货币品种可以分为主要货币、次要货币和外围货币。 主要货币:主要指美元、欧元、英镑、日元、…...

gerrit 拉取失败
在浏览器gerrit的设置界面设置的邮箱地址和在命令行使用git config --gloable user.email设置的邮箱地址必须保持一致吗 在浏览器gerrit的设置界面设置的邮箱地址和在命令行使用git config --global user.email设置的邮箱地址并不一定需要保持一致。这两个邮箱地址是独立的&am…...

大数据行业英语单词巩固20240410
20240410 Communication - 沟通 Example: Effective communication is essential for project success. 有效的沟通对于项目的成功至关重要。 Collaboration - 协作 Example: Team collaboration is crucial in achieving our goals. 团队协作对于实现我们的目标至关重要。 …...

天软特色因子看板 (2024.4 第3期)
该因子看板跟踪天软特色因子A05005(近一月单笔流出金额占比(%),该因子为近一月单笔流出金额占比(% 均值因子,用以刻画下跌时的 单成交中可能存在的抄底现象 今日为该因子跟踪第3期,跟踪其在SH000852 (中证1000) 中的表现,要点如下…...

使用QT 开发不规则窗体
使用QT 开发不规则窗体 不规则窗体贴图法的不规则窗体创建UI模板创建一个父类创建业务窗体main函数直接调用user_dialog创建QSS文件 完整的QT工程 不规则窗体 QT中开发不规则窗体有两种方法:(1)第一种方法,使用QWidget::setMask函…...

如何构建企业经营所需的商业智能(BI)能力
构建企业经营所需的商业智能(BI)能力是一项涉及诸多关键环节与细致考量的系统工程,通过科学的数据处理、分析与应用,赋能企业实现精准决策,提升运营效率,优化业务流程,并在竞争激烈的市场环境中…...

【vue】watch监听取不到this指向的数?
今天同事问我,watch里this指向的数值,别的地方却可以打印出来。工具也能看到数值,但打印出来却是undifined,先看看代码: 懒得打字了直接上截图吧 ps: 在Vue组件中,如果你在watch选项中访问this…...

Ubuntu-22.04安装VMware虚拟机并安装Windows10
提示:文章写完后,目录可以自动生成,如何生成可参考右边的帮助文档 文章目录 前言一、VMware是什么?二、安装VMware1.注册VMware账号2.下载虚拟机3.编译vmmon&vmnet4.加载module5.安装bundle 三、安装Windows101.基础配置2.进阶…...

ELK企业日志分析系统介绍
前言 随着企业级应用系统日益复杂,随之产生的海量日志数据。传统的日志管理和分析手段,难以做到高效检索、实时监控以及深度挖掘潜在价值。在此背景下,ELK日志分析系统应运而生。本文将从ELK 日志分析系统的原理、架构及其在实践中的应用做相…...

在C#中读取写入字节流与读取写入二进制数据, 有何差异?
在C#中,读取和写入字节流与读取和写入二进制数据有些许不同,尽管它们在某些情况下可能会重叠使用。以下是它们之间的主要区别: 读取和写入字节流: 读取和写入字节流通常指的是处理文件或流中的原始字节数据。在C#中,可…...

数据库相关知识总结
一、数据库三级模式 三个抽象层次: 1. 视图层:最高层次的抽象,描述整个数据库的某个部分的数据 2. 逻辑层:描述数据库中存储的数据以及这些数据存在的关联 3. 物理层:最低层次的抽象,描述数据在存储器中时如…...

【汇编语言实战】输出数组中特定元素
C语言描述: #include <stdio.h> int main() {int a[]{1,2,3,4,5,6};printf("%d",a[3]); }汇编语言: include irvine32.inc .data arr dword 1,2,3,4,5,6 num dword 1 ;输出第二个元素 .code main proc mov esi,offset arr mov edx,nu…...

WordPress LayerSlider插件SQL注入漏洞复现(CVE-2024-2879)
0x01 产品简介 WordPress插件LayerSlider是一款可视化网页内容编辑器、图形设计软件和数字视觉效果应用程序,全球活跃安装量超过 1,000,000 次。 0x02 漏洞概述 WordPress LayerSlider插件版本7.9.11 – 7.10.0中,由于对用户提供的参数转义不充分以及缺少wpdb::prepare(),…...

MOS管的判别符号记忆与导通条件
参考链接 MOS管的判别与导通条件 (qq.com)https://mp.weixin.qq.com/s?__bizMzU3MDU1Mzg2OQ&mid2247520228&idx1&sn5996780179fbf01f66b5db0c71622ac3&chksmfcef6c86cb98e590e3d3734ee27797bdded17b6b648b3b0d3b1599e8a4496a1fa4e457be6516&mpshare1&…...

数据指标与经营智慧:构建有洞见的经营分析报告
经营分析报告不仅仅是数字的堆砌,它是企业运营状况的“晴雨表”,能够反映企业的健康状况和发展潜力。一个有洞见的经营分析报告能够帮助管理层识别问题、评估风险、发现机会,并据此制定相应的战略和行动计划。 关注【数据化运营圈】共同探讨…...

Spring 中类似 aBbb 单字母单词序列化与反序列问题
文章目录 前言代码准备问题排查lombok自定义生成 get、set 结合源码解析使用 lombok使用 lombok 自定义生成 user 对象 get、set 方法 如何解决使用注解 JsonProperty("aTest")自定义实现符合 Spring 规范的 get set 方法 个人简介 前言 最近在使用 spring boot mvc…...

TiDB 慢查询日志分析
导读 TiDB 中的慢查询日志是一项 关键的性能监控工具,其主要作用在于协助数据库管理员追踪执行时间较长的 SQL 查询语句。 通过记录那些超过设定阈值的查询,慢查询日志为性能优化提供了关键的线索,有助于发现潜在的性能瓶颈,优化…...

网页文件批量下载工具有哪些 网页文件批量下载工具推荐 IDM免费激活 网络下载加速器
把任务丢给软件,把时间还给自己,批量下载功能让下载变得更高效。它可以有效减少重复性操作,只需要一次简单的设置,就能把大量文件下载到电脑。有关网页文件批量下载工具有哪些,网页文件批量下载工具推荐的问题…...

嵌入式算法开发系列之图像处理算法
嵌入式系统中的图像处理算法及其应用 文章目录 嵌入式系统中的图像处理算法及其应用前言一、图像处理算法的原理二、图像处理算法的应用三、C 语言实现总结 前言 在嵌入式系统中,图像处理算法是一项重要的技术,用于实现各种视觉应用,如机器视…...

HarmonyOS4-ArkUI组件动画
一、ArkUI组件属性动画和显示动画 显示动画: 案例:上下左右箭头控制小鱼的游动 具体代码如下: import router from ohos.routerEntry Component struct AnimationPage {// 小鱼坐标State fishX: number 200State fishY: number 180// 小鱼…...

模块化——如何导入模块?(内置模块与自定义模块)
在Node.js中,要导入另一个模块,我们可以使用require函数。这个函数接受一个文件路径参数,并返回导入的模块。 注意:require导入包场景:内置模块、自定义模块、npm包的导入... 下面介绍内置模块与自定义模块。npm包的…...