当前位置: 首页 > news >正文

Pytorch | 从零构建ResNet对CIFAR10进行分类

Pytorch | 从零构建ResNet对CIFAR10进行分类

  • CIFAR10数据集
  • ResNet
      • 核心思想
      • 网络结构
      • 创新点
      • 优点
      • 应用
  • ResNet结构代码详解
    • 结构代码
    • 代码详解
      • BasicBlock 类
      • ResNet 类
      • ResNet18、ResNet34、ResNet50、ResNet101、ResNet152函数
  • 训练过程和测试结果
  • 代码汇总
    • resnet.py
    • train.py
    • test.py

前面文章我们构建了AlexNet、Vgg、GoogleNet对CIFAR10进行分类:
Pytorch | 从零构建AlexNet对CIFAR10进行分类
Pytorch | 从零构建Vgg对CIFAR10进行分类
Pytorch | 从零构建GoogleNet对CIFAR10进行分类
这篇文章我们来构建ResNet.

CIFAR10数据集

CIFAR-10数据集是由加拿大高级研究所(CIFAR)收集整理的用于图像识别研究的常用数据集,基本信息如下:

  • 数据规模:该数据集包含60,000张彩色图像,分为10个不同的类别,每个类别有6,000张图像。通常将其中50,000张作为训练集,用于模型的训练;10,000张作为测试集,用于评估模型的性能。
  • 图像尺寸:所有图像的尺寸均为32×32像素,这相对较小的尺寸使得模型在处理该数据集时能够相对快速地进行训练和推理,但也增加了图像分类的难度。
  • 类别内容:涵盖了飞机(plane)、汽车(car)、鸟(bird)、猫(cat)、鹿(deer)、狗(dog)、青蛙(frog)、马(horse)、船(ship)、卡车(truck)这10个不同的类别,这些类别都是现实世界中常见的物体,具有一定的代表性。

下面是一些示例样本:
在这里插入图片描述

ResNet

ResNet(Residual Network)即残差网络,是由微软研究院的何恺明等人在2015年提出的一种深度卷积神经网络架构,它在ILSVRC 2015图像识别挑战赛中取得了优异成绩,在图像分类、目标检测、语义分割等计算机视觉任务中具有广泛应用。以下是对ResNet的详细介绍:

核心思想

  • 解决梯度消失和退化问题:随着神经网络层数的增加,会出现梯度消失或梯度爆炸问题,导致模型难以训练。同时,还会出现网络退化现象,即增加网络层数后,准确率反而下降。ResNet的核心思想是引入残差连接(Residual Connection),通过跨层的shortcut连接,将输入直接传递到后面的层,使得后面的层可以学习到输入的残差,从而缓解了梯度消失和网络退化问题。

网络结构

  • 基本残差块:ResNet的基本组成单元是残差块(Residual Block)。一个典型的残差块包含两个3×3卷积层,中间有一个ReLU激活函数,并且在第二个卷积层之后也有一个ReLU激活函数。输入通过一个shortcut连接直接与残差块的输出相加,形成残差学习。
  • 不同层数的架构:ResNet有多种不同层数的架构,如ResNet-18、ResNet-34、ResNet-50、ResNet-101和ResNet-152等。其中,数字表示网络中卷积层和全连接层的总层数。层数越深,模型的表示能力越强,但计算成本也越高。

创新点

  • 瓶颈结构:在ResNet-50及更深的网络中,采用了瓶颈结构(Bottleneck)的残差块。这种结构先使用1×1卷积层进行降维,然后使用3×3卷积层进行特征提取,最后再使用1×1卷积层进行升维,这样可以在减少计算量的同时增加网络的深度和宽度,提高模型的性能。
  • 全局平均池化:在网络的最后一层,ResNet采用了全局平均池化(Global Average Pooling)代替传统的全连接层进行分类。全局平均池化可以将每个特征图的空间维度压缩为一个值,得到一个固定长度的特征向量,然后直接输入到分类器中进行分类。

优点

  • 训练深度网络更容易:残差连接使得梯度能够更有效地在网络中传播,大大降低了训练深度网络的难度,使得可以成功训练上百层甚至上千层的网络。
  • 性能出色:在各种图像识别任务中,ResNet都取得了非常出色的性能,相比之前的网络结构,具有更高的准确率和更好的泛化能力。
  • 模型可扩展性强:可以方便地通过增加残差块的数量来扩展网络的深度,以适应不同的任务和数据集需求。

应用

  • 图像分类:ResNet在图像分类任务中取得了巨大成功,如在ImageNet数据集上达到了很高的准确率,成为了图像分类领域的主流模型之一。
  • 目标检测:与其他目标检测算法结合,如Faster R-CNN、YOLO等,通过提取图像的特征,提高目标检测的准确率和召回率。
  • 语义分割:用于对图像进行像素级的分类,将图像中的每个像素分配到不同的类别中,在城市景观分割、医学图像分割等领域有广泛应用。

ResNet结构代码详解

结构代码

