当前位置: 首页 > news >正文

深度学习中的知识蒸馏

大家好,我是小青

今天给大家分享神经网络中的一个关键概念,知识蒸馏

知识蒸馏(Knowledge Distillation)是一种模型压缩技术,旨在将大型、复杂的模型(通常称为教师模型)的知识迁移到小型、简单的模型(学生模型)中。通过这种方式,学生模型可以在保持较高性能的同时,显著减少计算资源和存储需求。

知识蒸馏广泛用于深度学习领域,尤其在计算资源有限的场景(如移动端设备、嵌入式设备)中,用于加速推理、减少存储成本,同时尽可能保持模型性能。

图片

核心思想

知识蒸馏的核心思想是利用教师模型的输出(通常是软标签,即概率分布)来指导学生模型的训练。与传统的监督学习不同,知识蒸馏不仅使用真实标签(硬标签),还利用教师模型生成的软标签来传递更多的信息。

通过这种方式,学生模型不仅学习到数据的类别信息,还能够捕捉到类别之间的相似性和关系,从而提升其泛化能力。

关键技术与方法

知识蒸馏的核心在于让学生模型不仅仅学习真实标签,还学习教师模型提供的软标签,即教师模型输出的概率分布。这种方式可以让学生模型获得更丰富的信息。

传统神经网络的交叉熵损失

在传统的神经网络训练中,我们通常用交叉熵损失(Cross-Entropy Loss)来训练分类模型:

传统的交叉熵损失函数仅利用了数据的硬标签(hard labels),即 仅在真实类别处为 1,其他类别为 0,导致模型无法学习类别之间的相似性信息。

知识蒸馏的损失函数

在知识蒸馏中,教师模型提供了一种软标签(soft targets),即对所有类别的预测分布,而不仅仅是单个类别。

这些软标签由温度化 Softmax 得到。

知识蒸馏的优势

  • 模型压缩:学生模型通常比教师模型小得多,适合在资源受限的设备上部署。

  • 性能保持:通过知识蒸馏,学生模型能够在保持较高性能的同时,显著减少计算资源和存储需求。

  • 泛化能力:软标签提供了更多的信息,有助于学生模型更好地泛化。

知识蒸馏的变种

除了标准的知识蒸馏方法,研究人员还提出了多个改进版本。

  1. 自蒸馏(Self-Distillation):模型自身作为教师,将深层网络的知识蒸馏到浅层部分。

  2. 多教师蒸馏(Multi-Teacher Distillation):多个教师模型联合指导学生模型,融合不同教师的知识。

  3. 在线蒸馏(Online Distillation):教师模型和学生模型同步训练,而不是先训练教师模型再训练学生模型。

案例分享

下面是一个完整的知识蒸馏的示例代码,使用 PyTorch 训练一个教师模型并将其知识蒸馏到学生模型。

这里,我们采用 MNIST 数据集,教师模型使用一个较大的神经网络,而学生模型是一个较小的神经网络。

首先,定义教师模型和学生模型。

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
# 教师模型(较大的神经网络)
class TeacherModel(nn.Module):def __init__(self):super(TeacherModel, self).__init__()self.fc1 = nn.Linear(28 * 28, 512)self.fc2 = nn.Linear(512, 256)self.fc3 = nn.Linear(256, 10)def forward(self, x):x = x.view(-1, 28 * 28)x = F.relu(self.fc1(x))x = F.relu(self.fc2(x))x = self.fc3(x)  # 注意这里没有 Softmaxreturn x# 学生模型(较小的神经网络)
class StudentModel(nn.Module):def __init__(self):super(StudentModel, self).__init__()self.fc1 = nn.Linear(28 * 28, 128)self.fc2 = nn.Linear(128, 10)def forward(self, x):x = x.view(-1, 28 * 28)x = F.relu(self.fc1(x))x = self.fc2(x)  # 注意这里没有 Softmaxreturn x

然后加载数据集。

