当前位置: 首页 > article >正文

DeepSeek V3 源码:从入门到放弃!

从入门到放弃

花了几天时间,看懂了DeepSeek V3 源码的逻辑。源码的逻辑是不难的,但为什么模型结构需要这样设计,为什么参数需要这样设置呢?知其然,但不知其所以然。除了模型结构以外,模型的训练数据、训练脚本和训练经验,也是DeepSeek V3能够训练出来的关键,但这些是DeepSeek母公司的核心机密,我们无从得知。
因此,看懂了源码,算是入门了DeepSeek V3,因为没有条件知道更多重要细节,因此不得不放弃重现整个模型的训练。

Paper 和源码

Paper URL: https://github.com/deepseek-ai/DeepSeek-V3/blob/main/DeepSeek_V3.pdf
Code URL: https://github.com/deepseek-ai/DeepSeek-V3

模型逻辑

下面这张图,代表了DeepSeek的核心逻辑。左边是Transformer的逻辑结构,可以认为有N个左边这样的Block结构不断重复,组成Transformer模型。每个Block中,分成两个部分,Attention 和 Feed-Forward Network。对这两个部分使用不同的网络结构,我们就得到了不同的模型。
DeepSeek V3 的 Attention 用的是 Multi-Head Latent Attention(MLA) ,Feed-Forward Network 用的是DeepSeekMoE。
在这里插入图片描述

MLA

Multi-Head Latent Attention(MLA)即多头潜在注意力,是DeepSeek模型中引入的一种创新注意力机制,旨在优化传统多头注意力(Multi-Head Attention,MHA)的计算效率和内存占用。具体介绍如下:

核心创新点

  • 低秩键值压缩
    • KV的低秩压缩:不直接存储原始的Key和Value,而是先将隐藏状态投影到一个更小的压缩潜在向量。在推理时,只需缓存该压缩潜在向量,而不是完整的Key和Value,从而大大降低了KV缓存的存储需求。
    • Query的低秩压缩:对Query也进行低秩压缩,虽然不会减少KV缓存的大小,但可以减少训练时的激活存储需求,进而降低计算成本。
  • 解耦旋转位置嵌入(RoPE)
    • 额外引入“解耦查询”:将查询拆分为两个部分,一部分不经过RoPE变换,代表非位置敏感的特征信息;另一部分专门用于嵌入RoPE位置编码信息。
    • 共享RoPE变换的Key:所有注意力头共用一个旋转变换后的Key,减少了计算开销,也减小了KV缓存大小,降低了GPU内存占用,提高了推理速度,特别适用于长序列任务和大规模Transformer。

推理过程中的优化

将上投影矩阵吸收到里面,简化查询计算,并优化注意力分数的计算,减少了计算步骤,提升了计算效率。避免了先计算Value向量,减少了矩阵运算的开销,使推理更快。

整体优势

  • 降低内存占用:通过对键值进行低秩联合压缩以及解耦RoPE等策略,显著减少了KV缓存的存储需求,降低了GPU内存占用。
  • 提高计算效率:减少了训练和推理过程中的计算量,加快了模型的推理速度,在保持甚至提高模型性能的同时,提升了模型的运行效率。
  • 增强模型适应性:特别适用于长序列任务和大规模Transformer模型,能够更好地处理长序列输入,提高模型在各种自然语言处理任务中的表现。

MLA 有物理意义吗?

Multi-Head Latent Attention(MLA)能够起作用主要源于其独特的技术设计,在数学和信息处理层面有清晰的逻辑,不过它是一种抽象的算法概念,并不直接对应具体的物理意义,以下是对其作用原理的分析:

起作用的原因

  • 低秩压缩的有效性
    • 信息浓缩与降噪:通过低秩键值压缩,MLA将高维的Key和Value信息投影到低维的潜在向量空间,这一过程类似于对原始信息进行浓缩,提取出最关键、最具代表性的特征,去除了一些可能的噪声和冗余信息,使得模型能够更聚焦于重要信息,从而提高信息处理的效率和准确性。
    • 减少计算量和存储需求:低秩压缩大大降低了数据的维度,减少了模型训练和推理过程中的计算量和存储需求,使得模型能够更高效地运行,尤其是在处理大规模数据和长序列数据时,这种优势更为明显。
  • 解耦旋转位置嵌入的优势
    • 位置信息与内容信息的分离:传统的位置编码方式将位置信息和内容信息混合在一起进行处理,而MLA的解耦旋转位置嵌入将查询拆分为位置敏感和非位置敏感两部分,使模型能够更清晰地分离和处理位置信息与内容信息,更好地捕捉文本中的长距离依赖关系。
    • 共享RoPE变换的Key:所有注意力头共用一个旋转变换后的Key,不仅减少了计算开销,还使得模型能够从更宏观的角度利用位置信息,增强了模型对序列数据整体结构的理解和把握能力。
  • 多头机制的协同作用
    • 捕捉多维度信息:MLA中的多头机制允许模型同时从多个不同的角度和维度去捕捉输入数据中的信息,每个头可以关注到输入序列的不同方面,通过多个头的并行计算和协同工作,模型能够更全面、更深入地理解输入数据,提高模型的表示能力和泛化能力。

难以直接赋予物理意义的原因及近似理解

  • 抽象的算法概念:MLA是一种基于数学和计算机科学的算法概念,主要用于处理和分析数据中的模式和关系,它不像物理概念那样具有直接可观测的物理实体或现象与之对应,更多地是在数据空间和计算逻辑中发挥作用。
  • 类比物理现象理解:可以进行一些类比来帮助理解。比如低秩压缩类似于物理中的能量聚集,将分散的能量(信息)聚集到关键的“点”上;解耦旋转位置嵌入有点像物理中对不同性质力的分解,将位置信息和内容信息这两种“力”分开处理;多头机制如同多个物理传感器从不同方向和角度对环境进行感知,然后综合这些感知信息来对整个系统进行理解和判断。

DeepSeekMoE

DeepSeekMoE是由深度求索(DeepSeek)研发的基于混合专家系统(Mixture of Experts,MoE)的技术架构,以下是具体介绍:

