图神经网络实战——分层自注意力网络
图神经网络实战——分层自注意力网络
- 0. 前言
- 1. 分层自注意力网络
- 1.1 模型架构
- 1.2 节点级注意力
- 1.3 语义级注意力
- 1.4 预测模块
- 2. 构建分层自注意力网络
- 相关链接
0. 前言
在异构图数据集上,异构图注意力网络的测试准确率为 78.39%,比之同构版本有了较大提高,但我们还能进一步提高准确率。在本节中,我们将学习一种专门用于处理异构图的图神经网络架构,分层自注意力网络 (hierarchical self-attention network, HAN)。我们将介绍其工作原理,以便更好地理解该架构与经典图注意力网络 (Graph Attention Networks,GAT) 之间的区别。最后,使用 PyTorch Geometric 实现此架构,并将结果与其它 GNN 模型进行比较。
1. 分层自注意力网络
1.1 模型架构
在本节中,我们将实现一个专为处理异构图而设计的图神经网络 (Graph Neural Networks, GNN) 模型——分层自注意力网络 (hierarchical self-attention network, HAN)。该架构由 Liu 等人于 2021 年提出。HAN 在两个不同层次上使用自注意力:
- 节点级注意力 (
Node-level attention):了解给定元路径中相邻节点的重要性 - 语义级注意力 (
Semantic-level attention):了解每个元路径的重要性。这是HAN的主要特点,它允许我们自动为给定任务选择最佳元路径。例如,在某些任务(如预测玩家人数)中,元路径game-user-game可能比game-dev-game更合适
接下来,我们将详细介绍 HAN 的三个主要组件——节点级注意力 (Node-level attention)、语义级注意力 (Semantic-level attention) 和预测模块 (prediction module),HAN 架构如下所示。
)
1.2 节点级注意力
与图注意力网络 (Graph Attention Networks,GAT) 一样,第一步将节点投影到每个元路径的统一特征空间中。然后,用第二个权重矩阵计算同一元路径中的节点对(两个投影节点的连接)的权重,并对这一结果应用非线性函数,然后用 softmax 函数对其进行归一化处理。 j j j 节点对 i i i 节点的归一化注意力分数(重要性)计算如下:
α i j Φ = exp ( σ ( a Φ T [ W Φ h i ∣ ∣ W Φ h j ] ) ) ∑ k ∈ N i Φ exp ( σ ( a Φ T [ W Φ h i ∣ ∣ W Φ h k ] ) ) \alpha_{ij}^\Phi =\frac {\exp (\sigma(a_{\Phi}^T[W_{\Phi}h_i||W_{\Phi}h_j]))}{\sum _{k\in \mathcal N_i^{\Phi}}\exp(\sigma(a_{\Phi}^T[W_{\Phi}h_i||W_{\Phi}h_k]))} αijΦ=∑k∈NiΦexp(σ(aΦT[WΦhi∣∣WΦhk]))exp(σ(aΦT[WΦhi∣∣WΦhj]))
其中, h i h_i hi 表示 i i i 节点的特征, W Φ W_{\Phi} WΦ 是 Φ \Phi Φ 元路径的共享权重矩阵, a Φ a_{\Phi} aΦ 是 Φ \Phi Φ 元路径的注意力权重矩阵, σ σ σ 是非线性激活函数(如 LeakyReLU), N i Φ \mathcal N_i^{\Phi} NiΦ 是节点(包括其自身)在 Φ \Phi Φ 元路径中的邻居集。使用多头注意力获得最终的嵌入:
Z i = ∣ ∣ k = 1 K σ ( ∑ k ∈ N i α i j Φ ⋅ W Φ h j ) Z_i=||_{k=1}^K\sigma(\sum _{k\in \mathcal N_i}\alpha _{ij}^{\Phi}\cdot W_{\Phi}h_j) Zi=∣∣k=1Kσ(k∈Ni∑αijΦ⋅WΦhj)
1.3 语义级注意力
对于语义级注意力,我们对每个元路径的注意力得分(表示为 β Φ 1 , β Φ 2 , . . . , β Φ p β_{\Phi _1}, β_{\Phi _2}, ... , β_{\Phi _p} βΦ1,βΦ2,...,βΦp )重复类似的过程。对于给定元路径中的每个节点嵌入(表示为 Z Φ p Z_{\Phi _p} ZΦp),都将其馈送到一个多层感知机 (Multilayer Perceptron, MLP) 中,应用非线性变换。将这一结果与新的注意力向量 q q q 进行比较,作为相似性度量。我们将这一结果平均化,以计算给定元路径的重要性:
w Φ p = 1 ∣ V ∣ ∑ i ∈ V q T ⋅ tanh ( W ⋅ z i Φ p + b ) w_{\Phi_p}=\frac 1{|V|}\sum_{i\in V}q^T\cdot \tanh(W\cdot z_i^{\Phi_p}+b) wΦp=∣V∣1i∈V∑qT⋅tanh(W⋅ziΦp+b)
其中, W W W (MLP 的权重矩阵)、 b b b (MLP 的偏置)和 q q q (语义级注意力向量)在元路径中是共享的。
必须对这一结果进行归一化处理,以比较不同的语义级注意力得分。使用 softmax 函数来获得最终权重:
β Φ p = exp ( w Φ p ) ∑ k = 1 P exp ( w Φ k ) \beta _{\Phi_p}=\frac {\exp(w_{\Phi_p})}{\sum_{k=1}^P\exp(w_{\Phi_k})} βΦp=∑k=1Pexp(wΦk)exp(wΦp)
将节点级注意力和语义级注意力结合起来得到最终嵌入 Z Z Z:
Z = ∑ p = 1 P β Φ p ⋅ Z Φ p Z=\sum_{p=1}^P\beta_{\Phi_p}\cdot Z_{\Phi_p} Z=p=1∑PβΦp⋅ZΦp
1.4 预测模块
最后一层(如多层感知机 (Multilayer Perceptron, MLP) )用于针对特定的下游任务(如节点分类或链接预测)对模型进行微调。
2. 构建分层自注意力网络
接下来,使用 PyTorch Geometric 在 DBLP 数据集上实现分层自注意力网络 (hierarchical self-attention network, HAN) 架构。
(1) 首先,导入 HAN 层:
import torch
import torch.nn.functional as F
from torch import nnimport torch_geometric.transforms as T
from torch_geometric.datasets import DBLP
from torch_geometric.nn import HANConv, Linear
(2) 加载 DBLP 数据集,并为会议节点引入虚拟特征:
dataset = DBLP('.')
data = dataset[0]
print(data)data['conference'].x = torch.zeros(20, 1)

