深度学习基础--ResNet网络的讲解,ResNet50的复现(pytorch)以及用复现的ResNet50做鸟类图像分类
- 🍨 本文为🔗365天深度学习训练营 中的学习记录博客
- 🍖 原作者:K同学啊
前言
- 如果说最经典的神经网络,
ResNet肯定是一个,这篇文章是本人学习ResNet的学习笔记,并且用pytorch复现了ResNet50,后面用它做了一个鸟类图像分类demo; - 欢迎收藏 + 关注,本人将会持续更新
文章目录
- ResNet网络讲解
- 什么是ResNet?
- ResNet神经网络突出点
- 为什么采用残差连接
- 模型退化、梯度消失、梯度爆炸
- 解决方法
- 残差网络
- ResNet-50复现
- 1、导入数据
- 1、导入库
- 2、查看数据信息和导入数据
- 3、展示数据
- 4、数据导入
- 5、数据划分
- 6、动态加载数据
- 2、构建ResNet-50网络
- 3、模型训练
- 1、构建训练集
- 2、构建测试集
- 3、设置超参数
- 4、模型训练
- 5、结果可视化
- 参考资料
ResNet网络讲解
什么是ResNet?
ResNet网络是CNN的经典网络架构,是有大神何凯明提出的,主要为了解决随着网络的加深而引起的“ 退化 ”问题,主要用于图像分类。
可以说在如今的CV领域里面,大部分网络结构都有参考ResNet网络思想,无论是在图像分类、目标检测、图像识别上,甚至在Transformer网络模型中,也融合了ResNet网络的思想。
ResNet神经网络突出点
- 网络结构超过1000层:
- ❔ ❔ 超过1000层网络结构不是很容易么? 小编在学习深度学习的时候,曾经遇到过这样一个问题,有时候加深网络结构,反而在准确率、损失率上更差,这种现象称为模型“ 退化 ”现象,而ResNet的残差连接可以保证下一层的输出不会比输入差,从而可以加深网络结构。
- 提出残差模块(residual):这个是ResNet的核心;
- 采用大量的归一化在卷积层与激活函数之间.
为什么采用残差连接
模型退化、梯度消失、梯度爆炸
- 👉 模型退化:指随着网络层数的加深,其效果出现下降趋势,不如层数少的情况。如论文中图示,56层效果不如20层效果;

- 👉 梯度消失:这个是指随着网络层数的增加,反向传播,梯度更新的时候可能会造成前面几层的梯度很小、接近于0,这就会导致权重的更新会特别慢,效率低下。
- 👉 梯度爆炸:指随着网络层数的增加,在反向传播的时候,梯度变得非常大,从而在更新权重的时候,权重值发生大幅度变化,这可能导致网络不稳定,甚至是无法收敛。
解决方法
- 梯度消失、梯度爆炸:在数据预处理和网络层之间加入:BN层(Batch Normalization),从而对数据进行归一化;
- 模型退化:采用残差连接,如论文图,随着网络层数的增加,损失率更低了。

残差网络
在讲述前,这里先讲述一下恒等映射的概念:
- 恒等映射核心是复制,就是复制网络层,什么也不干。
➿ 可以这么理解:假设在一种网络A的后面添加几层形成新的网络B,如果A的输出经过新的层级变成B的输出没有发送变化,那么就可以说网络A和网络B的错误率是相等的,这样就确保了加深的网络层不会比之前的网络层效果差。
resent网络说明了,更深的网络结构可以有更好的效果,而解决这个的核心就是残差连接,网络结果如图所示:

上图就是何凯明提出的残差结构,这种结构实现了恒等映射,网络层的输出由两大模块组成:
- 其一:正常的卷积层;
- 其二:有一个分支输出到连接上,这个输出值就是输入的值;
最终结果就是:卷积层输出+分支输出,数学公式如下:

其中F(x)是卷积层的输出,x是分支的输入值。
极端情况:F(x)的网络层中,所有参数都为0,那么H(x)就是恒等映射。这样就确保了最后的错误率不会因为网络层的增加而导致变大。
在ResNet中有两个不同的ResNet模块,如图所示:

