【NLP】视觉变压器与卷积神经网络
一、说明
本篇是 变压器因其计算效率和可扩展性而成为NLP的首选模型。在计算机视觉中,卷积神经网络(CNN)架构仍然占主导地位,但一些研究人员已经尝试将CNN与自我注意相结合。作者尝试将标准变压器直接应用于图像,发现在中型数据集上训练时,与类似ResNet的架构相比,这些模型的准确性适中。然而,当在更大的数据集上进行训练时,视觉转换器(ViT)取得了出色的结果,并在多个图像识别基准上接近或超过了最先进的技术。本文记录这种结论,等有时机去验证。
二、CNN卷积网络transformer起源
这篇博文的灵感来自谷歌研究团队的一篇题为“图像价值16X16字:大规模图像识别的变形金刚”的论文。本文建议使用直接应用于图像补丁的纯转换器来完成图像分类任务。视觉转换器 (ViT) 在多个基准测试中优于最先进的卷积网络,同时在对大量数据进行预训练后,需要更少的计算资源进行训练。
变压器因其计算效率和可扩展性而成为NLP的首选模型。在计算机视觉中,卷积神经网络(CNN)架构仍然占主导地位,但一些研究人员已经尝试将CNN与自我注意相结合。作者尝试将标准变压器直接应用于图像,发现在中型数据集上训练时,与类似ResNet的架构相比,这些模型的准确性适中。然而,当在更大的数据集上进行训练时,视觉转换器(ViT)取得了出色的结果,并在多个图像识别基准上接近或超过了最先进的技术。
图 1(取自原始论文)描述了一个模型,该模型通过将 2D 图像转换为展平的 2D 补丁序列来处理 <>D 图像。然后将补丁映射到具有可训练线性投影的恒定潜在矢量大小。一个可学习的嵌入被附加到补丁序列之前,它在转换器编码器输出端的状态用作图像表示。然后将图像表示通过分类头进行预训练或微调。添加位置嵌入以保留位置信息,嵌入向量序列用作变压器编码器的输入,该编码器由多头自注意和 MLP 块的交替层组成。
过去,CNN长期以来一直是图像处理任务的首选。它们擅长通过卷积层捕获局部空间模式,从而实现分层特征提取。CNN擅长从大量图像数据中学习,并在图像分类,对象检测和分割等任务中取得了显着的成功。
虽然CNN在各种计算机视觉任务中拥有良好的记录,并且可以有效地处理大规模数据集,但视觉转换器在全局依赖关系和上下文理解至关重要的情况下具有优势。然而,视觉变压器通常需要大量的训练数据才能实现与CNN相当的性能。此外,CNN由于其可并行化的性质而具有计算效率,使其对于实时和资源受限的应用程序更加实用。
三、示例:CNN 与视觉转换器
在本节中,我们将使用 CNN 和视觉转换器方法,在 Kaggle 中可用的猫和狗数据集上训练视觉分类器。首先,我们将从 Kaggle 下载包含 25000 张 RGB 图像的猫和狗数据集。如果您还没有,可以阅读此处的说明,了解如何设置 Kaggle API 凭据。以下 Python 代码会将数据集下载到当前工作目录中。
from kaggle.api.kaggle_api_extended import KaggleApiapi = KaggleApi()
api.authenticate()# we write to the current directory with './'
api.dataset_download_files('karakaggle/kaggle-cat-vs-dog-dataset', path='./')
下载文件后,您可以使用以下命令解压缩文件。
!unzip -qq kaggle-cat-vs-dog-dataset.zip
!rm -r kaggle-cat-vs-dog-dataset.zip
使用以下命令克隆视觉转换器 GitHub 存储库。此存储库包含vision_tr目录下的视觉转换器所需的所有代码。
!git clone https://github.com/RustamyF/vision-transformer.git
!mv vision-transformer/vision_tr .
下载的数据需要清理并准备训练我们的图像分类器。创建以下实用程序函数以 Pytorch 的 DataLoader 格式清理和加载数据。
import torch.nn as nn
import torch
import torch.optim as optimfrom torchvision import datasets, models, transforms
from torch.utils.data import DataLoader, Dataset
from PIL import Image
from sklearn.model_selection import train_test_splitimport osclass LoadData:def __init__(self):self.cat_path = 'kagglecatsanddogs_3367a/PetImages/Cat'self.dog_path = 'kagglecatsanddogs_3367a/PetImages/Dog'def delete_non_jpeg_files(self, directory):for filename in os.listdir(directory):if not filename.endswith('.jpg') and not filename.endswith('.jpeg'):file_path = os.path.join(directory, filename)try:if os.path.isfile(file_path) or os.path.islink(file_path):os.unlink(file_path)elif os.path.isdir(file_path):shutil.rmtree(file_path)print('deleted', file_path)except Exception as e:print('Failed to delete %s. Reason: %s' % (file_path, e))def data(self):self.delete_non_jpeg_files(self.dog_path)self.delete_non_jpeg_files(self.cat_path)dog_list = os.listdir(self.dog_path)dog_list = [(os.path.join(self.dog_path, i), 1) for i in dog_list]cat_list = os.listdir(self.cat_path)cat_list = [(os.path.join(self.cat_path, i), 0) for i in cat_list]total_list = cat_list + dog_listtrain_list, test_list = train_test_split(total_list, test_size=0.2)train_list, val_list = train_test_split(train_list, test_size=0.2)print('train list', len(train_list))print('test list', len(test_list))print('val list', len(val_list))return train_list, test_list, val_list# data Augumentation
transform = transforms.Compose([transforms.Resize((224, 224)),transforms.RandomResizedCrop(224),transforms.RandomHorizontalFlip(),transforms.ToTensor(),
])class dataset(torch.utils.data.Dataset):def __init__(self, file_list, transform=None):self.file_list = file_listself.transform = transform# dataset lengthdef __len__(self):self.filelength = len(self.file_list)return self.filelength# load an one of imagesdef __getitem__(self, idx):img_path, label = self.file_list[idx]img = Image.open(img_path).convert('RGB')img_transformed = self.transform(img)return img_transformed, label
四、CNN方法
此图像分类器的 CNN 模型由三层 2D 卷积组成,内核大小为 3,步幅为 2,最大池化层为 2。在卷积层之后,有两个全连接层,每个层由 10 个节点组成。下面是说明此结构的代码片段:
class Cnn(nn.Module):def __init__(self):super(Cnn, self).__init__()self.layer1 = nn.Sequential(nn.Conv2d(3, 16, kernel_size=3, padding=0, stride=2),nn.BatchNorm2d(16),nn.ReLU(),nn.MaxPool2d(2))self.layer2 = nn.Sequential(nn.Conv2d(16, 32, kernel_size=3, padding=0, stride=2),nn.BatchNorm2d(32),nn.ReLU(),nn.MaxPool2d(2))self.layer3 = nn.Sequential(nn.Conv2d(32, 64, kernel_size=3, padding=0, stride=2),nn.BatchNorm2d(64),nn.ReLU(),nn.MaxPool2d(2))self.fc1 = nn.Linear(3 * 3 * 64, 10)self.dropout = nn.Dropout(0.5)self.fc2 = nn.Linear(10, 2)self.relu = nn.ReLU()def forward(self, x):out = self.layer1(x)out = self.layer2(out)out = self.layer3(out)out = out.view(out.size(0), -1)out = self.relu(self.fc1(out))out = self.fc2(out)return out
训练是用特斯拉T4(g4dn-xlarge)GPU机器进行的,训练了10个训练周期。Jupyter Notebook 在项目的 GitHub 存储库中可用,其中包含训练循环的代码。以下是每个纪元的训练循环的结果。
五、视觉转换器方法
视觉变压器架构设计有可定制的尺寸,可以根据特定要求进行调整。对于这种大小的图像数据集,此体系结构仍然很大。
from vision_tr.simple_vit import ViT
model = ViT(image_size=224,patch_size=32,num_classes=2,dim=128,depth=12,heads=8,mlp_dim=1024,dropout=0.1,emb_dropout=0.1,
).to(device)
视觉转换器中的每个参数都起着关键作用,如下所述:
image_size=224
:此参数指定模型输入图像的所需大小(宽度和高度)。在这种情况下,图像的大小应为 224x224 像素。patch_size=32
:图像被分成较小的补丁,此参数定义每个补丁的大小(宽度和高度)。在本例中,每个修补程序为 32x32 像素。num_classes=2
:此参数表示分类任务中的类数。在此示例中,模型旨在将输入分为两类(猫和狗)。dim=128
:它指定模型中嵌入向量的维数。嵌入捕获每个图像修补程序的表示形式。depth=12
:此参数定义视觉转换器模型(编码器模型)中的深度或层数。更高的深度允许更复杂的特征提取。heads=8
:此参数表示模型自注意机制中的注意力头数。mlp_dim=1024
:指定模型中多层感知器 (MLP) 隐藏层的维数。MLP 负责在自我注意后转换令牌表示。dropout=0.1
:此参数控制辍学率,这是一种用于防止过度拟合的正则化技术。它在训练期间将输入单位的一部分随机设置为 0。emb_dropout=0.1
:它定义了专门应用于令牌嵌入的辍学率。此丢弃有助于防止在训练期间过度依赖特定令牌。
使用Tesla T4(g4dn-xlarge)GPU机器对分类任务的视觉转换器进行了20个训练周期的训练。训练进行了20个epoch(而不是CNN使用的10个epoch),因为训练损失的收敛速度很慢。以下是每个纪元的训练循环的结果。
CNN 方法在 75 个时期内达到了 10% 的准确率,而视觉转换器模型的准确率达到了 69%,训练时间要长得多。
六、结论
总之,在比较CNN和Vision Transformer模型时,在模型大小,内存要求,准确性和性能方面存在显着差异。CNN 型号传统上以其紧凑的尺寸和高效的内存利用率而闻名,使其适用于资源受限的环境。事实证明,它们在图像处理任务中非常有效,并在各种计算机视觉应用中表现出出色的精度。另一方面,视觉变压器提供了一种强大的方法来捕获图像中的全局依赖关系和上下文理解,从而提高某些任务的性能。然而,与CNN相比,视觉变压器往往具有更大的模型尺寸和更高的内存要求。虽然它们可能会达到令人印象深刻的准确性,尤其是在处理较大的数据集时,但计算需求可能会限制它们在资源有限的场景中的实用性。最终,CNN 和 Vision Transformer 模型之间的选择取决于手头任务的特定要求,考虑可用资源、数据集大小以及模型复杂性、准确性和性能之间的权衡等因素。随着计算机视觉领域的不断发展,预计这两种架构将取得进一步进展,使研究人员和从业者能够根据他们的特定需求和限制做出更明智的选择。
相关文章:

【NLP】视觉变压器与卷积神经网络
一、说明 本篇是 变压器因其计算效率和可扩展性而成为NLP的首选模型。在计算机视觉中,卷积神经网络(CNN)架构仍然占主导地位,但一些研究人员已经尝试将CNN与自我注意相结合。作者尝试将标准变压器直接应用于图像,发现在…...

【redis】通过配置文件简述redis的rdb和aof
redis的持久化方式有2种,rdb,即通过快照的方式将全量数据以二进制记录在磁盘中,aof,仅追加文件,将增量的写命令追加在aof文件中。在恢复的时候,rdb要更快,但是会丢失一部分数据。aof丢失数据极少…...
Cypress 上传 pdf 变空白页问题
在使用cypress 上传文件时,上传正常,但是,pdf一直空白的,翻边了资料也没找到原因。最后在一个不起眼的地方发现了问题所在。 错误的代码: cy.fixture(CBKS.pdf).as(uploadFile)cy.get(.el-upload-dragger).selectFile…...

【ArcGIS Pro二次开发】(52):布局导出图片(批量)
在ArcGIS Pro中设定好布局后,可以直接导出为各种类型的图片。 这是很基本的功能,但是如果你的布局很多,一张一张导图就有点费劲。 之前有网友提出希望可以批量导图,要实现起来并不难,于是就做了这个工具。 一、要实现…...
Git拉取远程分支并创建本地分支
一、查看远程分支 使用如下git命令查看所有远程分支: git branch -r 查看远程和本地所有分支: git branch -a 查看本地分支: git branch 在输出结果中,前面带* 的是当前分支。 二、拉取远程分支并创建本地分支 方法一 使用…...

