当前位置: 首页 > news >正文

【2024 CSDN博客之星】技术洞察类:从DeepSeek-V3的成功,看MoE混合专家网络对深度学习算法领域的影响(MoE代码级实战)

目录

一、引言

1.1 本篇文章侧重点

1.2 技术洞察—MoE(Mixture-of-Experts,混合专家网络)

二、MoE(Mixture-of-Experts,混合专家网络)

2.1 技术原理

2.2 技术优缺点

2.3 业务代码实践

2.3.1 业务场景与建模

2.3.2 模型代码实现

2.3.3 模型训练与推理测试

2.3.4 打印模型结构 

三、总结


一、引言

经历了大模型2024一整年度的兵荒马乱,从年初的Sora文生视频到MiniMax顿悟后的开源,要说年度最大赢家,当属deepseek莫属:年中,deepseek-v2以其1/100的售价,横扫包括gpt4、qwen、百度等一系列商用模型;年底,deepseek-v3发布,以MoE为核心的专家网络技术,让其以极低的推理成本,获得了媲美gpt-4o的效果。

1.1 本篇文章侧重点

本篇文章作为年度技术洞察类文章,今天的重点不是deepseek的训练与推理,如果对训练和推理感兴趣,我在年中写过一篇训练与推理的实战,其中详细讲述了DeepSeek-V2大模型的训练和推理,详细可点击:AI智能体研发之路-模型篇(二):DeepSeek-V2-Chat 训练与推理实战(只需将V2替换为V3,即可体验最新版本deepseek)。今天的重点是更深一个层次,带大家代码级认识MoE混合专家网络技术。

1.2 技术洞察—MoE(Mixture-of-Experts,混合专家网络)

MoE(Mixture-of-Experts) 并不是一个新词,近7-8年间,在我做推荐系统精排模型过程中,业界将MoE技术应用于推荐系统多任务学习,以MMoE(2018,google)、PLE(2020,腾讯)为基石,通过门控网络为多个专家网络加权平均,定义每个专家的重要性,解决多目标、多场景、多任务等问题。近1-2年间,基于MoE思想构建的大模型层出不穷,通过路由网络对多个专家网络进行选择,提升推理效率,经典模型有DeepSeekMoE、Mixtral 8x7B、Flan-MoE等。 

万丈高楼平地起,今天我们不聊空中楼阁,而是带大家实现一个MoE网络,了解MoE代码是怎么构建的,大家可以以此代码为基础,继续垒砖,根据自己的业务场景,创新性的构建自己的专家网络。 

二、MoE(Mixture-of-Experts,混合专家网络)

2.1 技术原理

MoE(Mixture-of-Experts)全称为混合专家网络,主要由多个专家网络、多个任务塔、门控网络构成。核心原理:样本数据分别输入num_experts个专家网络进行推理,每个专家网络实际上是一个前馈神经网络(MLP),输入维度为x,输出维度为output_experts_dim;同时,样本数据输入门控网络,门控网络也是一个MLP(可以为多层,也可以为一层),输出为num_experts个experts专家的概率分布,维度为num_experts(菜用softmax将输出归一化,各个维度加起来和为1);将每个专家网络的输出,基于gate门控网络的softmax加权平均,作为Task的输入,所以Task的输入统一维度均为output_experts_dim。在每次反向传播迭代时,对Gate和num_experts个专家参数进行更新,Gate和专家网络的参数受任务Task A、B共同影响。

  • 专家网络:样本数据分别输入num_experts个专家网络进行推理,每个专家网络实际上是一个前馈神经网络(MLP),输入维度为x,输出维度为output_experts_dim。
  • ​​​​​​​门控网络:样本数据输入门控网络,门控网络也是一个MLP(可以为多层,也可以为一层),输出为num_experts个experts专家的概率分布,维度为num_experts(菜用softmax将输出归一化,各个维度加起来和为1)。
  • 任务网络:将每个专家网络的输出,基于gate门控网络的softmax加权平均,作为Task的输入,Task的输入统一维度均为output_experts_dim。