左边:
- 有两层残差单元,输出通道都是3*3
- 使用情况:用于较浅的ResNet网络。
右边:
- 三层残差单元,称为blottlenck模块,作用是:现用
1*1卷积进行降维,后用3*3卷积进行特征特权,最后用1*1卷积恢复原来的维度,这个可以很好的减少参数个数,用于较深的神经网络。
下图参考一个csdn大神笔记图:

CNN参数计算公式:卷积核尺寸 * 卷积核速度 * 卷积核组数 == 卷积核尺寸 * 输入特征矩阵深度 * 输出矩阵深度。
ResNet经典的网络结构有ResNet-50,ResNet-101等,本文将用pytorch复现ResNet-50,并用其做一个简单的实验–鸟类图片分类。
ResNet-50网络结果如下:

ResNet-50复现
1、导入数据
1、导入库
import torch
import torch.nn as nn
import torchvision
import numpy as np
import os, PIL, pathlib # 设置设备
device = "cuda" if torch.cuda.is_available() else "cpu"device
'cuda'
2、查看数据信息和导入数据
数据目录有两个文件:一个数据文件,一个权重。
data_dir = "./data/bird_photos"data_dir = pathlib.Path(data_dir)# 类别数量
classnames = [str(path).split('/')[0] for path in os.listdir(data_dir)]classnames
['Bananaquit', 'Black Skimmer', 'Black Throated Bushtiti', 'Cockatoo']
3、展示数据
import matplotlib.pylab as plt
from PIL import Image # 获取文件名称
data_path_name = "./data/bird_photos/Bananaquit/"
data_path_list = [f for f in os.listdir(data_path_name) if f.endswith(('jpg', 'png'))]# 创建画板
fig, axes = plt.subplots(2, 8, figsize=(16, 6))for ax, img_file in zip(axes.flat, data_path_list):path_name = os.path.join(data_path_name, img_file)img = Image.open(path_name) # 打开# 显示ax.imshow(img)ax.axis('off')plt.show()

