当前位置: 首页 > news >正文

【深度学习】多目标融合算法(二):底部共享多任务模型(Shared-Bottom Multi-task Model)

 

目录

一、引言

1.1 往期回顾

1.2 本期概要

二、Shared-Bottom Multi-task Model(SBMM)

2.1 技术原理

2.2 技术优缺点

2.3 业务代码实践

三、总结


一、引言

在朴素的深度学习ctr预估模型中(如DNN),通常以一个行为为预估目标,比如通过ctr预估点击率。但实际推荐系统业务场景中,更多是多种目标融合的结果,比如视频推荐,会存在视频点击率、视频完整播放率、视频播放时长等多个目标,而多种目标如何更好的融合,在工业界与学术界均有较多内容产出,由于该环节对实际业务影响最为直接,特开此专栏对推荐系统深度学习多目标问题进行讲述。

1.1 往期回顾

上一篇文章主要介绍了推荐系统多目标算法中的“样本Loss加权”,该方法在训练时Loss乘以样本权重实现对多种目标的加权,通过引导Loss梯度的学习方向,让模型参数朝着你设定的权重方向去学习。

1.2 本期概要

今天进一步深化,主要介绍Shared-Bottom Multi-task Model算法,该算法中文可译为“底部共享多任务模型”,该算法设定多个任务,每个任务设定多个目标,通过“Loss计算时调整每个任务的权重”,亦或是“每个塔单元内,多目标Loss计算时调整每个目标的权重”进行多任务多目标的调整。

二、Shared-Bottom Multi-task Model(SBMM)

2.1 技术原理

Shared-Bottom Multi-task Model(SBMM)全称为底层共享多任务模型,主要由底层共享网络、多任务塔、多目标输出构成。核心原理:通过构造多任务多目标样本数据,在Loss计算环节,将各任务Loss求和(或加权求和),对Loss求导(求梯度)后,逐步后向传播迭代。

  1. 底部网络:Shared-Bottom 网络通常位于底部,可以为一个DNN网络,或者emb+pooling+mlp的方式对input输入的稀疏(sparse)特征进行稠密(dense)化。
  2. ​​​​​​​多个任务塔:底部网络上层接N个任务塔(Tower),每个塔根据需要可以定义为简单或复杂的多层感知器(mlp)网络。每个塔可以对应特定的场景,比如一二级页面场景。
  3. 多个目标:每个任务塔(Tower)可以输出多个学习目标,每个学习目标还可以像上一篇文章一样进行样本Loss加权。每个目标可以对应一种特定的指标行为,比如点击、时长、下单等。

2.2 技术优缺点

相比于上一篇文章提到的样本Loss加权融合法,以及后续文章将会介绍的MoE、MMoE方法,有如下优缺点:

优点:

  • 可以对多级场景任务进行建模,使得ctcvr等点击后转化问题可以被深度学习
  • 浅层参数共享,互相补充学习,任务相关性越高,模型的loss可以降低到更低

缺点: 

  • 跷跷板问题:任务没有好的相关性时,这种Hard parameter sharing会损害效果

2.3 业务代码实践

我们以小红书推荐场景为例,用户在一级发现页场景中停留并点击了“误杀3”中的一个视频笔记,在二级场景视频播放页中观看并点赞了视频。

跨场景多目标建模:我们定义一个SBMM算法结构,底层是一个3层的MLP(64,32,16),MLP出来后接一级场景Tower和二级场景Tower,一级场景任务中分别定义视频一级页“是否停留”、“停留时长”、“是否点击”,二级场景任务中分别定义“点击后播放时长”,“播放后是否点赞”

伪代码:

导入 pytorch 库
定义 SharedBottomMultiTaskModel 类 继承自 nn.Module:定义 __init__ 方法 参数 (self, 输入维度, 隐藏层1大小, 隐藏层2大小, 隐藏层3大小, 输出任务1维度, 输出任务2维度):初始化共享底部的三层全连接层初始化任务1的三层全连接层初始化任务2的三层全连接层定义 forward 方法 参数 (self, 输入数据):计算输入数据通过共享底部后的输出从共享底部输出分别计算任务1和任务2的结果返回任务1和任务2的结果生成虚拟样本数据:创建训练集和测试集实例化模型对象
定义损失函数和优化器
训练循环:前向传播: 获取预测值计算每个任务的损失反向传播和优化

PyTorch版本:

算法逻辑

  1. 导入必要的库。
  2. 定义一个类来表示共享底部和特定任务头部的模型结构。
  3. 在初始化方法中定义共享底部和两个独立的任务头部网络层。
  4. 实现前向传播函数,处理输入数据通过共享底部后分发到不同的任务头部。
  5. 生成虚拟样本数据。
  6. 定义损失函数和优化器。
  7. 编写训练循环。
  8. 进行模型预测。
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDatasetclass SharedBottomMultiTaskModel(nn.Module):def __init__(self, input_dim, hidden1_dim, hidden2_dim, hidden3_dim, output_task1_dim, output_task2_dim):super(SharedBottomMultiTaskModel, self).__init__()# 定义共享底部的三层全连接层self.shared_bottom = nn.Sequential(nn.Linear(input_dim, hidden1_dim),nn.ReLU(),nn.Linear(hidden1_dim, hidden2_dim),nn.ReLU(),nn.Linear(hidden2_dim, hidden3_dim),nn.ReLU())# 定义任务1的三层全连接层self.task1_head = nn.Sequential(nn.Linear(hidden3_dim, hidden2_dim),nn.ReLU(),nn.Linear(hidden2_dim, output_task1_dim))# 定义任务2的三层全连接层self.task2_head = nn.Sequential(nn.Linear(hidden3_dim, hidden2_dim),nn.ReLU(),nn.Linear(hidden2_dim, output_task2_dim))def forward(self, x):# 计算输入数据通过共享底部后的输出shared_output = self.shared_bottom(x)# 从共享底部输出分别计算任务1和任务2的结果task1_output = self.task1_head(shared_output)task2_output = self.task2_head(shared_output)return task1_output, task2_output# 构造虚拟样本数据
torch.manual_seed(42)  # 设置随机种子以保证结果可重复
input_dim = 10
task1_dim = 3
task2_dim = 2
num_samples = 1000
X_train = torch.randn(num_samples, input_dim)
y_train_task1 = torch.randn(num_samples, task1_dim)  # 假设任务1的输出维度为task1_dim
y_train_task2 = torch.randn(num_samples, task2_dim)  # 假设任务2的输出维度为task2_dim# 创建数据加载器
train_dataset = TensorDataset(X_train, y_train_task1, y_train_task2)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)# 实例化模型对象
model = SharedBottomMultiTaskModel(input_dim, 64, 32, 16, task1_dim, task2_dim)# 定义损失函数和优化器
criterion_task1 = nn.MSELoss()
criterion_task2 = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)# 训练循环
num_epochs = 10
for epoch in range(num_epochs):model.train()running_loss = 0.0for batch_idx, (X_batch, y_task1_batch, y_task2_batch) in enumerate(train_loader):# 前向传播: 获取预测值outputs_task1, outputs_task2 = model(X_batch)# 计算每个任务的损失loss_task1 = criterion_task1(outputs_task1, y_task1_batch)loss_task2 = criterion_task2(outputs_task2, y_task2_batch)#print(f'loss_task1:{loss_task1},loss_task2:{loss_task2}')total_loss = loss_task1 + loss_task2# 反向传播和优化optimizer.zero_grad()total_loss.backward()optimizer.step()running_loss += total_loss.item()print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss/len(train_loader):.4f}')# 模型预测
model.eval()
with torch.no_grad():test_input = torch.randn(1, input_dim)  # 构造一个测试样本pred_task1, pred_task2 = model(test_input)print(f'任务1预测结果: {pred_task1}')print(f'任务2预测结果: {pred_task2}')

三、总结

