当前位置: 首页 > news >正文

Neuron Selectivity Transfer 原理与代码解析

paper:Like What You Like: Knowledge Distill via Neuron Selectivity Transfer

code:https://github.com/megvii-research/mdistiller/blob/master/mdistiller/distillers/NST.py

本文的创新点

本文探索了一种新型的知识 - 神经元的选择性知识,并将其传递给学生模型。这个模型背后的直觉相当简单:每个神经元本质上从原始输入提取与特定任务相关的某种模式,因此,如果一个神经元在某些区域或样本中被激活,这意味着这些区域或样本共享一些与该任务相关的特性。这种聚类知识对学生模型非常有价值,因为它为教室模型的最终预测结果提供了一种解释。因此,作者提出对齐教师模型和学生模型神经元选择模式的分布。

背景

Notions

假定教师模型和学生模型都是卷积神经网络,并将教师模型表示为 \(T\),学生模型表示为 \(S\)。CNN中某一层的输出特征图表示为 \(\mathbf{F}\in \mathbb{R}^{C\times HW}\),\(\mathbf{F}\) 的每一行即每个通道的特征图表示为 \(\mathbf{f}^{k\cdot}\in \mathbb{R}^{HW}\),\(\mathbf{F}\) 的每一列即每个空间位置沿所有通道的激活表示为 \(\mathbf{f}^{\cdot k}\in \mathbb{R}^{C}\)。\(\mathbf{F}_{T}\) 和 \(\mathbf{F}_{S}\) 分别表示教师模型和学生模型中某一层的特征图,不失一般性,我们假设 \(\mathbf{F}_{T}\) 和 \(\mathbf{F}_{S}\) 的大小相等,如果不相等则可以通过插值使它们相等。

Maximum Mean Discrepancy

最大平均差异(Maximum Mean Discrepancy,MMD)可以看作是一种概率分布间的距离度量,基于从它们采样的样本。假设我们有两组分别从分布 \(p\) 和 \(q\) 中采样的样本 \(\mathcal{X}=\left \{ x^{i} \right \}^{N}_{i=1} \) 和 \(\mathcal{Y}=\left \{ y^{j} \right \}^{M}_{j=1} \),那么 \(p\) 和 \(q\) 之间的MMD距离的平方如下

其中 \(\phi \left ( \cdot \right ) \) 是一个显示映射函数,通过进一步扩展并应用核技巧(kernel trick),式(1)可以表示为

其中 \(k(\cdot,\cdot)\) 是一个核函数,它将样本向量投射到一个更高维或是无限维的特征空间中。

最小化MMD等价于最小化 \(p\) 和 \(q\) 之间的距离。

方法介绍

Motivation

下面是两张叠加了热力图(heat map)的图片,其中热力图是根据VGG16 Conv5_3中的某个神经元得到的。从图中很容易看出这两个神经元具有很强的选择性:左图的神经元对猴子的脸部非常敏感,右侧的神经元对字符非常敏感。这种激活实际上意味着神经元的选择性,即什么样的输入可以触发这些神经元。换句话说,一个神经元高激活的区域可能共享一些与任务相关的相似特性,尽管这些特性可能对于人类没有非常直观的解释。

为了捕获这些相似特性,在学生模型中也应该有神经元模仿这些激活模式。因此本文提出了一种新的知识类型:神经元选择性(neuron selectivities)或者叫做共激活(co-activations),并将其传递给学生模型。

Formulation

每个通道的特征图 \(\mathbf{f}^{k\cdot}\) 示一个特定神经元的selectivity知识,我们定义Neuron Selectivity Transfer,NST损失如下

其中 \(\mathcal{H}\) 是交叉熵损失,\(\mathbb{y}_{true}\) 是ground truth标签,\(p_{S}\) 是学生模型的输出概率。

MMD损失可以扩展如下

