当前位置: 首页 > news >正文

Vision Transformer (ViT)原理

Vision Transformer (ViT)原理

flyfish
请添加图片描述
Transformer缺乏卷积神经网络(CNNs)的归纳偏差(inductive biases),比如平移不变性和局部受限的感受野。不变性意味着即使实体entity(即对象)的外观或位置发生变化,仍然可以在图像中识别出它。在计算机视觉中,平移意味着每个图像像素都以固定量朝特定方向移动。
卷积是一个线性局部操作符(linear local operator)。看到由卷积核kernel指示的邻近值。另一方面,Transformer的设计是permutation invariant 的。坏消息是它不能处理网格结构的数据。需要序列!将把空间非序列信号转换为序列!看看怎么做。
Vision Transformer(简称ViT)的工作原理概述如下:

  1. 将图像分割成小块(patches)
  2. 将小块展平 Flatten the patches
  3. 从展平的小块中产生低维线性嵌入 Produce lower-dimensional linear embeddings from the flattened patches
  4. 添加位置嵌入 Add positional embeddings
  5. 将序列作为输入输入到标准Transformer编码器 Feed the sequence as an input to a standard transformer encoder
  6. 使用图像标签对模型进行预训练(在大型数据集上进行全面监督)Pretrain the model with image labels (fully supervised on a huge dataset)
  7. 在下游数据集上进行微调以进行图像分类 Finetune on the downstream dataset for image classification

名词解释

“inductive”的中文意思是“归纳的;归纳法的;电感的”。在给定的上下文中,“inductive biases”可以理解为“归纳偏向”,指的是通过归纳推理而产生的某种倾向或偏好。例如在机器学习中,不同的模型可能具有不同的归纳偏向,这会影响它们对数据的处理方式和学习结果。

“permutation invariant”的中文意思是“排列不变性”。
在数学和计算机科学等领域,排列不变性指的是某个对象或函数在输入的元素排列顺序发生变化时,其值或性质保持不变。例如,对于一个集合的某种特征描述,如果无论集合中元素的排列顺序如何变化,这个特征描述的值都不变,那么就说这个特征具有排列不变性。
例如,计算一组数字的总和,无论这组数字以何种顺序排列,总和始终不变,这就体现了某种程度的排列不变性。在机器学习中,某些模型可能要求具有排列不变性,以确保对输入数据的不同排列方式具有相同的输出结果。
简化下就是4步

“Image Patching and Embedding”图像分块与嵌入,即将图像分割成小块并进行特定的编码嵌入操作。
“Positional Encoding”即位置编码,用于为模型提供序列中元素的位置信息。
“Transformer Encoder”编码器
“Classification Head (MLP Head)”是分类头(多层感知机头部),通常在模型的最后用于对输入进行分类任务。

详解点

Vision Transformer (ViT) 是一种将 Transformer 架构应用于计算机视觉任务的模型。它借鉴了自然语言处理 (NLP) 中 Transformer 的成功经验,旨在用 Transformer 替代传统卷积神经网络 (CNN),用于图像分类等任务。


1. 从图像到序列:图像分块与嵌入 (Image Patching and Embedding)

由于 Transformer 的输入需要是序列,而图像是二维网格数据,因此需要先将图像转换为序列形式。

  1. 图像分块

    • 将输入图像 X ∈ R H × W × C X \in \mathbb{R}^{H \times W \times C} XRH×W×C 划分为固定大小的非重叠小块(patches)。
      假设每个小块的大小为 P × P P \times P P×P,图像就被划分成 N = H × W P 2 N = \frac{H \times W}{P^2} N=P2H×W 个小块,每个小块展平为一个向量。
      每个向量的大小为 P × P × C P \times P \times C P×P×C,其中 C C C 是图像的通道数(如 RGB 图像中 C = 3 C=3 C=3)。
  2. 线性嵌入

    • 使用一个可学习的线性投影矩阵 E ∈ R ( P 2 ⋅ C ) × D E \in \mathbb{R}^{(P^2 \cdot C) \times D} ER(P2C)×D 将每个展平的小块嵌入到 D D D 维空间中。
    • 结果是一个大小为 N × D N \times D N×D 的嵌入序列,类似于 NLP 中的词嵌入(word embedding)。