架构原理

  • 混合专家系统核心:采用MoE架构,核心在于通过动态路由机制,把输入数据分配给最相关的专家处理。比如在自然语言处理中,有的专家专门处理情感分析,有的处理主题建模。
  • 结合多头潜在注意力机制:与MLA相结合,MLA通过引入潜在向量,减少键值缓存(KV cache)需求,提升推理效率。
  • Transformer架构基础:以Transformer架构为基础,每个Transformer块由一个注意力模块和一个前馈网络(FFN)组成,在注意力机制和FFN方面采用创新架构。

技术优势

  • 降低算力需求:MoE的动态分配机制和MLA减少KV缓存需求等特点,使模型在训练和推理时对算力的要求降低。
  • 保持高性能:在参数量减少的情况下仍能保持高性能,例如DeepSeek-V2以236B总参数、21B激活,大致可以达到70B-110B Dense的模型能力。
  • 减少计算量:自研Sparse结构DeepSeekMoE进一步降低了计算量。
  • 长上下文理解能力强:支持超100万token的上下文窗口,显著优于行业平均水平,适用于长文档分析、代码开发等复杂场景的连贯交互。

DeepSeekMoE的物理意义是什么?

DeepSeekMoE作为一种人工智能技术架构,没有严格意义上的物理意义,但可以从一些角度进行类比和理解:

从系统资源分配角度

  • 资源按需分配类比:可以将DeepSeekMoE的专家网络和动态路由机制类比为一个智能电力分配系统。在这个系统中,不同的电器设备(任务)需要不同的电量(计算资源)来运行。专家网络就像不同功率的发电机,而动态路由机制则像是智能电表和分配器,它会根据每个电器设备的实际需求,将电力(计算资源)精准地分配给需要的设备,避免了资源的浪费,提高了整个系统的能源利用效率。
  • 负载均衡类比:类似于在一个大型物流中心,不同的仓库区域(专家)负责存储和处理不同类型的货物(数据)。当有货物运输任务时,调度系统(动态路由)会根据货物的特点和仓库的负载情况,合理地安排货物存储到哪个仓库,确保每个仓库都能在其承载能力范围内高效运作,不会出现某个仓库过度拥挤而其他仓库闲置的情况,实现了负载均衡,提高了物流中心的整体运营效率。

从信息处理角度

  • 多维度信息处理类比:可以把DeepSeekMoE处理信息的过程想象成一个由多个不同专业的侦探(专家)组成的侦探团队在调查一个复杂案件。每个侦探都有自己独特的专业技能和视角,比如有的擅长调查线索,有的擅长分析人物关系,有的擅长破解密码等。当面对案件(输入数据)时,队长(路由器)会根据案件的具体情况,分配合适的侦探去处理相应的部分,最后将各个侦探的调查结果综合起来,形成对整个案件的全面了解和判断,从而更高效地解决复杂问题。
  • 特征提取与融合类比:如同在一个化学实验中,不同的化学试剂(专家)可以与不同的物质发生反应,提取出特定的化学特征。DeepSeekMoE中的专家网络就像这些化学试剂,它们各自对输入数据进行处理,提取出不同的特征。然后通过融合机制,将这些特征像混合化学物质一样进行整合,得到更全面、更有价值的信息,用于后续的分析和决策。

从模型架构角度

  • 积木搭建类比:把DeepSeekMoE的架构比作搭建积木。每个专家网络就像不同形状和功能的积木块,有的积木块负责搭建基础结构,有的负责构建上层建筑,有的负责添加装饰等。路由器则像是搭建者的手,根据要搭建的目标模型的需求,选择合适的积木块进行组合,最终搭建出一个复杂而功能强大的模型结构,实现对各种自然语言处理任务的高效处理。
  • 人体神经系统类比:可以将DeepSeekMoE类比为人体的神经系统。专家网络类似于人体的不同神经细胞或神经中枢,它们各自负责处理特定类型的信息,如视觉神经细胞负责处理视觉信息,听觉神经细胞负责处理听觉信息等。路由器就像神经系统中的神经递质或信号传导机制,它负责将外界的刺激信号(输入数据)准确地传递给相应的神经细胞,并将各个神经细胞处理后的信号进行整合和传递,使人体能够做出协调的反应和决策,实现对外部世界的感知和交互。

代码逻辑

整体 - Transformer

下面这段代码是典型的 Transformer 实现,核心可以看 forward 函数逻辑:

  1. 进行 Embeding;
  2. 经过各个 Block;
  3. 归一化并输出。
    对应的代码:
# 通过嵌入层将输入标记转换为向量表示
h = self.embed(tokens)
# 依次通过每个Transformer块进行处理
for layer in self.layers:h = layer(h, start_pos, freqs_cis, mask)
# 对输出进行层归一化,并取最后一个时间步的输出
h = self.norm(h)[:, -1]
# 通过输出投影层得到对数概率
logits = self.head(h)

完整代码:

# 定义Transformer类,继承自PyTorch的nn.Module类
class Transformer(Module):"""Transformer模型,包含位置嵌入、多个层以及输出投影。属性:max_seq_len (int): Transformer允许的最大序列长度。embed (nn.Module): 用于输入标记的嵌入层,将输入的标记转换为向量表示。layers (torch.nn.ModuleList): 存储多个Transformer块的列表,每个块包含多头注意力和前馈网络。norm (nn.Module): 层归一化层,在所有Transformer块之后应用,用于稳定训练。head (nn.Module): 输出投影层,将模型的输出映射到词汇表大小,用于预测下一个标记。freqs_cis (torch.Tensor): 预计算的复指数值,用于旋转位置嵌入,帮助模型捕捉序列中的位置信息。"""def __init__(self, args):"""初始化Transformer模型。参数:args: 模型参数对象,包含Transformer的各种参数,如词汇表大小、维度、层数等。"""# 获取全局变量world_size和rank,分别表示分布式训练中的进程总数和当前进程的编号global world_size, rank# 如果分布式训练已初始化,则获取进程总数,否则默认为1world_size = dist.get_world_size() if dist.is_initialized() else 1# 如果分布式训练已初始化,则获取当前进程编号,否则默认为0rank = dist.get_rank() if dist.is_initialized() else 0# 根据参数设置线性层的数据类型Linear.dtype = torch.float8_e4m3fn if args.dtype == "fp8" else torch.bfloat16# 调用父类的初始化方法super().__init__()# 保存最大序列长度self.max_seq_len = args.max_seq_len# 初始化嵌入层,将输入标记转换为向量表示self.embed = ParallelEmbedding(args.vocab_size, args.dim)# 初始化一个空的ModuleList,用于存储Transformer块self.layers = torch.nn.ModuleList()# 循环创建指定数量的Transformer块,并添加到layers列表中for layer_id in range(args.n_layers):self.layers.append(Block(layer_id, args))# 初始化层归一化层self.norm = RMSNorm(args.dim)# 初始化输出投影层,将模型的输出映射到词汇表大小self.head = ColumnParallelLinear(args.dim, args.vocab_size, dtype=torch.get_default_dtype())# 预计算旋转位置嵌入所需的复指数值,并将其注册为缓冲区,不参与模型参数的更新self.register_buffer("freqs_cis", precompute_freqs_cis(args), persistent=False)@torch.inference_mode()def forward(self, tokens, start_pos=0):"""Transformer模型的前向传播过程。参数:tokens (torch.Tensor): 输入的标记ID张量,形状为 (batch_size, seq_len)。start_pos (int, 可选): 旋转位置嵌入的起始位置,默认为0。返回:torch.Tensor: 对数概率张量,形状为 (batch_size, vocab_size),表示每个标记的预测概率。"""# 获取输入序列的长度seqlen = tokens.size(1)# 通过嵌入层将输入标记转换为向量表示h = self.embed(tokens)# 从预计算的复指数值中截取当前序列所需的部分freqs_cis = self.freqs_cis[start_pos:start_pos + seqlen]# 初始化掩码为Nonemask = None# 如果序列长度大于1,则创建一个上三角掩码,用于屏蔽未来的标记if seqlen > 1:mask = torch.full((seqlen, seqlen), float("-inf"), device=tokens.device).triu_(1)# 依次通过每个Transformer块进行处理for layer in self.layers:h = layer(h, start_pos, freqs_cis, mask)# 对输出进行层归一化,并取最后一个时间步的输出h = self.norm(h)[:, -1]# 通过输出投影层得到对数概率logits = self.head(h)# 如果使用分布式训练,则收集所有进程的对数概率if world_size > 1:# 创建一个列表,用于存储所有进程的对数概率all_logits = [torch.empty_like(logits) for _ in range(world_size)]# 收集所有进程的对数概率dist.all_gather(all_logits, logits)# 将所有进程的对数概率拼接在一起logits = torch.cat(all_logits, dim=-1)return logits

单个 - Block

在这里插入图片描述
核心代码非常简单MLA(attention) + MOE(Feed-Forward Network):

# 首先对输入进行层归一化,然后通过注意力层进行计算,最后将结果与输入进行残差连接
x = x + self.attn(self.attn_norm(x), start_pos, freqs_cis, mask)
# 接着对上述结果进行层归一化,再通过前馈网络层进行计算,最后将结果与之前的结果进行残差连接
x = x + self.ffn(self.ffn_norm(x))

全部代码:

# 定义一个Transformer块类,继承自PyTorch的nn.Module类
class Block(Module):"""Transformer块,结合了注意力层和前馈网络层。属性:attn (nn.Module): 注意力层(采用多头潜在注意力机制,即MLA),用于捕捉输入序列中不同位置之间的依赖关系。ffn (nn.Module): 前馈网络层(可以是多层感知机MLP或者混合专家模型MoE),对注意力层的输出进行非线性变换。attn_norm (nn.Module): 用于注意力层的层归一化层,对输入到注意力层的数据进行归一化处理,稳定训练过程。ffn_norm (nn.Module): 用于前馈网络层的层归一化层,对输入到前馈网络层的数据进行归一化处理。"""def __init__(self, layer_id, args):"""初始化Transformer块。参数:layer_id (int): 当前块在Transformer模型中的层索引,用于确定使用哪种前馈网络结构。args: 模型参数对象,包含了块的各种参数,如维度、层数等。"""# 调用父类的初始化方法super().__init__()# 初始化注意力层,使用多头潜在注意力机制(MLA)self.attn = MLA(args)# 根据当前层的索引来决定使用MLP还是MoE作为前馈网络# 如果当前层索引小于密集层的数量,则使用MLP# 否则使用混合专家模型(MoE)self.ffn = MLP(args.dim, args.inter_dim) if layer_id < args.n_dense_layers else MoE(args)# 初始化用于注意力层的层归一化层self.attn_norm = RMSNorm(args.dim)# 初始化用于前馈网络层的层归一化层self.ffn_norm = RMSNorm(args.dim)def forward(self, x, start_pos, freqs_cis, mask=None):"""Transformer块的前向传播过程。参数:x (torch.Tensor): 输入张量,包含了序列的特征信息。start_pos (int): 序列中的起始位置,用于旋转位置嵌入。freqs_cis (torch.Tensor): 预计算的复指数值,用于旋转位置嵌入,帮助模型捕捉序列中的位置信息。mask (Optional[torch.Tensor]): 掩码张量,用于在注意力计算中排除某些位置,避免模型关注到不应该关注的信息。返回:torch.Tensor: 经过当前Transformer块计算后的输出张量。"""# 首先对输入进行层归一化,然后通过注意力层进行计算,最后将结果与输入进行残差连接x = x + self.attn(self.attn_norm(x), start_pos, freqs_cis, mask)# 接着对上述结果进行层归一化,再通过前馈网络层进行计算,最后将结果与之前的结果进行残差连接x = x + self.ffn(self.ffn_norm(x))return x

Attention 模块

经典的QKV计算公式。解释可以自行搜索,或者参考:Transformer结构和注意力机制
在这里插入图片描述
和传统的QKV相比,可以认为是做了压缩,主要是为了减小 KV Cache。
在这里插入图片描述
代码,就是做了一堆下采样上采样和矩阵的组合变换。最终目的是减少计算量和显存使用量。

