当前位置: 首页 > news >正文

图注意网络GAT理解及Pytorch代码实现【PyGAT代码详细注释】

文章目录

    • GAT
    • 代码实现【PyGAT】
      • GraphAttentionLayer【一个图注意力层实现】
      • 用上面实现的单层网络测试
      • 加入Multi-head机制的GAT
      • 对数据集Cora的处理
      • csr_matrix()处理稀疏矩阵
      • encode_onehot()对label编号
      • build graph
      • 邻接矩阵构造
    • GAT的推广

GAT

题:Graph Attention Networks
摘要:
提出了图形注意网络(GAT) ,这是一种基于图结构数据的新型神经网络结构,利用掩蔽的自我注意层来解决基于图卷积或其近似的先前方法的缺点。通过叠加层,节点能够参与其邻域的特征,我们能够(隐式地)为邻域中的不同节点指定不同的权重,而不需要任何代价高昂的矩阵操作(如反演) ,或者依赖于预先知道图的结构。通过这种方法,我们同时解决了基于谱的图形神经网络的几个关键问题,并使我们的模型容易地适用于归纳和转导问题。我们的 GAT 模型已经实现或匹配了四个已建立的转导和归纳图基准的最新结果: Cora,Citeseer 和 Pubmed 引用网络数据集,以及protein-protein interaction dataset(其中测试图在训练期间保持不可见)。

Paper with code 网址,可找到对应论文和github源码,原论文使用TensorFlow实现,本篇主要对Pytorch版本的 PyGAT附详细注释帮助理解和测试。

GitHUb: keras版本实现
Pytorch版本实现 PyGAT

在这里插入图片描述
在这里插入图片描述

截图及下文代码注释参考自视频:GAT详解及代码实现

视频中的eij的实现与源码不同,视频中是先拼接两个W,再与a乘
源码在_prepare_attentional_mechanism_input()函数中先分别与a乘,再拼接

代码实现【PyGAT】

在PyGAT :

  • layers.py中定义Simple GAT layer实现(GraphAttentionLayer)和Sparse version GAT layer实现(SpGraphAttentionLayer)。
  • models.py 实现两个版本加入Multi-head机制
  • trains.py 使用model定义的GAT构建模型进行训练,使用cora数据集

GraphAttentionLayer【一个图注意力层实现】

class GraphAttentionLayer(nn.Module):"""Simple GAT layer, similar to https://arxiv.org/abs/1710.10903"""def __init__(self, in_features, out_features, dropout, alpha, concat=True):super(GraphAttentionLayer, self).__init__()self.dropout = dropoutself.in_features = in_features#结点向量的特征维度self.out_features = out_features#经过GAT之后的特征维度self.alpha = alpha#dropout参数self.concat = concat#LeakyReLU参数# 定义可训练参数,即论文中的W和aself.W = nn.Parameter(torch.empty(size=(in_features, out_features)))nn.init.xavier_uniform_(self.W.data, gain=1.414)# xavier初始化self.a = nn.Parameter(torch.empty(size=(2*out_features, 1)))nn.init.xavier_uniform_(self.a.data, gain=1.414)# xavier初始化# 定义leakyReLU激活函数self.leakyrelu = nn.LeakyReLU(self.alpha)def forward(self, h, adj):'''adj图邻接矩阵,维度[N,N]非零即一h.shape: (N, in_features), self.W.shape:(in_features,out_features)Wh.shape: (N, out_features)'''Wh = torch.mm(h, self.W) # 对应eij的计算公式e = self._prepare_attentional_mechanism_input(Wh)#对应LeakyReLU(eij)计算公式zero_vec = -9e15*torch.ones_like(e)#将没有链接的边设置为负无穷attention = torch.where(adj > 0, e, zero_vec)#[N,N]# 表示如果邻接矩阵元素大于0时,则两个节点有连接,该位置的注意力系数保留# 否则需要mask设置为非常小的值,因为softmax的时候这个最小值会不考虑attention = F.softmax(attention, dim=1)# softmax形状保持不变[N,N],得到归一化的注意力全忠!attention = F.dropout(attention, self.dropout, training=self.training)# dropout,防止过拟合h_prime = torch.matmul(attention, Wh)#[N,N].[N,out_features]=>[N,out_features]# 得到由周围节点通过注意力权重进行更新后的表示if self.concat:return F.elu(h_prime)else:return h_primedef _prepare_attentional_mechanism_input(self, Wh):# Wh.shape (N, out_feature)# self.a.shape (2 * out_feature, 1)# Wh1&2.shape (N, 1)# e.shape (N, N)# 先分别与a相乘再进行拼接Wh1 = torch.matmul(Wh, self.a[:self.out_features, :])Wh2 = torch.matmul(Wh, self.a[self.out_features:, :])# broadcast adde = Wh1 + Wh2.Treturn self.leakyrelu(e)def __repr__(self):return self.__class__.__name__ + ' (' + str(self.in_features) + ' -> ' + str(self.out_features) + ')'