2. 位置编码 (Positional Encoding)

由于 Transformer 对输入序列是排列不变的(Permutation Invariant),它无法直接利用图像中的空间关系。因此,需要为嵌入序列添加位置信息。

  • 使用可学习的位置嵌入 P ∈ R N × D P \in \mathbb{R}^{N \times D} PRN×D 或固定的位置编码,将每个小块的嵌入与其位置对应起来。
  • 最终输入变为 Z 0 = [ x 1 + p 1 ; x 2 + p 2 ; … ; x N + p N ] Z_0 = [x_1 + p_1; x_2 + p_2; \dots; x_N + p_N] Z0=[x1+p1;x2+p2;;xN+pN],其中 x i x_i xi 是第 i i i 个小块的嵌入, p i p_i pi 是其位置嵌入。

3. Transformer 编码器 (Transformer Encoder)

ViT 的核心部分是标准 Transformer 编码器,它包含以下模块:

  1. 多头自注意力机制 (Multi-Head Self-Attention, MHSA)

    • 通过自注意力机制捕获小块之间的全局关系,无需限制在局部感受野内操作(如 CNN)。
    • 每个小块与其他小块计算注意力权重,以关注哪些部分对当前任务最重要。
  2. 前馈神经网络 (Feed-Forward Neural Network, FFN)

    • 对每个小块的表示单独进行非线性变换,提升模型的表达能力。
  3. 残差连接与层归一化 (Residual Connections and Layer Normalization)

    • 使用残差连接缓解梯度消失问题,同时通过层归一化稳定训练。

多个 Transformer 编码器堆叠后,输出序列保持大小不变,为 N × D N \times D N×D


4. 分类头 (Classification Head)
  1. 类别标记 (Class Token)

    • 在输入序列中引入一个可学习的 [CLS] 标记,其嵌入表示整个图像的信息。
    • 类别标记通过编码器与其他小块交互,最终用于分类任务。
  2. 多层感知机 (MLP Head)

    • 最终 [CLS] 标记的输出经过一个多层感知机(通常是全连接层)进行分类,得到最终的预测结果。

5. 预训练与微调 (Pretraining and Finetuning)
  1. 预训练

    • ViT 通常在大规模数据集(如 ImageNet-21k 或 JFT-300M)上进行监督学习预训练。
    • 通过大量的标注数据,模型学习到丰富的视觉特征。
  2. 微调

    • 在特定任务的小型数据集上进行微调,例如 CIFAR-10、ImageNet 等。

ViT 的优点与局限性

优点
  1. 全局感受野

    • 自注意力机制可以直接建模全局依赖关系,而 CNN 的局部感受野需要通过多层叠加才能实现。
  2. 更少的归纳偏差

    • ViT 依赖数据学习到特征,而非依赖 CNN 的平移不变性等归纳偏差,更适合大规模数据。
  3. 灵活性

    • 不受限于卷积核的大小,可适配不同的输入尺寸和任务。
局限性
  1. 数据需求大

    • 由于缺乏 CNN 的归纳偏差,ViT 在小数据集上表现较差,需要大量预训练数据。
  2. 计算成本高

    • 多头自注意力的计算复杂度为 O ( N 2 ⋅ D ) O(N^2 \cdot D) O(N2D),对高分辨率图像或较多的分块数量计算开销较大。
