NeuralCF 模型:神经网络协同过滤模型
实验和完整代码
完整代码实现和jupyter运行:https://github.com/Myolive-Lin/RecSys--deep-learning-recommendation-system/tree/main
引言
NeuralCF 模型由新加坡国立大学研究人员于 2017 年提出,其核心思想在于将传统协同过滤方法与深度学习技术相结合,从而更为有效地捕捉用户与物品之间的复杂交互关系。该模型利用神经网络自动学习用户和物品的低维表示,并通过这些表示实现对用户评分的精准预测。
1. NeuralCF模型简介
NeuralCF 模型融合了矩阵分解与深度学习两种方法的优势,采用基于神经网络的结构来建模用户与物品间的非线性交互。传统矩阵分解方法通过计算用户与物品隐向量的内积来进行评分预测,而 NeuralCF 则利用多层感知机(MLP)对用户与物品隐向量进行联合建模。具体而言,模型首先为每个用户和物品分配低维嵌入向量,然后将这些向量进行拼接(concatenate),再输入到深层神经网络中以提取潜在交互特征,最后通过输出层得到预测评分。
2. NeuralCF的模型架构
NeuralCF 模型的架构主要包括以下关键组件:
- 用户与物品嵌入(Embedding)
与传统矩阵分解方法类似,NeuralCF 为每个用户与物品分配低维嵌入向量,分别表征用户兴趣和物品特征。 - 嵌入向量拼接(Concatenation)
在模型中,用户与物品的嵌入向量被拼接为一个更高维度的向量,作为神经网络的输入。这种拼接不仅保留了各自的特征信息,同时为网络提供了学习复杂交互模式的可能性。 - 多层感知机(MLP)
拼接后的向量经过多个全连接层(MLP)的处理,每一层均采用激活函数(通常为 ReLU)引入非线性变换,以便捕捉用户与物品之间更高阶的特征交互。 - 输出层
多层感知机的输出经过一个线性层转换后,最终得到评分预测。在实际应用中,该预测值可以代表二分类问题(例如点击与否)或回归问题(例如具体评分)

2.1 数学模型
1. 用户和物品嵌入(Embedding)
NeuralCF模型首先为每个用户和每个物品分配一个低维度的隐向量。假设有 M M M 个用户和 N N N 个物品,用户 u u u 的隐向量为 p u ∈ R d \mathbf{p_u} \in \mathbb{R}^d pu∈Rd,物品 i i i 的隐向量为 q i ∈ R d \mathbf{q_i} \in \mathbb{R}^d qi∈Rd,其中 d d d 是隐向量的维度。
2. 嵌入向量的拼接
传统的矩阵分解方法直接计算用户和物品隐向量的内积来预测评分,而NeuralCF通过将用户和物品的隐向量拼接(concatenate)在一起,构成一个新的向量:
z = concat ( p u , q i ) ∈ R 2 d \mathbf{z} = \text{concat}(\mathbf{p_u}, \mathbf{q_i}) \in \mathbb{R}^{2d} z=concat(pu,qi)∈R2d
3. 多层感知机(MLP)
将拼接后的向量z输入到包含L层的多层感知机中,每一层的变换公式为:
h l = ReLU ( W l h l − 1 + b l ) , l = 1 , 2 , … , L \mathbf{h_l} = \text{ReLU}(\mathbf{W_l} \mathbf{h_{l-1}} + \mathbf{b_l}), \quad l = 1, 2, \dots, L hl=ReLU(Wlhl−1+bl),l=1,2,…,L
其中,初始输入为 h 0 = z \mathbf{h_0} = \mathbf{z} h0=z, W l \mathbf{W_l} Wl 和偏置 b l \mathbf{b_l} bl分别为第 l层的权重和偏置参数
4. 输出层
经过多层感知机后,最终输出层采用线性变换:
r u i ^ = σ ( W L h L + b L ) \hat{r_{ui}} = \sigma(\mathbf{W_L} \mathbf{h_L} + \mathbf{b_L}) rui^=σ(WLhL+bL)
其中, σ \sigma σ 表示sigmoid激活函数,输出值位于 0 与 1 之间,适用于二分类任务;对于回归任务,则可去除 Sigmoid 激活。
5. 损失函数
针对不同任务,NeuralCF 可采用不同的损失函数:
- 回归问题:通常使用均方误差(MSE):
L = 1 N ∑ ( u , i ) ( r u i − r u i ^ ) 2 \mathcal{L} = \frac{1}{N} \sum_{(u,i)} \left( r_{ui} - \hat{r_{ui}} \right)^2 L=N1(u,i)∑(rui−rui^)2
- 二分类问题, 损失函数为交叉熵:
L = − 1 N ∑ ( u , i ) ( r u i log ( r u i ^ ) + ( 1 − r u i ) log ( 1 − r u i ^ ) ) \mathcal{L} = -\frac{1}{N} \sum_{(u,i)} \left( r_{ui} \log(\hat{r_{ui}}) + (1 - r_{ui}) \log(1 - \hat{r_{ui}}) \right) L=−N1(u,i)∑(ruilog(rui^)+(1−rui)log(1−rui^))
3 NeuralCF混合模型
为进一步提升特征组合能力和非线性表达能力,NeuralCF 在原有架构基础上引入了广义矩阵分解(Generalized Matrix Factorization, GMF)模块。需要指出的是,GMF 与 MLP 部分分别采用独立的嵌入层,这一设计有效提升了模型的灵活性和表现力。

