Vision Transformer(ViT)模型原理及PyTorch逐行实现
Vision Transformer(ViT)模型原理及PyTorch逐行实现
一、TRM模型结构
1.Encoder
- Position Embedding 注入位置信息
- Multi-head Self-attention 对各个位置的embedding融合(空间融合)
- LayerNorm & Residual
- Feedforward Neural Network 对每个位置上单独仿射变换(通道融合)
- Linear1(large)
- Linear2(d_model)
- LayerNorm & Residual
2.Decoder
- Position Embedding
- Casual Multi-head Self-attention
- LayerNorm & Residual
- Memory-base Multi-head Cross-attention
- LayerNorm & Residual
- Feedforward Neural Network
- Linear1(large)
- Linear2(d_model)
- LayerNorm & Residual
二、TRM使用类型
- Encoder only 【 ViT 所使用的】
- BERT、分类任务、非流式任务
- Decoder only
- GPT系列、语言建模、自回归生成任务、流式任务
- Encoder-Decoder
- 机器翻译、语音识别
三、TRM特点
- 无先验假设(例如:局部关联性、有序建模性)
- 核心计算在于自注意力机制,平方复杂度
- 数据量的要求与归纳偏置【人类通过归纳法得到的经验,把这些经验带入到模型中,很多事物的共性】的引入成反比
四、Vision Transformer(ViT)
- DNN perspective 图像的信息量主要还是聚集在一块区域上
- image2patch 将图片切分成很多个块
- patch2embedding 将每个块转换为向量
- CNN perspective 从卷积的角度得到向量
- 2D convolution over image 二维卷积
- flatten the output feature map 把输出的卷积图拉直
- class token embedding 占位符
- position embedding
- interpolation when inference
- Transformer Encoder 只使用的Encoder
- classification head 最后分类
五、ViT论文讲解
首先将一副图片分为很多个块,每个块的大小都是不会变化的,图片即使大一点,只是序列更长一点。先左到右,再上到下,把图片拉直成一个序列的形状。把每个块中的像素点进行归一化,范围变为0到1之间,再把块里面的所有值通过一个线性变换映射到模型的维度,得到patchembedding,得到以后,我们为了做分类任务,还需要在序列的开头加上一个可训练的embedding,这个是随机初始化的。这样就构造出了一个n+1长度的序列,然后我们再加入position embedding,加上后的这个序列的表征就可以送入到TRM的encoder当中,最后取出结果中的我们加入的可训练的embedding位置上的值(输出状态),经过一个MLP,得到各个类别的概率分布,再通过一个交叉熵函数算出分类的loss,这样就完成了一个ViT模型的搭建。
六、代码实现
1.convert image to embedding vector sequence
1.通过DNN实现
import torch
import torch.nn as nn
import torch.nn.functional as Fdef image2emb_naive(image,patch_size,weight):# image shape: bs*channel*h*wpatch = F.unfold(image,kernel_size=patch_size,stride=patch_size).transpose(-1,-2)patch_embedding = patch @ weightreturn patch_embedding# test code for image2emb
bs,ic,image_h,image_w=1,3,8,8
patch_size=4 # 每个块的大小为4*4(自定义)
model_dim=8 #将每个块映射成长度为8的向量(自定义)
patch_depth=patch_size*patch_size*ic
image=torch.randn(bs,ic,image_h,image_w) #初始化
weight=torch.randn(patch_depth,model_dim)#初始化patch_embedding_navie=image2emb_navie(image,patch_size,weight)
print(patch_embedding_naive.shape) # [1,4,8],分成四块了,每块对应一个长度为8的向量
2.通过CNN实现
import torch
import torch.nn as nn
import torch.nn.functional as Fdef image2emb_conv(image,kernel,stride):conv_output=F.conv2d(image,kernel,stride=stride) # bs*oc*oh*owbs,oc,oh,ow=conv_output.shapepatch_embedding=conv_output.reshape((bs,oc,oh*ow)).transpose(-1,-2)return patch_embedding# test code for image2emb
bs,ic,image_h,image_w=1,3,8,8
patch_size=4
model_dim=8
patch_depth=patch_size*patch_size*ic
image=torch.randn(bs,ic,image_h,image_w)
weight=torch.randn(patch_depth,model_dim) #model_dim是输出通道数目,patch_depth是卷积核的面积乘以输入通道数kernel=weight.transpose(0,1).reshape((-1,ic,patch_size,patch_size)) # oc*ic*kh*kw
patch_embedding_conv=image2emb_conv(image,kernel,patch_size) # 二维卷积的方法得到embedding
2.prepend CLS token embedding
cls_token_embedding = torch.randn(1,model_dim,requires_grad=True)
token_embedding = torch.cat([[bs,cls_token_embedding],patch_embedding_conv],dim=1)
提问:本身cls_token_embedding没有和任何样本矩阵有乘法联系,最后训练出来的也是一张确定的表,在做inference的时候,完全是一个常数的作用。送入transformer后,又与其他矩阵做了MHA,没搞懂用意何在啊?
答:有联系啊,就是与其他时刻的sample做MHSA。这个token其实是取代了avg pool的作用,也就是说,你可以用avg pool得到分类的logits,也可以用采用cls token来得到分类的logits
注意:cls_token_embedding作为batch_size中每一个序列的开始,应该对于每一个序列的开始都torch.cat同样的一个cls_token_embedding,然后都是对这同一个cls_token_embedding进行训练,所以这里的cls token embedding应该是二维的,1*model_dim,与batchsize无关。
3.add position embedding
max_num_token=16 #自定义
position_embedding_table = torch.randn(max_num_token,model_dim,requires_grad=True)
seq_len=token_embedding.shape[1] # 刚刚的1+4
position_embedding=torch.tile(position_embedding_table[:seq_len],[token_embedding.shape[0],1,1]) # 5,bs,1,1
token_embedding += position_embedding
4.pass embedding to Transformer Encoder
encoder_layer = nn.TransformerEncoderLayer(d_model=model_dim,nhead=8)
transformer_encoder=nn.TransformerEncoder(encoder_layer,num_layers=6)
encoder_output=transformer_encoder(token_embedding)
5.do classification
cls_token_output=encoder_output[:,0,:] #拿到TRM的输出值
num_classes=10 # 自定义的类别数目
label=torch.randint(10,(bs,)) # 自定义的生成的label
linear_layer = nn.Linear(model_dim,num_classes)
logits = linear_layer(cls_token_output)
loss_fn=nn.CrossEntropyLoss()
loss=loss_fn(logits,label)
print(loss)
相关文章:

Vision Transformer(ViT)模型原理及PyTorch逐行实现
Vision Transformer(ViT)模型原理及PyTorch逐行实现 一、TRM模型结构 1.Encoder Position Embedding 注入位置信息Multi-head Self-attention 对各个位置的embedding融合(空间融合)LayerNorm & ResidualFeedforward Neural Network 对每个位置上单…...

828华为云征文 | Flexus X实例CPU、内存及磁盘性能实测与分析
引言 随着云计算的普及,企业对于云资源的需求日益增加,而选择一款性能强劲、稳定性高的云实例成为了关键。华为云Flexus X实例作为华为云最新推出的高性能实例,旨在为用户提供更强的计算能力和更高的网络带宽支持。最近华为云828 B2B企业节正…...

FreeRTOS学习笔记(六)队列
提示:文章写完后,目录可以自动生成,如何生成可参考右边的帮助文档 文章目录 前言一、队列的基本内容1.1 队列的引入1.2 FreeRTOS 队列的功能与作用1.3 队列的结构体1.4 队列的使用流程 二、相关API详解2.1 xQueueCreate2.2 xQueueSend2.3 xQu…...

【Python篇】PyQt5 超详细教程——由入门到精通(中篇一)
文章目录 PyQt5入门级超详细教程前言第4部分:事件处理与信号槽机制4.1 什么是信号与槽?4.2 信号与槽的基本用法4.3 信号与槽的基础示例代码详解: 4.4 处理不同的信号代码详解: 4.5 自定义信号与槽代码详解: 4.6 信号槽…...

LinuxQt下的一些坑之一
我们在使用Qt开发时,经常会遇到Windows上应用正常,但到Linux嵌入式下就会出现莫名奇妙的问题。这篇文章就举例分析下: 1.QPushButton按钮外侧虚线框问题 Windows下QPushButton按钮设置样式正常,但到了Linux下就会有一个虚线边框。…...

Statement batch
我们可以看到 Statement 和 PreparedStatement 为我们提供的批次执行 sql 操作 JDBC 引入上述 batch 功能的主要目的,是加快对客户端SQL的执行和响应速度,并进而提高数据库整体并发度,而 jdbc batch 能够提高对客户端SQL的执行和响应速度,其…...

