当前位置: 首页 > news >正文

PyTorch搭建GNN(GCN、GraphSAGE和GAT)实现多节点、单节点内多变量输入多变量输出时空预测

目录

  • I. 前言
  • II. 数据集说明
  • III. 模型
    • 3.1 GCN
    • 3.2 GraphSAGE
    • 3.3 GAT
  • IV. 训练与测试
  • V. 实验结果

I. 前言

前面已经写了很多关于时间序列预测的文章:

  1. 深入理解PyTorch中LSTM的输入和输出(从input输入到Linear输出)
  2. PyTorch搭建LSTM实现时间序列预测(负荷预测)
  3. PyTorch中利用LSTMCell搭建多层LSTM实现时间序列预测
  4. PyTorch搭建LSTM实现多变量时间序列预测(负荷预测)
  5. PyTorch搭建双向LSTM实现时间序列预测(负荷预测)
  6. PyTorch搭建LSTM实现多变量多步长时间序列预测(一):直接多输出
  7. PyTorch搭建LSTM实现多变量多步长时间序列预测(二):单步滚动预测
  8. PyTorch搭建LSTM实现多变量多步长时间序列预测(三):多模型单步预测
  9. PyTorch搭建LSTM实现多变量多步长时间序列预测(四):多模型滚动预测
  10. PyTorch搭建LSTM实现多变量多步长时间序列预测(五):seq2seq
  11. PyTorch中实现LSTM多步长时间序列预测的几种方法总结(负荷预测)
  12. PyTorch-LSTM时间序列预测中如何预测真正的未来值
  13. PyTorch搭建LSTM实现多变量输入多变量输出时间序列预测(多任务学习)
  14. PyTorch搭建ANN实现时间序列预测(风速预测)
  15. PyTorch搭建CNN实现时间序列预测(风速预测)
  16. PyTorch搭建CNN-LSTM混合模型实现多变量多步长时间序列预测(负荷预测)
  17. PyTorch搭建Transformer实现多变量多步长时间序列预测(负荷预测)
  18. PyTorch时间序列预测系列文章总结(代码使用方法)
  19. TensorFlow搭建LSTM实现时间序列预测(负荷预测)
  20. TensorFlow搭建LSTM实现多变量时间序列预测(负荷预测)
  21. TensorFlow搭建双向LSTM实现时间序列预测(负荷预测)
  22. TensorFlow搭建LSTM实现多变量多步长时间序列预测(一):直接多输出
  23. TensorFlow搭建LSTM实现多变量多步长时间序列预测(二):单步滚动预测
  24. TensorFlow搭建LSTM实现多变量多步长时间序列预测(三):多模型单步预测
  25. TensorFlow搭建LSTM实现多变量多步长时间序列预测(四):多模型滚动预测
  26. TensorFlow搭建LSTM实现多变量多步长时间序列预测(五):seq2seq
  27. TensorFlow搭建LSTM实现多变量输入多变量输出时间序列预测(多任务学习)
  28. TensorFlow搭建ANN实现时间序列预测(风速预测)
  29. TensorFlow搭建CNN实现时间序列预测(风速预测)
  30. TensorFlow搭建CNN-LSTM混合模型实现多变量多步长时间序列预测(负荷预测)
  31. PyG搭建图神经网络实现多变量输入多变量输出时间序列预测
  32. PyTorch搭建GNN-LSTM和LSTM-GNN模型实现多变量输入多变量输出时间序列预测
  33. PyG Temporal搭建STGCN实现多变量输入多变量输出时间序列预测
  34. 时序预测中Attention机制是否真的有效?盘点LSTM/RNN中24种Attention机制+效果对比
  35. 详解Transformer在时序预测中的Encoder和Decoder过程:以负荷预测为例
  36. (PyTorch)TCN和RNN/LSTM/GRU结合实现时间序列预测
  37. PyTorch搭建Informer实现长序列时间序列预测
  38. PyTorch搭建Autoformer实现长序列时间序列预测
  39. PyTorch搭建GNN(GCN、GraphSAGE和GAT)实现多节点、单节点内多变量输入多变量输出时空预测

前边已经有两篇文章讲解了如何利用PyG搭建GNN以及GNN-LSTM进行时间序列预测,这两部分内容都只是针对多变量进行预测,即将每个变量当成一个节点,然后利用皮尔逊相关系数构建变量间的邻接矩阵。

上述过程有以下两个问题:
(1)不少人使用时不会安装PyG(后台经常有人询问如何安装PyG),这个其实不困难,具体安装命令为:

pip install torch_scatter torch_sparse torch_cluster torch_spline_conv torch-geometric -f https://data.pyg.org/whl/torch-1.10.0+cu113.html

使用时将命令中的torch-1.10.0+cu113.html换成自己的torch版本和CUDA版本即可。

(2)使用PyG有诸多限制,例如PyG中每个节点只能拥有单个变量序列。当然,可以将每个节点的多条变量序列用神经网络或者注意力机制等方式转换为一个变量序列,这里不再细说。

有不少人要求出一期多站点多变量预测,前期由于实习+秋招+毕设缠身,事情较多,所以没来得及写。现在有了一些空闲时间,因此在这篇文章里做一些详细的说明。

II. 数据集说明

本次使用的数据集为交通流量预测领域常见的PEMS系列数据集,包括PEMS03、PEMS04、PEMS07和PEMS08四个数据集。

