当前位置: 首页 > news >正文

大模型入门3:理解LLAMA

  • LLama在transformers库中的代码,以及各部分原理
  • Llama3.1技术报告
  • LLama 33b 微调尝试

Model

  • a stack of DecoderBlocks(SelfAttention, FeedForward, and RMSNorm)
    在这里插入图片描述
    decoder block 整体结构:最大的区别在pre-norm

x -> norm(x) -> attention() -> residual connect -> norm() -> ffn -> residual connect

class DecoderBlock(nn.Module):def __init__(self, config):super().__init__()self.n_heads = config['n_heads']self.dim = config['embed_dim']self.head_dim = self.dim // self.n_headsself.attention = SelfAttention(config)self.feed_forward = FeedForward(config)# rms before attention blockself.attention_norm = RMSNorm(self.dim, eps=config['norm_eps'])# rms before  feed forward blockself.ffn_norm = RMSNorm(self.dim, eps=config['norm_eps'])def forward(self, x, start_pos, freqs_complex):# (m, seq_len, dim)h = x + self.attention.forward(self.attention_norm(x), start_pos, freqs_complex)# (m, seq_len, dim)out = h + self.feed_forward.forward(self.ffn_norm(h))return outclass Transformer(nn.Module):def __init__(self, config):super().__init__()self.vocab_size = config['vocab_size']self.n_layers = config['n_layers']self.tok_embeddings = nn.Embedding(self.vocab_size, config['embed_dim'])self.head_dim = config['embed_dim'] // config['n_heads']self.layers = nn.ModuleList()for layer_id in range(config['n_layers']):self.layers.append(DecoderBlock(config))self.norm = RMSNorm(config['embed_dim'], eps=config['norm_eps'])self.output = nn.Linear(config['embed_dim'], self.vocab_size, bias=False)self.freqs_complex = precompute_theta_pos_frequencies(self.head_dim, config['max_seq_len'] * 2, device=(config['device']))def forward(self, tokens, start_pos):# (m, seq_len)batch_size, seq_len = tokens.shape# (m, seq_len) -> (m, seq_len, embed_dim)h = self.tok_embeddings(tokens)# (seq_len, (embed_dim/n_heads)/2]freqs_complex = self.freqs_complex[start_pos:start_pos + seq_len]# Consecutively apply all the encoder layers# (m, seq_len, dim)for layer in self.layers:h = layer(h, start_pos, freqs_complex)h = self.norm(h)# (m, seq_len, vocab_size)output = self.output(h).float()return outputmodel = Transformer(config).to(config['device'])
res = model.forward(test_set['input_ids'].to(config['device']), 0)
print(res.size())

RoPE

在这里插入图片描述

def precompute_theta_pos_frequencies(head_dim, seq_len, device, theta=10000.0):# theta_i = 10000^(-2(i-1)/dim) for i = [1, 2, ... dim/2]# (head_dim / 2)theta_numerator = torch.arange(0, head_dim, 2).float()theta = 1.0 / (theta ** (theta_numerator / head_dim)).to(device)# (seq_len)m = torch.arange(seq_len, device=device)# (seq_len, head_dim / 2)freqs = torch.outer(m, theta).float()# complex numbers in polar, c = R * exp(m * theta), where R = 1:# (seq_len, head_dim/2)freqs_complex = torch.polar(torch.ones_like(freqs), freqs)return freqs_complexdef apply_rotary_embeddings(x, freqs_complex, device):# last dimension pairs of two values represent real and imaginary# two consecutive values will become a single complex number# (m, seq_len, num_heads, head_dim/2, 2)x = x.float().reshape(*x.shape[:-1], -1, 2)# (m, seq_len, num_heads, head_dim/2)x_complex = torch.view_as_complex(x)# (seq_len, head_dim/2) --> (1, seq_len, 1, head_dim/2)freqs_complex = freqs_complex.unsqueeze(0).unsqueeze(2)# multiply each complex number# (m, seq_len, n_heads, head_dim/2)x_rotated = x_complex * freqs_complex# convert back to the real number# (m, seq_len, n_heads, head_dim/2, 2)x_out = torch.view_as_real(x_rotated)# (m, seq_len, n_heads, head_dim)x_out = x_out.reshape(*x.shape)return x_out.type_as(x).to(device)

RMS norm

