当前位置: 首页 > news >正文

LLaMA长度外推高性价比trick:线性插值法及相关改进源码阅读及相关记录

前言

最近,开源了可商用的llama2,支持长度相比llama1的1024,拓展到了4096长度,然而,相比GPT-4、Claude-2等支持的长度,llama的长度外推显得尤为重要,本文记录了三种网络开源的RoPE改进方式及相关源码的阅读。

关于长度外推性:https://kexue.fm/archives/9431

关于RoPE:https://kexue.fm/archives/8265

1、线性插值法

论文:EXTENDING CONTEXT WINDOW OF LARGE LANGUAGE MODELS VIA POSITION INTERPOLATION

链接:https://arxiv.org/pdf/2306.15595.pdf

思想:不进行长度外推,而是直接缩小位置索引。即:将4096的位置编码通过线性插值法压缩到2048内,这样只需在少量的4096长度的数据上继续预训练,便可达到不错的效果。

在这里插入图片描述

源码阅读(附注释)

class LlamaLinearScaledRotaryEmbedding(torch.nn.Module):def __init__(self, dim, max_position_embeddings=2048, base=10000, scale=1, device=None):super().__init__()# 相比RoPE增加scale参数self.scale = scale# inv_freq为基值向量inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim))self.register_buffer("inv_freq", inv_freq)# Build here to make `torch.jit.trace` work.self.max_seq_len_cached = max_position_embeddings# 构建max_seq_len_cached大小的张量tt = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=self.inv_freq.dtype)# 张量t归一化,RoPE没有这一步t /= self.scale# einsum计算频率矩阵# 'i, j->i j’表示分别输入尺寸为[i]、[j]的向量,做笛卡尔运算得到尺寸为[i, j]的矩阵。freqs = torch.einsum("i,j->ij", t, self.inv_freq)# Different from paper, but it uses a different permutation in order to obtain the same calculation# 在-1维做一次拷贝、拼接emb = torch.cat((freqs, freqs), dim=-1)dtype = torch.get_default_dtype()# 注册为模型的缓冲区cos_cached和sin_cachedself.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False)self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False)def forward(self, x, seq_len=None):# x: [bs, num_attention_heads, seq_len, head_size]# This `if` block is unlikely to be run after we build sin/cos in `__init__`. Keep the logic here just in case.# seq_len为序列长度,seq_len大于max_seq_len_cached,则重新计算频率矩阵,并更新cos_cached和sin_cached的缓冲区if seq_len > self.max_seq_len_cached:self.max_seq_len_cached = seq_lent = torch.arange(self.max_seq_len_cached, device=x.device, dtype=self.inv_freq.dtype)t /= self.scalefreqs = torch.einsum("i,j->ij", t, self.inv_freq)# Different from paper, but it uses a different permutation in order to obtain the same calculationemb = torch.cat((freqs, freqs), dim=-1).to(x.device)self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(x.dtype), persistent=False)self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(x.dtype), persistent=False)# 长度裁剪:返回cos_cached和sin_cached中与seq_len(序列长度)return (self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype),)

线性插值法的相关实验效果:https://lmsys.org/blog/2023-06-29-longchat/

2、NTK插值法

NTK插值改进llama中使用的RoPE插值方法,同样,对于RoPE代码改动更小,其他地方与线性插值法实现一致。

reddit原帖:NTK-Aware Scaled RoPE allows LLaMA models to have extended (8k+) context size without any fine-tuning and minimal perplexity degradation

链接:https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/?rdt=58346

源码阅读:

class LlamaNTKScaledRotaryEmbedding(torch.nn.Module):def __init__(self, dim, max_position_embeddings=2048, base=10000, alpha=1, device=None):super().__init__()# 与线性插值法相比,实现更简单,alpha仅用来改变basebase = base * alpha ** (dim / (dim-2))inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim))self.register_buffer("inv_freq", inv_freq)# Build here to make `torch.jit.trace` work.self.max_seq_len_cached = max_position_embeddingst = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=self.inv_freq.dtype)freqs = torch.einsum("i,j->ij", t, self.inv_freq)# Different from paper, but it uses a different permutation in order to obtain the same calculationemb = torch.cat((freqs, freqs), dim=-1)dtype = torch.get_default_dtype()self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False)self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False)def forward(self, x, seq_len=None):# x: [bs, num_attention_heads, seq_len, head_size]# This `if` block is unlikely to be run after we build sin/cos in `__init__`. Keep the logic here just in case.if seq_len > self.max_seq_len_cached:self.max_seq_len_cached = seq_lent = torch.arange(self.max_seq_len_cached, device=x.device, dtype=self.inv_freq.dtype)freqs = torch.einsum("i,j->ij", t, self.inv_freq)# Different from paper, but it uses a different permutation in order to obtain the same calculationemb = torch.cat((freqs, freqs), dim=-1).to(x.device)self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(x.dtype), persistent=False)self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(x.dtype), persistent=False)return (self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype),)