(3) 使用 HANConv 的 HAN 卷积层和用于最终分类的线性层创建 HAN 类:
class HAN(nn.Module):def __init__(self, dim_in, dim_out, dim_h=128, heads=8):super().__init__()self.han = HANConv(dim_in, dim_h, heads=heads, dropout=0.6, metadata=data.metadata())self.linear = nn.Linear(dim_h, dim_out)
(4) 在 forward() 方法中,我们必须指定需要关注作者:
def forward(self, x_dict, edge_index_dict):out = self.han(x_dict, edge_index_dict)out = self.linear(out['author'])return out
(5) 使用懒初始化 (dim_in=-1) 来初始化模型,因此 PyTorch Geometric 会自动计算每个节点类型的输入大小:
model = HAN(dim_in=-1, dim_out=4)
(6) 实例化 Adam 优化器,并尝试将数据和模型传输到 GPU:
optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=0.001)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
data, model = data.to(device), model.to(device)
(7) 实现 test() 函数计算分类任务的准确率:
@torch.no_grad()
def test(mask):model.eval()pred = model(data.x_dict, data.edge_index_dict).argmax(dim=-1)acc = (pred[mask] == data['author'].y[mask]).sum() / mask.sum()return float(acc)
(8) 对模型进行 101 个 epoch 的训练,与同构图神经网络 (Graph Neural Networks, GNN) 的训练循环唯一不同的是,需要指定关注的作者节点类型:
for epoch in range(101):model.train()optimizer.zero_grad()out = model(data.x_dict, data.edge_index_dict)mask = data['author'].train_maskloss = F.cross_entropy(out[mask], data['author'].y[mask])loss.backward()optimizer.step()if epoch % 20 == 0:train_acc = test(data['author'].train_mask)val_acc = test(data['author'].val_mask)print(f'Epoch: {epoch:>3} | Train Loss: {loss:.4f} | Train Acc: {train_acc*100:.2f}% | Val Acc: {val_acc*100:.2f}%')
训练过程如下:

(9) 最后,在测试集上测试训练后的模型:
test_acc = test(data['author'].test_mask)
print(f'Test accuracy: {test_acc*100:.2f}%')# Test accuracy: 81.58%
HAN 的测试准确率为 81.58%,高于异构图注意力网络 (78.39%)和经典图注意力网络 (Graph Attention Networks,GAT) (73.29%)。这说明了构建良好的表示方法以聚合不同类型节点和关系的重要性。异构图的技术在很大程度上取决于具体应用,但尝试不同的模型对于构建高性能应能具有重要作用。
相关链接
图神经网络实战(1)——图神经网络(Graph Neural Networks, GNN)基础
图神经网络实战(6)——使用PyTorch构建图神经网络
图神经网络实战(7)——图卷积网络(Graph Convolutional Network, GCN)详解与实现
图神经网络实战(8)——图注意力网络(Graph Attention Networks, GAT)
图神经网络实战(12)——图同构网络(Graph Isomorphism Network, GIN)
图神经网络实战(18)——消息传播神经网络
相关文章:
图神经网络实战——分层自注意力网络
图神经网络实战——分层自注意力网络 0. 前言1. 分层自注意力网络1.1 模型架构1.2 节点级注意力1.3 语义级注意力1.4 预测模块 2. 构建分层自注意力网络相关链接 0. 前言 在异构图数据集上,异构图注意力网络的测试准确率为 78.39%,比之同构版本有了较大…...
基于深度学习的数字识别系统的设计与实现(python、yolov、PyQt5)
💗博主介绍💗:✌在职Java研发工程师、专注于程序设计、源码分享、技术交流、专注于Java技术领域和毕业设计✌ 温馨提示:文末有 CSDN 平台官方提供的老师 Wechat / QQ 名片 :) Java精品实战案例《700套》 2025最新毕业设计选题推荐…...
ChatGPT 提取文档内容,高效制作PPT、论文
随着人工智能生成内容(AIGC)的快速发展,利用先进的技术工具如 ChatGPT 的 RAG(Retrieval-Augmented Generation,检索增强生成)模式,可以显著提升文档内容提取和内容创作的效率。以下将详细介绍如…...
3、等保1.0 与 2.0 的区别
数据来源:3.等保1.0和2.0的区别_哔哩哔哩_bilibili 等保1.0时代VS等保2.0时代五个规定动作:定级、备案、建设整改、等级测评、监督检查工作内容维持5个规定动作,增加风险评估、安全监测、通报预警、事件调查、数据防护自主可控、供应链安全、…...
Angular面试题九
一、在Angular中,你如何管理全局状态或跨组件共享数据?有哪些常见的实现方式? 在Angular中,管理全局状态或跨组件共享数据是应用开发中的一个重要方面。这有助于保持数据的一致性和可维护性,特别是在复杂的应用中。以下…...
(转载)智能指针shared_ptr从C++11到C++20
shared_ptr和动态数组 - apocelipes - 博客园 (cnblogs.com) template<typename T> std::shared_ptr<T> make_shared_array(size_t size) { return std::shared_ptr<T>(new T[size],std::default_delete<T[]>()); } std::shar…...
Ubuntu 上安装 Miniconda
一、下载 Miniconda 打开终端。访问 Anaconda 官方仓库下载页面https://repo.anaconda.com/miniconda/选择Miniconda3-py310_24.7.1-0-Linux-x86_64.sh,进行下载。文件名当中的py310_24.7.1表示,在 conda 的默认的 base 环境中的 Python 版本是3.10&…...
【Vue系列五】—Vue学习历程的知识分享!
前言 本篇文章讲述前端工程化从模块化到如今的脚手架的发展,以及Webpack、Vue脚手架的详解! 一、模块化 模块化就是把单独的功能封装到模块(文件)中,模块之间相互隔离,但可以通过特定的接口公开内部成员…...
CaLM 因果推理评测体系:如何让大模型更贴近人类认知水平?
CaLM 是什么 CaLM(Causal Evaluation of Language Models,以下简称“CaLM”)是上海人工智能实验室联合同济大学、上海交通大学、北京大学及商汤科技发布首个大模型因果推理开放评测体系及开放平台。首次从因果推理角度提出评估框架ÿ…...
深入探索卷积神经网络(CNN)
深入探索卷积神经网络(CNN) 前言图像的数字表示灰度图像RGB图像 卷积神经网络(CNN)的架构基本组件卷积操作填充(Padding)步幅(Strides) 多通道图像的卷积池化层全连接层 CNN与全连接…...
【C++篇】手撕 C++ string 类:从零实现到深入剖析的模拟之路
文章目录 C string 类的模拟实现:从构造到高级操作前言第一章:为什么要手写 C string 类?1.1 理由与价值 第二章:实现一个简单的 string 类2.1 基本构造与析构2.1.1 示例代码:基础的 string 类实现2.1.2 解读代码 2.2 …...
毕业设计选题:基于ssm+vue+uniapp的校园失物招领小程序
开发语言:Java框架:ssmuniappJDK版本:JDK1.8服务器:tomcat7数据库:mysql 5.7(一定要5.7版本)数据库工具:Navicat11开发软件:eclipse/myeclipse/ideaMaven包:M…...
[系统设计总结] - Proximity Service算法介绍
问题描述 Proximity Service广泛应用于各种地图相关的服务中比如外卖,大众点评,Uber打车,Google地图中,其中比较关键的是我们根据用户的位置来快速找到附近的餐厅,司机,外卖员也就是就近查询算法。 主流的…...
变压吸附制氧机的应用范围
变压吸附制氧机是一种利用变压吸附技术从空气中分离出氧气的设备。该技术通过吸附剂在不同压力下的吸附与解吸性能,实现了氧气的有效分离和纯化。 工业领域 在工业领域,变压吸附制氧机同样具有广泛的应用。首先,钢铁企业在生产过程中需要大量…...
MATLAB绘图基础8:双变量图形绘制
参考书:《 M A T L A B {\rm MATLAB} MATLAB与学术图表绘制》(关东升)。 8.双变量图形绘制 8.1 散点图 散点图用于显示两个变量间的关系,每个数据点在图上表示为一个点,一个变量在 X {\rm X} X轴,一个变量在 Y {\rm Y} Y轴&#…...
Appium高级话题:混合应用与原生应用测试策略
Appium高级话题:混合应用与原生应用测试策略 在移动应用开发领域,混合应用与原生应用各有千秋,但它们的测试策略却大相径庭。本文旨在深入探讨这两种应用类型的测试挑战,并介绍如何利用自动化测试软件ItBuilder高效解决这些问题&…...
windows源码安装protobuf,opencv,ncnn
安装笔记 cmake 在windows可以使用-G"MinGW Makefiles" 搭配make使用,install出来的lib文件时.a结尾的,适合linux下面使用。所以在windows上若无需求使用-G"NMake Makefiles" 搭配nmake。 但是windows上使用-G"NMake Makefil…...
MicroPython 怎么搭建工程代码
在MicroPython中搭建工程代码可以遵循以下步骤: 1. 准备工作 安装MicroPython固件:确保已经将MicroPython烧录到ESP32开发板中。准备开发环境: 可以使用文本编辑器(如VS Code、Thonny、uPyCraft等)来编写代码。 2.…...
Android studio安装问题及解决方案
Android studio安装问题及解决方案 gradle已经安装好了,但是每次就是找不到gradle的位置,每次要重新下载,很慢,每次都不成功 我尝试用安装android studio时自带的卸载程序,卸载android studio,然后重新下…...
前端面试题(二)
6. 深入 JavaScript this 关键字的指向是什么? this 的指向是在函数执行时决定的。默认情况下,非严格模式下 this 指向全局对象(浏览器中为 window),严格模式下 this 为 undefined。在对象方法中,this 通常…...
未来机器人的大脑:如何用神经网络模拟器实现更智能的决策?
编辑:陈萍萍的公主一点人工一点智能 未来机器人的大脑:如何用神经网络模拟器实现更智能的决策?RWM通过双自回归机制有效解决了复合误差、部分可观测性和随机动力学等关键挑战,在不依赖领域特定归纳偏见的条件下实现了卓越的预测准…...
Android Wi-Fi 连接失败日志分析
1. Android wifi 关键日志总结 (1) Wi-Fi 断开 (CTRL-EVENT-DISCONNECTED reason3) 日志相关部分: 06-05 10:48:40.987 943 943 I wpa_supplicant: wlan0: CTRL-EVENT-DISCONNECTED bssid44:9b:c1:57:a8:90 reason3 locally_generated1解析: CTR…...
Java多线程实现之Thread类深度解析
Java多线程实现之Thread类深度解析 一、多线程基础概念1.1 什么是线程1.2 多线程的优势1.3 Java多线程模型 二、Thread类的基本结构与构造函数2.1 Thread类的继承关系2.2 构造函数 三、创建和启动线程3.1 继承Thread类创建线程3.2 实现Runnable接口创建线程 四、Thread类的核心…...
dify打造数据可视化图表
一、概述 在日常工作和学习中,我们经常需要和数据打交道。无论是分析报告、项目展示,还是简单的数据洞察,一个清晰直观的图表,往往能胜过千言万语。 一款能让数据可视化变得超级简单的 MCP Server,由蚂蚁集团 AntV 团队…...
QT3D学习笔记——圆台、圆锥
类名作用Qt3DWindow3D渲染窗口容器QEntity场景中的实体(对象或容器)QCamera控制观察视角QPointLight点光源QConeMesh圆锥几何网格QTransform控制实体的位置/旋转/缩放QPhongMaterialPhong光照材质(定义颜色、反光等)QFirstPersonC…...
处理vxe-table 表尾数据是单独一个接口,表格tableData数据更新后,需要点击两下,表尾才是正确的
修改bug思路: 分别把 tabledata 和 表尾相关数据 console.log() 发现 更新数据先后顺序不对 settimeout延迟查询表格接口 ——测试可行 升级↑:async await 等接口返回后再开始下一个接口查询 ________________________________________________________…...
关于uniapp展示PDF的解决方案
在 UniApp 的 H5 环境中使用 pdf-vue3 组件可以实现完整的 PDF 预览功能。以下是详细实现步骤和注意事项: 一、安装依赖 安装 pdf-vue3 和 PDF.js 核心库: npm install pdf-vue3 pdfjs-dist二、基本使用示例 <template><view class"con…...
Web后端基础(基础知识)
BS架构:Browser/Server,浏览器/服务器架构模式。客户端只需要浏览器,应用程序的逻辑和数据都存储在服务端。 优点:维护方便缺点:体验一般 CS架构:Client/Server,客户端/服务器架构模式。需要单独…...
c++第七天 继承与派生2
这一篇文章主要内容是 派生类构造函数与析构函数 在派生类中重写基类成员 以及多继承 第一部分:派生类构造函数与析构函数 当创建一个派生类对象时,基类成员是如何初始化的? 1.当派生类对象创建的时候,基类成员的初始化顺序 …...
GraphQL 实战篇:Apollo Client 配置与缓存
GraphQL 实战篇:Apollo Client 配置与缓存 上一篇:GraphQL 入门篇:基础查询语法 依旧和上一篇的笔记一样,主实操,没啥过多的细节讲解,代码具体在: https://github.com/GoldenaArcher/graphql…...
