当前位置: 首页 > article >正文

MOEFeedForward 模块

代码

class FeedForward(nn.Module):def __init__(self, config: LMConfig):super().__init__()if config.hidden_dim is None:hidden_dim = 4 * config.dimhidden_dim = int(2 * hidden_dim / 3)config.hidden_dim = config.multiple_of * ((hidden_dim + config.multiple_of - 1) // config.multiple_of)self.w1 = nn.Linear(config.dim, config.hidden_dim, bias=False)self.w2 = nn.Linear(config.hidden_dim, config.dim, bias=False)self.w3 = nn.Linear(config.dim, config.hidden_dim, bias=False)self.dropout = nn.Dropout(config.dropout)def forward(self, x):return self.dropout(self.w2(F.silu(self.w1(x)) * self.w3(x)))class MoEGate(nn.Module):def __init__(self, config: LMConfig):super().__init__()self.config = configself.top_k = config.num_experts_per_tokself.n_routed_experts = config.n_routed_expertsself.scoring_func = config.scoring_funcself.alpha = config.aux_loss_alphaself.seq_aux = config.seq_auxself.norm_topk_prob = config.norm_topk_probself.gating_dim = config.dimself.weight = nn.Parameter(torch.empty((self.n_routed_experts, self.gating_dim)))self.reset_parameters()def reset_parameters(self) -> None:import torch.nn.init as initinit.kaiming_uniform_(self.weight, a=math.sqrt(5))def forward(self, hidden_states):bsz, seq_len, h = hidden_states.shapehidden_states = hidden_states.view(-1, h)logits = F.linear(hidden_states, self.weight, None)if self.scoring_func == 'softmax':scores = logits.softmax(dim=-1)else:raise NotImplementedError(f'insupportable scoring function for MoE gating: {self.scoring_func}')topk_weight, topk_idx = torch.topk(scores, k=self.top_k, dim=-1, sorted=False)if self.top_k > 1 and self.norm_topk_prob:denominator = topk_weight.sum(dim=-1, keepdim=True) + 1e-20topk_weight = topk_weight / denominatorif self.training and self.alpha > 0.0:scores_for_aux = scoresaux_topk = self.top_ktopk_idx_for_aux_loss = topk_idx.view(bsz, -1)if self.seq_aux:scores_for_seq_aux = scores_for_aux.view(bsz, seq_len, -1)ce = torch.zeros(bsz, self.n_routed_experts, device=hidden_states.device)ce.scatter_add_(1, topk_idx_for_aux_loss,torch.ones(bsz, seq_len * aux_topk, device=hidden_states.device)).div_(seq_len * aux_topk / self.n_routed_experts)aux_loss = (ce * scores_for_seq_aux.mean(dim=1)).sum(dim=1).mean() * self.alphaelse:mask_ce = F.one_hot(topk_idx_for_aux_loss.view(-1), num_classes=self.n_routed_experts)ce = mask_ce.float().mean(0)Pi = scores_for_aux.mean(0)fi = ce * self.n_routed_expertsaux_loss = (Pi * fi).sum() * self.alphaelse:aux_loss = 0return topk_idx, topk_weight, aux_lossclass MOEFeedForward(nn.Module):def __init__(self, config: LMConfig):super().__init__()self.config = configself.experts = nn.ModuleList([FeedForward(config)for _ in range(config.n_routed_experts)])self.gate = MoEGate(config)if config.n_shared_experts is not None:self.shared_experts = FeedForward(config)def forward(self, x):identity = xorig_shape = x.shapebsz, seq_len, _ = x.shape# 使用门控机制选择专家topk_idx, topk_weight, aux_loss = self.gate(x)x = x.view(-1, x.shape[-1])flat_topk_idx = topk_idx.view(-1)if self.training:# 训练模式下,重复输入数据x = x.repeat_interleave(self.config.num_experts_per_tok, dim=0)y = torch.empty_like(x, dtype=torch.float16)for i, expert in enumerate(self.experts):y[flat_topk_idx == i] = expert(x[flat_topk_idx == i]).to(y.dtype)  # 确保类型一致y = (y.view(*topk_weight.shape, -1) * topk_weight.unsqueeze(-1)).sum(dim=1)y = y.view(*orig_shape)else:# 推理模式下,只选择最优专家y = self.moe_infer(x, flat_topk_idx, topk_weight.view(-1, 1)).view(*orig_shape)if self.config.n_shared_experts is not None:y = y + self.shared_experts(identity)self.aux_loss = aux_lossreturn y@torch.no_grad()def moe_infer(self, x, flat_expert_indices, flat_expert_weights):expert_cache = torch.zeros_like(x)idxs = flat_expert_indices.argsort()tokens_per_expert = flat_expert_indices.bincount().cpu().numpy().cumsum(0)token_idxs = idxs // self.config.num_experts_per_tok# 例如当tokens_per_expert=[6, 15, 20, 26, 33, 38, 46, 52]# 当token_idxs=[3, 7, 19, 21, 24, 25,  4,  5,  6, 10, 11, 12...]# 意味着当token_idxs[:6] -> [3,  7, 19, 21, 24, 25,  4]位置的token都由专家0处理,token_idxs[6:15]位置的token都由专家1处理......for i, end_idx in enumerate(tokens_per_expert):start_idx = 0 if i == 0 else tokens_per_expert[i - 1]if start_idx == end_idx:continueexpert = self.experts[i]exp_token_idx = token_idxs[start_idx:end_idx]expert_tokens = x[exp_token_idx]expert_out = expert(expert_tokens).to(expert_cache.dtype)expert_out.mul_(flat_expert_weights[idxs[start_idx:end_idx]])# 使用 scatter_add_ 进行 sum 操作expert_cache.scatter_add_(0, exp_token_idx.view(-1, 1).repeat(1, x.shape[-1]), expert_out)return expert_cache

代码解释

解释一下这段代码的主要组成部分:

  1. FeedForward 类:
  • 实现了一个基础的前馈网络
  • 使用 SwiGLU 激活函数(F.silu(self.w1(x)) * self.w3(x)
  • 包含三个线性层(w1、w2、w3)和一个 dropout 层
  1. MoEGate 类(门控机制):
  • 负责决定每个 token 应该由哪些专家处理
  • 主要步骤:
    1. 计算每个 token 对应每个专家的分数(使用 softmax)
    2. 选择 top-k 个最高分的专家
    3. 计算辅助损失(aux_loss)来平衡专家的使用
  1. MOEFeedForward 类(混合专家系统):
  • 包含多个专家(FeedForward)和一个门控网络(MoEGate)

  • 训练模式:

    1. 使用门控网络选择每个 token 的专家
    2. 将输入数据复制多份,分发给不同专家
    3. 专家并行处理数据
    4. 根据门控权重合并结果
  • 推理模式(moe_infer):

    1. 对专家索引排序,将相同专家的 token 批量处理
    2. 使用 scatter_add_ 将专家输出累加到正确位置
    3. 更高效的推理实现,避免了数据重复
  1. 特殊功能:
  • 支持共享专家(n_shared_experts
  • 实现了专家负载均衡(通过辅助损失)
  • 支持每个 token 选择多个专家(num_experts_per_tok

这是一个典型的 MoE(Mixture of Experts)实现,用于大型语言模型中提高模型容量和计算效率。

示例

# 创建 MoE 实例
dim = 512                    # 输入维度
n_routed_experts = 4         # 专家数量
num_experts_per_tok = 2      # 每个token选择的专家数量moe = MOEFeedForward(dim=dim,n_routed_experts=n_routed_experts,num_experts_per_tok=num_experts_per_tok,hidden_dim=None,         # FFN隐藏层维度,None时自动计算dropout=0.1             # dropout比率
)# 创建示例输入
batch_size = 2
seq_len = 10
x = torch.randn(batch_size, seq_len, dim)  # 形状: [2, 10, 512]moe(x)

输出

After gate - topk_idx.shape: torch.Size([20, 2]), topk_weight.shape: torch.Size([20, 2])
After view - x.shape: torch.Size([20, 512]), flat_topk_idx.shape: torch.Size([40])
After repeat_interleave - x.shape: torch.Size([40, 512])
Empty y tensor shape: torch.Size([40, 512])
Expert 0 - input shape: torch.Size([9, 512])
Expert 0 - output shape: torch.Size([9, 512])
Expert 1 - input shape: torch.Size([13, 512])
Expert 1 - output shape: torch.Size([13, 512])
Expert 2 - input shape: torch.Size([11, 512])
Expert 2 - output shape: torch.Size([11, 512])
Expert 3 - input shape: torch.Size([7, 512])
Expert 3 - output shape: torch.Size([7, 512])
Before view - y.shape: torch.Size([40, 512])
topk_weight.shape: torch.Size([20, 2])
After view and sum - y.shape: torch.Size([20, 512])
Final y.shape: torch.Size([2, 10, 512])

相应的torch函数

import torch
# empty: 创建未初始化的张量
x = torch.empty((2, 3))  # 创建形状为 2x3 的未初始化张量# zeros_like: 创建与输入相同形状的全零张量
a = torch.tensor([[1, 2], [3, 4]])
b = torch.zeros_like(a)  # 创建形状为 2x2 的全零张量
print(b)  # tensor([[0, 0], [0, 0]])
tensor([[0, 0],[0, 0]])
import torch.nn.functional as F
x = torch.tensor([[1, 2, 3, 4], [5, 6, 7, 8]])
# view: 改变张量形状
y = x.view(-1)  # 展平为一维
print(y)  # tensor([1, 2, 3, 4, 5, 6, 7, 8])# -1 表示自动计算该维度大小
z = x.view(-1, 2)  # 重塑为 4x2
print(z)  # tensor([[1, 2], [3, 4], [5, 6], [7, 8]])
tensor([1, 2, 3, 4, 5, 6, 7, 8])
tensor([[1, 2],[3, 4],[5, 6],[7, 8]])
# linear: 线性变换 y = xA^T + b
input = torch.randn(2, 3)  # 2个样本,每个3维
weight = torch.randn(4, 3)  # 输出4维
output = F.linear(input, weight)  # 形状变为 [2, 4]# softmax: 将数值转换为概率分布
logits = torch.tensor([1.0, 2.0, 3.0])
probs = F.softmax(logits, dim=0)
print(probs)  # tensor([0.0900, 0.2447, 0.6652])
tensor([0.0900, 0.2447, 0.6652])
# 找出最大的k个值及其索引
x = torch.tensor([1, 5, 2, 8, 3])
values, indices = torch.topk(x, k=2)
print(values)   # tensor([8, 5])
print(indices)  # tensor([3, 1])
tensor([8, 5])
tensor([3, 1])
x = torch.tensor([1, 2, 3])
# 每个元素重复2次
y = x.repeat_interleave(2)
print(y)  # tensor([1, 1, 2, 2, 3, 3])
tensor([1, 1, 2, 2, 3, 3])
# 统计每个数字出现的次数
x = torch.tensor([1, 1, 2, 3, 1, 2])
counts = x.bincount()
print(counts)  # tensor([0, 3, 2, 1])  # 0出现0次,1出现3次,2出现2次,3出现1次
tensor([0, 3, 2, 1])
# 在指定位置累加值
src = torch.tensor([[1, 2], [3, 4]], dtype=torch.float)  # 指定数据类型为 float
index = torch.tensor([[0, 1], [0, 1]])
out = torch.zeros(2, 2, dtype=torch.float)  # 确保与 src 的数据类型相同
out.scatter_add_(0, index, src)
print(out) 
tensor([[4., 0.],[0., 6.]])
# 返回排序后的索引
x = torch.tensor([3, 1, 4, 1, 5])
indices = x.argsort()
print(indices)  # tensor([1, 3, 0, 2, 4])  # 最小值在位置1和3,然后是0,2,4
tensor([1, 3, 0, 2, 4])

相关文章:

MOEFeedForward 模块

代码 class FeedForward(nn.Module):def __init__(self, config: LMConfig):super().__init__()if config.hidden_dim is None:hidden_dim 4 * config.dimhidden_dim int(2 * hidden_dim / 3)config.hidden_dim config.multiple_of * ((hidden_dim config.multiple_of - 1…...

笔记:代码随想录算法训练营day41:LeetCode121. 买卖股票的最佳时机、122.买卖股票的最佳时机II、123.买卖股票的最佳时机III

学习资料:代码随想录 121. 买卖股票的最佳时机 力扣题目链接 思路:注意题意只能买卖一次 定义:dp[i][0]表示不持有当前股票,dp[i][1]表示持有当前股票 递推公式:今天持有分之前就持有和今天才买,今天不…...

政策助力,3C 数码行业数字化起航

政策引领,数字经济浪潮来袭 在当今时代,数字经济已成为全球经济发展的核心驱动力,引领着新一轮科技革命和产业变革的潮流。我国深刻洞察这一发展趋势,大力推进数字化经济发展战略,为经济的高质量发展注入了强大动力。 …...

MySQL数据库复制

文章目录 MySQL数据库复制一、复制的原理二、复制的搭建1.编辑配置文件2.在主库上创建复制的用户3.获取主库的备份4.基于从库的恢复5.建立主从复制6.开启主从复制7.查看主从复制状态 MySQL数据库复制 MySQL作为非常流行的数据库,支撑它如此出彩的因素主要有两个&am…...

安装 ubuntu 2404 LTS 服务器 设置 服务器名称

安装 ubuntu服务器 设置 服务器名称 hostname 打开终端(Terminal),通过快捷键CtrlAltT或在应用程序中搜索"终端"来打开;在终端中输入以下命令:hostname,然后按下回车键即可查看本机服务器名称。…...

101.在 Vue 3 + OpenLayers 使用 declutter 避免文字标签重叠

1. 前言 在使用 OpenLayers 进行地图开发时,我们经常需要在地图上添加点、线、区域等图形,并给它们附加文字标签。但当地图上的标注较多时,文字标签可能会发生重叠,导致用户无法清晰地查看地图信息。 幸运的是,OpenL…...

uniapp移动端图片比较器组件,仿英伟达官网rtx光追图片比较器功能

组件下载地址:https://ext.dcloud.net.cn/plugin?id22609 已测试h5和微信小程序,理论支持全平台 亮点: 简单易用 使用js计算而不是resize属性,定制化程度更高 组件挂在后可播放指示线动画,提示用户可以拖拽比较图片…...

深度学习与大模型-矩阵

矩阵其实在我们的生活中也有很多应用,只是我们没注意罢了。 1. 矩阵是什么? 简单来说,矩阵就是一个长方形的数字表格。比如你有一个2行3列的矩阵,可以写成这样: 这个矩阵有2行3列,每个数字都有一个位置&a…...

搭建基于chatgpt的问答系统

一、语言模型,提问范式与 Token 1.语言模型 大语言模型(LLM)是通过预测下一个词的监督学习方式进行训练的,通过预测下一个词为训练目标的方法使得语言模型获得强大的语言生成能力。 a.基础语言模型 (Base LLM&…...

LuaJIT 学习(2)—— 使用 FFI 库的几个例子

文章目录 介绍Motivating Example: Calling External C Functions例子:Lua 中调用 C 函数 Motivating Example: Using C Data StructuresAccessing Standard System FunctionsAccessing the zlib Compression LibraryDefining Metamethods for a C Type例子&#xf…...

解锁 AI 开发的无限可能:邀请您加入 coze-sharp 开源项目

大家好!今天我要向大家介绍一个充满潜力的开源项目——coze-sharp!这是一个基于 C# 开发的 Coze 客户端,旨在帮助开发者轻松接入 Coze AI 平台,打造智能应用。项目地址在这里:https://github.com/zhulige/coze-sharp&a…...

全面解析与实用指南:如何有效解决ffmpeg.dll丢失问题并恢复软件正常运行

在使用多媒体处理软件或进行视频编辑时,你可能会遇到一个常见的问题——ffmpeg.dll文件丢失。这个错误不仅会中断你的工作流程,还可能导致软件无法正常运行。ffmpeg.dll是FFmpeg库中的一个关键动态链接库文件,负责处理视频和音频的编码、解码…...

Python----计算机视觉处理(opencv:像素,RGB颜色,图像的存储,opencv安装,代码展示)

一、计算机眼中的图像 像素 像素是图像的基本单元,每个像素存储着图像的颜色、亮度和其他特征。一系列像素组合到一起就形成 了完整的图像,在计算机中,图像以像素的形式存在并采用二进制格式进行存储。根据图像的颜色不 同,每个像…...

Nginx 限流功能:原理、配置与应用

Nginx 限流功能:原理、配置与应用 在当今互联网应用的高并发场景下,服务器面临着巨大的压力。为了确保系统的稳定运行,保障核心业务的正常开展,限流成为了一项至关重要的技术手段。Nginx 作为一款高性能的 Web 服务器和反向代理服…...

【大模型学习】第十九章 什么是迁移学习

目录 1. 迁移学习的起源背景 1.1 传统机器学习的问题 1.2 迁移学习的提出背景 2. 什么是迁移学习 2.1 迁移学习的定义 2.2 生活实例解释 3. 技术要点与原理 3.1 迁移学习方法分类 3.1.1 基于特征的迁移学习(Feature-based Transfer) 案例说明 代码示例 3.1.2 基于…...

小米路由器SSH下安装DDNS-GO

文章目录 前言一、下载&安装DDNS-GO二、配置ddns-go设置开机启动 前言 什么是DDNS? DDNS(Dynamic Domain Name Server)是动态域名服务的缩写。 目前路由器拨号上网获得的多半都是动态IP,DDNS可以将路由器变化的外网I…...

C++ 布尔类型(bool)深度解析

引言 在 C 编程里,布尔类型(bool)是一种基础且极为关键的数据类型。它专门用于表达逻辑值,在程序的条件判断、循环控制等诸多方面都发挥着重要作用。接下来,我们将对 C 中的布尔类型展开全面且深入的探讨。 一、布尔…...

树莓科技集团董事长:第五代产业园运营模式的深度剖析与展望​

第五代产业园运营模式,以创新为核心驱动,强调数字化、网络化和资源整合。树莓科技集团在这一领域具有代表性,其运营模式值得深入剖析。 核心特征 数字化转型:第五代产业园高度重视数字化技术的应用,通过构建数字化平…...

go语言zero框架拉取内部平台开发的sdk报错的修复与实践

在开发过程中,我们可能会遇到由于认证问题无法拉取私有 SDK 的情况。这种情况常发生在使用 Go 语言以及 Zero 框架时,尤其是在连接到私有平台,如阿里云 Codeup 上托管的 Go SDK。如果你遇到这种错误,通常是因为 Go 没有适当的认证…...

手机屏幕摔不显示了,如何用其他屏幕临时显示,用来导出资料或者清理手机

首先准备一个拓展坞 然后 插入一个外接的U盘 插入鼠标 插入有数字小键盘区的键盘 然后准备一根高清线,一端链接电脑显示器,一端插入拓展坞 把拓展坞的连接线,插入手机充电口(可能会需要转接头) 然后确保手机开机 按下键盘…...

工业三防平板AORO-P300 Ultra,开创铁路检修与调度数字化新范式

在现代化铁路系统的庞大网络中,其设备维护与运营调度的精准性直接影响着运输效率和公共安全。在昼夜温差大、电磁环境复杂、震动粉尘交织的铁路作业场景中,AORO-P300 Ultra工业三防平板以高防护标准与智能化功能体系,开创了铁路行业移动端数字…...

LInux基础--apache部署网站

httpd的安装 yum -y install httpdhttpd的使用 启动httpd systemctl enable --now httpd使用enable --now 进行系统设置时,会将该服务设置为开机自启并且同时开启服务 访问httpd 创建虚拟主机 基于域名 在一台主机上配置两个服务server1和server2,其…...

Linux内核套接字以及分层模型

一、套接字通信 内核开发工程师将网络部分的头文件存储到一个专门的目录include/net中,而不是存储到标准位置include/linux。 计算机之间通信是一个非常复杂的问题: 如何建立物理连接?使用什么样的线缆?通信介质有那些限制和特殊…...

Linux《基础开发工具(中)》

在之前的Linux《基础开发工具(上)》当中已经了解了Linux当中到的两大基础的开发工具yum与vim;了解了在Linux当中如何进行软件的下载以及实现的基本原理、知道了编辑器vim的基本使用方式,那么接下来在本篇当中将接下去继续来了解另…...

使用1Panel一键搭建WordPress网站的详细教程(全)

嘿,各位想搭建自己网站的朋友们!今天我要跟大家分享我用1Panel搭建WordPress网站的全过程。说实话,我之前对服务器运维一窍不通,但通过这次尝试,我发现原来建站可以这么简单!下面是我的亲身经历和一些小技巧…...

uni-app学习笔记——自定义模板

一、流程 1.这是一个硬性的流程,只要按照如此程序化就可以实现 二、步骤 1.第一步 2.第二步 3.第三步 4.每一次新建页面,都如第二步一样;可以选择自定义的模版(vue3Setup——这是我自己的模版),第二步的…...

kotlin基础知识点汇总

对象类继承变量常量静态常量定义方法重载方法基本数据类型比较类型转换符字符串比较数组循环角标循环高级循环判断器构造函数类创建私有化 set 方法私有化 get 方法枚举接口匿名内部类内部类内部类访问外部类同名变量抽象类静态变量和方法可变参数泛型构造代码块静态代码块方法…...

git备份or打补丁

起因 在工作中使用git pull突然发现仓库出现了找不到代码库问题,但是这个时候有个对策又急着需要,于是乎,就需要备份,拷贝给另一个工程师输出。 git 打补丁操作 工程师A生成补丁文件 touch a.txtgit add a.txtgit commit -m &qu…...

如何使用GuzzleHttp库:详细教程与代码示例

GuzzleHttp 是一个功能强大的 PHP HTTP 客户端库,它可以帮助开发者方便地发送 HTTP 请求。与传统的 cURL 相比,Guzzle 提供了一个更简单且易于使用的 API,并且支持同步和异步请求。以下是 GuzzleHttp 的使用方法和一些高级特性。 一、安装 G…...

数据结构——顺序表seqlist

前言:大家好😍,本文主要介绍了数据结构——顺序表部分的内容 目录 一、线性表的定义 二、线性表的基本操作 三.顺序表 1.定义 2. 存储结构 3. 特点 四 顺序表操作 4.1初始化 4.2 插入 4.2.1头插 4.2.2 尾插 4.2.3 按位置插 4.3 …...