Yolov8 目标检测剪枝学习记录
最近在进行YOLOv8系列的轻量化,目前在网络结构方面的优化已经接近极限了,所以想要学习一下模型剪枝是否能够进一步优化模型的性能
这里主要参考了torch-pruning的基本使用,v8模型剪枝,Jetson nano部署剪枝YOLOv8
下面只是记录一个简单流程,用于后续使用在自己的任务和网络中,数据不作为参考
首先训练一个base模型用于参考
- 环境:Ultralytics YOLOv8.2.18 🚀 Python-3.10.14 torch-2.4.0 CUDA:0 (NVIDIA H100 PCIe, 81008MiB)
- 训练代码
参考网上或者自己写一个能训练即可,为了方便我将通用的记录下来,实测可用来自代码来源
from ultralytics import YOLO
import osroot = os.getcwd()
## 配置文件路径
name_yaml = os.path.join(root, "ultralytics/datasets/VOC.yaml")
name_pretrain = os.path.join(root, "yolov8s.pt")
## 原始训练路径
path_train = os.path.join(root, "runs/detect/VOC")
name_train = os.path.join(path_train, "weights/last.pt")
## 约束训练路径、剪枝模型文件
path_constraint_train = os.path.join(root, "runs/detect/VOC_Constraint")
name_prune_before = os.path.join(path_constraint_train, "weights/last.pt")
name_prune_after = os.path.join(path_constraint_train, "weights/last_prune.pt")
## 微调路径
path_fineturn = os.path.join(root, "runs/detect/VOC_finetune")def else_api():path_data = ""path_result = ""model = YOLO(name_pretrain) metrics = model.val() # evaluate model performance on the validation setmodel.export(format='onnx', opset=11, simplify=True, dynamic=False, imgsz=640)model.predict(path_data, device="0", save=True, show=False, save_txt=True, imgsz=[288,480], save_conf=True, name=path_result, iou=0.5) # 这里的imgsz为高宽def step1_train():model = YOLO(name_pretrain) model.train(data=name_yaml, device="0,1", imgsz=640, epochs=50, batch=32, workers=16, save_period=1, name=path_train) # train the model## 2024.3.4添加【amp=False】
def step2_Constraint_train():model = YOLO(name_train) model.train(data=name_yaml, device="0,1", imgsz=640, epochs=50, batch=32, amp=False, workers=16, save_period=1,name=path_constraint_train) # train the modeldef step3_pruning():from LL_pruning import do_pruningdo_pruning(name_prune_before, name_prune_after)def step4_finetune():model = YOLO(name_prune_after) # load a pretrained model (recommended for training)model.train(data=name_yaml, device="0,1", imgsz=640, epochs=50, batch=32, workers=16, save_period=1, name=path_fineturn) # train the modelstep1_train()
# step2_Constraint_train()
# step3_pruning()
# step4_finetune()
第一步,step1_train()
- 即训练一个base模型,用于最后性能好坏的重要参考
第二步,step2_Constraint_train()
训练之前在ultralytics\engine\trainer.py添加bn的L1正则,使得bn参数在训练时变得稀疏
- 通过对参数的绝对值进行惩罚,使得一些不重要的权重变为零,从而实现模型的稀疏化和简化
# Backwardself.scaler.scale(self.loss).backward()## add new code=============================duj## add l1 regulation for step2_Constraint_train l1_lambda = 1e-2 * (1 - 0.9 * epoch / self.epochs)for k, m in self.model.named_modules():if isinstance(m, nn.BatchNorm2d):m.weight.grad.data.add_(l1_lambda * torch.sign(m.weight.data))m.bias.grad.data.add_(1e-2 * torch.sign(m.bias.data))# Optimize - https://pytorch.org/docs/master/notes/amp_examples.htmlif ni - last_opt_step >= self.accumulate:self.optimizer_step()last_opt_step = ni
- 个人理解的稀疏化作用
- 通过对 gamma 和 beta 添加 L1 正则化,可以促使某些通道的 BN 权重变得非常小,甚至为零。这意味着在剪枝时,可以将这些通道从模型中移除
- 通过稀疏化 BN 层并剪除不重要的通道,剩下的通道会更有效地利用计算资源,减少无用计算。
第三步,step3_pruning()剪枝操作
LL_pruning.py
from ultralytics import YOLO
import torch
from ultralytics.nn.modules import Bottleneck, Conv, C2f, SPPF, Detect
import osclass PRUNE():def __init__(self) -> None:self.threshold = Nonedef get_threshold(self, model, factor=0.8):ws = []bs = []for name, m in model.named_modules():if isinstance(m, torch.nn.BatchNorm2d):w = m.weight.abs().detach()b = m.bias.abs().detach()ws.append(w)bs.append(b)print(name, w.max().item(), w.min().item(), b.max().item(), b.min().item())print()# keepws = torch.cat(ws)self.threshold = torch.sort(ws, descending=True)[0][int(len(ws) * factor)]def prune_conv(self, conv1: Conv, conv2: Conv):## a. 根据BN中的参数,获取需要保留的index================gamma = conv1.bn.weight.data.detach()beta = conv1.bn.bias.data.detach()keep_idxs = []local_threshold = self.thresholdwhile len(keep_idxs) < 8: ## 若剩余卷积核<8, 则降低阈值重新筛选keep_idxs = torch.where(gamma.abs() >= local_threshold)[0]local_threshold = local_threshold * 0.5n = len(keep_idxs)# n = max(int(len(idxs) * 0.8), p)print(n / len(gamma) * 100)# scale = len(idxs) / n## b. 利用index对BN进行剪枝============================conv1.bn.weight.data = gamma[keep_idxs]conv1.bn.bias.data = beta[keep_idxs]conv1.bn.running_var.data = conv1.bn.running_var.data[keep_idxs]conv1.bn.running_mean.data = conv1.bn.running_mean.data[keep_idxs]conv1.bn.num_features = nconv1.conv.weight.data = conv1.conv.weight.data[keep_idxs]conv1.conv.out_channels = n## c. 利用index对conv1进行剪枝=========================if conv1.conv.bias is not None:conv1.conv.bias.data = conv1.conv.bias.data[keep_idxs]## d. 利用index对conv2进行剪枝=========================if not isinstance(conv2, list):conv2 = [conv2]for item in conv2:if item is None: continueif isinstance(item, Conv):conv = item.convelse:conv = itemconv.in_channels = nconv.weight.data = conv.weight.data[:, keep_idxs]def prune(self, m1, m2):if isinstance(m1, C2f): # C2f as a top convm1 = m1.cv2if not isinstance(m2, list): # m2 is just one modulem2 = [m2]for i, item in enumerate(m2):if isinstance(item, C2f) or isinstance(item, SPPF):m2[i] = item.cv1self.prune_conv(m1, m2)def do_pruning(modelpath, savepath):pruning = PRUNE()### 0. 加载模型yolo = YOLO(modelpath) # build a new model from scratchpruning.get_threshold(yolo.model, 0.8) # 获取剪枝时bn参数的阈值,这里的0.8为剪枝率。### 1. 剪枝c2f 中的Bottleneckfor name, m in yolo.model.named_modules():if isinstance(m, Bottleneck):pruning.prune_conv(m.cv1, m.cv2)### 2. 指定剪枝不同模块之间的卷积核seq = yolo.model.modelfor i in [3,5,7,8]: pruning.prune(seq[i], seq[i+1])### 3. 对检测头进行剪枝# 在P3层: seq[15]之后的网络节点与其相连的有 seq[16]、detect.cv2[0] (box分支)、detect.cv3[0] (class分支)# 在P4层: seq[18]之后的网络节点与其相连的有 seq[19]、detect.cv2[1] 、detect.cv3[1] # 在P5层: seq[21]之后的网络节点与其相连的有 detect.cv2[2] 、detect.cv3[2] detect:Detect = seq[-1]last_inputs = [seq[15], seq[18], seq[21]]colasts = [seq[16], seq[19], None]for last_input, colast, cv2, cv3 in zip(last_inputs, colasts, detect.cv2, detect.cv3):pruning.prune(last_input, [colast, cv2[0], cv3[0]])pruning.prune(cv2[0], cv2[1])pruning.prune(cv2[1], cv2[2])pruning.prune(cv3[0], cv3[1])pruning.prune(cv3[1], cv3[2])### 4. 模型梯度设置与保存for name, p in yolo.model.named_parameters():p.requires_grad = Trueyolo.val()torch.save(yolo.ckpt, savepath)yolo.model.pt_path = yolo.model.pt_path.replace("last.pt", os.path.basename(savepath))yolo.export(format="onnx")## 重新load模型,修改保存命名,用以比较剪枝前后的onnx的大小yolo = YOLO(modelpath) # build a new model from scratchyolo.export(format="onnx")if __name__ == "__main__":modelpath = "runs/detect1/14_Constraint/weights/last.pt"savepath = "runs/detect1/14_Constraint/weights/last_prune.pt"do_pruning(modelpath, savepath)
- 如下图可用看到剪枝前后还是有区别的,参数量减少很多,网络性能将不可用,需要微调恢复精度
- 查看剪枝前后模型大小
du -sh ./runs/detect/VOC_Constraint/weights/last*
(yolov8n模型)
微调
该部分内容我也存在一些疑问,例如很多博主让ultralytics\engine\trainer.py添加加载模型代码,经过我8.2版本测试代码添加是完全失效的,因为setup_model在执行if isinstance(self.model, torch.nn.Module)就已经return了。
def setup_model(self):"""Load/create/download model for any task."""if isinstance(self.model, torch.nn.Module): # if model is loaded beforehand. No setup neededreturn
- 例如ultralytics\engine\trainer.py
- v8…x添加代码:548行 参考这里
self.model = self.get_model(cfg=cfg, weights=weights, verbose=RANK == -1)
# duj add code to finetune
self.model = weights
return ckpt
- 如果是v8.0.x 参考这里
在看到这篇中的修改1启发
- v8.2.x上面我不确定是哪个版本需要添加的,但是我实测都不起作用
- 我尝试在
ultralytics\engine\model.py
添加如下代码加载模型成功
self.trainer = (trainer or self._smart_load("trainer"))(overrides=args, _callbacks=self.callbacks)if not args.get("resume"): # manually set model only if not resuming# self.trainer.model = self.trainer.get_model(weights=self.model if self.ckpt else None, cfg=self.model.yaml)# self.model = self.trainer.model# dujiang edit self.trainer.model = self.model.train()if SETTINGS["hub"] is True and not self.session:
- 这里就是确保自己加载的是剪枝后的模型,但是不同版本好像不同,后续在探究原因。。。
- 这里有个小插曲,我在使用自己模型稀疏训练后剪枝发现(步骤2)发现BN层全没了,这里后面我将别人的稀疏训练的v8s模型拿来进行剪枝就没问题
- 可能是v8n的问题,也可能是我训练的问题,这里先不做深究继续查看剪枝是否成功且微调加载成功后能否恢复精度
- 此时多次尝试我基本确定微调加载的是我剪枝后的模型,接下来就是等待训练结果是否参数量正确了。
总结
总的来说跑通整个流程了,接下来尝试在自己的任务和数据上面进行剪枝,看看更换了模型结构又会有哪些坑等着我
相关文章:

Yolov8 目标检测剪枝学习记录
最近在进行YOLOv8系列的轻量化,目前在网络结构方面的优化已经接近极限了,所以想要学习一下模型剪枝是否能够进一步优化模型的性能 这里主要参考了torch-pruning的基本使用,v8模型剪枝,Jetson nano部署剪枝YOLOv8 下面只是记录一个…...

LeDeCo:AI自动化排版、设计、美化海报
1.简介 平面设计是一门艺术学科,致力于创造吸引注意力和有效传达信息的视觉内容。今天,创造视觉上吸引人的设计完全依赖于具有艺术创造力和技术专长的人类设计师,他们巧妙地整合多模态图形元素,这是一个复杂而耗时的过程…...

Flink CDC解决数据库同步,异常情况下增量、全量问题
Flink 1.11 引入了 Flink SQL CDC,CDC 能给我们数据和业务间能带来什么变化?本文由 Apache Flink PMC,阿里巴巴技术专家伍翀 (云邪)分享,内容将从传统的数据同步方案,基于 Flink CDC 同步的解决方案以及更多…...

01、flink的原理和安装部署
flink中主要有两个进程,分别是JobMManager和TaskManager,当然了根据flink的部署和运行环境不同,会有一些不同,但是主要的功能是类似的,下面我会讲下聊下,公司用的多的部署方式,基于yarn集群的部…...

