当前位置: 首页 > news >正文

Llama架构及代码详解

Llama的框架图如图:
在这里插入图片描述
源码中含有大量分布式训练相关的代码,读起来比较晦涩难懂,所以我们对llama自顶向下进行了解析及复现,我们对其划分成三层,分别是顶层、中层、和底层,如下:

Llama的整体组成

由上图可知,Llama整体是由1个embedding层,n个transformer层,和1个RMSNorm层组成的,所以顶层代码如下:
顶层

class Llama(torch.nn.Module):def __init__(self, config: ModelArgs):super().__init__()self.config = config# embedding层self.tok_embeddings = torch.nn.Embedding(self.config.vocab_size, self.config.dim)# RMSNormself.norm = RMSNorm(config.dim, eps=config.norm_eps)# n层Transformerself.layers = torch.nn.ModuleList()for i in range(self.config.n_layers):self.layers.append(TransformerBlock(config))def forward(self, tokens):# 进行token的嵌入编码h = self.tok_embeddings(tokens)# decoder架构需要生成一个maskseqlen = h.shape[1]mask = torch.full((seqlen, seqlen), float('-inf'), device=tokens.device)mask = torch.triu(mask, diagonal=1)# 进行n层Transformerfor i in range(self.config.n_layers):h = self.layers[i](h, mask)# 进行RMSNormtoken_embeddings = self.norm(h)return token_embeddings

中层
我们首先进行RMSNorm的复现

class RMSNorm(torch.nn.Module):def __init__(self, dim, eps):super().__init__()self.eps = epsself.weight = torch.nn.Parameter(torch.ones(dim))def _norm(self, tensor):return tensor * torch.rsqrt(tensor.pow(2).mean(-1, keepdim=True) + self.eps)def forward(self, tensor):output = self._norm(tensor)return output * self.weight

然后对Transformer进行复现,在Transformer中,Transformer包括两个RMSNorm层,一个多头attention层,一个全连接层。

class TransformerBlock(torch.nn.Module):def __init__(self, config):super().__init__()self.config = config# 多头注意力层self.attention = Attention(config)# Norm层self.attention_normal = RMSNorm(config.dim, config.norm_eps)self.ffn_norm = RMSNorm(config.dim, config.norm_eps)# 全连接层self.ffn = FeedForwad(self.config.dim, self.config.dim * 4)def forward(self, embeddings, mask):# normh = self.attention_normal(embeddings)# attentionh = self.attention(h, mask)# add & normh = self.ffn_norm(h + embeddings)# fnnf = self.ffn(h)# addreturn f + h

底层
在多头attention中,首先需要对token的嵌入进行空间映射,多头拆分,旋转位置编码,分数计算等操作

class Attention(torch.nn.Module):def __init__(self, config):super().__init__()self.config = configself.n_head = config.n_headsself.dim = config.dim // self.n_headself.k = torch.nn.Linear(config.dim, config.dim)self.q = torch.nn.Linear(config.dim, config.dim)self.v = torch.nn.Linear(config.dim, config.dim)def forward(self, embeddings, mask):bsz, seq_len, dim = embeddings.shapek_embeddings = self.k(embeddings)q_embeddings = self.q(embeddings)v_embeddings = self.v(embeddings)n_q_embeddings = q_embeddings.reshape(bsz, -1, self.n_head, self.dim).permute(0, 2, 1, 3)n_k_embeddings = k_embeddings.reshape(bsz, -1, self.n_head, self.dim).permute(0, 2, 1, 3)n_v_embeddings = v_embeddings.reshape(bsz, -1, self.n_head, self.dim).permute(0, 2, 1, 3)rotated_n_q_embeddings = compute_rotated_embedding(n_q_embeddings, self.dim, seq_len, self.config.rope_theta)rotated_n_k_embeddings = compute_rotated_embedding(n_k_embeddings, self.dim, seq_len, self.config.rope_theta)scores = torch.nn.functional.softmax(mask + rotated_n_q_embeddings @ rotated_n_k_embeddings.transpose(-1, -2)/ math.sqrt(self.dim), dim=-1)n_embeddings = scores @ n_v_embeddingsembeddings = n_embeddings.permute(0, 2, 1, 3).reshape(bsz, -1, self.config.dim)return embeddings
class FeedForwad(torch.nn.Module):def __init__(self, dim, hidden_dim):super().__init__()self.linear1 = torch.nn.Linear(dim, hidden_dim)self.linear2 = torch.nn.Linear(dim, hidden_dim)self.linear3 = torch.nn.Linear(hidden_dim, dim)def forward(self, embeddings):gate = torch.nn.functional.silu(self.linear1(embeddings))up_proj = self.linear2(embeddings) * gatereturn self.linear3(up_proj)