# 定义多头注意力层类,继承自PyTorch的nn.Module类
class MLA(Module):"""多头注意力层(MLA)。属性:dim (int): 输入特征的维度。n_heads (int): 注意力头的数量。n_local_heads (int): 分布式系统中本地注意力头的数量。q_lora_rank (int): 查询(query)的低秩投影的秩。kv_lora_rank (int): 键(key)和值(value)的低秩投影的秩。qk_nope_head_dim (int): 非位置相关的查询/键投影的维度。qk_rope_head_dim (int): 旋转位置编码的查询/键投影的维度。qk_head_dim (int): 查询/键投影的总维度。v_head_dim (int): 值投影的维度。softmax_scale (float): 注意力计算中softmax函数的缩放因子。"""def __init__(self, args):# 调用父类的初始化方法super().__init__()# 保存输入特征的维度self.dim = args.dim# 保存注意力头的数量self.n_heads = args.n_heads# 计算分布式系统中本地注意力头的数量self.n_local_heads = args.n_heads // world_size# 保存查询的低秩投影的秩self.q_lora_rank = args.q_lora_rank# 保存键和值的低秩投影的秩self.kv_lora_rank = args.kv_lora_rank# 保存非位置相关的查询/键投影的维度self.qk_nope_head_dim = args.qk_nope_head_dim# 保存旋转位置编码的查询/键投影的维度self.qk_rope_head_dim = args.qk_rope_head_dim# 计算查询/键投影的总维度self.qk_head_dim = args.qk_nope_head_dim + args.qk_rope_head_dim# 保存值投影的维度self.v_head_dim = args.v_head_dim# 如果查询的低秩投影的秩为0,直接使用列并行线性层进行查询投影if self.q_lora_rank == 0:self.wq = ColumnParallelLinear(self.dim, self.n_heads * self.qk_head_dim)# 否则,使用低秩分解的方式进行查询投影else:self.wq_a = Linear(self.dim, self.q_lora_rank)self.q_norm = RMSNorm(self.q_lora_rank)self.wq_b = ColumnParallelLinear(self.q_lora_rank, self.n_heads * self.qk_head_dim)# 对输入进行线性变换得到键和值的低秩表示self.wkv_a = Linear(self.dim, self.kv_lora_rank + self.qk_rope_head_dim)# 对键和值的低秩表示进行归一化self.kv_norm = RMSNorm(self.kv_lora_rank)# 对归一化后的键和值进行线性变换self.wkv_b = ColumnParallelLinear(self.kv_lora_rank, self.n_heads * (self.qk_nope_head_dim + self.v_head_dim))# 对多头注意力的输出进行行并行线性变换self.wo = RowParallelLinear(self.n_heads * self.v_head_dim, self.dim)# 计算softmax函数的缩放因子self.softmax_scale = self.qk_head_dim ** -0.5# 如果最大序列长度大于原始序列长度,对缩放因子进行调整if args.max_seq_len > args.original_seq_len:mscale = 0.1 * args.mscale * math.log(args.rope_factor) + 1.0self.softmax_scale = self.softmax_scale * mscale * mscale# 如果注意力实现方式为朴素方式if attn_impl == "naive":# 注册键缓存self.register_buffer("k_cache", torch.zeros(args.max_batch_size, args.max_seq_len, self.n_local_heads, self.qk_head_dim), persistent=False)# 注册值缓存self.register_buffer("v_cache", torch.zeros(args.max_batch_size, args.max_seq_len, self.n_local_heads, self.v_head_dim), persistent=False)# 否则else:# 注册键值缓存self.register_buffer("kv_cache", torch.zeros(args.max_batch_size, args.max_seq_len, self.kv_lora_rank), persistent=False)# 注册位置编码缓存self.register_buffer("pe_cache", torch.zeros(args.max_batch_size, args.max_seq_len, self.qk_rope_head_dim), persistent=False)def forward(self, x, start_pos, freqs_cis, mask=None):"""多头注意力层(MLA)的前向传播过程。参数:x (torch.Tensor): 输入张量,形状为 (batch_size, seq_len, dim)。start_pos (int): 序列中用于缓存的起始位置。freqs_cis (torch.Tensor): 预计算的复指数值,用于旋转位置编码。mask (Optional[torch.Tensor]): 掩码张量,用于在注意力计算中排除某些位置。返回:torch.Tensor: 输出张量,形状与输入相同。"""# 获取输入张量的批次大小、序列长度bsz, seqlen, _ = x.size()# 计算序列的结束位置end_pos = start_pos + seqlen# 如果查询的低秩投影的秩为0,直接通过线性层得到查询if self.q_lora_rank == 0:q = self.wq(x)# 否则,通过低秩分解的方式得到查询else:q = self.wq_b(self.q_norm(self.wq_a(x)))# 调整查询的形状,将其划分为多个头q = q.view(bsz, seqlen, self.n_local_heads, self.qk_head_dim)# 将查询划分为非位置相关部分和位置相关部分q_nope, q_pe = torch.split(q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)# 对位置相关部分应用旋转位置编码q_pe = apply_rotary_emb(q_pe, freqs_cis)# 通过线性层得到键和值的低秩表示kv = self.wkv_a(x)# 将键和值的低秩表示划分为低秩部分和位置编码部分kv, k_pe = torch.split(kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)# 对位置编码部分应用旋转位置编码k_pe = apply_rotary_emb(k_pe.unsqueeze(2), freqs_cis)# 如果注意力实现方式为朴素方式if attn_impl == "naive":# 将非位置相关部分和位置相关部分拼接得到完整的查询q = torch.cat([q_nope, q_pe], dim=-1)# 对键和值的低秩表示进行归一化和线性变换kv = self.wkv_b(self.kv_norm(kv))# 调整键和值的形状,将其划分为多个头kv = kv.view(bsz, seqlen, self.n_local_heads, self.qk_nope_head_dim + self.v_head_dim)# 将键和值划分为非位置相关部分和值部分k_nope, v = torch.split(kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1)# 将非位置相关部分和位置编码部分拼接得到完整的键k = torch.cat([k_nope, k_pe.expand(-1, -1, self.n_local_heads, -1)], dim=-1)# 将键存入缓存self.k_cache[:bsz, start_pos:end_pos] = k# 将值存入缓存self.v_cache[:bsz, start_pos:end_pos] = v# 计算注意力分数scores = torch.einsum("bshd,bthd->bsht", q, self.k_cache[:bsz, :end_pos]) * self.softmax_scale# 否则else:# 获取键和值的线性变换层的权重wkv_b = self.wkv_b.weight if self.wkv_b.scale is None else weight_dequant(self.wkv_b.weight, self.wkv_b.scale, block_size) # 调整权重的形状wkv_b = wkv_b.view(self.n_local_heads, -1, self.kv_lora_rank)# 计算非位置相关部分的注意力分数q_nope = torch.einsum("bshd,hdc->bshc", q_nope, wkv_b[:, :self.qk_nope_head_dim])# 将键和值的低秩表示归一化后存入缓存self.kv_cache[:bsz, start_pos:end_pos] = self.kv_norm(kv)# 将位置编码部分存入缓存self.pe_cache[:bsz, start_pos:end_pos] = k_pe.squeeze(2)# 计算注意力分数scores = (torch.einsum("bshc,btc->bsht", q_nope, self.kv_cache[:bsz, :end_pos]) +torch.einsum("bshr,btr->bsht", q_pe, self.pe_cache[:bsz, :end_pos])) * self.softmax_scale# 如果存在掩码,将掩码加到注意力分数上if mask is not None:scores += mask.unsqueeze(1)# 对注意力分数应用softmax函数scores = scores.softmax(dim=-1, dtype=torch.float32).type_as(x)# 如果注意力实现方式为朴素方式if attn_impl == "naive":# 通过注意力分数和值缓存计算输出x = torch.einsum("bsht,bthd->bshd", scores, self.v_cache[:bsz, :end_pos])# 否则else:# 通过注意力分数和键值缓存计算中间结果x = torch.einsum("bsht,btc->bshc", scores, self.kv_cache[:bsz, :end_pos])# 通过中间结果和权重计算输出x = torch.einsum("bshc,hdc->bshd", x, wkv_b[:, -self.v_head_dim:])# 对输出进行线性变换x = self.wo(x.flatten(2))return x