3.2 GMF广义矩阵分解
广义矩阵分解模型扩展了传统矩阵分解方法,通过引入不同的用户与物品交互方式来建模。与经典矩阵分解方法通过内积计算用户与物品之间的相似性不同,GMF 采用元素积(Hadamard 乘积)来刻画二者间的交互关系:
ϕ 1 ( p u , q i ) = p u ⊙ q i \phi_1(p_u, q_i) = p_u \odot q_i ϕ1(pu,qi)=pu⊙qi
其中, p u p_u pu 和 q i q_i qi 是用户和物品的嵌入向量, ⊙ \odot ⊙ 是元素积操作。
3.4 GMF和MLP的融合
为了解决共享嵌入层的限制,本方法提出了让GMF和MLP分别学习独立的嵌入层,并通过连接它们的最后一层隐藏层进行融合。具体而言,GMF和MLP的输出通过以下公式进行联合建模:
- GMF 部分
ϕ G M F = p u G ⊙ q i G , \phi_{GMF} = p_u^G \odot q_i^G, ϕGMF=puG⊙qiG,
其中, p u G p_u^G puG 和 q i G q_i^G qiG 分别表示GMF部分的用户和物品嵌入向量。
- MLP 部分
通过多层非线性变换,MLP 部分的用户与物品嵌入向量先进行拼接,再逐层传递,形式上可描述为:
ϕ M L P = a L ( W L T ( a L − 1 ( . . . a 2 ( W 2 T [ p u M q i M ] + b 2 ) . . . ) ) + b L ) , \phi_{MLP} = a_L(W_L^T (a_{L-1}(...a_2(W_2^T [p_u^M \quad q_i^M] + b_2)...)) + b_L), ϕMLP=aL(WLT(aL−1(...a2(W2T[puMqiM]+b2)...))+bL),
其中, p u M p_u^M puM 和 q i M q_i^M qiM 分别表示MLP部分的用户和物品嵌入向量; a L ( ⋅ ) a_L(\cdot) aL(⋅) 是激活函数, W L W_L WL 和 b L b_L bL 是MLP的权重和偏置参数。
- 融合与预测
最后,GMF和MLP的输出通过全连接层进行融合并计算最终预测:
y ^ u i = σ ( h T [ ϕ G M F ϕ M L P ] ) \hat{y}_{ui} = \sigma(h^T [\phi_{GMF} \quad \phi_{MLP}]) y^ui=σ(hT[ϕGMFϕMLP])
其中, σ ( ⋅ ) \sigma(\cdot) σ(⋅) 是Sigmoid激活函数, h T h^T hT 是融合层的权重。
该融合策略使得模型可以分别从不同角度捕捉用户与物品的特征,并通过联合表示进一步提升预测准确性与模型灵活性。
4.代码实现
以下代码段展示了基于 PyTorch 的 NeuralCF 模型实现,包括模型配置、数据集构建与模型定义。
模型配置与数据集构建
class Config:num_users = 1000num_items = 2000embed_dim = 16hidden_dims = [64, 32, 16]batch_size = 32lr = 0.001num_epochs = 30# 自定义数据集类
class CFDataset(Dataset):def __init__(self, num_samples=10000):# 生成示例数据(实际使用时替换为真实数据)self.user_ids = np.random.randint(0, Config.num_users, size=num_samples)self.item_ids = np.random.randint(0, Config.num_items, size=num_samples)self.labels = np.random.randint(0, 2, size=num_samples).astype(np.float32)def __len__(self):return len(self.user_ids)def __getitem__(self, idx):return (torch.tensor(self.user_ids[idx], dtype=torch.long),torch.tensor(self.item_ids[idx], dtype=torch.long),torch.tensor(self.labels[idx], dtype=torch.float))
NeuralCF 模型实现
class NeuralCF(nn.Module):def __init__(self, Config):super().__init__()# 定义用户和物品的隐向量self.user_embed_gmf = nn.Embedding(Config.num_users, Config.embed_dim) # GMF用户隐向量self.item_embed_gmf = nn.Embedding(Config.num_items, Config.embed_dim) # GMF物品隐向量self.user_embed_mlp = nn.Embedding(Config.num_users, Config.embed_dim) # MLP用户隐向量self.item_embed_mlp = nn.Embedding(Config.num_items, Config.embed_dim) # MLP物品隐向量# MLP层input_dim = 2 * Config.embed_dimmlp_layers = []for output_dim in Config.hidden_dims:mlp_layers.append(nn.Linear(input_dim, output_dim))mlp_layers.append(nn.ReLU())input_dim = output_dimself.mlp = nn.Sequential(*mlp_layers)# 输出层total_dim = Config.embed_dim + Config.hidden_dims[-1] # GMF + MLP层维度self.fc = nn.Sequential(nn.Linear(total_dim, 1),nn.Sigmoid())def forward(self, user_ids, item_ids):# 获取用户和物品的隐向量user_emb_gmf = self.user_embed_gmf(user_ids)item_emb_gmf = self.item_embed_gmf(item_ids)user_emb_mlp = self.user_embed_mlp(user_ids)item_emb_mlp = self.item_embed_mlp(item_ids)# GMF: 逐元素乘积gmf = user_emb_gmf * item_emb_gmf# MLP: 拼接并通过多层感知机concat_emb = torch.cat([user_emb_mlp, item_emb_mlp], dim=1)mlp = self.mlp(concat_emb)# 拼接GMF和MLP的结果neuralcf_emb = torch.cat([mlp, gmf], dim=1)# 输出层output = self.fc(neuralcf_emb).squeeze()return output
5. NeuralCF的优势
NeuralCF 模型通过引入深度神经网络,有效突破了传统矩阵分解方法的线性限制,能够捕捉用户与物品之间的复杂非线性交互。其主要优势包括:
- 非线性建模能力:利用多层神经网络对用户与物品的隐向量进行非线性组合,充分发掘潜在高阶交互信息。
- 架构灵活性:模型结构可以根据实际问题需求灵活调整隐藏层层数和神经元数量,适应不同数据规模与复杂度。
- 优异的泛化性能:深度学习框架使得 NeuralCF 在处理稀疏数据时能够更好地防止过拟合,提升了模型的泛化能力。
Reference
[1]. He, X., Liao, L., Zhang, H., Nie, L., Hu, X., & Chua, T.-S. (2017). Neural Collaborative Filtering. In Proceedings of the 26th International Conference on World Wide Web (WWW ’17), 173–182. ACM.
[2]. 王喆 《深度学习推荐系统》
相关文章:
NeuralCF 模型:神经网络协同过滤模型
实验和完整代码 完整代码实现和jupyter运行:https://github.com/Myolive-Lin/RecSys--deep-learning-recommendation-system/tree/main 引言 NeuralCF 模型由新加坡国立大学研究人员于 2017 年提出,其核心思想在于将传统协同过滤方法与深度学习技术相结…...
第二十三章 MySQL锁之表锁
目录 一、概述 二、语法 三、特点 一、概述 表级锁,每次操作锁住整张表。锁定粒度大,发生锁冲突的概率最高,并发度最低。应用在MyISAM、InnoDB、BDB等存储引擎中。 对于表级锁,主要分为以下三类: 1. 表锁 2. 元数…...
【Uniapp-Vue3】获取用户状态栏高度和胶囊按钮高度
在项目目录下创建一个utils文件,并在里面创建一个system.js文件。 在system.js中配置如下代码: const SYSTEM_INFO uni.getSystemInfoAsync();// 返回状态栏高度 export const getStatusBarHeight ()> SYSTEM_INFO.statusBarHeight || 15;// 返回胶…...
04树 + 堆 + 优先队列 + 图(D1_树(D10_决策树))
目录 一、引言 二、算法原理 三、算法实现 四、知识小结 一、引言 决策树算法是一种常用的机器学习算法,可用于分类和回归问题。它基于特征之间的条件判断来构 建一棵树,树的每个节点代表一个特征,每个叶节点代表一个类别或回归值。决策…...
通向AGI之路:人工通用智能的技术演进与人类未来
文章目录 引言:当机器开始思考一、AGI的本质定义与技术演进1.1 从专用到通用:智能形态的范式转移1.2 AGI发展路线图二、突破AGI的五大技术路径2.1 神经符号整合(Neuro-Symbolic AI)2.2 世界模型架构(World Models)2.3 具身认知理论(Embodied Cognition)三、AGI安全:价…...
将ollama迁移到其他盘(eg:F盘)
文章目录 1.迁移ollama的安装目录2.修改环境变量3.验证 背景:在windows操作系统中进行操作 相关阅读 :本地部署deepseek模型步骤 1.迁移ollama的安装目录 因为ollama默认安装在C盘,所以只能安装好之后再进行手动迁移位置。 # 1.迁移Ollama可…...
为AI聊天工具添加一个知识系统 之86 详细设计之27 数据处理:ETL
本文要点 ETL 数据提取 作为 数据项目的起点。数据的整个三部曲--里程碑式的发展进程: ETL : 1分形 Type()-层次Broker / 2完形 Method() - 维度Delegate /3 整形 Class() - 容器 Agent 1变象。变象 脸谱Extractor - 缠度(物理 皮肤缠度…...
Java自定义IO密集型和CPU密集型线程池
文章目录 前言线程池各类场景描述常见场景案例设计思路公共类自定义工厂类-MyThreadFactory自定义拒绝策略-RejectedExecutionHandlerFactory自定义阻塞队列-TaskQueue(实现 核心线程->最大线程数->队列) 场景1:CPU密集型场景思路&…...
使用开源项目:pdf2docx,让PDF转换为Word
目录 1.安装python 2.安装 pdf2docx 3.使用 pdf2docx 转换 PDF 到 Word pdf2docx:GitCode - 全球开发者的开源社区,开源代码托管平台 环境:windows电脑 1.安装python Download Python | Python.org 最好下载3.8以上的版本 安装时记得选择上&#…...
蓝桥杯思维训练营(四)
文章目录 小红打怪494.目标和 小红打怪 小红打怪 思路分析:可以看到ai的范围较大,如果我们直接一个个进行暴力遍历的话,会超时。当我们的攻击的次数越大的时候,怪物的血量就会越少,这里就有一个单调的规律在里面&…...
尝试把clang-tidy集成到AWTK项目
前言 项目经过一段时间的耕耘终于进入了团队开发阶段,期间出现了很多问题,其中一个就是开会讨论团队的代码风格规范,目前项目代码风格比较混乱,有的模块是驼峰,有的模块是匈牙利,后面经过讨论,…...
【学习笔记】深度学习网络-正则化方法
作者选择了由 Ian Goodfellow、Yoshua Bengio 和 Aaron Courville 三位大佬撰写的《Deep Learning》(人工智能领域的经典教程,深度学习领域研究生必读教材),开始深度学习领域学习,深入全面的理解深度学习的理论知识。 在之前的文章中介绍了深度学习中用…...
介绍一下Mybatis的底层原理(包括一二级缓存)
表面上我们的就是Sql语句和我们的java对象进行映射,然后Mapper代理然后调用方法来操作数据库 底层的话我们就涉及到Sqlsession和Configuration 首先说一下SqlSession, 它可以被视为与数据库交互的一个会话,用于执行 SQL 语句(Ex…...
WordPress使用(1)
1. 概述 WordPress是一个开源博客框架,配合不同主题,可以有多种展现方式,博客、企业官网、CMS系统等,都可以很好的实现。 官网:博客工具、发布平台和内容管理系统 – WordPress.org China 简体中文,这里可…...
BUUCTF_[安洵杯 2019]easy_web(preg_match绕过/MD5强碰撞绕过/代码审计)
打开靶场,出现下面的静态html页面,也没有找到什么有价值的信息。 查看页面源代码 在url里发现了img传参还有cmd 求img参数 这里先从img传参入手,这里我发现img传参好像是base64的样子 进行解码,解码之后还像是base64的样子再次进…...
C基础寒假练习(4)
输入带空格的字符串,求单词个数、 #include <stdio.h> // 计算字符串长度的函数 size_t my_strlen(const char *str) {size_t len 0;while (str[len] ! \0) {len;}return len; }int main() {char str[100];printf("请输入一个字符串: ");fgets(…...
git error: invalid path
git clone GitHub - guanpengchn/awesome-books: :books: 开发者推荐阅读的书籍 在windows上想把这个仓库拉取下来,发现本地git仓库创建 但只有一个.git隐藏文件夹,其他文件都处于删除状态。 问题: Cloning into awesome-books... remote:…...
MySQL 事务实现原理( 详解 )
MySQL 主要是通过: 锁、Redo Log、Undo Log、MVCC来实现事务 事务的隔离性利用锁机制实现 原子性、一致性和持久性由事务的 redo 日志和undo 日志来保证。 Redo Log(重做日志):记录事务对数据库的所有修改,在崩溃时恢复未提交的更改,保证事务…...
git基础使用--1--版本控制的基本概念
文章目录 git基础使用--1--版本控制的基本概念1.版本控制的需求背景,即为啥需要版本控制2. 集中式版本控制SVN3. 分布式版本控制 Git4. SVN和Git的比较 git基础使用–1–版本控制的基本概念 1.版本控制的需求背景,即为啥需要版本控制 先说啥叫版本&…...
Spring RESTful API 设计与实现
Spring RESTful API的设计与实现极大地提升了开发效率和系统可维护性,通过遵循RESTful设计原则,使得API结构清晰、行为一致,便于扩展和维护。它在构建微服务架构中扮演着核心角色,支持松耦合的通信,同时通过标准的HTTP协议和数据格式增强了系统的互操作性。结合Spring Sec…...
Unity飞行代码 超仿真 保姆级教程
本文使用Rigidbody控制飞机,基本不会穿模。 效果 飞行效果 这是一条优雅的广告 如果你也在开发飞机大战等类型的飞行游戏,欢迎在主页搜索博文并参考。 搜索词:Unity游戏(Assault空对地打击)开发。 脚本编写 首先是完整代码。 using System.Co…...
【自学笔记】Git的重点知识点-持续更新
提示:文章写完后,目录可以自动生成,如何生成可参考右边的帮助文档 文章目录 Git基础知识Git高级操作与概念Git常用命令 总结 Git基础知识 Git简介 Git是一种分布式版本控制系统,用于记录文件内容的改动,便于开发者追踪…...
力扣73矩阵置零
给定一个 m x n 的矩阵,如果一个元素为 0 ,则将其所在行和列的所有元素都设为 0 。请使用 原地 算法。 输入:matrix [[1,1,1],[1,0,1],[1,1,1]] 输出:[[1,0,1],[0,0,0],[1,0,1]] 输入:matrix [[0,1,2,0],[3,4,5,2],[…...
登录认证(5):过滤器:Filter
统一拦截 上文我们提到(登录认证(4):令牌技术),现在大部分项目都使用JWT令牌来进行会话跟踪,来完成登录功能。有了JWT令牌可以标识用户的登录状态,但是完整的登录逻辑如图所示&…...
python算法和数据结构刷题[1]:数组、矩阵、字符串
一画图二伪代码三写代码 LeetCode必刷100题:一份来自面试官的算法地图(题解持续更新中)-CSDN博客 算法通关手册(LeetCode) | 算法通关手册(LeetCode) (itcharge.cn) 面试经典 150 题 - 学习计…...
详解u3d之AssetBundle
一.AssetBundle的概念 “AssetBundle”可以指两种不同但相关的东西。 1.1 AssetBundle指的是u3d在磁盘上生成的存放资源的目录 目录包含两种类型文件(下文简称AB包): 一个序列化文件,其中包含分解为各个对象并写入此单个文件的资源。资源文件&#x…...
接口测试通用测试用例
接口测试主要用于检测外部系统与系统之间以及内部各个子系统之间的交互点。 测试的重点是检查数据的交换,传递和控制管理过程,以及系统间的相互逻辑依赖关系等。 现在很多系统前后端架构是分离的,从安全层面来说,只依赖前段进行限…...
深入理解 C# 与.NET 框架
.NET学习资料 .NET学习资料 .NET学习资料 一、引言 在现代软件开发领域,C# 与.NET 框架是构建 Windows、Web、移动及云应用的强大工具。C# 作为一种面向对象的编程语言,而.NET 框架则是一个综合性的开发平台,它们紧密结合,为开…...
CSS 图像、媒体和表单元素的样式化指南
CSS 图像、媒体和表单元素的样式化指南 1. 替换元素:图像和视频1.1 调整图像大小示例代码:调整图像大小 1.2 使用 object-fit 控制图像显示示例代码:使用 object-fit 2. 布局中的替换元素示例代码:Grid 布局中的图像 3. 表单元素的…...
【BUUCTF杂项题】荷兰宽带数据泄露、九连环
一.荷兰宽带数据泄露 打开发现是一个.bin为后缀的二进制文件,因为提示宽带数据泄露,考虑是宽带路由器方向的隐写 补充:大多数现代路由器都可以让您备份一个文件路由器的配置文件,软件RouterPassView可以读取这个路由配置文件。 用…...
