当前位置: 首页 > news >正文

[图神经网络]PyTorch简单实现一个GCN

Pytorch自带一个PyG的图神经网络库,和构建卷积神经网络类似。不同于卷积神经网络仅需重构__init__( )forward( )两个函数,PyTorch必须额外重构propagate( )message( )函数。

一、环境构建

        ①安装torch_geometric包。

pip install torch_geometric

        ②导入相关库

import torch
import torch.nn.functional as F
import torch.nn as nn
import torch_geometric.nn as pyg_nn
from torch_geometric.datasets import Planetoid

二、PyG图学习架构

        构建方法:首先继承MessagePassing类,接下来重写构造函数和以下三个方法:

message()      #构建消息传递
aggregate()    #将消息聚合到目标节点
update()       #更新消息节点

        1.构造函数

def __init__(self, aggr: Optional[str] = "add",flow: str = "source_to_target", node_dim: int = -2,decomposed_layers: int = 1):
参数内容
aggr消息聚合的方式,常见的方法:addmeanminmax
flow

消息传播的方向,source_to_target--从源节点到目标节点

                             target_to_source--从目标节点到源节点

node_dim传播的维度

        2.propagate函数

                该函数为消息传播的启动函数,调用此函数后会依次执行:messageaggregateupdate来完成消息的传递、聚合、更新

                该函数声明如下:

propagate(self, edge_index: Adj, size: Size = None, **kwargs)
参数说明
edge_index边索引
size邻接矩阵的尺寸,若为None则表示方阵
**kwargs额外参数

        3.message函数

                用于构建节点消息,传递给propagatetensor可以映射到中心节点和邻居节点,仅需在相应的变量名后加上_i(中心节点)或_j(邻居节点)即可。

self.propagate(edge_index, x=x):passdef message(self, x_i, x_j, edge_index_i):pass
x_i中心节点构成的特征向量组成的矩阵
x_j邻居节点构成的特征向量组成的矩阵
edge_index_i中心节点的索引

        4.aggregate函数

                消息聚合函数,用以聚合来自邻居的消息,常见的方法有add、sum、mean、max,可以通过super().__init__()中的参数aggr来设定

        5.update函数

                用于更新节点的消息

三、GCN图卷积网络

        GCN网络的原理可见:图卷积神经网络--GCN

        需要注意 torch_scatter无法使用pip install加载可以参见 torch_scatter安装

        1.加载数据集

from torch_geometric.datasets import Planetoiddataset = Planetoid(root='Cora', name='Cora')

                Cora数据集是一个根据科学论文之间相互引用关系构建的图(Graph)数据集合,论文合计7类,共2708篇论文(2708个节点),10556条边。

        2.定义GCN层

class GCNConv(MessagePassing):def __init__(self, in_channels, out_channels, add_self_loops=True, bias=True):super(GCNConv, self).__init__()self.add_self_loops = add_self_loopsself.edge_index = Noneself.linear = pyg_nn.dense.linear.Linear(in_channels, out_channels, weight_initializer='glorot')if bias:self.bias = nn.Parameter(torch.Tensor(out_channels, 1))self.bias = pyg_nn.inits.glorot(self.bias)else:self.register_parameter('bias', None)# 1.消息传递def message(self, x, edge_index):# 1.对所有节点进行新的空间映射x = self.linear(x) # [num_nodes, feature_size]# 2.添加偏置if self.bias != None:x += self.bias.flatten()# 3.返回source、target信息,对应边的起点和终点row, col = edge_index # [E]# 4.获得度矩阵deg = degree(col, x.shape[0], x.dtype) # [num_nodes]# 5.度矩阵归一化deg_inv_sqrt = deg.pow(-0.5) # [num_nodes]# 6.计算sqrt(deg(i)) * sqrt(deg(j))norm = deg_inv_sqrt[row] * deg_inv_sqrt[col] # [num_nodes]# 7.返回所有边的映射x_j = x[row] # [E, feature_size]# 8.计算归一化后的节点特征x_j = norm.view(-1, 1) * x_j # [E, feature_size]return x_j# 2.消息聚合def aggregate(self, x_j, edge_index):# 1.返回source、target信息,对应边的起点和终点row, col = edge_index # [E]# 2.聚合邻居特征aggr_out = scatter(x_j, row, dim=0, reduce='sum') # [num_nodes, feature_size]return aggr_out# 3.节点更新def update(self, aggr_out):# 对于GCN没有这个阶段,所以直接返回return aggr_outdef forward(self, x, edge_index):# 2.添加自环信息,考虑自身信息if self.add_self_loops:edge_index, _ = add_self_loops(edge_index, num_nodes=x.shape[0]) # [2, E]return self.propagate(edge_index, x=x)

        3.定义GCN网络

