当前位置: 首页 > news >正文

大模型推理——MLA实现方案

1.整体流程

先上一张图来整体理解下MLA的计算过程

2.实现代码

import math
import torch
import torch.nn as nn# rms归一化
class RMSNorm(nn.Module):""""""def __init__(self, hidden_size, eps=1e-6):super().__init__()self.weight = nn.Parameter(torch.ones(hidden_size))self.variance_epsilon = epsdef forward(self, hidden_states):hidden_states = hidden_states.float()variance = hidden_states.pow(2).mean(-1, keepdim=True)hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)return self.weight * hidden_states.float()def rotate_half(x):x1, x2 = x.chunk(2, dim=-1)return torch.cat((-x2, x1), dim=-1)def apply_rotate_pos_emb(q, k, cos, sin, unsqueeze_dim=2):cos = cos.unsqueeze(unsqueeze_dim)sin = sin.unsqueeze(unsqueeze_dim)q_embed = (q * cos) + (rotate_half(q) * sin)k_embed = (k * cos) + (rotate_half(k) * sin)return q_embed, k_embed# 旋转位置编码
class RotaryEmbedding(nn.Module):def __init__(self, dim, max_seq_len=1024):super(RotaryEmbedding, self).__init__()self.dim = dimself.max_seq_len = max_seq_leninv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))t = torch.arange(max_seq_len).float().unsqueeze(1)freqs = t @ inv_freq.unsqueeze(0)freqs = torch.cat((freqs, freqs), dim=-1)self.register_buffer("cos_cached", freqs.cos())self.register_buffer("sin_cached", freqs.sin())def forward(self, q, k):cos = self.cos_cached[:q.shape[1], :].unsqueeze(0)sin = self.sin_cached[:q.shape[1], :].unsqueeze(0)return apply_rotate_pos_emb(q, k, cos, sin)class MLA(nn.Module):def __init__(self,dim,n_heads,q_lora_rank,kv_lora_rank,qk_nope_head_dim,qk_rope_head_dim,v_head_dim,max_seq_len,max_batch_size,mode):super().__init__()self.dim = dim  # 隐藏层维度self.n_heads = n_heads  # 总头数self.q_lora_rank = q_lora_rank  # q低秩压缩到的维度self.kv_lora_rank = kv_lora_rank  # k/v低秩压缩到的维度self.qk_nope_head_dim = qk_nope_head_dim    # q/k不带旋转位置编码的维度self.qk_rope_head_dim = qk_rope_head_dim    # q/k带旋转位置编码的维度self.qk_head_dim = qk_nope_head_dim + qk_rope_head_dim  # q/k的总维度,不带旋转位置编码的维度加上带旋转位置编码的维度self.v_head_dim = v_head_dim  # value的维度,等于不带旋转位置编码的k维度self.mode = modeself.max_seq_len = max_seq_lenself.max_batch_size = max_batch_sizeself.wq_a = nn.Linear(self.dim, self.q_lora_rank)  # q的降维矩阵self.q_norm = RMSNorm(self.q_lora_rank)self.wq_b = nn.Linear(self.q_lora_rank, self.n_heads * self.qk_head_dim)  # q的升维矩阵# 4096*128+128*4864 = 524,288 + 622592 = 1146880    4096*4864 = 19,922,944self.wkv_a = nn.Linear(self.dim, self.kv_lora_rank + self.qk_rope_head_dim)  # k/v的降维矩阵# nn.Linear(self.dim, self.kv_lora_rank)# nn.Linear(self.dim, self.qk_rope_head_dim)self.kv_norm = RMSNorm(self.kv_lora_rank)self.wkv_b = nn.Linear(self.kv_lora_rank, self.n_heads * (self.qk_nope_head_dim + self.v_head_dim))  # k/v的升维矩阵self.wo = nn.Linear(self.n_heads * self.v_head_dim, self.dim)self.rotary_emb = RotaryEmbedding(self.qk_rope_head_dim)  # 旋转位置编码# 没有矩阵融合if self.mode == 'naive':self.register_buffer('k_cache',torch.zeros(self.max_batch_size, self.max_seq_len, self.n_heads, self.qk_head_dim),persistent=False)self.register_buffer('v_cache',torch.zeros(self.max_batch_size, self.max_seq_len, self.n_heads, self.v_head_dim),persistent=False)# 有矩阵融合else:self.register_buffer('kv_cache', torch.zeros(self.max_batch_size, self.max_seq_len, self.kv_lora_rank),persistent=False)self.register_buffer('pe_cache', torch.zeros(self.max_batch_size, self.max_seq_len, self.qk_rope_head_dim),persistent=False)def forward(self, x, mask=None):bs, seq_len, _ = x.shapeq = self.wq_a(x)  # [bs, seq_len, q_lora_rank]q = self.q_norm(q)  # [bs, seq_len, q_lora_rank]q = self.wq_b(q)  # [bs, seq_len, n_heads * qk_head_dim]q = q.view(bs, seq_len, self.n_heads, self.qk_head_dim)  # [bs, seq_len, n_heads, qk_head_dim]q_nope, q_pe = torch.split(q, [self.qk_nope_head_dim, self.qk_rope_head_dim],dim=-1)  # q_nope shape:[bs, seq_len, n_heads, qk_nope_head_dim] q_pe shape:[bs, seq_len, n_heads, qk_rope_head_dim]kv = self.wkv_a(x)  # [bs, seq_len, kv_lora_rank + qk_rope_head_dim]kv, k_pe = torch.split(kv, [self.kv_lora_rank, self.qk_rope_head_dim],dim=-1)  # kv shape:[bs, seq_len, kv_lora_rank] k_pe shape:[bs, seq_len, qk_rope_head_dim]k_pe = k_pe.unsqueeze(2)  # k_pe shape:[bs, seq_len, 1, qk_rope_head_dim]   一层共享一个keyq_pe, k_pe = self.rotary_emb(q_pe, k_pe)if self.mode == 'naive':q = torch.cat([q_nope, q_pe], dim=-1)  # * [bs, seq_len, n_heads, qk_head_dim]kv = self.kv_norm(kv)  # [bs, seq_len, kv_lora_rank)]kv = self.wkv_b(kv)  # [bs, seq_len, n_heads * (qk_nope_head_dim + v_head_dim)]kv = kv.view(bs, seq_len, self.n_heads, self.qk_nope_head_dim + self.v_head_dim)k_nope, v = torch.split(kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1)k = torch.cat([k_nope, k_pe.expand(-1, -1, self.n_heads, -1)], dim=-1)# k shape:[bs, seq_len, n_heads, qk_head_dim]self.k_cache[:bs, :seq_len, :, :] = kself.v_cache[:bs, :seq_len, :, :] = v# scores = torch.einsum("bshd,bthd->bsht", q, self.k_cache[:bs, :seq_len]) / math.sqrt(self.qk_nope_head_dim + self.qk_rope_head_dim)scores = torch.matmul(q.transpose(1, 2),self.k_cache[:bs, :seq_len, :, :].transpose(1, 2).transpose(2, 3) / math.sqrt(self.qk_nope_head_dim + self.qk_rope_head_dim))scores = scores.transpose(1, 2)else:k_pe = k_pe.squeeze(2)wkv_b = self.wkv_b.weight  # [n_heads * (qk_nope_head_dim + v_head_dim), kv_lora_rank]wkv_b = wkv_b.view(self.n_heads, -1,self.kv_lora_rank)  # [n_heads, qk_nope_head_dim + v_head_dim, kv_lora_rank]q_nope = torch.einsum("bshd,hdc->bshc", q_nope,wkv_b[:, :self.qk_nope_head_dim])  # q_nope shape:[bs, seq_len, n_heads, kv_lora_rank]# q*k(T) = x*wq*(c*wkv_b[:, :self.qk_nope_head_dim])(T) = x*wq*wkv_b[:, :self.qk_nope_head_dim](T)*c(T)    c为压缩后的k/v# wq*wkv_b[:, :self.qk_nope_head_dim](T)作为q的投影矩阵  c可以替代原先的k,这样就可以直接使用压缩后的k/v计算注意力了,kv_cache时也只需存储压缩后的k/vkv = self.kv_norm(kv)self.kv_cache[:bs, :seq_len, :] = kv  # kv shape:[bs, seq_len, kv_lora_rank]self.pe_cache[:bs, :seq_len, :] = k_pe  # k_pe shape:[bs, seq_len, qk_rope_head_dim]scores_nope = torch.einsum("bshc,btc->bsht", q_nope,self.kv_cache[:bs, :seq_len, :])  # bshc btc -> bshc bct -> bshtscores_pe = torch.einsum("bshr,btr->bsht", q_pe,self.pe_cache[:bs, :seq_len, :])  # bshr btr -> bshr bt1r -> bshr bthr -> bshtscores = (scores_nope + scores_pe) / math.sqrt(self.qk_nope_head_dim + self.qk_rope_head_dim)  # [bs, seq_len, n_heads, seq_len]if mask is not None:# mask shape:[bs, seq_len, seq_len]scores += mask.unsqueeze(2)scores = scores.softmax(dim=-1)if self.mode == 'naive':x = torch.einsum("bsht,bthd->bshd", scores,self.v_cache[:bs, :seq_len])  # bsht,bthd -> bhst, bhtd -> bhsd -> bshdelse:# scores * v = scores * c * wkv_b[:, -self.v_head_dim:]x = torch.einsum("bsht,btc->bshc", scores,self.kv_cache[:bs, :seq_len])  # x shape:[bs, seq_len, n_heads, kv_lora_rank]x = torch.einsum("bshc,hdc->bshd", x, wkv_b[:, -self.v_head_dim:])  # bshc, hdc -> bshc,dch -> bsdh -> bshdx = x.contiguous().view(bs, seq_len, -1)x = self.wo(x) return xif __name__ == '__main__':torch.manual_seed(0)torch.set_printoptions(precision=3, sci_mode=False)x = torch.randn(1, 4, 16)dim = 16n_heads = 2q_lora_rank = 10kv_lora_rank = 6qk_nope_head_dim = 8qk_rope_head_dim = 4v_head_dim = 8max_seq_len = 10max_batch_size = 4mode = 'none'mla = MLA(dim=dim,n_heads=n_heads,q_lora_rank=q_lora_rank,kv_lora_rank=kv_lora_rank,qk_nope_head_dim=qk_nope_head_dim,qk_rope_head_dim=qk_rope_head_dim,v_head_dim=v_head_dim,max_seq_len=max_seq_len,max_batch_size=max_batch_size,mode=mode)print(mla(x))print(mla.kv_cache)

