当前位置: 首页 > news >正文

模型蒸馏学习

知识蒸馏:获取学生网络和教师网络指定蒸馏位点输出特征并计算蒸馏 loss 的过程

知乎-mmrazor-模型蒸馏
知识蒸馏算法往往分为 reponse-based基于响应、feature-based基于特征 和 relation-based基于关系三类。
也可为 data-free KD、online KD、self KD(可视为一种特殊的 online KD)和比较经典的 offline KD

  • feature-based 方法以教师模型特征提取器产生的中间层特征为学习对象
    而ChannelWiseDivergence(cwd算法)使用的是预测之前的logistic特征图(在channel维度上取最大值,即为最终的预测结果),就是feature-based的蒸馏方法

mmrazor可以使用不同架构的student和teacher模型

  • 可以使用connector对不同维度进行对齐,以计算语义分割的蒸馏loss
  • 对于 feature-base 的方法,当学生和教师网络输出特征维度不同时,往往会对学生网络对应特征进行后处理以保证蒸馏 loss 正确计算(connector实现)

mmseg的模型,在deconde_head.conv_seg后拿到的特征图为logist的特征图,其内的元素都为小数,而非预测的0/1

cwd算法是一种什么样的蒸馏算法呢是data-free的吗,还是online的呢,还是offline的呢
适用于mmseg的cwd模型蒸馏配置文件(default_hook中的Student_CheckpointHook是自定义的hook,继承自mmegin中的CheekpointHook)
cwd算法:首先使用softmax归一化方法将每个通道的feature map转换成一个分布,然后最小化两个网络对应通道的Kullback Leibler (KL)散度。通过这样做,我们的方法着重于模拟网络间通道的软分布。特别的是,KL的差异使学习能够更多地关注通道图中最突出的区域,大概对应于语义分割最有用的信号

  • 现象:
    由于model capacity gap的存在,student往往弱于teacher模型,但也并不绝对,如果model本身的gap不是很离谱,student还是有超越teacher的可能的,因为student模型一般可以学习teacher模型蒸馏位点的特征和ground truth多种知识,学习效率会更高,如果本身student没有太大的问题,还是有机会学的更好的。
