当前位置: 首页 > article >正文

【深度学习】多目标融合算法(五):定制门控网络CGC(Customized Gate Control)

目录

一、引言

二、CGC(Customized Gate Control,定制门控网络)

2.1 技术原理

2.2 技术优缺点

2.3 业务代码实践

2.3.1 业务场景与建模

2.3.2 模型代码实现

2.3.3 模型训练与推理测试

2.3.4 打印模型结构 

三、总结


一、引言

上一篇我们讲了MMoE多任务网络,通过对每一个任务塔建立Gate门控,对专家网络进行加权平均,Gate门控起到了对多个共享专家重要度筛选的作用。在每轮反向传播时,每个任务tower分别更新对应Gate的参数,以及共享专家的参数。模型主要起到了多目标任务平衡的作用。

今天我们重点将CGC(Customized Gate Control)定制门控网络,核心思想是在MMoE基础上,为每一个任务tower定制独享专家,实用任务独享专家与共享专家共同决定任务Tower的输入,相比于MMoE仅用Gate门控表征任务Tower的方法,CGC引入独享专家,对任务表征更加全面,又通过共享专家保证关联性。

二、CGC(Customized Gate Control,定制门控网络)

2.1 技术原理

CGC(Customized Gate Control)全称为定制门控网络,主要由多个任务塔、对应多组独享专家网络,对应多个门控网络以及一组共享专家网络,专家网络组内可以包含多个专家MLP。核心原理:样本input分别输入共享专家MLP、独立专家MLP、独立专家对应门控网络,门控网络输出为经过softmax的权重分布,维度对应共享专家数num_shared_experts和独立专家数num_task_experts的和,通过对独立专家输出和共享专家输出采用Gate门控加权平均后, 输入到对应的任务Tower。每个任务Tower输入自己对应的独享专家、共享专家、门控加权平均的输入。反向传播时,每个任务更新自己独享专家、独享门控以及共享专家的参数。

  • 共享专家网络:样本数据分别输入num_shared_experts个专家网络进行推理,每个共享专家网络实际上是一个多层感知机(MLP),输入维度为x,输出维度为output_experts_dim。
  • 独享专家网络:样本数据分别输入num_task_experts个专家网络进行推理,每个共享专家网络实际上是一个多层感知机(MLP),输入维度为x,输出维度为output_experts_dim。
  • 门控网络:样本数据输出各自任务对应的门控网络,每个门控网络可以是一个多层感知机,也可以是一个双层的交叉,主要是为了输出专家网络的加权平均权重。
  • 任务网络:对于每一个Task,将各自对应num_shared_experts个共享专家和num_task_experts个独立专家,基于对应gate门控网络的softmax加权平均,作为各自Task的输入,所有Task的输入统一维度均为output_experts_dim。

2.2 技术优缺点

相较于MMoE网络,CGC为每一个任务tower定制独享专家,实用任务独享专家与共享专家共同决定任务Tower的输入,相比于MMoE仅用Gate门控表征任务Tower的方法,CGC引入独享专家,对任务表征更加全面,又通过共享专家保证关联性。

优点:

  • 切断任务tower与其他任务独享专家的联系,使得独享专家能够更专注的学习本任务内的知识与信息。比如切断互动塔与点击专家的联系,只和互动专家同时迭代,让互动目标的学习更加纯粹。
  • 独享专家只受对应任务梯度的影响,不受其他任务梯度的影响,而共享专家可以被多个任务梯度同时更新。
  • 本质上,CGC就是在MMoE上新增了独享专家,MMoE仅有共享专家。

缺点: 

  • 相较于PLE、SNR等,没有学习到专家与专家之间的相互关系,层级堆叠不够。
  • 相较于DeepSeekMoE的路由方法,CGC还是过于定制化与单一话,专家组合不足。

2.3 业务代码实践

2.3.1 业务场景与建模

我们还是以小红书推荐场景为例,针对一个视频,用户可以点红心(互动),也可以点击视频进行播放(点击),针对互动和点击两个目标进行多目标建模

