当前位置: 首页 > news >正文

模型蒸馏学习

知识蒸馏:获取学生网络和教师网络指定蒸馏位点输出特征并计算蒸馏 loss 的过程

知乎-mmrazor-模型蒸馏
知识蒸馏算法往往分为 reponse-based基于响应、feature-based基于特征 和 relation-based基于关系三类。
也可为 data-free KD、online KD、self KD(可视为一种特殊的 online KD)和比较经典的 offline KD

  • feature-based 方法以教师模型特征提取器产生的中间层特征为学习对象
    而ChannelWiseDivergence(cwd算法)使用的是预测之前的logistic特征图(在channel维度上取最大值,即为最终的预测结果),就是feature-based的蒸馏方法

mmrazor可以使用不同架构的student和teacher模型

  • 可以使用connector对不同维度进行对齐,以计算语义分割的蒸馏loss
  • 对于 feature-base 的方法,当学生和教师网络输出特征维度不同时,往往会对学生网络对应特征进行后处理以保证蒸馏 loss 正确计算(connector实现)

mmseg的模型,在deconde_head.conv_seg后拿到的特征图为logist的特征图,其内的元素都为小数,而非预测的0/1

cwd算法是一种什么样的蒸馏算法呢是data-free的吗,还是online的呢,还是offline的呢
适用于mmseg的cwd模型蒸馏配置文件(default_hook中的Student_CheckpointHook是自定义的hook,继承自mmegin中的CheekpointHook)
cwd算法:首先使用softmax归一化方法将每个通道的feature map转换成一个分布,然后最小化两个网络对应通道的Kullback Leibler (KL)散度。通过这样做,我们的方法着重于模拟网络间通道的软分布。特别的是,KL的差异使学习能够更多地关注通道图中最突出的区域,大概对应于语义分割最有用的信号

  • 现象:
    由于model capacity gap的存在,student往往弱于teacher模型,但也并不绝对,如果model本身的gap不是很离谱,student还是有超越teacher的可能的,因为student模型一般可以学习teacher模型蒸馏位点的特征和ground truth多种知识,学习效率会更高,如果本身student没有太大的问题,还是有机会学的更好的。
