当前位置: 首页 > news >正文

Yolov8 目标检测剪枝学习记录

最近在进行YOLOv8系列的轻量化,目前在网络结构方面的优化已经接近极限了,所以想要学习一下模型剪枝是否能够进一步优化模型的性能
这里主要参考了torch-pruning的基本使用,v8模型剪枝,Jetson nano部署剪枝YOLOv8
下面只是记录一个简单流程,用于后续使用在自己的任务和网络中,数据不作为参考

首先训练一个base模型用于参考

  • 环境:Ultralytics YOLOv8.2.18 🚀 Python-3.10.14 torch-2.4.0 CUDA:0 (NVIDIA H100 PCIe, 81008MiB)
  • 训练代码

参考网上或者自己写一个能训练即可,为了方便我将通用的记录下来,实测可用来自代码来源

from ultralytics import YOLO
import osroot = os.getcwd()
## 配置文件路径
name_yaml             = os.path.join(root, "ultralytics/datasets/VOC.yaml")
name_pretrain         = os.path.join(root, "yolov8s.pt")
## 原始训练路径
path_train            = os.path.join(root, "runs/detect/VOC")
name_train            = os.path.join(path_train, "weights/last.pt")
## 约束训练路径、剪枝模型文件
path_constraint_train = os.path.join(root, "runs/detect/VOC_Constraint")
name_prune_before     = os.path.join(path_constraint_train, "weights/last.pt")
name_prune_after      = os.path.join(path_constraint_train, "weights/last_prune.pt")
## 微调路径
path_fineturn         = os.path.join(root, "runs/detect/VOC_finetune")def else_api():path_data = ""path_result = ""model = YOLO(name_pretrain) metrics = model.val()  # evaluate model performance on the validation setmodel.export(format='onnx', opset=11, simplify=True, dynamic=False, imgsz=640)model.predict(path_data, device="0", save=True, show=False, save_txt=True, imgsz=[288,480], save_conf=True, name=path_result, iou=0.5)  # 这里的imgsz为高宽def step1_train():model = YOLO(name_pretrain) model.train(data=name_yaml, device="0,1", imgsz=640, epochs=50, batch=32, workers=16, save_period=1, name=path_train)  # train the model## 2024.3.4添加【amp=False】
def step2_Constraint_train():model = YOLO(name_train) model.train(data=name_yaml, device="0,1", imgsz=640, epochs=50, batch=32, amp=False, workers=16, save_period=1,name=path_constraint_train)  # train the modeldef step3_pruning():from LL_pruning import do_pruningdo_pruning(name_prune_before, name_prune_after)def step4_finetune():model = YOLO(name_prune_after)     # load a pretrained model (recommended for training)model.train(data=name_yaml, device="0,1", imgsz=640, epochs=50, batch=32, workers=16, save_period=1, name=path_fineturn)  # train the modelstep1_train()
# step2_Constraint_train()
# step3_pruning()
# step4_finetune()

第一步,step1_train()

  • 即训练一个base模型,用于最后性能好坏的重要参考
    在这里插入图片描述

第二步,step2_Constraint_train()

训练之前在ultralytics\engine\trainer.py添加bn的L1正则,使得bn参数在训练时变得稀疏

  • 通过对参数的绝对值进行惩罚,使得一些不重要的权重变为零,从而实现模型的稀疏化和简化
     # Backwardself.scaler.scale(self.loss).backward()## add new code=============================duj## add l1 regulation for step2_Constraint_train               l1_lambda = 1e-2 * (1 - 0.9 * epoch / self.epochs)for k, m in self.model.named_modules():if isinstance(m, nn.BatchNorm2d):m.weight.grad.data.add_(l1_lambda * torch.sign(m.weight.data))m.bias.grad.data.add_(1e-2 * torch.sign(m.bias.data))# Optimize - https://pytorch.org/docs/master/notes/amp_examples.htmlif ni - last_opt_step >= self.accumulate:self.optimizer_step()last_opt_step = ni