import torch
import torch.nn as nnclass BasicBlock(nn.Module):expansion = 1def __init__(self, in_channels, out_channels, stride=1):super(BasicBlock, self).__init__()self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)self.bn1 = nn.BatchNorm2d(out_channels)self.relu = nn.ReLU(inplace=True)self.conv2 = nn.Conv2d(out_channels, out_channels * BasicBlock.expansion, kernel_size=3, padding=1, bias=False)self.bn2 = nn.BatchNorm2d(out_channels * BasicBlock.expansion)self.shortcut = nn.Sequential()if stride != 1 or in_channels != out_channels * BasicBlock.expansion:self.shortcut = nn.Sequential(nn.Conv2d(in_channels, out_channels * BasicBlock.expansion, kernel_size=1, stride=stride, bias=False),nn.BatchNorm2d(out_channels * BasicBlock.expansion))def forward(self, x):out = self.conv1(x)out = self.bn1(out)out = self.relu(out)out = self.conv2(out)out = self.bn2(out)out += self.shortcut(x)out = self.relu(out)return outclass ResNet(nn.Module):def __init__(self, block, num_blocks, num_classes):super(ResNet, self).__init__()self.in_channels = 64self.conv1 = nn.Conv2d(3, 64, kernel_size=3, padding=1, bias=False)self.bn1 = nn.BatchNorm2d(64)self.relu = nn.ReLU(inplace=True)self.maxpool = nn.MaxPool2d(kernel_size=3, stride=1, padding=1)self.layer1 = self._make_layer(block, 64, num_blocks[0], 1)self.layer2 = self._make_layer(block, 128, num_blocks[1], 2)self.layer3 = self._make_layer(block, 256, num_blocks[2], 2)self.layer4 = self._make_layer(block, 512, num_blocks[3], 2)self.avgpool = nn.AdaptiveAvgPool2d((1, 1))self.fc = nn.Linear(512 * block.expansion, num_classes)def _make_layer(self, block, out_channels, num_blocks, stride=1):strides = [stride] + [1] * (num_blocks - 1)layers = []for stride in strides:layers.append(block(self.in_channels, out_channels, stride))self.in_channels = out_channels * block.expansionreturn nn.Sequential(*layers)def forward(self, x):out = self.conv1(x)out = self.bn1(out)out = self.relu(out)out = self.maxpool(out)out = self.layer1(out)out = self.layer2(out)out = self.layer3(out)out = self.layer4(out)out = self.avgpool(out)out = out.view(out.size(0), -1)out = self.fc(out)return out# ResNet18, ResNet34
def ResNet18(num_classes):return ResNet(BasicBlock, [2, 2, 2, 2], num_classes)def ResNet34(num_classes):return ResNet(BasicBlock, [3, 4, 6, 3], num_classes)# ResNet50, ResNet101, ResNet152 需要 BottleNeck 
class Bottleneck(nn.Module):expansion = 4def __init__(self, in_channels, out_channels, stride=1):super(Bottleneck, self).__init__()self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False)self.bn1= nn.BatchNorm2d(out_channels)self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)self.bn2 = nn.BatchNorm2d(out_channels)self.conv3 = nn.Conv2d(out_channels, out_channels * self.expansion, kernel_size=1, bias=False)self.bn3 = nn.BatchNorm2d(out_channels * self.expansion)self.relu = nn.ReLU(inplace=True)self.shortcut = nn.Sequential()if stride != 1 or in_channels != out_channels * self.expansion:self.shortcut = nn.Sequential(nn.Conv2d(in_channels, out_channels * self.expansion, kernel_size=1, stride=stride, bias=False),nn.BatchNorm2d(out_channels * self.expansion))def forward(self, x):out = self.conv1(x)out = self.bn1(out)out = self.relu(out)out = self.conv2(out)out = self.bn2(out)out = self.relu(out)out = self.conv3(out)out = self.bn3(out)out += self.shortcut(x)out = self.relu(out)return outdef ResNet50(num_classes):return ResNet(Bottleneck, [3, 4, 6, 3], num_classes)def ResNet101(num_classes):return ResNet(Bottleneck, [3, 4, 23, 3], num_classes)def ResNet152(num_classes):return ResNet(Bottleneck, [3, 8, 36, 3], num_classes)

代码详解

以下是对上述提供的PyTorch代码的详细解释,这段代码实现了经典的ResNet(残差网络)系列模型,包括ResNet-18、ResNet-34、ResNet-50、ResNet-101和ResNet-152等不同深度的网络架构:

BasicBlock 类

