当前位置: 首页 > news >正文

MoCo 算法阅读记录

论文地址:🐰

何凯明大神之作,通过无监督对比学习预训练Image Encoder的表征能力。后也被许多VLP算法作为ITC的底层算法来使用。

一方面由于源代码本身并不复杂,但是要求多GPU分布式训练,以及需要下载ImageNet这个大规模的数据集;另一方面 本次只是测试和阅读算法原理的实现,并不完整使用。因此,重写了一个低配版(流程不变,超参数没有严格要求设置,单GPU跑,数据集自己配置,几十张图片, no Shuffling BN)。

queue 即文中所构建的字典,这个好比如 C++ 中 的queue 容器,因为它是一种先进先出的数据结构。

目录

一、数据预处理

二、前向传播

网络结构

算法流程


一、数据预处理

对同一张图片进行数据增强操作,得到 query 和 key。

增强操作包括

transforms.RandomResizedCrop(224, scale=(0.2, 1.0)),transforms.RandomGrayscale(p=0.2),transforms.ColorJitter(0.4, 0.4, 0.4, 0.4),transforms.RandomHorizontalFlip(),normalize,

所以,dataloader中的每个输入样本是一个样本对儿。

通过下列方法实现

class TwoCropsTransform:"""Take two random crops of one image as the query and key."""def __init__(self, base_transform):self.base_transform = base_transformdef __call__(self, x):q = self.base_transform(x)k = self.base_transform(x)return [q, k]

二、前向传播

网络结构

代码中 encoder q 和 encoder k的网络结构用的都是ReNet 。ResNet最终的输出层包含了

(avgpool): AdaptiveAvgPool2d(output_size=(1, 1))(fc): Linear(in_features=2048, out_features=128, bias=True)

所以,输出的特征向量维度为 (N,C)。N为文中的Mini batch大小,代码中的超参数为batch size。C应该没有什么具体的含义,只是经验的设置为这一长度了(没找出来C的大小关乎什么)。

其输出还经过了L2归一化。 

算法流程

1、 q 送入 encoder q 得到输出,并经过L2归一化, (N,C)

2、 momentum 更新 key encoder。

3、 Shuffling BN(当然我重写的代码并没有实现这个,因为它需要多GPU,但这并不妨碍认识它的作用)

文中所述

大致意思由于ResNet使用了BN操作,因此由于Batch 数据之间的交互,使得模型利用它欺骗预设任务从而简单的找到一个低损失的解决方案,然而这个解决方案效果并不好,使得模型学习不到好的表征能力。

其提出的Shuffling BN

首先,把所有进程的Tensor的收集起来(如果分布式训练,一般每个GPU包含一个进程,所以收集的数据总量大小为 num GPUs * batch size),参考这里🤖

x_gather = concat_all_gather(x)

接下来制作打乱的索引,整个过程如下所示

    def _batch_shuffle_ddp(self, x):"""Batch shuffle, for making use of BatchNorm.*** Only support DistributedDataParallel (DDP) model. ***"""# gather from all gpusbatch_size_this = x.shape[0]x_gather = concat_all_gather(x)  # 将所有进程的数据收集起来batch_size_all = x_gather.shape[0]num_gpus = batch_size_all // batch_size_this# random shuffle indexidx_shuffle = torch.randperm(batch_size_all).cuda()  # torch.randperm 将[0,n)数随机排列# broadcast to all gpustorch.distributed.broadcast(idx_shuffle, src=0)  # 将这个信息广播到所有其他进程# index for restoringidx_unshuffle = torch.argsort(idx_shuffle)  # 按照值大小顺序返回下标# shuffled index for this gpugpu_idx = torch.distributed.get_rank()  # 返回当前的进程idx_this = idx_shuffle.view(num_gpus, -1)[gpu_idx]  # idx_shuffle view 后 (num_gpus, batch size) 但是batch size中的索引是打乱顺序的return x_gather[idx_this], idx_unshuffle

最终返回 随机打乱顺序后挑选的当前进程的 batch size 大小的数据,也就是说进行 BN归一化后的数据已经不在 同一个原来的批 中了。

