第J9周:Inception v3算法实战与解析(pytorch版)
>- **🍨 本文为[🔗365天深度学习训练营]中的学习记录博客**
>- **🍖 原作者:[K同学啊]**
📌本周任务:📌
-
了解并学习InceptionV3相对与InceptionV1有哪些改进的地方
-
使用Inception完成天气识别
🏡 我的环境:
- 语言环境:Python3.8
- 编译器:Jupyter Notebook
- 深度学习环境:Pytorch
-
- torch==2.3.1+cu118
-
- torchvision==0.18.1+cu118
一、理论基础
Inception v3论文:Rethinking the Inception Architecture for Computer Vision
Inception v3由谷歌研究员Christian Szegedy等人在2015年的论文《Rethinking the Inception Architecture for Computer Vision》中提出。Inception v3是Inception网络系列的第三个版本,它在ImageNet图像识别竞赛中取得了优异成绩,尤其是在大规模图像识别任务中表现出色。
Inception v3的主要特点如下:
- 更深的网络结构:Inception v3比之前的Inception网络结构更深,包含了48层卷积层。这使得网络可以提取更多层次的特征,从而在图像识别任务上取得更好的效果。
- 使用Factorized Convolutions:Inception v3采用了Factorized Convolutions(分解卷积),将较大的卷积核分解为多个较小的卷积核。这种方法可以降低网络的参数数量,减少计算复杂度,同时保持良好的性能。
- 使用Batch Normalization:Inception v3在每个卷积层之后都添加了Batch Normalization(BN),这有助于网络的收敛和泛化能力。BN可以减少Internal Covariate Shift(内部协变量偏移)现象,加快训练速度,同时提高模型的鲁棒性。
- 辅助分类器:Inception v3引入了辅助分类器,可以在网络训练过程中提供额外的梯度信息,帮助网络更好地学习特征。辅助分类器位于网络的某个中间层,其输出会与主分类器的输出进行加权融合,从而得到最终的预测结果。
- 基于RMSProp的优化器:Inception v3使用了RMSProp优化器进行训练。相比于传统的随机梯度下降(SGD)方法,RMSProp可以自适应地调整学习率,使得训练过程更加稳定,收敛速度更快。
Inception v3在图像分类、物体检测和图像分割等计算机视觉任务中均取得了显著的效果。然而,由于其较大的网络结构和计算复杂度,Inception v3在实际应用中可能需要较高的硬件要求。
相对于Inception v1的Inception Module结构,Inception v3中做出了如下改动:
●将 5×5 的卷积分解为两个 3×3 的卷积运算以提升计算速度。尽管这有点违反直觉,但一个 5×5 的卷积在计算成本上是一个 3×3 卷积的 2.78 倍。所以叠加两个 3×3 卷积实际上在性能上会有所提升,如下图所示:
●此外,作者将 n×n 的卷积核尺寸分解为 1×n 和 n×1 两个卷积。例如,一个 3×3 的卷积等价于首先执行一个 1×3 的卷积再执行一个 3×1 的卷积。他们还发现这种方法在成本上要比单个 3×3 的卷积降低 33%,这一结构如下图所示:
此处如果 n=3,则与上一张图像一致。最左侧的 5x5 卷积可被表示为两个 3x3 卷积,它们又可以被表示为 1x3 和 3x1 卷积。
模块中的滤波器组被扩展(即变得更宽而不是更深),以解决表征性瓶颈。如果该模块没有被拓展宽度,而是变得更深,那么维度会过多减少,造成信息损失。如下图所示:
最后实现的inception v3网络是上图结构图如下:
二、前期准备
1. 设置GPU
如果设备上支持GPU就使用GPU,否则使用CPU
import warnings
warnings.filterwarnings("ignore") #忽略警告信息import torch
device=torch.device("cuda" if torch.cuda.is_available() else "cpu")
device
运行结果:
device(type='cuda')
2. 导入数据
import pathlib
data_dir=r'D:\THE MNIST DATABASE\weather_photos'
data_dir=pathlib.Path(data_dir)img_count=len(list(data_dir.glob('*/*')))
print("图片总数为:",img_count)
运行结果:
图片总数为: 1125
3. 查看数据集分类
data_paths=list(data_dir.glob('*'))
classNames=[str(path).split('\\')[3] for path in data_paths]
classNames
运行结果:
['cloudy', 'rain', 'shine', 'sunrise']
4. 随机查看图片
随机抽取数据集中的10张图片进行查看
import PIL,random
import matplotlib.pyplot as plt
from PIL import Image
plt.rcParams['font.sans-serif']=['SimHei'] #用来正常显示中文标签
plt.rcParams['axes.unicode_minus']=False #用来正常显示负号data_paths2=list(data_dir.glob('*/*'))
plt.figure(figsize=(20,8))
#plt.suptitle("OreoCC的案例",fontsize=15)
for i in range(10):plt.subplot(2,5,i+1)plt.axis("off")image=random.choice(data_paths2) #随机选择一个图片plt.title(image.parts[-2],fontsize=20) #通过glob对象取出他的文件夹名称,即分类名plt.imshow(Image.open(str(image))) #显示图片
运行结果:
5. 图片预处理
import torchvision.transforms as transforms
from torchvision import transforms,datasetstrain_transforms=transforms.Compose([transforms.Resize([224,224]), #将图片统一尺寸transforms.RandomHorizontalFlip(), #将图片随机水平翻转transforms.RandomRotation(0.2), #将图片按照0.2的弧度值随机旋转transforms.ToTensor(), #将图片转换为tensortransforms.Normalize( #标准化处理->转换为正态分布,使模型更容易收敛mean=[0.485,0.456,0.406],std=[0.229,0.224,0.225])
])total_data=datasets.ImageFolder(r'D:\THE MNIST DATABASE\weather_photos',transform=train_transforms
)
total_data
运行结果:
Dataset ImageFolderNumber of datapoints: 1125Root location: D:\THE MNIST DATABASE\weather_photosStandardTransform
Transform: Compose(Resize(size=[224, 224], interpolation=bilinear, max_size=None, antialias=True)RandomHorizontalFlip(p=0.5)RandomRotation(degrees=[-0.2, 0.2], interpolation=nearest, expand=False, fill=0)ToTensor()Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]))
将数据集分类情况进行映射输出:
total_data.class_to_idx
运行结果:
{'cloudy': 0, 'rain': 1, 'shine': 2, 'sunrise': 3}
6. 划分数据集
train_size=int(0.8*len(total_data))
test_size=len(total_data)-train_sizetrain_dataset,test_dataset=torch.utils.data.random_split(total_data,[train_size,test_size]
)
train_dataset,test_dataset
运行结果:
(<torch.utils.data.dataset.Subset at 0x1f39bf80a90>,<torch.utils.data.dataset.Subset at 0x1f3bc462210>)
查看训练集和测试集的数据数量:
train_size,test_size
运行结果:
(900, 225)
7. 加载数据集
batch_size=8
train_dl=torch.utils.data.DataLoader(train_dataset,batch_size=batch_size,shuffle=True,num_workers=1
)
test_dl=torch.utils.data.DataLoader(test_dataset,batch_size=batch_size,shuffle=True,num_workers=1
)
查看测试集的情况:
for x,y in train_dl:print("Shape of x [N,C,H,W]:",x.shape)print("Shape of y:",y.shape,y.dtype)break
运行结果:
Shape of x [N,C,H,W]: torch.Size([8, 3, 224, 224])
Shape of y: torch.Size([8]) torch.int64
三、手动搭建网络模型
1. BasicConv2d模块
import torch.nn as nn
import torch.nn.functional as Fclass BasicConv2d(nn.Module):def __init__(self,in_channels,out_channels,**kwargs):super(BasicConv2d,self).__init__()self.conv=nn.Conv2d(in_channels,out_channels,bias=False,**kwargs)self.bn=nn.BatchNorm2d(out_channels,eps=0.001)self.relu=nn.ReLU(inplace=True)def forward(self,x):x=self.conv(x)x=self.bn(x)x=self.relu(x)return x
2. Inception-A
class InceptionA(nn.Module):def __init__(self,in_channels,pool_features):super(InceptionA,self).__init__()self.branch1x1=BasicConv2d(in_channels,64,kernel_size=1)self.branch5x5_1=BasicConv2d(in_channels,48,kernel_size=1)self.branch5x5_2=BasicConv2d(48,64,kernel_size=5,padding=2)self.branch3x3dbl_1=BasicConv2d(in_channels,64,kernel_size=1)self.branch3x3dbl_2=BasicConv2d(64,96,kernel_size=3,padding=1)self.branch3x3dbl_3=BasicConv2d(96,96,kernel_size=3,padding=1)self.branch_pool=BasicConv2d(in_channels,pool_features,kernel_size=1)def forward(self,x):branch1x1=self.branch1x1(x)branch5x5=self.branch5x5_1(x)branch5x5=self.branch5x5_2(branch5x5)branch3x3dbl=self.branch3x3dbl_1(x)branch3x3dbl=self.branch3x3dbl_2(branch3x3dbl)branch3x3dbl=self.branch3x3dbl_3(branch3x3dbl)branch_pool=F.avg_pool2d(x,kernel_size=3,stride=1,padding=1)branch_pool=self.branch_pool(branch_pool)outputs=[branch1x1,branch5x5,branch3x3dbl,branch_pool]return torch.cat(outputs,1)
InceptionA模块包含四个分支,每个分支使用不同的卷积核大小和参数。其中branch1x1、branch5x5_2、branch3x3dbl_3使用较小的卷积核,可以减少参数数量和计算量,提高网络效率。branch_pool使用平均池化的方式进行特征提取和降维。最后,将四个分支的结果在通道维度上进行拼接,输出InceptionA的结果。
3. Inception-B
class InceptionB(nn.Module):def __init__(self,in_channels,channels_7x7):super(InceptionB,self).__init__()self.branch1x1=BasicConv2d(in_channels,192,kernel_size=1)c7=channels_7x7self.branch7x7_1=BasicConv2d(in_channels,c7,kernel_size=1)self.branch7x7_2=BasicConv2d(c7,c7,kernel_size=(1,7),padding=(0,3))self.branch7x7_3=BasicConv2d(c7,192,kernel_size=(7,1),padding=(3,0))self.branch7x7dbl_1=BasicConv2d(in_channels,c7,kernel_size=1)self.branch7x7dbl_2=BasicConv2d(c7,c7,kernel_size=(7,1),padding=(3,0))self.branch7x7dbl_3=BasicConv2d(c7,c7,kernel_size=(1,7),padding=(0,3))self.branch7x7dbl_4=BasicConv2d(c7,c7,kernel_size=(7,1),padding=(3,0))self.branch7x7dbl_5=BasicConv2d(c7,192,kernel_size=(1,7),padding=(0,3))self.branch_pool=BasicConv2d(in_channels,192,kernel_size=1)def forward(self,x):branch1x1=self.branch1x1(x)branch7x7=self.branch7x7_1(x)branch7x7=self.branch7x7_2(branch7x7)branch7x7=self.branch7x7_3(branch7x7)branch7x7dbl=self.branch7x7dbl_1(x)branch7x7dbl=self.branch7x7dbl_2(branch7x7dbl)branch7x7dbl=self.branch7x7dbl_3(branch7x7dbl)branch7x7dbl=self.branch7x7dbl_4(branch7x7dbl)branch7x7dbl=self.branch7x7dbl_5(branch7x7dbl)branch_pool=F.avg_pool2d(x,kernel_size=3,stride=1,padding=1)branch_pool=self.branch_pool(branch_pool)outputs=[branch1x1,branch7x7,branch7x7dbl,branch_pool]return torch.cat(outputs,1)
InceptionB模块包含四个分支,其中branch7x7_2和branch7x7_3使用不同大小的卷积核进行多次卷积,可以提高特征的表达能力。branch7x7dbl_2、branch7x7dbl_3、branch7x7dbl_4、branch7x7dbl_5也类似地使用多个不同大小的卷积核进行多次卷积,提高了特征的表达能力,并且较好地保留了空间尺寸。branch_pool仍然使用平均池化的方式进行特征提取和降维。最后,将四个分支的结果在通道维度上进行拼接,输出InceptionB的结果。
4. Inception-C
class InceptionC(nn.Module):def __init__(self,in_channels):super(InceptionC,self).__init__()self.branch1x1=BasicConv2d(in_channels,320,kernel_size=1)self.branch3x3_1=BasicConv2d(in_channels,384,kernel_size=1)self.branch3x3_2a=BasicConv2d(384,384,kernel_size=(1,3),padding=(0,1))self.branch3x3_2b=BasicConv2d(384,384,kernel_size=(3,1),padding=(1,0))self.branch3x3dbl_1=BasicConv2d(in_channels,448,kernel_size=1)self.branch3x3dbl_2=BasicConv2d(448,384,kernel_size=3,padding=1)self.branch3x3dbl_3a=BasicConv2d(384,384,kernel_size=(1,3),padding=(0,1))self.branch3x3dbl_3b=BasicConv2d(384,384,kernel_size=(3,1),padding=(1,0))self.branch_pool=BasicConv2d(in_channels,192,kernel_size=1)def forward(self,x):branch1x1=self.branch1x1(x)branch3x3=self.branch3x3_1(x)branch3x3=[self.branch3x3_2a(branch3x3),self.branch3x3_2b(branch3x3),]branch3x3=torch.cat(branch3x3,1)branch3x3dbl=self.branch3x3dbl_1(x)branch3x3dbl=self.branch3x3dbl_2(branch3x3dbl)branch3x3dbl=[self.branch3x3dbl_3a(branch3x3dbl),self.branch3x3dbl_3b(branch3x3dbl),]branch3x3dbl=torch.cat(branch3x3dbl,1)branch_pool=F.avg_pool2d(x,kernel_size=3,stride=1,padding=1)branch_pool=self.branch_pool(branch_pool)outputs=[branch1x1,branch3x3,branch3x3dbl,branch_pool]return torch.cat(outputs,1)
InceptionC模块同样包含三个分支,其中branch3x3_2b和branch3x3dbl_2使用大小为3x3的卷积核,并且padding=1,dilation=1,可以一定程度上扩大感受野。这两个分支多次卷积,可以提高特征的表达能力。branch_pool仍然使用平均池化的方式进行特征提取和降维。最后,将三个分支的结果在通道维度上进行拼接,输出InceptionB的结果。
5、Reduction-A
class ReductionA(nn.Module):def __init__(self,in_channels):super(ReductionA,self).__init__()self.branch3x3=BasicConv2d(in_channels,384,kernel_size=3,stride=2)self.branch3x3dbl_1=BasicConv2d(in_channels,64,kernel_size=1)self.branch3x3dbl_2=BasicConv2d(64,96,kernel_size=3,padding=1)self.branch3x3dbl_3=BasicConv2d(96,96,kernel_size=3,stride=2)def forward(self,x):branch3x3=self.branch3x3(x)branch3x3dbl=self.branch3x3dbl_1(x)branch3x3dbl=self.branch3x3dbl_2(branch3x3dbl)branch3x3dbl=self.branch3x3dbl_3(branch3x3dbl)branch_pool=F.max_pool2d(x,kernel_size=3,stride=2)outputs=[branch3x3,branch3x3dbl,branch_pool]return torch.cat(outputs,1)
ReductionA模块包含三个分支,其中branch3x3_1使用3x3的卷积核进行卷积,通过stride=2来降维,同时提取特征。branch3x3_2a、branch3x3_2b、branch3x3_2c使用多层卷积对特征进行提取和表达,同时通过stride=2来降维和压缩特征,减少计算量。branch_pool使用max pooling的方式进行特征提取和降维,与其他模块类似。最后,将三个分支的结果在通道维度上进行拼接,输出ReductionA的结果。
6、Reduction-B
class ReductionB(nn.Module):def __init__(self,in_channels):super(ReductionB,self).__init__()self.branch3x3_1=BasicConv2d(in_channels,192,kernel_size=1)self.branch3x3_2=BasicConv2d(192,320,kernel_size=3,stride=2)self.branch7x7x3_1=BasicConv2d(in_channels,192,kernel_size=1)self.branch7x7x3_2=BasicConv2d(192,192,kernel_size=(1,7),padding=(0,3))self.branch7x7x3_3=BasicConv2d(192,192,kernel_size=(7,1),padding=(3,0))self.branch7x7x3_4=BasicConv2d(192,192,kernel_size=3,stride=2)def forward(self,x):branch3x3=self.branch3x3_1(x)branch3x3=self.branch3x3_2(branch3x3)branch7x7x3=self.branch7x7x3_1(x)branch7x7x3=self.branch7x7x3_2(branch7x7x3)branch7x7x3=self.branch7x7x3_3(branch7x7x3)branch7x7x3=self.branch7x7x3_4(branch7x7x3)branch_pool=F.max_pool2d(x,kernel_size=3,stride=2)outputs=[branch3x3,branch7x7x3,branch_pool]return torch.cat(outputs,1)
7、辅助分支
class InceptionAux(nn.Module): def __init__(self,in_channels,num_classes):super(InceptionAux,self).__init__()self.conv0=BasicConv2d(in_channels,128,kernel_size=1)self.conv1=BasicConv2d(128,768,kernel_size=5)self.conv1.stddev=0.01self.fc=nn.Linear(768,num_classes)self.fc.stddev=0.001def forward(self,x):#17x17x768x=F.avg_pool2d(x,kernel_size=5,stride=3)# 5x5x768x=self.conv0(x)# 5x5x128x=self.conv1(x)# 1x1x768x=x.view(x.size(0),-1)# 768x=self.fc(x)# 1000return x
8、InceptionV3实现
class InceptionV3(nn.Module):def __init__(self,num_classes=1000,aux_logits=False,transform_input=False):super(InceptionV3,self).__init__()self.aux_logits=aux_logitsself.transform_input=transform_inputself.Conv2d_1a_3x3=BasicConv2d(3,32,kernel_size=3,stride=2)self.Conv2d_2a_3x3=BasicConv2d(32,32,kernel_size=3)self.Conv2d_2b_3x3=BasicConv2d(32,64,kernel_size=3,padding=1)self.Conv2d_3b_1x1=BasicConv2d(64,80,kernel_size=1)self.Conv2d_4a_3x3=BasicConv2d(80,192,kernel_size=3)self.Mixed_5b=InceptionA(192,pool_features=32)self.Mixed_5c=InceptionA(256,pool_features=64)self.Mixed_5d=InceptionA(288,pool_features=64)self.Mixed_6a=ReductionA(288)self.Mixed_6b=InceptionB(768,channels_7x7=128)self.Mixed_6c=InceptionB(768,channels_7x7=160)self.Mixed_6d=InceptionB(768,channels_7x7=160)self.Mixed_6e=InceptionB(768,channels_7x7=192)if aux_logits:self.AuxLogits=InceptionAux(768,num_classes)self.Mixed_7a=ReductionB(768)self.Mixed_7b=InceptionC(1280)self.Mixed_7c=InceptionC(2048)self.fc=nn.Linear(2048,num_classes)def forward(self,x):if self.transform_input:x=x.clone()x[:,0]=x[:,0]*(0.229/0.5)+(0.485-0.5)/0.5x[:,1]=x[:,1]*(0.224/0.5)+(0.456-0.5)/0.5x[:,2]=x[:,2]*(0.225/0.5)+(0.406-0.5)/0.5# 229x229x3x=self.Conv2d_1a_3x3(x)# 149x149x32x=self.Conv2d_2a_3x3(x)# 147x147x32x=self.Conv2d_2b_3x3(x)# 147x147x64x=F.max_pool2d(x,kernel_size=3,stride=2)# 73x73x64x=self.Conv2d_3b_1x1(x)# 73x73x80x=self.Conv2d_4a_3x3(x)# 71x71x192x=F.max_pool2d(x,kernel_size=3,stride=2)# 35x35x192x=self.Mixed_5b(x)# 35x35x256x=self.Mixed_5c(x)# 35x35x288x=self.Mixed_5d(x)# 35x35x288x=self.Mixed_6a(x)# 17x17x768x=self.Mixed_6b(x)# 17x17x768x=self.Mixed_6c(x)# 17x17x768x=self.Mixed_6d(x)# 17x17x768x=self.Mixed_6e(x)# 17x17x768if self.training and self.aux_logits:aux=self.AuxLogits(x)# 17x17x768x=self.Mixed_7a(x)# 8x8x1280x=self.Mixed_7b(x)# 8x8x2048x=self.Mixed_7c(x)# 8x8x2048x=F.avg_pool2d(x,kernel_size=5)# 1x1x2048x=F.dropout(x,training=self.training)# 1x1x2048x=x.view(x.size(0),-1)# 2048x=self.fc(x)# 1000(num_classes)if self.training and self.aux_logits:return x,auxreturn x
model=InceptionV3(num_classes=4).to(device)
model
运行结果:
InceptionV3((Conv2d_1a_3x3): BasicConv2d((conv): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), bias=False)(bn): BatchNorm2d(32, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)(relu): ReLU(inplace=True))(Conv2d_2a_3x3): BasicConv2d((conv): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), bias=False)(bn): BatchNorm2d(32, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)(relu): ReLU(inplace=True))(Conv2d_2b_3x3): BasicConv2d((conv): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(bn): BatchNorm2d(64, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)(relu): ReLU(inplace=True))(Conv2d_3b_1x1): BasicConv2d((conv): Conv2d(64, 80, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn): BatchNorm2d(80, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)(relu): ReLU(inplace=True))(Conv2d_4a_3x3): BasicConv2d((conv): Conv2d(80, 192, kernel_size=(3, 3), stride=(1, 1), bias=False)(bn): BatchNorm2d(192, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)(relu): ReLU(inplace=True))(Mixed_5b): InceptionA((branch1x1): BasicConv2d((conv): Conv2d(192, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn): BatchNorm2d(64, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)(relu): ReLU(inplace=True))(branch5x5_1): BasicConv2d((conv): Conv2d(192, 48, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn): BatchNorm2d(48, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)(relu): ReLU(inplace=True))(branch5x5_2): BasicConv2d((conv): Conv2d(48, 64, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2), bias=False)(bn): BatchNorm2d(64, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)(relu): ReLU(inplace=True))(branch3x3dbl_1): BasicConv2d((conv): Conv2d(192, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn): BatchNorm2d(64, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)(relu): ReLU(inplace=True))(branch3x3dbl_2): BasicConv2d((conv): Conv2d(64, 96, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(bn): BatchNorm2d(96, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)(relu): ReLU(inplace=True))(branch3x3dbl_3): BasicConv2d((conv): Conv2d(96, 96, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(bn): BatchNorm2d(96, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)(relu): ReLU(inplace=True))(branch_pool): BasicConv2d((conv): Conv2d(192, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn): BatchNorm2d(32, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)(relu): ReLU(inplace=True)))(Mixed_5c): InceptionA((branch1x1): BasicConv2d((conv): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn): BatchNorm2d(64, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)(relu): ReLU(inplace=True))(branch5x5_1): BasicConv2d((conv): Conv2d(256, 48, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn): BatchNorm2d(48, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)(relu): ReLU(inplace=True))(branch5x5_2): BasicConv2d((conv): Conv2d(48, 64, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2), bias=False)(bn): BatchNorm2d(64, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)(relu): ReLU(inplace=True))(branch3x3dbl_1): BasicConv2d((conv): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn): BatchNorm2d(64, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)(relu): ReLU(inplace=True))(branch3x3dbl_2): BasicConv2d((conv): Conv2d(64, 96, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(bn): BatchNorm2d(96, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)(relu): ReLU(inplace=True))(branch3x3dbl_3): BasicConv2d((conv): Conv2d(96, 96, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(bn): BatchNorm2d(96, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)(relu): ReLU(inplace=True))(branch_pool): BasicConv2d((conv): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn): BatchNorm2d(64, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)(relu): ReLU(inplace=True)))(Mixed_5d): InceptionA((branch1x1): BasicConv2d((conv): Conv2d(288, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn): BatchNorm2d(64, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)(relu): ReLU(inplace=True))(branch5x5_1): BasicConv2d((conv): Conv2d(288, 48, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn): BatchNorm2d(48, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)(relu): ReLU(inplace=True))(branch5x5_2): BasicConv2d((conv): Conv2d(48, 64, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2), bias=False)(bn): BatchNorm2d(64, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)(relu): ReLU(inplace=True))(branch3x3dbl_1): BasicConv2d((conv): Conv2d(288, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn): BatchNorm2d(64, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)(relu): ReLU(inplace=True))(branch3x3dbl_2): BasicConv2d((conv): Conv2d(64, 96, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(bn): BatchNorm2d(96, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)(relu): ReLU(inplace=True))(branch3x3dbl_3): BasicConv2d((conv): Conv2d(96, 96, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(bn): BatchNorm2d(96, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)(relu): ReLU(inplace=True))(branch_pool): BasicConv2d((conv): Conv2d(288, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn): BatchNorm2d(64, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)(relu): ReLU(inplace=True)))(Mixed_6a): ReductionA((branch3x3): BasicConv2d((conv): Conv2d(288, 384, kernel_size=(3, 3), stride=(2, 2), bias=False)(bn): BatchNorm2d(384, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)(relu): ReLU(inplace=True))(branch3x3dbl_1): BasicConv2d((conv): Conv2d(288, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn): BatchNorm2d(64, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)(relu): ReLU(inplace=True))(branch3x3dbl_2): BasicConv2d((conv): Conv2d(64, 96, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(bn): BatchNorm2d(96, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)(relu): ReLU(inplace=True))(branch3x3dbl_3): BasicConv2d((conv): Conv2d(96, 96, kernel_size=(3, 3), stride=(2, 2), bias=False)(bn): BatchNorm2d(96, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)(relu): ReLU(inplace=True)))(Mixed_6b): InceptionB((branch1x1): BasicConv2d((conv): Conv2d(768, 192, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn): BatchNorm2d(192, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)(relu): ReLU(inplace=True))(branch7x7_1): BasicConv2d((conv): Conv2d(768, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn): BatchNorm2d(128, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)(relu): ReLU(inplace=True))(branch7x7_2): BasicConv2d((conv): Conv2d(128, 128, kernel_size=(1, 7), stride=(1, 1), padding=(0, 3), bias=False)(bn): BatchNorm2d(128, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)(relu): ReLU(inplace=True))(branch7x7_3): BasicConv2d((conv): Conv2d(128, 192, kernel_size=(7, 1), stride=(1, 1), padding=(3, 0), bias=False)(bn): BatchNorm2d(192, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)(relu): ReLU(inplace=True))(branch7x7dbl_1): BasicConv2d((conv): Conv2d(768, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn): BatchNorm2d(128, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)(relu): ReLU(inplace=True))(branch7x7dbl_2): BasicConv2d((conv): Conv2d(128, 128, kernel_size=(7, 1), stride=(1, 1), padding=(3, 0), bias=False)(bn): BatchNorm2d(128, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)(relu): ReLU(inplace=True))(branch7x7dbl_3): BasicConv2d((conv): Conv2d(128, 128, kernel_size=(1, 7), stride=(1, 1), padding=(0, 3), bias=False)(bn): BatchNorm2d(128, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)(relu): ReLU(inplace=True))(branch7x7dbl_4): BasicConv2d((conv): Conv2d(128, 128, kernel_size=(7, 1), stride=(1, 1), padding=(3, 0), bias=False)(bn): BatchNorm2d(128, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)(relu): ReLU(inplace=True))(branch7x7dbl_5): BasicConv2d((conv): Conv2d(128, 192, kernel_size=(1, 7), stride=(1, 1), padding=(0, 3), bias=False)(bn): BatchNorm2d(192, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)(relu): ReLU(inplace=True))(branch_pool): BasicConv2d((conv): Conv2d(768, 192, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn): BatchNorm2d(192, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)(relu): ReLU(inplace=True)))(Mixed_6c): InceptionB((branch1x1): BasicConv2d((conv): Conv2d(768, 192, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn): BatchNorm2d(192, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)(relu): ReLU(inplace=True))(branch7x7_1): BasicConv2d((conv): Conv2d(768, 160, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn): BatchNorm2d(160, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)(relu): ReLU(inplace=True))(branch7x7_2): BasicConv2d((conv): Conv2d(160, 160, kernel_size=(1, 7), stride=(1, 1), padding=(0, 3), bias=False)(bn): BatchNorm2d(160, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)(relu): ReLU(inplace=True))(branch7x7_3): BasicConv2d((conv): Conv2d(160, 192, kernel_size=(7, 1), stride=(1, 1), padding=(3, 0), bias=False)(bn): BatchNorm2d(192, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)(relu): ReLU(inplace=True))(branch7x7dbl_1): BasicConv2d((conv): Conv2d(768, 160, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn): BatchNorm2d(160, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)(relu): ReLU(inplace=True))(branch7x7dbl_2): BasicConv2d((conv): Conv2d(160, 160, kernel_size=(7, 1), stride=(1, 1), padding=(3, 0), bias=False)(bn): BatchNorm2d(160, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)(relu): ReLU(inplace=True))(branch7x7dbl_3): BasicConv2d((conv): Conv2d(160, 160, kernel_size=(1, 7), stride=(1, 1), padding=(0, 3), bias=False)(bn): BatchNorm2d(160, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)(relu): ReLU(inplace=True))(branch7x7dbl_4): BasicConv2d((conv): Conv2d(160, 160, kernel_size=(7, 1), stride=(1, 1), padding=(3, 0), bias=False)(bn): BatchNorm2d(160, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)(relu): ReLU(inplace=True))(branch7x7dbl_5): BasicConv2d((conv): Conv2d(160, 192, kernel_size=(1, 7), stride=(1, 1), padding=(0, 3), bias=False)(bn): BatchNorm2d(192, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)(relu): ReLU(inplace=True))(branch_pool): BasicConv2d((conv): Conv2d(768, 192, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn): BatchNorm2d(192, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)(relu): ReLU(inplace=True)))(Mixed_6d): InceptionB((branch1x1): BasicConv2d((conv): Conv2d(768, 192, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn): BatchNorm2d(192, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)(relu): ReLU(inplace=True))(branch7x7_1): BasicConv2d((conv): Conv2d(768, 160, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn): BatchNorm2d(160, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)(relu): ReLU(inplace=True))(branch7x7_2): BasicConv2d((conv): Conv2d(160, 160, kernel_size=(1, 7), stride=(1, 1), padding=(0, 3), bias=False)(bn): BatchNorm2d(160, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)(relu): ReLU(inplace=True))(branch7x7_3): BasicConv2d((conv): Conv2d(160, 192, kernel_size=(7, 1), stride=(1, 1), padding=(3, 0), bias=False)(bn): BatchNorm2d(192, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)(relu): ReLU(inplace=True))(branch7x7dbl_1): BasicConv2d((conv): Conv2d(768, 160, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn): BatchNorm2d(160, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)(relu): ReLU(inplace=True))(branch7x7dbl_2): BasicConv2d((conv): Conv2d(160, 160, kernel_size=(7, 1), stride=(1, 1), padding=(3, 0), bias=False)(bn): BatchNorm2d(160, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)(relu): ReLU(inplace=True))(branch7x7dbl_3): BasicConv2d((conv): Conv2d(160, 160, kernel_size=(1, 7), stride=(1, 1), padding=(0, 3), bias=False)(bn): BatchNorm2d(160, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)(relu): ReLU(inplace=True))(branch7x7dbl_4): BasicConv2d((conv): Conv2d(160, 160, kernel_size=(7, 1), stride=(1, 1), padding=(3, 0), bias=False)(bn): BatchNorm2d(160, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)(relu): ReLU(inplace=True))(branch7x7dbl_5): BasicConv2d((conv): Conv2d(160, 192, kernel_size=(1, 7), stride=(1, 1), padding=(0, 3), bias=False)(bn): BatchNorm2d(192, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)(relu): ReLU(inplace=True))(branch_pool): BasicConv2d((conv): Conv2d(768, 192, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn): BatchNorm2d(192, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)(relu): ReLU(inplace=True)))(Mixed_6e): InceptionB((branch1x1): BasicConv2d((conv): Conv2d(768, 192, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn): BatchNorm2d(192, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)(relu): ReLU(inplace=True))(branch7x7_1): BasicConv2d((conv): Conv2d(768, 192, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn): BatchNorm2d(192, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)(relu): ReLU(inplace=True))(branch7x7_2): BasicConv2d((conv): Conv2d(192, 192, kernel_size=(1, 7), stride=(1, 1), padding=(0, 3), bias=False)(bn): BatchNorm2d(192, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)(relu): ReLU(inplace=True))(branch7x7_3): BasicConv2d((conv): Conv2d(192, 192, kernel_size=(7, 1), stride=(1, 1), padding=(3, 0), bias=False)(bn): BatchNorm2d(192, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)(relu): ReLU(inplace=True))(branch7x7dbl_1): BasicConv2d((conv): Conv2d(768, 192, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn): BatchNorm2d(192, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)(relu): ReLU(inplace=True))(branch7x7dbl_2): BasicConv2d((conv): Conv2d(192, 192, kernel_size=(7, 1), stride=(1, 1), padding=(3, 0), bias=False)(bn): BatchNorm2d(192, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)(relu): ReLU(inplace=True))(branch7x7dbl_3): BasicConv2d((conv): Conv2d(192, 192, kernel_size=(1, 7), stride=(1, 1), padding=(0, 3), bias=False)(bn): BatchNorm2d(192, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)(relu): ReLU(inplace=True))(branch7x7dbl_4): BasicConv2d((conv): Conv2d(192, 192, kernel_size=(7, 1), stride=(1, 1), padding=(3, 0), bias=False)(bn): BatchNorm2d(192, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)(relu): ReLU(inplace=True))(branch7x7dbl_5): BasicConv2d((conv): Conv2d(192, 192, kernel_size=(1, 7), stride=(1, 1), padding=(0, 3), bias=False)(bn): BatchNorm2d(192, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)(relu): ReLU(inplace=True))(branch_pool): BasicConv2d((conv): Conv2d(768, 192, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn): BatchNorm2d(192, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)(relu): ReLU(inplace=True)))(Mixed_7a): ReductionB((branch3x3_1): BasicConv2d((conv): Conv2d(768, 192, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn): BatchNorm2d(192, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)(relu): ReLU(inplace=True))(branch3x3_2): BasicConv2d((conv): Conv2d(192, 320, kernel_size=(3, 3), stride=(2, 2), bias=False)(bn): BatchNorm2d(320, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)(relu): ReLU(inplace=True))(branch7x7x3_1): BasicConv2d((conv): Conv2d(768, 192, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn): BatchNorm2d(192, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)(relu): ReLU(inplace=True))(branch7x7x3_2): BasicConv2d((conv): Conv2d(192, 192, kernel_size=(1, 7), stride=(1, 1), padding=(0, 3), bias=False)(bn): BatchNorm2d(192, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)(relu): ReLU(inplace=True))(branch7x7x3_3): BasicConv2d((conv): Conv2d(192, 192, kernel_size=(7, 1), stride=(1, 1), padding=(3, 0), bias=False)(bn): BatchNorm2d(192, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)(relu): ReLU(inplace=True))(branch7x7x3_4): BasicConv2d((conv): Conv2d(192, 192, kernel_size=(3, 3), stride=(2, 2), bias=False)(bn): BatchNorm2d(192, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)(relu): ReLU(inplace=True)))(Mixed_7b): InceptionC((branch1x1): BasicConv2d((conv): Conv2d(1280, 320, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn): BatchNorm2d(320, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)(relu): ReLU(inplace=True))(branch3x3_1): BasicConv2d((conv): Conv2d(1280, 384, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn): BatchNorm2d(384, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)(relu): ReLU(inplace=True))(branch3x3_2a): BasicConv2d((conv): Conv2d(384, 384, kernel_size=(1, 3), stride=(1, 1), padding=(0, 1), bias=False)(bn): BatchNorm2d(384, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)(relu): ReLU(inplace=True))(branch3x3_2b): BasicConv2d((conv): Conv2d(384, 384, kernel_size=(3, 1), stride=(1, 1), padding=(1, 0), bias=False)(bn): BatchNorm2d(384, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)(relu): ReLU(inplace=True))(branch3x3dbl_1): BasicConv2d((conv): Conv2d(1280, 448, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn): BatchNorm2d(448, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)(relu): ReLU(inplace=True))(branch3x3dbl_2): BasicConv2d((conv): Conv2d(448, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(bn): BatchNorm2d(384, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)(relu): ReLU(inplace=True))(branch3x3dbl_3a): BasicConv2d((conv): Conv2d(384, 384, kernel_size=(1, 3), stride=(1, 1), padding=(0, 1), bias=False)(bn): BatchNorm2d(384, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)(relu): ReLU(inplace=True))(branch3x3dbl_3b): BasicConv2d((conv): Conv2d(384, 384, kernel_size=(3, 1), stride=(1, 1), padding=(1, 0), bias=False)(bn): BatchNorm2d(384, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)(relu): ReLU(inplace=True))(branch_pool): BasicConv2d((conv): Conv2d(1280, 192, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn): BatchNorm2d(192, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)(relu): ReLU(inplace=True)))(Mixed_7c): InceptionC((branch1x1): BasicConv2d((conv): Conv2d(2048, 320, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn): BatchNorm2d(320, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)(relu): ReLU(inplace=True))(branch3x3_1): BasicConv2d((conv): Conv2d(2048, 384, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn): BatchNorm2d(384, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)(relu): ReLU(inplace=True))(branch3x3_2a): BasicConv2d((conv): Conv2d(384, 384, kernel_size=(1, 3), stride=(1, 1), padding=(0, 1), bias=False)(bn): BatchNorm2d(384, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)(relu): ReLU(inplace=True))(branch3x3_2b): BasicConv2d((conv): Conv2d(384, 384, kernel_size=(3, 1), stride=(1, 1), padding=(1, 0), bias=False)(bn): BatchNorm2d(384, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)(relu): ReLU(inplace=True))(branch3x3dbl_1): BasicConv2d((conv): Conv2d(2048, 448, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn): BatchNorm2d(448, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)(relu): ReLU(inplace=True))(branch3x3dbl_2): BasicConv2d((conv): Conv2d(448, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(bn): BatchNorm2d(384, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)(relu): ReLU(inplace=True))(branch3x3dbl_3a): BasicConv2d((conv): Conv2d(384, 384, kernel_size=(1, 3), stride=(1, 1), padding=(0, 1), bias=False)(bn): BatchNorm2d(384, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)(relu): ReLU(inplace=True))(branch3x3dbl_3b): BasicConv2d((conv): Conv2d(384, 384, kernel_size=(3, 1), stride=(1, 1), padding=(1, 0), bias=False)(bn): BatchNorm2d(384, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)(relu): ReLU(inplace=True))(branch_pool): BasicConv2d((conv): Conv2d(2048, 192, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn): BatchNorm2d(192, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)(relu): ReLU(inplace=True)))(fc): Linear(in_features=2048, out_features=4, bias=True)
)
9. 查看模型详情
import torchsummary as summary
summary.summary(model,(3,299,299))
运行结果:
----------------------------------------------------------------Layer (type) Output Shape Param #
================================================================Conv2d-1 [-1, 32, 149, 149] 864BatchNorm2d-2 [-1, 32, 149, 149] 64ReLU-3 [-1, 32, 149, 149] 0BasicConv2d-4 [-1, 32, 149, 149] 0Conv2d-5 [-1, 32, 147, 147] 9,216BatchNorm2d-6 [-1, 32, 147, 147] 64ReLU-7 [-1, 32, 147, 147] 0BasicConv2d-8 [-1, 32, 147, 147] 0Conv2d-9 [-1, 64, 147, 147] 18,432BatchNorm2d-10 [-1, 64, 147, 147] 128ReLU-11 [-1, 64, 147, 147] 0BasicConv2d-12 [-1, 64, 147, 147] 0Conv2d-13 [-1, 80, 73, 73] 5,120BatchNorm2d-14 [-1, 80, 73, 73] 160ReLU-15 [-1, 80, 73, 73] 0BasicConv2d-16 [-1, 80, 73, 73] 0Conv2d-17 [-1, 192, 71, 71] 138,240BatchNorm2d-18 [-1, 192, 71, 71] 384ReLU-19 [-1, 192, 71, 71] 0BasicConv2d-20 [-1, 192, 71, 71] 0Conv2d-21 [-1, 64, 35, 35] 12,288BatchNorm2d-22 [-1, 64, 35, 35] 128ReLU-23 [-1, 64, 35, 35] 0BasicConv2d-24 [-1, 64, 35, 35] 0Conv2d-25 [-1, 48, 35, 35] 9,216BatchNorm2d-26 [-1, 48, 35, 35] 96ReLU-27 [-1, 48, 35, 35] 0BasicConv2d-28 [-1, 48, 35, 35] 0Conv2d-29 [-1, 64, 35, 35] 76,800BatchNorm2d-30 [-1, 64, 35, 35] 128ReLU-31 [-1, 64, 35, 35] 0BasicConv2d-32 [-1, 64, 35, 35] 0Conv2d-33 [-1, 64, 35, 35] 12,288BatchNorm2d-34 [-1, 64, 35, 35] 128ReLU-35 [-1, 64, 35, 35] 0BasicConv2d-36 [-1, 64, 35, 35] 0Conv2d-37 [-1, 96, 35, 35] 55,296BatchNorm2d-38 [-1, 96, 35, 35] 192ReLU-39 [-1, 96, 35, 35] 0BasicConv2d-40 [-1, 96, 35, 35] 0Conv2d-41 [-1, 96, 35, 35] 82,944BatchNorm2d-42 [-1, 96, 35, 35] 192ReLU-43 [-1, 96, 35, 35] 0BasicConv2d-44 [-1, 96, 35, 35] 0Conv2d-45 [-1, 32, 35, 35] 6,144BatchNorm2d-46 [-1, 32, 35, 35] 64ReLU-47 [-1, 32, 35, 35] 0BasicConv2d-48 [-1, 32, 35, 35] 0InceptionA-49 [-1, 256, 35, 35] 0Conv2d-50 [-1, 64, 35, 35] 16,384BatchNorm2d-51 [-1, 64, 35, 35] 128ReLU-52 [-1, 64, 35, 35] 0BasicConv2d-53 [-1, 64, 35, 35] 0Conv2d-54 [-1, 48, 35, 35] 12,288BatchNorm2d-55 [-1, 48, 35, 35] 96ReLU-56 [-1, 48, 35, 35] 0BasicConv2d-57 [-1, 48, 35, 35] 0Conv2d-58 [-1, 64, 35, 35] 76,800BatchNorm2d-59 [-1, 64, 35, 35] 128ReLU-60 [-1, 64, 35, 35] 0BasicConv2d-61 [-1, 64, 35, 35] 0Conv2d-62 [-1, 64, 35, 35] 16,384BatchNorm2d-63 [-1, 64, 35, 35] 128ReLU-64 [-1, 64, 35, 35] 0BasicConv2d-65 [-1, 64, 35, 35] 0Conv2d-66 [-1, 96, 35, 35] 55,296BatchNorm2d-67 [-1, 96, 35, 35] 192ReLU-68 [-1, 96, 35, 35] 0BasicConv2d-69 [-1, 96, 35, 35] 0Conv2d-70 [-1, 96, 35, 35] 82,944BatchNorm2d-71 [-1, 96, 35, 35] 192ReLU-72 [-1, 96, 35, 35] 0BasicConv2d-73 [-1, 96, 35, 35] 0Conv2d-74 [-1, 64, 35, 35] 16,384BatchNorm2d-75 [-1, 64, 35, 35] 128ReLU-76 [-1, 64, 35, 35] 0BasicConv2d-77 [-1, 64, 35, 35] 0InceptionA-78 [-1, 288, 35, 35] 0Conv2d-79 [-1, 64, 35, 35] 18,432BatchNorm2d-80 [-1, 64, 35, 35] 128ReLU-81 [-1, 64, 35, 35] 0BasicConv2d-82 [-1, 64, 35, 35] 0Conv2d-83 [-1, 48, 35, 35] 13,824BatchNorm2d-84 [-1, 48, 35, 35] 96ReLU-85 [-1, 48, 35, 35] 0BasicConv2d-86 [-1, 48, 35, 35] 0Conv2d-87 [-1, 64, 35, 35] 76,800BatchNorm2d-88 [-1, 64, 35, 35] 128ReLU-89 [-1, 64, 35, 35] 0BasicConv2d-90 [-1, 64, 35, 35] 0Conv2d-91 [-1, 64, 35, 35] 18,432BatchNorm2d-92 [-1, 64, 35, 35] 128ReLU-93 [-1, 64, 35, 35] 0BasicConv2d-94 [-1, 64, 35, 35] 0Conv2d-95 [-1, 96, 35, 35] 55,296BatchNorm2d-96 [-1, 96, 35, 35] 192ReLU-97 [-1, 96, 35, 35] 0BasicConv2d-98 [-1, 96, 35, 35] 0Conv2d-99 [-1, 96, 35, 35] 82,944BatchNorm2d-100 [-1, 96, 35, 35] 192ReLU-101 [-1, 96, 35, 35] 0BasicConv2d-102 [-1, 96, 35, 35] 0Conv2d-103 [-1, 64, 35, 35] 18,432BatchNorm2d-104 [-1, 64, 35, 35] 128ReLU-105 [-1, 64, 35, 35] 0BasicConv2d-106 [-1, 64, 35, 35] 0InceptionA-107 [-1, 288, 35, 35] 0Conv2d-108 [-1, 384, 17, 17] 995,328BatchNorm2d-109 [-1, 384, 17, 17] 768ReLU-110 [-1, 384, 17, 17] 0BasicConv2d-111 [-1, 384, 17, 17] 0Conv2d-112 [-1, 64, 35, 35] 18,432BatchNorm2d-113 [-1, 64, 35, 35] 128ReLU-114 [-1, 64, 35, 35] 0BasicConv2d-115 [-1, 64, 35, 35] 0Conv2d-116 [-1, 96, 35, 35] 55,296BatchNorm2d-117 [-1, 96, 35, 35] 192ReLU-118 [-1, 96, 35, 35] 0BasicConv2d-119 [-1, 96, 35, 35] 0Conv2d-120 [-1, 96, 17, 17] 82,944BatchNorm2d-121 [-1, 96, 17, 17] 192ReLU-122 [-1, 96, 17, 17] 0BasicConv2d-123 [-1, 96, 17, 17] 0ReductionA-124 [-1, 768, 17, 17] 0Conv2d-125 [-1, 192, 17, 17] 147,456BatchNorm2d-126 [-1, 192, 17, 17] 384ReLU-127 [-1, 192, 17, 17] 0BasicConv2d-128 [-1, 192, 17, 17] 0Conv2d-129 [-1, 128, 17, 17] 98,304BatchNorm2d-130 [-1, 128, 17, 17] 256ReLU-131 [-1, 128, 17, 17] 0BasicConv2d-132 [-1, 128, 17, 17] 0Conv2d-133 [-1, 128, 17, 17] 114,688BatchNorm2d-134 [-1, 128, 17, 17] 256ReLU-135 [-1, 128, 17, 17] 0BasicConv2d-136 [-1, 128, 17, 17] 0Conv2d-137 [-1, 192, 17, 17] 172,032BatchNorm2d-138 [-1, 192, 17, 17] 384ReLU-139 [-1, 192, 17, 17] 0BasicConv2d-140 [-1, 192, 17, 17] 0Conv2d-141 [-1, 128, 17, 17] 98,304BatchNorm2d-142 [-1, 128, 17, 17] 256ReLU-143 [-1, 128, 17, 17] 0BasicConv2d-144 [-1, 128, 17, 17] 0Conv2d-145 [-1, 128, 17, 17] 114,688BatchNorm2d-146 [-1, 128, 17, 17] 256ReLU-147 [-1, 128, 17, 17] 0BasicConv2d-148 [-1, 128, 17, 17] 0Conv2d-149 [-1, 128, 17, 17] 114,688BatchNorm2d-150 [-1, 128, 17, 17] 256ReLU-151 [-1, 128, 17, 17] 0BasicConv2d-152 [-1, 128, 17, 17] 0Conv2d-153 [-1, 128, 17, 17] 114,688BatchNorm2d-154 [-1, 128, 17, 17] 256ReLU-155 [-1, 128, 17, 17] 0BasicConv2d-156 [-1, 128, 17, 17] 0Conv2d-157 [-1, 192, 17, 17] 172,032BatchNorm2d-158 [-1, 192, 17, 17] 384ReLU-159 [-1, 192, 17, 17] 0BasicConv2d-160 [-1, 192, 17, 17] 0Conv2d-161 [-1, 192, 17, 17] 147,456BatchNorm2d-162 [-1, 192, 17, 17] 384ReLU-163 [-1, 192, 17, 17] 0BasicConv2d-164 [-1, 192, 17, 17] 0InceptionB-165 [-1, 768, 17, 17] 0Conv2d-166 [-1, 192, 17, 17] 147,456BatchNorm2d-167 [-1, 192, 17, 17] 384ReLU-168 [-1, 192, 17, 17] 0BasicConv2d-169 [-1, 192, 17, 17] 0Conv2d-170 [-1, 160, 17, 17] 122,880BatchNorm2d-171 [-1, 160, 17, 17] 320ReLU-172 [-1, 160, 17, 17] 0BasicConv2d-173 [-1, 160, 17, 17] 0Conv2d-174 [-1, 160, 17, 17] 179,200BatchNorm2d-175 [-1, 160, 17, 17] 320ReLU-176 [-1, 160, 17, 17] 0BasicConv2d-177 [-1, 160, 17, 17] 0Conv2d-178 [-1, 192, 17, 17] 215,040BatchNorm2d-179 [-1, 192, 17, 17] 384ReLU-180 [-1, 192, 17, 17] 0BasicConv2d-181 [-1, 192, 17, 17] 0Conv2d-182 [-1, 160, 17, 17] 122,880BatchNorm2d-183 [-1, 160, 17, 17] 320ReLU-184 [-1, 160, 17, 17] 0BasicConv2d-185 [-1, 160, 17, 17] 0Conv2d-186 [-1, 160, 17, 17] 179,200BatchNorm2d-187 [-1, 160, 17, 17] 320ReLU-188 [-1, 160, 17, 17] 0BasicConv2d-189 [-1, 160, 17, 17] 0Conv2d-190 [-1, 160, 17, 17] 179,200BatchNorm2d-191 [-1, 160, 17, 17] 320ReLU-192 [-1, 160, 17, 17] 0BasicConv2d-193 [-1, 160, 17, 17] 0Conv2d-194 [-1, 160, 17, 17] 179,200BatchNorm2d-195 [-1, 160, 17, 17] 320ReLU-196 [-1, 160, 17, 17] 0BasicConv2d-197 [-1, 160, 17, 17] 0Conv2d-198 [-1, 192, 17, 17] 215,040BatchNorm2d-199 [-1, 192, 17, 17] 384ReLU-200 [-1, 192, 17, 17] 0BasicConv2d-201 [-1, 192, 17, 17] 0Conv2d-202 [-1, 192, 17, 17] 147,456BatchNorm2d-203 [-1, 192, 17, 17] 384ReLU-204 [-1, 192, 17, 17] 0BasicConv2d-205 [-1, 192, 17, 17] 0InceptionB-206 [-1, 768, 17, 17] 0Conv2d-207 [-1, 192, 17, 17] 147,456BatchNorm2d-208 [-1, 192, 17, 17] 384ReLU-209 [-1, 192, 17, 17] 0BasicConv2d-210 [-1, 192, 17, 17] 0Conv2d-211 [-1, 160, 17, 17] 122,880BatchNorm2d-212 [-1, 160, 17, 17] 320ReLU-213 [-1, 160, 17, 17] 0BasicConv2d-214 [-1, 160, 17, 17] 0Conv2d-215 [-1, 160, 17, 17] 179,200BatchNorm2d-216 [-1, 160, 17, 17] 320ReLU-217 [-1, 160, 17, 17] 0BasicConv2d-218 [-1, 160, 17, 17] 0Conv2d-219 [-1, 192, 17, 17] 215,040BatchNorm2d-220 [-1, 192, 17, 17] 384ReLU-221 [-1, 192, 17, 17] 0BasicConv2d-222 [-1, 192, 17, 17] 0Conv2d-223 [-1, 160, 17, 17] 122,880BatchNorm2d-224 [-1, 160, 17, 17] 320ReLU-225 [-1, 160, 17, 17] 0BasicConv2d-226 [-1, 160, 17, 17] 0Conv2d-227 [-1, 160, 17, 17] 179,200BatchNorm2d-228 [-1, 160, 17, 17] 320ReLU-229 [-1, 160, 17, 17] 0BasicConv2d-230 [-1, 160, 17, 17] 0Conv2d-231 [-1, 160, 17, 17] 179,200BatchNorm2d-232 [-1, 160, 17, 17] 320ReLU-233 [-1, 160, 17, 17] 0BasicConv2d-234 [-1, 160, 17, 17] 0Conv2d-235 [-1, 160, 17, 17] 179,200BatchNorm2d-236 [-1, 160, 17, 17] 320ReLU-237 [-1, 160, 17, 17] 0BasicConv2d-238 [-1, 160, 17, 17] 0Conv2d-239 [-1, 192, 17, 17] 215,040BatchNorm2d-240 [-1, 192, 17, 17] 384ReLU-241 [-1, 192, 17, 17] 0BasicConv2d-242 [-1, 192, 17, 17] 0Conv2d-243 [-1, 192, 17, 17] 147,456BatchNorm2d-244 [-1, 192, 17, 17] 384ReLU-245 [-1, 192, 17, 17] 0BasicConv2d-246 [-1, 192, 17, 17] 0InceptionB-247 [-1, 768, 17, 17] 0Conv2d-248 [-1, 192, 17, 17] 147,456BatchNorm2d-249 [-1, 192, 17, 17] 384ReLU-250 [-1, 192, 17, 17] 0BasicConv2d-251 [-1, 192, 17, 17] 0Conv2d-252 [-1, 192, 17, 17] 147,456BatchNorm2d-253 [-1, 192, 17, 17] 384ReLU-254 [-1, 192, 17, 17] 0BasicConv2d-255 [-1, 192, 17, 17] 0Conv2d-256 [-1, 192, 17, 17] 258,048BatchNorm2d-257 [-1, 192, 17, 17] 384ReLU-258 [-1, 192, 17, 17] 0BasicConv2d-259 [-1, 192, 17, 17] 0Conv2d-260 [-1, 192, 17, 17] 258,048BatchNorm2d-261 [-1, 192, 17, 17] 384ReLU-262 [-1, 192, 17, 17] 0BasicConv2d-263 [-1, 192, 17, 17] 0Conv2d-264 [-1, 192, 17, 17] 147,456BatchNorm2d-265 [-1, 192, 17, 17] 384ReLU-266 [-1, 192, 17, 17] 0BasicConv2d-267 [-1, 192, 17, 17] 0Conv2d-268 [-1, 192, 17, 17] 258,048BatchNorm2d-269 [-1, 192, 17, 17] 384ReLU-270 [-1, 192, 17, 17] 0BasicConv2d-271 [-1, 192, 17, 17] 0Conv2d-272 [-1, 192, 17, 17] 258,048BatchNorm2d-273 [-1, 192, 17, 17] 384ReLU-274 [-1, 192, 17, 17] 0BasicConv2d-275 [-1, 192, 17, 17] 0Conv2d-276 [-1, 192, 17, 17] 258,048BatchNorm2d-277 [-1, 192, 17, 17] 384ReLU-278 [-1, 192, 17, 17] 0BasicConv2d-279 [-1, 192, 17, 17] 0Conv2d-280 [-1, 192, 17, 17] 258,048BatchNorm2d-281 [-1, 192, 17, 17] 384ReLU-282 [-1, 192, 17, 17] 0BasicConv2d-283 [-1, 192, 17, 17] 0Conv2d-284 [-1, 192, 17, 17] 147,456BatchNorm2d-285 [-1, 192, 17, 17] 384ReLU-286 [-1, 192, 17, 17] 0BasicConv2d-287 [-1, 192, 17, 17] 0InceptionB-288 [-1, 768, 17, 17] 0Conv2d-289 [-1, 192, 17, 17] 147,456BatchNorm2d-290 [-1, 192, 17, 17] 384ReLU-291 [-1, 192, 17, 17] 0BasicConv2d-292 [-1, 192, 17, 17] 0Conv2d-293 [-1, 320, 8, 8] 552,960BatchNorm2d-294 [-1, 320, 8, 8] 640ReLU-295 [-1, 320, 8, 8] 0BasicConv2d-296 [-1, 320, 8, 8] 0Conv2d-297 [-1, 192, 17, 17] 147,456BatchNorm2d-298 [-1, 192, 17, 17] 384ReLU-299 [-1, 192, 17, 17] 0BasicConv2d-300 [-1, 192, 17, 17] 0Conv2d-301 [-1, 192, 17, 17] 258,048BatchNorm2d-302 [-1, 192, 17, 17] 384ReLU-303 [-1, 192, 17, 17] 0BasicConv2d-304 [-1, 192, 17, 17] 0Conv2d-305 [-1, 192, 17, 17] 258,048BatchNorm2d-306 [-1, 192, 17, 17] 384ReLU-307 [-1, 192, 17, 17] 0BasicConv2d-308 [-1, 192, 17, 17] 0Conv2d-309 [-1, 192, 8, 8] 331,776BatchNorm2d-310 [-1, 192, 8, 8] 384ReLU-311 [-1, 192, 8, 8] 0BasicConv2d-312 [-1, 192, 8, 8] 0ReductionB-313 [-1, 1280, 8, 8] 0Conv2d-314 [-1, 320, 8, 8] 409,600BatchNorm2d-315 [-1, 320, 8, 8] 640ReLU-316 [-1, 320, 8, 8] 0BasicConv2d-317 [-1, 320, 8, 8] 0Conv2d-318 [-1, 384, 8, 8] 491,520BatchNorm2d-319 [-1, 384, 8, 8] 768ReLU-320 [-1, 384, 8, 8] 0BasicConv2d-321 [-1, 384, 8, 8] 0Conv2d-322 [-1, 384, 8, 8] 442,368BatchNorm2d-323 [-1, 384, 8, 8] 768ReLU-324 [-1, 384, 8, 8] 0BasicConv2d-325 [-1, 384, 8, 8] 0Conv2d-326 [-1, 384, 8, 8] 442,368BatchNorm2d-327 [-1, 384, 8, 8] 768ReLU-328 [-1, 384, 8, 8] 0BasicConv2d-329 [-1, 384, 8, 8] 0Conv2d-330 [-1, 448, 8, 8] 573,440BatchNorm2d-331 [-1, 448, 8, 8] 896ReLU-332 [-1, 448, 8, 8] 0BasicConv2d-333 [-1, 448, 8, 8] 0Conv2d-334 [-1, 384, 8, 8] 1,548,288BatchNorm2d-335 [-1, 384, 8, 8] 768ReLU-336 [-1, 384, 8, 8] 0BasicConv2d-337 [-1, 384, 8, 8] 0Conv2d-338 [-1, 384, 8, 8] 442,368BatchNorm2d-339 [-1, 384, 8, 8] 768ReLU-340 [-1, 384, 8, 8] 0BasicConv2d-341 [-1, 384, 8, 8] 0Conv2d-342 [-1, 384, 8, 8] 442,368BatchNorm2d-343 [-1, 384, 8, 8] 768ReLU-344 [-1, 384, 8, 8] 0BasicConv2d-345 [-1, 384, 8, 8] 0Conv2d-346 [-1, 192, 8, 8] 245,760BatchNorm2d-347 [-1, 192, 8, 8] 384ReLU-348 [-1, 192, 8, 8] 0BasicConv2d-349 [-1, 192, 8, 8] 0InceptionC-350 [-1, 2048, 8, 8] 0Conv2d-351 [-1, 320, 8, 8] 655,360BatchNorm2d-352 [-1, 320, 8, 8] 640ReLU-353 [-1, 320, 8, 8] 0BasicConv2d-354 [-1, 320, 8, 8] 0Conv2d-355 [-1, 384, 8, 8] 786,432BatchNorm2d-356 [-1, 384, 8, 8] 768ReLU-357 [-1, 384, 8, 8] 0BasicConv2d-358 [-1, 384, 8, 8] 0Conv2d-359 [-1, 384, 8, 8] 442,368BatchNorm2d-360 [-1, 384, 8, 8] 768ReLU-361 [-1, 384, 8, 8] 0BasicConv2d-362 [-1, 384, 8, 8] 0Conv2d-363 [-1, 384, 8, 8] 442,368BatchNorm2d-364 [-1, 384, 8, 8] 768ReLU-365 [-1, 384, 8, 8] 0BasicConv2d-366 [-1, 384, 8, 8] 0Conv2d-367 [-1, 448, 8, 8] 917,504BatchNorm2d-368 [-1, 448, 8, 8] 896ReLU-369 [-1, 448, 8, 8] 0BasicConv2d-370 [-1, 448, 8, 8] 0Conv2d-371 [-1, 384, 8, 8] 1,548,288BatchNorm2d-372 [-1, 384, 8, 8] 768ReLU-373 [-1, 384, 8, 8] 0BasicConv2d-374 [-1, 384, 8, 8] 0Conv2d-375 [-1, 384, 8, 8] 442,368BatchNorm2d-376 [-1, 384, 8, 8] 768ReLU-377 [-1, 384, 8, 8] 0BasicConv2d-378 [-1, 384, 8, 8] 0Conv2d-379 [-1, 384, 8, 8] 442,368BatchNorm2d-380 [-1, 384, 8, 8] 768ReLU-381 [-1, 384, 8, 8] 0BasicConv2d-382 [-1, 384, 8, 8] 0Conv2d-383 [-1, 192, 8, 8] 393,216BatchNorm2d-384 [-1, 192, 8, 8] 384ReLU-385 [-1, 192, 8, 8] 0BasicConv2d-386 [-1, 192, 8, 8] 0InceptionC-387 [-1, 2048, 8, 8] 0Linear-388 [-1, 4] 8,196
================================================================
Total params: 21,793,764
Trainable params: 21,793,764
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 1.02
Forward/backward pass size (MB): 292.53
Params size (MB): 83.14
Estimated Total Size (MB): 376.69
----------------------------------------------------------------
三、 训练模型
1. 编写训练函数
def train(dataloader,model,loss_fn,optimizer):size=len(dataloader.dataset) #训练集的大小num_batches=len(dataloader) #批次数目train_loss,train_acc=0,0 #初始化训练损失和正确率for x,y in dataloader: #获取图片及其标签x,y=x.to(device),y.to(device)#计算预测误差pred=model(x) #网络输出loss=loss_fn(pred,y) #计算网络输出和真实值之间的差距,二者差值即为损失#反向传播optimizer.zero_grad() #grad属性归零loss.backward() #反向传播optimizer.step() #每一步自动更新#记录acc与losstrain_acc+=(pred.argmax(1)==y).type(torch.float).sum().item()train_loss+=loss.item()train_acc/=sizetrain_loss/=num_batchesreturn train_acc,train_loss
2. 编写测试函数
测试函数和训练函数大致相同,但是由于不进行梯度下降对网络权重进行更新,所以不需要传入优化器
#测试函数
def test(dataloader,model,loss_fn):size=len(dataloader.dataset) #测试集的大小num_batches=len(dataloader) #批次数目test_loss,test_acc=0,0#当不进行训练时,停止梯度更新,节省计算内存消耗with torch.no_grad():for imgs,target in dataloader:imgs,target=imgs.to(device),target.to(device)#计算losstarget_pred=model(imgs)loss=loss_fn(target_pred,target)test_loss+=loss.item()test_acc+=(target_pred.argmax(1)==target).type(torch.float).sum().item()test_acc/=sizetest_loss/=num_batchesreturn test_acc,test_loss
3. 正式训练
import copy
opt=torch.optim.Adam(model.parameters(),lr=1e-4) #创建优化器,并设置学习率
loss_fn=nn.CrossEntropyLoss() #创建损失函数epochs=50train_loss=[]
train_acc=[]
test_loss=[]
test_acc=[]best_acc=0 #设置一个最佳准确率,作为最佳模型的判别指标for epoch in range(epochs):model.train()epoch_train_acc,epoch_train_loss=train(train_dl,model,loss_fn,opt)model.eval()epoch_test_acc,epoch_test_loss=test(test_dl,model,loss_fn)#保存最佳模型到J9_modelif epoch_test_acc>best_acc:best_acc=epoch_test_accJ9_model=copy.deepcopy(model)train_acc.append(epoch_train_acc)train_loss.append(epoch_train_loss)test_acc.append(epoch_test_acc)test_loss.append(epoch_test_loss)#获取当前学习率lr=opt.state_dict()['param_groups'][0]['lr']template=('Epoch:{:2d},Train_acc:{:.1f}%,Train_loss:{:.3f},Test_acc:{:.1f}%,Test_loss:{:.3f},Lr:{:.2E}')print(template.format(epoch+1,epoch_train_acc*100,epoch_train_loss,epoch_test_acc*100,epoch_test_loss,lr))#保存最佳模型到文件中
PATH=r'D:\THE MNIST DATABASE\J-series\J9_model.pth'
torch.save(model.state_dict(),PATH)
运行结果:
Epoch: 1,Train_acc:77.2%,Train_loss:0.634,Test_acc:89.3%,Test_loss:0.443,Lr:1.00E-04
Epoch: 2,Train_acc:78.3%,Train_loss:0.596,Test_acc:88.9%,Test_loss:0.346,Lr:1.00E-04
Epoch: 3,Train_acc:82.3%,Train_loss:0.477,Test_acc:68.9%,Test_loss:1.772,Lr:1.00E-04
Epoch: 4,Train_acc:83.1%,Train_loss:0.510,Test_acc:88.0%,Test_loss:0.358,Lr:1.00E-04
Epoch: 5,Train_acc:84.8%,Train_loss:0.443,Test_acc:91.1%,Test_loss:0.213,Lr:1.00E-04
Epoch: 6,Train_acc:83.9%,Train_loss:0.432,Test_acc:88.4%,Test_loss:0.301,Lr:1.00E-04
Epoch: 7,Train_acc:87.8%,Train_loss:0.368,Test_acc:88.4%,Test_loss:0.300,Lr:1.00E-04
Epoch: 8,Train_acc:87.9%,Train_loss:0.371,Test_acc:90.2%,Test_loss:0.358,Lr:1.00E-04
Epoch: 9,Train_acc:88.6%,Train_loss:0.361,Test_acc:92.9%,Test_loss:0.183,Lr:1.00E-04
Epoch:10,Train_acc:88.9%,Train_loss:0.310,Test_acc:91.6%,Test_loss:0.277,Lr:1.00E-04
Epoch:11,Train_acc:89.8%,Train_loss:0.306,Test_acc:94.2%,Test_loss:0.174,Lr:1.00E-04
Epoch:12,Train_acc:88.0%,Train_loss:0.357,Test_acc:92.0%,Test_loss:0.301,Lr:1.00E-04
Epoch:13,Train_acc:92.1%,Train_loss:0.246,Test_acc:92.9%,Test_loss:0.266,Lr:1.00E-04
Epoch:14,Train_acc:91.2%,Train_loss:0.269,Test_acc:94.2%,Test_loss:0.163,Lr:1.00E-04
Epoch:15,Train_acc:90.8%,Train_loss:0.294,Test_acc:87.1%,Test_loss:0.331,Lr:1.00E-04
Epoch:16,Train_acc:91.4%,Train_loss:0.251,Test_acc:93.8%,Test_loss:0.177,Lr:1.00E-04
Epoch:17,Train_acc:92.9%,Train_loss:0.214,Test_acc:90.7%,Test_loss:0.276,Lr:1.00E-04
Epoch:18,Train_acc:88.9%,Train_loss:0.315,Test_acc:92.9%,Test_loss:0.215,Lr:1.00E-04
Epoch:19,Train_acc:92.2%,Train_loss:0.231,Test_acc:91.6%,Test_loss:0.251,Lr:1.00E-04
Epoch:20,Train_acc:92.9%,Train_loss:0.228,Test_acc:91.1%,Test_loss:0.229,Lr:1.00E-04
Epoch:21,Train_acc:93.0%,Train_loss:0.211,Test_acc:90.7%,Test_loss:0.245,Lr:1.00E-04
Epoch:22,Train_acc:93.2%,Train_loss:0.235,Test_acc:94.2%,Test_loss:0.150,Lr:1.00E-04
Epoch:23,Train_acc:91.6%,Train_loss:0.259,Test_acc:83.1%,Test_loss:0.375,Lr:1.00E-04
Epoch:24,Train_acc:93.2%,Train_loss:0.209,Test_acc:93.3%,Test_loss:0.219,Lr:1.00E-04
Epoch:25,Train_acc:93.2%,Train_loss:0.183,Test_acc:94.2%,Test_loss:0.181,Lr:1.00E-04
Epoch:26,Train_acc:91.4%,Train_loss:0.242,Test_acc:92.9%,Test_loss:0.249,Lr:1.00E-04
Epoch:27,Train_acc:92.0%,Train_loss:0.220,Test_acc:93.3%,Test_loss:0.234,Lr:1.00E-04
Epoch:28,Train_acc:94.0%,Train_loss:0.195,Test_acc:94.7%,Test_loss:0.141,Lr:1.00E-04
Epoch:29,Train_acc:93.1%,Train_loss:0.189,Test_acc:95.1%,Test_loss:0.191,Lr:1.00E-04
Epoch:30,Train_acc:93.2%,Train_loss:0.226,Test_acc:95.6%,Test_loss:0.204,Lr:1.00E-04
Epoch:31,Train_acc:94.3%,Train_loss:0.147,Test_acc:93.8%,Test_loss:0.219,Lr:1.00E-04
Epoch:32,Train_acc:95.4%,Train_loss:0.150,Test_acc:93.3%,Test_loss:0.225,Lr:1.00E-04
Epoch:33,Train_acc:94.9%,Train_loss:0.165,Test_acc:95.6%,Test_loss:0.169,Lr:1.00E-04
Epoch:34,Train_acc:95.4%,Train_loss:0.150,Test_acc:94.7%,Test_loss:0.193,Lr:1.00E-04
Epoch:35,Train_acc:95.7%,Train_loss:0.100,Test_acc:93.8%,Test_loss:0.145,Lr:1.00E-04
Epoch:36,Train_acc:94.4%,Train_loss:0.198,Test_acc:92.9%,Test_loss:0.167,Lr:1.00E-04
Epoch:37,Train_acc:95.3%,Train_loss:0.163,Test_acc:94.7%,Test_loss:0.110,Lr:1.00E-04
Epoch:38,Train_acc:96.1%,Train_loss:0.120,Test_acc:93.3%,Test_loss:0.177,Lr:1.00E-04
Epoch:39,Train_acc:94.6%,Train_loss:0.197,Test_acc:94.2%,Test_loss:0.196,Lr:1.00E-04
Epoch:40,Train_acc:95.4%,Train_loss:0.132,Test_acc:96.0%,Test_loss:0.117,Lr:1.00E-04
Epoch:41,Train_acc:96.6%,Train_loss:0.115,Test_acc:96.9%,Test_loss:0.116,Lr:1.00E-04
Epoch:42,Train_acc:96.1%,Train_loss:0.113,Test_acc:95.6%,Test_loss:0.119,Lr:1.00E-04
Epoch:43,Train_acc:97.1%,Train_loss:0.103,Test_acc:93.3%,Test_loss:0.218,Lr:1.00E-04
Epoch:44,Train_acc:94.9%,Train_loss:0.168,Test_acc:89.3%,Test_loss:0.251,Lr:1.00E-04
Epoch:45,Train_acc:97.3%,Train_loss:0.094,Test_acc:93.3%,Test_loss:0.180,Lr:1.00E-04
Epoch:46,Train_acc:97.4%,Train_loss:0.086,Test_acc:94.7%,Test_loss:0.210,Lr:1.00E-04
Epoch:47,Train_acc:95.3%,Train_loss:0.125,Test_acc:95.1%,Test_loss:0.200,Lr:1.00E-04
Epoch:48,Train_acc:95.9%,Train_loss:0.131,Test_acc:94.7%,Test_loss:0.159,Lr:1.00E-04
Epoch:49,Train_acc:95.0%,Train_loss:0.147,Test_acc:93.8%,Test_loss:0.218,Lr:1.00E-04
Epoch:50,Train_acc:97.6%,Train_loss:0.076,Test_acc:95.6%,Test_loss:0.172,Lr:1.00E-04
四、 结果可视化
1. Loss与Accuracy图
import matplotlib.pyplot as plt
#隐藏警告
import warnings
warnings.filterwarnings("ignore") #忽略警告信息
plt.rcParams['font.sans-serif']=['SimHei'] #正常显示中文标签
plt.rcParams['axes.unicode_minus']=False #正常显示负号
plt.rcParams['figure.dpi']=300 #分辨率epochs_range=range(epochs)
plt.figure(figsize=(12,3))plt.subplot(1,2,1)
plt.plot(epochs_range,train_acc,label='Training Accuracy')
plt.plot(epochs_range,test_acc,label='Test Accuracy')
plt.legend(loc='lower right')
plt.title('Training and Validation Accuracy')plt.subplot(1,2,2)
plt.plot(epochs_range,train_loss,label='Training Loss')
plt.plot(epochs_range,test_loss,label='Test Loss')
plt.legend(loc='upper right')
plt.title('Training and Validation Loss')
plt.show()
运行结果:
2. 指定图片进行预测
from PIL import Imageclasses=list(total_data.class_to_idx)def predict_one_image(image_path,model,transform,classes):test_img=Image.open(image_path).convert('RGB')plt.imshow(test_img) #展示预测的图片test_img=transform(test_img)img=test_img.to(device).unsqueeze(0)model.eval()output=model(img)_,pred=torch.max(output,1)pred_class=classes[pred]print(f'预测结果是:{pred_class}')
预测图片:
#预测训练集中的某张照片
predict_one_image(image_path=r'D:\THE MNIST DATABASE\weather_photos\shine\shine15.jpg',model=model,transform=train_transforms,classes=classes)
运行结果:
预测结果是:shine
五、心得体会
在本周的项目训练中,手动搭建了InceptionV3模型,加深了对该模型的理解。
相关文章:

第J9周:Inception v3算法实战与解析(pytorch版)
>- **🍨 本文为[🔗365天深度学习训练营]中的学习记录博客** >- **🍖 原作者:[K同学啊]** 📌本周任务:📌 了解并学习InceptionV3相对与InceptionV1有哪些改进的地方 使用Inception完成天气…...
如何封装一个axios,封装axios有哪些好处
什么是Axios Axios 是一个基于 Promise 的 HTTP 客户端,用于在浏览器和 Node.js 中发送异步网络请求。它简化了发送 GET、POST、PUT、DELETE 等请求的过程,并且支持请求拦截、响应拦截、取消请求和自动处理 JSON 数据等功能。 为什么要封装Axios 封装…...
java的批量update
这个问题挺有代表性的,今天拿出来给大家一起分享一下,希望对你会有所帮助。 1 案发现场 有一天上午,在我的知识星球群里,有位小伙伴问了我一个问题:批量更新你们一般是使用when case吗?还是有其他的批量更…...
go语言连续监控事件并回调处理
前言 go语言中使用回调函数处理事件:事件监测部分(如无限循环中的事件检测逻辑)可以独立于具体的业务处理逻辑。这使得代码的各个部分更加清晰,易于理解和维护。如果需要更改事件处理的方式,只需要修改注册的回调函数…...

1.探索WebSocket:实时网络的心跳!
序言 你可能听说过"WebSokcet"这个词,感觉它好像很高深,但其实它是一个超级酷的小工具,让我们在Web应用里实现实时通信。想象一下,你可以像聊天一样,在浏览器和服务器之间来回“畅聊“,没有延迟…...

uniapp学习(010-2 实现抖音小程序上线)
零基础入门uniapp Vue3组合式API版本到咸虾米壁纸项目实战,开发打包微信小程序、抖音小程序、H5、安卓APP客户端等 总时长 23:40:00 共116P 此文章包含第113p的内容 文章目录 抖音小程序下载抖音开发者工具先去开发者工具里进行测试 抖音开放平台配置开始打包上传…...
测试和实施面试题收集
前端+测试+运维+算法综合 前端部分面试题 判断第二个日期比第一个日期大 如何用脚本判断用户输入的的字符串是下面的时间格式2004-11-21 必须要保证用户的输入是此格式,并且是时间,比如说月份不大于12等等,另外我需要用户输入两个,并且后一个要比前一个晚,只允许用JAVASCR…...

【Vue3】一文全览基础语法-案例程序及配图版
文章目录 Vue应用基本结构模块化开发ref和reactive绑定事件 v-on 简写显示和隐藏 v-show条件渲染 v-if动态属性绑定 v-bind 简写:遍历数组或对象 v-for双向数据绑定 v-model渲染数据 v-text 和 v-html计算属性 computed侦听器 watch自动侦听器 watchEffect 本文示例…...
【OpenSearch】安装部署OpenSearch和OpenSearch-Dashboard
一、安装OpenSearch 1.禁用主机swap提高性能 sudo swapoff -a2.增加OpenSearch可用的内存映射数量。 编辑sysctl配置文件 sudo vi /etc/sysctl.conf在文件中添加一行来定义所需的值, 或者如果键存在,则更改值,然后保存您的更改。 vm.max…...
【系统架构设计师】2023年真题论文: 论软件可靠性评价的设计与实现(包括和素材和论文)
更多内容请见: 备考系统架构设计师-专栏介绍和目录 文章目录 真题题目(2023年 试题3)论文素材参考论文参考摘要正文总结真题题目(2023年 试题3) 软件可靠性评价是利用可靠性数学模型、统计技术等,对软件失效数据进行处理,评估和预测软件可靠性的过程,包括选择模型、收集数…...

教程:使用 InterBase Express 访问数据库(二)
1. 添加数据模块(IBX 通用教程) 本节将创建一个数据模块(TDataModule),这是一种包含应用程序使用的非可视组件的表单。 以下是完全配置好的 TDataModule 的视图: 创建 TDataModule 后,您可以在其他表单中使用这个数据模块。 2. 添加 TDataModule 要将数据模块添加到…...

Windows密码的网络认证---基于挑战响应认证的NTLM协议
一,网络认证NTLM协议简介 在平时的测试中,经常会碰到处于工作组的计算机,处于工作组的计算机之间是无法建立一个可信的信托机构的,只能是点对点进行信息的传输。 举个例子就是,主机A想要访问主机B上的资源,…...
fpga 常量无法改变
parameter LED_ON_PERIOD0 n0*CLOCK_FREQ; parameter LED_OFF_PERIOD0 (2-n0)*CLOCK_FREQ;这种代码的变量不会无法内部修改 需要改成reg形式并在这种逻辑里面修改变量 always (posedge clk_ref or negedge sys_rst_n) begin虽然是并行逻辑 但是变量尽量还是先赋值从硬件上并…...

【HarmonyOS NEXT】如何给未知类型对象定义类型并使用递归打印所有的Key
关键词:嵌套对象、类型、递归、未知类型 目录 使用 Record 与 ESObject 定义未知对象类型 递归打印未知类型对象的key 在鸿蒙应用开发中,所有的数据都必须定义类型,且不存在 any 类型,那么我们当遇到 key 值可能随时变化的情况…...

RuoYi 样例框架运行步骤(测试项目自用,同学可自取)
目录 后台 API 运行导入,下载包端口号mysql 准备运行 PC(电脑端)运行安装 nodejs安装 yarn 及其依赖,启动服务登录admin(admin123) 或 ry(admin123) App(移动)运行下载 HBuilderX运行app运行注意࿱…...
Java进程CPU飙高排查
一、首先通过top指令查看当前占用CPU较高的进程pid 二、查看当前进程消耗资源的线程PID: top -Hp pid 使用 top -Hp <pid> 命令(pid为Java进程的id号)查看该Java进程内所有线程的资源占用情况。 三、通过print命令将线程pid转为16进…...
conda的对应环境下安装cuda11.0和对应的cudnn
在 Conda 环境中安装 CUDA 11.0 和对应的 cuDNN,可以按照以下步骤进行: 一. 环境配置 1. 创建 Conda 环境 首先,创建一个新的 Conda 环境(可选): conda create -n myenv python3.8 conda activate myen…...

微服务透传日志traceId
问题 在微服务架构中,一次业务执行完可能需要跨多个服务,这个时候,我们想看到业务完整的日志信息,就要从各个服务中获取,即便是使用了ELK把日志收集到一起,但如果不做处理,也是无法完整把一次业…...
【自然语言处理与大模型】大模型(LLM)基础知识②
(1)LLaMA输入句子的长度理论上可以无限长吗? 理论上来说,LLM大模型可以处理任意长度的输入句子,但实际上存在一些限制。下面是一些需要考虑的因素: 1. 计算资源:生成长句子需要更多的计算资源&a…...

新能源汽车的未来:车载电源与V2G技术的前景
近年来,新能源汽车在全球市场上发展迅速,尤其是在中国,新能源汽车的月销量已经超过了燃油车。随着新能源技术的不断发展,新能源汽车不仅仅是作为出行工具,而逐渐成为“移动能源站”。本文将探讨电动汽车的车载外放电功…...
KubeSphere 容器平台高可用:环境搭建与可视化操作指南
Linux_k8s篇 欢迎来到Linux的世界,看笔记好好学多敲多打,每个人都是大神! 题目:KubeSphere 容器平台高可用:环境搭建与可视化操作指南 版本号: 1.0,0 作者: 老王要学习 日期: 2025.06.05 适用环境: Ubuntu22 文档说…...

深入浅出Asp.Net Core MVC应用开发系列-AspNetCore中的日志记录
ASP.NET Core 是一个跨平台的开源框架,用于在 Windows、macOS 或 Linux 上生成基于云的新式 Web 应用。 ASP.NET Core 中的日志记录 .NET 通过 ILogger API 支持高性能结构化日志记录,以帮助监视应用程序行为和诊断问题。 可以通过配置不同的记录提供程…...

使用VSCode开发Django指南
使用VSCode开发Django指南 一、概述 Django 是一个高级 Python 框架,专为快速、安全和可扩展的 Web 开发而设计。Django 包含对 URL 路由、页面模板和数据处理的丰富支持。 本文将创建一个简单的 Django 应用,其中包含三个使用通用基本模板的页面。在此…...
React hook之useRef
React useRef 详解 useRef 是 React 提供的一个 Hook,用于在函数组件中创建可变的引用对象。它在 React 开发中有多种重要用途,下面我将全面详细地介绍它的特性和用法。 基本概念 1. 创建 ref const refContainer useRef(initialValue);initialValu…...

如何在看板中有效管理突发紧急任务
在看板中有效管理突发紧急任务需要:设立专门的紧急任务通道、重新调整任务优先级、保持适度的WIP(Work-in-Progress)弹性、优化任务处理流程、提高团队应对突发情况的敏捷性。其中,设立专门的紧急任务通道尤为重要,这能…...
oracle与MySQL数据库之间数据同步的技术要点
Oracle与MySQL数据库之间的数据同步是一个涉及多个技术要点的复杂任务。由于Oracle和MySQL的架构差异,它们的数据同步要求既要保持数据的准确性和一致性,又要处理好性能问题。以下是一些主要的技术要点: 数据结构差异 数据类型差异ÿ…...
Python爬虫(二):爬虫完整流程
爬虫完整流程详解(7大核心步骤实战技巧) 一、爬虫完整工作流程 以下是爬虫开发的完整流程,我将结合具体技术点和实战经验展开说明: 1. 目标分析与前期准备 网站技术分析: 使用浏览器开发者工具(F12&…...
在鸿蒙HarmonyOS 5中使用DevEco Studio实现录音机应用
1. 项目配置与权限设置 1.1 配置module.json5 {"module": {"requestPermissions": [{"name": "ohos.permission.MICROPHONE","reason": "录音需要麦克风权限"},{"name": "ohos.permission.WRITE…...

ArcGIS Pro制作水平横向图例+多级标注
今天介绍下载ArcGIS Pro中如何设置水平横向图例。 之前我们介绍了ArcGIS的横向图例制作:ArcGIS横向、多列图例、顺序重排、符号居中、批量更改图例符号等等(ArcGIS出图图例8大技巧),那这次我们看看ArcGIS Pro如何更加快捷的操作。…...

招商蛇口 | 执笔CID,启幕低密生活新境
作为中国城市生长的力量,招商蛇口以“美好生活承载者”为使命,深耕全球111座城市,以央企担当匠造时代理想人居。从深圳湾的开拓基因到西安高新CID的战略落子,招商蛇口始终与城市发展同频共振,以建筑诠释对土地与生活的…...