MoEs and Transformers 笔记
ref:https://huggingface.co/blog/zh/moe#%E7%94%A8router-z-loss%E7%A8%B3%E5%AE%9A%E6%A8%A1%E5%9E%8B%E8%AE%AD%E7%BB%83
MoEs and Transformers
Transformer 类模型明确表明,增加参数数量可以提高性能,因此谷歌使用 GShard 尝试将 Transformer 模型的参数量扩展到超过 6000 亿并不令人惊讶。
GShard 将在编码器和解码器中的每个前馈网络 (FFN) 层中的替换为使用 Top-2 门控的混合专家模型 (MoE) 层。下图展示了编码器部分的结构。这种架构对于大规模计算非常有效: 当扩展到多个设备时,MoE 层在不同设备间共享,而其他所有层则在每个设备上复制。我们将在 “让 MoE 起飞” 部分对这一点进行更详细的讨论。
为了保持负载平衡和训练效率,GShard 的作者除了引入了上一节中讨论的类似辅助损失外,还引入了一些关键变化:
随机路由: 在 Top-2 设置中,我们始终选择排名最高的专家,但第二个专家是根据其权重比例随机选择的。
专家容量: 我们可以设定一个阈值,定义一个专家能处理多少令牌。如果两个专家的容量都达到上限,令牌就会溢出,并通过残差连接传递到下一层,或在某些情况下被完全丢弃。专家容量是 MoE 中最重要的概念之一。为什么需要专家容量呢?因为所有张量的形状在编译时是静态确定的,我们无法提前知道多少令牌会分配给每个专家,因此需要一个固定的容量因子。
GShard 的工作对适用于 MoE 的并行计算模式也做出了重要贡献,但这些内容的讨论超出了这篇博客的范围。
注意: 在推理过程中,只有部分专家被激活。同时,有些计算过程是共享的,例如自注意力 (self-attention) 机制,它适用于所有令牌。这就解释了为什么我们可以使用相当于 12B 稠密模型的计算资源来运行一个包含 8 个专家的 47B 模型。如果我们采用 Top-2 门控,模型会使用高达 14B 的参数。但是,由于自注意力操作 (专家间共享) 的存在,实际上模型运行时使用的参数数量是 12B。
Switch Transformers
尽管混合专家模型 (MoE) 显示出了很大的潜力,但它们在训练和微调过程中存在稳定性问题。Switch Transformers 是一项非常激动人心的工作,它深入研究了这些话题。作者甚至在 Hugging Face 上发布了一个 1.6 万亿参数的 MoE,拥有 2048 个专家,你可以使用 transformers 库来运行它。Switch Transformers 实现了与 T5-XXL 相比 4 倍的预训练速度提升。
就像在 GShard 中一样,作者用混合专家模型 (MoE) 层替换了前馈网络 (FFN) 层。Switch Transformers 提出了一个 Switch Transformer 层,它接收两个输入 (两个不同的令牌) 并拥有四个专家。
与最初使用至少两个专家的想法相反,Switch Transformers 采用了简化的单专家策略。这种方法的效果包括:
减少门控网络 (路由) 计算负担
每个专家的批量大小至少可以减半
降低通信成本
保持模型质量
Switch Transformers 采用了编码器 - 解码器的架构,实现了与 T5 类似的混合专家模型 (MoE) 版本。GLaM 这篇工作探索了如何使用仅为原来 1/3 的计算资源 (因为 MoE 模型在训练时需要的计算量较少,从而能够显著降低碳足迹) 来训练与 GPT-3 质量相匹配的模型来提高这些模型的规模。作者专注于仅解码器 (decoder-only) 的模型以及少样本和单样本评估,而不是微调。他们使用了 Top-2 路由和更大的容量因子。此外,他们探讨了将容量因子作为一个动态度量,根据训练和评估期间所使用的计算量进行调整。
用 Router z-loss 稳定模型训练
之前讨论的平衡损失可能会导致稳定性问题。我们可以使用许多方法来稳定稀疏模型的训练,但这可能会牺牲模型质量。例如,引入 dropout 可以提高稳定性,但会导致模型质量下降。另一方面,增加更多的乘法分量可以提高质量,但会降低模型稳定性。
ST-MoE 引入的 Router z-loss 在保持了模型性能的同时显著提升了训练的稳定性。这种损失机制通过惩罚门控网络输入的较大 logits 来起作用,目的是促使数值的绝对大小保持较小,这样可以有效减少计算中的舍入误差。这一点对于那些依赖指数函数进行计算的门控网络尤其重要。
专家的数量对预训练有何影响?
增加更多专家可以提升处理样本的效率和加速模型的运算速度,但这些优势随着专家数量的增加而递减 (尤其是当专家数量达到 256 或 512 之后更为明显)。同时,这也意味着在推理过程中,需要更多的显存来加载整个模型。值得注意的是,Switch Transformers 的研究表明,其在大规模模型中的特性在小规模模型下也同样适用,即便是每层仅包含 2、4 或 8 个专家。
对于开源的混合专家模型 (MoE),你可以关注下面这些:
Switch Transformers (Google): 基于 T5 的 MoE 集合,专家数量从 8 名到 2048 名。最大的模型有 1.6 万亿个参数。
NLLB MoE (Meta): NLLB 翻译模型的一个 MoE 变体。
OpenMoE: 社区对基于 Llama 的模型的 MoE 尝试。
Mixtral 8x7B (Mistral): 一个性能超越了 Llama 2 70B 的高质量混合专家模型,并且具有更快的推理速度。此外,还发布了一个经过指令微调的模型。有关更多信息,可以在 Mistral 的 公告博客文章 中了解。
REF:https://github.com/kyegomez/SwitchTransformers/blob/main/switch_transformers/model.py
import torch
import torch.nn.functional as F
from torch import Tensor, nn
from zeta.nn import FeedForward, MultiQueryAttentionclass SwitchGate(nn.Module):"""SwitchGate module for MoE (Mixture of Experts) model.Args:dim (int): Input dimension.num_experts (int): Number of experts.capacity_factor (float, optional): Capacity factor for sparsity. Defaults to 1.0.*args: Variable length argument list.**kwargs: Arbitrary keyword arguments."""def __init__(self,dim,num_experts: int,capacity_factor: float = 1.0,epsilon: float = 1e-6,*args,**kwargs,):super().__init__()self.dim = dimself.num_experts = num_expertsself.capacity_factor = capacity_factorself.epsilon = epsilonself.w_gate = nn.Linear(dim, num_experts)def forward(self, x: Tensor, use_aux_loss=False):"""Forward pass of the SwitchGate module.Args:x (Tensor): Input tensor.Returns:Tensor: Gate scores."""# Compute gate scoresgate_scores = F.softmax(self.w_gate(x), dim=-1)# Determine the top-1 expert for each tokencapacity = int(self.capacity_factor * x.size(0))top_k_scores, top_k_indices = gate_scores.topk(1, dim=-1)# Mask to enforce sparsitymask = torch.zeros_like(gate_scores).scatter_(1, top_k_indices, 1)# Combine gating scores with the maskmasked_gate_scores = gate_scores * mask# Denominatorsdenominators = (masked_gate_scores.sum(0, keepdim=True) + self.epsilon)# Norm gate scores to sum to the capacitygate_scores = (masked_gate_scores / denominators) * capacityif use_aux_loss:load = gate_scores.sum(0) # Sum over all examplesimportance = gate_scores.sum(1) # Sum over all experts# Aux loss is mean suqared difference between load and importanceloss = ((load - importance) ** 2).mean()return gate_scores, lossreturn gate_scores, Noneclass SwitchMoE(nn.Module):"""A module that implements the Switched Mixture of Experts (MoE) architecture.Args:dim (int): The input dimension.hidden_dim (int): The hidden dimension of the feedforward network.output_dim (int): The output dimension.num_experts (int): The number of experts in the MoE.capacity_factor (float, optional): The capacity factor that controls the capacity of the MoE. Defaults to 1.0.mult (int, optional): The multiplier for the hidden dimension of the feedforward network. Defaults to 4.*args: Variable length argument list.**kwargs: Arbitrary keyword arguments.Attributes:dim (int): The input dimension.hidden_dim (int): The hidden dimension of the feedforward network.output_dim (int): The output dimension.num_experts (int): The number of experts in the MoE.capacity_factor (float): The capacity factor that controls the capacity of the MoE.mult (int): The multiplier for the hidden dimension of the feedforward network.experts (nn.ModuleList): The list of feedforward networks representing the experts.gate (SwitchGate): The switch gate module."""def __init__(self,dim: int,hidden_dim: int,output_dim: int,num_experts: int,capacity_factor: float = 1.0,mult: int = 4,use_aux_loss: bool = False,*args,**kwargs,):super().__init__()self.dim = dimself.hidden_dim = hidden_dimself.output_dim = output_dimself.num_experts = num_expertsself.capacity_factor = capacity_factorself.mult = multself.use_aux_loss = use_aux_lossself.experts = nn.ModuleList([FeedForward(dim, dim, mult, *args, **kwargs)for _ in range(num_experts)])self.gate = SwitchGate(dim,num_experts,capacity_factor,)def forward(self, x: Tensor):"""Forward pass of the SwitchMoE module.Args:x (Tensor): The input tensor.Returns:Tensor: The output tensor of the MoE."""# (batch_size, seq_len, num_experts)gate_scores, loss = self.gate(x, use_aux_loss=self.use_aux_loss)# Dispatch to expertsexpert_outputs = [expert(x) for expert in self.experts]# Check if any gate scores are nan and handleif torch.isnan(gate_scores).any():print("NaN in gate scores")gate_scores[torch.isnan(gate_scores)] = 0# Stack and weight outputsstacked_expert_outputs = torch.stack(expert_outputs, dim=-1) # (batch_size, seq_len, output_dim, num_experts)if torch.isnan(stacked_expert_outputs).any():stacked_expert_outputs[torch.isnan(stacked_expert_outputs)] = 0# Combine expert outputs and gating scoresmoe_output = torch.sum(gate_scores.unsqueeze(-2) * stacked_expert_outputs, dim=-1)return moe_output, lossclass SwitchTransformerBlock(nn.Module):"""SwitchTransformerBlock is a module that represents a single block of the Switch Transformer model.Args:dim (int): The input dimension of the block.heads (int): The number of attention heads.dim_head (int): The dimension of each attention head.mult (int, optional): The multiplier for the hidden dimension in the feed-forward network. Defaults to 4.dropout (float, optional): The dropout rate. Defaults to 0.1.depth (int, optional): The number of layers in the block. Defaults to 12.num_experts (int, optional): The number of experts in the SwitchMoE layer. Defaults to 6.*args: Variable length argument list.**kwargs: Arbitrary keyword arguments.Attributes:dim (int): The input dimension of the block.heads (int): The number of attention heads.dim_head (int): The dimension of each attention head.mult (int): The multiplier for the hidden dimension in the feed-forward network.dropout (float): The dropout rate.attn_layers (nn.ModuleList): List of MultiQueryAttention layers.ffn_layers (nn.ModuleList): List of SwitchMoE layers.Examples:>>> block = SwitchTransformerBlock(dim=512, heads=8, dim_head=64)>>> x = torch.randn(1, 10, 512)>>> out = block(x)>>> out.shape"""def __init__(self,dim: int,heads: int,dim_head: int,mult: int = 4,dropout: float = 0.1,num_experts: int = 3,*args,**kwargs,):super().__init__()self.dim = dimself.heads = headsself.dim_head = dim_headself.mult = multself.dropout = dropoutself.attn = MultiQueryAttention(dim, heads, qk_ln=True * args, **kwargs)self.ffn = SwitchMoE(dim, dim * mult, dim, num_experts, *args, **kwargs)self.add_norm = nn.LayerNorm(dim)def forward(self, x: Tensor):"""Forward pass of the SwitchTransformerBlock.Args:x (Tensor): The input tensor.Returns:Tensor: The output tensor."""resi = xx, _, _ = self.attn(x)x = x + resix = self.add_norm(x)add_normed = x##### MoE #####x, _ = self.ffn(x)x = x + add_normedx = self.add_norm(x)return xclass SwitchTransformer(nn.Module):"""SwitchTransformer is a PyTorch module that implements a transformer model with switchable experts.Args:num_tokens (int): The number of tokens in the input vocabulary.dim (int): The dimensionality of the token embeddings and hidden states.heads (int): The number of attention heads.dim_head (int, optional): The dimensionality of each attention head. Defaults to 64.mult (int, optional): The multiplier for the hidden dimension in the feed-forward network. Defaults to 4.dropout (float, optional): The dropout rate. Defaults to 0.1.num_experts (int, optional): The number of experts in the switchable experts mechanism. Defaults to 3.*args: Additional positional arguments.**kwargs: Additional keyword arguments."""def __init__(self,num_tokens: int,dim: int,heads: int,dim_head: int = 64,mult: int = 4,dropout: float = 0.1,num_experts: int = 3,depth: int = 4,*args,**kwargs,):super().__init__()self.num_tokens = num_tokensself.dim = dimself.heads = headsself.dim_head = dim_headself.mult = multself.dropout = dropoutself.num_experts = num_expertsself.depth = depthself.embedding = nn.Embedding(num_tokens, dim)self.layers = nn.ModuleList([])for _ in range(depth):self.layers.append(SwitchTransformerBlock(dim,heads,dim_head,mult,dropout,num_experts,*args,**kwargs,))self.to_out = nn.Sequential(nn.Softmax(dim=-1),nn.LayerNorm(dim),nn.Linear(dim, num_tokens),)def forward(self, x: Tensor) -> Tensor:"""Forward pass of the SwitchTransformer.Args:x (Tensor): The input tensor of shape (batch_size, sequence_length).Returns:Tensor: The output tensor of shape (batch_size, sequence_length, num_tokens)."""# Embed tokens through embedding layerx = self.embedding(x)# Pass through the transformer block with MoE, it's in modulelistfor layer in self.layers:x = layer(x)# Project to output tokensx = self.to_out(x)return x
相关文章:

MoEs and Transformers 笔记
ref:https://huggingface.co/blog/zh/moe#%E7%94%A8router-z-loss%E7%A8%B3%E5%AE%9A%E6%A8%A1%E5%9E%8B%E8%AE%AD%E7%BB%83 MoEs and Transformers Transformer 类模型明确表明,增加参数数量可以提高性能,因此谷歌使用 GShard 尝试将 Transformer 模型…...
在Linux中,如何禁用root用户直接SSH登录?
在Linux中禁用root用户的直接SSH登录是为了增强系统的安全性,因为允许root用户通过SSH远程登录会增加服务器被暴力破解的风险。以下是在Linux系统中禁止root用户直接SSH登录的步骤: 编辑SSH配置文件: 打开/etc/ssh/sshd_config文件ÿ…...

用Python实现简单的任务自动化
目录 1. 自动发送邮件提醒 2. 自动备份文件 3. 自动下载网页内容 总结 在现代工作和生活中,任务自动化可以极大地提高效率和准确性。Python,作为一种功能强大且易于学习的编程语言,是实现任务自动化的理想选择。本文将通过几个简单而实用的案例,展示如何用Python实现任…...
为AI聊天工具添加一个知识系统 之26 资源存储库和资源管理器
本文要点 资源存储库 为了能完成本项目(“为AI聊天工具增加一个知识系统”,其核心能力是“语言处理” ,该能力的最大挑战 当仁不让的应该是自然语言处理)的设计,我们考虑一个问题:在自然语言处理中&#…...

