当前位置: 首页 > news >正文

【论文笔记】LongLoRA: Efficient Fine-tuning of Long-Context Large Language Models

🍎个人主页:小嗷犬的个人主页
🍊个人网站:小嗷犬的技术小站
🥭个人信条:为天地立心,为生民立命,为往圣继绝学,为万世开太平。


基本信息

标题: LongLoRA: Efficient Fine-tuning of Long-Context Large Language Models
作者: Yukang Chen, Shengju Qian, Haotian Tang, Xin Lai, Zhijian Liu, Song Han, Jiaya Jia
发表: ICLR 2024
arXiv: https://arxiv.org/abs/2309.12307

基本信息

摘要

我们提出了LongLoRA,一种高效的微调方法,它通过有限的计算成本扩展了预训练大型语言模型(LLM)的上下文大小。

通常,使用长上下文大小训练LLM在计算上非常昂贵,需要大量的训练时间和GPU资源。例如,在 8192 8192 8192 个上下文长度的训练中,自注意力层的计算成本是 2048 2048 2048 个上下文长度的 16 16 16 倍。

在本文中,我们从两个方面加速了LLM上下文扩展。

一方面,尽管在推理过程中需要密集的全局注意力,但通过稀疏局部注意力可以有效地进行模型微调。提出的移位稀疏注意力(S2-Attn)有效地实现了上下文扩展,与使用标准注意力微调具有相似的性能,同时实现了显著的计算节省。特别是,它可以在训练中仅用两行代码实现,而在推理中是可选的。

另一方面,我们重新审视了参数高效的上下文扩展微调机制。值得注意的是,我们发现LoRA在可训练嵌入和归一化的前提下,对于上下文扩展效果良好。LongLoRA将这种改进的LoRA与 S 2 S^2 S2-Attn相结合。

LongLoRA在Llama2模型(从7B/13B到70B)的各种任务上展示了强大的实证结果。LongLoRA将Llama2 7B的上下文从4k扩展到100k,或将Llama2 70B扩展到32k,在单个8×A100机器上完成。

LongLoRA在保持原始架构的同时扩展了模型的上下文,并且与大多数现有技术兼容,如Flash-Attention2。

此外,我们还使用LongLoRA和我们的长指令遵循LongAlpaca数据集进行了监督微调。

我们所有的代码、模型、数据集和演示代码都可在github.com/dvlab-research/LongLoRA上找到。

简介

LongLoRA closes the accuracy gap that between conventional LoRA and full fine-tuning, while still maintaining up to 1.8× lower memory cost than full fine-tuning. Furthermore, LongLoRA improves the training speed of LoRA by up to 1.8× with S2-Attn. Llama2-7B are fine-tuned to various context lengths with Flash-Attention2 (Dao, 2023) and DeepSpeed (Rasley et al., 2020) stage 2 and evaluated on the proof-pile (Azerbayev et al., 2022) test set in perplexity.

LongLoRA缩小了传统LoRA和全量微调之间的精度差距,同时保持了比全量微调低1.8倍的内存成本。此外,LongLoRA通过 S 2 S^2 S2-Attn 将LoRA的训练速度提高了高达1.8倍。Llama2-7B使用Flash-Attention2和DeepSpeed的第二阶段进行微调,并在 proof-pile 测试集上评估了困惑度。

Illustration of S2-Attn. It involves three steps. First, it splits features along the head dimension into two chunks. Second, tokens in one of the chunks are shifted by half of the group size. Third, we split tokens into groups and reshape them into batch dimensions. Attention only computes in each group in ours while the information flows between groups via shifting. Potential information leakage might be introduced by shifting, while this is easy to prevent via a small modification on the attention mask. We ablate this in the variant 2 in Section B.3 in the appendix.

