当前位置: 首页 > news >正文

Pytorch之MobileViT图像分类

文章目录

  • 前言
  • 一、Transformer存在的问题
  • 二、MobileViT
    • 1.MobileViT网络结构
      • 🍓 Vision Transformer结构
      • 🍉MobileViT结构
    • 2.MV2(MobileNet v2 block)
    • 3.MobileViT block
      • 🥇Local representations
      • 🥈Transformers as Convolutions (global representations)
      • 🥉Fusion
    • 4.模型配置
    • 5.MobileViT优势
  • 三、MobileViT网络实现
    • 1.构建网络模型
    • 2.训练和测试模型
  • 四、图像分类
  • 结束语


  • 💂 个人主页:风间琉璃
  • 🤟 版权: 本文由【风间琉璃】原创、在CSDN首发、需要转载请联系博主
  • 💬 如果文章对你有帮助欢迎关注点赞收藏(一键三连)订阅专栏

前言

MobileViT是一种基于ViT(Vision Transformer)架构的轻量级视觉模型,旨在适用于移动设备和嵌入式系统。ViT是一种非常成功的深度学习模型,用于图像分类和其他计算机视觉任务,但通常需要大量的计算资源和参数。MobileViT的目标是在保持高性能的同时,减少模型的大小和计算需求,以便在移动设备上运行,据作者介绍,这是第一次基于轻量级CNN网络性能的轻量级ViT工作,性能SOTA。性能优于MobileNetV3、CrossviT等网络。


一、Transformer存在的问题

MobileVitV1是苹果公司2021年发表的一篇轻量型主干网络,它是CNNTransfomrer混合架构模型(CNN的轻量和高效+Transformer的自注意力机制和全局视野),这样的架构模型也是现在很多研究者们青睐的架构之一。

自Vision Transformer出现之后,人们发现Transfomrer也可以应用在计算机视觉领域,并且效果还是非常不错的。但是基于Transformer的网络模型存在着以下问题:

参数多,算力要求高
Transformer模型通常具有数十亿或数百亿个参数,这使得它们的模型文件非常大,不仅占用大量存储空间,而且在训练和部署过程中也需要更多的计算资源。

缺少空间归纳偏置
即纯Transformer对空间位置信息不敏感,但是,我们在进行视觉应用的时位置信息又比较重要,为了解决这个问题就引入了位置编码。

归纳 (Induction) 是自然科学中常用的两大方法之一 (归纳与演绎,Induction & Deduction),指从一些例子中寻找共性、泛化,形成一个较通用的规则的过程。偏置 (Bias) 则是指对模型的偏好,以下展示了 4 种解释:

∙ \bullet 通俗理解:归纳偏置可以理解为,从现实生活中观察到的现象中归纳出一定的 规则 (heuristics),然后对模型做一定的约束,从而可以起到 “模型选择” 的作用,类似贝叶斯学习中的 “先验”。
∙ \bullet 西瓜书解释:机器学习算法在学习过程中对某种类型假设的偏好,称为归纳偏好。归纳偏好可以看作学习算法自身在一个庞大的假设空间中对假设进行选择的启发式或 “价值观”。
∙ \bullet 维基百科解释:如果学习器需要去预测 “其未遇到过的输入” 的结果时,则需要一些假设来帮助它做出选择。
∙ \bullet 广义解释:归纳偏置会促使学习算法优先考虑具有某些属性的解。

深度神经网络偏好性地认为,层次化处理信息有更好效果;卷积神经网络认为信息具有空间局部性,可用滑动卷积共享权重的方式降低参数空间;循环神经网络则将时序信息纳入考虑,强调顺序重要性;图网络则认为中心节点与邻居节点的相似性会更好地引导信息流动。通常,模型容量 (capacity) 很大但 Inductive Bias 匮乏则容易过拟合 (overfitting),如 Transformer

CNN的空间归纳偏差内容如下:

CNN 的 归纳偏置(Inductive Bias)局部性 (Locality) 空间不变性 (Spatial Invariance) / 平移等效性 (Translation Equivariance),即空间位置上的元素 (Grid Elements) 的联系/相关性近大远小,以及空间平移的不变性 (Kernel 权重共享)。

⋆ \star locality:CNN是以滑动窗口的形式一点一点地在图片上进行卷积的,所以假设图片上相邻的区域会有相邻的特征,靠得越近的东西相关性越强;

⋆ \star translation equivariance(平移等变性或平移同变性):用公式表示为f(g(x))=g(f(x)),不论是先经过g映射,还是先经过f映射,其结果是不变的;其中f代表卷积操作,g代表平移操作。因为在卷积神经网络中,卷积核相当于是一个模板,不论图片中同样的物体移动到哪里,只要是相同的输入,经过相同的卷积核,其输出是不变的。

一旦网络(CNN)模型有了这两个归纳偏置,它就拥有很多的先验信息所以只需要相对较少的数据就可以学习一个相对比较好的模型。但是对于transformer来说,它没有这些先验信息,所以它对视觉的感知全部需要从这些数据中自己学习。

因此transformer结构的网络模型需要大量的数据才能得到不错的效果,如果使用少量数据进行训练,那么会掉点很明显。这是因为Transformer缺少空间归纳偏置,空间归纳偏置允许CNN在不同的视觉任务中学习较少参数的表示

虽然Transformer缺少空间归纳偏置必须要大量数据来进行学习数据中的某种特性,从而导致无法很好的应用在这样的边缘设备。但是CNN也有缺点CNN在空间上获取的信息是局部的,因此一定程度上会制约着CNN网络结构的性能,而Transformer的自注意力机制能够获取全局信息。

模型迁移困难

这个问题核心是引入的位置编码导致的。 Transformer 网络需要先对原始的图像进行切片处理,一般来说训练好的 ViT 网络原始的输入图像大小 224×224,patch 大小为 16×16,那么得到的 patch个数也就固定了。由于 Transformer 网络缺少空间归纳偏置在计算某一个 token 时其他 token 位置顺序发生变化并不会影响到最终的实验结果,也即输出与位置信息无关。而我们知道对于图像来说,空间信息是非常重要且具有实际意义的,因此,Transformer 通过加上位置偏置ViT使用绝对位置偏置,Swin T引入相对位置偏置来解决位置信息的丢失问题

但是,当输入图像的尺寸或者 patch 大小发生变化时,训练好的模型就会因为位置信息不准确而失效。目前常见的处理方法是将位置偏置信息进行插值,插值到所需要的序列长度从而匹配到图像的尺寸。这种方式需要对训练好的模型进行微调才能保证性能不出现大幅损失,每次改变输入图像的尺寸或者 patch 的尺寸均需要对位置编码进行插值和对网络进行微调,这提高了网络迁移的难度

Swin T网络使用了相对位置偏置,理论上来说序列的长度只与窗 windows 的大小有关而与输入图像的尺寸无关。但是,windows的大小一般被设定与输入尺寸匹配,当输入尺寸变大时,window 的大小也应该相应的增大,那么所使用的相对位置偏置序列也应该增大,这也会导致上述问题。这些问题将导致 Transformer 网络迁移时比 CNN 网络迁移得更加困难和繁琐。

模型训练困难

根据现有的一些经验,Transformer相比CNN要更难训练。Transformer需要更多的训练数据需要迭代更多的epoch需要更大的正则项(L2正则)需要更多的数据增强(且对数据增强很敏感)。

针对以上问题,采用CNN与Transformer的混合架构CNN能够提供空间归纳偏置所以可以解决位置偏置,而且加入CNN后能够加速网络的收敛,使网络训练过程更加的稳定

二、MobileViT

1.MobileViT网络结构

🍓 Vision Transformer结构

下图是MobileViT论文中绘制的Standard visual Transformer。首先将输入的图片划分成N个Patch,然后通过线性变化将每个Patch映射到一维向量中(Token),接着加上位置偏置信息(可学习参数),再通过一系列Transformer Block,最后通过一个全连接层得到最终预测输出。
在这里插入图片描述
首先将C,H,W的图片进行Patch处理成N个向量,然后经过线性层进行降低向量维度,再经过位置编码,然后再经过N个Transformer块,在通过class token来进行分类。