class BasicBlock(nn.Module):expansion = 1def __init__(self, in_channels, out_channels, stride=1):super(BasicBlock, self).__init__()self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)self.bn1 = nn.BatchNorm2d(out_channels)self.relu = nn.ReLU(inplace=True)self.conv2 = nn.Conv2d(out_channels, out_channels * BasicBlock.expansion, kernel_size=3, padding=1, bias=False)self.bn2 = nn.BatchNorm2d(out_channels * BasicBlock.expansion)self.shortcut = nn.Sequential()if stride!= 1 or in_channels!= out_channels * BasicBlock.expansion:self.shortcut = nn.Sequential(nn.Conv2d(in_channels, out_channels * BasicBlock.expansion, kernel_size=1, stride=stride, bias=False),nn.BatchNorm2d(out_channels * BasicBlock.expansion))
  • 类定义与属性
    • 定义了一个名为BasicBlock的类,继承自nn.Module,这是PyTorch中定义神经网络模块的基类。
    • expansion属性被设置为1,用于表示该基本块在通道维度上的扩展倍数,在BasicBlock中通道数不会进行额外的扩展(后续的Bottleneck块会有不同的扩展倍数)。
  • 初始化方法__init__
    • 首先调用父类nn.Module的初始化方法super(BasicBlock, self).__init__(),确保模块正确初始化。
    • 定义了两个卷积层conv1conv2
      • conv1:输入通道数为in_channels,输出通道数为out_channels,卷积核大小为3×3,步长为stride,填充为1,并且不使用偏置(bias=False),这是遵循ResNet论文中的实现方式,通常配合后续的BatchNorm使用。
      • conv2:输入通道数为out_channels,输出通道数为out_channels * BasicBlock.expansion(实际就是out_channels,因为expansion1),卷积核大小同样是3×3,填充为1,无偏置。
    • 定义了两个BatchNorm2dbn1bn2,分别对应两个卷积层之后,用于对卷积后的特征进行归一化处理,有助于加速训练和提高模型的稳定性。
    • 定义了一个ReLU激活函数relu,并且设置inplace=True,表示直接在原张量上进行激活操作,节省内存空间(但要注意使用不当可能导致梯度计算问题,如前面提到的错误情况)。
    • 定义了shortcut,初始化为一个空的nn.Sequential序列。当输入和输出的通道数不一致或者步长不为1时(意味着尺寸或通道数有变化),会重新构建shortcut,使其包含一个1×1卷积层(用于调整通道数)和一个BatchNorm2d层,以保证shortcut连接的特征维度能与主分支的输出特征维度相匹配,便于后续进行相加操作。
    def forward(self, x):out = self.conv1(x)out = self.bn1(out)out = self.relu(out)out = self.conv2(out)out = self.bn2(out)out += self.shortcut(x)out = self.relu(out)return out
  • 前向传播方法forward
    • 首先将输入x经过conv1卷积、bn1归一化后,再通过relu激活函数得到中间特征。
    • 接着将该中间特征再经过conv2卷积和bn2归一化。
    • 然后将主分支得到的特征outshortcut分支(直接连接输入x经过调整后的特征)进行逐元素相加,实现残差连接的操作。
    • 最后再经过一次relu激活函数后返回结果,作为该基本块的输出。

ResNet 类

class ResNet(nn.Module):def __init__(self, block, num_blocks, num_classes):super(ResNet, self).__init__()self.in_channels = 64self.conv1 = nn.Conv2d(3, 64, kernel_size=3, padding=1, bias=False)self.bn1 = nn.BatchNorm2d(64)self.relu = nn.ReLU(inplace=True)self.maxpool = nn.MaxPool2d(kernel_size=3, stride=1, padding=1)self.layer1 = self._make_layer(block, 64, num_blocks[0], 1)self.layer2 = self._make_layer(block, 128, num_blocks[1], 2)self.layer3 = self._make_layer(block, 256, num_blocks[2], 2)self.layer4 = self._make_layer(block, 512, num_blocks[3], 2)self.avgpool = nn.AdaptiveAvgPool2d((1, 1))self.fc = nn.Linear(512 * block.expansion, num_classes)
  • 类定义与属性
    • 定义了ResNet类,同样继承自nn.Module,用于构建完整的ResNet网络架构。
    • 初始化了一个属性in_channels64,用于记录当前层的输入通道数,后续会动态更新。
    • 定义了网络的起始层,包括一个3×3卷积层conv1(输入通道为3,对应彩色图像的RGB三个通道,输出通道为64),一个BatchNorm2dbn1用于归一化,一个ReLU激活函数relu,以及一个最大池化层maxpool(其参数设置按照常规的ResNet结构配置)。
    • 分别定义了layer1layer2layer3layer4这四层网络结构,它们通过调用_make_layer方法来构建,每层的输出通道数以及重复的块数量由传入的参数决定,并且随着层数加深,步长会相应改变(从第二层开始步长为2,用于逐步减小特征图尺寸)。
    • 定义了一个自适应平均池化层avgpool,它能将输入的特征图尺寸自适应地变为(1, 1)大小,无论输入特征图的尺寸原本是多少,便于后续全连接层处理。最后定义了一个全连接层fc,用于将池化后的特征映射到指定的类别数num_classes上进行分类。
    def _make_layer(self, block, out_channels, num_blocks, stride=1):strides = [stride] + [1] * (num_blocks - 1)layers = []for stride in strides:layers.append(block(self.in_channels, out_channels, stride))self.in_channels = out_channels * block.expansionreturn nn.Sequential(*layers)
  • _make_layer方法
    • 这个方法用于构建ResNet中的每一层网络结构(由多个基本块组成)。
    • 首先根据传入的stridenum_blocks生成一个步长列表strides,例如,如果传入stride=2num_blocks=3,那么strides会是[2, 1, 1],意味着第一个基本块可能会改变特征图的尺寸和通道数,后面的基本块保持步长为1
    • 然后循环遍历strides列表,每次创建一个指定的block(可以是BasicBlock或者后续定义的Bottleneck块),并传入当前的输入通道数、输出通道数以及对应的步长,将创建好的块添加到layers列表中。同时,更新self.in_channels为当前块输出的通道数(考虑了块的扩展倍数)。
    • 最后将layers列表中的所有块组合成一个nn.Sequential序列并返回,形成一层完整的网络结构。
    def forward(self, x):out = self.conv1(x)out = self.bn1(out)out = self.relu(out)out = self.maxpool(out)out = self.layer1(out)out = self.layer2(out)out = self.layer3(out)out = self.layer4(out)out = self.avgpool(out)out = out.view(out.size(0), -1)out = self.fc(out)return out
  • 前向传播方法forward
    • 首先将输入x依次经过网络起始层的卷积、归一化、激活和池化操作,得到初步的特征表示。
    • 然后将该特征依次通过layer1layer2layer3layer4这四层网络结构,不断提取和融合特征,每一层都会进一步加深特征的抽象程度并且改变特征图的尺寸和通道数。
    • 接着经过自适应平均池化层avgpool,将特征图变为(1, 1)大小的特征向量。
    • 通过out.view(out.size(0), -1)操作将特征向量展平为一维向量,使其能输入到全连接层fc中。
    • 最后将全连接层的输出作为整个网络的最终输出,返回分类结果。