4、k 送入 encoder k 中,在经过L2 归一化, 和q一样。  (N,C)

5、Shuffling BN 对齐 q 和 k

如下面举例

# idx_shuffle
tensor([10, 16, 13,  2,  4,  0,  6, 21, 22, 31, 29,  3, 19, 17, 14, 30, 28, 12,24, 26,  8, 25, 11, 18,  5,  7, 27,  1, 15, 23, 20,  9])# idx_unshuffle
tensor([ 5, 27,  3, 11,  4, 24,  6, 25, 20, 31,  0, 22, 17,  2, 14, 28,  1, 13,23, 12, 30,  7,  8, 29, 18, 21, 19, 26, 16, 10, 15,  9])# q 的 idx_this
tensor([10, 16, 13,  2,  4,  0,  6, 21])# k 的 idx_this
tensor([ 5, 27,  3, 11,  4, 24,  6, 25])

这里主要关注的点是 这步是为了使 k对齐打乱顺序的q。q之前是打乱了顺序从而改变了每个batch的内容,相当于从所有的batch中随机挑选了 batch size的q,从而保证去除BN的影响。

而 k 不需要 再打乱了, 只需要从原有的batch size 数据分布中挑选出与q对应的数据即可。所以才在 shuffle BN q的过程中记录了indx unshuffle。

这里的对应关系举例,比如 index shuffle 中的 0 现在位于原来没打乱状态的索引 5处, 类似的 1 -->27, 2-->3, 以此类推。

注:不要被上面单进程的(即idx this)不对齐所迷惑,上面的只是分进程处理的,分布式训练最终会把所有进程的数据拼接起来一起处理,所以所有进程的数据对齐就行。

6、计算损失,即文中公式1

其中 用到的计算方法举例如下,分别用爱因斯坦求和公式实现,参考这里🤖

a = torch.tensor([[1, 2, 3], [1, 1, 1], [2, 2, 2]])
b = torch.tensor([[2, 2, 2], [2, 2, 2], [1, 1, 1]])
print(a)
print(b)
c = torch.einsum("nc, nc->n", [a, b])  # (3)
d = c.unsqueeze(-1)  # (3,1)
print(c)#=== 输出
tensor([[1, 2, 3],[1, 1, 1],[2, 2, 3]])
tensor([[2, 2, 2],[2, 2, 2],[1, 1, 1]])
tensor([12,  6,  7])
tensor([[12],[ 6],[ 7]])
a = torch.tensor([[1, 2, 3], [1, 1, 1], [2, 2, 3]])  # (3,3)
a1 = torch.tensor([[1, 2], [1, 1], [2, 2]])  # (3,2)
c = torch.einsum("nc,ck->nk", [a, a1])
print(a)
print(a1)
print(c)# ===输出
tensor([[1, 2, 3],[1, 1, 1],[2, 2, 3]])
tensor([[1, 2],[1, 1],[2, 2]])
tensor([[ 9, 10],[ 4,  5],[10, 12]])

这里的self.queue 即文中的字典 queue,初始化为

self.register_buffer("queue", torch.randn(dim, K))
self.queue = nn.functional.normalize(self.queue, dim=0)

K为字典的长度,默认设置65536。这里为什么设置为这个可能是由于ImageNet数据集比较大,所以设置的字典比较长,具体的长度设置好像没有做固定的要求,

来源于github官网。但代码中有要求,K必须是batch size 的倍数,这个为了确保字典的更新,方便执行入栈和弹出操作。这个字典像是C++的 queue容器的FIFO数据结构,即先进先出

self.K % batch_size == 0
        l_pos = torch.einsum("nc,nc->n", [q, k]).unsqueeze(-1)  #  (8,1)  对应元素相乘并第一维加和# negative logits: NxKl_neg = torch.einsum("nc,ck->nk", [q, self.queue.clone().detach()])  # (8,65536)  矩阵相乘# logits: Nx(1+K)logits = torch.cat([l_pos, l_neg], dim=1)  # (8,65537)# apply temperaturelogits /= self.Tlabels = torch.zeros(logits.shape[0], dtype=torch.long).cuda()  # (8,)loss = criterion(output, target)