PPP 、PPPoE 浅析和配置示例
一、名词: PPP: Point to Point Protocol 点到点协议 LCP:Link Control Protocol 链路控制协议 NCP:Network Control Protocol 网络控制协议,对于上层协议的支持,N 可以为IPv4、IPv6…...

【Python机器学习】词向量推理——词向量
目录 面向向量的推理 使用词向量的更多原因 如何计算Word2vec表示 skip-gram方法 什么是softmax 神经网络如何学习向量表示 用线性代数检索词向量 连续词袋方法 skip-gram和CBOW:什么时候用哪种方法 word2vec计算技巧 高频2-gram 高频词条降采样 负采样…...

Python 语法糖:让编程更简单(续二)
Python 语法糖:让编程更简单(续) 10. Type hints Type hints 是 Python 中的一种语法糖,用于指定函数或变量的类型。例如: def greet(name: str) -> None:print(f"Hello, {name}!")这段代码将定义一个…...

6 - Shell编程之sed与awk编辑器
目录 一、sed 1.概述 2.sed命令格式 3.常用操作的语法演示 3.1 输出符合条件的文本 3.2 删除符合条件的文本 3.3 替换符合条件的文本 3.4 插入新行 二、awk 1.概述 2. awk命令格式 3.awk工作过程 4.awk内置变量 5.awk用法示例 5.1 按行输出文本 5.2 按字段输出文…...

什么是XML文件,以及如何打开和转换为其他文件格式
本文描述了什么是XML文件以及它们在哪里使用,哪些程序可以打开XML文件,以及如何将XML文件转换为另一种基于文本的格式,如JSON、PDF或CSV。 什么是XML文件 XML文件是一种可扩展标记语言文件。它们是纯文本文件,除了描述数据的传输、结构和存储外,本身什么也不做。 RSS提…...

海外直播对网速、带宽、安全的要求
要满足海外直播的要求,需要拥有合适的网络配置。在全球化的浪潮下,海外直播正逐渐成为企业、个人和各类组织的重要工具。不论是用于市场推广、品牌宣传,还是与观众互动,海外直播都为参与者带来了丰富的机会。然而,确保…...

UWB定位室外基站
定位基站,型号SW,是一款基于无线脉冲技术开发的UWB定位基站,基站可用于人员、车辆、物资的精确定位, 该基站专为恶劣环境使用而设计,防尘、防水等级IP67,工业级标准支持365天连续运行,本安防爆可…...

高斯平面直角坐标讲解,以及地理坐标转换高斯平面直角坐标
高斯平面直角坐标系(Gauss-Krger 坐标系)是基于 高斯-克吕格投影 的一种常见的平面坐标系统,主要用于地理信息系统 (GIS)、测绘和工程等领域。该坐标系将地球表面的经纬度(地理坐标)通过一种投影方式转换为平面直角坐标,以便在二维平面中进行距离、面积和角度的计算。 一…...

C++入门(06)安装QT并快速测试体验一个简单的C++GUI项目
文章目录 1. 清华镜像源下载2. 安装3. 开始菜单上的 QT 工具4. 打开 Qt Creator5. 简单的 GUI C 项目5.1 打开 Qt Creator 并创建新项目5.2 设计界面5.3 添加按钮的点击事件5.4 编译并运行项目 6. 信号和槽(Signals and Slots) 这里用到了C类与对象的很多…...

一篇文章告诉你小程序为什么最近这么火?
微信小程序之所以最近这么火,主要得益于其低成本获取高流量、线上线下流量互换、社交裂变引爆流量以及封闭商业生态闭环等优势。下面将详细探讨小程序火爆的多个原因: 一篇文章告诉你小程序为什么这么火爆? 低成本获取高流量 无需安装注册&…...

Qt-常用控件(3)-多元素控件、容器类控件和布局管理器
1. 多元素控件 Qt 中提供的多元素控件有: QListWidgetQListViewQTableWidgetQTableViewQTreeWidgetQTreeView xxWidget 和 xxView 之间的区别,以 QTableWidget 和 QTableView 为例. QTableView 是基于 MVC 设计的控件.QTableView 自身不持有数据,使用 QTableView 的…...

【系统设计】主动查询与主动推送:如何选择合适的数据传输策略
基本描述总结 主动查询机制:系统A主动向系统B请求数据,采用严格的权限控制和身份认证,防止未授权的数据访问。数据在传输过程中使用TLS加密,并通过动态脱敏处理隐藏敏感信息。 推送机制:系统B在数据更新时主动向系统…...