_base_ = ['mmseg::_base_/datasets/pascal_voc12.py','mmseg::_base_/schedules/schedule_160k.py','mmseg::_base_/default_runtime.py'
]# 模型的optim_wrapper,学习率和学习策略将来自于继承的schedule_160k,如果不改的话
# wandb的可视化设置在mmseg的default_runtime,也继承自mmseg# schedule_160k.py中的自动保存权重的部分
default_hooks = dict(_delete_=True,timer=dict(type='IterTimerHook'),logger=dict(type='LoggerHook', interval=100, log_metric_by_epoch=False),param_scheduler=dict(type='ParamSchedulerHook'),# 使用了自定义的Student_CheckpointHookcheckpoint=dict(type='Student_CheckpointHook', by_epoch=False, interval=-1, max_keep_ckpts=2, save_best=['mDice', 'mIoU']),# checkpoint中,interval=-1则不会保存least.pthsampler_seed=dict(type='DistSamplerSeedHook'),visualization=dict(type='SegVisualizationHook'))teacher_ckpt = '/root/autodl-tmp/all_workdir/mmseg_work_dir/baseline-convnext-tiny-upernet-rotate/best_mDice_iter_6800.pth'  # noqa: E501
teacher_cfg_path = 'mmseg::all_changed/baseline-convnext-tiny_upernet-rotate.py'  # noqa: E501student_cfg_path = 'mmseg::all_changed/pspnet_r18-d8_b16-160k_voc-material-512x512.py'  # noqa: E501
model = dict(_scope_='mmrazor',type='SingleTeacherDistill',architecture=dict(cfg_path=student_cfg_path, pretrained=False),teacher=dict(cfg_path=teacher_cfg_path, pretrained=False),teacher_ckpt=teacher_ckpt,distiller=dict(type='ConfigurableDistiller',distill_losses=dict(loss_cwd=dict(type='ChannelWiseDivergence', tau=1, loss_weight=5)),student_recorders=dict(logits=dict(type='ModuleOutputs', source='decode_head.conv_seg')),teacher_recorders=dict(logits=dict(type='ModuleOutputs', source='decode_head.conv_seg')),connectors=dict(loss_conv_stu=dict(type='ConvModuleConncetor', in_channel=2, out_channel=2, kernel_size=1, stride=1, padding=0,norm_cfg=dict(type='BN')),loss_conv_tea=dict(type='ConvModuleConncetor', in_channel=2, out_channel=2, kernel_size=3, stride=2, padding=1, padding_mode='circular',norm_cfg=dict(type='BN'))),loss_forward_mappings=dict(loss_cwd=dict(preds_S=dict(from_student=True, recorder='logits', connector='loss_conv_stu'), # 含义:从student_recorders(from_student=True)中读取名为logits的数据# 加上connnecor字段后,表示从student_recorders中读取名为logists的数据,而后将数据通过名为loss_conv_stu的连接器preds_T=dict(from_student=False, recorder='logits', connector='loss_conv_tea')))))# 从teacher_recorders中读取名为logits的数据,而后将数据通过名为loss_conv_tea的连接器# 而无论是loss_cwd、logits、loss_conv_stu、loss_conv_tea都是自定的名称find_unused_parameters = Truetrain_cfg = dict(type='IterBasedTrainLoop', max_iters=160000, val_interval=200)
val_cfg = dict(_delete_=True, type='mmrazor.SingleTeacherDistillValLoop')train_dataloader = dict(batch_size=16)  # 更改batch_size,否则会继承到pascal_voc12.py中的设置
# 这个16会作为teacher模型的推理batch和student模型的训练batchwork_dir = '/root/autodl-tmp/all_workdir/mmrazor_wokdir/distill/convnext-tiny-upernet_to_pspnet-r18'

模型蒸馏一

    1. 基于响应的KD(DKD ,FitNets):
      基于响应的KD以teacher模型的分类预测结果为目标知识,具体指的是分类器最后一个全连接层的输出(成为logits)。与最终的输出相比,logits没有经过softmax进行归一化,非目标类对应的输出值尚未被抑制。
      教师模型和学生模型之间的损失差异一般用KD散度,一般会用temperature(tau)大于1的参数对logits进行软化,以减小目标类和非目标类的预测值差异。
      logits具备的含义为模型判断当前样本为各类别的信心为多少
      1) logits提供的软标签信息,比one-hot的真实标签有着更高的熵值,从而提供了更多的信息量和数据之间更小梯度差异
      2) 软标签有着与标签平滑类似的效果,提高了模型的泛化能力
      3)除了gt标签外,还学习了软标签,使得模型学到了更多的知识,更倾向于学到不同的知识,优化方向更稳定
    1. 基于特征的KD(AB,AT,ofd,Factor Transfer):
      蒸馏位点位于模型中途获得的特征
      通常,teacher模型的通道大于学生通道,二者无法完全对齐,一般在学生的特征图后面接卷积,将两者在维度和通道上对齐,从而实现特征点的一一对应
      1)特征维度对齐,特征加权,mmrazor的connector模块的抽象
      2)知识定位,设计规则选出更为重要的教师特征
    1. 基于关系的KD(FSP, RKD):也使用特征,但计算不是特征点的一对一差异,而是特征关系的差异
      1)样本间关系蒸馏:在分类和分割中应用广泛,因为构建高质量的关系矩阵需要大量样本

总结:在语义分割中的cwd算法,可以看作是基于响应的KD,也可以看作是基于特征的KD,因为在传统的cwd算法中,使用的是在通过softmax之前的位置作为蒸馏位点,输出对应的特征图,去计算损失。

相关文章:

模型蒸馏学习

知识蒸馏:获取学生网络和教师网络指定蒸馏位点的输出特征并计算蒸馏 loss 的过程 知乎-mmrazor-模型蒸馏 知识蒸馏算法往往分为 reponse-based基于响应、feature-based基于特征 和 relation-based基于关系三类。 也可为 data-free KD、online KD、self KD&#xff…...