代码解释总结

这段代码定义了一个多头注意力层(MLA)类。在初始化时,根据传入的参数设置各种维度、低秩投影的秩等,并初始化相应的线性层和归一化层,同时根据注意力实现方式注册不同的缓存。在前向传播过程中,对输入进行处理得到查询、键和值,应用旋转位置编码,根据不同的注意力实现方式计算注意力分数,最后通过注意力分数和缓存得到输出并进行线性变换。

Feed-Forward Network

在这里插入图片描述
这个MoE分成两个部分,左边是一些可以分享的专家,就是每次都需要去计算的,右边的是根据分数来选择的。如何选择,是通过一个门控机制来选择。这个门控是如何设计的,代码里有实现,但论文和代码都没有对它物理意义的解释。门控机制,简单来说,就是设计了一个网络,来选出K个候选。

核心逻辑:

  1. 通过门控确定本轮要用到的本地专家:weights, indices = self.gate(x)
  2. 用选择的每个本地专家进行计算:y[idx] += expert(x[idx]) * weights[idx, top, None]
  3. 用共享专家进行计算:z = self.shared_experts(x)
  4. 将本地专家的输出和共享专家的输出相加,并恢复到原始形状: return (y + z).view(shape)

全部代码:

# 定义混合专家(Mixture-of-Experts, MoE)模块类,继承自PyTorch的nn.Module类
class MoE(nn.Module):"""混合专家(Mixture-of-Experts, MoE)模块。属性:dim (int): 输入特征的维度。n_routed_experts (int): 模型中专家的总数。n_local_experts (int): 在分布式系统中本地处理的专家数量。n_activated_experts (int): 每个输入激活的专家数量。gate (nn.Module): 门控机制,用于将输入路由到不同的专家。experts (nn.ModuleList): 专家模块列表,包含多个专家网络。shared_experts (nn.Module): 共享专家模块,应用于所有输入。"""def __init__(self, args):"""初始化MoE模块。参数:args: 模型参数对象,包含MoE模块的相关参数。"""# 调用父类的初始化方法super().__init__()# 保存输入特征的维度self.dim = args.dim# 确保专家总数能被分布式系统中的进程数整除assert args.n_routed_experts % world_size == 0, f"专家数量必须能被进程数整除 (进程数={world_size})"# 保存模型中专家的总数self.n_routed_experts = args.n_routed_experts# 计算本地处理的专家数量self.n_local_experts = args.n_routed_experts // world_size# 保存每个输入激活的专家数量self.n_activated_experts = args.n_activated_experts# 计算本地专家在所有专家中的起始索引self.experts_start_idx = rank * self.n_local_experts# 计算本地专家在所有专家中的结束索引self.experts_end_idx = self.experts_start_idx + self.n_local_experts# 初始化门控机制self.gate = Gate(args)# 初始化专家模块列表,本地负责的专家使用Expert模块,其他位置置为Noneself.experts = nn.ModuleList([Expert(args.dim, args.moe_inter_dim) if self.experts_start_idx <= i < self.experts_end_idx else Nonefor i in range(self.n_routed_experts)])# 初始化共享专家模块self.shared_experts = MLP(args.dim, args.n_shared_experts * args.moe_inter_dim)def forward(self, x):"""MoE模块的前向传播过程。参数:x (torch.Tensor): 输入张量。返回:torch.Tensor: 经过专家路由和计算后的输出张量。"""# 保存输入张量的原始形状shape = x.size()# 将输入张量展平为二维张量,方便后续处理x = x.view(-1, self.dim)# 通过门控机制得到每个输入分配到各个专家的权重和对应的专家索引weights, indices = self.gate(x)# 初始化输出张量,形状与输入相同,初始值全为0y = torch.zeros_like(x)# 统计每个专家被分配到的输入数量counts = torch.bincount(indices.flatten(), minlength=self.n_routed_experts).tolist()# 遍历本地负责的专家for i in range(self.experts_start_idx, self.experts_end_idx):# 如果该专家没有被分配到输入,则跳过if counts[i] == 0:continue# 获取当前专家模块expert = self.experts[i]# 找出分配到当前专家的输入的索引idx, top = torch.where(indices == i)# 将这些输入通过当前专家模块进行计算,并乘以对应的权重,累加到输出张量中y[idx] += expert(x[idx]) * weights[idx, top, None]# 将输入通过共享专家模块进行计算z = self.shared_experts(x)# 如果使用分布式训练,对本地专家的输出进行全局归约操作if world_size > 1:dist.all_reduce(y)# 将本地专家的输出和共享专家的输出相加,并恢复到原始形状return (y + z).view(shape)

