当前位置: 首页 > news >正文

Llama架构及代码详解

Llama的框架图如图:
在这里插入图片描述
源码中含有大量分布式训练相关的代码,读起来比较晦涩难懂,所以我们对llama自顶向下进行了解析及复现,我们对其划分成三层,分别是顶层、中层、和底层,如下:

Llama的整体组成

由上图可知,Llama整体是由1个embedding层,n个transformer层,和1个RMSNorm层组成的,所以顶层代码如下:
顶层

class Llama(torch.nn.Module):def __init__(self, config: ModelArgs):super().__init__()self.config = config# embedding层self.tok_embeddings = torch.nn.Embedding(self.config.vocab_size, self.config.dim)# RMSNormself.norm = RMSNorm(config.dim, eps=config.norm_eps)# n层Transformerself.layers = torch.nn.ModuleList()for i in range(self.config.n_layers):self.layers.append(TransformerBlock(config))def forward(self, tokens):# 进行token的嵌入编码h = self.tok_embeddings(tokens)# decoder架构需要生成一个maskseqlen = h.shape[1]mask = torch.full((seqlen, seqlen), float('-inf'), device=tokens.device)mask = torch.triu(mask, diagonal=1)# 进行n层Transformerfor i in range(self.config.n_layers):h = self.layers[i](h, mask)# 进行RMSNormtoken_embeddings = self.norm(h)return token_embeddings

中层
我们首先进行RMSNorm的复现

class RMSNorm(torch.nn.Module):def __init__(self, dim, eps):super().__init__()self.eps = epsself.weight = torch.nn.Parameter(torch.ones(dim))def _norm(self, tensor):return tensor * torch.rsqrt(tensor.pow(2).mean(-1, keepdim=True) + self.eps)def forward(self, tensor):output = self._norm(tensor)return output * self.weight

然后对Transformer进行复现,在Transformer中,Transformer包括两个RMSNorm层,一个多头attention层,一个全连接层。

class TransformerBlock(torch.nn.Module):def __init__(self, config):super().__init__()self.config = config# 多头注意力层self.attention = Attention(config)# Norm层self.attention_normal = RMSNorm(config.dim, config.norm_eps)self.ffn_norm = RMSNorm(config.dim, config.norm_eps)# 全连接层self.ffn = FeedForwad(self.config.dim, self.config.dim * 4)def forward(self, embeddings, mask):# normh = self.attention_normal(embeddings)# attentionh = self.attention(h, mask)# add & normh = self.ffn_norm(h + embeddings)# fnnf = self.ffn(h)# addreturn f + h

底层
在多头attention中,首先需要对token的嵌入进行空间映射,多头拆分,旋转位置编码,分数计算等操作

class Attention(torch.nn.Module):def __init__(self, config):super().__init__()self.config = configself.n_head = config.n_headsself.dim = config.dim // self.n_headself.k = torch.nn.Linear(config.dim, config.dim)self.q = torch.nn.Linear(config.dim, config.dim)self.v = torch.nn.Linear(config.dim, config.dim)def forward(self, embeddings, mask):bsz, seq_len, dim = embeddings.shapek_embeddings = self.k(embeddings)q_embeddings = self.q(embeddings)v_embeddings = self.v(embeddings)n_q_embeddings = q_embeddings.reshape(bsz, -1, self.n_head, self.dim).permute(0, 2, 1, 3)n_k_embeddings = k_embeddings.reshape(bsz, -1, self.n_head, self.dim).permute(0, 2, 1, 3)n_v_embeddings = v_embeddings.reshape(bsz, -1, self.n_head, self.dim).permute(0, 2, 1, 3)rotated_n_q_embeddings = compute_rotated_embedding(n_q_embeddings, self.dim, seq_len, self.config.rope_theta)rotated_n_k_embeddings = compute_rotated_embedding(n_k_embeddings, self.dim, seq_len, self.config.rope_theta)scores = torch.nn.functional.softmax(mask + rotated_n_q_embeddings @ rotated_n_k_embeddings.transpose(-1, -2)/ math.sqrt(self.dim), dim=-1)n_embeddings = scores @ n_v_embeddingsembeddings = n_embeddings.permute(0, 2, 1, 3).reshape(bsz, -1, self.config.dim)return embeddings
class FeedForwad(torch.nn.Module):def __init__(self, dim, hidden_dim):super().__init__()self.linear1 = torch.nn.Linear(dim, hidden_dim)self.linear2 = torch.nn.Linear(dim, hidden_dim)self.linear3 = torch.nn.Linear(hidden_dim, dim)def forward(self, embeddings):gate = torch.nn.functional.silu(self.linear1(embeddings))up_proj = self.linear2(embeddings) * gatereturn self.linear3(up_proj)

