当前位置: 首页 > news >正文

MoEs and Transformers 笔记

ref:https://huggingface.co/blog/zh/moe#%E7%94%A8router-z-loss%E7%A8%B3%E5%AE%9A%E6%A8%A1%E5%9E%8B%E8%AE%AD%E7%BB%83
MoEs and Transformers

Transformer 类模型明确表明,增加参数数量可以提高性能,因此谷歌使用 GShard 尝试将 Transformer 模型的参数量扩展到超过 6000 亿并不令人惊讶。

GShard 将在编码器和解码器中的每个前馈网络 (FFN) 层中的替换为使用 Top-2 门控的混合专家模型 (MoE) 层。下图展示了编码器部分的结构。这种架构对于大规模计算非常有效: 当扩展到多个设备时,MoE 层在不同设备间共享,而其他所有层则在每个设备上复制。我们将在 “让 MoE 起飞” 部分对这一点进行更详细的讨论。
在这里插入图片描述

为了保持负载平衡和训练效率,GShard 的作者除了引入了上一节中讨论的类似辅助损失外,还引入了一些关键变化:

随机路由: 在 Top-2 设置中,我们始终选择排名最高的专家,但第二个专家是根据其权重比例随机选择的。
专家容量: 我们可以设定一个阈值,定义一个专家能处理多少令牌。如果两个专家的容量都达到上限,令牌就会溢出,并通过残差连接传递到下一层,或在某些情况下被完全丢弃。专家容量是 MoE 中最重要的概念之一。为什么需要专家容量呢?因为所有张量的形状在编译时是静态确定的,我们无法提前知道多少令牌会分配给每个专家,因此需要一个固定的容量因子。
GShard 的工作对适用于 MoE 的并行计算模式也做出了重要贡献,但这些内容的讨论超出了这篇博客的范围。

注意: 在推理过程中,只有部分专家被激活。同时,有些计算过程是共享的,例如自注意力 (self-attention) 机制,它适用于所有令牌。这就解释了为什么我们可以使用相当于 12B 稠密模型的计算资源来运行一个包含 8 个专家的 47B 模型。如果我们采用 Top-2 门控,模型会使用高达 14B 的参数。但是,由于自注意力操作 (专家间共享) 的存在,实际上模型运行时使用的参数数量是 12B。

Switch Transformers
尽管混合专家模型 (MoE) 显示出了很大的潜力,但它们在训练和微调过程中存在稳定性问题。Switch Transformers 是一项非常激动人心的工作,它深入研究了这些话题。作者甚至在 Hugging Face 上发布了一个 1.6 万亿参数的 MoE,拥有 2048 个专家,你可以使用 transformers 库来运行它。Switch Transformers 实现了与 T5-XXL 相比 4 倍的预训练速度提升。
就像在 GShard 中一样,作者用混合专家模型 (MoE) 层替换了前馈网络 (FFN) 层。Switch Transformers 提出了一个 Switch Transformer 层,它接收两个输入 (两个不同的令牌) 并拥有四个专家。
在这里插入图片描述
与最初使用至少两个专家的想法相反,Switch Transformers 采用了简化的单专家策略。这种方法的效果包括:

减少门控网络 (路由) 计算负担
每个专家的批量大小至少可以减半
降低通信成本
保持模型质量

Switch Transformers 采用了编码器 - 解码器的架构,实现了与 T5 类似的混合专家模型 (MoE) 版本。GLaM 这篇工作探索了如何使用仅为原来 1/3 的计算资源 (因为 MoE 模型在训练时需要的计算量较少,从而能够显著降低碳足迹) 来训练与 GPT-3 质量相匹配的模型来提高这些模型的规模。作者专注于仅解码器 (decoder-only) 的模型以及少样本和单样本评估,而不是微调。他们使用了 Top-2 路由和更大的容量因子。此外,他们探讨了将容量因子作为一个动态度量,根据训练和评估期间所使用的计算量进行调整。

用 Router z-loss 稳定模型训练
之前讨论的平衡损失可能会导致稳定性问题。我们可以使用许多方法来稳定稀疏模型的训练,但这可能会牺牲模型质量。例如,引入 dropout 可以提高稳定性,但会导致模型质量下降。另一方面,增加更多的乘法分量可以提高质量,但会降低模型稳定性。

