当前位置: 首页 > news >正文

Vision Transfomer系列第一节---从0到1的源码实现

本专栏主要是深度学习/自动驾驶相关的源码实现,获取全套代码请参考

这里写目录标题

  • 准备
  • 逐步源码实现
    • 数据集读取
    • VIt模型搭建
      • hand
      • 类别和位置编码
          • 类别编码
          • 位置编码
      • blocks
      • head
      • VIT整体
    • Runner(参考mmlab)
    • 可视化
  • 总结

准备

在这里插入图片描述
本博客完成Vision Transfomer(VIT)模型的搭建和flowers数据集的训练测试.整个源码包括如下几个任务:
1.读取flowers数据集的dataset类,对应文件dataset.py
2.VIT模型搭建,主要依赖于上几篇博客,对应model.py

1.transfomer中Multi-Head Attention的源码实现的MultiheadAttention类,用于搭建BaseTransformerLayer类,实现encoder和decoder功能
2.transfomer中Decoder和Encoder的base_layer的源码实现的BaseTransformerLayer类,帮助我们丝滑地搭建各类transformer网络
3.transfomer中正余弦位置编码的源码实现[可选]

3.设置优化器学习率和训练/验证模型,对应runner.py和train.py
4.可视化测试单个图片的预测结果,对应demo.py

逐步源码实现

源码结构如下
在这里插入图片描述

数据集读取

主要原理:根据dataset的路径,存储各个图片对应的路径,label隐藏在路径中.
在getitem函数中完成指定index图片和label的读取和数据增强功能

class Flowers(Dataset):# 用于读取flower数据集def __init__(self, dataset_path: str, transforms=None):'''存储所有数据 data路径和label:param dataset_path:'''super(Dataset, self).__init__()flowers = os.listdir(dataset_path)flowers = sorted(flowers) # 必须排序,否在每一次顺序不一样训练测试类别就会乱self.flower_paths = []self.class2label = {}  # 类别str 转 labellabel = 0for _, flower in enumerate(flowers):flowers_path = os.path.join(dataset_path, flower)if os.path.isdir(flowers_path):self.class2label[flower] = labellabel +=1sub_flowers = os.listdir(flowers_path)for sub_flower in sub_flowers:self.flower_paths.append(os.path.join(flowers_path, sub_flower))self.label2class = label2class(self.class2label)  # label 转 类别strself.transforms = transforms''''''def __getitem__(self, item):# 读取数据和labelimg = Image.open(self.flower_paths[item])label = self.class2label[self.flower_paths[item].split('/')[-2]]if self.transforms is not None:img = self.transforms(img)  # 数据增强return img, label

VIt模型搭建

将整个深度学习模型按照人体分为hand+backbone+neck+head 4个部分,Vit模型不同CNN模型,它的backbone+neck为多个MultiHeadAttention堆叠组成,称之为blocks.

hand

hand主用完成预处理,将数据用"手"揉捏成想要的类型.本处主要完成图片的patch操作,将图片分割成一个个小块,使用大核的卷积完成.然后把w和h拉平后shape就和NLP(b,n,d)一样了.

