当前位置: 首页 > news >正文

Pytorch从零开始实现Vision Transformer (from scratch)

Pytorch从零开始实现Vision Transformer

  • 前言
  • 一、Vision Transformer架构介绍
    • 1. Patch Embedding
    • 2. Multi-Head Attention
    • 3. Transformer Block
      • Feed Forward
  • 二、预备知识
    • 1. Einsum
    • 2. Einops
  • 三、Vision Transformer代码实现
    • 0. 导入库
    • 1. Patch Embedding
    • 2. Residual & Norm
    • 3. Multi-Head Attention & FeedForward
    • 4. Transformer Encoder
    • 6. Vision Transformer
    • 7. Test Code
    • 模型参数量计算
      • 1. 卷积核参数量计算
      • 2. 全连接层参数量计算
      • 3. ViT参数量计算
  • 总结
  • 日志
  • 参考文献


前言

Transformer在NLP领域大放异彩,而实际上NLP(Natural Language Processing,自然语言处理)领域技术的发展都要先于CV(Computer Vision,计算机视觉),那么如何将Transformer这类模型也能适用到图像数据上呢?
在2017年Transformer发布后,历经3年时间,Vision Transformer于2020年问世。与Transformer相同,Vision Transformer也是由Google Brain和Google Research团队开发,然而并不是同一批人(除了Jakob Uszkoreit)。
值得一提的是,Vision Transformer并不是第一个将Transformer应用到CV上的,因为这些巨头的存在(如Google,FaceBook),论文的名气也自然会更大,而且从如今ViT的泛用程度来看也是,大家对其认可度更高纷纷follow。和这些巨头庞大资源比,高校产出的论文光芒显得黯淡了许多。而在大模型时代更是如此,都是“大力出奇迹”的结果。可大模型大数据训练就是AI的最终形态了吗,我觉得不然……或许在AI真正具有“智能”时,深度学习的模型也并不需要这么大吧,因为人脑正是有了联想推理才能拥有知识和技能,而不完全单靠记忆。


一、Vision Transformer架构介绍

在这里插入图片描述

1. Patch Embedding

2. Multi-Head Attention

3. Transformer Block

如图,(a) 是最初Transformer的Encoder结构图, (b)则是ViT的。可以明显看出,Transformer是在multi-head attention和feedforward模块后进行残差操作(即Add)和Norm(标准化),而ViT则是在这些模块前使用Norm操作。

Feed Forward

ViT的Feed Forward模块使用两层全连接层(Linear)和GeLU激活函数。而Transformer使用的是ReLu激活函数。
GeLu于2016年被提出,见于Bridging Nonlinearities and Stochastic Regularizers with Gaussian Error Linear Units,后来经过论文修改改名为“Gaussian Error Linear Units (GELUs)”。论文给出了ReLu和GeLu的图示:
在这里插入图片描述
ReLu确实好用,但缺点也很明显,其在输入值小于0时都会输出0,这样“一刀切”的策略势必会丢掉信息,累计error。因此后来出现了GeLu、LeakyReLu等一系列激活函数来解决神经元”死亡“问题,让输入值小于0时输出不总是0。


二、预备知识

本节的两个操作都是为了方便编程人员更好对tensor进行操作,且让代码更具可读性。

1. Einsum

Einsum即爱因斯坦和,torch.einsum即可调用。

2. Einops

是大牛 受Einsum启发所开发的一个库,主要用于张量的变形等操作。


三、Vision Transformer代码实现

这次代码并不是直接取用某一份代码,而是参考包括Pytorch官方的代码库、网上博客、github项目综合出的一份Vision Transformer代码,尽可能还原ViT又兼顾代码可读性以便读者学习理解。此处引用比ViT原论文更加具体的ViT模型图:
ViT流程图
此图出自论文Vision Transformers for Remote Sensing Image Classification。

0. 导入库

import torch
import torch.nn as nn
from einops import rearrange, repeat
from einops.layers.torch import Rearrange

1. Patch Embedding

