[图神经网络]PyTorch简单实现一个GCN
Pytorch自带一个PyG的图神经网络库,和构建卷积神经网络类似。不同于卷积神经网络仅需重构__init__( )和forward( )两个函数,PyTorch必须额外重构propagate( )和message( )函数。
一、环境构建
①安装torch_geometric包。
pip install torch_geometric
②导入相关库
import torch
import torch.nn.functional as F
import torch.nn as nn
import torch_geometric.nn as pyg_nn
from torch_geometric.datasets import Planetoid
二、PyG图学习架构
构建方法:首先继承MessagePassing类,接下来重写构造函数和以下三个方法:
message() #构建消息传递
aggregate() #将消息聚合到目标节点
update() #更新消息节点
1.构造函数
def __init__(self, aggr: Optional[str] = "add",flow: str = "source_to_target", node_dim: int = -2,decomposed_layers: int = 1):
参数 | 内容 |
aggr | 消息聚合的方式,常见的方法:add、mean、min、max |
flow | 消息传播的方向,source_to_target--从源节点到目标节点 target_to_source--从目标节点到源节点 |
node_dim | 传播的维度 |
2.propagate函数
该函数为消息传播的启动函数,调用此函数后会依次执行:message、aggregate、update来完成消息的传递、聚合、更新。
该函数声明如下:
propagate(self, edge_index: Adj, size: Size = None, **kwargs)
参数 | 说明 |
edge_index | 边索引 |
size | 邻接矩阵的尺寸,若为None则表示方阵 |
**kwargs | 额外参数 |
3.message函数
用于构建节点消息,传递给propagate的tensor可以映射到中心节点和邻居节点,仅需在相应的变量名后加上_i(中心节点)或_j(邻居节点)即可。
self.propagate(edge_index, x=x):passdef message(self, x_i, x_j, edge_index_i):pass
x_i | 中心节点构成的特征向量组成的矩阵 |
x_j | 邻居节点构成的特征向量组成的矩阵 |
edge_index_i | 中心节点的索引 |
4.aggregate函数
消息聚合函数,用以聚合来自邻居的消息,常见的方法有add、sum、mean、max,可以通过super().__init__()中的参数aggr来设定
5.update函数
用于更新节点的消息
三、GCN图卷积网络
GCN网络的原理可见:图卷积神经网络--GCN
需要注意 torch_scatter无法使用pip install加载可以参见 torch_scatter安装
1.加载数据集
from torch_geometric.datasets import Planetoiddataset = Planetoid(root='Cora', name='Cora')
Cora数据集是一个根据科学论文之间相互引用关系构建的图(Graph)数据集合,论文合计7类,共2708篇论文(2708个节点),10556条边。
2.定义GCN层
class GCNConv(MessagePassing):def __init__(self, in_channels, out_channels, add_self_loops=True, bias=True):super(GCNConv, self).__init__()self.add_self_loops = add_self_loopsself.edge_index = Noneself.linear = pyg_nn.dense.linear.Linear(in_channels, out_channels, weight_initializer='glorot')if bias:self.bias = nn.Parameter(torch.Tensor(out_channels, 1))self.bias = pyg_nn.inits.glorot(self.bias)else:self.register_parameter('bias', None)# 1.消息传递def message(self, x, edge_index):# 1.对所有节点进行新的空间映射x = self.linear(x) # [num_nodes, feature_size]# 2.添加偏置if self.bias != None:x += self.bias.flatten()# 3.返回source、target信息,对应边的起点和终点row, col = edge_index # [E]# 4.获得度矩阵deg = degree(col, x.shape[0], x.dtype) # [num_nodes]# 5.度矩阵归一化deg_inv_sqrt = deg.pow(-0.5) # [num_nodes]# 6.计算sqrt(deg(i)) * sqrt(deg(j))norm = deg_inv_sqrt[row] * deg_inv_sqrt[col] # [num_nodes]# 7.返回所有边的映射x_j = x[row] # [E, feature_size]# 8.计算归一化后的节点特征x_j = norm.view(-1, 1) * x_j # [E, feature_size]return x_j# 2.消息聚合def aggregate(self, x_j, edge_index):# 1.返回source、target信息,对应边的起点和终点row, col = edge_index # [E]# 2.聚合邻居特征aggr_out = scatter(x_j, row, dim=0, reduce='sum') # [num_nodes, feature_size]return aggr_out# 3.节点更新def update(self, aggr_out):# 对于GCN没有这个阶段,所以直接返回return aggr_outdef forward(self, x, edge_index):# 2.添加自环信息,考虑自身信息if self.add_self_loops:edge_index, _ = add_self_loops(edge_index, num_nodes=x.shape[0]) # [2, E]return self.propagate(edge_index, x=x)
3.定义GCN网络
class GCN(nn.Module):def __init__(self, num_node_features, num_classes):super(GCN, self).__init__()self.conv1 = GCNConv(num_node_features, 16)self.conv2 = GCNConv(16, num_classes)def forward(self, data):x, edge_index = data.x, data.edge_indexx = self.conv1(x, edge_index)x = F.relu(x)x = F.dropout(x, training=self.training)x = self.conv2(x, edge_index)return F.log_softmax(x, dim=1)
4.模型调用
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # 设备
epochs = 200 # 学习轮数
lr = 0.0003 # 学习率
num_node_features = dataset.num_node_features # 每个节点的特征数
num_classes = dataset.num_classes # 每个节点的类别数
data = dataset[0].to(device) # Cora的一张图# 4.定义模型
model = GCN(num_node_features, num_classes).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=lr) # 优化器
loss_function = nn.NLLLoss() # 损失函数# 训练模式
model.train()for epoch in range(epochs):optimizer.zero_grad()pred = model(data)loss = loss_function(pred[data.train_mask], data.y[data.train_mask]) # 损失correct_count_train = pred.argmax(axis=1)[data.train_mask].eq(data.y[data.train_mask]).sum().item() # epoch正确分类数目acc_train = correct_count_train / data.train_mask.sum().item() # epoch训练精度loss.backward()optimizer.step()if epoch % 20 == 0:print("【EPOCH: 】%s" % str(epoch + 1))print('训练损失为:{:.4f}'.format(loss.item()), '训练精度为:{:.4f}'.format(acc_train))print('【Finished Training!】')# 模型验证
model.eval()
pred = model(data)# 训练集(使用了掩码)
correct_count_train = pred.argmax(axis=1)[data.train_mask].eq(data.y[data.train_mask]).sum().item()
acc_train = correct_count_train / data.train_mask.sum().item()
loss_train = loss_function(pred[data.train_mask], data.y[data.train_mask]).item()# 测试集
correct_count_test = pred.argmax(axis=1)[data.test_mask].eq(data.y[data.test_mask]).sum().item()
acc_test = correct_count_test / data.test_mask.sum().item()
loss_test = loss_function(pred[data.test_mask], data.y[data.test_mask]).item()print('Train Accuracy: {:.4f}'.format(acc_train), 'Train Loss: {:.4f}'.format(loss_train))
print('Test Accuracy: {:.4f}'.format(acc_test), 'Test Loss: {:.4f}'.format(loss_test))
相关文章:
[图神经网络]PyTorch简单实现一个GCN
Pytorch自带一个PyG的图神经网络库,和构建卷积神经网络类似。不同于卷积神经网络仅需重构__init__( )和forward( )两个函数,PyTorch必须额外重构propagate( )和message( )函数。 一、环境构建 ①安装torch_geometric包。 pip install torch_geometric …...