2.2 技术优缺点

相较于传统的DNN网络,MoE的本质是通过多个专家网络对预估任务共同决策,引入Gate作为专家的裁判,给每一个专家打分,判定哪个专家更加权威。(DeepSeekMoE的Router与Gate类似,区别是Gate为每一个专家赋分,加权平均,Router对专家进行选择,推理速度更快)。相较于传统的DNN网络:

优点:

  • 多个DNN专家网络投票共同决定推理结果,相较于单个DNN网络泛化性更好,准确率更高。
  • Gate网络基于多个Task任务进行反馈收敛,可以学到多个Task任务数据的平衡性。

缺点: 

  • 朴素的MoE仅使用了一个Gate网络,虽然Gate网络由多个Task任务共同收敛学习得到,具有一定的平衡性,但对于每个Task的个性化能力仍然不足。(Google针对此缺点发布了MMoE)
  • 底层多个专家网络均为共享专家,输入均为样本数据,参数的差异主要由初始化的不同得到,并不具备特异性。(腾讯针对此缺点发布了PLE)
  • 输入Input均为全部样本数据,学不出不同场景任务的差异性,需要在输入层对场景特征进行拆分(阿里针对此缺点发布了ESMM)

2.3 业务代码实践

2.3.1 业务场景与建模

我们仍然以小红书推荐场景为例,用户在一级发现页场景中停留并点击了“误杀3”中的一个视频笔记,在二级场景视频播放页中观看并点赞了视频。

我们构建一个100维特征输入,4个experts专家网络,2个task任务的,1个门控的MoE网络,用于建模跨场景多任务学习问题,模型架构图如下:

​​​​​​​

如架构图所示,其中有几个注意的点:

  1. num_experts:门控gate的输出维度和专家数相同,均为num_experts,因为gate的用途是对专家网络最后一层进行加权平均,gate维度与专家数是直接对应关系。
  2. output_experts_dim:专家网络的输出维度和task网络的输入维度相同,task网络承接的是专家网络各维度的加权平均值,experts网络与task网络是直接对应关系。
  3. Softmax:Gate门控网络对最后一层采用Softmax归一化,保证专家网络加权平均后值域相同

2.3.2 模型代码实现

