当前位置: 首页 > news >正文

【深度学习】多目标融合算法(二):底部共享多任务模型(Shared-Bottom Multi-task Model)

 

目录

一、引言

1.1 往期回顾

1.2 本期概要

二、Shared-Bottom Multi-task Model(SBMM)

2.1 技术原理

2.2 技术优缺点

2.3 业务代码实践

三、总结


一、引言

在朴素的深度学习ctr预估模型中(如DNN),通常以一个行为为预估目标,比如通过ctr预估点击率。但实际推荐系统业务场景中,更多是多种目标融合的结果,比如视频推荐,会存在视频点击率、视频完整播放率、视频播放时长等多个目标,而多种目标如何更好的融合,在工业界与学术界均有较多内容产出,由于该环节对实际业务影响最为直接,特开此专栏对推荐系统深度学习多目标问题进行讲述。

1.1 往期回顾

上一篇文章主要介绍了推荐系统多目标算法中的“样本Loss加权”,该方法在训练时Loss乘以样本权重实现对多种目标的加权,通过引导Loss梯度的学习方向,让模型参数朝着你设定的权重方向去学习。

1.2 本期概要

今天进一步深化,主要介绍Shared-Bottom Multi-task Model算法,该算法中文可译为“底部共享多任务模型”,该算法设定多个任务,每个任务设定多个目标,通过“Loss计算时调整每个任务的权重”,亦或是“每个塔单元内,多目标Loss计算时调整每个目标的权重”进行多任务多目标的调整。

二、Shared-Bottom Multi-task Model(SBMM)

2.1 技术原理

Shared-Bottom Multi-task Model(SBMM)全称为底层共享多任务模型,主要由底层共享网络、多任务塔、多目标输出构成。核心原理:通过构造多任务多目标样本数据,在Loss计算环节,将各任务Loss求和(或加权求和),对Loss求导(求梯度)后,逐步后向传播迭代。

  1. 底部网络:Shared-Bottom 网络通常位于底部,可以为一个DNN网络,或者emb+pooling+mlp的方式对input输入的稀疏(sparse)特征进行稠密(dense)化。
  2. ​​​​​​​多个任务塔:底部网络上层接N个任务塔(Tower),每个塔根据需要可以定义为简单或复杂的多层感知器(mlp)网络。每个塔可以对应特定的场景,比如一二级页面场景。
  3. 多个目标:每个任务塔(Tower)可以输出多个学习目标,每个学习目标还可以像上一篇文章一样进行样本Loss加权。每个目标可以对应一种特定的指标行为,比如点击、时长、下单等。

2.2 技术优缺点

相比于上一篇文章提到的样本Loss加权融合法,以及后续文章将会介绍的MoE、MMoE方法,有如下优缺点:

优点:

  • 可以对多级场景任务进行建模,使得ctcvr等点击后转化问题可以被深度学习
  • 浅层参数共享,互相补充学习,任务相关性越高,模型的loss可以降低到更低

缺点: 

  • 跷跷板问题:任务没有好的相关性时,这种Hard parameter sharing会损害效果

2.3 业务代码实践

我们以小红书推荐场景为例,用户在一级发现页场景中停留并点击了“误杀3”中的一个视频笔记,在二级场景视频播放页中观看并点赞了视频。

跨场景多目标建模:我们定义一个SBMM算法结构,底层是一个3层的MLP(64,32,16),MLP出来后接一级场景Tower和二级场景Tower,一级场景任务中分别定义视频一级页“是否停留”、“停留时长”、“是否点击”,二级场景任务中分别定义“点击后播放时长”,“播放后是否点赞”

伪代码:

导入 pytorch 库
定义 SharedBottomMultiTaskModel 类 继承自 nn.Module:定义 __init__ 方法 参数 (self, 输入维度, 隐藏层1大小, 隐藏层2大小, 隐藏层3大小, 输出任务1维度, 输出任务2维度):初始化共享底部的三层全连接层初始化任务1的三层全连接层初始化任务2的三层全连接层定义 forward 方法 参数 (self, 输入数据):计算输入数据通过共享底部后的输出从共享底部输出分别计算任务1和任务2的结果返回任务1和任务2的结果生成虚拟样本数据:创建训练集和测试集实例化模型对象
定义损失函数和优化器
训练循环:前向传播: 获取预测值计算每个任务的损失反向传播和优化

PyTorch版本:

算法逻辑

  1. 导入必要的库。
  2. 定义一个类来表示共享底部和特定任务头部的模型结构。
  3. 在初始化方法中定义共享底部和两个独立的任务头部网络层。
  4. 实现前向传播函数,处理输入数据通过共享底部后分发到不同的任务头部。
  5. 生成虚拟样本数据。
  6. 定义损失函数和优化器。
  7. 编写训练循环。
  8. 进行模型预测。
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDatasetclass SharedBottomMultiTaskModel(nn.Module):def __init__(self, input_dim, hidden1_dim, hidden2_dim, hidden3_dim, output_task1_dim, output_task2_dim):super(SharedBottomMultiTaskModel, self).__init__()# 定义共享底部的三层全连接层self.shared_bottom = nn.Sequential(nn.Linear(input_dim, hidden1_dim),nn.ReLU(),nn.Linear(hidden1_dim, hidden2_dim),nn.ReLU(),nn.Linear(hidden2_dim, hidden3_dim),nn.ReLU())# 定义任务1的三层全连接层self.task1_head = nn.Sequential(nn.Linear(hidden3_dim, hidden2_dim),nn.ReLU(),nn.Linear(hidden2_dim, output_task1_dim))# 定义任务2的三层全连接层self.task2_head = nn.Sequential(nn.Linear(hidden3_dim, hidden2_dim),nn.ReLU(),nn.Linear(hidden2_dim, output_task2_dim))def forward(self, x):# 计算输入数据通过共享底部后的输出shared_output = self.shared_bottom(x)# 从共享底部输出分别计算任务1和任务2的结果task1_output = self.task1_head(shared_output)task2_output = self.task2_head(shared_output)return task1_output, task2_output# 构造虚拟样本数据
torch.manual_seed(42)  # 设置随机种子以保证结果可重复
input_dim = 10
task1_dim = 3
task2_dim = 2
num_samples = 1000
X_train = torch.randn(num_samples, input_dim)
y_train_task1 = torch.randn(num_samples, task1_dim)  # 假设任务1的输出维度为task1_dim
y_train_task2 = torch.randn(num_samples, task2_dim)  # 假设任务2的输出维度为task2_dim# 创建数据加载器
train_dataset = TensorDataset(X_train, y_train_task1, y_train_task2)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)# 实例化模型对象
model = SharedBottomMultiTaskModel(input_dim, 64, 32, 16, task1_dim, task2_dim)# 定义损失函数和优化器
criterion_task1 = nn.MSELoss()
criterion_task2 = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)# 训练循环
num_epochs = 10
for epoch in range(num_epochs):model.train()running_loss = 0.0for batch_idx, (X_batch, y_task1_batch, y_task2_batch) in enumerate(train_loader):# 前向传播: 获取预测值outputs_task1, outputs_task2 = model(X_batch)# 计算每个任务的损失loss_task1 = criterion_task1(outputs_task1, y_task1_batch)loss_task2 = criterion_task2(outputs_task2, y_task2_batch)#print(f'loss_task1:{loss_task1},loss_task2:{loss_task2}')total_loss = loss_task1 + loss_task2# 反向传播和优化optimizer.zero_grad()total_loss.backward()optimizer.step()running_loss += total_loss.item()print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss/len(train_loader):.4f}')# 模型预测
model.eval()
with torch.no_grad():test_input = torch.randn(1, input_dim)  # 构造一个测试样本pred_task1, pred_task2 = model(test_input)print(f'任务1预测结果: {pred_task1}')print(f'任务2预测结果: {pred_task2}')

