图神经网络实战——分层自注意力网络
图神经网络实战——分层自注意力网络
- 0. 前言
- 1. 分层自注意力网络
- 1.1 模型架构
- 1.2 节点级注意力
- 1.3 语义级注意力
- 1.4 预测模块
- 2. 构建分层自注意力网络
- 相关链接
0. 前言
在异构图数据集上,异构图注意力网络的测试准确率为 78.39%
,比之同构版本有了较大提高,但我们还能进一步提高准确率。在本节中,我们将学习一种专门用于处理异构图的图神经网络架构,分层自注意力网络 (hierarchical self-attention network
, HAN
)。我们将介绍其工作原理,以便更好地理解该架构与经典图注意力网络 (Graph Attention Networks,GAT) 之间的区别。最后,使用 PyTorch Geometric
实现此架构,并将结果与其它 GNN
模型进行比较。
1. 分层自注意力网络
1.1 模型架构
在本节中,我们将实现一个专为处理异构图而设计的图神经网络 (Graph Neural Networks, GNN) 模型——分层自注意力网络 (hierarchical self-attention network
, HAN
)。该架构由 Liu
等人于 2021
年提出。HAN
在两个不同层次上使用自注意力:
- 节点级注意力 (
Node-level attention
):了解给定元路径中相邻节点的重要性 - 语义级注意力 (
Semantic-level attention
):了解每个元路径的重要性。这是HAN
的主要特点,它允许我们自动为给定任务选择最佳元路径。例如,在某些任务(如预测玩家人数)中,元路径game-user-game
可能比game-dev-game
更合适
接下来,我们将详细介绍 HAN
的三个主要组件——节点级注意力 (Node-level attention
)、语义级注意力 (Semantic-level attention
) 和预测模块 (prediction module
),HAN
架构如下所示。
)
1.2 节点级注意力
与图注意力网络 (Graph Attention Networks,GAT) 一样,第一步将节点投影到每个元路径的统一特征空间中。然后,用第二个权重矩阵计算同一元路径中的节点对(两个投影节点的连接)的权重,并对这一结果应用非线性函数,然后用 softmax
函数对其进行归一化处理。 j j j 节点对 i i i 节点的归一化注意力分数(重要性)计算如下:
α i j Φ = exp ( σ ( a Φ T [ W Φ h i ∣ ∣ W Φ h j ] ) ) ∑ k ∈ N i Φ exp ( σ ( a Φ T [ W Φ h i ∣ ∣ W Φ h k ] ) ) \alpha_{ij}^\Phi =\frac {\exp (\sigma(a_{\Phi}^T[W_{\Phi}h_i||W_{\Phi}h_j]))}{\sum _{k\in \mathcal N_i^{\Phi}}\exp(\sigma(a_{\Phi}^T[W_{\Phi}h_i||W_{\Phi}h_k]))} αijΦ=∑k∈NiΦexp(σ(aΦT[WΦhi∣∣WΦhk]))exp(σ(aΦT[WΦhi∣∣WΦhj]))
其中, h i h_i hi 表示 i i i 节点的特征, W Φ W_{\Phi} WΦ 是 Φ \Phi Φ 元路径的共享权重矩阵, a Φ a_{\Phi} aΦ 是 Φ \Phi Φ 元路径的注意力权重矩阵, σ σ σ 是非线性激活函数(如 LeakyReLU
), N i Φ \mathcal N_i^{\Phi} NiΦ 是节点(包括其自身)在 Φ \Phi Φ 元路径中的邻居集。使用多头注意力获得最终的嵌入:
Z i = ∣ ∣ k = 1 K σ ( ∑ k ∈ N i α i j Φ ⋅ W Φ h j ) Z_i=||_{k=1}^K\sigma(\sum _{k\in \mathcal N_i}\alpha _{ij}^{\Phi}\cdot W_{\Phi}h_j) Zi=∣∣k=1Kσ(k∈Ni∑αijΦ⋅WΦhj)
1.3 语义级注意力
对于语义级注意力,我们对每个元路径的注意力得分(表示为 β Φ 1 , β Φ 2 , . . . , β Φ p β_{\Phi _1}, β_{\Phi _2}, ... , β_{\Phi _p} βΦ1,βΦ2,...,βΦp )重复类似的过程。对于给定元路径中的每个节点嵌入(表示为 Z Φ p Z_{\Phi _p} ZΦp),都将其馈送到一个多层感知机 (Multilayer Perceptron
, MLP
) 中,应用非线性变换。将这一结果与新的注意力向量 q q q 进行比较,作为相似性度量。我们将这一结果平均化,以计算给定元路径的重要性:
w Φ p = 1 ∣ V ∣ ∑ i ∈ V q T ⋅ tanh ( W ⋅ z i Φ p + b ) w_{\Phi_p}=\frac 1{|V|}\sum_{i\in V}q^T\cdot \tanh(W\cdot z_i^{\Phi_p}+b) wΦp=∣V∣1i∈V∑qT⋅tanh(W⋅ziΦp+b)
其中, W W W (MLP
的权重矩阵)、 b b b (MLP
的偏置)和 q q q (语义级注意力向量)在元路径中是共享的。
必须对这一结果进行归一化处理,以比较不同的语义级注意力得分。使用 softmax
函数来获得最终权重:
β Φ p = exp ( w Φ p ) ∑ k = 1 P exp ( w Φ k ) \beta _{\Phi_p}=\frac {\exp(w_{\Phi_p})}{\sum_{k=1}^P\exp(w_{\Phi_k})} βΦp=∑k=1Pexp(wΦk)exp(wΦp)
将节点级注意力和语义级注意力结合起来得到最终嵌入 Z Z Z:
Z = ∑ p = 1 P β Φ p ⋅ Z Φ p Z=\sum_{p=1}^P\beta_{\Phi_p}\cdot Z_{\Phi_p} Z=p=1∑PβΦp⋅ZΦp
1.4 预测模块
最后一层(如多层感知机 (Multilayer Perceptron
, MLP
) )用于针对特定的下游任务(如节点分类或链接预测)对模型进行微调。
2. 构建分层自注意力网络
接下来,使用 PyTorch Geometric
在 DBLP 数据集上实现分层自注意力网络 (hierarchical self-attention network
, HAN
) 架构。
(1) 首先,导入 HAN
层:
import torch
import torch.nn.functional as F
from torch import nnimport torch_geometric.transforms as T
from torch_geometric.datasets import DBLP
from torch_geometric.nn import HANConv, Linear
(2) 加载 DBLP
数据集,并为会议节点引入虚拟特征:
dataset = DBLP('.')
data = dataset[0]
print(data)data['conference'].x = torch.zeros(20, 1)
(3) 使用 HANConv
的 HAN
卷积层和用于最终分类的线性层创建 HAN
类:
class HAN(nn.Module):def __init__(self, dim_in, dim_out, dim_h=128, heads=8):super().__init__()self.han = HANConv(dim_in, dim_h, heads=heads, dropout=0.6, metadata=data.metadata())self.linear = nn.Linear(dim_h, dim_out)
(4) 在 forward()
方法中,我们必须指定需要关注作者:
def forward(self, x_dict, edge_index_dict):out = self.han(x_dict, edge_index_dict)out = self.linear(out['author'])return out
(5) 使用懒初始化 (dim_in=-1)
来初始化模型,因此 PyTorch Geometric
会自动计算每个节点类型的输入大小:
model = HAN(dim_in=-1, dim_out=4)
(6) 实例化 Adam
优化器,并尝试将数据和模型传输到 GPU
:
optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=0.001)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
data, model = data.to(device), model.to(device)
(7) 实现 test()
函数计算分类任务的准确率:
@torch.no_grad()
def test(mask):model.eval()pred = model(data.x_dict, data.edge_index_dict).argmax(dim=-1)acc = (pred[mask] == data['author'].y[mask]).sum() / mask.sum()return float(acc)
(8) 对模型进行 101
个 epoch
的训练,与同构图神经网络 (Graph Neural Networks, GNN) 的训练循环唯一不同的是,需要指定关注的作者节点类型:
for epoch in range(101):model.train()optimizer.zero_grad()out = model(data.x_dict, data.edge_index_dict)mask = data['author'].train_maskloss = F.cross_entropy(out[mask], data['author'].y[mask])loss.backward()optimizer.step()if epoch % 20 == 0:train_acc = test(data['author'].train_mask)val_acc = test(data['author'].val_mask)print(f'Epoch: {epoch:>3} | Train Loss: {loss:.4f} | Train Acc: {train_acc*100:.2f}% | Val Acc: {val_acc*100:.2f}%')
训练过程如下:
(9) 最后,在测试集上测试训练后的模型:
test_acc = test(data['author'].test_mask)
print(f'Test accuracy: {test_acc*100:.2f}%')# Test accuracy: 81.58%
HAN
的测试准确率为 81.58%
,高于异构图注意力网络 (78.39%
)和经典图注意力网络 (Graph Attention Networks,GAT) (73.29%
)。这说明了构建良好的表示方法以聚合不同类型节点和关系的重要性。异构图的技术在很大程度上取决于具体应用,但尝试不同的模型对于构建高性能应能具有重要作用。
相关链接
图神经网络实战(1)——图神经网络(Graph Neural Networks, GNN)基础
图神经网络实战(6)——使用PyTorch构建图神经网络
图神经网络实战(7)——图卷积网络(Graph Convolutional Network, GCN)详解与实现
图神经网络实战(8)——图注意力网络(Graph Attention Networks, GAT)
图神经网络实战(12)——图同构网络(Graph Isomorphism Network, GIN)
图神经网络实战(18)——消息传播神经网络
相关文章:

图神经网络实战——分层自注意力网络
图神经网络实战——分层自注意力网络 0. 前言1. 分层自注意力网络1.1 模型架构1.2 节点级注意力1.3 语义级注意力1.4 预测模块 2. 构建分层自注意力网络相关链接 0. 前言 在异构图数据集上,异构图注意力网络的测试准确率为 78.39%,比之同构版本有了较大…...

基于深度学习的数字识别系统的设计与实现(python、yolov、PyQt5)
💗博主介绍💗:✌在职Java研发工程师、专注于程序设计、源码分享、技术交流、专注于Java技术领域和毕业设计✌ 温馨提示:文末有 CSDN 平台官方提供的老师 Wechat / QQ 名片 :) Java精品实战案例《700套》 2025最新毕业设计选题推荐…...

ChatGPT 提取文档内容,高效制作PPT、论文
随着人工智能生成内容(AIGC)的快速发展,利用先进的技术工具如 ChatGPT 的 RAG(Retrieval-Augmented Generation,检索增强生成)模式,可以显著提升文档内容提取和内容创作的效率。以下将详细介绍如…...

3、等保1.0 与 2.0 的区别
数据来源:3.等保1.0和2.0的区别_哔哩哔哩_bilibili 等保1.0时代VS等保2.0时代五个规定动作:定级、备案、建设整改、等级测评、监督检查工作内容维持5个规定动作,增加风险评估、安全监测、通报预警、事件调查、数据防护自主可控、供应链安全、…...