ResNet18、ResNet34、ResNet50、ResNet101、ResNet152函数

# ResNet18, ResNet34
def ResNet18(num_classes):return ResNet(BasicBlock, [2, 2, 2, 2], num_classes)def ResNet34(num_classes):return ResNet(BasicBlock, [3, 4, 6, 3], num_classes)
  • 这两个函数分别用于创建ResNet-18ResNet-34网络模型。它们通过调用ResNet类的构造函数,传入BasicBlock作为构建块类型,以及对应不同层数的重复块数量列表(如ResNet-18中每层分别重复2个基本块),还有指定的类别数num_classes,最终返回构建好的相应深度的ResNet模型实例,用于图像分类等任务。
# ResNet50, ResNet101, ResNet152 需要 BottleNeck 
class Bottleneck(nn.Module):expansion = 4def __init__(self, in_channels, out_channels, stride=1):super(Bottleneck, self).__init__()self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False)self.bn1= nn.BatchNorm2d(out_channels)self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)self.bn2 = nn.BatchNorm2d(out_channels)self.conv3 = nn.Conv2d(out_channels, out_channels * self.expansion, kernel_size=1, bias=False)self.bn3 = nn.BatchNorm2d(out_channels * self.expansion)self.relu = nn.ReLU(inplace=True)self.shortcut = nn.Sequential()if stride!= 1 or in_channels!= out_channels * self.expansion:self.shortcut = nn.Sequential(nn.Conv2d(in_channels, out_channels * self.expansion, kernel_size=1, stride=stride, bias=False),nn.BatchNorm2d(out_channels * self.expansion))
  • Bottleneck类定义与初始化
    • 定义了Bottleneck类,同样继承自nn.Module,用于构建更深层的ResNet网络(如ResNet-50及以上)中的基本块。
    • expansion属性被设置为4,意味着该块在经过一系列操作后,输出通道数会是输入通道数的4倍,通过这种方式在增加网络深度的同时控制计算量。
    • 在初始化方法中,定义了三个卷积层conv1conv2conv3,分别是1×1卷积用于降维、3×3卷积进行主要的特征提取、1×1卷积用于升维,并且每个卷积层后都有对应的BatchNorm2d层进行归一化,还有ReLU激活函数用于激活中间特征。
    • 同样定义了shortcut,其构建逻辑和BasicBlock中类似,根据输入输出通道数和步长情况来决定是否需要构建包含1×1卷积和BatchNorm2d层的调整结构,以保证残差连接的维度匹配。
    def forward(self, x):out = self.conv1(x)out = self.bn1(out)out = self.relu(out)out = self.conv2(out)out = self.bn2(out)out = self.relu(out)out = self.conv3(out)out = self.bn3(out)out += self.shortcut(x)out = self.relu(out)return out
  • Bottleneck块的前向传播方法
    • 前向传播过程与BasicBlock类似,只是中间经过了三个卷积层及对应的归一化和激活操作,最后同样是将主分支特征与shortcut分支特征相加后再经过ReLU激活函数输出,实现残差学习。
def ResNet50(num_classes):return ResNet(Bottleneck, [3, 4, 6, 3], num_classes)def ResNet101(num_classes):return ResNet(Bottleneck, [3, 4, 23, 3], num_classes)def ResNet152(num_classes):return ResNet(Bottleneck, [3, 8, 36, 3], num_classes)
  • 这几个函数分别用于创建ResNet-50ResNet-101ResNet-152网络模型,它们与创建ResNet-18ResNet-34的函数类似,只是传入的构建块类型变为Bottleneck,以及对应不同层数的重复Bottleneck块数量列表,还有指定的类别数num_classes,最终返回相应深度的ResNet模型实例,用于更复杂的图像分类等任务,这些更深层的网络结构在处理大规模图像数据集时往往能取得更好的性能表现。

训练过程和测试结果

训练过程损失函数变化曲线:
在这里插入图片描述

训练过程准确率变化曲线:

在这里插入图片描述

测试结果:
在这里插入图片描述

代码汇总

项目github地址
项目结构:

|--data
|--models|--__init__.py|-resnet.py|--...
|--results
|--weights
|--train.py
|--test.py

resnet.py

