当前位置: 首页 > article >正文

大模型推理——MLA实现方案

1.整体流程

先上一张图来整体理解下MLA的计算过程

2.实现代码

import math
import torch
import torch.nn as nn# rms归一化
class RMSNorm(nn.Module):""""""def __init__(self, hidden_size, eps=1e-6):super().__init__()self.weight = nn.Parameter(torch.ones(hidden_size))self.variance_epsilon = epsdef forward(self, hidden_states):hidden_states = hidden_states.float()variance = hidden_states.pow(2).mean(-1, keepdim=True)hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)return self.weight * hidden_states.float()def rotate_half(x):x1, x2 = x.chunk(2, dim=-1)return torch.cat((-x2, x1), dim=-1)def apply_rotate_pos_emb(q, k, cos, sin, unsqueeze_dim=2):cos = cos.unsqueeze(unsqueeze_dim)sin = sin.unsqueeze(unsqueeze_dim)q_embed = (q * cos) + (rotate_half(q) * sin)k_embed = (k * cos) + (rotate_half(k) * sin)return q_embed, k_embed# 旋转位置编码
class RotaryEmbedding(nn.Module):def __init__(self, dim, max_seq_len=1024):super(RotaryEmbedding, self).__init__()self.dim = dimself.max_seq_len = max_seq_leninv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))t = torch.arange(max_seq_len).float().unsqueeze(1)freqs = t @ inv_freq.unsqueeze(0)freqs = torch.cat((freqs, freqs), dim=-1)self.register_buffer("cos_cached", freqs.cos())self.register_buffer("sin_cached", freqs.sin())def forward(self, q, k):cos = self.cos_cached[:q.shape[1], :].unsqueeze(0)sin = self.sin_cached[:q.shape[1], :].unsqueeze(0)return apply_rotate_pos_emb(q, k, cos, sin)class MLA(nn.Module):def __init__(self,dim,n_heads,q_lora_rank,kv_lora_rank,qk_nope_head_dim,qk_rope_head_dim,v_head_dim,max_seq_len,max_batch_size,mode):super().__init__()self.dim = dim  # 隐藏层维度self.n_heads = n_heads  # 总头数self.q_lora_rank = q_lora_rank  # q低秩压缩到的维度self.kv_lora_rank = kv_lora_rank  # k/v低秩压缩到的维度self.qk_nope_head_dim = qk_nope_head_dim    # q/k不带旋转位置编码的维度self.qk_rope_head_dim = qk_rope_head_dim    # q/k带旋转位置编码的维度self.qk_head_dim = qk_nope_head_dim + qk_rope_head_dim  # q/k的总维度,不带旋转位置编码的维度加上带旋转位置编码的维度self.v_head_dim = v_head_dim  # value的维度,等于不带旋转位置编码的k维度self.mode = modeself.max_seq_len = max_seq_lenself.max_batch_size = max_batch_sizeself.wq_a = nn.Linear(self.dim, self.q_lora_rank)  # q的降维矩阵self.q_norm = RMSNorm(self.q_lora_rank)self.wq_b = nn.Linear(self.q_lora_rank, self.n_heads * self.qk_head_dim)  # q的升维矩阵# 4096*128+128*4864 = 524,288 + 622592 = 1146880    4096*4864 = 19,922,944self.wkv_a = nn.Linear(self.dim, self.kv_lora_rank + self.qk_rope_head_dim)  # k/v的降维矩阵# nn.Linear(self.dim, self.kv_lora_rank)# nn.Linear(self.dim, self.qk_rope_head_dim)self.kv_norm = RMSNorm(self.kv_lora_rank)self.wkv_b = nn.Linear(self.kv_lora_rank, self.n_heads * (self.qk_nope_head_dim + self.v_head_dim))  # k/v的升维矩阵self.wo = nn.Linear(self.n_heads * self.v_head_dim, self.dim)self.rotary_emb = RotaryEmbedding(self.qk_rope_head_dim)  # 旋转位置编码# 没有矩阵融合if self.mode == 'naive':self.register_buffer('k_cache',torch.zeros(self.max_batch_size, self.max_seq_len, self.n_heads, self.qk_head_dim),persistent=False)self.register_buffer('v_cache',torch.zeros(self.max_batch_size, self.max_seq_len, self.n_heads, self.v_head_dim),persistent=False)# 有矩阵融合else:self.register_buffer('kv_cache', torch.zeros(self.max_batch_size, self.max_seq_len, self.kv_lora_rank),persistent=False)self.register_buffer('pe_cache', torch.zeros(self.max_batch_size, self.max_seq_len, self.qk_rope_head_dim),persistent=False)def forward(self, x, mask=None):bs, seq_len, _ = x.shapeq = self.wq_a(x)  # [bs, seq_len, q_lora_rank]q = self.q_norm(q)  # [bs, seq_len, q_lora_rank]q = self.wq_b(q)  # [bs, seq_len, n_heads * qk_head_dim]q = q.view(bs, seq_len, self.n_heads, self.qk_head_dim)  # [bs, seq_len, n_heads, qk_head_dim]q_nope, q_pe = torch.split(q, [self.qk_nope_head_dim, self.qk_rope_head_dim],dim=-1)  # q_nope shape:[bs, seq_len, n_heads, qk_nope_head_dim] q_pe shape:[bs, seq_len, n_heads, qk_rope_head_dim]kv = self.wkv_a(x)  # [bs, seq_len, kv_lora_rank + qk_rope_head_dim]kv, k_pe = torch.split(kv, [self.kv_lora_rank, self.qk_rope_head_dim],dim=-1)  # kv shape:[bs, seq_len, kv_lora_rank] k_pe shape:[bs, seq_len, qk_rope_head_dim]k_pe = k_pe.unsqueeze(2)  # k_pe shape:[bs, seq_len, 1, qk_rope_head_dim]   一层共享一个keyq_pe, k_pe = self.rotary_emb(q_pe, k_pe)if self.mode == 'naive':q = torch.cat([q_nope, q_pe], dim=-1)  # * [bs, seq_len, n_heads, qk_head_dim]kv = self.kv_norm(kv)  # [bs, seq_len, kv_lora_rank)]kv = self.wkv_b(kv)  # [bs, seq_len, n_heads * (qk_nope_head_dim + v_head_dim)]kv = kv.view(bs, seq_len, self.n_heads, self.qk_nope_head_dim + self.v_head_dim)k_nope, v = torch.split(kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1)k = torch.cat([k_nope, k_pe.expand(-1, -1, self.n_heads, -1)], dim=-1)# k shape:[bs, seq_len, n_heads, qk_head_dim]self.k_cache[:bs, :seq_len, :, :] = kself.v_cache[:bs, :seq_len, :, :] = v# scores = torch.einsum("bshd,bthd->bsht", q, self.k_cache[:bs, :seq_len]) / math.sqrt(self.qk_nope_head_dim + self.qk_rope_head_dim)scores = torch.matmul(q.transpose(1, 2),self.k_cache[:bs, :seq_len, :, :].transpose(1, 2).transpose(2, 3) / math.sqrt(self.qk_nope_head_dim + self.qk_rope_head_dim))scores = scores.transpose(1, 2)else:k_pe = k_pe.squeeze(2)wkv_b = self.wkv_b.weight  # [n_heads * (qk_nope_head_dim + v_head_dim), kv_lora_rank]wkv_b = wkv_b.view(self.n_heads, -1,self.kv_lora_rank)  # [n_heads, qk_nope_head_dim + v_head_dim, kv_lora_rank]q_nope = torch.einsum("bshd,hdc->bshc", q_nope,wkv_b[:, :self.qk_nope_head_dim])  # q_nope shape:[bs, seq_len, n_heads, kv_lora_rank]# q*k(T) = x*wq*(c*wkv_b[:, :self.qk_nope_head_dim])(T) = x*wq*wkv_b[:, :self.qk_nope_head_dim](T)*c(T)    c为压缩后的k/v# wq*wkv_b[:, :self.qk_nope_head_dim](T)作为q的投影矩阵  c可以替代原先的k,这样就可以直接使用压缩后的k/v计算注意力了,kv_cache时也只需存储压缩后的k/vkv = self.kv_norm(kv)self.kv_cache[:bs, :seq_len, :] = kv  # kv shape:[bs, seq_len, kv_lora_rank]self.pe_cache[:bs, :seq_len, :] = k_pe  # k_pe shape:[bs, seq_len, qk_rope_head_dim]scores_nope = torch.einsum("bshc,btc->bsht", q_nope,self.kv_cache[:bs, :seq_len, :])  # bshc btc -> bshc bct -> bshtscores_pe = torch.einsum("bshr,btr->bsht", q_pe,self.pe_cache[:bs, :seq_len, :])  # bshr btr -> bshr bt1r -> bshr bthr -> bshtscores = (scores_nope + scores_pe) / math.sqrt(self.qk_nope_head_dim + self.qk_rope_head_dim)  # [bs, seq_len, n_heads, seq_len]if mask is not None:# mask shape:[bs, seq_len, seq_len]scores += mask.unsqueeze(2)scores = scores.softmax(dim=-1)if self.mode == 'naive':x = torch.einsum("bsht,bthd->bshd", scores,self.v_cache[:bs, :seq_len])  # bsht,bthd -> bhst, bhtd -> bhsd -> bshdelse:# scores * v = scores * c * wkv_b[:, -self.v_head_dim:]x = torch.einsum("bsht,btc->bshc", scores,self.kv_cache[:bs, :seq_len])  # x shape:[bs, seq_len, n_heads, kv_lora_rank]x = torch.einsum("bshc,hdc->bshd", x, wkv_b[:, -self.v_head_dim:])  # bshc, hdc -> bshc,dch -> bsdh -> bshdx = x.contiguous().view(bs, seq_len, -1)x = self.wo(x) return xif __name__ == '__main__':torch.manual_seed(0)torch.set_printoptions(precision=3, sci_mode=False)x = torch.randn(1, 4, 16)dim = 16n_heads = 2q_lora_rank = 10kv_lora_rank = 6qk_nope_head_dim = 8qk_rope_head_dim = 4v_head_dim = 8max_seq_len = 10max_batch_size = 4mode = 'none'mla = MLA(dim=dim,n_heads=n_heads,q_lora_rank=q_lora_rank,kv_lora_rank=kv_lora_rank,qk_nope_head_dim=qk_nope_head_dim,qk_rope_head_dim=qk_rope_head_dim,v_head_dim=v_head_dim,max_seq_len=max_seq_len,max_batch_size=max_batch_size,mode=mode)print(mla(x))print(mla.kv_cache)