Angular面试题九
一、在Angular中,你如何管理全局状态或跨组件共享数据?有哪些常见的实现方式? 在Angular中,管理全局状态或跨组件共享数据是应用开发中的一个重要方面。这有助于保持数据的一致性和可维护性,特别是在复杂的应用中。以下…...

(转载)智能指针shared_ptr从C++11到C++20
shared_ptr和动态数组 - apocelipes - 博客园 (cnblogs.com) template<typename T> std::shared_ptr<T> make_shared_array(size_t size) { return std::shared_ptr<T>(new T[size],std::default_delete<T[]>()); } std::shar…...

Ubuntu 上安装 Miniconda
一、下载 Miniconda 打开终端。访问 Anaconda 官方仓库下载页面https://repo.anaconda.com/miniconda/选择Miniconda3-py310_24.7.1-0-Linux-x86_64.sh,进行下载。文件名当中的py310_24.7.1表示,在 conda 的默认的 base 环境中的 Python 版本是3.10&…...

【Vue系列五】—Vue学习历程的知识分享!
前言 本篇文章讲述前端工程化从模块化到如今的脚手架的发展,以及Webpack、Vue脚手架的详解! 一、模块化 模块化就是把单独的功能封装到模块(文件)中,模块之间相互隔离,但可以通过特定的接口公开内部成员…...

