PyTorch深度学习实战——基于ResNet模型实现猫狗分类
PyTorch深度学习实战——基于ResNet模型实现猫狗分类
- 0. 前言
- 1. ResNet 架构
- 2. 基于预训练 ResNet 模型实现猫狗分类
- 相关链接
0. 前言
从 VGG11
到 VGG19
,不同之处仅在于网络层数,一般来说,神经网络越深,它的准确率就越高。但并非仅增加网络层数,就可以获得更准确的结果,随着网络层数的增加可能会出现以下问题:
- 梯度消失和爆炸:在网络层次过深的情况下,反向传播可能会面临梯度消失和爆炸的问题,导致训练网络时无法收敛
- 过拟合:增加网络深度会带来更多的参数,如果数据样本过少或网络过于复杂,会导致网络过拟合,降低模型的泛化能力
总之,在构建的神经网络过深时,有两个问题:前向传播中,网络的最后几层几乎没有学习到有关原始图像的任何信息;在反向传播中,由于梯度消失(梯度值几乎为零),靠近输入的前几层几乎没有任何梯度更新。
深度残差网络 (ResNet
) 的提出就是为了解决上述问题。在 ResNet
中,如果模型没有什么要学习的,那么卷积层可以什么也不做,只是将上一层的输出传递给下一层。但是,如果模型需要学习其他一些特征,则卷积层将前一层的输出作为输入,并学习完成目标任务所需的其它特征。
1. ResNet 架构
ResNet
通过残差结构解决网络过深时出现的问题,让模型能够训练得更深。经典的 ResNet
架构如下所示:
残差结构的基本思想是:每一个残差块都不是直接映射输入信号到输出信号,而是通过学习残差映射来实现:
F ( x ) = H ( x ) − x F(x)=H(x)−x F(x)=H(x)−x
其中, x x x 是输入, H ( x ) H(x) H(x) 是一个表示所需映射的基本块,而 F ( x ) F(x) F(x) 是残差块学习到的映射。换句话说,输入 x x x 通过卷积层,得到特征变换后的输出 F ( x ) F(x) F(x),与输入 x x x 进行逐元素的相加运算,得到最终输出 H ( x ) H(x) H(x):
H ( x ) = x + F ( x ) H(x) = x + F(x) H(x)=x+F(x)
如果某个基本块为恒等映射,则残差块的学习目标就变为学习 F ( x ) = 0 F(x)=0 F(x)=0,也就是让输入信号直接到达残差块的输出层。这样就可以解决梯度消失的问题,可以训练更深的神经网络。
实现过程中 ResNet
中使用 Shortcut Connection
(也称跳跃连接, Skip Connection
)在残差块中实现跨层连接,从而实现信息的直接传递,跨层连接可以绕过一个或多个卷积层,直接将网络中的浅层信息传递到深层中。
在 ResNet
的残差块中,Shortcut Connection
经常与卷积层或批归一化 (Batch Normalization
) 相结合。通过该连接,残差块的激活张量可以直接和下一层的输出相加,理论上,即使是最后一层可能拥有原始图像的全部信息,并且反向传播过程中梯度将可以在几乎没有修改的情况下自由地流向浅层。典型的残差块如下所示:
在传统中顺序堆叠的神经网络中,神经网络通常直接学习 F ( x ) F(x) F(x),其中 x 是来自前一层的输出值,而在残差网络中,利用跳跃连接,将残差信号 F ( x ) F(x) F(x) 加上恒等映射 x x x 得到最终的输出 H ( x ) = F ( x ) + x H(x)=F(x)+x H(x)=F(x)+x。接下来,我们通过在 PyTorch
中构建残差块来深入了解残差网络。
2. 基于预训练 ResNet 模型实现猫狗分类
(1) 在 __init__
方法中定义一个带有卷积操作的类:
from torch import nnclass ResLayer(nn.Module):def __init__(self,ni,no,kernel_size,stride=1):super(ResLayer, self).__init__()padding = kernel_size - 2self.conv = nn.Sequential(nn.Conv2d(ni, no, kernel_size, stride, padding=padding),nn.ReLU())
在以上代码中,为了确保通过卷积后输出的尺寸保持不变,以便于将输入与卷结果相加,我们通过 padding
控制卷积时输出的尺寸。
(2) 定义 forward
方法:
def forward(self, x):return self.conv(x) + x
在以上代码中,得到的输出是通过卷积操作的输入和原始输入之和。
在 PyTorch
中预训练的基于残差块的 ResNet18
架构如下:
该架构有 18
个可训练网络层,因此被称为 ResNet18
架构。此外,需要注意的是,ResNet18
并不是每个卷积层都会添加跳跃连接,而是在每两层之后使用跳跃连接。
了解了 ResNet
架构之后,构建一个基于预训练 ResNet18
架构的模型来执行狗猫分类任务。构建分类器的流程可以参考在迁移学习中使用预训练 VGG16 模型构建的猫狗分类器。
(3) 加载预训练 ResNet18
模型并检查模型中的模块:
model = models.resnet18(pretrained=True)
ResNet18
模型架构包含以下组件:
- 卷积层
- 批归一化
ReLU
激活- 最大池化层
4
个ResNet
块- 平均池化 (
avgpool
) 层 - 全连接层 (
fc
) 层
冻结特征提取模块的网络权重,仅替换 avgpool
和 fc
层并更新其中的参数。
(4) 定义模型架构、损失函数和优化器:
def get_model():model = models.resnet18(pretrained=True)for param in model.parameters():param.requires_grad = Falsemodel.avgpool = nn.AdaptiveAvgPool2d(output_size=(1,1))model.fc = nn.Sequential(nn.Flatten(),nn.Linear(512, 128),nn.ReLU(),nn.Dropout(0.2),nn.Linear(128, 1),nn.Sigmoid())loss_fn = nn.BCELoss()optimizer = torch.optim.Adam(model.parameters(), lr= 1e-3)return model.to(device), loss_fn, optimizer
在模型中,fc
模块的输入形状为 512
,因为 avgpool
的输出形状为 batch size x 512 x 1 x 1
。定义了模型后,训练模型,随着 epoch
的增加,模型训练和验证准确率的变化(对应模型分别为 ResNet18
、ResNet34
、ResNet50
、ResNet101
和 ResNet152
) 如下:
仅对 1000
张图像进行训练时,模型的准确率就可以达到 98%
左右,且准确率随着 ResNet
层数的增加而增加。
相关链接
PyTorch深度学习实战(1)——神经网络与模型训练过程详解
PyTorch深度学习实战(2)——PyTorch基础
PyTorch深度学习实战(3)——使用PyTorch构建神经网络
PyTorch深度学习实战(4)——常用激活函数和损失函数详解
PyTorch深度学习实战(5)——计算机视觉基础
PyTorch深度学习实战(6)——神经网络性能优化技术
PyTorch深度学习实战(7)——批大小对神经网络训练的影响
PyTorch深度学习实战(8)——批归一化
PyTorch深度学习实战(9)——学习率优化
PyTorch深度学习实战(10)——过拟合及其解决方法
PyTorch深度学习实战(11)——卷积神经网络
PyTorch深度学习实战(12)——数据增强
PyTorch深度学习实战(13)——可视化神经网络中间层输出
PyTorch深度学习实战(14)——类激活图
PyTorch深度学习实战(15)——迁移学习
相关文章:

PyTorch深度学习实战——基于ResNet模型实现猫狗分类
PyTorch深度学习实战——基于ResNet模型实现猫狗分类 0. 前言1. ResNet 架构2. 基于预训练 ResNet 模型实现猫狗分类相关链接 0. 前言 从 VGG11 到 VGG19,不同之处仅在于网络层数,一般来说,神经网络越深,它的准确率就越高。但并非…...

机器学习第六课--朴素贝叶斯
朴素贝叶斯广泛地应用在文本分类任务中,其中最为经典的场景为垃圾文本分类(如垃圾邮件分类:给定一个邮件,把它自动分类为垃圾或者正常邮件)。这个任务本身是属于文本分析任务,因为对应的数据均为文本类型,所以对于此类任务我们首先…...

基于Java+SpringBoot+Vue的图书借还小程序的设计与实现(亮点:多角色、点赞评论、借书还书、在线支付)
图书借还管理小程序 一、前言二、我的优势2.1 自己的网站2.2 自己的小程序(小蔡coding)2.3 有保障的售后2.4 福利 三、开发环境与技术3.1 MySQL数据库3.2 Vue前端技术3.3 Spring Boot框架3.4 微信小程序 四、功能设计4.1 主要功能描述 五、系统实现5.1 小…...

【校招VIP】前端计算机网络之UDP相关
考点介绍 UDP是一个简单的面向消息的传输层协议,尽管UDP提供标头和有效负载的完整性验证(通过校验和),但它不保证向上层协议提供消息传递,并且UDP层在发送后不会保留UDP 消息的状态。因此,UDP有时被称为不可…...

