当前位置: 首页 > news >正文

模型蒸馏学习

知识蒸馏:获取学生网络和教师网络指定蒸馏位点输出特征并计算蒸馏 loss 的过程

知乎-mmrazor-模型蒸馏
知识蒸馏算法往往分为 reponse-based基于响应、feature-based基于特征 和 relation-based基于关系三类。
也可为 data-free KD、online KD、self KD(可视为一种特殊的 online KD)和比较经典的 offline KD

  • feature-based 方法以教师模型特征提取器产生的中间层特征为学习对象
    而ChannelWiseDivergence(cwd算法)使用的是预测之前的logistic特征图(在channel维度上取最大值,即为最终的预测结果),就是feature-based的蒸馏方法

mmrazor可以使用不同架构的student和teacher模型

  • 可以使用connector对不同维度进行对齐,以计算语义分割的蒸馏loss
  • 对于 feature-base 的方法,当学生和教师网络输出特征维度不同时,往往会对学生网络对应特征进行后处理以保证蒸馏 loss 正确计算(connector实现)

mmseg的模型,在deconde_head.conv_seg后拿到的特征图为logist的特征图,其内的元素都为小数,而非预测的0/1

cwd算法是一种什么样的蒸馏算法呢是data-free的吗,还是online的呢,还是offline的呢
适用于mmseg的cwd模型蒸馏配置文件(default_hook中的Student_CheckpointHook是自定义的hook,继承自mmegin中的CheekpointHook)
cwd算法:首先使用softmax归一化方法将每个通道的feature map转换成一个分布,然后最小化两个网络对应通道的Kullback Leibler (KL)散度。通过这样做,我们的方法着重于模拟网络间通道的软分布。特别的是,KL的差异使学习能够更多地关注通道图中最突出的区域,大概对应于语义分割最有用的信号

  • 现象:
    由于model capacity gap的存在,student往往弱于teacher模型,但也并不绝对,如果model本身的gap不是很离谱,student还是有超越teacher的可能的,因为student模型一般可以学习teacher模型蒸馏位点的特征和ground truth多种知识,学习效率会更高,如果本身student没有太大的问题,还是有机会学的更好的。