三、总结

本文从技术原理、技术优缺点方面对推荐系统深度学习多任务多目标“Shared-Bottom Multi-task Model”算法进行讲解,该模型使用深度学习模型对多个任务场景多个目标的业务问题进行建模,使得用户在多个场景连续性行为可以被学习,在现实推荐系统业务中是比较基础的方法,后面本专栏还会陆续介绍MoE、MMoE等多任务多目标算法,期待您的关注和支持。

如果您还有时间,欢迎阅读本专栏的其他文章:

【深度学习】多目标融合算法(一):样本Loss加权(Sample Loss Reweight)

【深度学习】多目标融合算法(二):底部共享多任务模型(Shared-Bottom Multi-task Model) ​​​​​​​

相关文章:

【深度学习】多目标融合算法(二):底部共享多任务模型(Shared-Bottom Multi-task Model)

目录 一、引言 1.1 往期回顾 1.2 本期概要 二、Shared-Bottom Multi-task Model(SBMM) 2.1 技术原理 2.2 技术优缺点 2.3 业务代码实践 三、总结 一、引言 在朴素的深度学习ctr预估模型中(如DNN),通常以一个行…...

如何使用vue引入three.js

在 Vue.js 项目中引入和使用 Three.js 是一个常见的需求,Three.js 是一个用于在浏览器中创建和显示动画 3D 计算机图形的 JavaScript 库。以下是一个基本的示例,展示如何在 Vue 项目中引入和使用 Three.js。 1. 创建 Vue 项目 如果你还没有一个 Vue 项…...

城市生命线安全综合监管平台

【落地产品,有需要可留言联系,支持项目合作或源码合作】 一、建设背景 以关于城市安全的重要论述为建设纲要,聚焦城市安全重点领域,围绕燃气爆炸、城市内涝、地下管线交互风险、第三方施工破坏、供水爆管、桥梁坍塌、道路塌陷七…...

计算机毕设【开题报告】怎么写?

技巧 1. 标题简洁且具体 技巧:开题报告的标题要简明扼要,并准确表达研究的核心内容。避免使用复杂的术语或过于宽泛的题目。 实用方法:根据你的研究方向,标题应该包括你的系统类型、技术框架或研究对象。例如,“基于…...

Go学习:多重赋值与匿名变量

1. 变量的多重赋值 1.1 基本语法格式 go语言中,可以将多个赋值语句 合并成 一句,比如: a : 10 b : 20 c : 30//a,b,c三个变量的赋值语句可以简练成以下格式a, b, c : 10, 20, 30 1.2 交换变量值 当需要交换两个变量的值时&#…...

【Ubuntu 上搭建 Nginx-RTMP 服务】

本章目录: 环境1. 安装依赖2. 创建 Nginx 编译目录3. 下载 Nginx 和 Nginx-RTMP-Module4. 编译 Nginx 并添加 RTMP 模块5. 验证 Nginx 安装成功6. 配置环境变量7. 修改 Nginx 配置文件8. 启动 Nginx 服务查看 Nginx 是否启动成功查看端口监听状态 8. 常见问题及解决方法1. 缺少…...

使用uniapp 微信小程序一些好用的插件分享