在这里插入图片描述

  • 个人理解的稀疏化作用
    • 通过对 gamma 和 beta 添加 L1 正则化,可以促使某些通道的 BN 权重变得非常小,甚至为零。这意味着在剪枝时,可以将这些通道从模型中移除
    • 通过稀疏化 BN 层并剪除不重要的通道,剩下的通道会更有效地利用计算资源,减少无用计算。

第三步,step3_pruning()剪枝操作

LL_pruning.py

from ultralytics import YOLO
import torch
from ultralytics.nn.modules import Bottleneck, Conv, C2f, SPPF, Detect
import osclass PRUNE():def __init__(self) -> None:self.threshold = Nonedef get_threshold(self, model, factor=0.8):ws = []bs = []for name, m in model.named_modules():if isinstance(m, torch.nn.BatchNorm2d):w = m.weight.abs().detach()b = m.bias.abs().detach()ws.append(w)bs.append(b)print(name, w.max().item(), w.min().item(), b.max().item(), b.min().item())print()# keepws = torch.cat(ws)self.threshold = torch.sort(ws, descending=True)[0][int(len(ws) * factor)]def prune_conv(self, conv1: Conv, conv2: Conv):## a. 根据BN中的参数,获取需要保留的index================gamma = conv1.bn.weight.data.detach()beta  = conv1.bn.bias.data.detach()keep_idxs = []local_threshold = self.thresholdwhile len(keep_idxs) < 8:  ## 若剩余卷积核<8, 则降低阈值重新筛选keep_idxs = torch.where(gamma.abs() >= local_threshold)[0]local_threshold = local_threshold * 0.5n = len(keep_idxs)# n = max(int(len(idxs) * 0.8), p)print(n / len(gamma) * 100)# scale = len(idxs) / n## b. 利用index对BN进行剪枝============================conv1.bn.weight.data = gamma[keep_idxs]conv1.bn.bias.data   = beta[keep_idxs]conv1.bn.running_var.data = conv1.bn.running_var.data[keep_idxs]conv1.bn.running_mean.data = conv1.bn.running_mean.data[keep_idxs]conv1.bn.num_features = nconv1.conv.weight.data = conv1.conv.weight.data[keep_idxs]conv1.conv.out_channels = n## c. 利用index对conv1进行剪枝=========================if conv1.conv.bias is not None:conv1.conv.bias.data = conv1.conv.bias.data[keep_idxs]## d. 利用index对conv2进行剪枝=========================if not isinstance(conv2, list):conv2 = [conv2]for item in conv2:if item is None: continueif isinstance(item, Conv):conv = item.convelse:conv = itemconv.in_channels = nconv.weight.data = conv.weight.data[:, keep_idxs]def prune(self, m1, m2):if isinstance(m1, C2f):      # C2f as a top convm1 = m1.cv2if not isinstance(m2, list): # m2 is just one modulem2 = [m2]for i, item in enumerate(m2):if isinstance(item, C2f) or isinstance(item, SPPF):m2[i] = item.cv1self.prune_conv(m1, m2)def do_pruning(modelpath, savepath):pruning = PRUNE()### 0. 加载模型yolo = YOLO(modelpath)                  # build a new model from scratchpruning.get_threshold(yolo.model, 0.8)  # 获取剪枝时bn参数的阈值,这里的0.8为剪枝率。### 1. 剪枝c2f 中的Bottleneckfor name, m in yolo.model.named_modules():if isinstance(m, Bottleneck):pruning.prune_conv(m.cv1, m.cv2)### 2. 指定剪枝不同模块之间的卷积核seq = yolo.model.modelfor i in [3,5,7,8]: pruning.prune(seq[i], seq[i+1])### 3. 对检测头进行剪枝# 在P3层: seq[15]之后的网络节点与其相连的有 seq[16]、detect.cv2[0] (box分支)、detect.cv3[0] (class分支)# 在P4层: seq[18]之后的网络节点与其相连的有 seq[19]、detect.cv2[1] 、detect.cv3[1] # 在P5层: seq[21]之后的网络节点与其相连的有 detect.cv2[2] 、detect.cv3[2] detect:Detect = seq[-1]last_inputs   = [seq[15], seq[18], seq[21]]colasts       = [seq[16], seq[19], None]for last_input, colast, cv2, cv3 in zip(last_inputs, colasts, detect.cv2, detect.cv3):pruning.prune(last_input, [colast, cv2[0], cv3[0]])pruning.prune(cv2[0], cv2[1])pruning.prune(cv2[1], cv2[2])pruning.prune(cv3[0], cv3[1])pruning.prune(cv3[1], cv3[2])### 4. 模型梯度设置与保存for name, p in yolo.model.named_parameters():p.requires_grad = Trueyolo.val()torch.save(yolo.ckpt, savepath)yolo.model.pt_path = yolo.model.pt_path.replace("last.pt", os.path.basename(savepath))yolo.export(format="onnx")## 重新load模型,修改保存命名,用以比较剪枝前后的onnx的大小yolo = YOLO(modelpath)  # build a new model from scratchyolo.export(format="onnx")if __name__ == "__main__":modelpath = "runs/detect1/14_Constraint/weights/last.pt"savepath  = "runs/detect1/14_Constraint/weights/last_prune.pt"do_pruning(modelpath, savepath)

