当前位置: 首页 > news >正文

【知识蒸馏】多任务模型 logit-based 知识蒸馏实战

一、什么是逻辑(logit)知识蒸馏

Feature-based蒸馏原理是知识蒸馏中的一种重要方法,其关键在于利用教师模型的隐藏层特征来指导学生模型的学习过程。这种蒸馏方式旨在使学生模型能够学习到教师模型在特征提取和表示方面的能力,从而提升其性能。

具体来说,Feature-based蒸馏通过比较教师模型和学生模型在某一或多个隐藏层的特征表示来实现知识的迁移。在训练过程中,教师模型的隐藏层特征被提取出来,并作为监督信号来指导学生模型相应层的特征学习。通过优化两者在特征层面的差异(如使用均方误差、余弦相似度等作为损失函数),可以使学生模型逐渐逼近教师模型的特征表示能力。

这种蒸馏方式有几个显著的优势。首先,它充分利用了教师模型在特征提取方面的优势,帮助学生模型学习到更具判别性的特征表示。其次,通过比较特征层面的差异,可以更加细致地指导学生模型的学习过程,使其在保持较高性能的同时减小模型复杂度。最后,Feature-based蒸馏可以与其他蒸馏方式相结合,形成更为复杂的蒸馏策略,以进一步提升模型性能。

需要注意的是,在选择进行Feature-based蒸馏的隐藏层时,需要谨慎考虑。不同层的特征具有不同的语义信息和抽象程度,因此选择合适的层进行蒸馏对于最终效果至关重要。此外,蒸馏过程中的损失函数和权重设置也需要根据具体任务和数据集进行调整。

综上所述,Feature-based蒸馏原理是通过利用教师模型的隐藏层特征来指导学生模型的学习过程,从而实现知识的迁移和模型性能的提升。这种方法在深度学习领域具有广泛的应用前景,尤其在需要提高模型特征提取能力的场景中表现出色。

二、如何进行多任务模型的知识蒸馏

(1)加载学生和教师模型
(2)定义分割蒸馏损失,定义检测蒸馏损失
(3)计算分割蒸馏损失,计算检测蒸馏损失
(4)计算学生模型的分割,检测损失
(5)计算总损失,反向传播

三、实现代码

(1)加载学生和教师模型

# 学生模型
model = torch.load(args.student_model, map_location=device)
# 教师模型
teacher_model = YourModel(task="multi")
teacher_model.load_state_dict(torch.load(args.teacher_model, map_location=device))

(2)定义分割蒸馏损失,定义检测蒸馏损失
分割损失,参考:【知识蒸馏】语义分割模型逻辑蒸馏实战,对剪枝的模型进行蒸馏训练

# ------------ seg logit distill loss -------------#
def seg_logit_distill_loss(t_pred, s_pred, tempature = 2):KD = nn.KLDivLoss(reduction='mean')t_p = F.softmax(t_pred / tempature, dim=1)s_p = F.log_softmax(s_pred / tempature, dim=1)loss = KD(s_p, t_p) * (tempature ** 2)return loss

检测损失,参考:【知识蒸馏】yolov5逻辑蒸馏和特征蒸馏实战

# ------------ det logit distill loss -------------#
def det_logit_distill_loss(t_pred,s_pred,tempature=1):L2 = nn.MSELoss(reduction="none")t_lobj = L2(s_pred[..., 4], t_pred[..., 4]).mean()t_lBox = L2(s_pred[..., :4], t_pred[..., :4]).mean()t_lcls = L2(s_pred[..., 5:], t_pred[..., 5:]).mean()return (t_lobj + t_lBox + t_lcls) * tempature

(3)计算分割蒸馏loss,计算检测蒸馏损失

with torch.no_grad():teacher_outputs = teacher_model(images)
# 分割蒸馏loss
teacher_seg_output = teacher_outputs.get("seg")
student_seg_output = predictions.get("seg")
seg_soft_loss = seg_logit_distill_loss(teacher_seg_output, student_seg_output)
# 检测蒸馏loss
teacher_det_output = teacher_outputs.get("det")
student_det_output = predictions.get("det")
det_soft_loss = det_logit_distill_loss(teacher_det_output, student_det_output)