前缀和实例4(和可被k整除的子数组)
题目: 给定一个整数数组 nums 和一个整数 k ,返回其中元素之和可被 k 整除的(连续、非空) 子数组 的数目。 子数组 是数组的 连续 部分。 示例 1: 输入:nums [4,5,0,-2,-3,1], k 5 输出:7 …...

Android获取系统读取权限
第一步在Androidifest.xml文件中加上授权语句 <uses-permission android:name"android.permission.WRITE_EXTERNAL_STORAGE"/><uses-permission android:name"android.permission.READ_EXTERNAL_STORAGE"/>并且在Application标签下添加 androi…...

输入学生成绩(最多不超过40),输入为负值时表示输入结束,统计成绩高于平均成绩的学生人数
#include<stdio.h> #define N 40 int scanfscore(int score[N]) {int i -1;do {i;printf("输入学生成绩:");scanf("%d", &score[i]);} while (score[i] > 0);return i; } int average(int score[N], int n) {int j 0;int k 0;double sum …...

【力扣周赛】第 363 场周赛(完全平方数和质因数分解)
文章目录 竞赛链接Q1:100031. 计算 K 置位下标对应元素的和竞赛时代码写法2——手写二进制中1的数量 Q2:100040. 让所有学生保持开心的分组方法数(排序后枚举分界)竞赛时代码 Q3:100033. 最大合金数(二分答…...

RocketMQ的介绍和环境搭建
一、介绍 我也不知道是啥,知道有什么用、怎么用就行了,说到mq(MessageQueue)就是消息队列,队列是先进先出的一种数据结构,但是RocketMQ不一定是这样,简单的理解一下,就是临时存储的…...

【web开发】7、Django(2)
提示:文章写完后,目录可以自动生成,如何生成可参考右边的帮助文档 文章目录 一、部门列表二、部门管理(增删改)三、用户管理过渡到modelform组件四、modelform实例:靓号操作五、自定义分页组件六、datepick…...

Prometheus+Grafana可视化监控【Nginx状态】
文章目录 一、安装Docker二、安装Nginx(Docker容器方式)三、安装Prometheus四、安装Grafana五、Pronetheus和Grafana相关联六、安装nginx_exporter七、Grafana添加Nginx监控模板 一、安装Docker 注意:我这里使用之前写好脚本进行安装Docker,如果已经有D…...

R 语言的安装教程
一、下载相关软件 1、R 下载 官网:R: The R Project for Statistical Computing 找到中国镜像,下载快 历史版本点击这里 2、Rtools 下载 进入镜像后,点击这里 然后选择与上面下载的R版本相对应的版本即可 3、Rstudio 下载 官网࿱…...

uniapp-提现功能(demo)
页面布局 提现页面 有一个输入框 一个提现按钮 一段提现全部的文字 首先用v-model 和data内的数据双向绑定 输入框逻辑分析 输入框的逻辑 为了符合日常输出 所以要对输入框加一些条件限制 因为是提现 所以对输入的字符做筛选,只允许出现小数点和数字 这里用正则实现的小数点…...

Spring 篇
1、什么是 Spring? Spring是一个轻量级的IOC和AOP容器框架。是为Java应用程序提供基础性服务的一套框架,目的是用于简化企业应用程序的开发,它使得开发者只需要关心业务需求。常见的配置方式有三种:基于XML的配置、基于注解的配置…...

three.js简单3D图形的使用
npm init vitelatest //创建一个vite的脚手架 选择 Vanilla 之后自己处理一下 在main.js中写入 // 导入three.js import * as THREE from three// 创建场景 const scene new THREE.Scene();// 创建相机 const camera new THREE.PerspectiveCamera(45, //视角window.inner…...
spark withColumn的使用(笔记)
目录 前言: spark withColumn的语法及使用: 准备源数据演示: 完整实例代码: 前言: withColumn():是Apache Spark中用于DataFrame操作的函数之一,它的作用是在DataFrame中添加或替换列ÿ…...

PTA:7-1 线性表的合并
线性表的合并 题目输入样例输出样例 代码解析 题目 输入样例 4 7 5 3 11 3 2 6 3输出样例 7 5 3 11 2 6 代码 #include<iostream> #include<vector> using namespace std;bool checkrep(const vector<int>& arr, int x) {for (int element : arr) {i…...

Spring 的创建和日志框架的整合
目录 一、第一个 Spring 项目 1、配置环境 2、Spring 的 jar 包 Maven 项目导入 jar 包和设置国内源的方法: 3、Spring 的配置文件 4、Spring 的核心 API ApplicationContext 4、程序开发 5、细节分析 (1)名词解释 (2&…...
11-集合和学生管理系统
1.ArrayList 集合和数组的优势对比: 长度可变添加数据的时候不需要考虑索引,默认将数据添加到末尾 1.1 ArrayList类概述 什么是集合 提供一种存储空间可变的存储模型,存储的数据容量可以发生改变 ArrayList集合的特点 长度可以变化…...

C语言进阶指针(3) ——qsort的实现
大家好,我们今天来学习回调函数qsort的实现。 首先让我们打开cplusplus.com找到qsort函数。 我们看到这个函数就可以看到它的头文件和参数信息。 #include<stdlib.h> void qsort (void* base, size_t num, size_t size, int (*compar)(const void*,const voi…...

使用docker在3台服务器上搭建基于redis 6.x的一主两从三台均是哨兵模式
一、环境及版本说明 如果服务器已经安装了docker,则忽略此步骤,如果没有安装,则可以按照一下方式安装: 1. 在线安装(有互联网环境): 请看我这篇文章 传送阵>> 点我查看 2. 离线安装(内网环境):请看我这篇文章 传送阵>> 点我查看 说明:假设每台服务器已…...
laravel8+vue3.0+element-plus搭建方法
创建 laravel8 项目 composer create-project --prefer-dist laravel/laravel laravel8 8.* 安装 laravel/ui composer require laravel/ui 修改 package.json 文件 "devDependencies": {"vue/compiler-sfc": "^3.0.7","axios": …...

系统掌握PyTorch:图解张量、Autograd、DataLoader、nn.Module与实战模型
本文较长,建议点赞收藏,以免遗失。更多AI大模型应用开发学习视频及资料,尽在聚客AI学院。 本文通过代码驱动的方式,系统讲解PyTorch核心概念和实战技巧,涵盖张量操作、自动微分、数据加载、模型构建和训练全流程&#…...

stm32wle5 lpuart DMA数据不接收
配置波特率9600时,需要使用外部低速晶振...

Unity中的transform.up
2025年6月8日,周日下午 在Unity中,transform.up是Transform组件的一个属性,表示游戏对象在世界空间中的“上”方向(Y轴正方向),且会随对象旋转动态变化。以下是关键点解析: 基本定义 transfor…...
如何配置一个sql server使得其它用户可以通过excel odbc获取数据
要让其他用户通过 Excel 使用 ODBC 连接到 SQL Server 获取数据,你需要完成以下配置步骤: ✅ 一、在 SQL Server 端配置(服务器设置) 1. 启用 TCP/IP 协议 打开 “SQL Server 配置管理器”。导航到:SQL Server 网络配…...
Java 与 MySQL 性能优化:MySQL 慢 SQL 诊断与分析方法详解
文章目录 一、开启慢查询日志,定位耗时SQL1.1 查看慢查询日志是否开启1.2 临时开启慢查询日志1.3 永久开启慢查询日志1.4 分析慢查询日志 二、使用EXPLAIN分析SQL执行计划2.1 EXPLAIN的基本使用2.2 EXPLAIN分析案例2.3 根据EXPLAIN结果优化SQL 三、使用SHOW PROFILE…...

Java数组Arrays操作全攻略
Arrays类的概述 Java中的Arrays类位于java.util包中,提供了一系列静态方法用于操作数组(如排序、搜索、填充、比较等)。这些方法适用于基本类型数组和对象数组。 常用成员方法及代码示例 排序(sort) 对数组进行升序…...
如何通过git命令查看项目连接的仓库地址?
要通过 Git 命令查看项目连接的仓库地址,您可以使用以下几种方法: 1. 查看所有远程仓库地址 使用 git remote -v 命令,它会显示项目中配置的所有远程仓库及其对应的 URL: git remote -v输出示例: origin https://…...
CppCon 2015 学习:Reactive Stream Processing in Industrial IoT using DDS and Rx
“Reactive Stream Processing in Industrial IoT using DDS and Rx” 是指在工业物联网(IIoT)场景中,结合 DDS(Data Distribution Service) 和 Rx(Reactive Extensions) 技术,实现 …...