当前位置: 首页 > article >正文

self.cls_token在 Vision Transformer (ViT) 模型中的训练阶段和推理阶段的行为和作用的异同

self.cls_token 在 Vision Transformer (ViT) 模型中,在训练阶段和推理阶段的行为和作用是不同的,而且它的值在训练过程中会发生变化。

1. self.cls_token 的作用

在 ViT 中,self.cls_token 是一个特殊的、可学习的嵌入向量(embedding vector),它被添加到输入序列(图像patch的embedding序列)的开头。这个 cls_token 的主要目的是在经过 Transformer Encoder 的多层自注意力计算后,其对应的输出向量能够聚合整个输入序列的信息,用于最终的分类任务。

可以把 cls_token 理解为一个“班长”的角色。每个图像块(patch)是一个“学生”。一开始,“班长”(cls_token)和“学生”(patches)互相不认识(都是随机初始化的)。在 Transformer 的每一层,“班长”都会和每个“学生”交流(自注意力机制),同时“学生”之间也互相交流。经过多层交流后,“班长”就逐渐了解了整个班级的情况(图像的全局信息)。最后,我们只用“班长”的输出来做分类。

2. 训练阶段

  1. 随机初始化:在模型初始化时,self.cls_token 是一个形状为 (1, 1, embed_dim) 的张量,其中的值通常是从某个分布(如正态分布)中随机采样的。这意味着在训练开始时,cls_token 没有任何关于图像的先验信息。

  2. 可学习参数self.cls_token 被定义为 nn.Parameter,这意味着它是一个模型的可学习参数。在训练过程中,它会随着其他模型参数一起通过反向传播和梯度下降进行更新。

  3. 与输入交互:在每个训练批次中,self.cls_token 会被复制并与每个输入图像的patch embeddings进行拼接(concatenate),形成 Transformer Encoder 的输入序列。

    # 假设 x 是图像patch embeddings, 形状为 (batch_size, num_patches, embed_dim)
    cls_token = self.cls_token.expand(x.shape[0], -1, -1)  # 扩展到与 batch_size 匹配
    x = torch.cat((cls_token, x), dim=1)  # 拼接
    
  4. 信息聚合:在 Transformer Encoder 的每一层,cls_token 对应的embedding都会与其他patch embeddings进行自注意力计算。通过这种方式,cls_token 逐渐“学习”到如何聚合来自所有patch的信息。

  5. 参数更新:在反向传播过程中,cls_token 的梯度会根据分类损失进行计算,并通过优化器进行更新。这意味着 cls_token 的值会不断调整,以更好地捕捉图像的全局特征。

3. 推理阶段

  1. 固定值:在推理阶段,模型的所有参数(包括 self.cls_token)都是固定的,不再进行更新。cls_token 使用的是训练结束时学习到的值。

  2. 相同操作:与训练阶段类似,self.cls_token 仍然会被复制并与输入图像的patch embeddings进行拼接,作为 Transformer Encoder 的输入。

  3. 信息提取:经过 Transformer Encoder 的处理后,cls_token 对应的输出向量被用作分类器的输入,进行最终的类别预测。

4. 总结

特性训练阶段推理阶段
随机初始化,通过反向传播更新固定(使用训练结束时学习到的值)
是否可学习是 (nn.Parameter)
作用与patch embeddings交互,聚合全局信息,参与梯度更新与patch embeddings交互,提取全局信息,用于分类

5. 代码示例 (简化)

import torch
import torch.nn as nnclass VisionTransformer(nn.Module):def __init__(self, embed_dim=768, ...):super().__init__()# ... 其他层 ...self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) # 可学习参数# ... 其他层 ...def forward(self, x):# ... patch embedding ...cls_token = self.cls_token.expand(x.shape[0], -1, -1)  # 复制cls_tokenx = torch.cat((cls_token, x), dim=1)  # 拼接# ... Transformer Encoder ...x = x[:, 0]  # 取cls_token对应的输出# ... 分类器 ...return x

