[图神经网络]PyTorch简单实现一个GCN
Pytorch自带一个PyG的图神经网络库,和构建卷积神经网络类似。不同于卷积神经网络仅需重构__init__( )和forward( )两个函数,PyTorch必须额外重构propagate( )和message( )函数。
一、环境构建
①安装torch_geometric包。
pip install torch_geometric
②导入相关库
import torch
import torch.nn.functional as F
import torch.nn as nn
import torch_geometric.nn as pyg_nn
from torch_geometric.datasets import Planetoid
二、PyG图学习架构
构建方法:首先继承MessagePassing类,接下来重写构造函数和以下三个方法:
message() #构建消息传递
aggregate() #将消息聚合到目标节点
update() #更新消息节点
1.构造函数
def __init__(self, aggr: Optional[str] = "add",flow: str = "source_to_target", node_dim: int = -2,decomposed_layers: int = 1):
| 参数 | 内容 |
| aggr | 消息聚合的方式,常见的方法:add、mean、min、max |
| flow | 消息传播的方向,source_to_target--从源节点到目标节点 target_to_source--从目标节点到源节点 |
| node_dim | 传播的维度 |
2.propagate函数
该函数为消息传播的启动函数,调用此函数后会依次执行:message、aggregate、update来完成消息的传递、聚合、更新。
该函数声明如下:
propagate(self, edge_index: Adj, size: Size = None, **kwargs)
| 参数 | 说明 |
| edge_index | 边索引 |
| size | 邻接矩阵的尺寸,若为None则表示方阵 |
| **kwargs | 额外参数 |
3.message函数
用于构建节点消息,传递给propagate的tensor可以映射到中心节点和邻居节点,仅需在相应的变量名后加上_i(中心节点)或_j(邻居节点)即可。
self.propagate(edge_index, x=x):passdef message(self, x_i, x_j, edge_index_i):pass
| x_i | 中心节点构成的特征向量组成的矩阵 |
| x_j | 邻居节点构成的特征向量组成的矩阵 |
| edge_index_i | 中心节点的索引 |
4.aggregate函数
消息聚合函数,用以聚合来自邻居的消息,常见的方法有add、sum、mean、max,可以通过super().__init__()中的参数aggr来设定
5.update函数
用于更新节点的消息
三、GCN图卷积网络
GCN网络的原理可见:图卷积神经网络--GCN
需要注意 torch_scatter无法使用pip install加载可以参见 torch_scatter安装
1.加载数据集
from torch_geometric.datasets import Planetoiddataset = Planetoid(root='Cora', name='Cora')
Cora数据集是一个根据科学论文之间相互引用关系构建的图(Graph)数据集合,论文合计7类,共2708篇论文(2708个节点),10556条边。
2.定义GCN层
class GCNConv(MessagePassing):def __init__(self, in_channels, out_channels, add_self_loops=True, bias=True):super(GCNConv, self).__init__()self.add_self_loops = add_self_loopsself.edge_index = Noneself.linear = pyg_nn.dense.linear.Linear(in_channels, out_channels, weight_initializer='glorot')if bias:self.bias = nn.Parameter(torch.Tensor(out_channels, 1))self.bias = pyg_nn.inits.glorot(self.bias)else:self.register_parameter('bias', None)# 1.消息传递def message(self, x, edge_index):# 1.对所有节点进行新的空间映射x = self.linear(x) # [num_nodes, feature_size]# 2.添加偏置if self.bias != None:x += self.bias.flatten()# 3.返回source、target信息,对应边的起点和终点row, col = edge_index # [E]# 4.获得度矩阵deg = degree(col, x.shape[0], x.dtype) # [num_nodes]# 5.度矩阵归一化deg_inv_sqrt = deg.pow(-0.5) # [num_nodes]# 6.计算sqrt(deg(i)) * sqrt(deg(j))norm = deg_inv_sqrt[row] * deg_inv_sqrt[col] # [num_nodes]# 7.返回所有边的映射x_j = x[row] # [E, feature_size]# 8.计算归一化后的节点特征x_j = norm.view(-1, 1) * x_j # [E, feature_size]return x_j# 2.消息聚合def aggregate(self, x_j, edge_index):# 1.返回source、target信息,对应边的起点和终点row, col = edge_index # [E]# 2.聚合邻居特征aggr_out = scatter(x_j, row, dim=0, reduce='sum') # [num_nodes, feature_size]return aggr_out# 3.节点更新def update(self, aggr_out):# 对于GCN没有这个阶段,所以直接返回return aggr_outdef forward(self, x, edge_index):# 2.添加自环信息,考虑自身信息if self.add_self_loops:edge_index, _ = add_self_loops(edge_index, num_nodes=x.shape[0]) # [2, E]return self.propagate(edge_index, x=x)
3.定义GCN网络
class GCN(nn.Module):def __init__(self, num_node_features, num_classes):super(GCN, self).__init__()self.conv1 = GCNConv(num_node_features, 16)self.conv2 = GCNConv(16, num_classes)def forward(self, data):x, edge_index = data.x, data.edge_indexx = self.conv1(x, edge_index)x = F.relu(x)x = F.dropout(x, training=self.training)x = self.conv2(x, edge_index)return F.log_softmax(x, dim=1)
4.模型调用
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # 设备
epochs = 200 # 学习轮数
lr = 0.0003 # 学习率
num_node_features = dataset.num_node_features # 每个节点的特征数
num_classes = dataset.num_classes # 每个节点的类别数
data = dataset[0].to(device) # Cora的一张图# 4.定义模型
model = GCN(num_node_features, num_classes).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=lr) # 优化器
loss_function = nn.NLLLoss() # 损失函数# 训练模式
model.train()for epoch in range(epochs):optimizer.zero_grad()pred = model(data)loss = loss_function(pred[data.train_mask], data.y[data.train_mask]) # 损失correct_count_train = pred.argmax(axis=1)[data.train_mask].eq(data.y[data.train_mask]).sum().item() # epoch正确分类数目acc_train = correct_count_train / data.train_mask.sum().item() # epoch训练精度loss.backward()optimizer.step()if epoch % 20 == 0:print("【EPOCH: 】%s" % str(epoch + 1))print('训练损失为:{:.4f}'.format(loss.item()), '训练精度为:{:.4f}'.format(acc_train))print('【Finished Training!】')# 模型验证
model.eval()
pred = model(data)# 训练集(使用了掩码)
correct_count_train = pred.argmax(axis=1)[data.train_mask].eq(data.y[data.train_mask]).sum().item()
acc_train = correct_count_train / data.train_mask.sum().item()
loss_train = loss_function(pred[data.train_mask], data.y[data.train_mask]).item()# 测试集
correct_count_test = pred.argmax(axis=1)[data.test_mask].eq(data.y[data.test_mask]).sum().item()
acc_test = correct_count_test / data.test_mask.sum().item()
loss_test = loss_function(pred[data.test_mask], data.y[data.test_mask]).item()print('Train Accuracy: {:.4f}'.format(acc_train), 'Train Loss: {:.4f}'.format(loss_train))
print('Test Accuracy: {:.4f}'.format(acc_test), 'Test Loss: {:.4f}'.format(loss_test))
相关文章:
[图神经网络]PyTorch简单实现一个GCN
Pytorch自带一个PyG的图神经网络库,和构建卷积神经网络类似。不同于卷积神经网络仅需重构__init__( )和forward( )两个函数,PyTorch必须额外重构propagate( )和message( )函数。 一、环境构建 ①安装torch_geometric包。 pip install torch_geometric …...
Elasticsearch(黑马)
初识elasticsearch . 安装elasticsearch 1.部署单点es 1.1.创建网络 因为我们还需要部署kibana容器,因此需要让es和kibana容器互联。这里先创建一个网络: docker network create es-net 1.2.加载镜像 这里我们采用elasticsearch的7.12.1版本的…...
oracle数据库调整字段类型
oracle数据库更改字段类型比较墨迹,因为如果该字段有值,是不允许直接更改字段类型的。另外oralce不支持在指定的某个字段后面新增一个字段,但是mysql数据可以向指定的字段后面新增一个字段。 mysql向指定字段后面新增一个字段: al…...
面部表情识别2:Pytorch实现表情识别(含表情识别数据集和训练代码)
面部表情识别2:Pytorch实现表情识别(含表情识别数据集和训练代码) 目录 面部表情识别2:Pytorch实现表情识别(含表情识别数据集和训练代码) 1.面部表情识别方法 2.面部表情识别数据集 (1)表情识别数据集说明 (2&…...
赛效:如何在线给图片加水印
学会给图片加水印是一个非常实用的技能,可以让你的图片更具保护性和个性化。说到加水印,很多人不知道怎么操作。其实,给图片加水印非常简单,不用下载任何程序,在线就能完成。今天,我将介绍如何使用改图宝在…...
动力节点杜老师Vue笔记——Vue程序初体验
一、Vue程序初体验 我们可以先不去了解Vue框架的发展历史、Vue框架有什么特点、Vue是谁开发的,这些对我们编写Vue程序起不到太大的作用,更何况现在说了一些特点之后,我们也没有办法彻底理解它,因此我们可以先学会用,使…...
ajax上传图片存入到指定的文件夹并回显
html代码: <!DOCTYPE html> <html lang"en"> <head><meta charset"UTF-8"><title>Title</title><script src"js/jquery-2.1.0.js"></script> </head> <body> <form…...
cesium加载cesiumlab切的影像切片和标准TMS瓦片的区别
1.加载cesiumlab切的影像 var labImg viewer.scene.imageryLayers.addImageryProvider( new Cesium.UrlTemplateImageryProvider({url:http://192.168.1.25:8080/DOMtms/{z}/{x}/{y}.png,fileExtension : "png"})); 2.标准TMS瓦片 var labImg viewer.scene.im…...
第二周P9-P22
文章目录第三章 系统总线3.1、总线的基本概念一、为什么要用总线二、什么是总线三、总线上信息的传送四、总线结构的计算机举例1、单总线结构框图2、面向CPU的双总线结构框图3、以存储器为中心的双总线结构图3.2、总线的分类1、片内总线2、系统总线3、通信走线3.3、总线特性及性…...
java反射
文章目录何为反射?反射的应用场景了解么?谈谈反射机制的优缺点优点缺点反射实战获取 Class 对象的四种方式1. 知道具体类的情况下可以使用TargetObject.class:2. 通过 Class.forName()传入类的全路径获取:3. 通过对象实例instance…...
307 Temporary Redirect 解决办法(临时重定向)
背景:java后台服务请求python服务端 java服务报错:Unexpected response status:307 python服务端报错:307 Temporary Redirect 解决:查了好久找不到什么原因,请求路径问题 请求url:http//:w…...
SpringBoot案例
SpringBoot案例5,案例5.1 创建工程5.2 代码拷贝5.3 配置文件5.4 静态资源目标 基于SpringBoot的完成SSM整合项目开发 5,案例 SpringBoot 到这就已经学习完毕,接下来我们将学习 SSM 时做的三大框架整合的案例用 SpringBoot 来实现一下。我们完…...
Android 10.0 系统framework发送悬浮通知的流程分析
1.前言 在android10.0rom定制化开发中,在原生系统的systemui中,状态栏通知,和闹钟,wifi等悬浮通知也是很重要的, 悬浮通知也是系统通知的一种,也是在frameworks中发送出来的通知,接下来就分析下10.0中的悬浮通知的发送 流程,然后就可以实现自己自定义悬浮通知的相关功…...
傅里叶谱方法-傅里叶谱方法求解二维浅水方程组和二维粘性 Burgers 方程及其Matlab程序实现
3.3.2 二维浅水方程组 二维浅水方程组是描述水波运动的基本方程之一。它主要用于描述近岸浅水区域内的波浪、潮汐等水动力学现象。这个方程组由两个偏微分方程组成,一个是质量守恒方程,另一个是动量守恒方程。浅水方程描述了具有自由表面、密度均匀、深…...
算法训练营 - 广度优先BFS
目录 从层序遍历开始 N 叉树的层序遍历 经典BFS最短路模板 经典C queue 数组模拟队列 打印路径 示例1.bfs查找所有连接方块 Cqueue版 数组模拟队列 示例2.从多个位置同时开始BFS 示例3.抽象最短路类(作图关键) 示例4.跨过障碍的最短路 从层序遍历…...
判断两个字符串是否匹配(1个通配符代表一个字符)
目录 判断两个字符串是否匹配(1个通配符代表一个字符) 程序设计 程序分析...
用css画一个csdn程序猿
效果如下: 头部 我们先来拆解一下,程序猿的结构 两只耳朵和头是圆形组成的,耳朵内的红色部分也是圆形 先画头部,利用圆角实现头部形状 借助工具来快速实现圆角效果 https://9elements.github.io/fancy-border-radius/ <div…...
Java多线程编程—wait/notify机制
文章目录1. 不使用wait/notify机制通信的缺点2. 什么是wait/notify机制3. wait/notify机制原理4. wait/notify方法的基本用法5. 线程状态的切换6. interrupt()遇到方法wait()7. notify/notifyAll方法8. wait(long)介绍9. 生产者/消费者模式10. 管道机制11. 利用wait/notify实现…...
Three.js教程:旋转动画、requestAnimationFrame周期性渲染
推荐:将NSDT场景编辑器加入你3D工具链其他工具系列:NSDT简石数字孪生基于WebGL技术开发在线游戏、商品展示、室内漫游往往都会涉及到动画,初步了解three.js可以做什么,深入讲解three.js动画之前,本节课先制作一个简单的…...
租车自驾app开发有什么作用?租车便利出行APP开发
在线租车APP有哪些优势,租车APP开发的基本功能,租车自驾app开发有什么作用?租车便利出行APP开发,租车服务平台小程序有哪些功能,租车软件开发需要多少钱,租车app都有哪些,租车平台定制开发,租车…...
【Axure高保真原型】引导弹窗
今天和大家中分享引导弹窗的原型模板,载入页面后,会显示引导弹窗,适用于引导用户使用页面,点击完成后,会显示下一个引导弹窗,直至最后一个引导弹窗完成后进入首页。具体效果可以点击下方视频观看或打开下方…...
龙虎榜——20250610
上证指数放量收阴线,个股多数下跌,盘中受消息影响大幅波动。 深证指数放量收阴线形成顶分型,指数短线有调整的需求,大概需要一两天。 2025年6月10日龙虎榜行业方向分析 1. 金融科技 代表标的:御银股份、雄帝科技 驱动…...
【配置 YOLOX 用于按目录分类的图片数据集】
现在的图标点选越来越多,如何一步解决,采用 YOLOX 目标检测模式则可以轻松解决 要在 YOLOX 中使用按目录分类的图片数据集(每个目录代表一个类别,目录下是该类别的所有图片),你需要进行以下配置步骤&#x…...
06 Deep learning神经网络编程基础 激活函数 --吴恩达
深度学习激活函数详解 一、核心作用 引入非线性:使神经网络可学习复杂模式控制输出范围:如Sigmoid将输出限制在(0,1)梯度传递:影响反向传播的稳定性二、常见类型及数学表达 Sigmoid σ ( x ) = 1 1 +...
【C++从零实现Json-Rpc框架】第六弹 —— 服务端模块划分
一、项目背景回顾 前五弹完成了Json-Rpc协议解析、请求处理、客户端调用等基础模块搭建。 本弹重点聚焦于服务端的模块划分与架构设计,提升代码结构的可维护性与扩展性。 二、服务端模块设计目标 高内聚低耦合:各模块职责清晰,便于独立开发…...
Linux nano命令的基本使用
参考资料 GNU nanoを使いこなすnano基础 目录 一. 简介二. 文件打开2.1 普通方式打开文件2.2 只读方式打开文件 三. 文件查看3.1 打开文件时,显示行号3.2 翻页查看 四. 文件编辑4.1 Ctrl K 复制 和 Ctrl U 粘贴4.2 Alt/Esc U 撤回 五. 文件保存与退出5.1 Ctrl …...
多模态图像修复系统:基于深度学习的图片修复实现
多模态图像修复系统:基于深度学习的图片修复实现 1. 系统概述 本系统使用多模态大模型(Stable Diffusion Inpainting)实现图像修复功能,结合文本描述和图片输入,对指定区域进行内容修复。系统包含完整的数据处理、模型训练、推理部署流程。 import torch import numpy …...
从面试角度回答Android中ContentProvider启动原理
Android中ContentProvider原理的面试角度解析,分为已启动和未启动两种场景: 一、ContentProvider已启动的情况 1. 核心流程 触发条件:当其他组件(如Activity、Service)通过ContentR…...
认识CMake并使用CMake构建自己的第一个项目
1.CMake的作用和优势 跨平台支持:CMake支持多种操作系统和编译器,使用同一份构建配置可以在不同的环境中使用 简化配置:通过CMakeLists.txt文件,用户可以定义项目结构、依赖项、编译选项等,无需手动编写复杂的构建脚本…...
嵌入式学习之系统编程(九)OSI模型、TCP/IP模型、UDP协议网络相关编程(6.3)
目录 一、网络编程--OSI模型 二、网络编程--TCP/IP模型 三、网络接口 四、UDP网络相关编程及主要函数 编辑编辑 UDP的特征 socke函数 bind函数 recvfrom函数(接收函数) sendto函数(发送函数) 五、网络编程之 UDP 用…...
