深度学习中的知识蒸馏
大家好,我是小青
今天给大家分享神经网络中的一个关键概念,知识蒸馏
知识蒸馏(Knowledge Distillation)是一种模型压缩技术,旨在将大型、复杂的模型(通常称为教师模型)的知识迁移到小型、简单的模型(学生模型)中。通过这种方式,学生模型可以在保持较高性能的同时,显著减少计算资源和存储需求。
知识蒸馏广泛用于深度学习领域,尤其在计算资源有限的场景(如移动端设备、嵌入式设备)中,用于加速推理、减少存储成本,同时尽可能保持模型性能。
核心思想
知识蒸馏的核心思想是利用教师模型的输出(通常是软标签,即概率分布)来指导学生模型的训练。与传统的监督学习不同,知识蒸馏不仅使用真实标签(硬标签),还利用教师模型生成的软标签来传递更多的信息。
通过这种方式,学生模型不仅学习到数据的类别信息,还能够捕捉到类别之间的相似性和关系,从而提升其泛化能力。
关键技术与方法
知识蒸馏的核心在于让学生模型不仅仅学习真实标签,还学习教师模型提供的软标签,即教师模型输出的概率分布。这种方式可以让学生模型获得更丰富的信息。
传统神经网络的交叉熵损失
在传统的神经网络训练中,我们通常用交叉熵损失(Cross-Entropy Loss)来训练分类模型:
传统的交叉熵损失函数仅利用了数据的硬标签(hard labels),即 仅在真实类别处为 1,其他类别为 0,导致模型无法学习类别之间的相似性信息。
知识蒸馏的损失函数
在知识蒸馏中,教师模型提供了一种软标签(soft targets),即对所有类别的预测分布,而不仅仅是单个类别。
这些软标签由温度化 Softmax 得到。
知识蒸馏的优势
-
模型压缩:学生模型通常比教师模型小得多,适合在资源受限的设备上部署。
-
性能保持:通过知识蒸馏,学生模型能够在保持较高性能的同时,显著减少计算资源和存储需求。
-
泛化能力:软标签提供了更多的信息,有助于学生模型更好地泛化。
知识蒸馏的变种
除了标准的知识蒸馏方法,研究人员还提出了多个改进版本。
-
自蒸馏(Self-Distillation):模型自身作为教师,将深层网络的知识蒸馏到浅层部分。
-
多教师蒸馏(Multi-Teacher Distillation):多个教师模型联合指导学生模型,融合不同教师的知识。
-
在线蒸馏(Online Distillation):教师模型和学生模型同步训练,而不是先训练教师模型再训练学生模型。
案例分享
下面是一个完整的知识蒸馏的示例代码,使用 PyTorch 训练一个教师模型并将其知识蒸馏到学生模型。
这里,我们采用 MNIST 数据集,教师模型使用一个较大的神经网络,而学生模型是一个较小的神经网络。
首先,定义教师模型和学生模型。
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
# 教师模型(较大的神经网络)
class TeacherModel(nn.Module):def __init__(self):super(TeacherModel, self).__init__()self.fc1 = nn.Linear(28 * 28, 512)self.fc2 = nn.Linear(512, 256)self.fc3 = nn.Linear(256, 10)def forward(self, x):x = x.view(-1, 28 * 28)x = F.relu(self.fc1(x))x = F.relu(self.fc2(x))x = self.fc3(x) # 注意这里没有 Softmaxreturn x# 学生模型(较小的神经网络)
class StudentModel(nn.Module):def __init__(self):super(StudentModel, self).__init__()self.fc1 = nn.Linear(28 * 28, 128)self.fc2 = nn.Linear(128, 10)def forward(self, x):x = x.view(-1, 28 * 28)x = F.relu(self.fc1(x))x = self.fc2(x) # 注意这里没有 Softmaxreturn x
然后加载数据集。
# 数据预处理
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])# 加载 MNIST 数据集
train_dataset = datasets.MNIST(root="./data", train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root="./data", train=False, download=True, transform=transform)train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=1000, shuffle=False)
训练教师模型
def train_teacher(model, train_loader, epochs=5, lr=0.001):optimizer = optim.Adam(model.parameters(), lr=lr)criterion = nn.CrossEntropyLoss()for epoch in range(epochs):model.train()total_loss = 0for images, labels in train_loader:optimizer.zero_grad()output = model(images)loss = criterion(output, labels)loss.backward()optimizer.step()total_loss += loss.item()print(f"Epoch [{epoch+1}/{epochs}], Loss: {total_loss / len(train_loader):.4f}")# 初始化并训练教师模型
teacher_model = TeacherModel()
train_teacher(teacher_model, train_loader)
知识蒸馏训练学生模型
def distillation_loss(student_logits, teacher_logits, labels, T=3.0, alpha=0.5):"""计算蒸馏损失,结合知识蒸馏损失和交叉熵损失"""soft_targets = F.softmax(teacher_logits / T, dim=1) # 教师模型的软标签soft_predictions = F.log_softmax(student_logits / T, dim=1) # 学生模型的预测distillation_loss = F.kl_div(soft_predictions, soft_targets, reduction="batchmean") * (T ** 2)ce_loss = F.cross_entropy(student_logits, labels)return alpha * ce_loss + (1 - alpha) * distillation_lossdef train_student_with_distillation(student_model, teacher_model, train_loader, epochs=5, lr=0.001, T=3.0, alpha=0.5):optimizer = optim.Adam(student_model.parameters(), lr=lr)teacher_model.eval() # 设定教师模型为评估模式for epoch in range(epochs):student_model.train()total_loss = 0for images, labels in train_loader:optimizer.zero_grad()student_logits = student_model(images)with torch.no_grad():teacher_logits = teacher_model(images) # 获取教师模型输出loss = distillation_loss(student_logits, teacher_logits, labels, T=T, alpha=alpha)loss.backward()optimizer.step()total_loss += loss.item()print(f"Epoch [{epoch+1}/{epochs}], Loss: {total_loss / len(train_loader):.4f}")# 初始化学生模型
student_model = StudentModel()
train_student_with_distillation(student_model, teacher_model, train_loader)
评估模型
def evaluate(model, test_loader):model.eval()correct = 0total = 0with torch.no_grad():for images, labels in test_loader:outputs = model(images)_, predicted = torch.max(outputs, 1)correct += (predicted == labels).sum().item()total += labels.size(0)accuracy = 100 * correct / totalreturn accuracy# 评估教师模型
teacher_acc = evaluate(teacher_model, test_loader)
print(f"教师模型准确率: {teacher_acc:.2f}%")# 评估知识蒸馏训练的学生模型
student_acc_distilled = evaluate(student_model, test_loader)
print(f"知识蒸馏训练的学生模型准确率: {student_acc_distilled:.2f}%")
相关文章:

深度学习中的知识蒸馏
大家好,我是小青 今天给大家分享神经网络中的一个关键概念,知识蒸馏 知识蒸馏(Knowledge Distillation)是一种模型压缩技术,旨在将大型、复杂的模型(通常称为教师模型)的知识迁移到小型、简单…...

【Windows软件 - HeidiSQL】导出数据库
HeidSQL导出数据库 软件信息 具体操作 示例文件 选项分析 选项(1) 结果(1) -- -------------------------------------------------------- -- 主机: 127.0.0.1 -- 服务器版本: …...

苏剑林“闭门造车”之多模态思路浅谈思考
原文来自科学空间苏剑林 “闭门造车”之多模态思路浅谈(一):无损输入和“闭门造车”之多模态思路浅谈(二):自回归,学习后总结。 文章目录 “闭门造车”之多模态思路浅谈(一ÿ…...

绿联nas docker 安装 rocketmq 队列。亲测可用
首先拉取docker 镜像,所需镜像如下: 安装 nameserver docker run -d -p 9876:9876 \ -v ${HOME}/docker/software/rocketmq/data/namesrv/logs:/opt/logs \ -v ${HOME}/docker/software/rocketmq/data/namesrv/store:/opt/store \ --name rmqnamesrv \ …...
C++(23):unreachable
C++23在头文件 "><utility>定义了std::unreachable(),用于指示编译器,该段代码不应该被允许,因此编译器可以对该位置进行优化,如果一旦允许了该位置的代码,行为未定义: #include <utility> #include <iostream>using namespace std;int func(…...
初等数论--欧几里得算法
1. 定义 u 0 u 1 ∈ Z , u 1 ≠ 0 , u 1 ∤ u 0 u_0\ u_1\in Z,u_1 \ne0,u_1 \nmid u_0 u0 u1∈Z,u10,u1∤u0 根据带余除法可得下面一系列等式 u 0 q 0 u 1 u 2 0 < u 2 < ∣ u 1 ∣ u 1 q 0 u 2 u 3 0 < u 3 < u 2 ⋯ u k − 1 q k − 1 u k …...
阿里云前端自动化部署流程指南
本文详细介绍从前端代码开发到阿里云 OSS/CDN 自动化部署的完整流程。 一、流程概览 © ivwdcwso (ID: u012172506) 1.1 部署流程图 #mermaid-svg-H1LBBmwTHAAF3QTL {font-family:"trebuchet ms",verdana,arial,sans-serif;font-size:16px;fill:#333;}#mermai…...

EXCEL解决IF函数“您已为此函数输入太多个参数”的报错
IF函数的基本结构是IF(条件, 值为真时的结果, 值为假时的结果),所以标准的IF函数最多只能有三个参数。当用户输入的参数超过三个时,Excel就会报这个错误。比如多个IF语句叠加,但可能在嵌套的过程中没有正确关闭每个IF函数的括号,导…...

CAS单点登录(第7版)18.日志和审计
如有疑问,请看视频:CAS单点登录(第7版) 日志和审计 Logging 概述 Logging CAS 提供了一个日志记录工具,用于记录重要信息事件,如身份验证成功和失败;可以对其进行自定义以生成用于故障排除的其他信息。…...

2025年软件测试面试题大全(附答案+文档)
🍅 点击文末小卡片,免费获取软件测试全套资料,资料在手,涨薪更快 一、测试基础 1、测试策略或测试包括哪些,测试要覆盖哪些方面 UI、功能、性能、可靠性、易用性、兼容性、安全性、安装卸载 2、设计测试用例的办法 …...

太空飞船任务,生成一个地球发射、火星着陆以及下一次发射窗口返回地球的动画3D代码
import numpy as np import matplotlib.pyplot as plt from matplotlib.animation import FuncAnimation from mpl_toolkits.mplot3d import Axes3D# 天体参数设置(简化模型) AU 1.5e8 # 天文单位(公里) earth_orbital_radius …...
IDEA——Mac版快捷键
目录 按键含义常用组合代码生成快捷键:代码追踪快捷键:高效编辑快捷键:代码重构快捷键:工具类快捷键:常规文件操作快捷键: 按键含义 ⌘ command Command键(⌘)相当于Windows中的Con…...

智能体系统(AI Agent System)是什么?——从概念解析到企业数字化转型的全景落地及投资视角
文章目录 一、 前言1.1 背景介绍1.2 写作目的 二、 智能体系统及相关概念解析2.1 智能体系统定义2.2 关键概念区分2.2.1 自主代理(Autonomous Agent)2.2.2 多智能体系统(MAS)2.2.3 人工智能/机器学习(AI/ML)…...
Vue 前端开发中的路由知识:从入门到精通
文章目录 引言1. Vue Router 简介1.1 安装 Vue Router1.2 配置 Vue Router1.3 在 Vue 实例中使用 Vue Router 2. 路由的基本用法2.1 路由映射2.2 路由视图2.3 路由链接 3. 动态路由3.1 动态路径参数3.2 访问动态参数3.3 响应路由参数的变化 4. 嵌套路由4.1 定义嵌套路由4.2 渲染…...

前端VUE+后端uwsgi 环境搭建
1整体架构 请求流程the web clinet--the web server->the socket->uwsgi--django 第一级的nginx并不是必须的,uwsgi完全可以完成整个的和浏览器交互的流程;在nginx上加上安全性或其他的限制,可以达到保护程序的作用;uWSGI本…...

I2C实践开发 ---【STM32-I2C-HDC1080温湿度采集系统】
I2C实践开发 — STM32-I2C-HDC1080温湿度采集系统 目录 I2C实践开发 --- STM32-I2C-HDC1080温湿度采集系统1. 引言2. 系统架构2.1 硬件架构2.2 软件架构 3. 代码分析3.1 I2C驱动文件 (i2c.h 和 i2c.c)3.2 HDC1080传感器驱动文件 (hdc1080.h 和 hdc1080.c) 4. 功能总结【HDC1080…...

【个人开发】deepspeed+Llama-factory 本地数据多卡Lora微调【完整教程】
文章目录 1.背景2.微调方式2.1 关键环境版本信息2.2 步骤2.2.1 下载llama-factory2.2.2 准备数据集2.2.3 微调模式2.2.3.1 zero-1微调2.2.3.2 zero-2微调2.2.3.3 zero-3微调2.2.3.4 单卡Lora微调 2.2.4 实验2.2.4.1 实验1:多GPU微调-zero12.2.4.2 实验2:…...

浏览器报错:无法访问此网站 无法找到xxx.xxx.net的DNS地址。正在诊断该问题。尝试运行Windows网络诊断。DNS_PROBE_STARTED
🤟致敬读者 🟩感谢阅读🟦希望我的文章能帮到您🟪如有兴趣可点关注了解更多内容 📘博主信息 点击标题👆有惊喜 📃文章前言 🔷文章均为学习和工作中整理的笔记,分享记录…...

【设计模式】 代理模式(静态代理、动态代理{JDK动态代理、JDK动态代理与CGLIB动态代理的区别})
代理模式 代理模式是一种结构型设计模式,它提供了一种替代访问的方法,即通过代理对象来间接访问目标对象。代理模式可以在不改变原始类代码的情况下,增加额外的功能,如权限控制、日志记录等。 静态代理 静态代理是指创建的或特…...
网络安全-攻击流程-用户层
用户层攻击主要针对操作系统中的用户空间应用程序及用户权限,利用软件漏洞、配置错误或用户行为弱点进行攻击。以下是常见的用户层攻击类型及其流程,以及防御措施: 1. 缓冲区溢出攻击 攻击流程: 目标识别:确定存在漏…...

51c自动驾驶~合集58
我自己的原文哦~ https://blog.51cto.com/whaosoft/13967107 #CCA-Attention 全局池化局部保留,CCA-Attention为LLM长文本建模带来突破性进展 琶洲实验室、华南理工大学联合推出关键上下文感知注意力机制(CCA-Attention),…...
Java 语言特性(面试系列1)
一、面向对象编程 1. 封装(Encapsulation) 定义:将数据(属性)和操作数据的方法绑定在一起,通过访问控制符(private、protected、public)隐藏内部实现细节。示例: public …...

解决Ubuntu22.04 VMware失败的问题 ubuntu入门之二十八
现象1 打开VMware失败 Ubuntu升级之后打开VMware上报需要安装vmmon和vmnet,点击确认后如下提示 最终上报fail 解决方法 内核升级导致,需要在新内核下重新下载编译安装 查看版本 $ vmware -v VMware Workstation 17.5.1 build-23298084$ lsb_release…...
ffmpeg(四):滤镜命令
FFmpeg 的滤镜命令是用于音视频处理中的强大工具,可以完成剪裁、缩放、加水印、调色、合成、旋转、模糊、叠加字幕等复杂的操作。其核心语法格式一般如下: ffmpeg -i input.mp4 -vf "滤镜参数" output.mp4或者带音频滤镜: ffmpeg…...

第一篇:Agent2Agent (A2A) 协议——协作式人工智能的黎明
AI 领域的快速发展正在催生一个新时代,智能代理(agents)不再是孤立的个体,而是能够像一个数字团队一样协作。然而,当前 AI 生态系统的碎片化阻碍了这一愿景的实现,导致了“AI 巴别塔问题”——不同代理之间…...

【单片机期末】单片机系统设计
主要内容:系统状态机,系统时基,系统需求分析,系统构建,系统状态流图 一、题目要求 二、绘制系统状态流图 题目:根据上述描述绘制系统状态流图,注明状态转移条件及方向。 三、利用定时器产生时…...
【学习笔记】深入理解Java虚拟机学习笔记——第4章 虚拟机性能监控,故障处理工具
第2章 虚拟机性能监控,故障处理工具 4.1 概述 略 4.2 基础故障处理工具 4.2.1 jps:虚拟机进程状况工具 命令:jps [options] [hostid] 功能:本地虚拟机进程显示进程ID(与ps相同),可同时显示主类&#x…...

Spring数据访问模块设计
前面我们已经完成了IoC和web模块的设计,聪明的码友立马就知道了,该到数据访问模块了,要不就这俩玩个6啊,查库势在必行,至此,它来了。 一、核心设计理念 1、痛点在哪 应用离不开数据(数据库、No…...

听写流程自动化实践,轻量级教育辅助
随着智能教育工具的发展,越来越多的传统学习方式正在被数字化、自动化所优化。听写作为语文、英语等学科中重要的基础训练形式,也迎来了更高效的解决方案。 这是一款轻量但功能强大的听写辅助工具。它是基于本地词库与可选在线语音引擎构建,…...
Webpack性能优化:构建速度与体积优化策略
一、构建速度优化 1、升级Webpack和Node.js 优化效果:Webpack 4比Webpack 3构建时间降低60%-98%。原因: V8引擎优化(for of替代forEach、Map/Set替代Object)。默认使用更快的md4哈希算法。AST直接从Loa…...