这个Standard visual Transformer和前面文章中ViT有一点不同,这里没有class token,class token只是针对分类才加上去的,上面这个网络才是最标准的视觉ViT网络。

由于VIT忽略了空间归纳偏差,所以它们需要更多的参数来学习视觉表征。此外,与CNN相比,VIT及其多种变体的优化性能不佳,这些模型对L2正则化很敏感,需要大量的数据增强以防止过拟合

🍉MobileViT结构

上面展示是标准视觉ViT模型,下面来看下本次介绍的重点:Mobile-ViT网路结构,如下图所示:
在这里插入图片描述通过上图可以看到MobileViT主要由普通卷积MV2(MobiletNetV2中的Inverted Residual block),MobileViT block全局池化以及全连接层共同组成。

其中,MobileViT块中的Convn × n表示一个标准的n × n卷积MV2指的是MobileNetv2块执行下采样的块用↓2标记

2.MV2(MobileNet v2 block)

MV2 块指MobileNet v2 block,是一个Inverted Residual Block(倒残差结构)。 在倒残差结构中,即特征图的维度是先升后降,据相关论文中描述说,更高的维度经过激活函数后,它损失的信息就会少一些。(注意倒残差结构中基本使用的都是ReLU6激活函数,但是最后一个1x1的卷积层使用的是线性激活函数)。具体网络结构如下图所示。
在这里插入图片描述
MobileViT结构图中标有向下箭头的MV2结构代表stride等于2的情况,即需要进行下采样

🌼Residual Block(残差结构):
①1x1卷积降维
②3x3卷积
③1x1卷积升维
🌻Inverted Residual Block(倒残差结构)
①1x1卷积升维
②3x3卷积DW
③1x1卷积降维

3.MobileViT block

MV2来源于mobilenetv2,所以Mobile-ViT的核心是MobileViT block模块。MobileViT block的结构如下图所示:
在这里插入图片描述
MobileViT Block旨在用更少的参数对输入张量中的局部全局信息进行建模。由上图可知MobileViT Block 整体由三部分组成分别为:Local representationsTransformers as Convolutions (global representations)Fusion

大致流程:首先将特征图通过一个卷积核大小为nxn(代码中是3x3)的卷积层进行局部的特征建模,然后通过一个卷积核大小为1x1的卷积层调整通道数。接着通过Unfold -> Transformer -> Fold结构进行全局的特征建模,然后再通过一个卷积核大小为1x1的卷积层将通道数调整回原始大小。接着通过shortcut分支(在V2版本中将该捷径分支取消了)与原始输入特征图进行Concat拼接(沿通道channel方向拼接),最后再通过一个卷积核大小为nxn(代码中是3x3)的卷积层做特征融合得到输出

Global representations它的具体计算过程如下图所示,
在这里插入图片描述
首先对特征图划分Patch(忽略了通道channels),图中的Patch大小为2x2,即每个Patch由4个Pixel组成。

在进行Self-Attention计算的时候,每个Token(图中的每个Pixel或者说每个小颜色块)只和颜色相同的Token进行Attention,可以减少参数计算量。对于原始的Self-Attention计算每个Token是需要和所有的Token进行Self-Attention。

假设特征图的高宽和通道数分别为H, W, C,在输入到Transformer中,在Self-Attention的时候,每个图中的每个像素和其他的像素进行计算,这样计算量就是:
P 1 = W ∗ H ∗ C P_1 = W*H*C P1=WHC

MobileViT中的是先对输入的特征图划分成多个的patch,但是在计算Self-Attention的时候只对相同位置的像素计算,即图中展示的颜色相同的位置,这样就可以相对的减少计算量,这个时候的计算量为:
P 2 = W ∗ H ∗ C 4 P_2 = \frac{W*H*C}{4} P2=4WHC即理论上的计算成本仅为原始的 1 4 \frac{1}{4} 41

在本次的自注意力机制中,只选择了位置相同的像素点进行点积操作。这样做的原因大概就是因为和所有的像素点都进行自注意力操作会带来信息冗余,毕竟不是所有的像素含有有用的信息对于图像数据本身就存在大量的数据冗余,一张图像的每个像素点的周围的像素值都差不多,并且分辨率越高相差越小,所以这样做并不会损失太多的信息。而且MobileViT在做全局表征之前已经做了一次局部表征(Local representations),进行全局建模时可以忽略一些信息。

Global representations中的​UnfoldFold只是为了将数据给reshape成计算Self-Attention时所需的数据格式。unfold就是将颜色相同的部分拼成一个序列输入到Transformer进行建模,最后再通过fold是调整为原始大小,如下图所示:
在这里插入图片描述
下面来简单的看下patch size对模型性能的影响,patch如果划分的比较大的话是可以减少计算量的,但是划分的太大的话又会忽略更多的语义信息,影响模型的性能。

下图从左到右对语义信息的要求逐渐递增。其中配置A的patch大小为{2, 2, 2},配置B的patch大小为{8, 4, 2},这三个数字分别对应下采样倍率为8,16,32的特征图所采用的patch大小。通过对比可以发现,在图像分类目标检测任务中(对语义细节要求不高的场景),配置A和配置B在Acc和mAP上没太大区别,但配置B要更快。但在语义分割任务中(对语义细节要求较高的场景)配置A的效果要更好。
在这里插入图片描述

🥇Local representations

Local representations 表示输入信息的局部表达。在这个部分,输入MobileViT Block 的数据会经过一个 n × n n \times n n×n的卷积块和一个 1 × 1 1 \times 1 1×1的卷积块。

从上文所述的CNN的空间归纳偏差就可以得知:经过 n × n n \times n n×n(n=3)的卷积块的输出获取到了输入模型的局部信息表达(因为卷积块是对一个整体块进行操作,但是这个卷积核的n是远远小于数据规模的,所以是局部信息表达,而不是全局信息表达)。另外, 1 × 1 1 \times 1 1×1的卷积块是为了线性投影将数据投影至高维空间。例如:对于 9 × 9 9\times 9 9×9的数据,使用 3 × 3 3\times 3 3×3的卷积层,获取到的每个数据都是对 9 × 9 9\times 9 9×9 数据的局部表达

🥈Transformers as Convolutions (global representations)

Transformers as Convolutions (global representations) 表示输入信息的全局表示。在Transformers as Convolutions 中首先通过Unfold 对数据进行转换,转化为 Transformer 可以接受的 1D 数据。然后将数据输入到Transformer 块中。最后通过Fold再将数据变换成原有的样子。

🥉Fusion

Fusion中,经过Transformers as Convolutions得到的信息原始输入信息 ( A ∈ R H × W × C ) (\mathrm{A} \in \mathrm{R^{\mathrm{H \times W \times C}}}) (ARH×W×C)进行合并,然后使用另一个 n × n n\times n n×n卷积层来融合这些连接的特征。这里,得到的信息指:全局表征 X F ∈ R H × W × d \mathrm{X_F} \in \mathrm{R^{\mathrm{H \times W \times d}}} XFRH×W×d经过逐点卷积( 1 × 1 1\times 1 1×1卷积)得到的输出 X F u ∈ R H × W × d \mathrm{X_{Fu}} \in \mathrm{R^{\mathrm{H \times W \times d}}} XFuRH×W×d ,并通过Concat操作与 X \mathrm{X} X组合。

4.模型配置

论文中总共给出了三组模型配置,即MobileViT-S(small)、MobileViT-XS(extra small)、MobileViT-XXS(extra extra small),三种配置是越来越轻量化,三者的主要区别在于特征图的通道数不同