import math
from collections import OrderedDict
from typing import Callable, List, Optionalimport torch
import torch.nn as nn
from torch import Tensor# 假设 Encoder 和 Conv2dNormActivation 已经被定义,并且 ConvStemConfig 是一个数据类或命名元组。
# 这些通常来自其他文件或库,例如 torchvision 或自定义实现。class VisionTransformer(nn.Module):"""Vision Transformer 如 https://arxiv.org/abs/2010.11929 所述."""def __init__(self,image_size: int,  # 输入图像的大小(假设为正方形)patch_size: int,  # 每个patch的大小(也假设为正方形)num_layers: int,  # 编码器中的层数num_heads: int,   # 注意力机制中的头数hidden_dim: int,  # 隐藏层维度mlp_dim: int,     # MLP 层的维度dropout: float = 0.0,  # Dropout 概率attention_dropout: float = 0.0,  # 注意力机制中的dropout概率num_classes: int = 1000,  # 分类的数量representation_size: Optional[int] = None,  # 表示层的大小,可选norm_layer: Callable[..., torch.nn.Module] = partial(nn.LayerNorm, eps=1e-6),  # 归一化层conv_stem_configs: Optional[List[ConvStemConfig]] = None,  # 卷积干配置,可选):super().__init__()_log_api_usage_once(self)  # 记录API使用情况torch._assert(image_size % patch_size == 0, "输入形状不能被patch大小整除!")  # 确保图像尺寸能被patch大小整除self.image_size = image_sizeself.patch_size = patch_sizeself.hidden_dim = hidden_dimself.mlp_dim = mlp_dimself.attention_dropout = attention_dropoutself.dropout = dropoutself.num_classes = num_classesself.representation_size = representation_sizeself.norm_layer = norm_layerif conv_stem_configs is not None:# 根据论文 https://arxiv.org/abs/2106.14881 使用卷积干seq_proj = nn.Sequential()  # 创建一个序列容器来保存卷积干的层prev_channels = 3  # 初始通道数为3(RGB图像)for i, conv_stem_layer_config in enumerate(conv_stem_configs):# 对于每个卷积干配置,添加一个卷积、归一化和激活层seq_proj.add_module(f"conv_bn_relu_{i}",Conv2dNormActivation(in_channels=prev_channels,out_channels=conv_stem_layer_config.out_channels,kernel_size=conv_stem_layer_config.kernel_size,stride=conv_stem_layer_config.stride,norm_layer=conv_stem_layer_config.norm_layer,activation_layer=conv_stem_layer_config.activation_layer,),)prev_channels = conv_stem_layer_config.out_channels# 添加最后一个1x1卷积层,将通道数转换为隐藏维度seq_proj.add_module("conv_last", nn.Conv2d(in_channels=prev_channels, out_channels=hidden_dim, kernel_size=1))self.conv_proj: nn.Module = seq_proj  # 将卷积干设置为self.conv_projelse:# 如果没有提供卷积干配置,则使用简单的卷积投影self.conv_proj = nn.Conv2d(in_channels=3, out_channels=hidden_dim, kernel_size=patch_size, stride=patch_size)# 计算序列长度(包括类别标记)seq_length = (image_size // patch_size) ** 2# 添加一个类别标记self.class_token = nn.Parameter(torch.zeros(1, 1, hidden_dim))seq_length += 1  # 类别标记增加序列长度# 初始化编码器self.encoder = Encoder(seq_length,num_layers,num_heads,hidden_dim,mlp_dim,dropout,attention_dropout,norm_layer,)self.seq_length = seq_length# 定义分类头heads_layers: OrderedDict[str, nn.Module] = OrderedDict()if representation_size is None:heads_layers["head"] = nn.Linear(hidden_dim, num_classes)else:heads_layers["pre_logits"] = nn.Linear(hidden_dim, representation_size)heads_layers["act"] = nn.Tanh()heads_layers["head"] = nn.Linear(representation_size, num_classes)self.heads = nn.Sequential(heads_layers)# 初始化权重if isinstance(self.conv_proj, nn.Conv2d):# 初始化patchify stem的权重fan_in = self.conv_proj.in_channels * self.conv_proj.kernel_size[0] * self.conv_proj.kernel_size[1]nn.init.trunc_normal_(self.conv_proj.weight, std=math.sqrt(1 / fan_in))if self.conv_proj.bias is not None:nn.init.zeros_(self.conv_proj.bias)elif self.conv_proj.conv_last is not None and isinstance(self.conv_proj.conv_last, nn.Conv2d):# 初始化卷积干中最后的1x1卷积层nn.init.normal_(self.conv_proj.conv_last.weight, mean=0.0, std=math.sqrt(2.0 / self.conv_proj.conv_last.out_channels))if self.conv_proj.conv_last.bias is not None:nn.init.zeros_(self.conv_proj.conv_last.bias)if hasattr(self.heads, "pre_logits") and isinstance(self.heads.pre_logits, nn.Linear):fan_in = self.heads.pre_logits.in_featuresnn.init.trunc_normal_(self.heads.pre_logits.weight, std=math.sqrt(1 / fan_in))nn.init.zeros_(self.heads.pre_logits.bias)if isinstance(self.heads.head, nn.Linear):nn.init.zeros_(self.heads.head.weight)nn.init.zeros_(self.heads.head.bias)def _process_input(self, x: Tensor) -> Tensor:# 获取输入张量的形状信息n, c, h, w = x.shapep = self.patch_sizetorch._assert(h == self.image_size, f"错误的图像高度!预期 {self.image_size} 但得到 {h}!")torch._assert(w == self.image_size, f"错误的图像宽度!预期 {self.image_size} 但得到 {w}!")n_h = h // pn_w = w // p# 使用卷积投影将图像划分为patch,并调整形状# (n, c, h, w) -> (n, hidden_dim, n_h, n_w)x = self.conv_proj(x)# (n, hidden_dim, n_h, n_w) -> (n, hidden_dim, (n_h * n_w))x = x.reshape(n, self.hidden_dim, n_h * n_w)# 转换到 (N, S, E) 的格式,其中 N 是批次大小,S 是源序列长度,E 是嵌入维度# (n, hidden_dim, (n_h * n_w)) -> (n, (n_h * n_w), hidden_dim)x = x.permute(0, 2, 1)return xdef forward(self, x: Tensor):# 对输入张量进行重塑和转置x = self._process_input(x)n = x.shape[0]# 将类别标记扩展到整个批次,并与patch序列连接batch_class_token = self.class_token.expand(n, -1, -1)x = torch.cat([batch_class_token, x], dim=1)# 通过编码器处理x = self.encoder(x)# 只取每个样本的第一个token(即类别标记)作为表示x = x[:, 0]# 通过分类头进行最终预测x = self.heads(x)return x

