当前位置: 首页 > news >正文

Pytorch | 从零构建EfficientNet对CIFAR10进行分类

Pytorch | 从零构建EfficientNet对CIFAR10进行分类

  • CIFAR10数据集
  • EfficientNet
    • 设计理念
    • 网络结构
    • 性能特点
    • 应用领域
    • 发展和改进
  • EfficientNet结构代码详解
    • 结构代码
    • 代码详解
      • MBConv 类
        • 初始化方法
        • 前向传播 forward 方法
      • EfficientNet 类
        • 初始化方法
        • 前向传播 forward 方法
  • 训练过程和测试结果
  • 代码汇总
    • efficientnet.py
    • train.py
    • test.py

前面文章我们构建了AlexNet、Vgg、GoogleNet、ResNet、MobileNet对CIFAR10进行分类:
Pytorch | 从零构建AlexNet对CIFAR10进行分类
Pytorch | 从零构建Vgg对CIFAR10进行分类
Pytorch | 从零构建GoogleNet对CIFAR10进行分类
Pytorch | 从零构建ResNet对CIFAR10进行分类
Pytorch | 从零构建MobileNet对CIFAR10进行分类
这篇文章我们来构建EfficientNet.

CIFAR10数据集

CIFAR-10数据集是由加拿大高级研究所(CIFAR)收集整理的用于图像识别研究的常用数据集,基本信息如下:

  • 数据规模:该数据集包含60,000张彩色图像,分为10个不同的类别,每个类别有6,000张图像。通常将其中50,000张作为训练集,用于模型的训练;10,000张作为测试集,用于评估模型的性能。
  • 图像尺寸:所有图像的尺寸均为32×32像素,这相对较小的尺寸使得模型在处理该数据集时能够相对快速地进行训练和推理,但也增加了图像分类的难度。
  • 类别内容:涵盖了飞机(plane)、汽车(car)、鸟(bird)、猫(cat)、鹿(deer)、狗(dog)、青蛙(frog)、马(horse)、船(ship)、卡车(truck)这10个不同的类别,这些类别都是现实世界中常见的物体,具有一定的代表性。

下面是一些示例样本:
在这里插入图片描述

EfficientNet

EfficientNet是由谷歌大脑团队在2019年提出的一种高效的卷积神经网络架构,在图像分类等任务上展现出了卓越的性能和效率,以下是对它的详细介绍:

设计理念

  • 平衡模型的深度、宽度和分辨率:传统的神经网络在提升性能时,往往只是单纯地增加网络的深度、宽度或输入图像的分辨率,而EfficientNet则通过一种系统的方法,同时对这三个维度进行优化调整,以达到在计算资源有限的情况下,模型性能的最大化。
    在这里插入图片描述

网络结构

  • MBConv模块:EfficientNet的核心模块是MBConv(Mobile Inverted Residual Bottleneck),它基于深度可分离卷积和倒置残差结构。这种结构在减少计算量的同时,能够有效提取图像特征,提高模型的表示能力。
  • Compound Scaling方法:使用该方法对网络的深度、宽度和分辨率进行统一缩放。通过一个固定的缩放系数,同时调整这三个维度,使得模型在不同的计算资源限制下,都能保持较好的性能和效率平衡。
    在这里插入图片描述
    上图中是EfficientNet-B0的结构.

性能特点

  • 高效性:在相同的计算资源下,EfficientNet能够取得比传统网络更好的性能。例如,与ResNet-50相比,EfficientNet-B0在ImageNet数据集上取得了相近的准确率,但参数量和计算量却大大减少。
  • 可扩展性:通过Compound Scaling方法,可以方便地调整模型的大小,以适应不同的应用场景和计算资源限制。从EfficientNet-B0到EfficientNet-B7,模型的复杂度逐渐增加,性能也相应提升,能够满足从移动端到服务器端的不同需求。

应用领域

  • 图像分类:在ImageNet等大规模图像分类数据集上,EfficientNet取得了领先的性能,成为图像分类任务的首选模型之一。
  • 目标检测:与Faster R-CNN等目标检测框架结合,EfficientNet作为骨干网络,能够提高目标检测的准确率和速度,在Pascal VOC、COCO等数据集上取得了不错的效果。
  • 语义分割:在语义分割任务中,EfficientNet也展现出了一定的优势,通过与U-Net等分割网络结合,能够对图像进行像素级的分类,在Cityscapes等数据集上有较好的表现。

发展和改进

  • EfficientNet v2:在EfficientNet基础上进行了进一步优化,主要改进包括改进了渐进式学习的方法,在训练过程中逐渐增加图像的分辨率,使得模型能够更好地学习到不同尺度的特征,同时优化了网络结构,提高了模型的训练速度和性能。
  • 其他改进:研究人员还在EfficientNet的基础上,结合其他技术如注意力机制、知识蒸馏等,进一步提高模型的性能和泛化能力。

EfficientNet结构代码详解

结构代码