这里看标签都是0,即第一个也就是0维数据为正样本。因为在拼接cat的时候正样本是在前面的。

7、更新字典

按mini batch 更新。具体地,如果 训练次数*mini batch size 小于字典长度,则字典queue每次都会填充新的key。若训练次数*mini batch size 大于 字典长度,则之前的被替换掉。

ptr = (ptr + batch_size) % self.K  # move pointer  8

相关文章:

MoCo 算法阅读记录

论文地址:🐰 何凯明大神之作,通过无监督对比学习预训练Image Encoder的表征能力。后也被许多VLP算法作为ITC的底层算法来使用。 一方面由于源代码本身并不复杂,但是要求多GPU分布式训练,以及需要下载ImageNet这个大规模…...

华为OD机试 - 数组连续和 - 滑动窗口(Java 2024 C卷 100分)

华为OD机试 2024C卷题库疯狂收录中,刷题点这里 专栏导读 本专栏收录于《华为OD机试(JAVA)真题(A卷B卷C卷)》。 刷的越多,抽中的概率越大,每一题都有详细的答题思路、详细的代码注释、样例测试…...

微店micro获得微店micro商品详情,API接口封装系列

微店商品详情API接口封装系列主要涉及注册账号、获取API密钥、选择API接口、发送请求以及处理响应等步骤。以下是详细的流程: 请求示例,API接口接入Anzexi58 一、注册账号并获取API密钥 首先,你需要在微店开放平台注册一个账号。注册成功后…...

C语言中的数据结构--链表的应用1(2)

前言 上一节我们学习了链表的概念以及链表的实现,那么本节我们就来了解一下链表具体有什么用,可以解决哪些实质性的问题,我们借用习题来加强对链表的理解,那么废话不多说,我们正式进入今天的学习 单链表相关经典算法O…...

.Net6 使用Autofac进行依赖注入

一、背景 刚接触.net 6,记录一下在.net6上是怎么使用Autofac进行动态的依赖注入的 二、注入方式 1、新建一个webapi项目,框架选择net 6 2、引用Nuget包---Autofac.Extensions.Dependency   3、在Program.cs上添加如下代码 //依赖注入 builder.Host.Us…...

第十二届蓝桥杯省赛真题(C/C++大学B组)

目录 #A 空间 #B 卡片 #C 直线 #D 货物摆放 #E 路径 #F 时间显示 #G 砝码称重 #H 杨辉三角形 #I 双向排序 #J 括号序列 #A 空间 #include <bits/stdc.h> using namespace std;int main() {cout<<256 * 1024 * 1024 / 4<<endl;return 0; } #B 卡片…...

DC40V降压恒压芯片H4120 40V转5V 3A 40V降压12V 车充降压恒压控制器

同步整流恒压芯片在现代电子设备中发挥着重要作用&#xff0c;为各种设备提供了稳定、高效的电源管理解决方案。 同步整流恒压芯片是一种电源管理芯片&#xff0c;它能够在不同电压输入条件下保持输出电压恒定。这种芯片广泛应用于各种电子设备中&#xff0c;如通讯设备、液晶…...

2、Qt UI控件 -- qucsdk项目使用

前言&#xff1a;上一篇文章讲了qucsdk的环境部署&#xff0c;可以在QDesigner和Qt Creator中看到qucsdk控件&#xff0c;这一篇来讲下在项目中使用qucsdk库中的控件。 一、准备材料 要想使用第三方库&#xff0c;需要三个先决条件&#xff0c; 1、控件的头文件 2、动/静态链…...

MATLAB算法实战应用案例精讲-【人工智能】AIGC概念三部曲(三)

目录 前言 算法原理 大模型 什么是AIGC? AIGC和Chat GPT的关系 常见的AIGC应用...

外汇110:外汇交易不同货币类别及交易注意事项!

外汇市场是一个庞大而复杂的市场&#xff0c;其中有各种各样的货币品种。对于外汇投资者来说&#xff0c;了解外汇品种的特性和走势是比较重要的。1. 货币种类 外汇市场中的货币品种可以分为主要货币、次要货币和外围货币。 主要货币&#xff1a;主要指美元、欧元、英镑、日元、…...