基于pytorch,实现上述网络架构,如下:

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDatasetclass MoEModel(nn.Module):def __init__(self, input_dim, experts_hidden1_dim, experts_hidden2_dim, output_experts_dim, task_hidden1_dim, task_hidden2_dim, output_task1_dim, output_task2_dim, gate_hidden1_dim, gate_hidden2_dim, num_experts):super(MoEModel, self).__init__()self.num_experts = num_expertsself.output_experts_dim = output_experts_dim# 初始化多个专家网络self.experts = nn.ModuleList([nn.Sequential(nn.Linear(input_dim, experts_hidden1_dim),nn.ReLU(),nn.Linear(experts_hidden1_dim, experts_hidden2_dim),nn.ReLU(),nn.Linear(experts_hidden2_dim, output_experts_dim),nn.ReLU()) for _ in range(num_experts)])# 定义任务1的输出层self.task1_head = nn.Sequential(nn.Linear(output_experts_dim, task_hidden1_dim),nn.ReLU(),nn.Linear(task_hidden1_dim, task_hidden2_dim),nn.ReLU(),nn.Linear(task_hidden2_dim, output_task1_dim),nn.Sigmoid()) # 定义任务2的输出层self.task2_head = nn.Sequential(nn.Linear(output_experts_dim, task_hidden1_dim),nn.ReLU(),nn.Linear(task_hidden1_dim, task_hidden2_dim),nn.ReLU(),nn.Linear(task_hidden2_dim, output_task2_dim),nn.Sigmoid()) # 初始化门控网络self.gating_network = nn.Sequential(nn.Linear(input_dim, gate_hidden1_dim),nn.ReLU(),nn.Linear(gate_hidden1_dim, gate_hidden2_dim),nn.ReLU(),nn.Linear(gate_hidden2_dim, num_experts),nn.Softmax(dim=1))def forward(self, x):# 计算输入数据通过门控网络后的权重gates = self.gating_network(x)#print(gates)batch_size, _ = x.shapetask1_inputs = torch.zeros(batch_size, self.output_experts_dim)task2_inputs = torch.zeros(batch_size, self.output_experts_dim)# 计算每个专家的输出并加权求和for i in range(self.num_experts):expert_output = self.experts[i](x)task1_inputs += expert_output * gates[:, i].unsqueeze(1)task2_inputs += expert_output * gates[:, i].unsqueeze(1)task1_outputs = self.task1_head(task1_inputs)task2_outputs = self.task2_head(task2_inputs)return task1_outputs, task2_outputs# 实例化模型对象
num_experts = 4  # 假设有4个专家
experts_hidden1_dim = 64
experts_hidden2_dim = 32
output_experts_dim = 16
gate_hidden1_dim = 16
gate_hidden2_dim = 8
task_hidden1_dim = 32
task_hidden2_dim = 16
output_task1_dim = 3
output_task2_dim = 2# 构造虚拟样本数据
torch.manual_seed(42)  # 设置随机种子以保证结果可重复
input_dim = 10
num_samples = 1024
X_train = torch.randint(0, 2, (num_samples, input_dim)).float()
y_train_task1 = torch.rand(num_samples, output_task1_dim)  # 假设任务1的输出维度为5
y_train_task2 = torch.rand(num_samples, output_task2_dim)  # 假设任务2的输出维度为3# 创建数据加载器
train_dataset = TensorDataset(X_train, y_train_task1, y_train_task2)
train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)model = MoEModel(input_dim, experts_hidden1_dim, experts_hidden2_dim, output_experts_dim, task_hidden1_dim, task_hidden2_dim, output_task1_dim, output_task2_dim, gate_hidden1_dim, gate_hidden2_dim, num_experts)# 定义损失函数和优化器
criterion_task1 = nn.MSELoss()
criterion_task2 = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)# 训练循环
num_epochs = 100
for epoch in range(num_epochs):model.train()running_loss = 0.0for batch_idx, (X_batch, y_task1_batch, y_task2_batch) in enumerate(train_loader):# 前向传播: 获取预测值#print(batch_idx, X_batch )#print(f'Epoch [{epoch+1}/{num_epochs}-{batch_idx}], Loss: {running_loss/len(train_loader):.4f}')outputs_task1, outputs_task2 = model(X_batch)# 计算每个任务的损失loss_task1 = criterion_task1(outputs_task1, y_task1_batch)loss_task2 = criterion_task2(outputs_task2, y_task2_batch)total_loss = loss_task1 + loss_task2# 反向传播和优化optimizer.zero_grad()total_loss.backward()optimizer.step()running_loss += total_loss.item()if epoch % 10 == 0:print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss/len(train_loader):.4f}')print(model)
for param_tensor in model.state_dict():print(param_tensor, "\t", model.state_dict()[param_tensor].size())
# 模型预测
model.eval()
with torch.no_grad():test_input = torch.randint(0, 2, (1, input_dim)).float()  # 构造一个测试样本pred_task1, pred_task2 = model(test_input)print(f'一级场景预测结果: {pred_task1}')print(f'二级场景预测结果: {pred_task2}')

2.3.3 模型训练与推理测试

运行上述代码,模型启动训练,Loss逐渐收敛,测试结果如下:

2.3.4 打印模型结构 

使用print(model)打印模型结构如下

三、总结

本文代码级脚踏实地讲解了DeepSeek大模型、MMoE推荐模型中的MoE(Mixture-of-Experts)技术,该技术的主要思想是通过门控(gate)或路由(router)网络,对多个专家进行加权平均或筛选,将一个DNN网络裂变为多个DNN网络后,投票决定预测结果,相较于单一的DNN网络,具有更强的容错性、泛化性与准确性,同时可以提高推理速度,节省推理资源。

