Code Lab - 34
GAT里面有一些地方看的不是太懂(GAT里Multi Attention的具体做法),暂时找了参考代码,留一个疑问
1. 一个通用的GNN Stack
import torch_geometric
import torch
import torch_scatter
import torch.nn as nn
import torch.nn.functional as Fimport torch_geometric.nn as pyg_nn
import torch_geometric.utils as pyg_utilsfrom torch import Tensor
from typing import Union, Tuple, Optional
from torch_geometric.typing import (OptPairTensor, Adj, Size, NoneType,OptTensor)from torch.nn import Parameter, Linear
from torch_sparse import SparseTensor, set_diag
from torch_geometric.nn.conv import MessagePassing
from torch_geometric.utils import remove_self_loops, add_self_loops, softmaxclass GNNStack(torch.nn.Module):def __init__(self, input_dim, hidden_dim, output_dim, args, emb=False):super(GNNStack, self).__init__()conv_model = self.build_conv_model(args.model_type)self.convs = nn.ModuleList()self.convs.append(conv_model(input_dim, hidden_dim))#assert(断言) 用于判断一个表达式,在表达式条件为 false 的时候触发异常assert (args.num_layers >= 1), 'Number of layers is not >=1'for l in range(args.num_layers-1):self.convs.append(conv_model(args.heads * hidden_dim, hidden_dim))# post-message-passingself.post_mp = nn.Sequential(nn.Linear(args.heads * hidden_dim, hidden_dim), nn.Dropout(args.dropout), nn.Linear(hidden_dim, output_dim))self.dropout = args.dropoutself.num_layers = args.num_layersself.emb = embdef build_conv_model(self, model_type):if model_type == 'GraphSage':return GraphSageelif model_type == 'GAT':# When applying GAT with num heads > 1, you need to modify the # input and output dimension of the conv layers (self.convs),# to ensure that the input dim of the next layer is num heads# multiplied by the output dim of the previous layer.# HINT: In case you want to play with multiheads, you need to change the for-loop that builds up self.convs to be# self.convs.append(conv_model(hidden_dim * num_heads, hidden_dim)), # and also the first nn.Linear(hidden_dim * num_heads, hidden_dim) in post-message-passing.return GATdef forward(self, data):x, edge_index, batch = data.x, data.edge_index, data.batchfor i in range(self.num_layers):x = self.convs[i](x, edge_index)x = F.relu(x)x = F.dropout(x, p=self.dropout,training=self.training)x = self.post_mp(x)if self.emb == True:return xreturn F.log_softmax(x, dim=1)def loss(self, pred, label):return F.nll_loss(pred, label)
2. 实现GraphSage和GAT
2.1 GraphSage

