当前位置: 首页 > article >正文

【深度学习】宠物品种分类Pet Breeds Classifier

文章目录

    • 宠物品种数据集
      • 制作宠物品种标签
      • 图像预处理Presizing
    • 损失函数loss
    • 观察模型的性能
    • 提升模型的性能
      • learning rate finder
        • 使用CLR算法训练
        • 选择学习率的策略
        • 重新训练
      • 迁移学习
        • 微调fine_tune
        • fit_one_cycle
        • 有判别力的学习率
      • 选择epoch的数量
      • 更深的网络架构

宠物品种数据集

这里我们使用fastai深度学习库。

from fastai.vision.all import *

从fastai的官网下载Pets数据集,解压至本地文件夹内。

path = untar_data(URLs.PETS)
Path.BASE_PATH = path
path.ls()

在这里插入图片描述

制作宠物品种标签

annotations目录内的文件主要告诉了我们宠物在图像的具体位置,但我们今天要完成的任务是宠物分类。所以我们需要重新制作标签。

fname = (path/"images").ls()[0]
fname

在这里插入图片描述这是一张图片的文件名,格式为“宠物名_编号.jpg”
我们的目的是提取出下划线前面的宠物名,这时候就需要用到正则表达式来提取字符串了。(正则表达式的讲解)

  • (.+)可截取若干个任意字符,所以可以提取宠物品种名称
  • _匹配下划线
  • \d+匹配数字,在这里就是匹配了编号
  • .jpg匹配后缀名
  • $结束字符
re.findall(r'(.+)_\d+.jpg$', fname.name)

在这里插入图片描述

在fastai深度学习库中,我们有已经提前实现好的RegexLabeller类,同样实现了“根据正则表达式提取字符串”的功能,而且我们通常是在DataBlock代码块中使用。

pets = DataBlock(# Inputs类型:Image, Targets类型:Categoryblocks = (ImageBlock, CategoryBlock),# 从images目录下的子文件夹内获得数据get_items=get_image_files, splitter=RandomSplitter(seed=42),# 定义对targets所作的操作,然后得到yget_y=using_attr(RegexLabeller(r'(.+)_\d+.jpg$'), 'name'),# 裁剪图像item_tfms=Resize(460),# 数据增强batch_tfms=aug_transforms(size=224, min_scale=0.75))
dls = pets.dataloaders(path/"images")

图像预处理Presizing

接下来我们来研究一下这两行代码,

# 裁剪图像
item_tfms=Resize(460),
# 数据增强
batch_tfms=aug_transforms(size=224, min_scale=0.75))

为什么先裁剪图像然后数据增强呢?
因为在fastai深度学习库中,许多数据增强方法会让图像的质量下降,我们再裁剪的就是已经损坏的图像,那样对模型训练没有益处。
因此我们先将图像统一裁剪成460*460的大小,然后再进行数据增强。

在这里插入图片描述

如上图所示,Presize通常有两步:
1、把图像裁剪成相对较大的尺寸(比训练时的图像尺寸大)
2、把所有常见的数据增强增强操作变成一个组合操作,然后在GPU上执行组合操作

损失函数loss

我们使用的是交叉熵损失函数(cross-entropy loss),它适用于多目标分类。(cross-entropy loss讲解)

观察模型的性能

我们通常使用混淆矩阵confusion matrix来观察模型的表现。
如果图像的类别有n个,那么混淆矩阵的大小就是nxn,但是宠物的类别有37种,那混淆矩阵就太大了。
在这里插入图片描述

因此我们使用most_confused方法来查看在有着最多不正确的预测值的单元格(至少有5个)。

interp.most_confused(min_val=5)

在这里插入图片描述根据研究,这两种情况下,即使是宠物专家也容易搞错,所以我们的模型效果还不错。
接下来就是考虑如何改进我们的模型了。

提升模型的性能

learning rate finder

learning rate finder的代码思想源自CLR(cyclical learning rates)算法
learning rate finder方法的主要步骤:
1、将训练集分成batch,每次循环训练一个batch
2、初始学习率(例如1e-6)较小,然后每次将学习率扩大(例如乘以2)后训练下一个batch
3、记录training loss,直到training loss不再变小反而变大停止

