【知识蒸馏】多任务模型 logit-based 知识蒸馏实战
一、什么是逻辑(logit)知识蒸馏
Feature-based蒸馏原理是知识蒸馏中的一种重要方法,其关键在于利用教师模型的隐藏层特征来指导学生模型的学习过程。这种蒸馏方式旨在使学生模型能够学习到教师模型在特征提取和表示方面的能力,从而提升其性能。
具体来说,Feature-based蒸馏通过比较教师模型和学生模型在某一或多个隐藏层的特征表示来实现知识的迁移。在训练过程中,教师模型的隐藏层特征被提取出来,并作为监督信号来指导学生模型相应层的特征学习。通过优化两者在特征层面的差异(如使用均方误差、余弦相似度等作为损失函数),可以使学生模型逐渐逼近教师模型的特征表示能力。
这种蒸馏方式有几个显著的优势。首先,它充分利用了教师模型在特征提取方面的优势,帮助学生模型学习到更具判别性的特征表示。其次,通过比较特征层面的差异,可以更加细致地指导学生模型的学习过程,使其在保持较高性能的同时减小模型复杂度。最后,Feature-based蒸馏可以与其他蒸馏方式相结合,形成更为复杂的蒸馏策略,以进一步提升模型性能。
需要注意的是,在选择进行Feature-based蒸馏的隐藏层时,需要谨慎考虑。不同层的特征具有不同的语义信息和抽象程度,因此选择合适的层进行蒸馏对于最终效果至关重要。此外,蒸馏过程中的损失函数和权重设置也需要根据具体任务和数据集进行调整。
综上所述,Feature-based蒸馏原理是通过利用教师模型的隐藏层特征来指导学生模型的学习过程,从而实现知识的迁移和模型性能的提升。这种方法在深度学习领域具有广泛的应用前景,尤其在需要提高模型特征提取能力的场景中表现出色。
二、如何进行多任务模型的知识蒸馏
(1)加载学生和教师模型
(2)定义分割蒸馏损失,定义检测蒸馏损失
(3)计算分割蒸馏损失,计算检测蒸馏损失
(4)计算学生模型的分割,检测损失
(5)计算总损失,反向传播
三、实现代码
(1)加载学生和教师模型
# 学生模型
model = torch.load(args.student_model, map_location=device)
# 教师模型
teacher_model = YourModel(task="multi")
teacher_model.load_state_dict(torch.load(args.teacher_model, map_location=device))
(2)定义分割蒸馏损失,定义检测蒸馏损失
分割损失,参考:【知识蒸馏】语义分割模型逻辑蒸馏实战,对剪枝的模型进行蒸馏训练
# ------------ seg logit distill loss -------------#
def seg_logit_distill_loss(t_pred, s_pred, tempature = 2):KD = nn.KLDivLoss(reduction='mean')t_p = F.softmax(t_pred / tempature, dim=1)s_p = F.log_softmax(s_pred / tempature, dim=1)loss = KD(s_p, t_p) * (tempature ** 2)return loss
检测损失,参考:【知识蒸馏】yolov5逻辑蒸馏和特征蒸馏实战
# ------------ det logit distill loss -------------#
def det_logit_distill_loss(t_pred,s_pred,tempature=1):L2 = nn.MSELoss(reduction="none")t_lobj = L2(s_pred[..., 4], t_pred[..., 4]).mean()t_lBox = L2(s_pred[..., :4], t_pred[..., :4]).mean()t_lcls = L2(s_pred[..., 5:], t_pred[..., 5:]).mean()return (t_lobj + t_lBox + t_lcls) * tempature
(3)计算分割蒸馏loss,计算检测蒸馏损失
with torch.no_grad():teacher_outputs = teacher_model(images)
# 分割蒸馏loss
teacher_seg_output = teacher_outputs.get("seg")
student_seg_output = predictions.get("seg")
seg_soft_loss = seg_logit_distill_loss(teacher_seg_output, student_seg_output)
# 检测蒸馏loss
teacher_det_output = teacher_outputs.get("det")
student_det_output = predictions.get("det")
det_soft_loss = det_logit_distill_loss(teacher_det_output, student_det_output)
(4)计算学生模型的分割,检测损失
det_loss = calc_det_loss(...)
seg_loss = CE_Loss(...)
(5)计算总损失,反向传播
seg_distill_loss = seg_loss * (1 - seg_alpha) + seg_soft_loss * seg_alpha
det_distill_loss = det_loss * (1 - det_alpha) + det_soft_loss * det_alpha
loss = det_distill_loss * Ratio_det + seg_distill_loss * Ratio_seg
loss.backward()
相关文章:
【知识蒸馏】多任务模型 logit-based 知识蒸馏实战
一、什么是逻辑(logit)知识蒸馏 Feature-based蒸馏原理是知识蒸馏中的一种重要方法,其关键在于利用教师模型的隐藏层特征来指导学生模型的学习过程。这种蒸馏方式旨在使学生模型能够学习到教师模型在特征提取和表示方面的能力,从…...
C:技术面试总结
1 变量的声明和定义: 定义:为变量分配地址和存储空间 声明:不分配地址。一个变量可以在多个地方声明,但只能在一个地方定义。extern修饰的变量声明,说明此变量将在文件以外或文件后面部分定义。 2 局部变量是否能与全局变量重名: 可以,局部变量会屏蔽全局变量 局部…...
OpenHarmony 实战开发——一文总结ACE代码框架
一、前言 ACE_Engine框架是OpenAtom OpenHarmony(简称“OpenHarmony”)的UI开发框架,为开发者提供在进行应用UI开发时所必需的各种组件,以及定义这些组件的属性、样式、事件及方法,通过这些组件可以方便进行OpenHarmo…...
【数据结构与算法】之堆的应用——堆排序及Top_K问题!
目录 1、堆排序 2、Top_K问题 3、完结散花 个人主页:秋风起,再归来~ 数据结构与算法 个人格言:悟已往之不谏,知来者犹可追 克心守己,律己则安! 1、堆排序 对一个无序的数组…...
啊哈!算法-第2章-栈、队列、链表
啊哈!算法-第2章-栈、队列、链表 第1节 解密qq号——队列第2节 解密回文——栈第3节 纸牌游戏——小猫钓鱼第4节 链表第5节 模拟链表 第1节 解密qq号——队列 新学期开始了,小哈是小哼的新同桌(小哈是个大帅哥哦~),小哼向小哈询问 QQ 号, 小…...
简述 v-if 和 v-show 的区别
v-if 和 v-show 都是 Vue.js 中用于控制元素显示与隐藏的指令,但它们的工作方式有显著的差异。以下是它们之间的主要区别: 渲染方式: v-if:v-if 是“真正”的条件渲染,因为它会确保在切换过程中条件块内的事件监听器和…...
Linux驱动学习之模块化,参数传递,符号导出
1.模块化 1.1.模块化的基本概念: 模块化是指将特定的功能或组件独立出来,以便于开发、测试和维护。在Linux设备驱动中,模块化允许将驱动程序作为内核模块动态加载到系统中,从而提高了系统的灵活性和可扩展性。 1.2.Linux内核模…...
RabbitMQ02-RebbitMQ简介及交换器
一. AMQP协议 什么是AMQP协议 AMQP(Advanced Message Queuing Protocol,高级消息队列协议):它是进程之间传递异步消息的网络协议 AMQP工作过程 发布者通过发布消息,通过交换机,交换机根据路由规则将收到的消息分发交换机绑定的下消息队列,最…...
Matlab自学笔记三十:元胞数组的修改、添加、删除和连接
1.说明 元胞数组的子数组或元素也是元胞型的,其元素内容(值)是本身类型,因此,在添、删、改和连接处理时,必须明确每个元素的值的类型和大小,否则,编程报错是不可避免的了。看本文前…...
【LeetCode】数组——双指针法
1 双指针法 1.1 介绍 双指针法是一种常用的算法技巧,通常用于处理数组或链表中的问题。它使用两个指针,通常一个从数组的开始位置遍历,另一个从数组的末尾位置开始遍历,根据问题的不同,这两个指针可以同时移动&#…...
react 低代码平台方案汇总
React作为当前最流行的前端框架之一,其生态系统中孕育了多种低代码平台方案,旨在加速应用开发过程。以下是几款基于React的低代码平台或工具,它们通过可视化构建、预制组件、数据绑定等功能,帮助开发者快速构建应用程序࿱…...
oss对象上传文件设置格式
PostMapping("upload")ApiOperation(value "上传文件")public Result<UploadDTO> upload(RequestParam("file") MultipartFile file) throws Exception {if (file.isEmpty()) {return new Result<UploadDTO>().error(ModuleErrorCo…...
【Linux学习】进程
下面是有关进程的相关介绍,希望对你有所帮助! 小海编程心语录-CSDN博客 目录 1. 进程的概念 1.1 进程与程序 1.2 进程号 2. 进程的状态 2.1 fork创建子进程 2.2 父子进程间的文件共享 3. 进程的诞生与终止 3.1 进程的诞生 3.2 进程的终止 1. 进…...
Python数据分析实验四:数据分析综合应用开发
目录 一、实验目的与要求二、主要实验过程1、加载数据集2、数据预处理3、划分数据集4、创建模型估计器5、模型拟合6、模型性能评估 三、主要程序清单和运行结果四、实验体会 一、实验目的与要求 1、目的: 综合运用所学知识,选取有实际背景的应用问题进行…...
基于51单片机的盆栽自动浇花系统
一.硬件方案 工作原理是湿度传感器将采集到的数据直接传送到ADC0832的IN端作为输入的模拟信号。选用湿度传感器和AD转换,电路内部包含有湿度采集、AD转换、单片机译码显示等功能。单片机需要采集数据时,发出指令启动A/D转换器工作,ADC0832根…...
SpirngMVC框架学习笔记(一):SpringMVC基本介绍
1 SpringMVC 特点&概述 SpringMVC 从易用性,效率上 比曾经流行的 Struts2 更好 SpringMVC 是 WEB 层框架,接管了 Web 层组件, 比如控制器, 视图, 视图解析, 返回给用户的数据格式, 同时支持 MVC 的开发模式/开发架构SpringMVC 通过注解,…...
实现信号发生控制
1. 信号发生器的基本原理 信号发生器是一种能够产生特定波形和频率的电子设备,常用于模拟信号的产生和测试。 信号发生器的基本原理 信号发生器的工作原理基于不同的技术,但最常见的类型包括模拟信号发生器和数字信号发生器(DDS࿰…...
二叉树基于队列实现的操作详解
一、队列知识补充 有关队列的知识请详见博主的另一篇博客:http://t.csdnimg.cn/3PwO4 本文仅仅附上需要的队列操作供读者参考 //结构体定义 typedef struct BinaryTreeNode* QDataType;typedef struct QueueNode {struct QueueNode* next;QDataType val; }QNode;…...
LabVIEW常用开发架构有哪些
LabVIEW常用开发架构有多种,每种架构都有其独特的特点和适用场合。以下是几种常用的开发架构及其特点和适用场合: 1. 单循环架构 特点: 简单易用适用于小型应用将所有代码放在一个循环中 适用场合: 简单的数据采集和处理任务…...
告别 Dart 中的 Future.wait([])
作为 Dart 开发人员,我们对异步编程和 Futures 的强大功能并不陌生。过去,当我们需要同时等待多个 future 时,我们依赖 Future.wait([]) 方法,该方法返回一个 List<T>。然而,这种方法有一个显着的缺点࿱…...
python打卡day49
知识点回顾: 通道注意力模块复习空间注意力模块CBAM的定义 作业:尝试对今天的模型检查参数数目,并用tensorboard查看训练过程 import torch import torch.nn as nn# 定义通道注意力 class ChannelAttention(nn.Module):def __init__(self,…...
rknn优化教程(二)
文章目录 1. 前述2. 三方库的封装2.1 xrepo中的库2.2 xrepo之外的库2.2.1 opencv2.2.2 rknnrt2.2.3 spdlog 3. rknn_engine库 1. 前述 OK,开始写第二篇的内容了。这篇博客主要能写一下: 如何给一些三方库按照xmake方式进行封装,供调用如何按…...
【力扣数据库知识手册笔记】索引
索引 索引的优缺点 优点1. 通过创建唯一性索引,可以保证数据库表中每一行数据的唯一性。2. 可以加快数据的检索速度(创建索引的主要原因)。3. 可以加速表和表之间的连接,实现数据的参考完整性。4. 可以在查询过程中,…...
python/java环境配置
环境变量放一起 python: 1.首先下载Python Python下载地址:Download Python | Python.org downloads ---windows -- 64 2.安装Python 下面两个,然后自定义,全选 可以把前4个选上 3.环境配置 1)搜高级系统设置 2…...
Mybatis逆向工程,动态创建实体类、条件扩展类、Mapper接口、Mapper.xml映射文件
今天呢,博主的学习进度也是步入了Java Mybatis 框架,目前正在逐步杨帆旗航。 那么接下来就给大家出一期有关 Mybatis 逆向工程的教学,希望能对大家有所帮助,也特别欢迎大家指点不足之处,小生很乐意接受正确的建议&…...
Pinocchio 库详解及其在足式机器人上的应用
Pinocchio 库详解及其在足式机器人上的应用 Pinocchio (Pinocchio is not only a nose) 是一个开源的 C 库,专门用于快速计算机器人模型的正向运动学、逆向运动学、雅可比矩阵、动力学和动力学导数。它主要关注效率和准确性,并提供了一个通用的框架&…...
基于Java Swing的电子通讯录设计与实现:附系统托盘功能代码详解
JAVASQL电子通讯录带系统托盘 一、系统概述 本电子通讯录系统采用Java Swing开发桌面应用,结合SQLite数据库实现联系人管理功能,并集成系统托盘功能提升用户体验。系统支持联系人的增删改查、分组管理、搜索过滤等功能,同时可以最小化到系统…...
安宝特案例丨Vuzix AR智能眼镜集成专业软件,助力卢森堡医院药房转型,赢得辉瑞创新奖
在Vuzix M400 AR智能眼镜的助力下,卢森堡罗伯特舒曼医院(the Robert Schuman Hospitals, HRS)凭借在无菌制剂生产流程中引入增强现实技术(AR)创新项目,荣获了2024年6月7日由卢森堡医院药剂师协会࿰…...
基于Java+MySQL实现(GUI)客户管理系统
客户资料管理系统的设计与实现 第一章 需求分析 1.1 需求总体介绍 本项目为了方便维护客户信息为了方便维护客户信息,对客户进行统一管理,可以把所有客户信息录入系统,进行维护和统计功能。可通过文件的方式保存相关录入数据,对…...
MySQL 知识小结(一)
一、my.cnf配置详解 我们知道安装MySQL有两种方式来安装咱们的MySQL数据库,分别是二进制安装编译数据库或者使用三方yum来进行安装,第三方yum的安装相对于二进制压缩包的安装更快捷,但是文件存放起来数据比较冗余,用二进制能够更好管理咱们M…...
