MoCo 算法阅读记录
论文地址:🐰
何凯明大神之作,通过无监督对比学习预训练Image Encoder的表征能力。后也被许多VLP算法作为ITC的底层算法来使用。
一方面由于源代码本身并不复杂,但是要求多GPU分布式训练,以及需要下载ImageNet这个大规模的数据集;另一方面 本次只是测试和阅读算法原理的实现,并不完整使用。因此,重写了一个低配版(流程不变,超参数没有严格要求设置,单GPU跑,数据集自己配置,几十张图片, no Shuffling BN)。
queue 即文中所构建的字典,这个好比如 C++ 中 的queue 容器,因为它是一种先进先出的数据结构。
目录
一、数据预处理
二、前向传播
网络结构
算法流程
一、数据预处理
对同一张图片进行数据增强操作,得到 query 和 key。
增强操作包括
transforms.RandomResizedCrop(224, scale=(0.2, 1.0)),transforms.RandomGrayscale(p=0.2),transforms.ColorJitter(0.4, 0.4, 0.4, 0.4),transforms.RandomHorizontalFlip(),normalize,
所以,dataloader中的每个输入样本是一个样本对儿。
通过下列方法实现
class TwoCropsTransform:"""Take two random crops of one image as the query and key."""def __init__(self, base_transform):self.base_transform = base_transformdef __call__(self, x):q = self.base_transform(x)k = self.base_transform(x)return [q, k]
二、前向传播
网络结构
代码中 encoder q 和 encoder k的网络结构用的都是ReNet 。ResNet最终的输出层包含了
(avgpool): AdaptiveAvgPool2d(output_size=(1, 1))(fc): Linear(in_features=2048, out_features=128, bias=True)
所以,输出的特征向量维度为 (N,C)。N为文中的Mini batch大小,代码中的超参数为batch size。C应该没有什么具体的含义,只是经验的设置为这一长度了(没找出来C的大小关乎什么)。

