【图卷积网络】GCN基础原理简单python实现
基础原理讲解
应用路径
卷积网络最经典的就是CNN,其 可以提取图片中的有效信息,而生活中存在大量拓扑结构的数据。图卷积网络主要特点就是在于其输入数据是图结构数据,即 G ( V , E ) G(V,E) G(V,E),其中V是节点,E是边,能有效提取拓扑结构中的有效信息,实现节点分类,边预测等。
基础原理
其核心公式是:
H ( l + 1 ) = σ ( D ~ − 1 / 2 A ~ D ~ − 1 / 2 H l W l ) H^{(l+1)}=\sigma(\tilde{D}^{-1/2}\tilde{A}\tilde{D}^{-1/2}H^{l}W^l) H(l+1)=σ(D~−1/2A~D~−1/2HlWl)
其中:
- σ \sigma σ 是非线性激活函数
- D ~ \tilde{D} D~是度矩阵, D ~ i i = ∑ j A ~ i j \tilde{D}_{ii}=\sum_j\tilde{A}_{ij} D~ii=∑jA~ij
- A ~ \tilde{A} A~是加了自环的邻接矩阵,通常表示为 A + I A+I A+I, A A A是原始邻接矩阵, I I I是单位矩阵
- H l H^l Hl是第 l l l层的节点特征矩阵, H l + 1 H^{l+1} Hl+1是第 l + 1 l+1 l+1层的节点特征矩阵
- W l W^l Wl是第 l l l层的学习权重矩阵
步骤讲解:
1、邻接矩阵归一化: 将邻接矩阵归一化,使得邻居节点特征对中心节点特征的贡献相等。
2、特征聚合: 通过邻接矩阵与节点特征矩阵相乘,实现邻居特征聚合。
3、线性变换: 通过可学习的权重矩阵对聚合后的特征进行线性变换。
加自环的邻接矩阵
A ~ = A + λ I \tilde{A} = A+\lambda I A~=A+λI
邻接矩阵加上一个单位矩阵, λ \lambda λ是一个可以训练的参数,但也可直接取1。加自环 是为了增强节点自我特征表示,这样在进行图卷积操作时,节点不仅会聚合来自邻居节点的特征,还会聚合自己的特征。
图卷积操作
图片的卷积是一个一个卷积核,在图片上滑动着做卷积。图的卷积就是自己加邻居一起做加和。
即:
A ~ X \tilde{A}X A~X
度矩阵求解
D ~ i i = ∑ j A ~ i j \tilde{D}_{ii}=\sum_j\tilde{A}_{ij} D~ii=j∑A~ij
标准化
在进行加和时,节点的度不同,有存在较高度值的节点和较低度值的节点,这可能导致梯度爆炸或梯度消失的问题。
根据度矩阵,求逆,然后 D ~ − 1 A ~ D ~ − 1 X \tilde{D}^{-1}\tilde{A} \tilde{D}^{-1}X D~−1A~D~−1X,就进行了标准化,前一个 D ~ − 1 \tilde{D}^{-1} D~−1是对行进行标准化,后一个 D ~ − 1 \tilde{D}^{-1} D~−1是对列进行标准化。能够实现给与低度节点更大的权重,从而降低高节点的影响。
在上式推导中, D ~ − 1 A ~ D ~ − 1 X \tilde{D}^{-1}\tilde{A} \tilde{D}^{-1}X D~−1A~D~−1X 做了两次标准化,所以修改上式为 D ~ − 1 / 2 A ~ D ~ − 1 / 2 X \tilde{D}^{-1/2}\tilde{A} \tilde{D}^{-1/2}X D~−1/2A~D~−1/2X
简单python实现
基于cora数据集实现节点分类
- cora数据集处理
# cora数据集测试
raw_data = pd.read_csv('./data/data/cora/cora.content', sep='\t', header=None)
print("content shape: ", raw_data.shape)raw_data_cites = pd.read_csv('./data/data/cora/cora.cites', sep='\t', header=None)
print("cites shape: ", raw_data_cites.shape)features = raw_data.iloc[:,1:-1]
print("features shape: ", features.shape)# one-hot encoding
labels = pd.get_dummies(raw_data[1434])
print("\n----head(3) one-hot label----")
print(labels.head(3))
l_ = np.array([0,1,2,3,4,5,6])
lab = []
for i in range(labels.shape[0]):lab.append(l_[labels.loc[i,:].values.astype(bool)][0])
#构建邻接矩阵
num_nodes = raw_data.shape[0]# 将节点重新编号为[0, 2707]
new_id = list(raw_data.index)
id = list(raw_data[0])
c = zip(id, new_id)
map = dict(c)# 根据节点个数定义矩阵维度
matrix = np.zeros((num_nodes,num_nodes))# 根据边构建矩阵
for i ,j in zip(raw_data_cites[0],raw_data_cites[1]):x = map[i] ; y = map[j]matrix[x][y] = matrix[y][x] = 1 # 无向图:有引用关系的样本点之间取1# 查看邻接矩阵的元素
print(matrix.shape)
- GCN网络实现
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f'Using device: {device}')
class GCNLayer(nn.Module):def __init__(self, in_features, out_features):super(GCNLayer, self).__init__()self.linear = nn.Linear(in_features, out_features)def forward(self, x, adj):rowsum = torch.sum(adj,dim=1)d_inv_sqrt = torch.pow(rowsum,-0.5)d_inv_sqrt[torch.isinf(d_inv_sqrt)] =0.0d_mat_inv_sqrt = torch.diag(d_inv_sqrt)adj_normalized = torch.mm(torch.mm(d_mat_inv_sqrt,adj),d_mat_inv_sqrt)out = torch.mm(adj_normalized,x)out = self.linear(out)return out
class GCN(nn.Module):def __init__(self, n_features, n_hidden, n_classes):super(GCN, self).__init__()self.gcn1 = GCNLayer(n_features, n_hidden)self.gcn2 = GCNLayer(n_hidden, n_classes)def forward(self, x, adj):x = self.gcn1(x, adj)x = F.relu(x)x = self.gcn2(x, adj)return x#F.log_softmax(x, dim=1)
# 示例数据(实际数据应根据具体情况加载)features = torch.tensor(features.values, dtype=torch.float32)
adj = torch.tensor(matrix, dtype=torch.float32)
labels = torch.tensor(lab, dtype=torch.long)
# features = torch.tensor([[1, 0], [0, 1], [1, 1]], dtype=torch.float32)
# adj = torch.tensor([[1, 1, 0], [1, 1, 1], [0, 1, 1]], dtype=torch.float32)
# labels = torch.tensor([0, 1, 0], dtype=torch.long)# 模型参数
n_features = features.shape[1]
n_hidden = 16
n_classes = len(torch.unique(labels))# 创建模型
model = GCN(n_features, n_hidden, n_classes)
model = model.cuda()
optimizer = optim.Adam(model.parameters(), lr=0.01)
loss_fn = nn.CrossEntropyLoss()
# 训练模型
n_epochs = 200
for epoch in range(n_epochs):model.train()features, labels = features.cuda(), labels.cuda()adj = adj.cuda()optimizer.zero_grad()output = model(features, adj)loss = loss_fn(output, labels)loss.backward()optimizer.step()if (epoch + 1) % 20 == 0:print(f'Epoch {epoch+1}, Loss: {loss.item()}')
print("Training complete.")
参考
cora数据集及简介
图卷积网络详细介绍
GCN讲解
相关文章:

【图卷积网络】GCN基础原理简单python实现
基础原理讲解 应用路径 卷积网络最经典的就是CNN,其 可以提取图片中的有效信息,而生活中存在大量拓扑结构的数据。图卷积网络主要特点就是在于其输入数据是图结构数据,即 G ( V , E ) G(V,E) G(V,E),其中V是节点,E是…...

【话题】AI是在帮助开发者还是取代他们
大家好,我是全栈小5,欢迎阅读小5的系列文章,这是《话题》系列文章 目录 引言AI在代码生成中的应用AI在错误检测和自动化测试中的作用对开发者职业前景的影响技能需求的变化与适应策略结论文章推荐 引言 随着人工智能(AIÿ…...
精通Perl正则表达式修饰符:提升文本处理能力的艺术
Perl语言以其强大的文本处理能力而闻名,其中正则表达式是其核心特性之一。正则表达式本身非常强大,但Perl提供的修饰符(Modifiers)进一步扩展了正则表达式的灵活性和表达能力。本文将深入探讨Perl中正则表达式修饰符的使用&#x…...

【web前端HTML+CSS+JS】--- HTML学习笔记01
学习链接:黑马程序员pink老师前端入门教程,零基础必看的h5(html5)css3移动端前端视频教程_哔哩哔哩_bilibili 学习文档: Web 开发技术 | MDN (mozilla.org) 一、前后端工作流程 WEB模型:前端用于采集和展示信息,中…...
Go 语言入门(一)
Go Modules依赖包查找机制 下载的第三方的依赖存储在 $GOPATH/pkg/mod 下go install 生成的可执行文件存储在 $GOPATH/bin下依赖查找顺序: 工作目录$GOPATH/pkg/mod$GOPATH/src 一、Go语言基础 1.标识符与关键字 1.1 命名方式 go变量、常量、自定义类型、包…...

爬虫笔记20——票星球抢票脚本的实现
以下内容仅供交流学习使用!!! 思路分析 前面的爬虫笔记一步一步走过来我们的技术水平也有了较大的提升了,现在我们来进行一下票星球抢票实战项目,实现票星球的自动抢票。 我们打开票星球的移动端页面,分…...

DDR3(三)
目录 1 预取1.1 什么是预取1.2 预取有哪些好处1.3 结构框图1.4 总结 2 突发2.1 什么是突发2.2 突发与预取 本文讲解DDR中常见的两个术语:预取和突发,对这两个概念理解的关键在于地址线的低位是否参与译码,具体内容请继续往下看。 1 预取 1.1…...

JDK都出到20多了,你还不会使用JDK8的Stream流写代码吗?
目录 前言 Stream流 是什么? 为什么要用Steam流 常见stream流使用案例 映射 map() & 集合 collect() 单字段映射 多字段映射 映射为其他的对象 映射为 Map 去重 distinct() 过滤 filter() Stream流的其他方法 使用Stream流的弊端 前言 当你某天看…...
QT slots 函数
文章目录 概述小结 概述 在Qt中,slots 是一种特殊的成员函数,它们可以与对象发出的信号连接。当信号被触发时,连接的槽函数会被调用。 来个简单的示例吧,如下图: #include <QObject> #include <QDebug>…...

pycharm如何使用jupyter
目录 配置jupyter新建jupyter文件别人写的方法(在pycharm种安装,在网页中使用) pycharm专业版 配置jupyter 在pycharm终端启动一个conda虚拟环境,输入 conda install jupyter会有很多前置包需要安装: 新建jupyter…...

机器学习——无监督学习(k-means算法)
1、K-Means聚类算法 K表示超参数个数,如分成几个类别,K值就取多少。若无需求,可使用网格搜索找到最佳的K。 步骤: 1、随机设置K个特征空间内的点作为初始聚类中心; 2、对于其他每个点计算到K个中心的距离,…...
强化学习-6 DDPG、PPO、SAC算法
文章目录 1 DPG方法2 DDPG算法3 DDPG算法的优缺点4 TD3算法4.1 双Q网络4.2 延迟更新4.3 噪声正则 5 附15.1 Ornstein-Uhlenbeck (OU) 噪声5.1.1 定义5.1.2 特性5.1.3 直观理解5.1.4 数学性质5.1.5 代码示例5.1.6 总结 6 重要性采样7 PPO算法8 附28.1 重要性采样方差计算8.1.1 公…...
vue3实现多表头列表el-table,拖拽,鼠标滑轮滚动条优化
需求背景解决效果index.vue 需求背景 需要实现多表头列表的用户体验优化 解决效果 index.vue <!--/** * author: liuk * date: 2024-07-03 * describe:**** 多表头列表 */--> <template><el-table ref"tableRef" height"calc(100% - 80px)&qu…...

Micron近期发布了32Gb DDR5 DRAM
Micron Technology近期发布了一项内存技术的重大突破——一款32Gb DDR5 DRAM芯片,这项创新不仅将存储容量翻倍,还显著提升了针对人工智能(AI)、机器学习(ML)、高性能计算(HPC)以及数…...
SQL Server时间转换
第一种:format --转化成年月日 select format( GETDATE(),yyyy-MM-dd) --转化年月日,时分秒,这里的HH指24小时的,hh是12小时的 select format( GETDATE(),yyyy-MM-dd HH:mm:ss) --转化成时分秒的,这里就不一样的&…...

kubernetes集群部署:node节点部署和CRI-O运行时安装(三)
关于CRI-O Kubernetes最初使用Docker作为默认的容器运行时。然而,随着Kubernetes的发展和OCI标准的确立,社区开始寻找更专门化的解决方案,以减少复杂性和提高性能。CRI-O的主要目标是提供一个轻量级的容器运行时,它可以直接运行O…...

03:Spring MVC
文章目录 一:Spring MVC简介1:说说自己对于Spring MVC的了解?1.1:流程说明: 一:Spring MVC简介 Spring MVC就是一个MVC框架,Spring MVC annotation式的开发比Struts2方便,可以直接代…...
玩转springboot之springboot注册servlet
springboot注册servlet 有时候在springboot中依然需要注册servlet,filter,listener,就以servlet为例来进行说明,另外两个也都类似 使用WebServlet注解 在servlet3.0之后,servlet注册支持注解注册,而不需要在…...

推荐好玩的工具之OhMyPosh使用
解除禁止脚本 Set-ExecutionPolicy RemoteSigned 下载Oh My Posh winget install oh-my-posh 或者 Install-Module oh-my-posh -Scope AllUsers 下载Git提示 Install-Module posh-git -Scope CurrentUser 或者 Install-Module posh-git -Scope AllUser 下载命令提示 Install-Mo…...

pydub、ffmpeg 音频文件声道选择转换、采样率更改
快速查看音频通道数和每个通道能力判断具体哪个通道说话;一般能量大的那个算是说话 import wave from pydub import AudioSegment import numpy as npdef read_wav_file(file_path):with wave.open(file_path, rb) as wav_file:params wav_file.getparams()num_cha…...
Java多线程实现之Callable接口深度解析
Java多线程实现之Callable接口深度解析 一、Callable接口概述1.1 接口定义1.2 与Runnable接口的对比1.3 Future接口与FutureTask类 二、Callable接口的基本使用方法2.1 传统方式实现Callable接口2.2 使用Lambda表达式简化Callable实现2.3 使用FutureTask类执行Callable任务 三、…...

从零开始打造 OpenSTLinux 6.6 Yocto 系统(基于STM32CubeMX)(九)
设备树移植 和uboot设备树修改的内容同步到kernel将设备树stm32mp157d-stm32mp157daa1-mx.dts复制到内核源码目录下 源码修改及编译 修改arch/arm/boot/dts/st/Makefile,新增设备树编译 stm32mp157f-ev1-m4-examples.dtb \stm32mp157d-stm32mp157daa1-mx.dtb修改…...

基于Docker Compose部署Java微服务项目
一. 创建根项目 根项目(父项目)主要用于依赖管理 一些需要注意的点: 打包方式需要为 pom<modules>里需要注册子模块不要引入maven的打包插件,否则打包时会出问题 <?xml version"1.0" encoding"UTF-8…...
Rust 异步编程
Rust 异步编程 引言 Rust 是一种系统编程语言,以其高性能、安全性以及零成本抽象而著称。在多核处理器成为主流的今天,异步编程成为了一种提高应用性能、优化资源利用的有效手段。本文将深入探讨 Rust 异步编程的核心概念、常用库以及最佳实践。 异步编程基础 什么是异步…...
【RockeMQ】第2节|RocketMQ快速实战以及核⼼概念详解(二)
升级Dledger高可用集群 一、主从架构的不足与Dledger的定位 主从架构缺陷 数据备份依赖Slave节点,但无自动故障转移能力,Master宕机后需人工切换,期间消息可能无法读取。Slave仅存储数据,无法主动升级为Master响应请求ÿ…...

什么是Ansible Jinja2
理解 Ansible Jinja2 模板 Ansible 是一款功能强大的开源自动化工具,可让您无缝地管理和配置系统。Ansible 的一大亮点是它使用 Jinja2 模板,允许您根据变量数据动态生成文件、配置设置和脚本。本文将向您介绍 Ansible 中的 Jinja2 模板,并通…...

听写流程自动化实践,轻量级教育辅助
随着智能教育工具的发展,越来越多的传统学习方式正在被数字化、自动化所优化。听写作为语文、英语等学科中重要的基础训练形式,也迎来了更高效的解决方案。 这是一款轻量但功能强大的听写辅助工具。它是基于本地词库与可选在线语音引擎构建,…...
Xen Server服务器释放磁盘空间
disk.sh #!/bin/bashcd /run/sr-mount/e54f0646-ae11-0457-b64f-eba4673b824c # 全部虚拟机物理磁盘文件存储 a$(ls -l | awk {print $NF} | cut -d. -f1) # 使用中的虚拟机物理磁盘文件 b$(xe vm-disk-list --multiple | grep uuid | awk {print $NF})printf "%s\n"…...

嵌入式学习笔记DAY33(网络编程——TCP)
一、网络架构 C/S (client/server 客户端/服务器):由客户端和服务器端两个部分组成。客户端通常是用户使用的应用程序,负责提供用户界面和交互逻辑 ,接收用户输入,向服务器发送请求,并展示服务…...

Chrome 浏览器前端与客户端双向通信实战
Chrome 前端(即页面 JS / Web UI)与客户端(C 后端)的交互机制,是 Chromium 架构中非常核心的一环。下面我将按常见场景,从通道、流程、技术栈几个角度做一套完整的分析,特别适合你这种在分析和改…...