当前位置: 首页 > news >正文

【论文笔记】LLaVA-KD: A Framework of Distilling Multimodal Large Language Models

Abstract

大语言模型(Large Language Models, LLM)的成功,使得研究者为了统一视觉和语言的理解去探索多模态大预言模型(Multimodal Large Language Models, MLLM)。
但是MLLM庞大的模型和复杂的计算使其很难应用在资源受限的环境,小型MLLM(s-MLLM)的表现又远不如大型的MLLM(l-MLLM)。

基于上述提到的问题,本文提出了全新的LLaVA-KD框架,将l-MLLM的知识转移到s-MLLM。具体地,本文提出:

  • 多模态蒸馏(Multimodal Distillation, MDist),减小l-MLLM和s-MLLM之间的视觉-文本输出分布的差异;
  • 关系蒸馏(Relation Distillation, RDist),迁移l-MLLM对视觉特征之间相关性的建模能力。
    本文还提出了三阶段训练方案,充分挖掘s-MLLM的潜力:
  • 预训练时蒸馏,对齐视觉文本表示;
  • 有监督地微调,使模型具备多模态理解;
  • 微调时蒸馏,进一步迁移l-MLLM的能力。

本文方法在不改变s-MLLM结构的情况下显著提升了其性能。

github仓库

Introduction

本文从研究各种训练策略的角度出发,在不改变模型架构的情况下,探索提高、s-MLLM的性能。

![[Pasted image 20241123173726.png]]

图1:为了训练s-MLLM,(a)已有方法遵循两步训练,包括预训练(Pre-Training, PT)和监督微调(Supervised Fine-Tuning, SFT);(b)本文的LLaVA-KD提出了三步训练,包含蒸馏预训练(Distilled Pre-Training, DPT)来对齐视觉文本表示、SFT来提升模型的多模态理解能力、蒸馏微调(Distilled Fine-Tuning, DFT)来转移l-MLLM的能力;©本文将LLaVA-KD和其他sota的MLLM在五个热门的多模态benchmark上进行比较。

如图1所示,已有的s-MLLM遵循两步训练策略,包含PT和SFT。PT阶段将视觉特征投影到文本嵌入空间;SFT阶段增强模型的理解和推离能力。
但是s-MLLM的模型容量很小,很难像l-MLLM那样捕获复杂的知识。本文将研究如何借助蒸馏提升s-MLLM的训练。

3 LLaVA-KD

![[Pasted image 20241123184930.png]]

图2:LLaVA-KD的总图,包含三阶段训练。1) DPT,向l-MLLM对齐视觉和文本信息。2) SFT,为s-MLLM带来多模态理解能力。3)DFT,向s-MLLM迁移l-MLLM的能力。在DPT和DFT中应用MDist,使用RDist来使得s-MLLM捕获视觉信息的复杂关系。

3.1 Composition of Distilled MLLM Architecture

图2左侧展示了MLLM的蒸馏过程,包含l-MLLM作为教师模型,和s-MLLM作为学生模型,分别包含三个部分:

Frozen Visual Encoder:用于获得强力的视觉特征。给定输入图像 X v ∈ R H × W × 3 X_v\in\mathbb{R}^{H\times W\times 3} XvRH×W×3,排序成2D patches P v ∈ R N p × S p 2 × 3 P_v\in\mathbb{R}^{N_p\times S_p^2\times 3} PvRNp×Sp2×3,其中 S p S_p Sp N p N_p Np表示patch的大小和数量。最后的transformer层将 P v P_v Pv变成 Z v ∈ R N p × C Z_v\in\mathbb{R}^{N_p\times C} ZvRNp×C,其中特征维度为 C C C。教师和学生都使用相同的Frozen Visual Encoder。

Visual Projector:包含两个MLP层,带有激活函数GELU,将 Z v Z_v Zv映射到文本嵌入空间 H v ∈ R N p × D H_v\in\mathbb{R}^{N_p\times D} HvRNp×D,其中 D D D是嵌入空间维度。