因此,self.cls_token 在训练阶段是随机初始化的可学习参数,通过与图像patch embeddings的交互和反向传播不断更新;在推理阶段,self.cls_token 的值是固定的,它利用训练中学到的知识来提取图像的全局特征,用于分类。
这种设计使得 ViT 能够有效地处理图像数据,并在各种视觉任务中取得了出色的性能。

相关文章:

self.cls_token在 Vision Transformer (ViT) 模型中的训练阶段和推理阶段的行为和作用的异同

self.cls_token 在 Vision Transformer (ViT) 模型中,在训练阶段和推理阶段的行为和作用是不同的,而且它的值在训练过程中会发生变化。 1. self.cls_token 的作用 在 ViT 中,self.cls_token 是一个特殊的、可学习的嵌入向量(emb…...

【量化科普】Leverage,杠杆

【量化科普】Leverage,杠杆 🚀量化软件开通 🚀量化实战教程 在量化投资领域,杠杆(Leverage)是一个核心概念,它允许投资者通过借入资金来增加投资规模,从而放大投资收益或亏损。简…...

247g 的工业级电调,如何让无人机飞得更 “聪明“?——STONE 200A-M 深度测评

一、轻量化设计背后的技术取舍 当拿到 STONE 200A-M 时,247g 的重量让人意外 —— 这个接近传统 200A 电调 70% 的重量,源自 1205624.5mm 的紧凑结构(0.1mm 公差控制)。实测装机显示,相比同规格产品,其体积…...

Maven Deploy Plugin如何使用?

在Java开发中,Maven是一个非常重要的构建工具。它不仅可以管理项目的依赖关系,还能帮助我们打包和发布项目。在Maven中,deploy插件是一个很实用的功能,它可以将构建好的项目发布到远程仓库。今天,就来聊聊如何使用Mave…...

Node.js:快速启动你的第一个Web服务器

Node.js 全面入门指南 文章目录 Node.js 全面入门指南一 安装Node.js1. Windows2. MacOS/Linux 二 配置开发环境1. VSCode集成 三 第一个Node.js程序1. 创建你的第一个Node.js程序 四 使用Express框架1. 快速搭建服务器 一 安装Node.js 1. Windows 以下是Windows环境下Node.j…...

自定义日志回调函数实现第三方库日志集成:从理论到实战

一、应用场景与痛点分析 在开发过程中,我们经常会遇到以下场景: 日志格式统一:第三方库使用自己的日志格式,导致系统日志混杂,难以统一管理和分析。日志分级过滤:需要动态调整第三方库的日志输出级别&…...

Linux练级宝典->任务管理和守护进程

任务管理 进程组概念 每个进程除了进程ID以外,还有一个进程组,进程组就是一个或多个进程的集合 同一个进程组,代表着他们是共同作业的,可以接收同一个终端的各种信号,进程组也有其唯一的进程组号。还有一个组长进程&a…...

C语言:计算并输出三个整数的最大值 并对三个数排序

这是《C语言程序设计》73页的思考题。下面分享自己的思路和代码 思路&#xff1a; 代码&#xff1a; #include <stdio.h> int main() {int a,b,c,max,min,mid ; //设置大中小的数分别为max&#xff0c;mid&#xff0c;min&#xff0c;abc为输入的三个数printf("ple…...

工具(十二):Java导出MySQL数据库表结构信息到excel

一、背景 遇到需求&#xff1a;将指定数据库表设计&#xff0c;统一导出到一个Excel中&#xff0c;存档查看。 如果一个一个弄&#xff0c;很复杂&#xff0c;耗时长。 二、写一个工具导出下 废话少絮&#xff0c;上码&#xff1a; 2.1 pom导入 <dependency><grou…...

如何设计微服务及其设计原则?

微服务架构是一种将大型单体应用拆分成多个小型、自治服务的设计方式&#xff0c;每个服务专注于单一的业务功能。设计微服务时&#xff0c;需要遵循以下原则和最佳实践&#xff1a; 1. 单一职责原则 核心思想&#xff1a; 每个微服务都应该只负责一块独立的业务功能。这使得…...

ACL初级总结

ACL–访问控制列表 1.访问控制 在路由器流量流入或者流出的接口上,匹配流量,然后执行相应动作 permit允许 deny拒绝 2.抓取感兴趣流 3.ACL匹配规则 自上而下逐一匹配,若匹配到了则按照对应规则执行动作,而不再向下继续匹配 思科:ACL列表末尾隐含一条拒绝所有的规则 华为:AC…...

调优案例一:堆空间扩容提升吞吐量实战记录

&#x1f4dd; 调优案例一&#xff1a;堆空间扩容提升吞吐量实战记录 &#x1f527; 调优策略&#xff1a;堆空间扩容三部曲 # 原配置&#xff08;30MB堆空间&#xff09; export CATALINA_OPTS"$CATALINA_OPTS -Xms30m -Xmx30m"# 新配置&#xff08;扩容至120MB&am…...

C语言 —— 此去经年梦浪荡魂音 - 深入理解指针(卷一)

目录 1. 内存和地址 2. 指针变量和地址 2.1 取地址操作符&#xff08;&&#xff09; 2.2 指针变量 2.3 解引用操作符 &#xff08;*&#xff09; 3. 指针的解引用 3.1 指针 - 整数 3.2 void* 指针 4. const修饰指针 4.1 const修饰变量 4.2 const修饰指针变量 5…...

计算机毕业设计:留守儿童的可视化界面

留守儿童的可视化界面mysql数据库创建语句留守儿童的可视化界面oracle数据库创建语句留守儿童的可视化界面sqlserver数据库创建语句留守儿童的可视化界面springspringMVChibernate框架对象(javaBean,pojo)设计留守儿童的可视化界面springspringMVCmybatis框架对象(javaBean,poj…...

golang算法二叉树对称平衡右视图

100. 相同的树 给你两棵二叉树的根节点 p 和 q &#xff0c;编写一个函数来检验这两棵树是否相同。 如果两个树在结构上相同&#xff0c;并且节点具有相同的值&#xff0c;则认为它们是相同的。 示例 1&#xff1a; 输入&#xff1a;p [1,2,3], q [1,2,3] 输出&#xff1a…...

c++20 Concepts的简写形式与requires 从句形式

c20 Concepts的简写形式与requires 从句形式 原始写法&#xff08;简写形式&#xff09;等效写法&#xff08;requires 从句形式&#xff09;关键区别说明&#xff1a;组合多个约束的示例&#xff1a;两种形式的编译结果&#xff1a;更复杂的约束示例&#xff1a;标准库风格的约…...

Chatbox通过百炼调用DeepSeek

解决方案链接&#xff1a;评测&#xff5c;零门槛&#xff0c;即刻拥有DeepSeek-R1满血版 方案概览 本方案以 DeepSeek-R1 满血版为例进行演示&#xff0c;通过百炼模型服务进行 DeepSeek 开源模型调用&#xff0c;可以根据实际需求选择其他参数规模的 DeepSeek 模型。百炼平台…...

【数据结构】6栈

0 章节 3&#xff0e;1到3&#xff0e;3小节。 认知与理解栈结构&#xff1b; 列举栈的操作特点。 理解并列举栈的应用案例。 重点 栈的特点与实现&#xff1b; 难点 栈的灵活实现与应用 作业或思考题 完成学习测试&#xff12;&#xff0c;&#xff1f; 内容达成以下标准(考核…...

PyTorch 入门学习

目录 PyTorch 定义 核心作用 应用场景 Pytorch 基本语法 1. 张量的创建 2. 张量的类型转换 3. 张量数值计算 4. 张量运算函数 5. 张量索引操作 6. 张量形状操作 7. 张量拼接操作 8. 自动微分模块 9. 案例-线性回归案例 PyTorch 定义 PyTorch 是一个基于 Python 深…...

mov格式视频如何转换mp4?

mov格式视频如何转换mp4&#xff1f;在日常的视频处理中&#xff0c;经常需要将MOV格式的视频转换为MP4格式&#xff0c;以兼容更多的播放设备和平台。下面给大家分享如何将MOV视频转换为MP4&#xff0c;4款视频格式转换工具分享。 一、牛学长转码大师 牛学长转码大师是一款功…...

数据结构与算法:动态规划dp:子序列相关力扣题(下):392. 判断子序列、115.不同的子序列

392. 判断子序列 1.套最长公共子序列问题的板子 class Solution:def isSubsequence(self, s: str, t: str) -> bool:"""最长公共子序列长度是否len(s)&#xff0c;是就是true&#xff0c;否就是falsedp[i][j]考虑以s[i-1]&#xff0c;t[j-1]的最长公共子序…...

二进制求和(js实现,LeetCode:67)

这道题我的解决思路是先将a和b的长度保持一致以方便后续按位加减 let lena a.length let lenb b.length if (lena ! lenb) {if (lena > lenb) {for (let i 0; i <lena-lenb; i) {b 0 b}} else {for (let i 0; i < lenb-lena; i) {a 0 a}} } 下一步直接进行按…...

【C#】使用DeepSeek帮助评估数据库性能问题,C# 使用定时任务,每隔一分钟移除一次表,再重新创建表,和往新创建的表追加5万多条记录

&#x1f339;欢迎来到《小5讲堂》&#x1f339; &#x1f339;这是《C#》系列文章&#xff0c;每篇文章将以博主理解的角度展开讲解。&#x1f339; &#x1f339;温馨提示&#xff1a;博主能力有限&#xff0c;理解水平有限&#xff0c;若有不对之处望指正&#xff01;&#…...

【openGauss】物理备份恢复

文章目录 1. gs_backup&#xff08;1&#xff09;备份&#xff08;2&#xff09;恢复&#xff08;3&#xff09;手动恢复的办法 2. gs_basebackup&#xff08;1&#xff09;备份&#xff08;2&#xff09;恢复① 伪造数据目录丢失② 恢复 3. gs_probackup&#xff08;1&#xf…...

蓝桥杯备赛-基础练习 day1

1、闰年判断 问题描述 给定一个年份&#xff0c;判断这一年是不是闰年。 当以下情况之一满足时&#xff0c;这一年是闰年:1.年份是4的倍数而不是100的倍数 2&#xff0e;年份是400的倍数。 其他的年份都不是闰年。 输入格式 输入包含一个…...

实验四 Python聚类决策树训练与预测 基于神经网络的MNIST手写体识别

一、实验目的 Python聚类决策树训练与预测&#xff1a; 1、掌握决策树的基本原理并理解监督学习的基本思想。 2、掌握Python实现决策树的方法。 基于神经网络的MNIST手写体识别&#xff1a; 1、学习导入和使用Tensorflow。 2、理解学习神经网络的基本原理。 3、学习使用…...

【原创】在高性能服务器上,使用受限用户运行Nginx,充当反向代理服务器[未完待续]

起因 在公共高性能服务器上运行OllamaDeepSeek&#xff0c;如果按照默认配置启动Ollama程序&#xff0c;则自己在远程无法连接你启动的Ollama服务。 如果修改配置&#xff0c;则会遇到你的Ollama被他人完全控制的安全风险。 不过&#xff0c;我们可以使用一个方向代理&#…...

网络_面试_HTTP请求报文和HTTP响应报文

简介&#xff1a; HTTP报文是面向文本的&#xff0c;报文中的每一个字段都是一些ASCII码串&#xff0c;各个字段的长度是不确定的。HTTP有两类报文&#xff1a;请求报文和响应报文。 HTTP请求报文 一个HTTP请求报文由请求行&#xff08;request line&#xff09;、请求头部&…...

详解CPU的组成与功能

CPU的组成与功能 一、 控制单元&#xff08;Control Unit, CU&#xff09;二、 算术逻辑单元&#xff08;Arithmetic Logic Unit, ALU&#xff09;三、 寄存器&#xff08;Registers&#xff09;四、 高速缓存&#xff08;Cache&#xff09;五、 辅助结构与技术译码器&#xff…...

Spring boot3-WebClient远程调用非阻塞、响应式HTTP客户端

来吧&#xff0c;会用就行具体理论不讨论 1、首先pom.xml引入webflux依赖 <dependency><groupId>org.springframework.boot</groupId><artifactId>spring-boot-starter-webflux</artifactId> </dependency> 别问为什么因为是响应式....…...