gerrit 拉取失败

在浏览器gerrit的设置界面设置的邮箱地址和在命令行使用git config --gloable user.email设置的邮箱地址必须保持一致吗 在浏览器gerrit的设置界面设置的邮箱地址和在命令行使用git config --global user.email设置的邮箱地址并不一定需要保持一致。这两个邮箱地址是独立的&am…...

大数据行业英语单词巩固20240410

20240410 Communication - 沟通 Example: Effective communication is essential for project success. 有效的沟通对于项目的成功至关重要。 Collaboration - 协作 Example: Team collaboration is crucial in achieving our goals. 团队协作对于实现我们的目标至关重要。 …...

天软特色因子看板 (2024.4 第3期)

该因子看板跟踪天软特色因子A05005(近一月单笔流出金额占比(%)&#xff0c;该因子为近一月单笔流出金额占比(% 均值因子&#xff0c;用以刻画下跌时的 单成交中可能存在的抄底现象 今日为该因子跟踪第3期&#xff0c;跟踪其在SH000852 (中证1000) 中的表现&#xff0c;要点如下…...

使用QT 开发不规则窗体

使用QT 开发不规则窗体 不规则窗体贴图法的不规则窗体创建UI模板创建一个父类创建业务窗体main函数直接调用user_dialog创建QSS文件 完整的QT工程 不规则窗体 QT中开发不规则窗体有两种方法&#xff1a;&#xff08;1&#xff09;第一种方法&#xff0c;使用QWidget::setMask函…...

如何构建企业经营所需的商业智能(BI)能力

构建企业经营所需的商业智能&#xff08;BI&#xff09;能力是一项涉及诸多关键环节与细致考量的系统工程&#xff0c;通过科学的数据处理、分析与应用&#xff0c;赋能企业实现精准决策&#xff0c;提升运营效率&#xff0c;优化业务流程&#xff0c;并在竞争激烈的市场环境中…...

【vue】watch监听取不到this指向的数?

今天同事问我&#xff0c;watch里this指向的数值&#xff0c;别的地方却可以打印出来。工具也能看到数值&#xff0c;但打印出来却是undifined&#xff0c;先看看代码&#xff1a; 懒得打字了直接上截图吧 ps&#xff1a; 在Vue组件中&#xff0c;如果你在watch选项中访问this…...

Ubuntu-22.04安装VMware虚拟机并安装Windows10

提示&#xff1a;文章写完后&#xff0c;目录可以自动生成&#xff0c;如何生成可参考右边的帮助文档 文章目录 前言一、VMware是什么&#xff1f;二、安装VMware1.注册VMware账号2.下载虚拟机3.编译vmmon&vmnet4.加载module5.安装bundle 三、安装Windows101.基础配置2.进阶…...

ELK企业日志分析系统介绍

前言 随着企业级应用系统日益复杂&#xff0c;随之产生的海量日志数据。传统的日志管理和分析手段&#xff0c;难以做到高效检索、实时监控以及深度挖掘潜在价值。在此背景下&#xff0c;ELK日志分析系统应运而生。本文将从ELK 日志分析系统的原理、架构及其在实践中的应用做相…...

在C#中读取写入字节流与读取写入二进制数据, 有何差异?

在C#中&#xff0c;读取和写入字节流与读取和写入二进制数据有些许不同&#xff0c;尽管它们在某些情况下可能会重叠使用。以下是它们之间的主要区别&#xff1a; 读取和写入字节流&#xff1a; 读取和写入字节流通常指的是处理文件或流中的原始字节数据。在C#中&#xff0c;可…...

数据库相关知识总结

一、数据库三级模式 三个抽象层次&#xff1a; 1. 视图层&#xff1a;最高层次的抽象&#xff0c;描述整个数据库的某个部分的数据 2. 逻辑层&#xff1a;描述数据库中存储的数据以及这些数据存在的关联 3. 物理层&#xff1a;最低层次的抽象&#xff0c;描述数据在存储器中时如…...

未来机器人的大脑:如何用神经网络模拟器实现更智能的决策?

编辑&#xff1a;陈萍萍的公主一点人工一点智能 未来机器人的大脑&#xff1a;如何用神经网络模拟器实现更智能的决策&#xff1f;RWM通过双自回归机制有效解决了复合误差、部分可观测性和随机动力学等关键挑战&#xff0c;在不依赖领域特定归纳偏见的条件下实现了卓越的预测准…...

Cesium1.95中高性能加载1500个点

一、基本方式&#xff1a; 图标使用.png比.svg性能要好 <template><div id"cesiumContainer"></div><div class"toolbar"><button id"resetButton">重新生成点</button><span id"countDisplay&qu…...

《从零掌握MIPI CSI-2: 协议精解与FPGA摄像头开发实战》-- CSI-2 协议详细解析 (一)

CSI-2 协议详细解析 (一&#xff09; 1. CSI-2层定义&#xff08;CSI-2 Layer Definitions&#xff09; 分层结构 &#xff1a;CSI-2协议分为6层&#xff1a; 物理层&#xff08;PHY Layer&#xff09; &#xff1a; 定义电气特性、时钟机制和传输介质&#xff08;导线&#…...

CMake基础:构建流程详解

目录 1.CMake构建过程的基本流程 2.CMake构建的具体步骤 2.1.创建构建目录 2.2.使用 CMake 生成构建文件 2.3.编译和构建 2.4.清理构建文件 2.5.重新配置和构建 3.跨平台构建示例 4.工具链与交叉编译 5.CMake构建后的项目结构解析 5.1.CMake构建后的目录结构 5.2.构…...

Cilium动手实验室: 精通之旅---20.Isovalent Enterprise for Cilium: Zero Trust Visibility

Cilium动手实验室: 精通之旅---20.Isovalent Enterprise for Cilium: Zero Trust Visibility 1. 实验室环境1.1 实验室环境1.2 小测试 2. The Endor System2.1 部署应用2.2 检查现有策略 3. Cilium 策略实体3.1 创建 allow-all 网络策略3.2 在 Hubble CLI 中验证网络策略源3.3 …...

Python爬虫(二):爬虫完整流程

爬虫完整流程详解&#xff08;7大核心步骤实战技巧&#xff09; 一、爬虫完整工作流程 以下是爬虫开发的完整流程&#xff0c;我将结合具体技术点和实战经验展开说明&#xff1a; 1. 目标分析与前期准备 网站技术分析&#xff1a; 使用浏览器开发者工具&#xff08;F12&…...

【C++从零实现Json-Rpc框架】第六弹 —— 服务端模块划分

一、项目背景回顾 前五弹完成了Json-Rpc协议解析、请求处理、客户端调用等基础模块搭建。 本弹重点聚焦于服务端的模块划分与架构设计&#xff0c;提升代码结构的可维护性与扩展性。 二、服务端模块设计目标 高内聚低耦合&#xff1a;各模块职责清晰&#xff0c;便于独立开发…...

DeepSeek 技术赋能无人农场协同作业:用 AI 重构农田管理 “神经网”

目录 一、引言二、DeepSeek 技术大揭秘2.1 核心架构解析2.2 关键技术剖析 三、智能农业无人农场协同作业现状3.1 发展现状概述3.2 协同作业模式介绍 四、DeepSeek 的 “农场奇妙游”4.1 数据处理与分析4.2 作物生长监测与预测4.3 病虫害防治4.4 农机协同作业调度 五、实际案例大…...

Caliper 负载(Workload)详细解析

Caliper 负载(Workload)详细解析 负载(Workload)是 Caliper 性能测试的核心部分,它定义了测试期间要执行的具体合约调用行为和交易模式。下面我将全面深入地讲解负载的各个方面。 一、负载模块基本结构 一个典型的负载模块(如 workload.js)包含以下基本结构: use strict;/…...

Chromium 136 编译指南 Windows篇:depot_tools 配置与源码获取(二)

引言 工欲善其事&#xff0c;必先利其器。在完成了 Visual Studio 2022 和 Windows SDK 的安装后&#xff0c;我们即将接触到 Chromium 开发生态中最核心的工具——depot_tools。这个由 Google 精心打造的工具集&#xff0c;就像是连接开发者与 Chromium 庞大代码库的智能桥梁…...