OSI七层模型——物理层
OSI模型的物理层位于协议栈的底部。它是 TCP/IP 模型的网络接入层的一部分。如果没有物理层,就没有网络。本模块详细介绍了连接到物理层的三种方法。 1 物理层的用途 1.1 物理连接 不管是在家连接本地打印机还是将其连接到另一国家/地区的网站上,在进…...
【NLP】使用变压器(tranformer)和自动编码器
一、说明 自然语言处理 (NLP)中,trnsformer和编码器是至关重要的概念;本篇不是探讨原理,而是讲现实中,如何调用和使用transformer以及encoder,注意。本文中有时出现“变压器”,那是transormer的同义词,在此事先声明。 二、NLP及其重要性的简要概述 NLP是人工…...

广州华锐互动:水利数字孪生智能管理系统的特色
水利数字孪生智能管理系统是一种基于数字孪生的新型水利管理工具,它通过将现实世界中的水利设施和设备数字化,并在虚拟环境中进行模拟和分析,为水利管理者提供更加直观、精准的决策支持。该系统具有以下亮点: 首先,水利…...
php使用chatGPT生成一些东西做一个记录
好久没写了,这么长时间都去坐一些自己感兴趣的事情去了。 之前使用chatgpt-3,效果一直不咋好,这里我们来说说各个版本区别 gpt-3收费成本可以接受,生成的内容对话有点不太聪明的样子 git-3.5-turbo收费相对来说低,生成文本质量…...