美图脱掉“复古外衣”,在AI浪潮中蜕变
"人工智能就像电力一样,如果你的竞争对手正在使用它,你也需要使用它,否则你就会失去竞争力",斯坦福大学教授和谷歌前首席科学家安德鲁恩格尔曾这样说到。 而近日拉开序幕的消费电子风向标——科技贸易展国际消费电子展…...

sqlalchemy The transaction is active - has not been committed or rolled back.
连接池参考 参考:https://blog.csdn.net/SunJW_2017/article/details/129332393 1、因为使用了连接池,没有释放 2、解决方法: from sqlalchemy import create_engine from sqlalchemy.orm import sessionmaker, scoped_session from gree…...

47.数据绑定的PropertyChanged C#例子 WPF例子
[CallerMemberName] string propertyName null 这段代码中的 [CallerMemberName] 是一个特性(Attribute),它应用于 propertyName 参数。这个特性的作用是,在编译时,如果调用 OnPropertyChanged 方法时没有显式提供 pr…...

网络安全 | Web安全常见漏洞和防护经验策略
关注:CodingTechWork 引言 OWASP (Open Web Application Security Project) Top 10是Web应用最常见的安全风险集合,帮助开发人员和安全专家识别和防止最严重的网络安全问题。以下是基于OWASP Top 10的Web安全防护经验策略与规则集。Web开发者必须对潜在…...