4、数据导入
from torchvision import transforms, datasets # 数据统一格式
img_height = 224
img_width = 224 data_tranforms = transforms.Compose([transforms.Resize([img_height, img_width]),transforms.ToTensor(),transforms.Normalize( # 归一化mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225] )
])# 加载所有数据
total_data = datasets.ImageFolder(root="./data/bird_photos", transform=data_tranforms)
5、数据划分
# 大小 8 : 2
train_size = int(len(total_data) * 0.8)
test_size = len(total_data) - train_size train_data, test_data = torch.utils.data.random_split(total_data, [train_size, test_size])
6、动态加载数据
batch_size = 32 train_dl = torch.utils.data.DataLoader(train_data,batch_size=batch_size,shuffle=True
)test_dl = torch.utils.data.DataLoader(test_data,batch_size=batch_size,shuffle=False
)
# 查看数据维度
for data, labels in train_dl:print("data shape[N, C, H, W]: ", data.shape)print("labels: ", labels)break
data shape[N, C, H, W]: torch.Size([32, 3, 224, 224])
labels: tensor([0, 1, 0, 1, 2, 1, 1, 0, 2, 2, 1, 2, 1, 3, 1, 2, 2, 2, 2, 1, 2, 1, 2, 2,0, 3, 3, 3, 3, 2, 3, 3])
2、构建ResNet-50网络

import torch.nn.functional as F# 定义残差模块一,这个用于处理输入和输出通道一样的情况
'''
卷积核大小:1 3 1
核心特点:尺寸不变:输入和输出的尺寸保持一致。 没有下采样:没有使用步长大于1的卷积操作,因此没有改变特征图的空间尺寸
'''
class Identity_block(nn.Module):def __init__(self, in_channels, kernel_size, filters):super(Identity_block, self).__init__()# 输出通道filter1, filter2, filter3 = filters# 卷积层一self.conv1 = nn.Conv2d(in_channels, filter1, kernel_size=1, stride=1)self.bn1 = nn.BatchNorm2d(filter1)# 卷积层2self.conv2 = nn.Conv2d(filter1, filter2, kernel_size=kernel_size, padding=1) # 通过卷积输入输出公式发现,padding=1,可以保证输入和输出尺寸相同self.bn2 = nn.BatchNorm2d(filter2)# 卷积层3self.conv3 = nn.Conv2d(filter2, filter3, kernel_size=1, stride=1)self.bn3 = nn.BatchNorm2d(filter3)def forward(self, x):# 记录原始值xx = xx = F.relu(self.bn1(self.conv1(x)))x = F.relu(self.bn2(self.conv2(x)))x = self.bn3(self.conv3(x))# 残差连接,输入、输出维度不变x += xxx = F.relu(x)return x # 定义卷积模块二:用于处理输入和输出不一样的情况
'''
* 卷积核还是:1 3 1
* stride=2
* 这里的分支是采用一个Conv2D,和一个归一化BN层,也是为了处理数据维度吧, 这种维度的变化,可以用ai举例子核心特点:尺寸变化,stride=2降维
'''
class ConvBlock(nn.Module):def __init__(self, in_channels, kernel_size, filters, stride=2):super(ConvBlock, self).__init__()filter1, filter2, filter3= filters# 卷积层1self.conv1 = nn.Conv2d(in_channels, filter1, kernel_size=1, stride=stride)self.bn1 = nn.BatchNorm2d(filter1)# 卷积2self.conv2 = nn.Conv2d(filter1, filter2, kernel_size=kernel_size, padding=1) # 需要维持维度不变self.bn2 = nn.BatchNorm2d(filter2)# 卷积3self.conv3 = nn.Conv2d(filter2, filter3, kernel_size=1, stride=1) # stride = 1,维持通道不变self.bn3 = nn.BatchNorm2d(filter3)# 用于匹配维度的shortcut卷积,这个就是上面Identity_block的x分支self.shortcut = nn.Conv2d(in_channels, filter3, kernel_size=1, stride=stride)self.shortcut_bn = nn.BatchNorm2d(filter3)def forward(self, x):xx = xx = F.relu(self.bn1(self.conv1(x)))x = F.relu(self.bn2(self.conv2(x)))x = self.bn3(self.conv3(x))temp = self.shortcut_bn(self.shortcut(xx))x += tempx = F.relu(x)return x # 定义ResNet50
class ResNet50(nn.Module):def __init__(self, classes): # 类别数量super().__init__()# 头顶self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3)self.bn1 = nn.BatchNorm2d(64)self.max_pool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)# 第一部分self.part1_1 = ConvBlock(64, 3, [64, 64, 256], stride=1)self.part1_2 = Identity_block(256, 3, [64, 64, 256])self.part1_3 = Identity_block(256, 3, [64, 64, 256])# 第二部分self.part2_1 = ConvBlock(256, 3, [128, 128, 512])self.part2_2 = Identity_block(512, 3, [128, 128, 512])self.part2_3 = Identity_block(512, 3, [128, 128, 512])self.part2_4 = Identity_block(512, 3, [128, 128, 512])# 第三部分self.part3_1 = ConvBlock(512, 3, [256, 256, 1024])self.part3_2 = Identity_block(1024, 3, [256, 256, 1024])self.part3_3 = Identity_block(1024, 3, [256, 256, 1024])self.part3_4 = Identity_block(1024, 3, [256, 256, 1024])self.part3_5 = Identity_block(1024, 3, [256, 256, 1024])self.part3_6 = Identity_block(1024, 3, [256, 256, 1024])# 第四部分self.part4_1 = ConvBlock(1024, 3, [512, 512, 2048])self.part4_2 = Identity_block(2048, 3, [512, 512, 2048])self.part4_3 = Identity_block(2048, 3, [512, 512, 2048])# 平均池化self.avg_pool = nn.AvgPool2d(kernel_size=7)# 全连接self.fn1 = nn.Linear(2048, classes)def forward(self, x):# 头部x = F.relu(self.bn1(self.conv1(x)))x = self.max_pool(x)x = self.part1_1(x)x = self.part1_2(x)x = self.part1_3(x)x = self.part2_1(x)x = self.part2_2(x)x = self.part2_3(x)x = self.part2_4(x)x = self.part3_1(x)x = self.part3_2(x)x = self.part3_3(x)x = self.part3_4(x)x = self.part3_5(x)x = self.part3_6(x)x = self.part4_1(x)x = self.part4_2(x)x = self.part4_3(x)x = self.avg_pool(x)x = x.view(x.size(0), -1) # 扁平化x = self.fn1(x)return x model = ResNet50(classes=len(classnames)).to(device)model
ResNet50((conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3))(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(max_pool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)(part1_1): ConvBlock((conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1))(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1))(bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(shortcut): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1))(shortcut_bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True))(part1_2): Identity_block((conv1): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1))(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1))(bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True))(part1_3): Identity_block((conv1): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1))(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1))(bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True))(part2_1): ConvBlock((conv1): Conv2d(256, 128, kernel_size=(1, 1), stride=(2, 2))(bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1))(bn3): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(shortcut): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2))(shortcut_bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True))(part2_2): Identity_block((conv1): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1))(bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1))(bn3): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True))(part2_3): Identity_block((conv1): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1))(bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1))(bn3): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True))(part2_4): Identity_block((conv1): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1))(bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1))(bn3): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True))(part3_1): ConvBlock((conv1): Conv2d(512, 256, kernel_size=(1, 1), stride=(2, 2))(bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1))(bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(shortcut): Conv2d(512, 1024, kernel_size=(1, 1), stride=(2, 2))(shortcut_bn): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True))(part3_2): Identity_block((conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1))(bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1))(bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True))(part3_3): Identity_block((conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1))(bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1))(bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True))(part3_4): Identity_block((conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1))(bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1))(bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True))(part3_5): Identity_block((conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1))(bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1))(bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True))(part3_6): Identity_block((conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1))(bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1))(bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True))(part4_1): ConvBlock((conv1): Conv2d(1024, 512, kernel_size=(1, 1), stride=(2, 2))(bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv3): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1))(bn3): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(shortcut): Conv2d(1024, 2048, kernel_size=(1, 1), stride=(2, 2))(shortcut_bn): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True))(part4_2): Identity_block((conv1): Conv2d(2048, 512, kernel_size=(1, 1), stride=(1, 1))(bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv3): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1))(bn3): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True))(part4_3): Identity_block((conv1): Conv2d(2048, 512, kernel_size=(1, 1), stride=(1, 1))(bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv3): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1))(bn3): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True))(avg_pool): AvgPool2d(kernel_size=7, stride=7, padding=0)(fn1): Linear(in_features=2048, out_features=4, bias=True)
)
model(torch.randn(32, 3, 224, 224).to(device)).shape
torch.Size([32, 4])
3、模型训练
1、构建训练集
def train(dataloader, model, loss_fn, optimizer):size = len(dataloader.dataset)batch_size = len(dataloader)train_acc, train_loss = 0, 0 for X, y in dataloader:X, y = X.to(device), y.to(device)# 训练pred = model(X)loss = loss_fn(pred, y)# 梯度下降法optimizer.zero_grad()loss.backward()optimizer.step()# 记录train_loss += loss.item()train_acc += (pred.argmax(1) == y).type(torch.float).sum().item()train_acc /= sizetrain_loss /= batch_sizereturn train_acc, train_loss
2、构建测试集
def test(dataloader, model, loss_fn):size = len(dataloader.dataset)batch_size = len(dataloader)test_acc, test_loss = 0, 0 with torch.no_grad():for X, y in dataloader:X, y = X.to(device), y.to(device)pred = model(X)loss = loss_fn(pred, y)test_loss += loss.item()test_acc += (pred.argmax(1) == y).type(torch.float).sum().item()test_acc /= sizetest_loss /= batch_sizereturn test_acc, test_loss
3、设置超参数
loss_fn = nn.CrossEntropyLoss() # 损失函数
learn_lr = 1e-4 # 超参数
optimizer = torch.optim.Adam(model.parameters(), lr=learn_lr) # 优化器
4、模型训练
train_acc = []
train_loss = []
test_acc = []
test_loss = []epoches = 80for i in range(epoches):model.train()epoch_train_acc, epoch_train_loss = train(train_dl, model, loss_fn, optimizer)model.eval()epoch_test_acc, epoch_test_loss = test(test_dl, model, loss_fn)train_acc.append(epoch_train_acc)train_loss.append(epoch_train_loss)test_acc.append(epoch_test_acc)test_loss.append(epoch_test_loss)# 输出template = ('Epoch:{:2d}, Train_acc:{:.1f}%, Train_loss:{:.3f}, Test_acc:{:.1f}%, Test_loss:{:.3f}')print(template.format(i + 1, epoch_train_acc*100, epoch_train_loss, epoch_test_acc*100, epoch_test_loss))print("Done")