S 2 S^2 S2-Attn的示意图涉及三个步骤。首先,它将特征沿头部维度分为两个部分。其次,其中一个部分中的token向右移动了组大小的一半。第三,我们将token分成组,并将它们重塑为批量维度。在我们的模型中,注意力仅在每组中计算,而信息通过移动在组之间流动。移动可能会引入潜在的信息泄露,但通过在注意力掩码上进行微小修改可以轻松防止。

LongLoRA

背景

Transformer

大型语言模型(LLMs)通常是基于 Transformer 构建的。例如,以 Llama2 为例,一个 LLM 模型由一个嵌入输入层和若干解码器层组成。每个解码器层包含一个自注意力模块。它通过带有权重矩阵 { W q , W k , W v } \{W_q, W_k, W_v\} {Wq,Wk,Wv} 的线性投影层将输入特征映射为一组查询、键和值 { q , k , v } \{q, k, v\} {q,k,v}。给定 { q , k , v } \{q, k, v\} {q,k,v},它计算输出 o o o

o = softmax ( q k T ) v o = \text{softmax}(qk^T)v o=softmax(qkT)v

输出随后通过一个权重矩阵 W o W_o Wo 的线性层进行投影,接着是多层感知机(MLP)层。在自注意力模块之前和之后,会应用层归一化。所有解码器层完成后,还会进行一次最终归一化。

对于较长的序列,自注意力在计算成本方面表现出困难,其计算复杂度与序列长度成平方关系。这大幅减慢了训练过程,并增加了 GPU 内存的使用成本。

Low-rank Adaptation

LoRA假设预训练模型中的权重更新在适配期间具有较低的内在秩(intrinsic rank)。对于一个预训练权重矩阵 W ∈ R d × k W \in \mathbb{R}^{d \times k} WRd×k,它通过低秩分解 W + Δ W = W + B A W + \Delta W = W + BA W+ΔW=W+BA 来更新,其中 B ∈ R d × r B \in \mathbb{R}^{d \times r} BRd×r A ∈ R r × k A \in \mathbb{R}^{r \times k} ARr×k。秩 r ≪ min ⁡ ( d , k ) r \ll \min(d, k) rmin(d,k)。在训练期间, W W W 被冻结(没有梯度更新),而 A A A B B B 是可训练的。这就是 LoRA 训练比完全微调更高效的原因。

在 Transformer 结构中,LoRA 仅适配注意力权重 { W q , W k , W v , W o } \{W_q, W_k, W_v, W_o\} {Wq,Wk,Wv,Wo},并冻结所有其他层,包括 MLP 和归一化层。这种方式简单且参数高效。然而,我们通过实验证明,仅在注意力权重中的低秩适配并不能很好地适用于长上下文扩展任务。

Shifted Sparse Attention

标准的自注意力计算成本为 O ( n 2 ) \mathcal{O}(n^2) O(n2),这使得长序列上的LLM具有高内存成本和低速。为了在训练期间避免这一问题,我们提出了移位稀疏注意力( S 2 S^2 S2-Attn),如图2所示。接下来,我们将进行一项初步研究,并逐步解释我们的设计。

Overview of LongLoRA. We introduce Shifted Sparse Attention (S2-Attn) during finetuning. The trained model retains original standard self-attention at inference time. In addition to training LoRA weights in linear layers, LongLoRA further makes embedding and normalization layers trainable. This extension is pivotal for context extension, and only introduces a minimal number of additional trainable parameters.

Pilot Study

在表1中,我们建立了一个标准基线,该基线经过完整注意力和微调训练和测试,在各种上下文长度下表现出一致的良好质量。第一次试验是使用短注意力进行训练,仅模式1如图2所示。正如我们所知,在长上下文中,高昂的成本主要来自自注意力模块。因此,在这次试验中,由于输入很长,我们在自注意力中将其分为几个组。例如,模型在训练和测试阶段都以8192个token作为输入,但在每个组中进行自注意力操作,组大小为2048,组数为4。这种模式效率很高,但在非常长的上下文中仍然不起作用,如表1所示。随着上下文长度的增加,困惑度变大。其背后的原因是没有不同组之间的信息交换。