轻量级Web报表工具ActiveReportsJS全新发布v4.0,支持集成更多前端框架!
ActiveReportsJS 是一款基于 JavaScript 和 HTML5 的轻量级Web报表工具,采用拖拽式设计模式,不需任何服务器和组件支持,即可在 Mac、Linux 和 Windows 操作系统中,设计多种类型的报表。ActiveReportsJS 同时提供跨平台报表设计、纯…...

听GPT 讲K8s源代码--pkg(七)
k8s项目中 pkg/kubelet/config,pkg/kubelet/configmap,pkg/kubelet/container,pkg/kubelet/cri 这几个目录处理与 kubelet 配置、ConfigMap、容器管理和容器运行时交互相关的功能。它们共同构成了 kubelet 的核心功能,使其能够在 …...

STM32MP157驱动开发——按键驱动(线程化处理)
文章目录 “线程化处理”机制:内核函数线程化处理方式的按键驱动程序(stm32mp157)编程思路button_test.cgpio_key_drv.cMakefile修改设备树文件编译测试 “线程化处理”机制: 工作队列是在内核的线程的上下文中执行的 工作队列中有多个 work࿰…...
探究HTTP代理爬虫的反爬虫策略
在当前信息爆炸的时代,海量的数据成为了企业发展和决策的关键资源。然而,越来越多的网站为了保护数据和用户隐私的安全,采取了各种反爬虫策略。作为一家专业的HTTP代理产品供应商,我们一直在研究和优化反爬虫策略,为用…...

短视频去水印小程序,一键部署你的小程序,可开流量主,实现睡后收入
插件地址 短视频去水印小程序,一键部署你的小程序,可开流量主,实现睡后收入 插件说明 本插件包含以下两部分: 短视频去水印插件,仅为一个接口,可以集成到自己的任意程序中。短视频去水印插件配套小程序…...
通讯录系统
目录 通讯录系统头文件: 通讯录系统Test: 通讯录系统函数源代码: 通讯录系统头文件: #define _CRT_SECURE_NO_WARNINGS 1 #include <stdio.h> #include <stdlib.h> #include <string.h> #include <assert…...

14:00面试,14:06就出来了,问的问题有点变态。。。
从小厂出来,没想到在另一家公司又寄了。 到这家公司开始上班,加班是每天必不可少的,看在钱给的比较多的份上,就不太计较了。没想到5月一纸通知,所有人不准加班,加班费不仅没有了,薪资还要降40%,…...

