【图卷积网络】GCN基础原理简单python实现
基础原理讲解
应用路径
卷积网络最经典的就是CNN,其 可以提取图片中的有效信息,而生活中存在大量拓扑结构的数据。图卷积网络主要特点就是在于其输入数据是图结构数据,即 G ( V , E ) G(V,E) G(V,E),其中V是节点,E是边,能有效提取拓扑结构中的有效信息,实现节点分类,边预测等。
基础原理
其核心公式是:
H ( l + 1 ) = σ ( D ~ − 1 / 2 A ~ D ~ − 1 / 2 H l W l ) H^{(l+1)}=\sigma(\tilde{D}^{-1/2}\tilde{A}\tilde{D}^{-1/2}H^{l}W^l) H(l+1)=σ(D~−1/2A~D~−1/2HlWl)
其中:
- σ \sigma σ 是非线性激活函数
- D ~ \tilde{D} D~是度矩阵, D ~ i i = ∑ j A ~ i j \tilde{D}_{ii}=\sum_j\tilde{A}_{ij} D~ii=∑jA~ij
- A ~ \tilde{A} A~是加了自环的邻接矩阵,通常表示为 A + I A+I A+I, A A A是原始邻接矩阵, I I I是单位矩阵
- H l H^l Hl是第 l l l层的节点特征矩阵, H l + 1 H^{l+1} Hl+1是第 l + 1 l+1 l+1层的节点特征矩阵
- W l W^l Wl是第 l l l层的学习权重矩阵
步骤讲解:
1、邻接矩阵归一化: 将邻接矩阵归一化,使得邻居节点特征对中心节点特征的贡献相等。
2、特征聚合: 通过邻接矩阵与节点特征矩阵相乘,实现邻居特征聚合。
3、线性变换: 通过可学习的权重矩阵对聚合后的特征进行线性变换。
加自环的邻接矩阵
A ~ = A + λ I \tilde{A} = A+\lambda I A~=A+λI
邻接矩阵加上一个单位矩阵, λ \lambda λ是一个可以训练的参数,但也可直接取1。加自环 是为了增强节点自我特征表示,这样在进行图卷积操作时,节点不仅会聚合来自邻居节点的特征,还会聚合自己的特征。
图卷积操作

图片的卷积是一个一个卷积核,在图片上滑动着做卷积。图的卷积就是自己加邻居一起做加和。
即:
A ~ X \tilde{A}X A~X
度矩阵求解
D ~ i i = ∑ j A ~ i j \tilde{D}_{ii}=\sum_j\tilde{A}_{ij} D~ii=j∑A~ij

