当前位置: 首页 > news >正文

PyTorch深度学习实战——基于ResNet模型实现猫狗分类

PyTorch深度学习实战——基于ResNet模型实现猫狗分类

    • 0. 前言
    • 1. ResNet 架构
    • 2. 基于预训练 ResNet 模型实现猫狗分类
    • 相关链接

0. 前言

VGG11VGG19,不同之处仅在于网络层数,一般来说,神经网络越深,它的准确率就越高。但并非仅增加网络层数,就可以获得更准确的结果,随着网络层数的增加可能会出现以下问题:

  • 梯度消失和爆炸:在网络层次过深的情况下,反向传播可能会面临梯度消失和爆炸的问题,导致训练网络时无法收敛
  • 过拟合:增加网络深度会带来更多的参数,如果数据样本过少或网络过于复杂,会导致网络过拟合,降低模型的泛化能力

总之,在构建的神经网络过深时,有两个问题:前向传播中,网络的最后几层几乎没有学习到有关原始图像的任何信息;在反向传播中,由于梯度消失(梯度值几乎为零),靠近输入的前几层几乎没有任何梯度更新。
深度残差网络 (ResNet) 的提出就是为了解决上述问题。在 ResNet 中,如果模型没有什么要学习的,那么卷积层可以什么也不做,只是将上一层的输出传递给下一层。但是,如果模型需要学习其他一些特征,则卷积层将前一层的输出作为输入,并学习完成目标任务所需的其它特征。

1. ResNet 架构

ResNet 通过残差结构解决网络过深时出现的问题,让模型能够训练得更深。经典的 ResNet 架构如下所示:

ResNet架构
残差结构的基本思想是:每一个残差块都不是直接映射输入信号到输出信号,而是通过学习残差映射来实现:
F ( x ) = H ( x ) − x F(x)=H(x)−x F(x)=H(x)x

其中, x x x 是输入, H ( x ) H(x) H(x) 是一个表示所需映射的基本块,而 F ( x ) F(x) F(x) 是残差块学习到的映射。换句话说,输入 x x x 通过卷积层,得到特征变换后的输出 F ( x ) F(x) F(x),与输入 x x x 进行逐元素的相加运算,得到最终输出 H ( x ) H(x) H(x)

H ( x ) = x + F ( x ) H(x) = x + F(x) H(x)=x+F(x)

如果某个基本块为恒等映射,则残差块的学习目标就变为学习 F ( x ) = 0 F(x)=0 F(x)=0,也就是让输入信号直接到达残差块的输出层。这样就可以解决梯度消失的问题,可以训练更深的神经网络。
实现过程中 ResNet 中使用 Shortcut Connection (也称跳跃连接, Skip Connection )在残差块中实现跨层连接,从而实现信息的直接传递,跨层连接可以绕过一个或多个卷积层,直接将网络中的浅层信息传递到深层中。
ResNet 的残差块中,Shortcut Connection 经常与卷积层或批归一化 (Batch Normalization) 相结合。通过该连接,残差块的激活张量可以直接和下一层的输出相加,理论上,即使是最后一层可能拥有原始图像的全部信息,并且反向传播过程中梯度将可以在几乎没有修改的情况下自由地流向浅层。典型的残差块如下所示:

残差模块

在传统中顺序堆叠的神经网络中,神经网络通常直接学习 F ( x ) F(x) F(x),其中 x 是来自前一层的输出值,而在残差网络中,利用跳跃连接,将残差信号 F ( x ) F(x) F(x) 加上恒等映射 x x x 得到最终的输出 H ( x ) = F ( x ) + x H(x)=F(x)+x H(x)=F(x)+x。接下来,我们通过在 PyTorch 中构建残差块来深入了解残差网络。

2. 基于预训练 ResNet 模型实现猫狗分类

(1)__init__ 方法中定义一个带有卷积操作的类:

from torch import nnclass ResLayer(nn.Module):def __init__(self,ni,no,kernel_size,stride=1):super(ResLayer, self).__init__()padding = kernel_size - 2self.conv = nn.Sequential(nn.Conv2d(ni, no, kernel_size, stride, padding=padding),nn.ReLU())

在以上代码中,为了确保通过卷积后输出的尺寸保持不变,以便于将输入与卷结果相加,我们通过 padding 控制卷积时输出的尺寸。

(2) 定义 forward 方法:

    def forward(self, x):return self.conv(x) + x

在以上代码中,得到的输出是通过卷积操作的输入和原始输入之和。

PyTorch 中预训练的基于残差块的 ResNet18 架构如下:

请添加图片描述
该架构有 18 个可训练网络层,因此被称为 ResNet18 架构。此外,需要注意的是,ResNet18 并不是每个卷积层都会添加跳跃连接,而是在每两层之后使用跳跃连接。
了解了 ResNet 架构之后,构建一个基于预训练 ResNet18 架构的模型来执行狗猫分类任务。构建分类器的流程可以参考在迁移学习中使用预训练 VGG16 模型构建的猫狗分类器。