在这里插入图片描述

  • 如下图可用看到剪枝前后还是有区别的,参数量减少很多,网络性能将不可用,需要微调恢复精度
    在这里插入图片描述
  • 查看剪枝前后模型大小 du -sh ./runs/detect/VOC_Constraint/weights/last*yolov8n模型
    在这里插入图片描述

微调

该部分内容我也存在一些疑问,例如很多博主让ultralytics\engine\trainer.py添加加载模型代码,经过我8.2版本测试代码添加是完全失效的,因为setup_model在执行if isinstance(self.model, torch.nn.Module)就已经return了。

 def setup_model(self):"""Load/create/download model for any task."""if isinstance(self.model, torch.nn.Module):  # if model is loaded beforehand. No setup neededreturn
  • 例如ultralytics\engine\trainer.py
  • v8…x添加代码:548行 参考这里
self.model = self.get_model(cfg=cfg, weights=weights, verbose=RANK == -1)
# duj add code to finetune
self.model = weights
return ckpt
  • 如果是v8.0.x 参考这里

在看到这篇中的修改1启发

  • v8.2.x上面我不确定是哪个版本需要添加的,但是我实测都不起作用
  • 我尝试在ultralytics\engine\model.py添加如下代码加载模型成功
 self.trainer = (trainer or self._smart_load("trainer"))(overrides=args, _callbacks=self.callbacks)if not args.get("resume"):  # manually set model only if not resuming# self.trainer.model = self.trainer.get_model(weights=self.model if self.ckpt else None, cfg=self.model.yaml)# self.model = self.trainer.model# dujiang edit self.trainer.model = self.model.train()if SETTINGS["hub"] is True and not self.session:
  • 这里就是确保自己加载的是剪枝后的模型,但是不同版本好像不同,后续在探究原因。。。
  • 这里有个小插曲,我在使用自己模型稀疏训练后剪枝发现(步骤2)发现BN层全没了,这里后面我将别人的稀疏训练的v8s模型拿来进行剪枝就没问题
  • 可能是v8n的问题,也可能是我训练的问题,这里先不做深究继续查看剪枝是否成功且微调加载成功后能否恢复精度
    在这里插入图片描述
  • 此时多次尝试我基本确定微调加载的是我剪枝后的模型,接下来就是等待训练结果是否参数量正确了。
    在这里插入图片描述

总结

总的来说跑通整个流程了,接下来尝试在自己的任务和数据上面进行剪枝,看看更换了模型结构又会有哪些坑等着我

相关文章:

Yolov8 目标检测剪枝学习记录

最近在进行YOLOv8系列的轻量化&#xff0c;目前在网络结构方面的优化已经接近极限了&#xff0c;所以想要学习一下模型剪枝是否能够进一步优化模型的性能 这里主要参考了torch-pruning的基本使用&#xff0c;v8模型剪枝&#xff0c;Jetson nano部署剪枝YOLOv8 下面只是记录一个…...

