当前位置: 首页 > news >正文

图神经网络实战(10)——归纳学习

图神经网络实战(10)——归纳学习

    • 0. 前言
    • 1. 转导学习与归纳学习
    • 2. 蛋白质相互作用数据集
    • 3. 构建 GraphSAGE 模型实现归纳学习
    • 小结
    • 系列链接

0. 前言

归纳学习 (Inductive learning) 通过基于已观测训练数据,建立一个通用模型,使模型能够对未见过的节点和图进行归纳预测,而转导学习(Transductive learning, 也称直推学习)是基于所有已经观测到的训练和测试数据构建模型,这种方法是通过已经有标记的节点信息来预测无标记数据节点,因此,在图神经网络 (Graph Neural Networks, GNN)、图卷积网络 (Graph Convolutional Network, GCN)、图注意力网络 (Graph Attention Networks,GAT) 和 GraphSAGE 等节中所构建的模型均属于转导学习模型。在本节中,我们将介绍图数据中的归纳学习和多标签分类,使用 GraphSAGE 模型在蛋白质相互作用 (protein-protein interactions) 数据集执行多标签分类任务,并了解归纳学习的优势和实现方法。

1. 转导学习与归纳学习

在图神经网络 (Graph Neural Networks, GNN)中,可以将学习分为两类,转导学习(Transductive learning, 也称直推学习)和归纳学习 (Inductive learning):

  • 在归纳学习中,GNN 在训练过程中只看到训练集中的数据,而在测试过程中需要对未见过的数据进行预测,这属于机器学习中典型的监督学习 (supervised learning)。在这种情况下,标签用来调整 GNN 的参数,模型需要具备良好的泛化能力,能够从有限的样本中推断出普遍适用的规律
  • 在转导学习中,GNN 在训练过程中会看到来自训练集和测试集的数据,它通过对已有的样本进行学习来进行预测和分类。模型只从训练集中学习数据,模型会尝试将已有的样本归类到已知的类别中,并根据这些样本的特征进行预测,标签用于信息扩散。转导学习不是直接从训练集中学习出一般性的规律,而是利用图数据间的相似性或连接性进行预测

我们在之前构建的图神经网络 (Graph Neural Networks, GNN) 和图卷积网络 (Graph Convolutional Network, GCN) 属于转导学习情况。而 GraphSAGE 模型可以在训练过程中使用整个图进行预测 (self(batch.x, batch.edge_index)),然后部分屏蔽这些预测,只使用训练数据计算损失并训练模型 (criterion(out[batch.train_mask], batch.y[batch.train_mask]))。
转导学习只能为固定的图生成嵌入,不能泛化到未见过的节点或图。但由于采用了邻居采样,GraphSAGE 可以在局部水平上对经过剪枝的计算图进行预测,这种情况下属于归纳学习框架,可以应用于具有相同特征模式的任何计算图。

2. 蛋白质相互作用数据集

在 GraphSAGE 一节中,我们已经在 PubMed 数据集上构建 GraphSAGE 模型实现了转导学习。接下来,我们将 GraphSAGE 应用于由 Agrawal 等人提出的蛋白质相互作用 (protein-protein interaction, PPI) 网络数据集。该数据集是 24 个图的集合,其中节点( 21,557 个)是人类蛋白质,边( 342,353 条)是人类细胞中蛋白质之间的连接。用 Gephi 制作的 PPI 图数据集可视化结果如下所示:

PPI 数据集

该数据集的目标是使用 121 个标签进行多标签分类,这意味着每个节点可以具有 0121 个标签。这不同于多类别分类,多类别分类中每个节点只会属于一个类别。接下来,我们使用 PyTorch Geometric (PyG) 实现 GraphSAGE 模型用于对 PPI 数据集执行多标签分类任务。

(1)PPI 数据集加载为三个不同的子集,训练集、验证集和测试集:

import torch
from sklearn.metrics import f1_scorefrom torch_geometric.datasets import PPI
from torch_geometric.data import Batch
from torch_geometric.loader import DataLoader, NeighborLoader
from torch_geometric.nn import GraphSAGE# Load training, evaluation, and test sets
train_dataset = PPI(root=".", split='train')
val_dataset = PPI(root=".", split='val')
test_dataset = PPI(root=".", split='test')