_base_ = ['mmseg::_base_/datasets/pascal_voc12.py','mmseg::_base_/schedules/schedule_160k.py','mmseg::_base_/default_runtime.py'
]# 模型的optim_wrapper,学习率和学习策略将来自于继承的schedule_160k,如果不改的话
# wandb的可视化设置在mmseg的default_runtime,也继承自mmseg# schedule_160k.py中的自动保存权重的部分
default_hooks = dict(_delete_=True,timer=dict(type='IterTimerHook'),logger=dict(type='LoggerHook', interval=100, log_metric_by_epoch=False),param_scheduler=dict(type='ParamSchedulerHook'),# 使用了自定义的Student_CheckpointHookcheckpoint=dict(type='Student_CheckpointHook', by_epoch=False, interval=-1, max_keep_ckpts=2, save_best=['mDice', 'mIoU']),# checkpoint中,interval=-1则不会保存least.pthsampler_seed=dict(type='DistSamplerSeedHook'),visualization=dict(type='SegVisualizationHook'))teacher_ckpt = '/root/autodl-tmp/all_workdir/mmseg_work_dir/baseline-convnext-tiny-upernet-rotate/best_mDice_iter_6800.pth'  # noqa: E501
teacher_cfg_path = 'mmseg::all_changed/baseline-convnext-tiny_upernet-rotate.py'  # noqa: E501student_cfg_path = 'mmseg::all_changed/pspnet_r18-d8_b16-160k_voc-material-512x512.py'  # noqa: E501
model = dict(_scope_='mmrazor',type='SingleTeacherDistill',architecture=dict(cfg_path=student_cfg_path, pretrained=False),teacher=dict(cfg_path=teacher_cfg_path, pretrained=False),teacher_ckpt=teacher_ckpt,distiller=dict(type='ConfigurableDistiller',distill_losses=dict(loss_cwd=dict(type='ChannelWiseDivergence', tau=1, loss_weight=5)),student_recorders=dict(logits=dict(type='ModuleOutputs', source='decode_head.conv_seg')),teacher_recorders=dict(logits=dict(type='ModuleOutputs', source='decode_head.conv_seg')),connectors=dict(loss_conv_stu=dict(type='ConvModuleConncetor', in_channel=2, out_channel=2, kernel_size=1, stride=1, padding=0,norm_cfg=dict(type='BN')),loss_conv_tea=dict(type='ConvModuleConncetor', in_channel=2, out_channel=2, kernel_size=3, stride=2, padding=1, padding_mode='circular',norm_cfg=dict(type='BN'))),loss_forward_mappings=dict(loss_cwd=dict(preds_S=dict(from_student=True, recorder='logits', connector='loss_conv_stu'), # 含义:从student_recorders(from_student=True)中读取名为logits的数据# 加上connnecor字段后,表示从student_recorders中读取名为logists的数据,而后将数据通过名为loss_conv_stu的连接器preds_T=dict(from_student=False, recorder='logits', connector='loss_conv_tea')))))# 从teacher_recorders中读取名为logits的数据,而后将数据通过名为loss_conv_tea的连接器# 而无论是loss_cwd、logits、loss_conv_stu、loss_conv_tea都是自定的名称find_unused_parameters = Truetrain_cfg = dict(type='IterBasedTrainLoop', max_iters=160000, val_interval=200)
val_cfg = dict(_delete_=True, type='mmrazor.SingleTeacherDistillValLoop')train_dataloader = dict(batch_size=16)  # 更改batch_size,否则会继承到pascal_voc12.py中的设置
# 这个16会作为teacher模型的推理batch和student模型的训练batchwork_dir = '/root/autodl-tmp/all_workdir/mmrazor_wokdir/distill/convnext-tiny-upernet_to_pspnet-r18'

模型蒸馏一

    1. 基于响应的KD(DKD ,FitNets):
      基于响应的KD以teacher模型的分类预测结果为目标知识,具体指的是分类器最后一个全连接层的输出(成为logits)。与最终的输出相比,logits没有经过softmax进行归一化,非目标类对应的输出值尚未被抑制。
      教师模型和学生模型之间的损失差异一般用KD散度,一般会用temperature(tau)大于1的参数对logits进行软化,以减小目标类和非目标类的预测值差异。
      logits具备的含义为模型判断当前样本为各类别的信心为多少
      1) logits提供的软标签信息,比one-hot的真实标签有着更高的熵值,从而提供了更多的信息量和数据之间更小梯度差异
      2) 软标签有着与标签平滑类似的效果,提高了模型的泛化能力
      3)除了gt标签外,还学习了软标签,使得模型学到了更多的知识,更倾向于学到不同的知识,优化方向更稳定
    1. 基于特征的KD(AB,AT,ofd,Factor Transfer):
      蒸馏位点位于模型中途获得的特征
      通常,teacher模型的通道大于学生通道,二者无法完全对齐,一般在学生的特征图后面接卷积,将两者在维度和通道上对齐,从而实现特征点的一一对应
      1)特征维度对齐,特征加权,mmrazor的connector模块的抽象
      2)知识定位,设计规则选出更为重要的教师特征
    1. 基于关系的KD(FSP, RKD):也使用特征,但计算不是特征点的一对一差异,而是特征关系的差异
      1)样本间关系蒸馏:在分类和分割中应用广泛,因为构建高质量的关系矩阵需要大量样本

总结:在语义分割中的cwd算法,可以看作是基于响应的KD,也可以看作是基于特征的KD,因为在传统的cwd算法中,使用的是在通过softmax之前的位置作为蒸馏位点,输出对应的特征图,去计算损失。

相关文章:

模型蒸馏学习

知识蒸馏:获取学生网络和教师网络指定蒸馏位点的输出特征并计算蒸馏 loss 的过程 知乎-mmrazor-模型蒸馏 知识蒸馏算法往往分为 reponse-based基于响应、feature-based基于特征 和 relation-based基于关系三类。 也可为 data-free KD、online KD、self KD&#xff…...