_base_ = ['mmseg::_base_/datasets/pascal_voc12.py','mmseg::_base_/schedules/schedule_160k.py','mmseg::_base_/default_runtime.py'
]# 模型的optim_wrapper,学习率和学习策略将来自于继承的schedule_160k,如果不改的话
# wandb的可视化设置在mmseg的default_runtime,也继承自mmseg# schedule_160k.py中的自动保存权重的部分
default_hooks = dict(_delete_=True,timer=dict(type='IterTimerHook'),logger=dict(type='LoggerHook', interval=100, log_metric_by_epoch=False),param_scheduler=dict(type='ParamSchedulerHook'),# 使用了自定义的Student_CheckpointHookcheckpoint=dict(type='Student_CheckpointHook', by_epoch=False, interval=-1, max_keep_ckpts=2, save_best=['mDice', 'mIoU']),# checkpoint中,interval=-1则不会保存least.pthsampler_seed=dict(type='DistSamplerSeedHook'),visualization=dict(type='SegVisualizationHook'))teacher_ckpt = '/root/autodl-tmp/all_workdir/mmseg_work_dir/baseline-convnext-tiny-upernet-rotate/best_mDice_iter_6800.pth'  # noqa: E501
teacher_cfg_path = 'mmseg::all_changed/baseline-convnext-tiny_upernet-rotate.py'  # noqa: E501student_cfg_path = 'mmseg::all_changed/pspnet_r18-d8_b16-160k_voc-material-512x512.py'  # noqa: E501
model = dict(_scope_='mmrazor',type='SingleTeacherDistill',architecture=dict(cfg_path=student_cfg_path, pretrained=False),teacher=dict(cfg_path=teacher_cfg_path, pretrained=False),teacher_ckpt=teacher_ckpt,distiller=dict(type='ConfigurableDistiller',distill_losses=dict(loss_cwd=dict(type='ChannelWiseDivergence', tau=1, loss_weight=5)),student_recorders=dict(logits=dict(type='ModuleOutputs', source='decode_head.conv_seg')),teacher_recorders=dict(logits=dict(type='ModuleOutputs', source='decode_head.conv_seg')),connectors=dict(loss_conv_stu=dict(type='ConvModuleConncetor', in_channel=2, out_channel=2, kernel_size=1, stride=1, padding=0,norm_cfg=dict(type='BN')),loss_conv_tea=dict(type='ConvModuleConncetor', in_channel=2, out_channel=2, kernel_size=3, stride=2, padding=1, padding_mode='circular',norm_cfg=dict(type='BN'))),loss_forward_mappings=dict(loss_cwd=dict(preds_S=dict(from_student=True, recorder='logits', connector='loss_conv_stu'), # 含义:从student_recorders(from_student=True)中读取名为logits的数据# 加上connnecor字段后,表示从student_recorders中读取名为logists的数据,而后将数据通过名为loss_conv_stu的连接器preds_T=dict(from_student=False, recorder='logits', connector='loss_conv_tea')))))# 从teacher_recorders中读取名为logits的数据,而后将数据通过名为loss_conv_tea的连接器# 而无论是loss_cwd、logits、loss_conv_stu、loss_conv_tea都是自定的名称find_unused_parameters = Truetrain_cfg = dict(type='IterBasedTrainLoop', max_iters=160000, val_interval=200)
val_cfg = dict(_delete_=True, type='mmrazor.SingleTeacherDistillValLoop')train_dataloader = dict(batch_size=16)  # 更改batch_size,否则会继承到pascal_voc12.py中的设置
# 这个16会作为teacher模型的推理batch和student模型的训练batchwork_dir = '/root/autodl-tmp/all_workdir/mmrazor_wokdir/distill/convnext-tiny-upernet_to_pspnet-r18'

模型蒸馏一

    1. 基于响应的KD(DKD ,FitNets):
      基于响应的KD以teacher模型的分类预测结果为目标知识,具体指的是分类器最后一个全连接层的输出(成为logits)。与最终的输出相比,logits没有经过softmax进行归一化,非目标类对应的输出值尚未被抑制。
      教师模型和学生模型之间的损失差异一般用KD散度,一般会用temperature(tau)大于1的参数对logits进行软化,以减小目标类和非目标类的预测值差异。
      logits具备的含义为模型判断当前样本为各类别的信心为多少
      1) logits提供的软标签信息,比one-hot的真实标签有着更高的熵值,从而提供了更多的信息量和数据之间更小梯度差异
      2) 软标签有着与标签平滑类似的效果,提高了模型的泛化能力
      3)除了gt标签外,还学习了软标签,使得模型学到了更多的知识,更倾向于学到不同的知识,优化方向更稳定
    1. 基于特征的KD(AB,AT,ofd,Factor Transfer):
      蒸馏位点位于模型中途获得的特征
      通常,teacher模型的通道大于学生通道,二者无法完全对齐,一般在学生的特征图后面接卷积,将两者在维度和通道上对齐,从而实现特征点的一一对应
      1)特征维度对齐,特征加权,mmrazor的connector模块的抽象
      2)知识定位,设计规则选出更为重要的教师特征
    1. 基于关系的KD(FSP, RKD):也使用特征,但计算不是特征点的一对一差异,而是特征关系的差异
      1)样本间关系蒸馏:在分类和分割中应用广泛,因为构建高质量的关系矩阵需要大量样本

总结:在语义分割中的cwd算法,可以看作是基于响应的KD,也可以看作是基于特征的KD,因为在传统的cwd算法中,使用的是在通过softmax之前的位置作为蒸馏位点,输出对应的特征图,去计算损失。

相关文章:

模型蒸馏学习

知识蒸馏:获取学生网络和教师网络指定蒸馏位点的输出特征并计算蒸馏 loss 的过程 知乎-mmrazor-模型蒸馏 知识蒸馏算法往往分为 reponse-based基于响应、feature-based基于特征 和 relation-based基于关系三类。 也可为 data-free KD、online KD、self KD&#xff…...