(2) 训练集包含 20 个图,而验证集和测试集只有两个图。对训练集应用邻居采样,为了方便起见,使用 Batch.from_data_list() 将所有训练图统一到一个集合中,然后应用邻居采样:

train_data = Batch.from_data_list(train_dataset)
train_loader = NeighborLoader(train_data, batch_size=2048, shuffle=True, num_neighbors=[20, 10], num_workers=2, persistent_workers=True)

(3) 训练集创建完毕后,使用 DataLoader 类创建批数据,将 batch_size 值定义为 2,与每批图的数量相对应:

val_loader = DataLoader(val_dataset, batch_size=2)
test_loader = DataLoader(test_dataset, batch_size=2)

(4) 定义设备使批处理能够在 GPU 上进行处理。如果计算机中有 GPU,使用 GPU,否则就使用 CPU

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

3. 构建 GraphSAGE 模型实现归纳学习

使用 PyTorch Geometrictorch_geometric.nn 模块构建 GraphSAGE 模型。

(1) 使用 GraphSAGE() 类初始化一个两层的 GraphSAGE 模型,其中隐藏维度为 512,此外,还需要使用 to(device) 将模型放置在与数据相同的设备上:

model = GraphSAGE(in_channels=train_dataset.num_features,hidden_channels=512,num_layers=2,out_channels=train_dataset.num_classes,
).to(device)

(2) fit() 函数与 GraphSAGE 一节中使用的函数类似,不同之处在于,我们希望尽可能将数据移动到 GPU 上,并且由于每批数据有两个图,因此将损失乘以 2 (data.num_graphs):

criterion = torch.nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.005)def fit(loader):model.train()total_loss = 0for data in loader:data = data.to(device)optimizer.zero_grad()out = model(data.x, data.edge_index)loss = criterion(out, data.y)total_loss += loss.item() * data.num_graphsloss.backward()optimizer.step()return total_loss / len(loader.data)

由于 val_loadertest_loader 包含两个图且 batch_size 值为 2,因此在 test() 函数中,两个图位于同一个批数据中,而无需像训练时那样在加载器中循环。

(3) 使用度量指标 F1 分数代替准确率,F1 分数相当于精确度和召回率的调和平均值。但,模型的预测结果是 121 维的实数向量,我们需要将其转换成二进制向量,使用 out > 0 将它们与 data.y 进行比较:

@torch.no_grad()
def test(loader):model.eval()data = next(iter(loader))out = model(data.x.to(device), data.edge_index.to(device))preds = (out > 0).float().cpu()y, pred = data.y.numpy(), preds.numpy()return f1_score(y, pred, average='micro') if pred.sum() > 0 else 0

(4) 对模型进行 300epoch 的训练,并打印训练过程中模型在验证数据集上的 F1 分数:

for epoch in range(301):loss = fit(train_loader)val_f1 = test(val_loader)if epoch % 50 == 0:print(f'Epoch {epoch:>3} | Train Loss: {loss:.3f} | Val F1-score: {val_f1:.4f}')
'''
Epoch   0 | Train Loss: 12.686 | Val F1-score: 0.4866
Epoch  50 | Train Loss: 8.734 | Val F1-score: 0.7963
Epoch 100 | Train Loss: 8.600 | Val F1-score: 0.8098
Epoch 150 | Train Loss: 8.531 | Val F1-score: 0.8202
Epoch 200 | Train Loss: 8.495 | Val F1-score: 0.8230
Epoch 250 | Train Loss: 8.497 | Val F1-score: 0.8255
Epoch 300 | Train Loss: 8.432 | Val F1-score: 0.8290
'''

(5) 最后,计算测试集上的 F1 分数:

print(f'Test F1-score: {test(test_loader):.4f}')# Test F1-score: 0.8527

可以看到,在归纳学习中,模型在 PPI 数据集上训练后的 F1 分数为 0.9360。当增加或减少隐藏维度的大小时,模型的性能会有有大幅改变,我们可以使用不同的值,如 1281,024,并观察训练的后的模型 F1 分数变化。
需要注意的是,在以上代码中,我们并未显式的使用掩码。这是由于实际上,归纳学习是由 PPI 数据集强制实现的;训练数据、验证数据和测试数据位于不同的图和数据加载器中。我们也可以使用 Batch.from_data_list() 将它们合并,然后再使用归纳学习的设定。

