当前位置: 首页 > news >正文

使用R-GCN处理异质图ACM的demo

加载和处理数据集

import torch
from torch_geometric.datasets import HGBDataset
from torch_geometric.transforms import RandomLinkSplit# 加载ACM数据集,这是一个包含论文(paper)、主题(subject)以及它们之间关系的异质图数据集
dataset = HGBDataset(root='/tmp/HGB', name='ACM')
data = dataset[0]  # 使用数据集中的第一个图# 利用RandomLinkSplit将数据集随机划分为训练、验证和测试集
transform = RandomLinkSplit(num_val=0.1, num_test=0.2, is_undirected=True, split_labels=True,neg_sampling_ratio=1.0,edge_types=[('paper', 'has-subject', 'subject')]  # 指定需要划分的边类型
)
train_data, val_data, test_data = transform(data)

定义 R-GCN 模型

from torch_geometric.nn import RGCNConv
import torch.nn.functional as Fclass RGCN(torch.nn.Module):def __init__(self, in_channels, hidden_channels, out_channels, num_relations):super().__init__()# 定义两层RGCN层,每层处理图中的不同关系类型self.conv1 = RGCNConv(in_channels, hidden_channels, num_relations=num_relations)self.conv2 = RGCNConv(hidden_channels, out_channels, num_relations=num_relations)def forward(self, x, edge_index, edge_type):# 使用ReLU激活函数处理第一层的输出x = F.relu(self.conv1(x, edge_index, edge_type))# 第二层RGCN处理并输出节点特征x = self.conv2(x, edge_index, edge_type)return xnum_relations = len(torch.unique(data.edge_type))  # 计算图中不同关系类型的数量
model = RGCN(in_channels=data.num_node_features, hidden_channels=64, out_channels=32, num_relations=num_relations)

训练和测试函数

optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
criterion = torch.nn.BCEWithLogitsLoss()def train():model.train()optimizer.zero_grad()# 前向传播,计算模型在训练数据上的输出z = model(train_data.x, train_data.edge_index, train_data.edge_type)# 计算二进制交叉熵损失loss = criterion(z[train_data.edge_label_index], train_data.edge_label.float())loss.backward()optimizer.step()return loss.item()def test(data):model.eval()with torch.no_grad():z = model(data.x, data.edge_index, data.edge_type)loss = criterion(z[data.edge_label_index], data.edge_label.float())# 根据sigmoid阈值判断预测为正类或负类pred = z.sigmoid() > 0.5# 计算准确率correct = pred == data.edge_label.bool()acc = int(correct.sum()) / int(correct.size(0))return loss.item(), accfor epoch in range(100):loss = train()val_loss, val_acc = test(val_data)print(f'Epoch: {epoch+1}, Loss: {loss:.4f}, Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}')test_loss, test_acc = test(test_data)
print(f'Test Loss: {test_loss:.4f}, Test Acc: {test_acc:.4f}')

完整代码

import torch
from torch_geometric.datasets import HGBDataset
from torch_geometric.transforms import RandomLinkSplit# 加载ACM数据集,这是一个包含论文(paper)、主题(subject)以及它们之间关系的异质图数据集
dataset = HGBDataset(root='/tmp/HGB', name='ACM')
data = dataset[0]  # 使用数据集中的第一个图# 利用RandomLinkSplit将数据集随机划分为训练、验证和测试集
transform = RandomLinkSplit(num_val=0.1, num_test=0.2, is_undirected=True, split_labels=True,neg_sampling_ratio=1.0,edge_types=[('paper', 'has-subject', 'subject')]  # 指定需要划分的边类型
)
train_data, val_data, test_data = transform(data)from torch_geometric.nn import RGCNConv
import torch.nn.functional as Fclass RGCN(torch.nn.Module):def __init__(self, in_channels, hidden_channels, out_channels, num_relations):super().__init__()# 定义两层RGCN层,每层处理图中的不同关系类型self.conv1 = RGCNConv(in_channels, hidden_channels, num_relations=num_relations)self.conv2 = RGCNConv(hidden_channels, out_channels, num_relations=num_relations)def forward(self, x, edge_index, edge_type):# 使用ReLU激活函数处理第一层的输出x = F.relu(self.conv1(x, edge_index, edge_type))# 第二层RGCN处理并输出节点特征x = self.conv2(x, edge_index, edge_type)return xnum_relations = len(torch.unique(data.edge_type))  # 计算图中不同关系类型的数量
model = RGCN(in_channels=data.num_node_features, hidden_channels=64, out_channels=32, num_relations=num_relations)optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
criterion = torch.nn.BCEWithLogitsLoss()def train():model.train()optimizer.zero_grad()# 前向传播,计算模型在训练数据上的输出z = model(train_data.x, train_data.edge_index, train_data.edge_type)# 计算二进制交叉熵损失loss = criterion(z[train_data.edge_label_index], train_data.edge_label.float())loss.backward()optimizer.step()return loss.item()def test(data):model.eval()with torch.no_grad():z = model(data.x, data.edge_index, data.edge_type)loss = criterion(z[data.edge_label_index], data.edge_label.float())# 根据sigmoid阈值判断预测为正类或负类pred = z.sigmoid() > 0.5# 计算准确率correct = pred == data.edge_label.bool()acc = int(correct.sum()) / int(correct.size(0))return loss.item(), accfor epoch in range(100):loss = train()val_loss, val_acc = test(val_data)print(f'Epoch: {epoch+1}, Loss: {loss:.4f}, Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}')test_loss, test_acc = test(test_data)
print(f'Test Loss: {test_loss:.4f}, Test Acc: {test_acc:.4f}')

