图神经网络实战——分层自注意力网络
图神经网络实战——分层自注意力网络
- 0. 前言
- 1. 分层自注意力网络
- 1.1 模型架构
- 1.2 节点级注意力
- 1.3 语义级注意力
- 1.4 预测模块
- 2. 构建分层自注意力网络
- 相关链接
0. 前言
在异构图数据集上,异构图注意力网络的测试准确率为 78.39%,比之同构版本有了较大提高,但我们还能进一步提高准确率。在本节中,我们将学习一种专门用于处理异构图的图神经网络架构,分层自注意力网络 (hierarchical self-attention network, HAN)。我们将介绍其工作原理,以便更好地理解该架构与经典图注意力网络 (Graph Attention Networks,GAT) 之间的区别。最后,使用 PyTorch Geometric 实现此架构,并将结果与其它 GNN 模型进行比较。
1. 分层自注意力网络
1.1 模型架构
在本节中,我们将实现一个专为处理异构图而设计的图神经网络 (Graph Neural Networks, GNN) 模型——分层自注意力网络 (hierarchical self-attention network, HAN)。该架构由 Liu 等人于 2021 年提出。HAN 在两个不同层次上使用自注意力:
- 节点级注意力 (
Node-level attention):了解给定元路径中相邻节点的重要性 - 语义级注意力 (
Semantic-level attention):了解每个元路径的重要性。这是HAN的主要特点,它允许我们自动为给定任务选择最佳元路径。例如,在某些任务(如预测玩家人数)中,元路径game-user-game可能比game-dev-game更合适
接下来,我们将详细介绍 HAN 的三个主要组件——节点级注意力 (Node-level attention)、语义级注意力 (Semantic-level attention) 和预测模块 (prediction module),HAN 架构如下所示。
)
1.2 节点级注意力
与图注意力网络 (Graph Attention Networks,GAT) 一样,第一步将节点投影到每个元路径的统一特征空间中。然后,用第二个权重矩阵计算同一元路径中的节点对(两个投影节点的连接)的权重,并对这一结果应用非线性函数,然后用 softmax 函数对其进行归一化处理。 j j j 节点对 i i i 节点的归一化注意力分数(重要性)计算如下:
α i j Φ = exp ( σ ( a Φ T [ W Φ h i ∣ ∣ W Φ h j ] ) ) ∑ k ∈ N i Φ exp ( σ ( a Φ T [ W Φ h i ∣ ∣ W Φ h k ] ) ) \alpha_{ij}^\Phi =\frac {\exp (\sigma(a_{\Phi}^T[W_{\Phi}h_i||W_{\Phi}h_j]))}{\sum _{k\in \mathcal N_i^{\Phi}}\exp(\sigma(a_{\Phi}^T[W_{\Phi}h_i||W_{\Phi}h_k]))} αijΦ=∑k∈NiΦexp(σ(aΦT[WΦhi∣∣WΦhk]))exp(σ(aΦT[WΦhi∣∣WΦhj]))
其中, h i h_i hi 表示 i i i 节点的特征, W Φ W_{\Phi} WΦ 是 Φ \Phi Φ 元路径的共享权重矩阵, a Φ a_{\Phi} aΦ 是 Φ \Phi Φ 元路径的注意力权重矩阵, σ σ σ 是非线性激活函数(如 LeakyReLU), N i Φ \mathcal N_i^{\Phi} NiΦ 是节点(包括其自身)在 Φ \Phi Φ 元路径中的邻居集。使用多头注意力获得最终的嵌入:
Z i = ∣ ∣ k = 1 K σ ( ∑ k ∈ N i α i j Φ ⋅ W Φ h j ) Z_i=||_{k=1}^K\sigma(\sum _{k\in \mathcal N_i}\alpha _{ij}^{\Phi}\cdot W_{\Phi}h_j) Zi=∣∣k=1Kσ(k∈Ni∑αijΦ⋅WΦhj)
1.3 语义级注意力
对于语义级注意力,我们对每个元路径的注意力得分(表示为 β Φ 1 , β Φ 2 , . . . , β Φ p β_{\Phi _1}, β_{\Phi _2}, ... , β_{\Phi _p} βΦ1,βΦ2,...,βΦp )重复类似的过程。对于给定元路径中的每个节点嵌入(表示为 Z Φ p Z_{\Phi _p} ZΦp),都将其馈送到一个多层感知机 (Multilayer Perceptron, MLP) 中,应用非线性变换。将这一结果与新的注意力向量 q q q 进行比较,作为相似性度量。我们将这一结果平均化,以计算给定元路径的重要性:
w Φ p = 1 ∣ V ∣ ∑ i ∈ V q T ⋅ tanh ( W ⋅ z i Φ p + b ) w_{\Phi_p}=\frac 1{|V|}\sum_{i\in V}q^T\cdot \tanh(W\cdot z_i^{\Phi_p}+b) wΦp=∣V∣1i∈V∑qT⋅tanh(W⋅ziΦp+b)
其中, W W W (MLP 的权重矩阵)、 b b b (MLP 的偏置)和 q q q (语义级注意力向量)在元路径中是共享的。
必须对这一结果进行归一化处理,以比较不同的语义级注意力得分。使用 softmax 函数来获得最终权重:
β Φ p = exp ( w Φ p ) ∑ k = 1 P exp ( w Φ k ) \beta _{\Phi_p}=\frac {\exp(w_{\Phi_p})}{\sum_{k=1}^P\exp(w_{\Phi_k})} βΦp=∑k=1Pexp(wΦk)exp(wΦp)
将节点级注意力和语义级注意力结合起来得到最终嵌入 Z Z Z:
Z = ∑ p = 1 P β Φ p ⋅ Z Φ p Z=\sum_{p=1}^P\beta_{\Phi_p}\cdot Z_{\Phi_p} Z=p=1∑PβΦp⋅ZΦp
1.4 预测模块
最后一层(如多层感知机 (Multilayer Perceptron, MLP) )用于针对特定的下游任务(如节点分类或链接预测)对模型进行微调。
2. 构建分层自注意力网络
接下来,使用 PyTorch Geometric 在 DBLP 数据集上实现分层自注意力网络 (hierarchical self-attention network, HAN) 架构。
(1) 首先,导入 HAN 层:
import torch
import torch.nn.functional as F
from torch import nnimport torch_geometric.transforms as T
from torch_geometric.datasets import DBLP
from torch_geometric.nn import HANConv, Linear
(2) 加载 DBLP 数据集,并为会议节点引入虚拟特征:
dataset = DBLP('.')
data = dataset[0]
print(data)data['conference'].x = torch.zeros(20, 1)

