深度学习中的模型蒸馏技术:实现流程、作用及实践案例
在深度学习领域,模型压缩与部署是一项重要的研究课题,而模型蒸馏便是其中一种有效的方法。
模型蒸馏(Model Distillation)最初由Hinton等人在2015年提出,其核心思想是通过知识迁移的方式,将一个复杂的大模型(教师模型)的知识传授给一个相对简单的小模型(学生模型),简单概括就是利用教师模型的预测概率分布作为软标签对学生模型进行训练,从而在保持较高预测性能的同时,极大地降低了模型的复杂性和计算资源需求,实现模型的轻量化和高效化。
模型蒸馏技术在计算机视觉、自然语言处理等领域均取得了显著的成功。
一. 模型蒸馏技术的实现流程
模型蒸馏技术的实现流程通常包括以下几个步骤:
- (1)准备教师模型和学生模型:首先,我们需要一个已经训练好的教师模型和一个待训练的学生模型。教师模型通常是一个性能较好但计算复杂度较高的模型,而学生模型则是一个计算复杂度较低的模型。
- (2)使用教师模型对数据集进行预测,得到每个样本的预测概率分布(软目标)。这些概率分布包含了模型对每个类别的置信度信息。
- (3)定义损失函数:损失函数用于衡量学生模型的输出与教师模型的输出之间的差异。在模型蒸馏中,我们通常会使用一种结合了软标签损失和硬标签损失的混合损失函数(通常这两个损失都是交叉熵损失)。软标签损失鼓励学生模型模仿教师模型的输出概率分布,这通常使用 KL 散度(Kullback-Leibler Divergence)来度量,而硬标签损失则鼓励学生模型正确预测真实标签。
- (4)训练学生模型:在训练过程中,我们将教师模型的输出作为监督信号,通过优化损失函数来更新学生模型的参数。这样,学生模型就可以从教师模型中学到有用的知识。KL 散度的计算涉及一个温度参数,该参数可以调整软目标的分布。温度较高会使分布更加平滑。在训练过程中,可以逐渐降低温度以提高蒸馏效果。
- (5)微调学生模型:在蒸馏过程完成后,可以对学生模型进行进一步的微调,以提高其性能表现。
二. 模型蒸馏的作用
-
模型轻量化:通过将大型模型的知识迁移到小型模型中,可以显著降低模型的复杂度和计算量,从而提高模型的运行效率。
-
加速推理,降低运行成本:简化后的模型在运行时速度更快,降低了计算成本和能耗,进一步的,减少了对硬件资源的需求,降低模型运行成本。
-
提升泛化能力:研究表明,模型蒸馏有可能帮助学生模型学习到教师模型中蕴含的泛化模式,提高其在未见过的数据上的表现。
-
迁移学习:模型蒸馏技术可以作为一种迁移学习方法,将在一个任务上训练好的模型知识迁移到另一个任务上。
-
促进模型的可解释性和可部署性:轻量化后的模型通常更加简洁明了,有利于理解和分析模型的决策过程,同时也更容易进行部署和应用。
三. 代码示例
以下是一个简单的模型蒸馏代码示例,使用PyTorch框架实现。在这个示例中,我们将使用一个预训练的ResNet-18模型作为教师模型,并使用一个简单的CNN模型作为学生模型。同时,我们将使用交叉熵损失函数和L2正则化项来优化学生模型的性能表现。
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, models, transforms# 定义教师模型和学生模型
teacher_model = models.resnet18(pretrained=True)
student_model = nn.Sequential(nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1),nn.ReLU(),nn.MaxPool2d(kernel_size=2, stride=2),nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),nn.ReLU(),nn.MaxPool2d(kernel_size=2, stride=2),nn.Flatten(),nn.Linear(128 * 7 * 7, 10)
)# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer_teacher = optim.SGD(teacher_model.parameters(), lr=0.01, momentum=0.9)
optimizer_student = optim.Adam(student_model.parameters(), lr=0.001)# 训练数据集
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5,), (0.5,))
])
trainset = datasets.MNIST('../data', train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True)# 蒸馏过程
for epoch in range(10):running_loss_teacher = 0.0running_loss_student = 0.0for inputs, labels in trainloader:# 教师模型的前向传播outputs_teacher = teacher_model(inputs)loss_teacher = criterion(outputs_teacher, labels)running_loss_teacher += loss_teacher.item()# 学生模型的前向传播outputs_student = student_model(inputs)loss_student = criterion(outputs_student, labels) + 0.1 * torch.sum((outputs_teacher - outputs_student) ** 2)running_loss_student += loss_student.item()# 反向传播和参数更新optimizer_teacher.zero_grad()optimizer_student.zero_grad()loss_teacher.backward()optimizer_teacher.step()loss_student.backward()optimizer_student.step()print(f'Epoch {epoch+1}/10 \t Loss Teacher: {running_loss_teacher / len(trainloader)} \t Loss Student: {running_loss_student / len(trainloader)}')
在这个示例中:
(1)首先定义了教师模型和学生模型,并初始化了相应的损失函数和优化器;
(2)然后,加载了MNIST手写数字数据集,并对其进行了预处理;
(3)接下来,进入蒸馏过程:对于每个批次的数据,首先使用教师模型进行前向传播并计算损失函数值;然后使用学生模型进行前向传播并计算损失函数值(同时加入了L2正则化项以鼓励学生模型学习教师模型的输出);
(4)最后,对损失函数值进行反向传播和参数更新:打印了每个批次的损失函数值以及每个epoch的平均损失函数值。
通过多次迭代训练后,我们可以得到一个性能较好且轻量化的学生模型。
四. 模型压缩和加速的其他技术
除了模型蒸馏技术外,还有一些类似的技术可以用于实现模型的压缩和加速,例如:
- 权重剪枝:通过删除神经网络中冗余的权重来减少模型的复杂度和计算量。具体来说,可以通过设定一个阈值来判断权重的重要性,然后将不重要的权重设置为零或删除。
- 模型量化:将神经网络中的权重和激活值从浮点数转换为低精度的整数表示,从而减少模型的存储空间和计算量。
- 知识蒸馏(Knowledge Distillation):这是一种特殊的模型蒸馏技术,其中教师模型和学生模型具有相同的架构,但参数不同。通过让学生模型学习教师模型的输出,可以实现模型的压缩和加速。
- 知识提炼(Knowledge Carving):选择性地从教师模型中抽取部分子结构用于构建学生模型。
- 网络剪枝(Network Pruning):通过删除神经网络中冗余的神经元或连接来减少模型的复杂度和计算量。具体来说,可以通过设定一个阈值来判断神经元或连接的重要性,然后将不重要的神经元或连接删除。
- 低秩分解(Low-Rank Factorization):将神经网络中的权重矩阵分解为两个低秩矩阵的乘积,从而减少模型的存储空间和计算量。这种方法可以应用于卷积层和全连接层等不同类型的神经网络层。
- 结构搜索(Neural Architecture Search):通过自动搜索最优的神经网络结构来实现模型的压缩和加速。这种方法可以根据特定任务的需求来定制适合的神经网络结构。
相关文章:

深度学习中的模型蒸馏技术:实现流程、作用及实践案例
在深度学习领域,模型压缩与部署是一项重要的研究课题,而模型蒸馏便是其中一种有效的方法。 模型蒸馏(Model Distillation)最初由Hinton等人在2015年提出,其核心思想是通过知识迁移的方式,将一个复杂的大模型…...

Java服务运行在Linux----维护常用命令
想起来哪些再添加上去 查看Java程序进程 jps -l 查出进程后根据pid 查询程序所在目录 pwdx 31313 根据端口查找PID 根据pid杀死程序 kill -p 31313 查看目录下所有包含9527的文件 grep -rn 9527 查看磁盘空间 查找文件名"nginx"文件或模糊查找"*nginx*&quo…...

夜晚水闸3D可视化:科技魔法点亮水利新纪元
在宁静的夜晚,当城市的霓虹灯逐渐暗淡,你是否曾想过,那些默默守护着城市安全的水闸,在科技的魔力下,正焕发出别样的光彩?今天,就让我们一起走进夜晚水闸3D模型,感受科技为水利带来的…...

从零开始的软件开发实战:互联网医院APP搭建详解
今天,笔者将以“从零开始的软件开发实战:互联网医院APP搭建详解”为主题,深入探讨互联网医院APP的开发过程和关键技术。 第一步:需求分析和规划 互联网医院APP的主要功能包括在线挂号、医生预约、医疗咨询、健康档案管理等。我们…...
【深度学习】YOLO检测器的发展历程
YOLO检测器的发展历程 YOLO(You Only Look Once)检测器是一种流行的实时对象检测系统,以其速度和准确性而闻名。自2016年首次推出以来,YOLO已经成为计算机视觉领域的一个重要里程碑。在本博客中,我们将探讨YOLO检测器…...