Effectiveness of S2-Attn under different context lengths

为了引入组之间的通信,我们包括了一个移位模式,如图2所示。我们在半注意力头中将组分区移位半个组大小。以总体8192个上下文长度为例,在模式1中,第一组从第1个到第2048个token进行自注意力。在模式2中,组分区移位1024。第一个注意力组从第1025个开始到第3072个token结束,而前1024个和最后1024个token属于同一组。我们在每个半自注意力头中分别使用模式1和模式2。这种方式不会增加额外的计算成本,但能够实现不同组之间的信息流。我们在表1中展示了它接近标准注意力基线的结果。

Consistency to Full Attention

现有的高效注意力设计也可以提高长上下文LLM的效率。然而,大多数这些设计并不适合长上下文微调。因为这些从头开始训练的Transformer与预训练中使用的标准全注意力存在差距。在表6中,我们展示了 S 2 S^2 S2-Attn 不仅能够实现高效的微调,还支持全注意力测试。尽管其他注意力机制也可以用于长上下文微调,但模型必须使用微调期间使用的注意力进行测试。移位防止了模型对特定注意力模式的过度拟合。

Easy Implementation

S 2 S^2 S2-Attn易于实现。它仅涉及两个步骤:

  1. 在半注意力头中移位token;
  2. 将特征从token维度转置到批次维度。

两行代码就足够了。我们在算法1中提供了一个PyTorch风格的代码示例。

Algorithm 1: Pseudocode of S2-Attn in PyTorch-like style.

Improved LoRA for Long Context

LoRA是一种高效且流行的将LLMs适应其他数据集的方法。与全微调相比,它节省了大量的可训练参数和内存成本。然而,将LLMs从短上下文长度适应到长上下文长度并不容易。我们观察到LoRA与全微调之间存在明显的差距。如表2所示,随着目标上下文长度的增加,LoRA与全微调之间的差距逐渐增大。并且,具有更大秩的LoRA无法缩小这个差距。

Finetuning normalization and embedding layers is crucial for low-rank long-context adaptation

为了弥合这一差距,我们为训练打开了嵌入层和归一化层。如表2所示,它们占用的参数有限,但对长上下文适应有显著效果。特别是对于归一化层,参数在整个Llama2 7B中仅占0.004%。在实验中,我们将这种改进版的LoRA称为LoRA+。

实验

主实验

Perplexity evaluation on proof-pile (Rae et al., 2020) test split

Maximum context length that we can fine-tune for various model sizes on a single 8× A100 machine

Topic retrieval evaluation with LongChat (Li et al., 2023)

Accuracy comparison on passkey retrieval between Llama2 7B and our 7B model fine-tuned on 32768 context length

消融实验

Ablation on fine-tuning steps in both full fine-tuning and LoRA+

Comparisons among S2-Attn and alternative attention patterns during fine-tuning

总结

在这项工作中,我们提出了LongLoRA,它能够高效地扩展LLMs的上下文长度,使其显著更大。

与标准全微调相比,LongLoRA具有更低的GPU内存成本和训练时间,同时精度损失最小。

在架构层面,我们提出了 S 2 S^2 S2-Attn,用于在训练过程中近似标准自注意力模式。 S 2 S^2 S2-Attn 易于实现,仅需两行代码。

此外,通过 S 2 S^2 S2-Attn 训练的模型在推理过程中保留了原始的标准注意力架构,使得大多数现有基础设施和优化可以重用。

在训练层面,我们通过可训练归一化和嵌入弥合了LoRA与全微调之间的差距。

我们的方法可以将Llama2 7B扩展到100k上下文长度,将70B模型扩展到32k上下文长度,在单个8×A100机器上实现。

我们还提出了一个长指令遵循数据集LongAlpaca,并使用LongLoRA进行了监督微调。

