【知识蒸馏】多任务模型 logit-based 知识蒸馏实战
一、什么是逻辑(logit)知识蒸馏
Feature-based蒸馏原理是知识蒸馏中的一种重要方法,其关键在于利用教师模型的隐藏层特征来指导学生模型的学习过程。这种蒸馏方式旨在使学生模型能够学习到教师模型在特征提取和表示方面的能力,从而提升其性能。
具体来说,Feature-based蒸馏通过比较教师模型和学生模型在某一或多个隐藏层的特征表示来实现知识的迁移。在训练过程中,教师模型的隐藏层特征被提取出来,并作为监督信号来指导学生模型相应层的特征学习。通过优化两者在特征层面的差异(如使用均方误差、余弦相似度等作为损失函数),可以使学生模型逐渐逼近教师模型的特征表示能力。
这种蒸馏方式有几个显著的优势。首先,它充分利用了教师模型在特征提取方面的优势,帮助学生模型学习到更具判别性的特征表示。其次,通过比较特征层面的差异,可以更加细致地指导学生模型的学习过程,使其在保持较高性能的同时减小模型复杂度。最后,Feature-based蒸馏可以与其他蒸馏方式相结合,形成更为复杂的蒸馏策略,以进一步提升模型性能。
需要注意的是,在选择进行Feature-based蒸馏的隐藏层时,需要谨慎考虑。不同层的特征具有不同的语义信息和抽象程度,因此选择合适的层进行蒸馏对于最终效果至关重要。此外,蒸馏过程中的损失函数和权重设置也需要根据具体任务和数据集进行调整。
综上所述,Feature-based蒸馏原理是通过利用教师模型的隐藏层特征来指导学生模型的学习过程,从而实现知识的迁移和模型性能的提升。这种方法在深度学习领域具有广泛的应用前景,尤其在需要提高模型特征提取能力的场景中表现出色。
二、如何进行多任务模型的知识蒸馏
(1)加载学生和教师模型
(2)定义分割蒸馏损失,定义检测蒸馏损失
(3)计算分割蒸馏损失,计算检测蒸馏损失
(4)计算学生模型的分割,检测损失
(5)计算总损失,反向传播
三、实现代码
(1)加载学生和教师模型
# 学生模型
model = torch.load(args.student_model, map_location=device)
# 教师模型
teacher_model = YourModel(task="multi")
teacher_model.load_state_dict(torch.load(args.teacher_model, map_location=device))
(2)定义分割蒸馏损失,定义检测蒸馏损失
分割损失,参考:【知识蒸馏】语义分割模型逻辑蒸馏实战,对剪枝的模型进行蒸馏训练
# ------------ seg logit distill loss -------------#
def seg_logit_distill_loss(t_pred, s_pred, tempature = 2):KD = nn.KLDivLoss(reduction='mean')t_p = F.softmax(t_pred / tempature, dim=1)s_p = F.log_softmax(s_pred / tempature, dim=1)loss = KD(s_p, t_p) * (tempature ** 2)return loss
检测损失,参考:【知识蒸馏】yolov5逻辑蒸馏和特征蒸馏实战
# ------------ det logit distill loss -------------#
def det_logit_distill_loss(t_pred,s_pred,tempature=1):L2 = nn.MSELoss(reduction="none")t_lobj = L2(s_pred[..., 4], t_pred[..., 4]).mean()t_lBox = L2(s_pred[..., :4], t_pred[..., :4]).mean()t_lcls = L2(s_pred[..., 5:], t_pred[..., 5:]).mean()return (t_lobj + t_lBox + t_lcls) * tempature
(3)计算分割蒸馏loss,计算检测蒸馏损失
with torch.no_grad():teacher_outputs = teacher_model(images)
# 分割蒸馏loss
teacher_seg_output = teacher_outputs.get("seg")
student_seg_output = predictions.get("seg")
seg_soft_loss = seg_logit_distill_loss(teacher_seg_output, student_seg_output)
# 检测蒸馏loss
teacher_det_output = teacher_outputs.get("det")
student_det_output = predictions.get("det")
det_soft_loss = det_logit_distill_loss(teacher_det_output, student_det_output)
(4)计算学生模型的分割,检测损失
det_loss = calc_det_loss(...)
seg_loss = CE_Loss(...)
(5)计算总损失,反向传播
seg_distill_loss = seg_loss * (1 - seg_alpha) + seg_soft_loss * seg_alpha
det_distill_loss = det_loss * (1 - det_alpha) + det_soft_loss * det_alpha
loss = det_distill_loss * Ratio_det + seg_distill_loss * Ratio_seg
loss.backward()
相关文章:
【知识蒸馏】多任务模型 logit-based 知识蒸馏实战
一、什么是逻辑(logit)知识蒸馏 Feature-based蒸馏原理是知识蒸馏中的一种重要方法,其关键在于利用教师模型的隐藏层特征来指导学生模型的学习过程。这种蒸馏方式旨在使学生模型能够学习到教师模型在特征提取和表示方面的能力,从…...
C:技术面试总结
1 变量的声明和定义: 定义:为变量分配地址和存储空间 声明:不分配地址。一个变量可以在多个地方声明,但只能在一个地方定义。extern修饰的变量声明,说明此变量将在文件以外或文件后面部分定义。 2 局部变量是否能与全局变量重名: 可以,局部变量会屏蔽全局变量 局部…...
OpenHarmony 实战开发——一文总结ACE代码框架
一、前言 ACE_Engine框架是OpenAtom OpenHarmony(简称“OpenHarmony”)的UI开发框架,为开发者提供在进行应用UI开发时所必需的各种组件,以及定义这些组件的属性、样式、事件及方法,通过这些组件可以方便进行OpenHarmo…...
【数据结构与算法】之堆的应用——堆排序及Top_K问题!
目录 1、堆排序 2、Top_K问题 3、完结散花 个人主页:秋风起,再归来~ 数据结构与算法 个人格言:悟已往之不谏,知来者犹可追 克心守己,律己则安! 1、堆排序 对一个无序的数组…...
啊哈!算法-第2章-栈、队列、链表
啊哈!算法-第2章-栈、队列、链表 第1节 解密qq号——队列第2节 解密回文——栈第3节 纸牌游戏——小猫钓鱼第4节 链表第5节 模拟链表 第1节 解密qq号——队列 新学期开始了,小哈是小哼的新同桌(小哈是个大帅哥哦~),小哼向小哈询问 QQ 号, 小…...
简述 v-if 和 v-show 的区别
v-if 和 v-show 都是 Vue.js 中用于控制元素显示与隐藏的指令,但它们的工作方式有显著的差异。以下是它们之间的主要区别: 渲染方式: v-if:v-if 是“真正”的条件渲染,因为它会确保在切换过程中条件块内的事件监听器和…...
Linux驱动学习之模块化,参数传递,符号导出
1.模块化 1.1.模块化的基本概念: 模块化是指将特定的功能或组件独立出来,以便于开发、测试和维护。在Linux设备驱动中,模块化允许将驱动程序作为内核模块动态加载到系统中,从而提高了系统的灵活性和可扩展性。 1.2.Linux内核模…...
RabbitMQ02-RebbitMQ简介及交换器
一. AMQP协议 什么是AMQP协议 AMQP(Advanced Message Queuing Protocol,高级消息队列协议):它是进程之间传递异步消息的网络协议 AMQP工作过程 发布者通过发布消息,通过交换机,交换机根据路由规则将收到的消息分发交换机绑定的下消息队列,最…...
Matlab自学笔记三十:元胞数组的修改、添加、删除和连接
1.说明 元胞数组的子数组或元素也是元胞型的,其元素内容(值)是本身类型,因此,在添、删、改和连接处理时,必须明确每个元素的值的类型和大小,否则,编程报错是不可避免的了。看本文前…...
【LeetCode】数组——双指针法
1 双指针法 1.1 介绍 双指针法是一种常用的算法技巧,通常用于处理数组或链表中的问题。它使用两个指针,通常一个从数组的开始位置遍历,另一个从数组的末尾位置开始遍历,根据问题的不同,这两个指针可以同时移动&#…...
react 低代码平台方案汇总
React作为当前最流行的前端框架之一,其生态系统中孕育了多种低代码平台方案,旨在加速应用开发过程。以下是几款基于React的低代码平台或工具,它们通过可视化构建、预制组件、数据绑定等功能,帮助开发者快速构建应用程序࿱…...
oss对象上传文件设置格式
PostMapping("upload")ApiOperation(value "上传文件")public Result<UploadDTO> upload(RequestParam("file") MultipartFile file) throws Exception {if (file.isEmpty()) {return new Result<UploadDTO>().error(ModuleErrorCo…...
【Linux学习】进程
下面是有关进程的相关介绍,希望对你有所帮助! 小海编程心语录-CSDN博客 目录 1. 进程的概念 1.1 进程与程序 1.2 进程号 2. 进程的状态 2.1 fork创建子进程 2.2 父子进程间的文件共享 3. 进程的诞生与终止 3.1 进程的诞生 3.2 进程的终止 1. 进…...
Python数据分析实验四:数据分析综合应用开发
目录 一、实验目的与要求二、主要实验过程1、加载数据集2、数据预处理3、划分数据集4、创建模型估计器5、模型拟合6、模型性能评估 三、主要程序清单和运行结果四、实验体会 一、实验目的与要求 1、目的: 综合运用所学知识,选取有实际背景的应用问题进行…...
基于51单片机的盆栽自动浇花系统
一.硬件方案 工作原理是湿度传感器将采集到的数据直接传送到ADC0832的IN端作为输入的模拟信号。选用湿度传感器和AD转换,电路内部包含有湿度采集、AD转换、单片机译码显示等功能。单片机需要采集数据时,发出指令启动A/D转换器工作,ADC0832根…...
SpirngMVC框架学习笔记(一):SpringMVC基本介绍
1 SpringMVC 特点&概述 SpringMVC 从易用性,效率上 比曾经流行的 Struts2 更好 SpringMVC 是 WEB 层框架,接管了 Web 层组件, 比如控制器, 视图, 视图解析, 返回给用户的数据格式, 同时支持 MVC 的开发模式/开发架构SpringMVC 通过注解,…...
实现信号发生控制
1. 信号发生器的基本原理 信号发生器是一种能够产生特定波形和频率的电子设备,常用于模拟信号的产生和测试。 信号发生器的基本原理 信号发生器的工作原理基于不同的技术,但最常见的类型包括模拟信号发生器和数字信号发生器(DDS࿰…...
二叉树基于队列实现的操作详解
一、队列知识补充 有关队列的知识请详见博主的另一篇博客:http://t.csdnimg.cn/3PwO4 本文仅仅附上需要的队列操作供读者参考 //结构体定义 typedef struct BinaryTreeNode* QDataType;typedef struct QueueNode {struct QueueNode* next;QDataType val; }QNode;…...
LabVIEW常用开发架构有哪些
LabVIEW常用开发架构有多种,每种架构都有其独特的特点和适用场合。以下是几种常用的开发架构及其特点和适用场合: 1. 单循环架构 特点: 简单易用适用于小型应用将所有代码放在一个循环中 适用场合: 简单的数据采集和处理任务…...
告别 Dart 中的 Future.wait([])
作为 Dart 开发人员,我们对异步编程和 Futures 的强大功能并不陌生。过去,当我们需要同时等待多个 future 时,我们依赖 Future.wait([]) 方法,该方法返回一个 List<T>。然而,这种方法有一个显着的缺点࿱…...
【位运算】消失的两个数字(hard)
消失的两个数字(hard) 题⽬描述:解法(位运算):Java 算法代码:更简便代码 题⽬链接:⾯试题 17.19. 消失的两个数字 题⽬描述: 给定⼀个数组,包含从 1 到 N 所有…...
MMaDA: Multimodal Large Diffusion Language Models
CODE : https://github.com/Gen-Verse/MMaDA Abstract 我们介绍了一种新型的多模态扩散基础模型MMaDA,它被设计用于在文本推理、多模态理解和文本到图像生成等不同领域实现卓越的性能。该方法的特点是三个关键创新:(i) MMaDA采用统一的扩散架构…...
2021-03-15 iview一些问题
1.iview 在使用tree组件时,发现没有set类的方法,只有get,那么要改变tree值,只能遍历treeData,递归修改treeData的checked,发现无法更改,原因在于check模式下,子元素的勾选状态跟父节…...
如何将联系人从 iPhone 转移到 Android
从 iPhone 换到 Android 手机时,你可能需要保留重要的数据,例如通讯录。好在,将通讯录从 iPhone 转移到 Android 手机非常简单,你可以从本文中学习 6 种可靠的方法,确保随时保持连接,不错过任何信息。 第 1…...
从零开始打造 OpenSTLinux 6.6 Yocto 系统(基于STM32CubeMX)(九)
设备树移植 和uboot设备树修改的内容同步到kernel将设备树stm32mp157d-stm32mp157daa1-mx.dts复制到内核源码目录下 源码修改及编译 修改arch/arm/boot/dts/st/Makefile,新增设备树编译 stm32mp157f-ev1-m4-examples.dtb \stm32mp157d-stm32mp157daa1-mx.dtb修改…...
HTML前端开发:JavaScript 常用事件详解
作为前端开发的核心,JavaScript 事件是用户与网页交互的基础。以下是常见事件的详细说明和用法示例: 1. onclick - 点击事件 当元素被单击时触发(左键点击) button.onclick function() {alert("按钮被点击了!&…...
初学 pytest 记录
安装 pip install pytest用例可以是函数也可以是类中的方法 def test_func():print()class TestAdd: # def __init__(self): 在 pytest 中不可以使用__init__方法 # self.cc 12345 pytest.mark.api def test_str(self):res add(1, 2)assert res 12def test_int(self):r…...
Android第十三次面试总结(四大 组件基础)
Activity生命周期和四大启动模式详解 一、Activity 生命周期 Activity 的生命周期由一系列回调方法组成,用于管理其创建、可见性、焦点和销毁过程。以下是核心方法及其调用时机: onCreate() 调用时机:Activity 首次创建时调用。…...
Hive 存储格式深度解析:从 TextFile 到 ORC,如何选对数据存储方案?
在大数据处理领域,Hive 作为 Hadoop 生态中重要的数据仓库工具,其存储格式的选择直接影响数据存储成本、查询效率和计算资源消耗。面对 TextFile、SequenceFile、Parquet、RCFile、ORC 等多种存储格式,很多开发者常常陷入选择困境。本文将从底…...
云原生玩法三问:构建自定义开发环境
云原生玩法三问:构建自定义开发环境 引言 临时运维一个古董项目,无文档,无环境,无交接人,俗称三无。 运行设备的环境老,本地环境版本高,ssh不过去。正好最近对 腾讯出品的云原生 cnb 感兴趣&…...