参考
https://ai.googleblog.com/2020/12/transformers-for-image-recognition-at.html
论文 Worth 16x16 Words: Transformers for Image Recognition at Scale
https://github.com/pytorch/vision/blob/main/torchvision/models/vision_transformer.py

相关文章:

Vision Transformer (ViT)原理

Vision Transformer (ViT)原理 flyfish Transformer缺乏卷积神经网络(CNNs)的归纳偏差(inductive biases),比如平移不变性和局部受限的感受野。不变性意味着即使实体entity(即对象)的外观或位…...

移动云自研云原生数据库入围国采!

近日,中央国家机关2024年度事务型数据库软件框架协议联合征集采购项目产品名单正式公布,移动云自主研发的云原生数据库产品顺利入围。这一成就不仅彰显了移动云在数据库领域深耕多年造就的领先技术优势,更标志着国家权威评审机构对移动云在数…...

Unity中对象池的使用(用一个简单粗暴的例子)

问题描述:Unity在创建和销毁对象的时候是很消耗性能的,所以我们在销毁一个对象的时候,可以不用Destroy,而是将这个物体隐藏后放到回收池里面,当再次需要的时候如果回收池里面有之前回收的对象,就直接拿来用…...

linux命令行连接Postgresql常用命令

1.linux系统命令行连接数据库命令 psql -h hostname -p port -U username -d databasename -h 主机名或IP地址 -p 端口 -U 用户名 -d 连接的数据库 2.查询数据库表命令 select version() #查看版本号 \dg #查看用户 \l #查询数据库 \c mydb #切换…...

每日一题-单链表排序

