【Graph Net学习】GNN/GCN代码实战
一、简介
GNN(Graph Neural Network)和GCN(Graph Convolutional Network)都是基于图结构的神经网络模型。本文目标就是打代码基础,未用PyG,来扒一扒Graph Net两个基础算法的原理。直接上代码。
二、代码
import time
import random
import os
import numpy as np
import math
from torch.nn.parameter import Parameter
from torch.nn.modules.module import Moduleimport torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optimimport scipy.sparse as sp#配置项
class configs():def __init__(self):# Dataself.data_path = r'E:\code\Graph\data'self.save_model_dir = r'\code\Graph'self.model_name = r'GCN' #GNN/GCNself.seed = 2023self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")self.batch_size = 64self.epoch = 200self.in_features = 1433 #core ~ feature:1433self.hidden_features = 16 # 隐层数量self.output_features = 8 # core~paper-point~ 8类self.learning_rate = 0.01self.dropout = 0.5self.istrain = Trueself.istest = Truecfg = configs()def seed_everything(seed=2023):random.seed(seed)os.environ['PYTHONHASHSEED']=str(seed)np.random.seed(seed)torch.manual_seed(seed)seed_everything(seed = cfg.seed)#数据
class Graph_Data_Loader():def __init__(self):self.adj, self.features, self.labels, self.idx_train, self.idx_val, self.idx_test = self.load_data()self.adj = self.adj.to(cfg.device)self.features = self.features.to(cfg.device)self.labels = self.labels.to(cfg.device)self.idx_train = self.idx_train.to(cfg.device)self.idx_val = self.idx_val.to(cfg.device)self.idx_test = self.idx_test.to(cfg.device)def load_data(self,path=cfg.data_path, dataset="cora"):"""Load citation network dataset (cora only for now)"""print('Loading {} dataset...'.format(dataset))idx_features_labels = np.genfromtxt(os.path.join(path,dataset,dataset+'.content'),dtype=np.dtype(str))features = sp.csr_matrix(idx_features_labels[:, 1:-1], dtype=np.float32)labels = self.encode_onehot(idx_features_labels[:, -1])# build graphidx = np.array(idx_features_labels[:, 0], dtype=np.int32)idx_map = {j: i for i, j in enumerate(idx)}edges_unordered = np.genfromtxt(os.path.join(path,dataset,dataset+'.cites'),dtype=np.int32)edges = np.array(list(map(idx_map.get, edges_unordered.flatten())),dtype=np.int32).reshape(edges_unordered.shape)adj = sp.coo_matrix((np.ones(edges.shape[0]), (edges[:, 0], edges[:, 1])),shape=(labels.shape[0], labels.shape[0]),dtype=np.float32)# build symmetric adjacency matrixadj = adj + adj.T.multiply(adj.T > adj) - adj.multiply(adj.T > adj)features = self.normalize(features)adj = self.normalize(adj + sp.eye(adj.shape[0]))idx_train = range(140)idx_val = range(200, 500)idx_test = range(500, 1500)features = torch.FloatTensor(np.array(features.todense()))labels = torch.LongTensor(np.where(labels)[1])adj = self.sparse_mx_to_torch_sparse_tensor(adj)idx_train = torch.LongTensor(idx_train)idx_val = torch.LongTensor(idx_val)idx_test = torch.LongTensor(idx_test)return adj, features, labels, idx_train, idx_val, idx_testdef encode_onehot(self,labels):classes = set(labels)classes_dict = {c: np.identity(len(classes))[i, :] for i, c inenumerate(classes)}labels_onehot = np.array(list(map(classes_dict.get, labels)),dtype=np.int32)return labels_onehotdef normalize(self,mx):"""Row-normalize sparse matrix"""rowsum = np.array(mx.sum(1))r_inv = np.power(rowsum, -1).flatten()r_inv[np.isinf(r_inv)] = 0.r_mat_inv = sp.diags(r_inv)mx = r_mat_inv.dot(mx)return mxdef sparse_mx_to_torch_sparse_tensor(self,sparse_mx):"""Convert a scipy sparse matrix to a torch sparse tensor."""sparse_mx = sparse_mx.tocoo().astype(np.float32)indices = torch.from_numpy(np.vstack((sparse_mx.row, sparse_mx.col)).astype(np.int64))values = torch.from_numpy(sparse_mx.data)shape = torch.Size(sparse_mx.shape)return torch.sparse.FloatTensor(indices, values, shape)#精度评价指标
def accuracy(output, labels):preds = output.max(1)[1].type_as(labels)correct = preds.eq(labels).double()correct = correct.sum()return correct / len(labels)#模型
#01-GNN
class GNNLayer(nn.Module):def __init__(self, in_features, output_features):super(GNNLayer, self).__init__()self.linear = nn.Linear(in_features, output_features)def forward(self, adj_matrix, features):hidden_features = torch.matmul(adj_matrix, features) # GNN公式:H' = A * Hhidden_features = self.linear(hidden_features) # 使用线性变换hidden_features = F.relu(hidden_features) # 使用ReLU作为激活函数return hidden_features
class GNN(nn.Module):def __init__(self, in_features, hidden_features, output_features, num_layers=2):super(GNN, self).__init__()#输入维度in_features、隐藏层维度hidden_features、输出维度output_features、GNN的层数num_layersself.layers = nn.ModuleList([GNNLayer(in_features, hidden_features) if i == 0 else GNNLayer(hidden_features, hidden_features) for i inrange(num_layers)])self.output_layer = nn.Linear(hidden_features, output_features)def forward(self, adj_matrix, features):hidden_features = featuresfor layer in self.layers:hidden_features = layer(adj_matrix, hidden_features)output = self.output_layer(hidden_features)return F.log_softmax(output,dim=1)#02-GCN
class GraphConvolution(Module):"""Simple GCN layer, similar to https://arxiv.org/abs/1609.02907"""def __init__(self, in_features, out_features, bias=True):super(GraphConvolution, self).__init__()self.in_features = in_featuresself.out_features = out_featuresself.weight = Parameter(torch.FloatTensor(in_features, out_features))if bias:self.bias = Parameter(torch.FloatTensor(out_features))else:self.register_parameter('bias', None)self.reset_parameters()def reset_parameters(self):stdv = 1. / math.sqrt(self.weight.size(1))self.weight.data.uniform_(-stdv, stdv)if self.bias is not None:self.bias.data.uniform_(-stdv, stdv)def forward(self, input, adj):support = torch.mm(input, self.weight)output = torch.spmm(adj, support)if self.bias is not None:return output + self.biaselse:return outputdef __repr__(self):return self.__class__.__name__ + ' (' \+ str(self.in_features) + ' -> ' \+ str(self.out_features) + ')'class GCN(nn.Module):def __init__(self, in_features, hidden_features, output_features, dropout=cfg.dropout):super(GCN, self).__init__()self.gc1 = GraphConvolution(in_features, hidden_features)self.gc2 = GraphConvolution(hidden_features, output_features)self.dropout = dropoutdef forward(self, adj_matrix, features):x = F.relu(self.gc1(features, adj_matrix))x = F.dropout(x, self.dropout, training=self.training)x = self.gc2(x, adj_matrix)return F.log_softmax(x, dim=1)class graph_run():def train(self):t = time.time()#Create Train Processingall_data = Graph_Data_Loader()#创建一个模型model = eval(cfg.model_name)(in_features=cfg.in_features,hidden_features=cfg.hidden_features,output_features=cfg.output_features).to(cfg.device)optimizer = optim.Adam(model.parameters(),lr=cfg.learning_rate, weight_decay=5e-4)#Trainmodel.train()for epoch in range(cfg.epoch):optimizer.zero_grad()output = model(all_data.adj, all_data.features)loss_train = F.nll_loss(output[all_data.idx_train], all_data.labels[all_data.idx_train])acc_train = accuracy(output[all_data.idx_train], all_data.labels[all_data.idx_train])loss_train.backward()optimizer.step()loss_val = F.nll_loss(output[all_data.idx_val], all_data.labels[all_data.idx_val])acc_val = accuracy(output[all_data.idx_val], all_data.labels[all_data.idx_val])print('Epoch: {:04d}'.format(epoch + 1),'loss_train: {:.4f}'.format(loss_train.item()),'acc_train: {:.4f}'.format(acc_train.item()),'loss_val: {:.4f}'.format(loss_val.item()),'acc_val: {:.4f}'.format(acc_val.item()),'time: {:.4f}s'.format(time.time() - t))torch.save(model, os.path.join(cfg.save_model_dir, 'latest.pth')) # 模型保存def infer(self):#Create Test Processingall_data = Graph_Data_Loader()model_path = os.path.join(cfg.save_model_dir, 'latest.pth')model = torch.load(model_path, map_location=torch.device(cfg.device))model.eval()output = model(all_data.adj,all_data.features)loss_test = F.nll_loss(output[all_data.idx_test], all_data.labels[all_data.idx_test])acc_test = accuracy(output[all_data.idx_test], all_data.labels[all_data.idx_test])print("Test set results:","loss= {:.4f}".format(loss_test.item()),"accuracy= {:.4f}".format(acc_test.item()))if __name__ == '__main__':mygraph = graph_run()if cfg.istrain == True:mygraph.train()if cfg.istest == True:mygraph.infer()
三、结果与讨论
需要从网上下载cora数据集,数据组织形式如下图。
测了下Params和GFLOPs,还是比较大的,发现若作为一个Net的Block还是需要优化的哈哈~
Model | Params | GFLOPs |
GNN | 23.352K | 126.258M |
Model | Cora(/train/val/test) |
GNN | 1.0000/0.7800/0.7620 |
GCN | 0.9714/0.7767/0.8290 |
四、展望
未来可以考虑用PyG(PyTorch Geometric),毕竟PyG实现GAT等图网络、图的数据组织、加载会更加方便。Graph Net通常用可以用于属性数据的embedding模式,将属性数据可以作为一种补充特征加入Net去训练,看能不能发挥效能。
相关文章:

【Graph Net学习】GNN/GCN代码实战
一、简介 GNN(Graph Neural Network)和GCN(Graph Convolutional Network)都是基于图结构的神经网络模型。本文目标就是打代码基础,未用PyG,来扒一扒Graph Net两个基础算法的原理。直接上代码。 二、代码 …...

RocketMQ 发送顺序消息
文章目录 顺序消息应用场景消息组(MessageGroup)顺序性生产的顺序性MQ 存储的顺序性消费的顺序性 rocketmq-client-java 示例(gRPC 协议)1. 创建 FIFO 主题生产者代码消费者代码解决办法解决后执行结果 rocketmq-client 示例&…...

【面试经典150 | 双指针】判断子序列
文章目录 写在前面Tag题目来源题目解题解题思路方法一:双指针方法二:动态规划 写在最后 写在前面 本专栏专注于分析与讲解【面试经典150】算法,两到三天更新一篇文章,欢迎催更…… 专栏内容以分析题目为主,并附带一些对…...
自动驾驶信息安全方案
目录 1. 修订历史... 3 2. 概述... 4 2.1 目的... 4 2.2 适用范围... 4 2.3 参考文档... 4 2.4 术语和缩写... 4 3. 安全分析... 5 4. 总体设计... 6 4.1 ACU的安全防护... 7 4.1.1 系统安全引导... 7 4.1.2 密钥安全存储... 8 4.1.3 应…...
【云原生】kubernetes中pod(最小的资源管理组件)
目录 前言 一、pod 1.1pause容器使得Pod中的所有容器可以共享两种资源: 1.2 通常把Pod分为两类 1.2.1自主式Pod 1.2.2控制器管理的Pod 1.3 Pod 容器的分类 1.3.1基础容器(infrastructure container) 1.3.2初始化容器(initc…...
[DB]数据库--lowdb
[DB]数据库--lowdb lowdb基本应用获取数据数据变更写入文件 lodash的使用获取数据lodash方法使用数据变更写入文件 lowdb lowdb ,是一个基于文件存储的非关系型数据库 基于loadsh的轻量级数据库 可用于在json中存储数据,大小一般为0~200M的json文件 方便简单的数…...

Kotlin | 在for、forEach循环中正确的使用break、continue
文章目录 for循环中使用break、continueLabel标签forEach中如何退出循环资料 Kotlin 有三种结构化跳转表达式: return:默认从最直接包围它的函数或者匿名函数返回。break:终止最直接包围它的循环。continue:继续下一次最直接包围…...

【C++】详解std::mutex
2023年9月11日,周一中午开始 2023年9月11日,周一晚上23:25写完 目录 概述头文件std::mutex类的成员类型方法没有std::mutex会产生什么问题问题一:数据竞争问题二:不一致lock和unlock死锁 概述 std::mutex是C标准库中…...

Matlab图像处理-Lab模型
Lab模型 Lab模型是由CIE(国际照明委员会)制定的一种彩色模型。该模型与设备无关,弥补了RGB模型和CMYK模型必须依赖于设备颜色特性的不足; 另外,自然界中的任何颜色都可以在Lab空间中表现出来,也就是说RGB和…...
分布式ETL工具Sqoop实践
Mysql数据准备 1、在node02节点登录Mysql。 mysql -uroot -proot2、新建数据库testdb。 create database testdb;3、新建数据表ts。 use testdb; create table ts(id int, name varchar(10), age int, sex char(1));4、向表中插入数据。 insert into ts values(10001,张三…...

展会预告 | 图扑邀您共聚 IOTE 国际物联网展·深圳站
参展时间:9 月 20 日- 22 日 图扑展位:9 号馆 9B 35-1 参展地址:深圳国际会展中心(宝安新馆) IOTE 2023 第二十届国际物联网展深圳站,将于 9 月 20 日- 22 日在深圳国际会展中心(宝安…...

如何下载安装 WampServer 并结合 cpolar 内网穿透,轻松实现对本地服务的公网访问
文章目录 前言1.WampServer下载安装2.WampServer启动3.安装cpolar内网穿透3.1 注册账号3.2 下载cpolar客户端3.3 登录cpolar web ui管理界面3.4 创建公网地址 4.固定公网地址访问 前言 Wamp 是一个 Windows系统下的 Apache PHP Mysql 集成安装环境,是一组常用来…...
iOS添加Mapbox地图库
配置凭据 注册并导航到Account页面。你将需要: 公共访问令牌: 从帐户的tokens页面,你可以复制默认的公共令牌或单击"create a token"按钮来创建新的公共令牌。 带有Downloads:Read范围的秘密访问令牌: 从你帐户的t…...

destoon根据目录下的html文件生成地图索引
因为项目需要,destoon根据目录下的html文件生成地图索引,操作方法,代码如下: <?php $new_array array(); function loopDir($dir,&$new_array,$modurl) {$handle opendir($dir);header("Content-Type:text/xml&qu…...
gRPC之gRPC流
1、gRPC流 从其名称可以理解,流就是持续不断的传输。有一些业务场景请求或者响应的数据量比较大,不适合使用普通的 RPC 调用通过一次请求-响应处理,一方面是考虑数据量大对请求响应时间的影响,另一方面业务场景的设计不一 定需…...
Kafka Shell命令交互
Kafka提供了一个命令行工具,用于管理和与Kafka集群交互。这个命令行工具通常称为Kafka Shell,它允许您执行各种操作,如创建主题、发送和消费消息、查看主题列表等。 以下是一些常用的Kafka Shell命令: 创建主题(Topic): kafka-topics.sh --create --topic my-topic --pa…...
什么是回归测试?
什么是回归测试? 回归测试被定义为一种软件测试类型,以确认最近的程序或代码更改未对现有功能产生不利影响。 回归测试只不过是全部或部分选择已执行的测试用例,然后重新执行以确保现有功能正常运行。 进行此测试是为了确保新代码更改不会…...

ZTMap是如何在相关政策引导下让建筑更加智慧化的?
近几年随着智慧楼宇概念的深入,尤其是在“十四五规划”“新基建”“数字经济”等相关战略和政策的引导下,智慧楼宇也迎来了快速发展期,对推动智慧城市系统的建设越来越重要。那么究竟什么是智慧楼宇呢?智慧楼宇其实就是整合楼宇内…...

Python:函数和代码复用
嗨喽,大家好呀~这里是爱看美女的茜茜呐 👇 👇 👇 更多精彩机密、教程,尽在下方,赶紧点击了解吧~ python源码、视频教程、插件安装教程、资料我都准备好了,直接在文末名片自取就可 1、关于递归函…...

three.js——模型对象的使用材质和方法
模型对象的使用材质和方法 前言效果图1、旋转、缩放、平移,居中的使用1.1 旋转rotation(.rotateX()、.rotateY()、.rotateZ())1.2缩放.scale()1.3平移.translate()1.4居中.center() 2、材质属性.wireframe 前言 BufferGeometry通过.scale()、…...

04-初识css
一、css样式引入 1.1.内部样式 <div style"width: 100px;"></div>1.2.外部样式 1.2.1.外部样式1 <style>.aa {width: 100px;} </style> <div class"aa"></div>1.2.2.外部样式2 <!-- rel内表面引入的是style样…...

JUC笔记(上)-复习 涉及死锁 volatile synchronized CAS 原子操作
一、上下文切换 即使单核CPU也可以进行多线程执行代码,CPU会给每个线程分配CPU时间片来实现这个机制。时间片非常短,所以CPU会不断地切换线程执行,从而让我们感觉多个线程是同时执行的。时间片一般是十几毫秒(ms)。通过时间片分配算法执行。…...
【学习笔记】深入理解Java虚拟机学习笔记——第4章 虚拟机性能监控,故障处理工具
第2章 虚拟机性能监控,故障处理工具 4.1 概述 略 4.2 基础故障处理工具 4.2.1 jps:虚拟机进程状况工具 命令:jps [options] [hostid] 功能:本地虚拟机进程显示进程ID(与ps相同),可同时显示主类&#x…...
代理篇12|深入理解 Vite中的Proxy接口代理配置
在前端开发中,常常会遇到 跨域请求接口 的情况。为了解决这个问题,Vite 和 Webpack 都提供了 proxy 代理功能,用于将本地开发请求转发到后端服务器。 什么是代理(proxy)? 代理是在开发过程中,前端项目通过开发服务器,将指定的请求“转发”到真实的后端服务器,从而绕…...

Windows安装Miniconda
一、下载 https://www.anaconda.com/download/success 二、安装 三、配置镜像源 Anaconda/Miniconda pip 配置清华镜像源_anaconda配置清华源-CSDN博客 四、常用操作命令 Anaconda/Miniconda 基本操作命令_miniconda创建环境命令-CSDN博客...

GO协程(Goroutine)问题总结
在使用Go语言来编写代码时,遇到的一些问题总结一下 [参考文档]:https://www.topgoer.com/%E5%B9%B6%E5%8F%91%E7%BC%96%E7%A8%8B/goroutine.html 1. main()函数默认的Goroutine 场景再现: 今天在看到这个教程的时候,在自己的电…...
MinIO Docker 部署:仅开放一个端口
MinIO Docker 部署:仅开放一个端口 在实际的服务器部署中,出于安全和管理的考虑,我们可能只能开放一个端口。MinIO 是一个高性能的对象存储服务,支持 Docker 部署,但默认情况下它需要两个端口:一个是 API 端口(用于存储和访问数据),另一个是控制台端口(用于管理界面…...

基于PHP的连锁酒店管理系统
有需要请加文章底部Q哦 可远程调试 基于PHP的连锁酒店管理系统 一 介绍 连锁酒店管理系统基于原生PHP开发,数据库mysql,前端bootstrap。系统角色分为用户和管理员。 技术栈 phpmysqlbootstrapphpstudyvscode 二 功能 用户 1 注册/登录/注销 2 个人中…...

springboot 日志类切面,接口成功记录日志,失败不记录
springboot 日志类切面,接口成功记录日志,失败不记录 自定义一个注解方法 import java.lang.annotation.ElementType; import java.lang.annotation.Retention; import java.lang.annotation.RetentionPolicy; import java.lang.annotation.Target;/***…...
离线语音识别方案分析
随着人工智能技术的不断发展,语音识别技术也得到了广泛的应用,从智能家居到车载系统,语音识别正在改变我们与设备的交互方式。尤其是离线语音识别,由于其在没有网络连接的情况下仍然能提供稳定、准确的语音处理能力,广…...