我们构建一个100维特征输入,1组共享专家网络(含2个共享专家),2组独享专家网络(各含2个独享专家),2个门控,2个任务塔的CGC网络,用于建模多目标学习问题,模型架构图如下:

​​​​​​​​​​​​​​

如架构图所示,其中有几个注意的点:

  • num_shared_experts+num_task_expertsGate的维度等于共享专家的维度加上任务独享专家的维度。
  • output_experts_dim:共享专家、独享专家网络的输出维度和task网络的输入维度相同,task网络承接的是专家网络各维度的加权平均值,experts网络与task网络是直接对应关系。
  • Softmax:Gate门控网络对共享专家和独享专家的偏好权重采用Softmax归一化,保证专家网络加权平均后值域相同

2.3.2 模型代码实现

基于pytorch,实现上述CGC网络架构,如下:

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDatasetclass CGCModel(nn.Module):def __init__(self, input_dim, experts_hidden1_dim, experts_hidden2_dim, output_experts_dim, task_hidden1_dim, task_hidden2_dim, output_task1_dim, output_task2_dim, gate_hidden1_dim, gate_hidden2_dim, num_shared_experts, num_task_experts):super(CGCModel, self).__init__()# 初始化函数外使用初始化变量需要赋值,否则默认使用全局变量# 初始化函数内使用初始化变量不需要赋值 self.num_shared_experts = num_shared_expertsself.num_task_experts = num_task_expertsself.output_experts_dim = output_experts_dim# 初始化共享专家self.shared_experts_2 = nn.ModuleList([nn.Sequential(nn.Linear(input_dim, experts_hidden1_dim),nn.ReLU(),nn.Linear(experts_hidden1_dim, experts_hidden2_dim),nn.ReLU(),nn.Linear(experts_hidden2_dim, output_experts_dim),nn.ReLU()) for _ in range(num_shared_experts)])# 初始化任务1专家self.task1_experts_2 = nn.ModuleList([nn.Sequential(nn.Linear(input_dim, experts_hidden1_dim),nn.ReLU(),nn.Linear(experts_hidden1_dim, experts_hidden2_dim),nn.ReLU(),nn.Linear(experts_hidden2_dim, output_experts_dim),nn.ReLU()) for _ in range(num_task_experts)])# 初始化任务2专家self.task2_experts_2 = nn.ModuleList([nn.Sequential(nn.Linear(input_dim, experts_hidden1_dim),nn.ReLU(),nn.Linear(experts_hidden1_dim, experts_hidden2_dim),nn.ReLU(),nn.Linear(experts_hidden2_dim, output_experts_dim),nn.ReLU()) for _ in range(num_task_experts)])# 初始化门控网络任务1self.gating1_network_2 = nn.Sequential(nn.Linear(input_dim, gate_hidden1_dim),nn.ReLU(),nn.Linear(gate_hidden1_dim, gate_hidden2_dim),nn.ReLU(),nn.Linear(gate_hidden2_dim, num_shared_experts+num_task_experts),nn.Softmax(dim=1))# 初始化门控网络任务2self.gating2_network_2 = nn.Sequential(nn.Linear(input_dim, gate_hidden1_dim),nn.ReLU(),nn.Linear(gate_hidden1_dim, gate_hidden2_dim),nn.ReLU(),nn.Linear(gate_hidden2_dim, num_shared_experts+num_task_experts),nn.Softmax(dim=1))# 定义任务1的输出层self.task1_head = nn.Sequential(nn.Linear(output_experts_dim, task_hidden1_dim),nn.ReLU(),nn.Linear(task_hidden1_dim, task_hidden2_dim),nn.ReLU(),nn.Linear(task_hidden2_dim, output_task1_dim),nn.Sigmoid()) # 定义任务2的输出层self.task2_head = nn.Sequential(nn.Linear(output_experts_dim, task_hidden1_dim),nn.ReLU(),nn.Linear(task_hidden1_dim, task_hidden2_dim),nn.ReLU(),nn.Linear(task_hidden2_dim, output_task2_dim),nn.Sigmoid()) def forward(self, x):gates1 = self.gating1_network_2(x)gates2 = self.gating2_network_2(x)#定义专家网络输出作为任务塔输入batch_size, _ = x.shapetask1_inputs = torch.zeros(batch_size, self.output_experts_dim)task2_inputs = torch.zeros(batch_size, self.output_experts_dim)for i in range(self.num_shared_experts):task1_inputs += self.shared_experts_2[i](x) * gates1[:, i].unsqueeze(1) + self.task1_experts_2[i](x) * gates1[:, i+self.num_shared_experts].unsqueeze(1)task2_inputs += self.shared_experts_2[i](x) * gates2[:, i].unsqueeze(1) + self.task2_experts_2[i](x) * gates2[:, i+self.num_shared_experts].unsqueeze(1)task1_outputs = self.task1_head(task1_inputs)task2_outputs = self.task2_head(task2_inputs)return task1_outputs, task2_outputs# 实例化模型对象
experts_hidden1_dim = 64
experts_hidden2_dim = 32
output_experts_dim = 16
gate_hidden1_dim = 16
gate_hidden2_dim = 8
task_hidden1_dim = 32
task_hidden2_dim = 16
output_task1_dim = 1
output_task2_dim = 1
num_shared_experts = 2
num_task_experts = 2# 构造虚拟样本数据
torch.manual_seed(42)  # 设置随机种子以保证结果可重复
input_dim = 100
num_samples = 1024
X_train = torch.randint(0, 2, (num_samples, input_dim)).float()
y_train_task1 = torch.rand(num_samples, output_task1_dim)  # 假设任务1的输出维度为1
y_train_task2 = torch.rand(num_samples, output_task2_dim)  # 假设任务2的输出维度为1# 创建数据加载器
train_dataset = TensorDataset(X_train, y_train_task1, y_train_task2)
train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)model = CGCModel(input_dim, experts_hidden1_dim, experts_hidden2_dim, output_experts_dim, task_hidden1_dim, task_hidden2_dim, output_task1_dim, output_task2_dim, gate_hidden1_dim, gate_hidden2_dim, num_shared_experts, num_task_experts)# 定义损失函数和优化器
criterion_task1 = nn.MSELoss()
criterion_task2 = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)# 训练循环
num_epochs = 100
for epoch in range(num_epochs):model.train()running_loss = 0.0for batch_idx, (X_batch, y_task1_batch, y_task2_batch) in enumerate(train_loader):# 前向传播: 获取预测值#print(batch_idx, X_batch )#print(f'Epoch [{epoch+1}/{num_epochs}-{batch_idx}], Loss: {running_loss/len(train_loader):.4f}')outputs_task1, outputs_task2 = model(X_batch)# 计算每个任务的损失loss_task1 = criterion_task1(outputs_task1, y_task1_batch)loss_task2 = criterion_task2(outputs_task2, y_task2_batch)total_loss = loss_task1 + loss_task2# 反向传播和优化optimizer.zero_grad()total_loss.backward()optimizer.step()running_loss += total_loss.item()if epoch % 10 == 0:print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss/len(train_loader):.4f}')print(model)
#for param_tensor in model.state_dict():
#    print(param_tensor, "\t", model.state_dict()[param_tensor].size())
# 模型预测
model.eval()
with torch.no_grad():test_input = torch.randint(0, 2, (1, input_dim)).float()  # 构造一个测试样本pred_task1, pred_task2 = model(test_input)print(f'互动目标预测结果: {pred_task1}')print(f'点击目标预测结果: {pred_task2}')

