图神经网络实战——分层自注意力网络
图神经网络实战——分层自注意力网络
- 0. 前言
- 1. 分层自注意力网络
- 1.1 模型架构
- 1.2 节点级注意力
- 1.3 语义级注意力
- 1.4 预测模块
- 2. 构建分层自注意力网络
- 相关链接
0. 前言
在异构图数据集上,异构图注意力网络的测试准确率为 78.39%
,比之同构版本有了较大提高,但我们还能进一步提高准确率。在本节中,我们将学习一种专门用于处理异构图的图神经网络架构,分层自注意力网络 (hierarchical self-attention network
, HAN
)。我们将介绍其工作原理,以便更好地理解该架构与经典图注意力网络 (Graph Attention Networks,GAT) 之间的区别。最后,使用 PyTorch Geometric
实现此架构,并将结果与其它 GNN
模型进行比较。
1. 分层自注意力网络
1.1 模型架构
在本节中,我们将实现一个专为处理异构图而设计的图神经网络 (Graph Neural Networks, GNN) 模型——分层自注意力网络 (hierarchical self-attention network
, HAN
)。该架构由 Liu
等人于 2021
年提出。HAN
在两个不同层次上使用自注意力:
- 节点级注意力 (
Node-level attention
):了解给定元路径中相邻节点的重要性 - 语义级注意力 (
Semantic-level attention
):了解每个元路径的重要性。这是HAN
的主要特点,它允许我们自动为给定任务选择最佳元路径。例如,在某些任务(如预测玩家人数)中,元路径game-user-game
可能比game-dev-game
更合适
接下来,我们将详细介绍 HAN
的三个主要组件——节点级注意力 (Node-level attention
)、语义级注意力 (Semantic-level attention
) 和预测模块 (prediction module
),HAN
架构如下所示。
)
1.2 节点级注意力
与图注意力网络 (Graph Attention Networks,GAT) 一样,第一步将节点投影到每个元路径的统一特征空间中。然后,用第二个权重矩阵计算同一元路径中的节点对(两个投影节点的连接)的权重,并对这一结果应用非线性函数,然后用 softmax
函数对其进行归一化处理。 j j j 节点对 i i i 节点的归一化注意力分数(重要性)计算如下:
α i j Φ = exp ( σ ( a Φ T [ W Φ h i ∣ ∣ W Φ h j ] ) ) ∑ k ∈ N i Φ exp ( σ ( a Φ T [ W Φ h i ∣ ∣ W Φ h k ] ) ) \alpha_{ij}^\Phi =\frac {\exp (\sigma(a_{\Phi}^T[W_{\Phi}h_i||W_{\Phi}h_j]))}{\sum _{k\in \mathcal N_i^{\Phi}}\exp(\sigma(a_{\Phi}^T[W_{\Phi}h_i||W_{\Phi}h_k]))} αijΦ=∑k∈NiΦexp(σ(aΦT[WΦhi∣∣WΦhk]))exp(σ(aΦT[WΦhi∣∣WΦhj]))
其中, h i h_i hi 表示 i i i 节点的特征, W Φ W_{\Phi} WΦ 是 Φ \Phi Φ 元路径的共享权重矩阵, a Φ a_{\Phi} aΦ 是 Φ \Phi Φ 元路径的注意力权重矩阵, σ σ σ 是非线性激活函数(如 LeakyReLU
), N i Φ \mathcal N_i^{\Phi} NiΦ 是节点(包括其自身)在 Φ \Phi Φ 元路径中的邻居集。使用多头注意力获得最终的嵌入:
Z i = ∣ ∣ k = 1 K σ ( ∑ k ∈ N i α i j Φ ⋅ W Φ h j ) Z_i=||_{k=1}^K\sigma(\sum _{k\in \mathcal N_i}\alpha _{ij}^{\Phi}\cdot W_{\Phi}h_j) Zi=∣∣k=1Kσ(k∈Ni∑αijΦ⋅WΦhj)
1.3 语义级注意力
对于语义级注意力,我们对每个元路径的注意力得分(表示为 β Φ 1 , β Φ 2 , . . . , β Φ p β_{\Phi _1}, β_{\Phi _2}, ... , β_{\Phi _p} βΦ1,βΦ2,...,βΦp )重复类似的过程。对于给定元路径中的每个节点嵌入(表示为 Z Φ p Z_{\Phi _p} ZΦp),都将其馈送到一个多层感知机 (Multilayer Perceptron
, MLP
) 中,应用非线性变换。将这一结果与新的注意力向量 q q q 进行比较,作为相似性度量。我们将这一结果平均化,以计算给定元路径的重要性:
w Φ p = 1 ∣ V ∣ ∑ i ∈ V q T ⋅ tanh ( W ⋅ z i Φ p + b ) w_{\Phi_p}=\frac 1{|V|}\sum_{i\in V}q^T\cdot \tanh(W\cdot z_i^{\Phi_p}+b) wΦp=∣V∣1i∈V∑qT⋅tanh(W⋅ziΦp+b)
其中, W W W (MLP
的权重矩阵)、 b b b (MLP
的偏置)和 q q q (语义级注意力向量)在元路径中是共享的。
必须对这一结果进行归一化处理,以比较不同的语义级注意力得分。使用 softmax
函数来获得最终权重:
β Φ p = exp ( w Φ p ) ∑ k = 1 P exp ( w Φ k ) \beta _{\Phi_p}=\frac {\exp(w_{\Phi_p})}{\sum_{k=1}^P\exp(w_{\Phi_k})} βΦp=∑k=1Pexp(wΦk)exp(wΦp)
将节点级注意力和语义级注意力结合起来得到最终嵌入 Z Z Z:
Z = ∑ p = 1 P β Φ p ⋅ Z Φ p Z=\sum_{p=1}^P\beta_{\Phi_p}\cdot Z_{\Phi_p} Z=p=1∑PβΦp⋅ZΦp
1.4 预测模块
最后一层(如多层感知机 (Multilayer Perceptron
, MLP
) )用于针对特定的下游任务(如节点分类或链接预测)对模型进行微调。
2. 构建分层自注意力网络
接下来,使用 PyTorch Geometric
在 DBLP 数据集上实现分层自注意力网络 (hierarchical self-attention network
, HAN
) 架构。
(1) 首先,导入 HAN
层:
import torch
import torch.nn.functional as F
from torch import nnimport torch_geometric.transforms as T
from torch_geometric.datasets import DBLP
from torch_geometric.nn import HANConv, Linear
(2) 加载 DBLP
数据集,并为会议节点引入虚拟特征:
dataset = DBLP('.')
data = dataset[0]
print(data)data['conference'].x = torch.zeros(20, 1)
(3) 使用 HANConv
的 HAN
卷积层和用于最终分类的线性层创建 HAN
类:
class HAN(nn.Module):def __init__(self, dim_in, dim_out, dim_h=128, heads=8):super().__init__()self.han = HANConv(dim_in, dim_h, heads=heads, dropout=0.6, metadata=data.metadata())self.linear = nn.Linear(dim_h, dim_out)
(4) 在 forward()
方法中,我们必须指定需要关注作者:
def forward(self, x_dict, edge_index_dict):out = self.han(x_dict, edge_index_dict)out = self.linear(out['author'])return out
(5) 使用懒初始化 (dim_in=-1)
来初始化模型,因此 PyTorch Geometric
会自动计算每个节点类型的输入大小:
model = HAN(dim_in=-1, dim_out=4)
(6) 实例化 Adam
优化器,并尝试将数据和模型传输到 GPU
:
optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=0.001)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
data, model = data.to(device), model.to(device)
(7) 实现 test()
函数计算分类任务的准确率:
@torch.no_grad()
def test(mask):model.eval()pred = model(data.x_dict, data.edge_index_dict).argmax(dim=-1)acc = (pred[mask] == data['author'].y[mask]).sum() / mask.sum()return float(acc)
(8) 对模型进行 101
个 epoch
的训练,与同构图神经网络 (Graph Neural Networks, GNN) 的训练循环唯一不同的是,需要指定关注的作者节点类型:
for epoch in range(101):model.train()optimizer.zero_grad()out = model(data.x_dict, data.edge_index_dict)mask = data['author'].train_maskloss = F.cross_entropy(out[mask], data['author'].y[mask])loss.backward()optimizer.step()if epoch % 20 == 0:train_acc = test(data['author'].train_mask)val_acc = test(data['author'].val_mask)print(f'Epoch: {epoch:>3} | Train Loss: {loss:.4f} | Train Acc: {train_acc*100:.2f}% | Val Acc: {val_acc*100:.2f}%')
训练过程如下:
(9) 最后,在测试集上测试训练后的模型:
test_acc = test(data['author'].test_mask)
print(f'Test accuracy: {test_acc*100:.2f}%')# Test accuracy: 81.58%
HAN
的测试准确率为 81.58%
,高于异构图注意力网络 (78.39%
)和经典图注意力网络 (Graph Attention Networks,GAT) (73.29%
)。这说明了构建良好的表示方法以聚合不同类型节点和关系的重要性。异构图的技术在很大程度上取决于具体应用,但尝试不同的模型对于构建高性能应能具有重要作用。
相关链接
图神经网络实战(1)——图神经网络(Graph Neural Networks, GNN)基础
图神经网络实战(6)——使用PyTorch构建图神经网络
图神经网络实战(7)——图卷积网络(Graph Convolutional Network, GCN)详解与实现
图神经网络实战(8)——图注意力网络(Graph Attention Networks, GAT)
图神经网络实战(12)——图同构网络(Graph Isomorphism Network, GIN)
图神经网络实战(18)——消息传播神经网络
相关文章:

图神经网络实战——分层自注意力网络
图神经网络实战——分层自注意力网络 0. 前言1. 分层自注意力网络1.1 模型架构1.2 节点级注意力1.3 语义级注意力1.4 预测模块 2. 构建分层自注意力网络相关链接 0. 前言 在异构图数据集上,异构图注意力网络的测试准确率为 78.39%,比之同构版本有了较大…...

基于深度学习的数字识别系统的设计与实现(python、yolov、PyQt5)
💗博主介绍💗:✌在职Java研发工程师、专注于程序设计、源码分享、技术交流、专注于Java技术领域和毕业设计✌ 温馨提示:文末有 CSDN 平台官方提供的老师 Wechat / QQ 名片 :) Java精品实战案例《700套》 2025最新毕业设计选题推荐…...

ChatGPT 提取文档内容,高效制作PPT、论文
随着人工智能生成内容(AIGC)的快速发展,利用先进的技术工具如 ChatGPT 的 RAG(Retrieval-Augmented Generation,检索增强生成)模式,可以显著提升文档内容提取和内容创作的效率。以下将详细介绍如…...
3、等保1.0 与 2.0 的区别
数据来源:3.等保1.0和2.0的区别_哔哩哔哩_bilibili 等保1.0时代VS等保2.0时代五个规定动作:定级、备案、建设整改、等级测评、监督检查工作内容维持5个规定动作,增加风险评估、安全监测、通报预警、事件调查、数据防护自主可控、供应链安全、…...
Angular面试题九
一、在Angular中,你如何管理全局状态或跨组件共享数据?有哪些常见的实现方式? 在Angular中,管理全局状态或跨组件共享数据是应用开发中的一个重要方面。这有助于保持数据的一致性和可维护性,特别是在复杂的应用中。以下…...
(转载)智能指针shared_ptr从C++11到C++20
shared_ptr和动态数组 - apocelipes - 博客园 (cnblogs.com) template<typename T> std::shared_ptr<T> make_shared_array(size_t size) { return std::shared_ptr<T>(new T[size],std::default_delete<T[]>()); } std::shar…...
Ubuntu 上安装 Miniconda
一、下载 Miniconda 打开终端。访问 Anaconda 官方仓库下载页面https://repo.anaconda.com/miniconda/选择Miniconda3-py310_24.7.1-0-Linux-x86_64.sh,进行下载。文件名当中的py310_24.7.1表示,在 conda 的默认的 base 环境中的 Python 版本是3.10&…...