import torch
import torch.nn as nn
import torch.nn.functional as Fclass MBConv(nn.Module):def __init__(self, in_channels, out_channels, expand_ratio, kernel_size, stride, padding, se_ratio=0.25):super(MBConv, self).__init__()self.stride = strideself.use_res_connect = (stride == 1 and in_channels == out_channels)# 扩展通道数(如果需要)expanded_channels = in_channels * expand_ratioself.expand_conv = nn.Conv2d(in_channels, expanded_channels, kernel_size=1, padding=0, bias=False)self.bn1 = nn.BatchNorm2d(expanded_channels)# 深度可分离卷积self.depthwise_conv = nn.Conv2d(expanded_channels, expanded_channels, kernel_size=kernel_size,stride=stride, padding=padding, groups=expanded_channels, bias=False)self.bn2 = nn.BatchNorm2d(expanded_channels)# 压缩和激励(SE)模块(可选,根据se_ratio判断是否添加)if se_ratio > 0:se_channels = int(in_channels * se_ratio)self.se = nn.Sequential(nn.AdaptiveAvgPool2d(1),nn.Conv2d(expanded_channels, se_channels, kernel_size=1),nn.ReLU(inplace=True),nn.Conv2d(se_channels, expanded_channels, kernel_size=1),nn.Sigmoid())else:self.se = None# 投影卷积,恢复到输出通道数self.project_conv = nn.Conv2d(expanded_channels, out_channels, kernel_size=1, padding=0, bias=False)self.bn3 = nn.BatchNorm2d(out_channels)def forward(self, x):identity = x# 扩展通道数out = F.relu(self.bn1(self.expand_conv(x)))# 深度可分离卷积out = F.relu(self.bn2(self.depthwise_conv(out)))# SE模块操作(如果存在)if self.se is not None:se_out = self.se(out)out = out * se_out# 投影卷积out = self.bn3(self.project_conv(out))# 残差连接(如果满足条件)if self.use_res_connect:out += identityreturn outclass EfficientNet(nn.Module):def __init__(self, num_classes, width_coefficient=1.0, depth_coefficient=1.0, dropout_rate=0.2):super(EfficientNet, self).__init__()self.stem_conv = nn.Conv2d(3, 32, kernel_size=3, stride=2, padding=1, bias=False)self.bn1 = nn.BatchNorm2d(32)mbconv_config = [# (in_channels, out_channels, expand_ratio, kernel_size, stride, padding)(32, 16, 1, 3, 1, 1),(16, 24, 6, 3, 2, 1),(24, 40, 6, 5, 2, 2),(40, 80, 6, 3, 2, 1),(80, 112, 6, 5, 1, 2),(112, 192, 6, 5, 2, 2),(192, 320, 6, 3, 1, 1)]# 根据深度系数调整每个MBConv模块的重复次数,这里简单地向下取整,你也可以根据实际情况采用更合理的方式repeat_counts = [max(1, int(depth_coefficient * 1)) for _ in mbconv_config]layers = []for i, config in enumerate(mbconv_config):in_channels, out_channels, expand_ratio, kernel_size, stride, padding = configfor _ in range(repeat_counts[i]):layers.append(MBConv(int(in_channels * width_coefficient),int(out_channels * width_coefficient),expand_ratio, kernel_size, stride, padding))self.mbconv_layers = nn.Sequential(*layers)self.last_conv = nn.Conv2d(int(320 * width_coefficient), 1280, kernel_size=1, bias=False)self.bn2 = nn.BatchNorm2d(1280)self.avgpool = nn.AdaptiveAvgPool2d(1)self.dropout = nn.Dropout(dropout_rate)self.fc = nn.Linear(1280, num_classes)def forward(self, x):out = F.relu(self.bn1(self.stem_conv(x)))out = self.mbconv_layers(out)out = F.relu(self.bn2(self.last_conv(out)))out = self.avgpool(out)out = out.view(out.size(0), -1)out = self.dropout(out)out = self.fc(out)return out

代码详解

以下是对上述EfficientNet代码的详细解释,代码整体定义了EfficientNet网络结构,主要由MBConv模块堆叠以及一些常规的卷积、池化和全连接层构成,下面按照类和方法分别进行剖析:

MBConv 类

这是EfficientNet中的核心模块,实现了MBConv(Mobile Inverted Residual Bottleneck Convolution)结构,其代码如下:

class MBConv(nn.Module):def __init__(self, in_channels, out_channels, expand_ratio, kernel_size, stride, padding, se_ratio=0.25):super(MBConv, self).__init__()self.stride = strideself.use_res_connect = (stride == 1 and in_channels == out_channels)# 扩展通道数(如果需要)expanded_channels = in_channels * expand_ratioself.expand_conv = nn.Conv2d(in_channels, expanded_channels, kernel_size=1, padding=0, bias=False)self.bn1 = nn.BatchNorm2d(expanded_channels)# 深度可分离卷积self.depthwise_conv = nn.Conv2d(expanded_channels, expanded_channels, kernel_size=kernel_size,stride=stride, padding=padding, groups=expanded_channels, bias=False)self.bn2 = nn.BatchNorm2d(expanded_channels)# 压缩和激励(SE)模块(可选,根据se_ratio判断是否添加)if se_ratio > 0:se_channels = int(in_channels * se_ratio)self.se = nn.Sequential(nn.AdaptiveAvgPool2d(1),nn.Conv2d(expanded_channels, se_channels, kernel_size=1),nn.ReLU(inplace=True),nn.Conv2d(se_channels, expanded_channels, kernel_size=1),nn.Sigmoid())else:self.se = None# 投影卷积,恢复到输出通道数self.project_conv = nn.Conv2d(expanded_channels, out_channels, kernel_size=1, padding=0, bias=False)self.bn3 = nn.BatchNorm2d(out_channels)def forward(self, x):identity = x# 扩展通道数out = F.relu(self.bn1(self.expand_conv(x)))# 深度可分离卷积out = F.relu(self.bn2(self.depthwise_conv(out)))# SE模块操作(如果存在)if self.se is not None:se_out = self.se(out)out = out * se_out# 投影卷积out = self.bn3(self.project_conv(out))# 残差连接(如果满足条件)if self.use_res_connect:out += identityreturn out
初始化方法
  • 参数说明
    • in_channels:输入张量的通道数。
    • out_channels:输出张量的通道数。
    • expand_ratio:用于确定扩展通道数时的比例系数,决定是否对输入通道数进行扩展以及扩展的倍数。
    • kernel_size:卷积核的大小,用于深度可分离卷积等操作。
    • stride:卷积的步长,控制特征图在卷积过程中的下采样程度等。
    • padding:卷积操作时的填充大小,保证输入输出特征图尺寸在特定要求下的一致性等。
    • se_ratio(可选,默认值为0.25):用于控制压缩和激励(SE)模块中通道压缩的比例,若为0则不添加SE模块。
  • 初始化操作
    • 首先保存传入的stride参数,并根据stride和输入输出通道数判断是否使用残差连接(use_res_connect),只有当步长为1且输入输出通道数相等时才使用残差连接,这符合残差网络的基本原理,有助于梯度传播和特征融合。
    • 根据expand_ratio计算扩展后的通道数expanded_channels,并创建expand_conv卷积层用于扩展通道数,同时搭配对应的bn1批归一化层,对扩展后的特征进行归一化处理,有助于加速网络收敛和提升模型稳定性。
    • 定义depthwise_conv深度可分离卷积层,其分组数设置为expanded_channels,意味着每个通道单独进行卷积操作,这种方式可以在减少计算量的同时保持较好的特征提取能力,同时搭配bn2批归一化层。
    • 根据se_ratio判断是否添加压缩和激励(SE)模块。如果se_ratio大于0,则创建一个包含自适应平均池化、卷积、激活函数(ReLU)、卷积和Sigmoid激活的序列模块se,用于对特征进行通道维度上的重加权,增强模型对不同通道特征的关注度;若se_ratio为0,则将se设为None
    • 最后创建project_conv投影卷积层用于将扩展和处理后的特征恢复到指定的输出通道数,并搭配bn3批归一化层。
前向传播 forward 方法
  • 首先将输入张量x保存为identity,用于后续可能的残差连接。
  • 通过F.relu(self.bn1(self.expand_conv(x)))对输入进行通道扩展,并使用ReLU激活函数和批归一化进行处理,得到扩展后的特征表示。
  • 接着对扩展后的特征执行深度可分离卷积操作F.relu(self.bn2(self.depthwise_conv(out))),同样使用ReLU激活和批归一化处理。
  • 如果存在SE模块(self.se不为None),则将经过深度可分离卷积后的特征传入SE模块进行通道重加权,即se_out = self.se(out),然后将特征与重加权后的结果相乘out = out * se_out
  • 通过self.bn3(self.project_conv(out))进行投影卷积操作,将特征恢复到输出通道数,并进行批归一化处理。
  • 最后,如果满足残差连接条件(self.use_res_connectTrue),则将投影卷积后的特征与最初保存的输入特征identity相加,实现残差连接,最终返回处理后的特征张量。

EfficientNet 类

这是整体的EfficientNet网络模型类,代码如下:

class EfficientNet(nn.Module):def __init__(self, num_classes, width_coefficient=1.0, depth_coefficient=1.0, dropout_rate=0.2):super(EfficientNet, self).__init__()self.stem_conv = nn.Conv2d(3, 32, kernel_size=3, stride=2, padding=1, bias=False)self.bn1 = nn.BatchNorm2d(32)mbconv_config = [# (in_channels, out_channels, expand_ratio, kernel_size, stride, padding)(32, 16, 1, 3, 1, 1),(16, 24, 6, 3, 2, 1),(24, 40, 6, 5, 2, 2),(40, 80, 6, 3, 2, 1),(80, 112, 6, 5, 1, 2),(112, 192, 6, 5, 2, 2),(192, 320, 6, 3, 1, 1)]# 根据深度系数调整每个MBConv模块的重复次数,这里简单地向下取整,你也可以根据实际情况采用更合理的方式repeat_counts = [max(1, int(depth_coefficient * 1)) for _ in mbconv_config]layers = []for i, config in enumerate(mbconv_config):in_channels, out_channels, expand_ratio, kernel_size, stride, padding = configfor _ in range(repeat_counts[i]):layers.append(MBConv(int(in_channels * width_coefficient),int(out_channels * width_coefficient),expand_ratio, kernel_size, stride, padding))self.mbconv_layers = nn.Sequential(*layers)self.last_conv = nn.Conv2d(int(320 * width_coefficient), 1280, kernel_size=1, bias=False)self.bn2 = nn.BatchNorm2d(1280)self.avgpool = nn.AdaptiveAvgPool2d(1)self.dropout = nn.Dropout(dropout_rate)self.fc = nn.Linear(1280, num_classes)def forward(self, x):out = F.relu(self.bn1(self.stem_conv(x)))out = self.mbconv_layers(out)out = F.relu(self.bn2(self.last_conv(out)))out = self.avgpool(out)out = out.view(out.size(0), -1)out = self.dropout(out)out = self.fc(out)return out
初始化方法
  • 参数说明
    • num_classes:最终分类任务的类别数量,用于确定全连接层的输出维度。
    • width_coefficient(默认值为1.0):用于控制模型各层的通道数,实现对模型宽度的缩放调整。
    • depth_coefficient(默认值为1.0):用于控制模型中MBConv模块的重复次数,实现对模型深度的缩放调整。
    • dropout_rate(默认值为0.2):在全连接层之前使用的Dropout概率,用于防止过拟合。
  • 初始化操作
    • 首先创建stem_conv卷积层,它将输入的图像数据(通常通道数为3,对应RGB图像)进行初始的卷积操作,步长为2,起到下采样的作用,同时不使用偏置(bias=False),并搭配bn1批归一化层对卷积后的特征进行归一化处理。
    • 定义mbconv_config列表,其中每个元素是一个元组,包含了MBConv模块的各项配置参数,如输入通道数、输出通道数、扩展比例、卷积核大小、步长和填充等,这是构建MBConv模块的基础配置信息。
    • 根据depth_coefficient计算每个MBConv模块的重复次数,通过列表推导式 repeat_counts = [max(1, int(depth_coefficient * 1)) for _ in mbconv_config] 实现,这里简单地将每个配置对应的重复次数设置为与depth_coefficient成比例(向下取整且保证至少重复1次),你可以根据更精细的设计规则来调整这个计算方式。
    • 构建self.mbconv_layers,通过两层嵌套循环实现。外层循环遍历mbconv_config配置列表,内层循环根据对应的重复次数来多次添加同一个MBConv模块实例到layers列表中,最后将layers列表转换为nn.Sequential类型的模块,这样就实现了根据depth_coefficient对模型深度进行调整以及MBConv模块的堆叠搭建。
    • 创建last_conv卷积层,它将经过MBConv模块处理后的特征进行进一步的卷积操作,将通道数转换为1280,同样不使用偏置,搭配bn2批归一化层。
    • 定义avgpool自适应平均池化层,将特征图转换为固定大小(这里为1x1),方便后续全连接层处理。
    • 创建dropout Dropout层,按照指定的dropout_rate在全连接层之前进行随机失活操作,防止过拟合。
    • 最后定义fc全连接层,其输入维度为1280(经过池化后的特征维度),输出维度为num_classes,用于最终的分类预测。
前向传播 forward 方法
  • 首先将输入x传入stem_conv卷积层进行初始卷积,然后通过F.relu(self.bn1(self.stem_conv(x)))进行激活和批归一化处理,得到初始的特征表示。
  • 将初始特征传入self.mbconv_layers,即经过一系列堆叠的MBConv模块进行特征提取和变换,充分挖掘图像中的特征信息。
  • 接着对经过MBConv模块处理后的特征执行F.relu(self.bn2(self.last_conv(out)))操作,进行最后的卷积以及激活、批归一化处理。
  • 使用self.avgpool(out)进行自适应平均池化,将特征图尺寸变为1x1,实现特征的压缩和固定维度表示。
  • 通过out = out.view(out.size(0), -1)将池化后的特征张量展平为一维向量,方便全连接层处理,这里-1表示自动根据张量元素总数和已知的批量大小维度(out.size(0))来推断展平后的维度大小。
  • 然后将展平后的特征传入self.dropout(out)进行Dropout操作,随机丢弃一部分神经元,防止过拟合。
  • 最后将特征传入self.fc(out)全连接层进行分类预测,得到最终的输出结果,输出的维度与设定的num_classes一致,表示每个样本属于各个类别的预测概率(或得分等,具体取决于任务和后续处理),并返回该输出结果。

训练过程和测试结果

训练过程损失函数变化曲线:

在这里插入图片描述

训练过程准确率变化曲线:
在这里插入图片描述

测试结果:
在这里插入图片描述

代码汇总

项目github地址
项目结构:

|--data
|--models|--__init__.py|-efficientnet.py|--...
|--results
|--weights
|--train.py
|--test.py

efficientnet.py

import torch
import torch.nn as nn
import torch.nn.functional as Fclass MBConv(nn.Module):def __init__(self, in_channels, out_channels, expand_ratio, kernel_size, stride, padding, se_ratio=0.25):super(MBConv, self).__init__()self.stride = strideself.use_res_connect = (stride == 1 and in_channels == out_channels)# 扩展通道数(如果需要)expanded_channels = in_channels * expand_ratioself.expand_conv = nn.Conv2d(in_channels, expanded_channels, kernel_size=1, padding=0, bias=False)self.bn1 = nn.BatchNorm2d(expanded_channels)# 深度可分离卷积self.depthwise_conv = nn.Conv2d(expanded_channels, expanded_channels, kernel_size=kernel_size,stride=stride, padding=padding, groups=expanded_channels, bias=False)self.bn2 = nn.BatchNorm2d(expanded_channels)# 压缩和激励(SE)模块(可选,根据se_ratio判断是否添加)if se_ratio > 0:se_channels = int(in_channels * se_ratio)self.se = nn.Sequential(nn.AdaptiveAvgPool2d(1),nn.Conv2d(expanded_channels, se_channels, kernel_size=1),nn.ReLU(inplace=True),nn.Conv2d(se_channels, expanded_channels, kernel_size=1),nn.Sigmoid())else:self.se = None# 投影卷积,恢复到输出通道数self.project_conv = nn.Conv2d(expanded_channels, out_channels, kernel_size=1, padding=0, bias=False)self.bn3 = nn.BatchNorm2d(out_channels)def forward(self, x):identity = x# 扩展通道数out = F.relu(self.bn1(self.expand_conv(x)))# 深度可分离卷积out = F.relu(self.bn2(self.depthwise_conv(out)))# SE模块操作(如果存在)if self.se is not None:se_out = self.se(out)out = out * se_out# 投影卷积out = self.bn3(self.project_conv(out))# 残差连接(如果满足条件)if self.use_res_connect:out += identityreturn outclass EfficientNet(nn.Module):def __init__(self, num_classes, width_coefficient=1.0, depth_coefficient=1.0, dropout_rate=0.2):super(EfficientNet, self).__init__()self.stem_conv = nn.Conv2d(3, 32, kernel_size=3, stride=2, padding=1, bias=False)self.bn1 = nn.BatchNorm2d(32)mbconv_config = [# (in_channels, out_channels, expand_ratio, kernel_size, stride, padding)(32, 16, 1, 3, 1, 1),(16, 24, 6, 3, 2, 1),(24, 40, 6, 5, 2, 2),(40, 80, 6, 3, 2, 1),(80, 112, 6, 5, 1, 2),(112, 192, 6, 5, 2, 2),(192, 320, 6, 3, 1, 1)]# 根据深度系数调整每个MBConv模块的重复次数,这里简单地向下取整,你也可以根据实际情况采用更合理的方式repeat_counts = [max(1, int(depth_coefficient * 1)) for _ in mbconv_config]layers = []for i, config in enumerate(mbconv_config):in_channels, out_channels, expand_ratio, kernel_size, stride, padding = configfor _ in range(repeat_counts[i]):layers.append(MBConv(int(in_channels * width_coefficient),int(out_channels * width_coefficient),expand_ratio, kernel_size, stride, padding))self.mbconv_layers = nn.Sequential(*layers)self.last_conv = nn.Conv2d(int(320 * width_coefficient), 1280, kernel_size=1, bias=False)self.bn2 = nn.BatchNorm2d(1280)self.avgpool = nn.AdaptiveAvgPool2d(1)self.dropout = nn.Dropout(dropout_rate)self.fc = nn.Linear(1280, num_classes)def forward(self, x):out = F.relu(self.bn1(self.stem_conv(x)))out = self.mbconv_layers(out)out = F.relu(self.bn2(self.last_conv(out)))out = self.avgpool(out)out = out.view(out.size(0), -1)out = self.dropout(out)out = self.fc(out)return out

train.py

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from models import *
import matplotlib.pyplot as pltimport ssl
ssl._create_default_https_context = ssl._create_unverified_context# 定义数据预处理操作
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.491, 0.482, 0.446), (0.247, 0.243, 0.261))])# 加载CIFAR10训练集
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,download=False, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=128,shuffle=True, num_workers=2)# 定义设备(GPU优先,若可用)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")# 实例化模型
model_name = 'EfficientNet'
if model_name == 'AlexNet':model = AlexNet(num_classes=10).to(device)
elif model_name == 'Vgg_A':model = Vgg(cfg_vgg='A', num_classes=10).to(device)
elif model_name == 'Vgg_A-LRN':model = Vgg(cfg_vgg='A-LRN', num_classes=10).to(device)
elif model_name == 'Vgg_B':model = Vgg(cfg_vgg='B', num_classes=10).to(device)
elif model_name == 'Vgg_C':model = Vgg(cfg_vgg='C', num_classes=10).to(device)
elif model_name == 'Vgg_D':model = Vgg(cfg_vgg='D', num_classes=10).to(device)
elif model_name == 'Vgg_E':model = Vgg(cfg_vgg='E', num_classes=10).to(device)
elif model_name == 'GoogleNet':model = GoogleNet(num_classes=10).to(device)
elif model_name == 'ResNet18':model = ResNet18(num_classes=10).to(device)
elif model_name == 'ResNet34':model = ResNet34(num_classes=10).to(device)
elif model_name == 'ResNet50':model = ResNet50(num_classes=10).to(device)
elif model_name == 'ResNet101':model = ResNet101(num_classes=10).to(device)
elif model_name == 'ResNet152':model = ResNet152(num_classes=10).to(device)
elif model_name == 'MobileNet':model = MobileNet(num_classes=10).to(device)
elif model_name == 'EfficientNet':model = EfficientNet(num_classes=10).to(device)criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)# 训练轮次
epochs = 15def train(model, trainloader, criterion, optimizer, device):model.train()running_loss = 0.0correct = 0total = 0for i, data in enumerate(trainloader, 0):inputs, labels = data[0].to(device), data[1].to(device)optimizer.zero_grad()outputs = model(inputs)loss = criterion(outputs, labels)loss.backward()optimizer.step()running_loss += loss.item()_, predicted = outputs.max(1)total += labels.size(0)correct += predicted.eq(labels).sum().item()epoch_loss = running_loss / len(trainloader)epoch_acc = 100. * correct / totalreturn epoch_loss, epoch_accif __name__ == "__main__":loss_history, acc_history = [], []for epoch in range(epochs):train_loss, train_acc = train(model, trainloader, criterion, optimizer, device)print(f'Epoch {epoch + 1}: Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}%')loss_history.append(train_loss)acc_history.append(train_acc)# 保存模型权重,每5轮次保存到weights文件夹下if (epoch + 1) % 5 == 0:torch.save(model.state_dict(), f'weights/{model_name}_epoch_{epoch + 1}.pth')# 绘制损失曲线plt.plot(range(1, epochs+1), loss_history, label='Loss', marker='o')plt.xlabel('Epoch')plt.ylabel('Loss')plt.title('Training Loss Curve')plt.legend()plt.savefig(f'results\\{model_name}_train_loss_curve.png')plt.close()# 绘制准确率曲线plt.plot(range(1, epochs+1), acc_history, label='Accuracy', marker='o')plt.xlabel('Epoch')plt.ylabel('Accuracy (%)')plt.title('Training Accuracy Curve')plt.legend()plt.savefig(f'results\\{model_name}_train_acc_curve.png')plt.close()