相比于上一篇MMoE中的代码,CGC复杂了很多,新增了2组独享专家,且在门控与独享、共享专家加权平均计算的时候需要进行处理,很容易出问题。

2.3.3 模型训练与推理测试

运行上述代码,模型启动训练,Loss逐渐收敛,测试结果如下:

2.3.4 打印模型结构 ​​​​​​​

三、总结

本文详细介绍了CGC多任务模型的算法原理、算法优势,他是下一篇PLE多层多任务模型的基础,并以小红书业务场景为例,构建CGC网络结构并使用pytorch代码实现对应的网络结构、训练流程。相比于MMoE,CGC新增独享专家网络,通过gate门控的串联,切断任务Tower与其他任务独享专家的联系,使得独享专家能够更专注的学习本任务内的知识与信息。

如果您还有时间,欢迎阅读本专栏的其他文章:

【深度学习】多目标融合算法(一):样本Loss加权(Sample Loss Reweight)

【深度学习】多目标融合算法(二):底部共享多任务模型(Shared-Bottom Multi-task Model) ​​​​​​​

【深度学习】多目标融合算法(三):混合专家网络MOE(Mixture-of-Experts) 

 【深度学习】多目标融合算法(四):多门混合专家网络MMOE(Multi-gate Mixture-of-Experts)​​​​​​​