【Vue系列五】—Vue学习历程的知识分享!
前言 本篇文章讲述前端工程化从模块化到如今的脚手架的发展,以及Webpack、Vue脚手架的详解! 一、模块化 模块化就是把单独的功能封装到模块(文件)中,模块之间相互隔离,但可以通过特定的接口公开内部成员…...

CaLM 因果推理评测体系:如何让大模型更贴近人类认知水平?
CaLM 是什么 CaLM(Causal Evaluation of Language Models,以下简称“CaLM”)是上海人工智能实验室联合同济大学、上海交通大学、北京大学及商汤科技发布首个大模型因果推理开放评测体系及开放平台。首次从因果推理角度提出评估框架ÿ…...

深入探索卷积神经网络(CNN)
深入探索卷积神经网络(CNN) 前言图像的数字表示灰度图像RGB图像 卷积神经网络(CNN)的架构基本组件卷积操作填充(Padding)步幅(Strides) 多通道图像的卷积池化层全连接层 CNN与全连接…...

【C++篇】手撕 C++ string 类:从零实现到深入剖析的模拟之路
文章目录 C string 类的模拟实现:从构造到高级操作前言第一章:为什么要手写 C string 类?1.1 理由与价值 第二章:实现一个简单的 string 类2.1 基本构造与析构2.1.1 示例代码:基础的 string 类实现2.1.2 解读代码 2.2 …...

