当前位置: 首页 > news >正文

深度学习中的知识蒸馏

大家好,我是小青

今天给大家分享神经网络中的一个关键概念,知识蒸馏

知识蒸馏(Knowledge Distillation)是一种模型压缩技术,旨在将大型、复杂的模型(通常称为教师模型)的知识迁移到小型、简单的模型(学生模型)中。通过这种方式,学生模型可以在保持较高性能的同时,显著减少计算资源和存储需求。

知识蒸馏广泛用于深度学习领域,尤其在计算资源有限的场景(如移动端设备、嵌入式设备)中,用于加速推理、减少存储成本,同时尽可能保持模型性能。

图片

核心思想

知识蒸馏的核心思想是利用教师模型的输出(通常是软标签,即概率分布)来指导学生模型的训练。与传统的监督学习不同,知识蒸馏不仅使用真实标签(硬标签),还利用教师模型生成的软标签来传递更多的信息。

通过这种方式,学生模型不仅学习到数据的类别信息,还能够捕捉到类别之间的相似性和关系,从而提升其泛化能力。

关键技术与方法

知识蒸馏的核心在于让学生模型不仅仅学习真实标签,还学习教师模型提供的软标签,即教师模型输出的概率分布。这种方式可以让学生模型获得更丰富的信息。

传统神经网络的交叉熵损失

在传统的神经网络训练中,我们通常用交叉熵损失(Cross-Entropy Loss)来训练分类模型:

传统的交叉熵损失函数仅利用了数据的硬标签(hard labels),即 仅在真实类别处为 1,其他类别为 0,导致模型无法学习类别之间的相似性信息。

知识蒸馏的损失函数

在知识蒸馏中,教师模型提供了一种软标签(soft targets),即对所有类别的预测分布,而不仅仅是单个类别。

这些软标签由温度化 Softmax 得到。

知识蒸馏的优势

  • 模型压缩:学生模型通常比教师模型小得多,适合在资源受限的设备上部署。

  • 性能保持:通过知识蒸馏,学生模型能够在保持较高性能的同时,显著减少计算资源和存储需求。

  • 泛化能力:软标签提供了更多的信息,有助于学生模型更好地泛化。

知识蒸馏的变种

除了标准的知识蒸馏方法,研究人员还提出了多个改进版本。

  1. 自蒸馏(Self-Distillation):模型自身作为教师,将深层网络的知识蒸馏到浅层部分。

  2. 多教师蒸馏(Multi-Teacher Distillation):多个教师模型联合指导学生模型,融合不同教师的知识。

  3. 在线蒸馏(Online Distillation):教师模型和学生模型同步训练,而不是先训练教师模型再训练学生模型。

案例分享

下面是一个完整的知识蒸馏的示例代码,使用 PyTorch 训练一个教师模型并将其知识蒸馏到学生模型。

这里,我们采用 MNIST 数据集,教师模型使用一个较大的神经网络,而学生模型是一个较小的神经网络。

首先,定义教师模型和学生模型。

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
# 教师模型(较大的神经网络)
class TeacherModel(nn.Module):def __init__(self):super(TeacherModel, self).__init__()self.fc1 = nn.Linear(28 * 28, 512)self.fc2 = nn.Linear(512, 256)self.fc3 = nn.Linear(256, 10)def forward(self, x):x = x.view(-1, 28 * 28)x = F.relu(self.fc1(x))x = F.relu(self.fc2(x))x = self.fc3(x)  # 注意这里没有 Softmaxreturn x# 学生模型(较小的神经网络)
class StudentModel(nn.Module):def __init__(self):super(StudentModel, self).__init__()self.fc1 = nn.Linear(28 * 28, 128)self.fc2 = nn.Linear(128, 10)def forward(self, x):x = x.view(-1, 28 * 28)x = F.relu(self.fc1(x))x = self.fc2(x)  # 注意这里没有 Softmaxreturn x

然后加载数据集。