import torch
import torch.nn as nnclass BasicBlock(nn.Module):expansion = 1def __init__(self, in_channels, out_channels, stride=1):super(BasicBlock, self).__init__()self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)self.bn1 = nn.BatchNorm2d(out_channels)self.relu = nn.ReLU(inplace=True)self.conv2 = nn.Conv2d(out_channels, out_channels * BasicBlock.expansion, kernel_size=3, padding=1, bias=False)self.bn2 = nn.BatchNorm2d(out_channels * BasicBlock.expansion)self.shortcut = nn.Sequential()if stride != 1 or in_channels != out_channels * BasicBlock.expansion:self.shortcut = nn.Sequential(nn.Conv2d(in_channels, out_channels * BasicBlock.expansion, kernel_size=1, stride=stride, bias=False),nn.BatchNorm2d(out_channels * BasicBlock.expansion))def forward(self, x):out = self.conv1(x)out = self.bn1(out)out = self.relu(out)out = self.conv2(out)out = self.bn2(out)out += self.shortcut(x)out = self.relu(out)return outclass ResNet(nn.Module):def __init__(self, block, num_blocks, num_classes):super(ResNet, self).__init__()self.in_channels = 64self.conv1 = nn.Conv2d(3, 64, kernel_size=3, padding=1, bias=False)self.bn1 = nn.BatchNorm2d(64)self.relu = nn.ReLU(inplace=True)self.maxpool = nn.MaxPool2d(kernel_size=3, stride=1, padding=1)self.layer1 = self._make_layer(block, 64, num_blocks[0], 1)self.layer2 = self._make_layer(block, 128, num_blocks[1], 2)self.layer3 = self._make_layer(block, 256, num_blocks[2], 2)self.layer4 = self._make_layer(block, 512, num_blocks[3], 2)self.avgpool = nn.AdaptiveAvgPool2d((1, 1))self.fc = nn.Linear(512 * block.expansion, num_classes)def _make_layer(self, block, out_channels, num_blocks, stride=1):strides = [stride] + [1] * (num_blocks - 1)layers = []for stride in strides:layers.append(block(self.in_channels, out_channels, stride))self.in_channels = out_channels * block.expansionreturn nn.Sequential(*layers)def forward(self, x):out = self.conv1(x)out = self.bn1(out)out = self.relu(out)out = self.maxpool(out)out = self.layer1(out)out = self.layer2(out)out = self.layer3(out)out = self.layer4(out)out = self.avgpool(out)out = out.view(out.size(0), -1)out = self.fc(out)return out# ResNet18, ResNet34
def ResNet18(num_classes):return ResNet(BasicBlock, [2, 2, 2, 2], num_classes)def ResNet34(num_classes):return ResNet(BasicBlock, [3, 4, 6, 3], num_classes)# ResNet50, ResNet101, ResNet152 需要 BottleNeck 
class Bottleneck(nn.Module):expansion = 4def __init__(self, in_channels, out_channels, stride=1):super(Bottleneck, self).__init__()self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False)self.bn1= nn.BatchNorm2d(out_channels)self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)self.bn2 = nn.BatchNorm2d(out_channels)self.conv3 = nn.Conv2d(out_channels, out_channels * self.expansion, kernel_size=1, bias=False)self.bn3 = nn.BatchNorm2d(out_channels * self.expansion)self.relu = nn.ReLU(inplace=True)self.shortcut = nn.Sequential()if stride != 1 or in_channels != out_channels * self.expansion:self.shortcut = nn.Sequential(nn.Conv2d(in_channels, out_channels * self.expansion, kernel_size=1, stride=stride, bias=False),nn.BatchNorm2d(out_channels * self.expansion))def forward(self, x):out = self.conv1(x)out = self.bn1(out)out = self.relu(out)out = self.conv2(out)out = self.bn2(out)out = self.relu(out)out = self.conv3(out)out = self.bn3(out)out += self.shortcut(x)out = self.relu(out)return outdef ResNet50(num_classes):return ResNet(Bottleneck, [3, 4, 6, 3], num_classes)def ResNet101(num_classes):return ResNet(Bottleneck, [3, 4, 23, 3], num_classes)def ResNet152(num_classes):return ResNet(Bottleneck, [3, 8, 36, 3], num_classes)

train.py

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from models import *
import matplotlib.pyplot as pltimport ssl
ssl._create_default_https_context = ssl._create_unverified_context# 定义数据预处理操作
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.491, 0.482, 0.446), (0.247, 0.243, 0.261))])# 加载CIFAR10训练集
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,download=False, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=128,shuffle=True, num_workers=2)# 定义设备(GPU优先,若可用)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")# 实例化模型
model_name = 'ResNet18'
if model_name == 'AlexNet':model = AlexNet(num_classes=10).to(device)
elif model_name == 'Vgg_A':model = Vgg(cfg_vgg='A', num_classes=10).to(device)
elif model_name == 'Vgg_A-LRN':model = Vgg(cfg_vgg='A-LRN', num_classes=10).to(device)
elif model_name == 'Vgg_B':model = Vgg(cfg_vgg='B', num_classes=10).to(device)
elif model_name == 'Vgg_C':model = Vgg(cfg_vgg='C', num_classes=10).to(device)
elif model_name == 'Vgg_D':model = Vgg(cfg_vgg='D', num_classes=10).to(device)
elif model_name == 'Vgg_E':model = Vgg(cfg_vgg='E', num_classes=10).to(device)
elif model_name == 'GoogleNet':model = GoogleNet(num_classes=10).to(device)
elif model_name == 'ResNet18':model = ResNet18(num_classes=10).to(device)
elif model_name == 'ResNet34':model = ResNet34(num_classes=10).to(device)
elif model_name == 'ResNet50':model = ResNet50(num_classes=10).to(device)
elif model_name == 'ResNet101':model = ResNet101(num_classes=10).to(device)
elif model_name == 'ResNet152':model = ResNet152(num_classes=10).to(device)criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)# 训练轮次
epochs = 15def train(model, trainloader, criterion, optimizer, device):model.train()running_loss = 0.0correct = 0total = 0for i, data in enumerate(trainloader, 0):inputs, labels = data[0].to(device), data[1].to(device)optimizer.zero_grad()outputs = model(inputs)loss = criterion(outputs, labels)loss.backward()optimizer.step()running_loss += loss.item()_, predicted = outputs.max(1)total += labels.size(0)correct += predicted.eq(labels).sum().item()epoch_loss = running_loss / len(trainloader)epoch_acc = 100. * correct / totalreturn epoch_loss, epoch_accif __name__ == "__main__":loss_history, acc_history = [], []for epoch in range(epochs):train_loss, train_acc = train(model, trainloader, criterion, optimizer, device)print(f'Epoch {epoch + 1}: Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}%')loss_history.append(train_loss)acc_history.append(train_acc)# 保存模型权重,每5轮次保存到weights文件夹下if (epoch + 1) % 5 == 0:torch.save(model.state_dict(), f'weights/{model_name}_epoch_{epoch + 1}.pth')# 绘制损失曲线plt.plot(range(1, epochs+1), loss_history, label='Loss', marker='o')plt.xlabel('Epoch')plt.ylabel('Loss')plt.title('Training Loss Curve')plt.legend()plt.savefig(f'results\\{model_name}_train_loss_curve.png')plt.close()# 绘制准确率曲线plt.plot(range(1, epochs+1), acc_history, label='Accuracy', marker='o')plt.xlabel('Epoch')plt.ylabel('Accuracy (%)')plt.title('Training Accuracy Curve')plt.legend()plt.savefig(f'results\\{model_name}_train_acc_curve.png')plt.close()