毕业设计选题:基于ssm+vue+uniapp的校园失物招领小程序
开发语言:Java框架:ssmuniappJDK版本:JDK1.8服务器:tomcat7数据库:mysql 5.7(一定要5.7版本)数据库工具:Navicat11开发软件:eclipse/myeclipse/ideaMaven包:M…...

[系统设计总结] - Proximity Service算法介绍
问题描述 Proximity Service广泛应用于各种地图相关的服务中比如外卖,大众点评,Uber打车,Google地图中,其中比较关键的是我们根据用户的位置来快速找到附近的餐厅,司机,外卖员也就是就近查询算法。 主流的…...

变压吸附制氧机的应用范围
变压吸附制氧机是一种利用变压吸附技术从空气中分离出氧气的设备。该技术通过吸附剂在不同压力下的吸附与解吸性能,实现了氧气的有效分离和纯化。 工业领域 在工业领域,变压吸附制氧机同样具有广泛的应用。首先,钢铁企业在生产过程中需要大量…...

MATLAB绘图基础8:双变量图形绘制
参考书:《 M A T L A B {\rm MATLAB} MATLAB与学术图表绘制》(关东升)。 8.双变量图形绘制 8.1 散点图 散点图用于显示两个变量间的关系,每个数据点在图上表示为一个点,一个变量在 X {\rm X} X轴,一个变量在 Y {\rm Y} Y轴&#…...
Appium高级话题:混合应用与原生应用测试策略
Appium高级话题:混合应用与原生应用测试策略 在移动应用开发领域,混合应用与原生应用各有千秋,但它们的测试策略却大相径庭。本文旨在深入探讨这两种应用类型的测试挑战,并介绍如何利用自动化测试软件ItBuilder高效解决这些问题&…...
windows源码安装protobuf,opencv,ncnn
安装笔记 cmake 在windows可以使用-G"MinGW Makefiles" 搭配make使用,install出来的lib文件时.a结尾的,适合linux下面使用。所以在windows上若无需求使用-G"NMake Makefiles" 搭配nmake。 但是windows上使用-G"NMake Makefil…...
MicroPython 怎么搭建工程代码
在MicroPython中搭建工程代码可以遵循以下步骤: 1. 准备工作 安装MicroPython固件:确保已经将MicroPython烧录到ESP32开发板中。准备开发环境: 可以使用文本编辑器(如VS Code、Thonny、uPyCraft等)来编写代码。 2.…...