3、动态插值法

动态插值法又是对NTK插值法和线性插值法的改进,可以看作是上述两者的一种结合思想,旨在减少困惑度损失并实现更大的缩放。

reddit原帖:Dynamically Scaled RoPE further increases performance of long context LLaMA with zero fine-tuning

链接:https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases/

源码阅读

class LlamaDynamicScaledRotaryEmbedding(torch.nn.Module):def __init__(self, dim, max_position_embeddings=2048, base=10000, ntk=False, device=None):super().__init__()# 是否开启NTK(Neural Tangent Kernel)self.ntk = ntkself.base = baseself.dim = dimself.max_position_embeddings = max_position_embeddings# inv_freq为基值向量inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim))self.register_buffer("inv_freq", inv_freq)# Build here to make `torch.jit.trace` work.self.max_seq_len_cached = max_position_embeddingst = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=self.inv_freq.dtype)freqs = torch.einsum("i,j->ij", t, self.inv_freq)# Different from paper, but it uses a different permutation in order to obtain the same calculation# emb:[max_seq_len_cached, dim]emb = torch.cat((freqs, freqs), dim=-1)dtype = torch.get_default_dtype()self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False)self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False)def forward(self, x, seq_len=None):# x: [bs, num_attention_heads, seq_len, head_size]# This `if` block is unlikely to be run after we build sin/cos in `__init__`. Keep the logic here just in case.if seq_len > self.max_seq_len_cached:self.max_seq_len_cached = seq_lenif self.ntk:base = self.base * ((self.ntk * seq_len / self.max_position_embeddings) - (self.ntk - 1)) ** (self.dim / (self.dim-2))# 计算新的inv_freqinv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2).float().to(x.device) / self.dim))self.register_buffer("inv_freq", inv_freq)t = torch.arange(self.max_seq_len_cached, device=x.device, dtype=self.inv_freq.dtype)if not self.ntk:# 缩放t *= self.max_position_embeddings / seq_len# 得到新的频率矩阵freqsfreqs = torch.einsum("i,j->ij", t, self.inv_freq)# Different from paper, but it uses a different permutation in order to obtain the same calculation# freqs与自身拼接得到新的embemb = torch.cat((freqs, freqs), dim=-1).to(x.device)# 注册为模型的缓冲区cos_cached和sin_cachedself.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(x.dtype), persistent=False)self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(x.dtype), persistent=False)# 长度裁剪return (self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype),)

网友对于困惑度的实验并取得了一定的效果:https://github.com/turboderp/exllama/pull/118

总结

本文介绍了llama通过线性插值法及相关改进方案进行长度外推的trcik,并对相关源码阅读及网络资源进行记录,个人粗浅认为,相比LongLLaMA,基于线性插值法+Finetune的方式,是一种高性价比的长度外推方案。

相关文章:

LLaMA长度外推高性价比trick:线性插值法及相关改进源码阅读及相关记录

前言 最近,开源了可商用的llama2,支持长度相比llama1的1024,拓展到了4096长度,然而,相比GPT-4、Claude-2等支持的长度,llama的长度外推显得尤为重要,本文记录了三种网络开源的RoPE改进方式及相…...

中国信息安全测评中心CISP家族认证一览

随着国家对网络安全的重视,中国信息安全测评中心根据国家政策、未来趋势、重点内容陆续增添了很多CISP细分认证。 今日份详细介绍,部分CISP及其子品牌相关认证内容,一定要收藏哟! 校园版CISP NISP国家信息安全水平考试&#xff…...

牛客网【面试必刷TOP101】~ 06 递归/回溯

牛客网【面试必刷TOP101】~ 06 递归/回溯 文章目录 牛客网【面试必刷TOP101】~ 06 递归/回溯[toc]BM55 没有重复项数字的全排列(★★)BM56 有重复项数字的全排列(★★)BM57 岛屿数量(★★)BM58 字符串的排列(★★)BM59 N皇后问题(★★★)BM60 括号生成(★★)BM61 矩阵最长递增路…...

ArcGIS Pro基础:【划分】工具实现等比例、等面积、等宽度划分图形操作

本次介绍【划分】工具的使用,如下所示,为该工具所处位置。使用该工具可以实现对某个图斑的等比例面积划分、相等面积划分和相等宽度划分。 【等比例面积】:其操作如下所示,其中: 1表示先选中待处理的图斑,2…...

括号匹配问题:栈的巧妙应用与代码优化【栈、优化、哈希表】

当解决算法问题时,灵活使用数据结构是至关重要的。在这个问题中,我们需要判断一个只包含括号的字符串是否有效,即括号是否能够正确匹配和闭合。使用栈这一数据结构可以很好地解决这个问题。 题目链接:有效的括号 解题思路&#xf…...