test.py

import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
from models import *import ssl
ssl._create_default_https_context = ssl._create_unverified_context
# 定义数据预处理操作
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.491, 0.482, 0.446), (0.247, 0.243, 0.261))])# 加载CIFAR10测试集
testset = torchvision.datasets.CIFAR10(root='./data', train=False,download=False, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=128,shuffle=False, num_workers=2)# 定义设备(GPU优先,若可用)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")# 实例化模型
model_name = 'ResNet18'
if model_name == 'AlexNet':model = AlexNet(num_classes=10).to(device)
elif model_name == 'Vgg_A':model = Vgg(cfg_vgg='A', num_classes=10).to(device)
elif model_name == 'Vgg_A-LRN':model = Vgg(cfg_vgg='A-LRN', num_classes=10).to(device)
elif model_name == 'Vgg_B':model = Vgg(cfg_vgg='B', num_classes=10).to(device)
elif model_name == 'Vgg_C':model = Vgg(cfg_vgg='C', num_classes=10).to(device)
elif model_name == 'Vgg_D':model = Vgg(cfg_vgg='D', num_classes=10).to(device)
elif model_name == 'Vgg_E':model = Vgg(cfg_vgg='E', num_classes=10).to(device)
elif model_name == 'GoogleNet':model = GoogleNet(num_classes=10).to(device)
elif model_name == 'ResNet18':model = ResNet18(num_classes=10).to(device)
elif model_name == 'ResNet34':model = ResNet34(num_classes=10).to(device)
elif model_name == 'ResNet50':model = ResNet50(num_classes=10).to(device)
elif model_name == 'ResNet101':model = ResNet101(num_classes=10).to(device)
elif model_name == 'ResNet152':model = ResNet152(num_classes=10).to(device)criterion = nn.CrossEntropyLoss()# 加载模型权重
weights_path = f"weights/{model_name}_epoch_15.pth"  
model.load_state_dict(torch.load(weights_path, map_location=device))def test(model, testloader, criterion, device):model.eval()running_loss = 0.0correct = 0total = 0with torch.no_grad():for data in testloader:inputs, labels = data[0].to(device), data[1].to(device)outputs = model(inputs)loss = criterion(outputs, labels)running_loss += loss.item()_, predicted = outputs.max(1)total += labels.size(0)correct += predicted.eq(labels).sum().item()epoch_loss = running_loss / len(testloader)epoch_acc = 100. * correct / totalreturn epoch_loss, epoch_accif __name__ == "__main__":test_loss, test_acc = test(model, testloader, criterion, device)print(f"================{model_name} Test================")print(f"Load Model Weights From: {weights_path}")print(f'Test Loss: {test_loss:.4f}, Test Acc: {test_acc:.2f}%')

相关文章:

Pytorch | 从零构建ResNet对CIFAR10进行分类

Pytorch | 从零构建ResNet对CIFAR10进行分类 CIFAR10数据集ResNet核心思想网络结构创新点优点应用 ResNet结构代码详解结构代码代码详解BasicBlock 类ResNet 类ResNet18、ResNet34、ResNet50、ResNet101、ResNet152函数 训练过程和测试结果代码汇总resnet.pytrain.pytest.py 前…...

Spring Boot 配置Kafka

1 Kafka Kafka 是由 Linkedin 公司开发的,它是一个分布式的,支持多分区、多副本,基于 Zookeeper 的分布式消息流平台,它同时也是一款开源的基于发布订阅模式的消息引擎系统。 2 Maven依赖 <dependency><groupId>org.springframework.kafka</groupId><…...

基于单片机的火灾报警器 (论文+源码)

1.系统设计 本系统由火灾检测模块、A/D转换模块、信号处理模块、声光报警模块和灭火装置模块组成。火灾检测模块由温度检测和烟雾检测构成&#xff0c;其温度传感器选用DS18B20&#xff0c;烟雾传感器选用MQ-2烟雾传感器。A/D转换模块选用常用的模数转换芯片ADC0832。声光报警…...

分析excel硕士序列数据提示词——包含对特征的筛选,非0值的过滤