参考资料:

https://zhuanlan.zhihu.com/p/16730036197

https://github.com/wyf3/llm_related/tree/main/deepseek_learn

相关文章:

大模型推理——MLA实现方案

1.整体流程 先上一张图来整体理解下MLA的计算过程 2.实现代码 import math import torch import torch.nn as nn# rms归一化 class RMSNorm(nn.Module):""""""def __init__(self, hidden_size, eps1e-6):super().__init__()self.weight nn.Pa…...

redis之GEO 模块

文章目录 背景GeoHash 算法redis中的GeoHash 算法基本使用增加距离获取元素位置获取元素的 hash 值附近的元素 注意事项原理 背景 如果我们有需求需要存储地理坐标,为了满足高性能的矩形区域算法,数据表需要在经纬度坐标加上双向复合索引 (x, y)&#x…...

21.2.7 综合示例

版权声明:本文为博主原创文章,转载请在显著位置标明本文出处以及作者网名,未经作者允许不得用于商业目的。 【例 21.7】【项目:code21-007】填充职员表并打印。 本例使用到的Excel文件为:职员信息登记表.xlsx&#x…...

使用Docker + Ollama在Ubuntu中部署deepseek

1、安装docker 这里建议用docker来部署,方便简单 安装教程需要自己找详细的,会用到跳过 如果你没有安装 Docker,可以按照以下步骤安装: sudo apt update sudo apt install apt-transport-https ca-certificates curl software-p…...