使用CLR算法训练
learn = vision_learner(dls, resnet34, metrics=error_rate)lr_min, lr_steep = learn.lr_find(suggest_funcs=(minimum, steep))

什么是minimum? 使training loss达到最小的点
什么是steep? 使training loss曲线下降坡度最大的点。

在这里插入图片描述

选择学习率的策略
print(f"Minimum/10: {lr_min:.2e}, steepest point: {lr_steep:.2e}")

在这里插入图片描述

根据loss-lr图,选择学习率的策略有以下2种:
1、使training loss达到最小的学习率,在这个数值上/10
2、最后一个使training loss明显下降的点附近
根据策略2,已经知道最陡的点是2.09e-3,我们在这里取最后一个使training loss明显下降的点大约为3e-3

重新训练

现在我们使用3e-3作为学习率,重新训练模型。

learn = vision_learner(dls, resnet34, metrics=error_rate)
learn.fine_tune(2, base_lr=3e-3)

在这里插入图片描述

观察结果,相较于使用1e-1,使用3e-3作为学习率训练模型,使error_rate从降低到0.073!

迁移学习

微调fine_tune

预训练模型是在Imagenet数据集上训练好的模型。
当我们做迁移学习的时候,我们去除网络架构的最后一层,然后加入新的全连接层,输出数量是数据集的标签类别个数。

fine_tune方法帮助我们实现了迁移学习,非常方便实用。
当我们调用fine_tune方法时,有以下2种策略:
1)训练新添加的层指定的epoch数量,其余的层全部都冻住
2)所有层都解冻,然后再重新训练指定的epoch数量

预训练网络模型的前面几层提取大多数图片的共有特征,如边框等,因此我们不想要重新训练前面几层的权重,所以只需要冻住freeze, 而最后几个新添加的层希望被用来学习新数据集的独有特征,因此我们想要训练新添加的层的权重,需要解冻unfreeze.

要想查看某个函数的源代码,只需在后面加上??即可。
这是fine_tune方法的源代码。

Signature:
learn.fine_tune(epochs,base_lr=0.002,freeze_epochs=1,lr_mult=100,pct_start=0.3,div=5.0,*,lr_max=None,div_final=100000.0,wd=None,moms=None,cbs=None,reset_opt=False,start_epoch=0,
)
Source:   
@patch
@delegates(Learner.fit_one_cycle)
def fine_tune(self:Learner, epochs, base_lr=2e-3, freeze_epochs=1, lr_mult=100,pct_start=0.3, div=5.0, **kwargs):"Fine tune with `Learner.freeze` for `freeze_epochs`, then with `Learner.unfreeze` for `epochs`, using discriminative LR."self.freeze()self.fit_one_cycle(freeze_epochs, slice(base_lr), pct_start=0.99, **kwargs)base_lr /= 2self.unfreeze()self.fit_one_cycle(epochs, slice(base_lr/lr_mult, base_lr), pct_start=pct_start, div=div, **kwargs)
File:      /usr/local/lib/python3.10/dist-packages/fastai/callback/schedule.py
Type:      method

从上面可以看到,在源代码中使用了fit_one_cycle函数。
我们接下来自己手动模拟一个fine_tune函数。不过,在这之前,先来了解一下fit_one_cycle函数

fit_one_cycle

fit_one_cycle的作用是,先以较低的学习率开始训练,慢慢增加到指定最大学习率后,再逐渐减小学习率。

我们先训练模型3个epoch,创建好vision_learner以后,除新添加的层外,其余层默认是被冻住的。此时我们训练的就只有刚刚添加的层。

learn = vision_learner(dls, resnet34, metrics=error_rate)
learn.fit_one_cycle(3, lr_max=3e-3)

在这里插入图片描述
我们接下来查看学习率的变化曲线,以及损失的变化曲线

learn.recorder.plot_sched()

在这里插入图片描述

learn.recorder.plot_loss()

在这里插入图片描述

我们可以看出:
1)学习率一开始很小,loss比较大
2)学习率增大并逐渐到达最大值,loss变小
3)学习率又逐渐变小,loss也在变小