Agent一键安装,快速上手Zabbix监控!
目录 一、Linux操作系统部署Agent环境配置1、防火墙配置2、永久关闭selinux yum方式安装1、配置zabbix仓库2、安装agent3、配置 Zabbix-Agent 指向 Zabbix-Server4、启动agent服务 二进制包安装1、下载二进制包2、创建用户和目录及更改属主(组)3、解压二…...

Edge Scdn是什么,它如何提升网站安全性与访问速度?
随着网络攻击的日益猖獗,尤其是分布式拒绝服务(DDoS)攻击的频繁发生,如何保护网站的安全性并确保用户的访问体验变得极为重要。Edge Scdn(内容分发网络)作为一种新兴的技术方案,逐渐被越来越多的…...

ubuntu20.04 docker安装
Ubuntu | Docker DocsPost-installation steps | Docker Docs # 创建目录 sudo mkdir -p /etc/docker # 写入配置文件 sudo tee /etc/docker/daemon.json <<-EOF { "registry-mirrors": [ "https://docker-0.unsee.tech", &qu…...

初始C#.
一.模板 using System; using System.Collections.Generic; using System.L。inq; using System.Text; using System.Threading.Tasks;//引用命名空间 namespace First_progream//项目名或者命名空间 { internal class Program //类名 { static void Main(string[] args)…...