技术洞察结论:MoE技术未来将成为大模型和推荐系统进一步突破的关键技术,个人认为该技术为2024年算法基础技术中的SOTA,但其实并没有那么神秘,通过本篇文章,可以试着动手实现一个MoE,再基于自己的业务场景,对齐专家网络、门控网络、任务网络进行创新,期待本篇文章对您有帮助!

如果您还有时间,欢迎阅读本专栏的其他文章:

【深度学习】多目标融合算法(一):样本Loss加权(Sample Loss Reweight)

【深度学习】多目标融合算法(二):底部共享多任务模型(Shared-Bottom Multi-task Model) ​​​​​​​

相关文章:

【2024 CSDN博客之星】技术洞察类:从DeepSeek-V3的成功,看MoE混合专家网络对深度学习算法领域的影响(MoE代码级实战)

目录 一、引言 1.1 本篇文章侧重点 1.2 技术洞察—MoE(Mixture-of-Experts,混合专家网络) 二、MoE(Mixture-of-Experts,混合专家网络) 2.1 技术原理 2.2 技术优缺点 2.3 业务代码实践 2.3.1 业务场…...

Linux——入门基本指令汇总

目录 1. ls指令2. pwd3. whoami指令4. cd指令5. clear指令6. touch指令7. mkdir指令8. rm指令9. man指令10. cp指令11. mv指令12. cat指令13. tac指令14. more指令15. less指令16. head指令17. tail指令18. date指令19. cal指令20. find指令21. which指令22. alias指令23. grep…...

54,【4】BUUCTF WEB GYCTF2020Ezsqli

进入靶场 吓我一跳,但凡放个彭于晏我都不说啥了 提交个1看看 1 and 11 1# 还尝试了很多,不过都被过滤了,头疼 看看别人的WP 竟然要写代码去跑!!!,不会啊,先用别人的代码吧&#xf…...

【Leetcode 热题 100】45. 跳跃游戏 II

