当前位置: 首页 > news >正文

yolov8蒸馏(附代码-免费)

首先蒸馏是什么?

模型蒸馏(Model Distillation)是一种用于在计算机视觉中提高模型性能和效率的技术。在模型蒸馏中,通常存在两个模型,即“教师模型”和“学生模型”。

为什么需要蒸馏?

  1. 在不增加模型计算量和参数量的情况下提升精度,也即是可以无损提高精度。
  2. 配合剪枝一起使用,可以尽量达到无损降低模型参数量、计算量,提高FPS的情况下,还能保持模型精度没有下降甚至上升,这是改进网络结构无法达到的高度。
  3. 论文中的保底手段,因为剪枝和蒸馏的特殊性,其都不会增加参数量和计算量,可以在最后一个点上大幅度增加实验和工作量,因为本身蒸馏也需要做大量实验。

目录

一.代码前提

(1)本文选取的老师模型为yolov8s,学生为剪枝完的yolov8s

(2)本文使用的蒸馏方法包括mgd,cwd

(3)使用前下载必须的包,并且把数据集放在datasets文件夹中,最后替换data.yaml中分类。

二.蒸馏步骤

(1) 训练教师模型

(2) 训练学生模型

(3) 蒸馏训练

三.模型剪枝+蒸馏

(1)约束训练在我上一篇文章中提到,链接:yolov8剪枝

(2)约束训练后,先进行剪枝,使用prune.py。替换模型位置,直接运行。

(3)剪完枝后,效果不一定好,所以使用剪枝完后的模型,继续训练:


一.代码前提

(1)本文选取的老师模型为yolov8s,学生为剪枝完的yolov8s

(2)本文使用的蒸馏方法包括mgd,cwd

(3)使用前下载必须的包,并且把数据集放在datasets文件夹中,最后替换data.yaml中分类。

本文代码已经上传到GitHub,链接:yolov8_蒸馏

使用不妨加个关注,后续还会加入Vit(vision transformer),替换loss等提升精度的方法。

二.蒸馏步骤

(1) 训练教师模型

打开文件中train.py,替换模型文件位置。开始训练,达到理想目标就停止。

import os
from ultralytics import YOLO
import torch
os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'def main():model = YOLO("yolov8s.pt")model.train(data="data.yaml", Distillation = None, loss_type='None', amp=False, imgsz=640, epochs=50, batch=20, device=0, workers=0)if __name__ == '__main__':main()

(2) 训练学生模型

打开文件中train.py,替换模型文件位置。我这边使用的是剪枝后的yolov8s模型,具体轻量化剪枝步骤可见本文最后。

import os
from ultralytics import YOLO
import torch
os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'def main():model_s = YOLO("./runs/detect/prune/weights/prune.pt")model_s.train(data="data.yaml", Distillation = None, loss_type='None', amp=False, imgsz=640, epochs=50, batch=20, device=0, workers=0)if __name__ == '__main__':main()

(3) 蒸馏训练

打开文件中train_distillation.py,替换老师与学生模型文件位置。两种蒸馏方法可以选择:cwd和mgd。

import os
from ultralytics import YOLO
import torchos.environ['KMP_DUPLICATE_LIB_OK'] = 'True'def main():model_t = YOLO('runs/detect/yolov8s/weights/best.pt')  # the teacher modelmodel_s = YOLO('runs/detect/prune/weights/best.pt')  # the student model"""Attributes:Distillation: the distillation modelloss_type: mgd, cwdamp: Automatic Mixed Precision"""model_s.train(data="data.yaml", Distillation=model_t.model, loss_type='mgd', amp=False, imgsz=640, epochs=100,batch=20, device=0, workers=0, lr0=0.001)if __name__ == '__main__':main()

现在先不进行训练,打开文件夹yolo_project_distillation\ultralytics\engine\trainer.py

在类FeatureLoss中,函数forward大概162行处打一个断点,进行调试。代码位置:

    def forward(self, y_s, y_t):assert len(y_s) == len(y_t)tea_feats = []stu_feats = []for idx, (s, t) in enumerate(zip(y_s, y_t)):# change ---if self.distiller == 'cwd':s = self.align_module[idx](s)s = self.norm[idx](s)else:s = self.norm1[idx](s)t = self.norm[idx](t)tea_feats.append(t)stu_feats.append(s)loss = self.feature_loss(stu_feats, tea_feats)return self.loss_weight * loss

