当前位置: 首页 > news >正文

4.将图神经网络应用于大规模图数据(Cluster-GCN)

       到目前为止,我们已经为节点分类任务单独以全批方式训练了图神经网络。特别是,这意味着每个节点的隐藏表示都是并行计算的,并且可以在下一层中重复使用。

       然而,一旦我们想在更大的图上操作,由于内存消耗爆炸,这种方案就不再可行。例如,一个具有大约1000万个节点和128个隐藏特征维度的图已经为每层消耗了大约5GB的GPU内存

       因此,最近有一些努力让图神经网络扩展到更大的图。其中一种方法被称为Cluster-GCN (Chiang et al. (2019),它基于将图预先划分为子图,可以在子图上以小批量的方式进行操作。

       为了展示,让我们从 Planetoid 节点分类基准套件(Yang et al. (2016))中加载PubMed 图:

import torch
from torch_geometric.datasets import Planetoid
from torch_geometric.transforms import NormalizeFeaturesdataset = Planetoid(root='data/Planetoid', name='PubMed', transform=NormalizeFeatures())print()
print(f'Dataset: {dataset}:')
print('==================')
print(f'Number of graphs: {len(dataset)}')
print(f'Number of features: {dataset.num_features}')
print(f'Number of classes: {dataset.num_classes}')data = dataset[0]  # Get the first graph object.print()
print(data)
print('===============================================================================================================')# Gather some statistics about the graph.
print(f'Number of nodes: {data.num_nodes}')
print(f'Number of edges: {data.num_edges}')
print(f'Average node degree: {data.num_edges / data.num_nodes:.2f}')
print(f'Number of training nodes: {data.train_mask.sum()}')
print(f'Training node label rate: {int(data.train_mask.sum()) / data.num_nodes:.3f}')
print(f'Has isolated nodes: {data.has_isolated_nodes()}')
print(f'Has self-loops: {data.has_self_loops()}')
print(f'Is undirected: {data.is_undirected()}')

在这里插入图片描述
       可以看出,该图大约有19717个节点。虽然这个数量的节点应该可以轻松地放入GPU内存,但它仍然是一个很好的例子,可以展示如何在PyTorch Geometric中扩展GNN

       Cluster-GCN的工作原理是首先基于图划分算法将图划分为子图。因此,GNN被限制为仅在其特定子图内进行卷积,从而省略了邻域爆炸的问题。
在这里插入图片描述
       然而,在对图进行分区后,会删除一些链接,这可能会由于有偏差的估计而限制模型的性能。为了解决这个问题,Cluster-GCN还将类别之间的连接合并到一个小批量中,这导致了以下随机划分方案
在这里插入图片描述

       这里,颜色表示每批维护的邻接信息(对于每个epoch可能不同)。PyTorch Geometric提供了Cluster-GCN算法的两阶段实现

  1. ClusterData 将一个 Data 对象转换为包含num_parts分区的子图的数据集。
  2. 给定一个用户定义的batch_size, ClusterLoader 实现随机划分方案以创建小批量。

然后,制作小批量的程序如下:

from torch_geometric.loader import ClusterData, ClusterLoadertorch.manual_seed(12345)
cluster_data = ClusterData(data, num_parts=128)  # 1. Create subgraphs.
train_loader = ClusterLoader(cluster_data, batch_size=32, shuffle=True)  # 2. Stochastic partioning scheme.print()
total_num_nodes = 0
for step, sub_data in enumerate(train_loader):print(f'Step {step + 1}:')print('=======')print(f'Number of nodes in the current batch: {sub_data.num_nodes}')print(sub_data)print()total_num_nodes += sub_data.num_nodesprint(f'Iterated over {total_num_nodes} of {data.num_nodes} nodes!')

在这里插入图片描述

       在这里,我们将初始图划分为128个分区,并使用32个子图batch_size来形成mini-batches(每个epoch4批)。正如我们所看到的,在一个epoch之后,每个节点都被精确地看到了一次。

       Cluster-GCN的伟大之处在于它不会使GNN模型的实现复杂化。我们构造如下一个简单模型:
在这里插入图片描述
       这种图神经网络的训练与用于图分类任务的图神经网络训练非常相似。我们现在不再以全批处理的方式对图进行操作,而是对每个小批进行迭代,并相互独立地优化每个批:

model = GCN(hidden_channels=16)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)
criterion = torch.nn.CrossEntropyLoss()def train():model.train()for sub_data in train_loader:  # Iterate over each mini-batch.out = model(sub_data.x, sub_data.edge_index)  # Perform a single forward pass.loss = criterion(out[sub_data.train_mask], sub_data.y[sub_data.train_mask])  # Compute the loss solely based on the training nodes.loss.backward()  # Derive gradients.optimizer.step()  # Update parameters based on gradients.optimizer.zero_grad()  # Clear gradients.def test():model.eval()out = model(data.x, data.edge_index)pred = out.argmax(dim=1)  # Use the class with highest probability.accs = []for mask in [data.train_mask, data.val_mask, data.test_mask]:correct = pred[mask] == data.y[mask]  # Check against ground-truth labels.accs.append(int(correct.sum()) / int(mask.sum()))  # Derive ratio of correct predictions.return accsfor epoch in range(1, 51):loss = train()train_acc, val_acc, test_acc = test()print(f'Epoch: {epoch:03d}, Train: {train_acc:.4f}, Val Acc: {val_acc:.4f}, Test Acc: {test_acc:.4f}')