Android studio安装问题及解决方案
Android studio安装问题及解决方案 gradle已经安装好了,但是每次就是找不到gradle的位置,每次要重新下载,很慢,每次都不成功 我尝试用安装android studio时自带的卸载程序,卸载android studio,然后重新下…...
前端面试题(二)
6. 深入 JavaScript this 关键字的指向是什么? this 的指向是在函数执行时决定的。默认情况下,非严格模式下 this 指向全局对象(浏览器中为 window),严格模式下 this 为 undefined。在对象方法中,this 通常…...

利用最小二乘法找圆心和半径
#include <iostream> #include <vector> #include <cmath> #include <Eigen/Dense> // 需安装Eigen库用于矩阵运算 // 定义点结构 struct Point { double x, y; Point(double x_, double y_) : x(x_), y(y_) {} }; // 最小二乘法求圆心和半径 …...

Python:操作 Excel 折叠
💖亲爱的技术爱好者们,热烈欢迎来到 Kant2048 的博客!我是 Thomas Kant,很开心能在CSDN上与你们相遇~💖 本博客的精华专栏: 【自动化测试】 【测试经验】 【人工智能】 【Python】 Python 操作 Excel 系列 读取单元格数据按行写入设置行高和列宽自动调整行高和列宽水平…...

Python实现prophet 理论及参数优化
文章目录 Prophet理论及模型参数介绍Python代码完整实现prophet 添加外部数据进行模型优化 之前初步学习prophet的时候,写过一篇简单实现,后期随着对该模型的深入研究,本次记录涉及到prophet 的公式以及参数调优,从公式可以更直观…...

基于Docker Compose部署Java微服务项目
一. 创建根项目 根项目(父项目)主要用于依赖管理 一些需要注意的点: 打包方式需要为 pom<modules>里需要注册子模块不要引入maven的打包插件,否则打包时会出问题 <?xml version"1.0" encoding"UTF-8…...
【无标题】路径问题的革命性重构:基于二维拓扑收缩色动力学模型的零点隧穿理论
路径问题的革命性重构:基于二维拓扑收缩色动力学模型的零点隧穿理论 一、传统路径模型的根本缺陷 在经典正方形路径问题中(图1): mermaid graph LR A((A)) --- B((B)) B --- C((C)) C --- D((D)) D --- A A -.- C[无直接路径] B -…...

Ubuntu Cursor升级成v1.0
0. 当前版本低 使用当前 Cursor v0.50时 GitHub Copilot Chat 打不开,快捷键也不好用,当看到 Cursor 升级后,还是蛮高兴的 1. 下载 Cursor 下载地址:https://www.cursor.com/cn/downloads 点击下载 Linux (x64) ,…...

QT开发技术【ffmpeg + QAudioOutput】音乐播放器
一、 介绍 使用ffmpeg 4.2.2 在数字化浪潮席卷全球的当下,音视频内容犹如璀璨繁星,点亮了人们的生活与工作。从短视频平台上令人捧腹的搞笑视频,到在线课堂中知识渊博的专家授课,再到影视平台上扣人心弦的高清大片,音…...
Python常用模块:time、os、shutil与flask初探
一、Flask初探 & PyCharm终端配置 目的: 快速搭建小型Web服务器以提供数据。 工具: 第三方Web框架 Flask (需 pip install flask 安装)。 安装 Flask: 建议: 使用 PyCharm 内置的 Terminal (模拟命令行) 进行安装,避免频繁切换。 PyCharm Terminal 配置建议: 打开 Py…...

CSS3相关知识点
CSS3相关知识点 CSS3私有前缀私有前缀私有前缀存在的意义常见浏览器的私有前缀 CSS3基本语法CSS3 新增长度单位CSS3 新增颜色设置方式CSS3 新增选择器CSS3 新增盒模型相关属性box-sizing 怪异盒模型resize调整盒子大小box-shadow 盒子阴影opacity 不透明度 CSS3 新增背景属性ba…...
字符串哈希+KMP
P10468 兔子与兔子 #include<bits/stdc.h> using namespace std; typedef unsigned long long ull; const int N 1000010; ull a[N], pw[N]; int n; ull gethash(int l, int r){return a[r] - a[l - 1] * pw[r - l 1]; } signed main(){ios::sync_with_stdio(false), …...