文章目录 1 分析出发点2 围绕出发点的文件分析3 功能模块计算重心相关性计算教学倾向百分比多列相关性计算结果封装证伪——过滤0后的交叉相关系数封装和总控——批量处理特征筛选——筛选提问倾向最大和最小的前五代码总的清洗1 分析出发点 写一个python代码,遍历"D:\Ba…...

MongoDB 更新文档

关于MongoDB更新文档的操作&#xff0c;可以通过多种方法实现。以下是一些常用的方法&#xff1a; updateOne() 方法&#xff1a;用于更新匹配过滤器的单个文档。其语法为 db.collection.updateOne(filter, update, options)。其中&#xff0c;filter 用于查找文档的查询条件&a…...

分布式协同 - 分布式事务_TCC解决方案

文章目录 导图Pre流程图2PC VS 3PC VS TCC2PC&#xff08;Two-Phase Commit&#xff0c;二阶段提交&#xff09;3PC&#xff08;Three-Phase Commit&#xff0c;三阶段提交&#xff09;TCC&#xff08;Try-Confirm-Cancel&#xff09;2PC、3PC与TCC的区别2PC、3PC与TCC的联系 导…...

MFC/C++学习系列之简单记录13

MFC/C学习系列之简单记录13 前言memsetList Control代码注意 总结 前言 今天记录一下memset和List control 的使用吧&#xff01; memset memset通常在初始化变量或清空内存区域的时候使用&#xff0c;可以对变量设定特定的值。 使用&#xff1a; 头文件&#xff1a; C&#…...

PostgreSQL表达式的类型

PostgreSQL表达式是数据库查询中非常重要的组成部分&#xff0c;它们由一个或多个值、运算符和PostgreSQL函数组合而成&#xff0c;用于计算出一个单一的结果。这些表达式类似于公式&#xff0c;可以用查询语言编写&#xff0c;并用于查询数据库中的特定数据集。 PostgreSQL表…...

速通Python 第四节——函数

一、函数 编程中的函数和数学中的函数有一定的相似之处. 数学上的函数, 比如 y sin x , x 取不同的值, y 就会得到不同的结果. 编程中的函数, 是一段 可以被重复使用的代码片段 代码示例 : 求一段范围的数的和 , 不使用函数 # 1. 求 1 - 100 的和 sum 0 for i in range(1, …...

如何在Windows系统上安装和配置Maven

Maven是一个强大的构建和项目管理工具&#xff0c;广泛应用于Java项目的自动化构建、依赖管理、项目构建生命周期控制等方面。在Windows系统上安装Maven并配置环境变量&#xff0c;是开发者开始使用Maven的第一步。本文将详细介绍如何在Windows系统上安装和配置Maven&#xff0…...

STM32之GPIO输出与输出

欢迎来到 破晓的历程的 博客 ⛺️不负时光&#xff0c;不负己✈️ 文章目录 一.GPIO输入1.1GPIP简介1.2GPIO基本结构1.3GPIO位结构1.4GPIO的八种模式1.4.1浮空/上拉/下拉输入1.4.2 模拟输入1.4.3 推挽输出\开漏输出 二.GPIO输入2.1.按键介绍2.2传感器模块介绍2.3按键电路 一.G…...

linux定时器操作

目录 1 简单示例2 timer_create方式2.1 SIGEV_SIGNAL信号方式通知2.2 SIGEV_THREAD启动线程方式通知2.3 参数 1 简单示例 #include <stdio.h> #include <stdlib.h> #include <sys/time.h> #include <signal.h> #include <unistd.h>void setup_t…...

重拾设计模式--观察者模式

文章目录 观察者模式&#xff08;Observer Pattern&#xff09;概述观察者模式UML图作用&#xff1a;实现对象间的解耦支持一对多的依赖关系易于维护和扩展 观察者模式的结构抽象主题&#xff08;Subject&#xff09;&#xff1a;具体主题&#xff08;Concrete Subject&#xf…...

Vue.js前端框架教程7:Vue计算属性和moment.js

文章目录 计算属性(Computed Properties)基本用法缓存机制计算属性 vs 方法使用场景计算属性的 setter 和 getter结论Moment.js 进行时间处理1. 安装 Moment.js2. 在 Vue 组件中引入 Moment.js3. 在全局使用 Moment.js4. 使用 Vue 插件的方式引入 Moment.js5. 常用日期格式化…...

【游戏设计原理】22 - 石头剪刀布

一、游戏基础&#xff1a;拳头、掌心、分指 首先&#xff0c;石头剪刀布&#xff08;又名“Roshambo”&#xff09;看似简单&#xff0c;实际上可是个“深藏玄机”的零和博弈&#xff08;听起来很高深&#xff0c;其实就是输赢相抵消的意思&#xff09;。游戏中有三种手势&…...

3-Gin 渲染 --[Gin 框架入门精讲与实战案例]

在 Gin 框架中&#xff0c;渲染指的是将数据传递给模板&#xff0c;并生成 HTML 或其他格式的响应内容。Gin 支持多种类型的渲染&#xff0c;包括 String HTML、JSON、XML 等。 String 渲染 在 Gin 框架中&#xff0c;String 渲染方法允许你直接返回一个字符串作为 HTTP 响应…...

python小课堂(一)

基础语法 1 常量和表达式2 变量和类型2.1 变量是什么2.2 变量语法 3 变量的类型3.1 动态类型特性 4 注释4.1注释是什么 5 输入输出5.1 print的介绍5.2 input 6 运算符6.1 算术运算符在这里插入图片描述6.2 关系运算符6.3 逻辑运算符6.4赋值运算符 1 常量和表达式 在print()中可…...