ST-MoE 引入的 Router z-loss 在保持了模型性能的同时显著提升了训练的稳定性。这种损失机制通过惩罚门控网络输入的较大 logits 来起作用,目的是促使数值的绝对大小保持较小,这样可以有效减少计算中的舍入误差。这一点对于那些依赖指数函数进行计算的门控网络尤其重要。

专家的数量对预训练有何影响?
增加更多专家可以提升处理样本的效率和加速模型的运算速度,但这些优势随着专家数量的增加而递减 (尤其是当专家数量达到 256 或 512 之后更为明显)。同时,这也意味着在推理过程中,需要更多的显存来加载整个模型。值得注意的是,Switch Transformers 的研究表明,其在大规模模型中的特性在小规模模型下也同样适用,即便是每层仅包含 2、4 或 8 个专家。

对于开源的混合专家模型 (MoE),你可以关注下面这些:

Switch Transformers (Google): 基于 T5 的 MoE 集合,专家数量从 8 名到 2048 名。最大的模型有 1.6 万亿个参数。
NLLB MoE (Meta): NLLB 翻译模型的一个 MoE 变体。
OpenMoE: 社区对基于 Llama 的模型的 MoE 尝试。
Mixtral 8x7B (Mistral): 一个性能超越了 Llama 2 70B 的高质量混合专家模型,并且具有更快的推理速度。此外,还发布了一个经过指令微调的模型。有关更多信息,可以在 Mistral 的 公告博客文章 中了解。

REF:https://github.com/kyegomez/SwitchTransformers/blob/main/switch_transformers/model.py