参考资料:

https://zhuanlan.zhihu.com/p/16730036197

https://github.com/wyf3/llm_related/tree/main/deepseek_learn

相关文章:

大模型推理——MLA实现方案

1.整体流程 先上一张图来整体理解下MLA的计算过程 2.实现代码 import math import torch import torch.nn as nn# rms归一化 class RMSNorm(nn.Module):""""""def __init__(self, hidden_size, eps1e-6):super().__init__()self.weight nn.Pa…...

深度学习-神经机器翻译模型

以下为你介绍使用Python和深度学习框架Keras(基于TensorFlow后端)实现一个简单的神经机器翻译模型的详细步骤和代码示例,该示例主要处理英 - 法翻译任务。 1. 安装必要的库 首先,确保你已经安装了以下库: pip insta…...

Android Camera API 介绍

一 StreamConfigurationMap 1. StreamConfigurationMap 的作用 StreamConfigurationMap 是 Android Camera2 API 中的一个核心类,用于描述相机设备支持的输出流配置,包含以下信息: 支持的格式与分辨率:例如 YUV_420_888、JPEG、…...

大数据项目2:基于hadoop的电影推荐和分析系统设计和实现

前言 大数据项目源码资料说明: 大数据项目资料来自我多年工作中的开发积累与沉淀。 我分享的每个项目都有完整代码、数据、文档、效果图、部署文档及讲解视频。 可用于毕设、课设、学习、工作或者二次开发等,极大提升效率! 1、项目目标 本…...