vue项目正确使用样式deep穿透

经常开发前端的程序员应该都知道前端一般都是组件化开发,为了避免样式污染通常会使用scoped添加属性选择器,此时如果我们想在父组件中修改子组件的样式便成了难题。其实,我们可以通过以下几种方式修改子组件样式, 组件样式穿透 …...

Jenkins持续集成-快速上手

Jenkins持续集成-快速上手 注:Jenkins一般不单独使用,而是需要依赖代码仓库,构建工具等。 搭配组合:GitGitee(GitHub、GitLab)MavenJenkins 前置准备 常见安装方式: war包Docker容器实例&…...

查看linux 所有运行的应用和端口命令

要查看 Linux 中所有运行的应用程序及其对应的端口,可以使用以下命令: 1. 使用 netstat 命令(已被弃用,建议使用 ss 命令): netstat -tuln 这会显示当前系统上所有打开的网络连接和监听的端口。其中&#…...

Maven安装与配置,Eclipse配置Maven【图文并茂的保姆级教程】

🥳🥳Welcome Huihuis Code World ! !🥳🥳 接下来看看由辉辉所写的关于Maven的相关操作吧 目录 🥳🥳Welcome Huihuis Code World ! !🥳🥳 一.Maven是什么? 二.Maven的下…...

利用XLL文件投递Qbot银行木马的钓鱼活动分析

1概述 近期,安天CERT发现了一起利用恶意Microsoft Excel加载项(XLL)文件投递Qbot银行木马的恶意活动。攻击者通过发送垃圾邮件来诱导用户打开附件中的XLL文件,一旦用户安装并激活Microsoft Excel加载项,恶意代码将被执…...

2023CNSS——WEB题解(持续更新)

[Baby] SignIn 进来看到 按钮点击不了,想到去修改代码,要“检查“,但这里的右键和F12都不可用 还好还有其他方法 检查的各种方法 选用一种后进入检查页面 删掉这里的disabled即可 点击后得到flag [Baby] Backdoor 进入&#xff0c…...

Unity之ShaderGraph 节点介绍 数学节点