我们相信LongLoRA是一种通用方法,可以与更多类型的LLMs和位置编码兼容。

我们计划在未来工作中调查这些问题。

相关文章:

【论文笔记】LongLoRA: Efficient Fine-tuning of Long-Context Large Language Models

🍎个人主页:小嗷犬的个人主页 🍊个人网站:小嗷犬的技术小站 🥭个人信条:为天地立心,为生民立命,为往圣继绝学,为万世开太平。 基本信息 标题: LongLoRA: Efficient Fine…...

数据挖掘——朴素贝叶斯分类

数据挖掘——朴素贝叶斯分类 朴素贝叶斯分类极大后验假设独立性假设贝叶斯分类器总结 朴素贝叶斯分类 什么是分类? 找出描述和区分数据类或概念的模型,以便能够使用模型预测未知的对象的类标号 概念区分 分类与回归 分类是预测分类(离散、…...

unity中的UI系统---GUI

一、工作原理和主要作用 1.GUI是什么? 即即时模式游戏用户交互界面(IMGUI),在unity中一般简称为GUI,它是一个代码驱动的UI系统。 2.GUI的主要作用 2.1作为程序员的调试工具,创建游戏内调测试工具 2.2为…...

鸿蒙Flutter实战:15-Flutter引擎Impeller鸿蒙化、性能优化与未来

Flutter 技术原理 Flutter 是一个主流的跨平台应用开发框架,基于 Dart 语言开发 UI 界面,它将描述界面的 Dart 代码直接编译成机器码,并使用渲染引擎调用 GPU/CPU 渲染。 渲染引擎的优势 使用自己的渲染引擎,这也是 Flutter 与其…...

C语言冒泡排序教程简介

冒泡排序(Bubble Sort)是一种简单的排序算法,因其工作原理像气泡一样逐渐上浮而得名。其基本思想是通过一轮一轮地比较相邻的元素,将较大的元素逐步“冒泡”到数组的尾部。 在本篇博客中,我们将详细讲解冒泡排序的基本…...

Fabric链码部署测试

参考链接:运行 Fabric 应用程序 — Hyperledger Fabric Docs 主文档 (hyperledger-fabric.readthedocs.io) (2)fabric2.4.3部署运行自己的链码 - 知乎 (zhihu.com) Fabric2.0测试网络部署链码 - 辉哥哥~ - 博客园 (cnblogs.com) 1.启动测试…...

k620老显卡,装cuda.等。

CUDA安装教程(超详细)-CSDN博客 1.下载支持12.0以上的驱动 NVIDIA RTX Driver Release 550 R550 U12 (553.50) | Windows 11 解压。安装。一路下一步。查看结果 2.下载 cuda CUDA Toolkit Archive | NVIDIA Developer 安装cuda时,第一次…...

网站常用功能模块-鉴权

一:JWT是什么? 常用鉴权方式有很多种,今天主要介绍基于token的鉴权方式JWT(Json JSON Web Token)。因为这种方式实现起来方便快捷。整体实现逻辑如下 第一次登陆时,前端携带账号和密码请求登录接口。服务…...

直接插入排序、折半插入排序、2路插入排序、希尔排序

本篇是排序专栏博客的第一篇,主要探讨以 “插入” 为核心思想的排序算法该如何实现 文章目录 一、前言二、直接插入排序1. 算法思想与操作分析2. 代码实现version 1version 2 3. 复杂度分析 三、折半插入排序1. 算法思想与操作分析2. 代码实现3. 复杂度分析 四、2路…...

FQ-GAN代码解析

主要看 model 、loss 和 data 部分如何实现和处理的。 model—VQ_modelsVQModelEncoderVectorQuantizerDecoder loss—VQLoss_triple_codebook model—VQ_models 创建vq_model直接根据传入的模型压缩倍率8/16初始化对应的VQ_8/VQ_16,两者都是初始化一个VQModel的类…...

如何恢复已删除的 Telegram 消息 [iOSamp;Android]

Telegram 是一款功能强大的消息应用程序,因其易用性、隐私保护和众多炫酷功能而深受用户喜爱。然而,有时我们会不小心删除重要的消息。在这种情况下你应该做什么? 本文将为您提供简单有效的解决方案来恢复 Telegram 上已删除的消息&#xff…...

asp.net core中的 Cookie 和 Session

在 Web 开发中,用户会话管理是非常重要的,尤其是在需要保持用户状态和身份验证的应用中。ASP.NET Core 提供了多种状态管理技术,如 Cookie 和 Session,它们可以帮助你管理用户会话、存储数据并实现用户身份验证等功能。下面将详细…...

Python实现一个简单的 HTTP echo 服务器

一个用来做测试的简单的 HTTP echo 服务器。 from http.server import HTTPServer, BaseHTTPRequestHandler import jsonclass EchoHandler(BaseHTTPRequestHandler):def do_GET(self):# 构造响应数据response_data {path: self.path,method: GET,headers: dict(self.headers…...

Ruby 中文编码

Ruby 中文编码 在 Ruby 编程语言中处理中文编码是一个常见的需求,尤其是在中国和其他使用中文的地区。Ruby 是一种动态、开放源代码的编程语言,它支持多种字符编码,包括中文编码。本文将探讨在 Ruby 中处理中文编码的几种方法,以…...

淘金优化算法的信息共享与更新机制改进

淘金优化算法作为一种模拟自然界淘金过程的启发式搜索算法,在解决复杂优化问题时展现出独特优势。然而,其性能在很大程度上依赖于信息共享与更新机制的有效性。传统机制在面对高维、多模态等复杂问题时,往往存在信息交流不畅、更新滞后等问题,导致算法陷入局部最优或收敛速…...

Python中的ast.literal_eval:安全地解析字符串为Python对象

Python中的ast.literal_eval:安全地解析字符串为Python对象 什么是ast.literal_eval?为什么说它是“安全”的? 如何使用ast.literal_eval?示例1:将字符串转换为列表示例2:将字符串转换为字典示例3&#xff…...

【AI数学基础】线性代数:内积和范数

(观前提醒,这是工科AI相关的数学基础的学习笔记,不是数学专业的文章,所以没有严谨的证明和定义,数院大神请勿批评) 2. 内积和范数 2.1 内积的定义 从代数的角度来说,内积是两个向量之间的一种…...

Go语言的 的泛型(Generics)核心知识

Go语言的泛型(Generics)核心知识 引言 在编程语言的发展历程中,泛型是一项重要的特性。它使得程序员能够编写更加灵活和可重用的代码,减少了代码重复,提高了类型安全性和性能。从最初的C和Java,到现代的R…...

C++vector

1. vector 的介绍及使用 1.1vector的介绍 vector的文档介绍 1.vector是表示可变大小数组的序列容器 2.就像数组一样,vector也采用的连续存储空间来存储元素,也就是意味着可以采用下标对vector 的元素进行访问,和数组一样高效但是又不像数组…...

如何配置【Docker镜像】加速器+【Docker镜像】的使用

一、配置Docker镜像加速器 1. 安装/升级容器引擎客户端​ 推荐安装1.11.2以上版本的容器引擎客户端 2. 配置镜像加速器​ 针对容器引擎客户端版本大于1.11.2的用户 以root用户登录容器引擎所在的虚拟机 修改 "/etc/docker/daemon.json" 文件(如果没有…...

IDEA运行Tomcat出现乱码问题解决汇总

最近正值期末周,有很多同学在写期末Java web作业时,运行tomcat出现乱码问题,经过多次解决与研究,我做了如下整理: 原因: IDEA本身编码与tomcat的编码与Windows编码不同导致,Windows 系统控制台…...

HTML 语义化

目录 HTML 语义化HTML5 新特性HTML 语义化的好处语义化标签的使用场景最佳实践 HTML 语义化 HTML5 新特性 标准答案&#xff1a; 语义化标签&#xff1a; <header>&#xff1a;页头<nav>&#xff1a;导航<main>&#xff1a;主要内容<article>&#x…...

云启出海,智联未来|阿里云网络「企业出海」系列客户沙龙上海站圆满落地

借阿里云中企出海大会的东风&#xff0c;以**「云启出海&#xff0c;智联未来&#xff5c;打造安全可靠的出海云网络引擎」为主题的阿里云企业出海客户沙龙云网络&安全专场于5.28日下午在上海顺利举办&#xff0c;现场吸引了来自携程、小红书、米哈游、哔哩哔哩、波克城市、…...

HBuilderX安装(uni-app和小程序开发)

下载HBuilderX 访问官方网站&#xff1a;https://www.dcloud.io/hbuilderx.html 根据您的操作系统选择合适版本&#xff1a; Windows版&#xff08;推荐下载标准版&#xff09; Windows系统安装步骤 运行安装程序&#xff1a; 双击下载的.exe安装文件 如果出现安全提示&…...

UR 协作机器人「三剑客」:精密轻量担当(UR7e)、全能协作主力(UR12e)、重型任务专家(UR15)

UR协作机器人正以其卓越性能在现代制造业自动化中扮演重要角色。UR7e、UR12e和UR15通过创新技术和精准设计满足了不同行业的多样化需求。其中&#xff0c;UR15以其速度、精度及人工智能准备能力成为自动化领域的重要突破。UR7e和UR12e则在负载规格和市场定位上不断优化&#xf…...

【Oracle】分区表

个人主页&#xff1a;Guiat 归属专栏&#xff1a;Oracle 文章目录 1. 分区表基础概述1.1 分区表的概念与优势1.2 分区类型概览1.3 分区表的工作原理 2. 范围分区 (RANGE Partitioning)2.1 基础范围分区2.1.1 按日期范围分区2.1.2 按数值范围分区 2.2 间隔分区 (INTERVAL Partit…...

Rapidio门铃消息FIFO溢出机制

关于RapidIO门铃消息FIFO的溢出机制及其与中断抖动的关系&#xff0c;以下是深入解析&#xff1a; 门铃FIFO溢出的本质 在RapidIO系统中&#xff0c;门铃消息FIFO是硬件控制器内部的缓冲区&#xff0c;用于临时存储接收到的门铃消息&#xff08;Doorbell Message&#xff09;。…...

OPenCV CUDA模块图像处理-----对图像执行 均值漂移滤波(Mean Shift Filtering)函数meanShiftFiltering()

操作系统&#xff1a;ubuntu22.04 OpenCV版本&#xff1a;OpenCV4.9 IDE:Visual Studio Code 编程语言&#xff1a;C11 算法描述 在 GPU 上对图像执行 均值漂移滤波&#xff08;Mean Shift Filtering&#xff09;&#xff0c;用于图像分割或平滑处理。 该函数将输入图像中的…...

Java求职者面试指南:Spring、Spring Boot、MyBatis框架与计算机基础问题解析

Java求职者面试指南&#xff1a;Spring、Spring Boot、MyBatis框架与计算机基础问题解析 一、第一轮提问&#xff08;基础概念问题&#xff09; 1. 请解释Spring框架的核心容器是什么&#xff1f;它在Spring中起到什么作用&#xff1f; Spring框架的核心容器是IoC容器&#…...

Spring AI Chat Memory 实战指南:Local 与 JDBC 存储集成

一个面向 Java 开发者的 Sring-Ai 示例工程项目&#xff0c;该项目是一个 Spring AI 快速入门的样例工程项目&#xff0c;旨在通过一些小的案例展示 Spring AI 框架的核心功能和使用方法。 项目采用模块化设计&#xff0c;每个模块都专注于特定的功能领域&#xff0c;便于学习和…...