相关文章:

【深度学习】多目标融合算法(五):定制门控网络CGC(Customized Gate Control)

目录 一、引言 二、CGC(Customized Gate Control,定制门控网络) 2.1 技术原理 2.2 技术优缺点 2.3 业务代码实践 2.3.1 业务场景与建模 2.3.2 模型代码实现 2.3.3 模型训练与推理测试 2.3.4 打印模型结构 三、总结 一、引言 上一…...

【NLP 42、实践 ⑪ 用Bert模型结构实现自回归语言模型的训练】

如果结局早已注定,那么过程就将大于结局 —— 25.3.18 自回归语言模型:由前文预测后文的语言模型 特点:单向 训练方式:利用前n个字预测第n1个字,实现一个mask矩阵,送入Bert模型,让其前文看不到…...

TCP | 序列号和确认号 [逐包分析] | seq / ack 详解

注 : 本文为 “TCP 序号(seq)与确认序号(ack)” 相关文章合辑。 英文引文,机翻未校。 中文引文,略作重排。 如有内容异常,请看原文。 Understanding TCP Seq & Ack Numbers […...

在Linux、Windows系统上安装开源InfluxDB——InfluxDB OSS v2并设置开机自启的保姆级图文教程

一、进入InfluxDB下载官网 InfluxData 文档https://docs.influxdata.com/Install InfluxDB OSS v2 | InfluxDB OSS v2 Documentation...

考研复习之队列

循环队列 队列为满的条件 队列为满的条件需要特殊处理,因为当队列满时,队尾指针的下一个位置应该是队头指针。但是,我们不能直接比较 rear 1 和 front 是否相等,因为 rear 1 可能会超出数组索引的范围。因此,我们需…...

Day 21: 数组中的逆序对

在股票交易中,如果前一天的股价高于后一天的股价,则可以认为存在一个「交易逆序对」。请设计一个程序,输入一段时间内的股票交易记录 record,返回其中存在的「交易逆序对」总数。 示例 1: 输入:record [9…...

【C++教程】setw()函数的使用方法

setw 是 C 中用于设置输出字段宽度的函数&#xff0c;属于 <iomanip> 头文件。以下是其使用方法及注意事项&#xff1a; 基本用法 包含头文件&#xff1a; #include <iostream> #include <iomanip> // 必须包含此头文件 using namespace std; // 避免写 std…...

智慧高速,安全护航:视频监控平台助力高速公路高效运营

随着我国高速公路里程的不断增长&#xff0c;交通安全和运营效率面临着前所未有的挑战。传统的监控方式已难以满足现代化高速公路管理的需求&#xff0c;而监控视频平台的出现&#xff0c;则为高速公路的安全运营提供了强有力的技术支撑。高速公路视频监控联网解决方案 高速公路…...

Jboss漏洞再现

一、CVE-2015-7501 1、开环境 2、访问地址 / invoker/JMXInvokerServlet 出现了让下载的页面&#xff0c;说明有漏洞 3、下载ysoserial工具进行漏洞利用 4、在cmd运行 看到可以成功运行&#xff0c;接下来去base64编码我们反弹shell的命令 5、执行命令 java -jar ysoserial-…...

C语言三大程序结构 单分支语句

核心概念&#xff1a;程序就像流水线&#xff0c;通过顺序、选择、循环三种结构完成复杂任务 &#x1f31f; 一、三大程序结构图解 结构类型形象比喻代码示例顺序直行马路 → 不拐弯printf("A"); printf("B");选择岔路口 → 二选一if...else循环环形跑道 …...

【Linux系统】Linux权限讲解!!!超详细!!!

目录 Linux文件类型 区分方法 文件类型 Linux用户 用户创建与删除 用户之间的转换 su指令 普通用户->超级用户(root) 超级用户(root) ->普通用户 普通账户->普通账户 普通用户的权限提高 sudo指令 注&#xff1a; Linux权限 定义 权限操作 1、修改文…...

Winform零基础从入门到精通(5)——WinForm菜单与工具栏开发详解

一、核心控件与功能 MenuStrip&#xff08;顶部菜单栏&#xff09; • 功能&#xff1a;创建应用程序主菜单&#xff0c;支持多级子菜单和快捷键。 • 关键操作&#xff1a; ◦ 添加菜单项&#xff1a;直接在菜单栏输入文字&#xff08;如“文件(F)”&#xff09;&#xff0c;…...

2.创建Collection、添加索引、加载内存、预览和搜索数据

milvus官方文档 milvus2.3.1的官方文档地址: https://milvus.io/docs/v2.3.x 使用attu创建collection collection必须要有一个主键字段、向量字段 确保字段类型与索引类型兼容 字符串类型&#xff08;VARCHAR&#xff09;通常需要使用 Trie 索引&#xff0c;而不是 AutoInd…...

AIGC 新势力:探秘海螺 AI 与蓝耘 MaaS 平台的协同创新之旅

探秘海螺AI&#xff1a;多模态架构下的认知智能新引擎 在人工智能持续进阶的进程中&#xff0c;海螺AI作为一款前沿的多功能AI工具&#xff0c;正凭借其独特的多模态架构崭露头角。它由上海稀宇科技有限公司&#xff08;MiniMax&#xff09;精心打造&#xff0c;依托自研的万亿…...

一文解读DeepSeek在法律商业仲裁细分行业的应用

引言 当AI闯入法律界&#xff1a;DeepSeek如何把商业仲裁变成“纠纷快车道” AI技术正在像水电煤一样渗透生活&#xff0c;随着DeepSeek的爆火出圈&#xff0c;全国各行各业都在如火如荼地接入DeepSeek&#xff0c;以期望利用DeepSeek的“超能力”来重塑各自行业的效能和格局&a…...

快速入手-基于Django的主子表间操作mysql(五)

1、如果该表中存在外键&#xff0c;结合实际业务情况&#xff0c;那可以这么写&#xff1a; 2、针对特殊的字典类型&#xff0c;可以这么定义 3、获取元组中的字典值和子表中的value值方法 4、对应的前端页面写法...

HTTPS协议—加密算法和中间攻击人的博弈

活动发起人小虚竹 想对你说&#xff1a; 这是一个以写作博客为目的的创作活动&#xff0c;旨在鼓励大学生博主们挖掘自己的创作潜能&#xff0c;展现自己的写作才华。如果你是一位热爱写作的、想要展现自己创作才华的小伙伴&#xff0c;那么&#xff0c;快来参加吧&#xff01…...

【大模型理论篇】CogVLM:多模态预训练语言模型

1. 模型背景 前两天我们在《Skywork R1V: Pioneering Multimodal Reasoning with Chain-of-Thought》中介绍了将ViT与推理模型结合构造多模态推理模型的案例&#xff0c;其中提到了VLM的应用。追溯起来就是两篇前期工作&#xff1a;Vision LLM以及CogVLM。 今天准备回顾一下Cog…...

AI知识补全(一):tokens是什么?

名人说&#xff1a;苔花如米小&#xff0c;也学牡丹开。——袁枚《苔》 创作者&#xff1a;Code_流苏(CSDN)&#xff08;一个喜欢古诗词和编程的Coder&#x1f60a;&#xff09; 目录 一、什么是Tokens&#xff1f;二、为什么Tokens如此重要&#xff1f;1.模型的输入输出限制2.…...

Wpf Avalonia-实现中英文切换工程

文章目录 language工程项目代码创建获取资源文件string工程图片主项目引用LanguageView中使用ViewModel中使用language工程项目 <Project Sdk="Microsoft.NET.Sdk"><PropertyGroup><ImplicitUsings>enable</ImplicitUsings><TargetFrame…...

pyqt5报错:qt.qpa.plugin: Could not find the Qt platform plugin “xcb“(已解决)

我在使用pyqt库的时候报错&#xff1a; qt.qpa.plugin: Could not load the Qt platform plugin "xcb" in \ "/mnt/private_disk/anaconda3/envs/aot-manip/lib/python3.8/site-packages/PyQt5/Qt5/plugins/platforms" even though it was found. This ap…...

【MySQL数据库】触发器与事件

MySQL触发器 trigger&#xff0c;在表的插入insert、更新update、删除delete操作发生时自动执行MySQL语句。 学过Qt的都知道信号槽&#xff0c;一旦发出某个信号&#xff0c;那么就会触发关联的信号槽函数。触发器就类似于这个操作。 创建触发器时需要给出一些信息&#xff…...

【LC插件开发】基于Java实现FSRS(自由间隔重复调度算法)

&#x1f60a;你好&#xff0c;我是小航&#xff0c;一个正在变秃、变强的文艺倾年。 &#x1f514;本文讲解【LC插件开发】基于Java实现FSRS&#xff08;自由间隔重复调度算法&#xff09;&#xff0c;期待与你一同探索、学习、进步&#xff0c;一起卷起来叭&#xff01; 目录…...

java后端接收数组,数组长度超256个就会报错

1.原因 DataBinder 中默认限制了list最大只能增长到256。 2.解决方案 1.在BaseController添加InitBinder方法&#xff0c;其余继承BaseController InitBinder //类初始化是调用的方法注解public void initBinder(WebDataBinder binder) {//给这个controller配置接收list的长…...

第45章:配置更新与应用热重载策略

第45章:配置更新与应用热重载策略 作者:DogDog_Shuai 阅读时间:约25分钟 难度:中级 目录 1. 引言2. 配置更新挑战3. Kubernetes原生配置更新机制4. 应用热重载技术5. 配置更新最佳实践...

数据库MVCC详解

MVCC 1.基本介绍 数据库&#xff1a;MySQL。【很多主流数据库都使用了MVCC&#xff0c;比如MySQL的InnoDB引擎、PostgreSQL、Oracle】 MVCC&#xff0c;全称Multi-Version Concurrency Control&#xff0c;即多版本并发控制。是数据库管理系统中的一种并发控制方法。 MVCC的…...

MySQL数据库基础篇

目录 SQL的分类 数据定义语言&#xff08;DDL&#xff09;---Data Definition Language 数据操作语言(DML) ---Data Manipulation Language 数据查询语言(DQL) ---Data Query Language 数据控制语言(DCL) ---Data Control Language 事务控制语言(TCL) --- Transaction Cont…...

Rust函数、条件语句、循环

文章目录 函数**语句与表达式**条件语句循环 函数 Rust的函数基本形式是这样的 fn a_func(a: i32) -> i32 {}函数名是蛇形风格&#xff0c;rust不在意函数的声明顺序&#xff0c;只需要有声明即可 函数参数必须声明参数名称和类型 语句与表达式 这是rust非常重要的基础…...

AI比人脑更强,因为被植入思维模型【17】万物联系思维模型

万物联系,万物,并不孤立。 定义 万物联系思维模型是一种强调世界上所有事物都相互关联、相互影响的思维方式。它认为任何事物都不是孤立存在的,而是与周围的环境、其他事物以及整个宇宙构成一个有机的整体。这种联系不仅包括直接的因果关系,还涵盖了间接的、潜在的、动态的…...

Android Compose 约束布局(ConstraintLayout、Modifier.constrainAs)源码深度剖析(十二)

Android Compose 约束布局&#xff08;ConstraintLayout、Modifier.constrainAs&#xff09;源码深度剖析 一、引言 在 Android 开发中&#xff0c;布局是构建用户界面的基础。随着 Android 开发技术的不断发展&#xff0c;Jetpack Compose 作为一种全新的声明式 UI 框架应运…...