js高亮文本
高亮文本 const inputs ["这是一个普通文本,包含关键字测试。",<p style"font-size: 10px">这是一个<span>GVM</span> <strong>测试</strong>内容。</p>, ];const keywords ["测试", "G…...

解决SpringBoot 健康检测接口 actuator/health 访问一直卡着,但 actuator/info等其他接口能正常访问的问题
背景 最近在做服务迁移,迁移完后的新服务,直接将pod的配置丢到新的K8S集群中,健康监测一直失败 Debug思路 先看日志,日志显示SpringBoot已成功启动,按理说不应该无法访问其/actuator/health接口 拉长健康监测延时时…...

KVM创建ubuntu20.04虚机,部署K8S,再克隆出二份,做为Worker节点加入集群,通过Helm创建2个Pod,让它们之间通过域名互访
KVM创建ubuntu20.04虚机,部署K8S,再克隆出二份,做为Worker节点加入集群,通过Helm创建2个Pod,让它们之间通过域名互访 一.背景二.操作步骤1.安装KVMA.在BIOS中开启VT-dB.修改grub,开启iommu在/etc/default/grub 中 GRUB_CMDLINE_LINUX行 添加 intel_iommuon iommupt重新创建引导…...

GaussDB中的Vacuum和Analyze
GaussDB中的Vacuum和Analyze 基本概念与区别手动Vacuum和Analyze查看Vacuum和Analyze记录Autovacuum配置参数 基本概念与区别 使用VACUUM、VACUUM FULL和ANALYZE命令定期对每个表进行维护,主要有以下原因: VACUUM FULL可回收已更新或已删除的数据所占据…...

IvorySQL 4.2 发布
IvorySQL 4.2 已于 2025 年 1 月 13 日正式发布。新版本全面支持 PostgreSQL 17.2,并修复了多项 bug。 增强功能 PostgreSQL 17.1 增强功能 确保当 RLS 应用于非顶级表引用时,缓存的计划会标记为依赖于调用角色使 libpq 在 SSL 或 GSS 协议协商期间丢…...

浅谈云计算20 | OpenStack管理模块(下)
OpenStack管理模块(下) 五、存储管理5.1 存储管理概述 5.2 架构设计5.2.1 Cinder块存储架构5.2.2 Swift对象存储架构 六、网络管理6.1 网络管理概述6.2 架构解析6.2.1 Neutron网络服务架构6.2.2 网络拓扑架构 6.3 原理与流程6.3.1 网络创建原理6.3.2 网络…...

去年社融增量超32万亿 货币信贷平稳增长-乐享数科
数据显示,2024年全年,社会融资规模增量为32.26万亿元;去年12月末,社会融资规模同比增长8.0%,广义货币供应量(M_[2])同比增长7.3%,人民币贷款同比增长7.6%,增速相比上月有…...

STM32 HAL库函数入门指南:从原理到实践
1 STM32 HAL库概述 STM32 HAL(Hardware Abstraction Layer)库是ST公司专门为STM32系列微控制器开发的一套硬件抽象层函数库。它的核心设计理念是在应用层与硬件层之间建立一个抽象层,这个抽象层屏蔽了底层硬件的具体实现细节,为开发者提供了一套统一的、…...

React封装倒计时按钮
背景 在开发过程中,经常需要使用到倒计时的场景,当用户点击后,按钮进行倒计时,然后等待邮件或者短信发送,每次都写重复代码,会让代码显得臃肿,所以封装一个组件来减少耦合 创建一个倒计时组件…...

深入探究Linux树状目录结构
Linux 作为一款广泛使用的开源操作系统,其目录结构采用了树状设计,这种结构清晰、有条理,便于用户和系统进行文件管理与操作。 一、根目录(/) 根目录是整个 Linux 文件系统的起始点,就像一棵大树的根部&…...

Realsense相机驱动安装及其ROS通讯配置——机器人抓取系统基础系列(四)
文章目录 概要1 Realsense相机驱动安装Method1: 使用Intel服务器预编译包Method2: 使用ROS服务器预编译包Method3: 使用SDK源代码方法对比总结 2 Realsense-ROS通讯配置与使用2.1 Realsense-ROS包安装2.2 ROS节点启动 小结Reference 概要 本文首先阐述了Realsense相机驱动安装…...

linux安装nvm
下载命令 wget https://github.com/nvm-sh/nvm/archive/refs/tags/v0.39.1.tar.gz当前盘打开终端后的nvm文件夹中 mkdir -p /nvm/.nvm如果树根不够就用加sudo 解压文件 tar xvf v0.39.1.tar.gz输入pwd 确定当前文件完成路径 在当前文件中写入。bashrc文件及代码回车进入编辑…...

图论1-问题 C: 算法7-6:图的遍历——广度优先搜索
题目描述 广度优先搜索遍历类似于树的按层次遍历的过程。其过程为:假设从图中的某顶点v出发,在访问了v之后依次访问v的各个未曾被访问过的邻接点,然后分别从这些邻接点出发依次访问它们的邻接点,并使“先被访问的顶点的邻接点”先…...

基于 STM32 的多功能时间管理器项目
引言 在快节奏的生活中,时间管理显得尤为重要。本项目旨在通过 STM32 开发一个多功能时间管理器,功能包括计时器、闹钟和日历。用户可以方便地设置不同的提醒和计时任务,以更好地管理日常生活和工作。 项目名称 多功能时间管理器 环境准备 …...

Java工程结构:二方库依赖规约
文章目录 I jar 包分类一方库:二方库:三方库:II 专有名词GAV(GroupId、ArtifactId、Version):Maven 坐标III GAV 规则GroupId 格式ArtifactId 格式二方库版本号命名方式:主版本号.次版本号.修订号I jar 包分类 一方库: 本工程内部子项目模块依赖的库(jar 包)。 二…...

Django自带admin管理系统使用
1、admin路径地址 localhost:8000/admin 2、使用命令行创建超级管理员 python manage.py createsuperuser 之后按照提示一步一步往下走就好了。 3、修改管理员密码 python manage.py changepassword admin admin是超级管理员的账号 4、后台管理系统注册模型,…...

Jmeter 简单使用、生成测试报告(一)
一、下载Jmter 去官网下载,我下载的是apache-jmeter-5.6.3.zip,解压后就能用。 二、安装java环境 JMeter是基于Java开发的,运行JMeter需要Java环境。 1.下载JDK、安装Jdk 2.配置java环境变量 3.验证安装是否成功(java -versio…...

手摸手实战前端项目CI CD
由于图片和格式解析问题,为了更好阅读体验可前往 阅读原文 CI/CD 是 持续集成(Continuous Integration) 和 持续交付/部署(Continuous Delivery/Continuous Deployment) 的缩写,是现代软件开发中的一种自动…...