# 数据预处理
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])# 加载 MNIST 数据集
train_dataset = datasets.MNIST(root="./data", train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root="./data", train=False, download=True, transform=transform)train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=1000, shuffle=False)

训练教师模型

def train_teacher(model, train_loader, epochs=5, lr=0.001):optimizer = optim.Adam(model.parameters(), lr=lr)criterion = nn.CrossEntropyLoss()for epoch in range(epochs):model.train()total_loss = 0for images, labels in train_loader:optimizer.zero_grad()output = model(images)loss = criterion(output, labels)loss.backward()optimizer.step()total_loss += loss.item()print(f"Epoch [{epoch+1}/{epochs}], Loss: {total_loss / len(train_loader):.4f}")# 初始化并训练教师模型
teacher_model = TeacherModel()
train_teacher(teacher_model, train_loader)

知识蒸馏训练学生模型

def distillation_loss(student_logits, teacher_logits, labels, T=3.0, alpha=0.5):"""计算蒸馏损失,结合知识蒸馏损失和交叉熵损失"""soft_targets = F.softmax(teacher_logits / T, dim=1)  # 教师模型的软标签soft_predictions = F.log_softmax(student_logits / T, dim=1)  # 学生模型的预测distillation_loss = F.kl_div(soft_predictions, soft_targets, reduction="batchmean") * (T ** 2)ce_loss = F.cross_entropy(student_logits, labels)return alpha * ce_loss + (1 - alpha) * distillation_lossdef train_student_with_distillation(student_model, teacher_model, train_loader, epochs=5, lr=0.001, T=3.0, alpha=0.5):optimizer = optim.Adam(student_model.parameters(), lr=lr)teacher_model.eval()  # 设定教师模型为评估模式for epoch in range(epochs):student_model.train()total_loss = 0for images, labels in train_loader:optimizer.zero_grad()student_logits = student_model(images)with torch.no_grad():teacher_logits = teacher_model(images)  # 获取教师模型输出loss = distillation_loss(student_logits, teacher_logits, labels, T=T, alpha=alpha)loss.backward()optimizer.step()total_loss += loss.item()print(f"Epoch [{epoch+1}/{epochs}], Loss: {total_loss / len(train_loader):.4f}")# 初始化学生模型
student_model = StudentModel()
train_student_with_distillation(student_model, teacher_model, train_loader)

评估模型

def evaluate(model, test_loader):model.eval()correct = 0total = 0with torch.no_grad():for images, labels in test_loader:outputs = model(images)_, predicted = torch.max(outputs, 1)correct += (predicted == labels).sum().item()total += labels.size(0)accuracy = 100 * correct / totalreturn accuracy# 评估教师模型
teacher_acc = evaluate(teacher_model, test_loader)
print(f"教师模型准确率: {teacher_acc:.2f}%")# 评估知识蒸馏训练的学生模型
student_acc_distilled = evaluate(student_model, test_loader)
print(f"知识蒸馏训练的学生模型准确率: {student_acc_distilled:.2f}%")

图片

相关文章:

深度学习中的知识蒸馏

大家好,我是小青 今天给大家分享神经网络中的一个关键概念,知识蒸馏 知识蒸馏(Knowledge Distillation)是一种模型压缩技术,旨在将大型、复杂的模型(通常称为教师模型)的知识迁移到小型、简单…...

【Windows软件 - HeidiSQL】导出数据库

HeidSQL导出数据库 软件信息 具体操作 示例文件 选项分析 选项(1) 结果(1) -- -------------------------------------------------------- -- 主机: 127.0.0.1 -- 服务器版本: …...

苏剑林“闭门造车”之多模态思路浅谈思考