Windows10环境下安装RabbitMq折腾记
最近有个老项目需要迁移到windows10环境,用的是比较老的rabbitmq安装包,如下所示。经过一番折腾,死活服务起不来,最终果断放弃老版本启用新版本。现在把折腾过程记录下: 一、安装erlang 安装完成后的目录结构ÿ…...

对快速由表及里说拜拜/如何正确运用由表及里
你是不是还:看到一男子拖走一女子就以为小情侣吵架而已(可能人贩子);看到男友对你好个几次就从此死心塌地(可能有手就行,细节装装而已)结果耽误终身;看到女同事对你微笑不排斥就以为…...
spring mvc源码学习笔记之八
本文说点儿简单的。 如果你想研究基于 XML 配置的 spring mvc 的话,可以简单扫一眼本文。 在基于 XML 配置的 spring mvc 开发中,我们主要就是通过 spring 提供的各种标签来配置。 但是,大家是不是都有个疑问,spring 到底给我们提…...

探秘5网口IIOT网关
在当今这个科技飞速发展的时代,工业领域正经历着一场深刻的变革,而工业物联网网关在其中扮演着至关重要的角色。 什么是IIOT网关 工业物联网网关,简单来说,就是连接工业现场设备与云端或者上层管理系统的关键桥梁。 而明达技术研…...