Large Language Model (LLM):用于实现对视觉和语言信息的统一认识。给定视觉嵌入的多模态输入 H v H_v Hv和文本嵌入 H t H_t Ht,LLM将二者的连接 H = [ H v , H t ] H=[H_v,H_t] H=[Hv,Ht]作为输入,生成输出 y = [ y p , y v , y r ] = { y t } t = 1 T y=[y_p,y_v,y_r]=\{y_t\}_{t=1}^T y=[yp,yv,yr]={yt}t=1T,其中 y p , y v , y r y_p,y_v,y_r yp,yv,yr分别代表prompt、视觉和响应tokens, T T T代表所有预测token的长度。本文将教师和学生的LLM分别称为l-LLM和s-LLM。

3.2 Training Scheme of Teacher Model L-MLLM

Pre-Training:Visual Encoder和l-LLM冻结,只有Projector被优化,用于对齐视觉和文本特征。训练过程中,使用图像-描述对,对应的目标公式表示为:
L reg = − ∑ m = 1 M log ⁡ ϕ l ( y m ∣ y < m ) (1) \mathcal{L}_\text{reg}=-\sum_{m=1}^M \log\phi_l(y_m|y_{<m})\tag{1} Lreg=m=1Mlogϕl(ymy<m)(1)
其中 M M M表示预测的响应tokens的长度, ϕ l ( y m ∣ y < m ) \phi_l(y_m|y_{<m}) ϕl(ymy<m)表示响应tokens y m y_m ym的分布基于先前预测 y < m y_{<m} y<m的条件。

Supervised Fine-Tuning:该阶段保持Visual Encoder的冻结,旨在联合优化Projector和l -LLM,以增强教师模型l-MLLM的理解和教学跟随能力。训练过程中,利用高质量的对话数据集,训练目标 L S F T \mathcal{L}_{SFT} LSFT如Eq.1所示。

3.3 Framework of LLaVA-KD

3.3.1 MLLM-Oriented KD Strategy

Multimodal Distillation (MDist):考虑到MLLM本质上是利用LLM进行多模态信息理解和推理,我们沿用LLM的朴素蒸馏方法,即利用KL散度(KLD)对响应预测进行蒸馏。训练目标可以定义为:
L res = ∑ m = 1 M KLD ( ϕ l ( y m ∣ y < m ) , ϕ s ( y m ∣ y < m ) ) = ∑ m = 1 M ∑ j = 1 V ϕ l ( Y j ∣ y < m ) log ⁡ ( ϕ l ( Y j ∣ y < m ) ϕ s ( Y j ∣ y < m ) ) (2) \begin{aligned} \mathcal{L}_\text{res}&=\sum_{m=1}^M \text{KLD}(\phi_l(y_m|y_{<m}),\phi_s(y_m|y_{<m})) \\ &=\sum_{m=1}^M \sum_{j=1}^V \phi_l(Y_j|y_{<m})\log (\frac{\phi_l(Y_j|y_{<m})}{\phi_s(Y_j|y_{<m})})\tag{2} \end{aligned} Lres=m=1MKLD(ϕl(ymy<m),ϕs(ymy<m))=m=1Mj=1Vϕl(Yjy<m)log(ϕs(Yjy<m)ϕl(Yjy<m))(2)
其中 M M M表示响应tokens的长度, V V V表示词汇空间。 ϕ l \phi_l ϕl ϕ s \phi_s ϕs表示l-MLLM和s-MLLM的参数, ϕ l ( Y j ∣ y < m ) \phi_l(Y_j|y_{<m}) ϕl(Yjy<m) ϕ s ( Y j ∣ y < m ) \phi_s(Y_j|y_{<m}) ϕs(Yjy<m)表示由l-MLLM和s-MLLM预测的词汇 Y j Y_j Yj出现在token y m y_m ym的概率。

同时,视觉表征对于LLM的多模态理解也至关重要。因此,进一步优化教师和学生输出视觉分布之间的KLD:
L vis = ∑ k = 1 K ∑ j = 1 V ϕ l ( Y j ∣ y < k ) log ⁡ ( ϕ l ( Y j ∣ y < k ) ϕ s ( Y j ∣ y < k ) ) (3) \mathcal{L}_\text{vis}=\sum_{k=1}^K \sum_{j=1}^V \phi_l(Y_j|y_{<k})\log (\frac{\phi_l(Y_j|y_{<k})}{\phi_s(Y_j|y_{<k})})\tag{3} Lvis=k=1Kj=1Vϕl(Yjy<k)log(ϕs(Yjy<k)ϕl(Yjy<k))(3)
其中 K K K表示视觉token的长度, ϕ l ( Y j ∣ y < k ) \phi_l(Y_j|y_{<k}) ϕl(Yjy<k) ϕ s ( Y j ∣ y < k ) \phi_s(Y_j|y_{<k}) ϕs(Yjy<k)分别表示由l-MLLM和s-MLLM预测的词汇 Y j Y_j Yj出现在token y k y_k yk的概率。