class GCN(nn.Module):def __init__(self, num_node_features, num_classes):super(GCN, self).__init__()self.conv1 = GCNConv(num_node_features, 16)self.conv2 = GCNConv(16, num_classes)def forward(self, data):x, edge_index = data.x, data.edge_indexx = self.conv1(x, edge_index)x = F.relu(x)x = F.dropout(x, training=self.training)x = self.conv2(x, edge_index)return F.log_softmax(x, dim=1)

        4.模型调用

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # 设备
epochs = 200 # 学习轮数
lr = 0.0003 # 学习率
num_node_features = dataset.num_node_features # 每个节点的特征数
num_classes = dataset.num_classes # 每个节点的类别数
data = dataset[0].to(device) # Cora的一张图# 4.定义模型
model = GCN(num_node_features, num_classes).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=lr) # 优化器
loss_function = nn.NLLLoss() # 损失函数# 训练模式
model.train()for epoch in range(epochs):optimizer.zero_grad()pred = model(data)loss = loss_function(pred[data.train_mask], data.y[data.train_mask]) # 损失correct_count_train = pred.argmax(axis=1)[data.train_mask].eq(data.y[data.train_mask]).sum().item() # epoch正确分类数目acc_train = correct_count_train / data.train_mask.sum().item() # epoch训练精度loss.backward()optimizer.step()if epoch % 20 == 0:print("【EPOCH: 】%s" % str(epoch + 1))print('训练损失为:{:.4f}'.format(loss.item()), '训练精度为:{:.4f}'.format(acc_train))print('【Finished Training!】')# 模型验证
model.eval()
pred = model(data)# 训练集(使用了掩码)
correct_count_train = pred.argmax(axis=1)[data.train_mask].eq(data.y[data.train_mask]).sum().item()
acc_train = correct_count_train / data.train_mask.sum().item()
loss_train = loss_function(pred[data.train_mask], data.y[data.train_mask]).item()# 测试集
correct_count_test = pred.argmax(axis=1)[data.test_mask].eq(data.y[data.test_mask]).sum().item()
acc_test = correct_count_test / data.test_mask.sum().item()
loss_test = loss_function(pred[data.test_mask], data.y[data.test_mask]).item()print('Train Accuracy: {:.4f}'.format(acc_train), 'Train Loss: {:.4f}'.format(loss_train))
print('Test  Accuracy: {:.4f}'.format(acc_test), 'Test  Loss: {:.4f}'.format(loss_test))

相关文章:

[图神经网络]PyTorch简单实现一个GCN

Pytorch自带一个PyG的图神经网络库,和构建卷积神经网络类似。不同于卷积神经网络仅需重构__init__( )和forward( )两个函数,PyTorch必须额外重构propagate( )和message( )函数。 一、环境构建 ①安装torch_geometric包。 pip install torch_geometric …...

Elasticsearch(黑马)

初识elasticsearch ​​. 安装elasticsearch 1.部署单点es 1.1.创建网络 因为我们还需要部署kibana容器,因此需要让es和kibana容器互联。这里先创建一个网络: docker network create es-net 1.2.加载镜像 这里我们采用elasticsearch的7.12.1版本的…...

oracle数据库调整字段类型

oracle数据库更改字段类型比较墨迹,因为如果该字段有值,是不允许直接更改字段类型的。另外oralce不支持在指定的某个字段后面新增一个字段,但是mysql数据可以向指定的字段后面新增一个字段。 mysql向指定字段后面新增一个字段: al…...

面部表情识别2:Pytorch实现表情识别(含表情识别数据集和训练代码)