代码解释总结

这段代码定义了一个混合专家(MoE)模块。在初始化时,根据传入的参数设置专家的数量、门控机制、专家模块列表和共享专家模块。在前向传播过程中,首先通过门控机制将输入路由到不同的专家,然后对本地负责的专家进行计算并累加结果,同时将输入通过共享专家模块进行计算,最后将两部分结果相加并恢复原始形状。如果使用分布式训练,还会对本地专家的输出进行全局归约操作。

Gate

不考虑分组路由来看看它的核心逻辑,实际上就是线下变换,然后激活,选择K个极值(如果用了分组,就是选择K的方式发生了一些变化):

# 通过线性变换计算每个输入对应各个专家的分数
scores = linear(x, self.weight)
# 根据评分函数类型对分数进行处理
if self.score_func == "softmax":scores = scores.softmax(dim=-1, dtype=torch.float32)
else:scores = scores.sigmoid()
# 选择分数最高的若干专家
indices = torch.topk(scores, self.topk, dim=-1)[1]
# 根据选择的专家索引,从原始分数中获取对应的权重
weights = scores.gather(1, indices)
# 如果评分函数是sigmoid,对权重进行归一化
if self.score_func == "sigmoid":weights /= weights.sum(dim=-1, keepdim=True)
# 对权重进行缩放
weights *= self.route_scale
return weights.type_as(x), indices
# 定义门控机制类,用于在混合专家(MoE)模型中对输入进行路由
class Gate(nn.Module):"""混合专家(MoE)模型中用于输入路由的门控机制。属性:dim (int): 输入特征的维度。topk (int): 每个输入激活的顶级专家数量。n_groups (int): 用于路由的分组数量。topk_groups (int): 输入将被路由到的分组数量。score_func (str): 评分函数,取值为 'softmax' 或 'sigmoid'。route_scale (float): 路由权重的缩放因子。weight (torch.nn.Parameter): 门控机制的可学习权重。bias (Optional[torch.nn.Parameter]): 门控机制的可选偏置项。"""def __init__(self, args):"""初始化门控机制模块。参数:args: 模型参数对象,包含门控机制的相关参数。"""# 调用父类的初始化方法super().__init__()# 保存输入特征的维度self.dim = args.dim# 保存每个输入激活的顶级专家数量self.topk = args.n_activated_experts# 保存用于路由的分组数量self.n_groups = args.n_expert_groups# 保存输入将被路由到的分组数量self.topk_groups = args.n_limited_groups# 保存评分函数类型self.score_func = args.score_func# 保存路由权重的缩放因子self.route_scale = args.route_scale# 初始化可学习权重self.weight = nn.Parameter(torch.empty(args.n_routed_experts, args.dim))# 根据输入特征维度决定是否初始化偏置项self.bias = nn.Parameter(torch.empty(args.n_routed_experts)) if self.dim == 7168 else Nonedef forward(self, x):"""门控机制的前向传播过程。参数:x (torch.Tensor): 输入张量。返回:Tuple[torch.Tensor, torch.Tensor]: 路由权重和选择的专家索引。"""# 通过线性变换计算每个输入对应各个专家的分数scores = linear(x, self.weight)# 根据评分函数类型对分数进行处理if self.score_func == "softmax":scores = scores.softmax(dim=-1, dtype=torch.float32)else:scores = scores.sigmoid()# 保存原始分数,后续计算权重时使用original_scores = scores# 如果存在偏置项,将其加到分数上if self.bias is not None:scores = scores + self.bias# 如果分组数量大于1,进行分组路由操作if self.n_groups > 1:# 调整分数的形状,以便按组处理scores = scores.view(x.size(0), self.n_groups, -1)# 根据是否有偏置项,计算每个组的分数表示if self.bias is None:group_scores = scores.amax(dim=-1)else:group_scores = scores.topk(2, dim=-1)[0].sum(dim=-1)# 选择分数最高的若干组indices = group_scores.topk(self.topk_groups, dim=-1)[1]# 创建掩码,用于屏蔽未选择的组mask = scores.new_ones(x.size(0), self.n_groups, dtype=bool).scatter_(1, indices, False)# 将屏蔽组的分数设为负无穷scores = scores.masked_fill_(mask.unsqueeze(-1), float("-inf")).flatten(1)# 选择分数最高的若干专家indices = torch.topk(scores, self.topk, dim=-1)[1]# 根据选择的专家索引,从原始分数中获取对应的权重weights = original_scores.gather(1, indices)# 如果评分函数是sigmoid,对权重进行归一化if self.score_func == "sigmoid":weights /= weights.sum(dim=-1, keepdim=True)# 对权重进行缩放weights *= self.route_scalereturn weights.type_as(x), indices

代码解释总结

这段代码定义了一个门控机制(Gate)类,用于在混合专家(MoE)模型中对输入进行路由。在初始化时,根据传入的参数设置各种属性,如输入维度、激活专家数量、分组数量等,并初始化可学习的权重和偏置项。在前向传播过程中,首先计算每个输入对应各个专家的分数,然后根据评分函数类型进行处理,接着根据分组情况进行分组路由操作,选择激活的专家并计算对应的权重,最后返回路由权重和选择的专家索引。

B站大牛详解视频链接

B站大牛有详细视频讲解:
https://www.bilibili.com/video/BV1RtNLeqEeu/

相关文章:

DeepSeek V3 源码:从入门到放弃!

从入门到放弃 花了几天时间&#xff0c;看懂了DeepSeek V3 源码的逻辑。源码的逻辑是不难的&#xff0c;但为什么模型结构需要这样设计&#xff0c;为什么参数需要这样设置呢&#xff1f;知其然&#xff0c;但不知其所以然。除了模型结构以外&#xff0c;模型的训练数据、训练…...