(3) 加载预训练 ResNet18 模型并检查模型中的模块:

model = models.resnet18(pretrained=True)

ResNet18 模型架构包含以下组件:

  • 卷积层
  • 批归一化
  • ReLU 激活
  • 最大池化层
  • 4ResNet
  • 平均池化 (avgpool) 层
  • 全连接层 (fc) 层

冻结特征提取模块的网络权重,仅替换 avgpoolfc 层并更新其中的参数。

(4) 定义模型架构、损失函数和优化器:

def get_model():model = models.resnet18(pretrained=True)for param in model.parameters():param.requires_grad = Falsemodel.avgpool = nn.AdaptiveAvgPool2d(output_size=(1,1))model.fc = nn.Sequential(nn.Flatten(),nn.Linear(512, 128),nn.ReLU(),nn.Dropout(0.2),nn.Linear(128, 1),nn.Sigmoid())loss_fn = nn.BCELoss()optimizer = torch.optim.Adam(model.parameters(), lr= 1e-3)return model.to(device), loss_fn, optimizer

在模型中,fc 模块的输入形状为 512,因为 avgpool 的输出形状为 batch size x 512 x 1 x 1。定义了模型后,训练模型,随着 epoch 的增加,模型训练和验证准确率的变化(对应模型分别为 ResNet18ResNet34ResNet50ResNet101ResNet152) 如下:

模型训练和验证准确率
仅对 1000 张图像进行训练时,模型的准确率就可以达到 98% 左右,且准确率随着 ResNet 层数的增加而增加。

相关链接

PyTorch深度学习实战(1)——神经网络与模型训练过程详解
PyTorch深度学习实战(2)——PyTorch基础
PyTorch深度学习实战(3)——使用PyTorch构建神经网络
PyTorch深度学习实战(4)——常用激活函数和损失函数详解
PyTorch深度学习实战(5)——计算机视觉基础
PyTorch深度学习实战(6)——神经网络性能优化技术
PyTorch深度学习实战(7)——批大小对神经网络训练的影响
PyTorch深度学习实战(8)——批归一化
PyTorch深度学习实战(9)——学习率优化
PyTorch深度学习实战(10)——过拟合及其解决方法
PyTorch深度学习实战(11)——卷积神经网络
PyTorch深度学习实战(12)——数据增强
PyTorch深度学习实战(13)——可视化神经网络中间层输出
PyTorch深度学习实战(14)——类激活图
PyTorch深度学习实战(15)——迁移学习

相关文章:

PyTorch深度学习实战——基于ResNet模型实现猫狗分类

PyTorch深度学习实战——基于ResNet模型实现猫狗分类 0. 前言1. ResNet 架构2. 基于预训练 ResNet 模型实现猫狗分类相关链接 0. 前言 从 VGG11 到 VGG19,不同之处仅在于网络层数,一般来说,神经网络越深,它的准确率就越高。但并非…...

机器学习第六课--朴素贝叶斯

朴素贝叶斯广泛地应用在文本分类任务中,其中最为经典的场景为垃圾文本分类(如垃圾邮件分类:给定一个邮件,把它自动分类为垃圾或者正常邮件)。这个任务本身是属于文本分析任务,因为对应的数据均为文本类型,所以对于此类任务我们首先…...

基于Java+SpringBoot+Vue的图书借还小程序的设计与实现(亮点:多角色、点赞评论、借书还书、在线支付)

图书借还管理小程序 一、前言二、我的优势2.1 自己的网站2.2 自己的小程序(小蔡coding)2.3 有保障的售后2.4 福利 三、开发环境与技术3.1 MySQL数据库3.2 Vue前端技术3.3 Spring Boot框架3.4 微信小程序 四、功能设计4.1 主要功能描述 五、系统实现5.1 小…...

【校招VIP】前端计算机网络之UDP相关

考点介绍 UDP是一个简单的面向消息的传输层协议,尽管UDP提供标头和有效负载的完整性验证(通过校验和),但它不保证向上层协议提供消息传递,并且UDP层在发送后不会保留UDP 消息的状态。因此,UDP有时被称为不可…...

前缀和实例4(和可被k整除的子数组)

题目: 给定一个整数数组 nums 和一个整数 k ,返回其中元素之和可被 k 整除的(连续、非空) 子数组 的数目。 子数组 是数组的 连续 部分。 示例 1: 输入:nums [4,5,0,-2,-3,1], k 5 输出:7 …...

Android获取系统读取权限