最后,我们复现旋转位置编码,至此我们捋清了llama的所有结构!

def compute_rotated_embedding(embedding, dim, m, base):# 计算所有嵌入位置的旋转角度all_theta = compute_all_theta(dim, m, base)# 旋转后嵌入位置 = 复数平面上初始位置 * 复数平面上角度坐标# 1、将嵌入投影到复数平面embedding_real_pair = embedding.reshape(*embedding.shape[:-1], -1, 2)embedding_complex_pair = torch.view_as_complex(embedding_real_pair)# 2、将旋转角度投影到复数平面all_theta = all_theta[: embedding.shape[-2]]theta_complex_pair = torch.polar(torch.ones_like(all_theta), all_theta)# 3、旋转后嵌入位置 = 复数平面上初始位置 * 复数平面上角度坐标rotated_complex_embedding = embedding_complex_pair * theta_complex_pair# 4、将复数平面的嵌入投影到实数平面rotated_real_embedding = torch.view_as_real(rotated_complex_embedding)rotated_real_embedding = rotated_real_embedding.reshape(*embedding.shape[:-1], -1)return rotated_real_embeddingdef compute_all_theta(dim, m, base):theta = 1 / (base ** (torch.arange(0, dim / 2).float() / (dim / 2)))m = torch.arange(0, m)all_theta = torch.outer(m, theta)return all_theta

附录:llama的config参数

@dataclass
class ModelArgs:dim: int = 4096n_layers: int = 32n_heads: int = 32n_kv_heads: Optional[int] = Nonevocab_size: int = -1multiple_of: int = 256  # make SwiGLU hidden layer size multiple of large power of 2ffn_dim_multiplier: Optional[float] = Nonenorm_eps: float = 1e-5rope_theta: float = 500000max_batch_size: int = 32max_seq_len: int = 2048use_scaled_rope: bool = True

相关文章:

Llama架构及代码详解

Llama的框架图如图: 源码中含有大量分布式训练相关的代码,读起来比较晦涩难懂,所以我们对llama自顶向下进行了解析及复现,我们对其划分成三层,分别是顶层、中层、和底层,如下: Llama的整体组成…...

Android onConfigurationChanged 基础配置

onConfigurationChanged 代替重建 0. **定义与基本用途**1. **具体使用场景 - 屏幕方向改变**2. **具体使用场景 - 键盘可用性改变**3. **具体使用场景 - 语言设置变更**4. **具体使用场景 - 屏幕密度变化**5. **具体使用场景 - 字体大小改变**6. **具体使用场景 - 屏幕尺寸变化…...

3. Sharding-Jdbc核⼼流 程+多种分⽚策略

1. Sharding-Jdbc 分库分表执⾏核⼼流程 Sharding-JDBC执行流程 1. SQL解析 -> SQL优化 -> SQL路由 -> SQL改写 -> SQL执⾏-> 结果归并 ->返回结果简写为:解析->路由->改写->执⾏->结果归并1.1 SQL解析 1. SQL解析过程分为词法解析…...

为什么财富的蓝图如此重要

我们生活在一个二元对立的世界里:上与下、明与暗、冷与热内与外、快与慢、左与右。这些还只是千百种对立之中的几个例子而已。 有了一个极端,表示一定同时有相对的另一端存在。有了右边不可能没有左边。 所以,在钱这件事上,有外…...

【云计算解决方案面试整理】1-2云计算基础概念及云计算技术原理

准备面云计算解决方案的岗位,整理了一些,也请大佬们指点。 文档分为 云计算基础概念、云计算技术原理、主流云计算平台(以天翼云为例)、云计算架构(弹性设计、高可用设计、高性能设计)、安全防护几个方面。 一、云计算基础概念 1.请简要解释一下什么是云计算? 简单说呢…...

循环语句 while()... 与 for()...(day11)