小结

在本节中,学习了图神经网络中转导学习(Transductive learning, 也称直推学习)和归纳学习 (Inductive learning) 的区别。其中,图神经网络中的归纳学习通常指的是从给定的训练图数据中学习出一个泛化能力强的模型,以便对未知图数据中的节点或边进行预测或分类,而转导学习通常指的是利用训练图数据和测试图数据之间的关联性进行推断,从而对给定的测试图数据进行预测或分类。并且构建了 GraphSAGE 模型在 PPI 数据集上测试了归纳学习,以执行多标签分类任务。

系列链接

图神经网络实战(1)——图神经网络(Graph Neural Networks, GNN)基础
图神经网络实战(2)——图论基础
图神经网络实战(3)——基于DeepWalk创建节点表示
图神经网络实战(4)——基于Node2Vec改进嵌入质量
图神经网络实战(5)——常用图数据集
图神经网络实战(6)——使用PyTorch构建图神经网络
图神经网络实战(7)——图卷积网络(Graph Convolutional Network, GCN)详解与实现
图神经网络实战(8)——图注意力网络(Graph Attention Networks, GAT)
图神经网络实战(9)——GraphSAGE详解与实现

相关文章:

图神经网络实战(10)——归纳学习

图神经网络实战(10)——归纳学习 0. 前言1. 转导学习与归纳学习2. 蛋白质相互作用数据集3. 构建 GraphSAGE 模型实现归纳学习小结系列链接 0. 前言 归纳学习 (Inductive learning) 通过基于已观测训练数据,建立一个通用模型,使模…...

Python——IO编程

IO在计算机中指Input/Output,也就是输入和输出。由于程序和运行时数据是在内存中驻留,由CPU这个超快的计算核心来执行,涉及到数据交换的地方,通常是磁盘、网络等,就需要IO接口。 比如你打开浏览器,访问新浪…...

什么是网络端口?为什么会有高危端口?

一、什么是网络端口? 网络技术中的端口默认指的是TCP/IP协议中的服务端口,一共有0-65535个端口,比如我们最常见的端口是80端口默认访问网站的端口就是80,你直接在浏览器打开,会发现浏览器默认把80去掉,就是…...

CleanMyMac X v4.14.6中文破解版,让您的电脑像新的一样

小编给您带来CleanMyMac X v4.14.6中文破解版,CleanMyMac X破解版是应用在MacOS上的一款Mac系统清理优化工具,使用cleanmymac x 中文破解版只需两个简单步骤就可以把系统里那些乱七八糟的无用文件统统清理掉,节省宝贵的磁盘空间。 CleanMyMa…...

LeetCode 235. 二叉搜索树的最近公共祖先

LeetCode 235. 二叉搜索树的最近公共祖先 1、题目 题目链接:235. 二叉搜索树的最近公共祖先 给定一个二叉搜索树, 找到该树中两个指定节点的最近公共祖先。 百度百科中最近公共祖先的定义为:“对于有根树 T 的两个结点 p、q,最近公共祖先表…...

基于ASN.1的RSA算法公私钥存储格式解读

1.概述 RFC5958主要定义非对称密钥的封装语法,RFC5958用于替代RFC5208。非对称算法会涉及到1对公私钥,例如按照RSA算法,公钥是n和e,私钥是d和n。当需要将公私钥保存到文件时,需按照一定的格式保存。本文主要定义公私钥…...

RS2227XN功能和参数介绍及PDF资料

RS2227XN是一款模拟开关/多路复用器 品牌: RUNIC(润石) 封装: MSOP-10 描述: USB2.0高速模拟开关 开关电路: 双刀双掷(DPDT) 通道数: 2 工作电压: 1.8V~5.5V 导通电阻(RonVCC): 10Ω 功能:模拟开关/多路复用器 USB2.0高速模拟开关 工作电压范围:1.8V ~ 5…...

机器人非线性阻抗控制系统

机器人非线性控制系统本质上是一个复杂的控制系统,其状态变量和输出变量相对于输入变量的运动特性不能用线性关系来描述。这种系统的形成基于两类原因:一是被控系统中包含有不能忽略的非线性因素,二是为提高控制性能或简化控制系统结构而人为…...