第一步在Androidifest.xml文件中加上授权语句 <uses-permission android:name"android.permission.WRITE_EXTERNAL_STORAGE"/><uses-permission android:name"android.permission.READ_EXTERNAL_STORAGE"/>并且在Application标签下添加 androi…...

输入学生成绩(最多不超过40),输入为负值时表示输入结束,统计成绩高于平均成绩的学生人数

#include<stdio.h> #define N 40 int scanfscore(int score[N]) {int i -1;do {i;printf("输入学生成绩:");scanf("%d", &score[i]);} while (score[i] > 0);return i; } int average(int score[N], int n) {int j 0;int k 0;double sum …...

【力扣周赛】第 363 场周赛(完全平方数和质因数分解)

文章目录 竞赛链接Q1&#xff1a;100031. 计算 K 置位下标对应元素的和竞赛时代码写法2——手写二进制中1的数量 Q2&#xff1a;100040. 让所有学生保持开心的分组方法数&#xff08;排序后枚举分界&#xff09;竞赛时代码 Q3&#xff1a;100033. 最大合金数&#xff08;二分答…...

RocketMQ的介绍和环境搭建

一、介绍 我也不知道是啥&#xff0c;知道有什么用、怎么用就行了&#xff0c;说到mq&#xff08;MessageQueue&#xff09;就是消息队列&#xff0c;队列是先进先出的一种数据结构&#xff0c;但是RocketMQ不一定是这样&#xff0c;简单的理解一下&#xff0c;就是临时存储的…...

【web开发】7、Django(2)

提示&#xff1a;文章写完后&#xff0c;目录可以自动生成&#xff0c;如何生成可参考右边的帮助文档 文章目录 一、部门列表二、部门管理&#xff08;增删改&#xff09;三、用户管理过渡到modelform组件四、modelform实例&#xff1a;靓号操作五、自定义分页组件六、datepick…...

Prometheus+Grafana可视化监控【Nginx状态】

文章目录 一、安装Docker二、安装Nginx(Docker容器方式)三、安装Prometheus四、安装Grafana五、Pronetheus和Grafana相关联六、安装nginx_exporter七、Grafana添加Nginx监控模板 一、安装Docker 注意&#xff1a;我这里使用之前写好脚本进行安装Docker&#xff0c;如果已经有D…...

R 语言的安装教程

一、下载相关软件 1、R 下载 官网&#xff1a;R: The R Project for Statistical Computing 找到中国镜像&#xff0c;下载快 历史版本点击这里 2、Rtools 下载 进入镜像后&#xff0c;点击这里 然后选择与上面下载的R版本相对应的版本即可 3、Rstudio 下载 官网&#xff1…...

uniapp-提现功能(demo)

页面布局 提现页面 有一个输入框 一个提现按钮 一段提现全部的文字 首先用v-model 和data内的数据双向绑定 输入框逻辑分析 输入框的逻辑 为了符合日常输出 所以要对输入框加一些条件限制 因为是提现 所以对输入的字符做筛选,只允许出现小数点和数字 这里用正则实现的小数点…...

Spring 篇

1、什么是 Spring&#xff1f; Spring是一个轻量级的IOC和AOP容器框架。是为Java应用程序提供基础性服务的一套框架&#xff0c;目的是用于简化企业应用程序的开发&#xff0c;它使得开发者只需要关心业务需求。常见的配置方式有三种&#xff1a;基于XML的配置、基于注解的配置…...

three.js简单3D图形的使用

npm init vitelatest //创建一个vite的脚手架 选择 Vanilla 之后自己处理一下 在main.js中写入 // 导入three.js import * as THREE from three// 创建场景 const scene new THREE.Scene();// 创建相机 const camera new THREE.PerspectiveCamera(45, //视角window.inner…...

spark withColumn的使用(笔记)

目录 前言&#xff1a; spark withColumn的语法及使用&#xff1a; 准备源数据演示&#xff1a; 完整实例代码&#xff1a; 前言&#xff1a; withColumn()&#xff1a;是Apache Spark中用于DataFrame操作的函数之一&#xff0c;它的作用是在DataFrame中添加或替换列&#xff…...

PTA:7-1 线性表的合并

线性表的合并 题目输入样例输出样例 代码解析 题目 输入样例 4 7 5 3 11 3 2 6 3输出样例 7 5 3 11 2 6 代码 #include<iostream> #include<vector> using namespace std;bool checkrep(const vector<int>& arr, int x) {for (int element : arr) {i…...

Spring 的创建和日志框架的整合

目录 一、第一个 Spring 项目 1、配置环境 2、Spring 的 jar 包 Maven 项目导入 jar 包和设置国内源的方法&#xff1a; 3、Spring 的配置文件 4、Spring 的核心 API ApplicationContext 4、程序开发 5、细节分析 &#xff08;1&#xff09;名词解释 &#xff08;2&…...

11-集合和学生管理系统