其中,PEMS04是由307个探测器(节点数)每隔5分钟采集一次数据,共采集59天产生的交通流量数据;PEMS08是由170个探测器每隔5分钟采集一次,共采集62天产生的数据。每个探测器每次采集的数据包含三个维度的特征,分别为:流量、平均速度和平均占有率。因此,数据集的格式应该为一个矩阵,大小为num * num_nodes * 3,其中PSMS04的num=59*24*12=16992num_nodes=307,而PEMS08的num=62*24*12=17856num_nodes=170。PEMS03和PEMS07两个数据集中只包含流量这一个变量,二者的的大小分别为26208*358*1进和28224*883*1

在这篇文章中,使用前2小时的数据预测未来半小时的数据,即历史24个时刻的多个变量预测未来6个时刻的多个变量。

数据处理代码与前面类似:

def nn_seq(args):seq_len = args.seq_lenbatch_size, pred_len = args.batch_size, args.pred_lenroot_path = os.path.abspath(os.path.dirname(os.getcwd()))file_name = args.file_namedata_path = root_path + "/data/" + file_name + "/"npz = np.load(data_path + file_name + ".npz")data = npz['data']  # lens num_nodes, in_featsprint(data.shape)# data = data[:2000]# 3表示:车流量、平均车速、平均车道占用率num_nodes = data.shape[1]# 加载邻接矩阵adj_data = pd.read_csv(data_path + file_name + ".csv")adj_data = adj_data[["from", "to"]].values.tolist()# 找出最大最小值all_elements = [element for row in adj_data for element in row]all_elements = list(set(all_elements))all_elements.sort()print(len(all_elements) == num_nodes)node_dict = dict(zip(all_elements, [x for x in range(num_nodes)]))# print(max_val, min_val)adj = torch.zeros((num_nodes, num_nodes))for src, dst in adj_data:src = node_dict[src]dst = node_dict[dst]adj[src, dst] = adj[dst, src] = 1## splittrain = data[:int(len(data) * 0.6)]val = data[int(len(data) * 0.6):int(len(data) * 0.8)]test = data[int(len(data) * 0.8):]# 归一化 要求在站点内部,对按照时间列进行归一化scalers = []for idx in range(num_nodes):cur_train = train[:, idx, :]cur_val = val[:, idx, :]cur_test = test[:, idx, :]scaler = MinMaxScaler()train[:, idx, :] = scaler.fit_transform(cur_train)val[:, idx, :] = scaler.transform(cur_val)test[:, idx, :] = scaler.transform(cur_test)scalers.append(scaler)def process(dataset, step_size, shuffle):# dataset: num num_nodes dimseq = []for i in tqdm(range(0, len(dataset) - seq_len - pred_len + 1, step_size)):x = dataset[i:i + seq_len]y = dataset[i + seq_len:i + seq_len + pred_len]# tensorx = torch.FloatTensor(x)y = torch.FloatTensor(y)seq.append((x, y))seq = MyDataset(seq)seq = DataLoader(dataset=seq, batch_size=batch_size, shuffle=shuffle, num_workers=0, drop_last=False)return seqDtr = process(train, step_size=1, shuffle=True)Val = process(val, step_size=1, shuffle=True)Dte = process(test, step_size=pred_len, shuffle=False)return Dtr, Val, Dte, adj, scalers

归一化时,由于不同站点间的数据没有太大关联,因此需要单独对每个站点内部的数据进行归一化,这里采用了MinMaxSacler归一化。

III. 模型

在这篇文章中将使用常见的三个GNN模型进行预测,即GCN、GraphSAGE和GAT。

图卷积网络(Graph Convolutional Network,GCN)是最早提出的图神经网络之一,GCN通过在图的邻域内进行信息聚合来学习节点的低维表示。具体来说,GCN利用了拉普拉斯矩阵的特征值分解,通过图傅里叶变换将卷积操作转换为频域上的滤波操作。GCN的核心公式为:
h ( l + 1 ) = σ ( D ~ − 1 2 A ~ D ~ − 1 2 h ( l ) W ( l ) ) h^{(l+1)} = \sigma(\tilde{D}^{-\frac{1}{2}} \tilde{A} \tilde{D}^{-\frac{1}{2}} h^{(l)} W^{(l)}) h(l+1)=σ(D~21A~D~21h(l)W(l))

其中, A ~ \tilde{A} A~是带有自环的邻接矩阵, D ~ \tilde{D} D~是对应的度矩阵, h ( l ) h^{(l)} h(l) 是第 l l l层的隐藏状态, W ( l ) W^{(l)} W(l)是权重矩阵, σ \sigma σ是激活函数。

GraphSAGE是由Hamilton等人在2017年提出的,旨在解决大规模图上的节点表示学习问题。GraphSAGE通过采样节点的邻居,并在局部邻域内进行信息聚合,从而生成节点表示。GraphSAGE支持多种聚合方法,包括 Mean Aggregator、LSTM Aggregator 和 Max Pooling Aggregator。GraphSAGE的核心公式为:
h i ( k + 1 ) = σ ( W f ( h i ( k ) , { h j ( k ) ∣ j ∈ N ( i ) } ) ) h_i^{(k+1)} = \sigma(W f(h_i^{(k)}, \{h_j^{(k)} | j \in \mathcal{N}(i)\})) hi(k+1)=σ(Wf(hi(k),{hj(k)jN(i)}))

其中 h i ( k ) h_i^{(k)} hi(k)是第 k k k层节点 i i i的隐藏状态, N ( i ) \mathcal{N}(i) N(i)是节点 i i i的邻居集合, f f f是聚合函数。GraphSAGE通过多层聚合操作,能够有效地捕捉节点的局部结构信息。

