Pytorch Geometric官方例程pytorch_geometric/examples/link_pred.py环境安装教程及图数据集制作
最近需要训练图卷积神经网络(Graph Convolution Neural Network, GCNN),在配置GCNN环境上总结了一些经验。
我觉得对于初学者而言,图神经网络的训练会有2个难点:
①环境配置
②数据集制作
一、环境配置
我最初光想到要给GCNN配环境就觉得有些困难,感觉相比于目标检测、分类识别这些任务用规则数据,图神经网络的模型、数据都是图,所以内心觉得会比较难。
我之前更有一个误区,就是觉得不规则结构的图数据不能用CUDA进行并行加速。实际上,图,在电脑里也是以张量这种规则结构数据存在的,完全能用CUDA进行加速计算,训练GCN前配置CUDA完全OK。
以下是我配置的环境,可用CUDA成功运行link_pred.py
几个关键包的版本:
torch 2.4.1
torch-geometric 2.3.1
torchaudio 2.4.1
torchvision 0.14.0
torchviz 0.0.2pandas 1.0.3
numpy 1.20.0
CUDA: 11.8
注意要先安装好CUDA,显示了:
再安装GPU版本的torch,不然python检测安装的是cpu版本的torch。这时,就得卸载重新安装了
环境配置成功:
print(torch.__version__)
print(torch.cuda.is_available())
如果CUDA环境安装失败,会打印:
2.4.1+cpu
False
其实只安装torch和CUDA还好,如果你的python中有numpy和pandas可能解决版本之间的冲突会耗费不少时间,我就是在numpy和pandas版本上试了很久,最终找到现在的版本是相互兼容的。
CUDA的版本切换可以参考我的另一篇博客:
CUDA版本切换
二、数据集制作
掌握图数据集制作的关键在于掌握slices切片:
for ...data = Data(x=X, edge_index=Edge_index, edge_label_index=Edge_label_index, edge_label=Edge_label)data_list.append(data)
data_, slices = self.collate(data_list) # 将不同大小的图数据对齐,填充
torch.save((data_, slices), self.processed_paths[0])
和CNN不同的是,GCN没有样本维度,需要把所有样本拼成一张大图喂给GCN进行训练
数据集生成代码:
#作者:zhouzhichao
#创建时间:2025/5/30
#内容:生成200个样本的PYG数据集import h5py
import hdf5storage
import numpy as np
import torch
from torch_geometric.data import InMemoryDataset, Data
from torch_geometric.utils import negative_samplingbase_dir = "D:\\无线通信网络认知\\论文1\\experiment\\直推式拓扑推理实验\\拓扑生成\\200样本\\"N = 30
grapg_size = N
train_n = 31
M = 3000class graph_data(InMemoryDataset):def __init__(self, root, signals=None, tp_list = None, transform=None, pre_transform=None):# self.Signals = Signals# self.Tp_list = Tp_listself.signals = signalsself.tp_list = tp_listsuper().__init__(root, transform, pre_transform)# self.data, self.slices = torch.load(self.processed_paths[0])self.data = torch.load(self.processed_paths[0])# 返回process方法所需的保存文件名。你之后保存的数据集名字和列表里的一致@propertydef processed_file_names(self):return ['gcn_data.pt']# 生成数据集所用的方法def process(self):# data_list = []# for k in range(200):# signals = self.Signals[:, :, k]# tp_list = np.array(mat_file[self.Tp_list[0, k]])signals = self.signalstp_list =self.tp_list# tp = Tp[:,:,k]X = torch.tensor(signals, dtype=torch.float)# 所有的边Edge_index = torch.tensor(tp_list, dtype=torch.long)# 所有的边1标签edge_label = np.ones((tp_list.shape[1]))# edge_label = np.zeros((tp_list.shape[1]))Edge_label = torch.tensor(edge_label, dtype=torch.float)neg_edge_index = negative_sampling(edge_index=Edge_index, num_nodes=grapg_size,num_neg_samples=Edge_index.shape[1], method='sparse')# 拼接正负样本索引# c = 0# for i in range(31):# for i in range(31):# if torch.equal(Edge_index[:, i], neg_edge_index[:, i]):# c = c + 1# print("c: ",c)Edge_label_index = Edge_indexperm = torch.randperm(Edge_index.size(1))Edge_index = Edge_index[:, perm]Edge_index = Edge_index[:, :train_n]Edge_label_index = torch.cat([Edge_label_index, neg_edge_index],dim=-1,)# 拼接正负样本Edge_label = torch.cat([Edge_label,Edge_label.new_zeros(neg_edge_index.size(1))], dim=0)# Edge_label = torch.cat([# Edge_label,# Edge_label.new_ones(neg_edge_index.size(1))# ], dim=0)data = Data(x=X, edge_index=Edge_index, edge_label_index=Edge_label_index, edge_label=Edge_label)torch.save(data, self.processed_paths[0])# data_list.append(data)# data_, slices = self.collate(data_list) # 将不同大小的图数据对齐,填充# torch.save((data_, slices), self.processed_paths[0])for snr in [0,20,40]:print("snr: ", snr)mat_file = h5py.File(base_dir + str(N) + '_nodes_dataset_snr-' + str(snr) + '_M_' + str(M) + '.mat', 'r')# mat_file = hdf5storage.loadmat(base_dir + str(N) + '_nodes_dataset_snr-' + str(snr) + '_M_' + str(M) + '.mat', 'r')# 获取数据集Signals = mat_file["Signals"][()]# signals = np.swapaxes(signals, 1, 0)Tp = mat_file["Tp"][()]Tp_list = mat_file["Tp_list"][()]# tp_list = tp_list - 1# 关闭文件# mat_file.close()# graph_data("gcn_data")# n = Signals.shape[2]n = 10for i in range(n):signals = Signals[:,:,i]tp_list = np.array(mat_file[Tp_list[0, i]])root = "gcn_data-"+str(i)+"_N_"+str(N)+"_snr_"+str(snr)+"_train_n_"+str(train_n)+"_M_"+str(M)graph_data(root, signals = signals, tp_list = tp_list)print("")print("...图数据生成完成...")
训练代码:
#作者:zhouzhichao
#创建时间:25年5月29日
#内容:统计图中有关系节点和无关系节点的GCN特征欧式距离import sys
import torch
import random
import numpy as np
import pandas as pd
from torch_geometric.nn import GCNConv
from sklearn.metrics import roc_auc_score
sys.path.append('D:\无线通信网络认知\论文1\experiment\直推式拓扑推理实验\GCN推理')
from gcn_dataset import graph_data
print(torch.__version__)
print(torch.cuda.is_available())mode = "gcn"class Net(torch.nn.Module):def __init__(self):super().__init__()self.conv1 = GCNConv(Input_L, 1000)self.conv2 = GCNConv(1000, 20)def encode(self, x, edge_index):x1 = self.conv1(x, edge_index)x1_1 = x1.relu()x2 = self.conv2(x1_1, edge_index)x2_2 = x2.relu()return x2_2def decode(self, z, edge_label_index):# 节点和边都是矩阵,不同的计算方法致使:节点->节点,节点->边# nodes_relation = (z[edge_label_index[0]] * z[edge_label_index[1]]).sum(dim=-1)# distances = torch.norm(z[edge_label_index[0]] - z[edge_label_index[1]], dim=-1)distance_squared = torch.sum((z[edge_label_index[0]] - z[edge_label_index[1]]) ** 2, dim=-1)# print("distance_squared: ",distance_squared)return distance_squareddef decode_all(self, z):prob_adj = z @ z.t() # 得到所有边概率矩阵return (prob_adj > 0).nonzero(as_tuple=False).t() # 返回概率大于0的边,以edge_index的形式@torch.no_grad()def test(self,input_data):model.eval()z = model.encode(input_data.x, input_data.edge_index)out = model.decode(z, input_data.edge_label_index).view(-1)out = 1 - outN = 30
train_n = 31
M = 3000
# snr = -20
# for train_n in range(1,51):
# for M in range(3000, 499, -100):
for snr in [0,20,40]:print("snr: ", snr)for I in range(10):root = "gcn_data-"+str(I)+"_N_"+str(N)+"_snr_"+str(snr)+"_train_n_"+str(train_n)+"_M_"+str(M)gcn_data = graph_data(root)Input_L = gcn_data.x.shape[1]model = Net()# model = Net().to(device)optimizer = torch.optim.Adam(params=model.parameters(), lr=0.01)criterion = torch.nn.BCEWithLogitsLoss()def train():model.train()optimizer.zero_grad()z = model.encode(gcn_data.x, gcn_data.edge_index)# out = model.decode(z, train_data.edge_label_index).view(-1).sigmoid()out = model.decode(z, gcn_data.edge_label_index).view(-1)out = 1 - outloss = criterion(out, gcn_data.edge_label)loss.backward()optimizer.step()return lossmin_loss = 99999count = 0#早停for epoch in range(10000):loss = train()if loss<min_loss:min_loss = losscount = 0count = count + 1if count>100:breakprint("epoch: ",epoch," loss: ",round(loss.item(),2), " min_loss: ",round(min_loss.item(),2))z = model.encode(gcn_data.x, gcn_data.edge_index)out = model.decode(z, gcn_data.edge_label_index).view(-1)list_0 = []list_1 = []for i in range(len(gcn_data.edge_label)):true_label = gcn_data.edge_label[i].item()euclidean_distance_value = out[i].item()if true_label==1:list_1.append(euclidean_distance_value)if true_label==0:list_0.append(euclidean_distance_value)minlength = min(len(list_1), len(list_0))list_1 = random.sample(list_1, minlength)list_0 = random.sample(list_0, minlength)value = list_1 + list_0large_class = list(np.full(len(value), snr))small_class = list(np.full(len(list_1), 1)) + list(np.full(len(list_0), 0))data = {'large_class': large_class,'small_class': small_class,'value': value}# 创建一个 DataFramedf = pd.DataFrame(data)## # 保存到 Excel 文件file_path = 'D:\无线通信网络认知\论文1\大修意见\图聚类、阈值相似性图实验补充\\' + mode + '_similarity_' + str(snr) + 'db_'+str(I)+'.xlsx'df.to_excel(file_path, index=False)
相关文章:

Pytorch Geometric官方例程pytorch_geometric/examples/link_pred.py环境安装教程及图数据集制作
最近需要训练图卷积神经网络(Graph Convolution Neural Network, GCNN),在配置GCNN环境上总结了一些经验。 我觉得对于初学者而言,图神经网络的训练会有2个难点: ①环境配置 ②数据集制作 一、环境配置 我最初光想…...

React---day6、7
6、组件之间进行数据传递 **6.1 父传子:**props传递属性 父组件: <div><ChildCpn name"蒋乙菥" age"18" height"1,88" /> </div>子组件: export class ChildCpn extends React.Component…...

hook组件-useEffect、useRef
hook组件-useEffect、useRef useEffect 用法及执行机制 WillMount -> render -> DidMount ShouldUpdate -> WillUpdate -> render -> DidUpdate WillUnmount(只有这个安全) WillReceiveProps useEffect(callback) 默认所有依赖都更新useEffect(callback, [])&am…...
功能结构整理
C# Sxer Sxer.Base:基础子功能 Sxer.Base.Debug:打印 Sxer.Utility:工具类 Sxer.CustomFunction:独立功能点开发 Unity...
企业级开发中的 maven-mvnd 应用实践
1. 引言:Maven 在企业级开发中的挑战 1.1 Maven 构建的常见痛点 在大型 Java 项目中,Maven 是主流的构建工具,但随着项目的复杂度增加,其性能瓶颈逐渐显现: 构建速度慢:每次执行 mvn clean install 都需要重新加载插件和依赖。重复构建浪费资源:即使未修改源码,仍会触…...
yolov12毕设前置知识准备 1
1 什么是目标检测呢? 目标检测(Object Detection)主要用于识别图像或视频中特定类型物体的位置,并标注其类别。 简单来说,就是让计算机像人类一样 “看懂” 图像内容,不仅能识别出物体(如人、…...

随机游动算法解决kSAT问题
input:n个变量的k-CNF公式 ouput:该公式的一组满足赋值或宣布没有满足赋值 算法步骤: 随机均匀地初始化赋值 a ∈ { 0 , 1 } n a\in\{0,1\}^n a∈{0,1}n.重复t次(后面会估计这个t): a. 如果在当前赋值下…...

《Discuz! X3.5开发从入门到生态共建》第1章 Discuz! 的前世今生-优雅草卓伊凡
《Discuz! X3.5开发从入门到生态共建》第1章 Discuz! 的前世今生-优雅草卓伊凡 第一节 从康盛创想到腾讯收购:PC时代的辉煌 1.1 Discuz! 的诞生:康盛创想的开源梦想 2001年,中国互联网正处于萌芽阶段,个人网站和论坛开始兴起。…...
azure web app创建分步指南系列之一
什么是 Azure Web 应用? Azure Web 应用是 Azure 应用服务的一部分,是一个完全托管的平台,用于开发、部署和扩展 Web 应用程序。它支持各种编程语言和框架,例如 .NET、Java、Python、PHP 和 Node.js,使开发人员能够构建强大的 Web 应用程序,而无需担心底层基础架构。借助…...
PyTorch实战——基于生成对抗网络生成服饰图像
PyTorch实战——基于生成对抗网络生成服饰图像 0. 前言1. 模型分析与数据准备2. 判别器3. 生成器4. 模型训练5. 模型保存与加载相关链接0. 前言 我们已经学习了生成对抗网络 (Generative Adversarial Network, GAN) 的工作原理,接下来,将学习如何将其应用于生成其他形式的内…...

笔试强训:Day6
一、小红的口罩(贪心优先级队列) 登录—专业IT笔试面试备考平台_牛客网 #include<iostream> #include<queue> #include<vector> using namespace std; int n,k; int main(){//用一个小根堆 每次使用不舒适度最小的cin>>n>&…...
【Hexo】4.Hexo 博客文章进行加密
安装 npm install --save hexo-blog-encrypt1-快速使用 将“ password”添加到您的文章信息头就像这样: password: 123456 ---2-按标签加密 1.修改文章信息头如下: title: Hello World tags: - 加密文章tag date: 2020-03-13 21:12:21 password: muyiio…...
Android --- ObjectAnimator 和 TranslateAnimation有什么区别
文章目录 2. 作用范围和功能2. 动画表现3. 是否修改 View 的属性4. 适用场景5. 性能总结: ObjectAnimator 和 TranslateAnimation 都是 Android 中常用的动画类型,但它们有以下几个关键的区别: 2. 作用范围和功能 ObjectAnimator:…...
小白的进阶之路系列之四----人工智能从初步到精通pytorch自定义数据集下
本篇涵盖的内容 在之前的文章中,我们已经讨论了如何获取数据,转换数据以及如何准备自定义数据集,本篇文章将涵盖更加深入的问题,希望通过详细的代码示例,帮助大家了解PyTorch自定义数据集是如何应对各种复杂实际情况中,数据处理的。 更加详细的,我们将讨论下面一些内容…...
安卓添加设备节点权限和selinux访问权限
# 1 修改设备节点权限及配置属性设置节点值 ## 1.1 修改设备节点权限 ### 1.1.1 不会手动卸载的节点 在system/core/rootdir/init.rc中添加节点权限 在on boot下面添加 chown system system /sys/kernel/usb/host chmod 0664 /sys/kernel/usb/host ### 1.1.2 支持热插拔的…...

谷歌Stitch:AI赋能UI设计,免费高效新利器
在AI技术日新月异的今天,各大科技巨头都在不断刷新我们对智能工具的认知。最近,谷歌在其年度I/O开发者大会期间,除了那些聚光灯下的重磅发布,还悄然上线了一款令人惊喜的AI工具——Stitch。这是一款全新的、完全免费的AI驱动UI&am…...

运营商地址和ip属地一样吗?怎么样更改ip属地地址
在互联网时代,IP属地和运营商地址是两个经常被提及的概念,但它们是否相同?如何更改IP属地地址?这些问题困扰着许多网民。本文将深入探讨这两个概念的区别,并详细介绍更改IP属地地址的方法。 一、运营商地址和IP属地一…...

在QT中,利用charts库绘制FFT图形
第1章 添加charts库 1.1 .pro工程添加chart库 1.1.1 在.pro工程里面添加charts库 1.1.2 在需要使用的地方添加这两个库函数,顺序一点不要搞错,先添加.pro,否则编译器会找不到这两个.h文件。 第2章 Charts关键绘图函数 2.1 QChart 类 QChart 是…...
ChatGPT + 知网 + 知乎,如何高效整合信息写出一篇专业内容?
——写作,不是闭门造车,而是高效聚合 🧠 为什么“信息整合力”才是AI时代的核心写作能力? 现在的写作,不缺工具,也不缺资料,缺的是: 把 scattered info 变成 structured idea 的能力…...

流媒体协议分析:流媒体传输的基石
在流媒体传输过程中,协议的选择至关重要,它决定了数据如何封装、传输和解析,直接影响着视频的播放质量和用户体验。本文将深入分析几种常见的流媒体传输协议,探讨它们的特点、应用场景及优缺点。 协议分类概述 流媒体传输协议根据…...

vscode中让文件夹一直保持展开不折叠
vscode中让文件夹一直保持展开不折叠 问题 很多小伙伴使用vscode发现空文件夹会折叠显示, 让人看起来非常难受, 如下图 解决办法 首先打开设置->setting, 搜索compact Folders, 去掉勾选即可, 如下图所示 效果如下 看起来非常爽 ! ! !...

JAVA-springboot整合Mybatis
SpringBoot从入门到精通-第15章 MyBatis框架 学习MyBatis心路历程 2022年学习java基础时候,想着怎么使用java代码操作数据库,咨询了项目上开发W同事,没有引用框架,操作数据库很麻烦,就帮我写好多行代码,就…...

深度学习pycharm debug
深度学习中,Debug 是定位并解决代码逻辑错误(如张量维度不匹配)、训练异常(如 Loss 波动)、数据问题(如标签错误)的关键手段,通过打印维度、可视化梯度等方法确保模型正常运行、优化…...

MicroPython+L298N+ESP32控制电机转速
要使用MicroPython控制L298N电机驱动板来控制电机的转速,你可以通过PWM(脉冲宽度调制)信号来调节电机速度。L298N是一个双H桥驱动器,可以同时控制两个电机的正反转和速度。 硬件准备: 1. L298N 电机控制板 2. ESP32…...
Hive的存储格式如何优化?
Hive的存储格式对查询性能、存储成本和数据处理效率有显著影响。以下是主流存储格式的特点、选择标准和优化方法: 一、主流存储格式对比 特性ORC(Optimized Row Columnar)ParquetTextFile(默认)SequenceFile数据布局…...

在部署了一台mysql5.7的机器上部署mysql8.0.35
在已部署 MySQL 5.7 的机器上部署 MySQL 8.0.35 的完整指南 在同一台服务器上部署多个 MySQL 版本需要谨慎规划,避免端口冲突和数据混淆。以下是详细的部署步骤: 一、规划配置 端口分配 MySQL 5.7:使用默认端口 3306MySQL 8.0.35࿱…...
OpenCV CUDA模块结构分析与形状描述符------在 GPU 上计算图像的原始矩(spatial moments)函数spatialMoments()
操作系统:ubuntu22.04 OpenCV版本:OpenCV4.9 IDE:Visual Studio Code 编程语言:C11 算法描述 该函数用于在 GPU 上计算图像的原始矩(spatial moments)。这些矩可用于描述图像中物体的形状特征,如面积、质…...

QT入门学习(一)---新建工程与、信号与槽
一: 新建QT项目 二:QT文件构成 2.1 first.pro 项目管理文件,下面来看代码解析 QT core guigreaterThan(QT_MAJOR_VERSION, 4): QT widgetsCONFIG c11TARGET main# The following define makes your compiler emit warnings if you use # any Qt feature …...

UE5.4.4+Rider2024.3.7开发环境配置
文章目录 一、UE5安装 安装有两种方式一种的源码编译安装、一种是EPIC安装,推荐后者,只需要注册一个EPIC账号就可以一键安装。 二、C环境安装 1.下载VisualStudioSetup 下载链接如下下载 Visual Studio Tools - 免费安装 Windows、Mac、Linux 选择社…...

Windows环境下PHP,在PowerShell控制台输出中文乱码
解决方法: 以管理员运行PowerShell , 输入: chcp 65001 重启控制台;然后就正常输出中文;...