1.ArrayList 集合和数组的优势对比&#xff1a; 长度可变添加数据的时候不需要考虑索引&#xff0c;默认将数据添加到末尾 1.1 ArrayList类概述 什么是集合 ​ 提供一种存储空间可变的存储模型&#xff0c;存储的数据容量可以发生改变 ArrayList集合的特点 ​ 长度可以变化…...

C语言进阶指针(3) ——qsort的实现

大家好&#xff0c;我们今天来学习回调函数qsort的实现。 首先让我们打开cplusplus.com找到qsort函数。 我们看到这个函数就可以看到它的头文件和参数信息。 #include<stdlib.h> void qsort (void* base, size_t num, size_t size, int (*compar)(const void*,const voi…...

深度学习在微纳光子学中的应用

深度学习在微纳光子学中的主要应用方向 深度学习与微纳光子学的结合主要集中在以下几个方向&#xff1a; 逆向设计 通过神经网络快速预测微纳结构的光学响应&#xff0c;替代传统耗时的数值模拟方法。例如设计超表面、光子晶体等结构。 特征提取与优化 从复杂的光学数据中自…...

基于距离变化能量开销动态调整的WSN低功耗拓扑控制开销算法matlab仿真

目录 1.程序功能描述 2.测试软件版本以及运行结果展示 3.核心程序 4.算法仿真参数 5.算法理论概述 6.参考文献 7.完整程序 1.程序功能描述 通过动态调整节点通信的能量开销&#xff0c;平衡网络负载&#xff0c;延长WSN生命周期。具体通过建立基于距离的能量消耗模型&am…...

C++:std::is_convertible

C++标志库中提供is_convertible,可以测试一种类型是否可以转换为另一只类型: template <class From, class To> struct is_convertible; 使用举例: #include <iostream> #include <string>using namespace std;struct A { }; struct B : A { };int main…...

线程同步:确保多线程程序的安全与高效!

全文目录&#xff1a; 开篇语前序前言第一部分&#xff1a;线程同步的概念与问题1.1 线程同步的概念1.2 线程同步的问题1.3 线程同步的解决方案 第二部分&#xff1a;synchronized关键字的使用2.1 使用 synchronized修饰方法2.2 使用 synchronized修饰代码块 第三部分&#xff…...

postgresql|数据库|只读用户的创建和删除(备忘)

CREATE USER read_only WITH PASSWORD 密码 -- 连接到xxx数据库 \c xxx -- 授予对xxx数据库的只读权限 GRANT CONNECT ON DATABASE xxx TO read_only; GRANT USAGE ON SCHEMA public TO read_only; GRANT SELECT ON ALL TABLES IN SCHEMA public TO read_only; GRANT EXECUTE O…...

WordPress插件:AI多语言写作与智能配图、免费AI模型、SEO文章生成

厌倦手动写WordPress文章&#xff1f;AI自动生成&#xff0c;效率提升10倍&#xff01; 支持多语言、自动配图、定时发布&#xff0c;让内容创作更轻松&#xff01; AI内容生成 → 不想每天写文章&#xff1f;AI一键生成高质量内容&#xff01;多语言支持 → 跨境电商必备&am…...

Java入门学习详细版(一)

大家好&#xff0c;Java 学习是一个系统学习的过程&#xff0c;核心原则就是“理论 实践 坚持”&#xff0c;并且需循序渐进&#xff0c;不可过于着急&#xff0c;本篇文章推出的这份详细入门学习资料将带大家从零基础开始&#xff0c;逐步掌握 Java 的核心概念和编程技能。 …...

智能仓储的未来:自动化、AI与数据分析如何重塑物流中心

当仓库学会“思考”&#xff0c;物流的终极形态正在诞生 想象这样的场景&#xff1a; 凌晨3点&#xff0c;某物流中心灯火通明却空无一人。AGV机器人集群根据实时订单动态规划路径&#xff1b;AI视觉系统在0.1秒内扫描包裹信息&#xff1b;数字孪生平台正模拟次日峰值流量压力…...

有限自动机到正规文法转换器v1.0

1 项目简介 这是一个功能强大的有限自动机&#xff08;Finite Automaton, FA&#xff09;到正规文法&#xff08;Regular Grammar&#xff09;转换器&#xff0c;它配备了一个直观且完整的图形用户界面&#xff0c;使用户能够轻松地进行操作和观察。该程序基于编译原理中的经典…...

听写流程自动化实践,轻量级教育辅助

随着智能教育工具的发展&#xff0c;越来越多的传统学习方式正在被数字化、自动化所优化。听写作为语文、英语等学科中重要的基础训练形式&#xff0c;也迎来了更高效的解决方案。 这是一款轻量但功能强大的听写辅助工具。它是基于本地词库与可选在线语音引擎构建&#xff0c;…...