在这里插入图片描述
       在本文中,我们向您介绍了一种将 scale GNNs to large graphs的方法,否则这些图将不适合GPU内存。

本文内容参考:PyG官网

相关文章:

4.将图神经网络应用于大规模图数据(Cluster-GCN)

到目前为止,我们已经为节点分类任务单独以全批方式训练了图神经网络。特别是,这意味着每个节点的隐藏表示都是并行计算的,并且可以在下一层中重复使用。 然而,一旦我们想在更大的图上操作,由于内存消耗爆炸&#xff0c…...

pymongo更新数据

使用 PyMongo,可以通过以下步骤将查询到的记录进行更新: 下面是一个简单的示例代码片段,展示如何向名为users的集合中的所有文档添加一个新字段age。 import pymongo # 连接 MongoDB client pymongo.MongoClient("mongodb://localh…...

手机软件测试规范(含具体用例)

菜单基本功能测试规范一、短消息功能测试规范测试选项操作方法观察与判断结果创建、编辑短消息并发送书写短消息1、分别使用菜单或快捷方式进入书写短消息是否有异常; 2、输入0个字符,选择、输入号码发送,应成功; 3、输入1个中文…...

mysql having的用法

having的用法 having字句可以让我们筛选成组后的各种数据,where字句在聚合前先筛选记录,也就是说作用在group by和having字句前。而 having子句在聚合后对组记录进行筛选。我的理解就是真实表中没有此数据,这些数据是通过一些函数生存。 SQ…...

大数据需要学习哪些内容?

大数据技术的体系庞大且复杂,每年都会涌现出大量新的技术,目前大数据行业所涉及到的核心技术主要就是:数据采集、数据存储、数据清洗、数据查询分析和数据可视化。 Python 已成利器 在大数据领域中大放异彩 Python,成为职场人追求…...

【c++】static和const修饰类的成员变量或成员函数

目录 1、静态成员变量 2、静态成员函数 3、常函数 4、常对象 当我们使用c的关键字static修饰类中的成员变量和成员函数的时候,此时的成员变量和成员函数被称为静态成员。 静态成员包含: 静态成员变量静态成员函数 1、静态成员变量 静态成员变量有…...

DVWA-9.Weak Session IDs

大约 了解会话 ID 通常是在登录后以特定用户身份访问站点所需的唯一内容,如果能够计算或轻松猜测该会话 ID,则攻击者将有一种简单的方法来访问用户帐户,而无需暴力破解密码或查找其他漏洞,例如跨站点脚本。 目的 该模块使用四种…...

Bug序列——容器内给/root目录777权限后无法使用ssh免密登录

Linux——创建容器并将本地调试完全的前后端分离项目打包上传docker运行_北岭山脚鼠鼠的博客-CSDN博客 接着上一篇文章结尾出现403错误时通过赋予/root目录以777权限解决403错误。 chmod 777 /root 现在又出现新的问题,远程ssh无法免密登录了,即使通过…...

华为OD机试真题 JavaScript 实现【服务中心选址】【2023Q1 100分 】