mac 安装brew并配置国内源
前置条件 - Xcode 命令行工具 一行代码安装Homebrew 添加到路径(PATH) - zsh shell为例 背景介绍 最近重装了我的MAC mini (m1 芯片), 很多软件都需要重新安装,因为后续还需要安装一些软件,所以想着安装个包管理软件 什么…...

Temu官方宣导务必将所有的点位材料进行检测-RSL资质检测
关于饰品类产品合规问题宣导: 产品法规RSL要求 RSL测试是根据REACH法规及附录17的要求进行测试。REACH法规是欧洲一项重要的法规,其中包含许多对化学物质进行限制的规定和高度关注物质。 为了确保珠宝首饰的安全性,欧盟REACH法规规定&#…...

mysql高级sql
文章目录 一,查询1.按关键字排序1.1按关键字排序操作(1)按分数排序查询(不加asc默认为升序)(2)按分数降序查询(DESC)(3)使用where进行条件查询(4)使用ORDER BY语句对多个字段排序 1.2使用区间判断查询(and/…...

Linux CentOS 7.9 安装mysql8
1、新建mysql文件夹 数据比较大,所以我在服务器另外挂了一个盘装mysql,和默认安装一个道理,换路径即可 cd ../ //创建文件夹 mkdir mysql //进入mysql文件夹 cd mysql 2、下载mysql8.0安装包并解压、重命名 //下载安装包 wget https://dev…...

替代 Django 默认 User 模型并使用 `django-mysql` 添加数据库备注20240904
替代 Django 默认 User 模型并使用 django-mysql 添加数据库备注 前言 在 Django 项目开发中,默认的 User 模型虽然能够满足许多基础需求,但在实际项目中我们常常需要对用户模型进行定制化。通过覆盖默认的 User 模型,我们可以根据具体的业…...

三维激光扫描点云配准外业棋盘的布设与棋盘坐标测量
文章目录 一、棋盘标定板准备二、棋盘标定板布设三、棋盘标定板坐标测量一、棋盘标定板准备 三维激光扫描棋盘是用来校准和校正激光扫描仪的重要工具,主要用于提高扫描精度。棋盘标定板通常具有以下特点: 高对比度图案:通常是黑白相间的棋盘格,便于识别。已知尺寸:每个格…...

【Python知识宝库】文件操作:读写文件的最佳实践
🎬 鸽芷咕:个人主页 🔥 个人专栏: 《C干货基地》《粉丝福利》 ⛺️生活的理想,就是为了理想的生活! 文章目录 前言一、文件读取1. 使用open函数2. 逐行读取3. 使用readlines和readline 二、文件写入1. 写入文本2. 追加内容3. 写入…...

Chapter 13 普通组件的注册使用
欢迎大家订阅【Vue2Vue3】入门到实践 专栏,开启你的 Vue 学习之旅! 文章目录 前言一、组件创建二、局部注册三、全局注册 前言 在 Vue.js 中,组件是构建应用程序的基本单元。本章详细讲解了注册和使用 Vue 的普通组件的两种方式:…...

u盘显示需要格式化才能用预警下的数据拯救恢复指南
U盘困境:需要格式化的紧急应对 在数字信息爆炸的时代,U盘作为便携的数据存储介质,承载着我们工作、学习乃至生活中的大量重要资料。然而,当U盘突然弹出“需要格式化才能用”的提示时,这份便捷瞬间转化为焦虑与不安。这…...

还不懂BIO,NIO,AIO吗
BIO(Blocking I/O)、NIO(Non-blocking I/O)和 AIO(Asynchronous I/O)是 Java 中三种不同的 I/O 模型,主要用于处理输入 / 输出操作。 一、BIO(Blocking I/O) 定义与工作原…...

物联网——DMA+AD多通道
DMA简介 存储器映像 某些数据在运行时不会发生变化,则设置为常量,存在Flash存储器中,节省运行内存的空间 DMA结构图 DMA访问权限高于cpu 结构要素 软件触发源:存储器到存储器传输完成后,计数器清零 硬件触发源&…...

Vue 中 watch 和 watchEffect 的区别
watch 和 watcheffect 都是 vue 中用于监视响应式数据的 api,它们的区别在于:watch 用于监视特定响应式属性并执行回调函数。watcheffect 用于更通用的响应式数据监视,但回调函数中不能更新响应式数据。Vue 中 watch 和 watchEffect 的区别 …...