【2024 CSDN博客之星】技术洞察类:从DeepSeek-V3的成功,看MoE混合专家网络对深度学习算法领域的影响(MoE代码级实战)
目录
一、引言
1.1 本篇文章侧重点
1.2 技术洞察—MoE(Mixture-of-Experts,混合专家网络)
二、MoE(Mixture-of-Experts,混合专家网络)
2.1 技术原理
2.2 技术优缺点
2.3 业务代码实践
2.3.1 业务场景与建模
2.3.2 模型代码实现
2.3.3 模型训练与推理测试
2.3.4 打印模型结构
三、总结
一、引言
经历了大模型2024一整年度的兵荒马乱,从年初的Sora文生视频到MiniMax顿悟后的开源,要说年度最大赢家,当属deepseek莫属:年中,deepseek-v2以其1/100的售价,横扫包括gpt4、qwen、百度等一系列商用模型;年底,deepseek-v3发布,以MoE为核心的专家网络技术,让其以极低的推理成本,获得了媲美gpt-4o的效果。
1.1 本篇文章侧重点
本篇文章作为年度技术洞察类文章,今天的重点不是deepseek的训练与推理,如果对训练和推理感兴趣,我在年中写过一篇训练与推理的实战,其中详细讲述了DeepSeek-V2大模型的训练和推理,详细可点击:AI智能体研发之路-模型篇(二):DeepSeek-V2-Chat 训练与推理实战(只需将V2替换为V3,即可体验最新版本deepseek)。今天的重点是更深一个层次,带大家代码级认识MoE混合专家网络技术。
1.2 技术洞察—MoE(Mixture-of-Experts,混合专家网络)
MoE(Mixture-of-Experts) 并不是一个新词,近7-8年间,在我做推荐系统精排模型过程中,业界将MoE技术应用于推荐系统多任务学习,以MMoE(2018,google)、PLE(2020,腾讯)为基石,通过门控网络为多个专家网络加权平均,定义每个专家的重要性,解决多目标、多场景、多任务等问题。近1-2年间,基于MoE思想构建的大模型层出不穷,通过路由网络对多个专家网络进行选择,提升推理效率,经典模型有DeepSeekMoE、Mixtral 8x7B、Flan-MoE等。
万丈高楼平地起,今天我们不聊空中楼阁,而是带大家实现一个MoE网络,了解MoE代码是怎么构建的,大家可以以此代码为基础,继续垒砖,根据自己的业务场景,创新性的构建自己的专家网络。
二、MoE(Mixture-of-Experts,混合专家网络)
2.1 技术原理
MoE(Mixture-of-Experts)全称为混合专家网络,主要由多个专家网络、多个任务塔、门控网络构成。核心原理:样本数据分别输入num_experts个专家网络进行推理,每个专家网络实际上是一个前馈神经网络(MLP),输入维度为x,输出维度为output_experts_dim;同时,样本数据输入门控网络,门控网络也是一个MLP(可以为多层,也可以为一层),输出为num_experts个experts专家的概率分布,维度为num_experts(菜用softmax将输出归一化,各个维度加起来和为1);将每个专家网络的输出,基于gate门控网络的softmax加权平均,作为Task的输入,所以Task的输入统一维度均为output_experts_dim。在每次反向传播迭代时,对Gate和num_experts个专家参数进行更新,Gate和专家网络的参数受任务Task A、B共同影响。
- 专家网络:样本数据分别输入num_experts个专家网络进行推理,每个专家网络实际上是一个前馈神经网络(MLP),输入维度为x,输出维度为output_experts_dim。
- 门控网络:样本数据输入门控网络,门控网络也是一个MLP(可以为多层,也可以为一层),输出为num_experts个experts专家的概率分布,维度为num_experts(菜用softmax将输出归一化,各个维度加起来和为1)。
- 任务网络:将每个专家网络的输出,基于gate门控网络的softmax加权平均,作为Task的输入,Task的输入统一维度均为output_experts_dim。
2.2 技术优缺点
相较于传统的DNN网络,MoE的本质是通过多个专家网络对预估任务共同决策,引入Gate作为专家的裁判,给每一个专家打分,判定哪个专家更加权威。(DeepSeekMoE的Router与Gate类似,区别是Gate为每一个专家赋分,加权平均,Router对专家进行选择,推理速度更快)。相较于传统的DNN网络:
优点:
- 多个DNN专家网络投票共同决定推理结果,相较于单个DNN网络泛化性更好,准确率更高。
- Gate网络基于多个Task任务进行反馈收敛,可以学到多个Task任务数据的平衡性。
缺点:
- 朴素的MoE仅使用了一个Gate网络,虽然Gate网络由多个Task任务共同收敛学习得到,具有一定的平衡性,但对于每个Task的个性化能力仍然不足。(Google针对此缺点发布了MMoE)
- 底层多个专家网络均为共享专家,输入均为样本数据,参数的差异主要由初始化的不同得到,并不具备特异性。(腾讯针对此缺点发布了PLE)
- 输入Input均为全部样本数据,学不出不同场景任务的差异性,需要在输入层对场景特征进行拆分(阿里针对此缺点发布了ESMM)
2.3 业务代码实践
2.3.1 业务场景与建模
我们仍然以小红书推荐场景为例,用户在一级发现页场景中停留并点击了“误杀3”中的一个视频笔记,在二级场景视频播放页中观看并点赞了视频。
我们构建一个100维特征输入,4个experts专家网络,2个task任务的,1个门控的MoE网络,用于建模跨场景多任务学习问题,模型架构图如下:
如架构图所示,其中有几个注意的点:
- num_experts:门控gate的输出维度和专家数相同,均为num_experts,因为gate的用途是对专家网络最后一层进行加权平均,gate维度与专家数是直接对应关系。
- output_experts_dim:专家网络的输出维度和task网络的输入维度相同,task网络承接的是专家网络各维度的加权平均值,experts网络与task网络是直接对应关系。
- Softmax:Gate门控网络对最后一层采用Softmax归一化,保证专家网络加权平均后值域相同
2.3.2 模型代码实现
基于pytorch,实现上述网络架构,如下:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDatasetclass MoEModel(nn.Module):def __init__(self, input_dim, experts_hidden1_dim, experts_hidden2_dim, output_experts_dim, task_hidden1_dim, task_hidden2_dim, output_task1_dim, output_task2_dim, gate_hidden1_dim, gate_hidden2_dim, num_experts):super(MoEModel, self).__init__()self.num_experts = num_expertsself.output_experts_dim = output_experts_dim# 初始化多个专家网络self.experts = nn.ModuleList([nn.Sequential(nn.Linear(input_dim, experts_hidden1_dim),nn.ReLU(),nn.Linear(experts_hidden1_dim, experts_hidden2_dim),nn.ReLU(),nn.Linear(experts_hidden2_dim, output_experts_dim),nn.ReLU()) for _ in range(num_experts)])# 定义任务1的输出层self.task1_head = nn.Sequential(nn.Linear(output_experts_dim, task_hidden1_dim),nn.ReLU(),nn.Linear(task_hidden1_dim, task_hidden2_dim),nn.ReLU(),nn.Linear(task_hidden2_dim, output_task1_dim),nn.Sigmoid()) # 定义任务2的输出层self.task2_head = nn.Sequential(nn.Linear(output_experts_dim, task_hidden1_dim),nn.ReLU(),nn.Linear(task_hidden1_dim, task_hidden2_dim),nn.ReLU(),nn.Linear(task_hidden2_dim, output_task2_dim),nn.Sigmoid()) # 初始化门控网络self.gating_network = nn.Sequential(nn.Linear(input_dim, gate_hidden1_dim),nn.ReLU(),nn.Linear(gate_hidden1_dim, gate_hidden2_dim),nn.ReLU(),nn.Linear(gate_hidden2_dim, num_experts),nn.Softmax(dim=1))def forward(self, x):# 计算输入数据通过门控网络后的权重gates = self.gating_network(x)#print(gates)batch_size, _ = x.shapetask1_inputs = torch.zeros(batch_size, self.output_experts_dim)task2_inputs = torch.zeros(batch_size, self.output_experts_dim)# 计算每个专家的输出并加权求和for i in range(self.num_experts):expert_output = self.experts[i](x)task1_inputs += expert_output * gates[:, i].unsqueeze(1)task2_inputs += expert_output * gates[:, i].unsqueeze(1)task1_outputs = self.task1_head(task1_inputs)task2_outputs = self.task2_head(task2_inputs)return task1_outputs, task2_outputs# 实例化模型对象
num_experts = 4 # 假设有4个专家
experts_hidden1_dim = 64
experts_hidden2_dim = 32
output_experts_dim = 16
gate_hidden1_dim = 16
gate_hidden2_dim = 8
task_hidden1_dim = 32
task_hidden2_dim = 16
output_task1_dim = 3
output_task2_dim = 2# 构造虚拟样本数据
torch.manual_seed(42) # 设置随机种子以保证结果可重复
input_dim = 10
num_samples = 1024
X_train = torch.randint(0, 2, (num_samples, input_dim)).float()
y_train_task1 = torch.rand(num_samples, output_task1_dim) # 假设任务1的输出维度为5
y_train_task2 = torch.rand(num_samples, output_task2_dim) # 假设任务2的输出维度为3# 创建数据加载器
train_dataset = TensorDataset(X_train, y_train_task1, y_train_task2)
train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)model = MoEModel(input_dim, experts_hidden1_dim, experts_hidden2_dim, output_experts_dim, task_hidden1_dim, task_hidden2_dim, output_task1_dim, output_task2_dim, gate_hidden1_dim, gate_hidden2_dim, num_experts)# 定义损失函数和优化器
criterion_task1 = nn.MSELoss()
criterion_task2 = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)# 训练循环
num_epochs = 100
for epoch in range(num_epochs):model.train()running_loss = 0.0for batch_idx, (X_batch, y_task1_batch, y_task2_batch) in enumerate(train_loader):# 前向传播: 获取预测值#print(batch_idx, X_batch )#print(f'Epoch [{epoch+1}/{num_epochs}-{batch_idx}], Loss: {running_loss/len(train_loader):.4f}')outputs_task1, outputs_task2 = model(X_batch)# 计算每个任务的损失loss_task1 = criterion_task1(outputs_task1, y_task1_batch)loss_task2 = criterion_task2(outputs_task2, y_task2_batch)total_loss = loss_task1 + loss_task2# 反向传播和优化optimizer.zero_grad()total_loss.backward()optimizer.step()running_loss += total_loss.item()if epoch % 10 == 0:print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss/len(train_loader):.4f}')print(model)
for param_tensor in model.state_dict():print(param_tensor, "\t", model.state_dict()[param_tensor].size())
# 模型预测
model.eval()
with torch.no_grad():test_input = torch.randint(0, 2, (1, input_dim)).float() # 构造一个测试样本pred_task1, pred_task2 = model(test_input)print(f'一级场景预测结果: {pred_task1}')print(f'二级场景预测结果: {pred_task2}')
2.3.3 模型训练与推理测试
运行上述代码,模型启动训练,Loss逐渐收敛,测试结果如下:
2.3.4 打印模型结构
使用print(model)打印模型结构如下
三、总结
本文代码级脚踏实地讲解了DeepSeek大模型、MMoE推荐模型中的MoE(Mixture-of-Experts)技术,该技术的主要思想是通过门控(gate)或路由(router)网络,对多个专家进行加权平均或筛选,将一个DNN网络裂变为多个DNN网络后,投票决定预测结果,相较于单一的DNN网络,具有更强的容错性、泛化性与准确性,同时可以提高推理速度,节省推理资源。
技术洞察结论:MoE技术未来将成为大模型和推荐系统进一步突破的关键技术,个人认为该技术为2024年算法基础技术中的SOTA,但其实并没有那么神秘,通过本篇文章,可以试着动手实现一个MoE,再基于自己的业务场景,对齐专家网络、门控网络、任务网络进行创新,期待本篇文章对您有帮助!
如果您还有时间,欢迎阅读本专栏的其他文章:
【深度学习】多目标融合算法(一):样本Loss加权(Sample Loss Reweight)
【深度学习】多目标融合算法(二):底部共享多任务模型(Shared-Bottom Multi-task Model)
相关文章:

【2024 CSDN博客之星】技术洞察类:从DeepSeek-V3的成功,看MoE混合专家网络对深度学习算法领域的影响(MoE代码级实战)
目录 一、引言 1.1 本篇文章侧重点 1.2 技术洞察—MoE(Mixture-of-Experts,混合专家网络) 二、MoE(Mixture-of-Experts,混合专家网络) 2.1 技术原理 2.2 技术优缺点 2.3 业务代码实践 2.3.1 业务场…...

Linux——入门基本指令汇总
目录 1. ls指令2. pwd3. whoami指令4. cd指令5. clear指令6. touch指令7. mkdir指令8. rm指令9. man指令10. cp指令11. mv指令12. cat指令13. tac指令14. more指令15. less指令16. head指令17. tail指令18. date指令19. cal指令20. find指令21. which指令22. alias指令23. grep…...

54,【4】BUUCTF WEB GYCTF2020Ezsqli
进入靶场 吓我一跳,但凡放个彭于晏我都不说啥了 提交个1看看 1 and 11 1# 还尝试了很多,不过都被过滤了,头疼 看看别人的WP 竟然要写代码去跑!!!,不会啊,先用别人的代码吧…...
【Leetcode 热题 100】45. 跳跃游戏 II
问题背景 给定一个长度为 n n n 的 0 0 0 索引 整数数组 n u m s nums nums。初始位置为 n u m s [ 0 ] nums[0] nums[0]。 每个元素 n u m s [ i ] nums[i] nums[i] 表示从索引 i i i 向前跳转的最大长度。换句话说,如果你在 n u m s [ i ] nums[i] nums[i…...

C/C++ 时间复杂度(On)
定义: 在计算机科学中,时间复杂性,又称时间复杂度,算法的时间复杂度是一个函数,它定性描述该算法的运行时间。这是一个代表算法输入值的字符串的长度的函数。时间复杂度常用大O符号表述,不包括这个函数的低…...

【STM32-学习笔记-10-】BKP备份寄存器+时间戳
文章目录 BKP备份寄存器Ⅰ、BKP简介1. BKP的基本功能2. BKP的存储容量3. BKP的访问和操作4. BKP的应用场景5. BKP的控制寄存器 Ⅱ、BKP基本结构Ⅲ、BKP函数Ⅳ、BKP使用示例 时间戳一、Unix时间戳二、时间戳的转换(time.h函数介绍)Ⅰ、time()Ⅱ、mktime()…...
React 中hooks之 React.memo 和 useMemo用法总结
1. React.memo 基础 React.memo 是一个高阶组件(HOC),用于优化函数组件的性能,通过记忆组件渲染结果来避免不必要的重新渲染。 1.1 基本用法 const MemoizedComponent React.memo(function MyComponent(props) {/* 渲染逻辑 *…...

日志收集Day001
1.ElasticSearch 作用:日志存储和检索 2.单点部署Elasticsearch与基础配置 rpm -ivh elasticsearch-7.17.5-x86_64.rpm 查看配置文件yy /etc/elasticsearch/elasticsearch.yml(这里yy做了别名,过滤掉空行和注释行) yy /etc/el…...

机器人“大脑+小脑”范式:算力魔方赋能智能自主导航
在机器人技术的发展中,“大脑小脑”的架构模式逐渐成为推动机器人智能化的关键。其中,“大脑”作为机器人的核心决策单元,承担着复杂任务规划、环境感知和决策制定的重要角色,而“小脑”则专注于运动控制和实时调整。这种分工明确…...
python程序跑起来后,然后引用的数据文件发生了更新,python读取的数据会发生变化吗
在 Python 程序运行过程中,如果引用的数据文件被更新,程序能否读取到更新后的数据,取决于以下几个因素: 1. 是否动态读取文件 如果 Python 程序在运行过程中动态读取文件(例如通过循环或定时机制反复打开文件读取&…...

VSCode最新离线插件拓展下载方式
之前在vscode商店有以下类似的download按钮,但是2025年更新之后这个按钮就不提供了,所以需要使用新的方式下载 ps:给自己的网站推广下~~(国内直连GPT/Claude) 新的下载方式1 首先打开vscode商店官网:vscode插件下载…...
算法题目总结-栈和队列
文章目录 1.有效的括号1.答案2.思路 2.最小栈1.答案2.思路 3.前 K 个高频元素1.答案2.思路 4.用栈实现队列1.答案2.思路 5.删除字符串中的所有相邻重复项1.答案2.思路 1.有效的括号 1.答案 package com.sunxiansheng.arithmetic.day10;import java.util.Stack;/*** Descripti…...

IO进程----进程
进程 什么是进程 进程和程序的区别 概念: 程序:编译好的可执行文件 存放在磁盘上的指令和数据的有序集合(文件) 程序是静态的,没有任何执行的概念 进程:一个独立的可调度的任务 执行一个程序分配资…...

【机器学习实战高阶】基于深度学习的图像分割
机器学习项目图像分割 你可能已经注意到,大脑如何快速高效地识别并分类眼睛感知到的事物。大脑以某种方式进行训练,以便能够从微观层面分析所有内容。这种能力有助于我们从一篮子橙子中分辨出一个苹果。 计算机视觉是计算机科学的一个领域,…...

「免填邀请码」赋能各类APP,提升转化率与用户体验
在当前移动互联网的高速发展下,用户获取和留存已成为各类APP成功的关键。传统的注册流程虽然能够有效识别用户来源并进行用户管理,但随着市场竞争的激烈,复杂的注册和绑定步骤往往会成为用户流失的瓶颈。免填邀请码技术,结合自研的…...

基于海思soc的智能产品开发(视频的后续开发)
【 声明:版权所有,欢迎转载,请勿用于商业用途。 联系信箱:feixiaoxing 163.com】 前面我们讨论了camera,也讨论了屏幕驱动,这些都是基础的部分。关键是,我们拿到了这些视频数据之后,…...

创建 pdf 合同模板
创建 pdf 合同模板 一、前言二、模板展示三、制作过程 一、前言 前段时间要求创建“pdf”模板,学会了后感觉虽然简单,但开始也折腾了好久,这里做个记录。 二、模板展示 要创建这样的模板 三、制作过程 新建一个“Word”,这里命…...

2024 年度学习总结
目录 1. 前言 2. csdn 对于我的意义 3. 写博客的初衷 3.1 现在的想法 4. 写博客的意义 5. 关于生活和博客创作 5.1 写博客较于纸质笔记的优势 6. 致 2025 1. 前言 不知不觉, 来到 csdn 已经快一年了, 在这一年中, 我通过 csdn 学习到了很多知识, 结识了很多的良师益友…...

CSS笔记基础篇02——浮动、标准流、定位、CSS精灵、字体图标
黑马程序员视频地址: 前端Web开发HTML5CSS3移动web视频教程https://www.bilibili.com/video/BV1kM4y127Li?vd_source0a2d366696f87e241adc64419bf12cab&spm_id_from333.788.videopod.episodes&p70https://www.bilibili.com/video/BV1kM4y127Li?vd_source…...

C++ 面向对象(继承)
三、继承 3.1 继承的概念 基于一个已有的类 去重新定义一个新的类,这种方式我们叫做继承 关于继承的称呼 一个类B 继承来自 类 A 我们一般称呼 A类:父类 基类 B类: 子类 派生类 B继承自A A 派生了B 示例图的语法 class vehicle // 车类 {}class …...

UDP(Echoserver)
网络命令 Ping 命令 检测网络是否连通 使用方法: ping -c 次数 网址ping -c 3 www.baidu.comnetstat 命令 netstat 是一个用来查看网络状态的重要工具. 语法:netstat [选项] 功能:查看网络状态 常用选项: n 拒绝显示别名&#…...

令牌桶 滑动窗口->限流 分布式信号量->限并发的原理 lua脚本分析介绍
文章目录 前言限流限制并发的实际理解限流令牌桶代码实现结果分析令牌桶lua的模拟实现原理总结: 滑动窗口代码实现结果分析lua脚本原理解析 限并发分布式信号量代码实现结果分析lua脚本实现原理 双注解去实现限流 并发结果分析: 实际业务去理解体会统一注…...
【C语言练习】080. 使用C语言实现简单的数据库操作
080. 使用C语言实现简单的数据库操作 080. 使用C语言实现简单的数据库操作使用原生APIODBC接口第三方库ORM框架文件模拟1. 安装SQLite2. 示例代码:使用SQLite创建数据库、表和插入数据3. 编译和运行4. 示例运行输出:5. 注意事项6. 总结080. 使用C语言实现简单的数据库操作 在…...

《基于Apache Flink的流处理》笔记
思维导图 1-3 章 4-7章 8-11 章 参考资料 源码: https://github.com/streaming-with-flink 博客 https://flink.apache.org/bloghttps://www.ververica.com/blog 聚会及会议 https://flink-forward.orghttps://www.meetup.com/topics/apache-flink https://n…...
HTML前端开发:JavaScript 常用事件详解
作为前端开发的核心,JavaScript 事件是用户与网页交互的基础。以下是常见事件的详细说明和用法示例: 1. onclick - 点击事件 当元素被单击时触发(左键点击) button.onclick function() {alert("按钮被点击了!&…...

智能分布式爬虫的数据处理流水线优化:基于深度强化学习的数据质量控制
在数字化浪潮席卷全球的今天,数据已成为企业和研究机构的核心资产。智能分布式爬虫作为高效的数据采集工具,在大规模数据获取中发挥着关键作用。然而,传统的数据处理流水线在面对复杂多变的网络环境和海量异构数据时,常出现数据质…...
uniapp 实现腾讯云IM群文件上传下载功能
UniApp 集成腾讯云IM实现群文件上传下载功能全攻略 一、功能背景与技术选型 在团队协作场景中,群文件共享是核心需求之一。本文将介绍如何基于腾讯云IMCOS,在uniapp中实现: 群内文件上传/下载文件元数据管理下载进度追踪跨平台文件预览 二…...

VisualXML全新升级 | 新增数据库编辑功能
VisualXML是一个功能强大的网络总线设计工具,专注于简化汽车电子系统中复杂的网络数据设计操作。它支持多种主流总线网络格式的数据编辑(如DBC、LDF、ARXML、HEX等),并能够基于Excel表格的方式生成和转换多种数据库文件。由此&…...

Sklearn 机器学习 缺失值处理 获取填充失值的统计值
💖亲爱的技术爱好者们,热烈欢迎来到 Kant2048 的博客!我是 Thomas Kant,很开心能在CSDN上与你们相遇~💖 本博客的精华专栏: 【自动化测试】 【测试经验】 【人工智能】 【Python】 使用 Scikit-learn 处理缺失值并提取填充统计信息的完整指南 在机器学习项目中,数据清…...
SQL进阶之旅 Day 22:批处理与游标优化
【SQL进阶之旅 Day 22】批处理与游标优化 文章简述(300字左右) 在数据库开发中,面对大量数据的处理任务时,单条SQL语句往往无法满足性能需求。本篇文章聚焦“批处理与游标优化”,深入探讨如何通过批量操作和游标技术提…...