最后,我们复现旋转位置编码,至此我们捋清了llama的所有结构!

def compute_rotated_embedding(embedding, dim, m, base):# 计算所有嵌入位置的旋转角度all_theta = compute_all_theta(dim, m, base)# 旋转后嵌入位置 = 复数平面上初始位置 * 复数平面上角度坐标# 1、将嵌入投影到复数平面embedding_real_pair = embedding.reshape(*embedding.shape[:-1], -1, 2)embedding_complex_pair = torch.view_as_complex(embedding_real_pair)# 2、将旋转角度投影到复数平面all_theta = all_theta[: embedding.shape[-2]]theta_complex_pair = torch.polar(torch.ones_like(all_theta), all_theta)# 3、旋转后嵌入位置 = 复数平面上初始位置 * 复数平面上角度坐标rotated_complex_embedding = embedding_complex_pair * theta_complex_pair# 4、将复数平面的嵌入投影到实数平面rotated_real_embedding = torch.view_as_real(rotated_complex_embedding)rotated_real_embedding = rotated_real_embedding.reshape(*embedding.shape[:-1], -1)return rotated_real_embeddingdef compute_all_theta(dim, m, base):theta = 1 / (base ** (torch.arange(0, dim / 2).float() / (dim / 2)))m = torch.arange(0, m)all_theta = torch.outer(m, theta)return all_theta

附录:llama的config参数

@dataclass
class ModelArgs:dim: int = 4096n_layers: int = 32n_heads: int = 32n_kv_heads: Optional[int] = Nonevocab_size: int = -1multiple_of: int = 256  # make SwiGLU hidden layer size multiple of large power of 2ffn_dim_multiplier: Optional[float] = Nonenorm_eps: float = 1e-5rope_theta: float = 500000max_batch_size: int = 32max_seq_len: int = 2048use_scaled_rope: bool = True

相关文章:

Llama架构及代码详解

Llama的框架图如图: 源码中含有大量分布式训练相关的代码,读起来比较晦涩难懂,所以我们对llama自顶向下进行了解析及复现,我们对其划分成三层,分别是顶层、中层、和底层,如下: Llama的整体组成…...

Android onConfigurationChanged 基础配置

onConfigurationChanged 代替重建 0. **定义与基本用途**1. **具体使用场景 - 屏幕方向改变**2. **具体使用场景 - 键盘可用性改变**3. **具体使用场景 - 语言设置变更**4. **具体使用场景 - 屏幕密度变化**5. **具体使用场景 - 字体大小改变**6. **具体使用场景 - 屏幕尺寸变化…...

3. Sharding-Jdbc核⼼流 程+多种分⽚策略

1. Sharding-Jdbc 分库分表执⾏核⼼流程 Sharding-JDBC执行流程 1. SQL解析 -> SQL优化 -> SQL路由 -> SQL改写 -> SQL执⾏-> 结果归并 ->返回结果简写为:解析->路由->改写->执⾏->结果归并1.1 SQL解析 1. SQL解析过程分为词法解析…...

为什么财富的蓝图如此重要

我们生活在一个二元对立的世界里:上与下、明与暗、冷与热内与外、快与慢、左与右。这些还只是千百种对立之中的几个例子而已。 有了一个极端,表示一定同时有相对的另一端存在。有了右边不可能没有左边。 所以,在钱这件事上,有外…...

【云计算解决方案面试整理】1-2云计算基础概念及云计算技术原理

准备面云计算解决方案的岗位,整理了一些,也请大佬们指点。 文档分为 云计算基础概念、云计算技术原理、主流云计算平台(以天翼云为例)、云计算架构(弹性设计、高可用设计、高性能设计)、安全防护几个方面。 一、云计算基础概念 1.请简要解释一下什么是云计算? 简单说呢…...

循环语句 while()... 与 for()...(day11)