LeDeCo:AI自动化排版、设计、美化海报

1.简介 平面设计是一门艺术学科&#xff0c;致力于创造吸引注意力和有效传达信息的视觉内容。今天&#xff0c;创造视觉上吸引人的设计完全依赖于具有艺术创造力和技术专长的人类设计师&#xff0c;他们巧妙地整合多模态图形元素&#xff0c;这是一个复杂而耗时的过程&#xf…...

Flink CDC解决数据库同步,异常情况下增量、全量问题

Flink 1.11 引入了 Flink SQL CDC&#xff0c;CDC 能给我们数据和业务间能带来什么变化&#xff1f;本文由 Apache Flink PMC&#xff0c;阿里巴巴技术专家伍翀 (云邪&#xff09;分享&#xff0c;内容将从传统的数据同步方案&#xff0c;基于 Flink CDC 同步的解决方案以及更多…...

01、flink的原理和安装部署

flink中主要有两个进程&#xff0c;分别是JobMManager和TaskManager&#xff0c;当然了根据flink的部署和运行环境不同&#xff0c;会有一些不同&#xff0c;但是主要的功能是类似的&#xff0c;下面我会讲下聊下&#xff0c;公司用的多的部署方式&#xff0c;基于yarn集群的部…...

美图脱掉“复古外衣”,在AI浪潮中蜕变

"人工智能就像电力一样&#xff0c;如果你的竞争对手正在使用它&#xff0c;你也需要使用它&#xff0c;否则你就会失去竞争力"&#xff0c;斯坦福大学教授和谷歌前首席科学家安德鲁恩格尔曾这样说到。 而近日拉开序幕的消费电子风向标——科技贸易展国际消费电子展…...

sqlalchemy The transaction is active - has not been committed or rolled back.

连接池参考 参考&#xff1a;https://blog.csdn.net/SunJW_2017/article/details/129332393 1、因为使用了连接池&#xff0c;没有释放 2、解决方法&#xff1a; from sqlalchemy import create_engine from sqlalchemy.orm import sessionmaker, scoped_session from gree…...

47.数据绑定的PropertyChanged C#例子 WPF例子

[CallerMemberName] string propertyName null 这段代码中的 [CallerMemberName] 是一个特性&#xff08;Attribute&#xff09;&#xff0c;它应用于 propertyName 参数。这个特性的作用是&#xff0c;在编译时&#xff0c;如果调用 OnPropertyChanged 方法时没有显式提供 pr…...

网络安全 | Web安全常见漏洞和防护经验策略

关注&#xff1a;CodingTechWork 引言 OWASP (Open Web Application Security Project) Top 10是Web应用最常见的安全风险集合&#xff0c;帮助开发人员和安全专家识别和防止最严重的网络安全问题。以下是基于OWASP Top 10的Web安全防护经验策略与规则集。Web开发者必须对潜在…...

Agent一键安装,快速上手Zabbix监控!

目录 一、Linux操作系统部署Agent环境配置1、防火墙配置2、永久关闭selinux yum方式安装1、配置zabbix仓库2、安装agent3、配置 Zabbix-Agent 指向 Zabbix-Server4、启动agent服务 二进制包安装1、下载二进制包2、创建用户和目录及更改属主&#xff08;组&#xff09;3、解压二…...

Edge Scdn是什么,它如何提升网站安全性与访问速度?

随着网络攻击的日益猖獗&#xff0c;尤其是分布式拒绝服务&#xff08;DDoS&#xff09;攻击的频繁发生&#xff0c;如何保护网站的安全性并确保用户的访问体验变得极为重要。Edge Scdn&#xff08;内容分发网络&#xff09;作为一种新兴的技术方案&#xff0c;逐渐被越来越多的…...

ubuntu20.04 docker安装