问题背景 给定一个长度为 n n n 的 0 0 0 索引 整数数组 n u m s nums nums。初始位置为 n u m s [ 0 ] nums[0] nums[0]。 每个元素 n u m s [ i ] nums[i] nums[i] 表示从索引 i i i 向前跳转的最大长度。换句话说,如果你在 n u m s [ i ] nums[i] nums[i…...

C/C++ 时间复杂度(On)

定义: 在计算机科学中,时间复杂性,又称时间复杂度,算法的时间复杂度是一个函数,它定性描述该算法的运行时间。这是一个代表算法输入值的字符串的长度的函数。时间复杂度常用大O符号表述,不包括这个函数的低…...

【STM32-学习笔记-10-】BKP备份寄存器+时间戳

文章目录 BKP备份寄存器Ⅰ、BKP简介1. BKP的基本功能2. BKP的存储容量3. BKP的访问和操作4. BKP的应用场景5. BKP的控制寄存器 Ⅱ、BKP基本结构Ⅲ、BKP函数Ⅳ、BKP使用示例 时间戳一、Unix时间戳二、时间戳的转换(time.h函数介绍)Ⅰ、time()Ⅱ、mktime()…...

React 中hooks之 React.memo 和 useMemo用法总结

1. React.memo 基础 React.memo 是一个高阶组件(HOC),用于优化函数组件的性能,通过记忆组件渲染结果来避免不必要的重新渲染。 1.1 基本用法 const MemoizedComponent React.memo(function MyComponent(props) {/* 渲染逻辑 *…...

日志收集Day001

1.ElasticSearch 作用:日志存储和检索 2.单点部署Elasticsearch与基础配置 rpm -ivh elasticsearch-7.17.5-x86_64.rpm 查看配置文件yy /etc/elasticsearch/elasticsearch.yml(这里yy做了别名,过滤掉空行和注释行) yy /etc/el…...

机器人“大脑+小脑”范式:算力魔方赋能智能自主导航

在机器人技术的发展中,“大脑小脑”的架构模式逐渐成为推动机器人智能化的关键。其中,“大脑”作为机器人的核心决策单元,承担着复杂任务规划、环境感知和决策制定的重要角色,而“小脑”则专注于运动控制和实时调整。这种分工明确…...

python程序跑起来后,然后引用的数据文件发生了更新,python读取的数据会发生变化吗

在 Python 程序运行过程中,如果引用的数据文件被更新,程序能否读取到更新后的数据,取决于以下几个因素: 1. 是否动态读取文件 如果 Python 程序在运行过程中动态读取文件(例如通过循环或定时机制反复打开文件读取&…...

VSCode最新离线插件拓展下载方式

之前在vscode商店有以下类似的download按钮,但是2025年更新之后这个按钮就不提供了,所以需要使用新的方式下载 ps:给自己的网站推广下~~(国内直连GPT/Claude) 新的下载方式1 首先打开vscode商店官网:vscode插件下载…...

算法题目总结-栈和队列

文章目录 1.有效的括号1.答案2.思路 2.最小栈1.答案2.思路 3.前 K 个高频元素1.答案2.思路 4.用栈实现队列1.答案2.思路 5.删除字符串中的所有相邻重复项1.答案2.思路 1.有效的括号 1.答案 package com.sunxiansheng.arithmetic.day10;import java.util.Stack;/*** Descripti…...

IO进程----进程

进程 什么是进程 进程和程序的区别 概念: 程序:编译好的可执行文件 存放在磁盘上的指令和数据的有序集合(文件) 程序是静态的,没有任何执行的概念 进程:一个独立的可调度的任务 执行一个程序分配资…...

【机器学习实战高阶】基于深度学习的图像分割

机器学习项目图像分割 你可能已经注意到,大脑如何快速高效地识别并分类眼睛感知到的事物。大脑以某种方式进行训练,以便能够从微观层面分析所有内容。这种能力有助于我们从一篮子橙子中分辨出一个苹果。 计算机视觉是计算机科学的一个领域,…...

「免填邀请码」赋能各类APP,提升转化率与用户体验

在当前移动互联网的高速发展下,用户获取和留存已成为各类APP成功的关键。传统的注册流程虽然能够有效识别用户来源并进行用户管理,但随着市场竞争的激烈,复杂的注册和绑定步骤往往会成为用户流失的瓶颈。免填邀请码技术,结合自研的…...

基于海思soc的智能产品开发(视频的后续开发)

【 声明:版权所有,欢迎转载,请勿用于商业用途。 联系信箱:feixiaoxing 163.com】 前面我们讨论了camera,也讨论了屏幕驱动,这些都是基础的部分。关键是,我们拿到了这些视频数据之后,…...

创建 pdf 合同模板

创建 pdf 合同模板 一、前言二、模板展示三、制作过程 一、前言 前段时间要求创建“pdf”模板,学会了后感觉虽然简单,但开始也折腾了好久,这里做个记录。 二、模板展示 要创建这样的模板 三、制作过程 新建一个“Word”,这里命…...

2024 年度学习总结

目录 1. 前言 2. csdn 对于我的意义 3. 写博客的初衷 3.1 现在的想法 4. 写博客的意义 5. 关于生活和博客创作 5.1 写博客较于纸质笔记的优势 6. 致 2025 1. 前言 不知不觉, 来到 csdn 已经快一年了, 在这一年中, 我通过 csdn 学习到了很多知识, 结识了很多的良师益友…...

CSS笔记基础篇02——浮动、标准流、定位、CSS精灵、字体图标

黑马程序员视频地址: 前端Web开发HTML5CSS3移动web视频教程https://www.bilibili.com/video/BV1kM4y127Li?vd_source0a2d366696f87e241adc64419bf12cab&spm_id_from333.788.videopod.episodes&p70https://www.bilibili.com/video/BV1kM4y127Li?vd_source…...

C++ 面向对象(继承)

三、继承 3.1 继承的概念 基于一个已有的类 去重新定义一个新的类,这种方式我们叫做继承 关于继承的称呼 一个类B 继承来自 类 A 我们一般称呼 A类:父类 基类 B类: 子类 派生类 B继承自A A 派生了B 示例图的语法 class vehicle // 车类 {}class …...

SkyWalking 10.2.0 SWCK 配置过程

SkyWalking 10.2.0 & SWCK 配置过程 skywalking oap-server & ui 使用Docker安装在K8S集群以外,K8S集群中的微服务使用initContainer按命名空间将skywalking-java-agent注入到业务容器中。 SWCK有整套的解决方案,全安装在K8S群集中。 具体可参…...

突破不可导策略的训练难题:零阶优化与强化学习的深度嵌合

强化学习(Reinforcement Learning, RL)是工业领域智能控制的重要方法。它的基本原理是将最优控制问题建模为马尔可夫决策过程,然后使用强化学习的Actor-Critic机制(中文译作“知行互动”机制),逐步迭代求解…...

在HarmonyOS ArkTS ArkUI-X 5.0及以上版本中,手势开发全攻略:

在 HarmonyOS 应用开发中,手势交互是连接用户与设备的核心纽带。ArkTS 框架提供了丰富的手势处理能力,既支持点击、长按、拖拽等基础单一手势的精细控制,也能通过多种绑定策略解决父子组件的手势竞争问题。本文将结合官方开发文档&#xff0c…...

【入坑系列】TiDB 强制索引在不同库下不生效问题

文章目录 背景SQL 优化情况线上SQL运行情况分析怀疑1:执行计划绑定问题?尝试:SHOW WARNINGS 查看警告探索 TiDB 的 USE_INDEX 写法Hint 不生效问题排查解决参考背景 项目中使用 TiDB 数据库,并对 SQL 进行优化了,添加了强制索引。 UAT 环境已经生效,但 PROD 环境强制索…...

Swift 协议扩展精进之路:解决 CoreData 托管实体子类的类型不匹配问题(下)

概述 在 Swift 开发语言中,各位秃头小码农们可以充分利用语法本身所带来的便利去劈荆斩棘。我们还可以恣意利用泛型、协议关联类型和协议扩展来进一步简化和优化我们复杂的代码需求。 不过,在涉及到多个子类派生于基类进行多态模拟的场景下,…...

【项目实战】通过多模态+LangGraph实现PPT生成助手

PPT自动生成系统 基于LangGraph的PPT自动生成系统,可以将Markdown文档自动转换为PPT演示文稿。 功能特点 Markdown解析:自动解析Markdown文档结构PPT模板分析:分析PPT模板的布局和风格智能布局决策:匹配内容与合适的PPT布局自动…...

今日科技热点速览

🔥 今日科技热点速览 🎮 任天堂Switch 2 正式发售 任天堂新一代游戏主机 Switch 2 今日正式上线发售,主打更强图形性能与沉浸式体验,支持多模态交互,受到全球玩家热捧 。 🤖 人工智能持续突破 DeepSeek-R1&…...

如何理解 IP 数据报中的 TTL?

目录 前言理解 前言 面试灵魂一问:说说对 IP 数据报中 TTL 的理解?我们都知道,IP 数据报由首部和数据两部分组成,首部又分为两部分:固定部分和可变部分,共占 20 字节,而即将讨论的 TTL 就位于首…...

【开发技术】.Net使用FFmpeg视频特定帧上绘制内容

目录 一、目的 二、解决方案 2.1 什么是FFmpeg 2.2 FFmpeg主要功能 2.3 使用Xabe.FFmpeg调用FFmpeg功能 2.4 使用 FFmpeg 的 drawbox 滤镜来绘制 ROI 三、总结 一、目的 当前市场上有很多目标检测智能识别的相关算法,当前调用一个医疗行业的AI识别算法后返回…...

大数据学习(132)-HIve数据分析

​​​​🍋🍋大数据学习🍋🍋 🔥系列专栏: 👑哲学语录: 用力所能及,改变世界。 💖如果觉得博主的文章还不错的话,请点赞👍收藏⭐️留言&#x1f4…...