test.py

import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
from models import *import ssl
ssl._create_default_https_context = ssl._create_unverified_context
# 定义数据预处理操作
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.491, 0.482, 0.446), (0.247, 0.243, 0.261))])# 加载CIFAR10测试集
testset = torchvision.datasets.CIFAR10(root='./data', train=False,download=False, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=128,shuffle=False, num_workers=2)# 定义设备(GPU优先,若可用)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")# 实例化模型
model_name = 'EfficientNet'
if model_name == 'AlexNet':model = AlexNet(num_classes=10).to(device)
elif model_name == 'Vgg_A':model = Vgg(cfg_vgg='A', num_classes=10).to(device)
elif model_name == 'Vgg_A-LRN':model = Vgg(cfg_vgg='A-LRN', num_classes=10).to(device)
elif model_name == 'Vgg_B':model = Vgg(cfg_vgg='B', num_classes=10).to(device)
elif model_name == 'Vgg_C':model = Vgg(cfg_vgg='C', num_classes=10).to(device)
elif model_name == 'Vgg_D':model = Vgg(cfg_vgg='D', num_classes=10).to(device)
elif model_name == 'Vgg_E':model = Vgg(cfg_vgg='E', num_classes=10).to(device)
elif model_name == 'GoogleNet':model = GoogleNet(num_classes=10).to(device)
elif model_name == 'ResNet18':model = ResNet18(num_classes=10).to(device)
elif model_name == 'ResNet34':model = ResNet34(num_classes=10).to(device)
elif model_name == 'ResNet50':model = ResNet50(num_classes=10).to(device)
elif model_name == 'ResNet101':model = ResNet101(num_classes=10).to(device)
elif model_name == 'ResNet152':model = ResNet152(num_classes=10).to(device)
elif model_name == 'MobileNet':model = MobileNet(num_classes=10).to(device)
elif model_name == 'EfficientNet':model = EfficientNet(num_classes=10).to(device)criterion = nn.CrossEntropyLoss()# 加载模型权重
weights_path = f"weights/{model_name}_epoch_15.pth"  
model.load_state_dict(torch.load(weights_path, map_location=device))def test(model, testloader, criterion, device):model.eval()running_loss = 0.0correct = 0total = 0with torch.no_grad():for data in testloader:inputs, labels = data[0].to(device), data[1].to(device)outputs = model(inputs)loss = criterion(outputs, labels)running_loss += loss.item()_, predicted = outputs.max(1)total += labels.size(0)correct += predicted.eq(labels).sum().item()epoch_loss = running_loss / len(testloader)epoch_acc = 100. * correct / totalreturn epoch_loss, epoch_accif __name__ == "__main__":test_loss, test_acc = test(model, testloader, criterion, device)print(f"================{model_name} Test================")print(f"Load Model Weights From: {weights_path}")print(f'Test Loss: {test_loss:.4f}, Test Acc: {test_acc:.2f}%')

