【llm对话系统】大模型源码分析之llama模型的long context更长上下文支持
1. 引言
Llama模型的一个重要特性是支持长上下文处理。本文将深入分析Llama源码中实现长上下文的关键技术点,包括位置编码(position embedding)的外推方法、注意力机制的优化等。我们将通过详细的代码解析来理解其实现原理。
2. 位置编码的外推实现
2.1 旋转位置编码(RoPE)基础
Llama采用旋转位置编码(RoPE, Rotary Position Embedding)来编码token的位置信息。RoPE的实现包含几个关键步骤:
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0, scale: float = 1.0):"""预计算RoPE的频率Args:dim: 隐藏层维度end: 序列最大长度theta: RoPE的基频参数scale: 位置缩放因子Returns:freqs_cis: 复数形式的频率矩阵"""# 生成维度序列 [0, 2, ..., dim-2]dims = torch.arange(0, dim, 2)[: (dim // 2)].float()# 计算频率基数 1/θ^(2i/d)freqs = 1.0 / (theta ** (dims / dim))# 生成位置序列并应用缩放t = torch.arange(end, device=freqs.device) * scale# 计算位置和频率的外积freqs = torch.outer(t, freqs)# 转换为复数形式 e^(iθ)freqs_cis = torch.polar(torch.ones_like(freqs), freqs)return freqs_cisdef apply_rotary_emb(xq: torch.Tensor,xk: torch.Tensor,freqs_cis: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:"""应用旋转位置编码Args:xq: query张量 [batch_size, seq_len, num_heads, head_dim]xk: key张量 [batch_size, seq_len, num_heads, head_dim]freqs_cis: 预计算的频率 [seq_len, head_dim//2]"""# 重塑张量以方便运算xq_r, xq_i = xq.float().reshape(*xq.shape[:-1], -1, 2).unbind(-1)xk_r, xk_i = xk.float().reshape(*xk.shape[:-1], -1, 2).unbind(-1)# 提取频率的实部和虚部freqs_cos = freqs_cis.real()freqs_sin = freqs_cis.imag()# 应用旋转变换# xq_out = xq * cos(θ) + rotate_half(xq) * sin(θ)xq_out_r = xq_r * freqs_cos - xq_i * freqs_sinxq_out_i = xq_r * freqs_sin + xq_i * freqs_cosxk_out_r = xk_r * freqs_cos - xk_i * freqs_sinxk_out_i = xk_r * freqs_sin + xk_i * freqs_cos# 重新组合实部和虚部xq_out = torch.stack([xq_out_r, xq_out_i], dim=-1).flatten(-2)xk_out = torch.stack([xk_out_r, xk_out_i], dim=-1).flatten(-2)return xq_out.type_as(xq), xk_out.type_as(xk)
2.2 动态NTK外推方案
动态NTK缩放是实现长上下文的关键技术,它通过动态调整位置编码的缩放因子来改善模型在更长序列上的表现:
class LlamaConfig:def __init__(self):self.rope_scaling = {"type": "dynamic", # 动态缩放类型"factor": 2.0, # 基础缩放因子"original_max_position_embeddings": 2048 # 原始训练长度}def compute_dynamic_ntk_scaling(ctx_len: int,orig_ctx_len: int = 2048,base_scale: float = 0.25,alpha: float = 1.0
) -> float:"""计算动态NTK缩放因子Args:ctx_len: 当前上下文长度orig_ctx_len: 原始训练上下文长度base_scale: 基础缩放系数alpha: 缩放曲线的陡峭程度"""# 使用对数曲线计算缩放因子return base_scale * math.log(ctx_len / orig_ctx_len) ** alphaclass LlamaAttention(nn.Module):def __init__(self, config: LlamaConfig):super().__init__()self.config = configself.rope_scaling = config.rope_scalingdef forward(self,hidden_states: torch.Tensor,attention_mask: Optional[torch.Tensor] = None,position_ids: Optional[torch.LongTensor] = None,) -> torch.Tensor:"""注意力前向计算Args:hidden_states: 输入张量 [batch_size, seq_len, hidden_size]attention_mask: 注意力掩码position_ids: 位置索引"""seq_len = hidden_states.shape[1]# 计算动态缩放因子if self.rope_scaling["type"] == "dynamic":rope_scale = compute_dynamic_ntk_scaling(seq_len,self.config.rope_scaling["original_max_position_embeddings"],base_scale=self.rope_scaling["factor"])else:rope_scale = 1.0# 计算位置编码freqs_cis = precompute_freqs_cis(self.head_dim,seq_len,scale=rope_scale)# 应用旋转位置编码query_states, key_states = apply_rotary_emb(self.q_proj(hidden_states),self.k_proj(hidden_states),freqs_cis)
3. 注意力机制优化
3.1 分块注意力计算
为了高效处理长序列,Llama实现了分块注意力计算。以下是详细的实现代码:
class ChunkedAttention(nn.Module):def __init__(self, chunk_size: int = 1024):super().__init__()self.chunk_size = chunk_sizedef forward(self,query: torch.Tensor, # [batch, num_heads, seq_len, head_dim]key: torch.Tensor, # [batch, num_heads, seq_len, head_dim]value: torch.Tensor, # [batch, num_heads, seq_len, head_dim]mask: Optional[torch.Tensor] = None) -> torch.Tensor:"""分块计算注意力"""batch_size, num_heads, seq_len, head_dim = query.shape# 计算需要的块数num_chunks = (seq_len + self.chunk_size - 1) // self.chunk_size# 存储每个块的输出chunked_outputs = []# 按块计算注意力for chunk_idx in range(num_chunks):# 计算当前块的起止位置chunk_start = chunk_idx * self.chunk_sizechunk_end = min(chunk_start + self.chunk_size, seq_len)# 提取当前块的querychunk_query = query[:, :, chunk_start:chunk_end]# 计算注意力得分chunk_scores = torch.matmul(chunk_query, # [b, h, chunk_size, d]key.transpose(-2, -1) # [b, h, d, seq_len]) # 得到 [b, h, chunk_size, seq_len]# 缩放注意力得分chunk_scores = chunk_scores / math.sqrt(head_dim)# 应用attention maskif mask is not None:chunk_mask = mask[:, :, chunk_start:chunk_end, :]chunk_scores = chunk_scores + chunk_mask# 应用softmaxchunk_attn = F.softmax(chunk_scores, dim=-1)# 计算输出chunk_output = torch.matmul(chunk_attn, value)chunked_outputs.append(chunk_output)# 拼接所有块的输出return torch.cat(chunked_outputs, dim=2)
3.2 优化的KV Cache实现
KV Cache的实现需要考虑内存效率和计算性能:
class KVCache:def __init__(self,max_batch_size: int,max_seq_length: int,num_heads: int,head_dim: int,dtype: torch.dtype = torch.float16):"""初始化KV缓存Args:max_batch_size: 最大批次大小max_seq_length: 最大序列长度num_heads: 注意力头数head_dim: 每个头的维度dtype: 数据类型"""self.max_seq_length = max_seq_length# 初始化缓存张量self.k_cache = torch.zeros(max_batch_size,num_heads,max_seq_length,head_dim,dtype=dtype)self.v_cache = torch.zeros(max_batch_size,num_heads,max_seq_length,head_dim,dtype=dtype)# 记录当前序列长度self.current_length = 0def update(self,key: torch.Tensor,value: torch.Tensor,position: int) -> None:"""更新缓存Args:key: key状态 [batch_size, num_heads, seq_len, head_dim]value: value状态 [batch_size, num_heads, seq_len, head_dim]position: 起始位置"""seq_len = key.shape[2]if position + seq_len > self.max_seq_length:raise ValueError(f"Position {position + seq_len} exceeds max_seq_length {self.max_seq_length}")# 更新缓存self.k_cache[:, :, position:position+seq_len] = keyself.v_cache[:, :, position:position+seq_len] = value# 更新当前长度self.current_length = max(self.current_length, position + seq_len)def get_cached_kv(self,start_pos: int,end_pos: int) -> Tuple[torch.Tensor, torch.Tensor]:"""获取指定范围的缓存内容"""return (self.k_cache[:, :, start_pos:end_pos],self.v_cache[:, :, start_pos:end_pos])def clear(self) -> None:"""清空缓存"""self.k_cache.zero_()self.v_cache.zero_()self.current_length = 0
4. 实际应用示例
让我们看一个完整的使用示例,展示如何处理长文本:
class LongContextProcessor:def __init__(self,model: LlamaModel,tokenizer,max_length: int = 16384,chunk_size: int = 1024):self.model = modelself.tokenizer = tokenizerself.chunk_size = chunk_size# 初始化KV缓存self.kv_cache = KVCache(max_batch_size=1,max_seq_length=max_length,num_heads=model.config.num_attention_heads,head_dim=model.config.hidden_size // model.config.num_attention_heads)def process_long_text(self, text: str) -> torch.Tensor:"""处理长文本输入Args:text: 输入文本Returns:处理后的隐藏状态"""# 分词tokens = self.tokenizer(text,return_tensors="pt",truncation=False).input_ids# 清空KV缓存self.kv_cache.clear()# 分块处理all_hidden_states = []for i in range(0, tokens.size(1), self.chunk_size):# 获取当前块chunk = tokens[:, i:i+self.chunk_size]# 获取位置编码索引position_ids = torch.arange(i,i + chunk.size(1),dtype=torch.long,device=chunk.device).unsqueeze(0)# 获取当前位置的缓存k_cache, v_cache = self.kv_cache.get_cached_kv(0, i)# 前向计算outputs = self.model(chunk,position_ids=position_ids,past_key_values=[(k_cache, v_cache)] * self.model.config.num_hidden_layers)# 更新缓存self.kv
相关文章:
【llm对话系统】大模型源码分析之llama模型的long context更长上下文支持
1. 引言 Llama模型的一个重要特性是支持长上下文处理。本文将深入分析Llama源码中实现长上下文的关键技术点,包括位置编码(position embedding)的外推方法、注意力机制的优化等。我们将通过详细的代码解析来理解其实现原理。 2. 位置编码的外推实现 2.1 旋转位置…...
单片机基础模块学习——NE555芯片
一、NE555电路图 NE555也称555定时器,本文主要利用NE555产生方波发生电路。整个电路相当于频率可调的方波发生器。 通过调整电位器的阻值,方波的频率也随之改变。 RB3在开发板的位置如下图 测量方波信号的引脚为SIGHAL,由上面的电路图可知,NE555已经构成完整的方波发生电…...
Hive:struct数据类型,内置函数(日期,字符串,类型转换,数学)
struct STRUCT(结构体)是一种复合数据类型,它允许你将多个字段组合成一个单一的值, 常用于处理嵌套数据,例如当你需要在一个表中存储有关另一个实体的信息时。你可以使用 STRUCT 函数来创建一个结构体。STRUCT 函数接受多个参数&…...
最优化问题 - 内点法
以下是一种循序推理的方式,来帮助你从基础概念出发,理解 内点法(Interior-Point Method, IPM) 是什么、为什么要用它,以及它是如何工作的。 1. 问题起点:带不等式约束的优化 假设你有一个带不等式约束的优…...
vim交换文件的工作原理
在vim中,交换文件是一个临时文件,当我们使用vim打开一个文件进行编辑(一定得是做出了修改才会产生交换文件)时候,vim就会自动创建一个交换文件,而之后我们对于文件的一系列修改都是在交换文件中进行的&…...
CISCO路由基础全集
第一章:交换机的工作原理和基本技能_交换机有操作系统吗-CSDN博客文章浏览阅读1.1k次,点赞24次,收藏24次。交换机可看成是一台特殊的计算机,同样有CPU、存储介质和操作系统,只是与计算机的稍有不同。作为数据交换设备&…...
网络直播时代的营销新策略:基于受众分析与开源AI智能名片2+1链动模式S2B2C商城小程序源码的探索
摘要:随着互联网技术的飞速发展,网络直播作为一种新兴的、极具影响力的媒体形式,正逐渐改变着人们的娱乐方式、消费习惯乃至社交模式。据中国互联网络信息中心数据显示,网络直播用户规模已达到3.25亿,占网民总数的45.8…...
2024年终总结——今年是蜕变的一年
2024年终总结 摘要前因转折找工作工作的成长人生的意义 摘要 2024我从国企出来,兜兜转转还是去了北京,一边是工资低、感情受挫,一边是压力大、项目经历少,让我一度找不到自己梦寐以求的工作,我投了一家又一家ÿ…...
AutoDL 云服务器:普通 用户 miniconda 配置
AutoDL 初始状态下只有root用户,miniconda 安装在root用户目录下 /// 增加普通用户 rootautodl-container-1c0641804d-5bb7040c:~/Desktop# apt updaterootautodl-container-1c0641804d-5bb7040c:~/Desktop# apt install sudorootautodl-container-1c0641804d-5…...
渲染流程概述
渲染流程包括 CPU应用程序端渲染逻辑 和 GPU渲染管线 一、CPU应用程序端渲染逻辑 剔除操作对物体进行渲染排序打包数据调用Shader SetPassCall 和 Drawcall 1.剔除操作 视椎体剔除 (给物体一个包围盒,利用包围盒和摄像机的视椎体进行碰撞检测…...
前端力扣刷题 | 4:hot100之 子串
560. 和为K的子数组 给你一个整数数组 nums 和一个整数 k ,请你统计并返回 该数组中和为 k 的子数组的个数 。 子数组是数组中元素的连续非空序列。 示例: 输入:nums [1,1,1], k 2 输出:2 法一:暴力法 var subar…...
Julia 之 @btime 精准测量详解
Julia 语言因其高性能和易用性在科学计算、数据分析等领域获得了广泛关注。在性能优化中,精准测量代码执行时间是至关重要的任务,而 Julia 提供了强大的工具 btime 来辅助这一任务。本文将围绕 Julia 的 btime 来展开,帮助读者深入理解并高效…...
【Django教程】用户管理系统
Get Started With Django User Management 开始使用Django用户管理 By the end of this tutorial, you’ll understand that: 在本教程结束时,您将了解: Django’s user authentication is a built-in authentication system that comes with pre-conf…...
【机器学习】自定义数据集 使用pytorch框架实现逻辑回归并保存模型,然后保存模型后再加载模型进行预测
一、使用pytorch框架实现逻辑回归 1. 数据部分: 首先自定义了一个简单的数据集,特征 X 是 100 个随机样本,每个样本一个特征,目标值 y 基于线性关系并添加了噪声。将 numpy 数组转换为 PyTorch 张量,方便后续在模型中…...
C语言连接Mysql
目录 C语言连接Mysql下载 mysql 开发库 方法介绍mysql_init()mysql_real_connect()mysql_query()mysql_store_result()mysql_num_fields()mysql_fetch_fields()mysql_fetch_row()mysql_free_result()mysql_close() 完整代码 C语言连接Mysql 下载 mysql 开发库 方法一…...
Windows上通过Git Bash激活Anaconda
在Windows上配置完Anaconda后,普遍通过Anaconda Prompt激活虚拟环境并执行Python,如下图所示: 有时需要连续执行多个python脚本时,直接在Anaconda Prompt下可以通过在以下方式,即命令间通过&&连接,…...
面试经典150题——图
文章目录 1、岛屿数量1.1 题目链接1.2 题目描述1.3 解题代码1.4 解题思路 2、被围绕的区域2.1 题目链接2.2 题目描述2.3 解题代码2.4 解题思路 3、克隆图3.1 题目链接3.2 题目描述3.3 解题代码3.4 解题思路 4、除法求值4.1 题目链接4.2 题目描述4.3 解题代码4.4 解题思路 5、课…...
学习数据结构(1)时间复杂度
1.数据结构和算法 (1)数据结构是计算机存储、组织数据的方式,指相互之间存在⼀种或多种特定关系的数据元素的集合 (2)算法就是定义良好的计算过程,取一个或一组的值为输入,并产生出一个或一组…...
项目集成GateWay
文章目录 1.环境搭建1.创建sunrays-common-cloud-gateway-starter模块2.目录结构3.自动配置1.GateWayAutoConfiguration.java2.spring.factories 3.pom.xml4.注意:GateWay不能跟Web一起引入! 1.环境搭建 1.创建sunrays-common-cloud-gateway-starter模块…...
【Ubuntu】使用远程桌面协议(RDP)在Windows上远程连接Ubuntu
使用远程桌面协议(RDP)在Windows上远程连接Ubuntu 远程桌面协议(RDP)是一种允许用户通过图形界面远程控制计算机的协议。本文将详细介绍如何在Ubuntu上安装和配置xrdp,并通过Windows的远程桌面连接工具访问Ubuntu。 …...
浏览器访问 AWS ECS 上部署的 Docker 容器(监听 80 端口)
✅ 一、ECS 服务配置 Dockerfile 确保监听 80 端口 EXPOSE 80 CMD ["nginx", "-g", "daemon off;"]或 EXPOSE 80 CMD ["python3", "-m", "http.server", "80"]任务定义(Task Definition&…...
Chapter03-Authentication vulnerabilities
文章目录 1. 身份验证简介1.1 What is authentication1.2 difference between authentication and authorization1.3 身份验证机制失效的原因1.4 身份验证机制失效的影响 2. 基于登录功能的漏洞2.1 密码爆破2.2 用户名枚举2.3 有缺陷的暴力破解防护2.3.1 如果用户登录尝试失败次…...
Docker 离线安装指南
参考文章 1、确认操作系统类型及内核版本 Docker依赖于Linux内核的一些特性,不同版本的Docker对内核版本有不同要求。例如,Docker 17.06及之后的版本通常需要Linux内核3.10及以上版本,Docker17.09及更高版本对应Linux内核4.9.x及更高版本。…...
Java 8 Stream API 入门到实践详解
一、告别 for 循环! 传统痛点: Java 8 之前,集合操作离不开冗长的 for 循环和匿名类。例如,过滤列表中的偶数: List<Integer> list Arrays.asList(1, 2, 3, 4, 5); List<Integer> evens new ArrayList…...
解决Ubuntu22.04 VMware失败的问题 ubuntu入门之二十八
现象1 打开VMware失败 Ubuntu升级之后打开VMware上报需要安装vmmon和vmnet,点击确认后如下提示 最终上报fail 解决方法 内核升级导致,需要在新内核下重新下载编译安装 查看版本 $ vmware -v VMware Workstation 17.5.1 build-23298084$ lsb_release…...
Leetcode 3577. Count the Number of Computer Unlocking Permutations
Leetcode 3577. Count the Number of Computer Unlocking Permutations 1. 解题思路2. 代码实现 题目链接:3577. Count the Number of Computer Unlocking Permutations 1. 解题思路 这一题其实就是一个脑筋急转弯,要想要能够将所有的电脑解锁&#x…...
什么是库存周转?如何用进销存系统提高库存周转率?
你可能听说过这样一句话: “利润不是赚出来的,是管出来的。” 尤其是在制造业、批发零售、电商这类“货堆成山”的行业,很多企业看着销售不错,账上却没钱、利润也不见了,一翻库存才发现: 一堆卖不动的旧货…...
高等数学(下)题型笔记(八)空间解析几何与向量代数
目录 0 前言 1 向量的点乘 1.1 基本公式 1.2 例题 2 向量的叉乘 2.1 基础知识 2.2 例题 3 空间平面方程 3.1 基础知识 3.2 例题 4 空间直线方程 4.1 基础知识 4.2 例题 5 旋转曲面及其方程 5.1 基础知识 5.2 例题 6 空间曲面的法线与切平面 6.1 基础知识 6.2…...
【Web 进阶篇】优雅的接口设计:统一响应、全局异常处理与参数校验
系列回顾: 在上一篇中,我们成功地为应用集成了数据库,并使用 Spring Data JPA 实现了基本的 CRUD API。我们的应用现在能“记忆”数据了!但是,如果你仔细审视那些 API,会发现它们还很“粗糙”:有…...
WEB3全栈开发——面试专业技能点P2智能合约开发(Solidity)
一、Solidity合约开发 下面是 Solidity 合约开发 的概念、代码示例及讲解,适合用作学习或写简历项目背景说明。 🧠 一、概念简介:Solidity 合约开发 Solidity 是一种专门为 以太坊(Ethereum)平台编写智能合约的高级编…...