class RMSNorm(nn.Module):def __init__(self, dim, eps=1e-6):super().__init__()self.eps = epsself.weight = nn.Parameter(torch.ones(dim))def _norm(self, x: torch.Tensor):# (m, seq_len, dim) * (m, seq_len, 1) = (m, seq_len, dim)# rsqrt: 1 / sqrt(x)return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)def forward(self, x: torch.Tensor):# weight is a gain parameter used to re-scale the standardized summed inputs# (dim) * (m, seq_len, dim) = (m, seq_Len, dim)return self.weight * self._norm(x.float()).type_as(x)

KV Caching

在这里插入图片描述
在这里插入图片描述

class KVCache:def __init__(self, max_batch_size, max_seq_len, n_kv_heads, head_dim, device):self.cache_k = torch.zeros((max_batch_size, max_seq_len, n_kv_heads, head_dim)).to(device)self.cache_v = torch.zeros((max_batch_size, max_seq_len, n_kv_heads, head_dim)).to(device)def update(self, batch_size, start_pos, xk, xv):self.cache_k[:batch_size, start_pos :start_pos + xk.size(1)] = xkself.cache_v[:batch_size, start_pos :start_pos + xv.size(1)] = xvdef get(self, batch_size, start_pos, seq_len):keys = self.cache_k[:batch_size,  :start_pos + seq_len]values = self.cache_v[:batch_size, :start_pos + seq_len]return keys, values

Grouped Query Attention

在这里插入图片描述

def repeat_kv(x, n_rep):batch_size, seq_len, n_kv_heads, head_dim = x.shapeif n_rep == 1:return xelse:# (m, seq_len, n_kv_heads, 1, head_dim)# --> (m, seq_len, n_kv_heads, n_rep, head_dim)# --> (m, seq_len, n_kv_heads * n_rep, head_dim)return (x[:, :, :, None, :].expand(batch_size, seq_len, n_kv_heads, n_rep, head_dim).reshape(batch_size, seq_len, n_kv_heads * n_rep, head_dim))class SelfAttention(nn.Module):def __init__(self, config):super().__init__()self.n_heads = config['n_heads']self.n_kv_heads = config['n_kv_heads']self.dim = config['embed_dim']self.n_kv_heads = self.n_heads if self.n_kv_heads is None else self.n_kv_headsself.n_heads_q = self.n_headsself.n_rep = self.n_heads_q // self.n_kv_headsself.head_dim = self.dim // self.n_headsself.wq = nn.Linear(self.dim, self.n_heads * self.head_dim, bias=False)self.wk = nn.Linear(self.dim, self.n_kv_heads * self.head_dim, bias=False)self.wv = nn.Linear(self.dim, self.n_kv_heads * self.head_dim, bias=False)self.wo = nn.Linear(self.n_heads * self.head_dim, self.dim, bias=False)self.cache = KVCache(max_batch_size=config['max_batch_size'],max_seq_len=config['max_seq_len'],n_kv_heads=self.n_kv_heads,head_dim=self.head_dim,device=config['device'])def forward(self, x, start_pos, freqs_complex):# seq_len is always 1 during inferencebatch_size, seq_len, _ = x.shape# (m, seq_len, dim)xq = self.wq(x)# (m, seq_len, h_kv * head_dim)xk = self.wk(x)xv = self.wv(x)# (m, seq_len, n_heads, head_dim)xq = xq.view(batch_size, seq_len, self.n_heads_q, self.head_dim)# (m, seq_len, h_kv, head_dim)xk = xk.view(batch_size, seq_len, self.n_kv_heads, self.head_dim)xv = xv.view(batch_size, seq_len, self.n_kv_heads, self.head_dim)# (m, seq_len, num_head, head_dim)xq = apply_rotary_embeddings(xq, freqs_complex, device=x.device)# (m, seq_len, h_kv, head_dim)xk = apply_rotary_embeddings(xk, freqs_complex, device=x.device)# replace the entry in the cacheself.cache.update(batch_size, start_pos, xk, xv)# (m, seq_len, h_kv, head_dim)keys, values = self.cache.get(batch_size, start_pos, seq_len)# (m, seq_len, h_kv, head_dim) --> (m, seq_len, n_heads, head_dim)keys = repeat_kv(keys, self.n_rep)values = repeat_kv(values, self.n_rep)# (m, n_heads, seq_len, head_dim)# seq_len is 1 for xq during inferencexq = xq.transpose(1, 2)# (m, n_heads, seq_len, head_dim)keys = keys.transpose(1, 2)values = values.transpose(1, 2)# (m, n_heads, seq_len_q, head_dim) @ (m, n_heads, head_dim, seq_len) -> (m, n_heads, seq_len_q, seq_len)scores = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(self.head_dim)# (m, n_heads, seq_len_q, seq_len)scores = F.softmax(scores.float(), dim=-1).type_as(xq)# (m, n_heads, seq_len_q, seq_len) @ (m, n_heads, seq_len, head_dim) -> (m, n_heads, seq_len_q, head_dim)output = torch.matmul(scores, values)# ((m, n_heads, seq_len_q, head_dim) -> (m, seq_len_q, dim)output = (output.transpose(1, 2).contiguous().view(batch_size, seq_len, -1))# (m, seq_len_q, dim)return self.wo(output)