Windows逆向工程入门之汇编环境搭建

公开视频 -> 链接点击跳转公开课程博客首页 -> ​​​链接点击跳转博客主页 Visual Studio逆向工程配置 基础环境搭建 Visual Studio 官方下载地址安装配置选项(后期可随时通过VS调整) 使用C的桌面开发 拓展可选选项 MASM汇编框架 配置MASM汇编项目 创建新项目 选择空…...

gc buffer busy acquire导致的重大数据库性能故障

📢📢📢📣📣📣 作者:IT邦德 中国DBA联盟(ACDU)成员,10余年DBA工作经验 Oracle、PostgreSQL ACE CSDN博客专家及B站知名UP主,全网粉丝10万 擅长主流Oracle、MySQL、PG、高斯…...

前端学习-页面加载事件和页面滚动事件(三十二)

目录 前言 页面加载事件和页面滚动事件 页面加载事件 load事件 语法 注意 DOMContentLoaded事件 语法 总结 页面加载事件有哪两个?如何添加? load 事件 DOMContentLoaded事件 页面滚动事件 存在原因 scroll监听整个页面滚动 页面滚动事件-获取位置 scrollLef…...

C++:将函数参数定义为const T的意义

C++很多函数的参数都会定义为const T&,那么这么做的意义是什么呢? 避免拷贝:通过引用传递参数而不是值传递,可以避免对象的拷贝,从而提高性能,特别是当对象较大时。 保护数据:使用const关键字可以防止函数修改传入的参数,确保数据的安全性和一致性。 对于保护数据这…...

