Code Lab - 34
GAT里面有一些地方看的不是太懂(GAT里Multi Attention的具体做法),暂时找了参考代码,留一个疑问
1. 一个通用的GNN Stack
import torch_geometric
import torch
import torch_scatter
import torch.nn as nn
import torch.nn.functional as Fimport torch_geometric.nn as pyg_nn
import torch_geometric.utils as pyg_utilsfrom torch import Tensor
from typing import Union, Tuple, Optional
from torch_geometric.typing import (OptPairTensor, Adj, Size, NoneType,OptTensor)from torch.nn import Parameter, Linear
from torch_sparse import SparseTensor, set_diag
from torch_geometric.nn.conv import MessagePassing
from torch_geometric.utils import remove_self_loops, add_self_loops, softmaxclass GNNStack(torch.nn.Module):def __init__(self, input_dim, hidden_dim, output_dim, args, emb=False):super(GNNStack, self).__init__()conv_model = self.build_conv_model(args.model_type)self.convs = nn.ModuleList()self.convs.append(conv_model(input_dim, hidden_dim))#assert(断言) 用于判断一个表达式,在表达式条件为 false 的时候触发异常assert (args.num_layers >= 1), 'Number of layers is not >=1'for l in range(args.num_layers-1):self.convs.append(conv_model(args.heads * hidden_dim, hidden_dim))# post-message-passingself.post_mp = nn.Sequential(nn.Linear(args.heads * hidden_dim, hidden_dim), nn.Dropout(args.dropout), nn.Linear(hidden_dim, output_dim))self.dropout = args.dropoutself.num_layers = args.num_layersself.emb = embdef build_conv_model(self, model_type):if model_type == 'GraphSage':return GraphSageelif model_type == 'GAT':# When applying GAT with num heads > 1, you need to modify the # input and output dimension of the conv layers (self.convs),# to ensure that the input dim of the next layer is num heads# multiplied by the output dim of the previous layer.# HINT: In case you want to play with multiheads, you need to change the for-loop that builds up self.convs to be# self.convs.append(conv_model(hidden_dim * num_heads, hidden_dim)), # and also the first nn.Linear(hidden_dim * num_heads, hidden_dim) in post-message-passing.return GATdef forward(self, data):x, edge_index, batch = data.x, data.edge_index, data.batchfor i in range(self.num_layers):x = self.convs[i](x, edge_index)x = F.relu(x)x = F.dropout(x, p=self.dropout,training=self.training)x = self.post_mp(x)if self.emb == True:return xreturn F.log_softmax(x, dim=1)def loss(self, pred, label):return F.nll_loss(pred, label)
2. 实现GraphSage和GAT
2.1 GraphSage