因此fit_one_cycle函数的代码实现了learning rate annealing算法,它也被称为学习率退火算法。

好了,话说回来,我们继续手动搭建自己的fine_tune函数,来训练我们的模型。
一开始我们训练了3个epoch,现在我们将模型所有的层都解冻。

learn.unfreeze()

我们需要重新再使用learning rate finder, 因为模型的权重已经发生了改变,所有我们要找到新的最佳的learning rate

learn.lr_find()

在这里插入图片描述
基于选择学习率的策略2,从上图中我们可以看出最后一个明显下降的点在1e-4附近,但我们选择略小一点的,所以是1e-5,接下来我们继续训练6个epoch

learn.fit_one_cycle(6, lr_max=1e-5)

在这里插入图片描述
从上图中可以看出,我们模型的error_rate已经从刚刚的0.074降低到0.059了!

有判别力的学习率

因为预训练的层已经能够识别边缘等,但对于具体任务的特征还需要训练,所以新添加的层就相较于预训练层使用大一点的学习率。
新论文中提出较早的层使用较小的学习率训练,新添加的层用较大的学习率训练。
Python中的slice(,)对象指明了学习率。
第一个参数,是用于训练第一层的学习率;第二个参数,是用于训练最后一层的学习率;中间的层的学习率是在这个范围内等距增加的。
先使用学习率3e-3训练3个epoch,然后将所有的网络参数解冻。然后再继续训练12个epoch,但不同的层学习率也不同,第一层为1e-6, 最后一层为1e-4, 中间的层等距增加。

learn = vision_learner(dls, resnet34, metrics=error_rate)
learn.fit_one_cycle(3, 3e-3)
learn.unfreeze()
learn.fit_one_cycle(12, lr_max=slice(1e-6,1e-4))

在这里插入图片描述

从上面的表格中我们可以发现,error_rate降低到了0.0541

我们来对比一下源代码,和我们基于自己模型创建的fine_tune方法

## 这是源代码
def fine_tune(self:Learner, epochs, base_lr=2e-3, freeze_epochs=1, lr_mult=100,pct_start=0.3, div=5.0, **kwargs):self.freeze()self.fit_one_cycle(freeze_epochs, slice(base_lr), pct_start=0.99, **kwargs)base_lr /= 2self.unfreeze()self.fit_one_cycle(epochs, slice(base_lr/lr_mult, base_lr), pct_start=pct_start, div=div, **kwargs)## 这是我们自己的版本
learn = vision_learner(dls, resnet34, metrics=error_rate)
learn.fit_one_cycle(3, 3e-3)
learn.unfreeze()
learn.fit_one_cycle(12, lr_max=slice(1e-6,1e-4))

有人可能会想问“为什么源代码中就有self.freeze()呢?”
因为在初始化vision_learner的时候,其源函数中有这样一行代码,if pretrained: learn.freeze(),所有创建好learner对象后,其实已经处于冻结状态了。

选择epoch的数量

  • 选择你愿意等待的训练时间的epoch数量,开始训练
  • 观察training loss和validation loss图,也要注意metrics,如果在最后的几个epoch中,loss和metrics仍然能变得更好,那说明我们训练的时间还不够长
  • 在训练的最后几个epoch中,validation loss变得更差,而且metrics也变得更差,说明我们训练的时间太长了。
  • 此外,如果过拟合了,重新训练你的模型,并且基于之前的结果,重新设定epoch数。

在one cycle策略出现之前,我们通常使用early stopping早停策略。也就是,在训练的每一个epoch结束后,
我们会把权重保存下来,最后从中选择一个最佳的模型。 但有时候early stopping策略并不能使我们获得最好的结果,在学习率达到很小的值(同时让模型表现最佳)之前,中间的那些epoch已经发生了。
因此,如果你发现已经过拟合了,你要做的是重新训练你的模型,并且基于之前的结果,重新设定epoch数。

如果你有更多的时间训练更多的epoch,那么我们也可以选择把这么多的时间用于训练更多的参数,即更深的网络架构。

更深的网络架构

  • 好处:一般来说,更深的网络架构意味着更多的参数,这也使得我们的模型能学习更多与数据相关的特征,使得准确率更高。
  • 坏处:大量参数意味着需要占用很大的gpu内存,也需要很长的训练时间