相关文章:

Pytorch | 从零构建EfficientNet对CIFAR10进行分类

Pytorch | 从零构建EfficientNet对CIFAR10进行分类 CIFAR10数据集EfficientNet设计理念网络结构性能特点应用领域发展和改进 EfficientNet结构代码详解结构代码代码详解MBConv 类初始化方法前向传播 forward 方法 EfficientNet 类初始化方法前向传播 forward 方法 训练过程和测…...

Python超能力:高级技巧让你的代码飞起来

文章一览 前言一、with1.1 基本用法1.2 示例自定义上下文管理器 二、条件表达式三、列表式推导式与 zip 结合 四、map() 函数(内置函数)map用于数据清洗1. 数据清洗:字母大小写规范2. filter() 函数 五、匿名函数 lambda5.1 lambda的参数&…...

熊军出席ACDU·中国行南京站,详解SQL管理之道

12月21日,2024 ACDU中国行在南京圆满收官,本次活动分为三个篇章——回顾历史、立足当下、展望未来,为线上线下与会观众呈现了一场跨越时空的技术盛宴,吸引了众多业内人士的关注。云和恩墨副总经理熊军出席此次活动并发表了主题演讲…...

FPGA实现MIPI转FPD-Link车载同轴视频传输方案,基于IMX327+FPD953架构,提供工程源码和技术支持

