当前位置: 首页 > news >正文

【深度学习】多目标融合算法(二):底部共享多任务模型(Shared-Bottom Multi-task Model)

 

目录

一、引言

1.1 往期回顾

1.2 本期概要

二、Shared-Bottom Multi-task Model(SBMM)

2.1 技术原理

2.2 技术优缺点

2.3 业务代码实践

三、总结


一、引言

在朴素的深度学习ctr预估模型中(如DNN),通常以一个行为为预估目标,比如通过ctr预估点击率。但实际推荐系统业务场景中,更多是多种目标融合的结果,比如视频推荐,会存在视频点击率、视频完整播放率、视频播放时长等多个目标,而多种目标如何更好的融合,在工业界与学术界均有较多内容产出,由于该环节对实际业务影响最为直接,特开此专栏对推荐系统深度学习多目标问题进行讲述。

1.1 往期回顾

上一篇文章主要介绍了推荐系统多目标算法中的“样本Loss加权”,该方法在训练时Loss乘以样本权重实现对多种目标的加权,通过引导Loss梯度的学习方向,让模型参数朝着你设定的权重方向去学习。

1.2 本期概要

今天进一步深化,主要介绍Shared-Bottom Multi-task Model算法,该算法中文可译为“底部共享多任务模型”,该算法设定多个任务,每个任务设定多个目标,通过“Loss计算时调整每个任务的权重”,亦或是“每个塔单元内,多目标Loss计算时调整每个目标的权重”进行多任务多目标的调整。

二、Shared-Bottom Multi-task Model(SBMM)

2.1 技术原理

Shared-Bottom Multi-task Model(SBMM)全称为底层共享多任务模型,主要由底层共享网络、多任务塔、多目标输出构成。核心原理:通过构造多任务多目标样本数据,在Loss计算环节,将各任务Loss求和(或加权求和),对Loss求导(求梯度)后,逐步后向传播迭代。

  1. 底部网络:Shared-Bottom 网络通常位于底部,可以为一个DNN网络,或者emb+pooling+mlp的方式对input输入的稀疏(sparse)特征进行稠密(dense)化。
  2. ​​​​​​​多个任务塔:底部网络上层接N个任务塔(Tower),每个塔根据需要可以定义为简单或复杂的多层感知器(mlp)网络。每个塔可以对应特定的场景,比如一二级页面场景。
  3. 多个目标:每个任务塔(Tower)可以输出多个学习目标,每个学习目标还可以像上一篇文章一样进行样本Loss加权。每个目标可以对应一种特定的指标行为,比如点击、时长、下单等。

2.2 技术优缺点

相比于上一篇文章提到的样本Loss加权融合法,以及后续文章将会介绍的MoE、MMoE方法,有如下优缺点:

优点:

  • 可以对多级场景任务进行建模,使得ctcvr等点击后转化问题可以被深度学习
  • 浅层参数共享,互相补充学习,任务相关性越高,模型的loss可以降低到更低

缺点: 

  • 跷跷板问题:任务没有好的相关性时,这种Hard parameter sharing会损害效果

2.3 业务代码实践

我们以小红书推荐场景为例,用户在一级发现页场景中停留并点击了“误杀3”中的一个视频笔记,在二级场景视频播放页中观看并点赞了视频。

跨场景多目标建模:我们定义一个SBMM算法结构,底层是一个3层的MLP(64,32,16),MLP出来后接一级场景Tower和二级场景Tower,一级场景任务中分别定义视频一级页“是否停留”、“停留时长”、“是否点击”,二级场景任务中分别定义“点击后播放时长”,“播放后是否点赞”

伪代码:

导入 pytorch 库
定义 SharedBottomMultiTaskModel 类 继承自 nn.Module:定义 __init__ 方法 参数 (self, 输入维度, 隐藏层1大小, 隐藏层2大小, 隐藏层3大小, 输出任务1维度, 输出任务2维度):初始化共享底部的三层全连接层初始化任务1的三层全连接层初始化任务2的三层全连接层定义 forward 方法 参数 (self, 输入数据):计算输入数据通过共享底部后的输出从共享底部输出分别计算任务1和任务2的结果返回任务1和任务2的结果生成虚拟样本数据:创建训练集和测试集实例化模型对象
定义损失函数和优化器
训练循环:前向传播: 获取预测值计算每个任务的损失反向传播和优化

PyTorch版本:

算法逻辑

  1. 导入必要的库。
  2. 定义一个类来表示共享底部和特定任务头部的模型结构。
  3. 在初始化方法中定义共享底部和两个独立的任务头部网络层。
  4. 实现前向传播函数,处理输入数据通过共享底部后分发到不同的任务头部。
  5. 生成虚拟样本数据。
  6. 定义损失函数和优化器。
  7. 编写训练循环。
  8. 进行模型预测。
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDatasetclass SharedBottomMultiTaskModel(nn.Module):def __init__(self, input_dim, hidden1_dim, hidden2_dim, hidden3_dim, output_task1_dim, output_task2_dim):super(SharedBottomMultiTaskModel, self).__init__()# 定义共享底部的三层全连接层self.shared_bottom = nn.Sequential(nn.Linear(input_dim, hidden1_dim),nn.ReLU(),nn.Linear(hidden1_dim, hidden2_dim),nn.ReLU(),nn.Linear(hidden2_dim, hidden3_dim),nn.ReLU())# 定义任务1的三层全连接层self.task1_head = nn.Sequential(nn.Linear(hidden3_dim, hidden2_dim),nn.ReLU(),nn.Linear(hidden2_dim, output_task1_dim))# 定义任务2的三层全连接层self.task2_head = nn.Sequential(nn.Linear(hidden3_dim, hidden2_dim),nn.ReLU(),nn.Linear(hidden2_dim, output_task2_dim))def forward(self, x):# 计算输入数据通过共享底部后的输出shared_output = self.shared_bottom(x)# 从共享底部输出分别计算任务1和任务2的结果task1_output = self.task1_head(shared_output)task2_output = self.task2_head(shared_output)return task1_output, task2_output# 构造虚拟样本数据
torch.manual_seed(42)  # 设置随机种子以保证结果可重复
input_dim = 10
task1_dim = 3
task2_dim = 2
num_samples = 1000
X_train = torch.randn(num_samples, input_dim)
y_train_task1 = torch.randn(num_samples, task1_dim)  # 假设任务1的输出维度为task1_dim
y_train_task2 = torch.randn(num_samples, task2_dim)  # 假设任务2的输出维度为task2_dim# 创建数据加载器
train_dataset = TensorDataset(X_train, y_train_task1, y_train_task2)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)# 实例化模型对象
model = SharedBottomMultiTaskModel(input_dim, 64, 32, 16, task1_dim, task2_dim)# 定义损失函数和优化器
criterion_task1 = nn.MSELoss()
criterion_task2 = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)# 训练循环
num_epochs = 10
for epoch in range(num_epochs):model.train()running_loss = 0.0for batch_idx, (X_batch, y_task1_batch, y_task2_batch) in enumerate(train_loader):# 前向传播: 获取预测值outputs_task1, outputs_task2 = model(X_batch)# 计算每个任务的损失loss_task1 = criterion_task1(outputs_task1, y_task1_batch)loss_task2 = criterion_task2(outputs_task2, y_task2_batch)#print(f'loss_task1:{loss_task1},loss_task2:{loss_task2}')total_loss = loss_task1 + loss_task2# 反向传播和优化optimizer.zero_grad()total_loss.backward()optimizer.step()running_loss += total_loss.item()print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss/len(train_loader):.4f}')# 模型预测
model.eval()
with torch.no_grad():test_input = torch.randn(1, input_dim)  # 构造一个测试样本pred_task1, pred_task2 = model(test_input)print(f'任务1预测结果: {pred_task1}')print(f'任务2预测结果: {pred_task2}')

三、总结