左神算法基础巩固--5
文章目录 前缀树生成前缀树查询前缀树查询字符串加入过几次查询所有加入的字符串中,有几个是以pre这个字符串作为前缀 删除前缀树中的某个字符串 贪心算法解题 前缀树 生成前缀树 要想生成一棵前缀树,需要先创建一个根节点,这个根节点有26条…...

Python的Matplotlib库应用(超详细教程)
目录 一、环境搭建 1.1 配置matplotlib库 1.2 配置seaborn库 1.3 配置Skimage库 二、二维图像 2.1 曲线(直线)可视化 2.2 曲线(虚线)可视化 2.3 直方图 2.4 阶梯图 三、三维图像 3.1 3D曲面图 3.2 3D散点图 3.3 3D散…...
负载均衡服务器要怎么配置?
目录 一、概述: 二、硬件配置: 三、操作系统配置: 四、负载均衡软件: 五、网络配置: 六、软件安装步骤: 6.1 安装 Nginx 6.2 安装 LVS 6.3 安装 HAProxy 6.4 安装 Keepalived 一、概述࿱…...

CANopen转EtherCAT网关连接伺服驱动
在现代工业自动化领域,CANopen和EtherCAT是两种常见的通信协议,各自在不同的应用场景中发挥着重要作用。然而,随着工业自动化系统的日益复杂化,不同设备间的通信需求也变得多样化。因此,如何实现不同协议设备之间的无缝…...
自动化测试脚本实践:基于 Bash 的模块化测试框架
前言 在现代软件开发中,测试自动化是确保软件质量和稳定性的核心手段之一。随着开发周期的缩短和功能模块的增多,手动测试逐渐无法满足高效性和准确性的需求。因此,测试人员需要依赖自动化工具来提升测试效率,减少人为干预和错误。…...