本文在DPT阶段用MDist来对齐s-MLLM中的视觉和语言特征,加强了s-MLLM的理解能力。

Relation Distillation (RDist):为了使学生模型能够捕获视觉信息中的复杂关系,本文从LLM输出的视觉tokens中构造自相关矩阵。通过优化矩阵之间的相似性,学生模型继承了教师模型理解视觉tokens之间错综复杂关系的能力。为了达到这个目的,首先计算自相关矩阵如下:
{ R v s = y v s ⊗ y v s ∈ R N p × N p R v t = y v t ⊗ y v t ∈ R N p × N p \begin{equation} \left\{ \begin{aligned} R_v^s &= y_v^s\otimes y_v^s\in\mathbb{R}^{N_p\times N_p} \\ R_v^t &= y_v^t\otimes y_v^t\in\mathbb{R}^{N_p\times N_p} \end{aligned} \right.\tag{4} \end{equation} {RvsRvt=yvsyvsRNp×Np=yvtyvtRNp×Np(4)
其中 ⊗ \otimes 表示矩阵乘法, y v s y_v^s yvs y v t y_v^t yvt表示学生和教师的视觉logits, N p N_p Np表示视觉token的数量。目标是最大化 R v s R_v^s Rvs R v t R_v^t Rvt的余弦相似度:
L rel = 1 − Cos ( R v s , R v t ) = 1 − R v s ⋅ R v t ∣ ∣ R v s ∣ ∣ ∣ ∣ R v t ∣ ∣ (5) \mathcal{L}_\text{rel}=1-\text{Cos}(R_v^s,R_v^t)=1-\frac{R_v^s\cdot R_v^t}{||R_v^s||\ ||R_v^t||}\tag{5} Lrel=1Cos(Rvs,Rvt)=1∣∣Rvs∣∣ ∣∣Rvt∣∣RvsRvt(5)
用RDist可以进一步提升s-MLLM在DPT和DFT阶段的视觉表达能力。

3.3.2 Three-stage Distillation Scheme

Distilled Pre-Training (DPT):该阶段的主要目的是将视觉特征投射到文本嵌入空间。在LLaVA-KD中,使用蒸馏过程来像l-MLLM一样更好地对齐视觉和文本信息。

具体地,冻结visual encoder和s-MLLM中的LLM,只优化projector。在训练过程中,通过MDist最小化学生模型和教师模型在视觉和反应的输出分布上的差异。
为了优化这个目标,可以进一步促进投影的视觉特征与文本嵌入的对齐。此外,我们利用RDist来增强视觉特征的质量,使学生模型能够借鉴教师模型处理复杂视觉信息的能力。

总的来说,除了优化自回归预测结果,还使用了MDist和RDist:
L DPT = L PT + α L res + β L vis + γ L rel (6) \mathcal{L}_\text{DPT}=\mathcal{L}_\text{PT}+\alpha\mathcal{L}_\text{res}+\beta\mathcal{L}_\text{vis}+\gamma\mathcal{L}_\text{rel}\tag{6} LDPT=LPT+αLres+βLvis+γLrel(6)

Supervised Fine-Tuning (SFT):这个阶段遵循l-MLLM训练阶段的通用SFT过程(Sec.3.2)。通过联合训练Projector和l-LLM,使模型具有推理能力和指令跟踪能力。训练目标由Eq.1定义,表示为 L SFT ′ \mathcal{L}_\text{SFT}' LSFT

Distilled Fine-Tuning (DFT):该阶段的主要目标是进一步增强s-MLLM的理解和推理能力。具体来说,采用了MDist和RDist相结合的蒸馏策略,冻结了Visual Encoder,优化了Projector和sLLM。通过使用MDist,可以对s-MLLM中的小规模s-LLM进行充分优化,从而更好地模拟大规模l-LLM的推理能力。和RDist可以进一步促进s-MLLM学习l-MLLM的视觉表征。

总体训练目标可以表示为:
L D F T = L reg + α ′ L res + β ′ L vis + γ ′ L rel (7) \mathcal{L}_{DFT}=\mathcal{L}_\text{reg}+\alpha'\mathcal{L}_\text{res}+\beta'\mathcal{L}_\text{vis}+\gamma'\mathcal{L}_\text{rel}\tag{7} LDFT=Lreg+αLres+βLvis+γLrel(7)
其中 L reg \mathcal{L}_\text{reg} Lreg表示自回归预测loss。

相关文章:

【论文笔记】LLaVA-KD: A Framework of Distilling Multimodal Large Language Models

Abstract 大语言模型(Large Language Models, LLM)的成功&#xff0c;使得研究者为了统一视觉和语言的理解去探索多模态大预言模型(Multimodal Large Language Models, MLLM)。 但是MLLM庞大的模型和复杂的计算使其很难应用在资源受限的环境&#xff0c;小型MLLM(s-MLLM)的表现…...

M|大脑越狱

rating: 7.0 豆瓣: 7.6 上映时间: “2015” 类型: M悬疑 导演: 约瑟夫怀特 Joseph White 主演: 亚历山大欧文 Alexander Owen爱德华富兰克林 Edward Franklin 国家/地区: 英国 片长/分钟: 20分钟 M&#xff5c;大脑越狱 想法不错&#xff0c;但是逻辑比较一般。属于…...

数据库编程(sqlite3)

一&#xff1a;数据库分类 常用的数据库 大型数据库 &#xff1a;Oracle商业、多平台、关系型数据库功能最强大、最复杂、市场占比最高的商业数据库 中型数据库 &#xff1a;Server是微软开发的数据库产品&#xff0c;主要支持windows平台 小型数据库 : mySQL是一个小型关系型…...

【C语言】关键字详解

【C语言】关键字详解 文章目录 [TOC](文章目录) 前言一、char1.定义字符串类型2.定义字符类型 二、short三、int四、long五、signed六、unsigned七、float八、double九、struct、union、enum十、void1.void用于函数声明&#xff0c;没有返回值的函数&#xff0c;其类型为 void。…...

什么是计算机网络

什么是计算机网络&#xff1f; 计算机网络的定义计算机网络的分类按覆盖范围分类按拓扑结构分类按通信传输介质分类按信号频带占用方式分类 计算机网络的功能信息交换资源共享分布式处理 计算机网络的组成计算机网络的定义计算机网络的分类按覆盖范围分类按拓扑结构分类按通信传…...

【大数据学习 | Spark-Core】Spark的分区器(HashPartitioner和RangePartitioner)

之前学过的kv类型上面的算子 groupby groupByKey reduceBykey sortBy sortByKey join[cogroup left inner right] shuffle的 mapValues keys values flatMapValues 普通算子&#xff0c;管道形式的算子 shuffle的过程是因为数据产生了打乱重分&#xff0c;分组、排序、join等…...

CSS3_BFC(十二)

BFC MDN对BFC的解释&#xff1a;块格式化上下文&#xff08;Block Formating Context, BFC&#xff09;是web页面的可视CSS渲染的一部分&#xff0c;是块盒子的布局过程发生的区域&#xff0c;也是浮动元素与其他元素交互的区域。 1、开启BFC flow-root对内容的影响是最低的&am…...

C0032.在Clion中使用MSVC编译器编译opencv的配置方法

使用MSVC编译器编译opencv的配置方法...

微信小程序中会议列表页面的前后端实现

题外话&#xff1a;想通过集成腾讯IM来解决即时聊天的问题&#xff0c;如果含语音视频&#xff0c;腾讯组件一年5万起步&#xff0c;贵了&#xff01;后面我们改为自己实现这个功能&#xff0c;这里只是个总结而已。 图文会诊需求 首先是个图文列表界面 同个界面可以查看具体…...

WEB攻防-通用漏洞文件上传二次渲染.htaccess变异免杀

知识点&#xff1a; 1、文件上传-二次渲染 2、文件上传-简单免杀变异 3、文件上传-.htaccess妙用 4、文件上传-PHP语言特性 1、上传后门时&#xff0c;文件内容带.就不行 这时可以上传一个转换后的ip地址&#xff0c;ip地址对应网站包含后门代码 转换后的int会在访问的时候…...

vue实现列表滑动下拉加载数据

一、实现效果 二、实现思路 使用滚动事件监听器来检测用户是否滚动到底部&#xff0c;然后加载更多数据 监听滚动事件。检测用户是否滚动到底部。加载更多数据。 三、案例代码 <div class"drawer-content"><div ref"loadMoreTrigger" class&q…...

全面解析:HTML页面的加载全过程(四)--浏览器渲染之样式计算

主线程遍历得到的 DOM 树&#xff0c;依次为树中的每个节点计算出它最终的样式&#xff0c;称之为 Computed Style。 通过前面生成的DOM 树和 CSSOM 树&#xff0c;遍历 DOM 树&#xff0c;为每一个 DOM 节点&#xff0c;计算它的所有 CSS 属性&#xff0c;最后会得到一棵带有…...

#Verilog HDL# 谈谈代码中如何跨层次引用

目录 一 先谈作用问题 二 再谈跨层次问题 2.1 向下引用 2.2 向上引用 一 先谈作用问题 大多数编程语言都有一个称为作用域(scope)的特征,它定义了代码的某些部分对于变量和方法的可见性。作用域定义了一个命名空间,以避免同一命名空间内不同对象名称之间的冲突。 V…...

LeetCode 每日一题 2024/11/18-2024/11/24

记录了初步解题思路 以及本地实现代码&#xff1b;并不一定为最优 也希望大家能一起探讨 一起进步 目录 11/18 661. 图片平滑器11/19 3243. 新增道路查询后的最短距离 I11/20 3244. 新增道路查询后的最短距离 II11/21 3248. 矩阵中的蛇11/22 3233. 统计不是特殊数字的数字数量1…...

客户流失分析综述

引言 客户流失这个术语通常用来描述在特定时间或合同期内停止与公司进行业务往来的客户倾向性[1]。传统上&#xff0c;关于客户流失的研究始于客户关系管理&#xff08;CRM&#xff09;[2]。在运营服务时&#xff0c;防止客户流失至关重要。过去&#xff0c;客户获取相对于流失…...

基于51单片机的红包抽奖proteus仿真

地址&#xff1a; https://pan.baidu.com/s/1nYZlLb64kdZAWSydT_uHfA 提取码&#xff1a;1234 仿真图&#xff1a; 芯片/模块的特点&#xff1a; AT89C52/AT89C51简介&#xff1a; AT89C52/AT89C51是一款经典的8位单片机&#xff0c;是意法半导体&#xff08;STMicroelectro…...

cangjie (仓颉) vscode环境搭建

sdk下载 下载中心-仓颉编程语言官网 可选择半年更新版&#xff0c;不用申请。目前版本&#xff1a;0.53.13 &#xff0c;选择不同平台压缩包下载解压到任意位置即可 补充下载&#xff0c;vscode插件解压后&#xff0c;在vscode扩展中选择从vsix安装&#xff0c;安装后新增名为…...

阿里云私服地址

1.解压apache-maven-3.6.1-bin 2.配置本地仓库&#xff1a;修改conf/dettings.xml中的<localReoisitory>为一个指定目录。56行 <localRepository>D:\apache-maven-3.6.1-bin\apache-maven-3.6.1\mvn_repo</localRepository> 3.配置阿里云私服&#xff1a;…...

HTMLCSS:3D金字塔加载动画

效果演示 这段代码通过CSS3的3D变换和动画功能&#xff0c;创建了一个旋转的金字塔加载动画&#xff0c;每个侧面都有不同的颜色渐变&#xff0c;底部还有一个模糊的阴影效果&#xff0c;增加了视觉的立体感。 HTML <div class"pyramid-loader"><div cl…...

shell编程(2)(3)

目录 一、永久环境变量 按用户设置永久环境变量 文件路径&#xff1a; 示例步骤&#xff1a; 删除永久环境变量 二、脚本程序传递参数怎么实现 三、用编程进行数学运算 shell中利用expr进行运算 运算与变量结合 1. 变量赋值和基本运算 2. 使用expr进行运算 3. 变量…...

k8s从入门到放弃之Ingress七层负载

k8s从入门到放弃之Ingress七层负载 在Kubernetes&#xff08;简称K8s&#xff09;中&#xff0c;Ingress是一个API对象&#xff0c;它允许你定义如何从集群外部访问集群内部的服务。Ingress可以提供负载均衡、SSL终结和基于名称的虚拟主机等功能。通过Ingress&#xff0c;你可…...

以下是对华为 HarmonyOS NETX 5属性动画(ArkTS)文档的结构化整理,通过层级标题、表格和代码块提升可读性:

一、属性动画概述NETX 作用&#xff1a;实现组件通用属性的渐变过渡效果&#xff0c;提升用户体验。支持属性&#xff1a;width、height、backgroundColor、opacity、scale、rotate、translate等。注意事项&#xff1a; 布局类属性&#xff08;如宽高&#xff09;变化时&#…...

[免费]微信小程序问卷调查系统(SpringBoot后端+Vue管理端)【论文+源码+SQL脚本】

大家好&#xff0c;我是java1234_小锋老师&#xff0c;看到一个不错的微信小程序问卷调查系统(SpringBoot后端Vue管理端)【论文源码SQL脚本】&#xff0c;分享下哈。 项目视频演示 【免费】微信小程序问卷调查系统(SpringBoot后端Vue管理端) Java毕业设计_哔哩哔哩_bilibili 项…...

基于Springboot+Vue的办公管理系统

角色&#xff1a; 管理员、员工 技术&#xff1a; 后端: SpringBoot, Vue2, MySQL, Mybatis-Plus 前端: Vue2, Element-UI, Axios, Echarts, Vue-Router 核心功能&#xff1a; 该办公管理系统是一个综合性的企业内部管理平台&#xff0c;旨在提升企业运营效率和员工管理水…...

MFE(微前端) Module Federation:Webpack.config.js文件中每个属性的含义解释

以Module Federation 插件详为例&#xff0c;Webpack.config.js它可能的配置和含义如下&#xff1a; 前言 Module Federation 的Webpack.config.js核心配置包括&#xff1a; name filename&#xff08;定义应用标识&#xff09; remotes&#xff08;引用远程模块&#xff0…...

C++实现分布式网络通信框架RPC(2)——rpc发布端

有了上篇文章的项目的基本知识的了解&#xff0c;现在我们就开始构建项目。 目录 一、构建工程目录 二、本地服务发布成RPC服务 2.1理解RPC发布 2.2实现 三、Mprpc框架的基础类设计 3.1框架的初始化类 MprpcApplication 代码实现 3.2读取配置文件类 MprpcConfig 代码实现…...

Java详解LeetCode 热题 100(26):LeetCode 142. 环形链表 II(Linked List Cycle II)详解

文章目录 1. 题目描述1.1 链表节点定义 2. 理解题目2.1 问题可视化2.2 核心挑战 3. 解法一&#xff1a;HashSet 标记访问法3.1 算法思路3.2 Java代码实现3.3 详细执行过程演示3.4 执行结果示例3.5 复杂度分析3.6 优缺点分析 4. 解法二&#xff1a;Floyd 快慢指针法&#xff08;…...

PydanticAI快速入门示例

参考链接&#xff1a;https://ai.pydantic.dev/#why-use-pydanticai 示例代码 from pydantic_ai import Agent from pydantic_ai.models.openai import OpenAIModel from pydantic_ai.providers.openai import OpenAIProvider# 配置使用阿里云通义千问模型 model OpenAIMode…...

Qwen系列之Qwen3解读:最强开源模型的细节拆解

文章目录 1.1分钟快览2.模型架构2.1.Dense模型2.2.MoE模型 3.预训练阶段3.1.数据3.2.训练3.3.评估 4.后训练阶段S1: 长链思维冷启动S2: 推理强化学习S3: 思考模式融合S4: 通用强化学习 5.全家桶中的小模型训练评估评估数据集评估细节评估效果弱智评估和民间Arena 分析展望 如果…...

统计学(第8版)——统计抽样学习笔记(考试用)

一、统计抽样的核心内容与问题 研究内容 从总体中科学抽取样本的方法利用样本数据推断总体特征&#xff08;均值、比率、总量&#xff09;控制抽样误差与非抽样误差 解决的核心问题 在成本约束下&#xff0c;用少量样本准确推断总体特征量化估计结果的可靠性&#xff08;置…...