当前位置: 首页 > news >正文

【知识蒸馏】多任务模型 logit-based 知识蒸馏实战

一、什么是逻辑(logit)知识蒸馏

Feature-based蒸馏原理是知识蒸馏中的一种重要方法,其关键在于利用教师模型的隐藏层特征来指导学生模型的学习过程。这种蒸馏方式旨在使学生模型能够学习到教师模型在特征提取和表示方面的能力,从而提升其性能。

具体来说,Feature-based蒸馏通过比较教师模型和学生模型在某一或多个隐藏层的特征表示来实现知识的迁移。在训练过程中,教师模型的隐藏层特征被提取出来,并作为监督信号来指导学生模型相应层的特征学习。通过优化两者在特征层面的差异(如使用均方误差、余弦相似度等作为损失函数),可以使学生模型逐渐逼近教师模型的特征表示能力。

这种蒸馏方式有几个显著的优势。首先,它充分利用了教师模型在特征提取方面的优势,帮助学生模型学习到更具判别性的特征表示。其次,通过比较特征层面的差异,可以更加细致地指导学生模型的学习过程,使其在保持较高性能的同时减小模型复杂度。最后,Feature-based蒸馏可以与其他蒸馏方式相结合,形成更为复杂的蒸馏策略,以进一步提升模型性能。

需要注意的是,在选择进行Feature-based蒸馏的隐藏层时,需要谨慎考虑。不同层的特征具有不同的语义信息和抽象程度,因此选择合适的层进行蒸馏对于最终效果至关重要。此外,蒸馏过程中的损失函数和权重设置也需要根据具体任务和数据集进行调整。

综上所述,Feature-based蒸馏原理是通过利用教师模型的隐藏层特征来指导学生模型的学习过程,从而实现知识的迁移和模型性能的提升。这种方法在深度学习领域具有广泛的应用前景,尤其在需要提高模型特征提取能力的场景中表现出色。

二、如何进行多任务模型的知识蒸馏

(1)加载学生和教师模型
(2)定义分割蒸馏损失,定义检测蒸馏损失
(3)计算分割蒸馏损失,计算检测蒸馏损失
(4)计算学生模型的分割,检测损失
(5)计算总损失,反向传播

三、实现代码

(1)加载学生和教师模型

# 学生模型
model = torch.load(args.student_model, map_location=device)
# 教师模型
teacher_model = YourModel(task="multi")
teacher_model.load_state_dict(torch.load(args.teacher_model, map_location=device))

(2)定义分割蒸馏损失,定义检测蒸馏损失
分割损失,参考:【知识蒸馏】语义分割模型逻辑蒸馏实战,对剪枝的模型进行蒸馏训练

# ------------ seg logit distill loss -------------#
def seg_logit_distill_loss(t_pred, s_pred, tempature = 2):KD = nn.KLDivLoss(reduction='mean')t_p = F.softmax(t_pred / tempature, dim=1)s_p = F.log_softmax(s_pred / tempature, dim=1)loss = KD(s_p, t_p) * (tempature ** 2)return loss

检测损失,参考:【知识蒸馏】yolov5逻辑蒸馏和特征蒸馏实战

# ------------ det logit distill loss -------------#
def det_logit_distill_loss(t_pred,s_pred,tempature=1):L2 = nn.MSELoss(reduction="none")t_lobj = L2(s_pred[..., 4], t_pred[..., 4]).mean()t_lBox = L2(s_pred[..., :4], t_pred[..., :4]).mean()t_lcls = L2(s_pred[..., 5:], t_pred[..., 5:]).mean()return (t_lobj + t_lBox + t_lcls) * tempature

(3)计算分割蒸馏loss,计算检测蒸馏损失

with torch.no_grad():teacher_outputs = teacher_model(images)
# 分割蒸馏loss
teacher_seg_output = teacher_outputs.get("seg")
student_seg_output = predictions.get("seg")
seg_soft_loss = seg_logit_distill_loss(teacher_seg_output, student_seg_output)
# 检测蒸馏loss
teacher_det_output = teacher_outputs.get("det")
student_det_output = predictions.get("det")
det_soft_loss = det_logit_distill_loss(teacher_det_output, student_det_output)