class PatchEmbedding(nn.Module):def __init__(self, embed_size=768, patch_size=16, channels=3, img_size=224):super(PatchEmbedding, self).__init__()self.patch_size = patch_size# Version 1.0# self.patch_projection = nn.Sequential(#     Rearrange("b c (h h1) (w w1) -> b (h w) (h1 w1 c)", h1=patch_size, w1=patch_size),#     nn.Linear(patch_size * patch_size * channels, embed_size)# )# Version 2.0self.patch_projection = nn.Sequential(nn.Conv2d(channels, embed_size, kernel_size=(patch_size, patch_size), stride=(patch_size, patch_size)),Rearrange("b e (h) (w) -> b (h w) e"),)self.cls_token = nn.Parameter(torch.randn(1, 1, embed_size))self.positions = nn.Parameter(torch.randn((img_size // patch_size) ** 2 + 1, embed_size))def forward(self, x):batch_size = x.shape[0]x = self.patch_projection(x)# prepend the cls token to the inputcls_tokens = repeat(self.cls_token, "() n e -> b n e", b=batch_size)x = torch.cat([cls_tokens, x], dim=1)# add position embeddingx += self.positionsreturn x

2. Residual & Norm

class Residual(nn.Module):def __init__(self, fn):super(Residual, self).__init__()self.fn = fndef forward(self, x, **kwargs):return self.fn(x, **kwargs) + xclass PreNorm(nn.Module):def __init__(self, dim, fn):super(PreNorm, self).__init__()self.norm = nn.LayerNorm(dim)self.fn = fndef forward(self, x, **kwargs):return self.fn(self.norm(x), **kwargs)

3. Multi-Head Attention & FeedForward

class FeedForward(nn.Module):def __init__(self, dim, hidden_dim, dropout=0.):super(FeedForward, self).__init__()self.mlp = nn.Sequential(nn.Linear(dim, hidden_dim),nn.GELU(),nn.Dropout(dropout),nn.Linear(hidden_dim, dim),nn.Dropout(dropout),)def forward(self, x):return self.mlp(x)class MultiHeadAttention(nn.Module):def __init__(self, embed_dim=768, n_heads=8, dropout=0.):"""Args:embed_dim: dimension of embeding vector outputn_heads: number of self attention heads"""super(MultiHeadAttention, self).__init__()self.embed_dim = embed_dim  # 768 dimself.n_heads = n_heads  # 8self.head_dim = self.embed_dim // self.n_heads  # 768/8 = 96. each key,query,value will be of 96dself.scale = self.head_dim ** -0.5self.attn_drop = nn.Dropout(dropout)# key,query and value matrixesself.to_qkv = nn.Linear(self.embed_dim, self.embed_dim * 3, bias=False)self.to_out = nn.Sequential(nn.Linear(self.embed_dim, self.embed_dim),nn.Dropout(dropout))def forward(self, x):"""Args:x : a unified vector of key query valueReturns:output vector from multihead attention"""qkv = self.to_qkv(x).chunk(3, dim=-1)q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=self.n_heads), qkv)dots = torch.einsum('bhid,bhjd->bhij', q, k) * self.scaleattn = dots.softmax(dim=-1)attn = self.attn_drop(attn)out = torch.einsum('bhij,bhjd->bhid', attn, v)out = rearrange(out, "b h n d -> b n (h d)")out = self.to_out(out)return out

4. Transformer Encoder

class Transformer(nn.Module):def __init__(self, dim=768, depth=12, n_heads=8, mlp_expansions=4, dropout=0.):super(Transformer, self).__init__()self.layers = nn.ModuleList([])for _ in range(depth):self.layers.append(nn.ModuleList([Residual(PreNorm(dim, MultiHeadAttention(dim, n_heads, dropout))),Residual(FeedForward(dim, dim * mlp_expansions, dropout))]))def forward(self, x):for attn, ff in self.layers:x = attn(x)x = ff(x)return x

6. Vision Transformer