数学 高级Absolute(绝对值)Exponential(幂)Length(长度)Log(对数)Modulo(余数)Negate(相反数)Normalize(标准化矢量&…...

springboot mongo 使用

nosql对我来说,就是用它的变动列,如果列是固定的,我为什么不用mysql这种关系型数据库呢? 所以,现在网上搜出来的大部分,用实体类去接的做法,并不适合我的需求。 所以,整理记录一下…...

如何使用appuploader制作apple证书​

转载:如何使用appuploader制作apple证书​ 如何使用appuploader制作apple证书​ 一.证书管理​ 点击首页的证书管理 二.新建证书​ 点击“添加”,新建一个证书文件 免费账号制作证书只有7天有效期,没有推送消息功能,推送证书…...

Promise详细版

promise基础原理到难点分析 常见的Promise的方法解读 扩展async和await深入分析 逐步分析Promise底层逻辑代码 一、Promise基础 1.什么是promise 为了解决回调地狱: //2.设置点击事件btn.onclick function() {//3.创建ajax实例化对象let xhr new XMLHttpRe…...

v-for循环生成的盒子只改变当前选中的盒子的样式

1.给盒子添加动态属性:class"[index isActive?active-box:choose-box]" <div v-for"(item,index) in zyList" :key"item.sid" :class"[index isActive?active-box:choose-box]" click"getKmList(item,index)"…...

Spring Data JPA源码

导读: 什么是Spring Data JPA? 要解释这个问题,我们先将Spring Data JPA拆成两个部分&#xff0c;即Sping Data和JPA。 从这两个部分来解释。 Spring Data是什么? 摘自: https://spring.io/projects/spring-data Spring Data’s mission is to provide a familiar and cons…...

如何防止CSRF攻击

背景 随着互联网的高速发展&#xff0c;信息安全问题已经成为企业最为关注的焦点之一&#xff0c;而前端又是引发企业安全问题的高危据点。在移动互联网时代&#xff0c;前端人员除了传统的 XSS、CSRF 等安全问题之外&#xff0c;又时常遭遇网络劫持、非法调用 Hybrid API 等新…...

linuxARM裸机学习笔记(7)----RTC实时时钟实验

基础概念&#xff1a; I.MX6U 内部也有个RTC 模块&#xff0c;但是不叫作“ RTC ”&#xff0c;而是叫做“ SNVS ”。 SNVS 直译过来就是安全的非易性存储&#xff0c; SNVS 里面主要是一些低功耗的外设&#xff0c;包括一个 安全的实时计数器 (RTC) 、一个单调计数器 (mo…...

NSS [UUCTF 2022 新生赛]ez_upload

NSS [UUCTF 2022 新生赛]ez_upload 考点&#xff1a;Apache解析漏洞 开题就是标准的上传框 起手式就是传入一个php文件&#xff0c;非常正常的有过滤。 .txt、.user.ini、.txxx都被过滤了&#xff0c;应该是白名单或者黑名单加MIME过滤&#xff0c;只允许.jpg、.png。 猜测二…...

golang循环变量捕获问题​​

在 Go 语言中&#xff0c;当在循环中启动协程&#xff08;goroutine&#xff09;时&#xff0c;如果在协程闭包中直接引用循环变量&#xff0c;可能会遇到一个常见的陷阱 - ​​循环变量捕获问题​​。让我详细解释一下&#xff1a; 问题背景 看这个代码片段&#xff1a; fo…...

Java多线程实现之Callable接口深度解析

Java多线程实现之Callable接口深度解析 一、Callable接口概述1.1 接口定义1.2 与Runnable接口的对比1.3 Future接口与FutureTask类 二、Callable接口的基本使用方法2.1 传统方式实现Callable接口2.2 使用Lambda表达式简化Callable实现2.3 使用FutureTask类执行Callable任务 三、…...

【OSG学习笔记】Day 16: 骨骼动画与蒙皮(osgAnimation)

骨骼动画基础 骨骼动画是 3D 计算机图形中常用的技术&#xff0c;它通过以下两个主要组件实现角色动画。 骨骼系统 (Skeleton)&#xff1a;由层级结构的骨头组成&#xff0c;类似于人体骨骼蒙皮 (Mesh Skinning)&#xff1a;将模型网格顶点绑定到骨骼上&#xff0c;使骨骼移动…...

【Linux】Linux 系统默认的目录及作用说明

博主介绍&#xff1a;✌全网粉丝23W&#xff0c;CSDN博客专家、Java领域优质创作者&#xff0c;掘金/华为云/阿里云/InfoQ等平台优质作者、专注于Java技术领域✌ 技术范围&#xff1a;SpringBoot、SpringCloud、Vue、SSM、HTML、Nodejs、Python、MySQL、PostgreSQL、大数据、物…...

【 java 虚拟机知识 第一篇 】

目录 1.内存模型 1.1.JVM内存模型的介绍 1.2.堆和栈的区别 1.3.栈的存储细节 1.4.堆的部分 1.5.程序计数器的作用 1.6.方法区的内容 1.7.字符串池 1.8.引用类型 1.9.内存泄漏与内存溢出 1.10.会出现内存溢出的结构 1.内存模型 1.1.JVM内存模型的介绍 内存模型主要分…...

SQL Server 触发器调用存储过程实现发送 HTTP 请求

文章目录 需求分析解决第 1 步:前置条件,启用 OLE 自动化方式 1:使用 SQL 实现启用 OLE 自动化方式 2:Sql Server 2005启动OLE自动化方式 3:Sql Server 2008启动OLE自动化第 2 步:创建存储过程第 3 步:创建触发器扩展 - 如何调试?第 1 步:登录 SQL Server 2008第 2 步…...

第一篇:Liunx环境下搭建PaddlePaddle 3.0基础环境(Liunx Centos8.5安装Python3.10+pip3.10)

第一篇&#xff1a;Liunx环境下搭建PaddlePaddle 3.0基础环境&#xff08;Liunx Centos8.5安装Python3.10pip3.10&#xff09; 一&#xff1a;前言二&#xff1a;安装编译依赖二&#xff1a;安装Python3.10三&#xff1a;安装PIP3.10四&#xff1a;安装Paddlepaddle基础框架4.1…...

多元隐函数 偏导公式

我们来推导隐函数 z z ( x , y ) z z(x, y) zz(x,y) 的偏导公式&#xff0c;给定一个隐函数关系&#xff1a; F ( x , y , z ( x , y ) ) 0 F(x, y, z(x, y)) 0 F(x,y,z(x,y))0 &#x1f9e0; 目标&#xff1a; 求 ∂ z ∂ x \frac{\partial z}{\partial x} ∂x∂z​、 …...

Linux基础开发工具——vim工具

文章目录 vim工具什么是vimvim的多模式和使用vim的基础模式vim的三种基础模式三种模式的初步了解 常用模式的详细讲解插入模式命令模式模式转化光标的移动文本的编辑 底行模式替换模式视图模式总结 使用vim的小技巧vim的配置(了解) vim工具 本文章仍然是继续讲解Linux系统下的…...

路由基础-路由表

本篇将会向读者介绍路由的基本概念。 前言 在一个典型的数据通信网络中&#xff0c;往往存在多个不同的IP网段&#xff0c;数据在不同的IP网段之间交互是需要借助三层设备的&#xff0c;这些设备具备路由能力&#xff0c;能够实现数据的跨网段转发。 路由是数据通信网络中最基…...