# 数据预处理
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])# 加载 MNIST 数据集
train_dataset = datasets.MNIST(root="./data", train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root="./data", train=False, download=True, transform=transform)train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=1000, shuffle=False)

训练教师模型

def train_teacher(model, train_loader, epochs=5, lr=0.001):optimizer = optim.Adam(model.parameters(), lr=lr)criterion = nn.CrossEntropyLoss()for epoch in range(epochs):model.train()total_loss = 0for images, labels in train_loader:optimizer.zero_grad()output = model(images)loss = criterion(output, labels)loss.backward()optimizer.step()total_loss += loss.item()print(f"Epoch [{epoch+1}/{epochs}], Loss: {total_loss / len(train_loader):.4f}")# 初始化并训练教师模型
teacher_model = TeacherModel()
train_teacher(teacher_model, train_loader)

知识蒸馏训练学生模型

def distillation_loss(student_logits, teacher_logits, labels, T=3.0, alpha=0.5):"""计算蒸馏损失,结合知识蒸馏损失和交叉熵损失"""soft_targets = F.softmax(teacher_logits / T, dim=1)  # 教师模型的软标签soft_predictions = F.log_softmax(student_logits / T, dim=1)  # 学生模型的预测distillation_loss = F.kl_div(soft_predictions, soft_targets, reduction="batchmean") * (T ** 2)ce_loss = F.cross_entropy(student_logits, labels)return alpha * ce_loss + (1 - alpha) * distillation_lossdef train_student_with_distillation(student_model, teacher_model, train_loader, epochs=5, lr=0.001, T=3.0, alpha=0.5):optimizer = optim.Adam(student_model.parameters(), lr=lr)teacher_model.eval()  # 设定教师模型为评估模式for epoch in range(epochs):student_model.train()total_loss = 0for images, labels in train_loader:optimizer.zero_grad()student_logits = student_model(images)with torch.no_grad():teacher_logits = teacher_model(images)  # 获取教师模型输出loss = distillation_loss(student_logits, teacher_logits, labels, T=T, alpha=alpha)loss.backward()optimizer.step()total_loss += loss.item()print(f"Epoch [{epoch+1}/{epochs}], Loss: {total_loss / len(train_loader):.4f}")# 初始化学生模型
student_model = StudentModel()
train_student_with_distillation(student_model, teacher_model, train_loader)

评估模型

def evaluate(model, test_loader):model.eval()correct = 0total = 0with torch.no_grad():for images, labels in test_loader:outputs = model(images)_, predicted = torch.max(outputs, 1)correct += (predicted == labels).sum().item()total += labels.size(0)accuracy = 100 * correct / totalreturn accuracy# 评估教师模型
teacher_acc = evaluate(teacher_model, test_loader)
print(f"教师模型准确率: {teacher_acc:.2f}%")# 评估知识蒸馏训练的学生模型
student_acc_distilled = evaluate(student_model, test_loader)
print(f"知识蒸馏训练的学生模型准确率: {student_acc_distilled:.2f}%")

图片

相关文章:

深度学习中的知识蒸馏

大家好,我是小青 今天给大家分享神经网络中的一个关键概念,知识蒸馏 知识蒸馏(Knowledge Distillation)是一种模型压缩技术,旨在将大型、复杂的模型(通常称为教师模型)的知识迁移到小型、简单…...

【Windows软件 - HeidiSQL】导出数据库

HeidSQL导出数据库 软件信息 具体操作 示例文件 选项分析 选项(1) 结果(1) -- -------------------------------------------------------- -- 主机: 127.0.0.1 -- 服务器版本: …...

苏剑林“闭门造车”之多模态思路浅谈思考