一、while()与do...while()... 循环语句: 通过循环语句可以反复执行一段代码多次 1、while循环: - 语法: while(①条件表达式){ ②语句... } - while语句在执行时, 先对条件表达式进行求值判断, 如果值为true&#…...

Mysql篇-三大日志

概述 undo log(回滚日志):是 Innodb 存储引擎层生成的日志,实现了事务中的原子性,主要用于事务回滚和 MVCC。 redo log(重做日志):是 Innodb 存储引擎层生成的日志,实现…...

MySQL的SQL书写顺序和执行顺序

老是忘记执行顺序,记录一下: 1. SQL语句的书写顺序 书写顺序通常是我们编写SQL查询时的顺序,主要包括以下关键字: SELECT:选择要查询的字段。FROM:指定数据来源表。JOIN(可选)&am…...

摄像机视频分析软件下载LiteAIServer视频智能分析软件抖动检测的技术实现

在现代社会中,视频监控系统扮演着至关重要的角色,其可靠性和有效性在很大程度上取决于视频质量。然而,由于多种因素,如摄像机安装不当、外部环境振动或视频信号传输的不稳定,视频画面常常出现抖动问题,这不…...

spring gateway 动态路由

##yml配置 spring:application:name: public-gateway # cloud: # gateway: # routes: # - id: mybatis-plus-test # 路由的唯一标识 # uri: http://192.168.3.188:9898 # 目标服务的地址 # predicates: # - Path/test/** # 匹配…...

除了 Postman,还有什么好用的 API 管理工具吗?

Postman在团队协作上的支持相对有限,且免费版本的功能较为基础,高级功能需要付费解锁。 为了寻找更加符合团队需求的解决方案,许多开发者开始探索其他API管理工具,其中Apifox便是备受推崇的选择之一。下面通过一个表格来简单了解…...

JAVA:探索 EasyExcel 的技术指南

1、简述 在 Java 开发中,Excel 文件的读写操作是一项常见的需求。阿里巴巴开源的 EasyExcel 提供了一种高效、简洁的解决方案,特别是在处理大规模数据时表现尤为突出。本文将详细介绍 EasyExcel 的优缺点、应用场景,并通过实例展示其基本用法…...

【数字图像处理+MATLAB】对图片进行伽马校正(Gamma Correction):使用幂律变换公式进行伽马变换

引言 伽马校正(Gamma Correction)是一种用于图像处理的技术,主要用于调整图像的亮度或对比度。其基本原理是对图像的每一个像素应用一个非线性变换,以更好地适应人眼的视觉感知。在数字图像处理中,伽马校正通常用于调…...

算法——螺旋矩阵II(leetcode59)

给你一个正整数 n ,生成一个包含 1 到 n^2所有元素,且元素按顺时针顺序螺旋排列的 n x n 正方形矩阵 matrix 。 对于螺旋矩阵来讲难点主要在于行或列放置元素时的边界条件,我们遵循一个循环不变量原则在放置行或列元素时遵循左闭右开来放置元…...

以往运维岗本人面试真题分享

以下是本人面试运维岗的一些面试经历,在此做个记录分享 目录 TCP/IP三次握手 IPtables IPtables四表五链都是什么? nat端口如何做? 开放本机的80端口该如何做? 如何在单用户模式下引导Centos? nginx轮询模式都有…...

macOS解决U盘装完系统容量变小的问题

发现原来256GB容量的U盘在mac电脑上只显示34GB,想起来之前用该U盘装过系统,最终搜到了以下解决方案,在此记录: (1) 查看盘符列表,找到需要格式化的U盘,假设为disk4 diskutil list(2) 卸载分区disk4 disk…...

ORA-00257: archiver error

ORA-00257: archiver error 归档满问题: 报错: SQL> conn admin/admin ERROR: ORA-00257: archiver error. Connect internal only, until freed. Warning: You are no longer connected to ORACLE. 检查空间: SQL> select name, tot…...

IO技术详解

IO监控项在监控中一直是很重要的存在,服务有IO,磁盘有IO,操作系统也有IO,IO到底是什么呢 IO IO,即“输入/输出”(Input/Output),是指计算机系统或设备之间交换数据的过程。这个概念…...

pySpark乱码

1.现象 python的变量包含中文,用format放入SQL中时,出现乱码 2.原因 python2默认编码是ascii 3.解决办法 使用python3,并且把所有print,改成带括号的 4.在pyspark中加入参数 spark.pyspark.driver.python/usr/bin/python3 …...