import torch
import torch.nn.functional as F
from torch import Tensor, nn
from zeta.nn import FeedForward, MultiQueryAttentionclass SwitchGate(nn.Module):"""SwitchGate module for MoE (Mixture of Experts) model.Args:dim (int): Input dimension.num_experts (int): Number of experts.capacity_factor (float, optional): Capacity factor for sparsity. Defaults to 1.0.*args: Variable length argument list.**kwargs: Arbitrary keyword arguments."""def __init__(self,dim,num_experts: int,capacity_factor: float = 1.0,epsilon: float = 1e-6,*args,**kwargs,):super().__init__()self.dim = dimself.num_experts = num_expertsself.capacity_factor = capacity_factorself.epsilon = epsilonself.w_gate = nn.Linear(dim, num_experts)def forward(self, x: Tensor, use_aux_loss=False):"""Forward pass of the SwitchGate module.Args:x (Tensor): Input tensor.Returns:Tensor: Gate scores."""# Compute gate scoresgate_scores = F.softmax(self.w_gate(x), dim=-1)# Determine the top-1 expert for each tokencapacity = int(self.capacity_factor * x.size(0))top_k_scores, top_k_indices = gate_scores.topk(1, dim=-1)# Mask to enforce sparsitymask = torch.zeros_like(gate_scores).scatter_(1, top_k_indices, 1)# Combine gating scores with the maskmasked_gate_scores = gate_scores * mask# Denominatorsdenominators = (masked_gate_scores.sum(0, keepdim=True) + self.epsilon)# Norm gate scores to sum to the capacitygate_scores = (masked_gate_scores / denominators) * capacityif use_aux_loss:load = gate_scores.sum(0)  # Sum over all examplesimportance = gate_scores.sum(1)  # Sum over all experts# Aux loss is mean suqared difference between load and importanceloss = ((load - importance) ** 2).mean()return gate_scores, lossreturn gate_scores, Noneclass SwitchMoE(nn.Module):"""A module that implements the Switched Mixture of Experts (MoE) architecture.Args:dim (int): The input dimension.hidden_dim (int): The hidden dimension of the feedforward network.output_dim (int): The output dimension.num_experts (int): The number of experts in the MoE.capacity_factor (float, optional): The capacity factor that controls the capacity of the MoE. Defaults to 1.0.mult (int, optional): The multiplier for the hidden dimension of the feedforward network. Defaults to 4.*args: Variable length argument list.**kwargs: Arbitrary keyword arguments.Attributes:dim (int): The input dimension.hidden_dim (int): The hidden dimension of the feedforward network.output_dim (int): The output dimension.num_experts (int): The number of experts in the MoE.capacity_factor (float): The capacity factor that controls the capacity of the MoE.mult (int): The multiplier for the hidden dimension of the feedforward network.experts (nn.ModuleList): The list of feedforward networks representing the experts.gate (SwitchGate): The switch gate module."""def __init__(self,dim: int,hidden_dim: int,output_dim: int,num_experts: int,capacity_factor: float = 1.0,mult: int = 4,use_aux_loss: bool = False,*args,**kwargs,):super().__init__()self.dim = dimself.hidden_dim = hidden_dimself.output_dim = output_dimself.num_experts = num_expertsself.capacity_factor = capacity_factorself.mult = multself.use_aux_loss = use_aux_lossself.experts = nn.ModuleList([FeedForward(dim, dim, mult, *args, **kwargs)for _ in range(num_experts)])self.gate = SwitchGate(dim,num_experts,capacity_factor,)def forward(self, x: Tensor):"""Forward pass of the SwitchMoE module.Args:x (Tensor): The input tensor.Returns:Tensor: The output tensor of the MoE."""# (batch_size, seq_len, num_experts)gate_scores, loss = self.gate(x, use_aux_loss=self.use_aux_loss)# Dispatch to expertsexpert_outputs = [expert(x) for expert in self.experts]# Check if any gate scores are nan and handleif torch.isnan(gate_scores).any():print("NaN in gate scores")gate_scores[torch.isnan(gate_scores)] = 0# Stack and weight outputsstacked_expert_outputs = torch.stack(expert_outputs, dim=-1)  # (batch_size, seq_len, output_dim, num_experts)if torch.isnan(stacked_expert_outputs).any():stacked_expert_outputs[torch.isnan(stacked_expert_outputs)] = 0# Combine expert outputs and gating scoresmoe_output = torch.sum(gate_scores.unsqueeze(-2) * stacked_expert_outputs, dim=-1)return moe_output, lossclass SwitchTransformerBlock(nn.Module):"""SwitchTransformerBlock is a module that represents a single block of the Switch Transformer model.Args:dim (int): The input dimension of the block.heads (int): The number of attention heads.dim_head (int): The dimension of each attention head.mult (int, optional): The multiplier for the hidden dimension in the feed-forward network. Defaults to 4.dropout (float, optional): The dropout rate. Defaults to 0.1.depth (int, optional): The number of layers in the block. Defaults to 12.num_experts (int, optional): The number of experts in the SwitchMoE layer. Defaults to 6.*args: Variable length argument list.**kwargs: Arbitrary keyword arguments.Attributes:dim (int): The input dimension of the block.heads (int): The number of attention heads.dim_head (int): The dimension of each attention head.mult (int): The multiplier for the hidden dimension in the feed-forward network.dropout (float): The dropout rate.attn_layers (nn.ModuleList): List of MultiQueryAttention layers.ffn_layers (nn.ModuleList): List of SwitchMoE layers.Examples:>>> block = SwitchTransformerBlock(dim=512, heads=8, dim_head=64)>>> x = torch.randn(1, 10, 512)>>> out = block(x)>>> out.shape"""def __init__(self,dim: int,heads: int,dim_head: int,mult: int = 4,dropout: float = 0.1,num_experts: int = 3,*args,**kwargs,):super().__init__()self.dim = dimself.heads = headsself.dim_head = dim_headself.mult = multself.dropout = dropoutself.attn = MultiQueryAttention(dim, heads, qk_ln=True * args, **kwargs)self.ffn = SwitchMoE(dim, dim * mult, dim, num_experts, *args, **kwargs)self.add_norm = nn.LayerNorm(dim)def forward(self, x: Tensor):"""Forward pass of the SwitchTransformerBlock.Args:x (Tensor): The input tensor.Returns:Tensor: The output tensor."""resi = xx, _, _ = self.attn(x)x = x + resix = self.add_norm(x)add_normed = x##### MoE #####x, _ = self.ffn(x)x = x + add_normedx = self.add_norm(x)return xclass SwitchTransformer(nn.Module):"""SwitchTransformer is a PyTorch module that implements a transformer model with switchable experts.Args:num_tokens (int): The number of tokens in the input vocabulary.dim (int): The dimensionality of the token embeddings and hidden states.heads (int): The number of attention heads.dim_head (int, optional): The dimensionality of each attention head. Defaults to 64.mult (int, optional): The multiplier for the hidden dimension in the feed-forward network. Defaults to 4.dropout (float, optional): The dropout rate. Defaults to 0.1.num_experts (int, optional): The number of experts in the switchable experts mechanism. Defaults to 3.*args: Additional positional arguments.**kwargs: Additional keyword arguments."""def __init__(self,num_tokens: int,dim: int,heads: int,dim_head: int = 64,mult: int = 4,dropout: float = 0.1,num_experts: int = 3,depth: int = 4,*args,**kwargs,):super().__init__()self.num_tokens = num_tokensself.dim = dimself.heads = headsself.dim_head = dim_headself.mult = multself.dropout = dropoutself.num_experts = num_expertsself.depth = depthself.embedding = nn.Embedding(num_tokens, dim)self.layers = nn.ModuleList([])for _ in range(depth):self.layers.append(SwitchTransformerBlock(dim,heads,dim_head,mult,dropout,num_experts,*args,**kwargs,))self.to_out = nn.Sequential(nn.Softmax(dim=-1),nn.LayerNorm(dim),nn.Linear(dim, num_tokens),)def forward(self, x: Tensor) -> Tensor:"""Forward pass of the SwitchTransformer.Args:x (Tensor): The input tensor of shape (batch_size, sequence_length).Returns:Tensor: The output tensor of shape (batch_size, sequence_length, num_tokens)."""# Embed tokens through embedding layerx = self.embedding(x)# Pass through the transformer block with MoE, it's in modulelistfor layer in self.layers:x = layer(x)# Project to output tokensx = self.to_out(x)return x