总结一下自己在开发中遇见的一问题,通过引入组件可以快速的解决 1.zxz-uni-data-select 下拉框选择器(添加下拉框检索,多选功能,多选搜索功能,自定义 下拉框插件,使用这个的原因是因为 uniui uview 组件库下拉框太…...

linux centos挂载未分配的磁盘空间

使用到的命令 lshw -class disk -short hostnamectl fdisk /dev/sdb partprobe /dev/sdb mount /dev/sdb2 /opt/fastdfs/ mkfs.ext4 /dev/sdb2 mount -t ext4 /dev/sdb2 /opt/fastdfs/...

C语言凯撒密码程序分享

把刚才编写的程序又加工了一下,变成了程序,发给大家 我用夸克网盘分享了「凯撒密码」,点击链接即可保存。打开「夸克APP」,无需下载在线播放视频,畅享原画5倍速,支持电视投屏。 链接:https://p…...

2025新年源码免费送

2025很开门很开门的源码免费传递。不需要馒头就能获取4套大开门源码。 听泉偷宝,又进来偷我源码啦👊👊👊。欢迎偷源码 🔥🔥🔥 获取免费源码以及更多源码,可以私信联系我 我们常常…...

阿里云ethereum

https://geth.ethereum.org/docs/getting-started/installing-geth#linux-and-mac git clone https://github.com/ethereum/go-ethereum.git git checkout v1.10.11 cd go-ethereum # 阿里云添加goproxy export GOPROXYhttps://mirrors.aliyun.com/goproxy/ make geth创建gene…...

子父组件传值

Angular 2 及以上版本中的父子组件通信方式 在 Angular 2 及以上版本中,父子组件通信主要通过以下几种方式实现: 一、使用Input()进行父向子通信 父组件通过属性绑定的方式将数据传递给子组件,子组件使用Input()装饰器来接收这些数据。 二…...

QT自定义工具条渐变背景颜色一例

使用样式定义: QWidget* toolbar new QWidget(this);toolbar->setObjectName("main_tool");toolbar->setStyleSheet("#main_tool{background: qlineargradient(x1:0 , y1:0 , x2:1 , y2:0,""stop:0 rgba(0,255,0, 0.2),"&q…...

2025最新Facebook广告投放常见问题:如何提高广告效果?

Facebook广告投放已成为众多品牌拓展市场、提升品牌知名度和促进销售增长的关键手段。然而经常有人提出遇到广告没人看、定位不准或者内容不吸引人这些问题。那怎么办呢?别急,下面咱们就来聊聊Facebook广告投放常见问题以及如何提高Facebook广告的效果。…...

双向导航和单向导航

目录 双向导航 单向导航 迁移数据库异常 解决办法 1.导航属性改为空 2.使用 ON DELETE NO ACTION 或 ON UPDATE NO ACTION 选择 双向导航 一对多:一个Article有多个Comment class Article {public long Id { get; set; }public string Title { get; set; }pu…...

Unity3d 基于Barracuda推理库和YOLO算法实现对象检测功能

前言 近年来,随着AI技术的发展,在游戏引擎中实现和运行机器学习模型的需求也逐渐显现。Unity3d引擎官方推出深度学习推理框架–Barracuda ,旨在帮助开发者在Unity3d中轻松地实现和运行机器学习模型,它的主要功能是支持在 Unity 中…...

Lambda离线实时分治架构深度解析与实战

一、引言 在大数据技术日新月异的今天,Lambda架构作为一种经典的数据处理模型,在应对大规模数据应用方面展现出了强大的能力。它整合了离线批处理和实时流处理,为需要同时处理批量和实时数据的应用场景提供了成熟的解决方案。本文将对Lambda…...

Spring Boot教程之五十一:Spring Boot – CrudRepository 示例

Spring Boot – CrudRepository 示例 Spring Boot 建立在 Spring 之上,包含 Spring 的所有功能。由于其快速的生产就绪环境,使开发人员能够直接专注于逻辑,而不必费力配置和设置,因此如今它正成为开发人员的最爱。Spring Boot 是…...

jenkins入门6 --拉取代码

Jenkins代码拉取 需要的插件,缺少的安装下 新建一个item,选择freestyle project 源码管理配置如下:需要添加git库地址,和登录git的用户密码 配置好后执行编译,成功后拉取的代码在工作空间里...

CAPL概述与环境搭建

CAPL概述与环境搭建 目录 CAPL概述与环境搭建1. CAPL简介与应用领域1.1 CAPL简介1.2 CAPL的应用领域 2. CANoe/CANalyzer 安装与配置2.1 CANoe/CANalyzer 简介2.2 安装CANoe/CANalyzer2.2.1 系统要求2.2.2 安装步骤 2.3 配置CANoe/CANalyzer2.3.1 配置CAN通道2.3.2 配置CAPL节点…...

SpringCloud学习笔记-2

说明:来源于网络,如有侵权请联系我删除 1.提问:如果注册中心宕机,远程调用还能成功吗 答:当微服务发起请求时,会向注册中心请求所有的微服务地址,然后在向指定的微服务地址发起请求。在设计实…...

Python使用总结之Mac安装docker并配置wechaty

Python使用总结之Mac安装docker并配置wechaty ✅ 一、安装 Docker Desktop for macOS 1. 下载 Docker Desktop 安装包 访问官网下载安装包: 👉 https://www.docker.com/products/docker-desktop 选择 macOS (Apple 芯片或 Intel 芯片) 版本下载。 …...

基于有效集MPC控制算法的直线同步电机simulink建模与仿真,MPC使用S函数实现

目录 1.课题概述 2.系统仿真结果 3.核心程序 4.系统仿真参数 5.系统原理简介 6.参考文献 7.完整工程文件 1.课题概述 有效集算法通过迭代地选择一组 "有效" 约束,将约束优化问题转化为一系列无约束或等式约束优化问题。直线同步电机 (Linear Synch…...

5.2 HarmonyOS NEXT应用性能诊断与优化:工具链、启动速度与功耗管理实战

HarmonyOS NEXT应用性能诊断与优化:工具链、启动速度与功耗管理实战 在HarmonyOS NEXT的全场景生态中,应用性能直接影响用户体验。通过专业的性能分析工具链、针对性的启动速度优化,以及精细化的功耗管理,开发者能够构建"秒…...

【应用】Ghost Dance:利用惯性动捕构建虚拟舞伴

Ghost Dance是葡萄牙大学的一个研究项目,研究方向是探索人与人之间的联系,以及如何通过虚拟舞伴重现这种联系。项目负责人Cecilia和Rui利用惯性动捕创造出具有流畅动作的虚拟舞伴,让现实中的舞者也能与之共舞。 挑战:Ghost Danc…...

每日算法 -【Swift 算法】电话号码字母组合

🚀 LeetCode 字符串数字映射(Swift)——电话号码字母组合 在日常刷题或面试中,我们经常会遇到字符串 回溯组合的问题。这道经典题——电话号码的字母组合 就是典型代表。本文将带你用 Swift 实现这道题,思路清晰&…...

Vue3 + UniApp 蓝牙连接与数据发送(稳定版)

本教程适用于使用 uni-app Vue3 (script setup) 开发的跨平台 App(支持微信小程序、H5、Android/iOS 等) 🎯 功能目标 ✅ 获取蓝牙权限✅ 扫描周围蓝牙设备✅ 连接指定蓝牙设备✅ 获取服务和特征值✅ 向设备发送数据包(ArrayBu…...

2025最新Java日志框架深度解析:Log4j 2 vs Logback性能实测+企业级实战案例

一、为什么printStackTrace是"代码坟场"? 你写的日志可能正在拖垮系统! 在Java开发中,直接调用printStackTrace()打印异常堆栈是最常见的"自杀式操作"。这种方式会导致三大致命问题: 无法分级控制&#xff…...

python3.9带 C++绑定的基础镜像

FROM ubuntu:20.04 # 设置非交互式环境变量(避免apt安装时提示时区选择) ENV DEBIAN_FRONTENDnoninteractive RUN ln -fs /usr/share/zoneinfo/Asia/Shanghai /etc/localtime # 安装基础编译工具和依赖 # 添加Python 3.9 PPA并安装依赖 RUN apt-get upda…...

c# :this() 和 :base()区别

在 C# 中,:this() 和 :base() 都用于构造函数的重载和继承,但它们有不同的用途和上下文: 1. :this() 用途:用于调用当前类中的其他构造函数(构造函数重载)。场景:当你希望一个构造函数先执行另…...