Formily 如何进行表单验证

🤍 前端开发工程师、技术日更博主、已过CET6 🍨 阿珊和她的猫_CSDN博客专家、23年度博客之星前端领域TOP1 🕠 牛客高级专题作者、打造专栏《前端面试必备》 、《2024面试高频手撕题》 🍚 蓝桥云课签约作者、上架课程《Vue.js 和 E…...

安宝特方案 | AR眼镜:远程医疗的“时空折叠者”,如何为生命争夺每一分钟?

行业痛点:当“千里求医”遇上“资源鸿沟” 20世纪50年代,远程会诊的诞生曾让医疗界为之一振——患者不必跨越山河,专家无需舟车劳顿,一根电话线、一张传真纸便能架起问诊的桥梁。然而,传统远程医疗的局限也日益凸显&a…...

使用git commit时‘“node“‘ 不是内部或外部命令,也不是可运行的程序

第一种: 使用git commit -m "xxx"时会报错,我看网上的方法是在命令行后面添加--no-verify:git commit -m "主题更新" --no-verify,但是不可能每次都添加。 最后解决办法是:使用git config --lis…...

Python分享20个Excel自动化脚本

在数据处理和分析的过程中,Excel文件是我们日常工作中常见的格式。通过Python,我们可以实现对Excel文件的各种自动化操作,提高工作效率。 本文将分享20个实用的Excel自动化脚本,以帮助新手小白更轻松地掌握这些技能。 1. Excel单…...

nodejs - vue 视频切片上传,本地正常,线上环境导致磁盘爆满bug

nodejs 视频切片上传,本地正常,线上环境导致磁盘爆满bug 原因: 然后在每隔一分钟执行du -sh ls ,发现文件变得越来越大,即文件下的mp4文件越来越大 最后导致磁盘直接爆满 排查原因 1、尝试将m3u8文件夹下的所有视…...

瑞友天翼应用虚拟化系统 GetPwdPolicy SQL注入漏洞复现

免责声明 本文旨在提供有关特定漏洞的深入信息,帮助用户充分了解潜在的安全风险。发布此信息的目的在于提升网络安全意识和推动技术进步,未经授权访问系统、网络或应用程序,可能会导致法律责任或严重后果。因此,作者不对读者基于本文内容所采取的任何行为承担责任。读者在使…...

【MySQL — 数据库基础】深入解析MySQL的聚合查询

1. 聚合查询 1.1 聚合函数 函数说明COUNT ( [DISTINCT] expr)返回查询到的数据的数量( 行数 )SUM ( [DISTINCT] expr)返回查询到的数据的总和,不是数字没有意义AVG ( [DISTINCT] expr)返回查询到的数据的平均值,不是数字没有意义MAX( [DISTINCT] expr)…...

22.3、IIS安全分析与增强

目录 IIS安全威胁分析iis安全机制iis安全增强 IIS安全威胁分析 iis是微软公司的Web服务软件,主要提供网页服务,除此之外还可以提供其他服务,第一个最主要的是网页服务,第二个是SMTP邮件服务,第三个是FTP文件传输服务。…...

windows平台本地部署DeepSeek大模型+Open WebUI网页界面(可以离线使用)

环境准备: 确定部署方案请参考:DeepSeek-R1系列(1.5b/7b/8b/32b/70b/761b)大模型部署需要什么硬件条件-CSDN博客 根据本人电脑配置:windows11 + i9-13900HX+RTX4060+DDR5 5600 32G内存 确定部署方案:DeepSeek-R1:7b + Ollama + Open WebUI 1. 安装 Ollama Ollama 是一…...

港中文腾讯提出可穿戴3D资产生成方法BAG,可自动生成服装和配饰等3D资产如,并适应特定的人体模型。