用上面实现的单层网络测试

x = torch.randn(6,10)
adj=torch.tensor([[0,1,1,0,0,0],[1,0,1,0,0,0],[1,1,0,1,0,0],[0,0,1,0,1,1],[0,0,0,1,0,0,],[0,0,0,1,1,0]])
my_gat = GraphAttentionLayer(10,5,0.2,0.2)
print(my_gat(x,adj))
输出:
tensor([[-0.2965,  2.8110, -0.6680, -0.9643, -0.9882],[ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000],[-0.4981, -0.7515,  1.1159,  0.3546,  1.3592],[ 0.4679,  1.7208,  0.3084, -0.5331, -0.1291],[-0.4375, -0.8778,  1.1767, -0.5869,  1.5154],[-0.2164, -0.5897,  0.4988, -0.3125,  0.6423]], grad_fn=<EluBackward>)

加入Multi-head机制的GAT

用不同head捕捉不同特征,使模型有更好的拟合能力。

class GAT(nn.Module):def __init__(self, nfeat, nhid, nclass, dropout, alpha, nheads):"""Dense version of GAT."""super(GAT, self).__init__()self.dropout = dropout# 加入Multi-head机制self.attentions = [GraphAttentionLayer(nfeat, nhid, dropout=dropout, alpha=alpha, concat=True) for _ in range(nheads)]for i, attention in enumerate(self.attentions):self.add_module('attention_{}'.format(i), attention)self.out_att = GraphAttentionLayer(nhid * nheads, nclass, dropout=dropout, alpha=alpha, concat=False)def forward(self, x, adj):x = F.dropout(x, self.dropout, training=self.training)x = torch.cat([att(x, adj) for att in self.attentions], dim=1)x = F.dropout(x, self.dropout, training=self.training)x = F.elu(self.out_att(x, adj))return F.log_softmax(x, dim=1)

对数据集Cora的处理

在这里插入图片描述
数据集中两个文件,cites:比如上图11行:编号25和编号1114331的文章
content文件:如下图,每篇文章的id、features及类别
在这里插入图片描述

csr_matrix()处理稀疏矩阵

utils.py中对数据进行的处理

#数据是稀疏的,csr_matrix操作从行开始将1的位置取出来,对数据进行压缩

features = sp.csr_matrix(idx_features_labels[:, 1:-1], dtype=np.float32)
labels = encode_onehot(idx_features_labels[:, -1])

在这里插入图片描述

encode_onehot()对label编号

有7个类别,通过classes_dict是7*7的对角阵把每个类别映射成不同向量,对所有label进行编号,再将编号转换为one_hot向量
在这里插入图片描述

def encode_onehot(labels):# The classes must be sorted before encoding to enable static class encoding.# In other words, make sure the first class always maps to index 0.classes = sorted(list(set(labels)))classes_dict = {c: np.identity(len(classes))[i, :] for i, c in enumerate(classes)}labels_onehot = np.array(list(map(classes_dict.get, labels)), dtype=np.int32)return labels_onehot

build graph

见注释:

# build graphidx = np.array(idx_features_labels[:, 0], dtype=np.int32)#获取所有文章ididx_map = {j: i for i, j in enumerate(idx)}#按文章数目,对id重新映射# 读取数据集中文章和文章直接的引用关系edges_unordered = np.genfromtxt("{}{}.cites".format(path, dataset), dtype=np.int32)# 根据idx_map,将文章引用关系也重新映射edges = np.array(list(map(idx_map.get, edges_unordered.flatten())), dtype=np.int32).reshape(edges_unordered.shape)adj = sp.coo_matrix((np.ones(edges.shape[0]), (edges[:, 0], edges[:, 1])), shape=(labels.shape[0], labels.shape[0]), dtype=np.float32)# build symmetric adjacency matrix 生成邻接矩阵adj = adj + adj.T.multiply(adj.T > adj) - adj.multiply(adj.T > adj)features = normalize_features(features)adj = normalize_adj(adj + sp.eye(adj.shape[0]))

邻接矩阵构造

csr_matrix()只记录了(0,1)1,忽略了(1,0)1。所以需要coo_matrix()操作!才能还原出无向图的邻接矩阵!
在这里插入图片描述

本文的一些代码注释及截图还可见视频