【MySQL 保姆级教学】事务的隔离级别(详细)--下(13)

事务的隔离级别 1. 如何理解事务的隔离性2. 事务隔离级别的分类3. 查看和设置事务隔离级别3.1 全局和会话隔离级别3.2 查看和设置隔离级别 4. 事务隔离级别的演示4.1 读未提交(Read Uncommitted)4.2 读已提交(Read Committed)4.3 …...

关注模块 API

关注用户 POST /api/v1/relations/followHeaders:Authorization: Bearer {token}Request: {"user_id": "target_user_id" }Response: {"code": 0,"data": {"relation_type": "following"} }接口语义设计 POST /…...

YOLOv11光伏板二极管异常目标检测数据集-45张-Solar-panel-anomalies-1

YOLOv11光伏板二极管异常目标检测数据集 📊 数据集基本信息 目标类别: [‘Diode anomaly’, ‘Hot Spots’, ‘Reverse polarity’]中文类别:[‘二极管异常’, ‘热点’, ‘反向极性’]训练集:31 张验证集:9 张测试集&…...

机器学习论文有效阅读:三层穿透法定位技术杠杆点

1. 这不是“读论文”,而是“拆解模型生长的土壤”你有没有过这种体验:打开一篇顶会论文,标题写着《Neural Architecture Search with Reinforcement Learning》,摘要读得热血沸腾,结果翻到Methodology部分,…...

AI Agent Harness Engineering 在餐饮行业的应用:智能点餐与库存管理

标题选项 《从排队到零浪费:AI Agent Harness Engineering 重构餐饮智能点餐与库存管理全链路》 《AI Agent 落地餐饮行业实战:基于Harness框架打造高可用智能点餐+库存联动系统》 《告别漏单、超卖、食材浪费:AI Agent Harness 工程化在餐饮场景的落地指南》 《垂直行业Age…...

注塑行业的数智化突围:告别“黑盒”生产,拥抱透明化管理新纪元

在从“经验驱动”向“数据驱动”的关键跃迁中,注塑成型作为典型的离散制造环节,其数字化转型的痛点尤为尖锐。盘古信息基于近二十年的行业深耕,依托其自主研发的IMS工软底座,为注塑行业带来了一套完整的数智化破局方案&#xff0c…...

雷达信号体制识别

雷达信号体制识别 摘要 本文档基于工程中的信号识别流水线入口脚本及其所依赖的核心模块,系统梳理该工程如何实现雷达脉冲信号的体制分类(Signal Type Recognition)。该流水线采用“脉冲检测 → 脉冲描述字提取 → 脉内特征分析 → 驻留段分段…...

毕业设计 深度学习车道线检测(源码+论文)

文章目录 0 前言1 项目运行效果2 课题背景3 卷积神经网络3.1卷积层3.2 池化层3.3 激活函数:3.4 全连接层3.5 使用tensorflow中keras模块实现卷积神经网络 4 YOLOV56 数据集处理7 模型训练8 最后 0 前言 🔥这两年开始毕业设计和毕业答辩的要求和难度不断…...

日薪2700的护网HW面试,以及HW全面熟悉必看流程

前言 参与hvv的事情还是要想办法规避掉很多坑的。网络安全这个行业现阶段还是主要政策驱动,后面应该是客户意识,现在用户教育成本明显比以前低太多。 1.关于HVV的一个简单流程 首先我带大家从甲方和厂商的角度来分解一下整个护网流程的核心逻辑 第一阶段…...

2026年免费照片去水印软件App排行榜|去水印App推荐和评测指南

照片被水印困扰是很多用户的常见问题。无论是保存网络上的精美图片、处理工作资料,还是制作个人素材库,去水印都是一个实用的需求。本篇文章根据2026年最新的工具体验,为你梳理免费照片去水印软件app有哪些、各类去水印App怎么选择&#xff0…...

GitHub Desktop中文汉化解决方案:智能文本映射技术实现界面本地化

GitHub Desktop中文汉化解决方案:智能文本映射技术实现界面本地化 【免费下载链接】GitHubDesktop2Chinese GithubDesktop语言本地化(汉化)工具 【GitHub桌面客户端中文汉化】 项目地址: https://gitcode.com/gh_mirrors/gi/GitHubDesktop2Chinese GitHub De…...