为了对给定的单链表按升序排序,我们可以考虑以下解决方法: 思路 归并排序(Merge Sort):由于归并排序的时间复杂度为 O ( n log ⁡ n ) O(n \log n) O(nlogn),并且归并排序不需要额外的空间(空…...

webpack04服务器配置

webpack配置 entryoutput filenamepathpublicPath 。。 打包引入的基本路径,,,比如引入一个bundle.js,。引用之后的路径就是 publicPathfilename -devServer:static : 静态文件的位置。。。hostportopencompress : 静态资源是否用gzip压缩hi…...

JDK下载安装配置

一.JDK安装配置。 1.安装注意路径,其他直接下一步。 2.配置。 下接第4步. 或者 代码复制: JAVA_HOME D:\Program Files\Java\jdk1.8.0_91 %JAVA_HOME%\bin 或者直接配置 D:\Program Files\Java\jdk1.8.0_91\bin 3.验证(CMD)。 java javac java -version javac -version 二.下…...

30_Redis哨兵模式

在Redis主从复制模式中,因为系统不具备自动恢复的功能,所以当主服务器(master)宕机后,需要手动把一台从服务器(slave)切换为主服务器。在这个过程中,不仅需要人为干预,而且还会造成一段时间内服务器处于不可用状态,同时数据安全性也得不到保障,因此主从模式的可用性…...

NLP三大特征抽取器:CNN、RNN与Transformer全面解析

引言 自然语言处理(NLP)领域的快速发展离不开深度学习技术的推动。随着应用需求的不断增加,如何高效地从文本中抽取特征成为NLP研究中的核心问题。深度学习中三大主要特征抽取器——卷积神经网络(Convolutional Neural Network, …...

《使用 YOLOV8 和 KerasCV 进行高效目标检测》

《使用 YOLOV8 和 KerasCV 进行高效目标检测》 作者:Gitesh Chawda创建日期:2023/06/26最后修改时间:2023/06/26描述:使用 KerasCV 训练自定义 YOLOV8 对象检测模型。 (i) 此示例使用 Keras 2 在 Colab 中…...

从MySQL迁移到PostgreSQL的完整指南

1.引言 在现代数据库管理中,选择合适的数据库系统对业务的成功至关重要。随着企业数据量的增长和对性能要求的提高,许多公司开始考虑从MySQL迁移到PostgreSQL。这一迁移的主要原因包括以下几个方面: 1.1 性能和扩展性 PostgreSQL以其高性能…...

服务器一次性部署One API + ChatGPT-Next-Web

服务器一次性部署One API ChatGPT-Next-Web One API ChatGPT-Next-Web 介绍One APIChatGPT-Next-Web docker-compose 部署One API ChatGPT-Next-WebOpen API docker-compose 配置ChatGPT-Next-Web docker-compose 配置docker-compose 启动容器 后续配置 同步发布在个人笔记服…...

51单片机 和 STM32 的烧录方式和通信协议的区别

51单片机 和 STM32 的烧录方式和通信协议的区别 1. 为什么51单片机需要额外的软件(如ISP)? (1)51单片机的烧录方式 ISP(In-System Programming): 51单片机通常通过 串口&#xff08…...

(STM32笔记)十二、DMA的基础知识与用法 第二部分

我用的是正点的STM32F103来进行学习,板子和教程是野火的指南者。 之后的这个系列笔记开头未标明的话,用的也是这个板子和教程。 DMA的基础知识与用法 二、DMA传输设置1、数据来源与数据去向外设到存储器存储器到外设存储器到存储器 2、每次传输大小3、传…...

【优选算法篇】:模拟算法的力量--解决复杂问题的新视角

✨感谢您阅读本篇文章,文章内容是个人学习笔记的整理,如果哪里有误的话还请您指正噢✨ ✨ 个人主页:余辉zmh–CSDN博客 ✨ 文章所属专栏:优选算法篇–CSDN博客 文章目录 一.模拟算法二.例题1.替换所有的问号2.提莫攻击3.外观数列4…...

探秘 JMeter (Interleave Controller)交错控制器:解锁性能测试的隐藏密码

嘿,小伙伴们!今天咱们要把 JMeter 里超厉害的 Interleave Controller(交错控制器)研究个透,让你从新手直接进阶成高手,轻松拿捏各种性能测试难题! 一、Interleave Controller 深度剖析 所属家族…...

脚本化挂在物理盘、nfs、yum、pg数据库、nginx(已上传脚本)

文章目录 前言一、什么是脚本化安装二、使用步骤1.物理磁盘脚本挂载(离线)2.yum脚本化安装(离线)3.nfs脚本化安装(离线)4.pg数据库脚本化安装(离线)5.nginx脚本化安装(离…...

ESP嵌入式开发环境安装

前期准备,虚拟机,ios镜像,VSCode。 centOS8:centos安装包下载_开源镜像站-阿里云 虚拟机:vmware VSCode:Visual Studio Code - Code Editing. Redefined 如何安装镜像自行查找 完成以上环境后进行一下操…...

Elasticsearch入门学习

Elasticsearch是什么 Elasticsearch 是一个基于 Apache Lucene 构建的分布式搜索和分析引擎、可扩展的数据存储和矢量数据库。 它针对生产规模工作负载的速度和相关性进行了优化。 使用 Elasticsearch 近乎实时地搜索、索引、存储和分析各种形状和大小的数据。 特点 分布式&a…...

黑马linux笔记(03)在Linux上部署各类软件 MySQL5.7/8.0 Tomcat(JDK) Nginx RabbitMQ

文章目录 实战章节:在Linux上部署各类软件tar -zxvf各个选项的含义 为什么学习各类软件在Linux上的部署 一 MySQL数据库管理系统安装部署【简单】MySQL5.7版本在CentOS系统安装MySQL8.0版本在CentOS系统安装MySQL5.7版本在Ubuntu(WSL环境)系统…...

KubeSphere 容器平台高可用:环境搭建与可视化操作指南

Linux_k8s篇 欢迎来到Linux的世界,看笔记好好学多敲多打,每个人都是大神! 题目:KubeSphere 容器平台高可用:环境搭建与可视化操作指南 版本号: 1.0,0 作者: 老王要学习 日期: 2025.06.05 适用环境: Ubuntu22 文档说…...

国防科技大学计算机基础课程笔记02信息编码

1.机内码和国标码 国标码就是我们非常熟悉的这个GB2312,但是因为都是16进制,因此这个了16进制的数据既可以翻译成为这个机器码,也可以翻译成为这个国标码,所以这个时候很容易会出现这个歧义的情况; 因此,我们的这个国…...

Python:操作 Excel 折叠

💖亲爱的技术爱好者们,热烈欢迎来到 Kant2048 的博客!我是 Thomas Kant,很开心能在CSDN上与你们相遇~💖 本博客的精华专栏: 【自动化测试】 【测试经验】 【人工智能】 【Python】 Python 操作 Excel 系列 读取单元格数据按行写入设置行高和列宽自动调整行高和列宽水平…...

STM32F4基本定时器使用和原理详解

STM32F4基本定时器使用和原理详解 前言如何确定定时器挂载在哪条时钟线上配置及使用方法参数配置PrescalerCounter ModeCounter Periodauto-reload preloadTrigger Event Selection 中断配置生成的代码及使用方法初始化代码基本定时器触发DCA或者ADC的代码讲解中断代码定时启动…...

中医有效性探讨

文章目录 西医是如何发展到以生物化学为药理基础的现代医学?传统医学奠基期(远古 - 17 世纪)近代医学转型期(17 世纪 - 19 世纪末)​现代医学成熟期(20世纪至今) 中医的源远流长和一脉相承远古至…...

GruntJS-前端自动化任务运行器从入门到实战

Grunt 完全指南:从入门到实战 一、Grunt 是什么? Grunt是一个基于 Node.js 的前端自动化任务运行器,主要用于自动化执行项目开发中重复性高的任务,例如文件压缩、代码编译、语法检查、单元测试、文件合并等。通过配置简洁的任务…...

华为OD机试-最短木板长度-二分法(A卷,100分)

此题是一个最大化最小值的典型例题, 因为搜索范围是有界的,上界最大木板长度补充的全部木料长度,下界最小木板长度; 即left0,right10^6; 我们可以设置一个候选值x(mid),将木板的长度全部都补充到x,如果成功…...

【Linux手册】探秘系统世界:从用户交互到硬件底层的全链路工作之旅

目录 前言 操作系统与驱动程序 是什么,为什么 怎么做 system call 用户操作接口 总结 前言 日常生活中,我们在使用电子设备时,我们所输入执行的每一条指令最终大多都会作用到硬件上,比如下载一款软件最终会下载到硬盘上&am…...

【Veristand】Veristand环境安装教程-Linux RT / Windows

首先声明,此教程是针对Simulink编译模型并导入Veristand中编写的,同时需要注意的是老用户编译可能用的是Veristand Model Framework,那个是历史版本,且NI不会再维护,新版本编译支持为VeriStand Model Generation Suppo…...

数据结构:递归的种类(Types of Recursion)

目录 尾递归(Tail Recursion) 什么是 Loop(循环)? 复杂度分析 头递归(Head Recursion) 树形递归(Tree Recursion) 线性递归(Linear Recursion)…...