(4)计算学生模型的分割,检测损失

det_loss = calc_det_loss(...)
seg_loss = CE_Loss(...)

(5)计算总损失,反向传播

seg_distill_loss = seg_loss * (1 - seg_alpha) + seg_soft_loss * seg_alpha
det_distill_loss = det_loss * (1 - det_alpha) + det_soft_loss * det_alpha
loss = det_distill_loss * Ratio_det + seg_distill_loss * Ratio_seg
loss.backward()

相关文章:

【知识蒸馏】多任务模型 logit-based 知识蒸馏实战

一、什么是逻辑(logit)知识蒸馏 Feature-based蒸馏原理是知识蒸馏中的一种重要方法,其关键在于利用教师模型的隐藏层特征来指导学生模型的学习过程。这种蒸馏方式旨在使学生模型能够学习到教师模型在特征提取和表示方面的能力,从…...

C:技术面试总结

1 变量的声明和定义: 定义:为变量分配地址和存储空间 声明:不分配地址。一个变量可以在多个地方声明,但只能在一个地方定义。extern修饰的变量声明,说明此变量将在文件以外或文件后面部分定义。 2 局部变量是否能与全局变量重名: 可以,局部变量会屏蔽全局变量 局部…...

OpenHarmony 实战开发——一文总结ACE代码框架

一、前言 ACE_Engine框架是OpenAtom OpenHarmony(简称“OpenHarmony”)的UI开发框架,为开发者提供在进行应用UI开发时所必需的各种组件,以及定义这些组件的属性、样式、事件及方法,通过这些组件可以方便进行OpenHarmo…...

【数据结构与算法】之堆的应用——堆排序及Top_K问题!

目录 1、堆排序 2、Top_K问题 3、完结散花 个人主页:秋风起,再归来~ 数据结构与算法 个人格言:悟已往之不谏,知来者犹可追 克心守己,律己则安! 1、堆排序 对一个无序的数组…...

啊哈!算法-第2章-栈、队列、链表

啊哈!算法-第2章-栈、队列、链表 第1节 解密qq号——队列第2节 解密回文——栈第3节 纸牌游戏——小猫钓鱼第4节 链表第5节 模拟链表 第1节 解密qq号——队列 新学期开始了,小哈是小哼的新同桌(小哈是个大帅哥哦~),小哼向小哈询问 QQ 号, 小…...

简述 v-if 和 v-show 的区别

v-if 和 v-show 都是 Vue.js 中用于控制元素显示与隐藏的指令,但它们的工作方式有显著的差异。以下是它们之间的主要区别: 渲染方式: v-if:v-if 是“真正”的条件渲染,因为它会确保在切换过程中条件块内的事件监听器和…...

Linux驱动学习之模块化,参数传递,符号导出

1.模块化 1.1.模块化的基本概念: 模块化是指将特定的功能或组件独立出来,以便于开发、测试和维护。在Linux设备驱动中,模块化允许将驱动程序作为内核模块动态加载到系统中,从而提高了系统的灵活性和可扩展性。 1.2.Linux内核模…...

RabbitMQ02-RebbitMQ简介及交换器

一. AMQP协议 什么是AMQP协议 AMQP(Advanced Message Queuing Protocol,高级消息队列协议):它是进程之间传递异步消息的网络协议 AMQP工作过程 发布者通过发布消息,通过交换机,交换机根据路由规则将收到的消息分发交换机绑定的下消息队列,最…...

Matlab自学笔记三十:元胞数组的修改、添加、删除和连接

1.说明 元胞数组的子数组或元素也是元胞型的,其元素内容(值)是本身类型,因此,在添、删、改和连接处理时,必须明确每个元素的值的类型和大小,否则,编程报错是不可避免的了。看本文前…...

【LeetCode】数组——双指针法