GESP202309 二级【小杨的 X 字矩阵】题解(AC)

》》》点我查看「视频」详解》》》 [GESP202309 二级] 小杨的 X 字矩阵 题目描述 小杨想要构造一个 的 X 字矩阵&#xff08; 为奇数&#xff09;&#xff0c;这个矩阵的两条对角线都是半角加号 &#xff0c;其余都是半角减号 - 。例如&#xff0c;一个 5 5 5 \times 5 5…...

@PostConstruct注解解释!!!!

PostConstruct 注解修饰的方法是在 Bean 完成初始化后自动调用的。它是 Java EE 和 Spring 中的一种机制&#xff0c;用于在 Bean 被创建并依赖注入完成后&#xff0c;执行一些初始化的操作。 具体触发时机&#xff1a; 依赖注入完成后&#xff1a;首先&#xff0c;Spring 容器…...

laya游戏引擎中打包之后图片模糊

如下图正常运行没问题&#xff0c;打包之后却模糊 纹理类型中的默认类型都是精灵纹理&#xff0c;改为默认值即可。注意&#xff1a;要点击“应用”才可有效。精灵纹理类型会对图片进行渲染处理&#xff0c;而默认值 平面类型不会处理图片。...

【数据结构练习题】链表与LinkedList

顺序表与链表LinkedList 选择题链表面试题1. 删除链表中等于给定值 val 的所有节点。2. 反转一个单链表。3. 给定一个带有头结点 head 的非空单链表&#xff0c;返回链表的中间结点。如果有两个中间结点&#xff0c;则返回第二个中间结点。4. 输入一个链表&#xff0c;输出该链…...

[项目代码] YOLOv8 遥感航拍飞机和船舶识别 [目标检测]

项目代码下载链接 &#xff1c;项目代码&#xff1e;YOLO 遥感航拍飞机和船舶识别&#xff1c;目标检测&#xff1e;https://download.csdn.net/download/qq_53332949/90163939YOLOv8是一种单阶段&#xff08;one-stage&#xff09;检测算法&#xff0c;它将目标检测问题转化为…...

移动魔百盒中的 OpenWrt作为旁路由 安装Tailscale并配置子网路由实现在外面通过家里的局域网ip访问内网设备

移动魔百盒中的 OpenWrt作为旁路由 安装Tailscale并配置子网路由实现在外面通过家里的局域网ip访问内网设备 一、前提条件 确保路由器硬件支持&#xff1a; OpenWrt 路由器需要足够的存储空间和 CPU 性能来运行 Tailscale。确保设备架构支持 Tailscale 二进制文件&#xff0c;例…...

JVM对象分配内存如何保证线程安全?

大家好&#xff0c;我是锋哥。今天分享关于【JVM对象分配内存如何保证线程安全&#xff1f;】面试题。希望对大家有帮助&#xff1b; JVM对象分配内存如何保证线程安全&#xff1f; 1000道 互联网大厂Java工程师 精选面试题-Java资源分享网 在JVM中&#xff0c;对象的内存分配…...

ArcGIS计算土地转移矩阵

在计算土地转移矩阵时&#xff0c;最常使用的方法就是在ArcGIS中将土地利用栅格数据转为矢量&#xff0c;然后采用叠加分析计算&#xff0c;但这种方法计算效率低。还有一种方法是采用ArcGIS中的栅格计算器&#xff0c;将一个年份的地类编号乘以个100或是1000再加上另一个年份的…...

数据库 MYSQL的概念

数据库的概念 数据库是按照数据结 构来组织、存储和管理数据的系统&#xff0c;它允许用户高效地存储、检索、更新和管理数据 database&#xff1a;用来组织&#xff0c;存储&#xff0c;管理数据的仓库 数据库的管理系统&#xff1a;DBMS&#xff0c;实现对数据的有效储值&am…...

Node.js后端程序打包问题汇总(webpack、rsbuild、fastify、knex、objection、sqlite3、svg-captcha)

背景说明 场景 使用 node.js 进行后端开发&#xff0c;部署时通常需要打包为单文件&#xff0c;然后放到服务器运行。 这里记录我在打包过程中&#xff0c;碰到的各类问题及解决方案&#xff0c;希望能够帮助到更多道友&#x1f604; 提示 此文持续更新&#xff0c;可以收藏⭐…...

部署 Apache Samza 和 Apache Kafka

部署 Apache Samza 和 Apache Kafka 的流处理系统可以分为以下几个步骤,涵盖环境准备、部署细节和生产环境的优化。 1. 环境准备 硬件要求 Kafka Broker:至少 3 台服务器,建议每台服务器配备 4 核 CPU、16GB 内存和高速磁盘。Samza 部署节点:根据任务规模,至少准备 2 台…...

xiaomiR4c openwrt

文章目录 openwrt 安装openwrt 配置开启WiFi 救砖minieap编译参数帮助 openwrt 安装 Router&#xff1a;xiaomi R4C官方固件&#xff1a;openwrt 23.05.5 &#xff08;下图标红处&#xff09;官方教程 下载 OpenWRTInvasionpython remote_command_execution_vulnerability.py …...

leetcode-128.最长连续序列-day14

为什么我感觉上述代码时间复杂度接近O(2n), 虽然有while循环&#xff0c;但是前面有个if判断&#xff0c;能进入while循环的也不多&#xff0c;while循环就相当于两个for循环&#xff0c;但不是嵌套类型的&#xff1a; 变量作用域问题&#xff1a;...