其中用 \(l_{2}\) 标准化后的 \(\frac{\mathbf{f}^{k\cdot} }{\left \|\mathbf{f}^{k\cdot} \right \|_{2} } \) 替代了 \(\mathbf{f}^{k\cdot}\),这是为了使每个样本具有相同的尺度。最小化MMD损失就等价于将神经元的选择性知识传递给学生模型

Choice of Kernels

本文选用以下三种核函数

对于多项式核,本文设置 \(d=2,c=0\)。对于高斯核,\(\sigma ^{2}\) 设置为对应距离的平方。

代码解析

import torch
import torch.nn as nn
import torch.nn.functional as Ffrom ._base import Distillerdef nst_loss(g_s, g_t):return sum([single_stage_nst_loss(f_s, f_t) for f_s, f_t in zip(g_s, g_t)])def single_stage_nst_loss(f_s, f_t):s_H, t_H = f_s.shape[2], f_t.shape[2]if s_H > t_H:f_s = F.adaptive_avg_pool2d(f_s, (t_H, t_H))elif s_H < t_H:f_t = F.adaptive_avg_pool2d(f_t, (s_H, s_H))f_s = f_s.view(f_s.shape[0], f_s.shape[1], -1)  # (64,64,32,32)->(64,64,1024)f_s = F.normalize(f_s, dim=2)f_t = f_t.view(f_t.shape[0], f_t.shape[1], -1)f_t = F.normalize(f_t, dim=2)return (poly_kernel(f_t, f_t).mean().detach()+ poly_kernel(f_s, f_s).mean()- 2 * poly_kernel(f_s, f_t).mean())def poly_kernel(a, b):a = a.unsqueeze(1)  # (64,64,1024)->(64,1,64,1024)b = b.unsqueeze(2)  # (64,64,1024)->(64,64,1,1024)res = (a * b).sum(-1).pow(2)  # (64,64,64,1024)->(64,64,64)return resclass NST(Distiller):"""Like What You Like: Knowledge Distill via Neuron Selectivity Transfer"""def __init__(self, student, teacher, cfg):super(NST, self).__init__(student, teacher)self.ce_loss_weight = cfg.NST.LOSS.CE_WEIGHTself.feat_loss_weight = cfg.NST.LOSS.FEAT_WEIGHTdef forward_train(self, image, target, **kwargs):logits_student, feature_student = self.student(image)  # (64,3,32,32)with torch.no_grad():_, feature_teacher = self.teacher(image)# lossesloss_ce = self.ce_loss_weight * F.cross_entropy(logits_student, target)loss_feat = self.feat_loss_weight * nst_loss(feature_student["feats"][1:], feature_teacher["feats"][1:]# [torch.Size([64, 64, 32, 32]), torch.Size([64, 128, 16, 16]), torch.Size([64, 256, 8, 8])]# [torch.Size([64, 64, 32, 32]), torch.Size([64, 128, 16, 16]), torch.Size([64, 256, 8, 8])])losses_dict = {"loss_ce": loss_ce,"loss_kd": loss_feat,}return logits_student, losses_dict

相关文章:

Neuron Selectivity Transfer 原理与代码解析

paper&#xff1a;Like What You Like: Knowledge Distill via Neuron Selectivity Transfercode&#xff1a;https://github.com/megvii-research/mdistiller/blob/master/mdistiller/distillers/NST.py本文的创新点本文探索了一种新型的知识 - 神经元的选择性知识&#xff0c…...

vue项目关闭子页面,并更新父页面的数据

今天下午是一个非常痛苦的&#xff0c;想要实现一个功能&#xff1a; 父页面打开了一个新的页面&#xff08;浏览器打开一个新的窗口&#xff09;&#xff0c;并在子页面提交数据之后&#xff0c;父页面的数据要同步更新。 难点&#xff1a;父页面是一个表格列表&#xff0c;…...

第五次作业:修改redis的配置文件使得windows的图形界面客户端可以连接redis服务器

