深度学习中的模型蒸馏技术:实现流程、作用及实践案例
在深度学习领域,模型压缩与部署是一项重要的研究课题,而模型蒸馏便是其中一种有效的方法。
模型蒸馏(Model Distillation)最初由Hinton等人在2015年提出,其核心思想是通过知识迁移的方式,将一个复杂的大模型(教师模型)的知识传授给一个相对简单的小模型(学生模型),简单概括就是利用教师模型的预测概率分布作为软标签对学生模型进行训练,从而在保持较高预测性能的同时,极大地降低了模型的复杂性和计算资源需求,实现模型的轻量化和高效化。
模型蒸馏技术在计算机视觉、自然语言处理等领域均取得了显著的成功。

一. 模型蒸馏技术的实现流程
模型蒸馏技术的实现流程通常包括以下几个步骤:
- (1)准备教师模型和学生模型:首先,我们需要一个已经训练好的教师模型和一个待训练的学生模型。教师模型通常是一个性能较好但计算复杂度较高的模型,而学生模型则是一个计算复杂度较低的模型。
- (2)使用教师模型对数据集进行预测,得到每个样本的预测概率分布(软目标)。这些概率分布包含了模型对每个类别的置信度信息。
- (3)定义损失函数:损失函数用于衡量学生模型的输出与教师模型的输出之间的差异。在模型蒸馏中,我们通常会使用一种结合了软标签损失和硬标签损失的混合损失函数(通常这两个损失都是交叉熵损失)。软标签损失鼓励学生模型模仿教师模型的输出概率分布,这通常使用 KL 散度(Kullback-Leibler Divergence)来度量,而硬标签损失则鼓励学生模型正确预测真实标签。
- (4)训练学生模型:在训练过程中,我们将教师模型的输出作为监督信号,通过优化损失函数来更新学生模型的参数。这样,学生模型就可以从教师模型中学到有用的知识。KL 散度的计算涉及一个温度参数,该参数可以调整软目标的分布。温度较高会使分布更加平滑。在训练过程中,可以逐渐降低温度以提高蒸馏效果。
- (5)微调学生模型:在蒸馏过程完成后,可以对学生模型进行进一步的微调,以提高其性能表现。
二. 模型蒸馏的作用
-
模型轻量化:通过将大型模型的知识迁移到小型模型中,可以显著降低模型的复杂度和计算量,从而提高模型的运行效率。
-
加速推理,降低运行成本:简化后的模型在运行时速度更快,降低了计算成本和能耗,进一步的,减少了对硬件资源的需求,降低模型运行成本。
-
提升泛化能力:研究表明,模型蒸馏有可能帮助学生模型学习到教师模型中蕴含的泛化模式,提高其在未见过的数据上的表现。
-
迁移学习:模型蒸馏技术可以作为一种迁移学习方法,将在一个任务上训练好的模型知识迁移到另一个任务上。
-
促进模型的可解释性和可部署性:轻量化后的模型通常更加简洁明了,有利于理解和分析模型的决策过程,同时也更容易进行部署和应用。
三. 代码示例
以下是一个简单的模型蒸馏代码示例,使用PyTorch框架实现。在这个示例中,我们将使用一个预训练的ResNet-18模型作为教师模型,并使用一个简单的CNN模型作为学生模型。同时,我们将使用交叉熵损失函数和L2正则化项来优化学生模型的性能表现。
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, models, transforms# 定义教师模型和学生模型
teacher_model = models.resnet18(pretrained=True)
student_model = nn.Sequential(nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1),nn.ReLU(),nn.MaxPool2d(kernel_size=2, stride=2),nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),nn.ReLU(),nn.MaxPool2d(kernel_size=2, stride=2),nn.Flatten(),nn.Linear(128 * 7 * 7, 10)
)# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer_teacher = optim.SGD(teacher_model.parameters(), lr=0.01, momentum=0.9)
optimizer_student = optim.Adam(student_model.parameters(), lr=0.001)# 训练数据集
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5,), (0.5,))
])
trainset = datasets.MNIST('../data', train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True)# 蒸馏过程
for epoch in range(10):running_loss_teacher = 0.0running_loss_student = 0.0for inputs, labels in trainloader:# 教师模型的前向传播outputs_teacher = teacher_model(inputs)loss_teacher = criterion(outputs_teacher, labels)running_loss_teacher += loss_teacher.item()# 学生模型的前向传播outputs_student = student_model(inputs)loss_student = criterion(outputs_student, labels) + 0.1 * torch.sum((outputs_teacher - outputs_student) ** 2)running_loss_student += loss_student.item()# 反向传播和参数更新optimizer_teacher.zero_grad()optimizer_student.zero_grad()loss_teacher.backward()optimizer_teacher.step()loss_student.backward()optimizer_student.step()print(f'Epoch {epoch+1}/10 \t Loss Teacher: {running_loss_teacher / len(trainloader)} \t Loss Student: {running_loss_student / len(trainloader)}')
在这个示例中:
(1)首先定义了教师模型和学生模型,并初始化了相应的损失函数和优化器;
(2)然后,加载了MNIST手写数字数据集,并对其进行了预处理;
(3)接下来,进入蒸馏过程:对于每个批次的数据,首先使用教师模型进行前向传播并计算损失函数值;然后使用学生模型进行前向传播并计算损失函数值(同时加入了L2正则化项以鼓励学生模型学习教师模型的输出);
(4)最后,对损失函数值进行反向传播和参数更新:打印了每个批次的损失函数值以及每个epoch的平均损失函数值。
通过多次迭代训练后,我们可以得到一个性能较好且轻量化的学生模型。
四. 模型压缩和加速的其他技术
除了模型蒸馏技术外,还有一些类似的技术可以用于实现模型的压缩和加速,例如:
- 权重剪枝:通过删除神经网络中冗余的权重来减少模型的复杂度和计算量。具体来说,可以通过设定一个阈值来判断权重的重要性,然后将不重要的权重设置为零或删除。
- 模型量化:将神经网络中的权重和激活值从浮点数转换为低精度的整数表示,从而减少模型的存储空间和计算量。
- 知识蒸馏(Knowledge Distillation):这是一种特殊的模型蒸馏技术,其中教师模型和学生模型具有相同的架构,但参数不同。通过让学生模型学习教师模型的输出,可以实现模型的压缩和加速。
- 知识提炼(Knowledge Carving):选择性地从教师模型中抽取部分子结构用于构建学生模型。
- 网络剪枝(Network Pruning):通过删除神经网络中冗余的神经元或连接来减少模型的复杂度和计算量。具体来说,可以通过设定一个阈值来判断神经元或连接的重要性,然后将不重要的神经元或连接删除。
- 低秩分解(Low-Rank Factorization):将神经网络中的权重矩阵分解为两个低秩矩阵的乘积,从而减少模型的存储空间和计算量。这种方法可以应用于卷积层和全连接层等不同类型的神经网络层。
- 结构搜索(Neural Architecture Search):通过自动搜索最优的神经网络结构来实现模型的压缩和加速。这种方法可以根据特定任务的需求来定制适合的神经网络结构。
相关文章:
深度学习中的模型蒸馏技术:实现流程、作用及实践案例
在深度学习领域,模型压缩与部署是一项重要的研究课题,而模型蒸馏便是其中一种有效的方法。 模型蒸馏(Model Distillation)最初由Hinton等人在2015年提出,其核心思想是通过知识迁移的方式,将一个复杂的大模型…...
Java服务运行在Linux----维护常用命令
想起来哪些再添加上去 查看Java程序进程 jps -l 查出进程后根据pid 查询程序所在目录 pwdx 31313 根据端口查找PID 根据pid杀死程序 kill -p 31313 查看目录下所有包含9527的文件 grep -rn 9527 查看磁盘空间 查找文件名"nginx"文件或模糊查找"*nginx*&quo…...
夜晚水闸3D可视化:科技魔法点亮水利新纪元
在宁静的夜晚,当城市的霓虹灯逐渐暗淡,你是否曾想过,那些默默守护着城市安全的水闸,在科技的魔力下,正焕发出别样的光彩?今天,就让我们一起走进夜晚水闸3D模型,感受科技为水利带来的…...
从零开始的软件开发实战:互联网医院APP搭建详解
今天,笔者将以“从零开始的软件开发实战:互联网医院APP搭建详解”为主题,深入探讨互联网医院APP的开发过程和关键技术。 第一步:需求分析和规划 互联网医院APP的主要功能包括在线挂号、医生预约、医疗咨询、健康档案管理等。我们…...
【深度学习】YOLO检测器的发展历程
YOLO检测器的发展历程 YOLO(You Only Look Once)检测器是一种流行的实时对象检测系统,以其速度和准确性而闻名。自2016年首次推出以来,YOLO已经成为计算机视觉领域的一个重要里程碑。在本博客中,我们将探讨YOLO检测器…...
C语言--编译和链接
1.翻译环境 计算机能够执行二进制指令,我们的电脑不会直接执行C语言代码,编译器把代码转换成二进制的指令; 我们在VS上面写下printf("hello world");这行代码的时候,经过翻译环境,生成可执行的exe文件&…...
实现使用C#代码完成wifi的切换和连接功能
实现使用C#代码完成wifi的切换和连接功能 代码如下: namespace Wifi连接器 {public partial class Form1 : Form{private List<Wlan.WlanAvailableNetwork> NetWorkList new List<Wlan.WlanAvailableNetwork>();private WlanClient.WlanInterface Wla…...
Mac添加和关闭开机应用
文章目录 mac添加和关闭开机应用添加开机应用删除/查看 mac添加和关闭开机应用 添加开机应用 删除/查看 打开:系统设置–》通用–》登录项–》查看登录时打开列表 选中打开项目,点击“-”符号...
QT QInputDialog弹出消息框用法
使用QInputDialog类的静态方法来弹出对话框获取用户输入,缺点是不能自定义按钮的文字,默认为OK和Cancel: int main(int argc, char *argv[]) {QApplication a(argc, argv);bool isOK;QString text QInputDialog::getText(NULL, "Input …...
Unity3d使用Jenkins自动化打包(Windows)(一)
文章目录 前言一、安装JDK二、安装Jenkins三、Jenkins插件安装和使用基础操作 实战一基础操作 实战二 四、离线安装总结 前言 本篇旨在介绍基础的安装和操作流程,只需完成一次即可。后面的篇章将深入探讨如何利用Jenkins为Unity项目进行打包。 一、安装JDK 1、进入…...
HarmonyOS 应用开发之Want的定义与用途
Want 是一种对象,用于在应用组件之间传递信息。 其中,一种常见的使用场景是作为 startAbility() 方法的参数。例如,当UIAbilityA需要启动UIAbilityB并向UIAbilityB传递一些数据时,可以使用Want作为一个载体,将数据传递…...
enscan自动化主域名信息收集
enscan下载 Releases wgpsec/ENScan_GO (github.com) 能查的分类 实操: 首先打开linux 的虚拟机、 然后把下面这个粘贴到虚拟机中 解压后打开命令行 初始化 ./enscan-0.0.16-linux-amd64 -v 命令参数如下 oppo信息收集 运行下面代码时 先去配置文件把coo…...
分享全栈开发医疗小程序 -带源码课件(课件无解压密码),自行速度保存
课程介绍 分享全栈开发医疗小程序 -带源码课件(课件无解压密码),自行速度保存!看到好多坛友都在求SpringBoot2.X Vue UniAPP,全栈开发医疗小程序 - 带源码课件,我看了一下,要么链接过期&…...
基于YOLOv8与ByteTrack实现多目标跟踪——算法原理与代码实践
概述 在目标检测中,有许多经算法如Faster RCNN、SSD和YOLO的各种版本,这些算法利用深度学习技术,特别是卷积神经网络(CNN),能够高效地在图像中定位和识别不同类别的目标。Faster RCNN是一种基于区域提议的…...
C语言——函数练习程序
1.从终端接收一个数,封装一个函数判断该数是否为素数 #include <stdio.h>int pri(int num) {int i 0;for (i 2; i < num; i){if (num % i 0){return 0;break;}}if (i num-1){return 1;} }int main(void) {int num 0;int ret 0;scanf("%d", &num);…...
ssh 启动 docker 中 app, docker logs 无日志
ssh 启动 app, 标准输出被重定向 ssh 客户端,而不是 docker 容器的标准输出。只需要在启动时把app 标准输出重定向到 docker标准输出。 测试如下: 1.启动 docker docker run -it -p 60022:22 --name test test:v4 bash -c "service ssh restart;…...
WPF---1.入门学习
🎈个人主页:靓仔很忙i 💻B 站主页:👉B站👈 🎉欢迎 👍点赞✍评论⭐收藏 🤗收录专栏:WPF 🤝希望本文对您有所裨益,如有不足之处…...
Vue3 + Vite + TS + Element-Plus + Pinia项目(5)对axios进行封装
1、在src文件夹下新建config文件夹后,新建baseURL.ts文件,用来配置http主链接 2、在src文件夹下新建http文件夹后,新建request.ts文件,内容如下 import axios from "axios" import { ElMessage } from element-plus im…...
【Rust】——编写自动化测试(一)
🎃个人专栏: 🐬 算法设计与分析:算法设计与分析_IT闫的博客-CSDN博客 🐳Java基础:Java基础_IT闫的博客-CSDN博客 🐋c语言:c语言_IT闫的博客-CSDN博客 🐟MySQL:…...
第十二章 微服务核心(一)
一、Spring Boot 1.1 SpringBoot 构建方式 1.1.1 通过官网自动生成 进入官网:https://spring.io/,点击 Projects --> Spring Framework; 拖动滚动条到中间位置,点击 Spring Initializr 或者直接通过 https://start.spring…...
利用最小二乘法找圆心和半径
#include <iostream> #include <vector> #include <cmath> #include <Eigen/Dense> // 需安装Eigen库用于矩阵运算 // 定义点结构 struct Point { double x, y; Point(double x_, double y_) : x(x_), y(y_) {} }; // 最小二乘法求圆心和半径 …...
Chapter03-Authentication vulnerabilities
文章目录 1. 身份验证简介1.1 What is authentication1.2 difference between authentication and authorization1.3 身份验证机制失效的原因1.4 身份验证机制失效的影响 2. 基于登录功能的漏洞2.1 密码爆破2.2 用户名枚举2.3 有缺陷的暴力破解防护2.3.1 如果用户登录尝试失败次…...
MPNet:旋转机械轻量化故障诊断模型详解python代码复现
目录 一、问题背景与挑战 二、MPNet核心架构 2.1 多分支特征融合模块(MBFM) 2.2 残差注意力金字塔模块(RAPM) 2.2.1 空间金字塔注意力(SPA) 2.2.2 金字塔残差块(PRBlock) 2.3 分类器设计 三、关键技术突破 3.1 多尺度特征融合 3.2 轻量化设计策略 3.3 抗噪声…...
大数据学习栈记——Neo4j的安装与使用
本文介绍图数据库Neofj的安装与使用,操作系统:Ubuntu24.04,Neofj版本:2025.04.0。 Apt安装 Neofj可以进行官网安装:Neo4j Deployment Center - Graph Database & Analytics 我这里安装是添加软件源的方法 最新版…...
【第二十一章 SDIO接口(SDIO)】
第二十一章 SDIO接口 目录 第二十一章 SDIO接口(SDIO) 1 SDIO 主要功能 2 SDIO 总线拓扑 3 SDIO 功能描述 3.1 SDIO 适配器 3.2 SDIOAHB 接口 4 卡功能描述 4.1 卡识别模式 4.2 卡复位 4.3 操作电压范围确认 4.4 卡识别过程 4.5 写数据块 4.6 读数据块 4.7 数据流…...
Mac软件卸载指南,简单易懂!
刚和Adobe分手,它却总在Library里给你写"回忆录"?卸载的Final Cut Pro像电子幽灵般阴魂不散?总是会有残留文件,别慌!这份Mac软件卸载指南,将用最硬核的方式教你"数字分手术"࿰…...
Java 加密常用的各种算法及其选择
在数字化时代,数据安全至关重要,Java 作为广泛应用的编程语言,提供了丰富的加密算法来保障数据的保密性、完整性和真实性。了解这些常用加密算法及其适用场景,有助于开发者在不同的业务需求中做出正确的选择。 一、对称加密算法…...
《基于Apache Flink的流处理》笔记
思维导图 1-3 章 4-7章 8-11 章 参考资料 源码: https://github.com/streaming-with-flink 博客 https://flink.apache.org/bloghttps://www.ververica.com/blog 聚会及会议 https://flink-forward.orghttps://www.meetup.com/topics/apache-flink https://n…...
前端开发面试题总结-JavaScript篇(一)
文章目录 JavaScript高频问答一、作用域与闭包1.什么是闭包(Closure)?闭包有什么应用场景和潜在问题?2.解释 JavaScript 的作用域链(Scope Chain) 二、原型与继承3.原型链是什么?如何实现继承&a…...
Web 架构之 CDN 加速原理与落地实践
文章目录 一、思维导图二、正文内容(一)CDN 基础概念1. 定义2. 组成部分 (二)CDN 加速原理1. 请求路由2. 内容缓存3. 内容更新 (三)CDN 落地实践1. 选择 CDN 服务商2. 配置 CDN3. 集成到 Web 架构 …...