目录 1、前言工程概述免责声明 2、相关方案推荐本博主所有FPGA工程项目-->汇总目录我这里已有的 MIPI 编解码方案 3、本 MIPI CSI-RX IP 介绍4、详细设计方案设计原理框图IMX327 及其配置FPD-Link视频串化-解串方案MIPI CSI RX图像 ISP 处理图像缓存HDMI输出工程源码架构 5、…...

vue3动态绑定图片和使用阿里巴巴矢量图

矢量图 1。加购物车 2. 下载在本地 解压 (把以下文件放进项目文件夹里面) ├── font ├── iconfont.css ├── iconfont.json (font-class用法) ├── iconfont.js (symbol用法) ├─…...

‘vite‘ 不是内部或外部命令,也不是可运行的程序

报错:执行 npm run dev时,提示’vite’ 不是内部或外部命令,也不是可运行的程序 解决:执行 npm install -g vite 报错:导入vite后再次执行npm run dev,报错failed to load config from E:\eclipseWP\test1…...

2024年12月一区SCI-加权平均优化算法Weighted average algorithm-附Matlab免费代码

引言 本期介绍了一种基于加权平均位置概念的元启发式优化算法,称为加权平均优化算法Weighted average algorithm,WAA。该成果于2024年12月最新发表在中JCR1区、 中科院1区 SCI期刊 Knowledge-Based Systems。 在WAA算法中,加权平均位置代表当…...

如何获取 ABAP 内表中的重复项

要识别 ABAP 内表中的重复项,可以结合使用排序和循环。下面的示例展示了如何查找内部表中的重复条目: DATA: BEGIN OF itab OCCURS 0,field1 TYPE i,field2 TYPE c LENGTH 10,END OF itab,wa LIKE LINE OF itab.* Add sample data to internal table it…...

编译笔记:vs 中 正在从以下位置***加载符号 C# 中捕获C/C++抛出的异常

加载符号 解决方法: 进入VS—工具—选项----调试----符号,看右边有个“Microsoft符号服务器”,将前面的勾去掉,(可能还有删除下面的那个缓存)。 参考 C# 中捕获C/C抛出的异常 在需要捕捉破坏性异常的函数…...

ChatGPT与Postman协作完成接口测试(二)