F5 LTM 知识点和实验 3-负载均衡中的负载算法
第三章:负载均衡中的负载算法 负载算法分为静态的和动态的。静态的连接分布模式是预先设置的,流量处理中是不会变化的,动态的连接分布模式也是预先设置的,但是连接分布会根据某些因素的改变而调整。 轮询(round robi…...

多线程(JavaEE初阶系列2)
目录 前言: 1.什么是线程 2.为什么要有线程 3.进程与线程的区别与联系 4.Java的线程和操作系统线程的关系 5.多线程编程示例 6.创建线程 6.1继承Thread类 6.2实现Runnable接口 6.3继承Thread,使用匿名内部类 6.4实现Runnable接口,使…...
Ubuntu20.04点Ubuntu software没反应,打不开的解决方案(Ubuntu笔记)
首先检查Ubuntu Software的状态,在终端输入:systemctl status snap.ubuntu-software.ubuntu-software.service 如果状态显示为inactive,则需要启动snap.ubuntu-software.ubuntu-software.service,在终端输入:sudo sys…...

力扣1114.按序打印-----题目解析
题目描述 解析: class Foo {public int a 0;public Foo() {}public void first(Runnable printFirst) throws InterruptedException {// printFirst.run() outputs "first". Do not change or remove this line.printFirst.run();a;}public void second…...

手游刚开服就被攻击怎么办?如何防御DDoS?
开服初期是手游最脆弱的阶段,极易成为DDoS攻击的目标。一旦遭遇攻击,可能导致服务器瘫痪、玩家流失,甚至造成巨大经济损失。本文为开发者提供一套简洁有效的应急与防御方案,帮助快速应对并构建长期防护体系。 一、遭遇攻击的紧急应…...
Golang 面试经典题:map 的 key 可以是什么类型?哪些不可以?
Golang 面试经典题:map 的 key 可以是什么类型?哪些不可以? 在 Golang 的面试中,map 类型的使用是一个常见的考点,其中对 key 类型的合法性 是一道常被提及的基础却很容易被忽视的问题。本文将带你深入理解 Golang 中…...
MySQL 隔离级别:脏读、幻读及不可重复读的原理与示例
一、MySQL 隔离级别 MySQL 提供了四种隔离级别,用于控制事务之间的并发访问以及数据的可见性,不同隔离级别对脏读、幻读、不可重复读这几种并发数据问题有着不同的处理方式,具体如下: 隔离级别脏读不可重复读幻读性能特点及锁机制读未提交(READ UNCOMMITTED)允许出现允许…...
STM32+rt-thread判断是否联网
一、根据NETDEV_FLAG_INTERNET_UP位判断 static bool is_conncected(void) {struct netdev *dev RT_NULL;dev netdev_get_first_by_flags(NETDEV_FLAG_INTERNET_UP);if (dev RT_NULL){printf("wait netdev internet up...");return false;}else{printf("loc…...
MySQL中【正则表达式】用法
MySQL 中正则表达式通过 REGEXP 或 RLIKE 操作符实现(两者等价),用于在 WHERE 子句中进行复杂的字符串模式匹配。以下是核心用法和示例: 一、基础语法 SELECT column_name FROM table_name WHERE column_name REGEXP pattern; …...

Maven 概述、安装、配置、仓库、私服详解
目录 1、Maven 概述 1.1 Maven 的定义 1.2 Maven 解决的问题 1.3 Maven 的核心特性与优势 2、Maven 安装 2.1 下载 Maven 2.2 安装配置 Maven 2.3 测试安装 2.4 修改 Maven 本地仓库的默认路径 3、Maven 配置 3.1 配置本地仓库 3.2 配置 JDK 3.3 IDEA 配置本地 Ma…...

Netty从入门到进阶(二)
二、Netty入门 1. 概述 1.1 Netty是什么 Netty is an asynchronous event-driven network application framework for rapid development of maintainable high performance protocol servers & clients. Netty是一个异步的、基于事件驱动的网络应用框架,用于…...
Java数值运算常见陷阱与规避方法
整数除法中的舍入问题 问题现象 当开发者预期进行浮点除法却误用整数除法时,会出现小数部分被截断的情况。典型错误模式如下: void process(int value) {double half = value / 2; // 整数除法导致截断// 使用half变量 }此时...
CRMEB 中 PHP 短信扩展开发:涵盖一号通、阿里云、腾讯云、创蓝
目前已有一号通短信、阿里云短信、腾讯云短信扩展 扩展入口文件 文件目录 crmeb\services\sms\Sms.php 默认驱动类型为:一号通 namespace crmeb\services\sms;use crmeb\basic\BaseManager; use crmeb\services\AccessTokenServeService; use crmeb\services\sms\…...
为什么要创建 Vue 实例
核心原因:Vue 需要一个「控制中心」来驱动整个应用 你可以把 Vue 实例想象成你应用的**「大脑」或「引擎」。它负责协调模板、数据、逻辑和行为,将它们变成一个活的、可交互的应用**。没有这个实例,你的代码只是一堆静态的 HTML、JavaScript 变量和函数,无法「活」起来。 …...