使用R-GCN处理异质图ACM的demo
加载和处理数据集
import torch
from torch_geometric.datasets import HGBDataset
from torch_geometric.transforms import RandomLinkSplit# 加载ACM数据集,这是一个包含论文(paper)、主题(subject)以及它们之间关系的异质图数据集
dataset = HGBDataset(root='/tmp/HGB', name='ACM')
data = dataset[0] # 使用数据集中的第一个图# 利用RandomLinkSplit将数据集随机划分为训练、验证和测试集
transform = RandomLinkSplit(num_val=0.1, num_test=0.2, is_undirected=True, split_labels=True,neg_sampling_ratio=1.0,edge_types=[('paper', 'has-subject', 'subject')] # 指定需要划分的边类型
)
train_data, val_data, test_data = transform(data)
定义 R-GCN 模型
from torch_geometric.nn import RGCNConv
import torch.nn.functional as Fclass RGCN(torch.nn.Module):def __init__(self, in_channels, hidden_channels, out_channels, num_relations):super().__init__()# 定义两层RGCN层,每层处理图中的不同关系类型self.conv1 = RGCNConv(in_channels, hidden_channels, num_relations=num_relations)self.conv2 = RGCNConv(hidden_channels, out_channels, num_relations=num_relations)def forward(self, x, edge_index, edge_type):# 使用ReLU激活函数处理第一层的输出x = F.relu(self.conv1(x, edge_index, edge_type))# 第二层RGCN处理并输出节点特征x = self.conv2(x, edge_index, edge_type)return xnum_relations = len(torch.unique(data.edge_type)) # 计算图中不同关系类型的数量
model = RGCN(in_channels=data.num_node_features, hidden_channels=64, out_channels=32, num_relations=num_relations)
训练和测试函数
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
criterion = torch.nn.BCEWithLogitsLoss()def train():model.train()optimizer.zero_grad()# 前向传播,计算模型在训练数据上的输出z = model(train_data.x, train_data.edge_index, train_data.edge_type)# 计算二进制交叉熵损失loss = criterion(z[train_data.edge_label_index], train_data.edge_label.float())loss.backward()optimizer.step()return loss.item()def test(data):model.eval()with torch.no_grad():z = model(data.x, data.edge_index, data.edge_type)loss = criterion(z[data.edge_label_index], data.edge_label.float())# 根据sigmoid阈值判断预测为正类或负类pred = z.sigmoid() > 0.5# 计算准确率correct = pred == data.edge_label.bool()acc = int(correct.sum()) / int(correct.size(0))return loss.item(), accfor epoch in range(100):loss = train()val_loss, val_acc = test(val_data)print(f'Epoch: {epoch+1}, Loss: {loss:.4f}, Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}')test_loss, test_acc = test(test_data)
print(f'Test Loss: {test_loss:.4f}, Test Acc: {test_acc:.4f}')
完整代码
import torch
from torch_geometric.datasets import HGBDataset
from torch_geometric.transforms import RandomLinkSplit# 加载ACM数据集,这是一个包含论文(paper)、主题(subject)以及它们之间关系的异质图数据集
dataset = HGBDataset(root='/tmp/HGB', name='ACM')
data = dataset[0] # 使用数据集中的第一个图# 利用RandomLinkSplit将数据集随机划分为训练、验证和测试集
transform = RandomLinkSplit(num_val=0.1, num_test=0.2, is_undirected=True, split_labels=True,neg_sampling_ratio=1.0,edge_types=[('paper', 'has-subject', 'subject')] # 指定需要划分的边类型
)
train_data, val_data, test_data = transform(data)from torch_geometric.nn import RGCNConv
import torch.nn.functional as Fclass RGCN(torch.nn.Module):def __init__(self, in_channels, hidden_channels, out_channels, num_relations):super().__init__()# 定义两层RGCN层,每层处理图中的不同关系类型self.conv1 = RGCNConv(in_channels, hidden_channels, num_relations=num_relations)self.conv2 = RGCNConv(hidden_channels, out_channels, num_relations=num_relations)def forward(self, x, edge_index, edge_type):# 使用ReLU激活函数处理第一层的输出x = F.relu(self.conv1(x, edge_index, edge_type))# 第二层RGCN处理并输出节点特征x = self.conv2(x, edge_index, edge_type)return xnum_relations = len(torch.unique(data.edge_type)) # 计算图中不同关系类型的数量
model = RGCN(in_channels=data.num_node_features, hidden_channels=64, out_channels=32, num_relations=num_relations)optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
criterion = torch.nn.BCEWithLogitsLoss()def train():model.train()optimizer.zero_grad()# 前向传播,计算模型在训练数据上的输出z = model(train_data.x, train_data.edge_index, train_data.edge_type)# 计算二进制交叉熵损失loss = criterion(z[train_data.edge_label_index], train_data.edge_label.float())loss.backward()optimizer.step()return loss.item()def test(data):model.eval()with torch.no_grad():z = model(data.x, data.edge_index, data.edge_type)loss = criterion(z[data.edge_label_index], data.edge_label.float())# 根据sigmoid阈值判断预测为正类或负类pred = z.sigmoid() > 0.5# 计算准确率correct = pred == data.edge_label.bool()acc = int(correct.sum()) / int(correct.size(0))return loss.item(), accfor epoch in range(100):loss = train()val_loss, val_acc = test(val_data)print(f'Epoch: {epoch+1}, Loss: {loss:.4f}, Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}')test_loss, test_acc = test(test_data)
print(f'Test Loss: {test_loss:.4f}, Test Acc: {test_acc:.4f}')
相关文章:
使用R-GCN处理异质图ACM的demo
加载和处理数据集 import torch from torch_geometric.datasets import HGBDataset from torch_geometric.transforms import RandomLinkSplit# 加载ACM数据集,这是一个包含论文(paper)、主题(subject)以及它们之间关…...
征程 6E DISPLAY 功能介绍及上手实践
01 功能概述 本文将带大家一起实现单路、多路 MIPI CSI TX 输出、IDU 回写、IDU oneshot 模式、绑定输出 VPS 数据等功能,此处主要介绍各 sample 的实现与使用方法。 02 软件架构说明 本文中绑定 VPS 输出功能基于 libvio API 实现,调用 libvio 提供的…...
安卓窗口wms/input小知识NO_INPUT_CHANNEL剖析
背景: 经常在学员的vip技术群里经常有很多学员会提问一些不太常见的窗口和input的相关的问题,虽然不太常见,但确实是工作中会遇到的一些问题,所以马哥有必要进行一下记录这些窗口技术知识点。 具体分享技术点: input中…...
【2024最新版】Win10下 Java环境变量配置----适合入门小白
首先,你应该已经安装了 Java 的 JDK 了(如果没有安装JDK,请跳转到此网址:http://www.oracle.com/technetwork/java/javase/downloads/index-jsp-138363.html) 笔者安装的是 jdk-8u91-windows-x64 接下来主要讲怎么配…...
Servlet 生命周期详解及案例演示(SpringMVC底层实现)
文章目录 什么是Servlet?Servlet生命周期简介1. 初始化阶段:init()方法示例代码: 2. 请求处理阶段:service() 和 doGet()、doPost()方法示例代码: 3. 销毁阶段:destroy()方法示例代码: Servlet生…...
2024 kali系统2024版本,可视化界面汉化教程(需要命令行更改),英文版切换为中文版,基于Debian创建的kali虚拟机
我的界面如下所示 1. 安装 locales sudo apt install locales 2. 生成中文语言环境 sudo locale-gen zh_CN.UTF-8 如果你希望安装繁体中文,可以加入: sudo locale-gen zh_TW.UTF-8 3. 修改 /etc/default/locale 文件 确保有以下内容 LANGzh_CN.UT…...
深入理解 CMake 中的 INCLUDE_DIRECTORIES 与 target_include_directories
在使用 CMake 构建系统时,指定头文件的包含路径是非常常见的一步。对于这个任务,CMake 提供了两种主要的命令:INCLUDE_DIRECTORIES 和 target_include_directories。虽然它们看似类似,但它们的作用范围、应用方式以及适用场景却有…...
【不知道原因的问题】java.lang.AbstractMethodError
项目场景: 提示:这里简述项目相关背景: 遇到了一个问题: java.lang.AbstractMethodError 问题描述 提示:这里描述项目中遇到的问题: 在Java开发中,java.lang.AbstractMethodError是一个常见…...
分布式篇(分布式事务)(持续更新迭代)
一、事务 1. 什么是事务 2. 事务目的 3. 事务的流程 4. 事务四大特性 原子性(Atomicity) 一致性(Consistency) 持久性(Durability) 隔离性(Isolation) 5. MySQL VS Oracle …...
[Linux] 逐层深入理解文件系统 (2)—— 文件重定向
标题:[Linux] 逐层深入理解文件系统 (2)—— 文件重定向 个人主页水墨不写bug (图片来源于网络) 目录 一、文件的读取和写入 二、文件重定向的本质 1.手动模拟重定向的过程——把标准输出重定向到redir.txt 2.重定向…...
html+css+js实现Badge 标记
实现效果: 代码实现: <!DOCTYPE html> <html lang"en"> <head><meta charset"UTF-8"><meta name"viewport" content"widthdevice-width, initial-scale1.0"><title>Badge…...
纯css 轮播图片,鼠标移入暂停 移除继续
核心 滚动: animation: 动画名称 20s linear infinite normal;停止: animation: 动画名称 20s linear infinite paused; 完整例子: html: <div class"carousel-wrapper"><div class"carousel"><div cl…...
iOS GCD的基本使用
一:什么是GCD GCD的全程是:Grand Central Dispatch, 直白的用汉语翻译就是:厉害的中枢调度器. GCD 是iOS 的多线程技术的实现方案,但是它并不是多线程技术,它是“并发解决技术”,是苹果公司研发的,会自动管理线程(这一段定义有点拗口,简单了解就行) GCD会自动管理线程的生命…...
如何设计开发RTSP直播播放器?
技术背景 我们在对接RTSP直播播放器相关技术诉求的时候,好多开发者,除了选用成熟的RTSP播放器外,还想知其然知其所以然,对RTSP播放器的整体开发有个基础的了解,方便方案之作和技术延伸。本文抛砖引玉,做个…...
Java基础系列-一文搞懂自定义排序
java自定义排序 自定义排序的理解: 我们首先看需求:一个二维数组 [[1,3],[8,10],[15,18],[2,6]] 我们的需求是根据集合(二维数组取出来的数据) 左边小的左边这种方式排序 例如1<8 排序方式就是[1,3],[8,10] 此时我们就需…...
扫普通链接二维码打开小程序
1. 2.新增规则(注意下载文件到跟目录下,需要建个文件夹放下载的校验文件) 3.发布 ps:发布后,只能访问正式版本。体验版本如果加了 测试链接http://xxx/xsc/10 那么http://xxx/xsc/aa.....应该都能访问 例如aa101 aa…...
计算机储存与分区
Disk partitioning 盘分区是在辅助储存上创建一个或多个区域,以便可以单独管理每个区域。而这些区域称为分区(partition)。这通常是在为新盘选择分区方案后,需要做的事。 MBR and GPT 分区方案(分区表)有…...
OpenCV之换脸技术:一场面部识别的奇妙之旅
在这个数字化与智能化并进的时代,图像处理技术日益成为连接现实与虚拟世界的桥梁。其中,换脸技术作为一项颇受欢迎且富有挑战性的应用,不仅让人惊叹于技术的魔力,更在娱乐、影视制作等领域展现了无限可能。今天,我们就…...
Linux学习笔记9 文件系统的基础
一、查看文件组织结构 Linux中一切都是文件。 Linux和Win的文件系统不是一个结构,Linux存在的根目录是所有目录的起点。 所有的存储空间和设备共享一个根目录,不同的磁盘块和分区挂载在其下,成为某个子目录的子目录,甚至设备也挂…...
Android OpenGL粒子特效
在本篇,我们将开启一个新的项目,探索粒子的世界。粒子是一种基本的图形元素,它们通常被表示为一组点。通过巧妙地组合一些基础的物理效果,我们能够创造出许多令人惊叹的视觉效果。想象一下,我们可以模拟一个水滴从喷泉…...
Word样式与多级列表深度绑定指南:让你的标题编号“活”起来,增删章节不再乱
Word样式与多级列表深度绑定指南:让你的标题编号“活”起来,增删章节不再乱 每次在Word中调整文档结构时,你是否经历过这样的崩溃瞬间:精心排版的章节编号突然乱成一团,原本整齐的"1.1"变成了毫无规律的&quo…...
OpenClaw自动化测试:千问3.5-35B-A3B-FP8多模态任务可靠性验证方法
OpenClaw自动化测试:千问3.5-35B-A3B-FP8多模态任务可靠性验证方法 1. 为什么需要系统性测试多模态模型 上周我在调试一个自动整理图片的OpenClaw工作流时,遇到了诡异的现象——AI助手把会议白板照片里的流程图误识别成了"披萨制作步骤"。这…...
从理论到实践:信道利用率在停止-等待与回退N帧协议中的量化分析与优化
1. 信道利用率的核心概念与实战意义 第一次接触信道利用率这个概念时,我也被各种公式绕得头晕。直到在卫星通信项目中踩过几次坑才真正明白:信道利用率就是衡量你把通信线路"压榨"到什么程度的标尺。想象你租了条高速公路送货,总不…...
2025届最火的五大AI论文网站横评
Ai论文网站排名(开题报告、文献综述、降aigc率、降重综合对比) TOP1. 千笔AI TOP2. aipasspaper TOP3. 清北论文 TOP4. 豆包 TOP5. kimi TOP6. deepseek 在生成式人工智能技术于学术写作里被广泛施行当下,维普平台正式推出了AIGC内容检…...
洁净车间PLC数据采集远程监控系统方案
为了维持洁净厂房内的温度、湿度及洁净度等,需要在车间部署多个高精度的温湿度传感器以及监控空气风管的风机、风阀,和监控冷热源管道循环水的压力传感器、电动调节阀等,由PLC控制冷热源机组运行状态,进而为车间洁净度进行自动化管…...
深入解析.ko驱动模块加载报错:unknown symbol问题排查与依赖管理
1. 遇到unknown symbol报错时的心态调整 第一次看到"unknown symbol in module"这个报错时,我正熬夜调试一个摄像头驱动。当时整个人都是懵的——明明编译通过了,为什么加载时会说找不到符号?后来才发现,这是Linux内核驱…...
90%嵌入式工程师必踩坑之volatile关键字,学会它轻松搞定面试官!!!
若想搞定什么是volatile关键字,首先要清楚CPU的变量读取规则:CPU 的运算单元(ALU)无法直接对内存中的变量做运算,内存里的变量(或外设寄存器中的变量)必须先加载到 CPU 内部的通用寄存器&#x…...
测试左移与右移:不仅仅是工作环节的变化
从被动执行到主动防御的质变传统瀑布模型中,测试常被压缩在开发周期末端,被动等待提测、疲于缺陷修复。而测试左移(Shift-Left)与右移(Shift-Right)的核心理念,是通过重构质量保障体系ÿ…...
Build-A-Large-Language-Model-CN:大语言模型训练中的常见问题与解决方案
Build-A-Large-Language-Model-CN:大语言模型训练中的常见问题与解决方案 【免费下载链接】Build-A-Large-Language-Model-CN 《Build a Large Language Model (From Scratch)》是一本深入探讨大语言模型原理与实现的电子书,适合希望深入了解 GPT 等大模…...
OBS绿幕抠像技术解析:chroma_key_filter.effect源码实现与优化
1. 绿幕抠像技术基础与OBS实现原理 绿幕抠像(Chroma Key)是视频处理领域的经典技术,就像魔术师用的隐身斗篷,它能让特定颜色范围(通常是绿色或蓝色)变得透明。我在实际项目中发现,OBS Studio作为…...