class GraphSage(MessagePassing):def __init__(self, in_channels, out_channels, normalize = True,bias = False, **kwargs): super(GraphSage, self).__init__(**kwargs)self.in_channels = in_channelsself.out_channels = out_channelsself.normalize = normalize# self.lin_l is the linear transformation that you apply to embedding for central node.self.lin_l=Linear(in_channels,out_channels) #Wl# self.lin_r is the linear transformation that you apply to aggregated message from neighbors.self.lin_r=Linear(in_channels,out_channels) #Wrself.reset_parameters()def reset_parameters(self):self.lin_l.reset_parameters()self.lin_r.reset_parameters()def forward(self, x, edge_index, size = None):# 调用propagation函数进行消息传递:propagate(edge_index, x=(x_i, x_j), extra=(extra_i, extra_j), size=size)# 我们将只使用邻居节点(x_j)的表示,因此默认情况下我们为中心节点和邻居节点传递与x=(x,x)相同的表示out1 = self.lin_l(x)out2 = self.propagate(edge_index,x = (x,x),size = size)out2 = self.lin_r(out2)out = out1 + out2if self.normalize:out = F.normalize(out)return out# 供propagate调用,对于所有(i,j)边,构造从邻点j到中心点i的信息# x_j表示 所有邻点的特征嵌入矩阵 def message(self, x_j):out = x_jreturn out# 聚合邻居信息def aggregate(self, inputs, index, dim_size = None):# The axis along which to index number of nodes.node_dim = self.node_dimout = torch_scatter.scatter(inputs,index,node_dim,dim_size=dim_size,reduce='mean')return out
2.2 GAT
class GAT(MessagePassing):def __init__(self, in_channels, out_channels, heads = 2,negative_slope = 0.2, dropout = 0., **kwargs):super(GAT, self).__init__(node_dim=0, **kwargs)self.in_channels = in_channelsself.out_channels = out_channelsself.heads = headsself.negative_slope = negative_slopeself.dropout = dropout# self.lin_l is the linear transformation that you apply to embeddings # Pay attention to dimensions of the linear layers, since we're using multi-head attention.self.lin_l = Linear(in_channels,heads*out_channels) #W_l 这里的in_channels就是已经乘过heads的数字self.lin_r = self.lin_l #W_r# Define the attention parameters \overrightarrow{a_l/r}^T in the above intro.self.att_l = Parameter(torch.Tensor(1, heads, out_channels))self.att_r = Parameter(torch.Tensor(1, heads, out_channels))self.reset_parameters()def reset_parameters(self):nn.init.xavier_uniform_(self.lin_l.weight)nn.init.xavier_uniform_(self.lin_r.weight)nn.init.xavier_uniform_(self.att_l)nn.init.xavier_uniform_(self.att_r)def forward(self, x, edge_index, size = None):H, C = self.heads, self.out_channelsx_l = self.lin_l(x)x_r = self.lin_r(x)x_l = x_l.view(-1,H,C)x_r = x_r.view(-1,H,C)alpha_l = (x_l * self.att_l).sum(axis=1) #*是逐元素相乘(每个特征对应的所有节点一样处理?)。sum的维度是H(聚合)。alpha_r = (x_r * self.att_r).sum(axis=1)out = self.propagate(edge_index, x=(x_l, x_r), alpha=(alpha_l, alpha_r),size=size)out = out.view(-1, H * C)return outdef message(self, x_j, alpha_j, alpha_i, index, ptr, size_i):#alpha:[E, C]alpha = alpha_i + alpha_j #leakyrelu的对象alpha = F.leaky_relu(alpha,self.negative_slope)alpha = softmax(alpha, index, ptr, size_i)alpha = F.dropout(alpha, p=self.dropout, training=self.training).unsqueeze(1) #[E,1,C]out = x_j * alpha #通过计算得到的alpha来计算节点信息聚合值(得到h_i^') #[E,H,C]return outdef aggregate(self, inputs, index, dim_size = None):out = torch_scatter.scatter(inputs, index, dim=self.node_dim, dim_size=dim_size, reduce='sum')return out
3. 训练
3.1 优化器
import torch.optim as optimdef build_optimizer(args, params):weight_decay = args.weight_decayfilter_fn = filter(lambda p : p.requires_grad, params)if args.opt == 'adam':optimizer = optim.Adam(filter_fn, lr=args.lr, weight_decay=weight_decay)elif args.opt == 'sgd':optimizer = optim.SGD(filter_fn, lr=args.lr, momentum=0.95, weight_decay=weight_decay)elif args.opt == 'rmsprop':optimizer = optim.RMSprop(filter_fn, lr=args.lr, weight_decay=weight_decay)elif args.opt == 'adagrad':optimizer = optim.Adagrad(filter_fn, lr=args.lr, weight_decay=weight_decay)if args.opt_scheduler == 'none':return None, optimizerelif args.opt_scheduler == 'step':scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=args.opt_decay_step, gamma=args.opt_decay_rate)elif args.opt_scheduler == 'cos':scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.opt_restart)return scheduler, optimizer
3.2 训练
import timeimport networkx as nx
import numpy as np
import torch
import torch.optim as optim
from tqdm import trange
import pandas as pd
import copyfrom torch_geometric.datasets import TUDataset
from torch_geometric.datasets import Planetoid
from torch_geometric.data import DataLoaderimport torch_geometric.nn as pyg_nnimport matplotlib.pyplot as pltdef train(dataset, args):print("Node task. test set size:", np.sum(dataset[0]['test_mask'].numpy()))print()test_loader = loader = DataLoader(dataset, batch_size=args.batch_size, shuffle=False)# build modelmodel = GNNStack(dataset.num_node_features, args.hidden_dim, dataset.num_classes, args)scheduler, opt = build_optimizer(args, model.parameters())# trainlosses = []test_accs = []best_acc = 0best_model = Nonefor epoch in trange(args.epochs, desc="Training", unit="Epochs"):total_loss = 0model.train()for batch in loader:opt.zero_grad()pred = model(batch)label = batch.ypred = pred[batch.train_mask]label = label[batch.train_mask]loss = model.loss(pred, label)loss.backward()opt.step()total_loss += loss.item() * batch.num_graphstotal_loss /= len(loader.dataset)losses.append(total_loss)if epoch % 10 == 0:test_acc = test(test_loader, model)test_accs.append(test_acc)if test_acc > best_acc:best_acc = test_accbest_model = copy.deepcopy(model)else:test_accs.append(test_accs[-1])return test_accs, losses, best_model, best_acc, test_loaderdef test(loader, test_model, is_validation=False, save_model_preds=False, model_type=None):test_model.eval()correct = 0# Note that Cora is only one graph!for data in loader:with torch.no_grad():# max(dim=1) returns values, indices tuple; only need indicespred = test_model(data).max(dim=1)[1]label = data.ymask = data.val_mask if is_validation else data.test_mask# node classification: only evaluate on nodes in test setpred = pred[mask]label = label[mask]if save_model_preds:print ("Saving Model Predictions for Model Type", model_type)data = {}data['pred'] = pred.view(-1).cpu().detach().numpy()data['label'] = label.view(-1).cpu().detach().numpy()df = pd.DataFrame(data=data)# Save locally as csvdf.to_csv('CORA-Node-' + model_type + '.csv', sep=',', index=False)correct += pred.eq(label).sum().item()total = 0for data in loader.dataset:total += torch.sum(data.val_mask if is_validation else data.test_mask).item()return correct / totalclass objectview(object):def __init__(self, d):self.__dict__ = d
for args in [{'model_type': 'GraphSage', 'dataset': 'cora', 'num_layers': 2, 'heads': 1, 'batch_size': 32, 'hidden_dim': 32, 'dropout': 0.5, 'epochs': 500, 'opt': 'adam', 'opt_scheduler': 'none', 'opt_restart': 0, 'weight_decay': 5e-3, 'lr': 0.01},
]:args = objectview(args)for model in ['GraphSage']:args.model_type = model# Match the dimension.if model == 'GAT':args.heads = 2else:args.heads = 1if args.dataset == 'cora':dataset = Planetoid(root='/tmp/cora', name='Cora')else:raise NotImplementedError("Unknown dataset") test_accs, losses, best_model, best_acc, test_loader = train(dataset, args) print("Maximum test set accuracy: {0}".format(max(test_accs)))print("Minimum loss: {0}".format(min(losses)))# Run test for our best model to save the predictions!test(test_loader, best_model, is_validation=False, save_model_preds=True, model_type=model)print()plt.title(dataset.name)plt.plot(losses, label="training loss" + " - " + args.model_type)plt.plot(test_accs, label="test accuracy" + " - " + args.model_type)plt.legend()plt.show()
相关文章:
Code Lab - 34
GAT里面有一些地方看的不是太懂(GAT里Multi Attention的具体做法),暂时找了参考代码,留一个疑问 1. 一个通用的GNN Stack import torch_geometric import torch import torch_scatter import torch.nn as nn import torch.nn.fun…...
后端返回文件流,前端怎么导出、下载(8种方法可实现)
在前端导出和下载后端返回的文件流时,可以使用以下几种方法: 使用window.open()方法: 在前端使用window.open()方法打开一个新的窗口或标签页,并将后端返回的文件流作为URL传递给该方法。浏览器会自动下载该文件。例如:…...
什么是 ThreadLocal?
ThreadLocal 是 Java 中的一个类,用于在多线程环境下,为每个线程提供独立的变量副本。每个线程可以通过 ThreadLocal 存储和获取数据,而不会影响其他线程的数据。这在某些情况下非常有用,特别是当多个线程需要访问共享数据,但又希望保持数据的隔离性时。 ThreadLocal 主要…...
CANOCO5.0实现冗余分析(RDA)最详细步骤
在地理及生态领域会常使用RDA分析,RDA的实现路径也有很多,今天介绍一下CANOCO软件的实现方法。 1.软件安装 时间调整到2010年 2.数据处理 得有不同的物种或者样点数值,再加上环境因子数据。 3.软件运行 4.结果解读 结果解读主要把握这几点…...
【tkinter 专栏】掷骰子游戏
文章目录 前言本章内容导图1. 需求分析2. 系统功能结构3. 设计流程4. 系统开发环境5. 系统预览6. 窗口布局7. 功能实现用户和电脑选择骰子的点数大小摇骰子过程实现判断游戏结果单击开始按钮进行游戏源代码汇总前言 本专栏将参考《Python GUI 设计 tkinter 从入门到实践》书籍…...
19 NAT穿透|python高级
文章目录 网络通信过程NAT穿透 python高级GIL锁深拷贝与浅拷贝私有化import导入模块工厂模式多继承以及 MRO 顺序烧脑题property属性property装饰器property类属性 魔法属性\_\_doc\_\_\_\_module\_\_ 和 \_\_class\_\_\_\_init\_\_\_\_del\_\_\_\_call\_\_\_\_dict\_\_\_\_str…...
2023常见前端面试题
以下是一些2023年秋招常见的前端面试题及其答案: 1. 请解释一下什么是前端开发? 前端开发是指使用HTML、CSS和JavaScript等技术来构建网页和用户界面的过程。前端开发人员负责将设计师提供的视觉设计转化为可交互的网页,并确保网页在不同设备…...
登录校验-JWT令牌-生成和校验
目录 JWT-生成 具体代码 运行结果如下 JWT-校验 具体代码 运行结果如下 小结 JWT-生成 具体代码 /*** 测试JWT令牌的生成*/Testpublic void TestJWT() {// 设置自定义内容Map<String, Object> claims new HashMap<>();claims.put("id", 1);claims…...
GIT 常用指令
基础指令 $ git init #初始化仓库,在该文件夹创建的为workspace$ git add . #已暂存 [.通配符,全部添加]$ git commit -m "log add file" #提交到仓库,并写了日志 ”log add file“$ git status #查看状态,可查看被修改的文件…...
多目标优化
https://zhuanlan.zhihu.com/p/158705342 概念 单目标优化只有一个优化目标,所以可以比较其好坏。 但是多目标优化,在需要优化多个目标时,容易存在目标之间的冲突,一个目标的优化是以其他目标劣化为代价的,所以我们要…...
odoo的优势
plus,主要是为了能尽早通过开发者审核,加入到chatgpt4 api的开发中去,接入到我们odoo aiCenter中。4的回答,明显比3.5的更聪明了。 可能是由于国内的特殊情况吧,我们的chatgpt模块很受欢迎,我也被问了不少…...
Spring Boot(Vue3+ElementPlus+Axios+MyBatisPlus+Spring Boot 前后端分离)【三】
😀前言 本篇博文是关于Spring Boot(Vue3ElementPlusAxiosMyBatisPlusSpring Boot 前后端分离)【三】的分享,希望你能够喜欢 🏠个人主页:晨犀主页 🧑个人简介:大家好,我是晨犀,希望我…...
Kali 软件管理
kali 更新 1. 查看发行版本 ┌──(root㉿kali)-[~] └─# lsb_release -a No LSB modules are available. Distributor ID: Kali Description: Kali GNU/Linux Rolling Release: 2023.2 Codename: kali-rolling2. 查看内核版本 ┌──(root㉿kali)-[~] └─…...
加油站【贪心算法】
加油站 在一条环路上有 n 个加油站,其中第 i 个加油站有汽油 gas[i] 升。 你有一辆油箱容量无限的的汽车,从第 i 个加油站开往第 i1 个加油站需要消耗汽油 cost[i] 升。你从其中的一个加油站出发,开始时油箱为空。 给定两个整数数组 gas 和…...
java八股文面试[多线程]——死锁、活锁、饥饿
DCL双重锁:TODO 如何预防死锁: 如何查看线程死锁: 知识来源: 【2023年面试】描述一下线程安全活跃态问题,以及竞态条件_哔哩哔哩_bilibili 【2023年面试】如何预防死锁_哔哩哔哩_bilibili 【并发与线程】阿里一面&…...
设计模式——装饰器模式
装饰器模式 装饰器模式(Decorator Pattern)允许向一个现有的对象添加新的功能,同时又不改变其结构。这种类型的设计模式属于结构型模式,它是作为现有的类的一个包装。 装饰器模式通过将对象包装在装饰器类中,以便动态…...
①matlab的命令掌握
目录 输入命令 命名变量 保存和加载变量 使用内置的函数和常量 输入命令 1.您可以通过在命令行窗口中 MATLAB 提示符 (>>) 后输入命令 任务 使用命令 3*5 将数值 3 和 5 相乘。 答案 3*5 2.除非另有指定,否则 MATLAB 会将计算结果存储在一个名为 ans…...
MySQL----索引
一、索引的概念 索引是一个排序的列表,在这个列表中存储着索引的值和包含这个值的数据所在行的物理地址(类似于c语言的链表通过指针指向数据记录的内存地址)。使用索引后可以不用扫描全表来定位某行的数据,而是先通过索引表找到该…...
秒杀系统的业务流程以及优化方案(实现异步秒杀)
先看基本的业务流程 那么我们可以看到整个流程都是一个线程来完成的,这样的话耗时还是很长的,那么可不可以采用多线程去实现呢? 首先我们要思考怎么对业务进行拆分,可以想象一个我们去饭店点餐,会有前台接待ÿ…...
Java实现根据商品ID获取1688商品详情跨境属性数据,1688商品重量数据接口,1688API接口封装方法
要通过1688的API获取商品详情跨境属性数据,您可以使用1688开放平台提供的接口来实现。以下是一种使用Java编程语言实现的示例,展示如何通过1688开放平台API获取商品详情属性数据接口: 首先,确保您已注册成为1688开放平台的开发者…...
SpringBoot-17-MyBatis动态SQL标签之常用标签
文章目录 1 代码1.1 实体User.java1.2 接口UserMapper.java1.3 映射UserMapper.xml1.3.1 标签if1.3.2 标签if和where1.3.3 标签choose和when和otherwise1.4 UserController.java2 常用动态SQL标签2.1 标签set2.1.1 UserMapper.java2.1.2 UserMapper.xml2.1.3 UserController.ja…...
三维GIS开发cesium智慧地铁教程(5)Cesium相机控制
一、环境搭建 <script src"../cesium1.99/Build/Cesium/Cesium.js"></script> <link rel"stylesheet" href"../cesium1.99/Build/Cesium/Widgets/widgets.css"> 关键配置点: 路径验证:确保相对路径.…...
UE5 学习系列(三)创建和移动物体
这篇博客是该系列的第三篇,是在之前两篇博客的基础上展开,主要介绍如何在操作界面中创建和拖动物体,这篇博客跟随的视频链接如下: B 站视频:s03-创建和移动物体 如果你不打算开之前的博客并且对UE5 比较熟的话按照以…...
【JVM】- 内存结构
引言 JVM:Java Virtual Machine 定义:Java虚拟机,Java二进制字节码的运行环境好处: 一次编写,到处运行自动内存管理,垃圾回收的功能数组下标越界检查(会抛异常,不会覆盖到其他代码…...
ffmpeg(四):滤镜命令
FFmpeg 的滤镜命令是用于音视频处理中的强大工具,可以完成剪裁、缩放、加水印、调色、合成、旋转、模糊、叠加字幕等复杂的操作。其核心语法格式一般如下: ffmpeg -i input.mp4 -vf "滤镜参数" output.mp4或者带音频滤镜: ffmpeg…...
Spring是如何解决Bean的循环依赖:三级缓存机制
1、什么是 Bean 的循环依赖 在 Spring框架中,Bean 的循环依赖是指多个 Bean 之间互相持有对方引用,形成闭环依赖关系的现象。 多个 Bean 的依赖关系构成环形链路,例如: 双向依赖:Bean A 依赖 Bean B,同时 Bean B 也依赖 Bean A(A↔B)。链条循环: Bean A → Bean…...
处理vxe-table 表尾数据是单独一个接口,表格tableData数据更新后,需要点击两下,表尾才是正确的
修改bug思路: 分别把 tabledata 和 表尾相关数据 console.log() 发现 更新数据先后顺序不对 settimeout延迟查询表格接口 ——测试可行 升级↑:async await 等接口返回后再开始下一个接口查询 ________________________________________________________…...
Oracle11g安装包
Oracle 11g安装包 适用于windows系统,64位 下载路径 oracle 11g 安装包...
mac:大模型系列测试
0 MAC 前几天经过学生优惠以及国补17K入手了mac studio,然后这两天亲自测试其模型行运用能力如何,是否支持微调、推理速度等能力。下面进入正文。 1 mac 与 unsloth 按照下面的进行安装以及测试,是可以跑通文章里面的代码。训练速度也是很快的。 注意…...
ZYNQ学习记录FPGA(一)ZYNQ简介
一、知识准备 1.一些术语,缩写和概念: 1)ZYNQ全称:ZYNQ7000 All Pgrammable SoC 2)SoC:system on chips(片上系统),对比集成电路的SoB(system on board) 3)ARM:处理器…...
