【llm对话系统】大模型源码分析之llama模型的long context更长上下文支持
1. 引言
Llama模型的一个重要特性是支持长上下文处理。本文将深入分析Llama源码中实现长上下文的关键技术点,包括位置编码(position embedding)的外推方法、注意力机制的优化等。我们将通过详细的代码解析来理解其实现原理。
2. 位置编码的外推实现
2.1 旋转位置编码(RoPE)基础
Llama采用旋转位置编码(RoPE, Rotary Position Embedding)来编码token的位置信息。RoPE的实现包含几个关键步骤:
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0, scale: float = 1.0):"""预计算RoPE的频率Args:dim: 隐藏层维度end: 序列最大长度theta: RoPE的基频参数scale: 位置缩放因子Returns:freqs_cis: 复数形式的频率矩阵"""# 生成维度序列 [0, 2, ..., dim-2]dims = torch.arange(0, dim, 2)[: (dim // 2)].float()# 计算频率基数 1/θ^(2i/d)freqs = 1.0 / (theta ** (dims / dim))# 生成位置序列并应用缩放t = torch.arange(end, device=freqs.device) * scale# 计算位置和频率的外积freqs = torch.outer(t, freqs)# 转换为复数形式 e^(iθ)freqs_cis = torch.polar(torch.ones_like(freqs), freqs)return freqs_cisdef apply_rotary_emb(xq: torch.Tensor,xk: torch.Tensor,freqs_cis: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:"""应用旋转位置编码Args:xq: query张量 [batch_size, seq_len, num_heads, head_dim]xk: key张量 [batch_size, seq_len, num_heads, head_dim]freqs_cis: 预计算的频率 [seq_len, head_dim//2]"""# 重塑张量以方便运算xq_r, xq_i = xq.float().reshape(*xq.shape[:-1], -1, 2).unbind(-1)xk_r, xk_i = xk.float().reshape(*xk.shape[:-1], -1, 2).unbind(-1)# 提取频率的实部和虚部freqs_cos = freqs_cis.real()freqs_sin = freqs_cis.imag()# 应用旋转变换# xq_out = xq * cos(θ) + rotate_half(xq) * sin(θ)xq_out_r = xq_r * freqs_cos - xq_i * freqs_sinxq_out_i = xq_r * freqs_sin + xq_i * freqs_cosxk_out_r = xk_r * freqs_cos - xk_i * freqs_sinxk_out_i = xk_r * freqs_sin + xk_i * freqs_cos# 重新组合实部和虚部xq_out = torch.stack([xq_out_r, xq_out_i], dim=-1).flatten(-2)xk_out = torch.stack([xk_out_r, xk_out_i], dim=-1).flatten(-2)return xq_out.type_as(xq), xk_out.type_as(xk)
2.2 动态NTK外推方案
动态NTK缩放是实现长上下文的关键技术,它通过动态调整位置编码的缩放因子来改善模型在更长序列上的表现:
class LlamaConfig:def __init__(self):self.rope_scaling = {"type": "dynamic", # 动态缩放类型"factor": 2.0, # 基础缩放因子"original_max_position_embeddings": 2048 # 原始训练长度}def compute_dynamic_ntk_scaling(ctx_len: int,orig_ctx_len: int = 2048,base_scale: float = 0.25,alpha: float = 1.0
) -> float:"""计算动态NTK缩放因子Args:ctx_len: 当前上下文长度orig_ctx_len: 原始训练上下文长度base_scale: 基础缩放系数alpha: 缩放曲线的陡峭程度"""# 使用对数曲线计算缩放因子return base_scale * math.log(ctx_len / orig_ctx_len) ** alphaclass LlamaAttention(nn.Module):def __init__(self, config: LlamaConfig):super().__init__()self.config = configself.rope_scaling = config.rope_scalingdef forward(self,hidden_states: torch.Tensor,attention_mask: Optional[torch.Tensor] = None,position_ids: Optional[torch.LongTensor] = None,) -> torch.Tensor:"""注意力前向计算Args:hidden_states: 输入张量 [batch_size, seq_len, hidden_size]attention_mask: 注意力掩码position_ids: 位置索引"""seq_len = hidden_states.shape[1]# 计算动态缩放因子if self.rope_scaling["type"] == "dynamic":rope_scale = compute_dynamic_ntk_scaling(seq_len,self.config.rope_scaling["original_max_position_embeddings"],base_scale=self.rope_scaling["factor"])else:rope_scale = 1.0# 计算位置编码freqs_cis = precompute_freqs_cis(self.head_dim,seq_len,scale=rope_scale)# 应用旋转位置编码query_states, key_states = apply_rotary_emb(self.q_proj(hidden_states),self.k_proj(hidden_states),freqs_cis)
3. 注意力机制优化
3.1 分块注意力计算
为了高效处理长序列,Llama实现了分块注意力计算。以下是详细的实现代码:
class ChunkedAttention(nn.Module):def __init__(self, chunk_size: int = 1024):super().__init__()self.chunk_size = chunk_sizedef forward(self,query: torch.Tensor, # [batch, num_heads, seq_len, head_dim]key: torch.Tensor, # [batch, num_heads, seq_len, head_dim]value: torch.Tensor, # [batch, num_heads, seq_len, head_dim]mask: Optional[torch.Tensor] = None) -> torch.Tensor:"""分块计算注意力"""batch_size, num_heads, seq_len, head_dim = query.shape# 计算需要的块数num_chunks = (seq_len + self.chunk_size - 1) // self.chunk_size# 存储每个块的输出chunked_outputs = []# 按块计算注意力for chunk_idx in range(num_chunks):# 计算当前块的起止位置chunk_start = chunk_idx * self.chunk_sizechunk_end = min(chunk_start + self.chunk_size, seq_len)# 提取当前块的querychunk_query = query[:, :, chunk_start:chunk_end]# 计算注意力得分chunk_scores = torch.matmul(chunk_query, # [b, h, chunk_size, d]key.transpose(-2, -1) # [b, h, d, seq_len]) # 得到 [b, h, chunk_size, seq_len]# 缩放注意力得分chunk_scores = chunk_scores / math.sqrt(head_dim)# 应用attention maskif mask is not None:chunk_mask = mask[:, :, chunk_start:chunk_end, :]chunk_scores = chunk_scores + chunk_mask# 应用softmaxchunk_attn = F.softmax(chunk_scores, dim=-1)# 计算输出chunk_output = torch.matmul(chunk_attn, value)chunked_outputs.append(chunk_output)# 拼接所有块的输出return torch.cat(chunked_outputs, dim=2)
3.2 优化的KV Cache实现
KV Cache的实现需要考虑内存效率和计算性能:
class KVCache:def __init__(self,max_batch_size: int,max_seq_length: int,num_heads: int,head_dim: int,dtype: torch.dtype = torch.float16):"""初始化KV缓存Args:max_batch_size: 最大批次大小max_seq_length: 最大序列长度num_heads: 注意力头数head_dim: 每个头的维度dtype: 数据类型"""self.max_seq_length = max_seq_length# 初始化缓存张量self.k_cache = torch.zeros(max_batch_size,num_heads,max_seq_length,head_dim,dtype=dtype)self.v_cache = torch.zeros(max_batch_size,num_heads,max_seq_length,head_dim,dtype=dtype)# 记录当前序列长度self.current_length = 0def update(self,key: torch.Tensor,value: torch.Tensor,position: int) -> None:"""更新缓存Args:key: key状态 [batch_size, num_heads, seq_len, head_dim]value: value状态 [batch_size, num_heads, seq_len, head_dim]position: 起始位置"""seq_len = key.shape[2]if position + seq_len > self.max_seq_length:raise ValueError(f"Position {position + seq_len} exceeds max_seq_length {self.max_seq_length}")# 更新缓存self.k_cache[:, :, position:position+seq_len] = keyself.v_cache[:, :, position:position+seq_len] = value# 更新当前长度self.current_length = max(self.current_length, position + seq_len)def get_cached_kv(self,start_pos: int,end_pos: int) -> Tuple[torch.Tensor, torch.Tensor]:"""获取指定范围的缓存内容"""return (self.k_cache[:, :, start_pos:end_pos],self.v_cache[:, :, start_pos:end_pos])def clear(self) -> None:"""清空缓存"""self.k_cache.zero_()self.v_cache.zero_()self.current_length = 0
4. 实际应用示例
让我们看一个完整的使用示例,展示如何处理长文本:
class LongContextProcessor:def __init__(self,model: LlamaModel,tokenizer,max_length: int = 16384,chunk_size: int = 1024):self.model = modelself.tokenizer = tokenizerself.chunk_size = chunk_size# 初始化KV缓存self.kv_cache = KVCache(max_batch_size=1,max_seq_length=max_length,num_heads=model.config.num_attention_heads,head_dim=model.config.hidden_size // model.config.num_attention_heads)def process_long_text(self, text: str) -> torch.Tensor:"""处理长文本输入Args:text: 输入文本Returns:处理后的隐藏状态"""# 分词tokens = self.tokenizer(text,return_tensors="pt",truncation=False).input_ids# 清空KV缓存self.kv_cache.clear()# 分块处理all_hidden_states = []for i in range(0, tokens.size(1), self.chunk_size):# 获取当前块chunk = tokens[:, i:i+self.chunk_size]# 获取位置编码索引position_ids = torch.arange(i,i + chunk.size(1),dtype=torch.long,device=chunk.device).unsqueeze(0)# 获取当前位置的缓存k_cache, v_cache = self.kv_cache.get_cached_kv(0, i)# 前向计算outputs = self.model(chunk,position_ids=position_ids,past_key_values=[(k_cache, v_cache)] * self.model.config.num_hidden_layers)# 更新缓存self.kv
相关文章:

【llm对话系统】大模型源码分析之llama模型的long context更长上下文支持
1. 引言 Llama模型的一个重要特性是支持长上下文处理。本文将深入分析Llama源码中实现长上下文的关键技术点,包括位置编码(position embedding)的外推方法、注意力机制的优化等。我们将通过详细的代码解析来理解其实现原理。 2. 位置编码的外推实现 2.1 旋转位置…...

单片机基础模块学习——NE555芯片
一、NE555电路图 NE555也称555定时器,本文主要利用NE555产生方波发生电路。整个电路相当于频率可调的方波发生器。 通过调整电位器的阻值,方波的频率也随之改变。 RB3在开发板的位置如下图 测量方波信号的引脚为SIGHAL,由上面的电路图可知,NE555已经构成完整的方波发生电…...

Hive:struct数据类型,内置函数(日期,字符串,类型转换,数学)
struct STRUCT(结构体)是一种复合数据类型,它允许你将多个字段组合成一个单一的值, 常用于处理嵌套数据,例如当你需要在一个表中存储有关另一个实体的信息时。你可以使用 STRUCT 函数来创建一个结构体。STRUCT 函数接受多个参数&…...

最优化问题 - 内点法
以下是一种循序推理的方式,来帮助你从基础概念出发,理解 内点法(Interior-Point Method, IPM) 是什么、为什么要用它,以及它是如何工作的。 1. 问题起点:带不等式约束的优化 假设你有一个带不等式约束的优…...

vim交换文件的工作原理
在vim中,交换文件是一个临时文件,当我们使用vim打开一个文件进行编辑(一定得是做出了修改才会产生交换文件)时候,vim就会自动创建一个交换文件,而之后我们对于文件的一系列修改都是在交换文件中进行的&…...

CISCO路由基础全集
第一章:交换机的工作原理和基本技能_交换机有操作系统吗-CSDN博客文章浏览阅读1.1k次,点赞24次,收藏24次。交换机可看成是一台特殊的计算机,同样有CPU、存储介质和操作系统,只是与计算机的稍有不同。作为数据交换设备&…...

网络直播时代的营销新策略:基于受众分析与开源AI智能名片2+1链动模式S2B2C商城小程序源码的探索
摘要:随着互联网技术的飞速发展,网络直播作为一种新兴的、极具影响力的媒体形式,正逐渐改变着人们的娱乐方式、消费习惯乃至社交模式。据中国互联网络信息中心数据显示,网络直播用户规模已达到3.25亿,占网民总数的45.8…...

2024年终总结——今年是蜕变的一年
2024年终总结 摘要前因转折找工作工作的成长人生的意义 摘要 2024我从国企出来,兜兜转转还是去了北京,一边是工资低、感情受挫,一边是压力大、项目经历少,让我一度找不到自己梦寐以求的工作,我投了一家又一家ÿ…...

AutoDL 云服务器:普通 用户 miniconda 配置
AutoDL 初始状态下只有root用户,miniconda 安装在root用户目录下 /// 增加普通用户 rootautodl-container-1c0641804d-5bb7040c:~/Desktop# apt updaterootautodl-container-1c0641804d-5bb7040c:~/Desktop# apt install sudorootautodl-container-1c0641804d-5…...

渲染流程概述
渲染流程包括 CPU应用程序端渲染逻辑 和 GPU渲染管线 一、CPU应用程序端渲染逻辑 剔除操作对物体进行渲染排序打包数据调用Shader SetPassCall 和 Drawcall 1.剔除操作 视椎体剔除 (给物体一个包围盒,利用包围盒和摄像机的视椎体进行碰撞检测…...

前端力扣刷题 | 4:hot100之 子串
560. 和为K的子数组 给你一个整数数组 nums 和一个整数 k ,请你统计并返回 该数组中和为 k 的子数组的个数 。 子数组是数组中元素的连续非空序列。 示例: 输入:nums [1,1,1], k 2 输出:2 法一:暴力法 var subar…...

Julia 之 @btime 精准测量详解
Julia 语言因其高性能和易用性在科学计算、数据分析等领域获得了广泛关注。在性能优化中,精准测量代码执行时间是至关重要的任务,而 Julia 提供了强大的工具 btime 来辅助这一任务。本文将围绕 Julia 的 btime 来展开,帮助读者深入理解并高效…...

【Django教程】用户管理系统
Get Started With Django User Management 开始使用Django用户管理 By the end of this tutorial, you’ll understand that: 在本教程结束时,您将了解: Django’s user authentication is a built-in authentication system that comes with pre-conf…...

【机器学习】自定义数据集 使用pytorch框架实现逻辑回归并保存模型,然后保存模型后再加载模型进行预测
一、使用pytorch框架实现逻辑回归 1. 数据部分: 首先自定义了一个简单的数据集,特征 X 是 100 个随机样本,每个样本一个特征,目标值 y 基于线性关系并添加了噪声。将 numpy 数组转换为 PyTorch 张量,方便后续在模型中…...

C语言连接Mysql
目录 C语言连接Mysql下载 mysql 开发库 方法介绍mysql_init()mysql_real_connect()mysql_query()mysql_store_result()mysql_num_fields()mysql_fetch_fields()mysql_fetch_row()mysql_free_result()mysql_close() 完整代码 C语言连接Mysql 下载 mysql 开发库 方法一…...

Windows上通过Git Bash激活Anaconda
在Windows上配置完Anaconda后,普遍通过Anaconda Prompt激活虚拟环境并执行Python,如下图所示: 有时需要连续执行多个python脚本时,直接在Anaconda Prompt下可以通过在以下方式,即命令间通过&&连接,…...

面试经典150题——图
文章目录 1、岛屿数量1.1 题目链接1.2 题目描述1.3 解题代码1.4 解题思路 2、被围绕的区域2.1 题目链接2.2 题目描述2.3 解题代码2.4 解题思路 3、克隆图3.1 题目链接3.2 题目描述3.3 解题代码3.4 解题思路 4、除法求值4.1 题目链接4.2 题目描述4.3 解题代码4.4 解题思路 5、课…...

学习数据结构(1)时间复杂度
1.数据结构和算法 (1)数据结构是计算机存储、组织数据的方式,指相互之间存在⼀种或多种特定关系的数据元素的集合 (2)算法就是定义良好的计算过程,取一个或一组的值为输入,并产生出一个或一组…...

项目集成GateWay
文章目录 1.环境搭建1.创建sunrays-common-cloud-gateway-starter模块2.目录结构3.自动配置1.GateWayAutoConfiguration.java2.spring.factories 3.pom.xml4.注意:GateWay不能跟Web一起引入! 1.环境搭建 1.创建sunrays-common-cloud-gateway-starter模块…...

【Ubuntu】使用远程桌面协议(RDP)在Windows上远程连接Ubuntu
使用远程桌面协议(RDP)在Windows上远程连接Ubuntu 远程桌面协议(RDP)是一种允许用户通过图形界面远程控制计算机的协议。本文将详细介绍如何在Ubuntu上安装和配置xrdp,并通过Windows的远程桌面连接工具访问Ubuntu。 …...

python3+TensorFlow 2.x 基础学习(一)
目录 TensorFlow 2.x基础 1、安装 TensorFlow 2.x 2、TensorFlow 2.x 基础概念 2、1 Eager Execution 2、2 TensorFlow 张量(Tensor) 3、使用Keras构建神经网络模型 3、1 构建 Sequential 模型 3、2 编译模型 1、Optimizer(优化器&a…...

《活出人生的厚度》
《活出人生的厚度》可以从不同角度来理解和实践,以下为你提供一些拓展内容: ### 不断学习与自我提升 - **持续知识更新**:保持对新知识的渴望,利用各种渠道学习,如在线课程、学术讲座、行业研讨会等。例如,…...

安装 docker 详解
在平常的开发工作中,我们经常需要部署项目。随着 Docker 容器的出现,大大提高了部署效率。Docker 容器包含了应用程序运行所需的所有依赖,避免了换环境运行问题。可以在短时间内创建、启动和停止容器,大大提高了应用的部署速度&am…...

【Rust自学】16.3. 共享状态的并发
喜欢的话别忘了点赞、收藏加关注哦(加关注即可阅读全文),对接下来的教程有兴趣的可以关注专栏。谢谢喵!(・ω・) 16.3.1. 使用共享来实现并发 还记得Go语言有一句名言是这么说的:Do not commun…...

开发者交流平台项目部署到阿里云服务器教程
本文使用PuTTY软件在本地Windows系统远程控制Linux服务器;其中,Windows系统为Windows 10专业版,Linux系统为CentOS 7.6 64位。 1.工具软件的准备 maven:https://archive.apache.org/dist/maven/maven-3/3.6.1/binaries/apache-m…...

【2024年华为OD机试】 (B卷,100分)- 乘坐保密电梯(JavaScriptJava PythonC/C++)
一、问题描述 问题描述 我们需要从0楼到达指定楼层m,乘坐电梯的规则如下: 给定一个数字序列,每次根据序列中的数字n,上升n层或下降n层。前后两次的方向必须相反,且首次方向向上。必须使用序列中的所有数字,不能只使用一部分。目标是到达指定楼层m,如果无法到达,则给出…...

maven的打包插件如何使用
默认的情况下,当直接执行maven项目的编译命令时,对于结果来说是不打第三方包的,只有一个单独的代码jar,想要打一个包含其他资源的完整包就需要用到maven编译插件,使用时分以下几种情况 第一种:当只是想单纯…...

solidity高阶 -- 线性继承
Solidity是一种面向合约的高级编程语言,用于编写智能合约。在Solidity中,多线继承是一个强大的特性,允许合约从多个父合约继承属性和方法。本文将详细介绍Solidity中的多线继承,并通过不同的实例展示其使用方法和注意事项。 在Sol…...

国内外大语言模型领域发展现状与预期
在数字化浪潮中,大语言模型已成为人工智能领域的关键力量,深刻影响着各个行业的发展轨迹。下面我们将深入探讨国内外大语言模型领域的发展现状以及未来预期。 一、发展现状 (一)国外进展 美国的引领地位:OpenAI 的 …...

【Leetcode 热题 100】416. 分割等和子集
问题背景 给你一个 只包含正整数 的 非空 数组 n u m s nums nums。请你判断是否可以将这个数组分割成两个子集,使得两个子集的元素和相等。 数据约束 1 ≤ n u m s . l e n g t h ≤ 200 1 \le nums.length \le 200 1≤nums.length≤200 1 ≤ n u m s [ i ] ≤ …...