总结Kibana DevTools如何操作elasticsearch的常用语句

一、操作es的工具 ElasticSearch HeadKibana DevToolsElasticHQ 本文主要是总结Kibana DevTools操作es的语句。 二、搜索文档 1、根据ID查询单个记录 GET /course_idx/_doc/course:202、term 匹配"name"字段的值为"6789999"的文档 类似于sql语句中的等…...

【QT】QT自定义C++类

在使用Qt的ui设计时,Qt为我们提供了标准的类,但是在很多复杂工程中,标准的类并不能满足所有的需求,这时就需要我们自定义C类。 下面以自定义的QPushButton作一个很简单的例子。 先新建默认Qt Widgets Application项目 一、自定义…...

【多媒体文件格式】AVI、WAV、RIFF

AVI、RIFF AVI:Audio/Video Interleaved(音频视频交织/交错),用于采集、编辑、播放的RIFF文件。由Microsoft公司1992年11月推出,是Microsoft公司开发的一种符合RIFF文件规范的数字音频与视频文件格式,原先…...

AI创作系统ChatGPT商业运营系统源码+支持GPT4/支持ai绘画

一、AI创作系统 SparkAi创作系统是基于OpenAI很火的ChatGPT进行开发的Ai智能问答系统和Midjourney绘画系统,支持OpenAI-GPT全模型国内AI全模型。本期针对源码系统整体测试下来非常完美,可以说SparkAi是目前国内一款的ChatGPT对接OpenAI软件系统。那么如…...

JWT简介 JWT结构 JWT示例 前端添加JWT令牌功能 后端程序

