Pytorch | 从零构建MobileNet对CIFAR10进行分类
Pytorch | 从零构建MobileNet对CIFAR10进行分类
- CIFAR10数据集
- MobileNet
- 设计理念
- 网络结构
- 技术优势
- 应用领域
- MobileNet结构代码详解
- 结构代码
- 代码详解
- DepthwiseSeparableConv 类
- 初始化方法
- 前向传播 forward 方法
- MobileNet 类
- 初始化方法
- 前向传播 forward 方法
- 训练和测试
- 训练代码train.py
- 测试代码test.py
- 训练过程和测试结果
- 代码汇总
- mobilenet.py
- train.py
- test.py
前面文章我们构建了AlexNet、Vgg、GoogleNet对CIFAR10进行分类:
Pytorch | 从零构建AlexNet对CIFAR10进行分类
Pytorch | 从零构建Vgg对CIFAR10进行分类
Pytorch | 从零构建GoogleNet对CIFAR10进行分类
Pytorch | 从零构建ResNet对CIFAR10进行分类
这篇文章我们来构建MobileNet.
CIFAR10数据集
CIFAR-10数据集是由加拿大高级研究所(CIFAR)收集整理的用于图像识别研究的常用数据集,基本信息如下:
- 数据规模:该数据集包含60,000张彩色图像,分为10个不同的类别,每个类别有6,000张图像。通常将其中50,000张作为训练集,用于模型的训练;10,000张作为测试集,用于评估模型的性能。
- 图像尺寸:所有图像的尺寸均为32×32像素,这相对较小的尺寸使得模型在处理该数据集时能够相对快速地进行训练和推理,但也增加了图像分类的难度。
- 类别内容:涵盖了飞机(plane)、汽车(car)、鸟(bird)、猫(cat)、鹿(deer)、狗(dog)、青蛙(frog)、马(horse)、船(ship)、卡车(truck)这10个不同的类别,这些类别都是现实世界中常见的物体,具有一定的代表性。
下面是一些示例样本:

