深度学习中的知识蒸馏
大家好,我是小青
今天给大家分享神经网络中的一个关键概念,知识蒸馏
知识蒸馏(Knowledge Distillation)是一种模型压缩技术,旨在将大型、复杂的模型(通常称为教师模型)的知识迁移到小型、简单的模型(学生模型)中。通过这种方式,学生模型可以在保持较高性能的同时,显著减少计算资源和存储需求。
知识蒸馏广泛用于深度学习领域,尤其在计算资源有限的场景(如移动端设备、嵌入式设备)中,用于加速推理、减少存储成本,同时尽可能保持模型性能。

核心思想
知识蒸馏的核心思想是利用教师模型的输出(通常是软标签,即概率分布)来指导学生模型的训练。与传统的监督学习不同,知识蒸馏不仅使用真实标签(硬标签),还利用教师模型生成的软标签来传递更多的信息。
通过这种方式,学生模型不仅学习到数据的类别信息,还能够捕捉到类别之间的相似性和关系,从而提升其泛化能力。
关键技术与方法
知识蒸馏的核心在于让学生模型不仅仅学习真实标签,还学习教师模型提供的软标签,即教师模型输出的概率分布。这种方式可以让学生模型获得更丰富的信息。
传统神经网络的交叉熵损失
在传统的神经网络训练中,我们通常用交叉熵损失(Cross-Entropy Loss)来训练分类模型:



传统的交叉熵损失函数仅利用了数据的硬标签(hard labels),即 仅在真实类别处为 1,其他类别为 0,导致模型无法学习类别之间的相似性信息。
知识蒸馏的损失函数
在知识蒸馏中,教师模型提供了一种软标签(soft targets),即对所有类别的预测分布,而不仅仅是单个类别。
这些软标签由温度化 Softmax 得到。


知识蒸馏的优势
-
模型压缩:学生模型通常比教师模型小得多,适合在资源受限的设备上部署。
-
性能保持:通过知识蒸馏,学生模型能够在保持较高性能的同时,显著减少计算资源和存储需求。
-
泛化能力:软标签提供了更多的信息,有助于学生模型更好地泛化。
知识蒸馏的变种
除了标准的知识蒸馏方法,研究人员还提出了多个改进版本。
-
自蒸馏(Self-Distillation):模型自身作为教师,将深层网络的知识蒸馏到浅层部分。
-
多教师蒸馏(Multi-Teacher Distillation):多个教师模型联合指导学生模型,融合不同教师的知识。
-
在线蒸馏(Online Distillation):教师模型和学生模型同步训练,而不是先训练教师模型再训练学生模型。
案例分享
下面是一个完整的知识蒸馏的示例代码,使用 PyTorch 训练一个教师模型并将其知识蒸馏到学生模型。
这里,我们采用 MNIST 数据集,教师模型使用一个较大的神经网络,而学生模型是一个较小的神经网络。
首先,定义教师模型和学生模型。
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
# 教师模型(较大的神经网络)
class TeacherModel(nn.Module):def __init__(self):super(TeacherModel, self).__init__()self.fc1 = nn.Linear(28 * 28, 512)self.fc2 = nn.Linear(512, 256)self.fc3 = nn.Linear(256, 10)def forward(self, x):x = x.view(-1, 28 * 28)x = F.relu(self.fc1(x))x = F.relu(self.fc2(x))x = self.fc3(x) # 注意这里没有 Softmaxreturn x# 学生模型(较小的神经网络)
class StudentModel(nn.Module):def __init__(self):super(StudentModel, self).__init__()self.fc1 = nn.Linear(28 * 28, 128)self.fc2 = nn.Linear(128, 10)def forward(self, x):x = x.view(-1, 28 * 28)x = F.relu(self.fc1(x))x = self.fc2(x) # 注意这里没有 Softmaxreturn x
然后加载数据集。
# 数据预处理
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])# 加载 MNIST 数据集
train_dataset = datasets.MNIST(root="./data", train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root="./data", train=False, download=True, transform=transform)train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=1000, shuffle=False)
训练教师模型
def train_teacher(model, train_loader, epochs=5, lr=0.001):optimizer = optim.Adam(model.parameters(), lr=lr)criterion = nn.CrossEntropyLoss()for epoch in range(epochs):model.train()total_loss = 0for images, labels in train_loader:optimizer.zero_grad()output = model(images)loss = criterion(output, labels)loss.backward()optimizer.step()total_loss += loss.item()print(f"Epoch [{epoch+1}/{epochs}], Loss: {total_loss / len(train_loader):.4f}")# 初始化并训练教师模型
teacher_model = TeacherModel()
train_teacher(teacher_model, train_loader)
知识蒸馏训练学生模型
def distillation_loss(student_logits, teacher_logits, labels, T=3.0, alpha=0.5):"""计算蒸馏损失,结合知识蒸馏损失和交叉熵损失"""soft_targets = F.softmax(teacher_logits / T, dim=1) # 教师模型的软标签soft_predictions = F.log_softmax(student_logits / T, dim=1) # 学生模型的预测distillation_loss = F.kl_div(soft_predictions, soft_targets, reduction="batchmean") * (T ** 2)ce_loss = F.cross_entropy(student_logits, labels)return alpha * ce_loss + (1 - alpha) * distillation_lossdef train_student_with_distillation(student_model, teacher_model, train_loader, epochs=5, lr=0.001, T=3.0, alpha=0.5):optimizer = optim.Adam(student_model.parameters(), lr=lr)teacher_model.eval() # 设定教师模型为评估模式for epoch in range(epochs):student_model.train()total_loss = 0for images, labels in train_loader:optimizer.zero_grad()student_logits = student_model(images)with torch.no_grad():teacher_logits = teacher_model(images) # 获取教师模型输出loss = distillation_loss(student_logits, teacher_logits, labels, T=T, alpha=alpha)loss.backward()optimizer.step()total_loss += loss.item()print(f"Epoch [{epoch+1}/{epochs}], Loss: {total_loss / len(train_loader):.4f}")# 初始化学生模型
student_model = StudentModel()
train_student_with_distillation(student_model, teacher_model, train_loader)
评估模型
def evaluate(model, test_loader):model.eval()correct = 0total = 0with torch.no_grad():for images, labels in test_loader:outputs = model(images)_, predicted = torch.max(outputs, 1)correct += (predicted == labels).sum().item()total += labels.size(0)accuracy = 100 * correct / totalreturn accuracy# 评估教师模型
teacher_acc = evaluate(teacher_model, test_loader)
print(f"教师模型准确率: {teacher_acc:.2f}%")# 评估知识蒸馏训练的学生模型
student_acc_distilled = evaluate(student_model, test_loader)
print(f"知识蒸馏训练的学生模型准确率: {student_acc_distilled:.2f}%")