本文从技术原理、技术优缺点方面对推荐系统深度学习多任务多目标“Shared-Bottom Multi-task Model”算法进行讲解,该模型使用深度学习模型对多个任务场景多个目标的业务问题进行建模,使得用户在多个场景连续性行为可以被学习,在现实推荐系统业务中是比较基础的方法,后面本专栏还会陆续介绍MoE、MMoE等多任务多目标算法,期待您的关注和支持。

如果您还有时间,欢迎阅读本专栏的其他文章:

【深度学习】多目标融合算法(一):样本Loss加权(Sample Loss Reweight)

【深度学习】多目标融合算法(二):底部共享多任务模型(Shared-Bottom Multi-task Model) ​​​​​​​

相关文章:

【深度学习】多目标融合算法(二):底部共享多任务模型(Shared-Bottom Multi-task Model)

目录 一、引言 1.1 往期回顾 1.2 本期概要 二、Shared-Bottom Multi-task Model(SBMM) 2.1 技术原理 2.2 技术优缺点 2.3 业务代码实践 三、总结 一、引言 在朴素的深度学习ctr预估模型中(如DNN),通常以一个行…...

如何使用vue引入three.js

在 Vue.js 项目中引入和使用 Three.js 是一个常见的需求,Three.js 是一个用于在浏览器中创建和显示动画 3D 计算机图形的 JavaScript 库。以下是一个基本的示例,展示如何在 Vue 项目中引入和使用 Three.js。 1. 创建 Vue 项目 如果你还没有一个 Vue 项…...

城市生命线安全综合监管平台

【落地产品,有需要可留言联系,支持项目合作或源码合作】 一、建设背景 以关于城市安全的重要论述为建设纲要,聚焦城市安全重点领域,围绕燃气爆炸、城市内涝、地下管线交互风险、第三方施工破坏、供水爆管、桥梁坍塌、道路塌陷七…...

计算机毕设【开题报告】怎么写?

技巧 1. 标题简洁且具体 技巧:开题报告的标题要简明扼要,并准确表达研究的核心内容。避免使用复杂的术语或过于宽泛的题目。 实用方法:根据你的研究方向,标题应该包括你的系统类型、技术框架或研究对象。例如,“基于…...

Go学习:多重赋值与匿名变量

1. 变量的多重赋值 1.1 基本语法格式 go语言中,可以将多个赋值语句 合并成 一句,比如: a : 10 b : 20 c : 30//a,b,c三个变量的赋值语句可以简练成以下格式a, b, c : 10, 20, 30 1.2 交换变量值 当需要交换两个变量的值时&#…...

【Ubuntu 上搭建 Nginx-RTMP 服务】

本章目录: 环境1. 安装依赖2. 创建 Nginx 编译目录3. 下载 Nginx 和 Nginx-RTMP-Module4. 编译 Nginx 并添加 RTMP 模块5. 验证 Nginx 安装成功6. 配置环境变量7. 修改 Nginx 配置文件8. 启动 Nginx 服务查看 Nginx 是否启动成功查看端口监听状态 8. 常见问题及解决方法1. 缺少…...

使用uniapp 微信小程序一些好用的插件分享

