【PYG】GNN和全连接层(FC)分别在不同的类中,使用反向传播联合训练,实现端到端的训练过程
文章目录
- 基本步骤
- GNN和全连接层(FC)联合训练
- 1. 定义GNN模型类
- 2. 定义FC模型类
- 3. 训练循环中的联合优化
- 解释
- 完整代码
- GNN和全连接层(FC)分别使用不同的优化器和学习率分别进行参数更新
- 解释
基本步骤
要从GNN(图神经网络)中提取特征,并使用全连接层(FC,Fully Connected Layer)进行后续处理,可以按照以下步骤进行:
-
构建图神经网络模型:选择一种GNN架构,例如GCN(Graph Convolutional Network)、GAT(Graph Attention Network)等。你可以使用深度学习框架(如PyTorch、TensorFlow)来实现。
-
获取节点特征和图结构:准备好节点特征矩阵和邻接矩阵,这些是GNN模型的输入。
-
通过GNN提取特征:
- 设计GNN模型的前向传播过程,将节点特征和邻接矩阵输入GNN层。
- 从GNN层的输出中提取节点的嵌入特征。
-
连接全连接层进行分类或回归:
- 将GNN提取的节点特征作为输入传递给一个或多个全连接层。
- 通过全连接层进行后续的分类、回归等任务。
GNN和全连接层(FC)联合训练
如果GNN和全连接层(FC)分别在不同的类中,并且你希望它们可以联合训练,你可以通过以下步骤实现端到端的训练过程,并确保反向传播能够正确进行:
- 定义GNN和FC模型:分别定义GNN和FC模型类。
- 特征提取与分类:在训练循环中,将GNN提取的特征传递给FC进行分类。
- 联合优化:使用一个优化器来更新两个模型的参数。
以下是具体的实现步骤和代码示例:
1. 定义GNN模型类
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
from torch_geometric.data import Data, Batch
from sklearn.preprocessing import StandardScalerclass GNN(nn.Module):def __init__(self, in_channels, hidden_channels, out_channels):super(GNN, self).__init__()self.conv1 = GCNConv(in_channels, hidden_channels)self.conv2 = GCNConv(hidden_channels, out_channels)def forward(self, data):x, edge_index = data.x, data.edge_indexx = self.conv1(x, edge_index)x = F.relu(x)x = self.conv2(x, edge_index)gnn_features = F.relu(x)return gnn_features
2. 定义FC模型类
class FC(nn.Module):def __init__(self, in_features, num_classes):super(FC, self).__init__()self.fc = nn.Linear(in_features, num_classes)def forward(self, x):out = self.fc(x)return out
3. 训练循环中的联合优化
# 假设我们有一些数据
num_nodes_per_graph = 10
num_graphs = 5
num_node_features = 16
num_classes = 3# 创建多个图数据
graphs = []
for _ in range(num_graphs):x = torch.randn((num_nodes_per_graph, num_node_features))scaler = StandardScaler()x = torch.tensor(scaler.fit_transform(x), dtype=torch.float) # 标准化edge_index = torch_geometric.utils.grid(num_nodes_per_graph)graphs.append(Data(x=x, edge_index=edge_index))# 批处理数据
batch = Batch.from_data_list(graphs)# 创建模型
gnn_model = GNN(in_channels=num_node_features, hidden_channels=32, out_channels=64)
fc_model = FC(in_features=64, num_classes=num_classes)# 使用一个优化器来联合优化两个模型的参数
optimizer = torch.optim.Adam(list(gnn_model.parameters()) + list(fc_model.parameters()), lr=1e-4)
criterion = nn.CrossEntropyLoss()# 生成一些随机目标
target = torch.randint(0, num_classes, (num_nodes_per_graph * num_graphs,))# 训练模型
for epoch in range(100):gnn_model.train()fc_model.train()optimizer.zero_grad()# 前向传播通过GNN模型gnn_features = gnn_model(batch)# 前向传播通过FC模型output = fc_model(gnn_features)# 计算损失loss = criterion(output, target)# 反向传播loss.backward()# 优化器步optimizer.step()print(f'Epoch {epoch+1}, Loss: {loss.item()}')# 查看特征
print("Extracted GNN features:", gnn_features)
解释
- GNN模型类:
GNN
类定义了一个简单的两层GCN模型,用于特征提取。 - FC模型类:
FC
类定义了一个全连接层模型,用于分类。 - 联合优化:
- 在训练循环中,首先通过GNN模型提取特征,然后将提取的特征传递给FC模型进行分类。
- 使用一个优化器来同时优化GNN和FC模型的参数。
- 通过调用
optimizer.zero_grad()
清除梯度,调用loss.backward()
进行反向传播,最后调用optimizer.step()
更新参数。
通过这种方式,尽管GNN和FC模型分别在不同的类中,它们仍然可以端到端地进行联合训练,并确保梯度正确地传播到整个模型的每一部分。
使用正确的参数来生成随机图。torch_geometric.utils.erdos_renyi_graph需要使用num_nodes和edge_prob参数
完整代码
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.data import Data, Batch
from torch_geometric.nn import GCNConv
from torch_geometric.utils import erdos_renyi_graph
from sklearn.preprocessing import StandardScalerclass GNN(nn.Module):def __init__(self, in_channels, hidden_channels, out_channels):super(GNN, self).__init__()self.conv1 = GCNConv(in_channels, hidden_channels)self.conv2 = GCNConv(hidden_channels, out_channels)def forward(self, data):x, edge_index = data.x, data.edge_indexx = self.conv1(x, edge_index)x = F.relu(x)x = self.conv2(x, edge_index)gnn_features = F.relu(x)return gnn_featuresclass FC(nn.Module):def __init__(self, in_features, num_classes):super(FC, self).__init__()self.fc = nn.Linear(in_features, num_classes)def forward(self, x):out = self.fc(x)return out# 假设我们有一些数据
num_nodes_per_graph = 10
num_graphs = 5
num_node_features = 16
num_classes = 3# 创建多个图数据
graphs = []
for _ in range(num_graphs):x = torch.randn((num_nodes_per_graph, num_node_features))scaler = StandardScaler()x = torch.tensor(scaler.fit_transform(x), dtype=torch.float) # 标准化edge_index = erdos_renyi_graph(num_nodes=num_nodes_per_graph, edge_prob=0.5) # 生成随机图graphs.append(Data(x=x, edge_index=edge_index))# 批处理数据
batch = Batch.from_data_list(graphs)# 创建模型
gnn_model = GNN(in_channels=num_node_features, hidden_channels=32, out_channels=64)
fc_model = FC(in_features=64, num_classes=num_classes)# 使用一个优化器来联合优化两个模型的参数
optimizer = torch.optim.Adam(list(gnn_model.parameters()) + list(fc_model.parameters()), lr=1e-4)
criterion = nn.CrossEntropyLoss()# 生成一些随机目标
target = torch.randint(0, num_classes, (num_nodes_per_graph * num_graphs,))# 训练模型
for epoch in range(100):gnn_model.train()fc_model.train()optimizer.zero_grad()# 前向传播通过GNN模型gnn_features = gnn_model(batch)# 前向传播通过FC模型output = fc_model(gnn_features)# 计算损失loss = criterion(output, target)# 反向传播loss.backward()# 优化器步optimizer.step()print(f'Epoch {epoch+1}, Loss: {loss.item()}')# 查看特征
print("Extracted GNN features:", gnn_features)
GNN和全连接层(FC)分别使用不同的优化器和学习率分别进行参数更新
如果你想为GNN和全连接层(FC)分别使用不同的优化器和学习率,可以按照以下步骤进行:
- 定义两个优化器:一个用于GNN模型,另一个用于FC模型。
- 分别进行参数更新:在训练循环中,分别对两个模型进行前向传播、损失计算和反向传播,然后使用各自的优化器更新参数。
以下是实现代码示例:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.data import Data, Batch
from torch_geometric.nn import GCNConv
from torch_geometric.utils import erdos_renyi_graph
from sklearn.preprocessing import StandardScalerclass GNN(nn.Module):def __init__(self, in_channels, hidden_channels, out_channels):super(GNN, self).__init__()self.conv1 = GCNConv(in_channels, hidden_channels)self.conv2 = GCNConv(hidden_channels, out_channels)def forward(self, data):x, edge_index = data.x, data.edge_indexx = self.conv1(x, edge_index)x = F.relu(x)x = self.conv2(x, edge_index)gnn_features = F.relu(x)return gnn_featuresclass FC(nn.Module):def __init__(self, in_features, num_classes):super(FC, self).__init__()self.fc = nn.Linear(in_features, num_classes)def forward(self, x):out = self.fc(x)return out# 假设我们有一些数据
num_nodes_per_graph = 10
num_graphs = 5
num_node_features = 16
num_classes = 3# 创建多个图数据
graphs = []
for _ in range(num_graphs):x = torch.randn((num_nodes_per_graph, num_node_features))scaler = StandardScaler()x = torch.tensor(scaler.fit_transform(x), dtype=torch.float) # 标准化edge_index = erdos_renyi_graph(num_nodes=num_nodes_per_graph, edge_prob=0.5) # 生成随机图graphs.append(Data(x=x, edge_index=edge_index))# 批处理数据
batch = Batch.from_data_list(graphs)# 创建模型
gnn_model = GNN(in_channels=num_node_features, hidden_channels=32, out_channels=64)
fc_model = FC(in_features=64, num_classes=num_classes)# 使用两个优化器分别优化GNN和FC模型的参数
optimizer_gnn = torch.optim.Adam(gnn_model.parameters(), lr=1e-3) # GNN使用较高的学习率
optimizer_fc = torch.optim.Adam(fc_model.parameters(), lr=1e-4) # FC使用较低的学习率
criterion = nn.CrossEntropyLoss()# 生成一些随机目标
target = torch.randint(0, num_classes, (num_nodes_per_graph * num_graphs,))# 训练模型
for epoch in range(100):gnn_model.train()fc_model.train()optimizer_gnn.zero_grad()optimizer_fc.zero_grad()# 前向传播通过GNN模型gnn_features = gnn_model(batch)# 前向传播通过FC模型output = fc_model(gnn_features)# 计算损失loss = criterion(output, target)# 反向传播loss.backward()# 使用各自的优化器更新参数optimizer_gnn.step()optimizer_fc.step()print(f'Epoch {epoch+1}, Loss: {loss.item()}')# 查看特征
print("Extracted GNN features:", gnn_features)
解释
- GNN模型类:
GNN
类定义了一个简单的两层GCN模型,用于特征提取。 - FC模型类:
FC
类定义了一个全连接层模型,用于分类。 - 数据生成:使用
torch_geometric.utils.erdos_renyi_graph
生成随机图数据,并确保参数正确。 - 联合优化:
- 定义两个优化器,分别用于GNN和FC模型,并为它们设置不同的学习率。
- 在训练循环中,首先通过GNN模型提取特征,然后将提取的特征传递给FC模型进行分类。
- 使用各自的优化器来分别清除梯度、进行反向传播和更新参数。
通过这种方式,尽管GNN和FC模型分别在不同的类中,并使用不同的优化器和学习率,它们仍然可以端到端地进行联合训练,并确保梯度正确地传播到整个模型的每一部分。
相关文章:

【PYG】GNN和全连接层(FC)分别在不同的类中,使用反向传播联合训练,实现端到端的训练过程
文章目录 基本步骤GNN和全连接层(FC)联合训练1. 定义GNN模型类2. 定义FC模型类3. 训练循环中的联合优化解释完整代码 GNN和全连接层(FC)分别使用不同的优化器和学习率分别进行参数更新解释 基本步骤 要从GNN(图神经网…...

vue3使用方式汇总
1、引入iconfont阿里图库图标: 1.1 进入阿里图标网站: iconfont阿里:https://www.iconfont.cn/ 1.2 添加图标: 1.3 下载代码: 1.4 在vue3中配置代码: 将其代码复制到src/assets/fonts/目录下࿱…...

Turborepo简易教程
参考官网:https://turbo.build/repo/docs 开始 安装全新的项目 pnpm dlx create-turbolatest测试应用包含: 两个可部署的应用三个共享库 运行: pnpm install pnpm dev会启动两个应用web(http://localhost:3000/)、docs(http://localhost…...

初中物理知识点总结(人教版)
初中物理知识点大全 声现象知识归纳 1 .声音的发生:由物体的振动而产生。振动停止,发声也停止。 2.声音的传播:声音靠介质传播。真空不能传声。通常我们听到的声音是靠空气传来的。 3.声速:在空气中传播速度是:340…...
ChatGPT-4o大语言模型优化、本地私有化部署、从0-1搭建、智能体构建等高级进阶
目录 第一章 ChatGPT-4o使用进阶 第二章 大语言模型原理详解 第三章 大语言模型优化 第四章 开源大语言模型及本地部署 第五章 从0到1搭建第一个大语言模型 第六章 智能体(Agent)构建 第七章 大语言模型发展趋势 第八章 总结与答疑讨论 更多应用…...

【开源项目】LocalSend 局域网文件传输工具
【开源项目】LocalSend 局域网文件传输工具 一个免费、开源、跨平台的局域网传输工具 LocalSend 简介 LocalSend 是一个免费的开源跨平台的应用程序,允许用户在不需要互联网连接的情况下,通过本地网络安全地与附近设备共享文件和消息。 项目地址&…...

ARM/Linux嵌入式面经(十一):地平线嵌入式实习
地平线嵌入式实习面经 1.自我介绍 等着,在给大哥们准备了。 2.spi与iic协议可以连接多个设备吗?最多多少个?通讯时序。 这是几个问题,在回答的时候。不要一问就开口,花几秒钟沉吟思考整理一下自己的思路。 这个问题问了几个点?每个点的回答步骤。 是我的话,我会采用以…...

基于Redis的分布式锁
分布式场景下并发安全问题的引发 前面通过加锁解决了单机状态下一人一单的问题,但是当出现了分布式,前面的加锁形式不再适用 ,每个jvm有一个自己的锁监视器,只能被内部线程获取,其他jvm无法使用,那么多台j…...

如何将 Apifox 的自动化测试与 Jenkins 集成?
CI/CD (持续集成/持续交付) 在 API 测试 中的主要目的是为了自动化 API 的验证流程,确保 API 发布到生产环境前的可用性。通过持续集成,我们可以在 API 定义变更时自动执行功能测试,以及时发现潜在问题。 Apifox 支持…...

【FFmpeg】av_write_frame函数
目录 1.av_write_frame1.1 写入pkt(write_packets_common)1.1.1 检查pkt的信息(check_packet)1.1.2 准备输入的pkt(prepare_input_packet)1.1.3 检查码流(check_bitstream)1.1.4 写入…...

【算法专题】双指针算法
1. 移动零 题目分析 对于这类数组分块的问题,我们应该首先想到用双指针的思路来进行处理,因为数组可以通过下标进行访问,所以说我们不用真的定义指针,用下标即可。比如本题就要求将数组划分为零区域和非零区域,我们不…...

Lock与ReentrantLock
在 Java 中,Lock 接口和 ReentrantLock 类提供了比使用 synchronized 方法和代码块更广泛的锁定机制。 简单示例: import java.util.concurrent.locks.Lock; import java.util.concurrent.locks.ReentrantLock;public class ReentrantLockExample {pr…...

ARM/Linux嵌入式面经(十三):紫光同芯嵌入式
static关键字 static关键字一文搞懂这个知识点,真的是喜欢考!!! stm32启动时如何配置栈的地址 在STM32启动时配置栈的地址是一个关键步骤,这通常是在启动文件(如startup_stm32fxxx.s,其中xxx代表具体的STM32型号)中完成的。 面试者回答: STM32启动时配置栈的地址主…...

名企面试必问30题(二十四)—— 说说你空窗期做了什么?
回答示例一 在空窗期这段时间,我主要进行了两方面的活动。 一方面,我持续提升自己的专业技能。我系统地学习了最新的软件测试理论和方法,深入研究了自动化测试工具和框架,例如 Selenium、Appium 等,并通过在线课程和实…...

基础权限储存
一、要求: 1、建立用户组shengcan,其id为2000工 2、建立用户组 caiwu,其id为2001 3、建立用户组 jishu,其id 为 2002 4、建立目录/sc,此目录是 shengchan 部门的存储目录,只能被 shengchan 组的成员操作,其他用户没有…...

Could not find a package configuration file provided by “roscpp“ 的参考解决方法
文章目录 写在前面一、问题描述二、解决方法参考链接 写在前面 自己的测试环境: Ubuntu20.04 ROS-Noetic 一、问题描述 编译程序时出现如下报错: -- Could NOT find roscpp (missing: roscpp_DIR) -- Could not find the required component roscpp.…...

运维系列.Nginx配置中的高级指令和流程控制
运维专题 Nginx配置中的高级指令和流程控制 - 文章信息 - Author: 李俊才 (jcLee95) Visit me at CSDN: https://jclee95.blog.csdn.netMy WebSite:http://thispage.tech/Email: 291148484163.com. Shenzhen ChinaAddress of this article:https://blog.csdn.net/…...

Virtualbox和ubuntu之间的关系
1、什么是ubuntu Ubuntu 是一个类似于 Windows 的操作系统,但它是基于 Linux 内核开发的开源操作系统 2、什么是Virtualbox VirtualBox 是一款虚拟机软件,使我们可以物理机上创建和运行虚拟机 也就是说,VirtualBox 提供了一个可以安装和运行其他操作系…...

【在Linux世界中追寻伟大的One Piece】HTTPS协议原理
目录 1 -> HTTPS是什么? 2 -> 相关概念 2.1 -> 什么是"加密" 2.2 -> 为什么要加密 2.3 -> 常见的加密方式 2.4 -> 数据摘要 && 数据指纹 2.5 -> 数字签名 3 -> HTTPS的工作过程 3.1 -> 只使用对称加密 3.2 …...

【WebRTC实现点对点视频通话】
介绍 WebRTC (Web Real-Time Communications) 是一个实时通讯技术,也是实时音视频技术的标准和框架。简单来说WebRTC是一个集大成的实时音视频技术集,包含了各种客户端api、音视频编/解码lib、流媒体传输协议、回声消除、安全传输等。对于开发者来说可以…...

【Unity】RPG2D龙城纷争(八)寻路系统
更新日期:2024年7月4日。 项目源码:第五章发布(正式开始游戏逻辑的章节) 索引 简介一、寻路系统二、寻路规则(角色移动)三、寻路规则(角色攻击)四、角色移动寻路1.自定义寻路规则2.寻…...

C++ 函数高级——函数重载——基本语法
作用:函数名可以相同,提高复用性 函数重载满足条件: 1.同一个作用域下 2.函数名称相同 3.函数参数类型不同 或者 个数不同 或者 顺序不同 注意:函数的返回值不可以作为函数重载的条件 示例: 运行结果:...

Go语言实现的端口扫描工具示例
Go语言实现的端口扫描工具示例 创建一个端口扫描工具涉及到网络编程和并发处理,下面是一个简单的Go语言实现的端口扫描工具示例。这个工具会扫描指定IP地址的指定范围内的端口。 请注意,使用端口扫描工具可能会违反某些网络的使用条款,甚至…...

SpringSecurity初始化过程
SpringSecurity初始化过程 SpringSecurity一定是被Spring加载的: web.xml中通过ContextLoaderListener监听器实现初始化 <!-- 初始化web容器--><!--设置配置文件的路径--><context-param><param-name>contextConfigLocation</param-…...

Python爬取股票信息-并进行数据可视化分析,绘股票成交量柱状图
为了使用Python爬取股票信息并进行数据可视化分析,我们可以使用几个流行的库:requests 用于网络请求,pandas 用于数据处理,以及 matplotlib 或 seaborn 用于数据可视化。 步骤 1: 安装必要的库 首先,确保安装了以下P…...

秋招突击——7/4——复习{}——新作{最长公共子序列、编辑距离、买股票最佳时机、跳跃游戏}
文章目录 引言复习新作1143-最长公共子序列个人实现 参考实现编辑距离个人实现参考实现 贪心——买股票的最佳时机个人实现参考实现 贪心——55-跳跃游戏个人实现参考做法 总结 引言 昨天主要是面试,然后剩下的时间都是用来对面试中不会的东西进行查漏补缺ÿ…...

udp发送数据如果超过1个mtu时,抓包所遇到的问题记录说明
最近在测试Syslog udp发送相关功能,测试环境是centos udp头部的数据长度是2个字节,最大传输长度理论上是65535,除去头部这些字节,可以大概的说是64k。 写了一个超过64k的数据(随便用了一个7w字节的buffer)发送demo,打…...

电子电气架构 --- 智能座舱万物互联
我是穿拖鞋的汉子,魔都中坚持长期主义的汽车电子工程师。 老规矩,分享一段喜欢的文字,避免自己成为高知识低文化的工程师: 屏蔽力是信息过载时代一个人的特殊竞争力,任何消耗你的人和事,多看一眼都是你的不对。非必要不费力证明自己,无利益不试图说服别人,是精神上的节…...

笔记本电脑内存不够
笔记本电脑内存不够是众多笔记本用户面临的常见问题,尤其是对于一些需要处理大型文件或者运行复杂软件的用户,这个问题可能会严重影响笔记本的使用体验。那么,我们应该如何解决笔记本电脑内存不够的问题呢?本文将从几个方面进行详…...

Vue项目使用mockjs模拟后端接口
文章目录 操作步骤1. 安装 mockjs 和 vite-plugin-mock2. 安装 axios3. 创建mock路径4. 配置 viteMockConfig5. 编写第一个mock接口6. 创建 createProdMockServer7. 配置 axios8. 编写请求接口9. 在页面中使用 操作步骤 1. 安装 mockjs 和 vite-plugin-mock vite-plugin-mock …...