CaLM 因果推理评测体系:如何让大模型更贴近人类认知水平?
CaLM 是什么 CaLM(Causal Evaluation of Language Models,以下简称“CaLM”)是上海人工智能实验室联合同济大学、上海交通大学、北京大学及商汤科技发布首个大模型因果推理开放评测体系及开放平台。首次从因果推理角度提出评估框架ÿ…...

深入探索卷积神经网络(CNN)
深入探索卷积神经网络(CNN) 前言图像的数字表示灰度图像RGB图像 卷积神经网络(CNN)的架构基本组件卷积操作填充(Padding)步幅(Strides) 多通道图像的卷积池化层全连接层 CNN与全连接…...

【C++篇】手撕 C++ string 类:从零实现到深入剖析的模拟之路
文章目录 C string 类的模拟实现:从构造到高级操作前言第一章:为什么要手写 C string 类?1.1 理由与价值 第二章:实现一个简单的 string 类2.1 基本构造与析构2.1.1 示例代码:基础的 string 类实现2.1.2 解读代码 2.2 …...

毕业设计选题:基于ssm+vue+uniapp的校园失物招领小程序
开发语言:Java框架:ssmuniappJDK版本:JDK1.8服务器:tomcat7数据库:mysql 5.7(一定要5.7版本)数据库工具:Navicat11开发软件:eclipse/myeclipse/ideaMaven包:M…...

[系统设计总结] - Proximity Service算法介绍
问题描述 Proximity Service广泛应用于各种地图相关的服务中比如外卖,大众点评,Uber打车,Google地图中,其中比较关键的是我们根据用户的位置来快速找到附近的餐厅,司机,外卖员也就是就近查询算法。 主流的…...

变压吸附制氧机的应用范围
变压吸附制氧机是一种利用变压吸附技术从空气中分离出氧气的设备。该技术通过吸附剂在不同压力下的吸附与解吸性能,实现了氧气的有效分离和纯化。 工业领域 在工业领域,变压吸附制氧机同样具有广泛的应用。首先,钢铁企业在生产过程中需要大量…...

MATLAB绘图基础8:双变量图形绘制
参考书:《 M A T L A B {\rm MATLAB} MATLAB与学术图表绘制》(关东升)。 8.双变量图形绘制 8.1 散点图 散点图用于显示两个变量间的关系,每个数据点在图上表示为一个点,一个变量在 X {\rm X} X轴,一个变量在 Y {\rm Y} Y轴&#…...

Appium高级话题:混合应用与原生应用测试策略
Appium高级话题:混合应用与原生应用测试策略 在移动应用开发领域,混合应用与原生应用各有千秋,但它们的测试策略却大相径庭。本文旨在深入探讨这两种应用类型的测试挑战,并介绍如何利用自动化测试软件ItBuilder高效解决这些问题&…...

windows源码安装protobuf,opencv,ncnn
安装笔记 cmake 在windows可以使用-G"MinGW Makefiles" 搭配make使用,install出来的lib文件时.a结尾的,适合linux下面使用。所以在windows上若无需求使用-G"NMake Makefiles" 搭配nmake。 但是windows上使用-G"NMake Makefil…...

MicroPython 怎么搭建工程代码
在MicroPython中搭建工程代码可以遵循以下步骤: 1. 准备工作 安装MicroPython固件:确保已经将MicroPython烧录到ESP32开发板中。准备开发环境: 可以使用文本编辑器(如VS Code、Thonny、uPyCraft等)来编写代码。 2.…...

Android studio安装问题及解决方案
Android studio安装问题及解决方案 gradle已经安装好了,但是每次就是找不到gradle的位置,每次要重新下载,很慢,每次都不成功 我尝试用安装android studio时自带的卸载程序,卸载android studio,然后重新下…...