(4)计算学生模型的分割,检测损失

det_loss = calc_det_loss(...)
seg_loss = CE_Loss(...)

(5)计算总损失,反向传播

seg_distill_loss = seg_loss * (1 - seg_alpha) + seg_soft_loss * seg_alpha
det_distill_loss = det_loss * (1 - det_alpha) + det_soft_loss * det_alpha
loss = det_distill_loss * Ratio_det + seg_distill_loss * Ratio_seg
loss.backward()

相关文章:

【知识蒸馏】多任务模型 logit-based 知识蒸馏实战

一、什么是逻辑(logit)知识蒸馏 Feature-based蒸馏原理是知识蒸馏中的一种重要方法,其关键在于利用教师模型的隐藏层特征来指导学生模型的学习过程。这种蒸馏方式旨在使学生模型能够学习到教师模型在特征提取和表示方面的能力,从…...

C:技术面试总结

1 变量的声明和定义: 定义:为变量分配地址和存储空间 声明:不分配地址。一个变量可以在多个地方声明,但只能在一个地方定义。extern修饰的变量声明,说明此变量将在文件以外或文件后面部分定义。 2 局部变量是否能与全局变量重名: 可以,局部变量会屏蔽全局变量 局部…...

OpenHarmony 实战开发——一文总结ACE代码框架

一、前言 ACE_Engine框架是OpenAtom OpenHarmony(简称“OpenHarmony”)的UI开发框架,为开发者提供在进行应用UI开发时所必需的各种组件,以及定义这些组件的属性、样式、事件及方法,通过这些组件可以方便进行OpenHarmo…...

【数据结构与算法】之堆的应用——堆排序及Top_K问题!

目录 1、堆排序 2、Top_K问题 3、完结散花 个人主页:秋风起,再归来~ 数据结构与算法 个人格言:悟已往之不谏,知来者犹可追 克心守己,律己则安! 1、堆排序 对一个无序的数组…...

啊哈!算法-第2章-栈、队列、链表

啊哈!算法-第2章-栈、队列、链表 第1节 解密qq号——队列第2节 解密回文——栈第3节 纸牌游戏——小猫钓鱼第4节 链表第5节 模拟链表 第1节 解密qq号——队列 新学期开始了,小哈是小哼的新同桌(小哈是个大帅哥哦~),小哼向小哈询问 QQ 号, 小…...

简述 v-if 和 v-show 的区别

v-if 和 v-show 都是 Vue.js 中用于控制元素显示与隐藏的指令,但它们的工作方式有显著的差异。以下是它们之间的主要区别: 渲染方式: v-if:v-if 是“真正”的条件渲染,因为它会确保在切换过程中条件块内的事件监听器和…...

Linux驱动学习之模块化,参数传递,符号导出

1.模块化 1.1.模块化的基本概念: 模块化是指将特定的功能或组件独立出来,以便于开发、测试和维护。在Linux设备驱动中,模块化允许将驱动程序作为内核模块动态加载到系统中,从而提高了系统的灵活性和可扩展性。 1.2.Linux内核模…...

RabbitMQ02-RebbitMQ简介及交换器

一. AMQP协议 什么是AMQP协议 AMQP(Advanced Message Queuing Protocol,高级消息队列协议):它是进程之间传递异步消息的网络协议 AMQP工作过程 发布者通过发布消息,通过交换机,交换机根据路由规则将收到的消息分发交换机绑定的下消息队列,最…...

Matlab自学笔记三十:元胞数组的修改、添加、删除和连接

1.说明 元胞数组的子数组或元素也是元胞型的,其元素内容(值)是本身类型,因此,在添、删、改和连接处理时,必须明确每个元素的值的类型和大小,否则,编程报错是不可避免的了。看本文前…...

【LeetCode】数组——双指针法

1 双指针法 1.1 介绍 双指针法是一种常用的算法技巧,通常用于处理数组或链表中的问题。它使用两个指针,通常一个从数组的开始位置遍历,另一个从数组的末尾位置开始遍历,根据问题的不同,这两个指针可以同时移动&#…...