总结Kibana DevTools如何操作elasticsearch的常用语句

一、操作es的工具 ElasticSearch HeadKibana DevToolsElasticHQ 本文主要是总结Kibana DevTools操作es的语句。 二、搜索文档 1、根据ID查询单个记录 GET /course_idx/_doc/course:202、term 匹配"name"字段的值为"6789999"的文档 类似于sql语句中的等…...

【QT】QT自定义C++类

在使用Qt的ui设计时,Qt为我们提供了标准的类,但是在很多复杂工程中,标准的类并不能满足所有的需求,这时就需要我们自定义C类。 下面以自定义的QPushButton作一个很简单的例子。 先新建默认Qt Widgets Application项目 一、自定义…...

【多媒体文件格式】AVI、WAV、RIFF

AVI、RIFF AVI:Audio/Video Interleaved(音频视频交织/交错),用于采集、编辑、播放的RIFF文件。由Microsoft公司1992年11月推出,是Microsoft公司开发的一种符合RIFF文件规范的数字音频与视频文件格式,原先…...

AI创作系统ChatGPT商业运营系统源码+支持GPT4/支持ai绘画

一、AI创作系统 SparkAi创作系统是基于OpenAI很火的ChatGPT进行开发的Ai智能问答系统和Midjourney绘画系统,支持OpenAI-GPT全模型国内AI全模型。本期针对源码系统整体测试下来非常完美,可以说SparkAi是目前国内一款的ChatGPT对接OpenAI软件系统。那么如…...

JWT简介 JWT结构 JWT示例 前端添加JWT令牌功能 后端程序