原文来自科学空间苏剑林 “闭门造车”之多模态思路浅谈(一):无损输入和“闭门造车”之多模态思路浅谈(二):自回归,学习后总结。 文章目录 “闭门造车”之多模态思路浅谈(一&#xff…...

绿联nas docker 安装 rocketmq 队列。亲测可用

首先拉取docker 镜像,所需镜像如下: 安装 nameserver docker run -d -p 9876:9876 \ -v ${HOME}/docker/software/rocketmq/data/namesrv/logs:/opt/logs \ -v ${HOME}/docker/software/rocketmq/data/namesrv/store:/opt/store \ --name rmqnamesrv \ …...

C++(23):unreachable

C++23在头文件 "><utility>定义了std::unreachable(),用于指示编译器,该段代码不应该被允许,因此编译器可以对该位置进行优化,如果一旦允许了该位置的代码,行为未定义: #include <utility> #include <iostream>using namespace std;int func(…...

初等数论--欧几里得算法

1. 定义 u 0 u 1 ∈ Z , u 1 ≠ 0 , u 1 ∤ u 0 u_0\ u_1\in Z,u_1 \ne0,u_1 \nmid u_0 u0​ u1​∈Z,u1​0,u1​∤u0​ 根据带余除法可得下面一系列等式 u 0 q 0 u 1 u 2 0 < u 2 < ∣ u 1 ∣ u 1 q 0 u 2 u 3 0 < u 3 < u 2 ⋯ u k − 1 q k − 1 u k …...

阿里云前端自动化部署流程指南

本文详细介绍从前端代码开发到阿里云 OSS/CDN 自动化部署的完整流程。 一、流程概览 © ivwdcwso (ID: u012172506) 1.1 部署流程图 #mermaid-svg-H1LBBmwTHAAF3QTL {font-family:"trebuchet ms",verdana,arial,sans-serif;font-size:16px;fill:#333;}#mermai…...

EXCEL解决IF函数“您已为此函数输入太多个参数”的报错

IF函数的基本结构是IF(条件, 值为真时的结果, 值为假时的结果)&#xff0c;所以标准的IF函数最多只能有三个参数。当用户输入的参数超过三个时&#xff0c;Excel就会报这个错误。比如多个IF语句叠加&#xff0c;但可能在嵌套的过程中没有正确关闭每个IF函数的括号&#xff0c;导…...

CAS单点登录(第7版)18.日志和审计

如有疑问&#xff0c;请看视频&#xff1a;CAS单点登录&#xff08;第7版&#xff09; 日志和审计 Logging 概述 Logging CAS 提供了一个日志记录工具&#xff0c;用于记录重要信息事件&#xff0c;如身份验证成功和失败;可以对其进行自定义以生成用于故障排除的其他信息。…...

2025年软件测试面试题大全(附答案+文档)

&#x1f345; 点击文末小卡片&#xff0c;免费获取软件测试全套资料&#xff0c;资料在手&#xff0c;涨薪更快 一、测试基础 1、测试策略或测试包括哪些&#xff0c;测试要覆盖哪些方面 UI、功能、性能、可靠性、易用性、兼容性、安全性、安装卸载 2、设计测试用例的办法 …...

太空飞船任务,生成一个地球发射、火星着陆以及下一次发射窗口返回地球的动画3D代码

import numpy as np import matplotlib.pyplot as plt from matplotlib.animation import FuncAnimation from mpl_toolkits.mplot3d import Axes3D# 天体参数设置&#xff08;简化模型&#xff09; AU 1.5e8 # 天文单位&#xff08;公里&#xff09; earth_orbital_radius …...

IDEA——Mac版快捷键

目录 按键含义常用组合代码生成快捷键&#xff1a;代码追踪快捷键&#xff1a;高效编辑快捷键&#xff1a;代码重构快捷键&#xff1a;工具类快捷键&#xff1a;常规文件操作快捷键&#xff1a; 按键含义 ⌘ command Command键&#xff08;⌘&#xff09;相当于Windows中的Con…...

智能体系统(AI Agent System)是什么?——从概念解析到企业数字化转型的全景落地及投资视角

文章目录 一、 前言1.1 背景介绍1.2 写作目的 二、 智能体系统及相关概念解析2.1 智能体系统定义2.2 关键概念区分2.2.1 自主代理&#xff08;Autonomous Agent&#xff09;2.2.2 多智能体系统&#xff08;MAS&#xff09;2.2.3 人工智能/机器学习&#xff08;AI/ML&#xff09…...

Vue 前端开发中的路由知识:从入门到精通

文章目录 引言1. Vue Router 简介1.1 安装 Vue Router1.2 配置 Vue Router1.3 在 Vue 实例中使用 Vue Router 2. 路由的基本用法2.1 路由映射2.2 路由视图2.3 路由链接 3. 动态路由3.1 动态路径参数3.2 访问动态参数3.3 响应路由参数的变化 4. 嵌套路由4.1 定义嵌套路由4.2 渲染…...

前端VUE+后端uwsgi 环境搭建

1整体架构 请求流程the web clinet--the web server->the socket->uwsgi--django 第一级的nginx并不是必须的&#xff0c;uwsgi完全可以完成整个的和浏览器交互的流程&#xff1b;在nginx上加上安全性或其他的限制&#xff0c;可以达到保护程序的作用&#xff1b;uWSGI本…...

I2C实践开发 ---【STM32-I2C-HDC1080温湿度采集系统】

I2C实践开发 — STM32-I2C-HDC1080温湿度采集系统 目录 I2C实践开发 --- STM32-I2C-HDC1080温湿度采集系统1. 引言2. 系统架构2.1 硬件架构2.2 软件架构 3. 代码分析3.1 I2C驱动文件 (i2c.h 和 i2c.c)3.2 HDC1080传感器驱动文件 (hdc1080.h 和 hdc1080.c) 4. 功能总结【HDC1080…...

【个人开发】deepspeed+Llama-factory 本地数据多卡Lora微调【完整教程】

文章目录 1.背景2.微调方式2.1 关键环境版本信息2.2 步骤2.2.1 下载llama-factory2.2.2 准备数据集2.2.3 微调模式2.2.3.1 zero-1微调2.2.3.2 zero-2微调2.2.3.3 zero-3微调2.2.3.4 单卡Lora微调 2.2.4 实验2.2.4.1 实验1&#xff1a;多GPU微调-zero12.2.4.2 实验2&#xff1a;…...

浏览器报错:无法访问此网站 无法找到xxx.xxx.net的DNS地址。正在诊断该问题。尝试运行Windows网络诊断。DNS_PROBE_STARTED

&#x1f91f;致敬读者 &#x1f7e9;感谢阅读&#x1f7e6;希望我的文章能帮到您&#x1f7ea;如有兴趣可点关注了解更多内容 &#x1f4d8;博主信息 点击标题&#x1f446;有惊喜 &#x1f4c3;文章前言 &#x1f537;文章均为学习和工作中整理的笔记&#xff0c;分享记录…...

【设计模式】 代理模式(静态代理、动态代理{JDK动态代理、JDK动态代理与CGLIB动态代理的区别})

代理模式 代理模式是一种结构型设计模式&#xff0c;它提供了一种替代访问的方法&#xff0c;即通过代理对象来间接访问目标对象。代理模式可以在不改变原始类代码的情况下&#xff0c;增加额外的功能&#xff0c;如权限控制、日志记录等。 静态代理 静态代理是指创建的或特…...

网络安全-攻击流程-用户层

用户层攻击主要针对操作系统中的用户空间应用程序及用户权限&#xff0c;利用软件漏洞、配置错误或用户行为弱点进行攻击。以下是常见的用户层攻击类型及其流程&#xff0c;以及防御措施&#xff1a; 1. 缓冲区溢出攻击 攻击流程&#xff1a; 目标识别&#xff1a;确定存在漏…...

wordpress后台更新后 前端没变化的解决方法

使用siteground主机的wordpress网站&#xff0c;会出现更新了网站内容和修改了php模板文件、js文件、css文件、图片文件后&#xff0c;网站没有变化的情况。 不熟悉siteground主机的新手&#xff0c;遇到这个问题&#xff0c;就很抓狂&#xff0c;明明是哪都没操作错误&#x…...

未来机器人的大脑:如何用神经网络模拟器实现更智能的决策?

编辑&#xff1a;陈萍萍的公主一点人工一点智能 未来机器人的大脑&#xff1a;如何用神经网络模拟器实现更智能的决策&#xff1f;RWM通过双自回归机制有效解决了复合误差、部分可观测性和随机动力学等关键挑战&#xff0c;在不依赖领域特定归纳偏见的条件下实现了卓越的预测准…...

反向工程与模型迁移:打造未来商品详情API的可持续创新体系

在电商行业蓬勃发展的当下&#xff0c;商品详情API作为连接电商平台与开发者、商家及用户的关键纽带&#xff0c;其重要性日益凸显。传统商品详情API主要聚焦于商品基本信息&#xff08;如名称、价格、库存等&#xff09;的获取与展示&#xff0c;已难以满足市场对个性化、智能…...

python爬虫:Newspaper3k 的详细使用(好用的新闻网站文章抓取和解析的Python库)

更多内容请见: 爬虫和逆向教程-专栏介绍和目录 文章目录 一、Newspaper3k 概述1.1 Newspaper3k 介绍1.2 主要功能1.3 典型应用场景1.4 安装二、基本用法2.2 提取单篇文章的内容2.2 处理多篇文档三、高级选项3.1 自定义配置3.2 分析文章情感四、实战案例4.1 构建新闻摘要聚合器…...

Rust 异步编程

Rust 异步编程 引言 Rust 是一种系统编程语言,以其高性能、安全性以及零成本抽象而著称。在多核处理器成为主流的今天,异步编程成为了一种提高应用性能、优化资源利用的有效手段。本文将深入探讨 Rust 异步编程的核心概念、常用库以及最佳实践。 异步编程基础 什么是异步…...

【学习笔记】深入理解Java虚拟机学习笔记——第4章 虚拟机性能监控,故障处理工具

第2章 虚拟机性能监控&#xff0c;故障处理工具 4.1 概述 略 4.2 基础故障处理工具 4.2.1 jps:虚拟机进程状况工具 命令&#xff1a;jps [options] [hostid] 功能&#xff1a;本地虚拟机进程显示进程ID&#xff08;与ps相同&#xff09;&#xff0c;可同时显示主类&#x…...

算法岗面试经验分享-大模型篇

文章目录 A 基础语言模型A.1 TransformerA.2 Bert B 大语言模型结构B.1 GPTB.2 LLamaB.3 ChatGLMB.4 Qwen C 大语言模型微调C.1 Fine-tuningC.2 Adapter-tuningC.3 Prefix-tuningC.4 P-tuningC.5 LoRA A 基础语言模型 A.1 Transformer &#xff08;1&#xff09;资源 论文&a…...

PAN/FPN

import torch import torch.nn as nn import torch.nn.functional as F import mathclass LowResQueryHighResKVAttention(nn.Module):"""方案 1: 低分辨率特征 (Query) 查询高分辨率特征 (Key, Value).输出分辨率与低分辨率输入相同。"""def __…...

【Linux】自动化构建-Make/Makefile

前言 上文我们讲到了Linux中的编译器gcc/g 【Linux】编译器gcc/g及其库的详细介绍-CSDN博客 本来我们将一个对于编译来说很重要的工具&#xff1a;make/makfile 1.背景 在一个工程中源文件不计其数&#xff0c;其按类型、功能、模块分别放在若干个目录中&#xff0c;mak…...

【前端异常】JavaScript错误处理:分析 Uncaught (in promise) error

在前端开发中&#xff0c;JavaScript 异常是不可避免的。随着现代前端应用越来越多地使用异步操作&#xff08;如 Promise、async/await 等&#xff09;&#xff0c;开发者常常会遇到 Uncaught (in promise) error 错误。这个错误是由于未正确处理 Promise 的拒绝&#xff08;r…...