LLaMA长度外推高性价比trick:线性插值法及相关改进源码阅读及相关记录
前言
最近,开源了可商用的llama2,支持长度相比llama1的1024,拓展到了4096长度,然而,相比GPT-4、Claude-2等支持的长度,llama的长度外推显得尤为重要,本文记录了三种网络开源的RoPE改进方式及相关源码的阅读。
关于长度外推性:https://kexue.fm/archives/9431
关于RoPE:https://kexue.fm/archives/8265
1、线性插值法
论文:EXTENDING CONTEXT WINDOW OF LARGE LANGUAGE MODELS VIA POSITION INTERPOLATION
链接:https://arxiv.org/pdf/2306.15595.pdf
思想:不进行长度外推,而是直接缩小位置索引。即:将4096的位置编码通过线性插值法压缩到2048内,这样只需在少量的4096长度的数据上继续预训练,便可达到不错的效果。

源码阅读(附注释):
class LlamaLinearScaledRotaryEmbedding(torch.nn.Module):def __init__(self, dim, max_position_embeddings=2048, base=10000, scale=1, device=None):super().__init__()# 相比RoPE增加scale参数self.scale = scale# inv_freq为基值向量inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim))self.register_buffer("inv_freq", inv_freq)# Build here to make `torch.jit.trace` work.self.max_seq_len_cached = max_position_embeddings# 构建max_seq_len_cached大小的张量tt = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=self.inv_freq.dtype)# 张量t归一化,RoPE没有这一步t /= self.scale# einsum计算频率矩阵# 'i, j->i j’表示分别输入尺寸为[i]、[j]的向量,做笛卡尔运算得到尺寸为[i, j]的矩阵。freqs = torch.einsum("i,j->ij", t, self.inv_freq)# Different from paper, but it uses a different permutation in order to obtain the same calculation# 在-1维做一次拷贝、拼接emb = torch.cat((freqs, freqs), dim=-1)dtype = torch.get_default_dtype()# 注册为模型的缓冲区cos_cached和sin_cachedself.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False)self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False)def forward(self, x, seq_len=None):# x: [bs, num_attention_heads, seq_len, head_size]# This `if` block is unlikely to be run after we build sin/cos in `__init__`. Keep the logic here just in case.# seq_len为序列长度,seq_len大于max_seq_len_cached,则重新计算频率矩阵,并更新cos_cached和sin_cached的缓冲区if seq_len > self.max_seq_len_cached:self.max_seq_len_cached = seq_lent = torch.arange(self.max_seq_len_cached, device=x.device, dtype=self.inv_freq.dtype)t /= self.scalefreqs = torch.einsum("i,j->ij", t, self.inv_freq)# Different from paper, but it uses a different permutation in order to obtain the same calculationemb = torch.cat((freqs, freqs), dim=-1).to(x.device)self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(x.dtype), persistent=False)self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(x.dtype), persistent=False)# 长度裁剪:返回cos_cached和sin_cached中与seq_len(序列长度)return (self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype),)
线性插值法的相关实验效果:https://lmsys.org/blog/2023-06-29-longchat/
2、NTK插值法
NTK插值改进llama中使用的RoPE插值方法,同样,对于RoPE代码改动更小,其他地方与线性插值法实现一致。
reddit原帖:NTK-Aware Scaled RoPE allows LLaMA models to have extended (8k+) context size without any fine-tuning and minimal perplexity degradation
链接:https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/?rdt=58346
源码阅读:
class LlamaNTKScaledRotaryEmbedding(torch.nn.Module):def __init__(self, dim, max_position_embeddings=2048, base=10000, alpha=1, device=None):super().__init__()# 与线性插值法相比,实现更简单,alpha仅用来改变basebase = base * alpha ** (dim / (dim-2))inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim))self.register_buffer("inv_freq", inv_freq)# Build here to make `torch.jit.trace` work.self.max_seq_len_cached = max_position_embeddingst = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=self.inv_freq.dtype)freqs = torch.einsum("i,j->ij", t, self.inv_freq)# Different from paper, but it uses a different permutation in order to obtain the same calculationemb = torch.cat((freqs, freqs), dim=-1)dtype = torch.get_default_dtype()self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False)self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False)def forward(self, x, seq_len=None):# x: [bs, num_attention_heads, seq_len, head_size]# This `if` block is unlikely to be run after we build sin/cos in `__init__`. Keep the logic here just in case.if seq_len > self.max_seq_len_cached:self.max_seq_len_cached = seq_lent = torch.arange(self.max_seq_len_cached, device=x.device, dtype=self.inv_freq.dtype)freqs = torch.einsum("i,j->ij", t, self.inv_freq)# Different from paper, but it uses a different permutation in order to obtain the same calculationemb = torch.cat((freqs, freqs), dim=-1).to(x.device)self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(x.dtype), persistent=False)self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(x.dtype), persistent=False)return (self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype),)
3、动态插值法
动态插值法又是对NTK插值法和线性插值法的改进,可以看作是上述两者的一种结合思想,旨在减少困惑度损失并实现更大的缩放。
reddit原帖:Dynamically Scaled RoPE further increases performance of long context LLaMA with zero fine-tuning
链接:https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases/
源码阅读:
class LlamaDynamicScaledRotaryEmbedding(torch.nn.Module):def __init__(self, dim, max_position_embeddings=2048, base=10000, ntk=False, device=None):super().__init__()# 是否开启NTK(Neural Tangent Kernel)self.ntk = ntkself.base = baseself.dim = dimself.max_position_embeddings = max_position_embeddings# inv_freq为基值向量inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim))self.register_buffer("inv_freq", inv_freq)# Build here to make `torch.jit.trace` work.self.max_seq_len_cached = max_position_embeddingst = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=self.inv_freq.dtype)freqs = torch.einsum("i,j->ij", t, self.inv_freq)# Different from paper, but it uses a different permutation in order to obtain the same calculation# emb:[max_seq_len_cached, dim]emb = torch.cat((freqs, freqs), dim=-1)dtype = torch.get_default_dtype()self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False)self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False)def forward(self, x, seq_len=None):# x: [bs, num_attention_heads, seq_len, head_size]# This `if` block is unlikely to be run after we build sin/cos in `__init__`. Keep the logic here just in case.if seq_len > self.max_seq_len_cached:self.max_seq_len_cached = seq_lenif self.ntk:base = self.base * ((self.ntk * seq_len / self.max_position_embeddings) - (self.ntk - 1)) ** (self.dim / (self.dim-2))# 计算新的inv_freqinv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2).float().to(x.device) / self.dim))self.register_buffer("inv_freq", inv_freq)t = torch.arange(self.max_seq_len_cached, device=x.device, dtype=self.inv_freq.dtype)if not self.ntk:# 缩放t *= self.max_position_embeddings / seq_len# 得到新的频率矩阵freqsfreqs = torch.einsum("i,j->ij", t, self.inv_freq)# Different from paper, but it uses a different permutation in order to obtain the same calculation# freqs与自身拼接得到新的embemb = torch.cat((freqs, freqs), dim=-1).to(x.device)# 注册为模型的缓冲区cos_cached和sin_cachedself.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(x.dtype), persistent=False)self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(x.dtype), persistent=False)# 长度裁剪return (self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype),)
网友对于困惑度的实验并取得了一定的效果:https://github.com/turboderp/exllama/pull/118
总结
本文介绍了llama通过线性插值法及相关改进方案进行长度外推的trcik,并对相关源码阅读及网络资源进行记录,个人粗浅认为,相比LongLLaMA,基于线性插值法+Finetune的方式,是一种高性价比的长度外推方案。
相关文章:
LLaMA长度外推高性价比trick:线性插值法及相关改进源码阅读及相关记录
前言 最近,开源了可商用的llama2,支持长度相比llama1的1024,拓展到了4096长度,然而,相比GPT-4、Claude-2等支持的长度,llama的长度外推显得尤为重要,本文记录了三种网络开源的RoPE改进方式及相…...
中国信息安全测评中心CISP家族认证一览
随着国家对网络安全的重视,中国信息安全测评中心根据国家政策、未来趋势、重点内容陆续增添了很多CISP细分认证。 今日份详细介绍,部分CISP及其子品牌相关认证内容,一定要收藏哟! 校园版CISP NISP国家信息安全水平考试ÿ…...
牛客网【面试必刷TOP101】~ 06 递归/回溯
牛客网【面试必刷TOP101】~ 06 递归/回溯 文章目录 牛客网【面试必刷TOP101】~ 06 递归/回溯[toc]BM55 没有重复项数字的全排列(★★)BM56 有重复项数字的全排列(★★)BM57 岛屿数量(★★)BM58 字符串的排列(★★)BM59 N皇后问题(★★★)BM60 括号生成(★★)BM61 矩阵最长递增路…...
ArcGIS Pro基础:【划分】工具实现等比例、等面积、等宽度划分图形操作
本次介绍【划分】工具的使用,如下所示,为该工具所处位置。使用该工具可以实现对某个图斑的等比例面积划分、相等面积划分和相等宽度划分。 【等比例面积】:其操作如下所示,其中: 1表示先选中待处理的图斑,2…...
括号匹配问题:栈的巧妙应用与代码优化【栈、优化、哈希表】
当解决算法问题时,灵活使用数据结构是至关重要的。在这个问题中,我们需要判断一个只包含括号的字符串是否有效,即括号是否能够正确匹配和闭合。使用栈这一数据结构可以很好地解决这个问题。 题目链接:有效的括号 解题思路…...
vue项目正确使用样式deep穿透
经常开发前端的程序员应该都知道前端一般都是组件化开发,为了避免样式污染通常会使用scoped添加属性选择器,此时如果我们想在父组件中修改子组件的样式便成了难题。其实,我们可以通过以下几种方式修改子组件样式, 组件样式穿透 …...
Jenkins持续集成-快速上手
Jenkins持续集成-快速上手 注:Jenkins一般不单独使用,而是需要依赖代码仓库,构建工具等。 搭配组合:GitGitee(GitHub、GitLab)MavenJenkins 前置准备 常见安装方式: war包Docker容器实例&…...
查看linux 所有运行的应用和端口命令
要查看 Linux 中所有运行的应用程序及其对应的端口,可以使用以下命令: 1. 使用 netstat 命令(已被弃用,建议使用 ss 命令): netstat -tuln 这会显示当前系统上所有打开的网络连接和监听的端口。其中&#…...
Maven安装与配置,Eclipse配置Maven【图文并茂的保姆级教程】
🥳🥳Welcome Huihuis Code World ! !🥳🥳 接下来看看由辉辉所写的关于Maven的相关操作吧 目录 🥳🥳Welcome Huihuis Code World ! !🥳🥳 一.Maven是什么? 二.Maven的下…...
利用XLL文件投递Qbot银行木马的钓鱼活动分析
1概述 近期,安天CERT发现了一起利用恶意Microsoft Excel加载项(XLL)文件投递Qbot银行木马的恶意活动。攻击者通过发送垃圾邮件来诱导用户打开附件中的XLL文件,一旦用户安装并激活Microsoft Excel加载项,恶意代码将被执…...
2023CNSS——WEB题解(持续更新)
[Baby] SignIn 进来看到 按钮点击不了,想到去修改代码,要“检查“,但这里的右键和F12都不可用 还好还有其他方法 检查的各种方法 选用一种后进入检查页面 删掉这里的disabled即可 点击后得到flag [Baby] Backdoor 进入,…...
Unity之ShaderGraph 节点介绍 数学节点
数学 高级Absolute(绝对值)Exponential(幂)Length(长度)Log(对数)Modulo(余数)Negate(相反数)Normalize(标准化矢量&…...
springboot mongo 使用
nosql对我来说,就是用它的变动列,如果列是固定的,我为什么不用mysql这种关系型数据库呢? 所以,现在网上搜出来的大部分,用实体类去接的做法,并不适合我的需求。 所以,整理记录一下…...
如何使用appuploader制作apple证书
转载:如何使用appuploader制作apple证书 如何使用appuploader制作apple证书 一.证书管理 点击首页的证书管理 二.新建证书 点击“添加”,新建一个证书文件 免费账号制作证书只有7天有效期,没有推送消息功能,推送证书…...
Promise详细版
promise基础原理到难点分析 常见的Promise的方法解读 扩展async和await深入分析 逐步分析Promise底层逻辑代码 一、Promise基础 1.什么是promise 为了解决回调地狱: //2.设置点击事件btn.onclick function() {//3.创建ajax实例化对象let xhr new XMLHttpRe…...
v-for循环生成的盒子只改变当前选中的盒子的样式
1.给盒子添加动态属性:class"[index isActive?active-box:choose-box]" <div v-for"(item,index) in zyList" :key"item.sid" :class"[index isActive?active-box:choose-box]" click"getKmList(item,index)"…...
Spring Data JPA源码
导读: 什么是Spring Data JPA? 要解释这个问题,我们先将Spring Data JPA拆成两个部分,即Sping Data和JPA。 从这两个部分来解释。 Spring Data是什么? 摘自: https://spring.io/projects/spring-data Spring Data’s mission is to provide a familiar and cons…...
如何防止CSRF攻击
背景 随着互联网的高速发展,信息安全问题已经成为企业最为关注的焦点之一,而前端又是引发企业安全问题的高危据点。在移动互联网时代,前端人员除了传统的 XSS、CSRF 等安全问题之外,又时常遭遇网络劫持、非法调用 Hybrid API 等新…...
linuxARM裸机学习笔记(7)----RTC实时时钟实验
基础概念: I.MX6U 内部也有个RTC 模块,但是不叫作“ RTC ”,而是叫做“ SNVS ”。 SNVS 直译过来就是安全的非易性存储, SNVS 里面主要是一些低功耗的外设,包括一个 安全的实时计数器 (RTC) 、一个单调计数器 (mo…...
NSS [UUCTF 2022 新生赛]ez_upload
NSS [UUCTF 2022 新生赛]ez_upload 考点:Apache解析漏洞 开题就是标准的上传框 起手式就是传入一个php文件,非常正常的有过滤。 .txt、.user.ini、.txxx都被过滤了,应该是白名单或者黑名单加MIME过滤,只允许.jpg、.png。 猜测二…...
Unity3D中Gfx.WaitForPresent优化方案
前言 在Unity中,Gfx.WaitForPresent占用CPU过高通常表示主线程在等待GPU完成渲染(即CPU被阻塞),这表明存在GPU瓶颈或垂直同步/帧率设置问题。以下是系统的优化方案: 对惹,这里有一个游戏开发交流小组&…...
8k长序列建模,蛋白质语言模型Prot42仅利用目标蛋白序列即可生成高亲和力结合剂
蛋白质结合剂(如抗体、抑制肽)在疾病诊断、成像分析及靶向药物递送等关键场景中发挥着不可替代的作用。传统上,高特异性蛋白质结合剂的开发高度依赖噬菌体展示、定向进化等实验技术,但这类方法普遍面临资源消耗巨大、研发周期冗长…...
MODBUS TCP转CANopen 技术赋能高效协同作业
在现代工业自动化领域,MODBUS TCP和CANopen两种通讯协议因其稳定性和高效性被广泛应用于各种设备和系统中。而随着科技的不断进步,这两种通讯协议也正在被逐步融合,形成了一种新型的通讯方式——开疆智能MODBUS TCP转CANopen网关KJ-TCPC-CANP…...
OPenCV CUDA模块图像处理-----对图像执行 均值漂移滤波(Mean Shift Filtering)函数meanShiftFiltering()
操作系统:ubuntu22.04 OpenCV版本:OpenCV4.9 IDE:Visual Studio Code 编程语言:C11 算法描述 在 GPU 上对图像执行 均值漂移滤波(Mean Shift Filtering),用于图像分割或平滑处理。 该函数将输入图像中的…...
HashMap中的put方法执行流程(流程图)
1 put操作整体流程 HashMap 的 put 操作是其最核心的功能之一。在 JDK 1.8 及以后版本中,其主要逻辑封装在 putVal 这个内部方法中。整个过程大致如下: 初始判断与哈希计算: 首先,putVal 方法会检查当前的 table(也就…...
【JavaSE】多线程基础学习笔记
多线程基础 -线程相关概念 程序(Program) 是为完成特定任务、用某种语言编写的一组指令的集合简单的说:就是我们写的代码 进程 进程是指运行中的程序,比如我们使用QQ,就启动了一个进程,操作系统就会为该进程分配内存…...
Razor编程中@Html的方法使用大全
文章目录 1. 基础HTML辅助方法1.1 Html.ActionLink()1.2 Html.RouteLink()1.3 Html.Display() / Html.DisplayFor()1.4 Html.Editor() / Html.EditorFor()1.5 Html.Label() / Html.LabelFor()1.6 Html.TextBox() / Html.TextBoxFor() 2. 表单相关辅助方法2.1 Html.BeginForm() …...
Web中间件--tomcat学习
Web中间件–tomcat Java虚拟机详解 什么是JAVA虚拟机 Java虚拟机是一个抽象的计算机,它可以执行Java字节码。Java虚拟机是Java平台的一部分,Java平台由Java语言、Java API和Java虚拟机组成。Java虚拟机的主要作用是将Java字节码转换为机器代码&#x…...
解析奥地利 XARION激光超声检测系统:无膜光学麦克风 + 无耦合剂的技术协同优势及多元应用
在工业制造领域,无损检测(NDT)的精度与效率直接影响产品质量与生产安全。奥地利 XARION开发的激光超声精密检测系统,以非接触式光学麦克风技术为核心,打破传统检测瓶颈,为半导体、航空航天、汽车制造等行业提供了高灵敏…...
Python 高效图像帧提取与视频编码:实战指南
Python 高效图像帧提取与视频编码:实战指南 在音视频处理领域,图像帧提取与视频编码是基础但极具挑战性的任务。Python 结合强大的第三方库(如 OpenCV、FFmpeg、PyAV),可以高效处理视频流,实现快速帧提取、压缩编码等关键功能。本文将深入介绍如何优化这些流程,提高处理…...