目录 1. JWT简述 1.1 什么是JWT 1.2 为什么使用JWT 1.3 JWT结构 1.4 验证过程 2. JWT示例 2.1 后台程序 2.2 前台加入jwt令牌功能 1. JWT简述 1.1 什么是JWT Json web token (JWT), 是为了在网络应用环境间传递声明而执行的一种基于JSON的开放标准((RFC 7…...

Rust核心功能之一(所有权)

目录 1、什么是所有权? 1.1 所有权规则 1.2 变量作用域 1.3 String 类型 1.4 内存与分配 变量与数据交互的方式(一):移动 变量与数据交互的方式(二):克隆 只在栈上的数据:拷贝…...

跨域(CORS)和JWT 详解

跨域 (CORS) 概念 同源策略 (Same-Origin Policy) 同源策略是一项浏览器安全特性,它限制了一个网页中的脚本如何与另一个来源(域名、协议、端口)的资源进行交互。这对于防止跨站点请求伪造和数据泄露非常重要。 为什么需要跨域? 跨域问题通…...

前端框架Vue学习 ——(二)Vue常用指令

文章目录 常用指令 常用指令 指令: HTML 标签上带有 “v-” 前缀的特殊属性&#xff0c;不同指令具有不同含义。例如: v-if, v-for… 常用指令&#xff1a; v-bind&#xff1a;为 HTML 标签绑定属性值&#xff0c;如设置 href&#xff0c;css 样式等 <a v-bind:href"…...

Linux 指令心法(十四)`flash_erase` 擦除Flash存储器

文章目录 flash_erase 作用flash_erase的主要特点和使用场景flash_erase命令应用方法flash_erase命令可以解决哪些问题?flash_erase命令使用时注意事项 flash_erase 作用 这是一个用于擦除Flash存储器的命令。它可以擦除指定的Flash块或扇区&#xff0c;以便在写入新数据之前…...

GoLong的学习之路(二十一)进阶,语法之并发(go最重要的特点)(协程的主要用法)

并发编程在当前软件领域是一个非常重要的概念&#xff0c;随着CPU等硬件的发展&#xff0c;我们无一例外的想让我们的程序运行的快一点、再快一点。Go语言在语言层面天生支持并发&#xff0c;充分利用现代CPU的多核优势&#xff0c;这也是Go语言能够大范围流行的一个很重要的原…...

加快网站收录 3小时百度收录新站方法

加快网站收录 3小时百度收录新站方法 3小时百度收录新站方法说起来大家可能不相信&#xff0c;但这确实是真实的(该方法是通过技术提交&#xff0c;让百度快速抓取收录您的网站&#xff0c;不管你网站有没有备案&#xff0c;都能在短时间内被收录&#xff0c;要是你的网站迟迟不…...

GPT实战系列-ChatGLM3本地部署CUDA11+1080Ti+显卡24G实战方案

目录 一、ChatGLM3 模型 二、资源需求 三、部署安装 配置环境 安装过程 低成本配置部署方案 四、启动 ChatGLM3 五、功能测试 新鲜出炉&#xff0c;国产 GPT 版本迭代更新啦~清华团队刚刚发布ChatGLM3&#xff0c;恰逢云栖大会前百川也发布Baichuan2-192K&#xff0c;一…...

图片怎么转换成pdf?

图片怎么转换成pdf&#xff1f;图片可以转换成PDF格式文档吗&#xff1f;当然是可以的呀&#xff0c;当图片转换成PDF文件类型时&#xff0c;我们就会发现图片更加方便的打开分享和传播&#xff0c;而且还可以更加安全的保证我们的图片所有性。我们知道PDF文档是可以加密的&…...

【源码】医学影像PACS实现三维影像后处理等功能

医学影像诊断技术近年来取得了快速发展&#xff0c;包括高性能的影像检查设备的临床应用和数字信息技术的图像显示、存储、传输、处理、识别&#xff0c;这些技术使得计算机辅助检测和诊断成为可能&#xff0c;同时人工智能影像诊断也进入了人们的视野。这些技术进步提高了疾病…...

DOCTYPE是什么,有何作用、 使用方式、渲染模式、严格模式和怪异模式的区别?

前言 持续学习总结输出中&#xff0c;今天分享的是DOCTYPE是什么&#xff0c;有何作用、 使用方式、渲染模式、严格模式和怪异模式的区别。 DOCTYPE是什么&#xff0c;有何作用&#xff1f; DOCTYPE是HTML5的文档声明&#xff0c;通过它可以告诉浏览器&#xff0c;使用那个H…...

Go语言实现HTTP正向代理

文章目录 前言实现思路代码实现 前言 正向代理&#xff08;Forward Proxy&#xff09;是一种代理服务器的部署方式&#xff0c;它位于客户端和目标服务器之间&#xff0c;代表客户端向目标服务器发送请求。正向代理可以用来隐藏客户端的真实身份&#xff0c;以及在不同网络环境…...

第11章_数据处理之增删改

文章目录 1 插入数据1.1 实际问题1.2 方式 1&#xff1a;VALUES的方式添加1.3 方式2&#xff1a;将查询结果插入到表中演示代码 2 更新数据演示代码 3 删除数据演示代码 4 MySQL8新特性&#xff1a;计算列演示代码 5 综合案例课后练习 1 插入数据 1.1 实际问题 解决方式&#…...

数据时代的新引擎:数据治理与开发,揭秘数据领域的黄金机遇!

文章目录 一、数据时代的需求二、数据治理与开发三、案例分析四、黄金机遇《数据要素安全流通》《Python数据挖掘&#xff1a;入门、进阶与实用案例分析》《数据保护&#xff1a;工作负载的可恢复性 》《Data Mesh权威指南》《分布式统一大数据虚拟文件系统 Alluxio原理、技术与…...

使用 Golang 实现基于时间的一次性密码 TOTP

上篇文章详细讲解了一次性密码 OTP 相关的知识&#xff0c;基于时间的一次性密码 TOTP 是 OTP 的一种实现方式。这种方法的优点是不依赖网络&#xff0c;因此即使在没有网络的情况下&#xff0c;用户也可以生成密码。所以这种方式被许多流行的网站使用到双因子或多因子认证中&a…...

Docker 离线安装指南

参考文章 1、确认操作系统类型及内核版本 Docker依赖于Linux内核的一些特性&#xff0c;不同版本的Docker对内核版本有不同要求。例如&#xff0c;Docker 17.06及之后的版本通常需要Linux内核3.10及以上版本&#xff0c;Docker17.09及更高版本对应Linux内核4.9.x及更高版本。…...

css实现圆环展示百分比,根据值动态展示所占比例

代码如下 <view class""><view class"circle-chart"><view v-if"!!num" class"pie-item" :style"{background: conic-gradient(var(--one-color) 0%,#E9E6F1 ${num}%),}"></view><view v-else …...

定时器任务——若依源码分析

分析util包下面的工具类schedule utils&#xff1a; ScheduleUtils 是若依中用于与 Quartz 框架交互的工具类&#xff0c;封装了定时任务的 创建、更新、暂停、删除等核心逻辑。 createScheduleJob createScheduleJob 用于将任务注册到 Quartz&#xff0c;先构建任务的 JobD…...

【OSG学习笔记】Day 16: 骨骼动画与蒙皮(osgAnimation)

骨骼动画基础 骨骼动画是 3D 计算机图形中常用的技术&#xff0c;它通过以下两个主要组件实现角色动画。 骨骼系统 (Skeleton)&#xff1a;由层级结构的骨头组成&#xff0c;类似于人体骨骼蒙皮 (Mesh Skinning)&#xff1a;将模型网格顶点绑定到骨骼上&#xff0c;使骨骼移动…...

【Java学习笔记】BigInteger 和 BigDecimal 类

BigInteger 和 BigDecimal 类 二者共有的常见方法 方法功能add加subtract减multiply乘divide除 注意点&#xff1a;传参类型必须是类对象 一、BigInteger 1. 作用&#xff1a;适合保存比较大的整型数 2. 使用说明 创建BigInteger对象 传入字符串 3. 代码示例 import j…...

莫兰迪高级灰总结计划简约商务通用PPT模版

莫兰迪高级灰总结计划简约商务通用PPT模版&#xff0c;莫兰迪调色板清新简约工作汇报PPT模版&#xff0c;莫兰迪时尚风极简设计PPT模版&#xff0c;大学生毕业论文答辩PPT模版&#xff0c;莫兰迪配色总结计划简约商务通用PPT模版&#xff0c;莫兰迪商务汇报PPT模版&#xff0c;…...

解决:Android studio 编译后报错\app\src\main\cpp\CMakeLists.txt‘ to exist

现象&#xff1a; android studio报错&#xff1a; [CXX1409] D:\GitLab\xxxxx\app.cxx\Debug\3f3w4y1i\arm64-v8a\android_gradle_build.json : expected buildFiles file ‘D:\GitLab\xxxxx\app\src\main\cpp\CMakeLists.txt’ to exist 解决&#xff1a; 不要动CMakeLists.…...

解析两阶段提交与三阶段提交的核心差异及MySQL实现方案

引言 在分布式系统的事务处理中&#xff0c;如何保障跨节点数据操作的一致性始终是核心挑战。经典的两阶段提交协议&#xff08;2PC&#xff09;通过准备阶段与提交阶段的协调机制&#xff0c;以同步决策模式确保事务原子性。其改进版本三阶段提交协议&#xff08;3PC&#xf…...

WEB3全栈开发——面试专业技能点P4数据库

一、mysql2 原生驱动及其连接机制 概念介绍 mysql2 是 Node.js 环境中广泛使用的 MySQL 客户端库&#xff0c;基于 mysql 库改进而来&#xff0c;具有更好的性能、Promise 支持、流式查询、二进制数据处理能力等。 主要特点&#xff1a; 支持 Promise / async-await&#xf…...

医疗AI模型可解释性编程研究:基于SHAP、LIME与Anchor

1 医疗树模型与可解释人工智能基础 医疗领域的人工智能应用正迅速从理论研究转向临床实践,在这一过程中,模型可解释性已成为确保AI系统被医疗专业人员接受和信任的关键因素。基于树模型的集成算法(如RandomForest、XGBoost、LightGBM)因其卓越的预测性能和相对良好的解释性…...