使用R-GCN处理异质图ACM的demo
加载和处理数据集
import torch
from torch_geometric.datasets import HGBDataset
from torch_geometric.transforms import RandomLinkSplit# 加载ACM数据集,这是一个包含论文(paper)、主题(subject)以及它们之间关系的异质图数据集
dataset = HGBDataset(root='/tmp/HGB', name='ACM')
data = dataset[0] # 使用数据集中的第一个图# 利用RandomLinkSplit将数据集随机划分为训练、验证和测试集
transform = RandomLinkSplit(num_val=0.1, num_test=0.2, is_undirected=True, split_labels=True,neg_sampling_ratio=1.0,edge_types=[('paper', 'has-subject', 'subject')] # 指定需要划分的边类型
)
train_data, val_data, test_data = transform(data)
定义 R-GCN 模型
from torch_geometric.nn import RGCNConv
import torch.nn.functional as Fclass RGCN(torch.nn.Module):def __init__(self, in_channels, hidden_channels, out_channels, num_relations):super().__init__()# 定义两层RGCN层,每层处理图中的不同关系类型self.conv1 = RGCNConv(in_channels, hidden_channels, num_relations=num_relations)self.conv2 = RGCNConv(hidden_channels, out_channels, num_relations=num_relations)def forward(self, x, edge_index, edge_type):# 使用ReLU激活函数处理第一层的输出x = F.relu(self.conv1(x, edge_index, edge_type))# 第二层RGCN处理并输出节点特征x = self.conv2(x, edge_index, edge_type)return xnum_relations = len(torch.unique(data.edge_type)) # 计算图中不同关系类型的数量
model = RGCN(in_channels=data.num_node_features, hidden_channels=64, out_channels=32, num_relations=num_relations)
训练和测试函数
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
criterion = torch.nn.BCEWithLogitsLoss()def train():model.train()optimizer.zero_grad()# 前向传播,计算模型在训练数据上的输出z = model(train_data.x, train_data.edge_index, train_data.edge_type)# 计算二进制交叉熵损失loss = criterion(z[train_data.edge_label_index], train_data.edge_label.float())loss.backward()optimizer.step()return loss.item()def test(data):model.eval()with torch.no_grad():z = model(data.x, data.edge_index, data.edge_type)loss = criterion(z[data.edge_label_index], data.edge_label.float())# 根据sigmoid阈值判断预测为正类或负类pred = z.sigmoid() > 0.5# 计算准确率correct = pred == data.edge_label.bool()acc = int(correct.sum()) / int(correct.size(0))return loss.item(), accfor epoch in range(100):loss = train()val_loss, val_acc = test(val_data)print(f'Epoch: {epoch+1}, Loss: {loss:.4f}, Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}')test_loss, test_acc = test(test_data)
print(f'Test Loss: {test_loss:.4f}, Test Acc: {test_acc:.4f}')
完整代码
import torch
from torch_geometric.datasets import HGBDataset
from torch_geometric.transforms import RandomLinkSplit# 加载ACM数据集,这是一个包含论文(paper)、主题(subject)以及它们之间关系的异质图数据集
dataset = HGBDataset(root='/tmp/HGB', name='ACM')
data = dataset[0] # 使用数据集中的第一个图# 利用RandomLinkSplit将数据集随机划分为训练、验证和测试集
transform = RandomLinkSplit(num_val=0.1, num_test=0.2, is_undirected=True, split_labels=True,neg_sampling_ratio=1.0,edge_types=[('paper', 'has-subject', 'subject')] # 指定需要划分的边类型
)
train_data, val_data, test_data = transform(data)from torch_geometric.nn import RGCNConv
import torch.nn.functional as Fclass RGCN(torch.nn.Module):def __init__(self, in_channels, hidden_channels, out_channels, num_relations):super().__init__()# 定义两层RGCN层,每层处理图中的不同关系类型self.conv1 = RGCNConv(in_channels, hidden_channels, num_relations=num_relations)self.conv2 = RGCNConv(hidden_channels, out_channels, num_relations=num_relations)def forward(self, x, edge_index, edge_type):# 使用ReLU激活函数处理第一层的输出x = F.relu(self.conv1(x, edge_index, edge_type))# 第二层RGCN处理并输出节点特征x = self.conv2(x, edge_index, edge_type)return xnum_relations = len(torch.unique(data.edge_type)) # 计算图中不同关系类型的数量
model = RGCN(in_channels=data.num_node_features, hidden_channels=64, out_channels=32, num_relations=num_relations)optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
criterion = torch.nn.BCEWithLogitsLoss()def train():model.train()optimizer.zero_grad()# 前向传播,计算模型在训练数据上的输出z = model(train_data.x, train_data.edge_index, train_data.edge_type)# 计算二进制交叉熵损失loss = criterion(z[train_data.edge_label_index], train_data.edge_label.float())loss.backward()optimizer.step()return loss.item()def test(data):model.eval()with torch.no_grad():z = model(data.x, data.edge_index, data.edge_type)loss = criterion(z[data.edge_label_index], data.edge_label.float())# 根据sigmoid阈值判断预测为正类或负类pred = z.sigmoid() > 0.5# 计算准确率correct = pred == data.edge_label.bool()acc = int(correct.sum()) / int(correct.size(0))return loss.item(), accfor epoch in range(100):loss = train()val_loss, val_acc = test(val_data)print(f'Epoch: {epoch+1}, Loss: {loss:.4f}, Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}')test_loss, test_acc = test(test_data)
print(f'Test Loss: {test_loss:.4f}, Test Acc: {test_acc:.4f}')
相关文章:
使用R-GCN处理异质图ACM的demo
加载和处理数据集 import torch from torch_geometric.datasets import HGBDataset from torch_geometric.transforms import RandomLinkSplit# 加载ACM数据集,这是一个包含论文(paper)、主题(subject)以及它们之间关…...
征程 6E DISPLAY 功能介绍及上手实践
01 功能概述 本文将带大家一起实现单路、多路 MIPI CSI TX 输出、IDU 回写、IDU oneshot 模式、绑定输出 VPS 数据等功能,此处主要介绍各 sample 的实现与使用方法。 02 软件架构说明 本文中绑定 VPS 输出功能基于 libvio API 实现,调用 libvio 提供的…...
安卓窗口wms/input小知识NO_INPUT_CHANNEL剖析
背景: 经常在学员的vip技术群里经常有很多学员会提问一些不太常见的窗口和input的相关的问题,虽然不太常见,但确实是工作中会遇到的一些问题,所以马哥有必要进行一下记录这些窗口技术知识点。 具体分享技术点: input中…...
【2024最新版】Win10下 Java环境变量配置----适合入门小白
首先,你应该已经安装了 Java 的 JDK 了(如果没有安装JDK,请跳转到此网址:http://www.oracle.com/technetwork/java/javase/downloads/index-jsp-138363.html) 笔者安装的是 jdk-8u91-windows-x64 接下来主要讲怎么配…...
Servlet 生命周期详解及案例演示(SpringMVC底层实现)
文章目录 什么是Servlet?Servlet生命周期简介1. 初始化阶段:init()方法示例代码: 2. 请求处理阶段:service() 和 doGet()、doPost()方法示例代码: 3. 销毁阶段:destroy()方法示例代码: Servlet生…...
2024 kali系统2024版本,可视化界面汉化教程(需要命令行更改),英文版切换为中文版,基于Debian创建的kali虚拟机
我的界面如下所示 1. 安装 locales sudo apt install locales 2. 生成中文语言环境 sudo locale-gen zh_CN.UTF-8 如果你希望安装繁体中文,可以加入: sudo locale-gen zh_TW.UTF-8 3. 修改 /etc/default/locale 文件 确保有以下内容 LANGzh_CN.UT…...
深入理解 CMake 中的 INCLUDE_DIRECTORIES 与 target_include_directories
在使用 CMake 构建系统时,指定头文件的包含路径是非常常见的一步。对于这个任务,CMake 提供了两种主要的命令:INCLUDE_DIRECTORIES 和 target_include_directories。虽然它们看似类似,但它们的作用范围、应用方式以及适用场景却有…...
【不知道原因的问题】java.lang.AbstractMethodError
项目场景: 提示:这里简述项目相关背景: 遇到了一个问题: java.lang.AbstractMethodError 问题描述 提示:这里描述项目中遇到的问题: 在Java开发中,java.lang.AbstractMethodError是一个常见…...
分布式篇(分布式事务)(持续更新迭代)
一、事务 1. 什么是事务 2. 事务目的 3. 事务的流程 4. 事务四大特性 原子性(Atomicity) 一致性(Consistency) 持久性(Durability) 隔离性(Isolation) 5. MySQL VS Oracle …...
[Linux] 逐层深入理解文件系统 (2)—— 文件重定向
标题:[Linux] 逐层深入理解文件系统 (2)—— 文件重定向 个人主页水墨不写bug (图片来源于网络) 目录 一、文件的读取和写入 二、文件重定向的本质 1.手动模拟重定向的过程——把标准输出重定向到redir.txt 2.重定向…...
html+css+js实现Badge 标记
实现效果: 代码实现: <!DOCTYPE html> <html lang"en"> <head><meta charset"UTF-8"><meta name"viewport" content"widthdevice-width, initial-scale1.0"><title>Badge…...
纯css 轮播图片,鼠标移入暂停 移除继续
核心 滚动: animation: 动画名称 20s linear infinite normal;停止: animation: 动画名称 20s linear infinite paused; 完整例子: html: <div class"carousel-wrapper"><div class"carousel"><div cl…...
iOS GCD的基本使用
一:什么是GCD GCD的全程是:Grand Central Dispatch, 直白的用汉语翻译就是:厉害的中枢调度器. GCD 是iOS 的多线程技术的实现方案,但是它并不是多线程技术,它是“并发解决技术”,是苹果公司研发的,会自动管理线程(这一段定义有点拗口,简单了解就行) GCD会自动管理线程的生命…...
如何设计开发RTSP直播播放器?
技术背景 我们在对接RTSP直播播放器相关技术诉求的时候,好多开发者,除了选用成熟的RTSP播放器外,还想知其然知其所以然,对RTSP播放器的整体开发有个基础的了解,方便方案之作和技术延伸。本文抛砖引玉,做个…...
Java基础系列-一文搞懂自定义排序
java自定义排序 自定义排序的理解: 我们首先看需求:一个二维数组 [[1,3],[8,10],[15,18],[2,6]] 我们的需求是根据集合(二维数组取出来的数据) 左边小的左边这种方式排序 例如1<8 排序方式就是[1,3],[8,10] 此时我们就需…...
扫普通链接二维码打开小程序
1. 2.新增规则(注意下载文件到跟目录下,需要建个文件夹放下载的校验文件) 3.发布 ps:发布后,只能访问正式版本。体验版本如果加了 测试链接http://xxx/xsc/10 那么http://xxx/xsc/aa.....应该都能访问 例如aa101 aa…...
计算机储存与分区
Disk partitioning 盘分区是在辅助储存上创建一个或多个区域,以便可以单独管理每个区域。而这些区域称为分区(partition)。这通常是在为新盘选择分区方案后,需要做的事。 MBR and GPT 分区方案(分区表)有…...
OpenCV之换脸技术:一场面部识别的奇妙之旅
在这个数字化与智能化并进的时代,图像处理技术日益成为连接现实与虚拟世界的桥梁。其中,换脸技术作为一项颇受欢迎且富有挑战性的应用,不仅让人惊叹于技术的魔力,更在娱乐、影视制作等领域展现了无限可能。今天,我们就…...
Linux学习笔记9 文件系统的基础
一、查看文件组织结构 Linux中一切都是文件。 Linux和Win的文件系统不是一个结构,Linux存在的根目录是所有目录的起点。 所有的存储空间和设备共享一个根目录,不同的磁盘块和分区挂载在其下,成为某个子目录的子目录,甚至设备也挂…...
Android OpenGL粒子特效
在本篇,我们将开启一个新的项目,探索粒子的世界。粒子是一种基本的图形元素,它们通常被表示为一组点。通过巧妙地组合一些基础的物理效果,我们能够创造出许多令人惊叹的视觉效果。想象一下,我们可以模拟一个水滴从喷泉…...
基于大模型的 UI 自动化系统
基于大模型的 UI 自动化系统 下面是一个完整的 Python 系统,利用大模型实现智能 UI 自动化,结合计算机视觉和自然语言处理技术,实现"看屏操作"的能力。 系统架构设计 #mermaid-svg-2gn2GRvh5WCP2ktF {font-family:"trebuchet ms",verdana,arial,sans-…...
docker详细操作--未完待续
docker介绍 docker官网: Docker:加速容器应用程序开发 harbor官网:Harbor - Harbor 中文 使用docker加速器: Docker镜像极速下载服务 - 毫秒镜像 是什么 Docker 是一种开源的容器化平台,用于将应用程序及其依赖项(如库、运行时环…...
大型活动交通拥堵治理的视觉算法应用
大型活动下智慧交通的视觉分析应用 一、背景与挑战 大型活动(如演唱会、马拉松赛事、高考中考等)期间,城市交通面临瞬时人流车流激增、传统摄像头模糊、交通拥堵识别滞后等问题。以演唱会为例,暖城商圈曾因观众集中离场导致周边…...
从深圳崛起的“机器之眼”:赴港乐动机器人的万亿赛道赶考路
进入2025年以来,尽管围绕人形机器人、具身智能等机器人赛道的质疑声不断,但全球市场热度依然高涨,入局者持续增加。 以国内市场为例,天眼查专业版数据显示,截至5月底,我国现存在业、存续状态的机器人相关企…...
使用van-uploader 的UI组件,结合vue2如何实现图片上传组件的封装
以下是基于 vant-ui(适配 Vue2 版本 )实现截图中照片上传预览、删除功能,并封装成可复用组件的完整代码,包含样式和逻辑实现,可直接在 Vue2 项目中使用: 1. 封装的图片上传组件 ImageUploader.vue <te…...
Linux云原生安全:零信任架构与机密计算
Linux云原生安全:零信任架构与机密计算 构建坚不可摧的云原生防御体系 引言:云原生安全的范式革命 随着云原生技术的普及,安全边界正在从传统的网络边界向工作负载内部转移。Gartner预测,到2025年,零信任架构将成为超…...
Java线上CPU飙高问题排查全指南
一、引言 在Java应用的线上运行环境中,CPU飙高是一个常见且棘手的性能问题。当系统出现CPU飙高时,通常会导致应用响应缓慢,甚至服务不可用,严重影响用户体验和业务运行。因此,掌握一套科学有效的CPU飙高问题排查方法&…...
LeetCode - 199. 二叉树的右视图
题目 199. 二叉树的右视图 - 力扣(LeetCode) 思路 右视图是指从树的右侧看,对于每一层,只能看到该层最右边的节点。实现思路是: 使用深度优先搜索(DFS)按照"根-右-左"的顺序遍历树记录每个节点的深度对于…...
视频行为标注工具BehaviLabel(源码+使用介绍+Windows.Exe版本)
前言: 最近在做行为检测相关的模型,用的是时空图卷积网络(STGCN),但原有kinetic-400数据集数据质量较低,需要进行细粒度的标注,同时粗略搜了下已有开源工具基本都集中于图像分割这块,…...
iOS性能调优实战:借助克魔(KeyMob)与常用工具深度洞察App瓶颈
在日常iOS开发过程中,性能问题往往是最令人头疼的一类Bug。尤其是在App上线前的压测阶段或是处理用户反馈的高发期,开发者往往需要面对卡顿、崩溃、能耗异常、日志混乱等一系列问题。这些问题表面上看似偶发,但背后往往隐藏着系统资源调度不当…...
