Pytorch | 从零构建ResNet对CIFAR10进行分类
Pytorch | 从零构建ResNet对CIFAR10进行分类
- CIFAR10数据集
- ResNet
- 核心思想
- 网络结构
- 创新点
- 优点
- 应用
- ResNet结构代码详解
- 结构代码
- 代码详解
- BasicBlock 类
- ResNet 类
- ResNet18、ResNet34、ResNet50、ResNet101、ResNet152函数
- 训练过程和测试结果
- 代码汇总
- resnet.py
- train.py
- test.py
前面文章我们构建了AlexNet、Vgg、GoogleNet对CIFAR10进行分类:
Pytorch | 从零构建AlexNet对CIFAR10进行分类
Pytorch | 从零构建Vgg对CIFAR10进行分类
Pytorch | 从零构建GoogleNet对CIFAR10进行分类
这篇文章我们来构建ResNet.
CIFAR10数据集
CIFAR-10数据集是由加拿大高级研究所(CIFAR)收集整理的用于图像识别研究的常用数据集,基本信息如下:
- 数据规模:该数据集包含60,000张彩色图像,分为10个不同的类别,每个类别有6,000张图像。通常将其中50,000张作为训练集,用于模型的训练;10,000张作为测试集,用于评估模型的性能。
- 图像尺寸:所有图像的尺寸均为32×32像素,这相对较小的尺寸使得模型在处理该数据集时能够相对快速地进行训练和推理,但也增加了图像分类的难度。
- 类别内容:涵盖了飞机(plane)、汽车(car)、鸟(bird)、猫(cat)、鹿(deer)、狗(dog)、青蛙(frog)、马(horse)、船(ship)、卡车(truck)这10个不同的类别,这些类别都是现实世界中常见的物体,具有一定的代表性。
下面是一些示例样本:
ResNet
ResNet(Residual Network)即残差网络,是由微软研究院的何恺明等人在2015年提出的一种深度卷积神经网络架构,它在ILSVRC 2015图像识别挑战赛中取得了优异成绩,在图像分类、目标检测、语义分割等计算机视觉任务中具有广泛应用。以下是对ResNet的详细介绍:
核心思想
- 解决梯度消失和退化问题:随着神经网络层数的增加,会出现梯度消失或梯度爆炸问题,导致模型难以训练。同时,还会出现网络退化现象,即增加网络层数后,准确率反而下降。ResNet的核心思想是引入残差连接(Residual Connection),通过跨层的shortcut连接,将输入直接传递到后面的层,使得后面的层可以学习到输入的残差,从而缓解了梯度消失和网络退化问题。
网络结构
- 基本残差块:ResNet的基本组成单元是残差块(Residual Block)。一个典型的残差块包含两个3×3卷积层,中间有一个ReLU激活函数,并且在第二个卷积层之后也有一个ReLU激活函数。输入通过一个shortcut连接直接与残差块的输出相加,形成残差学习。
- 不同层数的架构:ResNet有多种不同层数的架构,如ResNet-18、ResNet-34、ResNet-50、ResNet-101和ResNet-152等。其中,数字表示网络中卷积层和全连接层的总层数。层数越深,模型的表示能力越强,但计算成本也越高。
创新点
- 瓶颈结构:在ResNet-50及更深的网络中,采用了瓶颈结构(Bottleneck)的残差块。这种结构先使用1×1卷积层进行降维,然后使用3×3卷积层进行特征提取,最后再使用1×1卷积层进行升维,这样可以在减少计算量的同时增加网络的深度和宽度,提高模型的性能。
- 全局平均池化:在网络的最后一层,ResNet采用了全局平均池化(Global Average Pooling)代替传统的全连接层进行分类。全局平均池化可以将每个特征图的空间维度压缩为一个值,得到一个固定长度的特征向量,然后直接输入到分类器中进行分类。
优点
- 训练深度网络更容易:残差连接使得梯度能够更有效地在网络中传播,大大降低了训练深度网络的难度,使得可以成功训练上百层甚至上千层的网络。
- 性能出色:在各种图像识别任务中,ResNet都取得了非常出色的性能,相比之前的网络结构,具有更高的准确率和更好的泛化能力。
- 模型可扩展性强:可以方便地通过增加残差块的数量来扩展网络的深度,以适应不同的任务和数据集需求。
应用
- 图像分类:ResNet在图像分类任务中取得了巨大成功,如在ImageNet数据集上达到了很高的准确率,成为了图像分类领域的主流模型之一。
- 目标检测:与其他目标检测算法结合,如Faster R-CNN、YOLO等,通过提取图像的特征,提高目标检测的准确率和召回率。
- 语义分割:用于对图像进行像素级的分类,将图像中的每个像素分配到不同的类别中,在城市景观分割、医学图像分割等领域有广泛应用。
ResNet结构代码详解
结构代码
import torch
import torch.nn as nnclass BasicBlock(nn.Module):expansion = 1def __init__(self, in_channels, out_channels, stride=1):super(BasicBlock, self).__init__()self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)self.bn1 = nn.BatchNorm2d(out_channels)self.relu = nn.ReLU(inplace=True)self.conv2 = nn.Conv2d(out_channels, out_channels * BasicBlock.expansion, kernel_size=3, padding=1, bias=False)self.bn2 = nn.BatchNorm2d(out_channels * BasicBlock.expansion)self.shortcut = nn.Sequential()if stride != 1 or in_channels != out_channels * BasicBlock.expansion:self.shortcut = nn.Sequential(nn.Conv2d(in_channels, out_channels * BasicBlock.expansion, kernel_size=1, stride=stride, bias=False),nn.BatchNorm2d(out_channels * BasicBlock.expansion))def forward(self, x):out = self.conv1(x)out = self.bn1(out)out = self.relu(out)out = self.conv2(out)out = self.bn2(out)out += self.shortcut(x)out = self.relu(out)return outclass ResNet(nn.Module):def __init__(self, block, num_blocks, num_classes):super(ResNet, self).__init__()self.in_channels = 64self.conv1 = nn.Conv2d(3, 64, kernel_size=3, padding=1, bias=False)self.bn1 = nn.BatchNorm2d(64)self.relu = nn.ReLU(inplace=True)self.maxpool = nn.MaxPool2d(kernel_size=3, stride=1, padding=1)self.layer1 = self._make_layer(block, 64, num_blocks[0], 1)self.layer2 = self._make_layer(block, 128, num_blocks[1], 2)self.layer3 = self._make_layer(block, 256, num_blocks[2], 2)self.layer4 = self._make_layer(block, 512, num_blocks[3], 2)self.avgpool = nn.AdaptiveAvgPool2d((1, 1))self.fc = nn.Linear(512 * block.expansion, num_classes)def _make_layer(self, block, out_channels, num_blocks, stride=1):strides = [stride] + [1] * (num_blocks - 1)layers = []for stride in strides:layers.append(block(self.in_channels, out_channels, stride))self.in_channels = out_channels * block.expansionreturn nn.Sequential(*layers)def forward(self, x):out = self.conv1(x)out = self.bn1(out)out = self.relu(out)out = self.maxpool(out)out = self.layer1(out)out = self.layer2(out)out = self.layer3(out)out = self.layer4(out)out = self.avgpool(out)out = out.view(out.size(0), -1)out = self.fc(out)return out# ResNet18, ResNet34
def ResNet18(num_classes):return ResNet(BasicBlock, [2, 2, 2, 2], num_classes)def ResNet34(num_classes):return ResNet(BasicBlock, [3, 4, 6, 3], num_classes)# ResNet50, ResNet101, ResNet152 需要 BottleNeck
class Bottleneck(nn.Module):expansion = 4def __init__(self, in_channels, out_channels, stride=1):super(Bottleneck, self).__init__()self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False)self.bn1= nn.BatchNorm2d(out_channels)self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)self.bn2 = nn.BatchNorm2d(out_channels)self.conv3 = nn.Conv2d(out_channels, out_channels * self.expansion, kernel_size=1, bias=False)self.bn3 = nn.BatchNorm2d(out_channels * self.expansion)self.relu = nn.ReLU(inplace=True)self.shortcut = nn.Sequential()if stride != 1 or in_channels != out_channels * self.expansion:self.shortcut = nn.Sequential(nn.Conv2d(in_channels, out_channels * self.expansion, kernel_size=1, stride=stride, bias=False),nn.BatchNorm2d(out_channels * self.expansion))def forward(self, x):out = self.conv1(x)out = self.bn1(out)out = self.relu(out)out = self.conv2(out)out = self.bn2(out)out = self.relu(out)out = self.conv3(out)out = self.bn3(out)out += self.shortcut(x)out = self.relu(out)return outdef ResNet50(num_classes):return ResNet(Bottleneck, [3, 4, 6, 3], num_classes)def ResNet101(num_classes):return ResNet(Bottleneck, [3, 4, 23, 3], num_classes)def ResNet152(num_classes):return ResNet(Bottleneck, [3, 8, 36, 3], num_classes)
代码详解
以下是对上述提供的PyTorch代码的详细解释,这段代码实现了经典的ResNet(残差网络)系列模型,包括ResNet-18、ResNet-34、ResNet-50、ResNet-101和ResNet-152等不同深度的网络架构:
BasicBlock 类
class BasicBlock(nn.Module):expansion = 1def __init__(self, in_channels, out_channels, stride=1):super(BasicBlock, self).__init__()self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)self.bn1 = nn.BatchNorm2d(out_channels)self.relu = nn.ReLU(inplace=True)self.conv2 = nn.Conv2d(out_channels, out_channels * BasicBlock.expansion, kernel_size=3, padding=1, bias=False)self.bn2 = nn.BatchNorm2d(out_channels * BasicBlock.expansion)self.shortcut = nn.Sequential()if stride!= 1 or in_channels!= out_channels * BasicBlock.expansion:self.shortcut = nn.Sequential(nn.Conv2d(in_channels, out_channels * BasicBlock.expansion, kernel_size=1, stride=stride, bias=False),nn.BatchNorm2d(out_channels * BasicBlock.expansion))
- 类定义与属性:
- 定义了一个名为
BasicBlock
的类,继承自nn.Module
,这是PyTorch中定义神经网络模块的基类。 expansion
属性被设置为1
,用于表示该基本块在通道维度上的扩展倍数,在BasicBlock
中通道数不会进行额外的扩展(后续的Bottleneck
块会有不同的扩展倍数)。
- 定义了一个名为
- 初始化方法
__init__
:- 首先调用父类
nn.Module
的初始化方法super(BasicBlock, self).__init__()
,确保模块正确初始化。 - 定义了两个卷积层
conv1
和conv2
:conv1
:输入通道数为in_channels
,输出通道数为out_channels
,卷积核大小为3×3
,步长为stride
,填充为1
,并且不使用偏置(bias=False
),这是遵循ResNet论文中的实现方式,通常配合后续的BatchNorm
使用。conv2
:输入通道数为out_channels
,输出通道数为out_channels * BasicBlock.expansion
(实际就是out_channels
,因为expansion
为1
),卷积核大小同样是3×3
,填充为1
,无偏置。
- 定义了两个
BatchNorm2d
层bn1
和bn2
,分别对应两个卷积层之后,用于对卷积后的特征进行归一化处理,有助于加速训练和提高模型的稳定性。 - 定义了一个
ReLU
激活函数relu
,并且设置inplace=True
,表示直接在原张量上进行激活操作,节省内存空间(但要注意使用不当可能导致梯度计算问题,如前面提到的错误情况)。 - 定义了
shortcut
,初始化为一个空的nn.Sequential
序列。当输入和输出的通道数不一致或者步长不为1
时(意味着尺寸或通道数有变化),会重新构建shortcut
,使其包含一个1×1
卷积层(用于调整通道数)和一个BatchNorm2d
层,以保证shortcut
连接的特征维度能与主分支的输出特征维度相匹配,便于后续进行相加操作。
- 首先调用父类
def forward(self, x):out = self.conv1(x)out = self.bn1(out)out = self.relu(out)out = self.conv2(out)out = self.bn2(out)out += self.shortcut(x)out = self.relu(out)return out
- 前向传播方法
forward
:- 首先将输入
x
经过conv1
卷积、bn1
归一化后,再通过relu
激活函数得到中间特征。 - 接着将该中间特征再经过
conv2
卷积和bn2
归一化。 - 然后将主分支得到的特征
out
与shortcut
分支(直接连接输入x
经过调整后的特征)进行逐元素相加,实现残差连接的操作。 - 最后再经过一次
relu
激活函数后返回结果,作为该基本块的输出。
- 首先将输入
ResNet 类
class ResNet(nn.Module):def __init__(self, block, num_blocks, num_classes):super(ResNet, self).__init__()self.in_channels = 64self.conv1 = nn.Conv2d(3, 64, kernel_size=3, padding=1, bias=False)self.bn1 = nn.BatchNorm2d(64)self.relu = nn.ReLU(inplace=True)self.maxpool = nn.MaxPool2d(kernel_size=3, stride=1, padding=1)self.layer1 = self._make_layer(block, 64, num_blocks[0], 1)self.layer2 = self._make_layer(block, 128, num_blocks[1], 2)self.layer3 = self._make_layer(block, 256, num_blocks[2], 2)self.layer4 = self._make_layer(block, 512, num_blocks[3], 2)self.avgpool = nn.AdaptiveAvgPool2d((1, 1))self.fc = nn.Linear(512 * block.expansion, num_classes)
- 类定义与属性:
- 定义了
ResNet
类,同样继承自nn.Module
,用于构建完整的ResNet网络架构。 - 初始化了一个属性
in_channels
为64
,用于记录当前层的输入通道数,后续会动态更新。 - 定义了网络的起始层,包括一个
3×3
卷积层conv1
(输入通道为3
,对应彩色图像的RGB三个通道,输出通道为64
),一个BatchNorm2d
层bn1
用于归一化,一个ReLU
激活函数relu
,以及一个最大池化层maxpool
(其参数设置按照常规的ResNet结构配置)。 - 分别定义了
layer1
、layer2
、layer3
、layer4
这四层网络结构,它们通过调用_make_layer
方法来构建,每层的输出通道数以及重复的块数量由传入的参数决定,并且随着层数加深,步长会相应改变(从第二层开始步长为2
,用于逐步减小特征图尺寸)。 - 定义了一个自适应平均池化层
avgpool
,它能将输入的特征图尺寸自适应地变为(1, 1)
大小,无论输入特征图的尺寸原本是多少,便于后续全连接层处理。最后定义了一个全连接层fc
,用于将池化后的特征映射到指定的类别数num_classes
上进行分类。
- 定义了
def _make_layer(self, block, out_channels, num_blocks, stride=1):strides = [stride] + [1] * (num_blocks - 1)layers = []for stride in strides:layers.append(block(self.in_channels, out_channels, stride))self.in_channels = out_channels * block.expansionreturn nn.Sequential(*layers)
_make_layer
方法:- 这个方法用于构建ResNet中的每一层网络结构(由多个基本块组成)。
- 首先根据传入的
stride
和num_blocks
生成一个步长列表strides
,例如,如果传入stride=2
和num_blocks=3
,那么strides
会是[2, 1, 1]
,意味着第一个基本块可能会改变特征图的尺寸和通道数,后面的基本块保持步长为1
。 - 然后循环遍历
strides
列表,每次创建一个指定的block
(可以是BasicBlock
或者后续定义的Bottleneck
块),并传入当前的输入通道数、输出通道数以及对应的步长,将创建好的块添加到layers
列表中。同时,更新self.in_channels
为当前块输出的通道数(考虑了块的扩展倍数)。 - 最后将
layers
列表中的所有块组合成一个nn.Sequential
序列并返回,形成一层完整的网络结构。
def forward(self, x):out = self.conv1(x)out = self.bn1(out)out = self.relu(out)out = self.maxpool(out)out = self.layer1(out)out = self.layer2(out)out = self.layer3(out)out = self.layer4(out)out = self.avgpool(out)out = out.view(out.size(0), -1)out = self.fc(out)return out
- 前向传播方法
forward
:- 首先将输入
x
依次经过网络起始层的卷积、归一化、激活和池化操作,得到初步的特征表示。 - 然后将该特征依次通过
layer1
、layer2
、layer3
、layer4
这四层网络结构,不断提取和融合特征,每一层都会进一步加深特征的抽象程度并且改变特征图的尺寸和通道数。 - 接着经过自适应平均池化层
avgpool
,将特征图变为(1, 1)
大小的特征向量。 - 通过
out.view(out.size(0), -1)
操作将特征向量展平为一维向量,使其能输入到全连接层fc
中。 - 最后将全连接层的输出作为整个网络的最终输出,返回分类结果。
- 首先将输入
ResNet18、ResNet34、ResNet50、ResNet101、ResNet152函数
# ResNet18, ResNet34
def ResNet18(num_classes):return ResNet(BasicBlock, [2, 2, 2, 2], num_classes)def ResNet34(num_classes):return ResNet(BasicBlock, [3, 4, 6, 3], num_classes)
- 这两个函数分别用于创建
ResNet-18
和ResNet-34
网络模型。它们通过调用ResNet
类的构造函数,传入BasicBlock
作为构建块类型,以及对应不同层数的重复块数量列表(如ResNet-18
中每层分别重复2
个基本块),还有指定的类别数num_classes
,最终返回构建好的相应深度的ResNet模型实例,用于图像分类等任务。
# ResNet50, ResNet101, ResNet152 需要 BottleNeck
class Bottleneck(nn.Module):expansion = 4def __init__(self, in_channels, out_channels, stride=1):super(Bottleneck, self).__init__()self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False)self.bn1= nn.BatchNorm2d(out_channels)self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)self.bn2 = nn.BatchNorm2d(out_channels)self.conv3 = nn.Conv2d(out_channels, out_channels * self.expansion, kernel_size=1, bias=False)self.bn3 = nn.BatchNorm2d(out_channels * self.expansion)self.relu = nn.ReLU(inplace=True)self.shortcut = nn.Sequential()if stride!= 1 or in_channels!= out_channels * self.expansion:self.shortcut = nn.Sequential(nn.Conv2d(in_channels, out_channels * self.expansion, kernel_size=1, stride=stride, bias=False),nn.BatchNorm2d(out_channels * self.expansion))
Bottleneck
类定义与初始化:- 定义了
Bottleneck
类,同样继承自nn.Module
,用于构建更深层的ResNet网络(如ResNet-50
及以上)中的基本块。 expansion
属性被设置为4
,意味着该块在经过一系列操作后,输出通道数会是输入通道数的4
倍,通过这种方式在增加网络深度的同时控制计算量。- 在初始化方法中,定义了三个卷积层
conv1
、conv2
、conv3
,分别是1×1
卷积用于降维、3×3
卷积进行主要的特征提取、1×1
卷积用于升维,并且每个卷积层后都有对应的BatchNorm2d
层进行归一化,还有ReLU
激活函数用于激活中间特征。 - 同样定义了
shortcut
,其构建逻辑和BasicBlock
中类似,根据输入输出通道数和步长情况来决定是否需要构建包含1×1
卷积和BatchNorm2d
层的调整结构,以保证残差连接的维度匹配。
- 定义了
def forward(self, x):out = self.conv1(x)out = self.bn1(out)out = self.relu(out)out = self.conv2(out)out = self.bn2(out)out = self.relu(out)out = self.conv3(out)out = self.bn3(out)out += self.shortcut(x)out = self.relu(out)return out
Bottleneck
块的前向传播方法:- 前向传播过程与
BasicBlock
类似,只是中间经过了三个卷积层及对应的归一化和激活操作,最后同样是将主分支特征与shortcut
分支特征相加后再经过ReLU
激活函数输出,实现残差学习。
- 前向传播过程与
def ResNet50(num_classes):return ResNet(Bottleneck, [3, 4, 6, 3], num_classes)def ResNet101(num_classes):return ResNet(Bottleneck, [3, 4, 23, 3], num_classes)def ResNet152(num_classes):return ResNet(Bottleneck, [3, 8, 36, 3], num_classes)
- 这几个函数分别用于创建
ResNet-50
、ResNet-101
和ResNet-152
网络模型,它们与创建ResNet-18
、ResNet-34
的函数类似,只是传入的构建块类型变为Bottleneck
,以及对应不同层数的重复Bottleneck
块数量列表,还有指定的类别数num_classes
,最终返回相应深度的ResNet模型实例,用于更复杂的图像分类等任务,这些更深层的网络结构在处理大规模图像数据集时往往能取得更好的性能表现。
训练过程和测试结果
训练过程损失函数变化曲线:
训练过程准确率变化曲线:
测试结果:
代码汇总
项目github地址
项目结构:
|--data
|--models|--__init__.py|-resnet.py|--...
|--results
|--weights
|--train.py
|--test.py
resnet.py
import torch
import torch.nn as nnclass BasicBlock(nn.Module):expansion = 1def __init__(self, in_channels, out_channels, stride=1):super(BasicBlock, self).__init__()self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)self.bn1 = nn.BatchNorm2d(out_channels)self.relu = nn.ReLU(inplace=True)self.conv2 = nn.Conv2d(out_channels, out_channels * BasicBlock.expansion, kernel_size=3, padding=1, bias=False)self.bn2 = nn.BatchNorm2d(out_channels * BasicBlock.expansion)self.shortcut = nn.Sequential()if stride != 1 or in_channels != out_channels * BasicBlock.expansion:self.shortcut = nn.Sequential(nn.Conv2d(in_channels, out_channels * BasicBlock.expansion, kernel_size=1, stride=stride, bias=False),nn.BatchNorm2d(out_channels * BasicBlock.expansion))def forward(self, x):out = self.conv1(x)out = self.bn1(out)out = self.relu(out)out = self.conv2(out)out = self.bn2(out)out += self.shortcut(x)out = self.relu(out)return outclass ResNet(nn.Module):def __init__(self, block, num_blocks, num_classes):super(ResNet, self).__init__()self.in_channels = 64self.conv1 = nn.Conv2d(3, 64, kernel_size=3, padding=1, bias=False)self.bn1 = nn.BatchNorm2d(64)self.relu = nn.ReLU(inplace=True)self.maxpool = nn.MaxPool2d(kernel_size=3, stride=1, padding=1)self.layer1 = self._make_layer(block, 64, num_blocks[0], 1)self.layer2 = self._make_layer(block, 128, num_blocks[1], 2)self.layer3 = self._make_layer(block, 256, num_blocks[2], 2)self.layer4 = self._make_layer(block, 512, num_blocks[3], 2)self.avgpool = nn.AdaptiveAvgPool2d((1, 1))self.fc = nn.Linear(512 * block.expansion, num_classes)def _make_layer(self, block, out_channels, num_blocks, stride=1):strides = [stride] + [1] * (num_blocks - 1)layers = []for stride in strides:layers.append(block(self.in_channels, out_channels, stride))self.in_channels = out_channels * block.expansionreturn nn.Sequential(*layers)def forward(self, x):out = self.conv1(x)out = self.bn1(out)out = self.relu(out)out = self.maxpool(out)out = self.layer1(out)out = self.layer2(out)out = self.layer3(out)out = self.layer4(out)out = self.avgpool(out)out = out.view(out.size(0), -1)out = self.fc(out)return out# ResNet18, ResNet34
def ResNet18(num_classes):return ResNet(BasicBlock, [2, 2, 2, 2], num_classes)def ResNet34(num_classes):return ResNet(BasicBlock, [3, 4, 6, 3], num_classes)# ResNet50, ResNet101, ResNet152 需要 BottleNeck
class Bottleneck(nn.Module):expansion = 4def __init__(self, in_channels, out_channels, stride=1):super(Bottleneck, self).__init__()self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False)self.bn1= nn.BatchNorm2d(out_channels)self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)self.bn2 = nn.BatchNorm2d(out_channels)self.conv3 = nn.Conv2d(out_channels, out_channels * self.expansion, kernel_size=1, bias=False)self.bn3 = nn.BatchNorm2d(out_channels * self.expansion)self.relu = nn.ReLU(inplace=True)self.shortcut = nn.Sequential()if stride != 1 or in_channels != out_channels * self.expansion:self.shortcut = nn.Sequential(nn.Conv2d(in_channels, out_channels * self.expansion, kernel_size=1, stride=stride, bias=False),nn.BatchNorm2d(out_channels * self.expansion))def forward(self, x):out = self.conv1(x)out = self.bn1(out)out = self.relu(out)out = self.conv2(out)out = self.bn2(out)out = self.relu(out)out = self.conv3(out)out = self.bn3(out)out += self.shortcut(x)out = self.relu(out)return outdef ResNet50(num_classes):return ResNet(Bottleneck, [3, 4, 6, 3], num_classes)def ResNet101(num_classes):return ResNet(Bottleneck, [3, 4, 23, 3], num_classes)def ResNet152(num_classes):return ResNet(Bottleneck, [3, 8, 36, 3], num_classes)
train.py
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from models import *
import matplotlib.pyplot as pltimport ssl
ssl._create_default_https_context = ssl._create_unverified_context# 定义数据预处理操作
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.491, 0.482, 0.446), (0.247, 0.243, 0.261))])# 加载CIFAR10训练集
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,download=False, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=128,shuffle=True, num_workers=2)# 定义设备(GPU优先,若可用)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")# 实例化模型
model_name = 'ResNet18'
if model_name == 'AlexNet':model = AlexNet(num_classes=10).to(device)
elif model_name == 'Vgg_A':model = Vgg(cfg_vgg='A', num_classes=10).to(device)
elif model_name == 'Vgg_A-LRN':model = Vgg(cfg_vgg='A-LRN', num_classes=10).to(device)
elif model_name == 'Vgg_B':model = Vgg(cfg_vgg='B', num_classes=10).to(device)
elif model_name == 'Vgg_C':model = Vgg(cfg_vgg='C', num_classes=10).to(device)
elif model_name == 'Vgg_D':model = Vgg(cfg_vgg='D', num_classes=10).to(device)
elif model_name == 'Vgg_E':model = Vgg(cfg_vgg='E', num_classes=10).to(device)
elif model_name == 'GoogleNet':model = GoogleNet(num_classes=10).to(device)
elif model_name == 'ResNet18':model = ResNet18(num_classes=10).to(device)
elif model_name == 'ResNet34':model = ResNet34(num_classes=10).to(device)
elif model_name == 'ResNet50':model = ResNet50(num_classes=10).to(device)
elif model_name == 'ResNet101':model = ResNet101(num_classes=10).to(device)
elif model_name == 'ResNet152':model = ResNet152(num_classes=10).to(device)criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)# 训练轮次
epochs = 15def train(model, trainloader, criterion, optimizer, device):model.train()running_loss = 0.0correct = 0total = 0for i, data in enumerate(trainloader, 0):inputs, labels = data[0].to(device), data[1].to(device)optimizer.zero_grad()outputs = model(inputs)loss = criterion(outputs, labels)loss.backward()optimizer.step()running_loss += loss.item()_, predicted = outputs.max(1)total += labels.size(0)correct += predicted.eq(labels).sum().item()epoch_loss = running_loss / len(trainloader)epoch_acc = 100. * correct / totalreturn epoch_loss, epoch_accif __name__ == "__main__":loss_history, acc_history = [], []for epoch in range(epochs):train_loss, train_acc = train(model, trainloader, criterion, optimizer, device)print(f'Epoch {epoch + 1}: Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}%')loss_history.append(train_loss)acc_history.append(train_acc)# 保存模型权重,每5轮次保存到weights文件夹下if (epoch + 1) % 5 == 0:torch.save(model.state_dict(), f'weights/{model_name}_epoch_{epoch + 1}.pth')# 绘制损失曲线plt.plot(range(1, epochs+1), loss_history, label='Loss', marker='o')plt.xlabel('Epoch')plt.ylabel('Loss')plt.title('Training Loss Curve')plt.legend()plt.savefig(f'results\\{model_name}_train_loss_curve.png')plt.close()# 绘制准确率曲线plt.plot(range(1, epochs+1), acc_history, label='Accuracy', marker='o')plt.xlabel('Epoch')plt.ylabel('Accuracy (%)')plt.title('Training Accuracy Curve')plt.legend()plt.savefig(f'results\\{model_name}_train_acc_curve.png')plt.close()
test.py
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
from models import *import ssl
ssl._create_default_https_context = ssl._create_unverified_context
# 定义数据预处理操作
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.491, 0.482, 0.446), (0.247, 0.243, 0.261))])# 加载CIFAR10测试集
testset = torchvision.datasets.CIFAR10(root='./data', train=False,download=False, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=128,shuffle=False, num_workers=2)# 定义设备(GPU优先,若可用)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")# 实例化模型
model_name = 'ResNet18'
if model_name == 'AlexNet':model = AlexNet(num_classes=10).to(device)
elif model_name == 'Vgg_A':model = Vgg(cfg_vgg='A', num_classes=10).to(device)
elif model_name == 'Vgg_A-LRN':model = Vgg(cfg_vgg='A-LRN', num_classes=10).to(device)
elif model_name == 'Vgg_B':model = Vgg(cfg_vgg='B', num_classes=10).to(device)
elif model_name == 'Vgg_C':model = Vgg(cfg_vgg='C', num_classes=10).to(device)
elif model_name == 'Vgg_D':model = Vgg(cfg_vgg='D', num_classes=10).to(device)
elif model_name == 'Vgg_E':model = Vgg(cfg_vgg='E', num_classes=10).to(device)
elif model_name == 'GoogleNet':model = GoogleNet(num_classes=10).to(device)
elif model_name == 'ResNet18':model = ResNet18(num_classes=10).to(device)
elif model_name == 'ResNet34':model = ResNet34(num_classes=10).to(device)
elif model_name == 'ResNet50':model = ResNet50(num_classes=10).to(device)
elif model_name == 'ResNet101':model = ResNet101(num_classes=10).to(device)
elif model_name == 'ResNet152':model = ResNet152(num_classes=10).to(device)criterion = nn.CrossEntropyLoss()# 加载模型权重
weights_path = f"weights/{model_name}_epoch_15.pth"
model.load_state_dict(torch.load(weights_path, map_location=device))def test(model, testloader, criterion, device):model.eval()running_loss = 0.0correct = 0total = 0with torch.no_grad():for data in testloader:inputs, labels = data[0].to(device), data[1].to(device)outputs = model(inputs)loss = criterion(outputs, labels)running_loss += loss.item()_, predicted = outputs.max(1)total += labels.size(0)correct += predicted.eq(labels).sum().item()epoch_loss = running_loss / len(testloader)epoch_acc = 100. * correct / totalreturn epoch_loss, epoch_accif __name__ == "__main__":test_loss, test_acc = test(model, testloader, criterion, device)print(f"================{model_name} Test================")print(f"Load Model Weights From: {weights_path}")print(f'Test Loss: {test_loss:.4f}, Test Acc: {test_acc:.2f}%')
相关文章:

Pytorch | 从零构建ResNet对CIFAR10进行分类
Pytorch | 从零构建ResNet对CIFAR10进行分类 CIFAR10数据集ResNet核心思想网络结构创新点优点应用 ResNet结构代码详解结构代码代码详解BasicBlock 类ResNet 类ResNet18、ResNet34、ResNet50、ResNet101、ResNet152函数 训练过程和测试结果代码汇总resnet.pytrain.pytest.py 前…...
Spring Boot 配置Kafka
1 Kafka Kafka 是由 Linkedin 公司开发的,它是一个分布式的,支持多分区、多副本,基于 Zookeeper 的分布式消息流平台,它同时也是一款开源的基于发布订阅模式的消息引擎系统。 2 Maven依赖 <dependency><groupId>org.springframework.kafka</groupId><…...

基于单片机的火灾报警器 (论文+源码)
1.系统设计 本系统由火灾检测模块、A/D转换模块、信号处理模块、声光报警模块和灭火装置模块组成。火灾检测模块由温度检测和烟雾检测构成,其温度传感器选用DS18B20,烟雾传感器选用MQ-2烟雾传感器。A/D转换模块选用常用的模数转换芯片ADC0832。声光报警…...
分析excel硕士序列数据提示词——包含对特征的筛选,非0值的过滤
文章目录 1 分析出发点2 围绕出发点的文件分析3 功能模块计算重心相关性计算教学倾向百分比多列相关性计算结果封装证伪——过滤0后的交叉相关系数封装和总控——批量处理特征筛选——筛选提问倾向最大和最小的前五代码总的清洗1 分析出发点 写一个python代码,遍历"D:\Ba…...
MongoDB 更新文档
关于MongoDB更新文档的操作,可以通过多种方法实现。以下是一些常用的方法: updateOne() 方法:用于更新匹配过滤器的单个文档。其语法为 db.collection.updateOne(filter, update, options)。其中,filter 用于查找文档的查询条件&a…...