其输出还经过了L2归一化。
算法流程
1、 q 送入 encoder q 得到输出,并经过L2归一化, (N,C)
2、 momentum 更新 key encoder。
3、 Shuffling BN(当然我重写的代码并没有实现这个,因为它需要多GPU,但这并不妨碍认识它的作用)
文中所述
大致意思由于ResNet使用了BN操作,因此由于Batch 数据之间的交互,使得模型利用它欺骗预设任务从而简单的找到一个低损失的解决方案,然而这个解决方案效果并不好,使得模型学习不到好的表征能力。
其提出的Shuffling BN
首先,把所有进程的Tensor的收集起来(如果分布式训练,一般每个GPU包含一个进程,所以收集的数据总量大小为 num GPUs * batch size),参考这里🤖
x_gather = concat_all_gather(x)接下来制作打乱的索引,整个过程如下所示
def _batch_shuffle_ddp(self, x):"""Batch shuffle, for making use of BatchNorm.*** Only support DistributedDataParallel (DDP) model. ***"""# gather from all gpusbatch_size_this = x.shape[0]x_gather = concat_all_gather(x) # 将所有进程的数据收集起来batch_size_all = x_gather.shape[0]num_gpus = batch_size_all // batch_size_this# random shuffle indexidx_shuffle = torch.randperm(batch_size_all).cuda() # torch.randperm 将[0,n)数随机排列# broadcast to all gpustorch.distributed.broadcast(idx_shuffle, src=0) # 将这个信息广播到所有其他进程# index for restoringidx_unshuffle = torch.argsort(idx_shuffle) # 按照值大小顺序返回下标# shuffled index for this gpugpu_idx = torch.distributed.get_rank() # 返回当前的进程idx_this = idx_shuffle.view(num_gpus, -1)[gpu_idx] # idx_shuffle view 后 (num_gpus, batch size) 但是batch size中的索引是打乱顺序的return x_gather[idx_this], idx_unshuffle最终返回 随机打乱顺序后挑选的当前进程的 batch size 大小的数据,也就是说进行 BN归一化后的数据已经不在 同一个原来的批 中了。
4、k 送入 encoder k 中,在经过L2 归一化, 和q一样。 (N,C)
5、Shuffling BN 对齐 q 和 k
如下面举例
# idx_shuffle tensor([10, 16, 13, 2, 4, 0, 6, 21, 22, 31, 29, 3, 19, 17, 14, 30, 28, 12,24, 26, 8, 25, 11, 18, 5, 7, 27, 1, 15, 23, 20, 9])# idx_unshuffle tensor([ 5, 27, 3, 11, 4, 24, 6, 25, 20, 31, 0, 22, 17, 2, 14, 28, 1, 13,23, 12, 30, 7, 8, 29, 18, 21, 19, 26, 16, 10, 15, 9])# q 的 idx_this tensor([10, 16, 13, 2, 4, 0, 6, 21])# k 的 idx_this tensor([ 5, 27, 3, 11, 4, 24, 6, 25])这里主要关注的点是 这步是为了使 k对齐打乱顺序的q。q之前是打乱了顺序从而改变了每个batch的内容,相当于从所有的batch中随机挑选了 batch size的q,从而保证去除BN的影响。
而 k 不需要 再打乱了, 只需要从原有的batch size 数据分布中挑选出与q对应的数据即可。所以才在 shuffle BN q的过程中记录了indx unshuffle。
这里的对应关系举例,比如 index shuffle 中的 0 现在位于原来没打乱状态的索引 5处, 类似的 1 -->27, 2-->3, 以此类推。
注:不要被上面单进程的(即idx this)不对齐所迷惑,上面的只是分进程处理的,分布式训练最终会把所有进程的数据拼接起来一起处理,所以所有进程的数据对齐就行。
6、计算损失,即文中公式1
其中 用到的计算方法举例如下,分别用爱因斯坦求和公式实现,参考这里🤖
a = torch.tensor([[1, 2, 3], [1, 1, 1], [2, 2, 2]]) b = torch.tensor([[2, 2, 2], [2, 2, 2], [1, 1, 1]]) print(a) print(b) c = torch.einsum("nc, nc->n", [a, b]) # (3) d = c.unsqueeze(-1) # (3,1) print(c)#=== 输出 tensor([[1, 2, 3],[1, 1, 1],[2, 2, 3]]) tensor([[2, 2, 2],[2, 2, 2],[1, 1, 1]]) tensor([12, 6, 7]) tensor([[12],[ 6],[ 7]])a = torch.tensor([[1, 2, 3], [1, 1, 1], [2, 2, 3]]) # (3,3) a1 = torch.tensor([[1, 2], [1, 1], [2, 2]]) # (3,2) c = torch.einsum("nc,ck->nk", [a, a1]) print(a) print(a1) print(c)# ===输出 tensor([[1, 2, 3],[1, 1, 1],[2, 2, 3]]) tensor([[1, 2],[1, 1],[2, 2]]) tensor([[ 9, 10],[ 4, 5],[10, 12]])这里的self.queue 即文中的字典 queue,初始化为
self.register_buffer("queue", torch.randn(dim, K)) self.queue = nn.functional.normalize(self.queue, dim=0)K为字典的长度,默认设置65536。这里为什么设置为这个可能是由于ImageNet数据集比较大,所以设置的字典比较长,具体的长度设置好像没有做固定的要求,
来源于github官网。但代码中有要求,K必须是batch size 的倍数,这个为了确保字典的更新,方便执行入栈和弹出操作。这个字典像是C++的 queue容器的FIFO数据结构,即先进先出
self.K % batch_size == 0l_pos = torch.einsum("nc,nc->n", [q, k]).unsqueeze(-1) # (8,1) 对应元素相乘并第一维加和# negative logits: NxKl_neg = torch.einsum("nc,ck->nk", [q, self.queue.clone().detach()]) # (8,65536) 矩阵相乘# logits: Nx(1+K)logits = torch.cat([l_pos, l_neg], dim=1) # (8,65537)# apply temperaturelogits /= self.Tlabels = torch.zeros(logits.shape[0], dtype=torch.long).cuda() # (8,)loss = criterion(output, target)这里看标签都是0,即第一个也就是0维数据为正样本。因为在拼接cat的时候正样本是在前面的。
7、更新字典
按mini batch 更新。具体地,如果 训练次数*mini batch size 小于字典长度,则字典queue每次都会填充新的key。若训练次数*mini batch size 大于 字典长度,则之前的被替换掉。
ptr = (ptr + batch_size) % self.K # move pointer 8
相关文章:
MoCo 算法阅读记录
论文地址:🐰 何凯明大神之作,通过无监督对比学习预训练Image Encoder的表征能力。后也被许多VLP算法作为ITC的底层算法来使用。 一方面由于源代码本身并不复杂,但是要求多GPU分布式训练,以及需要下载ImageNet这个大规模…...
华为OD机试 - 数组连续和 - 滑动窗口(Java 2024 C卷 100分)
华为OD机试 2024C卷题库疯狂收录中,刷题点这里 专栏导读 本专栏收录于《华为OD机试(JAVA)真题(A卷B卷C卷)》。 刷的越多,抽中的概率越大,每一题都有详细的答题思路、详细的代码注释、样例测试…...
微店micro获得微店micro商品详情,API接口封装系列
微店商品详情API接口封装系列主要涉及注册账号、获取API密钥、选择API接口、发送请求以及处理响应等步骤。以下是详细的流程: 请求示例,API接口接入Anzexi58 一、注册账号并获取API密钥 首先,你需要在微店开放平台注册一个账号。注册成功后…...
C语言中的数据结构--链表的应用1(2)
前言 上一节我们学习了链表的概念以及链表的实现,那么本节我们就来了解一下链表具体有什么用,可以解决哪些实质性的问题,我们借用习题来加强对链表的理解,那么废话不多说,我们正式进入今天的学习 单链表相关经典算法O…...
.Net6 使用Autofac进行依赖注入
一、背景 刚接触.net 6,记录一下在.net6上是怎么使用Autofac进行动态的依赖注入的 二、注入方式 1、新建一个webapi项目,框架选择net 6 2、引用Nuget包---Autofac.Extensions.Dependency 3、在Program.cs上添加如下代码 //依赖注入 builder.Host.Us…...
第十二届蓝桥杯省赛真题(C/C++大学B组)
目录 #A 空间 #B 卡片 #C 直线 #D 货物摆放 #E 路径 #F 时间显示 #G 砝码称重 #H 杨辉三角形 #I 双向排序 #J 括号序列 #A 空间 #include <bits/stdc.h> using namespace std;int main() {cout<<256 * 1024 * 1024 / 4<<endl;return 0; } #B 卡片…...
DC40V降压恒压芯片H4120 40V转5V 3A 40V降压12V 车充降压恒压控制器
同步整流恒压芯片在现代电子设备中发挥着重要作用,为各种设备提供了稳定、高效的电源管理解决方案。 同步整流恒压芯片是一种电源管理芯片,它能够在不同电压输入条件下保持输出电压恒定。这种芯片广泛应用于各种电子设备中,如通讯设备、液晶…...
2、Qt UI控件 -- qucsdk项目使用
前言:上一篇文章讲了qucsdk的环境部署,可以在QDesigner和Qt Creator中看到qucsdk控件,这一篇来讲下在项目中使用qucsdk库中的控件。 一、准备材料 要想使用第三方库,需要三个先决条件, 1、控件的头文件 2、动/静态链…...
MATLAB算法实战应用案例精讲-【人工智能】AIGC概念三部曲(三)
目录 前言 算法原理 大模型 什么是AIGC? AIGC和Chat GPT的关系 常见的AIGC应用...
外汇110:外汇交易不同货币类别及交易注意事项!
外汇市场是一个庞大而复杂的市场,其中有各种各样的货币品种。对于外汇投资者来说,了解外汇品种的特性和走势是比较重要的。1. 货币种类 外汇市场中的货币品种可以分为主要货币、次要货币和外围货币。 主要货币:主要指美元、欧元、英镑、日元、…...
gerrit 拉取失败
在浏览器gerrit的设置界面设置的邮箱地址和在命令行使用git config --gloable user.email设置的邮箱地址必须保持一致吗 在浏览器gerrit的设置界面设置的邮箱地址和在命令行使用git config --global user.email设置的邮箱地址并不一定需要保持一致。这两个邮箱地址是独立的&am…...
大数据行业英语单词巩固20240410
20240410 Communication - 沟通 Example: Effective communication is essential for project success. 有效的沟通对于项目的成功至关重要。 Collaboration - 协作 Example: Team collaboration is crucial in achieving our goals. 团队协作对于实现我们的目标至关重要。 …...
天软特色因子看板 (2024.4 第3期)
该因子看板跟踪天软特色因子A05005(近一月单笔流出金额占比(%),该因子为近一月单笔流出金额占比(% 均值因子,用以刻画下跌时的 单成交中可能存在的抄底现象 今日为该因子跟踪第3期,跟踪其在SH000852 (中证1000) 中的表现,要点如下…...
使用QT 开发不规则窗体
使用QT 开发不规则窗体 不规则窗体贴图法的不规则窗体创建UI模板创建一个父类创建业务窗体main函数直接调用user_dialog创建QSS文件 完整的QT工程 不规则窗体 QT中开发不规则窗体有两种方法:(1)第一种方法,使用QWidget::setMask函…...
如何构建企业经营所需的商业智能(BI)能力
构建企业经营所需的商业智能(BI)能力是一项涉及诸多关键环节与细致考量的系统工程,通过科学的数据处理、分析与应用,赋能企业实现精准决策,提升运营效率,优化业务流程,并在竞争激烈的市场环境中…...
【vue】watch监听取不到this指向的数?
今天同事问我,watch里this指向的数值,别的地方却可以打印出来。工具也能看到数值,但打印出来却是undifined,先看看代码: 懒得打字了直接上截图吧 ps: 在Vue组件中,如果你在watch选项中访问this…...
Ubuntu-22.04安装VMware虚拟机并安装Windows10
提示:文章写完后,目录可以自动生成,如何生成可参考右边的帮助文档 文章目录 前言一、VMware是什么?二、安装VMware1.注册VMware账号2.下载虚拟机3.编译vmmon&vmnet4.加载module5.安装bundle 三、安装Windows101.基础配置2.进阶…...
ELK企业日志分析系统介绍
前言 随着企业级应用系统日益复杂,随之产生的海量日志数据。传统的日志管理和分析手段,难以做到高效检索、实时监控以及深度挖掘潜在价值。在此背景下,ELK日志分析系统应运而生。本文将从ELK 日志分析系统的原理、架构及其在实践中的应用做相…...
在C#中读取写入字节流与读取写入二进制数据, 有何差异?
在C#中,读取和写入字节流与读取和写入二进制数据有些许不同,尽管它们在某些情况下可能会重叠使用。以下是它们之间的主要区别: 读取和写入字节流: 读取和写入字节流通常指的是处理文件或流中的原始字节数据。在C#中,可…...
数据库相关知识总结
一、数据库三级模式 三个抽象层次: 1. 视图层:最高层次的抽象,描述整个数据库的某个部分的数据 2. 逻辑层:描述数据库中存储的数据以及这些数据存在的关联 3. 物理层:最低层次的抽象,描述数据在存储器中时如…...
椭圆曲线密码学(ECC)
一、ECC算法概述 椭圆曲线密码学(Elliptic Curve Cryptography)是基于椭圆曲线数学理论的公钥密码系统,由Neal Koblitz和Victor Miller在1985年独立提出。相比RSA,ECC在相同安全强度下密钥更短(256位ECC ≈ 3072位RSA…...
Cesium1.95中高性能加载1500个点
一、基本方式: 图标使用.png比.svg性能要好 <template><div id"cesiumContainer"></div><div class"toolbar"><button id"resetButton">重新生成点</button><span id"countDisplay&qu…...
UDP(Echoserver)
网络命令 Ping 命令 检测网络是否连通 使用方法: ping -c 次数 网址ping -c 3 www.baidu.comnetstat 命令 netstat 是一个用来查看网络状态的重要工具. 语法:netstat [选项] 功能:查看网络状态 常用选项: n 拒绝显示别名&#…...
1688商品列表API与其他数据源的对接思路
将1688商品列表API与其他数据源对接时,需结合业务场景设计数据流转链路,重点关注数据格式兼容性、接口调用频率控制及数据一致性维护。以下是具体对接思路及关键技术点: 一、核心对接场景与目标 商品数据同步 场景:将1688商品信息…...
大语言模型如何处理长文本?常用文本分割技术详解
为什么需要文本分割? 引言:为什么需要文本分割?一、基础文本分割方法1. 按段落分割(Paragraph Splitting)2. 按句子分割(Sentence Splitting)二、高级文本分割策略3. 重叠分割(Sliding Window)4. 递归分割(Recursive Splitting)三、生产级工具推荐5. 使用LangChain的…...
《用户共鸣指数(E)驱动品牌大模型种草:如何抢占大模型搜索结果情感高地》
在注意力分散、内容高度同质化的时代,情感连接已成为品牌破圈的关键通道。我们在服务大量品牌客户的过程中发现,消费者对内容的“有感”程度,正日益成为影响品牌传播效率与转化率的核心变量。在生成式AI驱动的内容生成与推荐环境中࿰…...
智能在线客服平台:数字化时代企业连接用户的 AI 中枢
随着互联网技术的飞速发展,消费者期望能够随时随地与企业进行交流。在线客服平台作为连接企业与客户的重要桥梁,不仅优化了客户体验,还提升了企业的服务效率和市场竞争力。本文将探讨在线客服平台的重要性、技术进展、实际应用,并…...
P3 QT项目----记事本(3.8)
3.8 记事本项目总结 项目源码 1.main.cpp #include "widget.h" #include <QApplication> int main(int argc, char *argv[]) {QApplication a(argc, argv);Widget w;w.show();return a.exec(); } 2.widget.cpp #include "widget.h" #include &q…...
从零开始打造 OpenSTLinux 6.6 Yocto 系统(基于STM32CubeMX)(九)
设备树移植 和uboot设备树修改的内容同步到kernel将设备树stm32mp157d-stm32mp157daa1-mx.dts复制到内核源码目录下 源码修改及编译 修改arch/arm/boot/dts/st/Makefile,新增设备树编译 stm32mp157f-ev1-m4-examples.dtb \stm32mp157d-stm32mp157daa1-mx.dtb修改…...
论文解读:交大港大上海AI Lab开源论文 | 宇树机器人多姿态起立控制强化学习框架(一)
宇树机器人多姿态起立控制强化学习框架论文解析 论文解读:交大&港大&上海AI Lab开源论文 | 宇树机器人多姿态起立控制强化学习框架(一) 论文解读:交大&港大&上海AI Lab开源论文 | 宇树机器人多姿态起立控制强化…...


