使用R-GCN处理异质图ACM的demo
加载和处理数据集
import torch
from torch_geometric.datasets import HGBDataset
from torch_geometric.transforms import RandomLinkSplit# 加载ACM数据集,这是一个包含论文(paper)、主题(subject)以及它们之间关系的异质图数据集
dataset = HGBDataset(root='/tmp/HGB', name='ACM')
data = dataset[0] # 使用数据集中的第一个图# 利用RandomLinkSplit将数据集随机划分为训练、验证和测试集
transform = RandomLinkSplit(num_val=0.1, num_test=0.2, is_undirected=True, split_labels=True,neg_sampling_ratio=1.0,edge_types=[('paper', 'has-subject', 'subject')] # 指定需要划分的边类型
)
train_data, val_data, test_data = transform(data)
定义 R-GCN 模型
from torch_geometric.nn import RGCNConv
import torch.nn.functional as Fclass RGCN(torch.nn.Module):def __init__(self, in_channels, hidden_channels, out_channels, num_relations):super().__init__()# 定义两层RGCN层,每层处理图中的不同关系类型self.conv1 = RGCNConv(in_channels, hidden_channels, num_relations=num_relations)self.conv2 = RGCNConv(hidden_channels, out_channels, num_relations=num_relations)def forward(self, x, edge_index, edge_type):# 使用ReLU激活函数处理第一层的输出x = F.relu(self.conv1(x, edge_index, edge_type))# 第二层RGCN处理并输出节点特征x = self.conv2(x, edge_index, edge_type)return xnum_relations = len(torch.unique(data.edge_type)) # 计算图中不同关系类型的数量
model = RGCN(in_channels=data.num_node_features, hidden_channels=64, out_channels=32, num_relations=num_relations)
训练和测试函数
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
criterion = torch.nn.BCEWithLogitsLoss()def train():model.train()optimizer.zero_grad()# 前向传播,计算模型在训练数据上的输出z = model(train_data.x, train_data.edge_index, train_data.edge_type)# 计算二进制交叉熵损失loss = criterion(z[train_data.edge_label_index], train_data.edge_label.float())loss.backward()optimizer.step()return loss.item()def test(data):model.eval()with torch.no_grad():z = model(data.x, data.edge_index, data.edge_type)loss = criterion(z[data.edge_label_index], data.edge_label.float())# 根据sigmoid阈值判断预测为正类或负类pred = z.sigmoid() > 0.5# 计算准确率correct = pred == data.edge_label.bool()acc = int(correct.sum()) / int(correct.size(0))return loss.item(), accfor epoch in range(100):loss = train()val_loss, val_acc = test(val_data)print(f'Epoch: {epoch+1}, Loss: {loss:.4f}, Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}')test_loss, test_acc = test(test_data)
print(f'Test Loss: {test_loss:.4f}, Test Acc: {test_acc:.4f}')
完整代码
import torch
from torch_geometric.datasets import HGBDataset
from torch_geometric.transforms import RandomLinkSplit# 加载ACM数据集,这是一个包含论文(paper)、主题(subject)以及它们之间关系的异质图数据集
dataset = HGBDataset(root='/tmp/HGB', name='ACM')
data = dataset[0] # 使用数据集中的第一个图# 利用RandomLinkSplit将数据集随机划分为训练、验证和测试集
transform = RandomLinkSplit(num_val=0.1, num_test=0.2, is_undirected=True, split_labels=True,neg_sampling_ratio=1.0,edge_types=[('paper', 'has-subject', 'subject')] # 指定需要划分的边类型
)
train_data, val_data, test_data = transform(data)from torch_geometric.nn import RGCNConv
import torch.nn.functional as Fclass RGCN(torch.nn.Module):def __init__(self, in_channels, hidden_channels, out_channels, num_relations):super().__init__()# 定义两层RGCN层,每层处理图中的不同关系类型self.conv1 = RGCNConv(in_channels, hidden_channels, num_relations=num_relations)self.conv2 = RGCNConv(hidden_channels, out_channels, num_relations=num_relations)def forward(self, x, edge_index, edge_type):# 使用ReLU激活函数处理第一层的输出x = F.relu(self.conv1(x, edge_index, edge_type))# 第二层RGCN处理并输出节点特征x = self.conv2(x, edge_index, edge_type)return xnum_relations = len(torch.unique(data.edge_type)) # 计算图中不同关系类型的数量
model = RGCN(in_channels=data.num_node_features, hidden_channels=64, out_channels=32, num_relations=num_relations)optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
criterion = torch.nn.BCEWithLogitsLoss()def train():model.train()optimizer.zero_grad()# 前向传播,计算模型在训练数据上的输出z = model(train_data.x, train_data.edge_index, train_data.edge_type)# 计算二进制交叉熵损失loss = criterion(z[train_data.edge_label_index], train_data.edge_label.float())loss.backward()optimizer.step()return loss.item()def test(data):model.eval()with torch.no_grad():z = model(data.x, data.edge_index, data.edge_type)loss = criterion(z[data.edge_label_index], data.edge_label.float())# 根据sigmoid阈值判断预测为正类或负类pred = z.sigmoid() > 0.5# 计算准确率correct = pred == data.edge_label.bool()acc = int(correct.sum()) / int(correct.size(0))return loss.item(), accfor epoch in range(100):loss = train()val_loss, val_acc = test(val_data)print(f'Epoch: {epoch+1}, Loss: {loss:.4f}, Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}')test_loss, test_acc = test(test_data)
print(f'Test Loss: {test_loss:.4f}, Test Acc: {test_acc:.4f}')
相关文章:

使用R-GCN处理异质图ACM的demo
加载和处理数据集 import torch from torch_geometric.datasets import HGBDataset from torch_geometric.transforms import RandomLinkSplit# 加载ACM数据集,这是一个包含论文(paper)、主题(subject)以及它们之间关…...

征程 6E DISPLAY 功能介绍及上手实践
01 功能概述 本文将带大家一起实现单路、多路 MIPI CSI TX 输出、IDU 回写、IDU oneshot 模式、绑定输出 VPS 数据等功能,此处主要介绍各 sample 的实现与使用方法。 02 软件架构说明 本文中绑定 VPS 输出功能基于 libvio API 实现,调用 libvio 提供的…...

安卓窗口wms/input小知识NO_INPUT_CHANNEL剖析
背景: 经常在学员的vip技术群里经常有很多学员会提问一些不太常见的窗口和input的相关的问题,虽然不太常见,但确实是工作中会遇到的一些问题,所以马哥有必要进行一下记录这些窗口技术知识点。 具体分享技术点: input中…...

【2024最新版】Win10下 Java环境变量配置----适合入门小白
首先,你应该已经安装了 Java 的 JDK 了(如果没有安装JDK,请跳转到此网址:http://www.oracle.com/technetwork/java/javase/downloads/index-jsp-138363.html) 笔者安装的是 jdk-8u91-windows-x64 接下来主要讲怎么配…...

Servlet 生命周期详解及案例演示(SpringMVC底层实现)
文章目录 什么是Servlet?Servlet生命周期简介1. 初始化阶段:init()方法示例代码: 2. 请求处理阶段:service() 和 doGet()、doPost()方法示例代码: 3. 销毁阶段:destroy()方法示例代码: Servlet生…...

2024 kali系统2024版本,可视化界面汉化教程(需要命令行更改),英文版切换为中文版,基于Debian创建的kali虚拟机
我的界面如下所示 1. 安装 locales sudo apt install locales 2. 生成中文语言环境 sudo locale-gen zh_CN.UTF-8 如果你希望安装繁体中文,可以加入: sudo locale-gen zh_TW.UTF-8 3. 修改 /etc/default/locale 文件 确保有以下内容 LANGzh_CN.UT…...

深入理解 CMake 中的 INCLUDE_DIRECTORIES 与 target_include_directories
在使用 CMake 构建系统时,指定头文件的包含路径是非常常见的一步。对于这个任务,CMake 提供了两种主要的命令:INCLUDE_DIRECTORIES 和 target_include_directories。虽然它们看似类似,但它们的作用范围、应用方式以及适用场景却有…...

【不知道原因的问题】java.lang.AbstractMethodError
项目场景: 提示:这里简述项目相关背景: 遇到了一个问题: java.lang.AbstractMethodError 问题描述 提示:这里描述项目中遇到的问题: 在Java开发中,java.lang.AbstractMethodError是一个常见…...

分布式篇(分布式事务)(持续更新迭代)
一、事务 1. 什么是事务 2. 事务目的 3. 事务的流程 4. 事务四大特性 原子性(Atomicity) 一致性(Consistency) 持久性(Durability) 隔离性(Isolation) 5. MySQL VS Oracle …...

[Linux] 逐层深入理解文件系统 (2)—— 文件重定向
标题:[Linux] 逐层深入理解文件系统 (2)—— 文件重定向 个人主页水墨不写bug (图片来源于网络) 目录 一、文件的读取和写入 二、文件重定向的本质 1.手动模拟重定向的过程——把标准输出重定向到redir.txt 2.重定向…...

html+css+js实现Badge 标记
实现效果: 代码实现: <!DOCTYPE html> <html lang"en"> <head><meta charset"UTF-8"><meta name"viewport" content"widthdevice-width, initial-scale1.0"><title>Badge…...

纯css 轮播图片,鼠标移入暂停 移除继续
核心 滚动: animation: 动画名称 20s linear infinite normal;停止: animation: 动画名称 20s linear infinite paused; 完整例子: html: <div class"carousel-wrapper"><div class"carousel"><div cl…...

iOS GCD的基本使用
一:什么是GCD GCD的全程是:Grand Central Dispatch, 直白的用汉语翻译就是:厉害的中枢调度器. GCD 是iOS 的多线程技术的实现方案,但是它并不是多线程技术,它是“并发解决技术”,是苹果公司研发的,会自动管理线程(这一段定义有点拗口,简单了解就行) GCD会自动管理线程的生命…...

如何设计开发RTSP直播播放器?
技术背景 我们在对接RTSP直播播放器相关技术诉求的时候,好多开发者,除了选用成熟的RTSP播放器外,还想知其然知其所以然,对RTSP播放器的整体开发有个基础的了解,方便方案之作和技术延伸。本文抛砖引玉,做个…...

Java基础系列-一文搞懂自定义排序
java自定义排序 自定义排序的理解: 我们首先看需求:一个二维数组 [[1,3],[8,10],[15,18],[2,6]] 我们的需求是根据集合(二维数组取出来的数据) 左边小的左边这种方式排序 例如1<8 排序方式就是[1,3],[8,10] 此时我们就需…...

扫普通链接二维码打开小程序
1. 2.新增规则(注意下载文件到跟目录下,需要建个文件夹放下载的校验文件) 3.发布 ps:发布后,只能访问正式版本。体验版本如果加了 测试链接http://xxx/xsc/10 那么http://xxx/xsc/aa.....应该都能访问 例如aa101 aa…...

计算机储存与分区
Disk partitioning 盘分区是在辅助储存上创建一个或多个区域,以便可以单独管理每个区域。而这些区域称为分区(partition)。这通常是在为新盘选择分区方案后,需要做的事。 MBR and GPT 分区方案(分区表)有…...

OpenCV之换脸技术:一场面部识别的奇妙之旅
在这个数字化与智能化并进的时代,图像处理技术日益成为连接现实与虚拟世界的桥梁。其中,换脸技术作为一项颇受欢迎且富有挑战性的应用,不仅让人惊叹于技术的魔力,更在娱乐、影视制作等领域展现了无限可能。今天,我们就…...

Linux学习笔记9 文件系统的基础
一、查看文件组织结构 Linux中一切都是文件。 Linux和Win的文件系统不是一个结构,Linux存在的根目录是所有目录的起点。 所有的存储空间和设备共享一个根目录,不同的磁盘块和分区挂载在其下,成为某个子目录的子目录,甚至设备也挂…...

Android OpenGL粒子特效
在本篇,我们将开启一个新的项目,探索粒子的世界。粒子是一种基本的图形元素,它们通常被表示为一组点。通过巧妙地组合一些基础的物理效果,我们能够创造出许多令人惊叹的视觉效果。想象一下,我们可以模拟一个水滴从喷泉…...

5 -《本地部署开源大模型》在Ubuntu 22.04系统下ChatGLM3-6B高效微调实战
在Ubuntu 22.04系统下ChatGLM3-6B高效微调实战 无论是在单机单卡(一台机器上只有一块GPU)还是单机多卡(一台机器上有多块GPU)的硬件配置上启动ChatGLM3-6B模型,其前置环境配置和项目文件是相同的。如果大家对配置过程还…...

dpkg:错误:另外一个进程已经为dpkg前端锁加锁
一、 问题描述 在新装ubuntu系统时,我们常常会遇见dpkg的错误,dpkg:错误:另外一个进程已经为dpkg前端锁加锁,如下图。 二、问题解决 方法一 先执行sudo rm /var/lib/dpkg/lock-frontend然后再继续安装软件包,如果出现问题dpkg:…...

基于SSM服装定制系统的设计
管理员账户功能包括:系统首页,个人中心,用户管理,服装类型管理,服装信息管理,服装定制管理,留言反馈,系统管理 前台账号功能包括:系统首页,个人中心…...

RK3588开发笔记-usb3.0 xhci-hcd控制器挂死问题解决
目录 前言 一、问题现象 二、问题分析 三、问题排查 总结 前言 在使用 RK3588 芯片进行开发的过程中,我遇到了 USB 3.0 xHCI-HCD 控制器外接5G通讯模块偶尔挂死的问题。这个问题导致 USB 设备失去响应,且不能恢复,需要重启整个系统才能恢复使用,针对该问题进行大量测试以…...

深入解析TCP/IP协议:网络通信的基石
1. 引言 TCP/IP 协议是现代计算机网络的核心,它为互联网上的设备提供了通信的基础。在网络通信中,TCP/IP 协议栈是无处不在的,无论是个人设备的浏览器请求,还是大型分布式系统的内部通信,都依赖于它的稳定、高效与可靠…...

基于微信小程序的汽车预约维修系统(lw+演示+源码+运行)
基于微信小程序的汽车预约维修系统 摘要 随着信息技术在管理上越来越深入而广泛的应用,管理信息系统的实施在技术上已逐步成熟。本文介绍了基于微信小程序的汽车预约维修系统的开发全过程。通过分析基于微信小程序的汽车预约维修系统管理的不足,创建了…...

wifi、热点密码破解 - python
乐子脚本,有点小慢,试过多线程,系统 wifi 连接太慢了,需要时间确认,多线程的话系统根本反应不过来。 也就可以试试破解别人的热点,一般都是 123456 这样的傻鸟口令 # coding:utf-8 import pywifi from pyw…...

bean的实例化2024年10月17日
跟不上为基础 1.你的java学习路线 2. 3.课程 注解的装配 contoller调用service用的是注解装配...

告别ELK,APO提供基于ClickHouse开箱即用的高效日志方案——APO 0.6.0发布
ELK一直是日志领域的主流产品,但是ElasticSearch的成本很高,查询效果随着数据量的增加越来越慢。业界已经有很多公司,比如滴滴、B站、Uber、Cloudflare都已经使用ClickHose作为ElasticSearch的替代品,都取得了不错的效果ÿ…...

Excel使用技巧:定位Ctrl+G +公式+原位填充 Ctrl+Enter快速填充数据(处理合并单元格)
Excel的正确用法: Excel是个数据库,不要随意合并单元格。 数据输入的时候一定要按照行列输入,中间不要留空,不然就没有关联。 定位CtrlG 公式原位填充 CtrlEnter快速填充数据 如果把合并的单元格 取消合并,只有第一…...