图卷积网络:从理论到实践
图卷积网络(Graph Convolutional Networks, GCNs)彻底改变了基于图的机器学习领域,使得深度学习能够应用于非欧几里得结构,如社交网络、引文网络和分子结构。本文将解释GCN的直观理解、数学原理,并提供代码片段帮助您理解和实现基础的GCN。
图表示法基础
定义图G = (V, E),其中:
- V V V:节点集合
- E E E:边集合
- A ∈ R N × N A \in \mathbb{R}^{N \times N} A∈RN×N:邻接矩阵
- X ∈ R N × F X \in \mathbb{R}^{N \times F} X∈RN×F:节点特征矩阵
其中, N N N是节点数量, F F F是每个节点的输入特征数量。
邻接矩阵
邻接矩阵是表示图中节点之间连接(边)的一种方式。
- 对于具有 N N N个节点的图, A A A是一个 N × N N \times N N×N的矩阵。
- 如果节点 i i i和节点 j j j之间有边,则 A i j = 1 A_{ij} = 1 Aij=1(如果带权重,则为边的权重);否则 A i j = 0 A_{ij} = 0 Aij=0。
- 在无向图中, A A A是对称的( A i j = A j i A_{ij} = A_{ji} Aij=Aji)。
- 例如,一个3节点图,其中节点0连接到节点1和2:
A = [ 0 1 1 1 0 0 1 0 0 ] A = \begin{bmatrix} 0 & 1 & 1 \\ 1 & 0 & 0 \\ 1 & 0 & 0 \end{bmatrix} A= 011100100
节点特征矩阵
节点特征矩阵存储图中每个节点的特征(属性)。
- N N N是节点数量, F F F是每个节点的特征数量。
- 每一行 X i X_i Xi是节点 i i i的特征向量。
- 例如,如果每个节点有3个特征(比如年龄、收入和组别),共有4个节点:
X = [ 23 50000 1 35 60000 2 29 52000 1 41 58000 3 ] X = \begin{bmatrix} 23 & 50000 & 1 \\ 35 & 60000 & 2 \\ 29 & 52000 & 1 \\ 41 & 58000 & 3 \end{bmatrix} X= 23352941500006000052000580001213 - 这些特征是GCN用来学习的输入。
两者共同构成了图卷积网络的基本输入:
- 邻接矩阵 A A A描述了节点如何连接。
- 节点特征矩阵 X X X描述了每个节点的特征。
GCN层公式(Kipf & Welling, 2016)
GCN层的核心公式如下:
H ( l + 1 ) = σ ( D ~ − 1 / 2 A ~ D ~ − 1 / 2 H ( l ) W ( l ) ) H^{(l+1)} = \sigma(\tilde{D}^{-1/2} \tilde{A} \tilde{D}^{-1/2} H^{(l)} W^{(l)}) H(l+1)=σ(D~−1/2A~D~−1/2H(l)W(l))
这个公式包含了很多信息,我们将在下面详细解析:
输入:
- H ( l ) H^{(l)} H(l):上一层的节点特征(对于第一层, H ( 0 ) = X H^{(0)} = X H(0)=X,即输入特征)
- A ~ = A + I \tilde{A} = A + I A~=A+I:添加了自环的邻接矩阵( I I I是单位矩阵)。图中的自环是指节点与自身相连的边。在邻接矩阵中,节点 i i i的自环表示为 A ~ i i = 1 \tilde{A}_{ii} = 1 A~ii=1。添加自环后,我们得到新矩阵: A ~ = A + I \tilde{A} = A + I A~=A+I。这一步很重要,因为我们希望在聚合时保留节点自身的特征。否则,节点只能从邻居获取信息,而丢失了自身特征。
- D ~ \tilde{D} D~: A ~ \tilde{A} A~的对角度矩阵(包含每个节点的连接数,包括自环)
- W ( l ) W^{(l)} W(l):第 l l l层的可训练权重矩阵
- σ \sigma σ:非线性激活函数(如ReLU)
关键操作:
- 消息传递:
- A ~ H ( l ) \tilde{A}H^{(l)} A~H(l):每个节点聚合其邻居的特征向量
- 添加自环( A ~ = A + I \tilde{A} = A + I A~=A+I)确保节点在聚合时包含自身特征
- 归一化:防止特征尺度在层间变化过大,通过节点度进行归一化有助于训练稳定性
- D ~ − 1 / 2 A ~ D ~ − 1 / 2 \tilde{D}^{-1/2} \tilde{A} \tilde{D}^{-1/2} D~−1/2A~D~−1/2:这步称为对称归一化或重归一化技巧。
- 如果没有归一化,具有许多连接(高度数)的节点在聚合后会有更大的特征值,这可能导致数值不稳定和训练困难。
- D ~ \tilde{D} D~:度矩阵(对角矩阵,其中 D ~ i i = ∑ j A ~ i j \tilde{D}_{ii} = \sum_j \tilde{A}_{ij} D~ii=∑jA~ij)
- D ~ − 1 / 2 \tilde{D}^{-1/2} D~−1/2:度矩阵的逆平方根
- 左乘( D ~ − 1 / 2 A ~ \tilde{D}^{-1/2} \tilde{A} D~−1/2A~):将每一行除以节点度数的平方根。这归一化了每个节点发出消息的影响。
- 右乘( ⋅ D ~ − 1 / 2 \cdot \tilde{D}^{-1/2} ⋅D~−1/2):将每一列除以节点度数的平方根。这归一化了每个节点接收消息的影响。
考虑一个简单的3节点图:
节点0连接到节点1
节点1连接到节点0和2
节点2连接到节点1
添加自环后:
A = [[1, 1, 0],[1, 1, 1],[0, 1, 1]]D = [[2, 0, 0],[0, 3, 0],[0, 0, 2]] # 度数:2, 3, 2D^(-1/2) = [[1/√2, 0, 0 ],[0, 1/√3, 0 ],[0, 0, 1/√2]]
归一化后的矩阵为:
D^(-1/2)AD^(-1/2) = [[1/2, 1/√6, 0 ],[1/√6, 1/3, 1/√6 ],[0, 1/√6, 1/2 ]]
在每一层,节点都会聚合来自其邻居(包括自身)的信息。网络越深,信息传播得越远。每个节点的新表示是其自身特征和邻居特征的加权平均。权重通过训练过程学习得到。归一化确保具有许多邻居的节点不会主导学习过程。
在社交网络中,每个人(节点)都有一些特征(如年龄、兴趣等),GCN层让每个人根据其朋友的信息更新自己的理解。归一化确保受欢迎的人(有很多朋友)不会主导学习过程。
在Cora数据集上实现节点分类的GCN
Cora数据集是一个引文网络,其中节点代表学术论文,边代表引用关系。每篇论文都有一组特征(如作者、标题、摘要)和一个标签(如论文主题)。总共有2,780篇论文(节点)和5,429条引用(边)。每篇论文由一个二进制词向量表示,表示1,433个唯一词典单词的存在(1)或不存在(0)。论文被分为7个类别(如神经网络、概率方法等)。目标是根据每篇论文的特征和引用关系预测其类别。
模型架构
GCN模型有2层:
class GCN(torch.nn.Module):def __init__(self):super(GCN, self).__init__()self.conv1 = GCNConv(dataset.num_node_features, 16) # 输入到隐藏层self.conv2 = GCNConv(16, dataset.num_classes) # 隐藏层到输出
第一层GCN将输入特征(1,433维)降维到16维。第二层GCN将16维降维到7维(类别数)。
前向传播函数
def forward(self):x, edge_index = data.x, data.edge_indexx = self.conv1(x, edge_index) # 第一层GCNx = F.relu(x) # 非线性激活x = F.dropout(x, training=self.training) # 可选的dropoutx = self.conv2(x, edge_index) # 第二层GCNreturn F.log_softmax(x, dim=1) # 每个类别的对数概率
x = self.conv1(x, edge_index)
做了几件事:它向图中添加自环,计算归一化邻接矩阵 D ~ − 1 / 2 A ~ D ~ − 1 / 2 \tilde{D}^{-1/2} \tilde{A} \tilde{D}^{-1/2} D~−1/2A~D~−1/2,与输入特征和权重 H ( l ) W ( l ) H^{(l)} W^{(l)} H(l)W(l)相乘,并应用归一化和聚合。基本上,所有复杂的数学运算都由GCNConv层处理了。F.relu(x)
应用ReLU激活函数,F.dropout(x, training=self.training)
应用dropout来防止过拟合。第二层GCN x = self.conv2(x, edge_index)
做同样的事情,但是使用不同的权重 H ( l ) W ( l ) H^{(l)} W^{(l)} H(l)W(l)。
训练过程
model = GCN()
data = dataset[0] # 获取第一个图对象
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)model.train()
for epoch in range(200):optimizer.zero_grad()out = model(data)loss = F.nll_loss(out[data.train_mask], data.y[data.train_mask])loss.backward()optimizer.step()
我们使用带权重衰减的Adam优化器。Adam是一种自适应学习率优化算法,它结合了AdaGrad和RMSProp的优点。它维护每个参数的学习率,并使用梯度的移动平均和梯度平方的移动平均。由于稀疏梯度在GNN中很常见,使用Adam是合理的。
它有两个主要参数:lr
是学习率,weight_decay
是L2正则化参数。权重衰减通过向损失函数添加惩罚项来防止过拟合,并将模型权重推向较小的值,防止任何单个权重变得过大。使用L2时,原始损失 L ( θ ) L(\theta) L(θ)变为 L ( θ ) + λ ∑ θ i 2 L(\theta) + \lambda \sum \theta_i^2 L(θ)+λ∑θi2,其中 λ \lambda λ是权重衰减参数。weight_decay=5e-4
意味着 λ = 0.0005 \lambda = 0.0005 λ=0.0005。它通过保持权重较小来防止过拟合,并使模型对未见过的数据更具泛化能力。
loss = F.nll_loss(...)
是负对数似然损失(NLL),通常用于分类任务。它衡量模型的预测概率与真实标签的匹配程度。对于单个样本,它表示为 − log ( p 真实类别 ) -\log(p_{\text{真实类别}}) −log(p真实类别)。如果模型对正确类别100%确信,则损失为0。data.train_mask
是一个布尔掩码,指示哪些节点在训练集中。data.y
是每个节点的标签。我们只使用train_mask
为True的节点进行训练。val_mask
用于验证的节点,test_mask
用于最终评估的节点。
与许多图数据集一样,标签仅对节点的一个小子集可用,模型通过有监督损失从标记节点学习,并通过图结构从未标记节点学习。因此,这是半监督学习。在Cora数据集中,总共有2,708个节点,其中约140个节点(5%)用于训练,500个用于验证,1000个用于测试。GCN假设相连的节点可能相似。这被称为同质性假设,它被编码到学习算法中。GCN的消息传递直接编码了这些偏差。
模型评估
model.eval()
pred = model().argmax(dim=1) # 获取预测类别
correct = pred[data.test_mask] == data.y[data.test_mask]
accuracy = int(correct.sum()) / int(data.test_mask.sum())
完整代码如下。首先,安装必要的包:
pip install torch-geometric
import torch
import torch.nn.functional as F
from torch_geometric.datasets import Planetoid
from torch_geometric.nn import GCNConv# 加载数据
dataset = Planetoid(root='/tmp/Cora', name='Cora')
data = dataset[0]# 定义GCN模型
class GCN(torch.nn.Module):def __init__(self):super(GCN, self).__init__()self.conv1 = GCNConv(dataset.num_node_features, 16)self.conv2 = GCNConv(16, dataset.num_classes)def forward(self):x, edge_index = data.x, data.edge_indexx = self.conv1(x, edge_index)x = F.relu(x)x = self.conv2(x, edge_index)return F.log_softmax(x, dim=1)# 训练循环
model = GCN()
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)for epoch in range(200):model.train()optimizer.zero_grad()out = model()loss = F.nll_loss(out[data.train_mask], data.y[data.train_mask])loss.backward()optimizer.step()if epoch % 20 == 0:print(f'Epoch {epoch}, Loss: {loss.item():.4f}')# 评估
model.eval()
pred = model().argmax(dim=1)
correct = pred[data.test_mask] == data.y[data.test_mask]
accuracy = int(correct.sum()) / int(data.test_mask.sum())
print(f'测试准确率: {accuracy:.4f}')
运行结果:
Epoch 0, Loss: 1.9515
Epoch 20, Loss: 0.1116
Epoch 40, Loss: 0.0147
Epoch 60, Loss: 0.0142
Epoch 80, Loss: 0.0166
Epoch 100, Loss: 0.0155
Epoch 120, Loss: 0.0137
Epoch 140, Loss: 0.0124
Epoch 160, Loss: 0.0114
Epoch 180, Loss: 0.0107
测试准确率: 0.8100
我们可以看到,模型在只看到少量标记节点的情况下就能达到相当不错的准确率(81%)。这展示了图结构与节点特征结合的力量。在下一篇博客中,我们将介绍EvolveGCN,这是一个可以处理动态图数据的动态GCN模型。
相关文章:

图卷积网络:从理论到实践
图卷积网络(Graph Convolutional Networks, GCNs)彻底改变了基于图的机器学习领域,使得深度学习能够应用于非欧几里得结构,如社交网络、引文网络和分子结构。本文将解释GCN的直观理解、数学原理,并提供代码片段帮助您理…...

ES 学习总结一 基础内容
ElasticSearch学习 一、 初识ES1、 认识与安装2、 倒排索引2.1 正向索引2.2 倒排索引 3、 基本概念3.1 文档和字段3.2 索引和倒排 4 、 IK分词器 二、 操作1、 mapping 映射属性2、 索引库增删改查3、 文档的增删改查3.1 新增文档3.2 查询文档3.3 删除文档3.4 修改文档3.5 批处…...

Maven 构建缓存与离线模式
🧑 博主简介:CSDN博客专家,历代文学网(PC端可以访问:https://literature.sinhy.com/#/?__c=1000,移动端可微信小程序搜索“历代文学”)总架构师,15年工作经验,精通Java编程,高并发设计,Springboot和微服务,熟悉Linux,ESXI虚拟化以及云原生Docker和K8s,热衷于探…...

基于51单片机的光强控制LED灯亮灭
目录 具体实现功能 设计介绍 资料内容 全部内容 资料获取 具体实现功能 具体功能: (1)按下按键K后光敏电阻进行光照检测,LCD1602显示光照强度值; (2)光照值小于15时,上面2个LE…...

【Linux操作系统】基础开发工具(yum、vim、gcc/g++)
文章目录 Linux软件包管理器 - yumLinux下的三种安装方式什么是软件包认识Yum与RPMyum常用指令更新软件安装与卸载查找与搜索清理缓存与重建元数据 yum源更新1. 备份现有的 yum 源配置2. 下载新的 repo 文件3. 清理并重建缓存 Linux编辑器 - vim启动vimVim 的三种主要模式常用操…...
gopool 源码分析
gopool gopool是字节跳动开源节流的gopkg包中协程池的一个实现。 关键结构 协程池: type pool struct {// The name of the poolname string// capacity of the pool, the maximum number of goroutines that are actually working// 协程池的最大容量cap int32…...

【Survival Analysis】【机器学习】【3】 SHAP可解釋 AI
前言: SHAP(SHapley Additive explanations) 是一种基于博弈论的可解释工具。 现在很多高分的 论文里面都会带这种基于SHAP 分析的图,用于评估机器学习模型中特征对预测结果的贡献度. pip install -i https://pypi.tuna.tsinghua.edu.cn/sim…...

ModuleNotFoundError No module named ‘torch_geometric‘未找到
ModuleNotFoundError: No module named torch_geometric’未找到 试了很多方法,都没成功,安装torch对应版本的torch_geometric都不行, 后来发现是pip被设置了环境变量,所有pip文件都给安装在了一个文件夹了 排查建议 1. 检查 p…...
iOS 门店营收表格功能的实现
iOS 门店营收表格功能实现方案 核心功能需求 数据展示:表格形式展示门店/日期维度的营收数据排序功能:支持按营收金额、增长率等排序筛选功能:按日期范围/门店/区域筛选交互操作:点击查看详情、数据刷新数据可视化:关…...
链表题解——环形链表【LeetCode】
141. 环形链表 方法一 核心思想: 使用一个集合 seen 来记录已经访问过的节点。遍历链表,如果当前节点已经存在于集合中,说明链表存在环;否则,将当前节点添加到集合中,继续遍历。如果遍历结束(h…...

Cell-o1:强化学习训练LLM解决单细胞推理问题
细胞类型注释是分析scRNA-seq数据异质性的关键任务。尽管最近的基础模型实现了这一过程的自动化,但它们通常独立注释细胞,未考虑批次水平的细胞背景或提供解释性推理。相比之下,人类专家常基于领域知识为不同细胞簇注释不同的细胞类型。为模拟…...
求解插值多项式及其余项表达式
例 求满足 P ( x j ) f ( x j ) P(x_j) f(x_j) P(xj)f(xj) ( j 0 , 1 , 2 j0,1,2 j0,1,2) 及 P ′ ( x 1 ) f ′ ( x 1 ) P(x_1) f(x_1) P′(x1)f′(x1) 的插值多项式及其余项表达式。 解: 由给定条件,可确定次数不超过3的插值多项式。…...

vue3: bingmap using typescript
项目结构: <template><div class"bing-map-market"><!-- 加载遮罩层 --><div class"loading-overlay" v-show"isLoading || errorMessage"><div class"spinner-container"><div class&qu…...
vue3前端实现导出Excel功能
前端实现导出功能可以使用一些插件 我使用的是xlsx库 1.首先我们需要在vue3的项目中安装xlsx库。可以使用npm 或者 pnpm来进行安装 npm install xlsx或者 pnpm install xlsx2.在vue组件中引入xlsx库 import * as XLSX from xlsx;3.定义导出实例方法 const exportExcel () …...

超大规模芯片验证:基于AMD VP1902的S8-100原型验证系统实测性能翻倍
引言: 随着AI、HPC及超大规模芯片设计需求呈指数级增长原型验证平台已成为芯片设计流程中验证复杂架构、缩短迭代周期的核心工具。然而,传统原型验证系统受限于单芯片容量(通常<5000万门)、多芯片分割效率及系统级联能力&#…...

【工作记录】接口功能测试总结
如何对1个接口进行接口测试 一、单接口功能测试 1、接口文档信息 理解接口文档的内容: 请求URL: https://[ip]:[port]/xxxserviceValidation 请求方法: POST 请求参数: serviceCode(必填), servicePsw(必填) 响应参数: status, token 2、编写测试用例 2.1 正…...

Dubbo Logback 远程调用携带traceid
背景 A项目有调用B项目的服务,A项目使用 logback 且有 MDC 方式做 traceid,调用B项目的时候,traceid 没传递过期,导致有时候不好排查问题和链路追踪 准备工作 因为使用的是 alibaba 的 dubbo 所以需要加入单独的包 <depend…...
【element-ui】el-autocomplete实现 无数据匹配
文章目录 方法一:使用 default 插槽方法二:使用 empty-text 属性(适用于列表类型)总结 在使用 Element UI 的 el-autocomplete 组件时,如果你希望在没有任何数据匹配的情况下显示特定的内容,你可以通过自定…...

NLP学习路线图(二十):FastText
在自然语言处理(NLP)领域,词向量(Word Embedding)是基石般的存在。它将离散的符号——词语——转化为连续的、富含语义信息的向量表示,使得计算机能够“理解”语言。而在众多词向量模型中,FastText 凭借其独特的设计理念和卓越性能,尤其是在处理形态丰富的语言和罕见词…...

力扣面试150题--除法求值
Day 62 题目描述 做法 此题本质是一个图论问题,对于两个字母相除是否存在值,其实就是判断,从一个字母能否通过其他字母到达,做法如下: 遍历所有等式,为每个变量分配唯一的整数索引。初始化一个二维数组 …...
SQL进阶之旅 Day 20:锁与并发控制技巧
【JDK21深度解密 Day 20】锁与并发控制技巧 文章简述 在高并发的数据库环境中,锁与并发控制是保障数据一致性和系统稳定性的核心机制。本文作为“SQL进阶之旅”系列的第20天,深入探讨SQL中的锁机制、事务隔离级别以及并发控制策略。文章从理论基础入手…...

美业破局:AI智能体如何用数据重塑战略决策(5/6)
摘要:文章深入剖析美业现状与挑战,指出其市场规模庞大但竞争激烈,面临获客难、成本高、服务标准化缺失等问题。随后阐述 AI 智能体与数据驱动决策的概念,强调其在美业管理中的重要性。接着详细说明 AI 智能体在美业数据收集、整理…...

生成模型+两种机器学习范式
生成模型:从数据分布到样本创造 生成模型(Generative Model) 是机器学习中一类能够学习数据整体概率分布,并生成新样本的模型。其核心目标是建模输入数据 x 和标签 y 的联合概率分布 P(x,y),即回答 “数据是如何产生的…...

【学习笔记】Python金融基础
Python金融入门 1. 加载数据与可视化1.1. 加载数据1.2. 折线图1.3. 重采样1.4. K线图 / 蜡烛图1.5. 挑战1 2. 计算2.1. 收益 / 回报2.2. 绘制收益图2.3. 累积收益2.4. 波动率2.5. 挑战2 3. 滚动窗口3.1. 创建移动平均线3.2. 绘制移动平均线3.3 Challenge 4. 技术分析4.1. OBV4.…...
在Linux查看电脑的GPU型号
VGA 是指 Video Graphics Array,这是 IBM 于 1987 年推出的一种视频显示标准。 lspci | grep vga 📌 lspci | grep -i vga 的含义 lspci:列出所有连接到 PCI 总线的设备。 grep -i vga:过滤输出,仅显示包含“VGA”字…...

A Execllent Software Project Review and Solutions
The Phoenix Projec: how do we produce software? how many steps? how many people? how much money? you will get it. i am a pretty judge of people…a prank...

windows命令行面板升级Git版本
Date: 2025-06-05 11:41:56 author: lijianzhan Git 是一个 分布式版本控制系统 (DVCS),由 Linux 之父 Linus Torvalds 于 2005 年开发,用于管理 Linux 内核开发。它彻底改变了代码协作和版本管理的方式,现已成为软件开发的事实标准工具&…...
Langgraph实战--自定义embeding
概述 在Langgraph中我想使用第三方的embeding接口来实现文本的embeding。但目前langchain只提供了两个类,一个是AzureOpenAIEmbeddings,一个是:OpenAIEmbeddings。通过ChatOpenAI无法使用第三方的接口,例如:硅基流平台…...

大故障,阿里云核心域名疑似被劫持
2025年6月5日凌晨,阿里云多个服务突发异常,罪魁祸首居然是它自家的“核心域名”——aliyuncs.com。包括对象存储 OSS、内容分发 CDN、镜像仓库 ACR、云解析 DNS 等服务在内,全部受到波及,用户业务连夜“塌房”。 更让人惊讶的是&…...
什么是「镜像」?(Docker Image)
🧊 什么是「镜像」?(Docker Image) 💡 人话解释: Docker 镜像就像是一个装好程序的“快照包”,里面包含了程序本体、依赖库、运行环境,甚至是系统文件。 你可以把镜像理解为&…...