5、结果可视化
import matplotlib.pyplot as plt
#隐藏警告
import warnings
warnings.filterwarnings("ignore") #忽略警告信息epochs_range = range(epoches)plt.figure(figsize=(12, 3))
plt.subplot(1, 2, 1)plt.plot(epochs_range, train_acc, label='Training Accuracy')
plt.plot(epochs_range, test_acc, label='Test Accuracy')
plt.legend(loc='lower right')
plt.title('Training Accuracy')plt.subplot(1, 2, 2)
plt.plot(epochs_range, train_loss, label='Training Loss')
plt.plot(epochs_range, test_loss, label='Test Loss')
plt.legend(loc='upper right')
plt.title('Training= Loss')
plt.show()

参考资料
【深度学习】ResNet网络讲解-CSDN博客
K同学啊,训练营文档
相关文章:
深度学习基础--ResNet网络的讲解,ResNet50的复现(pytorch)以及用复现的ResNet50做鸟类图像分类
🍨 本文为🔗365天深度学习训练营 中的学习记录博客🍖 原作者:K同学啊 前言 如果说最经典的神经网络,ResNet肯定是一个,这篇文章是本人学习ResNet的学习笔记,并且用pytorch复现了ResNet50&…...
TMDS视频编解码算法
因为使用的是DDR进行传输,即双倍频率采样,故时钟只用是并行数据数据的5倍,而不是10倍。 TMDS算法流程: 视频编码TMDS算法流程实现: timescale 1 ps / 1ps //DVI编码通常用于视频传输,将并行数据转换为适合…...
深度解析SmartGBD助力Android音视频数据接入GB28181平台
在当今数字化时代,视频监控与音视频通信技术在各行各业的应用愈发广泛。GB28181协议作为中国国家标准,为视频监控设备的互联互通提供了规范,但在实际应用中,许多Android终端设备并不具备国标音视频能力,这限制了其在相…...
前端兼容处理接口返回的文件流或json数据
参考文档:JavaScript | MDN 参考链接:Blob格式转json格式,拿到后端返回的json数据_blob转json-CSDN博客 参考链接:https://juejin.cn/post/7117939029567340557 场景:导入上传文件,导入成功,…...
Eclipse 透视图 (Perspective)
Eclipse 透视图 (Perspective) Eclipse 是一款强大的集成开发环境(IDE),广泛应用于 Java 开发领域。其中,透视图(Perspective)是 Eclipse 中的一个核心概念,它将不同的工具和视图组合在一起,以便开发者能够更高效地完成特定的开发任务。本文将详细介绍 Eclipse 透视图…...
嵌入式硬件篇---滤波器
文章目录 前言一、模拟电子技术中的滤波器1. 基本概念功能实现方式 2. 分类按频率响应低通滤波器高通滤波器带通滤波器带阻滤波器 按实现方式无源滤波器有源滤波器 3. 设计方法巴特沃斯滤波器(Butterworth)切比雪夫滤波器(Chebyshevÿ…...
从零到一学习c++(基础篇--筑基期十一-类)
从零到一学习C(基础篇) 作者:羡鱼肘子 温馨提示1:本篇是记录我的学习经历,会有不少片面的认知,万分期待您的指正。 温馨提示2:本篇会尽量用更加通俗的语言介绍c的基础,用通俗的语言去…...
Java基础常见的面试题(易错!!)
面试题一:为什么 Java 不支持多继承 Java 不支持多继承主要是为避免 “菱形继承问题”(又称 “钻石问题”),即一个子类从多个父类继承到同名方法或属性时,编译器无法确定该调用哪个父类的成员。同时,多继承…...
DPVS-2:单臂负载均衡测试
上一篇编译安装了DPVS,这一篇开启DPVS的负载均衡测试 : 单臂 FULL NAT模式 拓扑-单臂 单臂模式 DPVS 单独物理机 CLINET,和两个RS都是另一个物理机的虚拟机,它们网卡都绑定在一个桥上br0 , 二层互通。 启动DPVS …...
C#中提供的多种集合类以及适用场景
在 C# 中,有多种集合类可供使用,它们分别适用于不同的场景,部分代码示例提供了LeetCode相关的代码应用。 1. 数组(Array) 特点 固定大小:在创建数组时需要指定其长度,之后无法动态改变。连续存储…...
【蓝桥杯集训·每日一题2025】 AcWing 6135. 奶牛体检 python
6135. 奶牛体检 Week 1 2月21日 农夫约翰的 N N N 头奶牛站成一行,奶牛 1 1 1 在队伍的最前面,奶牛 N N N 在队伍的最后面。 农夫约翰的奶牛也有许多不同的品种。 他用从 1 1 1 到 N N N 的整数来表示每一品种。 队伍从前到后第 i i i 头奶牛的…...
【为什么用pg数据库用 != null 过滤不出null值】
为什么用pg数据库用 ! null 过滤不出null值 1. NULL 的特殊性质2. 为什么 ! null 无效3. 正确的过滤 NULL 的方式示例 4. 为什么 IS NULL 和 IS NOT NULL 有效5. 示例对比6. 总结 在 PostgreSQL 中,使用 ! null 过滤不出 NULL 值的原因与 SQL 标准中 NULL 的特殊性质…...
Classic Control Theory | 12 Real Poles or Zeros (第12课笔记-中文版)
笔记链接:https://m.tb.cn/h.Tt876SW?tkQaITejKxnFLhttps://m.tb.cn/h.Tt876SW?tkQaITejKxnFL...
Kubernetes开发环境minikube | 开发部署MySQL单节点应用
minikube是一个主要用于开发与测试Kubernetes应用的运行环境 本文主要描述在minikube运行环境中部署MySQL单节点应用 minikube start --force kubectl get nodes 如上所示,启动minikube单节点运行环境 minikube ssh docker pull 如上所示,从MySQL官…...
大厂数据仓库数仓建模面试题及参考答案
目录 什么是数据仓库,和数据库有什么区别? 数据仓库的基本原理是什么? 数据仓库架构是怎样的? 数据仓库分层(层级划分),每层做什么?分层的好处是什么?数据分层是根据什么?数仓分层的原则与思路是什么? 数仓建模常用模型有哪些?区别、优缺点是什么?星型模型和雪…...
腾讯SQL面试题解析:如何找出连续5天涨幅超过5%的股票
腾讯SQL面试题解析:如何找出连续5天涨幅超过5%的股票 作者:某七年数据开发工程师 | 2025年02月23日 关键词:SQL窗口函数、连续问题、股票分析、腾讯面试题 一、问题背景与难点拆解 在股票量化分析场景中,"连续N天满足条件"是高频面试题类型。本题要求在单表stoc…...
安装可视化jar包部署平台JarManage
一、下载 下载地址:JarManage 发行版 - Gitee.com 🚒 下载 最新发行版 下载zip的里面linux和windows版本都有 二、运行 上传到服务器,解压进入目录 🚚 执行java -jar jarmanage-depoly.jar 命令运行 java -jar jarmanage-dep…...
基于数据可视化+SpringBoot+安卓端的数字化OA公司管理平台设计和实现
博主介绍:硕士研究生,专注于信息化技术领域开发与管理,会使用java、标准c/c等开发语言,以及毕业项目实战✌ 从事基于java BS架构、CS架构、c/c 编程工作近16年,拥有近12年的管理工作经验,拥有较丰富的技术架…...
输入搜索、分组展示选项、下拉选取,全局跳转页,el-select 实现 —— 后端数据处理代码,抛砖引玉展思路
详细前端代码写于上一篇:输入搜索、分组展示选项、下拉选取,el-select 实现:即输入关键字检索,返回分组选项,选取跳转到相应内容页 —— VUE项目-全局模糊检索 【效果图】:分组展示选项 >【去界面操作体…...
性能巅峰对决:Rust vs C++ —— 速度、安全与权衡的艺术
??关注,带你探索Java的奥秘!?? ??超萌技术攻略,轻松晋级编程高手!?? ??技术宝库已备好,就等你来挖掘!?? ??订阅,智趣学习不孤单!?? ??即刻启航,编…...
unity学习53:UI的子容器:面板panel
目录 1 UI的最底层容器:canvas 1.1 UI的最底层容器:canvas 1.2 UI的合理结构 2 UI的子容器:面板panel 2.1 创建panel 2.2 面板的本质: image ,就是一个透明的图片,1个空容器 3 面板的属性 4 面板的…...
4-知识图谱的抽取与构建-4_2实体识别与分类
🌟 知识图谱的实体识别与分类🔥 🔍 什么是实体识别与分类? 实体识别(Entity Recognition)是从文本中提取出具体的事物,如人名、地名、组织名等。分类(Entity Classification&#x…...
elasticsearch在windows上的配置
写在最前面: 上资源 第一步 解压: 第二步 配置两个环境变量 第三步 如果是其他资源需要将标蓝的文件中的内容加一句 xpack.security.enabled: false 不同版本的yaml文件可能配置不同,末尾加这个 xpack.security.enabled: true打开bin目…...
详解分布式ID实践
引言 分布式ID,所谓的分布式ID,就是针对整个系统而言,任何时刻获取一个ID,无论系统处于何种情况,该值不会与之前产生的值重复,之后获取分布式ID时,也不会再获取到与其相同的值,它是…...
如何在 Vue 项目中为 `el-pagination` 设置中文
文章目录 前言1. 安装 Element Plus2. 引入中文语言包3. 配置中文语言环境4. 使用 el-pagination 组件5. 确保其他组件支持中文6. 语言切换(可选)总结 前言 在 Vue 项目中,Element Plus 是一个流行的 UI 组件库,它提供了许多常用…...
PostgreSQL:更新字段慢
目录标题 PostgreSQL 慢查询优化与 pg_stat_statements 使用1. 启用慢查询日志2. 使用 pg_stat_statements 扩展收集查询统计信息3. 查找执行时间较长的查询4. 分析慢查询的执行计划5. 优化查询6. 检查并发连接和系统资源7. 进一步优化8. 查看某条SQL1. **如何生成 query_id**2…...
【Rust中级教程】2.8. API设计原则之灵活性(flexible) Pt.4:显式析构函数的问题及3种解决方案
喜欢的话别忘了点赞、收藏加关注哦(加关注即可阅读全文),对接下来的教程有兴趣的可以关注专栏。谢谢喵!(・ω・) 说句题外话,这篇文章一共5721个字,是我截至目前写的最长的一篇文章&a…...
【复习】Redis
数据结构 Redis常见的数据结构 String:缓存对象Hash:缓存对象、购物车List:消息队列Set:点赞、共同关注ZSet:排序 Zset底层? Zset底层的数据结构是由压缩链表或跳表实现的 如果有序集合的元素 < 12…...
STM32使用NRF2401进行数据传送
NRF2401是一款由Nordic Semiconductor公司生产的单片射频收发芯片,以下是关于它的详细介绍: 一、主要特点 工作频段:NRF2401工作于2.4~2.5GHz的ISM(工业、科学和医疗)频段,该频段无需申请即可使用…...
Fetch API 与 XMLHttpRequest:深入剖析异步请求的利器
Hi,我是布兰妮甜 !在现代 Web 开发中,异步通信是实现动态和交互式用户体验的基石。XMLHttpRequest (XHR) 作为老牌劲旅,曾一度统治着这一领域。然而,随着 Fetch API 的横空出世,开发者们迎来了一个更现代、…...
