隐私计算实训营:SplitRec:当拆分学习遇上推荐系统
拆分学习的概念
拆分学习的核心思想是拆分网络结构。每一个参与方拥有模型结构的一部分,所有参与方的模型合在一起形成一个完整的模型。训练过程中,不同参与方只对本地模型进行正向或者反向传播计算,并将计算结果传递给下一个参与方。多个参与方通过联合模型进行训练直至最终收敛。
一个典型的拆分学习例子:

Alice持有数据和基础模型。Bob只有数据、基础模型和fuse模型。
- Alice使用自己的数据和基础模型得到
hidden0,然后发送给Bob。 - Bob使用自己的数据和基础模型得到
hidden1。 - Agg Layer使用
hidden_0和hidden_1作为输入,并输出聚合后的隐层。 - Bob把聚合后的隐层作为fuse模型的输入,计算得到梯度。
- 梯度被拆分成两部分,分别返回给Alice和Bob。
- Alice和Bob使用各自收到的梯度更新基础模型。
SplitRec
SplitRec是隐语拆分学习针对跨域推荐场景中的模型训练所提供的一系列优化算法和策略。
在传统推荐场景中,用户的数据通常需要上传到中央服务器进行模型训练。而跨域推荐场景是指联合分布在不同域的数据进行分布式训练的推荐场景。例如一个用户在一个短视频平台看了很多短视频,在另一个电商平台被推荐相关的广告,电商平台除了自有数据外,也希望从短视频平台的数据中挖掘相关的信息。同时出于数据安全考虑,各平台数据不能被上传到中央服务器进行集中式的机器学习训练,这种联合分布在不同域的数据进行模型训练的场景很适合用联邦学习中的拆分学习。
跨域推荐模型将不同域的用户数据联合起来建模,相比传统推荐系统收集到的数据更多更丰富,同时由于数据分布在不同域,在精度、效率和安全性上都对模型的训练提出了很多挑战,主要有以下三点:
- 模型效果上,例如DeepFM等复杂模型能否直接放到拆分框架中使用?
- 训练效率上,模型训练中每个 batch 的前反向计算中的通信是否会严重降低训练效率?
- 安全性上,通信的中间数据是否会造成信息泄露,引起安全性问题?
SplitRec 在效果、效率和安全方面对拆分模型训练做了很多优化。
- 模型效果上,SplitRec 提供了拆分 DeepFM、BST、MMoe 等模型的封装。
- 训练效率上,SplitRec 借由隐语拆分学习框架的能力,提供了压缩、流水并行等策略来提升训练效率。
- 安全性上,SplitRec提供了安全聚合、差分隐私等安全策略。同时也提供了一些针对拆分学习的攻击方法,来验证不同攻击手段对拆分模型的影响,后续也会更新相关防御方法。
实践:在隐语中使用拆分 DeepFM 算法
DeepFM算法结合了FM和神经网络的长处,可以同时提升低维和高维特征,相比Wide&Deep模型还免去了特征工程的部分。