相关文章:

使用R-GCN处理异质图ACM的demo

加载和处理数据集 import torch from torch_geometric.datasets import HGBDataset from torch_geometric.transforms import RandomLinkSplit# 加载ACM数据集,这是一个包含论文(paper)、主题(subject)以及它们之间关…...

征程 6E DISPLAY 功能介绍及上手实践

01 功能概述 本文将带大家一起实现单路、多路 MIPI CSI TX 输出、IDU 回写、IDU oneshot 模式、绑定输出 VPS 数据等功能,此处主要介绍各 sample 的实现与使用方法。 02 软件架构说明 本文中绑定 VPS 输出功能基于 libvio API 实现,调用 libvio 提供的…...

安卓窗口wms/input小知识NO_INPUT_CHANNEL剖析

背景: 经常在学员的vip技术群里经常有很多学员会提问一些不太常见的窗口和input的相关的问题,虽然不太常见,但确实是工作中会遇到的一些问题,所以马哥有必要进行一下记录这些窗口技术知识点。 具体分享技术点: input中…...

【2024最新版】Win10下 Java环境变量配置----适合入门小白

首先,你应该已经安装了 Java 的 JDK 了(如果没有安装JDK,请跳转到此网址:http://www.oracle.com/technetwork/java/javase/downloads/index-jsp-138363.html) 笔者安装的是 jdk-8u91-windows-x64 接下来主要讲怎么配…...

Servlet 生命周期详解及案例演示(SpringMVC底层实现)

文章目录 什么是Servlet?Servlet生命周期简介1. 初始化阶段:init()方法示例代码: 2. 请求处理阶段:service() 和 doGet()、doPost()方法示例代码: 3. 销毁阶段:destroy()方法示例代码: Servlet生…...

2024 kali系统2024版本,可视化界面汉化教程(需要命令行更改),英文版切换为中文版,基于Debian创建的kali虚拟机

我的界面如下所示 1. 安装 locales sudo apt install locales 2. 生成中文语言环境 sudo locale-gen zh_CN.UTF-8 如果你希望安装繁体中文,可以加入: sudo locale-gen zh_TW.UTF-8 3. 修改 /etc/default/locale 文件 确保有以下内容 LANGzh_CN.UT…...

深入理解 CMake 中的 INCLUDE_DIRECTORIES 与 target_include_directories

在使用 CMake 构建系统时,指定头文件的包含路径是非常常见的一步。对于这个任务,CMake 提供了两种主要的命令:INCLUDE_DIRECTORIES 和 target_include_directories。虽然它们看似类似,但它们的作用范围、应用方式以及适用场景却有…...

【不知道原因的问题】java.lang.AbstractMethodError

项目场景: 提示:这里简述项目相关背景: 遇到了一个问题: java.lang.AbstractMethodError 问题描述 提示:这里描述项目中遇到的问题: 在Java开发中,java.lang.AbstractMethodError是一个常见…...

分布式篇(分布式事务)(持续更新迭代)

一、事务 1. 什么是事务 2. 事务目的 3. 事务的流程 4. 事务四大特性 原子性(Atomicity) 一致性(Consistency) 持久性(Durability) 隔离性(Isolation) 5. MySQL VS Oracle …...

[Linux] 逐层深入理解文件系统 (2)—— 文件重定向

标题:[Linux] 逐层深入理解文件系统 (2)—— 文件重定向 个人主页水墨不写bug (图片来源于网络) 目录 一、文件的读取和写入 二、文件重定向的本质 1.手动模拟重定向的过程——把标准输出重定向到redir.txt 2.重定向…...

html+css+js实现Badge 标记

实现效果&#xff1a; 代码实现&#xff1a; <!DOCTYPE html> <html lang"en"> <head><meta charset"UTF-8"><meta name"viewport" content"widthdevice-width, initial-scale1.0"><title>Badge…...

纯css 轮播图片,鼠标移入暂停 移除继续

核心 滚动&#xff1a; animation: 动画名称 20s linear infinite normal;停止&#xff1a; animation: 动画名称 20s linear infinite paused; 完整例子&#xff1a; html: <div class"carousel-wrapper"><div class"carousel"><div cl…...

iOS GCD的基本使用

一:什么是GCD GCD的全程是:Grand Central Dispatch, 直白的用汉语翻译就是:厉害的中枢调度器. GCD 是iOS 的多线程技术的实现方案,但是它并不是多线程技术,它是“并发解决技术”,是苹果公司研发的,会自动管理线程(这一段定义有点拗口,简单了解就行) GCD会自动管理线程的生命…...

如何设计开发RTSP直播播放器?

技术背景 我们在对接RTSP直播播放器相关技术诉求的时候&#xff0c;好多开发者&#xff0c;除了选用成熟的RTSP播放器外&#xff0c;还想知其然知其所以然&#xff0c;对RTSP播放器的整体开发有个基础的了解&#xff0c;方便方案之作和技术延伸。本文抛砖引玉&#xff0c;做个…...

Java基础系列-一文搞懂自定义排序

java自定义排序 自定义排序的理解&#xff1a; 我们首先看需求&#xff1a;一个二维数组 [[1,3],[8,10],[15,18],[2,6]] 我们的需求是根据集合&#xff08;二维数组取出来的数据&#xff09; 左边小的左边这种方式排序 例如1<8 排序方式就是[1,3],[8,10] 此时我们就需…...

扫普通链接二维码打开小程序

1. 2.新增规则&#xff08;注意下载文件到跟目录下&#xff0c;需要建个文件夹放下载的校验文件&#xff09; 3.发布 ps&#xff1a;发布后&#xff0c;只能访问正式版本。体验版本如果加了 测试链接http://xxx/xsc/10 那么http://xxx/xsc/aa.....应该都能访问 例如aa101 aa…...

计算机储存与分区

Disk partitioning 盘分区是在辅助储存上创建一个或多个区域&#xff0c;以便可以单独管理每个区域。而这些区域称为分区&#xff08;partition&#xff09;。这通常是在为新盘选择分区方案后&#xff0c;需要做的事。 MBR and GPT 分区方案&#xff08;分区表&#xff09;有…...

OpenCV之换脸技术:一场面部识别的奇妙之旅

在这个数字化与智能化并进的时代&#xff0c;图像处理技术日益成为连接现实与虚拟世界的桥梁。其中&#xff0c;换脸技术作为一项颇受欢迎且富有挑战性的应用&#xff0c;不仅让人惊叹于技术的魔力&#xff0c;更在娱乐、影视制作等领域展现了无限可能。今天&#xff0c;我们就…...

Linux学习笔记9 文件系统的基础

一、查看文件组织结构 Linux中一切都是文件。 Linux和Win的文件系统不是一个结构&#xff0c;Linux存在的根目录是所有目录的起点。 所有的存储空间和设备共享一个根目录&#xff0c;不同的磁盘块和分区挂载在其下&#xff0c;成为某个子目录的子目录&#xff0c;甚至设备也挂…...

Android OpenGL粒子特效

在本篇&#xff0c;我们将开启一个新的项目&#xff0c;探索粒子的世界。粒子是一种基本的图形元素&#xff0c;它们通常被表示为一组点。通过巧妙地组合一些基础的物理效果&#xff0c;我们能够创造出许多令人惊叹的视觉效果。想象一下&#xff0c;我们可以模拟一个水滴从喷泉…...

rknn优化教程(二)

文章目录 1. 前述2. 三方库的封装2.1 xrepo中的库2.2 xrepo之外的库2.2.1 opencv2.2.2 rknnrt2.2.3 spdlog 3. rknn_engine库 1. 前述 OK&#xff0c;开始写第二篇的内容了。这篇博客主要能写一下&#xff1a; 如何给一些三方库按照xmake方式进行封装&#xff0c;供调用如何按…...

解决Ubuntu22.04 VMware失败的问题 ubuntu入门之二十八

现象1 打开VMware失败 Ubuntu升级之后打开VMware上报需要安装vmmon和vmnet&#xff0c;点击确认后如下提示 最终上报fail 解决方法 内核升级导致&#xff0c;需要在新内核下重新下载编译安装 查看版本 $ vmware -v VMware Workstation 17.5.1 build-23298084$ lsb_release…...

UE5 学习系列(三)创建和移动物体

这篇博客是该系列的第三篇&#xff0c;是在之前两篇博客的基础上展开&#xff0c;主要介绍如何在操作界面中创建和拖动物体&#xff0c;这篇博客跟随的视频链接如下&#xff1a; B 站视频&#xff1a;s03-创建和移动物体 如果你不打算开之前的博客并且对UE5 比较熟的话按照以…...

Golang dig框架与GraphQL的完美结合

将 Go 的 Dig 依赖注入框架与 GraphQL 结合使用&#xff0c;可以显著提升应用程序的可维护性、可测试性以及灵活性。 Dig 是一个强大的依赖注入容器&#xff0c;能够帮助开发者更好地管理复杂的依赖关系&#xff0c;而 GraphQL 则是一种用于 API 的查询语言&#xff0c;能够提…...

拉力测试cuda pytorch 把 4070显卡拉满

import torch import timedef stress_test_gpu(matrix_size16384, duration300):"""对GPU进行压力测试&#xff0c;通过持续的矩阵乘法来最大化GPU利用率参数:matrix_size: 矩阵维度大小&#xff0c;增大可提高计算复杂度duration: 测试持续时间&#xff08;秒&…...

大语言模型(LLM)中的KV缓存压缩与动态稀疏注意力机制设计

随着大语言模型&#xff08;LLM&#xff09;参数规模的增长&#xff0c;推理阶段的内存占用和计算复杂度成为核心挑战。传统注意力机制的计算复杂度随序列长度呈二次方增长&#xff0c;而KV缓存的内存消耗可能高达数十GB&#xff08;例如Llama2-7B处理100K token时需50GB内存&a…...

基于Springboot+Vue的办公管理系统

角色&#xff1a; 管理员、员工 技术&#xff1a; 后端: SpringBoot, Vue2, MySQL, Mybatis-Plus 前端: Vue2, Element-UI, Axios, Echarts, Vue-Router 核心功能&#xff1a; 该办公管理系统是一个综合性的企业内部管理平台&#xff0c;旨在提升企业运营效率和员工管理水…...

Tauri2学习笔记

教程地址&#xff1a;https://www.bilibili.com/video/BV1Ca411N7mF?spm_id_from333.788.player.switch&vd_source707ec8983cc32e6e065d5496a7f79ee6 官方指引&#xff1a;https://tauri.app/zh-cn/start/ 目前Tauri2的教程视频不多&#xff0c;我按照Tauri1的教程来学习&…...

用鸿蒙HarmonyOS5实现国际象棋小游戏的过程

下面是一个基于鸿蒙OS (HarmonyOS) 的国际象棋小游戏的完整实现代码&#xff0c;使用Java语言和鸿蒙的Ability框架。 1. 项目结构 /src/main/java/com/example/chess/├── MainAbilitySlice.java // 主界面逻辑├── ChessView.java // 游戏视图和逻辑├── …...

职坐标物联网全栈开发全流程解析

物联网全栈开发涵盖从物理设备到上层应用的完整技术链路&#xff0c;其核心流程可归纳为四大模块&#xff1a;感知层数据采集、网络层协议交互、平台层资源管理及应用层功能实现。每个模块的技术选型与实现方式直接影响系统性能与扩展性&#xff0c;例如传感器选型需平衡精度与…...