class VisionTransformer(nn.Module):def __init__(self, dim=768,patch_size=16,channels=3,img_size=224,depth=12,n_heads=8,mlp_expansions=4,dropout=0.,num_classes=0,global_pool='avg'):super(VisionTransformer, self).__init__()assert global_pool in ('avg', 'token')self.global_pool = global_poolself.patch_embedding = PatchEmbedding(dim, patch_size, channels, img_size)self.transformer = Transformer(dim, depth, n_heads, mlp_expansions, dropout)self.mlp_head = nn.Sequential(nn.LayerNorm(dim),nn.Linear(dim, num_classes)) if num_classes > 0 else nn.Identity()def forward(self, img):x = self.patch_embedding(img)x = self.transformer(x)x = x[:, 1:].mean(dim=1) if self.global_pool == 'avg' else x[:, 0]x = self.mlp_head(x)return x

7. Test Code

if __name__ == '__main__':device = torch.device("cuda" if torch.cuda.is_available() else "cpu")images = torch.randn((16, 3, 224, 224)).to(device)vit = VisionTransformer(num_classes=4, global_pool="token").to(device)output = vit(images)print(output)torch.save(vit.state_dict(), "model.pth")

模型参数量计算

1. 卷积核参数量计算

对于二维卷积层,其参数量由输入通道数(C)、卷积核的大小(KxK)、卷积核的数量或者说输出通道数(F)、偏置项的数量等因素决定。计算公式为:
( K × K × C + 1 ) × F (K \times K \times C + 1)\times F (K×K×C+1)×F,其中1为偏置项。

2. 全连接层参数量计算

对于某一层全连接层的参数量只由其输入维度和输出维度(是否带偏置项)决定,将全连接层理解为一个映射函数,假设输入为矩阵A(维度为HxW),输出为矩阵C(维度为HxH),那么一层全连接层参数量就来自其所代表的矩阵B根据矩阵乘法其维度应为WxH,即Linear(W,H),输入维度W,输出维度也是H。计算公式易得:
W × H + H × 1 W \times H + H\times 1 W×H+H×1,其中1代表偏置项,需要输出维度个偏置项。

3. ViT参数量计算

模块/变量名计算过程参数量
PatchEmbedding c o n v 2 d + c l s _ t o k e n + p o s t i t i o n s conv2d + cls\_token + postitions conv2d+cls_token+postitions742656
conv2d ( 16 × 16 × 3 + 1 ) × 768 (16\times 16\times 3 + 1)\times 768 (16×16×3+1)×768590592
cls_token 1 × 1 × 768 1\times1\times768 1×1×768768
postitions ( ( 224 ÷ 16 ) 2 + 1 ) × 768 ((224\div 16)^2+1)\times768 ((224÷16)2+1)×768151296
Feedforward ( 768 × ( 768 × 4 ) + ( 768 × 4 ) ) + ( ( 768 × 4 ) × 768 + 768 ) (768\times(768\times4)+(768\times4)) + ((768\times4)\times768+768) (768×(768×4)+(768×4))+((768×4)×768+768)4722432
MultiHeadAttention t o _ q k v + t o _ o u t to\_qkv + to\_out to_qkv+to_out2360064
to_qkv 768 × ( 768 × 3 ) 768\times(768\times3) 768×(768×3)1769472
to_out 768 × 768 + 768 768\times768+768 768×768+768590592
Transformer 12 × ( F e e d f o r w a r d + M u l t i H e a d A t t e n t i o n ) 12\times(Feedforward+MultiHeadAttention) 12×(Feedforward+MultiHeadAttention)84989952
ViT T r a n s f o r m e r + P a t c h E m b e d d i n g + m l p _ h e a d Transformer+PatchEmbedding+mlp\_head Transformer+PatchEmbedding+mlp_head85735684
mlp_head 768 × n u m _ c l a s s e s + n u m _ c l a s s e s ,本文设置 n u m _ c l a s s e s 为 4 768\times num\_classes+num\_classes,本文设置num\_classes为4 768×num_classes+num_classes,本文设置num_classes43076

