「解析」YOLOv5 classify分类模板
学习深度学习有些时间了,相信很多小伙伴都已经接触 图像分类、目标检测甚至图像分割(语义分割)等算法了,相信大部分小伙伴都是从分类入门,接触各式各样的 Backbone算法开启自己的炼丹之路。
但是炼丹并非全是 Backbone,更多的是各种辅助代码,而这部分公开的并不多,特别是对于刚接触/入门的人来说就更难了,博主当时就苦于没有完善的辅助代码,走了很多弯路,好在YOLOv5提供了分类、目标检测的完整代码,不同于目标检测,因数据集不同,对应的数据辅助代码也不兼容,图像分类就不会有这方面的影响,只需要更换下模型,设置下输出类即可。可谓相当的成熟,学者必备!!!
官方代码:https://github.com/ultralytics/yolov5

分类任务有四部分组成:tutorial说明,train、val、predict 脚本
对于有一定基础的小伙伴可以直接查看 tutorial 自行运行,如果遇到一些暂无解决的问题时再往下阅读!
小插曲
对于调用官方库的模型时,不同数据集对应不同类别数,如果不想修改官方库的话,可以在创建好模型后,再修改最后的类别数
import torchvison.models as modelsmodel = models.__dict__["resnet50"]() # 加载官方模型库,默认是 imagesnet1000类
dataset_class = len(dataset.classes) # 计算训练数据集的类别数num_feature = model.fc.in_features
model.fc = nn.Sequential(nn.Linear(num_feature, dataset_class),nn.LogSofmax(dim=1))model.load_state_dict(torch.load"model_parmas.pt")
train任务
整个 图像分类任务还是较为复杂的,内容略微庞大,一篇讲解不完,讲解不清的可以下方留言,较难问题博主再出新博文解释。

parse_opt() 函数
首先大家在学习代码时,一定要学会 debug 模型,这样才知道代码是如何运行的,一般从 if __name__ == "__main__": 开始进行。
首先是 def parse_opt(known=False): 解析配置参数
-
parser.add_argument('--model', type=str, default='yolov5s-cls.pt', help='initial weights path')
–model 参数是配置模型类型,从下面的解析 --model参数可以看出,如果 --model的值是模型权重名称/路径的话,直接加载到模型model,如果–model是torchvision模型库的,将从torchvision库中读取, 如果都没有的话,将以错误输出。
所以 --model 一定要是 模型权重名称/路径,并且需要能够读取得到才可以。亦可以是torchvision模型库中的模型名称也可以(可以通过torchvision.models.__dict_查看安装的torchvision封装了哪些模型库)
此外torchvision.models.__dict__[opt.model](weights='IMAGENET1K_V1' if pretrained else None)代码并不适用于所有版本的 torchvision模型,还是需要进入 torchvision.model下的具体模型代码中查看 调用方法,否则会出现错误。

-
parser.add_argument('--data', type=str, default='mnist', help='cifar10, cifar100, mnist, imagenet, etc.')
–data 可以是数据集的路径,也可以是数据集而名称, 只是数据集名称必须是 ultralytics 公开的数据集才可以,比如:Classification:Caltech 101、Caltech 256、CIFAR-10、CIFAR-100、Fashion-MNIST、ImageNet、ImageNet-10、Imagenette、Imagewoof、MNIST

如果是自定义的数据集,需要注意的是每一类的所有数据需要放到同一个文件夹下面,如同 cifar10 数据集一样,在 train/val/test 文件夹下分别建立每一类的子文件夹,其中可以存放全部图片,也可以有多层嵌套路径,注意:train/val/test下的文件夹名称和数量 要保持一致,否则训练出来的指标会很差