1 双指针法 1.1 介绍 双指针法是一种常用的算法技巧,通常用于处理数组或链表中的问题。它使用两个指针,通常一个从数组的开始位置遍历,另一个从数组的末尾位置开始遍历,根据问题的不同,这两个指针可以同时移动&#…...

react 低代码平台方案汇总

React作为当前最流行的前端框架之一,其生态系统中孕育了多种低代码平台方案,旨在加速应用开发过程。以下是几款基于React的低代码平台或工具,它们通过可视化构建、预制组件、数据绑定等功能,帮助开发者快速构建应用程序&#xff1…...

oss对象上传文件设置格式

PostMapping("upload")ApiOperation(value "上传文件")public Result<UploadDTO> upload(RequestParam("file") MultipartFile file) throws Exception {if (file.isEmpty()) {return new Result<UploadDTO>().error(ModuleErrorCo…...

【Linux学习】进程

下面是有关进程的相关介绍&#xff0c;希望对你有所帮助&#xff01; 小海编程心语录-CSDN博客 目录 1. 进程的概念 1.1 进程与程序 1.2 进程号 2. 进程的状态 2.1 fork创建子进程 2.2 父子进程间的文件共享 3. 进程的诞生与终止 3.1 进程的诞生 3.2 进程的终止 1. 进…...

Python数据分析实验四:数据分析综合应用开发

目录 一、实验目的与要求二、主要实验过程1、加载数据集2、数据预处理3、划分数据集4、创建模型估计器5、模型拟合6、模型性能评估 三、主要程序清单和运行结果四、实验体会 一、实验目的与要求 1、目的&#xff1a; 综合运用所学知识&#xff0c;选取有实际背景的应用问题进行…...

基于51单片机的盆栽自动浇花系统

一.硬件方案 工作原理是湿度传感器将采集到的数据直接传送到ADC0832的IN端作为输入的模拟信号。选用湿度传感器和AD转换&#xff0c;电路内部包含有湿度采集、AD转换、单片机译码显示等功能。单片机需要采集数据时&#xff0c;发出指令启动A/D转换器工作&#xff0c;ADC0832根…...

SpirngMVC框架学习笔记(一):SpringMVC基本介绍

1 SpringMVC 特点&概述 SpringMVC 从易用性&#xff0c;效率上 比曾经流行的 Struts2 更好 SpringMVC 是 WEB 层框架&#xff0c;接管了 Web 层组件, 比如控制器, 视图, 视图解析, 返回给用户的数据格式, 同时支持 MVC 的开发模式/开发架构SpringMVC 通过注解&#xff0c;…...

实现信号发生控制

1. 信号发生器的基本原理 信号发生器是一种能够产生特定波形和频率的电子设备&#xff0c;常用于模拟信号的产生和测试。 信号发生器的基本原理 信号发生器的工作原理基于不同的技术&#xff0c;但最常见的类型包括模拟信号发生器和数字信号发生器&#xff08;DDS&#xff0…...

二叉树基于队列实现的操作详解

一、队列知识补充 有关队列的知识请详见博主的另一篇博客&#xff1a;http://t.csdnimg.cn/3PwO4 本文仅仅附上需要的队列操作供读者参考 //结构体定义 typedef struct BinaryTreeNode* QDataType;typedef struct QueueNode {struct QueueNode* next;QDataType val; }QNode;…...

LabVIEW常用开发架构有哪些

LabVIEW常用开发架构有多种&#xff0c;每种架构都有其独特的特点和适用场合。以下是几种常用的开发架构及其特点和适用场合&#xff1a; 1. 单循环架构 特点&#xff1a; 简单易用适用于小型应用将所有代码放在一个循环中 适用场合&#xff1a; 简单的数据采集和处理任务…...

告别 Dart 中的 Future.wait([])

作为 Dart 开发人员&#xff0c;我们对异步编程和 Futures 的强大功能并不陌生。过去&#xff0c;当我们需要同时等待多个 future 时&#xff0c;我们依赖 Future.wait([]) 方法&#xff0c;该方法返回一个 List<T>。然而&#xff0c;这种方法有一个显着的缺点&#xff1…...

django filter 统计数量 按属性去重

在Django中&#xff0c;如果你想要根据某个属性对查询集进行去重并统计数量&#xff0c;你可以使用values()方法配合annotate()方法来实现。这里有两种常见的方法来完成这个需求&#xff1a; 方法1&#xff1a;使用annotate()和Count 假设你有一个模型Item&#xff0c;并且你想…...

基于数字孪生的水厂可视化平台建设:架构与实践

分享大纲&#xff1a; 1、数字孪生水厂可视化平台建设背景 2、数字孪生水厂可视化平台建设架构 3、数字孪生水厂可视化平台建设成效 近几年&#xff0c;数字孪生水厂的建设开展的如火如荼。作为提升水厂管理效率、优化资源的调度手段&#xff0c;基于数字孪生的水厂可视化平台的…...

【Go】3、Go语言进阶与依赖管理

前言 本系列文章参考自稀土掘金上的 【字节内部课】公开课&#xff0c;做自我学习总结整理。 Go语言并发编程 Go语言原生支持并发编程&#xff0c;它的核心机制是 Goroutine 协程、Channel 通道&#xff0c;并基于CSP&#xff08;Communicating Sequential Processes&#xff0…...

【RockeMQ】第2节|RocketMQ快速实战以及核⼼概念详解(二)

升级Dledger高可用集群 一、主从架构的不足与Dledger的定位 主从架构缺陷 数据备份依赖Slave节点&#xff0c;但无自动故障转移能力&#xff0c;Master宕机后需人工切换&#xff0c;期间消息可能无法读取。Slave仅存储数据&#xff0c;无法主动升级为Master响应请求&#xff…...

什么?连接服务器也能可视化显示界面?:基于X11 Forwarding + CentOS + MobaXterm实战指南

文章目录 什么是X11?环境准备实战步骤1️⃣ 服务器端配置(CentOS)2️⃣ 客户端配置(MobaXterm)3️⃣ 验证X11 Forwarding4️⃣ 运行自定义GUI程序(Python示例)5️⃣ 成功效果![在这里插入图片描述](https://i-blog.csdnimg.cn/direct/55aefaea8a9f477e86d065227851fe3d.pn…...

网络编程(UDP编程)

思维导图 UDP基础编程&#xff08;单播&#xff09; 1.流程图 服务器&#xff1a;短信的接收方 创建套接字 (socket)-----------------------------------------》有手机指定网络信息-----------------------------------------------》有号码绑定套接字 (bind)--------------…...

CSS设置元素的宽度根据其内容自动调整

width: fit-content 是 CSS 中的一个属性值&#xff0c;用于设置元素的宽度根据其内容自动调整&#xff0c;确保宽度刚好容纳内容而不会超出。 效果对比 默认情况&#xff08;width: auto&#xff09;&#xff1a; 块级元素&#xff08;如 <div>&#xff09;会占满父容器…...

AI病理诊断七剑下天山,医疗未来触手可及

一、病理诊断困局&#xff1a;刀尖上的医学艺术 1.1 金标准背后的隐痛 病理诊断被誉为"诊断的诊断"&#xff0c;医生需通过显微镜观察组织切片&#xff0c;在细胞迷宫中捕捉癌变信号。某省病理质控报告显示&#xff0c;基层医院误诊率达12%-15%&#xff0c;专家会诊…...

Java求职者面试指南:计算机基础与源码原理深度解析

Java求职者面试指南&#xff1a;计算机基础与源码原理深度解析 第一轮提问&#xff1a;基础概念问题 1. 请解释什么是进程和线程的区别&#xff1f; 面试官&#xff1a;进程是程序的一次执行过程&#xff0c;是系统进行资源分配和调度的基本单位&#xff1b;而线程是进程中的…...

省略号和可变参数模板

本文主要介绍如何展开可变参数的参数包 1.C语言的va_list展开可变参数 #include <iostream> #include <cstdarg>void printNumbers(int count, ...) {// 声明va_list类型的变量va_list args;// 使用va_start将可变参数写入变量argsva_start(args, count);for (in…...