class PatchLayer(nn.Module):def __init__(self, img_size, patch_size=20, embeding_dim=64):super(PatchLayer, self).__init__()self.grid_size = (img_size[0] // patch_size, img_size[1] // patch_size)self.num_patches = self.grid_size[0] * self.grid_size[1]self.proj = nn.Conv2d(in_channels=3,out_channels=embeding_dim,kernel_size=(patch_size, patch_size),stride=patch_size,padding=0)self.norm = nn.LayerNorm(normalized_shape=embeding_dim)def forward(self, img):img = self.proj(img)  # 图片分割img = img.flatten(start_dim=2)  # wh拉平img = img.permute(0, 2, 1)  # [b wh c]img = self.norm(img)return img

类别和位置编码

类别编码

直接cat到input上面,那么最后也取出对应的那一列作为类别输出.这是transformer类型网络的常用手段.
个人解释:训练出类别的访问者,这个访问者可以从特征信息(原input)中提取类别信息.训练访问者方法就是类别loss回归,训练时候先推出,推理时推出

位置编码

add到input上,可以使用可学习式的位置编码也可以使用正余弦位置编码.这是transformer类型网络的常用手段,还要特征层编码等
个人解释:训练出位置的标记者

        # 类别编码self.cls_token = nn.Parameter(torch.zeros(size=[1, 1, embed_dim]))# 固定位置编码和可学习位置编码# self.pos_embed = posemb_sincos_1d(len=num_patches + 1, dim=embed_dim,temperature=1000).unsqueeze(0)self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))

blocks

blocks使用注意力机制完成特征提取,
个人解释:
input线性映射为[query,key,value],需求侧(query)从供给侧(value)中取值,取值的根据是qurey@key转置生成的注意力矩阵(需求侧和供给侧每个像素之间的相似度),最后输出与输入shape相同.所以我们重复depth次,多次特征提取.
源码直接调用:transfomer中Decoder和Encoder的base_layer的源码实现的BaseTransformerLayer类

head

主要对transfomer输出的类别特征进行映射,embed维度映射为num_class维度

self.head = nn.Linear(embed_dim, num_classes)

VIT整体

主要是上述几个模块的集合及其正向传播过程:
完成二维图片变一维特征,一维特征transfomer特征提取,分类头输出.

class Vit(nn.Module):def __init__(self, img_size=[224, 224], patch_size=16, num_classes=1000,embed_dim=768, depth=12, num_heads=12):super(Vit, self).__init__()self.patch_embed = PatchLayer(img_size, patch_size, embed_dim)num_patches = self.patch_embed.num_patchesself.blocks = nn.Sequential(*[BaseTransformerLayer(attn_cfgs=[dict(embed_dim=embed_dim, num_heads=num_heads)],fnn_cfg=dict(embed_dim=embed_dim, feedforward_channels=4 * embed_dim, act_cfg='ReLU',ffn_drop=0.),operation_order=('self_attn', 'norm', 'ffn', 'norm'))for _ in range(depth)])# 类别编码self.cls_token = nn.Parameter(torch.zeros(size=[1, 1, embed_dim]))# 固定位置编码和可学习位置编码# self.pos_embed = posemb_sincos_1d(len=num_patches + 1, dim=embed_dim,temperature=1000).unsqueeze(0)self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))# 分类头self.head = nn.Linear(embed_dim, num_classes)self.loss_class = nn.CrossEntropyLoss()  # 内置softmaxself.init_weights()''''''def forward(self, img):query = self.hand(img)query = self.extract_feature(query)cls_fea = query[:, -1, :]  # 刚刚class_token被cat到了dim1的最后一个数x = self.head(cls_fea)return x

Runner(参考mmlab)

建立优化前,设置学习率,根据指定的work_flow顺序进行训练的测试,并保留最优权重

class Runner:def __init__(self, arg, model, device):self.arg = arg# 建立优化器params = [p for p in model.parameters() if p.requires_grad]self.optimizer = torch.optim.SGD(params=params, lr=arg.lr, momentum=0.9, weight_decay=5E-5)lf = lambda x: ((1 + math.cos(x * math.pi / arg.epochs)) / 2) * (1 - arg.lrf) + arg.lrf  # cosineself.scheduler = torch.optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda=lf)self.model = model.to(device)self.device = deviceif arg.load_from is not None and arg.load_from != '':weight_dict = torch.load(arg.load_from, map_location=device)model.load_state_dict(weight_dict)def run(self, dataloaders: dict):# 开始训练和验证assert 'train' in self.arg.work_flow.keys(), '必须要用训练任务'epoch_start = 0best_accuracy = 0.0while epoch_start < self.arg.epochs:for task, times in self.arg.work_flow.items():if task == 'train':  # 开始训练for _ in range(times):epoch_start += 1  # epoch只记录训练轮self.model.train()loss_sum = 0.0data_loader = tqdm(dataloaders['train'], file=sys.stdout)for step, data_dict in enumerate(data_loader):img, label = data_dictinstance = {'data': img.to(self.device),'label': label.to(self.device)}loss = self.model.loss(**instance)loss_sum += loss.detach()  # 要十分注意 避免往计算图中引入新的东西loss.backward()self.optimizer.step()self.optimizer.zero_grad()data_loader.desc = "[train epoch {}] loss: {:.3f}".\format(epoch_start,loss_sum.item() / (step + 1))self.scheduler.step()print('train: epoch={}, loss={}'.format(epoch_start, loss_sum / (step + 1.0)))elif task == 'val':  # 开始验证''''''else:raise ValueError('task must be in [train, val, test]')

可视化

读取单张图片,转换格式输入模型,输出的label,转化为class名和置信度,显示图像,class名和置信度.

if __name__ == '__main__':# 建立数据集data_transform = transforms.Compose([transforms.Resize(256),transforms.CenterCrop(224),transforms.ToTensor(),transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])img = Image.open('*****daisy/21652746_cc379e0eea_m.jpg')input = data_transform(img).unsqueeze(0)label2class = Flowers(dataset_path='../datasets/flower_photos-mini').label2classdevice = torch.device('cuda:0')# 建立模型model = Vit(img_size=[224, 224],patch_size=16,embed_dim=768,depth=12,num_heads=12,num_classes=5).to(device)weight_dict = torch.load('weights/vit.pth', map_location=device)model.load_state_dict(weight_dict)model.eval()with torch.no_grad():output = model(input.to(device))output = output.detach().cpu()label = output[0].numpy().argmax()cnf = torch.softmax(output[0],dim=0).numpy().max()*100.0cnf = np.around(cnf, decimals=2) #保留2位小数plt.imshow(img)plt.title('{} : {}%'.format(label2class[label],cnf))plt.show()

总结

vit是视觉transfomer最经典的模型,复现一次代码十分有必要,中间会产生很多思考和问题.
后面章节将会更有价值,我将会:

1.利用本次的代码进行很多思考和trick的验证
2.总结本次代码的BUG们,及其产生的原理和解决方法

如需获取全套代码请参考

相关文章:

Vision Transfomer系列第一节---从0到1的源码实现

本专栏主要是深度学习/自动驾驶相关的源码实现,获取全套代码请参考 这里写目录标题 准备逐步源码实现数据集读取VIt模型搭建hand类别和位置编码类别编码位置编码 blocksheadVIT整体 Runner(参考mmlab)可视化 总结 准备 本博客完成Vision Transfomer(VIT)模型的搭建和flowers数…...

【CSS + ElementUI】更改 el-carousel 指示器样式且隐藏左右箭头

需求 前三条数据以走马灯形式展现&#xff0c;指示器 hover 时可以切换到对应内容 实现 <template><div v-loading"latestLoading"><div class"upload-first" v-show"latestThreeList.length > 0"><el-carousel ind…...

Ubuntu 22.04 上安装和使用 Go

1.下载  All releases - The Go Programming Language //https://golang.google.cn/dl/wget https://golang.google.cn/dl/go1.21.6.linux-amd64.tar.gz 2.在下载目录下执行&#xff0c;现在&#xff0c;使用以下命令将文件提取到 “/usr/local ” 位置 sudo tar -C /usr/…...

ES6-const

一、基本用法 - 语法&#xff1a;const 标识符初始值;注意:const一旦声明变量&#xff0c;就必须立即初始化&#xff0c;不能留到以后赋值 - 规则&#xff1a;1.const 声明一个只读的常量&#xff0c;一旦声明&#xff0c;常量的值就不能改变2.const 其实保证的不是变量的值不…...

Android消息通知Notification

Notification 发送消息接收消息 #前言 最近在做消息通知类Notification的相关业务&#xff0c;利用闲暇时间总结一下。主要分为两部分来记录&#xff1a;发送消息和接收消息。 发送消息 发送消息利用NotificationManager类的notify方法来实现&#xff0c;现用最普通的方式发…...

2V2无人机红蓝对抗仿真

两架红方和蓝方无人机分别从不同位置起飞&#xff0c;蓝方无人机跟踪及击毁红方无人机 2020a可正常运行 2V2无人机红蓝对抗仿真资源-CSDN文库...

VUE3语法--computed计算属性中get和set使用案例

1、功能概述 计算属性computed是Vue3中一个响应式的属性,最大的用处是基于多依赖时的监听。也就是属性A的值可以根据其他数据的变化而响应式的变化。 在Vue3中,你可以使用computed函数来定义计算属性。computed函数接收两个参数:一个包含getter和setter函数的对象和可选的…...

Linux cd 和 df 命令执行异常

这篇记录一些奇奇怪怪的命令执行异常的情况&#xff0c;后续有新的发现也会补录进来 情况一 /tmp 目录权限导致 按 tab 补充报错 情况描述 cd 按 tab 自动补充文件报错&#xff08;普通用户&#xff09; bash: cannot create temp file for here-document: Permission denie…...

【计算机网络】物理层概述|通信基础|奈氏准则|香农定理|信道复用技术

目录 一、思维导图 二、 物理层概述 1.物理层概述 2.四大特性&#xff08;巧记"械气功程") 三、通信基础 1.数据通信基础 2.趁热打铁☞习题训练 3.信号の变身&#xff1a;编码与调制 4.极限数据传输率 5.趁热打铁☞习题训练 6.信道复用技术 推荐 前些天发…...

XXE基础知识整理(附加xml基础整理)

全称&#xff1a;XML External Entity 外部实体注入攻击 原理 利用xml进行读取数据时过滤不严导致嵌入了恶意的xml代码&#xff1b;和xss原理雷同 危害 外界攻击者可读取商户服务器上的任意文件&#xff1b; 执行系统命令&#xff1b; 探测内网端口&#xff1b; 攻击内网网站…...

【pytorch】anaconda使用及安装pytorch

https://zhuanlan.zhihu.com/p/348120084 https://blog.csdn.net/weixin_44110563/article/details/123324304 介绍 Conda创建环境相当于创建一个虚拟的空间将这些包都装在这个位置&#xff0c;不需要了可以直接打包放入垃圾箱&#xff0c;同时也可以针对不同程序的运行环境选…...

SpringBoot过滤器获取响应的参数

一、背景 在项目开发过程中&#xff0c;需要对于某些接口统一处理。 这时候就需要获取响应的报文&#xff0c;再对获取的报文进行统一处理。 二、了解过滤器 首先了解一下过滤器拦截器的区别&#xff1a; JAVA中的拦截器、过滤器&#xff1a;https://blog.csdn.net/qq_38254…...

数据挖掘实战-基于决策树算法构建北京市空气质量预测模型

&#x1f935;‍♂️ 个人主页&#xff1a;艾派森的个人主页 ✍&#x1f3fb;作者简介&#xff1a;Python学习者 &#x1f40b; 希望大家多多支持&#xff0c;我们一起进步&#xff01;&#x1f604; 如果文章对你有帮助的话&#xff0c; 欢迎评论 &#x1f4ac;点赞&#x1f4…...

SOLID原理:用Golang的例子来解释

随着软件系统变得越来越复杂&#xff0c;编写模块化、灵活和易于理解的代码非常重要。实现这一目标的方法之一是遵循SOLID原则。这些原则是由罗伯特-C-马丁&#xff08;Robert C. Martin&#xff09;提出的&#xff0c;以帮助开发人员创建更容易维护、测试和扩展的代码。 本文将…...

mysql是如何使用索引的?

摘自官网 MySQL使用索引进行以下操作: WHERE条件中,快速查找匹配的行。(快速查询数据) 从准备查询的数据中消除多余行。如果可以在多个索引之间进行选择,则MySQL通常会使用查找最少行数的索引。 如果表具有多列索引,那么优化器可以使用索引的任何最左前缀来查找行。 举例来…...

自动驾驶IPO第一股及商业化行业标杆 Mobileye

一、Mobileye 简介 Mobileye 是全球领先的自动驾驶技术公司&#xff0c;成立于 1999 年&#xff0c;总部位于以色列耶路撒冷。公司专注于开发视觉感知技术和辅助驾驶系统 (ADAS)&#xff0c;并在自动驾驶领域处于领先地位。Mobileye 是高级驾驶辅助系统&#xff08;ADAS&#…...

Linux前后端程序部署

1.总述 首先安装包类型分为 二进制发布包安装:找到对应自己的linux平台版本(CentOS还是redhat等),的具体压缩文件,解压修改配置 源码编译安装:需要自己进行编译 对于redhat安装包,可以使用rpm命令进行安装,但是rpm命令安装不能够解决依赖库的问题,常用的rpm命令,只用于卸载…...

手把手 S32K344移植FreeRTOS

版本信息 RTD:2.0.0.2022 S32DS:3.4.0.2020 下载 从S32K3参考软件下载FreeTROS FreeRTOS安装链接&#xff1a;https://www.nxp.com/webapp/swlicensing/sso/downloadSoftware.sp?catidSW32K3-REFSW-D 根据S32DS版本和S32K3 RTD 2.0.0 Package找到对应的FreeRTOS的zip安装…...

《云原生安全攻防》-- 云原生安全概述

从本节课程开始&#xff0c;我们将正式踏上云原生安全的学习之旅。在深入探讨云原生安全的相关概念之前&#xff0c;让我们先对云原生有一个全面的认识。 什么是云原生呢? 云原生&#xff08;Cloud Native&#xff09;是一个组合词&#xff0c;我们把它拆分为云和原生两个词来…...

综合分享1

VM及安装配置windows server 2008 1&#xff09;安装配置VM 确保是否正确安装&#xff1a; 1&#xff09;检查本地“网络与internal设置”中的虚拟网卡是否创建成功&#xff08;vmnet1和vmnet8&#xff09; 2&#xff09;必须通过services.msc查看vmware的所有是否已经…...

学校招生小程序源码介绍

基于ThinkPHPFastAdminUniApp开发的学校招生小程序源码&#xff0c;专为学校招生场景量身打造&#xff0c;功能实用且操作便捷。 从技术架构来看&#xff0c;ThinkPHP提供稳定可靠的后台服务&#xff0c;FastAdmin加速开发流程&#xff0c;UniApp则保障小程序在多端有良好的兼…...

ESP32 I2S音频总线学习笔记(四): INMP441采集音频并实时播放

简介 前面两期文章我们介绍了I2S的读取和写入&#xff0c;一个是通过INMP441麦克风模块采集音频&#xff0c;一个是通过PCM5102A模块播放音频&#xff0c;那如果我们将两者结合起来&#xff0c;将麦克风采集到的音频通过PCM5102A播放&#xff0c;是不是就可以做一个扩音器了呢…...

Cinnamon修改面板小工具图标

Cinnamon开始菜单-CSDN博客 设置模块都是做好的&#xff0c;比GNOME简单得多&#xff01; 在 applet.js 里增加 const Settings imports.ui.settings;this.settings new Settings.AppletSettings(this, HTYMenusonichy, instance_id); this.settings.bind(menu-icon, menu…...

令牌桶 滑动窗口->限流 分布式信号量->限并发的原理 lua脚本分析介绍

文章目录 前言限流限制并发的实际理解限流令牌桶代码实现结果分析令牌桶lua的模拟实现原理总结&#xff1a; 滑动窗口代码实现结果分析lua脚本原理解析 限并发分布式信号量代码实现结果分析lua脚本实现原理 双注解去实现限流 并发结果分析&#xff1a; 实际业务去理解体会统一注…...

Android Bitmap治理全解析:从加载优化到泄漏防控的全生命周期管理

引言 Bitmap&#xff08;位图&#xff09;是Android应用内存占用的“头号杀手”。一张1080P&#xff08;1920x1080&#xff09;的图片以ARGB_8888格式加载时&#xff0c;内存占用高达8MB&#xff08;192010804字节&#xff09;。据统计&#xff0c;超过60%的应用OOM崩溃与Bitm…...

Redis数据倾斜问题解决

Redis 数据倾斜问题解析与解决方案 什么是 Redis 数据倾斜 Redis 数据倾斜指的是在 Redis 集群中&#xff0c;部分节点存储的数据量或访问量远高于其他节点&#xff0c;导致这些节点负载过高&#xff0c;影响整体性能。 数据倾斜的主要表现 部分节点内存使用率远高于其他节…...

Swagger和OpenApi的前世今生

Swagger与OpenAPI的关系演进是API标准化进程中的重要篇章&#xff0c;二者共同塑造了现代RESTful API的开发范式。 本期就扒一扒其技术演进的关键节点与核心逻辑&#xff1a; &#x1f504; 一、起源与初创期&#xff1a;Swagger的诞生&#xff08;2010-2014&#xff09; 核心…...

听写流程自动化实践,轻量级教育辅助

随着智能教育工具的发展&#xff0c;越来越多的传统学习方式正在被数字化、自动化所优化。听写作为语文、英语等学科中重要的基础训练形式&#xff0c;也迎来了更高效的解决方案。 这是一款轻量但功能强大的听写辅助工具。它是基于本地词库与可选在线语音引擎构建&#xff0c;…...

SAP学习笔记 - 开发26 - 前端Fiori开发 OData V2 和 V4 的差异 (Deepseek整理)

上一章用到了V2 的概念&#xff0c;其实 Fiori当中还有 V4&#xff0c;咱们这一章来总结一下 V2 和 V4。 SAP学习笔记 - 开发25 - 前端Fiori开发 Remote OData Service(使用远端Odata服务)&#xff0c;代理中间件&#xff08;ui5-middleware-simpleproxy&#xff09;-CSDN博客…...

Xen Server服务器释放磁盘空间

disk.sh #!/bin/bashcd /run/sr-mount/e54f0646-ae11-0457-b64f-eba4673b824c # 全部虚拟机物理磁盘文件存储 a$(ls -l | awk {print $NF} | cut -d. -f1) # 使用中的虚拟机物理磁盘文件 b$(xe vm-disk-list --multiple | grep uuid | awk {print $NF})printf "%s\n"…...