海量数据融合互通丨TiDB 在安徽省住房公积金监管服务平台的应用实践

导读 安徽省住房公积金监管服务平台通过整合全省 17 家公积金中心的数据&#xff0c;致力于实现数据共享、规范化管理与高效数据分析。为了应对海量数据处理需求&#xff0c;安徽省选择 TiDB 作为底层数据库&#xff0c;利用其分布式架构和 HTAP 能力&#xff0c;实现了快速的…...

Sqlserver安全篇之_手工创建TLS用到的pfx证书文件

Sqlserver官方提供的Windows Powershell脚本 https://learn.microsoft.com/zh-cn/sql/database-engine/configure-windows/configure-sql-server-encryption?viewsql-server-ver16 # Define parameters $certificateParams {Type "SSLServerAuthentication"Subje…...

Linux12-UDP\TCP

一、UDP 1.特点: 尽最大努力交付,存在丢包的可能 无连接 面向数据报 机制简单,传输效率高 2.应用场景: 1.画面传输 VNC 直播:要求实时性高、允许数据丢失、 二、TCP 1.特点: 面向数据流(流式套接字) 建立连接 安全可靠的传输协议 三次握手:TCP建立连接时,…...

Tailwind CSS 问题:npm error could not determine executable to run

问题与处理策略 问题描述 npx tailwindcss init -p在使用 Tailwind CSS 的前端项目中&#xff0c;执行上述指令&#xff0c;即初始化 Tailwind CSS 时&#xff0c;报如下错误 npm error could not determine executable to run# 报错npm 错误无法确定要运行的可执行文件问题…...

C# 实现鼠标轨迹录制与回放自动化功能(附源码)

在软件自动化测试或者重复性办公任务中&#xff0c;鼠标操作的自动化可以大大减少人工干预&#xff0c;提高工作效率。这里将详细介绍如何使用 C# 实现鼠标轨迹的录制与回放功能&#xff0c;代码结构清晰&#xff0c;具有较强的扩展性。 引用 NuGet 包 在开发这个功能时&…...

【HeadFirst系列之HeadFirst设计模式】第14天之与设计模式相处:真实世界中的设计模式

与设计模式相处&#xff1a;真实世界中的设计模式 设计模式是软件开发中的经典解决方案&#xff0c;它们帮助我们解决常见的设计问题&#xff0c;并提高代码的可维护性和可扩展性。在《Head First设计模式》一书中&#xff0c;作者通过生动的案例和通俗的语言&#xff0c;深入…...

自由学习记录(42)

可能会出现到后面没有教程可以看&#xff0c;走不动&#xff0c;&#xff0c;但还是尝试吧 过程远比想象的要多 那连Live2d的这些脚本怎么控制的都要了解一下 ------------ 文件类型和扩展名 | 编辑手册 | Live2D Manuals & Tutorials 全部导入之后 在这下载SDK Live2D…...

mac安装nvm=>node=>nrm

下载并安装 NVM 运行以下命令下载并安装 NVM&#xff1a; curl -o- https://raw.githubusercontent.com/nvm-sh/nvm/v0.39.4/install.sh | bash 配置环境变量 vim ~/.zshrc 按 i 将如下代码复制进去&#xff0c;controlc &#xff0c;再按 :wq完成编辑 export NVM_DIR…...

excel vlookup的精确查询、模糊查询、反向查询、多列查询

目录 入门 精确查询 模糊查询 反向查询 (搭配 if 函数) 多列查询 (搭配 match 函数) 入门 精确查询 需求: 查找 学生编号是008 所在的班级 操作: 在I2单元格输入公式如下,VLOOKUP(H2,B1:E12,4,FALSE), 得出结果 看一下vlookup 公式每一个参数应该怎么写? 语法: vlookup…...

安装remixd,在VScode创建hardhat

在终端&#xff0c;以管理员身份&#xff0c;cmd 需要科学上网 npm install -g remix-project/remixd 在vscode插件中&#xff0c;安装solidity插件&#xff0c;是暗灰色那款 1.将nodeJs的版本升级至18以上 2.在vscode打开一个新的文件&#xff0c;在终端输入 npx hardhat 3.…...

【Python爬虫】利用代理IP爬取跨境电商AI选品分析

引言 随着DeepSeek的流行&#xff0c;越来越多的用户开始尝试将AI工具融入到日常工作当中&#xff0c;借助AI的强大功能提高工作效率。最近又掀起了一波企业出海的小高潮&#xff0c;那么如果是做跨境电商业务&#xff0c;怎么将AI融入工作流中呢&#xff1f;在做跨境电商的时候…...

捣鼓180天,我写了一个相册小程序

&#x1f64b;为什么要做土著相册这样一个产品&#xff1f; ➡️在高压工作之余&#xff0c;我喜欢浏览B站上的熊猫幼崽视频来放松心情。有天在家族群里看到了大嫂分享的侄女卖萌照片&#xff0c;同样感到非常解压。于是开始翻阅过去的聊天记录&#xff0c;却发现部分图片和视…...

Linux 上离线安装 python3

在Linux系统上进行离线安装 Python3&#xff0c;通常是因为目标机器没有网络连接。以下是一个通用的步骤指南&#xff0c;帮助你在这种情况下成功安装Python 3&#xff1a; 下载安装包 选择一台有网络连接的机器&#xff1a;这台机器的操作系统应该尽可能与目标机器相同或相似…...

洛谷 P1480 A/B Problem(高精度详解)c++

题目链接&#xff1a;P1480 A/B Problem - 洛谷 1.题目分析 1&#xff1a;说明这里是高精度除以低精度的形式&#xff0c;为什么不是高精度除以高精度的形式&#xff0c;是因为它很少见&#xff0c;它的模拟方式是用高精度减法来做的&#xff0c;并不能用小学列竖式的方法模拟…...

图像滑块对比功能的开发记录