相关文章:
深度学习中的知识蒸馏
大家好,我是小青 今天给大家分享神经网络中的一个关键概念,知识蒸馏 知识蒸馏(Knowledge Distillation)是一种模型压缩技术,旨在将大型、复杂的模型(通常称为教师模型)的知识迁移到小型、简单…...
【Windows软件 - HeidiSQL】导出数据库
HeidSQL导出数据库 软件信息 具体操作 示例文件 选项分析 选项(1) 结果(1) -- -------------------------------------------------------- -- 主机: 127.0.0.1 -- 服务器版本: …...
苏剑林“闭门造车”之多模态思路浅谈思考
原文来自科学空间苏剑林 “闭门造车”之多模态思路浅谈(一):无损输入和“闭门造车”之多模态思路浅谈(二):自回归,学习后总结。 文章目录 “闭门造车”之多模态思路浅谈(一ÿ…...
绿联nas docker 安装 rocketmq 队列。亲测可用
首先拉取docker 镜像,所需镜像如下: 安装 nameserver docker run -d -p 9876:9876 \ -v ${HOME}/docker/software/rocketmq/data/namesrv/logs:/opt/logs \ -v ${HOME}/docker/software/rocketmq/data/namesrv/store:/opt/store \ --name rmqnamesrv \ …...
C++(23):unreachable
C++23在头文件 "><utility>定义了std::unreachable(),用于指示编译器,该段代码不应该被允许,因此编译器可以对该位置进行优化,如果一旦允许了该位置的代码,行为未定义: #include <utility> #include <iostream>using namespace std;int func(…...
初等数论--欧几里得算法
1. 定义 u 0 u 1 ∈ Z , u 1 ≠ 0 , u 1 ∤ u 0 u_0\ u_1\in Z,u_1 \ne0,u_1 \nmid u_0 u0 u1∈Z,u10,u1∤u0 根据带余除法可得下面一系列等式 u 0 q 0 u 1 u 2 0 < u 2 < ∣ u 1 ∣ u 1 q 0 u 2 u 3 0 < u 3 < u 2 ⋯ u k − 1 q k − 1 u k …...
阿里云前端自动化部署流程指南
本文详细介绍从前端代码开发到阿里云 OSS/CDN 自动化部署的完整流程。 一、流程概览 © ivwdcwso (ID: u012172506) 1.1 部署流程图 #mermaid-svg-H1LBBmwTHAAF3QTL {font-family:"trebuchet ms",verdana,arial,sans-serif;font-size:16px;fill:#333;}#mermai…...
EXCEL解决IF函数“您已为此函数输入太多个参数”的报错
IF函数的基本结构是IF(条件, 值为真时的结果, 值为假时的结果),所以标准的IF函数最多只能有三个参数。当用户输入的参数超过三个时,Excel就会报这个错误。比如多个IF语句叠加,但可能在嵌套的过程中没有正确关闭每个IF函数的括号,导…...
CAS单点登录(第7版)18.日志和审计
如有疑问,请看视频:CAS单点登录(第7版) 日志和审计 Logging 概述 Logging CAS 提供了一个日志记录工具,用于记录重要信息事件,如身份验证成功和失败;可以对其进行自定义以生成用于故障排除的其他信息。…...
2025年软件测试面试题大全(附答案+文档)
🍅 点击文末小卡片,免费获取软件测试全套资料,资料在手,涨薪更快 一、测试基础 1、测试策略或测试包括哪些,测试要覆盖哪些方面 UI、功能、性能、可靠性、易用性、兼容性、安全性、安装卸载 2、设计测试用例的办法 …...
太空飞船任务,生成一个地球发射、火星着陆以及下一次发射窗口返回地球的动画3D代码
import numpy as np import matplotlib.pyplot as plt from matplotlib.animation import FuncAnimation from mpl_toolkits.mplot3d import Axes3D# 天体参数设置(简化模型) AU 1.5e8 # 天文单位(公里) earth_orbital_radius …...
IDEA——Mac版快捷键
目录 按键含义常用组合代码生成快捷键:代码追踪快捷键:高效编辑快捷键:代码重构快捷键:工具类快捷键:常规文件操作快捷键: 按键含义 ⌘ command Command键(⌘)相当于Windows中的Con…...
智能体系统(AI Agent System)是什么?——从概念解析到企业数字化转型的全景落地及投资视角
文章目录 一、 前言1.1 背景介绍1.2 写作目的 二、 智能体系统及相关概念解析2.1 智能体系统定义2.2 关键概念区分2.2.1 自主代理(Autonomous Agent)2.2.2 多智能体系统(MAS)2.2.3 人工智能/机器学习(AI/ML)…...
Vue 前端开发中的路由知识:从入门到精通
文章目录 引言1. Vue Router 简介1.1 安装 Vue Router1.2 配置 Vue Router1.3 在 Vue 实例中使用 Vue Router 2. 路由的基本用法2.1 路由映射2.2 路由视图2.3 路由链接 3. 动态路由3.1 动态路径参数3.2 访问动态参数3.3 响应路由参数的变化 4. 嵌套路由4.1 定义嵌套路由4.2 渲染…...
前端VUE+后端uwsgi 环境搭建
1整体架构 请求流程the web clinet--the web server->the socket->uwsgi--django 第一级的nginx并不是必须的,uwsgi完全可以完成整个的和浏览器交互的流程;在nginx上加上安全性或其他的限制,可以达到保护程序的作用;uWSGI本…...
I2C实践开发 ---【STM32-I2C-HDC1080温湿度采集系统】
I2C实践开发 — STM32-I2C-HDC1080温湿度采集系统 目录 I2C实践开发 --- STM32-I2C-HDC1080温湿度采集系统1. 引言2. 系统架构2.1 硬件架构2.2 软件架构 3. 代码分析3.1 I2C驱动文件 (i2c.h 和 i2c.c)3.2 HDC1080传感器驱动文件 (hdc1080.h 和 hdc1080.c) 4. 功能总结【HDC1080…...
【个人开发】deepspeed+Llama-factory 本地数据多卡Lora微调【完整教程】
文章目录 1.背景2.微调方式2.1 关键环境版本信息2.2 步骤2.2.1 下载llama-factory2.2.2 准备数据集2.2.3 微调模式2.2.3.1 zero-1微调2.2.3.2 zero-2微调2.2.3.3 zero-3微调2.2.3.4 单卡Lora微调 2.2.4 实验2.2.4.1 实验1:多GPU微调-zero12.2.4.2 实验2:…...
浏览器报错:无法访问此网站 无法找到xxx.xxx.net的DNS地址。正在诊断该问题。尝试运行Windows网络诊断。DNS_PROBE_STARTED
🤟致敬读者 🟩感谢阅读🟦希望我的文章能帮到您🟪如有兴趣可点关注了解更多内容 📘博主信息 点击标题👆有惊喜 📃文章前言 🔷文章均为学习和工作中整理的笔记,分享记录…...
【设计模式】 代理模式(静态代理、动态代理{JDK动态代理、JDK动态代理与CGLIB动态代理的区别})
代理模式 代理模式是一种结构型设计模式,它提供了一种替代访问的方法,即通过代理对象来间接访问目标对象。代理模式可以在不改变原始类代码的情况下,增加额外的功能,如权限控制、日志记录等。 静态代理 静态代理是指创建的或特…...
网络安全-攻击流程-用户层
用户层攻击主要针对操作系统中的用户空间应用程序及用户权限,利用软件漏洞、配置错误或用户行为弱点进行攻击。以下是常见的用户层攻击类型及其流程,以及防御措施: 1. 缓冲区溢出攻击 攻击流程: 目标识别:确定存在漏…...
【Java学习笔记】Arrays类
Arrays 类 1. 导入包:import java.util.Arrays 2. 常用方法一览表 方法描述Arrays.toString()返回数组的字符串形式Arrays.sort()排序(自然排序和定制排序)Arrays.binarySearch()通过二分搜索法进行查找(前提:数组是…...
大型活动交通拥堵治理的视觉算法应用
大型活动下智慧交通的视觉分析应用 一、背景与挑战 大型活动(如演唱会、马拉松赛事、高考中考等)期间,城市交通面临瞬时人流车流激增、传统摄像头模糊、交通拥堵识别滞后等问题。以演唱会为例,暖城商圈曾因观众集中离场导致周边…...
【Redis技术进阶之路】「原理分析系列开篇」分析客户端和服务端网络诵信交互实现(服务端执行命令请求的过程 - 初始化服务器)
服务端执行命令请求的过程 【专栏简介】【技术大纲】【专栏目标】【目标人群】1. Redis爱好者与社区成员2. 后端开发和系统架构师3. 计算机专业的本科生及研究生 初始化服务器1. 初始化服务器状态结构初始化RedisServer变量 2. 加载相关系统配置和用户配置参数定制化配置参数案…...
连锁超市冷库节能解决方案:如何实现超市降本增效
在连锁超市冷库运营中,高能耗、设备损耗快、人工管理低效等问题长期困扰企业。御控冷库节能解决方案通过智能控制化霜、按需化霜、实时监控、故障诊断、自动预警、远程控制开关六大核心技术,实现年省电费15%-60%,且不改动原有装备、安装快捷、…...
【论文笔记】若干矿井粉尘检测算法概述
总的来说,传统机器学习、传统机器学习与深度学习的结合、LSTM等算法所需要的数据集来源于矿井传感器测量的粉尘浓度,通过建立回归模型来预测未来矿井的粉尘浓度。传统机器学习算法性能易受数据中极端值的影响。YOLO等计算机视觉算法所需要的数据集来源于…...
Mac软件卸载指南,简单易懂!
刚和Adobe分手,它却总在Library里给你写"回忆录"?卸载的Final Cut Pro像电子幽灵般阴魂不散?总是会有残留文件,别慌!这份Mac软件卸载指南,将用最硬核的方式教你"数字分手术"࿰…...
相机从app启动流程
一、流程框架图 二、具体流程分析 1、得到cameralist和对应的静态信息 目录如下: 重点代码分析: 启动相机前,先要通过getCameraIdList获取camera的个数以及id,然后可以通过getCameraCharacteristics获取对应id camera的capabilities(静态信息)进行一些openCamera前的…...
ServerTrust 并非唯一
NSURLAuthenticationMethodServerTrust 只是 authenticationMethod 的冰山一角 要理解 NSURLAuthenticationMethodServerTrust, 首先要明白它只是 authenticationMethod 的选项之一, 并非唯一 1 先厘清概念 点说明authenticationMethodURLAuthenticationChallenge.protectionS…...
【碎碎念】宝可梦 Mesh GO : 基于MESH网络的口袋妖怪 宝可梦GO游戏自组网系统
目录 游戏说明《宝可梦 Mesh GO》 —— 局域宝可梦探索Pokmon GO 类游戏核心理念应用场景Mesh 特性 宝可梦玩法融合设计游戏构想要素1. 地图探索(基于物理空间 广播范围)2. 野生宝可梦生成与广播3. 对战系统4. 道具与通信5. 延伸玩法 安全性设计 技术选…...
智能AI电话机器人系统的识别能力现状与发展水平
一、引言 随着人工智能技术的飞速发展,AI电话机器人系统已经从简单的自动应答工具演变为具备复杂交互能力的智能助手。这类系统结合了语音识别、自然语言处理、情感计算和机器学习等多项前沿技术,在客户服务、营销推广、信息查询等领域发挥着越来越重要…...