相关文章:

MoEs and Transformers 笔记

ref:https://huggingface.co/blog/zh/moe#%E7%94%A8router-z-loss%E7%A8%B3%E5%AE%9A%E6%A8%A1%E5%9E%8B%E8%AE%AD%E7%BB%83 MoEs and Transformers Transformer 类模型明确表明,增加参数数量可以提高性能,因此谷歌使用 GShard 尝试将 Transformer 模型…...

在Linux中,如何禁用root用户直接SSH登录?

在Linux中禁用root用户的直接SSH登录是为了增强系统的安全性,因为允许root用户通过SSH远程登录会增加服务器被暴力破解的风险。以下是在Linux系统中禁止root用户直接SSH登录的步骤: 编辑SSH配置文件: 打开/etc/ssh/sshd_config文件&#xff…...

用Python实现简单的任务自动化

目录 1. 自动发送邮件提醒 2. 自动备份文件 3. 自动下载网页内容 总结 在现代工作和生活中,任务自动化可以极大地提高效率和准确性。Python,作为一种功能强大且易于学习的编程语言,是实现任务自动化的理想选择。本文将通过几个简单而实用的案例,展示如何用Python实现任…...

为AI聊天工具添加一个知识系统 之26 资源存储库和资源管理器

本文要点 资源存储库 为了能完成本项目(“为AI聊天工具增加一个知识系统”,其核心能力是“语言处理” ,该能力的最大挑战 当仁不让的应该是自然语言处理)的设计,我们考虑一个问题:在自然语言处理中&#…...

Windows10环境下安装RabbitMq折腾记

最近有个老项目需要迁移到windows10环境,用的是比较老的rabbitmq安装包,如下所示。经过一番折腾,死活服务起不来,最终果断放弃老版本启用新版本。现在把折腾过程记录下: 一、安装erlang 安装完成后的目录结构&#xff…...

对快速由表及里说拜拜/如何正确运用由表及里

你是不是还:看到一男子拖走一女子就以为小情侣吵架而已(可能人贩子);看到男友对你好个几次就从此死心塌地(可能有手就行,细节装装而已)结果耽误终身;看到女同事对你微笑不排斥就以为…...

spring mvc源码学习笔记之八