总结Kibana DevTools如何操作elasticsearch的常用语句

一、操作es的工具 ElasticSearch HeadKibana DevToolsElasticHQ 本文主要是总结Kibana DevTools操作es的语句。 二、搜索文档 1、根据ID查询单个记录 GET /course_idx/_doc/course:202、term 匹配"name"字段的值为"6789999"的文档 类似于sql语句中的等…...

【QT】QT自定义C++类

在使用Qt的ui设计时,Qt为我们提供了标准的类,但是在很多复杂工程中,标准的类并不能满足所有的需求,这时就需要我们自定义C类。 下面以自定义的QPushButton作一个很简单的例子。 先新建默认Qt Widgets Application项目 一、自定义…...

【多媒体文件格式】AVI、WAV、RIFF

AVI、RIFF AVI:Audio/Video Interleaved(音频视频交织/交错),用于采集、编辑、播放的RIFF文件。由Microsoft公司1992年11月推出,是Microsoft公司开发的一种符合RIFF文件规范的数字音频与视频文件格式,原先…...

AI创作系统ChatGPT商业运营系统源码+支持GPT4/支持ai绘画

一、AI创作系统 SparkAi创作系统是基于OpenAI很火的ChatGPT进行开发的Ai智能问答系统和Midjourney绘画系统,支持OpenAI-GPT全模型国内AI全模型。本期针对源码系统整体测试下来非常完美,可以说SparkAi是目前国内一款的ChatGPT对接OpenAI软件系统。那么如…...

JWT简介 JWT结构 JWT示例 前端添加JWT令牌功能 后端程序