因此,有一种方法叫作“半精度训练”,即在训练的时候,使用“半精度浮点数”(fp16),使得训练速度大大加快,也减少了内存使用。在fastai深度学习库中,我们直接让learner对象调用to_fp16()方法。

from fastai.callback.fp16 import *
learn = vision_learner(dls, resnet50, metrics=error_rate).to_fp16()
learn.fine_tune(6, freeze_epochs=3)

在这里插入图片描述

这里我自己采用resnet152,并且使用单精度训练的方式,冻住预训练权重并训练2个epoch;然后解冻模型并训练5个epoch

from fastai.callback.fp16 import *
learn = vision_learner(dls, resnet152, metrics=error_rate).to_fp16()
learn.fine_tune(5, freeze_epochs=2)

从下图可以看出,我的模型的error_rate降低到了0.045!
在这里插入图片描述

如果大家有更好的方法,欢迎大家评论,我很想学习。

相关文章:

【深度学习】宠物品种分类Pet Breeds Classifier

文章目录 宠物品种数据集制作宠物品种标签图像预处理Presizing 损失函数loss观察模型的性能提升模型的性能learning rate finder使用CLR算法训练选择学习率的策略重新训练 迁移学习微调fine_tunefit_one_cycle有判别力的学习率 选择epoch的数量更深的网络架构 宠物品种数据集 …...

【从零开始学习计算机科学】HLS算子调度

算子调度 调度是HLS 中的核心问题,为无时序或部分时序的输入指定时钟边界,其对最终结果质量具有很大的影响。调度会影响时钟频率、延时、吞吐率、面积、功耗等多种因素。 调度的输入是控制数据流图,其节点表示算子/操作,有向边表示数据依赖,控制依赖,优先依赖。如果没有…...

centos 安装composer 教程

打开命令行 php -r "copy(https://getcomposer.org/installer, composer-setup.php);" sudo php composer-setup.php --install-dir/usr/local/bin --filenamecomposer composer --version sudo chmod us /usr/local/bin/composer Super18120/article/details/14388…...

C语言_数据结构总结2:动态分配方式的顺序表

0——静态分配内存的顺序表和动态分配内存的顺序表的相同之处和不同之处 相同之处 基本操作逻辑相同:无论是静态分配还是动态分配的顺序表,其核心的操作逻辑是一致的。例如插入操作都需要将插入位置之后的元素依次后移,删除操作都需要将删除…...

嵌入式人工智能应用-第6章 人脸检测

嵌入式人工智能应用 人脸检测 嵌入式人工智能应用1 人脸检测1.1 CNN 介绍1.2 人脸检测原理1.3 MTCNN介绍1.4 NCNN介绍2 系统安装2.1 安装依赖库NCNN2.2 运行对应的库3 总结1 人脸检测 1.1 CNN 介绍 卷积神经网络。卷积是什么意思呢?从数学上说,卷积是一种运算。它是我们学习…...

关于无感方波启动预定位阶段

一、预定位的核心目标与原理 消除启动不确定性 无位置传感器下,转子初始位置未知,直接换相可能导致反转或失步。预定位通过施加固定方向磁场,强制转子对齐至预定角度(通常0或60电角度),建立初始位置基准。 …...

WSL安装及问题

1 概述 Windows Subsystem for Linux(简称WSL)是一个在Windows 10\11上能够运行原生Linux二进制可执行文件(ELF格式)的兼容层。它是由微软与Canonical公司合作开发,开发人员可以在 Windows 计算机上同时访问 Windows 和…...

MySQL中的脏读与幻读:概念、影响与解决方案

在数据库事务处理中,脏读和幻读是两种常见的并发问题,可能导致数据不一致或逻辑错误。本文将结合实际场景,深入解析两者的原理及解决方案。 一、脏读(Dirty Read) 1. 概念解析 脏读指一个事务读取了另一个事务未提交…...

基于SpringBoot的商城管理系统(源码+部署教程)

运行环境 数据库:MySql 编译器:Intellij IDEA 前端运行环境:node.js v12.13.0 JAVA版本:JDK 1.8 主要功能 基于Springboot的商城管理系统包含管理端和用户端两个部分,主要功能有: 管理端 首页商品列…...