最终参数量为 85735684 × 4 ( B ) = 342942736 ( B ) 85735684\times 4(B) = 342942736(B) 85735684×4(B)=342942736(B)为什么要乘以4字节呢?
因为这些参数权重默认为float32保存,需要用到32bits即4Bytes,最终通过换算得,
342942736 ( B ) ÷ 1024 ÷ 1024 = 327.055679321 ( M B ) 342942736(B)\div 1024\div 1024 = 327.055679321(MB) 342942736(B)÷1024÷1024=327.055679321(MB)
因为我们在Test code有保存模型权重为model.pth文件,可以查看model.pth属性来验证计算是否准确。
在这里插入图片描述
在字节数上有所偏差,但足以表明计算过程大致是正确的! 偏差可能原因是model.pth不止要保存权重,还会附带一些其他信息,所以实际文件大小会比参数量要略大。


总结

日志

参考文献

https://theaisummer.com/vision-transformer/
https://medium.com/mlearning-ai/vision-transformers-from-scratch-pytorch-a-step-by-step-guide-96c3313c2e0c
https://www.kaggle.com/code/hannes82/vision-transformer-trained-from-scratch-pytorch
https://towardsdatascience.com/implementing-visualttransformer-in-pytorch-184f9f16f632
https://github.com/FrancescoSaverioZuppichini/ViT

相关文章:

Pytorch从零开始实现Vision Transformer (from scratch)

Pytorch从零开始实现Vision Transformer 前言一、Vision Transformer架构介绍1. Patch Embedding2. Multi-Head Attention3. Transformer BlockFeed Forward 二、预备知识1. Einsum2. Einops 三、Vision Transformer代码实现0. 导入库1. Patch Embedding2. Residual & Norm…...

ES6函数新增了哪些扩展?

目录 一、参数二、属性函数的length属性name属性 三、作用域四、严格模式五、箭头函数 一、参数 ES6允许为函数的参数设置默认值 function log(x, y World) {console.log(x, y); }console.log(Hello) // Hello World console.log(Hello, China) // Hello China console.log(…...

【firewalld防火墙】

目录 一、firewalld概述二、firewalld 与 iptables 的区别1、firewalld 区域的概念 三、firewalld防火墙默认的9个区域四、Firewalld 网络区域1、区域介绍2、firewalld数据处理流程 五、firewalld防火墙的配置方法1、使用firewall-cmd 命令行工具。2、使用firewall-config 图形…...

CNNs: ZFNet之CNN的可视化网络介绍

CNNs: ZFNet之CNN的可视化网络介绍 导言Deconvnet1. Unpooling2. ReLU3. Transpose conv AlexNet网络修改AlexNet Deconv网络介绍特征可视化 导言 上一个内容,我们主要学习了AlexNet网络的实现、超参数对网络结果的影响以及网络中涉及到一些其他的知识点&#xff0…...

云原生之深入解析Airbnb的动态Kubernetes集群扩缩容

一、前言 Airbnb 基础设施的一个重要作用是保证我们的云能够根据需求上升或下降进行自动扩缩容,我们每天的流量波动都非常大,需要依靠动态扩缩容来保证服务的正常运行。为了支持扩缩容,Airbnb 使用了 Kubernetes 编排系统,并且使…...

Django框架之模板其他补充

