【图卷积网络】GCN基础原理简单python实现
基础原理讲解
应用路径
卷积网络最经典的就是CNN,其 可以提取图片中的有效信息,而生活中存在大量拓扑结构的数据。图卷积网络主要特点就是在于其输入数据是图结构数据,即 G ( V , E ) G(V,E) G(V,E),其中V是节点,E是边,能有效提取拓扑结构中的有效信息,实现节点分类,边预测等。
基础原理
其核心公式是:
H ( l + 1 ) = σ ( D ~ − 1 / 2 A ~ D ~ − 1 / 2 H l W l ) H^{(l+1)}=\sigma(\tilde{D}^{-1/2}\tilde{A}\tilde{D}^{-1/2}H^{l}W^l) H(l+1)=σ(D~−1/2A~D~−1/2HlWl)
其中:
- σ \sigma σ 是非线性激活函数
- D ~ \tilde{D} D~是度矩阵, D ~ i i = ∑ j A ~ i j \tilde{D}_{ii}=\sum_j\tilde{A}_{ij} D~ii=∑jA~ij
- A ~ \tilde{A} A~是加了自环的邻接矩阵,通常表示为 A + I A+I A+I, A A A是原始邻接矩阵, I I I是单位矩阵
- H l H^l Hl是第 l l l层的节点特征矩阵, H l + 1 H^{l+1} Hl+1是第 l + 1 l+1 l+1层的节点特征矩阵
- W l W^l Wl是第 l l l层的学习权重矩阵
步骤讲解:
1、邻接矩阵归一化: 将邻接矩阵归一化,使得邻居节点特征对中心节点特征的贡献相等。
2、特征聚合: 通过邻接矩阵与节点特征矩阵相乘,实现邻居特征聚合。
3、线性变换: 通过可学习的权重矩阵对聚合后的特征进行线性变换。
加自环的邻接矩阵
A ~ = A + λ I \tilde{A} = A+\lambda I A~=A+λI
邻接矩阵加上一个单位矩阵, λ \lambda λ是一个可以训练的参数,但也可直接取1。加自环 是为了增强节点自我特征表示,这样在进行图卷积操作时,节点不仅会聚合来自邻居节点的特征,还会聚合自己的特征。
图卷积操作
图片的卷积是一个一个卷积核,在图片上滑动着做卷积。图的卷积就是自己加邻居一起做加和。
即:
A ~ X \tilde{A}X A~X
度矩阵求解
D ~ i i = ∑ j A ~ i j \tilde{D}_{ii}=\sum_j\tilde{A}_{ij} D~ii=j∑A~ij
标准化
在进行加和时,节点的度不同,有存在较高度值的节点和较低度值的节点,这可能导致梯度爆炸或梯度消失的问题。
根据度矩阵,求逆,然后 D ~ − 1 A ~ D ~ − 1 X \tilde{D}^{-1}\tilde{A} \tilde{D}^{-1}X D~−1A~D~−1X,就进行了标准化,前一个 D ~ − 1 \tilde{D}^{-1} D~−1是对行进行标准化,后一个 D ~ − 1 \tilde{D}^{-1} D~−1是对列进行标准化。能够实现给与低度节点更大的权重,从而降低高节点的影响。
在上式推导中, D ~ − 1 A ~ D ~ − 1 X \tilde{D}^{-1}\tilde{A} \tilde{D}^{-1}X D~−1A~D~−1X 做了两次标准化,所以修改上式为 D ~ − 1 / 2 A ~ D ~ − 1 / 2 X \tilde{D}^{-1/2}\tilde{A} \tilde{D}^{-1/2}X D~−1/2A~D~−1/2X
简单python实现
基于cora数据集实现节点分类
- cora数据集处理
# cora数据集测试
raw_data = pd.read_csv('./data/data/cora/cora.content', sep='\t', header=None)
print("content shape: ", raw_data.shape)raw_data_cites = pd.read_csv('./data/data/cora/cora.cites', sep='\t', header=None)
print("cites shape: ", raw_data_cites.shape)features = raw_data.iloc[:,1:-1]
print("features shape: ", features.shape)# one-hot encoding
labels = pd.get_dummies(raw_data[1434])
print("\n----head(3) one-hot label----")
print(labels.head(3))
l_ = np.array([0,1,2,3,4,5,6])
lab = []
for i in range(labels.shape[0]):lab.append(l_[labels.loc[i,:].values.astype(bool)][0])
#构建邻接矩阵
num_nodes = raw_data.shape[0]# 将节点重新编号为[0, 2707]
new_id = list(raw_data.index)
id = list(raw_data[0])
c = zip(id, new_id)
map = dict(c)# 根据节点个数定义矩阵维度
matrix = np.zeros((num_nodes,num_nodes))# 根据边构建矩阵
for i ,j in zip(raw_data_cites[0],raw_data_cites[1]):x = map[i] ; y = map[j]matrix[x][y] = matrix[y][x] = 1 # 无向图:有引用关系的样本点之间取1# 查看邻接矩阵的元素
print(matrix.shape)
- GCN网络实现
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f'Using device: {device}')
class GCNLayer(nn.Module):def __init__(self, in_features, out_features):super(GCNLayer, self).__init__()self.linear = nn.Linear(in_features, out_features)def forward(self, x, adj):rowsum = torch.sum(adj,dim=1)d_inv_sqrt = torch.pow(rowsum,-0.5)d_inv_sqrt[torch.isinf(d_inv_sqrt)] =0.0d_mat_inv_sqrt = torch.diag(d_inv_sqrt)adj_normalized = torch.mm(torch.mm(d_mat_inv_sqrt,adj),d_mat_inv_sqrt)out = torch.mm(adj_normalized,x)out = self.linear(out)return out
class GCN(nn.Module):def __init__(self, n_features, n_hidden, n_classes):super(GCN, self).__init__()self.gcn1 = GCNLayer(n_features, n_hidden)self.gcn2 = GCNLayer(n_hidden, n_classes)def forward(self, x, adj):x = self.gcn1(x, adj)x = F.relu(x)x = self.gcn2(x, adj)return x#F.log_softmax(x, dim=1)
# 示例数据(实际数据应根据具体情况加载)features = torch.tensor(features.values, dtype=torch.float32)
adj = torch.tensor(matrix, dtype=torch.float32)
labels = torch.tensor(lab, dtype=torch.long)
# features = torch.tensor([[1, 0], [0, 1], [1, 1]], dtype=torch.float32)
# adj = torch.tensor([[1, 1, 0], [1, 1, 1], [0, 1, 1]], dtype=torch.float32)
# labels = torch.tensor([0, 1, 0], dtype=torch.long)# 模型参数
n_features = features.shape[1]
n_hidden = 16
n_classes = len(torch.unique(labels))# 创建模型
model = GCN(n_features, n_hidden, n_classes)
model = model.cuda()
optimizer = optim.Adam(model.parameters(), lr=0.01)
loss_fn = nn.CrossEntropyLoss()
# 训练模型
n_epochs = 200
for epoch in range(n_epochs):model.train()features, labels = features.cuda(), labels.cuda()adj = adj.cuda()optimizer.zero_grad()output = model(features, adj)loss = loss_fn(output, labels)loss.backward()optimizer.step()if (epoch + 1) % 20 == 0:print(f'Epoch {epoch+1}, Loss: {loss.item()}')
print("Training complete.")
参考
cora数据集及简介
图卷积网络详细介绍
GCN讲解
相关文章:

【图卷积网络】GCN基础原理简单python实现
基础原理讲解 应用路径 卷积网络最经典的就是CNN,其 可以提取图片中的有效信息,而生活中存在大量拓扑结构的数据。图卷积网络主要特点就是在于其输入数据是图结构数据,即 G ( V , E ) G(V,E) G(V,E),其中V是节点,E是…...

【话题】AI是在帮助开发者还是取代他们
大家好,我是全栈小5,欢迎阅读小5的系列文章,这是《话题》系列文章 目录 引言AI在代码生成中的应用AI在错误检测和自动化测试中的作用对开发者职业前景的影响技能需求的变化与适应策略结论文章推荐 引言 随着人工智能(AIÿ…...

精通Perl正则表达式修饰符:提升文本处理能力的艺术
Perl语言以其强大的文本处理能力而闻名,其中正则表达式是其核心特性之一。正则表达式本身非常强大,但Perl提供的修饰符(Modifiers)进一步扩展了正则表达式的灵活性和表达能力。本文将深入探讨Perl中正则表达式修饰符的使用&#x…...

【web前端HTML+CSS+JS】--- HTML学习笔记01
学习链接:黑马程序员pink老师前端入门教程,零基础必看的h5(html5)css3移动端前端视频教程_哔哩哔哩_bilibili 学习文档: Web 开发技术 | MDN (mozilla.org) 一、前后端工作流程 WEB模型:前端用于采集和展示信息,中…...

Go 语言入门(一)
Go Modules依赖包查找机制 下载的第三方的依赖存储在 $GOPATH/pkg/mod 下go install 生成的可执行文件存储在 $GOPATH/bin下依赖查找顺序: 工作目录$GOPATH/pkg/mod$GOPATH/src 一、Go语言基础 1.标识符与关键字 1.1 命名方式 go变量、常量、自定义类型、包…...

爬虫笔记20——票星球抢票脚本的实现
以下内容仅供交流学习使用!!! 思路分析 前面的爬虫笔记一步一步走过来我们的技术水平也有了较大的提升了,现在我们来进行一下票星球抢票实战项目,实现票星球的自动抢票。 我们打开票星球的移动端页面,分…...

DDR3(三)
目录 1 预取1.1 什么是预取1.2 预取有哪些好处1.3 结构框图1.4 总结 2 突发2.1 什么是突发2.2 突发与预取 本文讲解DDR中常见的两个术语:预取和突发,对这两个概念理解的关键在于地址线的低位是否参与译码,具体内容请继续往下看。 1 预取 1.1…...

JDK都出到20多了,你还不会使用JDK8的Stream流写代码吗?
目录 前言 Stream流 是什么? 为什么要用Steam流 常见stream流使用案例 映射 map() & 集合 collect() 单字段映射 多字段映射 映射为其他的对象 映射为 Map 去重 distinct() 过滤 filter() Stream流的其他方法 使用Stream流的弊端 前言 当你某天看…...

QT slots 函数
文章目录 概述小结 概述 在Qt中,slots 是一种特殊的成员函数,它们可以与对象发出的信号连接。当信号被触发时,连接的槽函数会被调用。 来个简单的示例吧,如下图: #include <QObject> #include <QDebug>…...

pycharm如何使用jupyter
目录 配置jupyter新建jupyter文件别人写的方法(在pycharm种安装,在网页中使用) pycharm专业版 配置jupyter 在pycharm终端启动一个conda虚拟环境,输入 conda install jupyter会有很多前置包需要安装: 新建jupyter…...

机器学习——无监督学习(k-means算法)
1、K-Means聚类算法 K表示超参数个数,如分成几个类别,K值就取多少。若无需求,可使用网格搜索找到最佳的K。 步骤: 1、随机设置K个特征空间内的点作为初始聚类中心; 2、对于其他每个点计算到K个中心的距离,…...

强化学习-6 DDPG、PPO、SAC算法
文章目录 1 DPG方法2 DDPG算法3 DDPG算法的优缺点4 TD3算法4.1 双Q网络4.2 延迟更新4.3 噪声正则 5 附15.1 Ornstein-Uhlenbeck (OU) 噪声5.1.1 定义5.1.2 特性5.1.3 直观理解5.1.4 数学性质5.1.5 代码示例5.1.6 总结 6 重要性采样7 PPO算法8 附28.1 重要性采样方差计算8.1.1 公…...

vue3实现多表头列表el-table,拖拽,鼠标滑轮滚动条优化
需求背景解决效果index.vue 需求背景 需要实现多表头列表的用户体验优化 解决效果 index.vue <!--/** * author: liuk * date: 2024-07-03 * describe:**** 多表头列表 */--> <template><el-table ref"tableRef" height"calc(100% - 80px)&qu…...

Micron近期发布了32Gb DDR5 DRAM
Micron Technology近期发布了一项内存技术的重大突破——一款32Gb DDR5 DRAM芯片,这项创新不仅将存储容量翻倍,还显著提升了针对人工智能(AI)、机器学习(ML)、高性能计算(HPC)以及数…...

SQL Server时间转换
第一种:format --转化成年月日 select format( GETDATE(),yyyy-MM-dd) --转化年月日,时分秒,这里的HH指24小时的,hh是12小时的 select format( GETDATE(),yyyy-MM-dd HH:mm:ss) --转化成时分秒的,这里就不一样的&…...

kubernetes集群部署:node节点部署和CRI-O运行时安装(三)
关于CRI-O Kubernetes最初使用Docker作为默认的容器运行时。然而,随着Kubernetes的发展和OCI标准的确立,社区开始寻找更专门化的解决方案,以减少复杂性和提高性能。CRI-O的主要目标是提供一个轻量级的容器运行时,它可以直接运行O…...

03:Spring MVC
文章目录 一:Spring MVC简介1:说说自己对于Spring MVC的了解?1.1:流程说明: 一:Spring MVC简介 Spring MVC就是一个MVC框架,Spring MVC annotation式的开发比Struts2方便,可以直接代…...

玩转springboot之springboot注册servlet
springboot注册servlet 有时候在springboot中依然需要注册servlet,filter,listener,就以servlet为例来进行说明,另外两个也都类似 使用WebServlet注解 在servlet3.0之后,servlet注册支持注解注册,而不需要在…...

推荐好玩的工具之OhMyPosh使用
解除禁止脚本 Set-ExecutionPolicy RemoteSigned 下载Oh My Posh winget install oh-my-posh 或者 Install-Module oh-my-posh -Scope AllUsers 下载Git提示 Install-Module posh-git -Scope CurrentUser 或者 Install-Module posh-git -Scope AllUser 下载命令提示 Install-Mo…...

pydub、ffmpeg 音频文件声道选择转换、采样率更改
快速查看音频通道数和每个通道能力判断具体哪个通道说话;一般能量大的那个算是说话 import wave from pydub import AudioSegment import numpy as npdef read_wav_file(file_path):with wave.open(file_path, rb) as wav_file:params wav_file.getparams()num_cha…...

0803实操-Windows Server系统管理
Windows Server系统管理 系统管理与基础配置 查看系统信息、更改计算机名称 网络配置 启用网络发现 Windows启用网络发现是指在网络设置中启用一个功能,该功能允许您的计算机在网络上识别和访问其他设备和计算机。具体来说,启用网络发现后ÿ…...

使用Java构建物联网应用的最佳实践
使用Java构建物联网应用的最佳实践 大家好,我是免费搭建查券返利机器人省钱赚佣金就用微赚淘客系统3.0的小编,也是冬天不穿秋裤,天冷也要风度的程序猿! 随着物联网(IoT)技术的快速发展,越来越…...

价格预言机的使用总结(一):Chainlink篇
文章首发于公众号:Keegan小钢 前言 价格预言机已经成为了 DeFi 中不可获取的基础设施,很多 DeFi 应用都需要从价格预言机来获取稳定可信的价格数据,包括借贷协议 Compound、AAVE、Liquity ,也包括衍生品交易所 dYdX、PERP 等等。…...

【Pyhton】读取寄存器数据到MySQL数据库
目录 步骤 modsim32软件配置 Navicat for MySQL 代码实现 步骤 安装必要的库:确保安装了pymodbus和pymysql。 配置Modbus连接:设置Modbus从站的IP地址、端口(对于TCP)或串行通信参数(对于RTU)。 连接M…...

jmeter-beanshell学习3-beanshell获取请求报文和响应报文
前后两个报文,后面报文要用前面报文的响应结果,这个简单,正则表达式或者json提取器,都能实现。但是如果后面报文要用前面请求报文的内容,感觉有点难。最早时候把随机数写在自定义变量,前后两个接口都用这个…...

【C++】B树及其实现
写目录 一、B树的基本概念1.引入2.B树的概念 二、B树的实现1.B树的定义2.B树的查找3.B树的插入操作4.B树的删除5.B树的遍历6.B树的高度7.整体代码 三、B树和B*树1.B树2.B*树3.总结 一、B树的基本概念 1.引入 我们已经学习过二叉排序树、AVL树和红黑树三种树形查找结构&#x…...

C++(Qt)-GIS开发-QGraphicsView显示瓦片地图简单示例
C(Qt)-GIS开发-QGraphicsView显示瓦片地图简单示例 文章目录 C(Qt)-GIS开发-QGraphicsView显示瓦片地图简单示例1、概述2、实现效果3、主要代码4、源码地址 更多精彩内容👉个人内容分类汇总 👈👉GIS开发 👈 1、概述 支持多线程加…...

CTFShow的RE题(三)
数学不及格 strtol 函数 long strtol(char str, char **endptr, int base); 将字符串转换为长整型 就是解这个方程组了 主要就是 v4, v9的关系, 3v9-(v10v11v12)62d10d4673 v4 v12 v11 v10 0x13A31412F8C 得到 3*v9v419D024E75FF(1773860189695) 重点&…...

WordPress主题开发进群付费主题v1.1.2 多种引流方式
全新前端UI界面,多种前端交互特效让页面不再单调,进群页面群成员数,群成员头像名称,每次刷新页面随机更新不重复,最下面评论和点赞也是如此随机刷新不重复 进群页面简介,群聊名称,群内展示&…...

SAP中的 UPDATA TASK 和 BACKGROUND TASK
前言: 记录这篇文章起因是调查生产订单报工问题引申出来的一个问题,后来再次调查后了解了其中缘由,大概记录以下,如有不对,欢迎指正。问题原贴如下: SAP CO11N BAPI_PRODORDCONF_CREATE_TT连续报工异步更…...