-
parser.add_argument('--epochs', type=int, default=10)
就是训练的迭代轮数 -
parser.add_argument('--imgsz', '--img', '--img-size', type=int, default=128, help='train, val image size (pixels)')
训练时 图片的尺寸大小 -
parser.add_argument('--nosave', action='store_true', help='only save final checkpoint')
不保存中间每个epoch的权重,如果需要保存的话,将其设置为 False -
parser.add_argument('--cache', type=str, nargs='?', const='ram', help='--cache images in "ram" (default) or "disk"')
选择数据的读取方式,ram方式为一次性将所有的数据读取到内存里,以为内存与显存的传输速度高,因此训练市场可以极大降低,前提是内存够大,如果没有足够大的内存的话,可以算法disk硬盘读取,效率略低 -
parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
选择训练设备,可以选择:“cup, mps, cuda”(MPS:Apple Metal Performance Shaders) -
parser.add_argument('--workers', type=int, default=8, help='max dataloader workers (per RANK in DDP mode)')
数据集加载时的线程数 -
parser.add_argument('--project', default=ROOT / 'runs/train-cls', help='save to project/name')
项目保存路径及名称 -
parser.add_argument('--name', default='exp', help='save to project/name')
每次训练的子文件名 -
parser.add_argument('--exist-ok', action='store_true', help='existing project/name ok, do not increment')
如果已经存在保存文件名/路径,可以覆盖保存 -
parser.add_argument('--pretrained', nargs='?', const=True, default=True, help='start from i.e. --pretrained False')
是否使用预训练权重(前提是必须是torchvision中的模型,官方提供预训练接口的模型才有用) -
parser.add_argument('--optimizer', choices=['SGD', 'Adam', 'AdamW', 'RMSProp'], default='Adam', help='optimizer')
优化器选择,此处官方配置好了 [‘SGD’, ‘Adam’, ‘AdamW’, ‘RMSProp’] 优化器,如果需要其他优化器,需用小伙伴自行配置 -
parser.add_argument('--lr0', type=float, default=0.001, help='initial learning rate')
优化器的初始学习率 -
parser.add_argument('--label-smoothing', type=float, default=0.1, help='Label smoothing epsilon')
label-smoothing 方法,对 label进行 smoothing 处理 -
parser.add_argument('--cutoff', type=int, default=None, help='Model layer cutoff index for Classify() head')
裁切模型的 classify分支 的层数,model.model = model.model[:cutoff] -
parser.add_argument('--dropout', type=float, default=None, help='Dropout (fraction)')
随机失效部分神经元,dropout处理 -
parser.add_argument('--verbose', action='store_true', help='Verbose mode')
冗余模式,记录中间的模型日志 -
parser.add_argument('--seed', type=int, default=0, help='Global training seed')
全局随机种子 -
parser.add_argument('--local_rank', type=int, default=-1, help='Automatic DDP Multi-GPU argument, do not modify')
如果小伙伴有多卡,可以采用,此方法可以自动调用多个显卡的资源,即DDP 模式,-1 为不采用
train() 函数
train() 函数前面都是一些模型配置
- 模型训练保存路径,以及配置训练日志,默认情况下,模型训练保存 一个 last.pt 和 best.pt

- 数据集下载,如果是官方的数据集,直接 对
--data设置数据集名称即可(完成路径也是可以的),如果是自己的数据集,需要设置数据集路径,只需要给到 train 的上一级目录即可

- 数据集构建,此处将读取数据集的类别数以及加载数据集,此处默认是以 test 为验证集的,如果没有test 备份选择 val。如果需要用 val 当验证集,手动改为 val即可。再次提示:train 下的文件夹名称和数量需要和 验证集下的保持一致,否则模型性能很低,且无法提升(惨痛的教训!)

- 构建模型,此处需要注意一点,作为分类模型,模型的输出层必须和数据集的类别数量保持一致,必须!!!
如果不使用 torchvision中的模型,只需要将 model 赋值为自己的模型即可

- 日志保存模型等信息,以及加载 数据和标签,此处的数据加载器采用的是迭代器方式,因此采用 nest(iter());然后是优化器设置,学习率、调度器(scheduler)设置 和 EMA配置;最后是损失函数criterion。

- 进行完上面所有的参数配置,真正的模型训练还在下面这个循环里

