yolov8蒸馏(附代码-免费)
首先蒸馏是什么?
模型蒸馏(Model Distillation)是一种用于在计算机视觉中提高模型性能和效率的技术。在模型蒸馏中,通常存在两个模型,即“教师模型”和“学生模型”。
为什么需要蒸馏?
- 在不增加模型计算量和参数量的情况下提升精度,也即是可以无损提高精度。
- 配合剪枝一起使用,可以尽量达到无损降低模型参数量、计算量,提高FPS的情况下,还能保持模型精度没有下降甚至上升,这是改进网络结构无法达到的高度。
- 论文中的保底手段,因为剪枝和蒸馏的特殊性,其都不会增加参数量和计算量,可以在最后一个点上大幅度增加实验和工作量,因为本身蒸馏也需要做大量实验。
目录
一.代码前提
(1)本文选取的老师模型为yolov8s,学生为剪枝完的yolov8s
(2)本文使用的蒸馏方法包括mgd,cwd
(3)使用前下载必须的包,并且把数据集放在datasets文件夹中,最后替换data.yaml中分类。
二.蒸馏步骤
(1) 训练教师模型
(2) 训练学生模型
(3) 蒸馏训练
三.模型剪枝+蒸馏
(1)约束训练在我上一篇文章中提到,链接:yolov8剪枝
(2)约束训练后,先进行剪枝,使用prune.py。替换模型位置,直接运行。
(3)剪完枝后,效果不一定好,所以使用剪枝完后的模型,继续训练:
一.代码前提
(1)本文选取的老师模型为yolov8s,学生为剪枝完的yolov8s
(2)本文使用的蒸馏方法包括mgd,cwd
(3)使用前下载必须的包,并且把数据集放在datasets文件夹中,最后替换data.yaml中分类。
本文代码已经上传到GitHub,链接:yolov8_蒸馏
使用不妨加个关注,后续还会加入Vit(vision transformer),替换loss等提升精度的方法。
二.蒸馏步骤
(1) 训练教师模型
打开文件中train.py,替换模型文件位置。开始训练,达到理想目标就停止。
import os
from ultralytics import YOLO
import torch
os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'def main():model = YOLO("yolov8s.pt")model.train(data="data.yaml", Distillation = None, loss_type='None', amp=False, imgsz=640, epochs=50, batch=20, device=0, workers=0)if __name__ == '__main__':main()
(2) 训练学生模型
打开文件中train.py,替换模型文件位置。我这边使用的是剪枝后的yolov8s模型,具体轻量化剪枝步骤可见本文最后。
import os
from ultralytics import YOLO
import torch
os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'def main():model_s = YOLO("./runs/detect/prune/weights/prune.pt")model_s.train(data="data.yaml", Distillation = None, loss_type='None', amp=False, imgsz=640, epochs=50, batch=20, device=0, workers=0)if __name__ == '__main__':main()
(3) 蒸馏训练
打开文件中train_distillation.py,替换老师与学生模型文件位置。两种蒸馏方法可以选择:cwd和mgd。
import os
from ultralytics import YOLO
import torchos.environ['KMP_DUPLICATE_LIB_OK'] = 'True'def main():model_t = YOLO('runs/detect/yolov8s/weights/best.pt') # the teacher modelmodel_s = YOLO('runs/detect/prune/weights/best.pt') # the student model"""Attributes:Distillation: the distillation modelloss_type: mgd, cwdamp: Automatic Mixed Precision"""model_s.train(data="data.yaml", Distillation=model_t.model, loss_type='mgd', amp=False, imgsz=640, epochs=100,batch=20, device=0, workers=0, lr0=0.001)if __name__ == '__main__':main()
现在先不进行训练,打开文件夹yolo_project_distillation\ultralytics\engine\trainer.py
在类FeatureLoss中,函数forward大概162行处打一个断点,进行调试。代码位置:
def forward(self, y_s, y_t):assert len(y_s) == len(y_t)tea_feats = []stu_feats = []for idx, (s, t) in enumerate(zip(y_s, y_t)):# change ---if self.distiller == 'cwd':s = self.align_module[idx](s)s = self.norm[idx](s)else:s = self.norm1[idx](s)t = self.norm[idx](t)tea_feats.append(t)stu_feats.append(s)loss = self.feature_loss(stu_feats, tea_feats)return self.loss_weight * loss
调试运行,查看变量中学生模型y_s和老师模型y_t的张量大小。把通道数记下来,写在类Distillation_loss的
channels_s = [256, 480, 256, 64, 143, 229][-le:]channels_t = [256, 512, 256, 128, 256, 512][-le:]
这边总共有六个,刚好对应模型的六个层的通道数。
替换完成后,应该就可以进行训练了。训练不好的话,再来评论区找我吧。
三.模型剪枝+蒸馏
(1)约束训练在我上一篇文章中提到,链接:yolov8剪枝
(2)约束训练后,先进行剪枝,使用prune.py。替换模型位置,直接运行。
from ultralytics import YOLO
import torch
from ultralytics.nn.modules import Bottleneck, Conv, C2f, SPPF, Detect
from copy import deepcopy# Load a model
yolo = YOLO("./runs/detect/yolov8s/weights/last.pt")
# Save model address
res_dir = "./runs/detect/prune/weights/prune.pt"
# Pruning rate
factor = 0.75yolo.info()
model = yolo.model
ws = []
bs = []for name, m in model.named_modules():if isinstance(m, torch.nn.BatchNorm2d):w = m.weight.abs().detach()b = m.bias.abs().detach()ws.append(w)bs.append(b)# print(name, w.max().item(), w.min().item(), b.max().item(), b.min().item())# keepws = torch.cat(ws)
threshold = torch.sort(ws, descending=True)[0][int(len(ws) * factor)]
print(threshold)def prune_conv(conv1: Conv, conv2: Conv):gamma = conv1.bn.weight.data.detach()beta = conv1.bn.bias.data.detach()keep_idxs = []local_threshold = thresholdwhile len(keep_idxs) < 8:keep_idxs = torch.where(gamma.abs() >= local_threshold)[0]local_threshold = local_threshold * 0.5n = len(keep_idxs)# n = max(int(len(idxs) * 0.8), p)# print(n / len(gamma) * 100)# scale = len(idxs) / nconv1.bn.weight.data = gamma[keep_idxs]conv1.bn.bias.data = beta[keep_idxs]conv1.bn.running_var.data = conv1.bn.running_var.data[keep_idxs]conv1.bn.running_mean.data = conv1.bn.running_mean.data[keep_idxs]conv1.bn.num_features = nconv1.conv.weight.data = conv1.conv.weight.data[keep_idxs]conv1.conv.out_channels = nif conv1.conv.bias is not None:conv1.conv.bias.data = conv1.conv.bias.data[keep_idxs]if not isinstance(conv2, list):conv2 = [conv2]for item in conv2:if item is not None:if isinstance(item, Conv):conv = item.convelse:conv = itemconv.in_channels = nconv.weight.data = conv.weight.data[:, keep_idxs]def prune(m1, m2):if isinstance(m1, C2f): # C2f as a top convm1 = m1.cv2if not isinstance(m2, list): # m2 is just one modulem2 = [m2]for i, item in enumerate(m2):if isinstance(item, C2f) or isinstance(item, SPPF):m2[i] = item.cv1prune_conv(m1, m2)for name, m in model.named_modules():if isinstance(m, Bottleneck):prune_conv(m.cv1, m.cv2)seq = model.model
for i in range(3, 9):if i in [6, 4, 9]: continueprune(seq[i], seq[i + 1])detect: Detect = seq[-1]
last_inputs = [seq[15], seq[18], seq[21]]
colasts = [seq[16], seq[19], None]
for last_input, colast, cv2, cv3 in zip(last_inputs, colasts, detect.cv2, detect.cv3):prune(last_input, [colast, cv2[0], cv3[0]])prune(cv2[0], cv2[1])prune(cv2[1], cv2[2])prune(cv3[0], cv3[1])prune(cv3[1], cv3[2])for name, p in yolo.model.named_parameters():p.requires_grad = True#yolo.val(workers=0) # 剪枝模型进行验证 yolo.val(workers=0)
yolo.info()
# yolo.export(format="onnx") # 导出为onnx文件
# yolo.train(data="./data/data_nc5/data_nc5.yaml", epochs=100) # 剪枝后直接训练微调
ckpt = {'epoch': -1,'best_fitness': None,'model': yolo.ckpt['ema'],'ema': None,'updates': None,'optimizer': None,'train_args': yolo.ckpt["train_args"], # save as dict'date': None,'version': '8.0.142'}torch.save(yolo.ckpt, res_dir)
(3)剪完枝后,效果不一定好,所以使用剪枝完后的模型,继续训练:
import os
from ultralytics import YOLO
import torch
os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'def main():# model = YOLO(r'ultralytics/cfg/models/v8/yolov8s.yaml').load('runs/detect/yolov8s/weights/best.pt')model_s = YOLO("./runs/detect/prune/weights/prune.pt")model_s.train(data="data.yaml", Distillation = None, loss_type='None', amp=False, imgsz=640, epochs=50, batch=20, device=0, workers=0)if __name__ == '__main__':main()
------------------------------------------over!!!!!!!!!!!!!!!!!------------------------------
相关文章:
yolov8蒸馏(附代码-免费)
首先蒸馏是什么? 模型蒸馏(Model Distillation)是一种用于在计算机视觉中提高模型性能和效率的技术。在模型蒸馏中,通常存在两个模型,即“教师模型”和“学生模型”。 为什么需要蒸馏? 在不增加模型计算…...
Flink-StarRocks详解:第五部分查询数据湖(第55天)
系列文章目录 4.查询数据湖 4.1 Catalog 4.1.1 概述 4.1.1.1 基本概念 4.1.1.2 Catalog 4.1.1.3 访问Catalog 4.1.2 Default catalog 4.1.3 External Catalog 4.2 文件外部表 4.2.1 使用限制 4.2.2 开源版本语法 4.2.3 阿里云版本 5. 查询及优化 文章目录 系列文章目录前言4.查…...
【MySQL】常用数据类型
目录 数据类型 数据类型分类 数值类型 tinyint类型 bit类型 小数类型 float decimal 字符串类型 char varchar 日期和时间类型 enum和set 数据类型 数据类型分类 数值类型 tinyint类型 tinyint类型只占用一个字节类似于编程语言中的字符char。有带符号和无符号两…...
创建第一个rust tauri项目
安装nodejs curl -sL https://deb.nodesource.com/setup_20.x | sudo bash node -vproxychains4 npm create tauri-applatest✔ Project name tauri-app ✔ Choose which language to use for your frontend TypeScript / JavaScript - (pnpm, yarn, npm, bun) ✔ Choose yo…...
【课程总结】day19(中):Transformer架构及注意力机制了解
前言 本章内容,我们将从注意力的基础概念入手,结合Transformer架构,由宏观理解其运行流程,然后逐步深入了解多头注意力、多头掩码注意力、融合注意力等概念及作用。 注意力机制(Attension) 背景 深度学…...
4.4 标准正交基和格拉姆-施密特正交化
本节的两个目标就是为什么和怎么做(why and how)。首先是知道为什么正交性很好:因为它们的点积为零; A T A A^TA ATA 是对角矩阵;在求 x ^ \boldsymbol{\hat x} x^ 和 p A x ^ \boldsymbol pA\boldsymbol{\hat x} pAx^ 时也会很简单。第二…...
spring事务的8种失效的场景,7种传播行为
Spring事务大部分都是通过AOP实现的,所以事务失效的场景大部分都是因为AOP失效,AOP基于动态代理实现的 1.方法没有被public修饰 原因:Spring会为方法创建代理、AOP添加事务通知前提条件是该方法时public的。 2.类没有被Spring容器所托管 …...
进程的虚拟内存地址(C++程序的内存分区)
严谨的说法: 一个C、C程序实际就是一个进程,那么C的内存分区,实际上就是一个进程的内存分区,这样的话就可以分为两个大模块,从上往下,也就是0地址一直往下,假如是x86的32位Linux系统,…...
英特尔移除超线程与AMD多线程性能对比
#### 英特尔Lunar Lake架构取消超线程 在英特尔宣布Lunar Lake架构时,一个令人惊讶的消息是下一代轻薄优化架构将移除Hyper-Threading(超线程,简称SMT)。而AMD最新的Zen 5/Zen5C多线程基准测试结果显示,该特性依然为A…...
定期自动巡检,及时发现机房运维管理中的潜在问题
随着信息化技术的迅猛发展,机房作为企业数据处理与存储的核心场所,其运维管理的复杂性和挑战性也与日俱增。为确保机房设备的稳定运行和业务的连续性,运维团队必须定期进行全面的巡检。然而,传统的手工巡检方式不仅效率低下&#…...
八股文(一)
1. 为什么不使用本地缓存,而使用Redis? Redis相比于本地缓存(如JVM中的缓存)有以下几个显著优势: 高性能与低延迟:Redis是一个基于内存的数据库,其读写性能非常高,通常可以达到几万…...
灵茶八题 - 子数组 ^w^
灵茶八题 - 子数组 w 题目描述 给你一个长为 n n n 的数组 a a a,输出它的所有连续子数组的异或和的异或和。 例如 a [ 1 , 3 ] a[1,3] a[1,3] 有三个连续子数组 [ 1 ] , [ 3 ] , [ 1 , 3 ] [1],[3],[1,3] [1],[3],[1,3],异或和分别为 1 , 3 , …...
git clone private repo
Create personal access token Clone repo $ git clone https://<user_name>:<personal_access_tokens>github.com/<user_name>/<repo_name>.git...
vue3+ts+pinia+vant-项目搭建
1.pnpm介绍 npm和pnpm都是JavaScript的包管理工具,用于自动化安装、配置、更新和卸载npm包依赖。 pnpm节省了大量的磁盘空间并提高了安装速度:使用一个内容寻址的文件存储方式,如果多个项目使用相同的包版本,pnpm会存储单个副本…...
自动化测试概念篇
目录 一、自动化 1.1 自动化概念 1.2 自动化分类 1.3 自动化测试金字塔 二、web自动化测试 2.1 驱动 2.2 安装驱动管理 三、selenium 3.1 ⼀个简单的web自动化示例 3.2 selenium驱动浏览器的工作原理 一、自动化 1.1 自动化概念 在生活中: 自动洒水机&am…...
Mojo值的生命周期(Life of a value)详解
到目前为止,我们已经解释了 Mojo 如何允许您使用 Mojo 的所有权模型构建内存安全的高性能代码而无需手动管理内存。但是,Mojo 是为 系统编程而设计的,这通常需要对自定义数据类型进行手动内存管理。因此,Mojo 允许您根据需要执行此操作。需要明确的是,Mojo 没有引用计数器…...
java对接kimi详细说明,附完整项目
需求: 使用java封装kimi接口为http接口,并把调用kimi时的传参和返回数据,保存到mysql数据库中 自己记录一下,以做备忘。 具体步骤如下: 1.申请apiKey 访问:Moonshot AI - 开放平台使用手机号手机号验证…...
鸿蒙媒体开发【基于AVCodec能力的视频编解码】音频和视频
基于AVCodec能力的视频编解码 介绍 本实例基于AVCodec能力,提供基于视频编解码的视频播放和录制的功能。 视频播放的主要流程是将视频文件通过解封装->解码->送显/播放。视频录制的主要流程是相机采集->编码->封装成mp4文件。 播放支持的原子能力规…...
django集成pytest进行自动化单元测试实战
文章目录 一、引入pytest相关的包二、配置pytest1、将django的配置区分测试环境、开发环境和生产环境2、配置pytest 三、编写测试用例1、业务测试2、接口测试 四、进行测试 在Django项目中集成Pytest进行单元测试可以提高测试的灵活性和效率,相比于Django自带的测试…...
48天笔试训练错题——day40
目录 选择题 1. 2. 3. 4. 5. 6. 7. 8. 9. 10. 编程题 1. 发邮件 2. 最长上升子序列 选择题 1. DNS 劫持又称域名劫持,是指在劫持的网络范围内拦截域名解析的请求,分析请求的域名,把审查范围以外的请求放行,否则返回…...
盘古信息PCB行业解决方案:以全域场景重构,激活智造新未来
一、破局:PCB行业的时代之问 在数字经济蓬勃发展的浪潮中,PCB(印制电路板)作为 “电子产品之母”,其重要性愈发凸显。随着 5G、人工智能等新兴技术的加速渗透,PCB行业面临着前所未有的挑战与机遇。产品迭代…...
【HarmonyOS 5.0】DevEco Testing:鸿蒙应用质量保障的终极武器
——全方位测试解决方案与代码实战 一、工具定位与核心能力 DevEco Testing是HarmonyOS官方推出的一体化测试平台,覆盖应用全生命周期测试需求,主要提供五大核心能力: 测试类型检测目标关键指标功能体验基…...
深入理解JavaScript设计模式之单例模式
目录 什么是单例模式为什么需要单例模式常见应用场景包括 单例模式实现透明单例模式实现不透明单例模式用代理实现单例模式javaScript中的单例模式使用命名空间使用闭包封装私有变量 惰性单例通用的惰性单例 结语 什么是单例模式 单例模式(Singleton Pattern&#…...
cf2117E
原题链接:https://codeforces.com/contest/2117/problem/E 题目背景: 给定两个数组a,b,可以执行多次以下操作:选择 i (1 < i < n - 1),并设置 或,也可以在执行上述操作前执行一次删除任意 和 。求…...
【算法训练营Day07】字符串part1
文章目录 反转字符串反转字符串II替换数字 反转字符串 题目链接:344. 反转字符串 双指针法,两个指针的元素直接调转即可 class Solution {public void reverseString(char[] s) {int head 0;int end s.length - 1;while(head < end) {char temp …...
select、poll、epoll 与 Reactor 模式
在高并发网络编程领域,高效处理大量连接和 I/O 事件是系统性能的关键。select、poll、epoll 作为 I/O 多路复用技术的代表,以及基于它们实现的 Reactor 模式,为开发者提供了强大的工具。本文将深入探讨这些技术的底层原理、优缺点。 一、I…...
高效线程安全的单例模式:Python 中的懒加载与自定义初始化参数
高效线程安全的单例模式:Python 中的懒加载与自定义初始化参数 在软件开发中,单例模式(Singleton Pattern)是一种常见的设计模式,确保一个类仅有一个实例,并提供一个全局访问点。在多线程环境下,实现单例模式时需要注意线程安全问题,以防止多个线程同时创建实例,导致…...
代码随想录刷题day30
1、零钱兑换II 给你一个整数数组 coins 表示不同面额的硬币,另给一个整数 amount 表示总金额。 请你计算并返回可以凑成总金额的硬币组合数。如果任何硬币组合都无法凑出总金额,返回 0 。 假设每一种面额的硬币有无限个。 题目数据保证结果符合 32 位带…...
基于Springboot+Vue的办公管理系统
角色: 管理员、员工 技术: 后端: SpringBoot, Vue2, MySQL, Mybatis-Plus 前端: Vue2, Element-UI, Axios, Echarts, Vue-Router 核心功能: 该办公管理系统是一个综合性的企业内部管理平台,旨在提升企业运营效率和员工管理水…...
Web中间件--tomcat学习
Web中间件–tomcat Java虚拟机详解 什么是JAVA虚拟机 Java虚拟机是一个抽象的计算机,它可以执行Java字节码。Java虚拟机是Java平台的一部分,Java平台由Java语言、Java API和Java虚拟机组成。Java虚拟机的主要作用是将Java字节码转换为机器代码&#x…...