1. 安装 Redis 依赖 Redis 是基于 C语言编写的&#xff0c;因此首先需要安装 Redis 所需要的 gcc 依赖&#xff1a; yum install -y gcc tcl 2、上传安装文件 将下载好的 redis-6.2.7.tar.gz 安装包上传到虚拟机的任意目录&#xff08;一般推荐上传到 /usr/local/src目录&am…...

【11】FreeRTOS的延时函数

目录1.延时函数-介绍2.相对延时函数-解析2.1函数prvAddCurrentTaskToDelayedList-解析2.3滴答定时器中断服务函数xPortSysTickHandler()-解析2.4函数taskSWITCH_DELAYED_LISTS() -解析3.延时函数-实验4.总结1.延时函数-介绍 函数描述vTaskDelay()相对延时xTaskDelayUntil()绝对…...

Vue页面组成及常用属性

一、Vue页面组成 目前的项目中&#xff0c;Vue页面都是采用组件套娃的形式&#xff0c;由一个一个的组件拼接而成整个页面。一个组件就是一个.vue文件。组件通常由template和script两部分组成&#xff1a; template部分&#xff1a;页面展示的具体元素内容&#xff0c;比如文字…...

j6-IO流泛型集合多线程注解反射Socket

IO流 1 JDK API的使用 2 io简介 输入流用来读取in 输出流用来写出Out 在Java中&#xff0c;根据处理的数据单位不同&#xff0c;分为字节流和字符流 继承结构 java.io包&#xff1a; File 字节流&#xff1a;针对二进制文件 InputStream --FileInputStream --BufferedInputStre…...

创业能否成功?这几个因素很重要!

创业能否成功&#xff1f;这几个因素很重要&#xff01; 2023-02-22 19:06:53 大家好&#xff0c;我是你们熟悉而又陌生的好朋友梦龙&#xff0c;一个创业期的年轻人 上周末跟朋友一起钓鱼&#xff0c;他跟吐槽现在生意越来越难做。他是我身边可以说是创业很成功的例子&#…...

Bmp图片格式介绍

Bmp图片格式介绍 介绍 BMP是英文Bitmap&#xff08;位图&#xff09;的简写&#xff0c;它是Windows操作系统中的标准图像文件格式&#xff0c;能够被多种Windows应用程序所支持。随着Windows操作系统的流行与丰富的Windows应用程序的开发&#xff0c;BMP位图格式理所当然地被…...

Day4 leetcode

Day4 啊啊啊啊&#xff0c;什么玩意&#xff0c;第一次因为测评没过&#xff0c;约好的面试取消了&#xff0c;好尴尬呀&#xff0c;还有一家厦门的C/C电话面&#xff0c;是一家我还挺喜欢的公司&#xff0c;面的稀烂&#xff0c;只能安慰自己我现在手上至少有一个offer 有效括…...

Java设计模式-原型模式

1、定义 原型模式是一种创建型模式&#xff0c;用于创建重复的对象&#xff0c;并且保证性能。原型模式创建的对象是由原型对象自身创建的&#xff0c;是原型对象的一个克隆&#xff0c;和原型对象具有相同的结构和相同的值。 2、适用场景 创建对象时我们不仅仅需要创建一个新…...

2023年度最新且最详细Ubuntu的安装教程

目录 准备ISO镜像 1.去官网下载镜像&#xff0c;或者找有镜像源的网站下载 阿里云镜像站 2. 如果服务器是打算直接把底层系统安装为Ubuntu的话还需制作系统U盘 安装 1.新建虚拟机调整基础配置 2.打开电源&#xff0c;进入安装界面&#xff08;到这一步就跟u盘安装步骤一致…...

unix高级编程-fork之后父子进程共享文件

~/.bash_profile:每个用户都可使用该文件输入专用于自己使用的shell信息,当用户登录时,该文件仅仅执行一次!默认情况下,他设置一些环境变量,执行用户的.bashrc文件. 这里我看到的是centos的操作&#xff0c;但我用的是debian系的ubuntu&#xff0c;百度了一下发现debian的在这里…...