class GraphSage(MessagePassing):def __init__(self, in_channels, out_channels, normalize = True,bias = False, **kwargs): super(GraphSage, self).__init__(**kwargs)self.in_channels = in_channelsself.out_channels = out_channelsself.normalize = normalize# self.lin_l is the linear transformation that you apply to embedding for central node.self.lin_l=Linear(in_channels,out_channels) #Wl# self.lin_r is the linear transformation that you apply to aggregated message from neighbors.self.lin_r=Linear(in_channels,out_channels) #Wrself.reset_parameters()def reset_parameters(self):self.lin_l.reset_parameters()self.lin_r.reset_parameters()def forward(self, x, edge_index, size = None):# 调用propagation函数进行消息传递:propagate(edge_index, x=(x_i, x_j), extra=(extra_i, extra_j), size=size)# 我们将只使用邻居节点(x_j)的表示,因此默认情况下我们为中心节点和邻居节点传递与x=(x,x)相同的表示out1 = self.lin_l(x)out2 = self.propagate(edge_index,x = (x,x),size = size)out2 = self.lin_r(out2)out = out1 + out2if self.normalize:out = F.normalize(out)return out# 供propagate调用,对于所有(i,j)边,构造从邻点j到中心点i的信息# x_j表示 所有邻点的特征嵌入矩阵 def message(self, x_j):out = x_jreturn out# 聚合邻居信息def aggregate(self, inputs, index, dim_size = None):# The axis along which to index number of nodes.node_dim = self.node_dimout = torch_scatter.scatter(inputs,index,node_dim,dim_size=dim_size,reduce='mean')return out
2.2 GAT
class GAT(MessagePassing):def __init__(self, in_channels, out_channels, heads = 2,negative_slope = 0.2, dropout = 0., **kwargs):super(GAT, self).__init__(node_dim=0, **kwargs)self.in_channels = in_channelsself.out_channels = out_channelsself.heads = headsself.negative_slope = negative_slopeself.dropout = dropout# self.lin_l is the linear transformation that you apply to embeddings # Pay attention to dimensions of the linear layers, since we're using multi-head attention.self.lin_l = Linear(in_channels,heads*out_channels) #W_l 这里的in_channels就是已经乘过heads的数字self.lin_r = self.lin_l #W_r# Define the attention parameters \overrightarrow{a_l/r}^T in the above intro.self.att_l = Parameter(torch.Tensor(1, heads, out_channels))self.att_r = Parameter(torch.Tensor(1, heads, out_channels))self.reset_parameters()def reset_parameters(self):nn.init.xavier_uniform_(self.lin_l.weight)nn.init.xavier_uniform_(self.lin_r.weight)nn.init.xavier_uniform_(self.att_l)nn.init.xavier_uniform_(self.att_r)def forward(self, x, edge_index, size = None):H, C = self.heads, self.out_channelsx_l = self.lin_l(x)x_r = self.lin_r(x)x_l = x_l.view(-1,H,C)x_r = x_r.view(-1,H,C)alpha_l = (x_l * self.att_l).sum(axis=1) #*是逐元素相乘(每个特征对应的所有节点一样处理?)。sum的维度是H(聚合)。alpha_r = (x_r * self.att_r).sum(axis=1)out = self.propagate(edge_index, x=(x_l, x_r), alpha=(alpha_l, alpha_r),size=size)out = out.view(-1, H * C)return outdef message(self, x_j, alpha_j, alpha_i, index, ptr, size_i):#alpha:[E, C]alpha = alpha_i + alpha_j #leakyrelu的对象alpha = F.leaky_relu(alpha,self.negative_slope)alpha = softmax(alpha, index, ptr, size_i)alpha = F.dropout(alpha, p=self.dropout, training=self.training).unsqueeze(1) #[E,1,C]out = x_j * alpha #通过计算得到的alpha来计算节点信息聚合值(得到h_i^') #[E,H,C]return outdef aggregate(self, inputs, index, dim_size = None):out = torch_scatter.scatter(inputs, index, dim=self.node_dim, dim_size=dim_size, reduce='sum')return out
3. 训练
3.1 优化器
import torch.optim as optimdef build_optimizer(args, params):weight_decay = args.weight_decayfilter_fn = filter(lambda p : p.requires_grad, params)if args.opt == 'adam':optimizer = optim.Adam(filter_fn, lr=args.lr, weight_decay=weight_decay)elif args.opt == 'sgd':optimizer = optim.SGD(filter_fn, lr=args.lr, momentum=0.95, weight_decay=weight_decay)elif args.opt == 'rmsprop':optimizer = optim.RMSprop(filter_fn, lr=args.lr, weight_decay=weight_decay)elif args.opt == 'adagrad':optimizer = optim.Adagrad(filter_fn, lr=args.lr, weight_decay=weight_decay)if args.opt_scheduler == 'none':return None, optimizerelif args.opt_scheduler == 'step':scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=args.opt_decay_step, gamma=args.opt_decay_rate)elif args.opt_scheduler == 'cos':scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.opt_restart)return scheduler, optimizer
3.2 训练
import timeimport networkx as nx
import numpy as np
import torch
import torch.optim as optim
from tqdm import trange
import pandas as pd
import copyfrom torch_geometric.datasets import TUDataset
from torch_geometric.datasets import Planetoid
from torch_geometric.data import DataLoaderimport torch_geometric.nn as pyg_nnimport matplotlib.pyplot as pltdef train(dataset, args):print("Node task. test set size:", np.sum(dataset[0]['test_mask'].numpy()))print()test_loader = loader = DataLoader(dataset, batch_size=args.batch_size, shuffle=False)# build modelmodel = GNNStack(dataset.num_node_features, args.hidden_dim, dataset.num_classes, args)scheduler, opt = build_optimizer(args, model.parameters())# trainlosses = []test_accs = []best_acc = 0best_model = Nonefor epoch in trange(args.epochs, desc="Training", unit="Epochs"):total_loss = 0model.train()for batch in loader:opt.zero_grad()pred = model(batch)label = batch.ypred = pred[batch.train_mask]label = label[batch.train_mask]loss = model.loss(pred, label)loss.backward()opt.step()total_loss += loss.item() * batch.num_graphstotal_loss /= len(loader.dataset)losses.append(total_loss)if epoch % 10 == 0:test_acc = test(test_loader, model)test_accs.append(test_acc)if test_acc > best_acc:best_acc = test_accbest_model = copy.deepcopy(model)else:test_accs.append(test_accs[-1])return test_accs, losses, best_model, best_acc, test_loaderdef test(loader, test_model, is_validation=False, save_model_preds=False, model_type=None):test_model.eval()correct = 0# Note that Cora is only one graph!for data in loader:with torch.no_grad():# max(dim=1) returns values, indices tuple; only need indicespred = test_model(data).max(dim=1)[1]label = data.ymask = data.val_mask if is_validation else data.test_mask# node classification: only evaluate on nodes in test setpred = pred[mask]label = label[mask]if save_model_preds:print ("Saving Model Predictions for Model Type", model_type)data = {}data['pred'] = pred.view(-1).cpu().detach().numpy()data['label'] = label.view(-1).cpu().detach().numpy()df = pd.DataFrame(data=data)# Save locally as csvdf.to_csv('CORA-Node-' + model_type + '.csv', sep=',', index=False)correct += pred.eq(label).sum().item()total = 0for data in loader.dataset:total += torch.sum(data.val_mask if is_validation else data.test_mask).item()return correct / totalclass objectview(object):def __init__(self, d):self.__dict__ = d
for args in [{'model_type': 'GraphSage', 'dataset': 'cora', 'num_layers': 2, 'heads': 1, 'batch_size': 32, 'hidden_dim': 32, 'dropout': 0.5, 'epochs': 500, 'opt': 'adam', 'opt_scheduler': 'none', 'opt_restart': 0, 'weight_decay': 5e-3, 'lr': 0.01},
]:args = objectview(args)for model in ['GraphSage']:args.model_type = model# Match the dimension.if model == 'GAT':args.heads = 2else:args.heads = 1if args.dataset == 'cora':dataset = Planetoid(root='/tmp/cora', name='Cora')else:raise NotImplementedError("Unknown dataset") test_accs, losses, best_model, best_acc, test_loader = train(dataset, args) print("Maximum test set accuracy: {0}".format(max(test_accs)))print("Minimum loss: {0}".format(min(losses)))# Run test for our best model to save the predictions!test(test_loader, best_model, is_validation=False, save_model_preds=True, model_type=model)print()plt.title(dataset.name)plt.plot(losses, label="training loss" + " - " + args.model_type)plt.plot(test_accs, label="test accuracy" + " - " + args.model_type)plt.legend()plt.show()
相关文章:
Code Lab - 34
GAT里面有一些地方看的不是太懂(GAT里Multi Attention的具体做法),暂时找了参考代码,留一个疑问 1. 一个通用的GNN Stack import torch_geometric import torch import torch_scatter import torch.nn as nn import torch.nn.fun…...
后端返回文件流,前端怎么导出、下载(8种方法可实现)
在前端导出和下载后端返回的文件流时,可以使用以下几种方法: 使用window.open()方法: 在前端使用window.open()方法打开一个新的窗口或标签页,并将后端返回的文件流作为URL传递给该方法。浏览器会自动下载该文件。例如:…...
什么是 ThreadLocal?
ThreadLocal 是 Java 中的一个类,用于在多线程环境下,为每个线程提供独立的变量副本。每个线程可以通过 ThreadLocal 存储和获取数据,而不会影响其他线程的数据。这在某些情况下非常有用,特别是当多个线程需要访问共享数据,但又希望保持数据的隔离性时。 ThreadLocal 主要…...
CANOCO5.0实现冗余分析(RDA)最详细步骤
在地理及生态领域会常使用RDA分析,RDA的实现路径也有很多,今天介绍一下CANOCO软件的实现方法。 1.软件安装 时间调整到2010年 2.数据处理 得有不同的物种或者样点数值,再加上环境因子数据。 3.软件运行 4.结果解读 结果解读主要把握这几点…...
【tkinter 专栏】掷骰子游戏
文章目录 前言本章内容导图1. 需求分析2. 系统功能结构3. 设计流程4. 系统开发环境5. 系统预览6. 窗口布局7. 功能实现用户和电脑选择骰子的点数大小摇骰子过程实现判断游戏结果单击开始按钮进行游戏源代码汇总前言 本专栏将参考《Python GUI 设计 tkinter 从入门到实践》书籍…...
19 NAT穿透|python高级
文章目录 网络通信过程NAT穿透 python高级GIL锁深拷贝与浅拷贝私有化import导入模块工厂模式多继承以及 MRO 顺序烧脑题property属性property装饰器property类属性 魔法属性\_\_doc\_\_\_\_module\_\_ 和 \_\_class\_\_\_\_init\_\_\_\_del\_\_\_\_call\_\_\_\_dict\_\_\_\_str…...
2023常见前端面试题
以下是一些2023年秋招常见的前端面试题及其答案: 1. 请解释一下什么是前端开发? 前端开发是指使用HTML、CSS和JavaScript等技术来构建网页和用户界面的过程。前端开发人员负责将设计师提供的视觉设计转化为可交互的网页,并确保网页在不同设备…...
登录校验-JWT令牌-生成和校验
目录 JWT-生成 具体代码 运行结果如下 JWT-校验 具体代码 运行结果如下 小结 JWT-生成 具体代码 /*** 测试JWT令牌的生成*/Testpublic void TestJWT() {// 设置自定义内容Map<String, Object> claims new HashMap<>();claims.put("id", 1);claims…...
GIT 常用指令
基础指令 $ git init #初始化仓库,在该文件夹创建的为workspace$ git add . #已暂存 [.通配符,全部添加]$ git commit -m "log add file" #提交到仓库,并写了日志 ”log add file“$ git status #查看状态,可查看被修改的文件…...
多目标优化
https://zhuanlan.zhihu.com/p/158705342 概念 单目标优化只有一个优化目标,所以可以比较其好坏。 但是多目标优化,在需要优化多个目标时,容易存在目标之间的冲突,一个目标的优化是以其他目标劣化为代价的,所以我们要…...
odoo的优势
plus,主要是为了能尽早通过开发者审核,加入到chatgpt4 api的开发中去,接入到我们odoo aiCenter中。4的回答,明显比3.5的更聪明了。 可能是由于国内的特殊情况吧,我们的chatgpt模块很受欢迎,我也被问了不少…...
Spring Boot(Vue3+ElementPlus+Axios+MyBatisPlus+Spring Boot 前后端分离)【三】
😀前言 本篇博文是关于Spring Boot(Vue3ElementPlusAxiosMyBatisPlusSpring Boot 前后端分离)【三】的分享,希望你能够喜欢 🏠个人主页:晨犀主页 🧑个人简介:大家好,我是晨犀,希望我…...
Kali 软件管理
kali 更新 1. 查看发行版本 ┌──(root㉿kali)-[~] └─# lsb_release -a No LSB modules are available. Distributor ID: Kali Description: Kali GNU/Linux Rolling Release: 2023.2 Codename: kali-rolling2. 查看内核版本 ┌──(root㉿kali)-[~] └─…...
加油站【贪心算法】
加油站 在一条环路上有 n 个加油站,其中第 i 个加油站有汽油 gas[i] 升。 你有一辆油箱容量无限的的汽车,从第 i 个加油站开往第 i1 个加油站需要消耗汽油 cost[i] 升。你从其中的一个加油站出发,开始时油箱为空。 给定两个整数数组 gas 和…...
java八股文面试[多线程]——死锁、活锁、饥饿
DCL双重锁:TODO 如何预防死锁: 如何查看线程死锁: 知识来源: 【2023年面试】描述一下线程安全活跃态问题,以及竞态条件_哔哩哔哩_bilibili 【2023年面试】如何预防死锁_哔哩哔哩_bilibili 【并发与线程】阿里一面&…...
设计模式——装饰器模式
装饰器模式 装饰器模式(Decorator Pattern)允许向一个现有的对象添加新的功能,同时又不改变其结构。这种类型的设计模式属于结构型模式,它是作为现有的类的一个包装。 装饰器模式通过将对象包装在装饰器类中,以便动态…...
①matlab的命令掌握
目录 输入命令 命名变量 保存和加载变量 使用内置的函数和常量 输入命令 1.您可以通过在命令行窗口中 MATLAB 提示符 (>>) 后输入命令 任务 使用命令 3*5 将数值 3 和 5 相乘。 答案 3*5 2.除非另有指定,否则 MATLAB 会将计算结果存储在一个名为 ans…...
MySQL----索引
一、索引的概念 索引是一个排序的列表,在这个列表中存储着索引的值和包含这个值的数据所在行的物理地址(类似于c语言的链表通过指针指向数据记录的内存地址)。使用索引后可以不用扫描全表来定位某行的数据,而是先通过索引表找到该…...
秒杀系统的业务流程以及优化方案(实现异步秒杀)
先看基本的业务流程 那么我们可以看到整个流程都是一个线程来完成的,这样的话耗时还是很长的,那么可不可以采用多线程去实现呢? 首先我们要思考怎么对业务进行拆分,可以想象一个我们去饭店点餐,会有前台接待ÿ…...
Java实现根据商品ID获取1688商品详情跨境属性数据,1688商品重量数据接口,1688API接口封装方法
要通过1688的API获取商品详情跨境属性数据,您可以使用1688开放平台提供的接口来实现。以下是一种使用Java编程语言实现的示例,展示如何通过1688开放平台API获取商品详情属性数据接口: 首先,确保您已注册成为1688开放平台的开发者…...
BEYOND REALITY Z-Image新手入门:三步生成你的第一张8K写真人像
BEYOND REALITY Z-Image新手入门:三步生成你的第一张8K写真人像 1. 为什么选择BEYOND REALITY Z-Image? 在当前的AI图像生成领域,写实人像一直是最具挑战性的任务之一。传统模型往往难以平衡细节精度与自然感,生成的图片要么过于…...
学术论文解析神器!OpenDataLab MinerU智能文档理解实测体验
学术论文解析神器!OpenDataLab MinerU智能文档理解实测体验 1. 前言:当AI遇见学术论文 对于每一位科研工作者、学生或技术从业者来说,阅读和整理学术论文都是一项既基础又繁重的工作。你是否也曾经历过这样的场景:面对一篇几十页…...
2026硬核拆解:Grok 4.1镜像双版本架构、实时数据与情感智能实战评测
对于追求实时信息获取、个性化交互与创意内容生成的AI用户,2026年xAI推出的Grok 4.1系列(含Thinking与Fast双版本)凭借其独特的实时知识库、可调节的“叛逆风格”与卓越的情感智能,在竞争激烈的大模型市场中开辟了差异化赛道。 若…...
【喜报】义翘神州再获CNAS认可,全面对标2025版药典新标准
义翘神州生物安全检测实验室近日成功通过中国合格评定国家认可委员会(CNAS)的扩项评审及定期监督评审,并已完成全部能力附表更新!这标志着实验室技术能力与质量管理体系持续符合ISO/IEC 17025:2017国际标准的严苛要求,…...
算法审判日:用Git记录定程序员罪孽
一、版本控制的“审判台”在软件质量保障体系中,Git早已超越单纯的版本管理工具,演变为代码行为的“司法档案库”。每一次git commit都是程序员在数字法庭上的宣誓证词,而git blame则成为测试人员追溯缺陷根源的刑侦工具。罪证链条的三重维度…...
4阶段构建企业级离线文档处理平台:从问题诊断到性能优化全指南
4阶段构建企业级离线文档处理平台:从问题诊断到性能优化全指南 【免费下载链接】WeKnora LLM-powered framework for deep document understanding, semantic retrieval, and context-aware answers using RAG paradigm. 项目地址: https://gitcode.com/GitHub_Tr…...
在Jetson Orin Nano上手动编译部署AirSLAM:如何解决TensorRT模型转换(ONNX转Engine)的内存溢出问题
在Jetson Orin Nano上手动编译部署AirSLAM:解决TensorRT模型转换内存溢出的实战指南 1. 边缘设备部署AirSLAM的核心挑战 Jetson Orin Nano作为NVIDIA面向边缘计算推出的高性能模块,其4GB/8GB内存配置在运行复杂视觉SLAM算法时面临严峻的资源约束。AirSLA…...
Bilibili-Evolved:视频播放卡顿解决方案:实现60fps流畅体验的智能优化方法
Bilibili-Evolved:视频播放卡顿解决方案:实现60fps流畅体验的智能优化方法 【免费下载链接】Bilibili-Evolved 强大的哔哩哔哩增强脚本 项目地址: https://gitcode.com/gh_mirrors/bi/Bilibili-Evolved 你是否曾在观看高清动画时遇到画面卡顿&…...
HarmonyOS6 半年磨一剑 - RcCheckbox 组件核心架构与类型系统设计
文章目录前言一、组件整体架构1.1 双组件协作设计1.2 文件结构1.3 装饰器分工二、类型系统深度解析2.1 值类型的宽泛设计2.2 选项配置接口2.3 形状与尺寸类型三、核心参数体系3.1 RcCheckbox 参数全览3.2 RcCheckboxGroup 扩展参数四、内部状态设计4.1 受控模式的双状态机制4.2…...
Echo:预测智能的一小步,通往通用智能的一大步
来源:机器之心大模型能否预测未来?UniPat AI 构建了一套完整的预测智能基础设施,Echo,包含动态评测引擎、面向未来事件的训练范式和预测专用模型 EchoZ-1.0。在其公开的 General AI Prediction Leaderboard 上,EchoZ-1…...