MobileNet
MobileNet是由谷歌在2017年提出的一种轻量级卷积神经网络,主要用于移动端和嵌入式设备等资源受限的环境中进行图像识别和分类任务,以下是对其的详细介绍:
设计理念
- 深度可分离卷积:其核心创新是采用了深度可分离卷积(Depthwise Separable Convolution)来替代传统的卷积操作。深度可分离卷积将标准卷积分解为一个深度卷积(Depthwise Convolution)和一个逐点卷积(Pointwise Convolution),大大减少了计算量和模型参数,同时保持了较好的性能。
网络结构
- 标准卷积层:输入层为3通道的彩色图像,首先经过一个普通的卷积层
conv1,将通道数从3变为32,同时进行了步长为2的下采样操作,以减小图像尺寸。 - 深度可分离卷积层:包含了一系列的深度可分离卷积层
dsconv1至dsconv13,这些层按照一定的规律进行排列,通道数逐渐增加,同时通过不同的步长进行下采样,以提取不同层次的特征。 - 池化层和全连接层:在深度可分离卷积层之后,通过一个自适应平均池化层
avgpool将特征图转换为1x1的大小,然后通过一个全连接层fc将特征映射到指定的类别数,完成分类任务。
技术优势
- 模型轻量化:通过深度可分离卷积的使用,大大减少了模型的参数量和计算量,使得模型更加轻量化,适合在移动设备和嵌入式设备上运行。
- 计算效率高:由于减少了计算量,MobileNet在推理时具有较高的计算效率,可以快速地对图像进行分类和识别,满足实时性要求较高的应用场景。
- 性能表现较好:尽管模型轻量化,但MobileNet在图像识别任务上仍然具有较好的性能表现,能够在保持较高准确率的同时,大大降低模型的复杂度。
应用领域
- 移动端视觉任务:广泛应用于各种移动端设备,如智能手机、平板电脑等,用于图像分类、目标检测、人脸识别等视觉任务。
- 嵌入式设备视觉:在嵌入式设备,如智能摄像头、自动驾驶汽车等领域,MobileNet可以为这些设备提供高效的视觉处理能力,实现实时的图像分析和决策。
- 物联网视觉应用:在物联网设备中,MobileNet可以帮助实现对图像数据的快速处理和分析,为智能家居、智能安防等应用提供支持。
MobileNet结构代码详解
结构代码
import torch
import torch.nn as nnclass DepthwiseSeparableConv(nn.Module):def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False):super(DepthwiseSeparableConv, self).__init__()self.depthwise = nn.Conv2d(in_channels, in_channels, kernel_size, stride, padding, groups=in_channels, bias=bias)self.bn1 = nn.BatchNorm2d(in_channels)self.relu1 = nn.ReLU6(inplace=True)self.pointwise = nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=bias)self.bn2 = nn.BatchNorm2d(out_channels)self.relu2 = nn.ReLU6(inplace=True)def forward(self, x):out = self.depthwise(x)out = self.bn1(out)out = self.relu1(out)out = self.pointwise(out)out = self.bn2(out)out = self.relu2(out)return outclass MobileNet(nn.Module):def __init__(self, num_classes):super(MobileNet, self).__init__()self.conv1 = nn.Conv2d(3, 32, kernel_size=3, stride=2, padding=1, bias=False)self.bn1 = nn.BatchNorm2d(32)self.relu = nn.ReLU6(inplace=True)self.dsconv1 = DepthwiseSeparableConv(32, 64, stride=1)self.dsconv2 = DepthwiseSeparableConv(64, 128, stride=2)self.dsconv3 = DepthwiseSeparableConv(128, 128, stride=1)self.dsconv4 = DepthwiseSeparableConv(128, 256, stride=2)self.dsconv5 = DepthwiseSeparableConv(256, 256, stride=1)self.dsconv6 = DepthwiseSeparableConv(256, 512, stride=2)self.dsconv7 = DepthwiseSeparableConv(512, 512, stride=1)self.dsconv8 = DepthwiseSeparableConv(512, 512, stride=1)self.dsconv9 = DepthwiseSeparableConv(512, 512, stride=1)self.dsconv10 = DepthwiseSeparableConv(512, 512, stride=1)self.dsconv11 = DepthwiseSeparableConv(512, 512, stride=1)self.dsconv12 = DepthwiseSeparableConv(512, 1024, stride=2)self.dsconv13 = DepthwiseSeparableConv(1024, 1024, stride=1)self.avgpool = nn.AdaptiveAvgPool2d((1, 1))self.fc = nn.Linear(1024, num_classes)def forward(self, x):out = self.conv1(x)out = self.bn1(out)out = self.relu(out)out = self.dsconv1(out)out = self.dsconv2(out)out = self.dsconv3(out)out = self.dsconv4(out)out = self.dsconv5(out)out = self.dsconv6(out)out = self.dsconv7(out)out = self.dsconv8(out)out = self.dsconv9(out)out = self.dsconv10(out)out = self.dsconv11(out)out = self.dsconv12(out)out = self.dsconv13(out)out = self.avgpool(out)out = out.view(out.size(0), -1)out = self.fc(out)return out
代码详解
以下是对上述代码的详细解释:
DepthwiseSeparableConv 类
这是一个自定义的深度可分离卷积层类,继承自 nn.Module。
初始化方法
- 参数说明:
in_channels:输入通道数,指定输入数据的通道数量。out_channels:输出通道数,即卷积操作后输出特征图的通道数量。kernel_size:卷积核大小,默认为3,用于定义卷积操作中卷积核的尺寸。stride:步长,默认为1,控制卷积核在输入特征图上滑动的步长。padding:填充大小,默认为1,在输入特征图周围添加的填充像素数量,以保持特征图尺寸在卷积过程中合适变化。bias:是否使用偏置,默认为False,决定卷积层是否添加偏置项。
- 构建的层及作用:
self.depthwise:这是一个深度卷积层(nn.Conv2d),通过设置groups=in_channels,实现了深度可分离卷积中的深度卷积部分,它对每个输入通道分别进行卷积操作,有效地减少了计算量。self.bn1:批归一化层(nn.BatchNorm2d),用于对深度卷积后的输出进行归一化处理,加速模型收敛并提升模型的泛化能力。self.relu1:激活函数层(nn.ReLU6),采用ReLU6激活函数(输出值限定在0到6之间),并且设置inplace=True,意味着直接在输入的张量上进行修改,节省内存空间,增加非线性特性。self.pointwise:逐点卷积层(nn.Conv2d),卷积核大小为1,用于将深度卷积后的特征图在通道维度上进行融合,改变通道数到指定的out_channels。self.bn2:又是一个批归一化层,对逐点卷积后的输出进行归一化处理。self.relu2:同样是ReLU6激活函数层,进一步增加非线性,处理逐点卷积归一化后的结果。
前向传播 forward 方法
定义了数据在该层的前向传播过程:
- 首先将输入
x通过深度卷积层self.depthwise进行深度卷积操作,得到输出特征图。 - 然后将深度卷积的输出依次经过批归一化层
self.bn1和激活函数层self.relu1。 - 接着把经过处理后的特征图通过逐点卷积层
self.pointwise进行逐点卷积,改变通道数等特征。 - 最后再经过批归一化层
self.bn2和激活函数层self.relu2,并返回最终的输出结果。
MobileNet 类
这是定义的 MobileNet 网络模型类,同样继承自 nn.Module。
初始化方法
- 参数说明:
num_classes:分类的类别数量,用于最后全连接层输出对应类别数的预测结果。
- 构建的层及作用:
self.conv1:普通的二维卷积层(nn.Conv2d),输入通道数为3(通常对应RGB图像的三个通道),输出通道数为32,卷积核大小为3,步长为2,用于对输入图像进行初步的特征提取和下采样,减少特征图尺寸同时增加通道数。self.bn1:批归一化层,对conv1卷积后的输出进行归一化。self.relu:激活函数层,采用ReLU6激活函数给特征图增加非线性。- 一系列的
self.dsconv层(从dsconv1到dsconv13):都是前面定义的深度可分离卷积层DepthwiseSeparableConv的实例,它们逐步对特征图进行更精细的特征提取、通道变换以及下采样等操作,不同的dsconv层有着不同的输入输出通道数以及步长设置,以此构建出 MobileNet 网络的主体结构,不断提取和融合特征,逐步降低特征图尺寸并增加通道数来获取更高级、更抽象的特征表示。 self.avgpool:自适应平均池化层(nn.AdaptiveAvgPool2d),将输入特征图转换为指定大小(1, 1)的输出,起到全局平均池化的作用,进一步压缩特征图信息,同时保持特征图的维度一致性,方便后续全连接层处理。self.fc:全连接层(nn.Linear),输入维度为1024(与前面网络结构最终输出的特征维度对应),输出维度为num_classes,用于将经过前面卷积和池化等操作得到的特征向量映射到对应类别数量的预测分数上,实现分类任务。
前向传播 forward 方法
定义了 MobileNet 模型整体的前向传播流程:
- 首先将输入
x通过conv1进行初始卷积、bn1进行归一化以及relu激活。 - 然后依次通过各个深度可分离卷积层(
dsconv1到dsconv13),逐步提取和变换特征。 - 接着经过自适应平均池化层
self.avgpool,将特征图压缩为(1, 1)大小。 - 再通过
out.view(out.size(0), -1)操作将特征图展平为一维向量(其中out.size(0)表示批量大小,-1表示自动计算剩余维度大小使其展平)。 - 最后将展平后的特征向量通过全连接层
self.fc得到最终的分类预测结果并返回。
训练和测试
训练代码train.py
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from models import *
import matplotlib.pyplot as pltimport ssl
ssl._create_default_https_context = ssl._create_unverified_context# 定义数据预处理操作
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.491, 0.482, 0.446), (0.247, 0.243, 0.261))])# 加载CIFAR10训练集
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,download=False, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=128,shuffle=True, num_workers=2)# 定义设备(GPU优先,若可用)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")# 实例化模型
model_name = 'MobileNet'
if model_name == 'AlexNet':model = AlexNet(num_classes=10).to(device)
elif model_name == 'Vgg_A':model = Vgg(cfg_vgg='A', num_classes=10).to(device)
elif model_name == 'Vgg_A-LRN':model = Vgg(cfg_vgg='A-LRN', num_classes=10).to(device)
elif model_name == 'Vgg_B':model = Vgg(cfg_vgg='B', num_classes=10).to(device)
elif model_name == 'Vgg_C':model = Vgg(cfg_vgg='C', num_classes=10).to(device)
elif model_name == 'Vgg_D':model = Vgg(cfg_vgg='D', num_classes=10).to(device)
elif model_name == 'Vgg_E':model = Vgg(cfg_vgg='E', num_classes=10).to(device)
elif model_name == 'GoogleNet':model = GoogleNet(num_classes=10).to(device)
elif model_name == 'ResNet18':model = ResNet18(num_classes=10).to(device)
elif model_name == 'ResNet34':model = ResNet34(num_classes=10).to(device)
elif model_name == 'ResNet50':model = ResNet50(num_classes=10).to(device)
elif model_name == 'ResNet101':model = ResNet101(num_classes=10).to(device)
elif model_name == 'ResNet152':model = ResNet152(num_classes=10).to(device)
elif model_name == 'MobileNet':model = MobileNet(num_classes=10).to(device)criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)# 训练轮次
epochs = 15def train(model, trainloader, criterion, optimizer, device):model.train()running_loss = 0.0correct = 0total = 0for i, data in enumerate(trainloader, 0):inputs, labels = data[0].to(device), data[1].to(device)optimizer.zero_grad()outputs = model(inputs)loss = criterion(outputs, labels)loss.backward()optimizer.step()running_loss += loss.item()_, predicted = outputs.max(1)total += labels.size(0)correct += predicted.eq(labels).sum().item()epoch_loss = running_loss / len(trainloader)epoch_acc = 100. * correct / totalreturn epoch_loss, epoch_accif __name__ == "__main__":loss_history, acc_history = [], []for epoch in range(epochs):train_loss, train_acc = train(model, trainloader, criterion, optimizer, device)print(f'Epoch {epoch + 1}: Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}%')loss_history.append(train_loss)acc_history.append(train_acc)# 保存模型权重,每5轮次保存到weights文件夹下if (epoch + 1) % 5 == 0:torch.save(model.state_dict(), f'weights/{model_name}_epoch_{epoch + 1}.pth')# 绘制损失曲线plt.plot(range(1, epochs+1), loss_history, label='Loss', marker='o')plt.xlabel('Epoch')plt.ylabel('Loss')plt.title('Training Loss Curve')plt.legend()plt.savefig(f'results\\{model_name}_train_loss_curve.png')plt.close()# 绘制准确率曲线plt.plot(range(1, epochs+1), acc_history, label='Accuracy', marker='o')plt.xlabel('Epoch')plt.ylabel('Accuracy (%)')plt.title('Training Accuracy Curve')plt.legend()plt.savefig(f'results\\{model_name}_train_acc_curve.png')plt.close()
测试代码test.py
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
from models import *import ssl
ssl._create_default_https_context = ssl._create_unverified_context
# 定义数据预处理操作
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.491, 0.482, 0.446), (0.247, 0.243, 0.261))])# 加载CIFAR10测试集
testset = torchvision.datasets.CIFAR10(root='./data', train=False,download=False, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=128,shuffle=False, num_workers=2)# 定义设备(GPU优先,若可用)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")# 实例化模型
model_name = 'MobileNet'
if model_name == 'AlexNet':model = AlexNet(num_classes=10).to(device)
elif model_name == 'Vgg_A':model = Vgg(cfg_vgg='A', num_classes=10).to(device)
elif model_name == 'Vgg_A-LRN':model = Vgg(cfg_vgg='A-LRN', num_classes=10).to(device)
elif model_name == 'Vgg_B':model = Vgg(cfg_vgg='B', num_classes=10).to(device)
elif model_name == 'Vgg_C':model = Vgg(cfg_vgg='C', num_classes=10).to(device)
elif model_name == 'Vgg_D':model = Vgg(cfg_vgg='D', num_classes=10).to(device)
elif model_name == 'Vgg_E':model = Vgg(cfg_vgg='E', num_classes=10).to(device)
elif model_name == 'GoogleNet':model = GoogleNet(num_classes=10).to(device)
elif model_name == 'ResNet18':model = ResNet18(num_classes=10).to(device)
elif model_name == 'ResNet34':model = ResNet34(num_classes=10).to(device)
elif model_name == 'ResNet50':model = ResNet50(num_classes=10).to(device)
elif model_name == 'ResNet101':model = ResNet101(num_classes=10).to(device)
elif model_name == 'ResNet152':model = ResNet152(num_classes=10).to(device)
elif model_name == 'MobileNet':model = MobileNet(num_classes=10).to(device)criterion = nn.CrossEntropyLoss()# 加载模型权重
weights_path = f"weights/{model_name}_epoch_15.pth"
model.load_state_dict(torch.load(weights_path, map_location=device))def test(model, testloader, criterion, device):model.eval()running_loss = 0.0correct = 0total = 0with torch.no_grad():for data in testloader:inputs, labels = data[0].to(device), data[1].to(device)outputs = model(inputs)loss = criterion(outputs, labels)running_loss += loss.item()_, predicted = outputs.max(1)total += labels.size(0)correct += predicted.eq(labels).sum().item()epoch_loss = running_loss / len(testloader)epoch_acc = 100. * correct / totalreturn epoch_loss, epoch_accif __name__ == "__main__":test_loss, test_acc = test(model, testloader, criterion, device)print(f"================{model_name} Test================")print(f"Load Model Weights From: {weights_path}")print(f'Test Loss: {test_loss:.4f}, Test Acc: {test_acc:.2f}%')
训练过程和测试结果
训练过程损失函数变化曲线:

训练过程准确率变化曲线:

测试结果:

代码汇总
项目github地址
项目结构:
|--data
|--models|--__init__.py|-mobilenet.py|--...
|--results
|--weights
|--train.py
|--test.py
mobilenet.py
import torch
import torch.nn as nnclass DepthwiseSeparableConv(nn.Module):def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False):super(DepthwiseSeparableConv, self).__init__()self.depthwise = nn.Conv2d(in_channels, in_channels, kernel_size, stride, padding, groups=in_channels, bias=bias)self.bn1 = nn.BatchNorm2d(in_channels)self.relu1 = nn.ReLU6(inplace=True)self.pointwise = nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=bias)self.bn2 = nn.BatchNorm2d(out_channels)self.relu2 = nn.ReLU6(inplace=True)def forward(self, x):out = self.depthwise(x)out = self.bn1(out)out = self.relu1(out)out = self.pointwise(out)out = self.bn2(out)out = self.relu2(out)return outclass MobileNet(nn.Module):def __init__(self, num_classes):super(MobileNet, self).__init__()self.conv1 = nn.Conv2d(3, 32, kernel_size=3, stride=2, padding=1, bias=False)self.bn1 = nn.BatchNorm2d(32)self.relu = nn.ReLU6(inplace=True)self.dsconv1 = DepthwiseSeparableConv(32, 64, stride=1)self.dsconv2 = DepthwiseSeparableConv(64, 128, stride=2)self.dsconv3 = DepthwiseSeparableConv(128, 128, stride=1)self.dsconv4 = DepthwiseSeparableConv(128, 256, stride=2)self.dsconv5 = DepthwiseSeparableConv(256, 256, stride=1)self.dsconv6 = DepthwiseSeparableConv(256, 512, stride=2)self.dsconv7 = DepthwiseSeparableConv(512, 512, stride=1)self.dsconv8 = DepthwiseSeparableConv(512, 512, stride=1)self.dsconv9 = DepthwiseSeparableConv(512, 512, stride=1)self.dsconv10 = DepthwiseSeparableConv(512, 512, stride=1)self.dsconv11 = DepthwiseSeparableConv(512, 512, stride=1)self.dsconv12 = DepthwiseSeparableConv(512, 1024, stride=2)self.dsconv13 = DepthwiseSeparableConv(1024, 1024, stride=1)self.avgpool = nn.AdaptiveAvgPool2d((1, 1))self.fc = nn.Linear(1024, num_classes)def forward(self, x):out = self.conv1(x)out = self.bn1(out)out = self.relu(out)out = self.dsconv1(out)out = self.dsconv2(out)out = self.dsconv3(out)out = self.dsconv4(out)out = self.dsconv5(out)out = self.dsconv6(out)out = self.dsconv7(out)out = self.dsconv8(out)out = self.dsconv9(out)out = self.dsconv10(out)out = self.dsconv11(out)out = self.dsconv12(out)out = self.dsconv13(out)out = self.avgpool(out)out = out.view(out.size(0), -1)out = self.fc(out)return out
train.py
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from models import *
import matplotlib.pyplot as pltimport ssl
ssl._create_default_https_context = ssl._create_unverified_context# 定义数据预处理操作
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.491, 0.482, 0.446), (0.247, 0.243, 0.261))])# 加载CIFAR10训练集
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,download=False, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=128,shuffle=True, num_workers=2)# 定义设备(GPU优先,若可用)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")# 实例化模型
model_name = 'MobileNet'
if model_name == 'AlexNet':model = AlexNet(num_classes=10).to(device)
elif model_name == 'Vgg_A':model = Vgg(cfg_vgg='A', num_classes=10).to(device)
elif model_name == 'Vgg_A-LRN':model = Vgg(cfg_vgg='A-LRN', num_classes=10).to(device)
elif model_name == 'Vgg_B':model = Vgg(cfg_vgg='B', num_classes=10).to(device)
elif model_name == 'Vgg_C':model = Vgg(cfg_vgg='C', num_classes=10).to(device)
elif model_name == 'Vgg_D':model = Vgg(cfg_vgg='D', num_classes=10).to(device)
elif model_name == 'Vgg_E':model = Vgg(cfg_vgg='E', num_classes=10).to(device)
elif model_name == 'GoogleNet':model = GoogleNet(num_classes=10).to(device)
elif model_name == 'ResNet18':model = ResNet18(num_classes=10).to(device)
elif model_name == 'ResNet34':model = ResNet34(num_classes=10).to(device)
elif model_name == 'ResNet50':model = ResNet50(num_classes=10).to(device)
elif model_name == 'ResNet101':model = ResNet101(num_classes=10).to(device)
elif model_name == 'ResNet152':model = ResNet152(num_classes=10).to(device)
elif model_name == 'MobileNet':model = MobileNet(num_classes=10).to(device)criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)# 训练轮次
epochs = 15def train(model, trainloader, criterion, optimizer, device):model.train()running_loss = 0.0correct = 0total = 0for i, data in enumerate(trainloader, 0):inputs, labels = data[0].to(device), data[1].to(device)optimizer.zero_grad()outputs = model(inputs)loss = criterion(outputs, labels)loss.backward()optimizer.step()running_loss += loss.item()_, predicted = outputs.max(1)total += labels.size(0)correct += predicted.eq(labels).sum().item()epoch_loss = running_loss / len(trainloader)epoch_acc = 100. * correct / totalreturn epoch_loss, epoch_accif __name__ == "__main__":loss_history, acc_history = [], []for epoch in range(epochs):train_loss, train_acc = train(model, trainloader, criterion, optimizer, device)print(f'Epoch {epoch + 1}: Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}%')loss_history.append(train_loss)acc_history.append(train_acc)# 保存模型权重,每5轮次保存到weights文件夹下if (epoch + 1) % 5 == 0:torch.save(model.state_dict(), f'weights/{model_name}_epoch_{epoch + 1}.pth')# 绘制损失曲线plt.plot(range(1, epochs+1), loss_history, label='Loss', marker='o')plt.xlabel('Epoch')plt.ylabel('Loss')plt.title('Training Loss Curve')plt.legend()plt.savefig(f'results\\{model_name}_train_loss_curve.png')plt.close()# 绘制准确率曲线plt.plot(range(1, epochs+1), acc_history, label='Accuracy', marker='o')plt.xlabel('Epoch')plt.ylabel('Accuracy (%)')plt.title('Training Accuracy Curve')plt.legend()plt.savefig(f'results\\{model_name}_train_acc_curve.png')plt.close()
test.py
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
from models import *import ssl
ssl._create_default_https_context = ssl._create_unverified_context
# 定义数据预处理操作
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.491, 0.482, 0.446), (0.247, 0.243, 0.261))])# 加载CIFAR10测试集
testset = torchvision.datasets.CIFAR10(root='./data', train=False,download=False, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=128,shuffle=False, num_workers=2)# 定义设备(GPU优先,若可用)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")# 实例化模型
model_name = 'MobileNet'
if model_name == 'AlexNet':model = AlexNet(num_classes=10).to(device)
elif model_name == 'Vgg_A':model = Vgg(cfg_vgg='A', num_classes=10).to(device)
elif model_name == 'Vgg_A-LRN':model = Vgg(cfg_vgg='A-LRN', num_classes=10).to(device)
elif model_name == 'Vgg_B':model = Vgg(cfg_vgg='B', num_classes=10).to(device)
elif model_name == 'Vgg_C':model = Vgg(cfg_vgg='C', num_classes=10).to(device)
elif model_name == 'Vgg_D':model = Vgg(cfg_vgg='D', num_classes=10).to(device)
elif model_name == 'Vgg_E':model = Vgg(cfg_vgg='E', num_classes=10).to(device)
elif model_name == 'GoogleNet':model = GoogleNet(num_classes=10).to(device)
elif model_name == 'ResNet18':model = ResNet18(num_classes=10).to(device)
elif model_name == 'ResNet34':model = ResNet34(num_classes=10).to(device)
elif model_name == 'ResNet50':model = ResNet50(num_classes=10).to(device)
elif model_name == 'ResNet101':model = ResNet101(num_classes=10).to(device)
elif model_name == 'ResNet152':model = ResNet152(num_classes=10).to(device)
elif model_name == 'MobileNet':model = MobileNet(num_classes=10).to(device)criterion = nn.CrossEntropyLoss()# 加载模型权重
weights_path = f"weights/{model_name}_epoch_15.pth"
model.load_state_dict(torch.load(weights_path, map_location=device))def test(model, testloader, criterion, device):model.eval()running_loss = 0.0correct = 0total = 0with torch.no_grad():for data in testloader:inputs, labels = data[0].to(device), data[1].to(device)outputs = model(inputs)loss = criterion(outputs, labels)running_loss += loss.item()_, predicted = outputs.max(1)total += labels.size(0)correct += predicted.eq(labels).sum().item()epoch_loss = running_loss / len(testloader)epoch_acc = 100. * correct / totalreturn epoch_loss, epoch_accif __name__ == "__main__":test_loss, test_acc = test(model, testloader, criterion, device)print(f"================{model_name} Test================")print(f"Load Model Weights From: {weights_path}")print(f'Test Loss: {test_loss:.4f}, Test Acc: {test_acc:.2f}%')
相关文章:
Pytorch | 从零构建MobileNet对CIFAR10进行分类
Pytorch | 从零构建MobileNet对CIFAR10进行分类 CIFAR10数据集MobileNet设计理念网络结构技术优势应用领域 MobileNet结构代码详解结构代码代码详解DepthwiseSeparableConv 类初始化方法前向传播 forward 方法 MobileNet 类初始化方法前向传播 forward 方法 训练和测试训练代码…...
CSS系列(18)-- 工程化实践详解
前端技术探索系列:CSS 工程化实践详解 🏗️ 致读者:探索 CSS 工程化之路 👋 前端开发者们, 今天我们将深入探讨 CSS 工程化实践,学习如何在大型项目中管理 CSS。 工程化配置 🚀 项目结构 …...
日拱一卒(18)——leetcode学习记录:二叉树中的伪回文路径
一、题目 给你一棵二叉树,每个节点的值为 1 到 9 。我们称二叉树中的一条路径是 「伪回文」的,当它满足:路径经过的所有节点值的排列中,存在一个回文序列。 请你返回从根到叶子节点的所有路径中 伪回文 路径的数目。 二、思路 …...
hive—炸裂函数explode/posexplode
1、Explode炸裂函数 将hive某列一行中复杂的 array 或 map 结构拆分成多行(只能输入array或map) 语法: select explode(字段) as 字段命名 from 表名; 举例: 1)explode(array)使得结果中将array列表里的每个元素生…...
SpringBoot 新特性
优质博文:IT-BLOG-CN 2.1.0新特性最低支持jdk8,支持tomcat9 对响应式编程的支持,spring-boot-starter-webflux starter POM可以快速开始使用Spring WebFlux,它由嵌入式Netty服务器支持 1.5.8 2.1.0/2.7.0/3.0.0 Configuration propertie…...
鸿蒙app封装 axios post请求失败问题
这个问题是我的一个疏忽大意,在这里记录一下。如果有相同问题的朋友,可以借鉴。 当我 ohpm install ohos/axios 后,进行简单post请求验证,可以请求成功。 然后,我对axios 进行了封装。对axios 添加请求拦截器/添加响…...
消息队列 Kafka 架构组件及其特性
Kafka 人们通常有时会将 Kafka 中的 Topic 比作队列; 在 Kafka 中,数据是以主题(Topic)的形式组织的,每个 Topic 可以被分为多个分区(Partition)。每个 Partition 是一个有序的、不可变的消息…...
网络攻击与防范
目录 选填 第一章 1、三种网络模式 2、几种创建网络拓扑结构 NAT模式 VPN模式 软路由模式1 软路由模式2 3、Linux网络配置常用指令 4、常见网络服务配置 DHCP DNS Web服务与FTP服务 FTP用户隔离 第二章 DNS信息收集(dnsenum、dnsmap) 路…...
文献研读|基于像素语义层面图像重建的AI生成图像检测
前言:本篇文章主要对基于重建的AI生成图像检测的四篇相关工作进行介绍,分别为基于像素层面重建的检测方法 DIRE 和 Aeroblade,以及基于语义层面重建的检测方法 SimGIR 和 Zerofake;并对相应方法进行比较。 相关文章:论…...
【操作系统】为什么需要架构裁剪?
为什么需要架构裁剪? 原因 减小核心大小提高架构初始化速度降低内存占用提高系统性能移除不需要的功能,增加安全性 裁剪方法 初始化配置设置功能模块化移除不需要的驱动底层 一般裁剪对象(以操作系统为例) 文件系统的支持网…...
LSTM长短期记忆网络
LSTM(长短期记忆网络)数学原理 LSTM(Long Short-Term Memory)是一种特殊的递归神经网络(RNN),解决了标准RNN中存在的梯度消失(Vanishing Gradient) 和**梯度爆炸&#x…...
基于前端技术UniApp和后端技术Node.js的电影购票系统
文章目录 摘要Abstruct第一章 绪论1.1 研究背景与意义1.2 国内外研究现状 第二章 需求分析2.1 功能需求分析2.2 非功能性需求分析 第二章系统设计3.1 系统架构设计3.1.1 总体架构3.1.2 技术选型 3.2 功能架构 第四章 系统实现4.1 用户端系统实现4.1.1 用户认证模块实现4.1.2 电…...
数据结构与算法:稀疏数组
前言 此文以整型元素的二维数组为例,阐述稀疏数组的思想。其他类型或许有更适合压缩算法或者其他结构的稀疏数组,此文暂不扩展。 稀疏数组的定义 在一个二维数据数组里,由于大量的元素的值为同一个值,比如 0或者其他已知的默认值…...
Meta重磅发布Llama 3.3 70B:开源AI模型的新里程碑
在人工智能领域,Meta的最新动作再次引起了全球的关注。今天,我们见证了Meta发布的Llama 3.3 70B模型,这是一个开源的人工智能模型,它不仅令人印象深刻,而且在性能上达到了一个新的高度。 一,技术突破&#…...
VSCode中的Black Formatter没有生效的解决办法
说明 如果正常按照配置进行的话,理论上是可以生效的。 "[python]": {"editor.defaultFormatter": "ms-python.black-formatter","editor.formatOnSave": true }但我在一种情况下发现不能生效,应为其本身的bug…...
【潜意识Java】蓝桥杯算法有关的动态规划求解背包问题
目录 背包问题简介 问题描述 输入: 输出: 动态规划解法 动态规划状态转移 代码实现 代码解释 动态规划的时间复杂度 例子解析 输出: 总结 作者我蓝桥杯:2023第十四届蓝桥杯国赛C/C大学B组一等奖,所以请听我…...
Odoo:免费开源ERP的AI技术赋能出海企业电子商务应用介绍
概述 伴随电子商务的持续演进,客户对于便利性、速度以及个性化服务的期许急剧攀升。企业务必要探寻创新之途径,以强化自身运营,并优化购物体验。达成此目标的最为行之有效的方式之一,便是将 AI 呼叫助手融入您的电子商务平台。我们…...
微信小程序苹果手机自带的数字键盘老是弹出收起,影响用户体验,100%解决
文章目录 1、index.wxml2、index.js3、index.wxss1、index.wxml <!--index.wxml--> <view class="container"><view class="code-input-container"><view class="code-input-boxes"><!-- <block wx:for="{{…...
sql中case when若条件重复 执行的顺序
sql case when若条件重复 执行的顺序 在 SQL 中,如果你在 CASE 表达式中定义了多个 WHEN 子句,并且这些条件有重叠,那么 CASE 表达式的执行顺序遵循以下规则: (1)从上到下:SQL 引擎会按照 CASE …...
压力测试Jmeter简介
前提条件:要安装JDK 若不需要了解,请直接定位到左侧目录的安装环节。 1.引言 在现代软件开发中,性能和稳定性是衡量系统质量的重要指标。为了确保应用程序在高负载情况下仍能正常运行,压力测试变得尤为重要。Apache JMeter 是一…...
变量 varablie 声明- Rust 变量 let mut 声明与 C/C++ 变量声明对比分析
一、变量声明设计:let 与 mut 的哲学解析 Rust 采用 let 声明变量并通过 mut 显式标记可变性,这种设计体现了语言的核心哲学。以下是深度解析: 1.1 设计理念剖析 安全优先原则:默认不可变强制开发者明确声明意图 let x 5; …...
C++_核心编程_多态案例二-制作饮品
#include <iostream> #include <string> using namespace std;/*制作饮品的大致流程为:煮水 - 冲泡 - 倒入杯中 - 加入辅料 利用多态技术实现本案例,提供抽象制作饮品基类,提供子类制作咖啡和茶叶*//*基类*/ class AbstractDr…...
java 实现excel文件转pdf | 无水印 | 无限制
文章目录 目录 文章目录 前言 1.项目远程仓库配置 2.pom文件引入相关依赖 3.代码破解 二、Excel转PDF 1.代码实现 2.Aspose.License.xml 授权文件 总结 前言 java处理excel转pdf一直没找到什么好用的免费jar包工具,自己手写的难度,恐怕高级程序员花费一年的事件,也…...
【快手拥抱开源】通过快手团队开源的 KwaiCoder-AutoThink-preview 解锁大语言模型的潜力
引言: 在人工智能快速发展的浪潮中,快手Kwaipilot团队推出的 KwaiCoder-AutoThink-preview 具有里程碑意义——这是首个公开的AutoThink大语言模型(LLM)。该模型代表着该领域的重大突破,通过独特方式融合思考与非思考…...
OkHttp 中实现断点续传 demo
在 OkHttp 中实现断点续传主要通过以下步骤完成,核心是利用 HTTP 协议的 Range 请求头指定下载范围: 实现原理 Range 请求头:向服务器请求文件的特定字节范围(如 Range: bytes1024-) 本地文件记录:保存已…...
汇编常见指令
汇编常见指令 一、数据传送指令 指令功能示例说明MOV数据传送MOV EAX, 10将立即数 10 送入 EAXMOV [EBX], EAX将 EAX 值存入 EBX 指向的内存LEA加载有效地址LEA EAX, [EBX4]将 EBX4 的地址存入 EAX(不访问内存)XCHG交换数据XCHG EAX, EBX交换 EAX 和 EB…...
DeepSeek 技术赋能无人农场协同作业:用 AI 重构农田管理 “神经网”
目录 一、引言二、DeepSeek 技术大揭秘2.1 核心架构解析2.2 关键技术剖析 三、智能农业无人农场协同作业现状3.1 发展现状概述3.2 协同作业模式介绍 四、DeepSeek 的 “农场奇妙游”4.1 数据处理与分析4.2 作物生长监测与预测4.3 病虫害防治4.4 农机协同作业调度 五、实际案例大…...
HashMap中的put方法执行流程(流程图)
1 put操作整体流程 HashMap 的 put 操作是其最核心的功能之一。在 JDK 1.8 及以后版本中,其主要逻辑封装在 putVal 这个内部方法中。整个过程大致如下: 初始判断与哈希计算: 首先,putVal 方法会检查当前的 table(也就…...
AGain DB和倍数增益的关系
我在设置一款索尼CMOS芯片时,Again增益0db变化为6DB,画面的变化只有2倍DN的增益,比如10变为20。 这与dB和线性增益的关系以及传感器处理流程有关。以下是具体原因分析: 1. dB与线性增益的换算关系 6dB对应的理论线性增益应为&…...
Python Einops库:深度学习中的张量操作革命
Einops(爱因斯坦操作库)就像给张量操作戴上了一副"语义眼镜"——让你用人类能理解的方式告诉计算机如何操作多维数组。这个基于爱因斯坦求和约定的库,用类似自然语言的表达式替代了晦涩的API调用,彻底改变了深度学习工程…...