HeidiSQL:一款免费的数据库管理工具

HeidiSQL 是一款免费的图形化数据库管理工具,支持 MySQL、MariaDB、Microsoft SQL、PostgreSQL、SQLite、Interbase 以及 Firebird,目前只能在 Windows 平台使用。 HeidiSQL 的核心功能包括: 免费且开源,所有功能都可以直接使用。…...

Ae 效果详解:VR 色差

Ae菜单:效果/沉浸式视频/VR 色差 Immersive Video/VR Chromatic Aberrations VR 色差 VR Chromatic Aberrations效果用于模拟镜头色散现象,在 VR 视频中制造 RGB 通道错位的色彩偏移,以增强视觉风格或创造数字失真效果。 本效果适用于所有色深…...

计算机毕业设计SpringBoot+Vue.js制造装备物联及生产管理ERP系统(源码+文档+PPT+讲解)

温馨提示:文末有 CSDN 平台官方提供的学长联系方式的名片! 温馨提示:文末有 CSDN 平台官方提供的学长联系方式的名片! 温馨提示:文末有 CSDN 平台官方提供的学长联系方式的名片! 作者简介:Java领…...

Ubuntu 安装docker docker-compose

Docker 通过提供轻量级、可移植且高效的解决方案,简化了软件开发和部署。“docker build”命令是 Docker 镜像创建过程的核心。本文将探讨 Docker 构建命令、用法以及 Docker 构建的优化。 Docker 构建有什么作用? Docker build 是一个命令行界面 CLI命…...

【Linux内核系列】:深入解析输出以及输入重定向

🔥 本文专栏:Linux 🌸作者主页:努力努力再努力wz ★★★ 本文前置知识: 文件系统以及文件系统调用接口 用c语言简单实现一个shell外壳程序 内容回顾 那么在此前的学习中,我们对于Linux的文件系统已经有了…...

【linux网络编程】端口

一、端口(Port)概述 在计算机网络中,端口(Port) 是用来标识不同进程或服务的逻辑通信端点。它类似于一座大楼的房间号,帮助操作系统和网络协议区分不同的应用程序,以便正确地传输数据。 1. 端口…...

PyTorch系列教程:Tensor.view() 方法详解

这篇简明扼要的文章是关于PyTorch中的tensor.view()方法的介绍与应用,与reshape()方法的区别,同时给出示例进行详细解释。 Tensor基础 Tensor(张量)的视图是一个新的Tensor,它与原始Tensor共享相同的底层数据,但具有不同的形状或…...

软件测试的基础入门(二)

文章目录 一、软件(开发)的生命周期什么是生命周期软件(开发)的生命周期需求分析计划设计编码测试运行维护 二、常见的开发模型瀑布模型流程优点缺点适应的场景 螺旋模型流程优点缺点适应的场景 增量模型和迭代模型流程适应的场景…...

Springboot + minio

参考&#xff1a; SpringBoot整合Minio_springboot minio-CSDN博客 <!--minio 依赖--><dependency><groupId>io.minio</groupId><artifactId>minio</artifactId><version>8.5.11</version></dependency> applicaio…...

地下变电站如何实现安全智能运营-以110kV站为例看环境监测与设备联控

1、地下变电站简介 在经济发达的地区&#xff0c;由于城市中心土地资源紧张、征地拆迁费用昂贵&#xff0c;因此采用地下变电站来解决这些问题不失为一个好的途径和思路。地下变电站一般采用室内全封闭式组合电气设备&#xff0c;&#xff12;&#xff12;&#xff10;&#x…...

windows无界面后台定时任务 (重启自启动,ODBS为例)

一、前言 mdb(Microsoft Database)是Microsoft Access中使用的一种数据存储格式,可以通过ODBC驱动程序进行访问和操作,在Python中也可以安装相应模块打开。 这是我在项目中更新bs数据的一个实践记录,结合windows定时一起记录一下,方便以后照搬~ 二、安装 Python安装库…...

FPGA 实验报告:四位全加器与三八译码器仿真实现