下图为MobileViT的整体框架,主要看下图中的标出的Layer1~5,这里是根据源码中的配置信息划分的:
在这里插入图片描述
对于MobileViT-XXS,Layer1~5的详细配置信息如下:
在这里插入图片描述
对于MobileViT-XS,Layer1~5的详细配置信息如下:
在这里插入图片描述
对于MobileViT-S,Layer1~5的详细配置信息如下:
在这里插入图片描述
参数说明:
⋆ \star out_channels表示该模块输出的通道数
⋆ \star mv2_exp表示Inverted Residual Block中的expansion ratio
⋆ \star transformer_channels表示Transformer模块输入Token的序列长度(特征图通道数)
⋆ \star num_heads表示多头自注意力机制中的head数
⋆ \star ffn_dim表示FFN中间层Token的序列长度
⋆ \star patch_h表示每个patch的高度
⋆ \star patch_w表示每个patch的宽度

5.MobileViT优势

🍄更好的性能: 对于给定的参数预算,MobileViT 在不同的移动视觉任务(图像分类、物体检测、语义分割)中取得了比现有的轻量级 CNN 更好的性能

🍄更好的泛化能力泛化能力是指训练和评价指标之间的差距。对于具有相似训练指标的2个模型,具有更好评价指标的模型更具有通用性,因为它可以更好地预测未知数据集。与CNN相比,即使有广泛的数据增强,其泛化能力也很差,MobileViT显示出更好的泛化能力。

🍄更好的鲁棒性:一个好的模型应该对超参数具有鲁棒性,因为调优这些超参数会消耗时间和资源。与大多数基于ViT的模型不同,MobileViT模型使用基本增强训练,对L2正则化不太敏感

总之,MobileViT使用CNNTransformer相融合的方案,在减少模型复杂度的同时,提高了模型的精度和鲁棒性

⋆ \star 对于一个模型,如果全都使用 CNN 结构。模型只能获取到数据的局部信息而获取不到全局信息
⋆ \star 对于一个模型,如果全部使用 Transformer 结构。模型可以获取到全局信息。但是,Transformer 结构会带来较大的复杂度,存在训练时间上升,模型容易过拟合等等问题。

因此,基于上述问题。作者先使用CNN获取局部信息,然后使用 Transformer 结构获取全局信息。通过上述的理解可以发现:在MobileViT 中的Transformer 结构中,复杂度相比于 ViT 结构 中复杂度降低了很多,因为输入数据复杂度的降低。最终实验结果同时表明:MobileViT 精度更高且鲁棒性更好

三、MobileViT网络实现

1.构建网络模型

首先要构建MobileViT block,其结构图如下所示:
在这里插入图片描述
Transformer实现:
在这里插入图片描述


class MultiHeadAttention(nn.Module):"""This layer applies a multi-head self- or cross-attention as described in`Attention is all you need <https://arxiv.org/abs/1706.03762>`_ paperArgs:embed_dim (int): :math:`C_{in}` from an expected input of size :math:`(N, P, C_{in})`num_heads (int): Number of heads in multi-head attentionattn_dropout (float): Attention dropout. Default: 0.0bias (bool): Use bias or not. Default: ``True``Shape:- Input: :math:`(N, P, C_{in})` where :math:`N` is batch size, :math:`P` is number of patches,and :math:`C_{in}` is input embedding dim- Output: same shape as the input"""def __init__(self,embed_dim: int,num_heads: int,attn_dropout: float = 0.0,bias: bool = True,*args,**kwargs) -> None:super().__init__()if embed_dim % num_heads != 0:raise ValueError("Embedding dim must be divisible by number of heads in {}. Got: embed_dim={} and num_heads={}".format(self.__class__.__name__, embed_dim, num_heads))self.qkv_proj = nn.Linear(in_features=embed_dim, out_features=3 * embed_dim, bias=bias)self.attn_dropout = nn.Dropout(p=attn_dropout)self.out_proj = nn.Linear(in_features=embed_dim, out_features=embed_dim, bias=bias)self.head_dim = embed_dim // num_headsself.scaling = self.head_dim ** -0.5self.softmax = nn.Softmax(dim=-1)self.num_heads = num_headsself.embed_dim = embed_dimdef forward(self, x_q: Tensor) -> Tensor:# [N, P, C]b_sz, n_patches, in_channels = x_q.shape# self-attention# [N, P, C] -> [N, P, 3C] -> [N, P, 3, h, c] where C = hcqkv = self.qkv_proj(x_q).reshape(b_sz, n_patches, 3, self.num_heads, -1)# [N, P, 3, h, c] -> [N, h, 3, P, C]qkv = qkv.transpose(1, 3).contiguous()# [N, h, 3, P, C] -> [N, h, P, C] x 3query, key, value = qkv[:, :, 0], qkv[:, :, 1], qkv[:, :, 2]query = query * self.scaling# [N h, P, c] -> [N, h, c, P]key = key.transpose(-1, -2)# QK^T# [N, h, P, c] x [N, h, c, P] -> [N, h, P, P]attn = torch.matmul(query, key)attn = self.softmax(attn)attn = self.attn_dropout(attn)# weighted sum# [N, h, P, P] x [N, h, P, c] -> [N, h, P, c]out = torch.matmul(attn, value)# [N, h, P, c] -> [N, P, h, c] -> [N, P, C]out = out.transpose(1, 2).reshape(b_sz, n_patches, -1)out = self.out_proj(out)return outclass TransformerEncoder(nn.Module):"""This class defines the pre-norm `Transformer encoder <https://arxiv.org/abs/1706.03762>`_Args:embed_dim (int): :math:`C_{in}` from an expected input of size :math:`(N, P, C_{in})`ffn_latent_dim (int): Inner dimension of the FFNnum_heads (int) : Number of heads in multi-head attention. Default: 8attn_dropout (float): Dropout rate for attention in multi-head attention. Default: 0.0dropout (float): Dropout rate. Default: 0.0ffn_dropout (float): Dropout between FFN layers. Default: 0.0Shape:- Input: :math:`(N, P, C_{in})` where :math:`N` is batch size, :math:`P` is number of patches,and :math:`C_{in}` is input embedding dim- Output: same shape as the input"""def __init__(self,embed_dim: int,ffn_latent_dim: int,num_heads: Optional[int] = 8,attn_dropout: Optional[float] = 0.0,dropout: Optional[float] = 0.0,ffn_dropout: Optional[float] = 0.0,*args,**kwargs) -> None:super().__init__()attn_unit = MultiHeadAttention(embed_dim,num_heads,attn_dropout=attn_dropout,bias=True)self.pre_norm_mha = nn.Sequential(nn.LayerNorm(embed_dim),attn_unit,nn.Dropout(p=dropout))self.pre_norm_ffn = nn.Sequential(nn.LayerNorm(embed_dim),nn.Linear(in_features=embed_dim, out_features=ffn_latent_dim, bias=True),nn.SiLU(),nn.Dropout(p=ffn_dropout),nn.Linear(in_features=ffn_latent_dim, out_features=embed_dim, bias=True),nn.Dropout(p=dropout))self.embed_dim = embed_dimself.ffn_dim = ffn_latent_dimself.ffn_dropout = ffn_dropoutself.std_dropout = dropoutdef forward(self, x: Tensor) -> Tensor:# multi-head attentionres = xx = self.pre_norm_mha(x)x = x + res# feed forward networkx = x + self.pre_norm_ffn(x)return x

MobileViT的整体框架,主要看下图中的标出的Layer1~5,这里是根据源码中的配置信息划分的:
在这里插入图片描述