C语言--编译和链接
1.翻译环境 计算机能够执行二进制指令,我们的电脑不会直接执行C语言代码,编译器把代码转换成二进制的指令; 我们在VS上面写下printf("hello world");这行代码的时候,经过翻译环境,生成可执行的exe文件&…...
实现使用C#代码完成wifi的切换和连接功能
实现使用C#代码完成wifi的切换和连接功能 代码如下: namespace Wifi连接器 {public partial class Form1 : Form{private List<Wlan.WlanAvailableNetwork> NetWorkList new List<Wlan.WlanAvailableNetwork>();private WlanClient.WlanInterface Wla…...

Mac添加和关闭开机应用
文章目录 mac添加和关闭开机应用添加开机应用删除/查看 mac添加和关闭开机应用 添加开机应用 删除/查看 打开:系统设置–》通用–》登录项–》查看登录时打开列表 选中打开项目,点击“-”符号...

QT QInputDialog弹出消息框用法
使用QInputDialog类的静态方法来弹出对话框获取用户输入,缺点是不能自定义按钮的文字,默认为OK和Cancel: int main(int argc, char *argv[]) {QApplication a(argc, argv);bool isOK;QString text QInputDialog::getText(NULL, "Input …...

Unity3d使用Jenkins自动化打包(Windows)(一)
文章目录 前言一、安装JDK二、安装Jenkins三、Jenkins插件安装和使用基础操作 实战一基础操作 实战二 四、离线安装总结 前言 本篇旨在介绍基础的安装和操作流程,只需完成一次即可。后面的篇章将深入探讨如何利用Jenkins为Unity项目进行打包。 一、安装JDK 1、进入…...

HarmonyOS 应用开发之Want的定义与用途
Want 是一种对象,用于在应用组件之间传递信息。 其中,一种常见的使用场景是作为 startAbility() 方法的参数。例如,当UIAbilityA需要启动UIAbilityB并向UIAbilityB传递一些数据时,可以使用Want作为一个载体,将数据传递…...

enscan自动化主域名信息收集
enscan下载 Releases wgpsec/ENScan_GO (github.com) 能查的分类 实操: 首先打开linux 的虚拟机、 然后把下面这个粘贴到虚拟机中 解压后打开命令行 初始化 ./enscan-0.0.16-linux-amd64 -v 命令参数如下 oppo信息收集 运行下面代码时 先去配置文件把coo…...

分享全栈开发医疗小程序 -带源码课件(课件无解压密码),自行速度保存
课程介绍 分享全栈开发医疗小程序 -带源码课件(课件无解压密码),自行速度保存!看到好多坛友都在求SpringBoot2.X Vue UniAPP,全栈开发医疗小程序 - 带源码课件,我看了一下,要么链接过期&…...

基于YOLOv8与ByteTrack实现多目标跟踪——算法原理与代码实践
概述 在目标检测中,有许多经算法如Faster RCNN、SSD和YOLO的各种版本,这些算法利用深度学习技术,特别是卷积神经网络(CNN),能够高效地在图像中定位和识别不同类别的目标。Faster RCNN是一种基于区域提议的…...
C语言——函数练习程序
1.从终端接收一个数,封装一个函数判断该数是否为素数 #include <stdio.h>int pri(int num) {int i 0;for (i 2; i < num; i){if (num % i 0){return 0;break;}}if (i num-1){return 1;} }int main(void) {int num 0;int ret 0;scanf("%d", &num);…...
ssh 启动 docker 中 app, docker logs 无日志
ssh 启动 app, 标准输出被重定向 ssh 客户端,而不是 docker 容器的标准输出。只需要在启动时把app 标准输出重定向到 docker标准输出。 测试如下: 1.启动 docker docker run -it -p 60022:22 --name test test:v4 bash -c "service ssh restart;…...

WPF---1.入门学习
🎈个人主页:靓仔很忙i 💻B 站主页:👉B站👈 🎉欢迎 👍点赞✍评论⭐收藏 🤗收录专栏:WPF 🤝希望本文对您有所裨益,如有不足之处…...

Vue3 + Vite + TS + Element-Plus + Pinia项目(5)对axios进行封装
1、在src文件夹下新建config文件夹后,新建baseURL.ts文件,用来配置http主链接 2、在src文件夹下新建http文件夹后,新建request.ts文件,内容如下 import axios from "axios" import { ElMessage } from element-plus im…...
【Rust】——编写自动化测试(一)
🎃个人专栏: 🐬 算法设计与分析:算法设计与分析_IT闫的博客-CSDN博客 🐳Java基础:Java基础_IT闫的博客-CSDN博客 🐋c语言:c语言_IT闫的博客-CSDN博客 🐟MySQL:…...

第十二章 微服务核心(一)
一、Spring Boot 1.1 SpringBoot 构建方式 1.1.1 通过官网自动生成 进入官网:https://spring.io/,点击 Projects --> Spring Framework; 拖动滚动条到中间位置,点击 Spring Initializr 或者直接通过 https://start.spring…...

XML Group端口详解
在XML数据映射过程中,经常需要对数据进行分组聚合操作。例如,当处理包含多个物料明细的XML文件时,可能需要将相同物料号的明细归为一组,或对相同物料号的数量进行求和计算。传统实现方式通常需要编写脚本代码,增加了开…...
基于服务器使用 apt 安装、配置 Nginx
🧾 一、查看可安装的 Nginx 版本 首先,你可以运行以下命令查看可用版本: apt-cache madison nginx-core输出示例: nginx-core | 1.18.0-6ubuntu14.6 | http://archive.ubuntu.com/ubuntu focal-updates/main amd64 Packages ng…...

linux arm系统烧录
1、打开瑞芯微程序 2、按住linux arm 的 recover按键 插入电源 3、当瑞芯微检测到有设备 4、松开recover按键 5、选择升级固件 6、点击固件选择本地刷机的linux arm 镜像 7、点击升级 (忘了有没有这步了 估计有) 刷机程序 和 镜像 就不提供了。要刷的时…...
将对透视变换后的图像使用Otsu进行阈值化,来分离黑色和白色像素。这句话中的Otsu是什么意思?
Otsu 是一种自动阈值化方法,用于将图像分割为前景和背景。它通过最小化图像的类内方差或等价地最大化类间方差来选择最佳阈值。这种方法特别适用于图像的二值化处理,能够自动确定一个阈值,将图像中的像素分为黑色和白色两类。 Otsu 方法的原…...

Python爬虫(一):爬虫伪装
一、网站防爬机制概述 在当今互联网环境中,具有一定规模或盈利性质的网站几乎都实施了各种防爬措施。这些措施主要分为两大类: 身份验证机制:直接将未经授权的爬虫阻挡在外反爬技术体系:通过各种技术手段增加爬虫获取数据的难度…...
高防服务器能够抵御哪些网络攻击呢?
高防服务器作为一种有着高度防御能力的服务器,可以帮助网站应对分布式拒绝服务攻击,有效识别和清理一些恶意的网络流量,为用户提供安全且稳定的网络环境,那么,高防服务器一般都可以抵御哪些网络攻击呢?下面…...

手机平板能效生态设计指令EU 2023/1670标准解读
手机平板能效生态设计指令EU 2023/1670标准解读 以下是针对欧盟《手机和平板电脑生态设计法规》(EU) 2023/1670 的核心解读,综合法规核心要求、最新修正及企业合规要点: 一、法规背景与目标 生效与强制时间 发布于2023年8月31日(OJ公报&…...
ThreadLocal 源码
ThreadLocal 源码 此类提供线程局部变量。这些变量不同于它们的普通对应物,因为每个访问一个线程局部变量的线程(通过其 get 或 set 方法)都有自己独立初始化的变量副本。ThreadLocal 实例通常是类中的私有静态字段,这些类希望将…...
JavaScript 标签加载
目录 JavaScript 标签加载script 标签的 async 和 defer 属性,分别代表什么,有什么区别1. 普通 script 标签2. async 属性3. defer 属性4. type"module"5. 各种加载方式的对比6. 使用建议 JavaScript 标签加载 script 标签的 async 和 defer …...

统计按位或能得到最大值的子集数目
我们先来看题目描述: 给你一个整数数组 nums ,请你找出 nums 子集 按位或 可能得到的 最大值 ,并返回按位或能得到最大值的 不同非空子集的数目 。 如果数组 a 可以由数组 b 删除一些元素(或不删除)得到,…...