大模型推理——MLA实现方案
1.整体流程
先上一张图来整体理解下MLA的计算过程

2.实现代码
import math
import torch
import torch.nn as nn# rms归一化
class RMSNorm(nn.Module):""""""def __init__(self, hidden_size, eps=1e-6):super().__init__()self.weight = nn.Parameter(torch.ones(hidden_size))self.variance_epsilon = epsdef forward(self, hidden_states):hidden_states = hidden_states.float()variance = hidden_states.pow(2).mean(-1, keepdim=True)hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)return self.weight * hidden_states.float()def rotate_half(x):x1, x2 = x.chunk(2, dim=-1)return torch.cat((-x2, x1), dim=-1)def apply_rotate_pos_emb(q, k, cos, sin, unsqueeze_dim=2):cos = cos.unsqueeze(unsqueeze_dim)sin = sin.unsqueeze(unsqueeze_dim)q_embed = (q * cos) + (rotate_half(q) * sin)k_embed = (k * cos) + (rotate_half(k) * sin)return q_embed, k_embed# 旋转位置编码
class RotaryEmbedding(nn.Module):def __init__(self, dim, max_seq_len=1024):super(RotaryEmbedding, self).__init__()self.dim = dimself.max_seq_len = max_seq_leninv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))t = torch.arange(max_seq_len).float().unsqueeze(1)freqs = t @ inv_freq.unsqueeze(0)freqs = torch.cat((freqs, freqs), dim=-1)self.register_buffer("cos_cached", freqs.cos())self.register_buffer("sin_cached", freqs.sin())def forward(self, q, k):cos = self.cos_cached[:q.shape[1], :].unsqueeze(0)sin = self.sin_cached[:q.shape[1], :].unsqueeze(0)return apply_rotate_pos_emb(q, k, cos, sin)class MLA(nn.Module):def __init__(self,dim,n_heads,q_lora_rank,kv_lora_rank,qk_nope_head_dim,qk_rope_head_dim,v_head_dim,max_seq_len,max_batch_size,mode):super().__init__()self.dim = dim # 隐藏层维度self.n_heads = n_heads # 总头数self.q_lora_rank = q_lora_rank # q低秩压缩到的维度self.kv_lora_rank = kv_lora_rank # k/v低秩压缩到的维度self.qk_nope_head_dim = qk_nope_head_dim # q/k不带旋转位置编码的维度self.qk_rope_head_dim = qk_rope_head_dim # q/k带旋转位置编码的维度self.qk_head_dim = qk_nope_head_dim + qk_rope_head_dim # q/k的总维度,不带旋转位置编码的维度加上带旋转位置编码的维度self.v_head_dim = v_head_dim # value的维度,等于不带旋转位置编码的k维度self.mode = modeself.max_seq_len = max_seq_lenself.max_batch_size = max_batch_sizeself.wq_a = nn.Linear(self.dim, self.q_lora_rank) # q的降维矩阵self.q_norm = RMSNorm(self.q_lora_rank)self.wq_b = nn.Linear(self.q_lora_rank, self.n_heads * self.qk_head_dim) # q的升维矩阵# 4096*128+128*4864 = 524,288 + 622592 = 1146880 4096*4864 = 19,922,944self.wkv_a = nn.Linear(self.dim, self.kv_lora_rank + self.qk_rope_head_dim) # k/v的降维矩阵# nn.Linear(self.dim, self.kv_lora_rank)# nn.Linear(self.dim, self.qk_rope_head_dim)self.kv_norm = RMSNorm(self.kv_lora_rank)self.wkv_b = nn.Linear(self.kv_lora_rank, self.n_heads * (self.qk_nope_head_dim + self.v_head_dim)) # k/v的升维矩阵self.wo = nn.Linear(self.n_heads * self.v_head_dim, self.dim)self.rotary_emb = RotaryEmbedding(self.qk_rope_head_dim) # 旋转位置编码# 没有矩阵融合if self.mode == 'naive':self.register_buffer('k_cache',torch.zeros(self.max_batch_size, self.max_seq_len, self.n_heads, self.qk_head_dim),persistent=False)self.register_buffer('v_cache',torch.zeros(self.max_batch_size, self.max_seq_len, self.n_heads, self.v_head_dim),persistent=False)# 有矩阵融合else:self.register_buffer('kv_cache', torch.zeros(self.max_batch_size, self.max_seq_len, self.kv_lora_rank),persistent=False)self.register_buffer('pe_cache', torch.zeros(self.max_batch_size, self.max_seq_len, self.qk_rope_head_dim),persistent=False)def forward(self, x, mask=None):bs, seq_len, _ = x.shapeq = self.wq_a(x) # [bs, seq_len, q_lora_rank]q = self.q_norm(q) # [bs, seq_len, q_lora_rank]q = self.wq_b(q) # [bs, seq_len, n_heads * qk_head_dim]q = q.view(bs, seq_len, self.n_heads, self.qk_head_dim) # [bs, seq_len, n_heads, qk_head_dim]q_nope, q_pe = torch.split(q, [self.qk_nope_head_dim, self.qk_rope_head_dim],dim=-1) # q_nope shape:[bs, seq_len, n_heads, qk_nope_head_dim] q_pe shape:[bs, seq_len, n_heads, qk_rope_head_dim]kv = self.wkv_a(x) # [bs, seq_len, kv_lora_rank + qk_rope_head_dim]kv, k_pe = torch.split(kv, [self.kv_lora_rank, self.qk_rope_head_dim],dim=-1) # kv shape:[bs, seq_len, kv_lora_rank] k_pe shape:[bs, seq_len, qk_rope_head_dim]k_pe = k_pe.unsqueeze(2) # k_pe shape:[bs, seq_len, 1, qk_rope_head_dim] 一层共享一个keyq_pe, k_pe = self.rotary_emb(q_pe, k_pe)if self.mode == 'naive':q = torch.cat([q_nope, q_pe], dim=-1) # * [bs, seq_len, n_heads, qk_head_dim]kv = self.kv_norm(kv) # [bs, seq_len, kv_lora_rank)]kv = self.wkv_b(kv) # [bs, seq_len, n_heads * (qk_nope_head_dim + v_head_dim)]kv = kv.view(bs, seq_len, self.n_heads, self.qk_nope_head_dim + self.v_head_dim)k_nope, v = torch.split(kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1)k = torch.cat([k_nope, k_pe.expand(-1, -1, self.n_heads, -1)], dim=-1)# k shape:[bs, seq_len, n_heads, qk_head_dim]self.k_cache[:bs, :seq_len, :, :] = kself.v_cache[:bs, :seq_len, :, :] = v# scores = torch.einsum("bshd,bthd->bsht", q, self.k_cache[:bs, :seq_len]) / math.sqrt(self.qk_nope_head_dim + self.qk_rope_head_dim)scores = torch.matmul(q.transpose(1, 2),self.k_cache[:bs, :seq_len, :, :].transpose(1, 2).transpose(2, 3) / math.sqrt(self.qk_nope_head_dim + self.qk_rope_head_dim))scores = scores.transpose(1, 2)else:k_pe = k_pe.squeeze(2)wkv_b = self.wkv_b.weight # [n_heads * (qk_nope_head_dim + v_head_dim), kv_lora_rank]wkv_b = wkv_b.view(self.n_heads, -1,self.kv_lora_rank) # [n_heads, qk_nope_head_dim + v_head_dim, kv_lora_rank]q_nope = torch.einsum("bshd,hdc->bshc", q_nope,wkv_b[:, :self.qk_nope_head_dim]) # q_nope shape:[bs, seq_len, n_heads, kv_lora_rank]# q*k(T) = x*wq*(c*wkv_b[:, :self.qk_nope_head_dim])(T) = x*wq*wkv_b[:, :self.qk_nope_head_dim](T)*c(T) c为压缩后的k/v# wq*wkv_b[:, :self.qk_nope_head_dim](T)作为q的投影矩阵 c可以替代原先的k,这样就可以直接使用压缩后的k/v计算注意力了,kv_cache时也只需存储压缩后的k/vkv = self.kv_norm(kv)self.kv_cache[:bs, :seq_len, :] = kv # kv shape:[bs, seq_len, kv_lora_rank]self.pe_cache[:bs, :seq_len, :] = k_pe # k_pe shape:[bs, seq_len, qk_rope_head_dim]scores_nope = torch.einsum("bshc,btc->bsht", q_nope,self.kv_cache[:bs, :seq_len, :]) # bshc btc -> bshc bct -> bshtscores_pe = torch.einsum("bshr,btr->bsht", q_pe,self.pe_cache[:bs, :seq_len, :]) # bshr btr -> bshr bt1r -> bshr bthr -> bshtscores = (scores_nope + scores_pe) / math.sqrt(self.qk_nope_head_dim + self.qk_rope_head_dim) # [bs, seq_len, n_heads, seq_len]if mask is not None:# mask shape:[bs, seq_len, seq_len]scores += mask.unsqueeze(2)scores = scores.softmax(dim=-1)if self.mode == 'naive':x = torch.einsum("bsht,bthd->bshd", scores,self.v_cache[:bs, :seq_len]) # bsht,bthd -> bhst, bhtd -> bhsd -> bshdelse:# scores * v = scores * c * wkv_b[:, -self.v_head_dim:]x = torch.einsum("bsht,btc->bshc", scores,self.kv_cache[:bs, :seq_len]) # x shape:[bs, seq_len, n_heads, kv_lora_rank]x = torch.einsum("bshc,hdc->bshd", x, wkv_b[:, -self.v_head_dim:]) # bshc, hdc -> bshc,dch -> bsdh -> bshdx = x.contiguous().view(bs, seq_len, -1)x = self.wo(x) return xif __name__ == '__main__':torch.manual_seed(0)torch.set_printoptions(precision=3, sci_mode=False)x = torch.randn(1, 4, 16)dim = 16n_heads = 2q_lora_rank = 10kv_lora_rank = 6qk_nope_head_dim = 8qk_rope_head_dim = 4v_head_dim = 8max_seq_len = 10max_batch_size = 4mode = 'none'mla = MLA(dim=dim,n_heads=n_heads,q_lora_rank=q_lora_rank,kv_lora_rank=kv_lora_rank,qk_nope_head_dim=qk_nope_head_dim,qk_rope_head_dim=qk_rope_head_dim,v_head_dim=v_head_dim,max_seq_len=max_seq_len,max_batch_size=max_batch_size,mode=mode)print(mla(x))print(mla.kv_cache)
参考资料:
https://zhuanlan.zhihu.com/p/16730036197
https://github.com/wyf3/llm_related/tree/main/deepseek_learn
相关文章:
大模型推理——MLA实现方案
1.整体流程 先上一张图来整体理解下MLA的计算过程 2.实现代码 import math import torch import torch.nn as nn# rms归一化 class RMSNorm(nn.Module):""""""def __init__(self, hidden_size, eps1e-6):super().__init__()self.weight nn.Pa…...
redis之GEO 模块
文章目录 背景GeoHash 算法redis中的GeoHash 算法基本使用增加距离获取元素位置获取元素的 hash 值附近的元素 注意事项原理 背景 如果我们有需求需要存储地理坐标,为了满足高性能的矩形区域算法,数据表需要在经纬度坐标加上双向复合索引 (x, y)&#x…...
21.2.7 综合示例
版权声明:本文为博主原创文章,转载请在显著位置标明本文出处以及作者网名,未经作者允许不得用于商业目的。 【例 21.7】【项目:code21-007】填充职员表并打印。 本例使用到的Excel文件为:职员信息登记表.xlsx&#x…...
使用Docker + Ollama在Ubuntu中部署deepseek
1、安装docker 这里建议用docker来部署,方便简单 安装教程需要自己找详细的,会用到跳过 如果你没有安装 Docker,可以按照以下步骤安装: sudo apt update sudo apt install apt-transport-https ca-certificates curl software-p…...
【C语言标准库函数】三角函数
目录 一、头文件 二、函数简介 2.1. 正弦函数:sin(double angle) 2.2. 余弦函数:cos(double angle) 2.3. 正切函数:tan(double angle) 2.4. 反正弦函数:asin(double value) 2.5. 反余弦函数:acos(double value)…...
CNN-day9-经典神经网络ResNet
day10-经典神经网络ResNet 1 梯度消失问题 深层网络有个梯度消失问题:模型变深时,其错误率反而会提升,该问题非过拟合引起,主要是因为梯度消失而导致参数难以学习和更新。 2 网络创新 2015年何凯明等人提出deep residual netw…...
淘宝分类详情数据获取:Python爬虫的高效实现
在电商领域,淘宝作为中国最大的电商平台之一,其分类详情数据对于市场分析、竞争对手研究以及电商运营优化具有不可估量的价值。通过Python爬虫技术,我们可以高效地获取这些数据,为电商从业者提供强大的数据支持。 一、为什么选择…...
机器学习 —— 深入剖析线性回归模型
一、线性回归模型简介 线性回归是机器学习中最为基础的模型之一,主要用于解决回归问题,即预测一个连续的数值。其核心思想是构建线性方程,描述自变量(特征)和因变量(目标值)之间的关系。简单来…...
33.日常算法
1.螺旋矩阵 题目来源 给你一个 m 行 n 列的矩阵 matrix ,请按照 顺时针螺旋顺序 ,返回矩阵中的所有元素。 示例 1: 输入:matrix [[1,2,3],[4,5,6],[7,8,9]] 输出:[1,2,3,6,9,8,7,4,5] class Solution { public:vec…...
#渗透测试#批量漏洞挖掘#微商城系统 goods SQL注入漏洞
免责声明 本教程仅为合法的教学目的而准备,严禁用于任何形式的违法犯罪活动及其他商业行为,在使用本教程前,您应确保该行为符合当地的法律法规,继续阅读即表示您需自行承担所有操作的后果,如有异议,请立即停止本文章读。 目录 一、漏洞概述 二、漏洞复现步骤 三、技术…...
【翻译+论文阅读】DeepSeek-R1评测:粉碎GPT-4和Claude 3.5的开源AI革命
目录 一、DeepSeek-R1 势不可挡二、DeepSeek-R1 卓越之处三、DeepSeek-R1 创新设计四、DeepSeek-R1 进化之路1. 强化学习RL代替监督微调学习SFL2. Aha Moment “啊哈”时刻3. 蒸馏版本仅采用SFT4. 未来研究计划 部分内容有拓展,部分内容有删除,与原文会有…...
Vision Transformer学习笔记(2020 ICLR)
摘要(Abstract):简述了ViT(Vision Transformer)模型的设计和实验结果,展示了其在大规模图像数据集上进行训练时的优越性能。该模型直接采用原始图像块作为输入,而不是传统的卷积神经网络(CNNs),并通过Transformer架构处理这些图像块以实现高效的图像识别。引言(Introdu…...
一步一步生成音乐类小程序的详细指南,结合AI辅助开发的思路
以下是一步一步生成音乐类小程序的详细指南,结合AI辅助开发的思路: 需求分析阶段核心功能梳理 音乐播放器(播放/暂停/进度条/音量)歌单分类(流行/古典/摇滚等)用户系统(登录/收藏/历史记录)搜索功能(歌曲/歌手/专辑)推荐系统(根据用户偏好推荐)技术选型 前端:微信…...
25/2/8 <机器人基础> 阻抗控制
1. 什么是阻抗控制? 阻抗控制旨在通过调节机器人与环境的相互作用,控制其动态行为。阻抗可以理解为一个力和位移之间的关系,涉及力、速度和位置的协同控制。 2. 阻抗控制的基本概念 力控制:根据感测的外力调节机械手的动作。位置…...
golang 开启HTTP代理认证
内部网路不能直接访问外网接口,可以通过代理发送HTTP请求。 HTTP代理服务需要进行认证。 package cmdimport ("fmt""io/ioutil""log""net/http""net/url""strings" )// 推送CBC07功能 func main() {l…...
详解Nginx no live upstreams while connecting to upstream
网上看到几个相关的文章,觉得很不错,这里整理记录分享一下,供大家参考。 upstream配置分 在分析问题原因之前,我们先来看下关于上面upstream配置一些相关的参数配置说明,参考下面表格 ngx_http_proxy_module 这里重…...
Open3d Qt的环境配置
Open3d Qt的环境配置 一、概述二、操作流程2.1 下载文件2.2 新建文件夹2.3 环境变量设置2.4 qt6 引用3、qt中调用4、资源下载一、概述 目前统一使用qt6配置,open3d中可视化功能目前使用vtk代替,语言为c++。 二、操作流程 2.1 下载文件 访问open3d github链接,进入releas…...
5.Python字典和元组:字典的增删改查、字典遍历、访问元组、修改元组、集合(set)
1. 字典(dict) 字典是一个无序的键值对集合,每个键对应一个值。 字典的增、删、改、查: 添加键值对: my_dict {a: 1, b: 2} my_dict[c] 3 # 添加新键c,值为3 print(my_dict) # 输出:{a: 1, b: 2, c: …...
深度学习系列--04.梯度下降以及其他优化器
目录 一.梯度概念 1.一元函数 2.二元函数 3.几何意义上的区别 二.梯度下降 1.原理 2.步骤 3.示例代码(Python) 4.不同类型的梯度下降 5.优缺点 三.动量优化器(Momentum) 适用场景 1.复杂地形的优化问题 2.数据具有噪声的问…...
2022java面试总结,1000道(集合+JVM+并发编程+Spring+Mybatis)的Java高频面试题
1、面试题模块汇总 面试题包括以下十九个模块: Java 基础、容器、多线程、反射、对象拷贝、Java Web 模块、异常、网络、设计模式、Spring/Spring MVC、Spring Boot/Spring Cloud、Hibernate、Mybatis、RabbitMQ、Kafka、Zookeeper、MySql、Redis、JVM 。如下图所示…...
OpenClaw+ollama-QwQ-32B内容处理:自动生成周报与会议纪要
OpenClawollama-QwQ-32B内容处理:自动生成周报与会议纪要 1. 为什么需要自动化内容处理工具 每周五下午三点,我的日历总会准时弹出"编写本周工作报告"的提醒。这个看似简单的任务,却常常让我陷入两难:要么花半小时手动…...
各向异性方解石晶体的双折射效应
1. 摘要 双折射效应是各向异性材料最重要的光学特性,并广泛应用于多种光学器件。当入射光波撞击各向异性材料,会以不同的偏振态分束到不同路径,即众所周知的寻常光束和异常光束。在本示例中,描述了如何利用VirtualLab Fusion对双折…...
当多线雷达遇上RTK:一个能跑工业现场的SLAM方案
多传感器融合建图及定位的工程化落地方案,多线雷达rtk;室内室外导航都适用。 包含部署文档和代码注释;包含工程落地角度的优化。 不含运动控制。 室外场景用RTK信号稳如老狗,一进厂房立马抓瞎;多线雷达在室内横扫千军…...
brpc连接池动态调整算法:基于排队理论的设计与实现
brpc连接池动态调整算法:基于排队理论的设计与实现 【免费下载链接】brpc brpc is an Industrial-grade RPC framework using C Language, which is often used in high performance system such as Search, Storage, Machine learning, Advertisement, Recommendat…...
Path of Building:流放之路玩家必备的终极Build规划神器
Path of Building:流放之路玩家必备的终极Build规划神器 【免费下载链接】PathOfBuilding Offline build planner for Path of Exile. 项目地址: https://gitcode.com/GitHub_Trending/pa/PathOfBuilding 如果你正在玩《流放之路》并为复杂的Build规划感到头…...
猫抓插件:革新性浏览器资源捕获工具,让媒体下载效率倍增
猫抓插件:革新性浏览器资源捕获工具,让媒体下载效率倍增 【免费下载链接】cat-catch 猫抓 chrome资源嗅探扩展 项目地址: https://gitcode.com/GitHub_Trending/ca/cat-catch 在数字内容爆炸的时代,如何高效获取网页中的视频、音频和图…...
从SWF中提取加密通信协议:JPEXS Free Flash Decompiler安全分析报告
从SWF中提取加密通信协议:JPEXS Free Flash Decompiler安全分析报告 【免费下载链接】jpexs-decompiler JPEXS Free Flash Decompiler 项目地址: https://gitcode.com/gh_mirrors/jp/jpexs-decompiler 在网络安全分析领域,SWF(Shockwa…...
别再ping IP了!手把手教你给ZeroTier虚拟网络里的设备起个‘好记’的名字(DNS/mDNS实战)
告别IP记忆困扰:ZeroTier网络中的智能命名方案实战指南 每次在ZeroTier虚拟网络中访问设备时,你是否也厌倦了反复查看和输入那串冗长的IP地址?想象一下,当你想连接家庭NAS时,只需输入nas.home就能立即访问,…...
保姆级教程:在OrangePi 5 Plus上从SSD启动Ubuntu 22.04,并配置ROS2 Humble环境
OrangePi 5 Plus开发板全栈配置指南:从SSD启动到ROS2 Humble环境搭建 拿到一块OrangePi 5 Plus开发板时,如何快速搭建一个稳定高效的开发环境?本文将手把手带你完成从系统烧录到ROS2环境配置的全过程,特别针对ARM64架构的优化方案…...
便携激光云高仪:精确测量云底高度、云层厚度等关键参数
便携激光云高仪是一种用于测量云层高度、厚度及分布情况的气象观测设备,广泛应用于气象监测、航空安全、环境研究等领域。其便携式设计特别适合野外作业和临时观测需求。设备通过激光脉冲探测云底高度,并实时分析云层垂直结构,为气象预报、灾…...