标准化
在进行加和时,节点的度不同,有存在较高度值的节点和较低度值的节点,这可能导致梯度爆炸或梯度消失的问题。
根据度矩阵,求逆,然后 D ~ − 1 A ~ D ~ − 1 X \tilde{D}^{-1}\tilde{A} \tilde{D}^{-1}X D~−1A~D~−1X,就进行了标准化,前一个 D ~ − 1 \tilde{D}^{-1} D~−1是对行进行标准化,后一个 D ~ − 1 \tilde{D}^{-1} D~−1是对列进行标准化。能够实现给与低度节点更大的权重,从而降低高节点的影响。
在上式推导中, D ~ − 1 A ~ D ~ − 1 X \tilde{D}^{-1}\tilde{A} \tilde{D}^{-1}X D~−1A~D~−1X 做了两次标准化,所以修改上式为 D ~ − 1 / 2 A ~ D ~ − 1 / 2 X \tilde{D}^{-1/2}\tilde{A} \tilde{D}^{-1/2}X D~−1/2A~D~−1/2X
简单python实现
基于cora数据集实现节点分类
- cora数据集处理
# cora数据集测试
raw_data = pd.read_csv('./data/data/cora/cora.content', sep='\t', header=None)
print("content shape: ", raw_data.shape)raw_data_cites = pd.read_csv('./data/data/cora/cora.cites', sep='\t', header=None)
print("cites shape: ", raw_data_cites.shape)features = raw_data.iloc[:,1:-1]
print("features shape: ", features.shape)# one-hot encoding
labels = pd.get_dummies(raw_data[1434])
print("\n----head(3) one-hot label----")
print(labels.head(3))
l_ = np.array([0,1,2,3,4,5,6])
lab = []
for i in range(labels.shape[0]):lab.append(l_[labels.loc[i,:].values.astype(bool)][0])
#构建邻接矩阵
num_nodes = raw_data.shape[0]# 将节点重新编号为[0, 2707]
new_id = list(raw_data.index)
id = list(raw_data[0])
c = zip(id, new_id)
map = dict(c)# 根据节点个数定义矩阵维度
matrix = np.zeros((num_nodes,num_nodes))# 根据边构建矩阵
for i ,j in zip(raw_data_cites[0],raw_data_cites[1]):x = map[i] ; y = map[j]matrix[x][y] = matrix[y][x] = 1 # 无向图:有引用关系的样本点之间取1# 查看邻接矩阵的元素
print(matrix.shape)
- GCN网络实现
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f'Using device: {device}')
class GCNLayer(nn.Module):def __init__(self, in_features, out_features):super(GCNLayer, self).__init__()self.linear = nn.Linear(in_features, out_features)def forward(self, x, adj):rowsum = torch.sum(adj,dim=1)d_inv_sqrt = torch.pow(rowsum,-0.5)d_inv_sqrt[torch.isinf(d_inv_sqrt)] =0.0d_mat_inv_sqrt = torch.diag(d_inv_sqrt)adj_normalized = torch.mm(torch.mm(d_mat_inv_sqrt,adj),d_mat_inv_sqrt)out = torch.mm(adj_normalized,x)out = self.linear(out)return out
class GCN(nn.Module):def __init__(self, n_features, n_hidden, n_classes):super(GCN, self).__init__()self.gcn1 = GCNLayer(n_features, n_hidden)self.gcn2 = GCNLayer(n_hidden, n_classes)def forward(self, x, adj):x = self.gcn1(x, adj)x = F.relu(x)x = self.gcn2(x, adj)return x#F.log_softmax(x, dim=1)
# 示例数据(实际数据应根据具体情况加载)features = torch.tensor(features.values, dtype=torch.float32)
adj = torch.tensor(matrix, dtype=torch.float32)
labels = torch.tensor(lab, dtype=torch.long)
# features = torch.tensor([[1, 0], [0, 1], [1, 1]], dtype=torch.float32)
# adj = torch.tensor([[1, 1, 0], [1, 1, 1], [0, 1, 1]], dtype=torch.float32)
# labels = torch.tensor([0, 1, 0], dtype=torch.long)# 模型参数
n_features = features.shape[1]
n_hidden = 16
n_classes = len(torch.unique(labels))# 创建模型
model = GCN(n_features, n_hidden, n_classes)
model = model.cuda()
optimizer = optim.Adam(model.parameters(), lr=0.01)
loss_fn = nn.CrossEntropyLoss()
# 训练模型
n_epochs = 200
for epoch in range(n_epochs):model.train()features, labels = features.cuda(), labels.cuda()adj = adj.cuda()optimizer.zero_grad()output = model(features, adj)loss = loss_fn(output, labels)loss.backward()optimizer.step()if (epoch + 1) % 20 == 0:print(f'Epoch {epoch+1}, Loss: {loss.item()}')
print("Training complete.")
参考
cora数据集及简介
图卷积网络详细介绍
GCN讲解
相关文章:
【图卷积网络】GCN基础原理简单python实现
基础原理讲解 应用路径 卷积网络最经典的就是CNN,其 可以提取图片中的有效信息,而生活中存在大量拓扑结构的数据。图卷积网络主要特点就是在于其输入数据是图结构数据,即 G ( V , E ) G(V,E) G(V,E),其中V是节点,E是…...
【话题】AI是在帮助开发者还是取代他们
大家好,我是全栈小5,欢迎阅读小5的系列文章,这是《话题》系列文章 目录 引言AI在代码生成中的应用AI在错误检测和自动化测试中的作用对开发者职业前景的影响技能需求的变化与适应策略结论文章推荐 引言 随着人工智能(AIÿ…...
精通Perl正则表达式修饰符:提升文本处理能力的艺术
Perl语言以其强大的文本处理能力而闻名,其中正则表达式是其核心特性之一。正则表达式本身非常强大,但Perl提供的修饰符(Modifiers)进一步扩展了正则表达式的灵活性和表达能力。本文将深入探讨Perl中正则表达式修饰符的使用&#x…...
【web前端HTML+CSS+JS】--- HTML学习笔记01
学习链接:黑马程序员pink老师前端入门教程,零基础必看的h5(html5)css3移动端前端视频教程_哔哩哔哩_bilibili 学习文档: Web 开发技术 | MDN (mozilla.org) 一、前后端工作流程 WEB模型:前端用于采集和展示信息,中…...
Go 语言入门(一)
Go Modules依赖包查找机制 下载的第三方的依赖存储在 $GOPATH/pkg/mod 下go install 生成的可执行文件存储在 $GOPATH/bin下依赖查找顺序: 工作目录$GOPATH/pkg/mod$GOPATH/src 一、Go语言基础 1.标识符与关键字 1.1 命名方式 go变量、常量、自定义类型、包…...
爬虫笔记20——票星球抢票脚本的实现
以下内容仅供交流学习使用!!! 思路分析 前面的爬虫笔记一步一步走过来我们的技术水平也有了较大的提升了,现在我们来进行一下票星球抢票实战项目,实现票星球的自动抢票。 我们打开票星球的移动端页面,分…...
DDR3(三)
目录 1 预取1.1 什么是预取1.2 预取有哪些好处1.3 结构框图1.4 总结 2 突发2.1 什么是突发2.2 突发与预取 本文讲解DDR中常见的两个术语:预取和突发,对这两个概念理解的关键在于地址线的低位是否参与译码,具体内容请继续往下看。 1 预取 1.1…...
JDK都出到20多了,你还不会使用JDK8的Stream流写代码吗?
目录 前言 Stream流 是什么? 为什么要用Steam流 常见stream流使用案例 映射 map() & 集合 collect() 单字段映射 多字段映射 映射为其他的对象 映射为 Map 去重 distinct() 过滤 filter() Stream流的其他方法 使用Stream流的弊端 前言 当你某天看…...
QT slots 函数
文章目录 概述小结 概述 在Qt中,slots 是一种特殊的成员函数,它们可以与对象发出的信号连接。当信号被触发时,连接的槽函数会被调用。 来个简单的示例吧,如下图: #include <QObject> #include <QDebug>…...
pycharm如何使用jupyter
目录 配置jupyter新建jupyter文件别人写的方法(在pycharm种安装,在网页中使用) pycharm专业版 配置jupyter 在pycharm终端启动一个conda虚拟环境,输入 conda install jupyter会有很多前置包需要安装: 新建jupyter…...
机器学习——无监督学习(k-means算法)
1、K-Means聚类算法 K表示超参数个数,如分成几个类别,K值就取多少。若无需求,可使用网格搜索找到最佳的K。 步骤: 1、随机设置K个特征空间内的点作为初始聚类中心; 2、对于其他每个点计算到K个中心的距离,…...
强化学习-6 DDPG、PPO、SAC算法
文章目录 1 DPG方法2 DDPG算法3 DDPG算法的优缺点4 TD3算法4.1 双Q网络4.2 延迟更新4.3 噪声正则 5 附15.1 Ornstein-Uhlenbeck (OU) 噪声5.1.1 定义5.1.2 特性5.1.3 直观理解5.1.4 数学性质5.1.5 代码示例5.1.6 总结 6 重要性采样7 PPO算法8 附28.1 重要性采样方差计算8.1.1 公…...
vue3实现多表头列表el-table,拖拽,鼠标滑轮滚动条优化
需求背景解决效果index.vue 需求背景 需要实现多表头列表的用户体验优化 解决效果 index.vue <!--/** * author: liuk * date: 2024-07-03 * describe:**** 多表头列表 */--> <template><el-table ref"tableRef" height"calc(100% - 80px)&qu…...
Micron近期发布了32Gb DDR5 DRAM
Micron Technology近期发布了一项内存技术的重大突破——一款32Gb DDR5 DRAM芯片,这项创新不仅将存储容量翻倍,还显著提升了针对人工智能(AI)、机器学习(ML)、高性能计算(HPC)以及数…...
SQL Server时间转换
第一种:format --转化成年月日 select format( GETDATE(),yyyy-MM-dd) --转化年月日,时分秒,这里的HH指24小时的,hh是12小时的 select format( GETDATE(),yyyy-MM-dd HH:mm:ss) --转化成时分秒的,这里就不一样的&…...
kubernetes集群部署:node节点部署和CRI-O运行时安装(三)
关于CRI-O Kubernetes最初使用Docker作为默认的容器运行时。然而,随着Kubernetes的发展和OCI标准的确立,社区开始寻找更专门化的解决方案,以减少复杂性和提高性能。CRI-O的主要目标是提供一个轻量级的容器运行时,它可以直接运行O…...
03:Spring MVC
文章目录 一:Spring MVC简介1:说说自己对于Spring MVC的了解?1.1:流程说明: 一:Spring MVC简介 Spring MVC就是一个MVC框架,Spring MVC annotation式的开发比Struts2方便,可以直接代…...
玩转springboot之springboot注册servlet
springboot注册servlet 有时候在springboot中依然需要注册servlet,filter,listener,就以servlet为例来进行说明,另外两个也都类似 使用WebServlet注解 在servlet3.0之后,servlet注册支持注解注册,而不需要在…...
推荐好玩的工具之OhMyPosh使用
解除禁止脚本 Set-ExecutionPolicy RemoteSigned 下载Oh My Posh winget install oh-my-posh 或者 Install-Module oh-my-posh -Scope AllUsers 下载Git提示 Install-Module posh-git -Scope CurrentUser 或者 Install-Module posh-git -Scope AllUser 下载命令提示 Install-Mo…...
pydub、ffmpeg 音频文件声道选择转换、采样率更改
快速查看音频通道数和每个通道能力判断具体哪个通道说话;一般能量大的那个算是说话 import wave from pydub import AudioSegment import numpy as npdef read_wav_file(file_path):with wave.open(file_path, rb) as wav_file:params wav_file.getparams()num_cha…...
idea大量爆红问题解决
问题描述 在学习和工作中,idea是程序员不可缺少的一个工具,但是突然在有些时候就会出现大量爆红的问题,发现无法跳转,无论是关机重启或者是替换root都无法解决 就是如上所展示的问题,但是程序依然可以启动。 问题解决…...
python打卡day49
知识点回顾: 通道注意力模块复习空间注意力模块CBAM的定义 作业:尝试对今天的模型检查参数数目,并用tensorboard查看训练过程 import torch import torch.nn as nn# 定义通道注意力 class ChannelAttention(nn.Module):def __init__(self,…...
连锁超市冷库节能解决方案:如何实现超市降本增效
在连锁超市冷库运营中,高能耗、设备损耗快、人工管理低效等问题长期困扰企业。御控冷库节能解决方案通过智能控制化霜、按需化霜、实时监控、故障诊断、自动预警、远程控制开关六大核心技术,实现年省电费15%-60%,且不改动原有装备、安装快捷、…...
376. Wiggle Subsequence
376. Wiggle Subsequence 代码 class Solution { public:int wiggleMaxLength(vector<int>& nums) {int n nums.size();int res 1;int prediff 0;int curdiff 0;for(int i 0;i < n-1;i){curdiff nums[i1] - nums[i];if( (prediff > 0 && curdif…...
页面渲染流程与性能优化
页面渲染流程与性能优化详解(完整版) 一、现代浏览器渲染流程(详细说明) 1. 构建DOM树 浏览器接收到HTML文档后,会逐步解析并构建DOM(Document Object Model)树。具体过程如下: (…...
如何将联系人从 iPhone 转移到 Android
从 iPhone 换到 Android 手机时,你可能需要保留重要的数据,例如通讯录。好在,将通讯录从 iPhone 转移到 Android 手机非常简单,你可以从本文中学习 6 种可靠的方法,确保随时保持连接,不错过任何信息。 第 1…...
Python爬虫(二):爬虫完整流程
爬虫完整流程详解(7大核心步骤实战技巧) 一、爬虫完整工作流程 以下是爬虫开发的完整流程,我将结合具体技术点和实战经验展开说明: 1. 目标分析与前期准备 网站技术分析: 使用浏览器开发者工具(F12&…...
数据链路层的主要功能是什么
数据链路层(OSI模型第2层)的核心功能是在相邻网络节点(如交换机、主机)间提供可靠的数据帧传输服务,主要职责包括: 🔑 核心功能详解: 帧封装与解封装 封装: 将网络层下发…...
现代密码学 | 椭圆曲线密码学—附py代码
Elliptic Curve Cryptography 椭圆曲线密码学(ECC)是一种基于有限域上椭圆曲线数学特性的公钥加密技术。其核心原理涉及椭圆曲线的代数性质、离散对数问题以及有限域上的运算。 椭圆曲线密码学是多种数字签名算法的基础,例如椭圆曲线数字签…...
什么?连接服务器也能可视化显示界面?:基于X11 Forwarding + CentOS + MobaXterm实战指南
文章目录 什么是X11?环境准备实战步骤1️⃣ 服务器端配置(CentOS)2️⃣ 客户端配置(MobaXterm)3️⃣ 验证X11 Forwarding4️⃣ 运行自定义GUI程序(Python示例)5️⃣ 成功效果