原文来自科学空间苏剑林 “闭门造车”之多模态思路浅谈(一):无损输入和“闭门造车”之多模态思路浅谈(二):自回归,学习后总结。 文章目录 “闭门造车”之多模态思路浅谈(一&#xff…...

绿联nas docker 安装 rocketmq 队列。亲测可用

首先拉取docker 镜像,所需镜像如下: 安装 nameserver docker run -d -p 9876:9876 \ -v ${HOME}/docker/software/rocketmq/data/namesrv/logs:/opt/logs \ -v ${HOME}/docker/software/rocketmq/data/namesrv/store:/opt/store \ --name rmqnamesrv \ …...

C++(23):unreachable

C++23在头文件 "><utility>定义了std::unreachable(),用于指示编译器,该段代码不应该被允许,因此编译器可以对该位置进行优化,如果一旦允许了该位置的代码,行为未定义: #include <utility> #include <iostream>using namespace std;int func(…...

初等数论--欧几里得算法

1. 定义 u 0 u 1 ∈ Z , u 1 ≠ 0 , u 1 ∤ u 0 u_0\ u_1\in Z,u_1 \ne0,u_1 \nmid u_0 u0​ u1​∈Z,u1​0,u1​∤u0​ 根据带余除法可得下面一系列等式 u 0 q 0 u 1 u 2 0 < u 2 < ∣ u 1 ∣ u 1 q 0 u 2 u 3 0 < u 3 < u 2 ⋯ u k − 1 q k − 1 u k …...

阿里云前端自动化部署流程指南

本文详细介绍从前端代码开发到阿里云 OSS/CDN 自动化部署的完整流程。 一、流程概览 © ivwdcwso (ID: u012172506) 1.1 部署流程图 #mermaid-svg-H1LBBmwTHAAF3QTL {font-family:"trebuchet ms",verdana,arial,sans-serif;font-size:16px;fill:#333;}#mermai…...

EXCEL解决IF函数“您已为此函数输入太多个参数”的报错

IF函数的基本结构是IF(条件, 值为真时的结果, 值为假时的结果)&#xff0c;所以标准的IF函数最多只能有三个参数。当用户输入的参数超过三个时&#xff0c;Excel就会报这个错误。比如多个IF语句叠加&#xff0c;但可能在嵌套的过程中没有正确关闭每个IF函数的括号&#xff0c;导…...

CAS单点登录(第7版)18.日志和审计

如有疑问&#xff0c;请看视频&#xff1a;CAS单点登录&#xff08;第7版&#xff09; 日志和审计 Logging 概述 Logging CAS 提供了一个日志记录工具&#xff0c;用于记录重要信息事件&#xff0c;如身份验证成功和失败;可以对其进行自定义以生成用于故障排除的其他信息。…...

2025年软件测试面试题大全(附答案+文档)

&#x1f345; 点击文末小卡片&#xff0c;免费获取软件测试全套资料&#xff0c;资料在手&#xff0c;涨薪更快 一、测试基础 1、测试策略或测试包括哪些&#xff0c;测试要覆盖哪些方面 UI、功能、性能、可靠性、易用性、兼容性、安全性、安装卸载 2、设计测试用例的办法 …...

太空飞船任务,生成一个地球发射、火星着陆以及下一次发射窗口返回地球的动画3D代码

import numpy as np import matplotlib.pyplot as plt from matplotlib.animation import FuncAnimation from mpl_toolkits.mplot3d import Axes3D# 天体参数设置&#xff08;简化模型&#xff09; AU 1.5e8 # 天文单位&#xff08;公里&#xff09; earth_orbital_radius …...

IDEA——Mac版快捷键

目录 按键含义常用组合代码生成快捷键&#xff1a;代码追踪快捷键&#xff1a;高效编辑快捷键&#xff1a;代码重构快捷键&#xff1a;工具类快捷键&#xff1a;常规文件操作快捷键&#xff1a; 按键含义 ⌘ command Command键&#xff08;⌘&#xff09;相当于Windows中的Con…...

智能体系统(AI Agent System)是什么?——从概念解析到企业数字化转型的全景落地及投资视角

文章目录 一、 前言1.1 背景介绍1.2 写作目的 二、 智能体系统及相关概念解析2.1 智能体系统定义2.2 关键概念区分2.2.1 自主代理&#xff08;Autonomous Agent&#xff09;2.2.2 多智能体系统&#xff08;MAS&#xff09;2.2.3 人工智能/机器学习&#xff08;AI/ML&#xff09…...

Vue 前端开发中的路由知识:从入门到精通

文章目录 引言1. Vue Router 简介1.1 安装 Vue Router1.2 配置 Vue Router1.3 在 Vue 实例中使用 Vue Router 2. 路由的基本用法2.1 路由映射2.2 路由视图2.3 路由链接 3. 动态路由3.1 动态路径参数3.2 访问动态参数3.3 响应路由参数的变化 4. 嵌套路由4.1 定义嵌套路由4.2 渲染…...

前端VUE+后端uwsgi 环境搭建

1整体架构 请求流程the web clinet--the web server->the socket->uwsgi--django 第一级的nginx并不是必须的&#xff0c;uwsgi完全可以完成整个的和浏览器交互的流程&#xff1b;在nginx上加上安全性或其他的限制&#xff0c;可以达到保护程序的作用&#xff1b;uWSGI本…...

I2C实践开发 ---【STM32-I2C-HDC1080温湿度采集系统】

I2C实践开发 — STM32-I2C-HDC1080温湿度采集系统 目录 I2C实践开发 --- STM32-I2C-HDC1080温湿度采集系统1. 引言2. 系统架构2.1 硬件架构2.2 软件架构 3. 代码分析3.1 I2C驱动文件 (i2c.h 和 i2c.c)3.2 HDC1080传感器驱动文件 (hdc1080.h 和 hdc1080.c) 4. 功能总结【HDC1080…...

【个人开发】deepspeed+Llama-factory 本地数据多卡Lora微调【完整教程】

文章目录 1.背景2.微调方式2.1 关键环境版本信息2.2 步骤2.2.1 下载llama-factory2.2.2 准备数据集2.2.3 微调模式2.2.3.1 zero-1微调2.2.3.2 zero-2微调2.2.3.3 zero-3微调2.2.3.4 单卡Lora微调 2.2.4 实验2.2.4.1 实验1&#xff1a;多GPU微调-zero12.2.4.2 实验2&#xff1a;…...

浏览器报错:无法访问此网站 无法找到xxx.xxx.net的DNS地址。正在诊断该问题。尝试运行Windows网络诊断。DNS_PROBE_STARTED

&#x1f91f;致敬读者 &#x1f7e9;感谢阅读&#x1f7e6;希望我的文章能帮到您&#x1f7ea;如有兴趣可点关注了解更多内容 &#x1f4d8;博主信息 点击标题&#x1f446;有惊喜 &#x1f4c3;文章前言 &#x1f537;文章均为学习和工作中整理的笔记&#xff0c;分享记录…...

【设计模式】 代理模式(静态代理、动态代理{JDK动态代理、JDK动态代理与CGLIB动态代理的区别})

代理模式 代理模式是一种结构型设计模式&#xff0c;它提供了一种替代访问的方法&#xff0c;即通过代理对象来间接访问目标对象。代理模式可以在不改变原始类代码的情况下&#xff0c;增加额外的功能&#xff0c;如权限控制、日志记录等。 静态代理 静态代理是指创建的或特…...

网络安全-攻击流程-用户层

用户层攻击主要针对操作系统中的用户空间应用程序及用户权限&#xff0c;利用软件漏洞、配置错误或用户行为弱点进行攻击。以下是常见的用户层攻击类型及其流程&#xff0c;以及防御措施&#xff1a; 1. 缓冲区溢出攻击 攻击流程&#xff1a; 目标识别&#xff1a;确定存在漏…...

51c自动驾驶~合集58

我自己的原文哦~ https://blog.51cto.com/whaosoft/13967107 #CCA-Attention 全局池化局部保留&#xff0c;CCA-Attention为LLM长文本建模带来突破性进展 琶洲实验室、华南理工大学联合推出关键上下文感知注意力机制&#xff08;CCA-Attention&#xff09;&#xff0c;…...

Java 语言特性(面试系列1)

一、面向对象编程 1. 封装&#xff08;Encapsulation&#xff09; 定义&#xff1a;将数据&#xff08;属性&#xff09;和操作数据的方法绑定在一起&#xff0c;通过访问控制符&#xff08;private、protected、public&#xff09;隐藏内部实现细节。示例&#xff1a; public …...

解决Ubuntu22.04 VMware失败的问题 ubuntu入门之二十八

现象1 打开VMware失败 Ubuntu升级之后打开VMware上报需要安装vmmon和vmnet&#xff0c;点击确认后如下提示 最终上报fail 解决方法 内核升级导致&#xff0c;需要在新内核下重新下载编译安装 查看版本 $ vmware -v VMware Workstation 17.5.1 build-23298084$ lsb_release…...

ffmpeg(四):滤镜命令

FFmpeg 的滤镜命令是用于音视频处理中的强大工具&#xff0c;可以完成剪裁、缩放、加水印、调色、合成、旋转、模糊、叠加字幕等复杂的操作。其核心语法格式一般如下&#xff1a; ffmpeg -i input.mp4 -vf "滤镜参数" output.mp4或者带音频滤镜&#xff1a; ffmpeg…...

第一篇:Agent2Agent (A2A) 协议——协作式人工智能的黎明

AI 领域的快速发展正在催生一个新时代&#xff0c;智能代理&#xff08;agents&#xff09;不再是孤立的个体&#xff0c;而是能够像一个数字团队一样协作。然而&#xff0c;当前 AI 生态系统的碎片化阻碍了这一愿景的实现&#xff0c;导致了“AI 巴别塔问题”——不同代理之间…...

【单片机期末】单片机系统设计

主要内容&#xff1a;系统状态机&#xff0c;系统时基&#xff0c;系统需求分析&#xff0c;系统构建&#xff0c;系统状态流图 一、题目要求 二、绘制系统状态流图 题目&#xff1a;根据上述描述绘制系统状态流图&#xff0c;注明状态转移条件及方向。 三、利用定时器产生时…...

【学习笔记】深入理解Java虚拟机学习笔记——第4章 虚拟机性能监控,故障处理工具

第2章 虚拟机性能监控&#xff0c;故障处理工具 4.1 概述 略 4.2 基础故障处理工具 4.2.1 jps:虚拟机进程状况工具 命令&#xff1a;jps [options] [hostid] 功能&#xff1a;本地虚拟机进程显示进程ID&#xff08;与ps相同&#xff09;&#xff0c;可同时显示主类&#x…...

Spring数据访问模块设计

前面我们已经完成了IoC和web模块的设计&#xff0c;聪明的码友立马就知道了&#xff0c;该到数据访问模块了&#xff0c;要不就这俩玩个6啊&#xff0c;查库势在必行&#xff0c;至此&#xff0c;它来了。 一、核心设计理念 1、痛点在哪 应用离不开数据&#xff08;数据库、No…...

听写流程自动化实践,轻量级教育辅助

随着智能教育工具的发展&#xff0c;越来越多的传统学习方式正在被数字化、自动化所优化。听写作为语文、英语等学科中重要的基础训练形式&#xff0c;也迎来了更高效的解决方案。 这是一款轻量但功能强大的听写辅助工具。它是基于本地词库与可选在线语音引擎构建&#xff0c;…...

Webpack性能优化:构建速度与体积优化策略

一、构建速度优化 1、​​升级Webpack和Node.js​​ ​​优化效果​​&#xff1a;Webpack 4比Webpack 3构建时间降低60%-98%。​​原因​​&#xff1a; V8引擎优化&#xff08;for of替代forEach、Map/Set替代Object&#xff09;。默认使用更快的md4哈希算法。AST直接从Loa…...