使用R-GCN处理异质图ACM的demo
加载和处理数据集
import torch
from torch_geometric.datasets import HGBDataset
from torch_geometric.transforms import RandomLinkSplit# 加载ACM数据集,这是一个包含论文(paper)、主题(subject)以及它们之间关系的异质图数据集
dataset = HGBDataset(root='/tmp/HGB', name='ACM')
data = dataset[0] # 使用数据集中的第一个图# 利用RandomLinkSplit将数据集随机划分为训练、验证和测试集
transform = RandomLinkSplit(num_val=0.1, num_test=0.2, is_undirected=True, split_labels=True,neg_sampling_ratio=1.0,edge_types=[('paper', 'has-subject', 'subject')] # 指定需要划分的边类型
)
train_data, val_data, test_data = transform(data)
定义 R-GCN 模型
from torch_geometric.nn import RGCNConv
import torch.nn.functional as Fclass RGCN(torch.nn.Module):def __init__(self, in_channels, hidden_channels, out_channels, num_relations):super().__init__()# 定义两层RGCN层,每层处理图中的不同关系类型self.conv1 = RGCNConv(in_channels, hidden_channels, num_relations=num_relations)self.conv2 = RGCNConv(hidden_channels, out_channels, num_relations=num_relations)def forward(self, x, edge_index, edge_type):# 使用ReLU激活函数处理第一层的输出x = F.relu(self.conv1(x, edge_index, edge_type))# 第二层RGCN处理并输出节点特征x = self.conv2(x, edge_index, edge_type)return xnum_relations = len(torch.unique(data.edge_type)) # 计算图中不同关系类型的数量
model = RGCN(in_channels=data.num_node_features, hidden_channels=64, out_channels=32, num_relations=num_relations)
训练和测试函数
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
criterion = torch.nn.BCEWithLogitsLoss()def train():model.train()optimizer.zero_grad()# 前向传播,计算模型在训练数据上的输出z = model(train_data.x, train_data.edge_index, train_data.edge_type)# 计算二进制交叉熵损失loss = criterion(z[train_data.edge_label_index], train_data.edge_label.float())loss.backward()optimizer.step()return loss.item()def test(data):model.eval()with torch.no_grad():z = model(data.x, data.edge_index, data.edge_type)loss = criterion(z[data.edge_label_index], data.edge_label.float())# 根据sigmoid阈值判断预测为正类或负类pred = z.sigmoid() > 0.5# 计算准确率correct = pred == data.edge_label.bool()acc = int(correct.sum()) / int(correct.size(0))return loss.item(), accfor epoch in range(100):loss = train()val_loss, val_acc = test(val_data)print(f'Epoch: {epoch+1}, Loss: {loss:.4f}, Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}')test_loss, test_acc = test(test_data)
print(f'Test Loss: {test_loss:.4f}, Test Acc: {test_acc:.4f}')
完整代码
import torch
from torch_geometric.datasets import HGBDataset
from torch_geometric.transforms import RandomLinkSplit# 加载ACM数据集,这是一个包含论文(paper)、主题(subject)以及它们之间关系的异质图数据集
dataset = HGBDataset(root='/tmp/HGB', name='ACM')
data = dataset[0] # 使用数据集中的第一个图# 利用RandomLinkSplit将数据集随机划分为训练、验证和测试集
transform = RandomLinkSplit(num_val=0.1, num_test=0.2, is_undirected=True, split_labels=True,neg_sampling_ratio=1.0,edge_types=[('paper', 'has-subject', 'subject')] # 指定需要划分的边类型
)
train_data, val_data, test_data = transform(data)from torch_geometric.nn import RGCNConv
import torch.nn.functional as Fclass RGCN(torch.nn.Module):def __init__(self, in_channels, hidden_channels, out_channels, num_relations):super().__init__()# 定义两层RGCN层,每层处理图中的不同关系类型self.conv1 = RGCNConv(in_channels, hidden_channels, num_relations=num_relations)self.conv2 = RGCNConv(hidden_channels, out_channels, num_relations=num_relations)def forward(self, x, edge_index, edge_type):# 使用ReLU激活函数处理第一层的输出x = F.relu(self.conv1(x, edge_index, edge_type))# 第二层RGCN处理并输出节点特征x = self.conv2(x, edge_index, edge_type)return xnum_relations = len(torch.unique(data.edge_type)) # 计算图中不同关系类型的数量
model = RGCN(in_channels=data.num_node_features, hidden_channels=64, out_channels=32, num_relations=num_relations)optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
criterion = torch.nn.BCEWithLogitsLoss()def train():model.train()optimizer.zero_grad()# 前向传播,计算模型在训练数据上的输出z = model(train_data.x, train_data.edge_index, train_data.edge_type)# 计算二进制交叉熵损失loss = criterion(z[train_data.edge_label_index], train_data.edge_label.float())loss.backward()optimizer.step()return loss.item()def test(data):model.eval()with torch.no_grad():z = model(data.x, data.edge_index, data.edge_type)loss = criterion(z[data.edge_label_index], data.edge_label.float())# 根据sigmoid阈值判断预测为正类或负类pred = z.sigmoid() > 0.5# 计算准确率correct = pred == data.edge_label.bool()acc = int(correct.sum()) / int(correct.size(0))return loss.item(), accfor epoch in range(100):loss = train()val_loss, val_acc = test(val_data)print(f'Epoch: {epoch+1}, Loss: {loss:.4f}, Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}')test_loss, test_acc = test(test_data)
print(f'Test Loss: {test_loss:.4f}, Test Acc: {test_acc:.4f}')
相关文章:
使用R-GCN处理异质图ACM的demo
加载和处理数据集 import torch from torch_geometric.datasets import HGBDataset from torch_geometric.transforms import RandomLinkSplit# 加载ACM数据集,这是一个包含论文(paper)、主题(subject)以及它们之间关…...
征程 6E DISPLAY 功能介绍及上手实践
01 功能概述 本文将带大家一起实现单路、多路 MIPI CSI TX 输出、IDU 回写、IDU oneshot 模式、绑定输出 VPS 数据等功能,此处主要介绍各 sample 的实现与使用方法。 02 软件架构说明 本文中绑定 VPS 输出功能基于 libvio API 实现,调用 libvio 提供的…...
安卓窗口wms/input小知识NO_INPUT_CHANNEL剖析
背景: 经常在学员的vip技术群里经常有很多学员会提问一些不太常见的窗口和input的相关的问题,虽然不太常见,但确实是工作中会遇到的一些问题,所以马哥有必要进行一下记录这些窗口技术知识点。 具体分享技术点: input中…...
【2024最新版】Win10下 Java环境变量配置----适合入门小白
首先,你应该已经安装了 Java 的 JDK 了(如果没有安装JDK,请跳转到此网址:http://www.oracle.com/technetwork/java/javase/downloads/index-jsp-138363.html) 笔者安装的是 jdk-8u91-windows-x64 接下来主要讲怎么配…...
Servlet 生命周期详解及案例演示(SpringMVC底层实现)
文章目录 什么是Servlet?Servlet生命周期简介1. 初始化阶段:init()方法示例代码: 2. 请求处理阶段:service() 和 doGet()、doPost()方法示例代码: 3. 销毁阶段:destroy()方法示例代码: Servlet生…...
2024 kali系统2024版本,可视化界面汉化教程(需要命令行更改),英文版切换为中文版,基于Debian创建的kali虚拟机
我的界面如下所示 1. 安装 locales sudo apt install locales 2. 生成中文语言环境 sudo locale-gen zh_CN.UTF-8 如果你希望安装繁体中文,可以加入: sudo locale-gen zh_TW.UTF-8 3. 修改 /etc/default/locale 文件 确保有以下内容 LANGzh_CN.UT…...
深入理解 CMake 中的 INCLUDE_DIRECTORIES 与 target_include_directories
在使用 CMake 构建系统时,指定头文件的包含路径是非常常见的一步。对于这个任务,CMake 提供了两种主要的命令:INCLUDE_DIRECTORIES 和 target_include_directories。虽然它们看似类似,但它们的作用范围、应用方式以及适用场景却有…...
【不知道原因的问题】java.lang.AbstractMethodError
项目场景: 提示:这里简述项目相关背景: 遇到了一个问题: java.lang.AbstractMethodError 问题描述 提示:这里描述项目中遇到的问题: 在Java开发中,java.lang.AbstractMethodError是一个常见…...
分布式篇(分布式事务)(持续更新迭代)
一、事务 1. 什么是事务 2. 事务目的 3. 事务的流程 4. 事务四大特性 原子性(Atomicity) 一致性(Consistency) 持久性(Durability) 隔离性(Isolation) 5. MySQL VS Oracle …...
[Linux] 逐层深入理解文件系统 (2)—— 文件重定向
标题:[Linux] 逐层深入理解文件系统 (2)—— 文件重定向 个人主页水墨不写bug (图片来源于网络) 目录 一、文件的读取和写入 二、文件重定向的本质 1.手动模拟重定向的过程——把标准输出重定向到redir.txt 2.重定向…...
html+css+js实现Badge 标记
实现效果: 代码实现: <!DOCTYPE html> <html lang"en"> <head><meta charset"UTF-8"><meta name"viewport" content"widthdevice-width, initial-scale1.0"><title>Badge…...
纯css 轮播图片,鼠标移入暂停 移除继续
核心 滚动: animation: 动画名称 20s linear infinite normal;停止: animation: 动画名称 20s linear infinite paused; 完整例子: html: <div class"carousel-wrapper"><div class"carousel"><div cl…...
iOS GCD的基本使用
一:什么是GCD GCD的全程是:Grand Central Dispatch, 直白的用汉语翻译就是:厉害的中枢调度器. GCD 是iOS 的多线程技术的实现方案,但是它并不是多线程技术,它是“并发解决技术”,是苹果公司研发的,会自动管理线程(这一段定义有点拗口,简单了解就行) GCD会自动管理线程的生命…...
如何设计开发RTSP直播播放器?
技术背景 我们在对接RTSP直播播放器相关技术诉求的时候,好多开发者,除了选用成熟的RTSP播放器外,还想知其然知其所以然,对RTSP播放器的整体开发有个基础的了解,方便方案之作和技术延伸。本文抛砖引玉,做个…...
Java基础系列-一文搞懂自定义排序
java自定义排序 自定义排序的理解: 我们首先看需求:一个二维数组 [[1,3],[8,10],[15,18],[2,6]] 我们的需求是根据集合(二维数组取出来的数据) 左边小的左边这种方式排序 例如1<8 排序方式就是[1,3],[8,10] 此时我们就需…...
扫普通链接二维码打开小程序
1. 2.新增规则(注意下载文件到跟目录下,需要建个文件夹放下载的校验文件) 3.发布 ps:发布后,只能访问正式版本。体验版本如果加了 测试链接http://xxx/xsc/10 那么http://xxx/xsc/aa.....应该都能访问 例如aa101 aa…...
计算机储存与分区
Disk partitioning 盘分区是在辅助储存上创建一个或多个区域,以便可以单独管理每个区域。而这些区域称为分区(partition)。这通常是在为新盘选择分区方案后,需要做的事。 MBR and GPT 分区方案(分区表)有…...
OpenCV之换脸技术:一场面部识别的奇妙之旅
在这个数字化与智能化并进的时代,图像处理技术日益成为连接现实与虚拟世界的桥梁。其中,换脸技术作为一项颇受欢迎且富有挑战性的应用,不仅让人惊叹于技术的魔力,更在娱乐、影视制作等领域展现了无限可能。今天,我们就…...
Linux学习笔记9 文件系统的基础
一、查看文件组织结构 Linux中一切都是文件。 Linux和Win的文件系统不是一个结构,Linux存在的根目录是所有目录的起点。 所有的存储空间和设备共享一个根目录,不同的磁盘块和分区挂载在其下,成为某个子目录的子目录,甚至设备也挂…...
Android OpenGL粒子特效
在本篇,我们将开启一个新的项目,探索粒子的世界。粒子是一种基本的图形元素,它们通常被表示为一组点。通过巧妙地组合一些基础的物理效果,我们能够创造出许多令人惊叹的视觉效果。想象一下,我们可以模拟一个水滴从喷泉…...
观成科技:隐蔽隧道工具Ligolo-ng加密流量分析
1.工具介绍 Ligolo-ng是一款由go编写的高效隧道工具,该工具基于TUN接口实现其功能,利用反向TCP/TLS连接建立一条隐蔽的通信信道,支持使用Let’s Encrypt自动生成证书。Ligolo-ng的通信隐蔽性体现在其支持多种连接方式,适应复杂网…...
QMC5883L的驱动
简介 本篇文章的代码已经上传到了github上面,开源代码 作为一个电子罗盘模块,我们可以通过I2C从中获取偏航角yaw,相对于六轴陀螺仪的yaw,qmc5883l几乎不会零飘并且成本较低。 参考资料 QMC5883L磁场传感器驱动 QMC5883L磁力计…...
大型活动交通拥堵治理的视觉算法应用
大型活动下智慧交通的视觉分析应用 一、背景与挑战 大型活动(如演唱会、马拉松赛事、高考中考等)期间,城市交通面临瞬时人流车流激增、传统摄像头模糊、交通拥堵识别滞后等问题。以演唱会为例,暖城商圈曾因观众集中离场导致周边…...
【算法训练营Day07】字符串part1
文章目录 反转字符串反转字符串II替换数字 反转字符串 题目链接:344. 反转字符串 双指针法,两个指针的元素直接调转即可 class Solution {public void reverseString(char[] s) {int head 0;int end s.length - 1;while(head < end) {char temp …...
(转)什么是DockerCompose?它有什么作用?
一、什么是DockerCompose? DockerCompose可以基于Compose文件帮我们快速的部署分布式应用,而无需手动一个个创建和运行容器。 Compose文件是一个文本文件,通过指令定义集群中的每个容器如何运行。 DockerCompose就是把DockerFile转换成指令去运行。 …...
C++ Visual Studio 2017厂商给的源码没有.sln文件 易兆微芯片下载工具加开机动画下载。
1.先用Visual Studio 2017打开Yichip YC31xx loader.vcxproj,再用Visual Studio 2022打开。再保侟就有.sln文件了。 易兆微芯片下载工具加开机动画下载 ExtraDownloadFile1Info.\logo.bin|0|0|10D2000|0 MFC应用兼容CMD 在BOOL CYichipYC31xxloaderDlg::OnIni…...
技术栈RabbitMq的介绍和使用
目录 1. 什么是消息队列?2. 消息队列的优点3. RabbitMQ 消息队列概述4. RabbitMQ 安装5. Exchange 四种类型5.1 direct 精准匹配5.2 fanout 广播5.3 topic 正则匹配 6. RabbitMQ 队列模式6.1 简单队列模式6.2 工作队列模式6.3 发布/订阅模式6.4 路由模式6.5 主题模式…...
Golang——7、包与接口详解
包与接口详解 1、Golang包详解1.1、Golang中包的定义和介绍1.2、Golang包管理工具go mod1.3、Golang中自定义包1.4、Golang中使用第三包1.5、init函数 2、接口详解2.1、接口的定义2.2、空接口2.3、类型断言2.4、结构体值接收者和指针接收者实现接口的区别2.5、一个结构体实现多…...
6️⃣Go 语言中的哈希、加密与序列化:通往区块链世界的钥匙
Go 语言中的哈希、加密与序列化:通往区块链世界的钥匙 一、前言:离区块链还有多远? 区块链听起来可能遥不可及,似乎是只有密码学专家和资深工程师才能涉足的领域。但事实上,构建一个区块链的核心并不复杂,尤其当你已经掌握了一门系统编程语言,比如 Go。 要真正理解区…...
云安全与网络安全:核心区别与协同作用解析
在数字化转型的浪潮中,云安全与网络安全作为信息安全的两大支柱,常被混淆但本质不同。本文将从概念、责任分工、技术手段、威胁类型等维度深入解析两者的差异,并探讨它们的协同作用。 一、核心区别 定义与范围 网络安全:聚焦于保…...