vue+echarts:柱状图横向展示和竖向展示

第021个点击查看专栏目录本示例是显示柱状图&#xff0c;分别是横向展示和纵向展示。关键是X轴和Y轴的参数互换。 文章目录横向示例效果横向示例源代码&#xff08;共81行&#xff09;纵向示例效果纵向示例源代码&#xff08;共81行&#xff09;相关资料参考专栏介绍横向示例效…...

SealOS 一键安装 K8S

环境 # 查看系统发行版 $ cat /etc/os-release NAME"CentOS Linux" VERSION"7 (Core)" ID"centos" ID_LIKE"rhel fedora" VERSION_ID"7" PRETTY_NAME"CentOS Linux 7 (Core)" ANSI_COLOR"0;31" CPE_NA…...

python网络编程详解

最近在看《UNIX网络编程 卷1》和《FREEBSD操作系统设计与实现》这两本书&#xff0c;我重点关注了TCP协议相关的内容&#xff0c;结合自己后台开发的经验&#xff0c;写下这篇文章&#xff0c;一方面是为了帮助有需要的人&#xff0c;更重要的是方便自己整理思路&#xff0c;加…...

ICRA 2023 | 首个联合暗光增强和深度估计的自监督方法STEPS

原文链接&#xff1a;https://www.techbeat.net/article-info?id4629 作者&#xff1a;郑宇鹏 本文中&#xff0c;我们提出了STEPS&#xff0c;第一个自监督框架来联合学习图像增强和夜间深度估计的方法。它可以同时训练图像增强网络和深度估计网络&#xff0c;并利用了图像增…...

基于react+nodejs+mysql开发用户中心,用于项管理加入的项目的用户认证

基于reactnodejsmysql开发用户中心&#xff0c;用于项管理加入的项目的用户认证用户中心功能介绍页面截图后端采用架构user表projects表project_user表仓库地址用户中心功能介绍 用户中心项目&#xff0c;用于统一管理用户信息、登录、注册、鉴权等 功能如下&#xff1a; 用…...

mapreduce与yarn

文章目录一、MapReduce1.1、MapReduce思想1.2、MapReduce实例进程1.3、MapReduce阶段组成1.4、MapReduce数据类型1.5、MapReduce关键类1.6、MapReduce执行流程1.6.1、Map阶段执行流程1.6.2、Map的shuffle阶段执行流程1.6.3、Reduce阶段执行流程1.7、MapReduce实例WordCount二、…...

鲲鹏云服务器上使用 traceroute 命令跟踪路由

traceroute 命令跟踪路由 它由遍布全球的几万局域网和数百万台计算机组成&#xff0c;并通过用于异构网络的TCP/IP协议进行网间通信。互联网中&#xff0c;信息的传送是通过网中许多段的传输介质和设备&#xff08;路由器&#xff0c;交换机&#xff0c;服务器&#xff0c;网关…...

代码随想录算法训练营第47天 || 198.打家劫舍 || 213.打家劫舍II || 337.打家劫舍III

代码随想录算法训练营第47天 || 198.打家劫舍 || 213.打家劫舍II || 337.打家劫舍III 198.打家劫舍 题目介绍 你是一个专业的小偷&#xff0c;计划偷窃沿街的房屋。每间房内都藏有一定的现金&#xff0c;影响你偷窃的唯一制约因素就是相邻的房屋装有相互连通的防盗系统&…...

Flink学习笔记:窗口

简介 langchain中提供的chain链组件&#xff0c;能够帮助我门快速的实现各个组件的流水线式的调用&#xff0c;和模型的问答 Chain链的组成 根据查阅的资料&#xff0c;langchain的chain链结构如下&#xff1a; $$Input \rightarrow Prompt \rightarrow Model \rightarrow Outp…...

qmc-decoder:快速解锁QQ音乐加密文件的终极指南