一个拓展:

GAT的推广

GAT的推广
GAT仅仅是应用在了单层图结构网络上,我们是否可以将它推广到多层网络结构呢?

这里我们假设一个有N层网络的结构,每层网络都定义了相同的节点,但是节点之间的关系有所差异。举一个简单的例子,假设有一个用户关系型网络,每层网络的节点都被定义成了网络中的用户,网络的第一层视图的关系可以定义为,两个用户之间是否具有好友关系;网络的第二层视图可以定义为,你评论过我的动态;网络的第三层视图可以定义为你转发过我的动态;第四层关系可以定义为,你at过我等等。

通过这样的定义我们就完成了一个多层网络的构建,他们共享相同的节点,但又分别具有不同的邻边,如果我们分别处理每一层视图视图,然后将他们得出的节点表示单纯相加的话,就可能会失去不同视图之间的协作关系,降低分类(预测)的精度。

基于以上观点,我们提出了一种新的方法:首先在每一层单视图中应用GAT进行学习,并计算出每层视图的节点表示。之后再不同视图之间引入attention机制来让网络自行学习不同视图的权重。之后根据学习的权重,将各个视图加权相加得到全局节点表示并进行后续的诸如节点表示,链接预测等任务。

同时,因为不同视图共享同样的节点,即使每一层视图都表示了不同的节点关系,最终得到的每一层的节点嵌入表示应具有一定的相关性。基于以上理论,我们在每层GAT的网络参数间引入正则化项来约束参数,使其向互相相近的方向学习。大致的网络流程图如下:

这部分来源于 链接:https://www.jianshu.com/p/d5d366ba1a57 来源:简书

相关文章:

图注意网络GAT理解及Pytorch代码实现【PyGAT代码详细注释】

文章目录GAT代码实现【PyGAT】GraphAttentionLayer【一个图注意力层实现】用上面实现的单层网络测试加入Multi-head机制的GAT对数据集Cora的处理csr_matrix()处理稀疏矩阵encode_onehot()对label编号build graph邻接矩阵构造GAT的推广GAT 题&#xff1a;Graph Attention Netwo…...

项目成本管理中的常见误区及解决方案

做过项目的人都明白&#xff0c;项目实施时间一般很长&#xff0c;在实施期间总有很多项目结果不尽人意的问题。要使一个项目取得成功&#xff0c;就要结合很多因素一起才能作用&#xff0c;其中做好项目成本的管理就是最重要的步骤之一&#xff0c;下面列出了常见的项目成本管…...

墨天轮2022年度数据库获奖名单

2022年&#xff0c;国家相继从高位部署、省级试点布局、地市重点深入三个维度&#xff0c;颁布了多项中国数据库行业发展的利好政策。但是我们也能清晰地看到&#xff0c;中国数据库行业发展之路道阻且长&#xff0c;而道路上的“拦路虎”之一则是生态。中国数据库的发展需要多…...

仓储调度|库存管理系统

技术&#xff1a;Java、JSP等摘要&#xff1a;随着电子商务技术和网络技术的快速发展&#xff0c;现代物流技术也在不断进步。物流技术是指与物流要素活动有关的所有专业技术的总称&#xff0c;包括各种操作方法、管理技能等&#xff0c;物流业采用某些现代信息技术方面的成功经…...

Canvas入门-01

导读&#xff1a; 读完全文需要2min。通过这篇文章&#xff0c;你可以了解到以下内容&#xff1a; Canvas标签基本属性如何使用Canvas画矩形、圆形、线条、曲线、笑脸&#x1f60a; 如果你曾经了解过Canvas&#xff0c;可以对照目录回忆一下能否回答上来 毕竟带着问题学习最有效…...

运算符优先级

醋坛酸味罐&#xff0c;位落跳福豆 醋&#xff1a;初等运算符&#xff1a; () [] -> . 坛&#xff1a;单目运算符&#xff1a; - ~ – * & ! sizeof 右结合 酸&#xff1a;算术运算符&#xff1a; - * / % 味&#xff1a;位移运算符&#xff1a;>> << …...

微信小程序使用scss编译wxss文件的配置步骤

文章目录1、在 vscode 中搜索 easysass 插件并安装2、在微信开发工具中导入安装的easysass插件3、修改 spook.easysass-0.0.6/package.json 文件中的配置4、重启开发者工具&#xff0c;就可用使用了微信小程序开发者工具集成了 vscode 编辑器&#xff0c;可以使用 vscode 中众多…...

一步一步教你如何使用 Visual Studio Code 编译一段 C# 代码