WebSocket 测试入门篇
Websocket 是一种用于 H5 浏览器的实时通讯协议,可以做到数据的实时推送,可适用于广泛的工作环境,例如客服系统、物联网数据传输系统, 基础介绍 我们平常接触最多的是 http 协议的接口,http 协议是请求与响应的模式&…...
Apache Traffic存在SQL注入漏洞(CVE-2024-45387)
免责声明: 本文旨在提供有关特定漏洞的深入信息,帮助用户充分了解潜在的安全风险。发布此信息的目的在于提升网络安全意识和推动技术进步,未经授权访问系统、网络或应用程序,可能会导致法律责任或严重后果。因此,作者不对读者基于本文内容所采取的任何行为承担责任。读者在…...
Centos7使用yum工具出现 Could not resolve host: mirrorlist.centos.org
在 CentOS 7 中使用 yum 工具时,出现 "Could not resolve host: mirrorlist.centos.org" 的错误,一般情况是因为默认的镜像源无法访问。 以下是一些常用的解决方法: 检查网络连接:首先使用 ping 命令测试网络连接是否…...
zookeeper shell操作和zookeeper 典型应用(配置中心、集群选举服务、分布式锁)
文章目录 引言I zookeeper客户端命令查看子节点 ls创建子节点 create获取节点信息 get更新节点数据 set删除节点 delete\ rmrII 监听机制node1:设置监听node3:修改监听节点node1:得到监听反馈III zookeeper 典型应用分布式锁集群选举服务数据发布/订阅(配置中心)引言 zk 的…...