pandas style添加表格边框,或是只添加下边框等自定义边框样式设置

添加表格边框 可以使用如下程序添加表格: import dataframe_image as dfi import pandas as pd import numpy as npdf pd.DataFrame(np.random.random(size(10, 5))) df_style df.style.set_properties(**{text-align: center,border-color: black,border-width…...

OpenHarmony 3GPP协议开发深度剖析——一文读懂RIL

市面上关于终端(手机)操作系统在 3GPP 协议开发的内容太少了,即使 Android 相关的学习文档都很少,Android 协议开发书籍我是没有见过的。可能是市场需求的缘故吧,现在市场上还是前后端软件开发从业人员最多&#xff0c…...

windows部署腾讯tmagic-editor02-Runtime

创建editor项目 将上一教程中的hello-world复制过来,改名hello-editor 创建runtime项目 和hello-editor同级 pnpm create vite删除src/components/HelloWorld.vue 按钮需要用的ts types依赖 pnpm add tmagic/schema tmagic/stage实现runtime 将hello-editor中…...

“分块”算法的基本要素及 build() 函数的构建细节

【“分块”算法知识点】 ● 分块是用线段树的分区思想改良的暴力法。代码比线段树简单。效率比普通暴力法高。分块适合求解 m=n=10^5 规模的问题,或 m*sqrt(n)≈10^7 的问题。其中,n 为元素个数,m 为操作次数。 ● “分块”算法的基本要素 (1)块的大小用 block 表示。通常…...

畅捷通TPlus keyEdit.aspx、KeyInfoList.aspx SQL注入漏洞复现

前言 免责声明:请勿利用文章内的相关技术从事非法测试,由于传播、利用此文所提供的信息或者工具而造成的任何直接或者间接的后果及损失,均由使用者本人负责,所产生的一切不良后果与文章作者无关。该文章仅供学习用途使用。 一、产…...

Ubuntu22 下配置 Qt5 环境

1. Qt 简介 Qt5 中的新功能,可以看到各个版本的情况Whats New in Qt 5 | Qt 5.15 Qt 源文件网址Index of /archive/qt 2. 安装 Qt Creator cd 到安装包所在目录,进行软件安装。赋予可执行权限,加上 sudo 权限进入安装,这样会安…...

普通人也能创业!轻资产短视频带货项目,引领普通人实现创业梦想

在这个信息爆炸的时代,创业似乎成为了越来越多人的梦想。然而,传统的创业模式 keJ0277 往往伴随着高昂的资金投入和复杂的管理流程,让许多普通人望而却步。然而,现在有一种轻资产短视频带货项目正在悄然兴起,它以其低…...

【Maven】Nexus简单使用

1、安装配置介绍Nexus私服: 安装配置指路上一篇详细教程博客 【Maven】Nexus私服简介_下载安装_登录-CSDN博客 简单介绍原有仓库类型: proxy代理仓库:代理远程仓库,访问全球中央仓库或其他公共仓库,将资源存储在私…...

winform嵌入excel 设置父窗体分辨率不是100% 嵌入excel分辨率变成双倍大小

在WinForms应用程序中嵌入Excel时,遇到分辨率问题可能是由于DPI缩放导致的。Windows 10及更高版本默认启用了DPI缩放,以便在高分辨率显示器上显示更清晰的内容。这可能会导致嵌入的应用程序(如Excel)看起来变大或变小。 解决方案 …...

前端系列-4 promise与async/await与fetch/axios使用方式

背景: 本文介绍promise使用方式,以及以Promise为基础的async/await用法和fetch/axios使用方式,主要以案例的方式进行。 1.promise 1.1 promise介绍 javascript是单线程执行的,异步编程的本质是事件机制和函数回调。当执行阻塞…...

微信公众号自定义分销商城小程序源码系统 带完整的安装代码吧以及系统部署搭建教程

系统概述 微信公众号自定义分销商城小程序源码系统是一款功能强大的电商解决方案,它集成了商品管理、订单处理、支付接口、分销管理等多种功能。该系统支持自定义界面设计,商家可根据自身需求调整商城的页面布局和风格,打造独特的品牌形象。…...

在另外一个页面,让另外一个页面弹框显示操作(调佣公共的弹框)vue

大概意思是,登录弹框在另外一个页面中,而当前页面不存在,在当前页面中判断如果token不存在,就弹框出登录的弹框 最后一行 window.location.href … 如果当前用户已登录,则执行后续操作(注意此处,可不要)...

大数据学习栈记——Neo4j的安装与使用

本文介绍图数据库Neofj的安装与使用,操作系统:Ubuntu24.04,Neofj版本:2025.04.0。 Apt安装 Neofj可以进行官网安装:Neo4j Deployment Center - Graph Database & Analytics 我这里安装是添加软件源的方法 最新版…...

OpenLayers 可视化之热力图

注:当前使用的是 ol 5.3.0 版本,天地图使用的key请到天地图官网申请,并替换为自己的key 热力图(Heatmap)又叫热点图,是一种通过特殊高亮显示事物密度分布、变化趋势的数据可视化技术。采用颜色的深浅来显示…...

Cesium1.95中高性能加载1500个点

一、基本方式&#xff1a; 图标使用.png比.svg性能要好 <template><div id"cesiumContainer"></div><div class"toolbar"><button id"resetButton">重新生成点</button><span id"countDisplay&qu…...

PPT|230页| 制造集团企业供应链端到端的数字化解决方案:从需求到结算的全链路业务闭环构建

制造业采购供应链管理是企业运营的核心环节&#xff0c;供应链协同管理在供应链上下游企业之间建立紧密的合作关系&#xff0c;通过信息共享、资源整合、业务协同等方式&#xff0c;实现供应链的全面管理和优化&#xff0c;提高供应链的效率和透明度&#xff0c;降低供应链的成…...

Cilium动手实验室: 精通之旅---20.Isovalent Enterprise for Cilium: Zero Trust Visibility

Cilium动手实验室: 精通之旅---20.Isovalent Enterprise for Cilium: Zero Trust Visibility 1. 实验室环境1.1 实验室环境1.2 小测试 2. The Endor System2.1 部署应用2.2 检查现有策略 3. Cilium 策略实体3.1 创建 allow-all 网络策略3.2 在 Hubble CLI 中验证网络策略源3.3 …...

智能在线客服平台:数字化时代企业连接用户的 AI 中枢

随着互联网技术的飞速发展&#xff0c;消费者期望能够随时随地与企业进行交流。在线客服平台作为连接企业与客户的重要桥梁&#xff0c;不仅优化了客户体验&#xff0c;还提升了企业的服务效率和市场竞争力。本文将探讨在线客服平台的重要性、技术进展、实际应用&#xff0c;并…...

屋顶变身“发电站” ,中天合创屋面分布式光伏发电项目顺利并网!

5月28日&#xff0c;中天合创屋面分布式光伏发电项目顺利并网发电&#xff0c;该项目位于内蒙古自治区鄂尔多斯市乌审旗&#xff0c;项目利用中天合创聚乙烯、聚丙烯仓库屋面作为场地建设光伏电站&#xff0c;总装机容量为9.96MWp。 项目投运后&#xff0c;每年可节约标煤3670…...

React---day11

14.4 react-redux第三方库 提供connect、thunk之类的函数 以获取一个banner数据为例子 store&#xff1a; 我们在使用异步的时候理应是要使用中间件的&#xff0c;但是configureStore 已经自动集成了 redux-thunk&#xff0c;注意action里面要返回函数 import { configureS…...

Java求职者面试指南:计算机基础与源码原理深度解析

Java求职者面试指南&#xff1a;计算机基础与源码原理深度解析 第一轮提问&#xff1a;基础概念问题 1. 请解释什么是进程和线程的区别&#xff1f; 面试官&#xff1a;进程是程序的一次执行过程&#xff0c;是系统进行资源分配和调度的基本单位&#xff1b;而线程是进程中的…...

MacOS下Homebrew国内镜像加速指南(2025最新国内镜像加速)

macos brew国内镜像加速方法 brew install 加速formula.jws.json下载慢加速 &#x1f37a; 最新版brew安装慢到怀疑人生&#xff1f;别怕&#xff0c;教你轻松起飞&#xff01; 最近Homebrew更新至最新版&#xff0c;每次执行 brew 命令时都会自动从官方地址 https://formulae.…...