以下是一步一步教你如何使用 Visual Studio Code 编写使用 C# 语言输出当前日期和时间的代码&#xff1a; 1、下载并安装 .NET SDK。您可以从 Microsoft 官网下载并安装它。 2、打开 Visual Studio Code&#xff0c;并安装 C# 扩展。您可以在 Visual Studio Code 中通过扩展菜…...

vue-cli中的环境变量注意点

在客户端侧代码中使用环境变量只有以 VUE_APP_ 开头的变量会被 webpack.DefinePlugin 静态嵌入到客户端侧的包中。你可以在应用的代码中这样访问它们&#xff1a;console.log(process.env.VUE_APP_SECRET)在构建过程中&#xff0c;process.env.VUE_APP_SECRET 将会被相应的值所…...

2.3数据类型

文章目录1. 命名规则2.字符3.数字4.日期5.图片1. 命名规则 字段名必须以字母开头&#xff0c;尽量不要使用拼音长度不能超过30个字符&#xff08;不同数据库&#xff0c;不同版本会有不同&#xff09;不能使用SQL的保留字&#xff0c;如where,order,group只能使用如下字符a-z、…...

Kafka基本概念

什么是Kafka Kafka是一个消息系统。它可以集中收集生产者的消息&#xff0c;并由消费者按需获取。在Kafka中&#xff0c;也将消息称为日志(log)。 一个系统&#xff0c;若仅有一类或者少量的消息&#xff0c;可直接进行发送和接收。 随着业务量日益复杂&#xff0c;消息的种类…...

使用QueryBuilders、NativeSearchQuery实现复杂查询

使用QueryBuilders、NativeSearchQuery实现复杂查询 本文继续前面文章《ElasticSearch系列&#xff08;二&#xff09;springboot中集成使用ElasticSearch的Demo》&#xff0c;在前文中&#xff0c;我们介绍了使用springdata做一些简单查询&#xff0c;但是要实现一些高级的组…...

taobao.open.account.update( Open Account数据更新 )

&#xffe5;开放平台免费API不需用户授权 Open Account数据更新 公共参数 请求地址: HTTP地址 http://gw.api.taobao.com/router/rest 公共请求参数: 公共响应参数: 响应参数 点击获取key和secret 请求示例 TaobaoClient client new DefaultTaobaoClient(url, appkey, sec…...

PT100铂电阻温度传感器

PT100温度传感器又叫做铂热电阻。     热电阻是中低温区&#xfe61;常用的一种温度检测器。它的主要特点是测量精度高&#xff0c;性能稳定。其中铂热电阻的测量精确度是&#xfe61;高的&#xff0c;它不仅广泛应用于工业测温&#xff0c;而且被制成标准的基准仪。金属热…...

蓝桥杯-本质上升序列

没有白走的路&#xff0c;每一步都算数&#x1f388;&#x1f388;&#x1f388; 题目描述&#xff1a; 小蓝特别喜欢单调递增的事物 在一个字符串中如果取出若干个字符&#xff0c;按照在原来字符串中的顺序排列在一起&#xff0c;组成的新的字符串如果是单调递增的&#xf…...

synchronized锁重入验证

文章目录synchronized锁重入验证1. 可重入锁2. synchronized锁重入2.1 本类同步方法内部调用本类其它同步方法2.2 子类同步方法内部调用父类的同步方法2.3 A类的同步方法内部调用B类的同步方法3. synchronized修饰方法写法synchronized锁重入验证 1. 可重入锁 可重入锁&#…...

超简单的计数排序!!

假设给定混乱数据为&#xff1a;3&#xff0c;0&#xff0c;1&#xff0c;3&#xff0c;6&#xff0c;5&#xff0c;4&#xff0c;2&#xff0c;1&#xff0c;9。 下面我们将通过使用计数排序的思想来完成对上面数据的排序。(先不谈负数) 计数排序 该排序的思路和它的名字一样…...

发现新大陆——原来软件开发根本不需要会编码(看我10分钟应用上线)

目录 一、前言 二、官网基础功能及搭建 三、体验过程 01、连接数据源 02、设计表单 03、流程设计 04、图表呈现 05、组织架构设置 五、效率评价 六、小结 一、前言 众所周知&#xff0c;每家公司在发展过程中都需要构建大量的内部系统&#xff0c; 如运营使用的用户…...

【Leedcode】栈和队列必备的面试题(第二期)

【Leedcode】栈和队列必备的面试题&#xff08;第二期&#xff09; 文章目录【Leedcode】栈和队列必备的面试题&#xff08;第二期&#xff09;一、题目&#xff08;用两个队列实现栈&#xff09;二、思路图解1.定义两个队列2.初始化两个队列3.往两个队列中放入数据4.两个队列出…...