目录 安装Quartus软件 四位全加器 全加器、半加器 半加器&#xff1a; 全加器&#xff1a; 四位全加器电路图 创建项目 半加器 全加器 四位全加器 代码实现 半加器 全加器 四位全加器 三八译码器 创建项目 代码展示 modelsim仿真波形图 四位全加器 三八译码…...

win11 Visual Studio 17 2022源码编译 opencv4.11.0 + cuda12.6.3 启用GPU加速

win11 Visual Studio 17 2022 源码编译 opencv4.11.0 cuda12.6.3 启用GPU加速 配置: 生成 opencv 生成 opencv-python 1 下载源码和安装软件 win11 x64 系统 安装Visual Studio 17 2022 下载opencv4.11.0 源码 https://github.com/opencv/opencv/releases/tag/4.11.0 下载…...

Ribbon实现原理

文章目录 概要什么是Ribbon客户端负载均衡 RestTemplate核心方法GET 请求getForEntitygetForObject POST 请求postForEntitypostForObjectpostForLocation PUT请求DELETE请求 源码分析类图关系 与Eureka结合重试机制 概要 什么是Ribbon Spring Cloud Ribbon是一个基于HTTP和T…...

MuMu-LLaMA:通过大型语言模型进行多模态音乐理解和生成(Python代码实现+论文)

MuMu-LLaMA 模型是一种音乐理解和生成模型&#xff0c;能够进行音乐问答以及从文本、图像、视频和音频生成音乐&#xff0c;以及音乐编辑。该模型利用了用于音乐理解的 MERT、用于图像理解的 ViT 和用于视频理解的 ViViT 等编码器&#xff0c;以及作为音乐生成模型&#xff08;…...

高效Android MQTT封装工具:简化物联网开发,提升性能与稳定性

在Android开发中&#xff0c;封装MQTT工具可以帮助简化与MQTT服务器的通信。MQTT&#xff08;Message Queuing Telemetry Transport&#xff09;是一种轻量级的发布/订阅消息传输协议&#xff0c;常用于物联网&#xff08;IoT&#xff09;设备之间的通信。 以下是一个简单的MQ…...

数据库原理7

1.“数据库系统运行与维护工具”的研究属于数据库管理系统软件 2.1970年IBM公司的高级研究员E.F.Codd提出了关系数据模型 3.每个属性的属性值是不可分解的&#xff0c;即关系的每个分量必须是一个不可分的数据项。属性值的取值应满足域完整性约束。 4.视图作用&#xff1a;简…...

2025最新比较使用的ai工具都有哪些,分别主要用于哪些方面?

文章目录 一、AI对话与交互工具二、AI写作与内容生成工具三、AI绘画与设计工具四、AI视频生成工具五、办公与效率工具六、其他实用工具选择建议 根据2025年最新行业动态和用户反馈&#xff0c;以下AI工具在多个领域表现突出&#xff0c;覆盖对话、写作、设计、视频生成等场景&a…...

什么是 MyBatis? 它的优点和缺点是什么?

一、 什么是 MyBatis&#xff1f; 定义&#xff1a; MyBatis 是一款优秀的持久层框架&#xff0c;用于简化 Java 应用程序与数据库之间的交互。MyBatis 通过 XML 或注解 的方式&#xff0c;将 SQL 语句与 Java 代码分离&#xff0c;提供了一种灵活的、易于维护的数据访问解决方…...

在ArcMap中通过Python编写自定义工具(Python Toolbox)实现点转线工具

文章目录 一、需求二、实现过程2.1、创建Python工具箱&#xff08;.pyt&#xff09;2.2、使用catalog测试代码2.3、在ArcMap中使用工具 三、测试 一、需求 通过插件的形式将点转线功能嵌入ArcMap界面&#xff0c;如何从零开始创建一个插件&#xff0c;包括按钮的添加、工具的实…...

Array and string offset access syntax with curly braces is deprecated

警告信息 “Array and string offset access syntax with curly braces is deprecated” 是 PHP 中的一个弃用警告&#xff08;Deprecation Notice&#xff09;&#xff0c;表明在 PHP 中使用花括号 {} 来访问数组或字符串的偏移量已经被标记为过时。 背景 在 PHP 的早期版本…...