本文从技术原理、技术优缺点方面对推荐系统深度学习多任务多目标“Shared-Bottom Multi-task Model”算法进行讲解,该模型使用深度学习模型对多个任务场景多个目标的业务问题进行建模,使得用户在多个场景连续性行为可以被学习,在现实推荐系统业务中是比较基础的方法,后面本专栏还会陆续介绍MoE、MMoE等多任务多目标算法,期待您的关注和支持。

如果您还有时间,欢迎阅读本专栏的其他文章:

【深度学习】多目标融合算法(一):样本Loss加权(Sample Loss Reweight)

【深度学习】多目标融合算法(二):底部共享多任务模型(Shared-Bottom Multi-task Model) ​​​​​​​

相关文章:

【深度学习】多目标融合算法(二):底部共享多任务模型(Shared-Bottom Multi-task Model)

目录 一、引言 1.1 往期回顾 1.2 本期概要 二、Shared-Bottom Multi-task Model(SBMM) 2.1 技术原理 2.2 技术优缺点 2.3 业务代码实践 三、总结 一、引言 在朴素的深度学习ctr预估模型中(如DNN),通常以一个行…...

如何使用vue引入three.js

在 Vue.js 项目中引入和使用 Three.js 是一个常见的需求,Three.js 是一个用于在浏览器中创建和显示动画 3D 计算机图形的 JavaScript 库。以下是一个基本的示例,展示如何在 Vue 项目中引入和使用 Three.js。 1. 创建 Vue 项目 如果你还没有一个 Vue 项…...

城市生命线安全综合监管平台

【落地产品,有需要可留言联系,支持项目合作或源码合作】 一、建设背景 以关于城市安全的重要论述为建设纲要,聚焦城市安全重点领域,围绕燃气爆炸、城市内涝、地下管线交互风险、第三方施工破坏、供水爆管、桥梁坍塌、道路塌陷七…...

计算机毕设【开题报告】怎么写?

技巧 1. 标题简洁且具体 技巧:开题报告的标题要简明扼要,并准确表达研究的核心内容。避免使用复杂的术语或过于宽泛的题目。 实用方法:根据你的研究方向,标题应该包括你的系统类型、技术框架或研究对象。例如,“基于…...

Go学习:多重赋值与匿名变量

1. 变量的多重赋值 1.1 基本语法格式 go语言中,可以将多个赋值语句 合并成 一句,比如: a : 10 b : 20 c : 30//a,b,c三个变量的赋值语句可以简练成以下格式a, b, c : 10, 20, 30 1.2 交换变量值 当需要交换两个变量的值时&#…...

【Ubuntu 上搭建 Nginx-RTMP 服务】

本章目录: 环境1. 安装依赖2. 创建 Nginx 编译目录3. 下载 Nginx 和 Nginx-RTMP-Module4. 编译 Nginx 并添加 RTMP 模块5. 验证 Nginx 安装成功6. 配置环境变量7. 修改 Nginx 配置文件8. 启动 Nginx 服务查看 Nginx 是否启动成功查看端口监听状态 8. 常见问题及解决方法1. 缺少…...

使用uniapp 微信小程序一些好用的插件分享