SwiGlu

def sigmoid(x, beta=1):return 1 / (1 + torch.exp(-x * beta))def swiglu(x, beta=1):return x * sigmoid(x, beta)
class FeedForward(nn.Module):def __init__(self, config):super().__init__()hidden_dim = 4 * config['embed_dim']hidden_dim = int(2 * hidden_dim / 3)if config['ffn_dim_multiplier'] is not None:hidden_dim = int(config['ffn_dim_multiplier'] * hidden_dim)# Round the hidden_dim to the nearest multiple of the multiple_of parameterhidden_dim = config['multiple_of'] * ((hidden_dim + config['multiple_of'] - 1) // config['multiple_of'])self.w1 = nn.Linear(config['embed_dim'], hidden_dim, bias=False)self.w2 = nn.Linear(config['embed_dim'], hidden_dim, bias=False)self.w3 = nn.Linear(hidden_dim, config['embed_dim'], bias=False)def forward(self, x: torch.Tensor):# (m, seq_len, dim) --> (m, seq_len, hidden_dim)swish = swiglu(self.w1(x))# (m, seq_len, dim) --> (m, seq_len, hidden_dim)x_V = self.w2(x)# (m, seq_len, hidden_dim)x = swish * x_V# (m, seq_len, hidden_dim) --> (m, seq_len, dim)return self.w3(x)

小结

  • padding 方式

reference

  • llama tech report
  • 源码:transformers
  • 参数量计算: https://zhuanlan.zhihu.com/p/676113501
  • 基于 MLX 的 LLAMA2-13B 的详细分析 - 亚东的文章 - 知乎 https://zhuanlan.zhihu.com/p/677125915
  • 2023年你最喜欢的MLSys相关的工作是什么? - Lin Zhang的回答 - 知乎
  • https://ai.plainenglish.io/understanding-llama2-kv-cache-grouped-query-attention-rotary-embedding-and-more-c17e5f49a6d7
  • https://github.com/wdndev/llama3-from-scratch-zh/blob/main/llama3/model.py

相关文章:

大模型入门3:理解LLAMA

LLama在transformers库中的代码,以及各部分原理Llama3.1技术报告LLama 33b 微调尝试 Model a stack of DecoderBlocks(SelfAttention, FeedForward, and RMSNorm) decoder block 整体结构:最大的区别在pre-norm x -> norm(x) -> attention() -…...

React学习day07-ReactRouter-抽象路由模块、路由导航、路由导航传参、嵌套路由、默认二级路由的设置、两种路由模式