【C语言标准库函数】三角函数

目录 一、头文件 二、函数简介 2.1. 正弦函数:sin(double angle) 2.2. 余弦函数:cos(double angle) 2.3. 正切函数:tan(double angle) 2.4. 反正弦函数:asin(double value) 2.5. 反余弦函数:acos(double value)…...

CNN-day9-经典神经网络ResNet

day10-经典神经网络ResNet 1 梯度消失问题 深层网络有个梯度消失问题:模型变深时,其错误率反而会提升,该问题非过拟合引起,主要是因为梯度消失而导致参数难以学习和更新。 2 网络创新 2015年何凯明等人提出deep residual netw…...

淘宝分类详情数据获取:Python爬虫的高效实现

在电商领域,淘宝作为中国最大的电商平台之一,其分类详情数据对于市场分析、竞争对手研究以及电商运营优化具有不可估量的价值。通过Python爬虫技术,我们可以高效地获取这些数据,为电商从业者提供强大的数据支持。 一、为什么选择…...

机器学习 —— 深入剖析线性回归模型

一、线性回归模型简介 线性回归是机器学习中最为基础的模型之一,主要用于解决回归问题,即预测一个连续的数值。其核心思想是构建线性方程,描述自变量(特征)和因变量(目标值)之间的关系。简单来…...