总结一下自己在开发中遇见的一问题,通过引入组件可以快速的解决 1.zxz-uni-data-select 下拉框选择器(添加下拉框检索,多选功能,多选搜索功能,自定义 下拉框插件,使用这个的原因是因为 uniui uview 组件库下拉框太…...

linux centos挂载未分配的磁盘空间

使用到的命令 lshw -class disk -short hostnamectl fdisk /dev/sdb partprobe /dev/sdb mount /dev/sdb2 /opt/fastdfs/ mkfs.ext4 /dev/sdb2 mount -t ext4 /dev/sdb2 /opt/fastdfs/...

C语言凯撒密码程序分享

把刚才编写的程序又加工了一下,变成了程序,发给大家 我用夸克网盘分享了「凯撒密码」,点击链接即可保存。打开「夸克APP」,无需下载在线播放视频,畅享原画5倍速,支持电视投屏。 链接:https://p…...

2025新年源码免费送

2025很开门很开门的源码免费传递。不需要馒头就能获取4套大开门源码。 听泉偷宝,又进来偷我源码啦👊👊👊。欢迎偷源码 🔥🔥🔥 获取免费源码以及更多源码,可以私信联系我 我们常常…...

阿里云ethereum

https://geth.ethereum.org/docs/getting-started/installing-geth#linux-and-mac git clone https://github.com/ethereum/go-ethereum.git git checkout v1.10.11 cd go-ethereum # 阿里云添加goproxy export GOPROXYhttps://mirrors.aliyun.com/goproxy/ make geth创建gene…...

子父组件传值

Angular 2 及以上版本中的父子组件通信方式 在 Angular 2 及以上版本中,父子组件通信主要通过以下几种方式实现: 一、使用Input()进行父向子通信 父组件通过属性绑定的方式将数据传递给子组件,子组件使用Input()装饰器来接收这些数据。 二…...

QT自定义工具条渐变背景颜色一例

使用样式定义: QWidget* toolbar new QWidget(this);toolbar->setObjectName("main_tool");toolbar->setStyleSheet("#main_tool{background: qlineargradient(x1:0 , y1:0 , x2:1 , y2:0,""stop:0 rgba(0,255,0, 0.2),"&q…...

2025最新Facebook广告投放常见问题:如何提高广告效果?

Facebook广告投放已成为众多品牌拓展市场、提升品牌知名度和促进销售增长的关键手段。然而经常有人提出遇到广告没人看、定位不准或者内容不吸引人这些问题。那怎么办呢?别急,下面咱们就来聊聊Facebook广告投放常见问题以及如何提高Facebook广告的效果。…...

双向导航和单向导航

目录 双向导航 单向导航 迁移数据库异常 解决办法 1.导航属性改为空 2.使用 ON DELETE NO ACTION 或 ON UPDATE NO ACTION 选择 双向导航 一对多:一个Article有多个Comment class Article {public long Id { get; set; }public string Title { get; set; }pu…...

Unity3d 基于Barracuda推理库和YOLO算法实现对象检测功能

前言 近年来,随着AI技术的发展,在游戏引擎中实现和运行机器学习模型的需求也逐渐显现。Unity3d引擎官方推出深度学习推理框架–Barracuda ,旨在帮助开发者在Unity3d中轻松地实现和运行机器学习模型,它的主要功能是支持在 Unity 中…...

Lambda离线实时分治架构深度解析与实战

一、引言 在大数据技术日新月异的今天,Lambda架构作为一种经典的数据处理模型,在应对大规模数据应用方面展现出了强大的能力。它整合了离线批处理和实时流处理,为需要同时处理批量和实时数据的应用场景提供了成熟的解决方案。本文将对Lambda…...

Spring Boot教程之五十一:Spring Boot – CrudRepository 示例

Spring Boot – CrudRepository 示例 Spring Boot 建立在 Spring 之上,包含 Spring 的所有功能。由于其快速的生产就绪环境,使开发人员能够直接专注于逻辑,而不必费力配置和设置,因此如今它正成为开发人员的最爱。Spring Boot 是…...

jenkins入门6 --拉取代码

Jenkins代码拉取 需要的插件,缺少的安装下 新建一个item,选择freestyle project 源码管理配置如下:需要添加git库地址,和登录git的用户密码 配置好后执行编译,成功后拉取的代码在工作空间里...

CAPL概述与环境搭建

CAPL概述与环境搭建 目录 CAPL概述与环境搭建1. CAPL简介与应用领域1.1 CAPL简介1.2 CAPL的应用领域 2. CANoe/CANalyzer 安装与配置2.1 CANoe/CANalyzer 简介2.2 安装CANoe/CANalyzer2.2.1 系统要求2.2.2 安装步骤 2.3 配置CANoe/CANalyzer2.3.1 配置CAN通道2.3.2 配置CAPL节点…...

进程地址空间(比特课总结)

一、进程地址空间 1. 环境变量 1 )⽤户级环境变量与系统级环境变量 全局属性:环境变量具有全局属性,会被⼦进程继承。例如当bash启动⼦进程时,环 境变量会⾃动传递给⼦进程。 本地变量限制:本地变量只在当前进程(ba…...

【第二十一章 SDIO接口(SDIO)】

第二十一章 SDIO接口 目录 第二十一章 SDIO接口(SDIO) 1 SDIO 主要功能 2 SDIO 总线拓扑 3 SDIO 功能描述 3.1 SDIO 适配器 3.2 SDIOAHB 接口 4 卡功能描述 4.1 卡识别模式 4.2 卡复位 4.3 操作电压范围确认 4.4 卡识别过程 4.5 写数据块 4.6 读数据块 4.7 数据流…...

拉力测试cuda pytorch 把 4070显卡拉满

import torch import timedef stress_test_gpu(matrix_size16384, duration300):"""对GPU进行压力测试,通过持续的矩阵乘法来最大化GPU利用率参数:matrix_size: 矩阵维度大小,增大可提高计算复杂度duration: 测试持续时间(秒&…...

代理篇12|深入理解 Vite中的Proxy接口代理配置

在前端开发中,常常会遇到 跨域请求接口 的情况。为了解决这个问题,Vite 和 Webpack 都提供了 proxy 代理功能,用于将本地开发请求转发到后端服务器。 什么是代理(proxy)? 代理是在开发过程中,前端项目通过开发服务器,将指定的请求“转发”到真实的后端服务器,从而绕…...

C#中的CLR属性、依赖属性与附加属性

CLR属性的主要特征 封装性: 隐藏字段的实现细节 提供对字段的受控访问 访问控制: 可单独设置get/set访问器的可见性 可创建只读或只写属性 计算属性: 可以在getter中执行计算逻辑 不需要直接对应一个字段 验证逻辑: 可以…...

C++.OpenGL (20/64)混合(Blending)

混合(Blending) 透明效果核心原理 #mermaid-svg-SWG0UzVfJms7Sm3e {font-family:"trebuchet ms",verdana,arial,sans-serif;font-size:16px;fill:#333;}#mermaid-svg-SWG0UzVfJms7Sm3e .error-icon{fill:#552222;}#mermaid-svg-SWG0UzVfJms7Sm3e .error-text{fill…...

Java求职者面试指南:计算机基础与源码原理深度解析

Java求职者面试指南:计算机基础与源码原理深度解析 第一轮提问:基础概念问题 1. 请解释什么是进程和线程的区别? 面试官:进程是程序的一次执行过程,是系统进行资源分配和调度的基本单位;而线程是进程中的…...

vulnyx Blogger writeup

信息收集 arp-scan nmap 获取userFlag 上web看看 一个默认的页面,gobuster扫一下目录 可以看到扫出的目录中得到了一个有价值的目录/wordpress,说明目标所使用的cms是wordpress,访问http://192.168.43.213/wordpress/然后查看源码能看到 这…...

NPOI操作EXCEL文件 ——CAD C# 二次开发

缺点:dll.版本容易加载错误。CAD加载插件时,没有加载所有类库。插件运行过程中用到某个类库,会从CAD的安装目录找,找不到就报错了。 【方案2】让CAD在加载过程中把类库加载到内存 【方案3】是发现缺少了哪个库,就用插件程序加载进…...

苹果AI眼镜:从“工具”到“社交姿态”的范式革命——重新定义AI交互入口的未来机会

在2025年的AI硬件浪潮中,苹果AI眼镜(Apple Glasses)正在引发一场关于“人机交互形态”的深度思考。它并非简单地替代AirPods或Apple Watch,而是开辟了一个全新的、日常可接受的AI入口。其核心价值不在于功能的堆叠,而在于如何通过形态设计打破社交壁垒,成为用户“全天佩戴…...