分布式协同 - 分布式事务_TCC解决方案
文章目录 导图Pre流程图2PC VS 3PC VS TCC2PC(Two-Phase Commit,二阶段提交)3PC(Three-Phase Commit,三阶段提交)TCC(Try-Confirm-Cancel)2PC、3PC与TCC的区别2PC、3PC与TCC的联系 导…...

MFC/C++学习系列之简单记录13
MFC/C学习系列之简单记录13 前言memsetList Control代码注意 总结 前言 今天记录一下memset和List control 的使用吧! memset memset通常在初始化变量或清空内存区域的时候使用,可以对变量设定特定的值。 使用: 头文件: C&#…...
PostgreSQL表达式的类型
PostgreSQL表达式是数据库查询中非常重要的组成部分,它们由一个或多个值、运算符和PostgreSQL函数组合而成,用于计算出一个单一的结果。这些表达式类似于公式,可以用查询语言编写,并用于查询数据库中的特定数据集。 PostgreSQL表…...

速通Python 第四节——函数
一、函数 编程中的函数和数学中的函数有一定的相似之处. 数学上的函数, 比如 y sin x , x 取不同的值, y 就会得到不同的结果. 编程中的函数, 是一段 可以被重复使用的代码片段 代码示例 : 求一段范围的数的和 , 不使用函数 # 1. 求 1 - 100 的和 sum 0 for i in range(1, …...
如何在Windows系统上安装和配置Maven
Maven是一个强大的构建和项目管理工具,广泛应用于Java项目的自动化构建、依赖管理、项目构建生命周期控制等方面。在Windows系统上安装Maven并配置环境变量,是开发者开始使用Maven的第一步。本文将详细介绍如何在Windows系统上安装和配置Maven࿰…...