调试运行,查看变量中学生模型y_s和老师模型y_t的张量大小。把通道数记下来,写在类Distillation_loss的

        channels_s = [256, 480, 256, 64, 143, 229][-le:]channels_t = [256, 512, 256, 128, 256, 512][-le:]

这边总共有六个,刚好对应模型的六个层的通道数。

替换完成后,应该就可以进行训练了。训练不好的话,再来评论区找我吧。

三.模型剪枝+蒸馏

(1)约束训练在我上一篇文章中提到,链接:yolov8剪枝

(2)约束训练后,先进行剪枝,使用prune.py。替换模型位置,直接运行。

from ultralytics import YOLO
import torch
from ultralytics.nn.modules import Bottleneck, Conv, C2f, SPPF, Detect
from copy import deepcopy# Load a model
yolo = YOLO("./runs/detect/yolov8s/weights/last.pt")
# Save model address
res_dir = "./runs/detect/prune/weights/prune.pt"
# Pruning rate
factor = 0.75yolo.info()
model = yolo.model
ws = []
bs = []for name, m in model.named_modules():if isinstance(m, torch.nn.BatchNorm2d):w = m.weight.abs().detach()b = m.bias.abs().detach()ws.append(w)bs.append(b)# print(name, w.max().item(), w.min().item(), b.max().item(), b.min().item())# keepws = torch.cat(ws)
threshold = torch.sort(ws, descending=True)[0][int(len(ws) * factor)]
print(threshold)def prune_conv(conv1: Conv, conv2: Conv):gamma = conv1.bn.weight.data.detach()beta = conv1.bn.bias.data.detach()keep_idxs = []local_threshold = thresholdwhile len(keep_idxs) < 8:keep_idxs = torch.where(gamma.abs() >= local_threshold)[0]local_threshold = local_threshold * 0.5n = len(keep_idxs)# n = max(int(len(idxs) * 0.8), p)# print(n / len(gamma) * 100)# scale = len(idxs) / nconv1.bn.weight.data = gamma[keep_idxs]conv1.bn.bias.data = beta[keep_idxs]conv1.bn.running_var.data = conv1.bn.running_var.data[keep_idxs]conv1.bn.running_mean.data = conv1.bn.running_mean.data[keep_idxs]conv1.bn.num_features = nconv1.conv.weight.data = conv1.conv.weight.data[keep_idxs]conv1.conv.out_channels = nif conv1.conv.bias is not None:conv1.conv.bias.data = conv1.conv.bias.data[keep_idxs]if not isinstance(conv2, list):conv2 = [conv2]for item in conv2:if item is not None:if isinstance(item, Conv):conv = item.convelse:conv = itemconv.in_channels = nconv.weight.data = conv.weight.data[:, keep_idxs]def prune(m1, m2):if isinstance(m1, C2f):  # C2f as a top convm1 = m1.cv2if not isinstance(m2, list):  # m2 is just one modulem2 = [m2]for i, item in enumerate(m2):if isinstance(item, C2f) or isinstance(item, SPPF):m2[i] = item.cv1prune_conv(m1, m2)for name, m in model.named_modules():if isinstance(m, Bottleneck):prune_conv(m.cv1, m.cv2)seq = model.model
for i in range(3, 9):if i in [6, 4, 9]: continueprune(seq[i], seq[i + 1])detect: Detect = seq[-1]
last_inputs = [seq[15], seq[18], seq[21]]
colasts = [seq[16], seq[19], None]
for last_input, colast, cv2, cv3 in zip(last_inputs, colasts, detect.cv2, detect.cv3):prune(last_input, [colast, cv2[0], cv3[0]])prune(cv2[0], cv2[1])prune(cv2[1], cv2[2])prune(cv3[0], cv3[1])prune(cv3[1], cv3[2])for name, p in yolo.model.named_parameters():p.requires_grad = True#yolo.val(workers=0)  # 剪枝模型进行验证 yolo.val(workers=0)
yolo.info()
# yolo.export(format="onnx")  # 导出为onnx文件
# yolo.train(data="./data/data_nc5/data_nc5.yaml", epochs=100)  # 剪枝后直接训练微调
ckpt = {'epoch': -1,'best_fitness': None,'model': yolo.ckpt['ema'],'ema': None,'updates': None,'optimizer': None,'train_args': yolo.ckpt["train_args"],  # save as dict'date': None,'version': '8.0.142'}torch.save(yolo.ckpt, res_dir)

(3)剪完枝后,效果不一定好,所以使用剪枝完后的模型,继续训练:

import os
from ultralytics import YOLO
import torch
os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'def main():# model = YOLO(r'ultralytics/cfg/models/v8/yolov8s.yaml').load('runs/detect/yolov8s/weights/best.pt')model_s = YOLO("./runs/detect/prune/weights/prune.pt")model_s.train(data="data.yaml", Distillation = None, loss_type='None', amp=False, imgsz=640, epochs=50, batch=20, device=0, workers=0)if __name__ == '__main__':main()

------------------------------------------over!!!!!!!!!!!!!!!!!------------------------------

相关文章:

yolov8蒸馏(附代码-免费)

首先蒸馏是什么&#xff1f; 模型蒸馏&#xff08;Model Distillation&#xff09;是一种用于在计算机视觉中提高模型性能和效率的技术。在模型蒸馏中&#xff0c;通常存在两个模型&#xff0c;即“教师模型”和“学生模型”。 为什么需要蒸馏&#xff1f; 在不增加模型计算…...

Flink-StarRocks详解:第五部分查询数据湖(第55天)

系列文章目录 4.查询数据湖 4.1 Catalog 4.1.1 概述 4.1.1.1 基本概念 4.1.1.2 Catalog 4.1.1.3 访问Catalog 4.1.2 Default catalog 4.1.3 External Catalog 4.2 文件外部表 4.2.1 使用限制 4.2.2 开源版本语法 4.2.3 阿里云版本 5. 查询及优化 文章目录 系列文章目录前言4.查…...

【MySQL】常用数据类型

目录 数据类型 数据类型分类 数值类型 tinyint类型 bit类型 小数类型 float decimal 字符串类型 char varchar 日期和时间类型 enum和set 数据类型 数据类型分类 数值类型 tinyint类型 tinyint类型只占用一个字节类似于编程语言中的字符char。有带符号和无符号两…...

创建第一个rust tauri项目

安装nodejs curl -sL https://deb.nodesource.com/setup_20.x | sudo bash node -vproxychains4 npm create tauri-applatest✔ Project name tauri-app ✔ Choose which language to use for your frontend TypeScript / JavaScript - (pnpm, yarn, npm, bun) ✔ Choose yo…...

【课程总结】day19(中):Transformer架构及注意力机制了解

前言 本章内容&#xff0c;我们将从注意力的基础概念入手&#xff0c;结合Transformer架构&#xff0c;由宏观理解其运行流程&#xff0c;然后逐步深入了解多头注意力、多头掩码注意力、融合注意力等概念及作用。 注意力机制&#xff08;Attension&#xff09; 背景 深度学…...

4.4 标准正交基和格拉姆-施密特正交化

本节的两个目标就是为什么和怎么做(why and how)。首先是知道为什么正交性很好&#xff1a;因为它们的点积为零&#xff1b; A T A A^TA ATA 是对角矩阵&#xff1b;在求 x ^ \boldsymbol{\hat x} x^ 和 p A x ^ \boldsymbol pA\boldsymbol{\hat x} pAx^ 时也会很简单。第二…...

spring事务的8种失效的场景,7种传播行为

Spring事务大部分都是通过AOP实现的&#xff0c;所以事务失效的场景大部分都是因为AOP失效&#xff0c;AOP基于动态代理实现的 1.方法没有被public修饰 原因&#xff1a;Spring会为方法创建代理、AOP添加事务通知前提条件是该方法时public的。 2.类没有被Spring容器所托管 …...

进程的虚拟内存地址(C++程序的内存分区)

严谨的说法&#xff1a; 一个C、C程序实际就是一个进程&#xff0c;那么C的内存分区&#xff0c;实际上就是一个进程的内存分区&#xff0c;这样的话就可以分为两个大模块&#xff0c;从上往下&#xff0c;也就是0地址一直往下&#xff0c;假如是x86的32位Linux系统&#xff0c…...

英特尔移除超线程与AMD多线程性能对比

#### 英特尔Lunar Lake架构取消超线程 在英特尔宣布Lunar Lake架构时&#xff0c;一个令人惊讶的消息是下一代轻薄优化架构将移除Hyper-Threading&#xff08;超线程&#xff0c;简称SMT&#xff09;。而AMD最新的Zen 5/Zen5C多线程基准测试结果显示&#xff0c;该特性依然为A…...

定期自动巡检,及时发现机房运维管理中的潜在问题

随着信息化技术的迅猛发展&#xff0c;机房作为企业数据处理与存储的核心场所&#xff0c;其运维管理的复杂性和挑战性也与日俱增。为确保机房设备的稳定运行和业务的连续性&#xff0c;运维团队必须定期进行全面的巡检。然而&#xff0c;传统的手工巡检方式不仅效率低下&#…...

八股文(一)

1. 为什么不使用本地缓存&#xff0c;而使用Redis&#xff1f; Redis相比于本地缓存&#xff08;如JVM中的缓存&#xff09;有以下几个显著优势&#xff1a; 高性能与低延迟&#xff1a;Redis是一个基于内存的数据库&#xff0c;其读写性能非常高&#xff0c;通常可以达到几万…...

灵茶八题 - 子数组 ^w^

灵茶八题 - 子数组 w 题目描述 给你一个长为 n n n 的数组 a a a&#xff0c;输出它的所有连续子数组的异或和的异或和。 例如 a [ 1 , 3 ] a[1,3] a[1,3] 有三个连续子数组 [ 1 ] , [ 3 ] , [ 1 , 3 ] [1],[3],[1,3] [1],[3],[1,3]&#xff0c;异或和分别为 1 , 3 , …...

git clone private repo

Create personal access token Clone repo $ git clone https://<user_name>:<personal_access_tokens>github.com/<user_name>/<repo_name>.git...

vue3+ts+pinia+vant-项目搭建

1.pnpm介绍 npm和pnpm都是JavaScript的包管理工具&#xff0c;用于自动化安装、配置、更新和卸载npm包依赖。 pnpm节省了大量的磁盘空间并提高了安装速度&#xff1a;使用一个内容寻址的文件存储方式&#xff0c;如果多个项目使用相同的包版本&#xff0c;pnpm会存储单个副本…...

自动化测试概念篇

目录 一、自动化 1.1 自动化概念 1.2 自动化分类 1.3 自动化测试金字塔 二、web自动化测试 2.1 驱动 2.2 安装驱动管理 三、selenium 3.1 ⼀个简单的web自动化示例 3.2 selenium驱动浏览器的工作原理 一、自动化 1.1 自动化概念 在生活中&#xff1a; 自动洒水机&am…...

Mojo值的生命周期(Life of a value)详解

到目前为止,我们已经解释了 Mojo 如何允许您使用 Mojo 的所有权模型构建内存安全的高性能代码而无需手动管理内存。但是,Mojo 是为 系统编程而设计的,这通常需要对自定义数据类型进行手动内存管理。因此,Mojo 允许您根据需要执行此操作。需要明确的是,Mojo 没有引用计数器…...

java对接kimi详细说明,附完整项目

需求&#xff1a; 使用java封装kimi接口为http接口&#xff0c;并把调用kimi时的传参和返回数据&#xff0c;保存到mysql数据库中 自己记录一下&#xff0c;以做备忘。 具体步骤如下&#xff1a; 1.申请apiKey 访问&#xff1a;Moonshot AI - 开放平台使用手机号手机号验证…...

鸿蒙媒体开发【基于AVCodec能力的视频编解码】音频和视频

基于AVCodec能力的视频编解码 介绍 本实例基于AVCodec能力&#xff0c;提供基于视频编解码的视频播放和录制的功能。 视频播放的主要流程是将视频文件通过解封装->解码->送显/播放。视频录制的主要流程是相机采集->编码->封装成mp4文件。 播放支持的原子能力规…...

django集成pytest进行自动化单元测试实战

文章目录 一、引入pytest相关的包二、配置pytest1、将django的配置区分测试环境、开发环境和生产环境2、配置pytest 三、编写测试用例1、业务测试2、接口测试 四、进行测试 在Django项目中集成Pytest进行单元测试可以提高测试的灵活性和效率&#xff0c;相比于Django自带的测试…...

48天笔试训练错题——day40

目录 选择题 1. 2. 3. 4. 5. 6. 7. 8. 9. 10. 编程题 1. 发邮件 2. 最长上升子序列 选择题 1. DNS 劫持又称域名劫持&#xff0c;是指在劫持的网络范围内拦截域名解析的请求&#xff0c;分析请求的域名&#xff0c;把审查范围以外的请求放行&#xff0c;否则返回…...

k8s从入门到放弃之Ingress七层负载

k8s从入门到放弃之Ingress七层负载 在Kubernetes&#xff08;简称K8s&#xff09;中&#xff0c;Ingress是一个API对象&#xff0c;它允许你定义如何从集群外部访问集群内部的服务。Ingress可以提供负载均衡、SSL终结和基于名称的虚拟主机等功能。通过Ingress&#xff0c;你可…...

Oracle查询表空间大小

1 查询数据库中所有的表空间以及表空间所占空间的大小 SELECTtablespace_name,sum( bytes ) / 1024 / 1024 FROMdba_data_files GROUP BYtablespace_name; 2 Oracle查询表空间大小及每个表所占空间的大小 SELECTtablespace_name,file_id,file_name,round( bytes / ( 1024 …...

vue3 字体颜色设置的多种方式

在Vue 3中设置字体颜色可以通过多种方式实现&#xff0c;这取决于你是想在组件内部直接设置&#xff0c;还是在CSS/SCSS/LESS等样式文件中定义。以下是几种常见的方法&#xff1a; 1. 内联样式 你可以直接在模板中使用style绑定来设置字体颜色。 <template><div :s…...

Spring AI 入门:Java 开发者的生成式 AI 实践之路

一、Spring AI 简介 在人工智能技术快速迭代的今天&#xff0c;Spring AI 作为 Spring 生态系统的新生力量&#xff0c;正在成为 Java 开发者拥抱生成式 AI 的最佳选择。该框架通过模块化设计实现了与主流 AI 服务&#xff08;如 OpenAI、Anthropic&#xff09;的无缝对接&…...

12.找到字符串中所有字母异位词

&#x1f9e0; 题目解析 题目描述&#xff1a; 给定两个字符串 s 和 p&#xff0c;找出 s 中所有 p 的字母异位词的起始索引。 返回的答案以数组形式表示。 字母异位词定义&#xff1a; 若两个字符串包含的字符种类和出现次数完全相同&#xff0c;顺序无所谓&#xff0c;则互为…...

【RockeMQ】第2节|RocketMQ快速实战以及核⼼概念详解(二)

升级Dledger高可用集群 一、主从架构的不足与Dledger的定位 主从架构缺陷 数据备份依赖Slave节点&#xff0c;但无自动故障转移能力&#xff0c;Master宕机后需人工切换&#xff0c;期间消息可能无法读取。Slave仅存储数据&#xff0c;无法主动升级为Master响应请求&#xff…...

CVE-2020-17519源码分析与漏洞复现(Flink 任意文件读取)

漏洞概览 漏洞名称&#xff1a;Apache Flink REST API 任意文件读取漏洞CVE编号&#xff1a;CVE-2020-17519CVSS评分&#xff1a;7.5影响版本&#xff1a;Apache Flink 1.11.0、1.11.1、1.11.2修复版本&#xff1a;≥ 1.11.3 或 ≥ 1.12.0漏洞类型&#xff1a;路径遍历&#x…...

解读《网络安全法》最新修订,把握网络安全新趋势

《网络安全法》自2017年施行以来&#xff0c;在维护网络空间安全方面发挥了重要作用。但随着网络环境的日益复杂&#xff0c;网络攻击、数据泄露等事件频发&#xff0c;现行法律已难以完全适应新的风险挑战。 2025年3月28日&#xff0c;国家网信办会同相关部门起草了《网络安全…...

2025年- H71-Lc179--39.组合总和(回溯,组合)--Java版

1.题目描述 2.思路 当前的元素可以重复使用。 &#xff08;1&#xff09;确定回溯算法函数的参数和返回值&#xff08;一般是void类型&#xff09; &#xff08;2&#xff09;因为是用递归实现的&#xff0c;所以我们要确定终止条件 &#xff08;3&#xff09;单层搜索逻辑 二…...

HTTPS证书一年多少钱?

HTTPS证书作为保障网站数据传输安全的重要工具&#xff0c;成为众多网站运营者的必备选择。然而&#xff0c;面对市场上种类繁多的HTTPS证书&#xff0c;其一年费用究竟是多少&#xff0c;又受哪些因素影响呢&#xff1f; 首先&#xff0c;HTTPS证书通常在PinTrust这样的专业平…...