背景介绍 最近&#xff0c;公司需要开发一款在线图像压缩工具&#xff0c;其中的一个关键功能是让用户直观地比较压缩前后的图像效果。因此&#xff0c;我们设计了一个对比组件&#xff0c;它允许用户通过拖动滑块&#xff0c;动态调整两张图像的显示区域&#xff0c;从而清晰…...

电商行业门店管理软件架构设计与数据可视化实践

一、行业痛点与核心诉求 在电商多平台运营成为主流的背景下,企业普遍面临三大管理难题: ​数据碎片化:某头部服饰品牌2023年运营报告显示,其分布在8个平台的162家门店,日均产生23万条订单数据,但财务部门需要5个工作日才能完成跨平台利润核算。​成本核算失真:行业调研…...

基于Arcgis的python脚本实现相邻矢量面的高度字段取平均值

文章目录 背景效果实现逻辑步骤1、准备数据2、python脚本3、执行通过脚本工具箱来执行背景 在地理信息系统(GIS)数据处理或三维建模等实际应用场景中,我们常常会遇到需要对矢量面数据进行精细化处理的需求。其中一个常见的任务便是对相邻的矢量面中的高度字段开展特定操作。…...

基于uniapp的蓝牙打印功能(佳博打印机已测试)

相关步骤 1.蓝牙打印与低功耗打印的区别2.蓝牙打印流程2.1 搜索蓝牙2.2 连接蓝牙 3.连接蓝牙设备4.获取服务5.写入命令源码gbk.jsglobalindex.ts 1.蓝牙打印与低功耗打印的区别 低功耗蓝牙是一种无线、低功耗个人局域网&#xff0c;运行在 2.4 GHz ISM 频段 1、低功耗蓝牙能够…...

Golang的网络流量控制

# Golang的网络流量控制 什么是网络流量控制&#xff1f; 网络流量控制是指针对网络数据传输过程中的流量进行管理和调控的一种技术手段。通过网络流量控制&#xff0c;我们可以对网络中的数据传输速率、带宽使用情况、数据包丢失率等进行监控和调整&#xff0c;以达到优化网络…...

在Spring Boot + MyBatis中优雅处理多表数据清洗:基于XML的配置化方案

问题背景 在实际业务中&#xff0c;我们常会遇到数据冗余问题。例如&#xff0c;一个公司表&#xff08;sys_company&#xff09;中存在多条相同公司名的记录&#xff0c;但只有一条有效&#xff08;del_flag0&#xff09;&#xff0c;其余需要删除。删除前需将关联表&#xf…...

Python教程(一):基本语法、流程控制、数据容器

Python&#xff08;一&#xff09; 文章目录 Python&#xff08;一&#xff09;一、基础语法二、数据类型2.1 字符串2.2 空值2.3 类型转换&运算符 三、流程控制3.1 条件判断3.2 循环3.2.1 while循环3.2.2 for循环 四、数据结构4.1 字符串str4.1.1 字符串的格式化输出4.1.1.…...

Spring Boot 3.x 核心注解详解与最佳实践

Spring Boot 3.x 核心注解详解与最佳实践 前言 随着Spring Boot 3.x的正式发布&#xff0c;这个基于Spring Framework 6的里程碑版本带来了诸多新特性。本文将深入剖析Spring Boot 3.x的核心注解体系&#xff0c;结合代码示例讲解其作用及使用场景&#xff0c;助您快速掌握新…...

【AI深度学习基础】PyTorch初探

引言 PyTorch 是由 Facebook 开源的深度学习框架&#xff0c;专门针对 GPU 加速的深度神经网络编程&#xff0c;它的核心概念包括张量&#xff08;Tensor&#xff09;、计算图和自动求导机制。PyTorch作为Facebook开源的深度学习框架&#xff0c;凭借其动态计算图和直观的API设…...

Windows下安装VMware Workstation 17并设置支持MacOS

VMware Workstation 17 介绍 VMware Workstation 17 是 VMware 公司推出的一款强大的桌面虚拟化软件&#xff0c;适用于 Windows 、 Linux 和FreeBSD等操作系统。它允许用户在单一物理计算机上创建、运行和管理多个虚拟机&#xff08;VM&#xff09;&#xff0c;每个虚拟机都可…...

Mysql-主从搭建如何指定库表同步以及新增库表同步

背景&#xff1a; 当主库数据量过大&#xff0c;从库仅需要同步A库的所有表&#xff0c;并且在后续运行中&#xff0c;又提出需要在从库新增B库的users表进行同步。本文会详细列出过程与具体命令&#xff0c;并告诉你其中的深坑&#xff01; 步骤一&#xff1a; 修改从库参数…...

爬虫逆向:脱壳工具Youpk的使用详解

更多内容请见: 爬虫和逆向教程-专栏介绍和目录 文章目录 1. Youpk 简介1.1 Youpk介绍1.2 Youpk支持场景1.3 Youpk基本流程1.4 使用 Youpk 脱壳步骤1.5 常用的脱壳工具对比2. Youpk 的安装与使用2.1 安装 Youpk2.2 使用 Youpk 脱壳3. 脱壳后的 Dex 文件分析3.1 使用 JADX 反编译…...

UE4 组件 (对话组件)

制作一个可以生成对话气泡&#xff0c;显示对话台词的简单组件。这个组件要的变量&#xff1a;台词&#xff08;外部传入&#xff09;。功能&#xff1a;开始对话&#xff08;生成气泡UI&#xff09; &#xff0c;结束对话。 一、对话组件创建 二、开始对话事件 1、注意这里获…...

LeetCode 2588.统计美丽子数组数目:前缀和 + 位运算(异或) + 哈希表

【LetMeFly】2588.统计美丽子数组数目&#xff1a;前缀和 位运算(异或) 哈希表 力扣题目链接&#xff1a;https://leetcode.cn/problems/count-the-number-of-beautiful-subarrays/ 给你一个下标从 0 开始的整数数组nums 。每次操作中&#xff0c;你可以&#xff1a; 选择…...

blender看不到导入的模型

参考&#xff1a;blender 快捷键 常见问题_blender材质预览快捷键-CSDN博客 方法一&#xff1a;视图-裁剪起点&#xff0c;设置一个很大的值 方法二&#xff1a;选中所有对象&#xff0c;对齐视图-视图对齐活动项-选择一个视图...