一、题目描述 一个快递公司希望在一条街道建立新的服务中心。公司统计了该街道中所有区域在地图上的位置,并希望能够以此为依据为新的服务中心选址,使服务中心到所有区域的距离的总和最小。 给你一个数组 positions,其中 positions[i] [le…...

<Linux>《OpenSSH 客户端配置文件ssh_config详解》

《OpenSSH 客户端配置文件ssh_config详解》 1、 ssh获取配置数据顺序2、关键字2.1 Host2.2 Match2.3 AddKeysToAgent2.4 AddressFamily2.5 BatchMode2.6 BindAddress2.7 BindInterface2.8 CanonicalDomains2.9 CanonicalizeFallbackLocal2.10 CanonicalizeHostname2.11 Canonic…...

Linux内核中内存管理相关配置项的详细解析8

接前一篇文章:Linux内核中内存管理相关配置项的详细解析7 十一、Enable KSM for page merging 对应配置变量为:CONFIG_KSM。 此项只有选中和不选中两种状态,默认为选中。 内核源码详细解释为: Enable Kernel Samepage Merging:…...

深入浅出Vite:Vite打包与拆分

一、背景 在生产环境下,为了提高页面加载性能,构建工具一般将项目的代码打包(bundle)到一起,这样上线之后只需要请求少量的 JS 文件,大大减少 HTTP 请求。当然,Vite 也不例外,默认情况下 Vite 利用底层打包引擎 Rollup 来完成项目的模块打包。 某种意义上来说,对线上环…...

大数据ETL工具Kettle

提示:文章写完后,目录可以自动生成,如何生成可参考右边的帮助文档 文章目录 前言最近公司在搞大数据数字化,有MES,CIM,WorkFlow等等N多的系统,不同的数据源DB,需要将这些不同的数据源DB里的数据进行整治统一…...

大学物理(上)-期末知识点结合习题复习(4)——质点运动学-动能定理 力做功 保守力与非保守力 势能 机械能守恒定律 完全弹性碰撞

目录 1.力做功 恒力作用下的功 变力的功 2.动能定理 3.保守力与非保守力 4.势能 引力的功与弹力的功 引力势能与弹性势能 5.保守力做功与势能的关系 6.机械能守恒定律 7.完全弹性碰撞 题1 题目描述 题解 题2 题目描述 题解 1.力做功 物体在力作用下移动做功…...

这两个小众的资源搜索工具其实很好用

01 小不点搜索是一个中国网络技术公司开发的网盘搜索引擎,该网站通过与多个主流网盘进行整合,为用户提供一种快速查找和下载文件的方式。小不点搜索因其高效性、便利性和实用性受到了广大用户的喜爱。 在技术实现上,小不点搜索拥有先进的搜…...

Java设计模式(六)— 单例模式1

系列文章目录 单例模式介绍 单例模式之静态常量饿汉式 单例模式之静态代码饿汉式 单例模式之线程不安全懒汉式 文章目录 系列文章目录前言一、单例设计模式介绍二、单例设计模式八种方式三、单例—静态常量饿汉式1.静态常量饿汉式介绍2.静态常量饿汉式案例3.静态常量饿汉式优缺…...

iOS -- isa指针

isa指针:isa指针是一个指向对象所属类或元类的指针。它决定了对象可以调用的方法和属性。isa指针在对象的结构中存在,并且在运行时会被自动设置。isa 指针,表示这个对象是一个什么类。而 Class 类型, 也就是 struct objc_class * …...

【SA8295P 源码分析】14 - Passthrough配置文件 /mnt/vm/images/linux-la.config 内容分析

系列文章汇总见:《【SA8295P 源码分析】00 - 系列文章链接汇总》 本文链接:《【SA8295P 源码分析】14 - Passthrough配置文件 /mnt/vm/images/linux-la.config 内容分析》 透传配置文件位于:qnx.git\apps\qnx_ap\target\hypervisor\gvm\ivi\la\linux-la.config 它是在QNX Ho…...

新型糖基化氨基酸:Fmoc-Thr((Ac4Galβ1-3)Me,Ac4Neu5Acα2-6AcGalNAcα)-OH,化学CAS号174783-92-7

●英文名:Fmoc-Thr((Ac4Galβ1-3)Me,Ac4Neu5Acα2-6AcGalNAcα)-OH ●外观以及性质: Fmoc-Thr((Ac4Galβ1-3)Me,Ac4Neu5Acα2-6AcGalNAcα)-OH中通过对蛋白进行复杂蛋白糖基化修饰,细胞产生了极大丰度的蛋白质类型;通过对各类糖基…...

网络安全(黑客)怎么自学?

最近看到很多问题,都是小白想要转行网络安全行业咨询学习路线和学习资料的,作为一个培训机构,学习路线和免费学习资料肯定是很多的。机构里面的不是顶级的黑阔大佬就是正在学习的同学,也用不上这些内容,每天都在某云盘…...

stm32G473的flash模式是单bank还是双bank?

今天突然有人stm32G473的flash模式是单bank还是双bank?由于时间太久,我真忘记了。搜搜发现,还真有人和我一样。见下面的链接:https://shequ.stmicroelectronics.cn/forum.php?modviewthread&tid644563 根据STM32G4系列参考手…...

微软PowerBI考试 PL300-选择 Power BI 模型框架【附练习数据】

微软PowerBI考试 PL300-选择 Power BI 模型框架 20 多年来,Microsoft 持续对企业商业智能 (BI) 进行大量投资。 Azure Analysis Services (AAS) 和 SQL Server Analysis Services (SSAS) 基于无数企业使用的成熟的 BI 数据建模技术。 同样的技术也是 Power BI 数据…...

以下是对华为 HarmonyOS NETX 5属性动画(ArkTS)文档的结构化整理,通过层级标题、表格和代码块提升可读性:

一、属性动画概述NETX 作用:实现组件通用属性的渐变过渡效果,提升用户体验。支持属性:width、height、backgroundColor、opacity、scale、rotate、translate等。注意事项: 布局类属性(如宽高)变化时&#…...

Cinnamon修改面板小工具图标

Cinnamon开始菜单-CSDN博客 设置模块都是做好的,比GNOME简单得多! 在 applet.js 里增加 const Settings imports.ui.settings;this.settings new Settings.AppletSettings(this, HTYMenusonichy, instance_id); this.settings.bind(menu-icon, menu…...

spring:实例工厂方法获取bean

spring处理使用静态工厂方法获取bean实例,也可以通过实例工厂方法获取bean实例。 实例工厂方法步骤如下: 定义实例工厂类(Java代码),定义实例工厂(xml),定义调用实例工厂&#xff…...

SpringCloudGateway 自定义局部过滤器

场景: 将所有请求转化为同一路径请求(方便穿网配置)在请求头内标识原来路径,然后在将请求分发给不同服务 AllToOneGatewayFilterFactory import lombok.Getter; import lombok.Setter; import lombok.extern.slf4j.Slf4j; impor…...

网络编程(UDP编程)

思维导图 UDP基础编程(单播) 1.流程图 服务器:短信的接收方 创建套接字 (socket)-----------------------------------------》有手机指定网络信息-----------------------------------------------》有号码绑定套接字 (bind)--------------…...

项目部署到Linux上时遇到的错误(Redis,MySQL,无法正确连接,地址占用问题)

Redis无法正确连接 在运行jar包时出现了这样的错误 查询得知问题核心在于Redis连接失败,具体原因是客户端发送了密码认证请求,但Redis服务器未设置密码 1.为Redis设置密码(匹配客户端配置) 步骤: 1).修…...

Maven 概述、安装、配置、仓库、私服详解

目录 1、Maven 概述 1.1 Maven 的定义 1.2 Maven 解决的问题 1.3 Maven 的核心特性与优势 2、Maven 安装 2.1 下载 Maven 2.2 安装配置 Maven 2.3 测试安装 2.4 修改 Maven 本地仓库的默认路径 3、Maven 配置 3.1 配置本地仓库 3.2 配置 JDK 3.3 IDEA 配置本地 Ma…...

深度学习习题2

1.如果增加神经网络的宽度,精确度会增加到一个特定阈值后,便开始降低。造成这一现象的可能原因是什么? A、即使增加卷积核的数量,只有少部分的核会被用作预测 B、当卷积核数量增加时,神经网络的预测能力会降低 C、当卷…...