(3) 使用 HANConv 的 HAN 卷积层和用于最终分类的线性层创建 HAN 类:
class HAN(nn.Module):def __init__(self, dim_in, dim_out, dim_h=128, heads=8):super().__init__()self.han = HANConv(dim_in, dim_h, heads=heads, dropout=0.6, metadata=data.metadata())self.linear = nn.Linear(dim_h, dim_out)
(4) 在 forward() 方法中,我们必须指定需要关注作者:
def forward(self, x_dict, edge_index_dict):out = self.han(x_dict, edge_index_dict)out = self.linear(out['author'])return out
(5) 使用懒初始化 (dim_in=-1) 来初始化模型,因此 PyTorch Geometric 会自动计算每个节点类型的输入大小:
model = HAN(dim_in=-1, dim_out=4)
(6) 实例化 Adam 优化器,并尝试将数据和模型传输到 GPU:
optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=0.001)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
data, model = data.to(device), model.to(device)
(7) 实现 test() 函数计算分类任务的准确率:
@torch.no_grad()
def test(mask):model.eval()pred = model(data.x_dict, data.edge_index_dict).argmax(dim=-1)acc = (pred[mask] == data['author'].y[mask]).sum() / mask.sum()return float(acc)
(8) 对模型进行 101 个 epoch 的训练,与同构图神经网络 (Graph Neural Networks, GNN) 的训练循环唯一不同的是,需要指定关注的作者节点类型:
for epoch in range(101):model.train()optimizer.zero_grad()out = model(data.x_dict, data.edge_index_dict)mask = data['author'].train_maskloss = F.cross_entropy(out[mask], data['author'].y[mask])loss.backward()optimizer.step()if epoch % 20 == 0:train_acc = test(data['author'].train_mask)val_acc = test(data['author'].val_mask)print(f'Epoch: {epoch:>3} | Train Loss: {loss:.4f} | Train Acc: {train_acc*100:.2f}% | Val Acc: {val_acc*100:.2f}%')
训练过程如下:

(9) 最后,在测试集上测试训练后的模型:
test_acc = test(data['author'].test_mask)
print(f'Test accuracy: {test_acc*100:.2f}%')# Test accuracy: 81.58%
HAN 的测试准确率为 81.58%,高于异构图注意力网络 (78.39%)和经典图注意力网络 (Graph Attention Networks,GAT) (73.29%)。这说明了构建良好的表示方法以聚合不同类型节点和关系的重要性。异构图的技术在很大程度上取决于具体应用,但尝试不同的模型对于构建高性能应能具有重要作用。
相关链接
图神经网络实战(1)——图神经网络(Graph Neural Networks, GNN)基础
图神经网络实战(6)——使用PyTorch构建图神经网络
图神经网络实战(7)——图卷积网络(Graph Convolutional Network, GCN)详解与实现
图神经网络实战(8)——图注意力网络(Graph Attention Networks, GAT)
图神经网络实战(12)——图同构网络(Graph Isomorphism Network, GIN)
图神经网络实战(18)——消息传播神经网络
相关文章:
图神经网络实战——分层自注意力网络
图神经网络实战——分层自注意力网络 0. 前言1. 分层自注意力网络1.1 模型架构1.2 节点级注意力1.3 语义级注意力1.4 预测模块 2. 构建分层自注意力网络相关链接 0. 前言 在异构图数据集上,异构图注意力网络的测试准确率为 78.39%,比之同构版本有了较大…...
基于深度学习的数字识别系统的设计与实现(python、yolov、PyQt5)
💗博主介绍💗:✌在职Java研发工程师、专注于程序设计、源码分享、技术交流、专注于Java技术领域和毕业设计✌ 温馨提示:文末有 CSDN 平台官方提供的老师 Wechat / QQ 名片 :) Java精品实战案例《700套》 2025最新毕业设计选题推荐…...
ChatGPT 提取文档内容,高效制作PPT、论文
随着人工智能生成内容(AIGC)的快速发展,利用先进的技术工具如 ChatGPT 的 RAG(Retrieval-Augmented Generation,检索增强生成)模式,可以显著提升文档内容提取和内容创作的效率。以下将详细介绍如…...
3、等保1.0 与 2.0 的区别
数据来源:3.等保1.0和2.0的区别_哔哩哔哩_bilibili 等保1.0时代VS等保2.0时代五个规定动作:定级、备案、建设整改、等级测评、监督检查工作内容维持5个规定动作,增加风险评估、安全监测、通报预警、事件调查、数据防护自主可控、供应链安全、…...
Angular面试题九
一、在Angular中,你如何管理全局状态或跨组件共享数据?有哪些常见的实现方式? 在Angular中,管理全局状态或跨组件共享数据是应用开发中的一个重要方面。这有助于保持数据的一致性和可维护性,特别是在复杂的应用中。以下…...
(转载)智能指针shared_ptr从C++11到C++20
shared_ptr和动态数组 - apocelipes - 博客园 (cnblogs.com) template<typename T> std::shared_ptr<T> make_shared_array(size_t size) { return std::shared_ptr<T>(new T[size],std::default_delete<T[]>()); } std::shar…...
Ubuntu 上安装 Miniconda
一、下载 Miniconda 打开终端。访问 Anaconda 官方仓库下载页面https://repo.anaconda.com/miniconda/选择Miniconda3-py310_24.7.1-0-Linux-x86_64.sh,进行下载。文件名当中的py310_24.7.1表示,在 conda 的默认的 base 环境中的 Python 版本是3.10&…...
【Vue系列五】—Vue学习历程的知识分享!
前言 本篇文章讲述前端工程化从模块化到如今的脚手架的发展,以及Webpack、Vue脚手架的详解! 一、模块化 模块化就是把单独的功能封装到模块(文件)中,模块之间相互隔离,但可以通过特定的接口公开内部成员…...
CaLM 因果推理评测体系:如何让大模型更贴近人类认知水平?
CaLM 是什么 CaLM(Causal Evaluation of Language Models,以下简称“CaLM”)是上海人工智能实验室联合同济大学、上海交通大学、北京大学及商汤科技发布首个大模型因果推理开放评测体系及开放平台。首次从因果推理角度提出评估框架ÿ…...
深入探索卷积神经网络(CNN)
深入探索卷积神经网络(CNN) 前言图像的数字表示灰度图像RGB图像 卷积神经网络(CNN)的架构基本组件卷积操作填充(Padding)步幅(Strides) 多通道图像的卷积池化层全连接层 CNN与全连接…...
【C++篇】手撕 C++ string 类:从零实现到深入剖析的模拟之路
文章目录 C string 类的模拟实现:从构造到高级操作前言第一章:为什么要手写 C string 类?1.1 理由与价值 第二章:实现一个简单的 string 类2.1 基本构造与析构2.1.1 示例代码:基础的 string 类实现2.1.2 解读代码 2.2 …...
毕业设计选题:基于ssm+vue+uniapp的校园失物招领小程序
开发语言:Java框架:ssmuniappJDK版本:JDK1.8服务器:tomcat7数据库:mysql 5.7(一定要5.7版本)数据库工具:Navicat11开发软件:eclipse/myeclipse/ideaMaven包:M…...
[系统设计总结] - Proximity Service算法介绍
问题描述 Proximity Service广泛应用于各种地图相关的服务中比如外卖,大众点评,Uber打车,Google地图中,其中比较关键的是我们根据用户的位置来快速找到附近的餐厅,司机,外卖员也就是就近查询算法。 主流的…...
变压吸附制氧机的应用范围
变压吸附制氧机是一种利用变压吸附技术从空气中分离出氧气的设备。该技术通过吸附剂在不同压力下的吸附与解吸性能,实现了氧气的有效分离和纯化。 工业领域 在工业领域,变压吸附制氧机同样具有广泛的应用。首先,钢铁企业在生产过程中需要大量…...
MATLAB绘图基础8:双变量图形绘制
参考书:《 M A T L A B {\rm MATLAB} MATLAB与学术图表绘制》(关东升)。 8.双变量图形绘制 8.1 散点图 散点图用于显示两个变量间的关系,每个数据点在图上表示为一个点,一个变量在 X {\rm X} X轴,一个变量在 Y {\rm Y} Y轴&#…...
Appium高级话题:混合应用与原生应用测试策略
Appium高级话题:混合应用与原生应用测试策略 在移动应用开发领域,混合应用与原生应用各有千秋,但它们的测试策略却大相径庭。本文旨在深入探讨这两种应用类型的测试挑战,并介绍如何利用自动化测试软件ItBuilder高效解决这些问题&…...
windows源码安装protobuf,opencv,ncnn
安装笔记 cmake 在windows可以使用-G"MinGW Makefiles" 搭配make使用,install出来的lib文件时.a结尾的,适合linux下面使用。所以在windows上若无需求使用-G"NMake Makefiles" 搭配nmake。 但是windows上使用-G"NMake Makefil…...
MicroPython 怎么搭建工程代码
在MicroPython中搭建工程代码可以遵循以下步骤: 1. 准备工作 安装MicroPython固件:确保已经将MicroPython烧录到ESP32开发板中。准备开发环境: 可以使用文本编辑器(如VS Code、Thonny、uPyCraft等)来编写代码。 2.…...
Android studio安装问题及解决方案
Android studio安装问题及解决方案 gradle已经安装好了,但是每次就是找不到gradle的位置,每次要重新下载,很慢,每次都不成功 我尝试用安装android studio时自带的卸载程序,卸载android studio,然后重新下…...
前端面试题(二)
6. 深入 JavaScript this 关键字的指向是什么? this 的指向是在函数执行时决定的。默认情况下,非严格模式下 this 指向全局对象(浏览器中为 window),严格模式下 this 为 undefined。在对象方法中,this 通常…...
广汽埃安品牌车型AION UT在奥地利麦格纳工厂正式量产启动并成功下线 | 美通社头条
、美通社消息:3月18日,广汽欧洲业务发展迎来重要里程碑——旗下埃安品牌车型AION UT在奥地利麦格纳(Magna)工厂正式实现量产启动(SOP)并成功下线,标志着广汽在欧洲本地化战略迈入实质性推进阶段。AION UT是广汽欧洲本地化战略的重要核心车型&…...
Vite 8 架构革新:从双引擎到 Rolldown 统一打包的演进之路
1. Vite 8 架构革新的背景与痛点 如果你用过 Vite 7 或更早版本,一定对它的闪电般开发体验印象深刻。这主要得益于 Vite 独特的双引擎架构:开发时用 esbuild 实现毫秒级启动,生产环境则用 Rollup 保证打包质量。但我在实际项目中发现…...
使用ComfyUI搭建可视化DeOldify工作流
使用ComfyUI搭建可视化DeOldify工作流 想给家里的老照片上色,但觉得写代码太麻烦?或者想把手头的黑白视频变成彩色,却不知道从何下手?今天,我们就来聊聊一个特别有意思的玩法:用ComfyUI这个可视化工具&…...
深入解析visualization_msgs::Marker:从基础到实战应用
1. visualization_msgs::Marker是什么? 如果你正在用ROS做机器人开发,肯定遇到过这样的需求:想让机器人在rviz里显示一些自定义的图形,比如路径规划时的参考线、传感器检测到的障碍物轮廓,甚至是简单的文字提示。这时候…...
专为AI打造的浏览器:内存占用仅为Chrome的1/9、比Chrome快11倍(Docker部署教程,支持飞牛nas等服务器部署)
文章目录 📖 介绍 📖 🏡 演示环境 🏡 📒 轻量级无头浏览器介绍与Docker部署指南 📒 📝 工具介绍 🎯 为什么选择它 🔧 Docker Compose 快速部署 💡 连接进行自动化操作 ⚠️ 注意事项 📊 性能对比 🎯 适用场景 ⚓️ 相关链接 ⚓️ 📖 介绍 📖 在自动…...
AI赋能边缘设备:借助快马平台为树莓派集成图像识别功能
AI赋能边缘设备:借助快马平台为树莓派集成图像识别功能 最近在折腾树莓派项目时,发现很多场景需要用到图像识别功能。比如智能门禁、垃圾分类助手或者简单的安防监控。传统做法需要自己训练模型、处理数据,门槛实在太高。后来发现InsCode(快…...
嵌入式技术学习路径与核心技能解析
嵌入式技术学习路径与资源整合指南1. 嵌入式技术体系概述嵌入式系统作为现代电子设备的核心,其技术栈涵盖从底层硬件到上层软件的完整知识体系。一个合格的嵌入式工程师需要掌握以下核心领域:1.1 基础编程能力C/C语言编程基础数据结构与算法计算机组成原…...
Qt加载OBJ或STL模型文件,支持鼠标移动、缩放、旋转Demo
Qt加载模型文件obj或者stl实例,支持鼠标移动缩放旋转demo最近在捣鼓Qt的3D可视化功能,发现用Qt搞个模型查看器比想象中简单。咱们先整点实际的——做个能加载obj/stl模型,支持鼠标拖拽旋转、平移、缩放的demo。废话不多说,直接撸代…...
颠覆式开源工具GHelper:极简华硕笔记本硬件控制解决方案
颠覆式开源工具GHelper:极简华硕笔记本硬件控制解决方案 【免费下载链接】g-helper Lightweight Armoury Crate alternative for Asus laptops. Control tool for ROG Zephyrus G14, G15, G16, M16, Flow X13, Flow X16, TUF, Strix, Scar and other models 项目地…...
Linux下adb调试小米手机报错Exception的5种解决方法(附详细排查步骤)
Linux下adb调试小米手机报错Exception的5种深度解决方案 最近在Linux环境下用adb调试小米手机时,不少开发者遇到了Exception occurred while executing put这个让人头疼的错误。作为一名常年与adb打交道的开发者,我深知这种问题一旦出现,轻则…...