目录 1. JWT简述 1.1 什么是JWT 1.2 为什么使用JWT 1.3 JWT结构 1.4 验证过程 2. JWT示例 2.1 后台程序 2.2 前台加入jwt令牌功能 1. JWT简述 1.1 什么是JWT Json web token (JWT), 是为了在网络应用环境间传递声明而执行的一种基于JSON的开放标准((RFC 7…...

Rust核心功能之一(所有权)

目录 1、什么是所有权? 1.1 所有权规则 1.2 变量作用域 1.3 String 类型 1.4 内存与分配 变量与数据交互的方式(一):移动 变量与数据交互的方式(二):克隆 只在栈上的数据:拷贝…...

跨域(CORS)和JWT 详解

跨域 (CORS) 概念 同源策略 (Same-Origin Policy) 同源策略是一项浏览器安全特性,它限制了一个网页中的脚本如何与另一个来源(域名、协议、端口)的资源进行交互。这对于防止跨站点请求伪造和数据泄露非常重要。 为什么需要跨域? 跨域问题通…...

前端框架Vue学习 ——(二)Vue常用指令

文章目录 常用指令 常用指令 指令: HTML 标签上带有 “v-” 前缀的特殊属性&#xff0c;不同指令具有不同含义。例如: v-if, v-for… 常用指令&#xff1a; v-bind&#xff1a;为 HTML 标签绑定属性值&#xff0c;如设置 href&#xff0c;css 样式等 <a v-bind:href"…...

Linux 指令心法(十四)`flash_erase` 擦除Flash存储器

文章目录 flash_erase 作用flash_erase的主要特点和使用场景flash_erase命令应用方法flash_erase命令可以解决哪些问题?flash_erase命令使用时注意事项 flash_erase 作用 这是一个用于擦除Flash存储器的命令。它可以擦除指定的Flash块或扇区&#xff0c;以便在写入新数据之前…...

GoLong的学习之路(二十一)进阶,语法之并发(go最重要的特点)(协程的主要用法)

并发编程在当前软件领域是一个非常重要的概念&#xff0c;随着CPU等硬件的发展&#xff0c;我们无一例外的想让我们的程序运行的快一点、再快一点。Go语言在语言层面天生支持并发&#xff0c;充分利用现代CPU的多核优势&#xff0c;这也是Go语言能够大范围流行的一个很重要的原…...

加快网站收录 3小时百度收录新站方法

加快网站收录 3小时百度收录新站方法 3小时百度收录新站方法说起来大家可能不相信&#xff0c;但这确实是真实的(该方法是通过技术提交&#xff0c;让百度快速抓取收录您的网站&#xff0c;不管你网站有没有备案&#xff0c;都能在短时间内被收录&#xff0c;要是你的网站迟迟不…...

GPT实战系列-ChatGLM3本地部署CUDA11+1080Ti+显卡24G实战方案

目录 一、ChatGLM3 模型 二、资源需求 三、部署安装 配置环境 安装过程 低成本配置部署方案 四、启动 ChatGLM3 五、功能测试 新鲜出炉&#xff0c;国产 GPT 版本迭代更新啦~清华团队刚刚发布ChatGLM3&#xff0c;恰逢云栖大会前百川也发布Baichuan2-192K&#xff0c;一…...

图片怎么转换成pdf?

图片怎么转换成pdf&#xff1f;图片可以转换成PDF格式文档吗&#xff1f;当然是可以的呀&#xff0c;当图片转换成PDF文件类型时&#xff0c;我们就会发现图片更加方便的打开分享和传播&#xff0c;而且还可以更加安全的保证我们的图片所有性。我们知道PDF文档是可以加密的&…...

【源码】医学影像PACS实现三维影像后处理等功能

医学影像诊断技术近年来取得了快速发展&#xff0c;包括高性能的影像检查设备的临床应用和数字信息技术的图像显示、存储、传输、处理、识别&#xff0c;这些技术使得计算机辅助检测和诊断成为可能&#xff0c;同时人工智能影像诊断也进入了人们的视野。这些技术进步提高了疾病…...

DOCTYPE是什么,有何作用、 使用方式、渲染模式、严格模式和怪异模式的区别?

前言 持续学习总结输出中&#xff0c;今天分享的是DOCTYPE是什么&#xff0c;有何作用、 使用方式、渲染模式、严格模式和怪异模式的区别。 DOCTYPE是什么&#xff0c;有何作用&#xff1f; DOCTYPE是HTML5的文档声明&#xff0c;通过它可以告诉浏览器&#xff0c;使用那个H…...

Go语言实现HTTP正向代理

文章目录 前言实现思路代码实现 前言 正向代理&#xff08;Forward Proxy&#xff09;是一种代理服务器的部署方式&#xff0c;它位于客户端和目标服务器之间&#xff0c;代表客户端向目标服务器发送请求。正向代理可以用来隐藏客户端的真实身份&#xff0c;以及在不同网络环境…...

第11章_数据处理之增删改

文章目录 1 插入数据1.1 实际问题1.2 方式 1&#xff1a;VALUES的方式添加1.3 方式2&#xff1a;将查询结果插入到表中演示代码 2 更新数据演示代码 3 删除数据演示代码 4 MySQL8新特性&#xff1a;计算列演示代码 5 综合案例课后练习 1 插入数据 1.1 实际问题 解决方式&#…...

数据时代的新引擎:数据治理与开发,揭秘数据领域的黄金机遇!

文章目录 一、数据时代的需求二、数据治理与开发三、案例分析四、黄金机遇《数据要素安全流通》《Python数据挖掘&#xff1a;入门、进阶与实用案例分析》《数据保护&#xff1a;工作负载的可恢复性 》《Data Mesh权威指南》《分布式统一大数据虚拟文件系统 Alluxio原理、技术与…...

使用 Golang 实现基于时间的一次性密码 TOTP

上篇文章详细讲解了一次性密码 OTP 相关的知识&#xff0c;基于时间的一次性密码 TOTP 是 OTP 的一种实现方式。这种方法的优点是不依赖网络&#xff0c;因此即使在没有网络的情况下&#xff0c;用户也可以生成密码。所以这种方式被许多流行的网站使用到双因子或多因子认证中&a…...

【Axure高保真原型】引导弹窗

今天和大家中分享引导弹窗的原型模板&#xff0c;载入页面后&#xff0c;会显示引导弹窗&#xff0c;适用于引导用户使用页面&#xff0c;点击完成后&#xff0c;会显示下一个引导弹窗&#xff0c;直至最后一个引导弹窗完成后进入首页。具体效果可以点击下方视频观看或打开下方…...

【SpringBoot】100、SpringBoot中使用自定义注解+AOP实现参数自动解密

在实际项目中,用户注册、登录、修改密码等操作,都涉及到参数传输安全问题。所以我们需要在前端对账户、密码等敏感信息加密传输,在后端接收到数据后能自动解密。 1、引入依赖 <dependency><groupId>org.springframework.boot</groupId><artifactId...

1.3 VSCode安装与环境配置

进入网址Visual Studio Code - Code Editing. Redefined下载.deb文件&#xff0c;然后打开终端&#xff0c;进入下载文件夹&#xff0c;键入命令 sudo dpkg -i code_1.100.3-1748872405_amd64.deb 在终端键入命令code即启动vscode 需要安装插件列表 1.Chinese简化 2.ros …...

Unity | AmplifyShaderEditor插件基础(第七集:平面波动shader)

目录 一、&#x1f44b;&#x1f3fb;前言 二、&#x1f608;sinx波动的基本原理 三、&#x1f608;波动起来 1.sinx节点介绍 2.vertexPosition 3.集成Vector3 a.节点Append b.连起来 4.波动起来 a.波动的原理 b.时间节点 c.sinx的处理 四、&#x1f30a;波动优化…...

Python Ovito统计金刚石结构数量

大家好,我是小马老师。 本文介绍python ovito方法统计金刚石结构的方法。 Ovito Identify diamond structure命令可以识别和统计金刚石结构,但是无法直接输出结构的变化情况。 本文使用python调用ovito包的方法,可以持续统计各步的金刚石结构,具体代码如下: from ovito…...

基于PHP的连锁酒店管理系统

有需要请加文章底部Q哦 可远程调试 基于PHP的连锁酒店管理系统 一 介绍 连锁酒店管理系统基于原生PHP开发&#xff0c;数据库mysql&#xff0c;前端bootstrap。系统角色分为用户和管理员。 技术栈 phpmysqlbootstrapphpstudyvscode 二 功能 用户 1 注册/登录/注销 2 个人中…...

tomcat入门

1 tomcat 是什么 apache开发的web服务器可以为java web程序提供运行环境tomcat是一款高效&#xff0c;稳定&#xff0c;易于使用的web服务器tomcathttp服务器Servlet服务器 2 tomcat 目录介绍 -bin #存放tomcat的脚本 -conf #存放tomcat的配置文件 ---catalina.policy #to…...

mac:大模型系列测试

0 MAC 前几天经过学生优惠以及国补17K入手了mac studio,然后这两天亲自测试其模型行运用能力如何&#xff0c;是否支持微调、推理速度等能力。下面进入正文。 1 mac 与 unsloth 按照下面的进行安装以及测试&#xff0c;是可以跑通文章里面的代码。训练速度也是很快的。 注意…...

pycharm 设置环境出错

pycharm 设置环境出错 pycharm 新建项目&#xff0c;设置虚拟环境&#xff0c;出错 pycharm 出错 Cannot open Local Failed to start [powershell.exe, -NoExit, -ExecutionPolicy, Bypass, -File, C:\Program Files\JetBrains\PyCharm 2024.1.3\plugins\terminal\shell-int…...

​​企业大模型服务合规指南:深度解析备案与登记制度​​

伴随AI技术的爆炸式发展&#xff0c;尤其是大模型&#xff08;LLM&#xff09;在各行各业的深度应用和整合&#xff0c;企业利用AI技术提升效率、创新服务的步伐不断加快。无论是像DeepSeek这样的前沿技术提供者&#xff0c;还是积极拥抱AI转型的传统企业&#xff0c;在面向公众…...