目录 1. JWT简述 1.1 什么是JWT 1.2 为什么使用JWT 1.3 JWT结构 1.4 验证过程 2. JWT示例 2.1 后台程序 2.2 前台加入jwt令牌功能 1. JWT简述 1.1 什么是JWT Json web token (JWT), 是为了在网络应用环境间传递声明而执行的一种基于JSON的开放标准((RFC 7…...

Rust核心功能之一(所有权)

目录 1、什么是所有权? 1.1 所有权规则 1.2 变量作用域 1.3 String 类型 1.4 内存与分配 变量与数据交互的方式(一):移动 变量与数据交互的方式(二):克隆 只在栈上的数据:拷贝…...

跨域(CORS)和JWT 详解

跨域 (CORS) 概念 同源策略 (Same-Origin Policy) 同源策略是一项浏览器安全特性,它限制了一个网页中的脚本如何与另一个来源(域名、协议、端口)的资源进行交互。这对于防止跨站点请求伪造和数据泄露非常重要。 为什么需要跨域? 跨域问题通…...

前端框架Vue学习 ——(二)Vue常用指令

文章目录 常用指令 常用指令 指令: HTML 标签上带有 “v-” 前缀的特殊属性&#xff0c;不同指令具有不同含义。例如: v-if, v-for… 常用指令&#xff1a; v-bind&#xff1a;为 HTML 标签绑定属性值&#xff0c;如设置 href&#xff0c;css 样式等 <a v-bind:href"…...

Linux 指令心法(十四)`flash_erase` 擦除Flash存储器

文章目录 flash_erase 作用flash_erase的主要特点和使用场景flash_erase命令应用方法flash_erase命令可以解决哪些问题?flash_erase命令使用时注意事项 flash_erase 作用 这是一个用于擦除Flash存储器的命令。它可以擦除指定的Flash块或扇区&#xff0c;以便在写入新数据之前…...

GoLong的学习之路(二十一)进阶,语法之并发(go最重要的特点)(协程的主要用法)

并发编程在当前软件领域是一个非常重要的概念&#xff0c;随着CPU等硬件的发展&#xff0c;我们无一例外的想让我们的程序运行的快一点、再快一点。Go语言在语言层面天生支持并发&#xff0c;充分利用现代CPU的多核优势&#xff0c;这也是Go语言能够大范围流行的一个很重要的原…...

加快网站收录 3小时百度收录新站方法

加快网站收录 3小时百度收录新站方法 3小时百度收录新站方法说起来大家可能不相信&#xff0c;但这确实是真实的(该方法是通过技术提交&#xff0c;让百度快速抓取收录您的网站&#xff0c;不管你网站有没有备案&#xff0c;都能在短时间内被收录&#xff0c;要是你的网站迟迟不…...

GPT实战系列-ChatGLM3本地部署CUDA11+1080Ti+显卡24G实战方案

目录 一、ChatGLM3 模型 二、资源需求 三、部署安装 配置环境 安装过程 低成本配置部署方案 四、启动 ChatGLM3 五、功能测试 新鲜出炉&#xff0c;国产 GPT 版本迭代更新啦~清华团队刚刚发布ChatGLM3&#xff0c;恰逢云栖大会前百川也发布Baichuan2-192K&#xff0c;一…...

图片怎么转换成pdf?

图片怎么转换成pdf&#xff1f;图片可以转换成PDF格式文档吗&#xff1f;当然是可以的呀&#xff0c;当图片转换成PDF文件类型时&#xff0c;我们就会发现图片更加方便的打开分享和传播&#xff0c;而且还可以更加安全的保证我们的图片所有性。我们知道PDF文档是可以加密的&…...

【源码】医学影像PACS实现三维影像后处理等功能

医学影像诊断技术近年来取得了快速发展&#xff0c;包括高性能的影像检查设备的临床应用和数字信息技术的图像显示、存储、传输、处理、识别&#xff0c;这些技术使得计算机辅助检测和诊断成为可能&#xff0c;同时人工智能影像诊断也进入了人们的视野。这些技术进步提高了疾病…...

DOCTYPE是什么,有何作用、 使用方式、渲染模式、严格模式和怪异模式的区别?

前言 持续学习总结输出中&#xff0c;今天分享的是DOCTYPE是什么&#xff0c;有何作用、 使用方式、渲染模式、严格模式和怪异模式的区别。 DOCTYPE是什么&#xff0c;有何作用&#xff1f; DOCTYPE是HTML5的文档声明&#xff0c;通过它可以告诉浏览器&#xff0c;使用那个H…...

Go语言实现HTTP正向代理

文章目录 前言实现思路代码实现 前言 正向代理&#xff08;Forward Proxy&#xff09;是一种代理服务器的部署方式&#xff0c;它位于客户端和目标服务器之间&#xff0c;代表客户端向目标服务器发送请求。正向代理可以用来隐藏客户端的真实身份&#xff0c;以及在不同网络环境…...

第11章_数据处理之增删改

文章目录 1 插入数据1.1 实际问题1.2 方式 1&#xff1a;VALUES的方式添加1.3 方式2&#xff1a;将查询结果插入到表中演示代码 2 更新数据演示代码 3 删除数据演示代码 4 MySQL8新特性&#xff1a;计算列演示代码 5 综合案例课后练习 1 插入数据 1.1 实际问题 解决方式&#…...

数据时代的新引擎:数据治理与开发,揭秘数据领域的黄金机遇!

文章目录 一、数据时代的需求二、数据治理与开发三、案例分析四、黄金机遇《数据要素安全流通》《Python数据挖掘&#xff1a;入门、进阶与实用案例分析》《数据保护&#xff1a;工作负载的可恢复性 》《Data Mesh权威指南》《分布式统一大数据虚拟文件系统 Alluxio原理、技术与…...

使用 Golang 实现基于时间的一次性密码 TOTP

上篇文章详细讲解了一次性密码 OTP 相关的知识&#xff0c;基于时间的一次性密码 TOTP 是 OTP 的一种实现方式。这种方法的优点是不依赖网络&#xff0c;因此即使在没有网络的情况下&#xff0c;用户也可以生成密码。所以这种方式被许多流行的网站使用到双因子或多因子认证中&a…...

接口测试中缓存处理策略

在接口测试中&#xff0c;缓存处理策略是一个关键环节&#xff0c;直接影响测试结果的准确性和可靠性。合理的缓存处理策略能够确保测试环境的一致性&#xff0c;避免因缓存数据导致的测试偏差。以下是接口测试中常见的缓存处理策略及其详细说明&#xff1a; 一、缓存处理的核…...

手游刚开服就被攻击怎么办?如何防御DDoS?

开服初期是手游最脆弱的阶段&#xff0c;极易成为DDoS攻击的目标。一旦遭遇攻击&#xff0c;可能导致服务器瘫痪、玩家流失&#xff0c;甚至造成巨大经济损失。本文为开发者提供一套简洁有效的应急与防御方案&#xff0c;帮助快速应对并构建长期防护体系。 一、遭遇攻击的紧急应…...

日语学习-日语知识点小记-构建基础-JLPT-N4阶段(33):にする

日语学习-日语知识点小记-构建基础-JLPT-N4阶段(33):にする 1、前言(1)情况说明(2)工程师的信仰2、知识点(1) にする1,接续:名词+にする2,接续:疑问词+にする3,(A)は(B)にする。(2)復習:(1)复习句子(2)ために & ように(3)そう(4)にする3、…...

Cesium1.95中高性能加载1500个点

一、基本方式&#xff1a; 图标使用.png比.svg性能要好 <template><div id"cesiumContainer"></div><div class"toolbar"><button id"resetButton">重新生成点</button><span id"countDisplay&qu…...

【网络安全产品大调研系列】2. 体验漏洞扫描

前言 2023 年漏洞扫描服务市场规模预计为 3.06&#xff08;十亿美元&#xff09;。漏洞扫描服务市场行业预计将从 2024 年的 3.48&#xff08;十亿美元&#xff09;增长到 2032 年的 9.54&#xff08;十亿美元&#xff09;。预测期内漏洞扫描服务市场 CAGR&#xff08;增长率&…...

JUC笔记(上)-复习 涉及死锁 volatile synchronized CAS 原子操作

一、上下文切换 即使单核CPU也可以进行多线程执行代码&#xff0c;CPU会给每个线程分配CPU时间片来实现这个机制。时间片非常短&#xff0c;所以CPU会不断地切换线程执行&#xff0c;从而让我们感觉多个线程是同时执行的。时间片一般是十几毫秒(ms)。通过时间片分配算法执行。…...

自然语言处理——循环神经网络

自然语言处理——循环神经网络 循环神经网络应用到基于机器学习的自然语言处理任务序列到类别同步的序列到序列模式异步的序列到序列模式 参数学习和长程依赖问题基于门控的循环神经网络门控循环单元&#xff08;GRU&#xff09;长短期记忆神经网络&#xff08;LSTM&#xff09…...

Spring数据访问模块设计

前面我们已经完成了IoC和web模块的设计&#xff0c;聪明的码友立马就知道了&#xff0c;该到数据访问模块了&#xff0c;要不就这俩玩个6啊&#xff0c;查库势在必行&#xff0c;至此&#xff0c;它来了。 一、核心设计理念 1、痛点在哪 应用离不开数据&#xff08;数据库、No…...

精益数据分析(97/126):邮件营销与用户参与度的关键指标优化指南

精益数据分析&#xff08;97/126&#xff09;&#xff1a;邮件营销与用户参与度的关键指标优化指南 在数字化营销时代&#xff0c;邮件列表效度、用户参与度和网站性能等指标往往决定着创业公司的增长成败。今天&#xff0c;我们将深入解析邮件打开率、网站可用性、页面参与时…...

初探Service服务发现机制

1.Service简介 Service是将运行在一组Pod上的应用程序发布为网络服务的抽象方法。 主要功能&#xff1a;服务发现和负载均衡。 Service类型的包括ClusterIP类型、NodePort类型、LoadBalancer类型、ExternalName类型 2.Endpoints简介 Endpoints是一种Kubernetes资源&#xf…...