一、while()与do...while()... 循环语句: 通过循环语句可以反复执行一段代码多次 1、while循环: - 语法: while(①条件表达式){ ②语句... } - while语句在执行时, 先对条件表达式进行求值判断, 如果值为true&#…...

Mysql篇-三大日志

概述 undo log(回滚日志):是 Innodb 存储引擎层生成的日志,实现了事务中的原子性,主要用于事务回滚和 MVCC。 redo log(重做日志):是 Innodb 存储引擎层生成的日志,实现…...

MySQL的SQL书写顺序和执行顺序

老是忘记执行顺序,记录一下: 1. SQL语句的书写顺序 书写顺序通常是我们编写SQL查询时的顺序,主要包括以下关键字: SELECT:选择要查询的字段。FROM:指定数据来源表。JOIN(可选)&am…...

摄像机视频分析软件下载LiteAIServer视频智能分析软件抖动检测的技术实现

在现代社会中,视频监控系统扮演着至关重要的角色,其可靠性和有效性在很大程度上取决于视频质量。然而,由于多种因素,如摄像机安装不当、外部环境振动或视频信号传输的不稳定,视频画面常常出现抖动问题,这不…...

spring gateway 动态路由

##yml配置 spring:application:name: public-gateway # cloud: # gateway: # routes: # - id: mybatis-plus-test # 路由的唯一标识 # uri: http://192.168.3.188:9898 # 目标服务的地址 # predicates: # - Path/test/** # 匹配…...

除了 Postman,还有什么好用的 API 管理工具吗?

Postman在团队协作上的支持相对有限,且免费版本的功能较为基础,高级功能需要付费解锁。 为了寻找更加符合团队需求的解决方案,许多开发者开始探索其他API管理工具,其中Apifox便是备受推崇的选择之一。下面通过一个表格来简单了解…...

JAVA:探索 EasyExcel 的技术指南

1、简述 在 Java 开发中,Excel 文件的读写操作是一项常见的需求。阿里巴巴开源的 EasyExcel 提供了一种高效、简洁的解决方案,特别是在处理大规模数据时表现尤为突出。本文将详细介绍 EasyExcel 的优缺点、应用场景,并通过实例展示其基本用法…...

【数字图像处理+MATLAB】对图片进行伽马校正(Gamma Correction):使用幂律变换公式进行伽马变换

引言 伽马校正(Gamma Correction)是一种用于图像处理的技术,主要用于调整图像的亮度或对比度。其基本原理是对图像的每一个像素应用一个非线性变换,以更好地适应人眼的视觉感知。在数字图像处理中,伽马校正通常用于调…...

算法——螺旋矩阵II(leetcode59)

给你一个正整数 n ,生成一个包含 1 到 n^2所有元素,且元素按顺时针顺序螺旋排列的 n x n 正方形矩阵 matrix 。 对于螺旋矩阵来讲难点主要在于行或列放置元素时的边界条件,我们遵循一个循环不变量原则在放置行或列元素时遵循左闭右开来放置元…...

以往运维岗本人面试真题分享

以下是本人面试运维岗的一些面试经历,在此做个记录分享 目录 TCP/IP三次握手 IPtables IPtables四表五链都是什么? nat端口如何做? 开放本机的80端口该如何做? 如何在单用户模式下引导Centos? nginx轮询模式都有…...

macOS解决U盘装完系统容量变小的问题

发现原来256GB容量的U盘在mac电脑上只显示34GB,想起来之前用该U盘装过系统,最终搜到了以下解决方案,在此记录: (1) 查看盘符列表,找到需要格式化的U盘,假设为disk4 diskutil list(2) 卸载分区disk4 disk…...

ORA-00257: archiver error

ORA-00257: archiver error 归档满问题: 报错: SQL> conn admin/admin ERROR: ORA-00257: archiver error. Connect internal only, until freed. Warning: You are no longer connected to ORACLE. 检查空间: SQL> select name, tot…...

IO技术详解

IO监控项在监控中一直是很重要的存在,服务有IO,磁盘有IO,操作系统也有IO,IO到底是什么呢 IO IO,即“输入/输出”(Input/Output),是指计算机系统或设备之间交换数据的过程。这个概念…...

pySpark乱码