今天给大家介绍一种名为BAG(Body-Aligned 3D Wearable Asset Generation)的新方法,可以自动生成可穿戴的3D资产,如服装和配饰,以适应特定的人体模型。BAG方法通过构建一个多视图图像扩散模型,生成与人体对齐…...

【人工智能】Python中的序列到序列(Seq2Seq)模型:实现机器翻译

《Python OpenCV从菜鸟到高手》带你进入图像处理与计算机视觉的大门! 解锁Python编程的无限可能:《奇妙的Python》带你漫游代码世界 序列到序列(Seq2Seq)模型是自然语言处理(NLP)中一项核心技术,广泛应用于机器翻译、语音识别、文本摘要等任务。本文深入探讨Seq2Seq模…...

34.日常算法

1.合并区间 题目来源 以数组 intervals 表示若干个区间的集合,其中单个区间为 intervals[i] [starti, endi] 。请你合并所有重叠的区间,并返回 一个不重叠的区间数组,该数组需恰好覆盖输入中的所有区间 。 示例 1: 输入&#x…...

DeepSeek深度思考:客户端(Android/iOS)架构设计指南

目标读者:中高级开发者、架构师 适用场景:大型复杂应用开发、跨团队协作、长期维护迭代 一、架构设计核心原则 1.模块化(Modularization) 横向拆分:按功能边界划分(如登录、支付、消息模块)纵向…...

2025 年前端开发现状分析:卷疯了还是卷麻了?

一、前端现状:框架狂飙,开发者崩溃 如果你是个前端开发者,那么你大概率经历过这些场景: 早上打开 CSDN(或者掘金,随便),发现又有新框架发布了,名字可能是 VueXNext.js 之…...

数据库 绪论

目录 数据库基本概念 一.基本概念 1.信息 2.数据 3.数据库(DB) 4.数据库管理系统(DBMS) 5.数据库系统(DBS) 二.数据管理技术的发展 1.人工管理阶段 2.文件系统阶段 3.数据库系统阶段 4.数据库管…...

【AIGC魔童】DeepSeek v3提示词Prompt书写技巧

【AIGC魔童】DeepSeek v3提示词Prompt书写技巧 (1)基础通用公式(适用80%场景)(2)问题解决公式(决策支持)(3)创意生成公式(4)学习提升公…...

Docker 部署 RabbitMQ | 自带延时队列

一、获取镜像 docker pull farerboy/rabbitmq:3.9.9 二、运行镜像 docker run -d --name rabbitmq \n --hostname rabbitmq \n -p 15672:15672/tcp \n -p 5672:5672/tcp \n -v /wwwroot/opt/docker/rabbitmq:/var/lib/rabbitmq \n farerboy/rabbitmq:3.9.9 备注:…...

【‌Unity】Unity中物体的static属性作用

‌Unity中物体的static属性主要用于优化游戏性能和简化渲染过程。‌ Unity中物体的static属性的作用 优化渲染性能‌:当物体被标记为static时,Unity会在游戏运行时将其视为静止的物体,这意味着这些物体的渲染信息不会随着每一帧的更新而变化…...

网络编程基础1

七层协议模型和四层协议模型 七层协议模型:物理层、数据链路层、网络层、传输层、会话层、表示层、应用层 四层协议模型:链路层、网络层、传输层、应用层 TCP通信流程 服务器端 (1)创建socket(socket) (2)绑定自己的IP(bind) (3)监听客户端连接(liste…...

跨越边界,大模型如何助推科技与社会的完美结合?

点击蓝字 关注我们 AI TIME欢迎每一位AI爱好者的加入! 概述 2024年,大模型技术已成为人工智能领域的焦点。这不仅仅是一项技术进步,更是一次可能深刻影响社会发展方方面面的变革。大模型的交叉能否推动技术与社会的真正融合?2025年…...

kafka生产端之架构及工作原理

文章目录 整体架构元数据更新 整体架构 消息在真正发往Kafka之前,有可能需要经历拦截器(Interceptor)、序列化器(Serializer)和分区器(Partitioner)等一系列的作用,那么在此之后又会…...

在 Windows 上使用 ZIP 包安装 MySQL 的详细步骤

以下是使用官方 ZIP 包在 Windows 上安装 MySQL 的详细步骤,确保能通过 mysql -uroot -p 成功连接。 步骤 1:下载 MySQL ZIP 包 访问 MySQL 官方下载页面: https://dev.mysql.com/downloads/mysql/选择 Windows (x86, 64-bit), ZIP Archive&…...