LLaMA长度外推高性价比trick:线性插值法及相关改进源码阅读及相关记录
前言
最近,开源了可商用的llama2,支持长度相比llama1的1024,拓展到了4096长度,然而,相比GPT-4、Claude-2等支持的长度,llama的长度外推显得尤为重要,本文记录了三种网络开源的RoPE改进方式及相关源码的阅读。
关于长度外推性:https://kexue.fm/archives/9431
关于RoPE:https://kexue.fm/archives/8265
1、线性插值法
论文:EXTENDING CONTEXT WINDOW OF LARGE LANGUAGE MODELS VIA POSITION INTERPOLATION
链接:https://arxiv.org/pdf/2306.15595.pdf
思想:不进行长度外推,而是直接缩小位置索引。即:将4096的位置编码通过线性插值法压缩到2048内,这样只需在少量的4096长度的数据上继续预训练,便可达到不错的效果。

源码阅读(附注释):
class LlamaLinearScaledRotaryEmbedding(torch.nn.Module):def __init__(self, dim, max_position_embeddings=2048, base=10000, scale=1, device=None):super().__init__()# 相比RoPE增加scale参数self.scale = scale# inv_freq为基值向量inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim))self.register_buffer("inv_freq", inv_freq)# Build here to make `torch.jit.trace` work.self.max_seq_len_cached = max_position_embeddings# 构建max_seq_len_cached大小的张量tt = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=self.inv_freq.dtype)# 张量t归一化,RoPE没有这一步t /= self.scale# einsum计算频率矩阵# 'i, j->i j’表示分别输入尺寸为[i]、[j]的向量,做笛卡尔运算得到尺寸为[i, j]的矩阵。freqs = torch.einsum("i,j->ij", t, self.inv_freq)# Different from paper, but it uses a different permutation in order to obtain the same calculation# 在-1维做一次拷贝、拼接emb = torch.cat((freqs, freqs), dim=-1)dtype = torch.get_default_dtype()# 注册为模型的缓冲区cos_cached和sin_cachedself.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False)self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False)def forward(self, x, seq_len=None):# x: [bs, num_attention_heads, seq_len, head_size]# This `if` block is unlikely to be run after we build sin/cos in `__init__`. Keep the logic here just in case.# seq_len为序列长度,seq_len大于max_seq_len_cached,则重新计算频率矩阵,并更新cos_cached和sin_cached的缓冲区if seq_len > self.max_seq_len_cached:self.max_seq_len_cached = seq_lent = torch.arange(self.max_seq_len_cached, device=x.device, dtype=self.inv_freq.dtype)t /= self.scalefreqs = torch.einsum("i,j->ij", t, self.inv_freq)# Different from paper, but it uses a different permutation in order to obtain the same calculationemb = torch.cat((freqs, freqs), dim=-1).to(x.device)self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(x.dtype), persistent=False)self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(x.dtype), persistent=False)# 长度裁剪:返回cos_cached和sin_cached中与seq_len(序列长度)return (self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype),)
线性插值法的相关实验效果:https://lmsys.org/blog/2023-06-29-longchat/
2、NTK插值法
NTK插值改进llama中使用的RoPE插值方法,同样,对于RoPE代码改动更小,其他地方与线性插值法实现一致。
reddit原帖:NTK-Aware Scaled RoPE allows LLaMA models to have extended (8k+) context size without any fine-tuning and minimal perplexity degradation
链接:https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/?rdt=58346
源码阅读:
class LlamaNTKScaledRotaryEmbedding(torch.nn.Module):def __init__(self, dim, max_position_embeddings=2048, base=10000, alpha=1, device=None):super().__init__()# 与线性插值法相比,实现更简单,alpha仅用来改变basebase = base * alpha ** (dim / (dim-2))inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim))self.register_buffer("inv_freq", inv_freq)# Build here to make `torch.jit.trace` work.self.max_seq_len_cached = max_position_embeddingst = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=self.inv_freq.dtype)freqs = torch.einsum("i,j->ij", t, self.inv_freq)# Different from paper, but it uses a different permutation in order to obtain the same calculationemb = torch.cat((freqs, freqs), dim=-1)dtype = torch.get_default_dtype()self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False)self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False)def forward(self, x, seq_len=None):# x: [bs, num_attention_heads, seq_len, head_size]# This `if` block is unlikely to be run after we build sin/cos in `__init__`. Keep the logic here just in case.if seq_len > self.max_seq_len_cached:self.max_seq_len_cached = seq_lent = torch.arange(self.max_seq_len_cached, device=x.device, dtype=self.inv_freq.dtype)freqs = torch.einsum("i,j->ij", t, self.inv_freq)# Different from paper, but it uses a different permutation in order to obtain the same calculationemb = torch.cat((freqs, freqs), dim=-1).to(x.device)self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(x.dtype), persistent=False)self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(x.dtype), persistent=False)return (self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype),)
3、动态插值法
动态插值法又是对NTK插值法和线性插值法的改进,可以看作是上述两者的一种结合思想,旨在减少困惑度损失并实现更大的缩放。
reddit原帖:Dynamically Scaled RoPE further increases performance of long context LLaMA with zero fine-tuning
链接:https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases/
源码阅读:
class LlamaDynamicScaledRotaryEmbedding(torch.nn.Module):def __init__(self, dim, max_position_embeddings=2048, base=10000, ntk=False, device=None):super().__init__()# 是否开启NTK(Neural Tangent Kernel)self.ntk = ntkself.base = baseself.dim = dimself.max_position_embeddings = max_position_embeddings# inv_freq为基值向量inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim))self.register_buffer("inv_freq", inv_freq)# Build here to make `torch.jit.trace` work.self.max_seq_len_cached = max_position_embeddingst = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=self.inv_freq.dtype)freqs = torch.einsum("i,j->ij", t, self.inv_freq)# Different from paper, but it uses a different permutation in order to obtain the same calculation# emb:[max_seq_len_cached, dim]emb = torch.cat((freqs, freqs), dim=-1)dtype = torch.get_default_dtype()self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False)self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False)def forward(self, x, seq_len=None):# x: [bs, num_attention_heads, seq_len, head_size]# This `if` block is unlikely to be run after we build sin/cos in `__init__`. Keep the logic here just in case.if seq_len > self.max_seq_len_cached:self.max_seq_len_cached = seq_lenif self.ntk:base = self.base * ((self.ntk * seq_len / self.max_position_embeddings) - (self.ntk - 1)) ** (self.dim / (self.dim-2))# 计算新的inv_freqinv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2).float().to(x.device) / self.dim))self.register_buffer("inv_freq", inv_freq)t = torch.arange(self.max_seq_len_cached, device=x.device, dtype=self.inv_freq.dtype)if not self.ntk:# 缩放t *= self.max_position_embeddings / seq_len# 得到新的频率矩阵freqsfreqs = torch.einsum("i,j->ij", t, self.inv_freq)# Different from paper, but it uses a different permutation in order to obtain the same calculation# freqs与自身拼接得到新的embemb = torch.cat((freqs, freqs), dim=-1).to(x.device)# 注册为模型的缓冲区cos_cached和sin_cachedself.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(x.dtype), persistent=False)self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(x.dtype), persistent=False)# 长度裁剪return (self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype),)
网友对于困惑度的实验并取得了一定的效果:https://github.com/turboderp/exllama/pull/118
总结
本文介绍了llama通过线性插值法及相关改进方案进行长度外推的trcik,并对相关源码阅读及网络资源进行记录,个人粗浅认为,相比LongLLaMA,基于线性插值法+Finetune的方式,是一种高性价比的长度外推方案。
相关文章:
LLaMA长度外推高性价比trick:线性插值法及相关改进源码阅读及相关记录
前言 最近,开源了可商用的llama2,支持长度相比llama1的1024,拓展到了4096长度,然而,相比GPT-4、Claude-2等支持的长度,llama的长度外推显得尤为重要,本文记录了三种网络开源的RoPE改进方式及相…...
中国信息安全测评中心CISP家族认证一览
随着国家对网络安全的重视,中国信息安全测评中心根据国家政策、未来趋势、重点内容陆续增添了很多CISP细分认证。 今日份详细介绍,部分CISP及其子品牌相关认证内容,一定要收藏哟! 校园版CISP NISP国家信息安全水平考试ÿ…...
牛客网【面试必刷TOP101】~ 06 递归/回溯
牛客网【面试必刷TOP101】~ 06 递归/回溯 文章目录 牛客网【面试必刷TOP101】~ 06 递归/回溯[toc]BM55 没有重复项数字的全排列(★★)BM56 有重复项数字的全排列(★★)BM57 岛屿数量(★★)BM58 字符串的排列(★★)BM59 N皇后问题(★★★)BM60 括号生成(★★)BM61 矩阵最长递增路…...
ArcGIS Pro基础:【划分】工具实现等比例、等面积、等宽度划分图形操作
本次介绍【划分】工具的使用,如下所示,为该工具所处位置。使用该工具可以实现对某个图斑的等比例面积划分、相等面积划分和相等宽度划分。 【等比例面积】:其操作如下所示,其中: 1表示先选中待处理的图斑,2…...
括号匹配问题:栈的巧妙应用与代码优化【栈、优化、哈希表】
当解决算法问题时,灵活使用数据结构是至关重要的。在这个问题中,我们需要判断一个只包含括号的字符串是否有效,即括号是否能够正确匹配和闭合。使用栈这一数据结构可以很好地解决这个问题。 题目链接:有效的括号 解题思路…...
vue项目正确使用样式deep穿透
经常开发前端的程序员应该都知道前端一般都是组件化开发,为了避免样式污染通常会使用scoped添加属性选择器,此时如果我们想在父组件中修改子组件的样式便成了难题。其实,我们可以通过以下几种方式修改子组件样式, 组件样式穿透 …...
Jenkins持续集成-快速上手
Jenkins持续集成-快速上手 注:Jenkins一般不单独使用,而是需要依赖代码仓库,构建工具等。 搭配组合:GitGitee(GitHub、GitLab)MavenJenkins 前置准备 常见安装方式: war包Docker容器实例&…...
查看linux 所有运行的应用和端口命令
要查看 Linux 中所有运行的应用程序及其对应的端口,可以使用以下命令: 1. 使用 netstat 命令(已被弃用,建议使用 ss 命令): netstat -tuln 这会显示当前系统上所有打开的网络连接和监听的端口。其中&#…...
Maven安装与配置,Eclipse配置Maven【图文并茂的保姆级教程】
🥳🥳Welcome Huihuis Code World ! !🥳🥳 接下来看看由辉辉所写的关于Maven的相关操作吧 目录 🥳🥳Welcome Huihuis Code World ! !🥳🥳 一.Maven是什么? 二.Maven的下…...
利用XLL文件投递Qbot银行木马的钓鱼活动分析
1概述 近期,安天CERT发现了一起利用恶意Microsoft Excel加载项(XLL)文件投递Qbot银行木马的恶意活动。攻击者通过发送垃圾邮件来诱导用户打开附件中的XLL文件,一旦用户安装并激活Microsoft Excel加载项,恶意代码将被执…...
2023CNSS——WEB题解(持续更新)
[Baby] SignIn 进来看到 按钮点击不了,想到去修改代码,要“检查“,但这里的右键和F12都不可用 还好还有其他方法 检查的各种方法 选用一种后进入检查页面 删掉这里的disabled即可 点击后得到flag [Baby] Backdoor 进入,…...
Unity之ShaderGraph 节点介绍 数学节点
数学 高级Absolute(绝对值)Exponential(幂)Length(长度)Log(对数)Modulo(余数)Negate(相反数)Normalize(标准化矢量&…...
springboot mongo 使用
nosql对我来说,就是用它的变动列,如果列是固定的,我为什么不用mysql这种关系型数据库呢? 所以,现在网上搜出来的大部分,用实体类去接的做法,并不适合我的需求。 所以,整理记录一下…...
如何使用appuploader制作apple证书
转载:如何使用appuploader制作apple证书 如何使用appuploader制作apple证书 一.证书管理 点击首页的证书管理 二.新建证书 点击“添加”,新建一个证书文件 免费账号制作证书只有7天有效期,没有推送消息功能,推送证书…...
Promise详细版
promise基础原理到难点分析 常见的Promise的方法解读 扩展async和await深入分析 逐步分析Promise底层逻辑代码 一、Promise基础 1.什么是promise 为了解决回调地狱: //2.设置点击事件btn.onclick function() {//3.创建ajax实例化对象let xhr new XMLHttpRe…...
v-for循环生成的盒子只改变当前选中的盒子的样式
1.给盒子添加动态属性:class"[index isActive?active-box:choose-box]" <div v-for"(item,index) in zyList" :key"item.sid" :class"[index isActive?active-box:choose-box]" click"getKmList(item,index)"…...
Spring Data JPA源码
导读: 什么是Spring Data JPA? 要解释这个问题,我们先将Spring Data JPA拆成两个部分,即Sping Data和JPA。 从这两个部分来解释。 Spring Data是什么? 摘自: https://spring.io/projects/spring-data Spring Data’s mission is to provide a familiar and cons…...
如何防止CSRF攻击
背景 随着互联网的高速发展,信息安全问题已经成为企业最为关注的焦点之一,而前端又是引发企业安全问题的高危据点。在移动互联网时代,前端人员除了传统的 XSS、CSRF 等安全问题之外,又时常遭遇网络劫持、非法调用 Hybrid API 等新…...
linuxARM裸机学习笔记(7)----RTC实时时钟实验
基础概念: I.MX6U 内部也有个RTC 模块,但是不叫作“ RTC ”,而是叫做“ SNVS ”。 SNVS 直译过来就是安全的非易性存储, SNVS 里面主要是一些低功耗的外设,包括一个 安全的实时计数器 (RTC) 、一个单调计数器 (mo…...
NSS [UUCTF 2022 新生赛]ez_upload
NSS [UUCTF 2022 新生赛]ez_upload 考点:Apache解析漏洞 开题就是标准的上传框 起手式就是传入一个php文件,非常正常的有过滤。 .txt、.user.ini、.txxx都被过滤了,应该是白名单或者黑名单加MIME过滤,只允许.jpg、.png。 猜测二…...
相机Camera日志实例分析之二:相机Camx【专业模式开启直方图拍照】单帧流程日志详解
【关注我,后续持续新增专题博文,谢谢!!!】 上一篇我们讲了: 这一篇我们开始讲: 目录 一、场景操作步骤 二、日志基础关键字分级如下 三、场景日志如下: 一、场景操作步骤 操作步…...
高等数学(下)题型笔记(八)空间解析几何与向量代数
目录 0 前言 1 向量的点乘 1.1 基本公式 1.2 例题 2 向量的叉乘 2.1 基础知识 2.2 例题 3 空间平面方程 3.1 基础知识 3.2 例题 4 空间直线方程 4.1 基础知识 4.2 例题 5 旋转曲面及其方程 5.1 基础知识 5.2 例题 6 空间曲面的法线与切平面 6.1 基础知识 6.2…...
USB Over IP专用硬件的5个特点
USB over IP技术通过将USB协议数据封装在标准TCP/IP网络数据包中,从根本上改变了USB连接。这允许客户端通过局域网或广域网远程访问和控制物理连接到服务器的USB设备(如专用硬件设备),从而消除了直接物理连接的需要。USB over IP的…...
听写流程自动化实践,轻量级教育辅助
随着智能教育工具的发展,越来越多的传统学习方式正在被数字化、自动化所优化。听写作为语文、英语等学科中重要的基础训练形式,也迎来了更高效的解决方案。 这是一款轻量但功能强大的听写辅助工具。它是基于本地词库与可选在线语音引擎构建,…...
LeetCode - 199. 二叉树的右视图
题目 199. 二叉树的右视图 - 力扣(LeetCode) 思路 右视图是指从树的右侧看,对于每一层,只能看到该层最右边的节点。实现思路是: 使用深度优先搜索(DFS)按照"根-右-左"的顺序遍历树记录每个节点的深度对于…...
佰力博科技与您探讨热释电测量的几种方法
热释电的测量主要涉及热释电系数的测定,这是表征热释电材料性能的重要参数。热释电系数的测量方法主要包括静态法、动态法和积分电荷法。其中,积分电荷法最为常用,其原理是通过测量在电容器上积累的热释电电荷,从而确定热释电系数…...
Java + Spring Boot + Mybatis 实现批量插入
在 Java 中使用 Spring Boot 和 MyBatis 实现批量插入可以通过以下步骤完成。这里提供两种常用方法:使用 MyBatis 的 <foreach> 标签和批处理模式(ExecutorType.BATCH)。 方法一:使用 XML 的 <foreach> 标签ÿ…...
JVM 内存结构 详解
内存结构 运行时数据区: Java虚拟机在运行Java程序过程中管理的内存区域。 程序计数器: 线程私有,程序控制流的指示器,分支、循环、跳转、异常处理、线程恢复等基础功能都依赖这个计数器完成。 每个线程都有一个程序计数…...
Kafka入门-生产者
生产者 生产者发送流程: 延迟时间为0ms时,也就意味着每当有数据就会直接发送 异步发送API 异步发送和同步发送的不同在于:异步发送不需要等待结果,同步发送必须等待结果才能进行下一步发送。 普通异步发送 首先导入所需的k…...
认识CMake并使用CMake构建自己的第一个项目
1.CMake的作用和优势 跨平台支持:CMake支持多种操作系统和编译器,使用同一份构建配置可以在不同的环境中使用 简化配置:通过CMakeLists.txt文件,用户可以定义项目结构、依赖项、编译选项等,无需手动编写复杂的构建脚本…...