Ubuntu | Docker DocsPost-installation steps | Docker Docs # 创建目录 sudo mkdir -p /etc/docker # 写入配置文件 sudo tee /etc/docker/daemon.json <<-EOF { "registry-mirrors": [ "https://docker-0.unsee.tech", &qu…...

初始C#.

一.模板 using System; using System.Collections.Generic; using System.L。inq; using System.Text; using System.Threading.Tasks;//引用命名空间 namespace First_progream//项目名或者命名空间 { internal class Program //类名 { static void Main(string[] args)…...

js高亮文本

高亮文本 const inputs ["这是一个普通文本&#xff0c;包含关键字测试。",<p style"font-size: 10px">这是一个<span>GVM</span> <strong>测试</strong>内容。</p>, ];const keywords ["测试", "G…...

解决SpringBoot 健康检测接口 actuator/health 访问一直卡着,但 actuator/info等其他接口能正常访问的问题

背景 最近在做服务迁移&#xff0c;迁移完后的新服务&#xff0c;直接将pod的配置丢到新的K8S集群中&#xff0c;健康监测一直失败 Debug思路 先看日志&#xff0c;日志显示SpringBoot已成功启动&#xff0c;按理说不应该无法访问其/actuator/health接口 拉长健康监测延时时…...

KVM创建ubuntu20.04虚机,部署K8S,再克隆出二份,做为Worker节点加入集群,通过Helm创建2个Pod,让它们之间通过域名互访

KVM创建ubuntu20.04虚机,部署K8S,再克隆出二份,做为Worker节点加入集群,通过Helm创建2个Pod,让它们之间通过域名互访 一.背景二.操作步骤1.安装KVMA.在BIOS中开启VT-dB.修改grub,开启iommu在/etc/default/grub 中 GRUB_CMDLINE_LINUX行 添加 intel_iommuon iommupt重新创建引导…...

GaussDB中的Vacuum和Analyze

GaussDB中的Vacuum和Analyze 基本概念与区别手动Vacuum和Analyze查看Vacuum和Analyze记录Autovacuum配置参数 基本概念与区别 使用VACUUM、VACUUM FULL和ANALYZE命令定期对每个表进行维护&#xff0c;主要有以下原因&#xff1a; VACUUM FULL可回收已更新或已删除的数据所占据…...

IvorySQL 4.2 发布

IvorySQL 4.2 已于 2025 年 1 月 13 日正式发布。新版本全面支持 PostgreSQL 17.2&#xff0c;并修复了多项 bug。 增强功能 PostgreSQL 17.1 增强功能 确保当 RLS 应用于非顶级表引用时&#xff0c;缓存的计划会标记为依赖于调用角色使 libpq 在 SSL 或 GSS 协议协商期间丢…...

浅谈云计算20 | OpenStack管理模块(下)

OpenStack管理模块&#xff08;下&#xff09; 五、存储管理5.1 存储管理概述 5.2 架构设计5.2.1 Cinder块存储架构5.2.2 Swift对象存储架构 六、网络管理6.1 网络管理概述6.2 架构解析6.2.1 Neutron网络服务架构6.2.2 网络拓扑架构 6.3 原理与流程6.3.1 网络创建原理6.3.2 网络…...

去年社融增量超32万亿 货币信贷平稳增长-乐享数科

数据显示&#xff0c;2024年全年&#xff0c;社会融资规模增量为32.26万亿元&#xff1b;去年12月末&#xff0c;社会融资规模同比增长8.0%&#xff0c;广义货币供应量&#xff08;M_[2]&#xff09;同比增长7.3%&#xff0c;人民币贷款同比增长7.6%&#xff0c;增速相比上月有…...

STM32 HAL库函数入门指南:从原理到实践

1 STM32 HAL库概述 STM32 HAL(Hardware Abstraction Layer)库是ST公司专门为STM32系列微控制器开发的一套硬件抽象层函数库。它的核心设计理念是在应用层与硬件层之间建立一个抽象层&#xff0c;这个抽象层屏蔽了底层硬件的具体实现细节&#xff0c;为开发者提供了一套统一的、…...

TDengine 快速体验(Docker 镜像方式)

简介 TDengine 可以通过安装包、Docker 镜像 及云服务快速体验 TDengine 的功能&#xff0c;本节首先介绍如何通过 Docker 快速体验 TDengine&#xff0c;然后介绍如何在 Docker 环境下体验 TDengine 的写入和查询功能。如果你不熟悉 Docker&#xff0c;请使用 安装包的方式快…...

调用支付宝接口响应40004 SYSTEM_ERROR问题排查

在对接支付宝API的时候&#xff0c;遇到了一些问题&#xff0c;记录一下排查过程。 Body:{"datadigital_fincloud_generalsaas_face_certify_initialize_response":{"msg":"Business Failed","code":"40004","sub_msg…...

Prompt Tuning、P-Tuning、Prefix Tuning的区别

一、Prompt Tuning、P-Tuning、Prefix Tuning的区别 1. Prompt Tuning(提示调优) 核心思想:固定预训练模型参数,仅学习额外的连续提示向量(通常是嵌入层的一部分)。实现方式:在输入文本前添加可训练的连续向量(软提示),模型只更新这些提示参数。优势:参数量少(仅提…...

golang循环变量捕获问题​​

在 Go 语言中&#xff0c;当在循环中启动协程&#xff08;goroutine&#xff09;时&#xff0c;如果在协程闭包中直接引用循环变量&#xff0c;可能会遇到一个常见的陷阱 - ​​循环变量捕获问题​​。让我详细解释一下&#xff1a; 问题背景 看这个代码片段&#xff1a; fo…...

React Native 开发环境搭建(全平台详解)

React Native 开发环境搭建&#xff08;全平台详解&#xff09; 在开始使用 React Native 开发移动应用之前&#xff0c;正确设置开发环境是至关重要的一步。本文将为你提供一份全面的指南&#xff0c;涵盖 macOS 和 Windows 平台的配置步骤&#xff0c;如何在 Android 和 iOS…...

线程同步:确保多线程程序的安全与高效!

全文目录&#xff1a; 开篇语前序前言第一部分&#xff1a;线程同步的概念与问题1.1 线程同步的概念1.2 线程同步的问题1.3 线程同步的解决方案 第二部分&#xff1a;synchronized关键字的使用2.1 使用 synchronized修饰方法2.2 使用 synchronized修饰代码块 第三部分&#xff…...

vue3 字体颜色设置的多种方式

在Vue 3中设置字体颜色可以通过多种方式实现&#xff0c;这取决于你是想在组件内部直接设置&#xff0c;还是在CSS/SCSS/LESS等样式文件中定义。以下是几种常见的方法&#xff1a; 1. 内联样式 你可以直接在模板中使用style绑定来设置字体颜色。 <template><div :s…...

屋顶变身“发电站” ,中天合创屋面分布式光伏发电项目顺利并网!

5月28日&#xff0c;中天合创屋面分布式光伏发电项目顺利并网发电&#xff0c;该项目位于内蒙古自治区鄂尔多斯市乌审旗&#xff0c;项目利用中天合创聚乙烯、聚丙烯仓库屋面作为场地建设光伏电站&#xff0c;总装机容量为9.96MWp。 项目投运后&#xff0c;每年可节约标煤3670…...

Cloudflare 从 Nginx 到 Pingora:性能、效率与安全的全面升级

在互联网的快速发展中&#xff0c;高性能、高效率和高安全性的网络服务成为了各大互联网基础设施提供商的核心追求。Cloudflare 作为全球领先的互联网安全和基础设施公司&#xff0c;近期做出了一个重大技术决策&#xff1a;弃用长期使用的 Nginx&#xff0c;转而采用其内部开发…...

VTK如何让部分单位不可见

最近遇到一个需求&#xff0c;需要让一个vtkDataSet中的部分单元不可见&#xff0c;查阅了一些资料大概有以下几种方式 1.通过颜色映射表来进行&#xff0c;是最正规的做法 vtkNew<vtkLookupTable> lut; //值为0不显示&#xff0c;主要是最后一个参数&#xff0c;透明度…...