def get_config(mode: str = "xxs") -> dict:if mode == "xx_small":mv2_exp_mult = 2config = {"layer1": {"out_channels": 16,"expand_ratio": mv2_exp_mult,"num_blocks": 1,"stride": 1,"block_type": "mv2",},"layer2": {"out_channels": 24,"expand_ratio": mv2_exp_mult,"num_blocks": 3,"stride": 2,"block_type": "mv2",},"layer3": {  # 28x28"out_channels": 48,"transformer_channels": 64,"ffn_dim": 128,"transformer_blocks": 2,"patch_h": 2,  # 8,"patch_w": 2,  # 8,"stride": 2,"mv_expand_ratio": mv2_exp_mult,"num_heads": 4,"block_type": "mobilevit",},"layer4": {  # 14x14"out_channels": 64,"transformer_channels": 80,"ffn_dim": 160,"transformer_blocks": 4,"patch_h": 2,  # 4,"patch_w": 2,  # 4,"stride": 2,"mv_expand_ratio": mv2_exp_mult,"num_heads": 4,"block_type": "mobilevit",},"layer5": {  # 7x7"out_channels": 80,"transformer_channels": 96,"ffn_dim": 192,"transformer_blocks": 3,"patch_h": 2,"patch_w": 2,"stride": 2,"mv_expand_ratio": mv2_exp_mult,"num_heads": 4,"block_type": "mobilevit",},"last_layer_exp_factor": 4,"cls_dropout": 0.1}elif mode == "x_small":mv2_exp_mult = 4config = {"layer1": {"out_channels": 32,"expand_ratio": mv2_exp_mult,"num_blocks": 1,"stride": 1,"block_type": "mv2",},"layer2": {"out_channels": 48,"expand_ratio": mv2_exp_mult,"num_blocks": 3,"stride": 2,"block_type": "mv2",},"layer3": {  # 28x28"out_channels": 64,"transformer_channels": 96,"ffn_dim": 192,"transformer_blocks": 2,"patch_h": 2,"patch_w": 2,"stride": 2,"mv_expand_ratio": mv2_exp_mult,"num_heads": 4,"block_type": "mobilevit",},"layer4": {  # 14x14"out_channels": 80,"transformer_channels": 120,"ffn_dim": 240,"transformer_blocks": 4,"patch_h": 2,"patch_w": 2,"stride": 2,"mv_expand_ratio": mv2_exp_mult,"num_heads": 4,"block_type": "mobilevit",},"layer5": {  # 7x7"out_channels": 96,"transformer_channels": 144,"ffn_dim": 288,"transformer_blocks": 3,"patch_h": 2,"patch_w": 2,"stride": 2,"mv_expand_ratio": mv2_exp_mult,"num_heads": 4,"block_type": "mobilevit",},"last_layer_exp_factor": 4,"cls_dropout": 0.1}elif mode == "small":mv2_exp_mult = 4config = {"layer1": {"out_channels": 32,"expand_ratio": mv2_exp_mult,"num_blocks": 1,"stride": 1,"block_type": "mv2",},"layer2": {"out_channels": 64,"expand_ratio": mv2_exp_mult,"num_blocks": 3,"stride": 2,"block_type": "mv2",},"layer3": {  # 28x28"out_channels": 96,"transformer_channels": 144,"ffn_dim": 288,"transformer_blocks": 2,"patch_h": 2,"patch_w": 2,"stride": 2,"mv_expand_ratio": mv2_exp_mult,"num_heads": 4,"block_type": "mobilevit",},"layer4": {  # 14x14"out_channels": 128,"transformer_channels": 192,"ffn_dim": 384,"transformer_blocks": 4,"patch_h": 2,"patch_w": 2,"stride": 2,"mv_expand_ratio": mv2_exp_mult,"num_heads": 4,"block_type": "mobilevit",},"layer5": {  # 7x7"out_channels": 160,"transformer_channels": 240,"ffn_dim": 480,"transformer_blocks": 3,"patch_h": 2,"patch_w": 2,"stride": 2,"mv_expand_ratio": mv2_exp_mult,"num_heads": 4,"block_type": "mobilevit",},"last_layer_exp_factor": 4,"cls_dropout": 0.1}else:raise NotImplementedErrorfor k in ["layer1", "layer2", "layer3", "layer4", "layer5"]:config[k].update({"dropout": 0.1, "ffn_dropout": 0.0, "attn_dropout": 0.0})return config

现在开始构建MobileVit网络模型

def make_divisible(v: Union[float, int],divisor: Optional[int] = 8,min_value: Optional[Union[float, int]] = None,
) -> Union[float, int]:"""This function is taken from the original tf repo.It ensures that all layers have a channel number that is divisible by 8It can be seen here:https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py:param v::param divisor::param min_value::return:"""if min_value is None:min_value = divisornew_v = max(min_value, int(v + divisor / 2) // divisor * divisor)# Make sure that round down does not go down by more than 10%.if new_v < 0.9 * v:new_v += divisorreturn new_v# 卷积层
class ConvLayer(nn.Module):"""Applies a 2D convolution over an inputArgs:in_channels (int): :math:`C_{in}` from an expected input of size :math:`(N, C_{in}, H_{in}, W_{in})`out_channels (int): :math:`C_{out}` from an expected output of size :math:`(N, C_{out}, H_{out}, W_{out})`kernel_size (Union[int, Tuple[int, int]]): Kernel size for convolution.stride (Union[int, Tuple[int, int]]): Stride for convolution. Default: 1groups (Optional[int]): Number of groups in convolution. Default: 1bias (Optional[bool]): Use bias. Default: ``False``use_norm (Optional[bool]): Use normalization layer after convolution. Default: ``True``use_act (Optional[bool]): Use activation layer after convolution (or convolution and normalization).Default: ``True``Shape:- Input: :math:`(N, C_{in}, H_{in}, W_{in})`- Output: :math:`(N, C_{out}, H_{out}, W_{out})`.. note::For depth-wise convolution, `groups=C_{in}=C_{out}`."""def __init__(self,in_channels: int,out_channels: int,kernel_size: Union[int, Tuple[int, int]],stride: Optional[Union[int, Tuple[int, int]]] = 1,groups: Optional[int] = 1,bias: Optional[bool] = False,use_norm: Optional[bool] = True,use_act: Optional[bool] = True,) -> None:super().__init__()if isinstance(kernel_size, int):kernel_size = (kernel_size, kernel_size)if isinstance(stride, int):stride = (stride, stride)assert isinstance(kernel_size, Tuple)assert isinstance(stride, Tuple)padding = (int((kernel_size[0] - 1) / 2),int((kernel_size[1] - 1) / 2),)block = nn.Sequential()conv_layer = nn.Conv2d(in_channels=in_channels,out_channels=out_channels,kernel_size=kernel_size,stride=stride,groups=groups,padding=padding,bias=bias)block.add_module(name="conv", module=conv_layer)if use_norm:norm_layer = nn.BatchNorm2d(num_features=out_channels, momentum=0.1)block.add_module(name="norm", module=norm_layer)if use_act:act_layer = nn.SiLU()block.add_module(name="act", module=act_layer)self.block = blockdef forward(self, x: Tensor) -> Tensor:return self.block(x)# MV2
class InvertedResidual(nn.Module):"""This class implements the inverted residual block, as described in `MobileNetv2 <https://arxiv.org/abs/1801.04381>`_ paperArgs:in_channels (int): :math:`C_{in}` from an expected input of size :math:`(N, C_{in}, H_{in}, W_{in})`out_channels (int): :math:`C_{out}` from an expected output of size :math:`(N, C_{out}, H_{out}, W_{out)`stride (int): Use convolutions with a stride. Default: 1expand_ratio (Union[int, float]): Expand the input channels by this factor in depth-wise convskip_connection (Optional[bool]): Use skip-connection. Default: TrueShape:- Input: :math:`(N, C_{in}, H_{in}, W_{in})`- Output: :math:`(N, C_{out}, H_{out}, W_{out})`.. note::If `in_channels =! out_channels` and `stride > 1`, we set `skip_connection=False`"""def __init__(self,in_channels: int,out_channels: int,stride: int,expand_ratio: Union[int, float],skip_connection: Optional[bool] = True,) -> None:assert stride in [1, 2]hidden_dim = make_divisible(int(round(in_channels * expand_ratio)), 8)super().__init__()block = nn.Sequential()if expand_ratio != 1:block.add_module(name="exp_1x1",module=ConvLayer(in_channels=in_channels,out_channels=hidden_dim,kernel_size=1),)block.add_module(name="conv_3x3",module=ConvLayer(in_channels=hidden_dim,out_channels=hidden_dim,stride=stride,kernel_size=3,groups=hidden_dim),)block.add_module(name="red_1x1",module=ConvLayer(in_channels=hidden_dim,out_channels=out_channels,kernel_size=1,use_act=False,use_norm=True,),)self.block = blockself.in_channels = in_channelsself.out_channels = out_channelsself.exp = expand_ratioself.stride = strideself.use_res_connect = (self.stride == 1 and in_channels == out_channels and skip_connection)def forward(self, x: Tensor, *args, **kwargs) -> Tensor:if self.use_res_connect:return x + self.block(x)else:return self.block(x)class MobileViTBlock(nn.Module):"""This class defines the `MobileViT block <https://arxiv.org/abs/2110.02178?context=cs.LG>`_Args:opts: command line argumentsin_channels (int): :math:`C_{in}` from an expected input of size :math:`(N, C_{in}, H, W)`transformer_dim (int): Input dimension to the transformer unitffn_dim (int): Dimension of the FFN blockn_transformer_blocks (int): Number of transformer blocks. Default: 2head_dim (int): Head dimension in the multi-head attention. Default: 32attn_dropout (float): Dropout in multi-head attention. Default: 0.0dropout (float): Dropout rate. Default: 0.0ffn_dropout (float): Dropout between FFN layers in transformer. Default: 0.0patch_h (int): Patch height for unfolding operation. Default: 8patch_w (int): Patch width for unfolding operation. Default: 8transformer_norm_layer (Optional[str]): Normalization layer in the transformer block. Default: layer_normconv_ksize (int): Kernel size to learn local representations in MobileViT block. Default: 3no_fusion (Optional[bool]): Do not combine the input and output feature maps. Default: False"""def __init__(self,in_channels: int,transformer_dim: int,ffn_dim: int,n_transformer_blocks: int = 2,head_dim: int = 32,attn_dropout: float = 0.0,dropout: float = 0.0,ffn_dropout: float = 0.0,patch_h: int = 8,patch_w: int = 8,conv_ksize: Optional[int] = 3,*args,**kwargs) -> None:super().__init__()# 下面两个卷积层:Local representationsconv_3x3_in = ConvLayer(in_channels=in_channels,out_channels=in_channels,kernel_size=conv_ksize,stride=1)conv_1x1_in = ConvLayer(in_channels=in_channels,out_channels=transformer_dim,kernel_size=1,stride=1,use_norm=False,use_act=False)# 下面两个卷积层:Fusionconv_1x1_out = ConvLayer(in_channels=transformer_dim,out_channels=in_channels,kernel_size=1,stride=1)conv_3x3_out = ConvLayer(in_channels=2 * in_channels,out_channels=in_channels,kernel_size=conv_ksize,stride=1)# Local representationsself.local_rep = nn.Sequential()self.local_rep.add_module(name="conv_3x3", module=conv_3x3_in)self.local_rep.add_module(name="conv_1x1", module=conv_1x1_in)assert transformer_dim % head_dim == 0num_heads = transformer_dim // head_dim# global representationsglobal_rep = [TransformerEncoder(embed_dim=transformer_dim,ffn_latent_dim=ffn_dim,num_heads=num_heads,attn_dropout=attn_dropout,dropout=dropout,ffn_dropout=ffn_dropout)for _ in range(n_transformer_blocks)]global_rep.append(nn.LayerNorm(transformer_dim))self.global_rep = nn.Sequential(*global_rep)# Fusionself.conv_proj = conv_1x1_outself.fusion = conv_3x3_outself.patch_h = patch_hself.patch_w = patch_wself.patch_area = self.patch_w * self.patch_hself.cnn_in_dim = in_channelsself.cnn_out_dim = transformer_dimself.n_heads = num_headsself.ffn_dim = ffn_dimself.dropout = dropoutself.attn_dropout = attn_dropoutself.ffn_dropout = ffn_dropoutself.n_blocks = n_transformer_blocksself.conv_ksize = conv_ksizedef unfolding(self, x: Tensor) -> Tuple[Tensor, Dict]:patch_w, patch_h = self.patch_w, self.patch_hpatch_area = patch_w * patch_hbatch_size, in_channels, orig_h, orig_w = x.shapenew_h = int(math.ceil(orig_h / self.patch_h) * self.patch_h)new_w = int(math.ceil(orig_w / self.patch_w) * self.patch_w)interpolate = Falseif new_w != orig_w or new_h != orig_h:# Note: Padding can be done, but then it needs to be handled in attention function.x = F.interpolate(x, size=(new_h, new_w), mode="bilinear", align_corners=False)interpolate = True# number of patches along width and heightnum_patch_w = new_w // patch_w  # n_wnum_patch_h = new_h // patch_h  # n_hnum_patches = num_patch_h * num_patch_w  # N# [B, C, H, W] -> [B * C * n_h, p_h, n_w, p_w]x = x.reshape(batch_size * in_channels * num_patch_h, patch_h, num_patch_w, patch_w)# [B * C * n_h, p_h, n_w, p_w] -> [B * C * n_h, n_w, p_h, p_w]x = x.transpose(1, 2)# [B * C * n_h, n_w, p_h, p_w] -> [B, C, N, P] where P = p_h * p_w and N = n_h * n_wx = x.reshape(batch_size, in_channels, num_patches, patch_area)# [B, C, N, P] -> [B, P, N, C]x = x.transpose(1, 3)# [B, P, N, C] -> [BP, N, C]x = x.reshape(batch_size * patch_area, num_patches, -1)info_dict = {"orig_size": (orig_h, orig_w),"batch_size": batch_size,"interpolate": interpolate,"total_patches": num_patches,"num_patches_w": num_patch_w,"num_patches_h": num_patch_h,}return x, info_dictdef folding(self, x: Tensor, info_dict: Dict) -> Tensor:n_dim = x.dim()assert n_dim == 3, "Tensor should be of shape BPxNxC. Got: {}".format(x.shape)# [BP, N, C] --> [B, P, N, C]x = x.contiguous().view(info_dict["batch_size"], self.patch_area, info_dict["total_patches"], -1)batch_size, pixels, num_patches, channels = x.size()num_patch_h = info_dict["num_patches_h"]num_patch_w = info_dict["num_patches_w"]# [B, P, N, C] -> [B, C, N, P]x = x.transpose(1, 3)# [B, C, N, P] -> [B*C*n_h, n_w, p_h, p_w]x = x.reshape(batch_size * channels * num_patch_h, num_patch_w, self.patch_h, self.patch_w)# [B*C*n_h, n_w, p_h, p_w] -> [B*C*n_h, p_h, n_w, p_w]x = x.transpose(1, 2)# [B*C*n_h, p_h, n_w, p_w] -> [B, C, H, W]x = x.reshape(batch_size, channels, num_patch_h * self.patch_h, num_patch_w * self.patch_w)if info_dict["interpolate"]:x = F.interpolate(x,size=info_dict["orig_size"],mode="bilinear",align_corners=False,)return xdef forward(self, x: Tensor) -> Tensor:res = xfm = self.local_rep(x)# convert feature map to patchespatches, info_dict = self.unfolding(fm)# learn global representationsfor transformer_layer in self.global_rep:patches = transformer_layer(patches)# [B x Patch x Patches x C] -> [B x C x Patches x Patch]fm = self.folding(x=patches, info_dict=info_dict)fm = self.conv_proj(fm)fm = self.fusion(torch.cat((res, fm), dim=1))return fmclass MobileViT(nn.Module):"""This class implements the `MobileViT architecture <https://arxiv.org/abs/2110.02178?context=cs.LG>`_"""def __init__(self, model_cfg: Dict, num_classes: int = 1000):super().__init__()image_channels = 3out_channels = 16self.conv_1 = ConvLayer(in_channels=image_channels,out_channels=out_channels,kernel_size=3,stride=2)self.layer_1, out_channels = self._make_layer(input_channel=out_channels, cfg=model_cfg["layer1"])self.layer_2, out_channels = self._make_layer(input_channel=out_channels, cfg=model_cfg["layer2"])self.layer_3, out_channels = self._make_layer(input_channel=out_channels, cfg=model_cfg["layer3"])self.layer_4, out_channels = self._make_layer(input_channel=out_channels, cfg=model_cfg["layer4"])self.layer_5, out_channels = self._make_layer(input_channel=out_channels, cfg=model_cfg["layer5"])exp_channels = min(model_cfg["last_layer_exp_factor"] * out_channels, 960)self.conv_1x1_exp = ConvLayer(in_channels=out_channels,out_channels=exp_channels,kernel_size=1)self.classifier = nn.Sequential()self.classifier.add_module(name="global_pool", module=nn.AdaptiveAvgPool2d(1))self.classifier.add_module(name="flatten", module=nn.Flatten())if 0.0 < model_cfg["cls_dropout"] < 1.0:self.classifier.add_module(name="dropout", module=nn.Dropout(p=model_cfg["cls_dropout"]))self.classifier.add_module(name="fc", module=nn.Linear(in_features=exp_channels, out_features=num_classes))# weight initself.apply(self.init_parameters)def _make_layer(self, input_channel, cfg: Dict) -> Tuple[nn.Sequential, int]:block_type = cfg.get("block_type", "mobilevit")if block_type.lower() == "mobilevit":return self._make_mit_layer(input_channel=input_channel, cfg=cfg)else:return self._make_mobilenet_layer(input_channel=input_channel, cfg=cfg)@staticmethoddef _make_mobilenet_layer(input_channel: int, cfg: Dict) -> Tuple[nn.Sequential, int]:output_channels = cfg.get("out_channels")num_blocks = cfg.get("num_blocks", 2)expand_ratio = cfg.get("expand_ratio", 4)block = []for i in range(num_blocks):stride = cfg.get("stride", 1) if i == 0 else 1layer = InvertedResidual(in_channels=input_channel,out_channels=output_channels,stride=stride,expand_ratio=expand_ratio)block.append(layer)input_channel = output_channelsreturn nn.Sequential(*block), input_channel@staticmethoddef _make_mit_layer(input_channel: int, cfg: Dict) -> [nn.Sequential, int]:stride = cfg.get("stride", 1)block = []if stride == 2:layer = InvertedResidual(in_channels=input_channel,out_channels=cfg.get("out_channels"),stride=stride,expand_ratio=cfg.get("mv_expand_ratio", 4))block.append(layer)input_channel = cfg.get("out_channels")transformer_dim = cfg["transformer_channels"]ffn_dim = cfg.get("ffn_dim")num_heads = cfg.get("num_heads", 4)head_dim = transformer_dim // num_headsif transformer_dim % head_dim != 0:raise ValueError("Transformer input dimension should be divisible by head dimension. ""Got {} and {}.".format(transformer_dim, head_dim))block.append(MobileViTBlock(in_channels=input_channel,transformer_dim=transformer_dim,ffn_dim=ffn_dim,n_transformer_blocks=cfg.get("transformer_blocks", 1),patch_h=cfg.get("patch_h", 2),patch_w=cfg.get("patch_w", 2),dropout=cfg.get("dropout", 0.1),ffn_dropout=cfg.get("ffn_dropout", 0.0),attn_dropout=cfg.get("attn_dropout", 0.1),head_dim=head_dim,conv_ksize=3))return nn.Sequential(*block), input_channel@staticmethoddef init_parameters(m):if isinstance(m, nn.Conv2d):if m.weight is not None:nn.init.kaiming_normal_(m.weight, mode="fan_out")if m.bias is not None:nn.init.zeros_(m.bias)elif isinstance(m, (nn.LayerNorm, nn.BatchNorm2d)):if m.weight is not None:nn.init.ones_(m.weight)if m.bias is not None:nn.init.zeros_(m.bias)elif isinstance(m, (nn.Linear,)):if m.weight is not None:nn.init.trunc_normal_(m.weight, mean=0.0, std=0.02)if m.bias is not None:nn.init.zeros_(m.bias)else:passdef forward(self, x: Tensor) -> Tensor:x = self.conv_1(x)x = self.layer_1(x)x = self.layer_2(x)x = self.layer_3(x)x = self.layer_4(x)x = self.layer_5(x)x = self.conv_1x1_exp(x)x = self.classifier(x)return xdef mobile_vit_xx_small(num_classes: int = 1000):# pretrain weight link# https://docs-assets.developer.apple.com/ml-research/models/cvnets/classification/mobilevit_xxs.ptconfig = get_config("xx_small")m = MobileViT(config, num_classes=num_classes)return mdef mobile_vit_x_small(num_classes: int = 1000):# pretrain weight link# https://docs-assets.developer.apple.com/ml-research/models/cvnets/classification/mobilevit_xs.ptconfig = get_config("x_small")m = MobileViT(config, num_classes=num_classes)return mdef mobile_vit_small(num_classes: int = 1000):# pretrain weight link# https://docs-assets.developer.apple.com/ml-research/models/cvnets/classification/mobilevit_s.ptconfig = get_config("small")m = MobileViT(config, num_classes=num_classes)return m

2.训练和测试模型

def main(args):device = torch.device(args.device if torch.cuda.is_available() else "cpu")if os.path.exists("./weights") is False:os.makedirs("./weights")tb_writer = SummaryWriter()train_images_path, train_images_label, val_images_path, val_images_label = read_split_data(args.data_path)img_size = 224data_transform = {"train": transforms.Compose([transforms.RandomResizedCrop(img_size),transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),"val": transforms.Compose([transforms.Resize(int(img_size * 1.143)),transforms.CenterCrop(img_size),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])}# 实例化训练数据集train_dataset = MyDataSet(images_path=train_images_path,images_class=train_images_label,transform=data_transform["train"])# 实例化验证数据集val_dataset = MyDataSet(images_path=val_images_path,images_class=val_images_label,transform=data_transform["val"])batch_size = args.batch_sizenw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8])  # number of workersprint('Using {} dataloader workers every process'.format(nw))train_loader = torch.utils.data.DataLoader(train_dataset,batch_size=batch_size,shuffle=True,pin_memory=True,num_workers=nw,collate_fn=train_dataset.collate_fn)val_loader = torch.utils.data.DataLoader(val_dataset,batch_size=batch_size,shuffle=False,pin_memory=True,num_workers=nw,collate_fn=val_dataset.collate_fn)model = create_model(num_classes=args.num_classes).to(device)if args.weights != "":assert os.path.exists(args.weights), "weights file: '{}' not exist.".format(args.weights)weights_dict = torch.load(args.weights, map_location=device)weights_dict = weights_dict["model"] if "model" in weights_dict else weights_dict# 删除有关分类类别的权重for k in list(weights_dict.keys()):if "classifier" in k:del weights_dict[k]print(model.load_state_dict(weights_dict, strict=False))if args.freeze_layers:for name, para in model.named_parameters():# 除head外,其他权重全部冻结if "classifier" not in name:para.requires_grad_(False)else:print("training {}".format(name))pg = [p for p in model.parameters() if p.requires_grad]optimizer = optim.AdamW(pg, lr=args.lr, weight_decay=1E-2)best_acc = 0.for epoch in range(args.epochs):# traintrain_loss, train_acc = train_one_epoch(model=model,optimizer=optimizer,data_loader=train_loader,device=device,epoch=epoch)# validateval_loss, val_acc = evaluate(model=model,data_loader=val_loader,device=device,epoch=epoch)tags = ["train_loss", "train_acc", "val_loss", "val_acc", "learning_rate"]tb_writer.add_scalar(tags[0], train_loss, epoch)tb_writer.add_scalar(tags[1], train_acc, epoch)tb_writer.add_scalar(tags[2], val_loss, epoch)tb_writer.add_scalar(tags[3], val_acc, epoch)tb_writer.add_scalar(tags[4], optimizer.param_groups[0]["lr"], epoch)if val_acc > best_acc:best_acc = val_acctorch.save(model.state_dict(), "./weights/best_model.pth")torch.save(model.state_dict(), "./weights/latest_model.pth")if __name__ == '__main__':parser = argparse.ArgumentParser()parser.add_argument('--num_classes', type=int, default=5)parser.add_argument('--epochs', type=int, default=50)parser.add_argument('--batch-size', type=int, default=8)parser.add_argument('--lr', type=float, default=0.0002)# 数据集所在根目录# https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgzparser.add_argument('--data-path', type=str,default="F:/NN/Learn_Pytorch/flower_photos")# 预训练权重路径,如果不想载入就设置为空字符parser.add_argument('--weights', type=str, default='./mobilevit_xxs.pt',help='initial weights path')# 是否冻结权重parser.add_argument('--freeze-layers', type=bool, default=False)parser.add_argument('--device', default='cuda:0', help='device id (i.e. 0 or 0,1 or cpu)')opt = parser.parse_args()main(opt)

这里使用了预训练权重,在其基础上训练自己的数据集。训练50epoch的准确率能到达94%左右。
在这里插入图片描述

四、图像分类

这里使用花朵数据集,下载连接:https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz


def main():device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")img_size = 224data_transform = transforms.Compose([transforms.Resize(int(img_size * 1.14)),transforms.CenterCrop(img_size),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])# 加载图片img_path = 'daisy2.jpg'assert os.path.exists(img_path), "file: '{}' does not exist.".format(img_path)image = Image.open(img_path)# image.show()# [N, C, H, W]img = data_transform(image)# 扩展维度img = torch.unsqueeze(img, dim=0)# 获取标签json_path = 'class_indices.json'assert os.path.exists(json_path), "file: '{}' does not exist.".format(json_path)with open(json_path, 'r') as f:# 使用json.load()函数加载JSON文件的内容并将其存储在一个Python字典中class_indict = json.load(f)# create modelmodel = create_model(num_classes=5).to(device)# load model weightsmodel_weight_path = "./weights/best_model.pth"model.load_state_dict(torch.load(model_weight_path, map_location=device))model.eval()with torch.no_grad():# 对输入图像进行预测output = torch.squeeze(model(img.to(device))).cpu()# 对模型的输出进行 softmax 操作,将输出转换为类别概率predict = torch.softmax(output, dim=0)# 得到高概率的类别的索引predict_cla = torch.argmax(predict).numpy()res = "class: {}   prob: {:.3}".format(class_indict[str(predict_cla)], predict[predict_cla].numpy())draw = ImageDraw.Draw(image)# 文本的左上角位置position = (10, 10)# fill 指定文本颜色draw.text(position, res, fill='green')image.show()for i in range(len(predict)):print("class: {:10}   prob: {:.3}".format(class_indict[str(i)], predict[i].numpy()))

预测结果:
在这里插入图片描述

结束语

感谢阅读吾之文章,今已至此次旅程之终站 🛬。

吾望斯文献能供尔以宝贵之信息与知识也 🎉。

学习者之途,若藏于天际之星辰🍥,吾等皆当努力熠熠生辉,持续前行。

然而,如若斯文献有益于尔,何不以三连为礼?点赞、留言、收藏 - 此等皆以证尔对作者之支持与鼓励也 💞。

相关文章:

Pytorch之MobileViT图像分类

文章目录 前言一、Transformer存在的问题二、MobileViT1.MobileViT网络结构&#x1f353; Vision Transformer结构&#x1f349;MobileViT结构 2.MV2(MobileNet v2 block)3.MobileViT block&#x1f947;Local representations&#x1f948;Transformers as Convolutions (glob…...

03在命令行环境中创建Maven版的Java工程,了解pom.xml文件的结构,了解Java工程的目录结构并编写代码,执行Maven相关的构建命令

创建Maven版的Java工程 Maven工程的坐标 数学中使用x、y、z三个向量可以在空间中唯一的定位一个点, Maven中也可以使用groupId,artifactId,version三个向量在Maven的仓库中唯一的定位到一个jar包 groupId: 公司或组织域名的倒序, 通常也会加上项目名称代表公司或组织开发的一…...

论文阅读:CenterFormer: Center-based Transformer for 3D Object Detection

目录 概要 Motivation 整体架构流程 技术细节 Multi-scale Center Proposal Network Multi-scale Center Transformer Decoder Multi-frame CenterFormer 小结 论文地址&#xff1a;[2209.05588] CenterFormer: Center-based Transformer for 3D Object Detection (arx…...

Arduino驱动BNO055九轴绝对定向传感器(惯性测量传感器篇)

目录 1、传感器特性 2、硬件原理图 3、控制器和传感器连线图 4、驱动程序 BNO055是实现智能9轴绝对定向的新型传感器IC,它将整个传感器系统级封装在一起,集成了三轴14位加速度计,三轴16位陀螺仪,三轴地磁传感器和一个自带算法处理的32位微控制器。...

MQTT测试工具及使用教程

一步一步来&#xff1a;MQTT服务器搭建、MQTT客户端使用-CSDN博客 MQTT X 使用指南_mqttx使用教程-CSDN博客...

yolov7改进优化之蒸馏(一)

最近比较忙&#xff0c;有一段时间没更新了&#xff0c;最近yolov7用的比较多&#xff0c;总结一下。上一篇yolov5及yolov7实战之剪枝_CodingInCV的博客-CSDN博客 我们讲了通过剪枝来裁剪我们的模型&#xff0c;达到在精度损失不大的情况下&#xff0c;提高模型速度的目的。上一…...

视频美颜SDK,提升企业视频通话质量与形象

在今天的数字时代&#xff0c;视频通话已经成为企业与客户、员工之间不可或缺的沟通方式。然而&#xff0c;由于网络环境、设备性能等因素的影响&#xff0c;视频通话中的画面质量往往难以达到预期效果。为了提升视频通话的质量与形象&#xff0c;美摄美颜SDK应运而生&#xff…...

webmin远程命令执行漏洞

文章目录 漏洞编号&#xff1a;漏洞描述&#xff1a;影响版本&#xff1a;利用方法&#xff08;利用案例&#xff09;&#xff1a;安装环境漏洞复现 附带文件&#xff1a;加固建议&#xff1a;参考信息&#xff1a;漏洞分类&#xff1a; Webmin 远程命令执行漏洞&#xff08;CV…...

docker离线安装和使用

通过修改daemon配置文件/etc/docker/daemon.json来使用加速器sudo mkdir -p /etc/docker sudo tee /etc/docker/daemon.json <<-EOF {"registry-mirrors": ["https://ullx9uta.mirror.aliyuncs.com"] } EOF sudo systemctl daemon-reload sudo syste…...

解决 MyBatis 一对多查询中,出现每组元素只有一个,总组数与元素数总数相等的问题

文章目录 问题简述场景描述问题描述问题原因解决办法 问题简述 笔者在使用 MyBatis 进行一对多查询的时候遇到一个奇怪的问题。对于笔者的一对多的查询结果&#xff0c;出现了这样的一个现象&#xff1a;原来每个组里有多个元素&#xff0c;查询目标是查询所查的组&#xff0c;…...

这应该是关于回归模型最全的总结了(附原理+代码)

本文将继续修炼回归模型算法&#xff0c;并总结了一些常用的除线性回归模型之外的模型&#xff0c;其中包括一些单模型及集成学习器。 保序回归、多项式回归、多输出回归、多输出K近邻回归、决策树回归、多输出决策树回归、AdaBoost回归、梯度提升决策树回归、人工神经网络、随…...

基于闪电连接过程优化的BP神经网络(分类应用) - 附代码

基于闪电连接过程优化的BP神经网络&#xff08;分类应用&#xff09; - 附代码 文章目录 基于闪电连接过程优化的BP神经网络&#xff08;分类应用&#xff09; - 附代码1.鸢尾花iris数据介绍2.数据集整理3.闪电连接过程优化BP神经网络3.1 BP神经网络参数设置3.2 闪电连接过程算…...

Linux性能优化--性能工具:网络

7.0 概述 本章介绍一些在Linux上可用的网络性能工具。我们主要关注分析单个设备/系统网络流量的工具&#xff0c;而非全网管理工具。虽然在完全隔离的情况下评估网络性能通常是无意义的(节点不会与自己通信),但是&#xff0c;调查单个系统在网络上的行为对确定本地配置和应用程…...

【Linux】线程互斥与同步

文章目录 一.Linux线程互斥1.进程线程间的互斥相关背景概念2互斥量mutex3.互斥量的接口4.互斥量实现原理探究 二.可重入VS线程安全1.概念2.常见的线程不安全的情况3.常见的线程安全的情况4.常见的不可重入的情况5.常见的可重入的情况6.可重入与线程安全联系7.可重入与线程安全区…...

敏捷开发中,Sprint回顾会的目的

Sprint回顾会的主要目的是促进Scrum团队的学习和持续改进。在每个Sprint结束后&#xff0c;团队聚集在一起进行回顾&#xff0c;以达到以下目标&#xff1a; 识别问题&#xff1a; 回顾会允许团队识别在Sprint&#xff08;迭代&#xff09;期间遇到的问题、挑战和障碍。这有助于…...

排序【七大排序】

文章目录 1. 排序的概念及引用1.1 排序的概念1.2 常见的排序算法 2. 常见排序算法的实现2.1 插入排序2.1.1基本思想&#xff1a;2.1.2 直接插入排序2.1.3 希尔排序( 缩小增量排序 ) 2.2 选择排序2.2.1基本思想&#xff1a;2.2.2 直接选择排序:2.2.3 堆排序 2.3 交换排序2.3.1冒…...

人大与加拿大女王大学金融硕士项目——立即行动,才是缓解焦虑的解药

!在这个经济飞速的发展的时代&#xff0c;我国焦虑症的患病率为7%&#xff0c;焦虑已经超越个体范畴&#xff0c;成为整个社会与时代的课题。焦虑&#xff0c;往往源于我们想要达到的&#xff0c;与自己拥有的所产生的差距。任何事情&#xff0c;开始做远比准备做更会给人带来成…...

老卫带你学---leetcode刷题(46. 全排列)

46. 全排列 问题&#xff1a; 给定一个不含重复数字的数组 nums &#xff0c;返回其 所有可能的全排列 。你可以 按任意顺序 返回答案。 示例 1&#xff1a;输入&#xff1a;nums [1,2,3] 输出&#xff1a;[[1,2,3],[1,3,2],[2,1,3],[2,3,1],[3,1,2],[3,2,1]]示例 2&#x…...

6.6 图的应用

思维导图&#xff1a; 6.6.1 最小生成树 ### 6.6 图的应用 #### 主旨&#xff1a;图的概念可应用于现实生活中的许多问题&#xff0c;如网络构建、路径查询、任务排序等。 --- #### 6.6.1 最小生成树 **概念**&#xff1a;要在n个城市中建立通信联络网&#xff0c;则最少需…...

100问GPT4与大语言模型的关系以及LLMs的重要性

你现在是一个AI专家&#xff0c;语言学家和教师&#xff0c;你目标是让我理解语言模型的概念&#xff0c;理解ChatGPT 跟语言模型之间的关系。你的工作是以一种易于理解的方式解释这些概念。这可能包括提供 例子&#xff0c;提出问题或将复杂的想法分解成更容易理解的小块。现在…...

Linux:mongodb数据逻辑备份与恢复(3.4.5版本)

我在数据库aaa的里创建了一个名为tarro的集合&#xff0c;其中有三条数据 备份语法 mongodump –h server_ip –d database_name –o dbdirectory 恢复语法 mongorestore -d database_name --dirdbdirectory 备份 现在我要将aaa.tarro进行备份 mongodump --host 192.168.254…...

凉鞋的 Godot 笔记 109. 专题一 小结

109. 专题一 小结 在这一篇&#xff0c;我们来对第一个专题做一个小的总结。 到目前为止&#xff0c;大家应该能够感受到此教程的基调。 内容的难度非常简单&#xff0c;接近于零基础的程度&#xff0c;不过通过这些零基础内容所介绍的通识内容其实是笔者好多年的时间一点点…...

数据结构 - 4(栈和队列6000字详解)

一&#xff1a;栈 1.1 栈的概念 栈&#xff1a;一种特殊的线性表&#xff0c;其只允许在固定的一端进行插入和删除元素操作。进行数据插入和删除操作的一端称为栈顶&#xff0c;另一端称为栈底。栈中的数据元素遵守后进先出LIFO&#xff08;Last In First Out&#xff09;的原…...

MySQL InnoDB引擎深入学习的一天(InnoDB架构 + 事务底层原理 + MVCC)

目录 逻辑存储引擎 架构 概述 内存架构 Buffer Pool Change Buffe Adaptive Hash Index Log Buffer 磁盘结构 System Tablespace File-Per-Table Tablespaces General Tablespaces Undo Tablespaces Temporary Tablespaces Doublewrite Buffer Files Redo Log 后台线程 事务原…...

TX Text Control .NET Server for ASP.NET 32.0 Crack

TX Text Control .NET Server for ASP.NET 是VISUAL STUDIO 2022、ASP.NET CORE .NET 6 和 .NET 7 支持&#xff0c;将文档处理集成到 Web 应用程序中&#xff0c;为您的 ASP.NET Core、ASP.NET 和 Angular 应用程序添加强大的文档处理功能。 客户端用户界面 文档编辑器 将功能…...

Leetcode刷题详解——将x减到0的最小操作数

1. 题目链接&#xff1a;1658. 将 x 减到 0 的最小操作数 2. 题目描述: 给你一个整数数组 nums 和一个整数 x 。每一次操作时&#xff0c;你应当移除数组 nums 最左边或最右边的元素&#xff0c;然后从 x 中减去该元素的值。请注意&#xff0c;需要 修改 数组以供接下来的操作…...

精选免费热门api接口分享

IP归属地-IPv4城市级&#xff1a;根据IP地址查询归属地信息&#xff0c;支持到城市级&#xff0c;包含国家、省、市、和运营商等信息。IP归属地-IPv6城市级&#xff1a;根据IP地址&#xff08;IPv6版本&#xff09;查询归属地信息&#xff0c;支持到中国大陆地区&#xff08;不…...

androidx.appcompat.widget.Toolbar最右边设置控件不能仅靠最右边

androidx.appcompat.widget.Toolbar最右边设置控件不能仅靠最右边 Android Toolbar左、中、右对齐-CSDN博客&#xfeff;&#xfeff;Android Toolbar左、中、右对齐默认的Android Toolbar中添加子元素view是从左到右依次添加。需要注意的是&#xff0c;Android Toolbar为自身的…...

Springboot整合WebSocket实现浏览器和服务器交互

Websocket定义 代码实现 引入maven依赖 <dependency><groupId>org.springframework.boot</groupId><artifactId>spring-boot-starter-websocket</artifactId></dependency>配置类 import org.springframework.context.annotation.Bean;i…...

这些 channel 用法你都用起来了吗?

channel 是什么&#xff1f; channel 是GO语言中一种特殊的类型&#xff0c;是连接并发goroutine的管道 channel 通道是可以让一个 goroutine 协程发送特定值到另一个 goroutine 协程的通信机制。 关于 channel 的原理&#xff0c;channel通道需要注意的地方&#xff0c;之前…...