Vue中Watch使用监听修改变动
使用注意 监听一个值时 多个值时...
Lua语言的文件IO
1、我们都知道,在任何语言当中都有输入输出,比如c语言当中就有很多printf,scanf,get ,put,gets,puts,文件io:open,read,write,close,标准io:fopen,fread,fwrite,fclose.在lua语言当中,也有相同的一些输入输出特性,叫io.open,io.re…...
C语言基本知识复习浓缩版:输出函数printf
输出函数printf学习 printf()的作用是将文本输出到屏幕上使用之前需要先引入stdio.h头文件printf函数在使用的时候,至少需要一个参数 printf() 是 C 语言标准库中的一个函数,用于将格式化的文本输出到标准输出设备(通常是屏幕)。…...

(LeetCode 每日一题) 3442. 奇偶频次间的最大差值 I (哈希、字符串)
题目:3442. 奇偶频次间的最大差值 I 思路 :哈希,时间复杂度0(n)。 用哈希表来记录每个字符串中字符的分布情况,哈希表这里用数组即可实现。 C版本: class Solution { public:int maxDifference(string s) {int a[26]…...

地震勘探——干扰波识别、井中地震时距曲线特点
目录 干扰波识别反射波地震勘探的干扰波 井中地震时距曲线特点 干扰波识别 有效波:可以用来解决所提出的地质任务的波;干扰波:所有妨碍辨认、追踪有效波的其他波。 地震勘探中,有效波和干扰波是相对的。例如,在反射波…...
基于服务器使用 apt 安装、配置 Nginx
🧾 一、查看可安装的 Nginx 版本 首先,你可以运行以下命令查看可用版本: apt-cache madison nginx-core输出示例: nginx-core | 1.18.0-6ubuntu14.6 | http://archive.ubuntu.com/ubuntu focal-updates/main amd64 Packages ng…...

转转集团旗下首家二手多品类循环仓店“超级转转”开业
6月9日,国内领先的循环经济企业转转集团旗下首家二手多品类循环仓店“超级转转”正式开业。 转转集团创始人兼CEO黄炜、转转循环时尚发起人朱珠、转转集团COO兼红布林CEO胡伟琨、王府井集团副总裁祝捷等出席了开业剪彩仪式。 据「TMT星球」了解,“超级…...
OpenLayers 分屏对比(地图联动)
注:当前使用的是 ol 5.3.0 版本,天地图使用的key请到天地图官网申请,并替换为自己的key 地图分屏对比在WebGIS开发中是很常见的功能,和卷帘图层不一样的是,分屏对比是在各个地图中添加相同或者不同的图层进行对比查看。…...

Java面试专项一-准备篇
一、企业简历筛选规则 一般企业的简历筛选流程:首先由HR先筛选一部分简历后,在将简历给到对应的项目负责人后再进行下一步的操作。 HR如何筛选简历 例如:Boss直聘(招聘方平台) 直接按照条件进行筛选 例如:…...

uniapp手机号一键登录保姆级教程(包含前端和后端)
目录 前置条件创建uniapp项目并关联uniClound云空间开启一键登录模块并开通一键登录服务编写云函数并上传部署获取手机号流程(第一种) 前端直接调用云函数获取手机号(第三种)后台调用云函数获取手机号 错误码常见问题 前置条件 手机安装有sim卡手机开启…...

android RelativeLayout布局
<?xml version"1.0" encoding"utf-8"?> <RelativeLayout xmlns:android"http://schemas.android.com/apk/res/android"android:layout_width"match_parent"android:layout_height"match_parent"android:gravity&…...
raid存储技术
1. 存储技术概念 数据存储架构是对数据存储方式、存储设备及相关组件的组织和规划,涵盖存储系统的布局、数据存储策略等,它明确数据如何存储、管理与访问,为数据的安全、高效使用提供支撑。 由计算机中一组存储设备、控制部件和管理信息调度的…...
统计学(第8版)——统计抽样学习笔记(考试用)
一、统计抽样的核心内容与问题 研究内容 从总体中科学抽取样本的方法利用样本数据推断总体特征(均值、比率、总量)控制抽样误差与非抽样误差 解决的核心问题 在成本约束下,用少量样本准确推断总体特征量化估计结果的可靠性(置…...