面部表情识别2:Pytorch实现表情识别(含表情识别数据集和训练代码) 目录 面部表情识别2:Pytorch实现表情识别(含表情识别数据集和训练代码) 1.面部表情识别方法 2.面部表情识别数据集 (1)表情识别数据集说明 (2&…...

赛效:如何在线给图片加水印

学会给图片加水印是一个非常实用的技能,可以让你的图片更具保护性和个性化。说到加水印,很多人不知道怎么操作。其实,给图片加水印非常简单,不用下载任何程序,在线就能完成。今天,我将介绍如何使用改图宝在…...

动力节点杜老师Vue笔记——Vue程序初体验

一、Vue程序初体验 我们可以先不去了解Vue框架的发展历史、Vue框架有什么特点、Vue是谁开发的,这些对我们编写Vue程序起不到太大的作用,更何况现在说了一些特点之后,我们也没有办法彻底理解它,因此我们可以先学会用,使…...

ajax上传图片存入到指定的文件夹并回显

html代码&#xff1a; <!DOCTYPE html> <html lang"en"> <head><meta charset"UTF-8"><title>Title</title><script src"js/jquery-2.1.0.js"></script> </head> <body> <form…...

cesium加载cesiumlab切的影像切片和标准TMS瓦片的区别

1.加载cesiumlab切的影像 var labImg viewer.scene.imageryLayers.addImageryProvider( new Cesium.UrlTemplateImageryProvider({url:http://192.168.1.25:8080/DOMtms/{z}/{x}/{y}.png,fileExtension : "png"})); 2.标准TMS瓦片 var labImg viewer.scene.im…...

第二周P9-P22

文章目录第三章 系统总线3.1、总线的基本概念一、为什么要用总线二、什么是总线三、总线上信息的传送四、总线结构的计算机举例1、单总线结构框图2、面向CPU的双总线结构框图3、以存储器为中心的双总线结构图3.2、总线的分类1、片内总线2、系统总线3、通信走线3.3、总线特性及性…...

java反射

文章目录何为反射&#xff1f;反射的应用场景了解么&#xff1f;谈谈反射机制的优缺点优点缺点反射实战获取 Class 对象的四种方式1. 知道具体类的情况下可以使用TargetObject.class&#xff1a;2. 通过 Class.forName()传入类的全路径获取&#xff1a;3. 通过对象实例instance…...

307 Temporary Redirect 解决办法(临时重定向)

背景&#xff1a;java后台服务请求python服务端 java服务报错&#xff1a;Unexpected response status&#xff1a;307 python服务端报错&#xff1a;307 Temporary Redirect 解决&#xff1a;查了好久找不到什么原因&#xff0c;请求路径问题 请求url&#xff1a;http//:w…...

SpringBoot案例

SpringBoot案例5&#xff0c;案例5.1 创建工程5.2 代码拷贝5.3 配置文件5.4 静态资源目标 基于SpringBoot的完成SSM整合项目开发 5&#xff0c;案例 SpringBoot 到这就已经学习完毕&#xff0c;接下来我们将学习 SSM 时做的三大框架整合的案例用 SpringBoot 来实现一下。我们完…...

Android 10.0 系统framework发送悬浮通知的流程分析

1.前言 在android10.0rom定制化开发中,在原生系统的systemui中,状态栏通知,和闹钟,wifi等悬浮通知也是很重要的, 悬浮通知也是系统通知的一种,也是在frameworks中发送出来的通知,接下来就分析下10.0中的悬浮通知的发送 流程,然后就可以实现自己自定义悬浮通知的相关功…...

傅里叶谱方法-傅里叶谱方法求解二维浅水方程组和二维粘性 Burgers 方程及其Matlab程序实现

3.3.2 二维浅水方程组 二维浅水方程组是描述水波运动的基本方程之一。它主要用于描述近岸浅水区域内的波浪、潮汐等水动力学现象。这个方程组由两个偏微分方程组成&#xff0c;一个是质量守恒方程&#xff0c;另一个是动量守恒方程。浅水方程描述了具有自由表面、密度均匀、深…...

算法训练营 - 广度优先BFS

目录 从层序遍历开始 N 叉树的层序遍历 经典BFS最短路模板 经典C queue 数组模拟队列 打印路径 示例1.bfs查找所有连接方块 Cqueue版 数组模拟队列 示例2.从多个位置同时开始BFS 示例3.抽象最短路类&#xff08;作图关键&#xff09; 示例4.跨过障碍的最短路 从层序遍历…...

​​​​​​​判断两个字符串是否匹配(1个通配符代表一个字符)

目录 判断两个字符串是否匹配(1个通配符代表一个字符) 程序设计 程序分析...

用css画一个csdn程序猿

效果如下&#xff1a; 头部 我们先来拆解一下&#xff0c;程序猿的结构 两只耳朵和头是圆形组成的&#xff0c;耳朵内的红色部分也是圆形 先画头部&#xff0c;利用圆角实现头部形状 借助工具来快速实现圆角效果 https://9elements.github.io/fancy-border-radius/ <div…...

Java多线程编程—wait/notify机制

文章目录1. 不使用wait/notify机制通信的缺点2. 什么是wait/notify机制3. wait/notify机制原理4. wait/notify方法的基本用法5. 线程状态的切换6. interrupt()遇到方法wait()7. notify/notifyAll方法8. wait(long)介绍9. 生产者/消费者模式10. 管道机制11. 利用wait/notify实现…...

Three.js教程:旋转动画、requestAnimationFrame周期性渲染

推荐&#xff1a;将NSDT场景编辑器加入你3D工具链其他工具系列&#xff1a;NSDT简石数字孪生基于WebGL技术开发在线游戏、商品展示、室内漫游往往都会涉及到动画&#xff0c;初步了解three.js可以做什么&#xff0c;深入讲解three.js动画之前&#xff0c;本节课先制作一个简单的…...

租车自驾app开发有什么作用?租车便利出行APP开发

在线租车APP有哪些优势&#xff0c;租车APP开发的基本功能&#xff0c;租车自驾app开发有什么作用?租车便利出行APP开发&#xff0c;租车服务平台小程序有哪些功能&#xff0c;租车软件开发需要多少钱&#xff0c;租车app都有哪些&#xff0c;租车平台定制开发&#xff0c;租车…...

MySQL 隔离级别:脏读、幻读及不可重复读的原理与示例

一、MySQL 隔离级别 MySQL 提供了四种隔离级别,用于控制事务之间的并发访问以及数据的可见性,不同隔离级别对脏读、幻读、不可重复读这几种并发数据问题有着不同的处理方式,具体如下: 隔离级别脏读不可重复读幻读性能特点及锁机制读未提交(READ UNCOMMITTED)允许出现允许…...

AI Agent与Agentic AI:原理、应用、挑战与未来展望

文章目录 一、引言二、AI Agent与Agentic AI的兴起2.1 技术契机与生态成熟2.2 Agent的定义与特征2.3 Agent的发展历程 三、AI Agent的核心技术栈解密3.1 感知模块代码示例&#xff1a;使用Python和OpenCV进行图像识别 3.2 认知与决策模块代码示例&#xff1a;使用OpenAI GPT-3进…...

如何在看板中体现优先级变化

在看板中有效体现优先级变化的关键措施包括&#xff1a;采用颜色或标签标识优先级、设置任务排序规则、使用独立的优先级列或泳道、结合自动化规则同步优先级变化、建立定期的优先级审查流程。其中&#xff0c;设置任务排序规则尤其重要&#xff0c;因为它让看板视觉上直观地体…...

visual studio 2022更改主题为深色

visual studio 2022更改主题为深色 点击visual studio 上方的 工具-> 选项 在选项窗口中&#xff0c;选择 环境 -> 常规 &#xff0c;将其中的颜色主题改成深色 点击确定&#xff0c;更改完成...

c++ 面试题(1)-----深度优先搜索(DFS)实现

操作系统&#xff1a;ubuntu22.04 IDE:Visual Studio Code 编程语言&#xff1a;C11 题目描述 地上有一个 m 行 n 列的方格&#xff0c;从坐标 [0,0] 起始。一个机器人可以从某一格移动到上下左右四个格子&#xff0c;但不能进入行坐标和列坐标的数位之和大于 k 的格子。 例…...

【ROS】Nav2源码之nav2_behavior_tree-行为树节点列表

1、行为树节点分类 在 Nav2(Navigation2)的行为树框架中,行为树节点插件按照功能分为 Action(动作节点)、Condition(条件节点)、Control(控制节点) 和 Decorator(装饰节点) 四类。 1.1 动作节点 Action 执行具体的机器人操作或任务,直接与硬件、传感器或外部系统…...

CRMEB 中 PHP 短信扩展开发:涵盖一号通、阿里云、腾讯云、创蓝

目前已有一号通短信、阿里云短信、腾讯云短信扩展 扩展入口文件 文件目录 crmeb\services\sms\Sms.php 默认驱动类型为&#xff1a;一号通 namespace crmeb\services\sms;use crmeb\basic\BaseManager; use crmeb\services\AccessTokenServeService; use crmeb\services\sms\…...

【JVM】Java虚拟机(二)——垃圾回收

目录 一、如何判断对象可以回收 &#xff08;一&#xff09;引用计数法 &#xff08;二&#xff09;可达性分析算法 二、垃圾回收算法 &#xff08;一&#xff09;标记清除 &#xff08;二&#xff09;标记整理 &#xff08;三&#xff09;复制 &#xff08;四&#xff…...

HybridVLA——让单一LLM同时具备扩散和自回归动作预测能力:训练时既扩散也回归,但推理时则扩散

前言 如上一篇文章《dexcap升级版之DexWild》中的前言部分所说&#xff0c;在叠衣服的过程中&#xff0c;我会带着团队对比各种模型、方法、策略&#xff0c;毕竟针对各个场景始终寻找更优的解决方案&#xff0c;是我个人和我司「七月在线」的职责之一 且个人认为&#xff0c…...

DeepSeek源码深度解析 × 华为仓颉语言编程精粹——从MoE架构到全场景开发生态

前言 在人工智能技术飞速发展的今天&#xff0c;深度学习与大模型技术已成为推动行业变革的核心驱动力&#xff0c;而高效、灵活的开发工具与编程语言则为技术创新提供了重要支撑。本书以两大前沿技术领域为核心&#xff0c;系统性地呈现了两部深度技术著作的精华&#xff1a;…...