当前位置: 首页 > news >正文

yolov8蒸馏(附代码-免费)

首先蒸馏是什么?

模型蒸馏(Model Distillation)是一种用于在计算机视觉中提高模型性能和效率的技术。在模型蒸馏中,通常存在两个模型,即“教师模型”和“学生模型”。

为什么需要蒸馏?

  1. 在不增加模型计算量和参数量的情况下提升精度,也即是可以无损提高精度。
  2. 配合剪枝一起使用,可以尽量达到无损降低模型参数量、计算量,提高FPS的情况下,还能保持模型精度没有下降甚至上升,这是改进网络结构无法达到的高度。
  3. 论文中的保底手段,因为剪枝和蒸馏的特殊性,其都不会增加参数量和计算量,可以在最后一个点上大幅度增加实验和工作量,因为本身蒸馏也需要做大量实验。

目录

一.代码前提

(1)本文选取的老师模型为yolov8s,学生为剪枝完的yolov8s

(2)本文使用的蒸馏方法包括mgd,cwd

(3)使用前下载必须的包,并且把数据集放在datasets文件夹中,最后替换data.yaml中分类。

二.蒸馏步骤

(1) 训练教师模型

(2) 训练学生模型

(3) 蒸馏训练

三.模型剪枝+蒸馏

(1)约束训练在我上一篇文章中提到,链接:yolov8剪枝

(2)约束训练后,先进行剪枝,使用prune.py。替换模型位置,直接运行。

(3)剪完枝后,效果不一定好,所以使用剪枝完后的模型,继续训练:


一.代码前提

(1)本文选取的老师模型为yolov8s,学生为剪枝完的yolov8s

(2)本文使用的蒸馏方法包括mgd,cwd

(3)使用前下载必须的包,并且把数据集放在datasets文件夹中,最后替换data.yaml中分类。

本文代码已经上传到GitHub,链接:yolov8_蒸馏

使用不妨加个关注,后续还会加入Vit(vision transformer),替换loss等提升精度的方法。

二.蒸馏步骤

(1) 训练教师模型

打开文件中train.py,替换模型文件位置。开始训练,达到理想目标就停止。

import os
from ultralytics import YOLO
import torch
os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'def main():model = YOLO("yolov8s.pt")model.train(data="data.yaml", Distillation = None, loss_type='None', amp=False, imgsz=640, epochs=50, batch=20, device=0, workers=0)if __name__ == '__main__':main()

(2) 训练学生模型

打开文件中train.py,替换模型文件位置。我这边使用的是剪枝后的yolov8s模型,具体轻量化剪枝步骤可见本文最后。

import os
from ultralytics import YOLO
import torch
os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'def main():model_s = YOLO("./runs/detect/prune/weights/prune.pt")model_s.train(data="data.yaml", Distillation = None, loss_type='None', amp=False, imgsz=640, epochs=50, batch=20, device=0, workers=0)if __name__ == '__main__':main()

(3) 蒸馏训练

打开文件中train_distillation.py,替换老师与学生模型文件位置。两种蒸馏方法可以选择:cwd和mgd。

import os
from ultralytics import YOLO
import torchos.environ['KMP_DUPLICATE_LIB_OK'] = 'True'def main():model_t = YOLO('runs/detect/yolov8s/weights/best.pt')  # the teacher modelmodel_s = YOLO('runs/detect/prune/weights/best.pt')  # the student model"""Attributes:Distillation: the distillation modelloss_type: mgd, cwdamp: Automatic Mixed Precision"""model_s.train(data="data.yaml", Distillation=model_t.model, loss_type='mgd', amp=False, imgsz=640, epochs=100,batch=20, device=0, workers=0, lr0=0.001)if __name__ == '__main__':main()

现在先不进行训练,打开文件夹yolo_project_distillation\ultralytics\engine\trainer.py

在类FeatureLoss中,函数forward大概162行处打一个断点,进行调试。代码位置:

    def forward(self, y_s, y_t):assert len(y_s) == len(y_t)tea_feats = []stu_feats = []for idx, (s, t) in enumerate(zip(y_s, y_t)):# change ---if self.distiller == 'cwd':s = self.align_module[idx](s)s = self.norm[idx](s)else:s = self.norm1[idx](s)t = self.norm[idx](t)tea_feats.append(t)stu_feats.append(s)loss = self.feature_loss(stu_feats, tea_feats)return self.loss_weight * loss