本篇文章是对django框架模板内容的一些补充。包含注释、html转义和csrf内容。 目录 注释 单行注释 多行注释 HTML转义 Escape Safe Autoescape CSRF 防止csrf方式 表单中使用 ajax请求添加 注释 单行注释 语法:{# 注释内容 #} 示例: {# 注…...

安装Maven 3.6.1:图文详细教程(适用于Windows系统)

一、官网下载对应版本 推荐使用maven3.6.1版本,对应下载链接: Maven3.6.1下载地址 或者,这里提供csdn下载地址,点击下载即可: Maven3.6.1直链下载 其他版本下载地址: 进入网址:http://mave…...

计算机图形学 | 实验八:Phong模型

计算机图形学 | 实验八:Phong模型 计算机图形学 | 实验八:Phong模型Phong模型光源设置 光照计算定向光点光源聚光 华中科技大学《计算机图形学》课程 MOOC地址:计算机图形学(HUST) 计算机图形学 | 实验八&#xff1a…...

第三十一回:GestureDetector Widget

文章目录 概念介绍使用方法示例代码 我们在上一章回中介绍了ListView响应事件的内容t,本章回中将介绍 GestureDetector Widget.闲话休提,让我们一起Talk Flutter吧。 概念介绍 我们在这里介绍的GestureDetector是一个事件响应Widget,它可以响应双击事件&#xff0…...

Java面试知识点(全)-Java并发-多线程JUC三- JUC集合/线程池

Java面试知识点(全) 导航: https://nanxiang.blog.csdn.net/article/details/130640392 注:随时更新 JUC集合类 为什么HashTable慢? 它的并发度是什么? 那么ConcurrentHashMap并发度是什么? Hashtable之所以效率低下主要是因为其实现使用了synchro…...

Android 如何获取有效的DeviceId

目录 前言官方唯一标识符建议使用广告 ID使用实例 ID 和 GUID不要使用 MAC 地址标识符特性常见用例和适用的标识符 解决方案DeviceIdANDROID_IDMac地址UUID补充 总结 前言 从 Android 10 开始,应用必须具有 READ_PRIVILEGED_PHONE_STATE 特许权限才能访问设备的不可…...

<SQL>《SQL命令(含例句)精心整理版(2)》

《SQL命令(含例句)精心整理版(2)》 跳转《SQL命令(含例句)精心整理版(1)8 函数8.1 文本处理函数8.2 数值处理函数8.3 时间处理函数8.3.1 时间戳转化为自定义格式from_unixtime8.3.2 …...

完全自主研发,聚芯微发布3D dToF图像传感器芯片!

日前,由中国半导体行业协会IC设计分会(ICCAD)、芯原股份、松山湖管委会主办的主题为“AR/VR/XR元宇宙”的“2023松山湖中国IC创新高峰论坛”正式在广东东莞松山湖召开。武汉市聚芯微电子有限责任公司发布了完全自主知识产权的3D dToF图像传感…...

MySQL 事物(w字)

目录 事物 首先我们来看一个简单的问题 什么是事务 为什么会出现事务 事务的版本支持 事务提交方式 事务常见操作方式 设置隔离级别 事物操作 事物结论 事务隔离级别 理解隔离性 隔离级别 查看与设置隔离性 注意可重复读【Repeatable Read】的可能问题&#xff…...

字节跳动测试岗四面总结....

字节一面 1、 简单做一下自我介绍 2、 简要介绍一下项目/你负责的模块/选一个模块说一下你设计的用例 3 、get请求和post请求的区别 4、 如何判断前后端bug/3xx是什么意思 5、 说一下XXX项目中你做的接口测试/做了多少次 6、 http和https的区别 7、 考了几个ADB命令/查看…...

基于.NetCore开源的Windows的GIF录屏工具

推荐一个Github上Start超过20K的超火、好用的屏幕截图转换为 GIF 动图开源项目。 项目简介 这是基于.Net Core WPF 开发的、开源项目,可将屏幕截图转为 GIF 动画。它的核心功能是能够简单、快速地截取整个屏幕或者选定区域,并将其转为 GIF动画&#x…...

PCB 基础~典型的PCB设计流程,典型的PCB制造流程

典型的PCB设计流程 典型的PCB制造流程 • 从客户手中拿到Gerber, Drill以及其它PCB相关文件 • 准备PCB基片和薄片 – 铜箔的底片会被粘合在基材上 • 内层图像蚀刻 – 抗腐蚀的化学药水会涂在需要保留的铜箔上(例如走线和过孔) – 其他药水…...

Python logging使用

目录 logging模块 logging核心组件 logger handler StreamHandler:把日志内容在控制台中输出 FileHandler:把日志内容写入到文件中 filter formatter 注意日志级别的继承问题 logger.exception 上述样例的整体代码 日志的配置文件及其模板 lo…...

红黑树的实现原理和应用场景

红黑树的实现原理和应用场景; 有如图所示的表,现在希望查询的结果将列成行 建表语句如下: CREATE TABLE TEST_TB_GRADE2 ( ID int(10) NOT NULL AUTO_INCREMENT, USER_NAME varchar(20) DEFAULT NULL, CN_SCORE float DEFAULT NU…...

idea插件完成junit代码生成,和springboot代码示例

在idea环境下,可以用过插件的方式自动生成juint模板代码。不过具体要需要自己手动编写。 1、安装插件 打开idea,file–settings–plugins,搜索和安装插件(JunitGenerator V2.0和JUnit),安装后,后…...

【Axure高保真原型】引导弹窗

今天和大家中分享引导弹窗的原型模板,载入页面后,会显示引导弹窗,适用于引导用户使用页面,点击完成后,会显示下一个引导弹窗,直至最后一个引导弹窗完成后进入首页。具体效果可以点击下方视频观看或打开下方…...

国防科技大学计算机基础课程笔记02信息编码

1.机内码和国标码 国标码就是我们非常熟悉的这个GB2312,但是因为都是16进制,因此这个了16进制的数据既可以翻译成为这个机器码,也可以翻译成为这个国标码,所以这个时候很容易会出现这个歧义的情况; 因此,我们的这个国…...

谷歌浏览器插件

项目中有时候会用到插件 sync-cookie-extension1.0.0:开发环境同步测试 cookie 至 localhost,便于本地请求服务携带 cookie 参考地址:https://juejin.cn/post/7139354571712757767 里面有源码下载下来,加在到扩展即可使用FeHelp…...

内存分配函数malloc kmalloc vmalloc

内存分配函数malloc kmalloc vmalloc malloc实现步骤: 1)请求大小调整:首先,malloc 需要调整用户请求的大小,以适应内部数据结构(例如,可能需要存储额外的元数据)。通常,这包括对齐调整,确保分配的内存地址满足特定硬件要求(如对齐到8字节或16字节边界)。 2)空闲…...