react 低代码平台方案汇总

React作为当前最流行的前端框架之一,其生态系统中孕育了多种低代码平台方案,旨在加速应用开发过程。以下是几款基于React的低代码平台或工具,它们通过可视化构建、预制组件、数据绑定等功能,帮助开发者快速构建应用程序&#xff1…...

oss对象上传文件设置格式

PostMapping("upload")ApiOperation(value "上传文件")public Result<UploadDTO> upload(RequestParam("file") MultipartFile file) throws Exception {if (file.isEmpty()) {return new Result<UploadDTO>().error(ModuleErrorCo…...

【Linux学习】进程

下面是有关进程的相关介绍&#xff0c;希望对你有所帮助&#xff01; 小海编程心语录-CSDN博客 目录 1. 进程的概念 1.1 进程与程序 1.2 进程号 2. 进程的状态 2.1 fork创建子进程 2.2 父子进程间的文件共享 3. 进程的诞生与终止 3.1 进程的诞生 3.2 进程的终止 1. 进…...

Python数据分析实验四:数据分析综合应用开发

目录 一、实验目的与要求二、主要实验过程1、加载数据集2、数据预处理3、划分数据集4、创建模型估计器5、模型拟合6、模型性能评估 三、主要程序清单和运行结果四、实验体会 一、实验目的与要求 1、目的&#xff1a; 综合运用所学知识&#xff0c;选取有实际背景的应用问题进行…...

基于51单片机的盆栽自动浇花系统

一.硬件方案 工作原理是湿度传感器将采集到的数据直接传送到ADC0832的IN端作为输入的模拟信号。选用湿度传感器和AD转换&#xff0c;电路内部包含有湿度采集、AD转换、单片机译码显示等功能。单片机需要采集数据时&#xff0c;发出指令启动A/D转换器工作&#xff0c;ADC0832根…...

SpirngMVC框架学习笔记(一):SpringMVC基本介绍

1 SpringMVC 特点&概述 SpringMVC 从易用性&#xff0c;效率上 比曾经流行的 Struts2 更好 SpringMVC 是 WEB 层框架&#xff0c;接管了 Web 层组件, 比如控制器, 视图, 视图解析, 返回给用户的数据格式, 同时支持 MVC 的开发模式/开发架构SpringMVC 通过注解&#xff0c;…...

实现信号发生控制

1. 信号发生器的基本原理 信号发生器是一种能够产生特定波形和频率的电子设备&#xff0c;常用于模拟信号的产生和测试。 信号发生器的基本原理 信号发生器的工作原理基于不同的技术&#xff0c;但最常见的类型包括模拟信号发生器和数字信号发生器&#xff08;DDS&#xff0…...

二叉树基于队列实现的操作详解

一、队列知识补充 有关队列的知识请详见博主的另一篇博客&#xff1a;http://t.csdnimg.cn/3PwO4 本文仅仅附上需要的队列操作供读者参考 //结构体定义 typedef struct BinaryTreeNode* QDataType;typedef struct QueueNode {struct QueueNode* next;QDataType val; }QNode;…...

LabVIEW常用开发架构有哪些

LabVIEW常用开发架构有多种&#xff0c;每种架构都有其独特的特点和适用场合。以下是几种常用的开发架构及其特点和适用场合&#xff1a; 1. 单循环架构 特点&#xff1a; 简单易用适用于小型应用将所有代码放在一个循环中 适用场合&#xff1a; 简单的数据采集和处理任务…...

告别 Dart 中的 Future.wait([])

作为 Dart 开发人员&#xff0c;我们对异步编程和 Futures 的强大功能并不陌生。过去&#xff0c;当我们需要同时等待多个 future 时&#xff0c;我们依赖 Future.wait([]) 方法&#xff0c;该方法返回一个 List<T>。然而&#xff0c;这种方法有一个显着的缺点&#xff1…...

手游刚开服就被攻击怎么办?如何防御DDoS?

开服初期是手游最脆弱的阶段&#xff0c;极易成为DDoS攻击的目标。一旦遭遇攻击&#xff0c;可能导致服务器瘫痪、玩家流失&#xff0c;甚至造成巨大经济损失。本文为开发者提供一套简洁有效的应急与防御方案&#xff0c;帮助快速应对并构建长期防护体系。 一、遭遇攻击的紧急应…...

CTF show Web 红包题第六弹

提示 1.不是SQL注入 2.需要找关键源码 思路 进入页面发现是一个登录框&#xff0c;很难让人不联想到SQL注入&#xff0c;但提示都说了不是SQL注入&#xff0c;所以就不往这方面想了 ​ 先查看一下网页源码&#xff0c;发现一段JavaScript代码&#xff0c;有一个关键类ctfs…...

【OSG学习笔记】Day 18: 碰撞检测与物理交互

物理引擎&#xff08;Physics Engine&#xff09; 物理引擎 是一种通过计算机模拟物理规律&#xff08;如力学、碰撞、重力、流体动力学等&#xff09;的软件工具或库。 它的核心目标是在虚拟环境中逼真地模拟物体的运动和交互&#xff0c;广泛应用于 游戏开发、动画制作、虚…...

MongoDB学习和应用(高效的非关系型数据库)

一丶 MongoDB简介 对于社交类软件的功能&#xff0c;我们需要对它的功能特点进行分析&#xff1a; 数据量会随着用户数增大而增大读多写少价值较低非好友看不到其动态信息地理位置的查询… 针对以上特点进行分析各大存储工具&#xff1a; mysql&#xff1a;关系型数据库&am…...

STM32+rt-thread判断是否联网

一、根据NETDEV_FLAG_INTERNET_UP位判断 static bool is_conncected(void) {struct netdev *dev RT_NULL;dev netdev_get_first_by_flags(NETDEV_FLAG_INTERNET_UP);if (dev RT_NULL){printf("wait netdev internet up...");return false;}else{printf("loc…...

(二)TensorRT-LLM | 模型导出(v0.20.0rc3)

0. 概述 上一节 对安装和使用有个基本介绍。根据这个 issue 的描述&#xff0c;后续 TensorRT-LLM 团队可能更专注于更新和维护 pytorch backend。但 tensorrt backend 作为先前一直开发的工作&#xff0c;其中包含了大量可以学习的地方。本文主要看看它导出模型的部分&#x…...

大数据学习(132)-HIve数据分析

​​​​&#x1f34b;&#x1f34b;大数据学习&#x1f34b;&#x1f34b; &#x1f525;系列专栏&#xff1a; &#x1f451;哲学语录: 用力所能及&#xff0c;改变世界。 &#x1f496;如果觉得博主的文章还不错的话&#xff0c;请点赞&#x1f44d;收藏⭐️留言&#x1f4…...

如何在网页里填写 PDF 表格?

有时候&#xff0c;你可能希望用户能在你的网站上填写 PDF 表单。然而&#xff0c;这件事并不简单&#xff0c;因为 PDF 并不是一种原生的网页格式。虽然浏览器可以显示 PDF 文件&#xff0c;但原生并不支持编辑或填写它们。更糟的是&#xff0c;如果你想收集表单数据&#xff…...

Unsafe Fileupload篇补充-木马的详细教程与木马分享(中国蚁剑方式)

在之前的皮卡丘靶场第九期Unsafe Fileupload篇中我们学习了木马的原理并且学了一个简单的木马文件 本期内容是为了更好的为大家解释木马&#xff08;服务器方面的&#xff09;的原理&#xff0c;连接&#xff0c;以及各种木马及连接工具的分享 文件木马&#xff1a;https://w…...

视觉slam十四讲实践部分记录——ch2、ch3

ch2 一、使用g++编译.cpp为可执行文件并运行(P30) g++ helloSLAM.cpp ./a.out运行 二、使用cmake编译 mkdir build cd build cmake .. makeCMakeCache.txt 文件仍然指向旧的目录。这表明在源代码目录中可能还存在旧的 CMakeCache.txt 文件,或者在构建过程中仍然引用了旧的路…...