本文说点儿简单的。 如果你想研究基于 XML 配置的 spring mvc 的话,可以简单扫一眼本文。 在基于 XML 配置的 spring mvc 开发中,我们主要就是通过 spring 提供的各种标签来配置。 但是,大家是不是都有个疑问,spring 到底给我们提…...

探秘5网口IIOT网关

在当今这个科技飞速发展的时代,工业领域正经历着一场深刻的变革,而工业物联网网关在其中扮演着至关重要的角色。 什么是IIOT网关 工业物联网网关,简单来说,就是连接工业现场设备与云端或者上层管理系统的关键桥梁。 而明达技术研…...

左神算法基础巩固--5

文章目录 前缀树生成前缀树查询前缀树查询字符串加入过几次查询所有加入的字符串中,有几个是以pre这个字符串作为前缀 删除前缀树中的某个字符串 贪心算法解题 前缀树 生成前缀树 要想生成一棵前缀树,需要先创建一个根节点,这个根节点有26条…...

Python的Matplotlib库应用(超详细教程)

目录 一、环境搭建 1.1 配置matplotlib库 1.2 配置seaborn库 1.3 配置Skimage库 二、二维图像 2.1 曲线(直线)可视化 2.2 曲线(虚线)可视化 2.3 直方图 2.4 阶梯图 三、三维图像 3.1 3D曲面图 3.2 3D散点图 3.3 3D散…...

负载均衡服务器要怎么配置?

目录 一、概述: 二、硬件配置: 三、操作系统配置: 四、负载均衡软件: 五、网络配置: 六、软件安装步骤: 6.1 安装 Nginx 6.2 安装 LVS 6.3 安装 HAProxy 6.4 安装 Keepalived 一、概述&#xff1…...

CANopen转EtherCAT网关连接伺服驱动

在现代工业自动化领域,CANopen和EtherCAT是两种常见的通信协议,各自在不同的应用场景中发挥着重要作用。然而,随着工业自动化系统的日益复杂化,不同设备间的通信需求也变得多样化。因此,如何实现不同协议设备之间的无缝…...

自动化测试脚本实践:基于 Bash 的模块化测试框架

前言 在现代软件开发中,测试自动化是确保软件质量和稳定性的核心手段之一。随着开发周期的缩短和功能模块的增多,手动测试逐渐无法满足高效性和准确性的需求。因此,测试人员需要依赖自动化工具来提升测试效率,减少人为干预和错误。…...

WebSocket 测试入门篇

Websocket 是一种用于 H5 浏览器的实时通讯协议,可以做到数据的实时推送,可适用于广泛的工作环境,例如客服系统、物联网数据传输系统, 基础介绍 我们平常接触最多的是 http 协议的接口,http 协议是请求与响应的模式&…...

Apache Traffic存在SQL注入漏洞(CVE-2024-45387)

免责声明: 本文旨在提供有关特定漏洞的深入信息,帮助用户充分了解潜在的安全风险。发布此信息的目的在于提升网络安全意识和推动技术进步,未经授权访问系统、网络或应用程序,可能会导致法律责任或严重后果。因此,作者不对读者基于本文内容所采取的任何行为承担责任。读者在…...

Centos7使用yum工具出现 Could not resolve host: mirrorlist.centos.org

在 CentOS 7 中使用 yum 工具时,出现 "Could not resolve host: mirrorlist.centos.org" 的错误,一般情况是因为默认的镜像源无法访问。 以下是一些常用的解决方法: 检查网络连接:首先使用 ping 命令测试网络连接是否…...

zookeeper shell操作和zookeeper 典型应用(配置中心、集群选举服务、分布式锁)

文章目录 引言I zookeeper客户端命令查看子节点 ls创建子节点 create获取节点信息 get更新节点数据 set删除节点 delete\ rmrII 监听机制node1:设置监听node3:修改监听节点node1:得到监听反馈III zookeeper 典型应用分布式锁集群选举服务数据发布/订阅(配置中心)引言 zk 的…...

Vue中Watch使用监听修改变动

使用注意 监听一个值时 多个值时...

Lua语言的文件IO

1、我们都知道,在任何语言当中都有输入输出,比如c语言当中就有很多printf,scanf,get ,put,gets,puts,文件io:open,read,write,close,标准io:fopen,fread,fwrite,fclose.在lua语言当中,也有相同的一些输入输出特性,叫io.open,io.re…...