qmc-decoder&#xff1a;快速解锁QQ音乐加密文件的终极指南 【免费下载链接】qmc-decoder Fastest & best convert qmc 2 mp3 | flac tools 项目地址: https://gitcode.com/gh_mirrors/qm/qmc-decoder 你是否曾经从QQ音乐下载了心爱的歌曲&#xff0c;却发现只能在特…...

CogVideoX-2b CSDN专用版:5分钟部署你的本地AI视频导演

CogVideoX-2b CSDN专用版&#xff1a;5分钟部署你的本地AI视频导演 1. 从想法到画面&#xff0c;只差一个启动按钮 想象一下这样的场景&#xff1a;你脑子里闪过一个绝妙的视频创意——也许是“一只戴着宇航员头盔的柴犬在月球表面蹦跳”&#xff0c;也许是“赛博朋克都市的雨…...

【大语言模型基础(2)】自注意力与多头机制:QKV、缩放与因果掩码

文章目录摘要1. 为什么需要自注意力2. Q、K、V 到底是什么一个具体例子3. Attention 公式在干什么第一步&#xff1a;计算相似度第二步&#xff1a;做缩放第三步&#xff1a;softmax\mathrm{softmax}softmax 归一化第四步&#xff1a;对 ValueValueValue 做加权平均4. 为什么 G…...

小程序毕业设计基于微信小程序的校园跑腿小程序

前言 在校园生活节奏紧凑、同学们事务繁忙的当下&#xff0c;Spring Boot 基于微信小程序的校园跑腿小程序应运而生&#xff0c;为师生们提供了便捷高效的代劳服务&#xff0c;让校园生活更加从容有序。借助 Spring Boot 强大的后端支撑以及微信小程序无需安装、触手可及的优势…...

Youtu-Parsing开源模型实战:ONNX导出+TensorRT加速部署全流程

Youtu-Parsing开源模型实战&#xff1a;ONNX导出TensorRT加速部署全流程 1. 引言 如果你处理过大量的扫描文档、PDF文件或者图片资料&#xff0c;一定遇到过这样的烦恼&#xff1a;想把图片里的文字、表格、公式提取出来&#xff0c;手动操作不仅费时费力&#xff0c;还容易出…...

OpenClaw+Qwen3-32B科研助手:文献综述自动化实践

OpenClawQwen3-32B科研助手&#xff1a;文献综述自动化实践 1. 为什么需要自动化文献综述 作为一名计算机视觉方向的博士生&#xff0c;我每周需要阅读数十篇论文。传统的工作流程是&#xff1a;手动下载PDF→逐篇阅读→摘录关键观点→整理成表格。这个过程不仅耗时&#xff…...

SEO_新手必看的SEO优化入门教程与基础操作指南

<h2>SEO优化入门&#xff1a;为新手量身打造的指南</h2> <p>SEO优化&#xff0c;也就是搜索引擎优化&#xff0c;是一个让你的网站在搜索引擎结果中获得更高排名的过程。对于新手来说&#xff0c;SEO可能看起来有点复杂&#xff0c;但只要掌握了一些基础的操…...

OpenClaw多模型调度方案:GLM-4.7-Flash与本地小模型协同工作

OpenClaw多模型调度方案&#xff1a;GLM-4.7-Flash与本地小模型协同工作 1. 为什么需要多模型协同 去年冬天&#xff0c;当我第一次尝试用OpenClaw自动化处理周报时&#xff0c;发现一个尴尬的现象&#xff1a;用GLM-4.7-Flash这样的大模型处理简单表格整理&#xff0c;就像用…...

面向生产的Chatgpt5.4:系统集成、架构模式与成本优化深度拆解

对于计划将顶级AI能力深度集成至自身产品与工作流的团队而言&#xff0c;理解Gemini 3.1 Pro的系统级特性、集成模式与全生命周期成本至关重要。国内开发者可通过RskAi&#xff08;www.rsk.cn&#xff09;等聚合平台&#xff0c;以零成本、国内直访的方式完成前期技术验证与原型…...