【PYG】GNN和全连接层(FC)分别在不同的类中,使用反向传播联合训练,实现端到端的训练过程
文章目录
- 基本步骤
- GNN和全连接层(FC)联合训练
- 1. 定义GNN模型类
- 2. 定义FC模型类
- 3. 训练循环中的联合优化
- 解释
- 完整代码
- GNN和全连接层(FC)分别使用不同的优化器和学习率分别进行参数更新
- 解释
基本步骤
要从GNN(图神经网络)中提取特征,并使用全连接层(FC,Fully Connected Layer)进行后续处理,可以按照以下步骤进行:
-
构建图神经网络模型:选择一种GNN架构,例如GCN(Graph Convolutional Network)、GAT(Graph Attention Network)等。你可以使用深度学习框架(如PyTorch、TensorFlow)来实现。
-
获取节点特征和图结构:准备好节点特征矩阵和邻接矩阵,这些是GNN模型的输入。
-
通过GNN提取特征:
- 设计GNN模型的前向传播过程,将节点特征和邻接矩阵输入GNN层。
- 从GNN层的输出中提取节点的嵌入特征。
-
连接全连接层进行分类或回归:
- 将GNN提取的节点特征作为输入传递给一个或多个全连接层。
- 通过全连接层进行后续的分类、回归等任务。
GNN和全连接层(FC)联合训练
如果GNN和全连接层(FC)分别在不同的类中,并且你希望它们可以联合训练,你可以通过以下步骤实现端到端的训练过程,并确保反向传播能够正确进行:
- 定义GNN和FC模型:分别定义GNN和FC模型类。
- 特征提取与分类:在训练循环中,将GNN提取的特征传递给FC进行分类。
- 联合优化:使用一个优化器来更新两个模型的参数。
以下是具体的实现步骤和代码示例:
1. 定义GNN模型类
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
from torch_geometric.data import Data, Batch
from sklearn.preprocessing import StandardScalerclass GNN(nn.Module):def __init__(self, in_channels, hidden_channels, out_channels):super(GNN, self).__init__()self.conv1 = GCNConv(in_channels, hidden_channels)self.conv2 = GCNConv(hidden_channels, out_channels)def forward(self, data):x, edge_index = data.x, data.edge_indexx = self.conv1(x, edge_index)x = F.relu(x)x = self.conv2(x, edge_index)gnn_features = F.relu(x)return gnn_features
2. 定义FC模型类
class FC(nn.Module):def __init__(self, in_features, num_classes):super(FC, self).__init__()self.fc = nn.Linear(in_features, num_classes)def forward(self, x):out = self.fc(x)return out
3. 训练循环中的联合优化
# 假设我们有一些数据
num_nodes_per_graph = 10
num_graphs = 5
num_node_features = 16
num_classes = 3# 创建多个图数据
graphs = []
for _ in range(num_graphs):x = torch.randn((num_nodes_per_graph, num_node_features))scaler = StandardScaler()x = torch.tensor(scaler.fit_transform(x), dtype=torch.float) # 标准化edge_index = torch_geometric.utils.grid(num_nodes_per_graph)graphs.append(Data(x=x, edge_index=edge_index))# 批处理数据
batch = Batch.from_data_list(graphs)# 创建模型
gnn_model = GNN(in_channels=num_node_features, hidden_channels=32, out_channels=64)
fc_model = FC(in_features=64, num_classes=num_classes)# 使用一个优化器来联合优化两个模型的参数
optimizer = torch.optim.Adam(list(gnn_model.parameters()) + list(fc_model.parameters()), lr=1e-4)
criterion = nn.CrossEntropyLoss()# 生成一些随机目标
target = torch.randint(0, num_classes, (num_nodes_per_graph * num_graphs,))# 训练模型
for epoch in range(100):gnn_model.train()fc_model.train()optimizer.zero_grad()# 前向传播通过GNN模型gnn_features = gnn_model(batch)# 前向传播通过FC模型output = fc_model(gnn_features)# 计算损失loss = criterion(output, target)# 反向传播loss.backward()# 优化器步optimizer.step()print(f'Epoch {epoch+1}, Loss: {loss.item()}')# 查看特征
print("Extracted GNN features:", gnn_features)
解释
- GNN模型类:
GNN
类定义了一个简单的两层GCN模型,用于特征提取。 - FC模型类:
FC
类定义了一个全连接层模型,用于分类。 - 联合优化:
- 在训练循环中,首先通过GNN模型提取特征,然后将提取的特征传递给FC模型进行分类。
- 使用一个优化器来同时优化GNN和FC模型的参数。
- 通过调用
optimizer.zero_grad()
清除梯度,调用loss.backward()
进行反向传播,最后调用optimizer.step()
更新参数。
通过这种方式,尽管GNN和FC模型分别在不同的类中,它们仍然可以端到端地进行联合训练,并确保梯度正确地传播到整个模型的每一部分。
使用正确的参数来生成随机图。torch_geometric.utils.erdos_renyi_graph需要使用num_nodes和edge_prob参数
完整代码
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.data import Data, Batch
from torch_geometric.nn import GCNConv
from torch_geometric.utils import erdos_renyi_graph
from sklearn.preprocessing import StandardScalerclass GNN(nn.Module):def __init__(self, in_channels, hidden_channels, out_channels):super(GNN, self).__init__()self.conv1 = GCNConv(in_channels, hidden_channels)self.conv2 = GCNConv(hidden_channels, out_channels)def forward(self, data):x, edge_index = data.x, data.edge_indexx = self.conv1(x, edge_index)x = F.relu(x)x = self.conv2(x, edge_index)gnn_features = F.relu(x)return gnn_featuresclass FC(nn.Module):def __init__(self, in_features, num_classes):super(FC, self).__init__()self.fc = nn.Linear(in_features, num_classes)def forward(self, x):out = self.fc(x)return out# 假设我们有一些数据
num_nodes_per_graph = 10
num_graphs = 5
num_node_features = 16
num_classes = 3# 创建多个图数据
graphs = []
for _ in range(num_graphs):x = torch.randn((num_nodes_per_graph, num_node_features))scaler = StandardScaler()x = torch.tensor(scaler.fit_transform(x), dtype=torch.float) # 标准化edge_index = erdos_renyi_graph(num_nodes=num_nodes_per_graph, edge_prob=0.5) # 生成随机图graphs.append(Data(x=x, edge_index=edge_index))# 批处理数据
batch = Batch.from_data_list(graphs)# 创建模型
gnn_model = GNN(in_channels=num_node_features, hidden_channels=32, out_channels=64)
fc_model = FC(in_features=64, num_classes=num_classes)# 使用一个优化器来联合优化两个模型的参数
optimizer = torch.optim.Adam(list(gnn_model.parameters()) + list(fc_model.parameters()), lr=1e-4)
criterion = nn.CrossEntropyLoss()# 生成一些随机目标
target = torch.randint(0, num_classes, (num_nodes_per_graph * num_graphs,))# 训练模型
for epoch in range(100):gnn_model.train()fc_model.train()optimizer.zero_grad()# 前向传播通过GNN模型gnn_features = gnn_model(batch)# 前向传播通过FC模型output = fc_model(gnn_features)# 计算损失loss = criterion(output, target)# 反向传播loss.backward()# 优化器步optimizer.step()print(f'Epoch {epoch+1}, Loss: {loss.item()}')# 查看特征
print("Extracted GNN features:", gnn_features)
GNN和全连接层(FC)分别使用不同的优化器和学习率分别进行参数更新
如果你想为GNN和全连接层(FC)分别使用不同的优化器和学习率,可以按照以下步骤进行:
- 定义两个优化器:一个用于GNN模型,另一个用于FC模型。
- 分别进行参数更新:在训练循环中,分别对两个模型进行前向传播、损失计算和反向传播,然后使用各自的优化器更新参数。
以下是实现代码示例:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.data import Data, Batch
from torch_geometric.nn import GCNConv
from torch_geometric.utils import erdos_renyi_graph
from sklearn.preprocessing import StandardScalerclass GNN(nn.Module):def __init__(self, in_channels, hidden_channels, out_channels):super(GNN, self).__init__()self.conv1 = GCNConv(in_channels, hidden_channels)self.conv2 = GCNConv(hidden_channels, out_channels)def forward(self, data):x, edge_index = data.x, data.edge_indexx = self.conv1(x, edge_index)x = F.relu(x)x = self.conv2(x, edge_index)gnn_features = F.relu(x)return gnn_featuresclass FC(nn.Module):def __init__(self, in_features, num_classes):super(FC, self).__init__()self.fc = nn.Linear(in_features, num_classes)def forward(self, x):out = self.fc(x)return out# 假设我们有一些数据
num_nodes_per_graph = 10
num_graphs = 5
num_node_features = 16
num_classes = 3# 创建多个图数据
graphs = []
for _ in range(num_graphs):x = torch.randn((num_nodes_per_graph, num_node_features))scaler = StandardScaler()x = torch.tensor(scaler.fit_transform(x), dtype=torch.float) # 标准化edge_index = erdos_renyi_graph(num_nodes=num_nodes_per_graph, edge_prob=0.5) # 生成随机图graphs.append(Data(x=x, edge_index=edge_index))# 批处理数据
batch = Batch.from_data_list(graphs)# 创建模型
gnn_model = GNN(in_channels=num_node_features, hidden_channels=32, out_channels=64)
fc_model = FC(in_features=64, num_classes=num_classes)# 使用两个优化器分别优化GNN和FC模型的参数
optimizer_gnn = torch.optim.Adam(gnn_model.parameters(), lr=1e-3) # GNN使用较高的学习率
optimizer_fc = torch.optim.Adam(fc_model.parameters(), lr=1e-4) # FC使用较低的学习率
criterion = nn.CrossEntropyLoss()# 生成一些随机目标
target = torch.randint(0, num_classes, (num_nodes_per_graph * num_graphs,))# 训练模型
for epoch in range(100):gnn_model.train()fc_model.train()optimizer_gnn.zero_grad()optimizer_fc.zero_grad()# 前向传播通过GNN模型gnn_features = gnn_model(batch)# 前向传播通过FC模型output = fc_model(gnn_features)# 计算损失loss = criterion(output, target)# 反向传播loss.backward()# 使用各自的优化器更新参数optimizer_gnn.step()optimizer_fc.step()print(f'Epoch {epoch+1}, Loss: {loss.item()}')# 查看特征
print("Extracted GNN features:", gnn_features)
解释
- GNN模型类:
GNN
类定义了一个简单的两层GCN模型,用于特征提取。 - FC模型类:
FC
类定义了一个全连接层模型,用于分类。 - 数据生成:使用
torch_geometric.utils.erdos_renyi_graph
生成随机图数据,并确保参数正确。 - 联合优化:
- 定义两个优化器,分别用于GNN和FC模型,并为它们设置不同的学习率。
- 在训练循环中,首先通过GNN模型提取特征,然后将提取的特征传递给FC模型进行分类。
- 使用各自的优化器来分别清除梯度、进行反向传播和更新参数。
通过这种方式,尽管GNN和FC模型分别在不同的类中,并使用不同的优化器和学习率,它们仍然可以端到端地进行联合训练,并确保梯度正确地传播到整个模型的每一部分。
相关文章:
【PYG】GNN和全连接层(FC)分别在不同的类中,使用反向传播联合训练,实现端到端的训练过程
文章目录 基本步骤GNN和全连接层(FC)联合训练1. 定义GNN模型类2. 定义FC模型类3. 训练循环中的联合优化解释完整代码 GNN和全连接层(FC)分别使用不同的优化器和学习率分别进行参数更新解释 基本步骤 要从GNN(图神经网…...

vue3使用方式汇总
1、引入iconfont阿里图库图标: 1.1 进入阿里图标网站: iconfont阿里:https://www.iconfont.cn/ 1.2 添加图标: 1.3 下载代码: 1.4 在vue3中配置代码: 将其代码复制到src/assets/fonts/目录下࿱…...

Turborepo简易教程
参考官网:https://turbo.build/repo/docs 开始 安装全新的项目 pnpm dlx create-turbolatest测试应用包含: 两个可部署的应用三个共享库 运行: pnpm install pnpm dev会启动两个应用web(http://localhost:3000/)、docs(http://localhost…...

初中物理知识点总结(人教版)
初中物理知识点大全 声现象知识归纳 1 .声音的发生:由物体的振动而产生。振动停止,发声也停止。 2.声音的传播:声音靠介质传播。真空不能传声。通常我们听到的声音是靠空气传来的。 3.声速:在空气中传播速度是:340…...
ChatGPT-4o大语言模型优化、本地私有化部署、从0-1搭建、智能体构建等高级进阶
目录 第一章 ChatGPT-4o使用进阶 第二章 大语言模型原理详解 第三章 大语言模型优化 第四章 开源大语言模型及本地部署 第五章 从0到1搭建第一个大语言模型 第六章 智能体(Agent)构建 第七章 大语言模型发展趋势 第八章 总结与答疑讨论 更多应用…...

【开源项目】LocalSend 局域网文件传输工具
【开源项目】LocalSend 局域网文件传输工具 一个免费、开源、跨平台的局域网传输工具 LocalSend 简介 LocalSend 是一个免费的开源跨平台的应用程序,允许用户在不需要互联网连接的情况下,通过本地网络安全地与附近设备共享文件和消息。 项目地址&…...
ARM/Linux嵌入式面经(十一):地平线嵌入式实习
地平线嵌入式实习面经 1.自我介绍 等着,在给大哥们准备了。 2.spi与iic协议可以连接多个设备吗?最多多少个?通讯时序。 这是几个问题,在回答的时候。不要一问就开口,花几秒钟沉吟思考整理一下自己的思路。 这个问题问了几个点?每个点的回答步骤。 是我的话,我会采用以…...

基于Redis的分布式锁
分布式场景下并发安全问题的引发 前面通过加锁解决了单机状态下一人一单的问题,但是当出现了分布式,前面的加锁形式不再适用 ,每个jvm有一个自己的锁监视器,只能被内部线程获取,其他jvm无法使用,那么多台j…...

如何将 Apifox 的自动化测试与 Jenkins 集成?
CI/CD (持续集成/持续交付) 在 API 测试 中的主要目的是为了自动化 API 的验证流程,确保 API 发布到生产环境前的可用性。通过持续集成,我们可以在 API 定义变更时自动执行功能测试,以及时发现潜在问题。 Apifox 支持…...

【FFmpeg】av_write_frame函数
目录 1.av_write_frame1.1 写入pkt(write_packets_common)1.1.1 检查pkt的信息(check_packet)1.1.2 准备输入的pkt(prepare_input_packet)1.1.3 检查码流(check_bitstream)1.1.4 写入…...

【算法专题】双指针算法
1. 移动零 题目分析 对于这类数组分块的问题,我们应该首先想到用双指针的思路来进行处理,因为数组可以通过下标进行访问,所以说我们不用真的定义指针,用下标即可。比如本题就要求将数组划分为零区域和非零区域,我们不…...

Lock与ReentrantLock
在 Java 中,Lock 接口和 ReentrantLock 类提供了比使用 synchronized 方法和代码块更广泛的锁定机制。 简单示例: import java.util.concurrent.locks.Lock; import java.util.concurrent.locks.ReentrantLock;public class ReentrantLockExample {pr…...
ARM/Linux嵌入式面经(十三):紫光同芯嵌入式
static关键字 static关键字一文搞懂这个知识点,真的是喜欢考!!! stm32启动时如何配置栈的地址 在STM32启动时配置栈的地址是一个关键步骤,这通常是在启动文件(如startup_stm32fxxx.s,其中xxx代表具体的STM32型号)中完成的。 面试者回答: STM32启动时配置栈的地址主…...
名企面试必问30题(二十四)—— 说说你空窗期做了什么?
回答示例一 在空窗期这段时间,我主要进行了两方面的活动。 一方面,我持续提升自己的专业技能。我系统地学习了最新的软件测试理论和方法,深入研究了自动化测试工具和框架,例如 Selenium、Appium 等,并通过在线课程和实…...
基础权限储存
一、要求: 1、建立用户组shengcan,其id为2000工 2、建立用户组 caiwu,其id为2001 3、建立用户组 jishu,其id 为 2002 4、建立目录/sc,此目录是 shengchan 部门的存储目录,只能被 shengchan 组的成员操作,其他用户没有…...
Could not find a package configuration file provided by “roscpp“ 的参考解决方法
文章目录 写在前面一、问题描述二、解决方法参考链接 写在前面 自己的测试环境: Ubuntu20.04 ROS-Noetic 一、问题描述 编译程序时出现如下报错: -- Could NOT find roscpp (missing: roscpp_DIR) -- Could not find the required component roscpp.…...

运维系列.Nginx配置中的高级指令和流程控制
运维专题 Nginx配置中的高级指令和流程控制 - 文章信息 - Author: 李俊才 (jcLee95) Visit me at CSDN: https://jclee95.blog.csdn.netMy WebSite:http://thispage.tech/Email: 291148484163.com. Shenzhen ChinaAddress of this article:https://blog.csdn.net/…...
Virtualbox和ubuntu之间的关系
1、什么是ubuntu Ubuntu 是一个类似于 Windows 的操作系统,但它是基于 Linux 内核开发的开源操作系统 2、什么是Virtualbox VirtualBox 是一款虚拟机软件,使我们可以物理机上创建和运行虚拟机 也就是说,VirtualBox 提供了一个可以安装和运行其他操作系…...

【在Linux世界中追寻伟大的One Piece】HTTPS协议原理
目录 1 -> HTTPS是什么? 2 -> 相关概念 2.1 -> 什么是"加密" 2.2 -> 为什么要加密 2.3 -> 常见的加密方式 2.4 -> 数据摘要 && 数据指纹 2.5 -> 数字签名 3 -> HTTPS的工作过程 3.1 -> 只使用对称加密 3.2 …...

【WebRTC实现点对点视频通话】
介绍 WebRTC (Web Real-Time Communications) 是一个实时通讯技术,也是实时音视频技术的标准和框架。简单来说WebRTC是一个集大成的实时音视频技术集,包含了各种客户端api、音视频编/解码lib、流媒体传输协议、回声消除、安全传输等。对于开发者来说可以…...

接口测试中缓存处理策略
在接口测试中,缓存处理策略是一个关键环节,直接影响测试结果的准确性和可靠性。合理的缓存处理策略能够确保测试环境的一致性,避免因缓存数据导致的测试偏差。以下是接口测试中常见的缓存处理策略及其详细说明: 一、缓存处理的核…...
React 第五十五节 Router 中 useAsyncError的使用详解
前言 useAsyncError 是 React Router v6.4 引入的一个钩子,用于处理异步操作(如数据加载)中的错误。下面我将详细解释其用途并提供代码示例。 一、useAsyncError 用途 处理异步错误:捕获在 loader 或 action 中发生的异步错误替…...

循环冗余码校验CRC码 算法步骤+详细实例计算
通信过程:(白话解释) 我们将原始待发送的消息称为 M M M,依据发送接收消息双方约定的生成多项式 G ( x ) G(x) G(x)(意思就是 G ( x ) G(x) G(x) 是已知的)࿰…...
React Native在HarmonyOS 5.0阅读类应用开发中的实践
一、技术选型背景 随着HarmonyOS 5.0对Web兼容层的增强,React Native作为跨平台框架可通过重新编译ArkTS组件实现85%以上的代码复用率。阅读类应用具有UI复杂度低、数据流清晰的特点。 二、核心实现方案 1. 环境配置 (1)使用React Native…...
第25节 Node.js 断言测试
Node.js的assert模块主要用于编写程序的单元测试时使用,通过断言可以提早发现和排查出错误。 稳定性: 5 - 锁定 这个模块可用于应用的单元测试,通过 require(assert) 可以使用这个模块。 assert.fail(actual, expected, message, operator) 使用参数…...
sqlserver 根据指定字符 解析拼接字符串
DECLARE LotNo NVARCHAR(50)A,B,C DECLARE xml XML ( SELECT <x> REPLACE(LotNo, ,, </x><x>) </x> ) DECLARE ErrorCode NVARCHAR(50) -- 提取 XML 中的值 SELECT value x.value(., VARCHAR(MAX))…...

图表类系列各种样式PPT模版分享
图标图表系列PPT模版,柱状图PPT模版,线状图PPT模版,折线图PPT模版,饼状图PPT模版,雷达图PPT模版,树状图PPT模版 图表类系列各种样式PPT模版分享:图表系列PPT模板https://pan.quark.cn/s/20d40aa…...
大语言模型(LLM)中的KV缓存压缩与动态稀疏注意力机制设计
随着大语言模型(LLM)参数规模的增长,推理阶段的内存占用和计算复杂度成为核心挑战。传统注意力机制的计算复杂度随序列长度呈二次方增长,而KV缓存的内存消耗可能高达数十GB(例如Llama2-7B处理100K token时需50GB内存&a…...

初学 pytest 记录
安装 pip install pytest用例可以是函数也可以是类中的方法 def test_func():print()class TestAdd: # def __init__(self): 在 pytest 中不可以使用__init__方法 # self.cc 12345 pytest.mark.api def test_str(self):res add(1, 2)assert res 12def test_int(self):r…...
python报错No module named ‘tensorflow.keras‘
是由于不同版本的tensorflow下的keras所在的路径不同,结合所安装的tensorflow的目录结构修改from语句即可。 原语句: from tensorflow.keras.layers import Conv1D, MaxPooling1D, LSTM, Dense 修改后: from tensorflow.python.keras.lay…...