java_网络服务相关_gateway_nacos_feign区别联系

1. spring-cloud-starter-gateway 作用:作为微服务架构的网关,统一入口,处理所有外部请求。 核心能力: 路由转发(基于路径、服务名等)过滤器(鉴权、限流、日志、Header 处理)支持负…...

黑马Mybatis

Mybatis 表现层&#xff1a;页面展示 业务层&#xff1a;逻辑处理 持久层&#xff1a;持久数据化保存 在这里插入图片描述 Mybatis快速入门 ![在这里插入图片描述](https://i-blog.csdnimg.cn/direct/6501c2109c4442118ceb6014725e48e4.png //logback.xml <?xml ver…...

智慧工地云平台源码,基于微服务架构+Java+Spring Cloud +UniApp +MySql

智慧工地管理云平台系统&#xff0c;智慧工地全套源码&#xff0c;java版智慧工地源码&#xff0c;支持PC端、大屏端、移动端。 智慧工地聚焦建筑行业的市场需求&#xff0c;提供“平台网络终端”的整体解决方案&#xff0c;提供劳务管理、视频管理、智能监测、绿色施工、安全管…...

mysql已经安装,但是通过rpm -q 没有找mysql相关的已安装包

文章目录 现象&#xff1a;mysql已经安装&#xff0c;但是通过rpm -q 没有找mysql相关的已安装包遇到 rpm 命令找不到已经安装的 MySQL 包时&#xff0c;可能是因为以下几个原因&#xff1a;1.MySQL 不是通过 RPM 包安装的2.RPM 数据库损坏3.使用了不同的包名或路径4.使用其他包…...

Unsafe Fileupload篇补充-木马的详细教程与木马分享(中国蚁剑方式)

在之前的皮卡丘靶场第九期Unsafe Fileupload篇中我们学习了木马的原理并且学了一个简单的木马文件 本期内容是为了更好的为大家解释木马&#xff08;服务器方面的&#xff09;的原理&#xff0c;连接&#xff0c;以及各种木马及连接工具的分享 文件木马&#xff1a;https://w…...

【C++进阶篇】智能指针

C内存管理终极指南&#xff1a;智能指针从入门到源码剖析 一. 智能指针1.1 auto_ptr1.2 unique_ptr1.3 shared_ptr1.4 make_shared 二. 原理三. shared_ptr循环引用问题三. 线程安全问题四. 内存泄漏4.1 什么是内存泄漏4.2 危害4.3 避免内存泄漏 五. 最后 一. 智能指针 智能指…...