总结一下自己在开发中遇见的一问题,通过引入组件可以快速的解决 1.zxz-uni-data-select 下拉框选择器(添加下拉框检索,多选功能,多选搜索功能,自定义 下拉框插件,使用这个的原因是因为 uniui uview 组件库下拉框太…...

linux centos挂载未分配的磁盘空间

使用到的命令 lshw -class disk -short hostnamectl fdisk /dev/sdb partprobe /dev/sdb mount /dev/sdb2 /opt/fastdfs/ mkfs.ext4 /dev/sdb2 mount -t ext4 /dev/sdb2 /opt/fastdfs/...

C语言凯撒密码程序分享

把刚才编写的程序又加工了一下,变成了程序,发给大家 我用夸克网盘分享了「凯撒密码」,点击链接即可保存。打开「夸克APP」,无需下载在线播放视频,畅享原画5倍速,支持电视投屏。 链接:https://p…...

2025新年源码免费送

2025很开门很开门的源码免费传递。不需要馒头就能获取4套大开门源码。 听泉偷宝,又进来偷我源码啦👊👊👊。欢迎偷源码 🔥🔥🔥 获取免费源码以及更多源码,可以私信联系我 我们常常…...

阿里云ethereum

https://geth.ethereum.org/docs/getting-started/installing-geth#linux-and-mac git clone https://github.com/ethereum/go-ethereum.git git checkout v1.10.11 cd go-ethereum # 阿里云添加goproxy export GOPROXYhttps://mirrors.aliyun.com/goproxy/ make geth创建gene…...

子父组件传值

Angular 2 及以上版本中的父子组件通信方式 在 Angular 2 及以上版本中,父子组件通信主要通过以下几种方式实现: 一、使用Input()进行父向子通信 父组件通过属性绑定的方式将数据传递给子组件,子组件使用Input()装饰器来接收这些数据。 二…...

QT自定义工具条渐变背景颜色一例

使用样式定义: QWidget* toolbar new QWidget(this);toolbar->setObjectName("main_tool");toolbar->setStyleSheet("#main_tool{background: qlineargradient(x1:0 , y1:0 , x2:1 , y2:0,""stop:0 rgba(0,255,0, 0.2),"&q…...

2025最新Facebook广告投放常见问题:如何提高广告效果?

Facebook广告投放已成为众多品牌拓展市场、提升品牌知名度和促进销售增长的关键手段。然而经常有人提出遇到广告没人看、定位不准或者内容不吸引人这些问题。那怎么办呢?别急,下面咱们就来聊聊Facebook广告投放常见问题以及如何提高Facebook广告的效果。…...

双向导航和单向导航

目录 双向导航 单向导航 迁移数据库异常 解决办法 1.导航属性改为空 2.使用 ON DELETE NO ACTION 或 ON UPDATE NO ACTION 选择 双向导航 一对多:一个Article有多个Comment class Article {public long Id { get; set; }public string Title { get; set; }pu…...

Unity3d 基于Barracuda推理库和YOLO算法实现对象检测功能

前言 近年来,随着AI技术的发展,在游戏引擎中实现和运行机器学习模型的需求也逐渐显现。Unity3d引擎官方推出深度学习推理框架–Barracuda ,旨在帮助开发者在Unity3d中轻松地实现和运行机器学习模型,它的主要功能是支持在 Unity 中…...

Lambda离线实时分治架构深度解析与实战

一、引言 在大数据技术日新月异的今天,Lambda架构作为一种经典的数据处理模型,在应对大规模数据应用方面展现出了强大的能力。它整合了离线批处理和实时流处理,为需要同时处理批量和实时数据的应用场景提供了成熟的解决方案。本文将对Lambda…...

Spring Boot教程之五十一:Spring Boot – CrudRepository 示例

Spring Boot – CrudRepository 示例 Spring Boot 建立在 Spring 之上,包含 Spring 的所有功能。由于其快速的生产就绪环境,使开发人员能够直接专注于逻辑,而不必费力配置和设置,因此如今它正成为开发人员的最爱。Spring Boot 是…...

jenkins入门6 --拉取代码

Jenkins代码拉取 需要的插件,缺少的安装下 新建一个item,选择freestyle project 源码管理配置如下:需要添加git库地址,和登录git的用户密码 配置好后执行编译,成功后拉取的代码在工作空间里...

CAPL概述与环境搭建

CAPL概述与环境搭建 目录 CAPL概述与环境搭建1. CAPL简介与应用领域1.1 CAPL简介1.2 CAPL的应用领域 2. CANoe/CANalyzer 安装与配置2.1 CANoe/CANalyzer 简介2.2 安装CANoe/CANalyzer2.2.1 系统要求2.2.2 安装步骤 2.3 配置CANoe/CANalyzer2.3.1 配置CAN通道2.3.2 配置CAPL节点…...

测试微信模版消息推送

进入“开发接口管理”--“公众平台测试账号”,无需申请公众账号、可在测试账号中体验并测试微信公众平台所有高级接口。 获取access_token: 自定义模版消息: 关注测试号:扫二维码关注测试号。 发送模版消息: import requests da…...

51c自动驾驶~合集58

我自己的原文哦~ https://blog.51cto.com/whaosoft/13967107 #CCA-Attention 全局池化局部保留,CCA-Attention为LLM长文本建模带来突破性进展 琶洲实验室、华南理工大学联合推出关键上下文感知注意力机制(CCA-Attention),…...

渗透实战PortSwigger靶场-XSS Lab 14:大多数标签和属性被阻止

<script>标签被拦截 我们需要把全部可用的 tag 和 event 进行暴力破解 XSS cheat sheet&#xff1a; https://portswigger.net/web-security/cross-site-scripting/cheat-sheet 通过爆破发现body可以用 再把全部 events 放进去爆破 这些 event 全部可用 <body onres…...

【C语言练习】080. 使用C语言实现简单的数据库操作

080. 使用C语言实现简单的数据库操作 080. 使用C语言实现简单的数据库操作使用原生APIODBC接口第三方库ORM框架文件模拟1. 安装SQLite2. 示例代码:使用SQLite创建数据库、表和插入数据3. 编译和运行4. 示例运行输出:5. 注意事项6. 总结080. 使用C语言实现简单的数据库操作 在…...

《基于Apache Flink的流处理》笔记

思维导图 1-3 章 4-7章 8-11 章 参考资料 源码&#xff1a; https://github.com/streaming-with-flink 博客 https://flink.apache.org/bloghttps://www.ververica.com/blog 聚会及会议 https://flink-forward.orghttps://www.meetup.com/topics/apache-flink https://n…...

Android Bitmap治理全解析:从加载优化到泄漏防控的全生命周期管理

引言 Bitmap&#xff08;位图&#xff09;是Android应用内存占用的“头号杀手”。一张1080P&#xff08;1920x1080&#xff09;的图片以ARGB_8888格式加载时&#xff0c;内存占用高达8MB&#xff08;192010804字节&#xff09;。据统计&#xff0c;超过60%的应用OOM崩溃与Bitm…...

Web 架构之 CDN 加速原理与落地实践

文章目录 一、思维导图二、正文内容&#xff08;一&#xff09;CDN 基础概念1. 定义2. 组成部分 &#xff08;二&#xff09;CDN 加速原理1. 请求路由2. 内容缓存3. 内容更新 &#xff08;三&#xff09;CDN 落地实践1. 选择 CDN 服务商2. 配置 CDN3. 集成到 Web 架构 &#xf…...

深度学习习题2

1.如果增加神经网络的宽度&#xff0c;精确度会增加到一个特定阈值后&#xff0c;便开始降低。造成这一现象的可能原因是什么&#xff1f; A、即使增加卷积核的数量&#xff0c;只有少部分的核会被用作预测 B、当卷积核数量增加时&#xff0c;神经网络的预测能力会降低 C、当卷…...

视觉slam十四讲实践部分记录——ch2、ch3

ch2 一、使用g++编译.cpp为可执行文件并运行(P30) g++ helloSLAM.cpp ./a.out运行 二、使用cmake编译 mkdir build cd build cmake .. makeCMakeCache.txt 文件仍然指向旧的目录。这表明在源代码目录中可能还存在旧的 CMakeCache.txt 文件,或者在构建过程中仍然引用了旧的路…...

MFC 抛体运动模拟:常见问题解决与界面美化

在 MFC 中开发抛体运动模拟程序时,我们常遇到 轨迹残留、无效刷新、视觉单调、物理逻辑瑕疵 等问题。本文将针对这些痛点,详细解析原因并提供解决方案,同时兼顾界面美化,让模拟效果更专业、更高效。 问题一:历史轨迹与小球残影残留 现象 小球运动后,历史位置的 “残影”…...