前端面试题(二)
6. 深入 JavaScript this 关键字的指向是什么? this 的指向是在函数执行时决定的。默认情况下,非严格模式下 this 指向全局对象(浏览器中为 window),严格模式下 this 为 undefined。在对象方法中,this 通常…...

【C++】stack和queue的使用及模拟实现
stack就是栈的意思,这个结构遵循后进先出(LIFO)的原则,可以将栈想象为一个子弹夹,先进去的子弹后出来。 queue就是队列的意思,这个结构遵循先进先出(FIFO)的原则,可以将对列想象成我们排队买饭的场景,先排…...

MongoDB解说
MongoDB 是一个流行的开源 NoSQL 数据库,它使用了一种被称为文档存储的数据库模型。 与传统的关系型数据库管理系统(RDBMS)不同,MongoDB 不使用表格来存储数据,而是使用了一种更为灵活的格式——JSON 样式的文档。 这…...

问:JAVA中唤醒阻塞的线程有哪些?
在Java中,唤醒阻塞线程的方法有多种,以下是常见的线程唤醒方法。 唤醒方法 使用notify()和notifyAll()方法 synchronized (obj) {obj.notify(); // 唤醒单个等待线程// obj.notifyAll(); // 唤醒所有等待线程 }使用interrupt()方法 Thread thread n…...

Github Webhook触发Jenkins自动构建
1.功能说明 Github Webhook可以触发Jenkins自动构建,通过配置Github Webhook,每次代码变更之后(例如push操作),Webhook会自动通知Jenkins服务器,Jenkins会自动执行预定义的构建任务(如Jenkins …...

ESP32-WROOM-32 [创建AP站点-客户端-TCP透传]
简介 基于ESP32-WROOM-32 开篇(刚买), 本篇讲的是基于固件 ESP32-WROOM-32-AT-V3.4.0.0(内含用户指南, 有AT指令说明)的TCP透传设置与使用 设备连接 TTL转USB线, 接ESP32 板 的 GND,RX2, TX2 指令介绍 注意,下面指…...

新闻文本分类识别系统Python+卷积神经网络算法+人工智能+深度学习+计算机毕设项目+TensorFlow+Django网页界面
一、介绍 文本分类识别系统。本系统使用Python作为主要开发语言,首先收集了10种中文文本数据集(“体育类”, “财经类”, “房产类”, “家居类”, “教育类”, “科技类”, “时尚类”, “时政类”, “游戏类”, “娱乐类”),然…...

Java使用Map数据结构配合函数式接口存储方法引用
Java使用Map数据结构配合函数式接口存储方法引用 背景 需求中存在这样一直情况 一个国家下面有很多的州 每个州对应的计算日期方法是不同的 这个时候 就面临 可能会有很多if else 为了后期维护尽量还是不想采用这个方式,那么就可以使用策略模式 但是 使用策略带来的…...

LeetCode:2207. 字符串中最多数目的子序列(Java)
目录 2207. 字符串中最多数目的子序列 题目描述: 实现代码与解析: 遍历: 原理思路: 2207. 字符串中最多数目的子序列 题目描述: 给你一个下标从 0 开始的字符串 text 和另一个下标从 0 开始且长度为 2 的字符串 p…...

win10开机自启动方案总汇
win10开机自启动方案总汇 一、开始文件目录添加二、添加注册表启动程序三、服务启动3.1. 将程序注册为服务使用命令行创建服务设置服务启动类型启动服务 3.2. 使用 Windows 服务管理器配置服务3.3. 删除服务 四、定时任务或程序4.1 设置程序自启动(使用任务计划程序…...

【自动驾驶】基于车辆几何模型的横向控制算法 | Stanley 算法详解与编程实现
写在前面: 🌟 欢迎光临 清流君 的博客小天地,这里是我分享技术与心得的温馨角落。📝 个人主页:清流君_CSDN博客,期待与您一同探索 移动机器人 领域的无限可能。 🔍 本文系 清流君 原创之作&…...