1.现象 python的变量包含中文,用format放入SQL中时,出现乱码 2.原因 python2默认编码是ascii 3.解决办法 使用python3,并且把所有print,改成带括号的 4.在pyspark中加入参数 spark.pyspark.driver.python/usr/bin/python3 …...

【MySQL 保姆级教学】事务的隔离级别(详细)--下(13)

事务的隔离级别 1. 如何理解事务的隔离性2. 事务隔离级别的分类3. 查看和设置事务隔离级别3.1 全局和会话隔离级别3.2 查看和设置隔离级别 4. 事务隔离级别的演示4.1 读未提交(Read Uncommitted)4.2 读已提交(Read Committed)4.3 …...

Leetcode 3576. Transform Array to All Equal Elements

Leetcode 3576. Transform Array to All Equal Elements 1. 解题思路2. 代码实现 题目链接:3576. Transform Array to All Equal Elements 1. 解题思路 这一题思路上就是分别考察一下是否能将其转化为全1或者全-1数组即可。 至于每一种情况是否可以达到&#xf…...

Redis相关知识总结(缓存雪崩,缓存穿透,缓存击穿,Redis实现分布式锁,如何保持数据库和缓存一致)

文章目录 1.什么是Redis?2.为什么要使用redis作为mysql的缓存?3.什么是缓存雪崩、缓存穿透、缓存击穿?3.1缓存雪崩3.1.1 大量缓存同时过期3.1.2 Redis宕机 3.2 缓存击穿3.3 缓存穿透3.4 总结 4. 数据库和缓存如何保持一致性5. Redis实现分布式…...

生成 Git SSH 证书

🔑 1. ​​生成 SSH 密钥对​​ 在终端(Windows 使用 Git Bash,Mac/Linux 使用 Terminal)执行命令: ssh-keygen -t rsa -b 4096 -C "your_emailexample.com" ​​参数说明​​: -t rsa&#x…...

Reasoning over Uncertain Text by Generative Large Language Models

https://ojs.aaai.org/index.php/AAAI/article/view/34674/36829https://ojs.aaai.org/index.php/AAAI/article/view/34674/36829 1. 概述 文本中的不确定性在许多语境中传达,从日常对话到特定领域的文档(例如医学文档)(Heritage 2013;Landmark、Gulbrandsen 和 Svenevei…...

SiFli 52把Imagie图片,Font字体资源放在指定位置,编译成指定img.bin和font.bin的问题

分区配置 (ptab.json) img 属性介绍: img 属性指定分区存放的 image 名称,指定的 image 名称必须是当前工程生成的 binary 。 如果 binary 有多个文件,则以 proj_name:binary_name 格式指定文件名, proj_name 为工程 名&…...

Kafka入门-生产者

生产者 生产者发送流程: 延迟时间为0ms时,也就意味着每当有数据就会直接发送 异步发送API 异步发送和同步发送的不同在于:异步发送不需要等待结果,同步发送必须等待结果才能进行下一步发送。 普通异步发送 首先导入所需的k…...

CSS | transition 和 transform的用处和区别

省流总结: transform用于变换/变形,transition是动画控制器 transform 用来对元素进行变形,常见的操作如下,它是立即生效的样式变形属性。 旋转 rotate(角度deg)、平移 translateX(像素px)、缩放 scale(倍数)、倾斜 skewX(角度…...

jmeter聚合报告中参数详解

sample、average、min、max、90%line、95%line,99%line、Error错误率、吞吐量Thoughput、KB/sec每秒传输的数据量 sample(样本数) 表示测试中发送的请求数量,即测试执行了多少次请求。 单位,以个或者次数表示。 示例:…...

xmind转换为markdown

文章目录 解锁思维导图新姿势:将XMind转为结构化Markdown 一、认识Xmind结构二、核心转换流程详解1.解压XMind文件(ZIP处理)2.解析JSON数据结构3:递归转换树形结构4:Markdown层级生成逻辑 三、完整代码 解锁思维导图新…...

Mysql故障排插与环境优化

前置知识点 最上层是一些客户端和连接服务,包含本 sock 通信和大多数jiyukehuduan/服务端工具实现的TCP/IP通信。主要完成一些简介处理、授权认证、及相关的安全方案等。在该层上引入了线程池的概念,为通过安全认证接入的客户端提供线程。同样在该层上可…...