图注意力网络GAT(Graph Attention Networks)通过整合注意力机制实现了对图中不同邻居节点的动态加权。其主要创新之处在于,为每个邻接节点分配一个注意力得分,从而使模型可以聚焦于那些更为重要的邻近节点。GAT的核心公式为:
h i ( l + 1 ) = σ ( ∑ j ∈ N ( i ) α i j W h j ( l ) ) h_i^{(l+1)} = \sigma \left( \sum_{j \in \mathcal{N}(i)} \alpha_{ij} W h_j^{(l)} \right) hi(l+1)=σ jN(i)αijWhj(l)

其中 α i j \alpha_{ij} αij是节点 i i i j j j之间的注意力分数, W W W是权重矩阵。GAT通过自注意力机制,能够更好地捕捉节点之间的关系。

上述三种模型的原理十分简单,下边将依次介绍如何使用三种模型进行多站点、多变量输入、多变量输出的时空预测。

在进行模型讲解之前,先规定一下模型的输入和输出维度。在本文中,模型的输入尺寸为:batch_size * seq_len * num_nodes * in_feats,表示每个站点的多变量历史数据,输出为batch_size * pred_len * num_nodes * in_feats,表示多个站点未来的多变量数据。

3.1 GCN

GCN的代码实现十分优雅简洁,可以先看一下原作者的代码实现:

class GraphConvolution(Module):"""Simple GCN layer, similar to https://arxiv.org/abs/1609.02907"""def __init__(self, in_features, out_features, bias=True):super(GraphConvolution, self).__init__()self.in_features = in_featuresself.out_features = out_featuresself.weight = Parameter(torch.FloatTensor(in_features, out_features))if bias:self.bias = Parameter(torch.FloatTensor(out_features))else:self.register_parameter('bias', None)self.reset_parameters()def reset_parameters(self):stdv = 1. / math.sqrt(self.weight.size(1))self.weight.data.uniform_(-stdv, stdv)if self.bias is not None:self.bias.data.uniform_(-stdv, stdv)def forward(self, input, adj):support = torch.mm(input, self.weightoutput = torch.spmm(adj, support)if self.bias is not None:return output + self.biaselse:return outputdef __repr__(self):return self.__class__.__name__ + ' (' \+ str(self.in_features) + ' -> ' \+ str(self.out_features) + ')'

可以看到,GCN的本质就是将归一化后的邻接矩阵和节点特征矩阵执行矩阵乘法,即(num_nodes, num_nodes) * (num_nodes, feats) -> (num_nodes, feats)

因此,对于大小为batch_size * seq_len * num_nodes * in_feats的输入,可以直接对后两个维度进行计算。代码如下:

class GCNConv(nn.Module):def __init__(self, in_features, out_features, bias=True):super(GCNConv, self).__init__()self.in_features = in_featuresself.out_features = out_featuresself.weight = Parameter(torch.FloatTensor(in_features, out_features))if bias:self.bias = Parameter(torch.FloatTensor(out_features))else:self.register_parameter('bias', None)self.reset_parameters()def reset_parameters(self):stdv = 1. / math.sqrt(self.weight.size(1))self.weight.data.uniform_(-stdv, stdv)if self.bias is not None:self.bias.data.uniform_(-stdv, stdv)def forward(self, x, adj):support = torch.matmul(x, self.weight)# 输入的数据是x = b s n d, adj = n * noutput = torch.einsum("tn,bsnd->bstd", adj, support)   # bsndif self.bias is not None:output + self.biasreturn output

具体来讲,首先将batch_size * seq_len * num_nodes * in_feats利用self.weight变成batch_size * seq_len * num_nodes * out_feats,然后再与归一化后的邻接矩阵相乘,这里用到了torch.einsum()函数来指定参与计算的维度。

接着,便可以基于GCNConv来定义用于多站点。多变量输入、多变量输出的时刻预测GCN模型:

class GCN(torch.nn.Module):def __init__(self, args):super(GCN, self).__init__()self.args = argsself.conv1 = GCNConv(args.in_feats, args.h_feats)self.conv2 = GCNConv(args.h_feats, args.out_feats)self.dropout = 0.5self.fcs = nn.ModuleList()for _ in range(args.in_feats):self.fcs.append(nn.Sequential(nn.Linear(args.seq_len * args.out_feats, args.out_feats),nn.ReLU(),nn.Linear(args.out_feats, args.pred_len)))def forward(self, x, adj):# bsndx = F.dropout(x, self.dropout, training=self.training)x = F.elu(self.conv1(x, adj))x = self.conv2(x, adj)# b s n d  --> b s n 3x = x.permute(0, 2, 1, 3)  # bnsdx = torch.flatten(x, start_dim=2)  # bn s*dpred = []for idx in range(self.args.in_feats):sub_pred = self.fcs[idx](x)   # b n pred_lenpred.append(sub_pred)  # b pred_len 3pred = torch.stack(pred, dim=-1)  # b n pred_len 3# 变成和y一样的维度,即b pred_len num_node 3pred = pred.permute(0, 2, 1, 3)return pred

该模型由2个GCN层和一个预测层组成。输入batch_size * seq_len * num_nodes * in_feats(以下简称bsni)经过两层GCN变成bsno。接着,为了预测所有站点的多个变量,采用多任务学习中的思路,每个变量使用一个线性层进行预测。

预测时,首先将bsno的进行维度交换变成bnso,与LSTM等模型类似,可以将所有时刻的隐状态展开变成一个bn(s*d),然后使用多个线性层得到多个bn(pred_len),然后将多个变量的预测值拼接变成bn(pred_len)(in_feats)。最后,为了与真实值的batch_size * pred_len * num_nodes * in_feats相匹配,需要交换1和2两个维度。

需要注意的是,forward中传入的邻接矩阵是归一化后的邻接矩阵,归一化操作可以参考如下代码:

def normalize_adj(adj):"""归一化邻接矩阵,适用于图卷积网络(GCN)。:param adj: 原始邻接矩阵,形状为 (N, N):return: 归一化后的邻接矩阵,形状为 (N, N)"""# 添加自环adj = adj + torch.eye(adj.size(0))# 计算度矩阵 Ddegree = adj.sum(1)# 计算 D 的逆平方根d_inv_sqrt = torch.pow(degree, -0.5)d_inv_sqrt[torch.isinf(d_inv_sqrt)] = 0  # 防止出现无穷大# 构建 D 的逆平方根矩阵d_mat_inv_sqrt = torch.diag(d_inv_sqrt)# 归一化邻接矩阵adj_normalized = d_mat_inv_sqrt @ adj @ d_mat_inv_sqrtreturn adj_normalized

3.2 GraphSAGE

GraphSAGE的本质是将一个节点的邻居节点聚合后再与自身进行拼接变换,单层代码实现如下:

class SAGEConv(nn.Module):def __init__(self, in_features, out_features):super(SAGEConv, self).__init__()self.in_features = in_featuresself.out_features = out_featuresself.proj = nn.Linear(in_features, out_features)self.out_proj = nn.Linear(2 * out_features, out_features)def forward(self, x, adj):# 假设有多个站点support = self.proj(x)# 输入的数据是x = b s n d, adj = n * n# 邻居平均 设定一个很小的正数epseps = torch.tensor(1e-8)# 计算每一行的和,并确保不会除以零row_sums = adj.sum(dim=1, keepdim=True)row_sums = torch.max(row_sums, eps)# 对每一行进行规范化normalized_adj = adj / row_sumsoutput = torch.einsum("tn,bsnd->bstd", normalized_adj, support)   # bsnd# catcat_x = torch.cat((support, output), dim=-1)  # bsn 2dz = self.out_proj(cat_x)# normz_norm = z.norm(p=2, dim=-1, keepdim=True)z_norm = torch.where(z_norm == 0, torch.tensor(1.).to(z_norm), z_norm)z = z / z_normreturn z

上述代码采用的是平均聚合。与GCN类似,可以搭建GraphSAGE如下:

class GraphSAGE(torch.nn.Module):def __init__(self, args):super(GraphSAGE, self).__init__()self.args = argsself.conv1 = SAGEConv(args.in_feats, args.h_feats)self.conv2 = SAGEConv(args.h_feats, args.out_feats)self.dropout = 0.5self.fcs = nn.ModuleList()for _ in range(args.in_feats):self.fcs.append(nn.Sequential(nn.Linear(args.seq_len * args.out_feats, args.out_feats),nn.ReLU(),nn.Linear(args.out_feats, args.pred_len)))def forward(self, x, adj):# bsnd# x = F.dropout(x, self.dropout, training=self.training)x = F.relu(self.conv1(x, adj))x = self.conv2(x, adj)# b s n d  --> b s n 3x = x.permute(0, 2, 1, 3)  # bnsdx = torch.flatten(x, start_dim=2)  # bn s*dpred = []for idx in range(self.args.in_feats):sub_pred = self.fcs[idx](x)   # b n pred_lenpred.append(sub_pred)  # b pred_len 3pred = torch.stack(pred, dim=-1)  # b n pred_len 3pred = pred.permute(0, 2, 1, 3)return pred

3.3 GAT

GAT的代码稍显复杂,其本质是将节点的特征和邻居特征进行拼接然后变换得到这条边上的权重,最后再对邻居的特征进行加权。这里可以先参考一下GitHub上的GAT代码:

class GraphAttentionLayer(nn.Module):"""Simple GAT layer, similar to https://arxiv.org/abs/1710.10903"""def __init__(self, in_features, out_features, dropout, alpha, concat=True):super(GraphAttentionLayer, self).__init__()self.dropout = dropoutself.in_features = in_featuresself.out_features = out_featuresself.alpha = alphaself.concat = concatself.W = nn.Parameter(torch.empty(size=(in_features, out_features)))nn.init.xavier_uniform_(self.W.data, gain=1.414)self.a = nn.Parameter(torch.empty(size=(2*out_features, 1)))nn.init.xavier_uniform_(self.a.data, gain=1.414)self.leakyrelu = nn.LeakyReLU(self.alpha)def forward(self, h, adj):Wh = torch.mm(h, self.W) # h.shape: (N, in_features), Wh.shape: (N, out_features)e = self._prepare_attentional_mechanism_input(Wh)zero_vec = -9e15*torch.ones_like(e)attention = torch.where(adj > 0, e, zero_vec)attention = F.softmax(attention, dim=1)attention = F.dropout(attention, self.dropout, training=self.training)h_prime = torch.matmul(attention, Wh)if self.concat:return F.elu(h_prime)else:return h_primedef _prepare_attentional_mechanism_input(self, Wh):# Wh.shape (N, out_feature)# self.a.shape (2 * out_feature, 1)# Wh1&2.shape (N, 1)# e.shape (N, N)Wh1 = torch.matmul(Wh, self.a[:self.out_features, :])Wh2 = torch.matmul(Wh, self.a[self.out_features:, :])# broadcast adde = Wh1 + Wh2.Treturn self.leakyrelu(e)def __repr__(self):return self.__class__.__name__ + ' (' + str(self.in_features) + ' -> ' + str(self.out_features) + ')'

上述代码使用了broadcast add技巧来得到每个节点与其他所有节点的权重,然后再使用adj来将不存在边的权重变成一个很小的负数。

基于上述思想,可以将本文的图注意力层定义如下:

class GraphAttentionLayer(nn.Module):def __init__(self, in_features, out_features, dropout, alpha, concat=True):super(GraphAttentionLayer, self).__init__()self.dropout = dropoutself.in_features = in_featuresself.out_features = out_featuresself.alpha = alphaself.concat = concatself.W = nn.Parameter(torch.empty(size=(in_features, out_features)))nn.init.xavier_uniform_(self.W.data, gain=1.414)self.a = nn.Parameter(torch.empty(size=(2*out_features, 1)))nn.init.xavier_uniform_(self.a.data, gain=1.414)self.leakyrelu = nn.LeakyReLU(self.alpha)def forward(self, h, adj):# bsnd nnWh = torch.matmul(h, self.W)e = self._prepare_attentional_mechanism_input(Wh)  # bsnn# 掩码操作mask = (adj == 0)# 广播掩码矩阵mask = mask.unsqueeze(0).unsqueeze(0)mask = mask.expand_as(e)# 应用掩码e[mask] = -9e15e = F.softmax(e, dim=1)e = F.dropout(e, self.dropout, training=self.training)h_prime = torch.einsum("bstn,bsnd->bstd", e, Wh)  # bsndif self.concat:return F.elu(h_prime)else:return h_primedef _prepare_attentional_mechanism_input(self, Wh):# Wh.shape (bsz, seq_len, N, out_feature)# self.a.shape (2 * out_feature, 1)# Wh1&2.shape (N, 1)# e.shape (bsz, seq_len, N, N)Wh1 = torch.matmul(Wh, self.a[:self.out_features, :])Wh2 = torch.matmul(Wh, self.a[self.out_features:, :])# broadcast add# 只是最后两个维度相加e = Wh1 + Wh2.permute(0, 1, 3, 2)return self.leakyrelu(e)def __repr__(self):return self.__class__.__name__ + ' (' + str(self.in_features) + ' -> ' + str(self.out_features) + ')'

区别在于:

  1. 其一,执行broadcast add时候,只是后两个维度进行操作(e = Wh1 + Wh2.permute(0, 1, 3, 2)),即bsnd+bsdn。
  2. 得到attention矩阵大小为bsnn,而不是二维的nn。因此,同样需要进行广播来实现掩码操作。

最后,基于GATConv,可以搭建一个简易版本(不使用多头注意力机制)的GAT如下:

class GAT(torch.nn.Module):def __init__(self, args):super(GAT, self).__init__()self.args = argsalpha = 0.2self.dropout = args.dropoutself.conv1 = GraphAttentionLayer(args.in_feats,args.h_feats,dropout=self.dropout, alpha=alpha, concat=False)self.conv2 = GraphAttentionLayer(args.h_feats,args.out_feats,dropout=self.dropout, alpha=alpha, concat=False)self.fcs = nn.ModuleList()for _ in range(args.in_feats):self.fcs.append(nn.Sequential(nn.Linear(args.seq_len * args.out_feats, args.out_feats),nn.ReLU(),nn.Linear(args.out_feats, args.pred_len)))def forward(self, x, adj):# bsndx = F.dropout(x, self.dropout, training=self.training)x = F.elu(self.conv1(x, adj))x = self.conv2(x, adj)# b s n d  --> b s n 3x = x.permute(0, 2, 1, 3)  # bnsdx = torch.flatten(x, start_dim=2)  # bn s*dpred = []for idx in range(self.args.in_feats):sub_pred = self.fcs[idx](x)   # b n pred_lenpred.append(sub_pred)  # b pred_len 3pred = torch.stack(pred, dim=-1)  # b n pred_len 3# 变成和y一样的维度,即b pred_len num_node 3pred = pred.permute(0, 2, 1, 3)return pred

当然,也可以使用多头注意力机制:

class GAT(nn.Module):def __init__(self, args):super(GAT, self).__init__()self.args = argsself.dropout = args.dropoutalpha = 0.2self.attentions = nn.ModuleList()for _ in range(args.heads):layer = GraphAttentionLayer(args.in_feats, args.h_feats, dropout=self.dropout, alpha=alpha, concat=True)self.attentions.append(layer)self.out_att = GraphAttentionLayer(args.h_feats * args.heads,args.out_feats,dropout=self.dropout, alpha=alpha, concat=False)# fcself.fcs = nn.ModuleList()for _ in range(args.in_feats):self.fcs.append(nn.Sequential(nn.Linear(args.seq_len * args.out_feats, args.out_feats),nn.ReLU(),nn.Linear(args.out_feats, args.pred_len)))def forward(self, x, adj):x = F.dropout(x, self.dropout, training=self.training)x = torch.cat([att(x, adj) for att in self.attentions], dim=-1)x = F.dropout(x, self.dropout, training=self.training)x = self.out_att(x, adj)# b s n d  --> b s n 3x = x.permute(0, 2, 1, 3)  # bnsdx = torch.flatten(x, start_dim=2)  # bn s*dpred = []for idx in range(self.args.in_feats):sub_pred = self.fcs[idx](x)  # b n pred_lenpred.append(sub_pred)  # b pred_len 3pred = torch.stack(pred, dim=-1)  # b n pred_len 3pred = pred.permute(0, 2, 1, 3)return pred

IV. 训练与测试

训练测试代码与之前差不太多,训练函数定义如下:

def train(args, Dtr, Val, adj, path, model_type):if model_type == "gcn":adj = normalize_adj(adj)model = GCN(args).to(device)elif model_type == "sage":model = GraphSAGE(args).to(device)elif model_type == "gat":model = GAT(args).to(device)else:raise ValueError("model_type has to be one of ('gcn', 'sage', 'gat')")adj = adj.to(device)loss_function = nn.MSELoss().to(device)optimizer = torch.optim.Adam(model.parameters(), lr=args.lr,weight_decay=args.weight_decay)scheduler = StepLR(optimizer, step_size=args.step_size, gamma=args.gamma)# trainingmin_epochs = 2best_model = Nonemin_val_loss = np.Inffor epoch in tqdm(range(args.epochs)):model.train()train_loss = []for (seq, label) in Dtr:optimizer.zero_grad()seq = seq.to(device)label = label.to(device)  # b pred_len num_node 3pred = model(seq, adj)  # b pred_len num_node 3# print(label.shape, pred.shape)loss = loss_function(pred, label)loss.backward()optimizer.step()train_loss.append(loss.item())scheduler.step()# validationval_loss = get_val_loss(args, model, Val, adj)if epoch + 1 >= min_epochs and val_loss < min_val_loss:min_val_loss = val_lossbest_model = copy.deepcopy(model)state = {'model': best_model.state_dict()}torch.save(state, path + '/models/' + model_type + '.pkl')print('epoch {:03d} train_loss {:.8f} val_loss {:.8f}'.format(epoch, np.mean(train_loss), val_loss))state = {'model': best_model.state_dict()}torch.save(state, path + '/models/' + model_type + '.pkl')

测试代码:

@torch.no_grad()
def test(args, Dte, adj, path, model_type, scalers):if model_type == "gcn":adj = normalize_adj(adj)model = GCN(args).to(device)elif model_type == "sage":model = GraphSAGE(args).to(device)elif model_type == "gat":model = GAT(args).to(device)else:raise ValueError("model_type has to be one of ('gcn', 'sage', 'gat')")model.load_state_dict(torch.load(path + '/models/' + model_type + '.pkl')['model'])adj = adj.to(device)y, pred = [], []for seq, label in Dte:seq = seq.to(device)y.append(label)sub_pred = model(seq, adj)  # b pred_len num_node 3pred.append(sub_pred.cpu())#y = torch.concat(y, dim=0)y = y.view(-1, y.size(2), y.size(3)).numpy()pred = torch.concat(pred, dim=0)  # num num_node 3pred = pred.view(-1, pred.size(2), pred.size(3)).numpy()# flatten# scalerfor idx in range(y.shape[1]):y[:, idx, :] = scalers[idx].inverse_transform(y[:, idx, :])pred[:, idx, :] = scalers[idx].inverse_transform(pred[:, idx, :])for idx in range(y.shape[1]):cur_y, cur_pred = y[:, idx, :], pred[:, idx, :]# 输出各种指标print('第{}个站点的指标为:'.format(idx + 1))maes, mses, rmses, mapes, r2s = get_metric(cur_y, cur_pred)print('mae:', maes)print('mse', mses)print('rmse:', rmses)print('mape:', mapes)print('r2:', r2s)# plotfor i in range(cur_y.shape[1]):plt.plot(cur_y[:, i], label="第{}个站点的第{}个变量的真实值".format(idx + 1, i + 1))plt.plot(cur_pred[:, i], label="第{}个站点的第{}个变量的预测值".format(idx + 1, i + 1))plt.legend()plt.show()

测试反归一化时注意在站点内进行归一化。

V. 实验结果

以PEMS04为例,下图展示了一些预测结果:
在这里插入图片描述

相关文章:

PyTorch搭建GNN(GCN、GraphSAGE和GAT)实现多节点、单节点内多变量输入多变量输出时空预测

目录 I. 前言II. 数据集说明III. 模型3.1 GCN3.2 GraphSAGE3.3 GAT IV. 训练与测试V. 实验结果 I. 前言 前面已经写了很多关于时间序列预测的文章&#xff1a; 深入理解PyTorch中LSTM的输入和输出&#xff08;从input输入到Linear输出&#xff09;PyTorch搭建LSTM实现时间序列…...

51单片机快速入门之数码管的拓展应用2024/10/15

51单片机快速入门之数码管的拓展应用 在前面的文章中,我们已经了解到数码管的基础应用,今天来讲讲拓展应用 我们知道单个数码管分为以下 但是当我们碰到 如下这种数码管的时候又应该如何去控制呢? 这里就不得不说其拓展应用之-----------扫描显示 扫描显示: 扫描显示,又称…...

vue 音频播放控件封装

<template> <div> <audio @timeupdate="updateProgress" controls ref="audioRef" style="display: none" > <source :src="audioUrl" type="audio/mpeg" /> 您的浏览器不支持音频播放 </audio&…...

秋招面试题记录

嵌入式软件开发 网上搜集的题目 1.Static关键词的作用&#xff1f; static 关键字有三个主要作用&#xff1a; 局部变量&#xff1a;在函数内部&#xff0c;static 局部变量只初始化一次&#xff0c;且在函数调用结束后仍然保留其值。全局变量/函数&#xff1a;在文件内部&a…...

金字塔流(Pyramid Flow): 用于生成人工智能长视频的新文本-视频开源模型

在 "生成式人工智能 "中的文本生成模型和图像生成模型大行其道之后&#xff0c;现在该是文本-视频模型大显身手的时候了&#xff0c;这个列表中的新模型就是 pyramid-flow-sd3&#xff0c;它是一个开源模型&#xff0c;用于从文本或图像生成长达 10 秒的视频&#xf…...

施磊C++ | 进阶学习笔记 | 5.设计模式

五、设计模式 文章目录 五、设计模式1.设计模式三大类型概述一、创建型设计模式二、结构型设计模式三、行为型设计模式 2.设计模式三大原则3.单例模式1.饿汉单例模式2.懒汉单例模式 4.线程安全的懒汉单例模式1.锁双重判断2.简洁的线程安全懒汉单例模式 5.简单工厂(Simple Facto…...

智绘城市地图:使用百度地图 API 实现智能定位

✨✨ 欢迎大家来访Srlua的博文&#xff08;づ&#xffe3;3&#xffe3;&#xff09;づ╭❤&#xff5e;✨✨ &#x1f31f;&#x1f31f; 欢迎各位亲爱的读者&#xff0c;感谢你们抽出宝贵的时间来阅读我的文章。 我是Srlua小谢&#xff0c;在这里我会分享我的知识和经验。&am…...

【稳定性】稳定性建设之变更管理

作者&#xff1a;京东物流 冯志文 背景 在软件开发和运维领域&#xff0c;变更管理是一个至关重要的环节。无论是对现有系统的改进、功能的增加还是修复漏洞&#xff0c;变更都是不可避免的。这些变更可能涉及到软件代码的修改、配置的调整、服务器的扩容、三方jar包的变更等等…...

c语言中字符串函数strlen,strcmp,strcpy,srtcat,strncpy,strncmp,strncat

1.strlen的使用和模拟实现 strlen 用来求字符串的长度&#xff0c;统计\0之前字符的个数。 模拟实现1&#xff1a;计数参数法 模拟实验2&#xff1a;指针方法 模拟实验3&#xff1a;递归方法 2&#xff0c;strcpy 的使用和模拟实现&#xff08;拷贝字符串&#xff09; char*…...

高级SQL技巧

高级SQL技巧涵盖了许多方面&#xff0c;包括但不限于窗口函数、递归查询、公共表表达式&#xff08;CTEs&#xff09;、子查询、集合操作、临时函数、日期时间操作、索引优化等。以下是对这些技巧的详细讲解和示例。 窗口函数 窗口函数是一种特殊的SQL函数&#xff0c;能够在…...

新大话西游图文架设教程

开始架设 1. 架设条件 新大话西游架设需要准备&#xff1a; linux 系统服务器&#xff0c;建议 CentOs 7.6或以上版本游戏源码&#xff0c;。 2. 安装宝塔面板 宝塔是一个服务器运维管理软件&#xff0c;安装命令&#xff1a; yum install -y wget && wget -O in…...

Maven 快速入门

Maven 快速入门 一、简介1、概述2、特点3、工作原理4、常用命令5、生命周期6、优缺点&#x1f388; 面试题 二、安装和配置1、安装2、环境配置3、命令测试是否安装成功4、功能配置5、idea配置本地 maven6、maven 工程依赖包查询网站 三、基于IDEA创建Maven工程1、maven 工程中的…...

OpenCV-人脸检测

文章目录 一、人脸检测流程二、关键方法三、代码示例四、注意事项 OpenCV是一个开源的计算机视觉和机器学习软件库&#xff0c;它提供了多种人脸检测方法&#xff0c;以下是对OpenCV人脸检测的详细介绍&#xff1a; 一、人脸检测流程 人脸检测是识别图像中人脸位置的过程&…...

【重磅升级】基于大数据的股票量化分析与预测系统

温馨提示&#xff1a;文末有 CSDN 平台官方提供的学长 QQ 名片 :) 1. 项目简介 伴随全球经济一体化和我国经济的快速发展&#xff0c;中国股票市场对世界经济的影响力不断攀升&#xff0c;中国股市已成为全球第二大股票交易市场。在当今的金融市场中&#xff0c;股票价格的波动…...

python全栈学习记录(二十四)元类、异常处理

元类、异常处理 文章目录 元类、异常处理一、元类1.元类控制类的实例化2.属性/方法的查找顺序3.单例 二、异常处理 一、元类 1.元类控制类的实例化 类的__call__方法会在产生的对象被调用时自动触发&#xff0c;args和kwargs就是调用实例时传入的参数&#xff0c;返回值是调用…...

Golang Slice扩容机制及注意事项

Golang Slice扩容机制及注意事项&#xff1a; 在 Go语言中&#xff0c;Slice&#xff08;切片&#xff09;是一种非常灵活且强大的数据结构&#xff0c;它是对数组的抽象&#xff0c;提供了动态数组的功能。Slice 的扩容机制是自动的&#xff0c;但了解其背后的原理对于编写高…...

华为OD机试 - 猜数字 - 暴力枚举(Python/JS/C/C++ 2024 E卷 100分)

华为OD机试 2024E卷题库疯狂收录中&#xff0c;刷题点这里 专栏导读 本专栏收录于《华为OD机试真题&#xff08;Python/JS/C/C&#xff09;》。 刷的越多&#xff0c;抽中的概率越大&#xff0c;私信哪吒&#xff0c;备注华为OD&#xff0c;加入华为OD刷题交流群&#xff0c;…...

Flink触发器Trigger

前言 在 Flink 窗口计算模型中&#xff0c;数据先经过 WindowAssigner 分配窗口&#xff0c;然后再经过触发器 Trigger&#xff0c;Trigger 决定了一个窗口何时被 ProcessFunction 处理。每个 WindowAssigner 都有一个默认的 Trigger&#xff0c;如果默认的不满足需求&#xf…...

【操作系统的使用】Linux 系统环境变量与服务管理:设置与控制的艺术

文章目录 系统环境变量与服务管理&#xff1a;设置与控制的艺术一、系统环境变量的设置1.1 临时设置环境变量1.2 永久设置环境变量 二、服务启动类型的设置2.1 查看服务状态2.2 启动和停止服务2.3 设置服务的启动类型2.3.1 设置服务在启动时运行2.3.2 禁用服务在启动时运行2.3.…...

速盾:高防cdn配置中性能优化是什么?

高防CDN配置中的性能优化是指通过调整CDN配置以提升网站的加载速度、响应时间和用户体验。在进行性能优化时&#xff0c;需要考虑多个因素&#xff0c;包括CDN节点的选择和布置、缓存策略、缓存过期时间、预取和预加载、并发连接数和网络延迟等。 首先&#xff0c;CDN节点的选…...

【网络安全产品大调研系列】2. 体验漏洞扫描

前言 2023 年漏洞扫描服务市场规模预计为 3.06&#xff08;十亿美元&#xff09;。漏洞扫描服务市场行业预计将从 2024 年的 3.48&#xff08;十亿美元&#xff09;增长到 2032 年的 9.54&#xff08;十亿美元&#xff09;。预测期内漏洞扫描服务市场 CAGR&#xff08;增长率&…...

【解密LSTM、GRU如何解决传统RNN梯度消失问题】

解密LSTM与GRU&#xff1a;如何让RNN变得更聪明&#xff1f; 在深度学习的世界里&#xff0c;循环神经网络&#xff08;RNN&#xff09;以其卓越的序列数据处理能力广泛应用于自然语言处理、时间序列预测等领域。然而&#xff0c;传统RNN存在的一个严重问题——梯度消失&#…...

生成 Git SSH 证书

&#x1f511; 1. ​​生成 SSH 密钥对​​ 在终端&#xff08;Windows 使用 Git Bash&#xff0c;Mac/Linux 使用 Terminal&#xff09;执行命令&#xff1a; ssh-keygen -t rsa -b 4096 -C "your_emailexample.com" ​​参数说明​​&#xff1a; -t rsa&#x…...

Docker 本地安装 mysql 数据库

Docker: Accelerated Container Application Development 下载对应操作系统版本的 docker &#xff1b;并安装。 基础操作不再赘述。 打开 macOS 终端&#xff0c;开始 docker 安装mysql之旅 第一步 docker search mysql 》〉docker search mysql NAME DE…...

return this;返回的是谁

一个审批系统的示例来演示责任链模式的实现。假设公司需要处理不同金额的采购申请&#xff0c;不同级别的经理有不同的审批权限&#xff1a; // 抽象处理者&#xff1a;审批者 abstract class Approver {protected Approver successor; // 下一个处理者// 设置下一个处理者pub…...

七、数据库的完整性

七、数据库的完整性 主要内容 7.1 数据库的完整性概述 7.2 实体完整性 7.3 参照完整性 7.4 用户定义的完整性 7.5 触发器 7.6 SQL Server中数据库完整性的实现 7.7 小结 7.1 数据库的完整性概述 数据库完整性的含义 正确性 指数据的合法性 有效性 指数据是否属于所定…...

Unity UGUI Button事件流程

场景结构 测试代码 public class TestBtn : MonoBehaviour {void Start(){var btn GetComponent<Button>();btn.onClick.AddListener(OnClick);}private void OnClick(){Debug.Log("666");}}当添加事件时 // 实例化一个ButtonClickedEvent的事件 [Formerl…...

大数据治理的常见方式

大数据治理的常见方式 大数据治理是确保数据质量、安全性和可用性的系统性方法&#xff0c;以下是几种常见的治理方式&#xff1a; 1. 数据质量管理 核心方法&#xff1a; 数据校验&#xff1a;建立数据校验规则&#xff08;格式、范围、一致性等&#xff09;数据清洗&…...

LangChain【6】之输出解析器:结构化LLM响应的关键工具

文章目录 一 LangChain输出解析器概述1.1 什么是输出解析器&#xff1f;1.2 主要功能与工作原理1.3 常用解析器类型 二 主要输出解析器类型2.1 Pydantic/Json输出解析器2.2 结构化输出解析器2.3 列表解析器2.4 日期解析器2.5 Json输出解析器2.6 xml输出解析器 三 高级使用技巧3…...

网页端 js 读取发票里的二维码信息(图片和PDF格式)

起因 为了实现在报销流程中&#xff0c;发票不能重用的限制&#xff0c;发票上传后&#xff0c;希望能读出发票号&#xff0c;并记录发票号已用&#xff0c;下次不再可用于报销。 基于上面的需求&#xff0c;研究了OCR 的方式和读PDF的方式&#xff0c;实际是可行的&#xff…...