Elasticsearch实战之(商品搜索API实现)

Elasticsearch实战之&#xff08;商品搜索API实现&#xff09; 1、案例介绍 某医药电商H5商城基于Elasticsearch实现商品搜索 2、案例分析 2.1、数据来源 商品库 - 平台运营维护商品库 - 供应商维护 2.2、数据同步 2.2.1、同步双写 写入 MySQL&#xff0c;直接也同步往…...

【Web 进阶篇】优雅的接口设计:统一响应、全局异常处理与参数校验

系列回顾&#xff1a; 在上一篇中&#xff0c;我们成功地为应用集成了数据库&#xff0c;并使用 Spring Data JPA 实现了基本的 CRUD API。我们的应用现在能“记忆”数据了&#xff01;但是&#xff0c;如果你仔细审视那些 API&#xff0c;会发现它们还很“粗糙”&#xff1a;有…...

九天毕昇深度学习平台 | 如何安装库?

pip install 库名 -i https://pypi.tuna.tsinghua.edu.cn/simple --user 举个例子&#xff1a; 报错 ModuleNotFoundError: No module named torch 那么我需要安装 torch pip install torch -i https://pypi.tuna.tsinghua.edu.cn/simple --user pip install 库名&#x…...

PAN/FPN

import torch import torch.nn as nn import torch.nn.functional as F import mathclass LowResQueryHighResKVAttention(nn.Module):"""方案 1: 低分辨率特征 (Query) 查询高分辨率特征 (Key, Value).输出分辨率与低分辨率输入相同。"""def __…...

LINUX 69 FTP 客服管理系统 man 5 /etc/vsftpd/vsftpd.conf

FTP 客服管理系统 实现kefu123登录&#xff0c;不允许匿名访问&#xff0c;kefu只能访问/data/kefu目录&#xff0c;不能查看其他目录 创建账号密码 useradd kefu echo 123|passwd -stdin kefu [rootcode caozx26420]# echo 123|passwd --stdin kefu 更改用户 kefu 的密码…...

【笔记】WSL 中 Rust 安装与测试完整记录

#工作记录 WSL 中 Rust 安装与测试完整记录 1. 运行环境 系统&#xff1a;Ubuntu 24.04 LTS (WSL2)架构&#xff1a;x86_64 (GNU/Linux)Rust 版本&#xff1a;rustc 1.87.0 (2025-05-09)Cargo 版本&#xff1a;cargo 1.87.0 (2025-05-06) 2. 安装 Rust 2.1 使用 Rust 官方安…...

【从零学习JVM|第三篇】类的生命周期(高频面试题)

前言&#xff1a; 在Java编程中&#xff0c;类的生命周期是指类从被加载到内存中开始&#xff0c;到被卸载出内存为止的整个过程。了解类的生命周期对于理解Java程序的运行机制以及性能优化非常重要。本文会深入探寻类的生命周期&#xff0c;让读者对此有深刻印象。 目录 ​…...

【Android】Android 开发 ADB 常用指令

查看当前连接的设备 adb devices 连接设备 adb connect 设备IP 断开已连接的设备 adb disconnect 设备IP 安装应用 adb install 安装包的路径 卸载应用 adb uninstall 应用包名 查看已安装的应用包名 adb shell pm list packages 查看已安装的第三方应用包名 adb shell pm list…...

OD 算法题 B卷【正整数到Excel编号之间的转换】

文章目录 正整数到Excel编号之间的转换 正整数到Excel编号之间的转换 excel的列编号是这样的&#xff1a;a b c … z aa ab ac… az ba bb bc…yz za zb zc …zz aaa aab aac…; 分别代表以下的编号1 2 3 … 26 27 28 29… 52 53 54 55… 676 677 678 679 … 702 703 704 705;…...

【Linux】自动化构建-Make/Makefile

前言 上文我们讲到了Linux中的编译器gcc/g 【Linux】编译器gcc/g及其库的详细介绍-CSDN博客 本来我们将一个对于编译来说很重要的工具&#xff1a;make/makfile 1.背景 在一个工程中源文件不计其数&#xff0c;其按类型、功能、模块分别放在若干个目录中&#xff0c;mak…...

Matlab实现任意伪彩色图像可视化显示

Matlab实现任意伪彩色图像可视化显示 1、灰度原始图像2、RGB彩色原始图像 在科研研究中&#xff0c;如何展示好看的实验结果图像非常重要&#xff01;&#xff01;&#xff01; 1、灰度原始图像 灰度图像每个像素点只有一个数值&#xff0c;代表该点的​​亮度&#xff08;或…...