ChatGPT生成的Postman接口测试用例脚本如下所示。 ChatGPT生成的Postman接口测试用例脚本 以下是符合Collection v2.1格式要求的 Postman 测试用例脚本,覆盖了正常注册和密码不匹配两种情况的测试: { "info": { "_postman_id": &qu…...

flask-admin modelview 中重写get_query函数

背景: flask-admin框架中提供的模型视图默认是显示表实体中的所有列表数据,如果想通过某种条件限制初始列表数据,那么久需要重写一些方法才能实现。 材料: 略 制作: 视图源码: def get_query(self):re…...

【python 逆向分析某有道翻译】分析有道翻译公开的密文内容,webpack类型,全程扣代码,最后实现接口调用翻译,仅供学习参考

文章日期:2024.12.24 使用工具:Python,Node.js 逆向类型:webpack类型 本章知识:sign模拟生成,密文的解密(webpack),全程扣代码,仅供学习参考 文章难度:低等(没…...

tensorflow_probability与tensorflow版本依赖关系

参考:Tensorflow Probability 与 TensorFlow 的版本依赖关系_tensorflow与tensorflow-probability对应版本的网址-CSDN博客 tensorflow2.10对应tensorflow_probability0.18.0,安装命令:pip install tensorflow_probability0.18.0 版本对应关…...

构建安全的用户认证系统:PHP实现

构建安全的用户认证系统:PHP实现 用户认证是任何Web应用的重要组成部分,确保只有授权用户才能访问特定资源。构建一个安全的用户认证系统需要考虑多种因素,包括密码存储、会话管理和防止常见gongji。本文将介绍如何使用PHP实现一个安全的用户…...

VMD-SSA-BiLSTM、VMD-BiLSTM、BiLSTM时间序列预测对比

VMD-SSA-BiLSTM、VMD-BiLSTM、BiLSTM时间序列预测对比 目录 VMD-SSA-BiLSTM、VMD-BiLSTM、BiLSTM时间序列预测对比预测效果基本介绍程序设计参考资料 预测效果 基本介绍 1.MATLAB实现VMD-SSA-BiLSTM、VMD-BiLSTM、BiLSTM时间序列预测对比; 2.单变量时间序列预测 就是先vmd把变…...

天融信网络架构安全实践

1、医院客户想通过等保2.0三级,推荐哪几款网络安全产品?(至少6个) TopSAg(运维安全审计系统) TopNAC(网络准入系统) TopEDR(终端威胁防御系统) TDSM-SBU(存储备份一体机…...

腾讯云云开发 Copilot具有以下优势

与其他代码生成工具相比,腾讯云云开发 Copilot具有以下优势: 功能特性方面 自然语言处理能力更强:许多代码生成工具仅能实现简单的代码补全或根据特定模板生成代码,而云开发 Copilot可直接通过自然语言生成完整的小程序/web全栈…...

electron-vite【实战系列教程】

创建项目 安装必要的插件 UI 库 element-plus npm install element-plus --save 安装 element-plus 图标 npm install element-plus/icons-vue 安装插件 – 自动注册组件 vs 自动导入框架方法 npm install -D unplugin-vue-components unplugin-auto-import electron.vite.conf…...

【微信小程序】微信小程序中的异步函数是如何实现同步功能的

在微信小程序中,虽然很多 API 都是异步的,但可以通过一些方法来实现类似同步的功能。以下是几种常见的方法: 1. 使用 async/await async/await 是 ES2017 引入的语法糖,它基于 Promise 来实现异步操作的同步化写法。 示例代码 …...

贪心算法(三)

目录 一、k次取反后最大化的数组和 二、优势洗牌 三、最长回文串 四、增减字符串匹配 一、k次取反后最大化的数组和 k次取反后最大化的数组和 贪心策略&#xff1a; 解题代码&#xff1a; class Solution { public:int largestSumAfterKNegations(vector<int>&am…...

uniApp打包H5发布到服务器(docker)

使用docker部署uniApp打包后的H5项目记录&#xff0c;好像和VUE项目打包没什么区别... 用HX打开项目&#xff0c;首先调整manifest.json文件 开始用HX打包 填服务器域名和端口号~ 打包完成后可以看到控制台信息 我们可以在web文件夹下拿到下面打包好的静态文件 用FinalShell或…...

【AI落地应用实战】篡改检测技术前沿探索——从基于检测分割到大模型

在数字化洪流席卷全球的当下&#xff0c;视觉内容已成为信息交流与传播的核心媒介&#xff0c;然而&#xff0c;随着PS技术和AIGC技术的飞速发展&#xff0c;图像篡改给视觉内容安全带来了前所未有的挑战。 本文将探讨篡改检测技术的现实挑战&#xff0c;分享篡改检测技术前沿…...

使用 VSCode 学习与实践 LaTeX:从插件安装到排版技巧

文章目录 背景介绍编辑器编译文件指定输出文件夹 usepackagelatex 语法列表插入图片添加参考文献 背景介绍 最近在写文章&#xff0c;更喜欢latex的论文引用。然后开始学习 latex。 编辑器 本文选择vscode作为编辑器&#xff0c;当然大家也可以尝试overleaf。 overleaf 有网…...

使用scrapy框架爬取微博热搜榜

注&#xff1a;在使用爬虫抓取网站数据之前&#xff0c;非常重要的一点是确保遵守相关的法律、法规以及目标网站的使用条款。 &#xff08;最底下附下载链接&#xff09; 准备工作&#xff1a; 安装依赖&#xff1a; 确保已经安装了Python环境。 使用pip安装scrapy&#xff…...

瑞吉外卖项目学习笔记(七)新增菜品、(批量)删除菜品

瑞吉外卖项目学习笔记(一)准备工作、员工登录功能实现 瑞吉外卖项目学习笔记(二)Swagger、logback、表单校验和参数打印功能的实现 瑞吉外卖项目学习笔记(三)过滤器实现登录校验、添加员工、分页查询员工信息 瑞吉外卖项目学习笔记(四)TableField(fill FieldFill.INSERT)公共字…...

es快速扫描

介绍 Elasticsearch简称es&#xff0c;一款开源的分布式全文检索引擎 可组建一套上百台的服务器集群&#xff0c;处理PB级别数据 可满足近实时的存储和检索 倒排索引 跟正排索引相对&#xff0c;正排索引是根据id进行索引&#xff0c;所以查询效率非常高&#xff0c;但是模糊…...

前端对页面数据进行缓存

页面录入信息&#xff0c;退出且未提交状态下&#xff0c;前端对页面数据进行存储 前端做缓存&#xff0c;一般放在local、session和cookies里面&#xff0c;但是都有大小限制&#xff0c;如果页面东西多&#xff0c;比如有上传的图片、视频&#xff0c;浏览器会抛出一个Quota…...

leetCode322.零钱兑换

题目&#xff1a; 给你一个整数数组coins,表示不同面额的硬币&#xff1b;以及一个整数amount,表示总金额。 计算并返回可以凑成总金额所需的最少的硬币个数。如果没有任何一种硬币组合能组成总金额&#xff0c;返回-1。 你可以认为每种硬币的数量是无限的。 示例1&#xff1…...

jsp-servlet开发

STS中开发步骤 建普通jsp项目过程 1.建项目&#xff08;非Maven项目&#xff09; new----project----other----Web----Dynamic Web Project 2.下载包放到LIB目录中,如果是Maven项目可以自动导包&#xff08;pom.xml中设置好&#xff09; 3.设置工作空间&#xff0c;网页…...

从零玩转CanMV-K230(7)-I2C例程

文章目录 前言一、IIC API二、示例总结 前言 K230内部包含5个I2C硬件模块&#xff0c;支持标准100kb/s&#xff0c;快速400kb/s模式&#xff0c;高速模式3.4Mb/s。 通道输出IO配置参考IOMUX模块。 一、IIC API I2C类位于machine模块下。 i2c I2C(id, freq100000) 【参数】…...