C语言基本知识复习浓缩版:输出函数printf

输出函数printf学习 printf()的作用是将文本输出到屏幕上使用之前需要先引入stdio.h头文件printf函数在使用的时候,至少需要一个参数 printf() 是 C 语言标准库中的一个函数,用于将格式化的文本输出到标准输出设备(通常是屏幕)。…...

Auto-Coder使用GPT-4o完成:在用TabPFN这个模型构建一个预测未来3天涨跌的分类任务

通过akshare库,获取股票数据,并生成TabPFN这个模型 可以识别、处理的格式,写一个完整的预处理示例,并构建一个预测未来 3 天股价涨跌的分类任务 用TabPFN这个模型构建一个预测未来 3 天股价涨跌的分类任务,进行预测并输…...

Golang dig框架与GraphQL的完美结合

将 Go 的 Dig 依赖注入框架与 GraphQL 结合使用,可以显著提升应用程序的可维护性、可测试性以及灵活性。 Dig 是一个强大的依赖注入容器,能够帮助开发者更好地管理复杂的依赖关系,而 GraphQL 则是一种用于 API 的查询语言,能够提…...

MySQL 8.0 OCP 英文题库解析(十三)

Oracle 为庆祝 MySQL 30 周年,截止到 2025.07.31 之前。所有人均可以免费考取原价245美元的MySQL OCP 认证。 从今天开始,将英文题库免费公布出来,并进行解析,帮助大家在一个月之内轻松通过OCP认证。 本期公布试题111~120 试题1…...

Redis数据倾斜问题解决

Redis 数据倾斜问题解析与解决方案 什么是 Redis 数据倾斜 Redis 数据倾斜指的是在 Redis 集群中,部分节点存储的数据量或访问量远高于其他节点,导致这些节点负载过高,影响整体性能。 数据倾斜的主要表现 部分节点内存使用率远高于其他节…...

算法笔记2

1.字符串拼接最好用StringBuilder&#xff0c;不用String 2.创建List<>类型的数组并创建内存 List arr[] new ArrayList[26]; Arrays.setAll(arr, i -> new ArrayList<>()); 3.去掉首尾空格...

从 GreenPlum 到镜舟数据库:杭银消费金融湖仓一体转型实践

作者&#xff1a;吴岐诗&#xff0c;杭银消费金融大数据应用开发工程师 本文整理自杭银消费金融大数据应用开发工程师在StarRocks Summit Asia 2024的分享 引言&#xff1a;融合数据湖与数仓的创新之路 在数字金融时代&#xff0c;数据已成为金融机构的核心竞争力。杭银消费金…...

libfmt: 现代C++的格式化工具库介绍与酷炫功能

libfmt: 现代C的格式化工具库介绍与酷炫功能 libfmt 是一个开源的C格式化库&#xff0c;提供了高效、安全的文本格式化功能&#xff0c;是C20中引入的std::format的基础实现。它比传统的printf和iostream更安全、更灵活、性能更好。 基本介绍 主要特点 类型安全&#xff1a…...

Matlab实现任意伪彩色图像可视化显示

Matlab实现任意伪彩色图像可视化显示 1、灰度原始图像2、RGB彩色原始图像 在科研研究中&#xff0c;如何展示好看的实验结果图像非常重要&#xff01;&#xff01;&#xff01; 1、灰度原始图像 灰度图像每个像素点只有一个数值&#xff0c;代表该点的​​亮度&#xff08;或…...

数据库正常,但后端收不到数据原因及解决

从代码和日志来看&#xff0c;后端SQL查询确实返回了数据&#xff0c;但最终user对象却为null。这表明查询结果没有正确映射到User对象上。 在前后端分离&#xff0c;并且ai辅助开发的时候&#xff0c;很容易出现前后端变量名不一致情况&#xff0c;还不报错&#xff0c;只是单…...

运行vue项目报错 errors and 0 warnings potentially fixable with the `--fix` option.

报错 找到package.json文件 找到这个修改成 "lint": "eslint --fix --ext .js,.vue src" 为elsint有配置结尾换行符&#xff0c;最后运行&#xff1a;npm run lint --fix...