14、ReactRouter续 (2)抽象路由模块 1)新建page文件夹,存放组件 组件内容: 2)新建router文件夹,在其下创建实例 3)实例导入,使用 4)效果 (3&…...

Unity项目的脚本继承关系

1.Unity项目的脚本继承关系包括四层:自己的脚本、MonoBehaviour、Behaviour、Component、Object。 2.通过F12跳转可以查看各继承类中的方法和属性,如MonoBehaviour类中主要包括协程和相关API。 3.Component类中包含组件的只读属性、消息发送等API&…...

【自动驾驶】决策规划算法(一)决策规划仿真平台搭建 | Matlab + Prescan + Carsim 联合仿真基本操作

写在前面: 🌟 欢迎光临 清流君 的博客小天地,这里是我分享技术与心得的温馨角落。📝 个人主页:清流君_CSDN博客,期待与您一同探索 移动机器人 领域的无限可能。 🔍 本文系 清流君 原创之作&…...

grep 命令:文本搜索

一、grep 命令简介 ​grep ​命令用于在文件中搜索指定模式的文本,并显示匹配的行。 ‍ 二、grep 命令参数 匹配规则:可以是 普通字符 ​串或 正则表达式​。 grep [选项] [匹配规则] [指定目录]常用选项: ​-i, --ignore-case​&#…...

python画图|中秋到了,尝试画个月亮(球体画法)

学习了一段时间的画图,已经掌握了一些3D图的画法,部分链接如下: python画图|极坐标下的3D surface-CSDN博客 python画图|3D参数化图形输出-CSDN博客 我们今天尝试一下月亮的画法。 【1】官网教程 首先还是到达官网教程学习: …...

【网络安全的神秘世界】攻防环境搭建及漏洞原理学习

🌝博客主页:泥菩萨 💖专栏:Linux探索之旅 | 网络安全的神秘世界 | 专接本 | 每天学会一个渗透测试工具 Kali安装docker 安装教程 PHP攻防环境搭建 中间件 介于应用系统和系统软件之间的软件。 能为多种应用程序合作互通、资源…...

pythonnet python图像 C# .NET图像 互转

C#是dotnet的代表虽然不是一个东西但是在这里代表同一件事,不要在意细节。 pythonnet是 python 和.net无缝连接的桥梁。那么python的图像是numpy表示,C#图象是Bitmap。 做图像想要python的便利又想要dotnet的强大就需要图像类型转换。 上程序。 1.Bi…...

spring security OAuth2 搭建资源服务器以及授权服务器/jdbc/jwt两种方案

一、认证服务器基于jdbc方式 如果不懂请移步上一篇文章:Spring security OAuth2 授权服务器搭建-CSDN博客 在上一篇文章中,TokenStore的默认实现为 InHenoryTokenStore 即内存存储,对于 CLient 信息,userDetaitsServce 接负责从存…...

计算机视觉—3d点云数据基础

点云数据 3d点云数据由来 3d点云 3D Point Cloud是一种用于表示三维空间中对象或场景的数据结构。在最基础的形式中,它是一个包含多个三维坐标点(X, Y, Z)的集合。这些点是通过对实际物体或场景表面进行离散采样而获得的,因此&a…...

Matlab simulink建模与仿真 第十八章(Stateflow状态机)

参考视频:Simulink/stateflow的入门培训_哔哩哔哩_bilibili 一、概述 Stateflow是集成于Simulink中的图形化设计与开发工具,主要用于针对控制系统中的复杂控制逻辑进行建模与仿真,或者说,Stateflow适用于针对事件响应系统进行建模…...

Linux系统终端中文件权限的10位字符是什么意思

Linux操作系统终端长格式显示的文件 在Linux操作系统终端中用文件长格式命令ls -l显示文件,如上图。第一列10个字符表示的含义如下: drwxrwxrwx 第一个字符是表示该文件的类型,如红色d表示该文件是一个目录,详细内容可以参考我…...

Qt QSerialPort串口编程

文章目录 Qt QSerialPort串口编程Qt Serial Port模块简述1.QSerialPortInfo类1.1示例用法 2.QSerialPort类2.1设置串口参数2.2打开串口2.3数据读写2.4关闭串口 3.串口编程基本流程3.1 简单实例 Qt QSerialPort串口编程 Qt 框架的Qt Serial Port 模块提供了访问串口的基本功能&…...

扫雷游戏及其中的知识点

大家好呀,今天我们给大家讲解扫雷游戏如何用C语言制作,以及制作扫雷游戏中的一些C语言知识。 想到扫雷游戏,大家有什么想法吗?大家还记得扫雷游戏是什么样子的吗?我在网上找了一些扫雷游戏的图片给大家提供参考: 如图所示,扫雷游戏需要的元素有以下几个: 1.进入游戏界面…...

【乐企-业务篇】开票前置校验服务-规则链服务接口实现(发票基础信息校验)

开票前置校验服务-规则链服务接口实现(发票基础信息校验) 代码 import liquibase.pro.packaged.L; import org.apache.commons.collections4.Collec...

【搜索算法】以扩召回为目标,item-tag不如query-tag能扩更多数量

首先ElasticSearch的召回结果已大量解决了精确召回的问题,扩召回主要就是增加一些推荐的搜索结果。 以item类目tag为例, 如果item类目体系一共20个类目,每个item都有一个类目,一共有10000个item,则平均每个类目tag下有…...

SpringBoot入门(黑马)

1. SpringBootWeb入门开发 需求:使用SpringBoot 开发一个web 应用,浏览器发起请求 /hello 后,给浏览器返回字符串"Hello World~"。 步骤: 1. 创建springBoot工程,并勾选web开发相关依赖。 2. 定义 HelloCo…...

Stream流操作

准备工作 准备 Gender 枚举类以及 Customer 类 enum Gender {MALE("男性"), FEMALE("女性");private String value;Gender() {}Gender(String value) {this.value value;}Overridepublic String toString() {return "Gender{" "value&qu…...

【Linux】查看操作系统开机时初始化的驱动模块列表的一个方法

这个方法是摸索出来的,也不一定对: 1、驱动层module_init(module_init_function)作为模块初始化,并且提供模块内部初始化的函数名; 2、找到所有驱动目录drivers下所有module_init(module_init_function),在内核6.9.0…...

快速入门Vue

Vue是什么 Vue.js(通常简称为Vue)是一个开源的JavaScript框架,用于构建用户界面和单页应用程序(SPA)。它由尤雨溪(Evan You)在2014年开发并发布。Vue的核心库只关注视图层,易于上手…...

Ubuntu系统下交叉编译openssl

一、参考资料 OpenSSL&&libcurl库的交叉编译 - hesetone - 博客园 二、准备工作 1. 编译环境 宿主机:Ubuntu 20.04.6 LTSHost:ARM32位交叉编译器:arm-linux-gnueabihf-gcc-11.1.0 2. 设置交叉编译工具链 在交叉编译之前&#x…...

将对透视变换后的图像使用Otsu进行阈值化,来分离黑色和白色像素。这句话中的Otsu是什么意思?

Otsu 是一种自动阈值化方法,用于将图像分割为前景和背景。它通过最小化图像的类内方差或等价地最大化类间方差来选择最佳阈值。这种方法特别适用于图像的二值化处理,能够自动确定一个阈值,将图像中的像素分为黑色和白色两类。 Otsu 方法的原…...

基础测试工具使用经验

背景 vtune,perf, nsight system等基础测试工具,都是用过的,但是没有记录,都逐渐忘了。所以写这篇博客总结记录一下,只要以后发现新的用法,就记得来编辑补充一下 perf 比较基础的用法: 先改这…...

cf2117E

原题链接&#xff1a;https://codeforces.com/contest/2117/problem/E 题目背景&#xff1a; 给定两个数组a,b&#xff0c;可以执行多次以下操作&#xff1a;选择 i (1 < i < n - 1)&#xff0c;并设置 或&#xff0c;也可以在执行上述操作前执行一次删除任意 和 。求…...

在WSL2的Ubuntu镜像中安装Docker

Docker官网链接: https://docs.docker.com/engine/install/ubuntu/ 1、运行以下命令卸载所有冲突的软件包&#xff1a; for pkg in docker.io docker-doc docker-compose docker-compose-v2 podman-docker containerd runc; do sudo apt-get remove $pkg; done2、设置Docker…...

全志A40i android7.1 调试信息打印串口由uart0改为uart3

一&#xff0c;概述 1. 目的 将调试信息打印串口由uart0改为uart3。 2. 版本信息 Uboot版本&#xff1a;2014.07&#xff1b; Kernel版本&#xff1a;Linux-3.10&#xff1b; 二&#xff0c;Uboot 1. sys_config.fex改动 使能uart3(TX:PH00 RX:PH01)&#xff0c;并让boo…...

tree 树组件大数据卡顿问题优化

问题背景 项目中有用到树组件用来做文件目录&#xff0c;但是由于这个树组件的节点越来越多&#xff0c;导致页面在滚动这个树组件的时候浏览器就很容易卡死。这种问题基本上都是因为dom节点太多&#xff0c;导致的浏览器卡顿&#xff0c;这里很明显就需要用到虚拟列表的技术&…...

3-11单元格区域边界定位(End属性)学习笔记

返回一个Range 对象&#xff0c;只读。该对象代表包含源区域的区域上端下端左端右端的最后一个单元格。等同于按键 End 向上键(End(xlUp))、End向下键(End(xlDown))、End向左键(End(xlToLeft)End向右键(End(xlToRight)) 注意&#xff1a;它移动的位置必须是相连的有内容的单元格…...

Angular微前端架构:Module Federation + ngx-build-plus (Webpack)

以下是一个完整的 Angular 微前端示例&#xff0c;其中使用的是 Module Federation 和 npx-build-plus 实现了主应用&#xff08;Shell&#xff09;与子应用&#xff08;Remote&#xff09;的集成。 &#x1f6e0;️ 项目结构 angular-mf/ ├── shell-app/ # 主应用&…...

【SSH疑难排查】轻松解决新版OpenSSH连接旧服务器的“no matching...“系列算法协商失败问题

【SSH疑难排查】轻松解决新版OpenSSH连接旧服务器的"no matching..."系列算法协商失败问题 摘要&#xff1a; 近期&#xff0c;在使用较新版本的OpenSSH客户端连接老旧SSH服务器时&#xff0c;会遇到 "no matching key exchange method found"​, "n…...