STM32之GPIO输出与输出
欢迎来到 破晓的历程的 博客 ⛺️不负时光,不负己✈️ 文章目录 一.GPIO输入1.1GPIP简介1.2GPIO基本结构1.3GPIO位结构1.4GPIO的八种模式1.4.1浮空/上拉/下拉输入1.4.2 模拟输入1.4.3 推挽输出\开漏输出 二.GPIO输入2.1.按键介绍2.2传感器模块介绍2.3按键电路 一.G…...
linux定时器操作
目录 1 简单示例2 timer_create方式2.1 SIGEV_SIGNAL信号方式通知2.2 SIGEV_THREAD启动线程方式通知2.3 参数 1 简单示例 #include <stdio.h> #include <stdlib.h> #include <sys/time.h> #include <signal.h> #include <unistd.h>void setup_t…...

重拾设计模式--观察者模式
文章目录 观察者模式(Observer Pattern)概述观察者模式UML图作用:实现对象间的解耦支持一对多的依赖关系易于维护和扩展 观察者模式的结构抽象主题(Subject):具体主题(Concrete Subject…...
Vue.js前端框架教程7:Vue计算属性和moment.js
文章目录 计算属性(Computed Properties)基本用法缓存机制计算属性 vs 方法使用场景计算属性的 setter 和 getter结论Moment.js 进行时间处理1. 安装 Moment.js2. 在 Vue 组件中引入 Moment.js3. 在全局使用 Moment.js4. 使用 Vue 插件的方式引入 Moment.js5. 常用日期格式化…...

【游戏设计原理】22 - 石头剪刀布
一、游戏基础:拳头、掌心、分指 首先,石头剪刀布(又名“Roshambo”)看似简单,实际上可是个“深藏玄机”的零和博弈(听起来很高深,其实就是输赢相抵消的意思)。游戏中有三种手势&…...

3-Gin 渲染 --[Gin 框架入门精讲与实战案例]
在 Gin 框架中,渲染指的是将数据传递给模板,并生成 HTML 或其他格式的响应内容。Gin 支持多种类型的渲染,包括 String HTML、JSON、XML 等。 String 渲染 在 Gin 框架中,String 渲染方法允许你直接返回一个字符串作为 HTTP 响应…...

python小课堂(一)
基础语法 1 常量和表达式2 变量和类型2.1 变量是什么2.2 变量语法 3 变量的类型3.1 动态类型特性 4 注释4.1注释是什么 5 输入输出5.1 print的介绍5.2 input 6 运算符6.1 算术运算符在这里插入图片描述6.2 关系运算符6.3 逻辑运算符6.4赋值运算符 1 常量和表达式 在print()中可…...

GESP202309 二级【小杨的 X 字矩阵】题解(AC)
》》》点我查看「视频」详解》》》 [GESP202309 二级] 小杨的 X 字矩阵 题目描述 小杨想要构造一个 的 X 字矩阵( 为奇数),这个矩阵的两条对角线都是半角加号 ,其余都是半角减号 - 。例如,一个 5 5 5 \times 5 5…...
@PostConstruct注解解释!!!!
PostConstruct 注解修饰的方法是在 Bean 完成初始化后自动调用的。它是 Java EE 和 Spring 中的一种机制,用于在 Bean 被创建并依赖注入完成后,执行一些初始化的操作。 具体触发时机: 依赖注入完成后:首先,Spring 容器…...

laya游戏引擎中打包之后图片模糊
如下图正常运行没问题,打包之后却模糊 纹理类型中的默认类型都是精灵纹理,改为默认值即可。注意:要点击“应用”才可有效。精灵纹理类型会对图片进行渲染处理,而默认值 平面类型不会处理图片。...
Android Wi-Fi 连接失败日志分析
1. Android wifi 关键日志总结 (1) Wi-Fi 断开 (CTRL-EVENT-DISCONNECTED reason3) 日志相关部分: 06-05 10:48:40.987 943 943 I wpa_supplicant: wlan0: CTRL-EVENT-DISCONNECTED bssid44:9b:c1:57:a8:90 reason3 locally_generated1解析: CTR…...

C++_核心编程_多态案例二-制作饮品
#include <iostream> #include <string> using namespace std;/*制作饮品的大致流程为:煮水 - 冲泡 - 倒入杯中 - 加入辅料 利用多态技术实现本案例,提供抽象制作饮品基类,提供子类制作咖啡和茶叶*//*基类*/ class AbstractDr…...

label-studio的使用教程(导入本地路径)
文章目录 1. 准备环境2. 脚本启动2.1 Windows2.2 Linux 3. 安装label-studio机器学习后端3.1 pip安装(推荐)3.2 GitHub仓库安装 4. 后端配置4.1 yolo环境4.2 引入后端模型4.3 修改脚本4.4 启动后端 5. 标注工程5.1 创建工程5.2 配置图片路径5.3 配置工程类型标签5.4 配置模型5.…...
spring:实例工厂方法获取bean
spring处理使用静态工厂方法获取bean实例,也可以通过实例工厂方法获取bean实例。 实例工厂方法步骤如下: 定义实例工厂类(Java代码),定义实例工厂(xml),定义调用实例工厂ÿ…...
docker 部署发现spring.profiles.active 问题
报错: org.springframework.boot.context.config.InvalidConfigDataPropertyException: Property spring.profiles.active imported from location class path resource [application-test.yml] is invalid in a profile specific resource [origin: class path re…...

AI病理诊断七剑下天山,医疗未来触手可及
一、病理诊断困局:刀尖上的医学艺术 1.1 金标准背后的隐痛 病理诊断被誉为"诊断的诊断",医生需通过显微镜观察组织切片,在细胞迷宫中捕捉癌变信号。某省病理质控报告显示,基层医院误诊率达12%-15%,专家会诊…...

脑机新手指南(七):OpenBCI_GUI:从环境搭建到数据可视化(上)
一、OpenBCI_GUI 项目概述 (一)项目背景与目标 OpenBCI 是一个开源的脑电信号采集硬件平台,其配套的 OpenBCI_GUI 则是专为该硬件设计的图形化界面工具。对于研究人员、开发者和学生而言,首次接触 OpenBCI 设备时,往…...
十九、【用户管理与权限 - 篇一】后端基础:用户列表与角色模型的初步构建
【用户管理与权限 - 篇一】后端基础:用户列表与角色模型的初步构建 前言准备工作第一部分:回顾 Django 内置的 `User` 模型第二部分:设计并创建 `Role` 和 `UserProfile` 模型第三部分:创建 Serializers第四部分:创建 ViewSets第五部分:注册 API 路由第六部分:后端初步测…...
第14节 Node.js 全局对象
JavaScript 中有一个特殊的对象,称为全局对象(Global Object),它及其所有属性都可以在程序的任何地方访问,即全局变量。 在浏览器 JavaScript 中,通常 window 是全局对象, 而 Node.js 中的全局…...
python打卡day47
昨天代码中注意力热图的部分顺移至今天 知识点回顾: 热力图 作业:对比不同卷积层热图可视化的结果 import torch import torch.nn as nn import torch.optim as optim from torchvision import datasets, transforms from torch.utils.data import D…...