33.日常算法

1.螺旋矩阵 题目来源 给你一个 m 行 n 列的矩阵 matrix ,请按照 顺时针螺旋顺序 ,返回矩阵中的所有元素。 示例 1: 输入:matrix [[1,2,3],[4,5,6],[7,8,9]] 输出:[1,2,3,6,9,8,7,4,5] class Solution { public:vec…...

#渗透测试#批量漏洞挖掘#微商城系统 goods SQL注入漏洞

免责声明 本教程仅为合法的教学目的而准备,严禁用于任何形式的违法犯罪活动及其他商业行为,在使用本教程前,您应确保该行为符合当地的法律法规,继续阅读即表示您需自行承担所有操作的后果,如有异议,请立即停止本文章读。 目录 一、漏洞概述 二、漏洞复现步骤 三、技术…...

【翻译+论文阅读】DeepSeek-R1评测:粉碎GPT-4和Claude 3.5的开源AI革命

目录 一、DeepSeek-R1 势不可挡二、DeepSeek-R1 卓越之处三、DeepSeek-R1 创新设计四、DeepSeek-R1 进化之路1. 强化学习RL代替监督微调学习SFL2. Aha Moment “啊哈”时刻3. 蒸馏版本仅采用SFT4. 未来研究计划 部分内容有拓展,部分内容有删除,与原文会有…...

Vision Transformer学习笔记(2020 ICLR)

摘要(Abstract):简述了ViT(Vision Transformer)模型的设计和实验结果,展示了其在大规模图像数据集上进行训练时的优越性能。该模型直接采用原始图像块作为输入,而不是传统的卷积神经网络(CNNs),并通过Transformer架构处理这些图像块以实现高效的图像识别。引言(Introdu…...

一步一步生成音乐类小程序的详细指南,结合AI辅助开发的思路

以下是一步一步生成音乐类小程序的详细指南,结合AI辅助开发的思路: 需求分析阶段核心功能梳理 音乐播放器(播放/暂停/进度条/音量)歌单分类(流行/古典/摇滚等)用户系统(登录/收藏/历史记录)搜索功能(歌曲/歌手/专辑)推荐系统(根据用户偏好推荐)技术选型 前端:微信…...

25/2/8 <机器人基础> 阻抗控制

1. 什么是阻抗控制? 阻抗控制旨在通过调节机器人与环境的相互作用,控制其动态行为。阻抗可以理解为一个力和位移之间的关系,涉及力、速度和位置的协同控制。 2. 阻抗控制的基本概念 力控制:根据感测的外力调节机械手的动作。位置…...

golang 开启HTTP代理认证

内部网路不能直接访问外网接口,可以通过代理发送HTTP请求。 HTTP代理服务需要进行认证。 package cmdimport ("fmt""io/ioutil""log""net/http""net/url""strings" )// 推送CBC07功能 func main() {l…...

详解Nginx no live upstreams while connecting to upstream

网上看到几个相关的文章,觉得很不错,这里整理记录分享一下,供大家参考。 upstream配置分 在分析问题原因之前,我们先来看下关于上面upstream配置一些相关的参数配置说明,参考下面表格 ngx_http_proxy_module 这里重…...

Open3d Qt的环境配置

Open3d Qt的环境配置 一、概述二、操作流程2.1 下载文件2.2 新建文件夹2.3 环境变量设置2.4 qt6 引用3、qt中调用4、资源下载一、概述 目前统一使用qt6配置,open3d中可视化功能目前使用vtk代替,语言为c++。 二、操作流程 2.1 下载文件 访问open3d github链接,进入releas…...

5.Python字典和元组:字典的增删改查、字典遍历、访问元组、修改元组、集合(set)

1. 字典(dict) 字典是一个无序的键值对集合,每个键对应一个值。 字典的增、删、改、查: 添加键值对: my_dict {a: 1, b: 2} my_dict[c] 3 # 添加新键c,值为3 print(my_dict) # 输出:{a: 1, b: 2, c: …...

深度学习系列--04.梯度下降以及其他优化器

目录 一.梯度概念 1.一元函数 2.二元函数 3.几何意义上的区别 二.梯度下降 1.原理 2.步骤 3.示例代码(Python) 4.不同类型的梯度下降 5.优缺点 三.动量优化器(Momentum) 适用场景 1.复杂地形的优化问题 2.数据具有噪声的问…...

2022java面试总结,1000道(集合+JVM+并发编程+Spring+Mybatis)的Java高频面试题

1、面试题模块汇总 面试题包括以下十九个模块: Java 基础、容器、多线程、反射、对象拷贝、Java Web 模块、异常、网络、设计模式、Spring/Spring MVC、Spring Boot/Spring Cloud、Hibernate、Mybatis、RabbitMQ、Kafka、Zookeeper、MySql、Redis、JVM 。如下图所示…...

Ubuntu MKL(Intel Math Kernel Library)

Get Intel oneAPI Math Kernel Library wget https://registrationcenter-download.intel.com/akdlm/IRC_NAS/79153e0f-74d7-45af-b8c2-258941adf58a/intel-onemkl-2025.0.0.940_offline.sh sudo sh ./intel-onemkl-2025.0.0.940_offline.sh MKL库的配置和使用-CSDN博客 CMak…...

消费电子产品中的噪声对TPS54202的影响

本文章是笔者整理的备忘笔记。希望在帮助自己温习避免遗忘的同时,也能帮助其他需要参考的朋友。如有谬误,欢迎大家进行指正。 一、概述 在白色家电领域,降压转换器的应用非常广泛,为了实现不同的功能就需要不同的电源轨。TPS542…...

第四十章:职场转折:突破困境,重新出发

从绍兴与岳父岳母温馨相聚归来后,小冷满心都是温暖与幸福,本以为生活与工作会继续平稳前行,然而,命运却悄然为他的职场之路埋下了转折的伏笔。 平静工作下的暗潮涌动 小冷所在的公司是一家专注于地图导航与位置服务的企业&#xf…...

c++ 不定参数,不定类型的 max,min 函数

MSVC\14.29.30133\include\utility(33,19): error C2064: 项不会计算为接受 2 个参数的函数 max min #include <iostream> #include <type_traits>// 自定义 min_gd&#xff08;支持任意类型和数量参数&#xff09; template <typename... Args> auto min_g…...

数据库的关系代数

关系就是表 属性&#xff08;Attribute&#xff09;是关系中的列.例如&#xff0c;关系 “学生” 中可能有属性 “学号”、“姓名”、“班级”。 元组(Tuple)是关系中的一行数据 1. 基本运算符 选择&#xff08;Selection&#xff09; 符号&#xff1a;σ 作用&#xff1a;从关…...

VSCode使用总结

1、VSCode左边资源窗口字体大小设置 方法一&#xff08;使用&#xff0c;已成功&#xff09; 进入安装目录Microsoft VS Code\resources\app\out\vs\workbench(如果是下载的压缩包&#xff0c;解压后resources\app\out\vs\workbench) 打开文件 workbench.desktop.main.css 搜…...

关系模型的数据结构及形式化定义

1 关系模型的核心结构 ①单一的数据结构&#xff08;关系&#xff09; 现实世界的实体以及实体间的各种联系均用关系来表示 ②逻辑结构&#xff08;二维表&#xff09; 从用户角度&#xff0c;关系模型中数据的逻辑结构是一张二维表&#xff0c;行代表元组&#xff08;记录&a…...

【C++入门讲解】

目录 ​编辑 --------------------------------------begin---------------------------------------- 一、C简介 二、开发环境搭建 主流开发工具推荐 第一个C程序 三、核心语法精讲 1. 变量与数据类型 2. 运算符大全 3. 流程控制结构 4. 函数深度解析 5. 数组与容…...

数据表中的视图操作

文章目录 一、视图概述二、为什么要使用视图三、创建视图四、查看视图 一、视图概述 小学的时候&#xff0c;每年都会举办一次抽考活动&#xff0c;意思是从每一个班级里面筛选出几个优秀的同学去参加考试&#xff0c;这时候很多班级筛选出来的这些同学就可以临时组成一个班级…...

BFS算法篇——广度优先搜索,探索未知的旅程(上)

文章目录 前言一、BFS的思路二、BFS的C语言实现1. 图的表示2. BFS的实现 三、代码解析四、输出结果五、总结 前言 广度优先搜索&#xff08;BFS&#xff09;是一种广泛应用于图论中的算法&#xff0c;常用于寻找最短路径、图的遍历等问题。与深度优先搜索&#xff08;DFS&…...