相关文章:
「解析」YOLOv5 classify分类模板
学习深度学习有些时间了,相信很多小伙伴都已经接触 图像分类、目标检测甚至图像分割(语义分割)等算法了,相信大部分小伙伴都是从分类入门,接触各式各样的 Backbone算法开启自己的炼丹之路。 但是炼丹并非全是 Backbone,更多的是各…...
交换排序——冒泡排序、快速排序
交换排序就是通过比较交换实现排序。分冒泡排序和快速排序两种。 一、冒泡排序: 1、简述 顾名思义就是大的就冒头,换位置。 通过多次重复比较、交换相邻记录而实现排序;每一趟的效果都是将当前键值最大的记录换到最后。 冒泡排序算法的原…...
Android 10.0 禁用adb shell input输入功能
1.前言 在10.0的产品开发中,在进行一些定制开发中,对于一些adb shell功能需要通过属性来控制禁止使用input 等输入功能,比如adb shell input keyevent 响应输入事件等,所以就需要 熟悉adb shell input的输入事件流程,然后来禁用adb shell input的输入事件功能,接下来分…...
cuda显存访问耗时
背景: 项目中有个数据量大小为5195 * 512 * 128float 1.268G的显存,发现有个函数调用很耗时,函数里面就是对这个显存进行128个元素求和,得到一个5195 * 512的图像 分析 1. 为什么耗时 直观上感觉这个流程应该不怎么耗时才对&a…...
【HTML5高级第三篇】drag拖拽、音频视频、defer/async属性、dialog应用
文章目录 一、拖拽事件1.1 拖拽事件1.2 案例:拖拽丢弃图片 二、音频和视频三、defer 与 async 属性3.1 概述3.2 示例一:3.3 示例二: 四、dialog 元素 一、拖拽事件 原生JavaScipt案例合集 JavaScript DOM基础 JavaScript 基础到高级 Canvas…...
独享IP vs. 共享IP:哪种更适合你?
无论是个人用户还是企业组织,在互联网上都需要一个唯一标识来与其他设备进行通信。这就涉及到使用独立分配给自己或多个用户分享的公共 IP 地址(也称为共享 IP)。那么,究竟应该选择独占一个专用地址还是与他人分享相同地址呢&…...
【Arduino27】DHT11温湿度传感器模拟值实验
硬件准备 DHT11温湿度:1个 面包板:1个 杜邦线:3根 硬件连线 VDD引脚接 5V 电源 DATE引脚接 4号 接口 GND引脚接 GND 接口 软件程序 #include<DHT.h>#define DHT11_pin 4 //温湿度传感器引脚DHT dht(DHT11_pin,DHT11);float tem…...
dockerfile基于apline将JDK20打包成镜像
dockerfile基于apline将JDK20打包成镜像 今天就来和大家聊聊如何把最新出版的JDK20打包成docker镜像,很多uu都会采用centos作为基础镜像,这么做会有一个问题,centos系统会含有很多库文件,这些库文件JDK程序并不是完全需要的&a…...
MATLAB基础-MAT文件的读写操作
简介 MAT文件是MATLAB格式的双精度二进制数据文件,由MATLAB软件创建,可以使用MATLAB软件再其他计算机上以其他浮点格式读取,同时也可以使用其他软件通过MATLAB的应用程序接口来进行读写操作。如果只是再MATLAB环境中处理数据,使用…...
PostgreSQL PG15 新功能 PG_WALINSPECT
开头还是介绍一下群,如果感兴趣PolarDB ,MongoDB ,MySQL ,PostgreSQL ,Redis ,Oracle ,Oceanbase 等有问题,有需求都可以加群群内有各大数据库行业大咖,CTO,可以解决你的问题。加群请加微信号 liuaustin3 (…...
时序预测 | MATLAB实现TCN-BiLSTM时间卷积双向长短期记忆神经网络时间序列预测
时序预测 | MATLAB实现TCN-BiLSTM时间卷积双向长短期记忆神经网络时间序列预测 目录 时序预测 | MATLAB实现TCN-BiLSTM时间卷积双向长短期记忆神经网络时间序列预测预测效果基本介绍模型描述程序设计参考资料 预测效果 基本介绍 1.MATLAB实现TCN-BiLSTM时间卷积双向长短期记忆神…...
数据结构和算法(2):向量
抽象数据类型 数组到向量 C/C 中,数组A[]中的元素与[0,n)内的编号一一对应,A[0],A[1],...,A[n-1];反之,每个元素均由(非负)编号唯一指代,并可直接访问A[i] 的物理地址 Ai s,s 为单…...
mysql 大表如何ddl
大家好,我是蓝胖子,mysql对大表(千万级数据)的ddl语句,在生产上执行时一定要千万小心,一不小心就有可能造成业务阻塞,数据库io和cpu飙高的情况。今天我们就来看看如何针对大表执行ddl语句。 通过这篇文章,…...
C++新特性:智能指针
一 、为什么需要智能指针 智能指针主要解决以下问题: 1)内存泄漏:内存手动释放,使用智能指针可以自动释放 2)共享所有权指针的传播和释放,比如多线程使用同一个对象时析构问题,例如同样的数据…...
SAP FI之批量修改财务凭证的BAPI
文章目录 前言一、pandas是什么?二、使用步骤 1.引入库2.读入数据总结 前言 一般涉及修改财务凭证,或者其它凭证,不应直接更新数据库,而是使用系统提供的function module,或者BAPI,或者使用BDC。 一、 示例…...
Spring Boot + Vue的网上商城之商品分类
Spring Boot Vue的网上商城之商品分类 在网上商城中,商品分类是非常重要的一个功能,它可以帮助用户更方便地浏览和筛选商品。本文将介绍如何使用Spring Boot和Vue来实现商品分类的功能,包括一级分类和二级分类的管理以及前台按分类浏览商品…...
Docker 容器逃逸漏洞 (CVE-2020-15257)复现
漏洞概述 containerd是行业标准的容器运行时,可作为Linux和Windows的守护程序使用。在版本1.3.9和1.4.3之前的容器中,容器填充的API不正确地暴露给主机网络容器。填充程序的API套接字的访问控制验证了连接过程的有效UID为0,但没有以其他方式…...
Python 如何使用 csv、openpyxl 库进行读写 Excel 文件详细教程(更新中)
csv 基本概述 首先介绍下 csv (comma separated values),即逗号分隔值(也称字符分隔值,因为分隔符可以不是逗号),是一种常用的文本格式,用以存储表格数据,包括数字或者字符。 程序在处理数据时…...
$nextTick属性使用与介绍
属性介绍 $nextTick 是 Vue.js 中的一个重要方法,之前我们也说过$ref 等一些重要的属性,这次我们说$nextTick,$nextTick用于在 DOM 更新后执行回调函数。它通常用于处理 DOM 更新后的操作,因为 Vue 在更新 DOM 后不会立即触发回调…...
【群智能算法改进】一种改进的鹈鹕优化算法 IPOA算法[2]【Matlab代码#58】
文章目录 【获取资源请见文章第5节:资源获取】1. 原始POA算法2. 改进后的IPOA算法2.1 随机对立学习种群初始化2.2 动态权重系数2.3 透镜成像折射方向学习 3. 部分代码展示4. 仿真结果展示5. 资源获取 【获取资源请见文章第5节:资源获取】 1. 原始POA算法…...
ESP32读取DHT11温湿度数据
芯片:ESP32 环境:Arduino 一、安装DHT11传感器库 红框的库,别安装错了 二、代码 注意,DATA口要连接在D15上 #include "DHT.h" // 包含DHT库#define DHTPIN 15 // 定义DHT11数据引脚连接到ESP32的GPIO15 #define D…...
论文浅尝 | 基于判别指令微调生成式大语言模型的知识图谱补全方法(ISWC2024)
笔记整理:刘治强,浙江大学硕士生,研究方向为知识图谱表示学习,大语言模型 论文链接:http://arxiv.org/abs/2407.16127 发表会议:ISWC 2024 1. 动机 传统的知识图谱补全(KGC)模型通过…...
Linux离线(zip方式)安装docker
目录 基础信息操作系统信息docker信息 安装实例安装步骤示例 遇到的问题问题1:修改默认工作路径启动失败问题2 找不到对应组 基础信息 操作系统信息 OS版本:CentOS 7 64位 内核版本:3.10.0 相关命令: uname -rcat /etc/os-rele…...
MySQL 主从同步异常处理
阅读原文:https://www.xiaozaoshu.top/articles/mysql-m-s-update-pk MySQL 做双主,遇到的这个错误: Could not execute Update_rows event on table ... Error_code: 1032是 MySQL 主从复制时的经典错误之一,通常表示ÿ…...
【C++】纯虚函数类外可以写实现吗?
1. 答案 先说答案,可以。 2.代码测试 .h头文件 #include <iostream> #include <string>// 抽象基类 class AbstractBase { public:AbstractBase() default;virtual ~AbstractBase() default; // 默认析构函数public:virtual int PureVirtualFunct…...
GraphQL 实战篇:Apollo Client 配置与缓存
GraphQL 实战篇:Apollo Client 配置与缓存 上一篇:GraphQL 入门篇:基础查询语法 依旧和上一篇的笔记一样,主实操,没啥过多的细节讲解,代码具体在: https://github.com/GoldenaArcher/graphql…...
es6+和css3新增的特性有哪些
一:ECMAScript 新特性(ES6) ES6 (2015) - 革命性更新 1,记住的方法,从一个方法里面用到了哪些技术 1,let /const块级作用域声明2,**默认参数**:函数参数可以设置默认值。3&#x…...
若依登录用户名和密码加密
/*** 获取公钥:前端用来密码加密* return*/GetMapping("/getPublicKey")public RSAUtil.RSAKeyPair getPublicKey() {return RSAUtil.rsaKeyPair();}新建RSAUti.Java package com.ruoyi.common.utils;import org.apache.commons.codec.binary.Base64; im…...
如何在Windows本机安装Python并确保与Python.NET兼容
✅作者简介:2022年博客新星 第八。热爱国学的Java后端开发者,修心和技术同步精进。 🍎个人主页:Java Fans的博客 🍊个人信条:不迁怒,不贰过。小知识,大智慧。 💞当前专栏…...
node.js的初步学习
那什么是node.js呢? 和JavaScript又是什么关系呢? node.js 提供了 JavaScript的运行环境。当JavaScript作为后端开发语言来说, 需要在node.js的环境上进行当JavaScript作为前端开发语言来说,需要在浏览器的环境上进行 Node.js 可…...