Elasticsearch(黑马)
初识elasticsearch . 安装elasticsearch 1.部署单点es 1.1.创建网络 因为我们还需要部署kibana容器,因此需要让es和kibana容器互联。这里先创建一个网络: docker network create es-net 1.2.加载镜像 这里我们采用elasticsearch的7.12.1版本的…...
oracle数据库调整字段类型
oracle数据库更改字段类型比较墨迹,因为如果该字段有值,是不允许直接更改字段类型的。另外oralce不支持在指定的某个字段后面新增一个字段,但是mysql数据可以向指定的字段后面新增一个字段。 mysql向指定字段后面新增一个字段: al…...

面部表情识别2:Pytorch实现表情识别(含表情识别数据集和训练代码)
面部表情识别2:Pytorch实现表情识别(含表情识别数据集和训练代码) 目录 面部表情识别2:Pytorch实现表情识别(含表情识别数据集和训练代码) 1.面部表情识别方法 2.面部表情识别数据集 (1)表情识别数据集说明 (2&…...

赛效:如何在线给图片加水印
学会给图片加水印是一个非常实用的技能,可以让你的图片更具保护性和个性化。说到加水印,很多人不知道怎么操作。其实,给图片加水印非常简单,不用下载任何程序,在线就能完成。今天,我将介绍如何使用改图宝在…...

动力节点杜老师Vue笔记——Vue程序初体验
一、Vue程序初体验 我们可以先不去了解Vue框架的发展历史、Vue框架有什么特点、Vue是谁开发的,这些对我们编写Vue程序起不到太大的作用,更何况现在说了一些特点之后,我们也没有办法彻底理解它,因此我们可以先学会用,使…...

ajax上传图片存入到指定的文件夹并回显
html代码: <!DOCTYPE html> <html lang"en"> <head><meta charset"UTF-8"><title>Title</title><script src"js/jquery-2.1.0.js"></script> </head> <body> <form…...
cesium加载cesiumlab切的影像切片和标准TMS瓦片的区别
1.加载cesiumlab切的影像 var labImg viewer.scene.imageryLayers.addImageryProvider( new Cesium.UrlTemplateImageryProvider({url:http://192.168.1.25:8080/DOMtms/{z}/{x}/{y}.png,fileExtension : "png"})); 2.标准TMS瓦片 var labImg viewer.scene.im…...

第二周P9-P22
文章目录第三章 系统总线3.1、总线的基本概念一、为什么要用总线二、什么是总线三、总线上信息的传送四、总线结构的计算机举例1、单总线结构框图2、面向CPU的双总线结构框图3、以存储器为中心的双总线结构图3.2、总线的分类1、片内总线2、系统总线3、通信走线3.3、总线特性及性…...
java反射
文章目录何为反射?反射的应用场景了解么?谈谈反射机制的优缺点优点缺点反射实战获取 Class 对象的四种方式1. 知道具体类的情况下可以使用TargetObject.class:2. 通过 Class.forName()传入类的全路径获取:3. 通过对象实例instance…...
307 Temporary Redirect 解决办法(临时重定向)
背景:java后台服务请求python服务端 java服务报错:Unexpected response status:307 python服务端报错:307 Temporary Redirect 解决:查了好久找不到什么原因,请求路径问题 请求url:http//:w…...

SpringBoot案例
SpringBoot案例5,案例5.1 创建工程5.2 代码拷贝5.3 配置文件5.4 静态资源目标 基于SpringBoot的完成SSM整合项目开发 5,案例 SpringBoot 到这就已经学习完毕,接下来我们将学习 SSM 时做的三大框架整合的案例用 SpringBoot 来实现一下。我们完…...
Android 10.0 系统framework发送悬浮通知的流程分析
1.前言 在android10.0rom定制化开发中,在原生系统的systemui中,状态栏通知,和闹钟,wifi等悬浮通知也是很重要的, 悬浮通知也是系统通知的一种,也是在frameworks中发送出来的通知,接下来就分析下10.0中的悬浮通知的发送 流程,然后就可以实现自己自定义悬浮通知的相关功…...

傅里叶谱方法-傅里叶谱方法求解二维浅水方程组和二维粘性 Burgers 方程及其Matlab程序实现
3.3.2 二维浅水方程组 二维浅水方程组是描述水波运动的基本方程之一。它主要用于描述近岸浅水区域内的波浪、潮汐等水动力学现象。这个方程组由两个偏微分方程组成,一个是质量守恒方程,另一个是动量守恒方程。浅水方程描述了具有自由表面、密度均匀、深…...

算法训练营 - 广度优先BFS
目录 从层序遍历开始 N 叉树的层序遍历 经典BFS最短路模板 经典C queue 数组模拟队列 打印路径 示例1.bfs查找所有连接方块 Cqueue版 数组模拟队列 示例2.从多个位置同时开始BFS 示例3.抽象最短路类(作图关键) 示例4.跨过障碍的最短路 从层序遍历…...
判断两个字符串是否匹配(1个通配符代表一个字符)
目录 判断两个字符串是否匹配(1个通配符代表一个字符) 程序设计 程序分析...

用css画一个csdn程序猿
效果如下: 头部 我们先来拆解一下,程序猿的结构 两只耳朵和头是圆形组成的,耳朵内的红色部分也是圆形 先画头部,利用圆角实现头部形状 借助工具来快速实现圆角效果 https://9elements.github.io/fancy-border-radius/ <div…...

Java多线程编程—wait/notify机制
文章目录1. 不使用wait/notify机制通信的缺点2. 什么是wait/notify机制3. wait/notify机制原理4. wait/notify方法的基本用法5. 线程状态的切换6. interrupt()遇到方法wait()7. notify/notifyAll方法8. wait(long)介绍9. 生产者/消费者模式10. 管道机制11. 利用wait/notify实现…...
Three.js教程:旋转动画、requestAnimationFrame周期性渲染
推荐:将NSDT场景编辑器加入你3D工具链其他工具系列:NSDT简石数字孪生基于WebGL技术开发在线游戏、商品展示、室内漫游往往都会涉及到动画,初步了解three.js可以做什么,深入讲解three.js动画之前,本节课先制作一个简单的…...
租车自驾app开发有什么作用?租车便利出行APP开发
在线租车APP有哪些优势,租车APP开发的基本功能,租车自驾app开发有什么作用?租车便利出行APP开发,租车服务平台小程序有哪些功能,租车软件开发需要多少钱,租车app都有哪些,租车平台定制开发,租车…...

网络编程(Modbus进阶)
思维导图 Modbus RTU(先学一点理论) 概念 Modbus RTU 是工业自动化领域 最广泛应用的串行通信协议,由 Modicon 公司(现施耐德电气)于 1979 年推出。它以 高效率、强健性、易实现的特点成为工业控制系统的通信标准。 包…...
Java 语言特性(面试系列2)
一、SQL 基础 1. 复杂查询 (1)连接查询(JOIN) 内连接(INNER JOIN):返回两表匹配的记录。 SELECT e.name, d.dept_name FROM employees e INNER JOIN departments d ON e.dept_id d.dept_id; 左…...
SkyWalking 10.2.0 SWCK 配置过程
SkyWalking 10.2.0 & SWCK 配置过程 skywalking oap-server & ui 使用Docker安装在K8S集群以外,K8S集群中的微服务使用initContainer按命名空间将skywalking-java-agent注入到业务容器中。 SWCK有整套的解决方案,全安装在K8S群集中。 具体可参…...

黑马Mybatis
Mybatis 表现层:页面展示 业务层:逻辑处理 持久层:持久数据化保存 在这里插入图片描述 Mybatis快速入门 凭借其独特的优势,正成为构建高效、稳定、灵活系统的核心解决方案。它利用现有电力线路传输数据,无需额外布…...

从深圳崛起的“机器之眼”:赴港乐动机器人的万亿赛道赶考路
进入2025年以来,尽管围绕人形机器人、具身智能等机器人赛道的质疑声不断,但全球市场热度依然高涨,入局者持续增加。 以国内市场为例,天眼查专业版数据显示,截至5月底,我国现存在业、存续状态的机器人相关企…...

汽车生产虚拟实训中的技能提升与生产优化
在制造业蓬勃发展的大背景下,虚拟教学实训宛如一颗璀璨的新星,正发挥着不可或缺且日益凸显的关键作用,源源不断地为企业的稳健前行与创新发展注入磅礴强大的动力。就以汽车制造企业这一极具代表性的行业主体为例,汽车生产线上各类…...

linux arm系统烧录
1、打开瑞芯微程序 2、按住linux arm 的 recover按键 插入电源 3、当瑞芯微检测到有设备 4、松开recover按键 5、选择升级固件 6、点击固件选择本地刷机的linux arm 镜像 7、点击升级 (忘了有没有这步了 估计有) 刷机程序 和 镜像 就不提供了。要刷的时…...
python爬虫:Newspaper3k 的详细使用(好用的新闻网站文章抓取和解析的Python库)
更多内容请见: 爬虫和逆向教程-专栏介绍和目录 文章目录 一、Newspaper3k 概述1.1 Newspaper3k 介绍1.2 主要功能1.3 典型应用场景1.4 安装二、基本用法2.2 提取单篇文章的内容2.2 处理多篇文档三、高级选项3.1 自定义配置3.2 分析文章情感四、实战案例4.1 构建新闻摘要聚合器…...
【HTTP三个基础问题】
面试官您好!HTTP是超文本传输协议,是互联网上客户端和服务器之间传输超文本数据(比如文字、图片、音频、视频等)的核心协议,当前互联网应用最广泛的版本是HTTP1.1,它基于经典的C/S模型,也就是客…...