调试运行,查看变量中学生模型y_s和老师模型y_t的张量大小。把通道数记下来,写在类Distillation_loss的

        channels_s = [256, 480, 256, 64, 143, 229][-le:]channels_t = [256, 512, 256, 128, 256, 512][-le:]

这边总共有六个,刚好对应模型的六个层的通道数。

替换完成后,应该就可以进行训练了。训练不好的话,再来评论区找我吧。

三.模型剪枝+蒸馏

(1)约束训练在我上一篇文章中提到,链接:yolov8剪枝

(2)约束训练后,先进行剪枝,使用prune.py。替换模型位置,直接运行。

from ultralytics import YOLO
import torch
from ultralytics.nn.modules import Bottleneck, Conv, C2f, SPPF, Detect
from copy import deepcopy# Load a model
yolo = YOLO("./runs/detect/yolov8s/weights/last.pt")
# Save model address
res_dir = "./runs/detect/prune/weights/prune.pt"
# Pruning rate
factor = 0.75yolo.info()
model = yolo.model
ws = []
bs = []for name, m in model.named_modules():if isinstance(m, torch.nn.BatchNorm2d):w = m.weight.abs().detach()b = m.bias.abs().detach()ws.append(w)bs.append(b)# print(name, w.max().item(), w.min().item(), b.max().item(), b.min().item())# keepws = torch.cat(ws)
threshold = torch.sort(ws, descending=True)[0][int(len(ws) * factor)]
print(threshold)def prune_conv(conv1: Conv, conv2: Conv):gamma = conv1.bn.weight.data.detach()beta = conv1.bn.bias.data.detach()keep_idxs = []local_threshold = thresholdwhile len(keep_idxs) < 8:keep_idxs = torch.where(gamma.abs() >= local_threshold)[0]local_threshold = local_threshold * 0.5n = len(keep_idxs)# n = max(int(len(idxs) * 0.8), p)# print(n / len(gamma) * 100)# scale = len(idxs) / nconv1.bn.weight.data = gamma[keep_idxs]conv1.bn.bias.data = beta[keep_idxs]conv1.bn.running_var.data = conv1.bn.running_var.data[keep_idxs]conv1.bn.running_mean.data = conv1.bn.running_mean.data[keep_idxs]conv1.bn.num_features = nconv1.conv.weight.data = conv1.conv.weight.data[keep_idxs]conv1.conv.out_channels = nif conv1.conv.bias is not None:conv1.conv.bias.data = conv1.conv.bias.data[keep_idxs]if not isinstance(conv2, list):conv2 = [conv2]for item in conv2:if item is not None:if isinstance(item, Conv):conv = item.convelse:conv = itemconv.in_channels = nconv.weight.data = conv.weight.data[:, keep_idxs]def prune(m1, m2):if isinstance(m1, C2f):  # C2f as a top convm1 = m1.cv2if not isinstance(m2, list):  # m2 is just one modulem2 = [m2]for i, item in enumerate(m2):if isinstance(item, C2f) or isinstance(item, SPPF):m2[i] = item.cv1prune_conv(m1, m2)for name, m in model.named_modules():if isinstance(m, Bottleneck):prune_conv(m.cv1, m.cv2)seq = model.model
for i in range(3, 9):if i in [6, 4, 9]: continueprune(seq[i], seq[i + 1])detect: Detect = seq[-1]
last_inputs = [seq[15], seq[18], seq[21]]
colasts = [seq[16], seq[19], None]
for last_input, colast, cv2, cv3 in zip(last_inputs, colasts, detect.cv2, detect.cv3):prune(last_input, [colast, cv2[0], cv3[0]])prune(cv2[0], cv2[1])prune(cv2[1], cv2[2])prune(cv3[0], cv3[1])prune(cv3[1], cv3[2])for name, p in yolo.model.named_parameters():p.requires_grad = True#yolo.val(workers=0)  # 剪枝模型进行验证 yolo.val(workers=0)
yolo.info()
# yolo.export(format="onnx")  # 导出为onnx文件
# yolo.train(data="./data/data_nc5/data_nc5.yaml", epochs=100)  # 剪枝后直接训练微调
ckpt = {'epoch': -1,'best_fitness': None,'model': yolo.ckpt['ema'],'ema': None,'updates': None,'optimizer': None,'train_args': yolo.ckpt["train_args"],  # save as dict'date': None,'version': '8.0.142'}torch.save(yolo.ckpt, res_dir)

(3)剪完枝后,效果不一定好,所以使用剪枝完后的模型,继续训练:

import os
from ultralytics import YOLO
import torch
os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'def main():# model = YOLO(r'ultralytics/cfg/models/v8/yolov8s.yaml').load('runs/detect/yolov8s/weights/best.pt')model_s = YOLO("./runs/detect/prune/weights/prune.pt")model_s.train(data="data.yaml", Distillation = None, loss_type='None', amp=False, imgsz=640, epochs=50, batch=20, device=0, workers=0)if __name__ == '__main__':main()

------------------------------------------over!!!!!!!!!!!!!!!!!------------------------------

相关文章:

yolov8蒸馏(附代码-免费)

首先蒸馏是什么&#xff1f; 模型蒸馏&#xff08;Model Distillation&#xff09;是一种用于在计算机视觉中提高模型性能和效率的技术。在模型蒸馏中&#xff0c;通常存在两个模型&#xff0c;即“教师模型”和“学生模型”。 为什么需要蒸馏&#xff1f; 在不增加模型计算…...

Flink-StarRocks详解:第五部分查询数据湖(第55天)

系列文章目录 4.查询数据湖 4.1 Catalog 4.1.1 概述 4.1.1.1 基本概念 4.1.1.2 Catalog 4.1.1.3 访问Catalog 4.1.2 Default catalog 4.1.3 External Catalog 4.2 文件外部表 4.2.1 使用限制 4.2.2 开源版本语法 4.2.3 阿里云版本 5. 查询及优化 文章目录 系列文章目录前言4.查…...

【MySQL】常用数据类型

目录 数据类型 数据类型分类 数值类型 tinyint类型 bit类型 小数类型 float decimal 字符串类型 char varchar 日期和时间类型 enum和set 数据类型 数据类型分类 数值类型 tinyint类型 tinyint类型只占用一个字节类似于编程语言中的字符char。有带符号和无符号两…...

创建第一个rust tauri项目

安装nodejs curl -sL https://deb.nodesource.com/setup_20.x | sudo bash node -vproxychains4 npm create tauri-applatest✔ Project name tauri-app ✔ Choose which language to use for your frontend TypeScript / JavaScript - (pnpm, yarn, npm, bun) ✔ Choose yo…...

【课程总结】day19(中):Transformer架构及注意力机制了解

前言 本章内容&#xff0c;我们将从注意力的基础概念入手&#xff0c;结合Transformer架构&#xff0c;由宏观理解其运行流程&#xff0c;然后逐步深入了解多头注意力、多头掩码注意力、融合注意力等概念及作用。 注意力机制&#xff08;Attension&#xff09; 背景 深度学…...

4.4 标准正交基和格拉姆-施密特正交化

本节的两个目标就是为什么和怎么做(why and how)。首先是知道为什么正交性很好&#xff1a;因为它们的点积为零&#xff1b; A T A A^TA ATA 是对角矩阵&#xff1b;在求 x ^ \boldsymbol{\hat x} x^ 和 p A x ^ \boldsymbol pA\boldsymbol{\hat x} pAx^ 时也会很简单。第二…...

spring事务的8种失效的场景,7种传播行为

Spring事务大部分都是通过AOP实现的&#xff0c;所以事务失效的场景大部分都是因为AOP失效&#xff0c;AOP基于动态代理实现的 1.方法没有被public修饰 原因&#xff1a;Spring会为方法创建代理、AOP添加事务通知前提条件是该方法时public的。 2.类没有被Spring容器所托管 …...

进程的虚拟内存地址(C++程序的内存分区)

严谨的说法&#xff1a; 一个C、C程序实际就是一个进程&#xff0c;那么C的内存分区&#xff0c;实际上就是一个进程的内存分区&#xff0c;这样的话就可以分为两个大模块&#xff0c;从上往下&#xff0c;也就是0地址一直往下&#xff0c;假如是x86的32位Linux系统&#xff0c…...

英特尔移除超线程与AMD多线程性能对比

#### 英特尔Lunar Lake架构取消超线程 在英特尔宣布Lunar Lake架构时&#xff0c;一个令人惊讶的消息是下一代轻薄优化架构将移除Hyper-Threading&#xff08;超线程&#xff0c;简称SMT&#xff09;。而AMD最新的Zen 5/Zen5C多线程基准测试结果显示&#xff0c;该特性依然为A…...

定期自动巡检,及时发现机房运维管理中的潜在问题

随着信息化技术的迅猛发展&#xff0c;机房作为企业数据处理与存储的核心场所&#xff0c;其运维管理的复杂性和挑战性也与日俱增。为确保机房设备的稳定运行和业务的连续性&#xff0c;运维团队必须定期进行全面的巡检。然而&#xff0c;传统的手工巡检方式不仅效率低下&#…...

八股文(一)

1. 为什么不使用本地缓存&#xff0c;而使用Redis&#xff1f; Redis相比于本地缓存&#xff08;如JVM中的缓存&#xff09;有以下几个显著优势&#xff1a; 高性能与低延迟&#xff1a;Redis是一个基于内存的数据库&#xff0c;其读写性能非常高&#xff0c;通常可以达到几万…...

灵茶八题 - 子数组 ^w^

灵茶八题 - 子数组 w 题目描述 给你一个长为 n n n 的数组 a a a&#xff0c;输出它的所有连续子数组的异或和的异或和。 例如 a [ 1 , 3 ] a[1,3] a[1,3] 有三个连续子数组 [ 1 ] , [ 3 ] , [ 1 , 3 ] [1],[3],[1,3] [1],[3],[1,3]&#xff0c;异或和分别为 1 , 3 , …...

git clone private repo

Create personal access token Clone repo $ git clone https://<user_name>:<personal_access_tokens>github.com/<user_name>/<repo_name>.git...

vue3+ts+pinia+vant-项目搭建

1.pnpm介绍 npm和pnpm都是JavaScript的包管理工具&#xff0c;用于自动化安装、配置、更新和卸载npm包依赖。 pnpm节省了大量的磁盘空间并提高了安装速度&#xff1a;使用一个内容寻址的文件存储方式&#xff0c;如果多个项目使用相同的包版本&#xff0c;pnpm会存储单个副本…...

自动化测试概念篇

目录 一、自动化 1.1 自动化概念 1.2 自动化分类 1.3 自动化测试金字塔 二、web自动化测试 2.1 驱动 2.2 安装驱动管理 三、selenium 3.1 ⼀个简单的web自动化示例 3.2 selenium驱动浏览器的工作原理 一、自动化 1.1 自动化概念 在生活中&#xff1a; 自动洒水机&am…...

Mojo值的生命周期(Life of a value)详解

到目前为止,我们已经解释了 Mojo 如何允许您使用 Mojo 的所有权模型构建内存安全的高性能代码而无需手动管理内存。但是,Mojo 是为 系统编程而设计的,这通常需要对自定义数据类型进行手动内存管理。因此,Mojo 允许您根据需要执行此操作。需要明确的是,Mojo 没有引用计数器…...

java对接kimi详细说明,附完整项目

需求&#xff1a; 使用java封装kimi接口为http接口&#xff0c;并把调用kimi时的传参和返回数据&#xff0c;保存到mysql数据库中 自己记录一下&#xff0c;以做备忘。 具体步骤如下&#xff1a; 1.申请apiKey 访问&#xff1a;Moonshot AI - 开放平台使用手机号手机号验证…...

鸿蒙媒体开发【基于AVCodec能力的视频编解码】音频和视频

基于AVCodec能力的视频编解码 介绍 本实例基于AVCodec能力&#xff0c;提供基于视频编解码的视频播放和录制的功能。 视频播放的主要流程是将视频文件通过解封装->解码->送显/播放。视频录制的主要流程是相机采集->编码->封装成mp4文件。 播放支持的原子能力规…...

django集成pytest进行自动化单元测试实战

文章目录 一、引入pytest相关的包二、配置pytest1、将django的配置区分测试环境、开发环境和生产环境2、配置pytest 三、编写测试用例1、业务测试2、接口测试 四、进行测试 在Django项目中集成Pytest进行单元测试可以提高测试的灵活性和效率&#xff0c;相比于Django自带的测试…...

48天笔试训练错题——day40

目录 选择题 1. 2. 3. 4. 5. 6. 7. 8. 9. 10. 编程题 1. 发邮件 2. 最长上升子序列 选择题 1. DNS 劫持又称域名劫持&#xff0c;是指在劫持的网络范围内拦截域名解析的请求&#xff0c;分析请求的域名&#xff0c;把审查范围以外的请求放行&#xff0c;否则返回…...

TDengine 快速体验(Docker 镜像方式)

简介 TDengine 可以通过安装包、Docker 镜像 及云服务快速体验 TDengine 的功能&#xff0c;本节首先介绍如何通过 Docker 快速体验 TDengine&#xff0c;然后介绍如何在 Docker 环境下体验 TDengine 的写入和查询功能。如果你不熟悉 Docker&#xff0c;请使用 安装包的方式快…...

Xshell远程连接Kali(默认 | 私钥)Note版

前言:xshell远程连接&#xff0c;私钥连接和常规默认连接 任务一 开启ssh服务 service ssh status //查看ssh服务状态 service ssh start //开启ssh服务 update-rc.d ssh enable //开启自启动ssh服务 任务二 修改配置文件 vi /etc/ssh/ssh_config //第一…...

SciencePlots——绘制论文中的图片

文章目录 安装一、风格二、1 资源 安装 # 安装最新版 pip install githttps://github.com/garrettj403/SciencePlots.git# 安装稳定版 pip install SciencePlots一、风格 简单好用的深度学习论文绘图专用工具包–Science Plot 二、 1 资源 论文绘图神器来了&#xff1a;一行…...

Day131 | 灵神 | 回溯算法 | 子集型 子集

Day131 | 灵神 | 回溯算法 | 子集型 子集 78.子集 78. 子集 - 力扣&#xff08;LeetCode&#xff09; 思路&#xff1a; 笔者写过很多次这道题了&#xff0c;不想写题解了&#xff0c;大家看灵神讲解吧 回溯算法套路①子集型回溯【基础算法精讲 14】_哔哩哔哩_bilibili 完…...

根据万维钢·精英日课6的内容,使用AI(2025)可以参考以下方法:

根据万维钢精英日课6的内容&#xff0c;使用AI&#xff08;2025&#xff09;可以参考以下方法&#xff1a; 四个洞见 模型已经比人聪明&#xff1a;以ChatGPT o3为代表的AI非常强大&#xff0c;能运用高级理论解释道理、引用最新学术论文&#xff0c;生成对顶尖科学家都有用的…...

鸿蒙DevEco Studio HarmonyOS 5跑酷小游戏实现指南

1. 项目概述 本跑酷小游戏基于鸿蒙HarmonyOS 5开发&#xff0c;使用DevEco Studio作为开发工具&#xff0c;采用Java语言实现&#xff0c;包含角色控制、障碍物生成和分数计算系统。 2. 项目结构 /src/main/java/com/example/runner/├── MainAbilitySlice.java // 主界…...

用机器学习破解新能源领域的“弃风”难题

音乐发烧友深有体会&#xff0c;玩音乐的本质就是玩电网。火电声音偏暖&#xff0c;水电偏冷&#xff0c;风电偏空旷。至于太阳能发的电&#xff0c;则略显朦胧和单薄。 不知你是否有感觉&#xff0c;近两年家里的音响声音越来越冷&#xff0c;听起来越来越单薄&#xff1f; —…...

处理vxe-table 表尾数据是单独一个接口,表格tableData数据更新后,需要点击两下,表尾才是正确的

修改bug思路&#xff1a; 分别把 tabledata 和 表尾相关数据 console.log() 发现 更新数据先后顺序不对 settimeout延迟查询表格接口 ——测试可行 升级↑&#xff1a;async await 等接口返回后再开始下一个接口查询 ________________________________________________________…...

comfyui 工作流中 图生视频 如何增加视频的长度到5秒

comfyUI 工作流怎么可以生成更长的视频。除了硬件显存要求之外还有别的方法吗&#xff1f; 在ComfyUI中实现图生视频并延长到5秒&#xff0c;需要结合多个扩展和技巧。以下是完整解决方案&#xff1a; 核心工作流配置&#xff08;24fps下5秒120帧&#xff09; #mermaid-svg-yP…...

React父子组件通信:Props怎么用?如何从父组件向子组件传递数据?

系列回顾&#xff1a; 在上一篇《React核心概念&#xff1a;State是什么&#xff1f;》中&#xff0c;我们学习了如何使用useState让一个组件拥有自己的内部数据&#xff08;State&#xff09;&#xff0c;并通过一个计数器案例&#xff0c;实现了组件的自我更新。这很棒&#…...