整体上来看。这个模型可以分成两个部分,分别是FM部分以及Deep部分。这两个部分的输入是一样的,并没有像Wide & Deep模型那样做区分。Deep的部分用来训练这些特征的高维的关联,而FM模型会通过隐藏向量V的形式来计算特征之间的二维交叉的信息。
隐语中的DeepFM
拆分的详细过程可以来看这里:
SplitRec:在隐语中使用拆分 DeepFM 算法(Tensorflow 后端) | SecretFlow v1.9.0b1 | 隐语 SecretFlow
环境设置
import secretflow as sf# Check the version of your SecretFlow
print('The version of SecretFlow: {}'.format(sf.__version__))# In case you have a running secretflow runtime already.
sf.shutdown()
sf.init(['alice', 'bob', 'charlie'], address="local", log_to_driver=False)
alice, bob, charlie = sf.PYU('alice'), sf.PYU('bob'), sf.PYU('charlie')
数据集介绍
我们这里将使用最经典的MovieLens数据集来进行演示。 MovieLens是一个开放式的推荐系统数据集,包含了电影评分和电影元数据信息。
我们对数据进行了切分:
- alice: “UserID”, “Gender”, “Age”, “Occupation”, “Zip-code”
- bob: “MovieID”, “Rating”, “Title”, “Genres”, “Timestamp”
下载并处理数据
数据拆分处理
%%capture
%%!
wget https://secretflow-data.oss-accelerate.aliyuncs.com/datasets/movielens/ml-1m.zip
unzip ./ml-1m.zip
# Read the data in dat format and convert it into a dictionary
def load_data(filename, columns):data = {}with open(filename, "r", encoding="unicode_escape") as f:for line in f:ls = line.strip("\n").split("::")data[ls[0]] = dict(zip(columns[1:], ls[1:]))return data
fed_csv = {alice: "alice_ml1m.csv", bob: "bob_ml1m.csv"}
csv_writer_container = {alice: open(fed_csv[alice], "w"), bob: open(fed_csv[bob], "w")}
part_columns = {alice: ["UserID", "Gender", "Age", "Occupation", "Zip-code"],bob: ["MovieID", "Rating", "Title", "Genres", "Timestamp"],
}
for device, writer in csv_writer_container.items():writer.write("ID," + ",".join(part_columns[device]) + "\n")
f = open("ml-1m/ratings.dat", "r", encoding="unicode_escape")users_data = load_data("./ml-1m/users.dat",columns=["UserID", "Gender", "Age", "Occupation", "Zip-code"],
)
movies_data = load_data("./ml-1m/movies.dat", columns=["MovieID", "Title", "Genres"])
ratings_columns = ["UserID", "MovieID", "Rating", "Timestamp"]rating_data = load_data("./ml-1m/ratings.dat", columns=ratings_columns)def _parse_example(feature, columns, index):if "Title" in feature.keys():feature["Title"] = feature["Title"].replace(",", "_")if "Genres" in feature.keys():feature["Genres"] = feature["Genres"].replace("|", " ")values = []values.append(str(index))for c in columns:values.append(feature[c])return ",".join(values)index = 0
num_sample = 1000
for line in f:ls = line.strip().split("::")rating = dict(zip(ratings_columns, ls))rating.update(users_data.get(ls[0]))rating.update(movies_data.get(ls[1]))for device, columns in part_columns.items():parse_f = _parse_example(rating, columns, index)csv_writer_container[device].write(parse_f + "\n")index += 1if num_sample > 0 and index >= num_sample:break
for w in csv_writer_container.values():w.close()
到此就完成了数据的处理和拆分
得到
alice: alice_ml1m.csv
bob: bob_ml1m.csv
! head alice_ml1m.csv
! head bob_ml1m.csv
构造data_builder_dict
# alice
def create_dataset_builder_alice(batch_size=128,repeat_count=5,
):def dataset_builder(x):import pandas as pdimport tensorflow as tfx = [dict(t) if isinstance(t, pd.DataFrame) else t for t in x]x = x[0] if len(x) == 1 else tuple(x)data_set = (tf.data.Dataset.from_tensor_slices(x).batch(batch_size).repeat(repeat_count))return data_setreturn dataset_builder# bob
def create_dataset_builder_bob(batch_size=128,repeat_count=5,
):def _parse_bob(row_sample, label):import tensorflow as tfy_t = label["Rating"]y = tf.expand_dims(tf.where(y_t > 3,tf.ones_like(y_t, dtype=tf.float32),tf.zeros_like(y_t, dtype=tf.float32),),axis=1,)return row_sample, ydef dataset_builder(x):import pandas as pdimport tensorflow as tfx = [dict(t) if isinstance(t, pd.DataFrame) else t for t in x]x = x[0] if len(x) == 1 else tuple(x)data_set = (tf.data.Dataset.from_tensor_slices(x).batch(batch_size).repeat(repeat_count))data_set = data_set.map(_parse_bob)return data_setreturn dataset_builderdata_builder_dict = {alice: create_dataset_builder_alice(batch_size=128,repeat_count=5,),bob: create_dataset_builder_bob(batch_size=128,repeat_count=5,),
}
from secretflow.ml.nn.applications.sl_deep_fm import DeepFMbase, DeepFMfuse
from secretflow.ml.nn import SLModelNUM_USERS = 6040
NUM_MOVIES = 3952
GENDER_VOCAB = ["F", "M"]
AGE_VOCAB = [1, 18, 25, 35, 45, 50, 56]
OCCUPATION_VOCAB = [i for i in range(21)]
GENRES_VOCAB = ["Action","Adventure","Animation","Children's","Comedy","Crime","Documentary","Drama","Fantasy","Film-Noir","Horror","Musical","Mystery","Romance","Sci-Fi","Thriller","War","Western",
]
DeepFMBase有4个参数:
-dnn_units_size: 这个参数需要提供一个list来对dnn部分进行定义,比如[256,32]意思是中间两个隐层分别是256,和32
-dnn_activation: dnn 的激活函数,eg:relu
-preprocess_layer: 需要对输入进行处理,传入一个定义好的keras.preprocesslayer
-fm_embedding_dim: fm vector的维度是多少
# Define alice's basenet
def create_base_model_alice():# Create modeldef create_model():import tensorflow as tfdef preprocess():inputs = {"UserID": tf.keras.Input(shape=(1,), dtype=tf.string),"Gender": tf.keras.Input(shape=(1,), dtype=tf.string),"Age": tf.keras.Input(shape=(1,), dtype=tf.int64),"Occupation": tf.keras.Input(shape=(1,), dtype=tf.int64),}user_id_output = tf.keras.layers.Hashing(num_bins=NUM_USERS, output_mode="one_hot")user_gender_output = tf.keras.layers.StringLookup(vocabulary=GENDER_VOCAB, output_mode="one_hot")user_age_out = tf.keras.layers.IntegerLookup(vocabulary=AGE_VOCAB, output_mode="one_hot")user_occupation_out = tf.keras.layers.IntegerLookup(vocabulary=OCCUPATION_VOCAB, output_mode="one_hot")outputs = {"UserID": user_id_output(inputs["UserID"]),"Gender": user_gender_output(inputs["Gender"]),"Age": user_age_out(inputs["Age"]),"Occupation": user_occupation_out(inputs["Occupation"]),}return tf.keras.Model(inputs=inputs, outputs=outputs)preprocess_layer = preprocess()model = DeepFMbase(dnn_units_size=[256, 32],preprocess_layer=preprocess_layer,)model.compile(loss=tf.keras.losses.binary_crossentropy,optimizer=tf.keras.optimizers.Adam(),metrics=[tf.keras.metrics.AUC(),tf.keras.metrics.Precision(),tf.keras.metrics.Recall(),],)return model # need wrapreturn create_model
# Define bob's basenet
def create_base_model_bob():# Create modeldef create_model():import tensorflow as tf# define preprocess layerdef preprocess():inputs = {"MovieID": tf.keras.Input(shape=(1,), dtype=tf.string),"Genres": tf.keras.Input(shape=(1,), dtype=tf.string),}movie_id_out = tf.keras.layers.Hashing(num_bins=NUM_MOVIES, output_mode="one_hot")movie_genres_out = tf.keras.layers.TextVectorization(output_mode='multi_hot', split="whitespace", vocabulary=GENRES_VOCAB)outputs = {"MovieID": movie_id_out(inputs["MovieID"]),"Genres": movie_genres_out(inputs["Genres"]),}return tf.keras.Model(inputs=inputs, outputs=outputs)preprocess_layer = preprocess()model = DeepFMbase(dnn_units_size=[256, 32],preprocess_layer=preprocess_layer,)model.compile(loss=tf.keras.losses.binary_crossentropy,optimizer=tf.keras.optimizers.Adam(),metrics=[tf.keras.metrics.AUC(),tf.keras.metrics.Precision(),tf.keras.metrics.Recall(),],)return model # need wrapreturn create_model
定义Fusenet
def create_fuse_model():# Create modeldef create_model():import tensorflow as tfmodel = DeepFMfuse(dnn_units_size=[256, 256, 32])model.compile(loss=tf.keras.losses.binary_crossentropy,optimizer=tf.keras.optimizers.Adam(),metrics=[tf.keras.metrics.AUC(),tf.keras.metrics.Precision(),tf.keras.metrics.Recall(),],)return modelreturn create_model
base_model_dict = {alice: create_base_model_alice(), bob: create_base_model_bob()}
model_fuse = create_fuse_model()
from secretflow.data.vertical import read_csv as v_read_csvvdf = v_read_csv({alice: "alice_ml1m.csv", bob: "bob_ml1m.csv"}, keys="ID", drop_keys="ID"
)
label = vdf["Rating"]data = vdf.drop(columns=["Rating", "Timestamp", "Title", "Zip-code"])
data["UserID"] = data["UserID"].astype("string")
data["MovieID"] = data["MovieID"].astype("string")sl_model = SLModel(base_model_dict=base_model_dict,device_y=bob,model_fuse=model_fuse,
)
history = sl_model.fit(data,label,epochs=5,batch_size=128,random_seed=1234,dataset_builder=data_builder_dict,
)
到这里,我们已经使用隐语提供的deepfm封装完成了movieLens数据集上的推荐任务训练。
总结
我们通过movieLens数据集上的推荐任务来演示了如何通过隐语来实现DeepFM。
1.下载并拆分数据集;
2.定义好数据处理的dataloader;
3.定义好数据预处理的preprocesslayer,定义好dnn结构,调用DeepFMBase,DeepFMFuse来进行模型定义;
4.使用SLModel进行训练,预测,评估即可。
相关文章:
隐私计算实训营:SplitRec:当拆分学习遇上推荐系统
拆分学习的概念 拆分学习的核心思想是拆分网络结构。每一个参与方拥有模型结构的一部分,所有参与方的模型合在一起形成一个完整的模型。训练过程中,不同参与方只对本地模型进行正向或者反向传播计算,并将计算结果传递给下一个参与方。多个参…...
存在nginx版本信息泄露(请求头中存在nginx中间件版本信息)
在Nginx的配置文件中,server_tokens指令用于控制Nginx在HTTP响应头中包含的服务器版本信息,默认为true,开启状态。当设置为off时,Nginx将不会在响应头中包含任何服务器版本信息,仅显示“Server: nginx”这一行…...
在js中观察者模式讲解
在JavaScript中,观察者模式(Observer Pattern)是一种设计模式,允许一个对象(被观察者,Subject)维护一个依赖它的对象列表(观察者,Observer),并在它自身状态发生变化时自动通知这些观察者。观察者模式的典型使用场景包括事件系统、数据绑定和实时更新等情况。 一 、…...
java常用面试题-基础知识分享
什么是Java? Java是一种高级编程语言,旨在提供跨平台的解决方案。它是一种面向对象的语言,具有简单、结构化、可移植、可靠、安全等特点。 Java的主要特点是什么? Java的主要特点包括: 简单性:Java的语法…...
iOS——runLoop
什么是runloop RunLoop实际上就是一个对象,这个对象管理了其需要处理的事件和消息,并提供了一个入口函数来执行相应的处理逻辑。线程执行了这个函数后,就会处于这个函数内部的循环中,直到循环结束,函数返回。 RunLoo…...
python: 多模块(.py)中全局变量的导入
文章目录 global关键字可变类型和不可变类型数据的内存地址单模块(单个py文件)的全局变量示例总结 多模块(多个py文件)的全局变量from x import x导入全局变量示例 import x导入全局变量示例 总结 global关键字 global 的作用范围是模块(.py)级别: 当你在一个模块&…...
0基础学习爬虫系列:Python环境搭建
1.背景 当前网络资源更新非常快,然后对应自己感兴趣的内容,每天盯着刷网站又太费时间。我在尝试借助Ai,搭建一套自己知识抓取更新提醒的系统,这样可以用极少的时间,关注到自己感兴趣的信息。 其实,这套逻辑…...
Unity Shader实现简单的各向异性渲染(采用各向异性形式的GGX分布)
目录 准备工作 BRDF部分 Unity部分 代码 实现的效果 参考 最近刚结束GAMES202的学习,准备慢慢过渡到GAMES103。GAMES103的作业框架为Unity,并没有接触过,因此准备先学一点Unity的使用。刚好101和202都是渲染相关的,因此先学习…...
React开源框架之Refine
React Refine 是一个基于 React 的开源框架,它旨在帮助开发者快速构建企业级后台管理系统(Admin Panel)。Refine 是由 Retax 演变而来,它提供了一套完整的解决方案,用于构建 CRUD(创建、读取、更新、删除&a…...
【iOS】——渲染原理与离屏渲染
图像渲染流水线(图像渲染流程) 图像渲染流程大致分为四个部分: Application 应用处理阶段:得到图元Geometry 几何处理阶段:处理图元Rasterization 光栅化阶段:图元转换为像素Pixel 像素处理阶段࿱…...
详解CSS
目录 CSS 语法 引入方式 选择器 标签选择器 类选择器 ID选择器 通配符选择器 复合选择器 常用CSS color font-size border width和height padding 外边距 CSS CSS(Cascading Style Sheet),层叠样式表, ⽤于控制⻚⾯的样式. CSS 能够对⽹⻚中元素位置…...
Python执行cmd命令
在Python中执行cmd命令,可以使用内置的subprocess模块。以下是一个简单的例子,展示如何执行一个cmd命令并获取输出。 import subprocess# 要执行的cmd命令 cmd "dir"# 使用subprocess.run来执行命令 result subprocess.run(cmd, shellTrue,…...
基于激光雷达的无人机相互避障
本框架是基于激光雷达的无人机群自主避障代码: 其主体框架利用ORCA算法,他是经典的多智能体相互避障算法,此版本只能规避动态障碍物,不能规避环境形成的静态障碍物我们对ORVA算法稍作修改,使其可以分布式部署ÿ…...
Zookeeper基本原理
1.什么是Zookeeper? Zookeeper是一个开源的分布式协调服务器框架,由Apache软件基金会开发,专为分布式系统设计。它主要用于在分布式环境中管理和协调多个节点之间的配置信息、状态数据和元数据。 Zookeeper采用了观察者模式的设计理念,其核心…...
【生日视频制作】西游记孙悟空师徒提笔毛笔书法横幅AE模板修改文字软件生成器教程特效素材【AE模板】
生日视频制作教程西游记孙悟空师徒提笔毛笔书法横幅AE模板修改文字特效广告生成神器素材祝福玩法AE模板工程 怎么如何做的【生日视频制作】西游记孙悟空师徒提笔毛笔书法横幅AE模板修改文字软件生成器教程特效素材【AE模板】 生日视频制作步骤: 下载AE模板 安装AE…...
春日美食汇:基于SpringBoot的订餐平台
2 系统关键技术 2.1JSP技术 JSP(Java脚本页面)是Sun和许多参与建立的公司所提倡的动态web技术。将Java程序添加到传统的web页面HTML文件()。htm,。Html) [1]。 JSP这种能够独立使用的编程语言可以嵌入在html语言里面运行,正因为JSP参照了许多编程语言的特性…...
微信小程序中如何监听元素进入目标元素
Page({onLoad: function(){// 如果目标节点(用选择器 .target-class 指定)进入显示区域以下 100px 时,就会触发回调函数。wx.createIntersectionObserver().relativeToViewport({bottom: 100}).observe(.target-class, (res) > {res.inter…...
华为 HCIP-Datacom H12-821 题库 (6)
有需要题库的可以看主页置顶 V群仅进行学习交流 1.转发表中 FLAG 字段中B 的含义是? A、可用路由 B、静态路由 C、黑洞路由 D、网关路由 答案:C 解析: 可用路由用U 表示,静态路由用 S 表示,黑洞路由用 B 表示&#x…...
常见的pytest二次开发功能
pytest框架的二次开发主要是为了满足特定的测试需求或扩展其功能。以下是一些常见的pytest二次开发的功能及其实例,以及如何进行开发的大致步骤: 常见的pytest二次开发功能 定制化测试报告: 功能描述:pytest默认生成的测试报告可…...
Linux下安装MySQL8.0
一、安装 1.下载安装包 先创建一个mysql目录,在将压缩包下载到此 # 下载tar包 wget https://dev.mysql.com/get/Downloads/MySQL-8.0/mysql-8.0.20-linux-glibc2.12-x86_64.tar.xz等待下载成功 2.解压mysql8.0安装包 tar xvJf mysql-8.0.20-linux-glibc2.12-x86…...
Linux 文件类型,目录与路径,文件与目录管理
文件类型 后面的字符表示文件类型标志 普通文件:-(纯文本文件,二进制文件,数据格式文件) 如文本文件、图片、程序文件等。 目录文件:d(directory) 用来存放其他文件或子目录。 设备…...
React Native 导航系统实战(React Navigation)
导航系统实战(React Navigation) React Navigation 是 React Native 应用中最常用的导航库之一,它提供了多种导航模式,如堆栈导航(Stack Navigator)、标签导航(Tab Navigator)和抽屉…...
MySQL 隔离级别:脏读、幻读及不可重复读的原理与示例
一、MySQL 隔离级别 MySQL 提供了四种隔离级别,用于控制事务之间的并发访问以及数据的可见性,不同隔离级别对脏读、幻读、不可重复读这几种并发数据问题有着不同的处理方式,具体如下: 隔离级别脏读不可重复读幻读性能特点及锁机制读未提交(READ UNCOMMITTED)允许出现允许…...
HTML 列表、表格、表单
1 列表标签 作用:布局内容排列整齐的区域 列表分类:无序列表、有序列表、定义列表。 例如: 1.1 无序列表 标签:ul 嵌套 li,ul是无序列表,li是列表条目。 注意事项: ul 标签里面只能包裹 li…...
[10-3]软件I2C读写MPU6050 江协科技学习笔记(16个知识点)
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16...
涂鸦T5AI手搓语音、emoji、otto机器人从入门到实战
“🤖手搓TuyaAI语音指令 😍秒变表情包大师,让萌系Otto机器人🔥玩出智能新花样!开整!” 🤖 Otto机器人 → 直接点明主体 手搓TuyaAI语音 → 强调 自主编程/自定义 语音控制(TuyaAI…...
JDK 17 新特性
#JDK 17 新特性 /**************** 文本块 *****************/ python/scala中早就支持,不稀奇 String json “”" { “name”: “Java”, “version”: 17 } “”"; /**************** Switch 语句 -> 表达式 *****************/ 挺好的ÿ…...
基于Java Swing的电子通讯录设计与实现:附系统托盘功能代码详解
JAVASQL电子通讯录带系统托盘 一、系统概述 本电子通讯录系统采用Java Swing开发桌面应用,结合SQLite数据库实现联系人管理功能,并集成系统托盘功能提升用户体验。系统支持联系人的增删改查、分组管理、搜索过滤等功能,同时可以最小化到系统…...
【笔记】WSL 中 Rust 安装与测试完整记录
#工作记录 WSL 中 Rust 安装与测试完整记录 1. 运行环境 系统:Ubuntu 24.04 LTS (WSL2)架构:x86_64 (GNU/Linux)Rust 版本:rustc 1.87.0 (2025-05-09)Cargo 版本:cargo 1.87.0 (2025-05-06) 2. 安装 Rust 2.1 使用 Rust 官方安…...
【JavaSE】多线程基础学习笔记
多线程基础 -线程相关概念 程序(Program) 是为完成特定任务、用某种语言编写的一组指令的集合简单的说:就是我们写的代码 进程 进程是指运行中的程序,比如我们使用QQ,就启动了一个进程,操作系统就会为该进程分配内存…...
