当前位置: 首页 > news >正文

时间序列预测模型实战案例(十)(个人创新模型)通过堆叠CNN、GRU、LSTM实现多元预测和单元预测

本文介绍

本篇博客为大家讲解的是通过组堆叠CNN、GRU、LSTM个数,建立多元预测和单元预测的时间序列预测模型,其效果要比单用GRU、LSTM效果好的多,其结合了CNN的特征提取功能、GRU和LSTM用于处理数据中的时间依赖关系的功能。通过将它们组合在一起,模型可以同时考虑输入数据的空间和时间特征,以更好地进行预测。本篇实战案例中包括->详细的参数讲解、数据集介绍、模型框架原理、训练你个人数据集的教程、以及结果分析。本篇文章的讲解流程为->

预测类型->多元预测、单元预测

开源代码->文末有完整代码块复制粘贴即可运行

适用人群->时间序列建模的初学者、时间序列建模的工作者

模型框架原理 

首先我们来简单介绍一下本模型所用的框架原理,是为什么能够根据输入来预测出未来的值的,也就是数据的输入数据的输出在我们的模型内部到底经过了一个什么样的处理这样一个过程的讲解。首先我们要知道三个概念,也就是本文所用到的CNN、GRU、LSTM三个主要处理结构,下面我们来简单的进行分别介绍。

CNN

CNN我相信大家都已经非常了解了这里只是简单介绍一下其在时间序列预测也就是在数据是一维(1-D)时候的作用机制。

CNN:在时间序列预测中,CNN可以用于提取时间序列数据中的局部模式或特征。通过卷积操作,CNN可以捕捉时间序列数据中的局部相关性,并通过激活函数的非线性变换,提取出高层次的特征表示。

可以看到其和2-D、3-D时候的作用是一样的主要做到的是一个特征提取的工作,唯一的区别可能就是我们的输入数据是一维的所以他会沿着时间序列顺序执行,下面我们来看一个图片来理解其工作原理。 

总结->这个图片代表着一个stride分别为1、2、4和卷积核=3的1维度卷积处理结果,可以看到其是顺序执行不想2-D、3-D卷积哪样需要换行的操作。 

GRU

门控循环单元(GRU)是一种循环神经网络(RNN)单元,用于处理序列数据。GRU相对于传统的RNN单元具有改进的结构,旨在更好地处理长期依赖关系和消除梯度消失问题。

GRU的结构由以下几个关键单元组成:

  1. 更新门(Update Gate):更新门控制着前一时刻的隐藏状态(或记忆)保留多少信息传递到当前时刻的隐藏状态中。它通过观察当前输入和前一时刻的隐藏状态,决定更新的程度。更新门具有一个范围从0到1的值,其中0表示完全忘记先前的隐藏状态,1表示完全保留先前的隐藏状态。

  2. 重置门(Reset Gate):重置门决定如何使用前一时刻的隐藏状态来刷新当前时刻的隐藏状态。它通过控制前一时刻的隐藏状态在当前时刻的作用程度来帮助模型忘记不重要的信息。通过观察当前输入和前一时刻的隐藏状态,重置门输出一个范围从0到1的值,用于调整隐藏状态的刷新程度。

  3. 候选隐藏状态:候选隐藏状态(Candidate Hidden State)是一个候选的更新后的隐藏状态。它是根据当前输入和重置门的输出计算得到的。候选隐藏状态捕捉了当前输入和过去隐藏状态的相关性,并在一定程度上更新了隐藏状态。

  4. 新的隐藏状态(Updated Hidden State):新的隐藏状态由更新门、候选隐藏状态和前一时刻的隐藏状态组合而成。它被用作当前时刻的隐藏状态,并在下一个时刻传递下去。

总结->GRU的关键点是,它通过更新门控制了有关隐藏状态的信息流动,可以决定保留多少过去的信息。而重置门有助于捕捉当前输入和过去隐藏状态之间的相关性,并在一定程度上刷新隐藏状态。这种门控机制允许GRU更好地处理长期依赖关系,并减轻了梯度消失问题,使其能够更有效地处理序列数据。

下面的图片是一个完整的GRU的结构图片。

LSTM 

LSTM在我之前的博客中以及详细的讲过,如果有需要的可以看我之前的博客,这里只做一个间的回顾和概念介绍。

时间序列预测模型实战案例(三)(LSTM)(Python)(深度学习)时间序列预测(包括运行代码以及代码讲解)

LSTM(长短期记忆,Long Short-Term Memory)是一种用于处理序列数据的深度学习模型属于循环神经网络(RNN)的一种变体,其使用一种类似于搭桥术结构的RNN单元。相对于普通的RNN,LSTM引入了门控机制,能够更有效地处理长期依赖和短期记忆问题,是RNN网络中最常使用的Cell之一,其网络结构如下图。

模型执行流程

上面以及经过了三个主要结构的基本原理和结构,这里主要讲一下在模型中其它的结构如何搭配三个主要的结构(CNN、GRU、LSTM)实现回归问题的解决(时间序列预测就是一个回归问题)

  • 模型的数据输入首先是经过卷积神经网络(CNN)。它从输入数据开始,进行卷积操作,并使用 ReLU 激活函数进行非线性变换。然后对卷积输出进行一个 reshape 操作,将维度重新排列。所有输出被添加到输出列表中,最后使用 `Concat` 操作将它们在维度 2 上连接起来。连接后的结果经过一个 dropout 操作生成最终的特征输出
  • 接下来经过GRU层。首先将堆叠的所有循环单元按顺序连接在一起,并在每个循环单元之后添加一个 dropout 操作,将结果进行输出。
  • 再下面是LSTM层。同样将循环单元按顺序连接在一起,并在每个循环单元之后添加 dropout 操作。这一部分与之前的GRU 类似。
  • 接着是自回归(Autoregressive)组件。它对输入数据 `X` 的每个特征维度分别进行全连接操作,将维度减小到 1
  • 最后定义了预测组件的代码。通过一个全连接层将输入的维度变换为与输入特征向量的维度相同。然后将该输出与自回归组件的输出进行相加,得到最终的模型输出。使用线性回归损失函数计算模型输出与标签的损失。
  • 最终,函数返回损失值,输入数据的名称列表以及标签数据的名称列表。

模型的结构图下->

上述的过程讲述的就是模型中如下的代码流程可以参考着讲解和代码进行阅读。

def sym_gen(train_iter, q, filter_list, num_filter, dropout, rcells, skiprcells, seasonal_period, time_interval):input_feature_shape = train_iter.provide_data[0][1]X = mx.symbol.Variable(train_iter.provide_data[0].name)Y = mx.sym.Variable(train_iter.provide_label[0].name)# reshape data before applying convolutional layer (takes 4D shape incase you ever work with images)conv_input = mx.sym.reshape(data=X, shape=(0, 1, q, -1))################ CNN Component###############outputs = []for i, filter_size in enumerate(filter_list):# pad input array to ensure number output rows = number input rows after applying kernelpadi = mx.sym.pad(data=conv_input, mode="constant", constant_value=0,pad_width=(0, 0, 0, 0, filter_size - 1, 0, 0, 0))convi = mx.sym.Convolution(data=padi, kernel=(filter_size, input_feature_shape[2]), num_filter=num_filter)acti = mx.sym.Activation(data=convi, act_type='relu')trans = mx.sym.reshape(mx.sym.transpose(data=acti, axes=(0, 2, 1, 3)), shape=(0, 0, 0))outputs.append(trans)cnn_features = mx.sym.Concat(*outputs, dim=2)cnn_reg_features = mx.sym.Dropout(cnn_features, p=dropout)################ GRU Component###############stacked_rnn_cells = mx.rnn.SequentialRNNCell()for i, recurrent_cell in enumerate(rcells):stacked_rnn_cells.add(recurrent_cell)stacked_rnn_cells.add(mx.rnn.DropoutCell(dropout))outputs, states = stacked_rnn_cells.unroll(length=q, inputs=cnn_reg_features, merge_outputs=False)rnn_features = outputs[-1] #only take value from final unrolled cell for use later##################### LSTM Component####################stacked_rnn_cells = mx.rnn.SequentialRNNCell()for i, recurrent_cell in enumerate(skiprcells):stacked_rnn_cells.add(recurrent_cell)stacked_rnn_cells.add(mx.rnn.DropoutCell(dropout))outputs, states = stacked_rnn_cells.unroll(length=q, inputs=cnn_reg_features, merge_outputs=False)# Take output from cells p steps apartp = int(seasonal_period / time_interval)output_indices = list(range(0, q, p))outputs.reverse()skip_outputs = [outputs[i] for i in output_indices]skip_rnn_features = mx.sym.concat(*skip_outputs, dim=1)########################### Autoregressive Component##########################auto_list = []for i in list(range(input_feature_shape[2])):time_series = mx.sym.slice_axis(data=X, axis=2, begin=i, end=i+1)fc_ts = mx.sym.FullyConnected(data=time_series, num_hidden=1)auto_list.append(fc_ts)ar_output = mx.sym.concat(*auto_list, dim=1)####################### Prediction Component######################neural_components = mx.sym.concat(*[rnn_features, skip_rnn_features], dim=1)neural_output = mx.sym.FullyConnected(data=neural_components, num_hidden=input_feature_shape[2])model_output = neural_output + ar_outputloss_grad = mx.sym.LinearRegressionOutput(data=model_output, label=Y)return loss_grad, [v.name for v in train_iter.provide_data], [v.name for v in train_iter.provide_label]

数据集介绍

上面简单的介绍了模型的原理,下面的部分就是开始正式的实战讲解了,首先介绍的是我本次实战中举例用到的数据集部分截图如下,其主要预测列为OT列代表的含义是油温

参数讲解

下面我来介绍模型的主要参数,如果你想要使用自己的数据集进行预测,那么这个过程需要自信的看其中一些参数的讲解会涉及到如何替换个人数据集的介绍

首先先列出所有参数后进行讲解,参数如下->

parser = argparse.ArgumentParser(description="CNN-GRU-LSTM for multivariate time series forecasting",formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('--data-dir', type=str, default='./', help='relative path to input data')
parser.add_argument('--data_name',type=str, default='ETTh1-Test.csv', help='Input Model File Name')
parser.add_argument('--max-records', type=int, default=None, help='total records before data split')
parser.add_argument('--q', type=int, default=24*7, help='number of historical measurements included in each training example')
parser.add_argument('--horizon', type=int, default=4, help='number of measurements ahead to predict')
parser.add_argument('--splits', type=str, default="0.6,0.2", help='fraction of data to use for train & validation. remainder used for test.')
parser.add_argument('--batch-size', type=int, default=128, help='the batch size.')
parser.add_argument('--filter-list', type=str, default="6,12,18", help='unique filter sizes')
parser.add_argument('--num-filters', type=int, default=100, help='number of each filter size')
parser.add_argument('--recurrent-state-size', type=int, default=100, help='number of hidden units in each unrolled recurrent cell')
parser.add_argument('--seasonal-period', type=int, default=24, help='time between seasonal measurements')
parser.add_argument('--time-interval', type=int, default=1, help='time between each measurement')
parser.add_argument('--gpus', type=str, default='', help='list of gpus to run, e.g. 0 or 0,2,5. empty means using cpu. ')
parser.add_argument('--optimizer', type=str, default='adam', help='the optimizer type')
parser.add_argument('--lr', type=float, default=0.001, help='initial learning rate')
parser.add_argument('--dropout', type=float, default=0.2, help='dropout rate for network')
parser.add_argument('--num-epochs', type=int, default=100, help='max num of epochs')
parser.add_argument('--save-period', type=int, default=20, help='save checkpoint for every n epochs')
parser.add_argument('--model_prefix', type=str, default='electricity_model', help='prefix for saving model params')

上面的就是本模型用到的所有参数,下面来分别进行讲解。

CNN-GRU-LSTM模型参数详解
参数名称参数类型参数讲解
1data-dirstr数据数据集的目录,注意不要到具体的文件要都文件的目录名称!
2data-namestr文件的具体名称。
3max-recordsint数据分割之前的总记录数,可以选择仅使用部分数据进行训练和评估,默认为None表示使用全部数据。
4qint每个训练样本中包含的历史测量值的数量
5horizonint预测的未来测量值的数量,默认为4,这个就是你预测未来多少个点的数据
6splitsstr数据集划分的比例,用于训练、验证和测试集的分割,默认为"0.6,0.2",表示60%的数据用于训练,20%用于验证,剩余部分用于测试。
7batch-sizeint一次往模型里输入多少个数据
8filter-liststr这个是CNN的卷积核大小,默认为"6,12,18",表示使用6、12和18的卷积核大小,这个可以根据你数据集的间隔来,如果你觉得你数据集的数据比较平缓可以设置大一点如果比较极端那么设置就小一点,如果你不知道也可以看我末尾推荐的文章里面有详细讲解。
9num-filtersint每个过滤器大小的滤波器数量,默认为100。
10recurrent-state-sizeint每个未滚动展开循环单元中隐藏单位的数量,默认为100。
11seasonal-periodint季节性测量之间的时间间隔,默认为24,表示每24小时进行一次季节性测量。这个数据也很重要就是你数据具有的季节性。这个参数不同的数据集都不一样,如果你想知道如何测量你数据中的季节性、周期性等因素可以看我的其它博客文章的末尾会分享里面有详细的讲解。
12time-intervalint每个测量之间的时间间隔,默认为1。
13gpusstr要使用的GPU列表,例如"0"或"0,2,5",为空表示使用CPU。
14optimizerstr优化器的类型,默认为"adam"。
15lrfloat初始学习率,默认为0.001。
16dropoutfloat网络的dropout率,默认为0.2。
17num-epochsint最大的训练轮数,默认为100。
18svae-periodint每隔n个训练轮保存一次模型检查点,默认为20。
19model-prefixstr保存模型参数的前缀,默认为"electricity_model"。

模型训练

到此为止模型的准备工作以及全部做好了,经过参数的讲解和数据集的准备,可以开始训练模型了。

环境介绍

在正式开始训练之前介绍一下本模型用到的模块版本如下->

python=3.6

mxnet

numpy

pandas

tqdm

训练代码讲解

我们的程序入口代码汇总如下->

if __name__ == '__main__':# parse argsargs = parser.parse_args()args.splits = list(map(float, args.splits.split(',')))args.filter_list = list(map(int, args.filter_list.split(',')))# Check valid argsif not max(args.filter_list) <= args.q:raise AssertionError("no filter can be larger than q")if not args.q >= math.ceil(args.seasonal_period / args.time_interval):raise AssertionError("size of skip connections cannot exceed q")# Build data iteratorstrain_iter, val_iter, test_iter = build_iters(args.data_dir, args.max_records, args.q, args.horizon, args.splits, args.batch_size)# Choose cells for recurrent layers: each cell will take the output of the previous cell in the listrcells = [mx.rnn.GRUCell(num_hidden=args.recurrent_state_size)]skiprcells = [mx.rnn.LSTMCell(num_hidden=args.recurrent_state_size)]# Define network symbolsymbol, data_names, label_names = sym_gen(train_iter, args.q, args.filter_list, args.num_filters,args.dropout, rcells, skiprcells, args.seasonal_period, args.time_interval)Train = True# train cnn modelif Train:module = train(symbol, train_iter, val_iter, data_names, label_names)predict(symbol, train_iter, val_iter, test_iter, data_names, label_names)

下面对其中的代码分别进行讲解!!!

    args = parser.parse_args()args.splits = list(map(float, args.splits.split(',')))args.filter_list = list(map(int, args.filter_list.split(',')))# Check valid argsif not max(args.filter_list) <= args.q:raise AssertionError("no filter can be larger than q")if not args.q >= math.ceil(args.seasonal_period / args.time_interval):raise AssertionError("size of skip connections cannot exceed q")

这一部分就是一些参数的解析部分了,检测参数是否有一些不符合规定的输入,不涉及到代码的流程,没什么好讲的给大家。

# Build data iteratorstrain_iter, val_iter, test_iter = build_iters(args.data_dir, args.max_records, args.q, args.horizon, args.splits, args.batch_size)

这是构建训练集、验证集、测试集的数据加载器,需要注意的是时间序列是以滚动的形式构建数据加载器的。

    # Choose cells for recurrent layers: each cell will take the output of the previous cell in the listrcells = [mx.rnn.GRUCell(num_hidden=args.recurrent_state_size)]skiprcells = [mx.rnn.LSTMCell(num_hidden=args.recurrent_state_size)]

这里是定义GRU和LSTM的地方如果你想要修改其它的RNN单元就可以在这里修改进行其它尝试,毕竟GRU和LSTM以及存在许多年了现在有许多更高效效果更好的RNN单元存在。

    # Define network symbolsymbol, data_names, label_names = sym_gen(train_iter, args.q, args.filter_list, args.num_filters,args.dropout, rcells, skiprcells, args.seasonal_period, args.time_interval)

这一步就是构建网络结构了,其中sym_gen是我们定义的方法。

    Train = True# train cnn modelif Train:module = train(symbol, train_iter, val_iter, data_names, label_names)predict(symbol, train_iter, val_iter, test_iter, data_names, label_names)

进行训练和预测,Train=True时进行训练和预测,Train=False时候只进行预测不训练模型。 

训练模型 

下面我们开始正式的训练,运行程序文件,控制台进行输出如下。

训练完成后,模型会自动保存在该目录下->

模型预测

我们进行模型的预测主要观察我们想看的特征列"OT"列预测结果如下。

同时会将所有的结果和真实值输出到控制台并生成csv文件。

 保存到同级目录下的输出结果和折线图如下->

结果分析

可以说结果还可以接受,这个模型的设计还是算成功的,当然精度还有待提升,后续的话可以更改一些结构,或者添加一些其它的网络层,这里我们再来展示一下其它几列的预测结果。

PS->需要注意的是我的训练数据只用了三百多条能达到这个精度我还是比较满意的。

LULL特征预测结果如下图->

LUFL特征预测结果如下-> 

训练个人数据集所需修改

下面来讲一下训练你个人数据集需要什么修改,其实主要修改的主要是参数部分,大部分的代码Bug我以及修复好了,所以下面来讲一下。

parser.add_argument('--data-dir', type=str, default='./', help='relative path to input data')
parser.add_argument('--data_name',type=str, default='ETTh1-Test.csv', help='Input Model File Name')
parser.add_argument('--q', type=int, default=24*7, help='number of histrical measurements included in each training example')
parser.add_argument('--filter-list', type=str, default="6,12,18", help='unique filter sizes')
parser.add_argument('--seasonal-period', type=int, default=24, help='time between seasonal measurements')
parser.add_argument('--time-interval', type=int, default=1, help='time between each measurement')

这个是模型中需要你修改的参数部分,具体的修改和修改意见我在参数讲解部分以及提到了,大家可以回去参照这修改之后就可以运行该模型训练自己的数据集了。 

项目完整代码分析 

项目的完整代码如下->大家可以进行复制运行即可。 

import argparse
import logging
import math
import os
import matplotlib.pyplot as plt
import mxnet as mx
import numpy as np
import pandas as pd
from tqdm import tqdm# 将matplotlib的日志级别设置为警告级别
mpl_logger = logging.getLogger('matplotlib')
mpl_logger.setLevel(logging.WARNING)
logging.basicConfig(level=logging.DEBUG)# 参数设置部分
parser = argparse.ArgumentParser(description="CNN-GRU-LSTM for multivariate time series forecasting",formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('--data-dir', type=str, default='./', help='relative path to input data')
parser.add_argument('--data_name', type=str, default='ETTh1-Test.csv', help='Input Model File Name')
parser.add_argument('--max-records', type=int, default=None, help='total records before data split')
parser.add_argument('--q', type=int, default=24 * 7,help='number of histrical measurements included in each training example')
parser.add_argument('--horizon', type=int, default=4, help='number of measurements ahead to predict')
parser.add_argument('--splits', type=str, default="0.6,0.2",help='fraction of data to use for train & validation. remainder used for test.')
parser.add_argument('--batch-size', type=int, default=128, help='the batch size.')
parser.add_argument('--filter-list', type=str, default="6,12,18", help='unique filter sizes')
parser.add_argument('--num-filters', type=int, default=100, help='number of each filter size')
parser.add_argument('--recurrent-state-size', type=int, default=100,help='number of hidden units in each unrolled recurrent cell')
parser.add_argument('--seasonal-period', type=int, default=24, help='time between seasonal measurements')
parser.add_argument('--time-interval', type=int, default=1, help='time between each measurement')
parser.add_argument('--gpus', type=str, default='',help='list of gpus to run, e.g. 0 or 0,2,5. empty means using cpu. ')
parser.add_argument('--optimizer', type=str, default='adam', help='the optimizer type')
parser.add_argument('--lr', type=float, default=0.001, help='initial learning rate')
parser.add_argument('--dropout', type=float, default=0.2, help='dropout rate for network')
parser.add_argument('--num-epochs', type=int, default=100, help='max num of epochs')
parser.add_argument('--save-period', type=int, default=20, help='save checkpoint for every n epochs')
parser.add_argument('--model_prefix', type=str, default='electricity_model', help='prefix for saving model params')def rse(label, pred):"""computes the root relative squared error (condensed using standard deviation formula)"""numerator = np.sqrt(np.mean(np.square(label - pred), axis=None))denominator = np.std(label, axis=None)return numerator / denominatordef rae(label, pred):"""computes the relative absolute error (condensed using standard deviation formula)"""numerator = np.mean(np.abs(label - pred), axis=None)denominator = np.mean(np.abs(label - np.mean(label, axis=None)), axis=None)return numerator / denominatordef corr(label, pred):"""computes the empirical correlation coefficient"""numerator1 = label - np.mean(label, axis=0)numerator2 = pred - np.mean(pred, axis=0)numerator = np.mean(numerator1 * numerator2, axis=0)denominator = np.std(label, axis=0) * np.std(pred, axis=0)return np.mean(numerator / denominator)def get_custom_metrics():""":return: mxnet metric object"""_rse = mx.metric.create(rse)_rae = mx.metric.create(rae)_corr = mx.metric.create(corr)return mx.metric.create([_rae, _rse, _corr])def evaluate(pred, label):return {"RAE": rae(label, pred), "RSE": rse(label, pred), "CORR": corr(label, pred)}def build_iters(data_dir, max_records, q, horizon, splits, batch_size):"""Load & generate training examples from multivariate time series data:return: data iters & variables required to define network architecture"""# Read in data as numpy arraydf = pd.read_csv(os.path.join(data_dir, "ETTh1-Test.csv"), sep=",", )feature_df = df.iloc[:, 1:].fillna(0).astype(float)x = feature_df.valuesx = x[:max_records] if max_records else x# Construct training examples based on horizon and windowx_ts = np.zeros((x.shape[0] - q, q, x.shape[1]))y_ts = np.zeros((x.shape[0] - q, x.shape[1]))for n in range(x.shape[0]):if n + 1 < q:continueelif n + 1 + horizon > x.shape[0]:continueelse:y_n = x[n + horizon, :]x_n = x[n + 1 - q:n + 1, :]x_ts[n - q] = x_ny_ts[n - q] = y_n# Split into training and testing datatraining_examples = int(x_ts.shape[0] * splits[0])valid_examples = int(x_ts.shape[0] * splits[1])x_train, y_train = x_ts[:training_examples], \y_ts[:training_examples]x_valid, y_valid = x_ts[training_examples:training_examples + valid_examples], \y_ts[training_examples:training_examples + valid_examples]x_test, y_test = x_ts[training_examples + valid_examples:], \y_ts[training_examples + valid_examples:]# build iterators to feed batches to networktrain_iter = mx.io.NDArrayIter(data=x_train,label=y_train,batch_size=batch_size)val_iter = mx.io.NDArrayIter(data=x_valid,label=y_valid,batch_size=batch_size)test_iter = mx.io.NDArrayIter(data=x_test,label=y_test,batch_size=batch_size)return train_iter, val_iter, test_iterdef sym_gen(train_iter, q, filter_list, num_filter, dropout, rcells, skiprcells, seasonal_period, time_interval):input_feature_shape = train_iter.provide_data[0][1]X = mx.symbol.Variable(train_iter.provide_data[0].name)Y = mx.sym.Variable(train_iter.provide_label[0].name)# reshape data before applying convolutional layer (takes 4D shape incase you ever work with images)conv_input = mx.sym.reshape(data=X, shape=(0, 1, q, -1))################ CNN Component###############outputs = []for i, filter_size in enumerate(filter_list):# pad input array to ensure number output rows = number input rows after applying kernelpadi = mx.sym.pad(data=conv_input, mode="constant", constant_value=0,pad_width=(0, 0, 0, 0, filter_size - 1, 0, 0, 0))convi = mx.sym.Convolution(data=padi, kernel=(filter_size, input_feature_shape[2]), num_filter=num_filter)acti = mx.sym.Activation(data=convi, act_type='relu')trans = mx.sym.reshape(mx.sym.transpose(data=acti, axes=(0, 2, 1, 3)), shape=(0, 0, 0))outputs.append(trans)cnn_features = mx.sym.Concat(*outputs, dim=2)cnn_reg_features = mx.sym.Dropout(cnn_features, p=dropout)################ GRU Component###############stacked_rnn_cells = mx.rnn.SequentialRNNCell()for i, recurrent_cell in enumerate(rcells):stacked_rnn_cells.add(recurrent_cell)stacked_rnn_cells.add(mx.rnn.DropoutCell(dropout))outputs, states = stacked_rnn_cells.unroll(length=q, inputs=cnn_reg_features, merge_outputs=False)rnn_features = outputs[-1]  # only take value from final unrolled cell for use later##################### LSTM Component####################stacked_rnn_cells = mx.rnn.SequentialRNNCell()for i, recurrent_cell in enumerate(skiprcells):stacked_rnn_cells.add(recurrent_cell)stacked_rnn_cells.add(mx.rnn.DropoutCell(dropout))outputs, states = stacked_rnn_cells.unroll(length=q, inputs=cnn_reg_features, merge_outputs=False)# Take output from cells p steps apartp = int(seasonal_period / time_interval)output_indices = list(range(0, q, p))outputs.reverse()skip_outputs = [outputs[i] for i in output_indices]skip_rnn_features = mx.sym.concat(*skip_outputs, dim=1)########################### Autoregressive Component##########################auto_list = []for i in list(range(input_feature_shape[2])):time_series = mx.sym.slice_axis(data=X, axis=2, begin=i, end=i + 1)fc_ts = mx.sym.FullyConnected(data=time_series, num_hidden=1)auto_list.append(fc_ts)ar_output = mx.sym.concat(*auto_list, dim=1)####################### Prediction Component######################neural_components = mx.sym.concat(*[rnn_features, skip_rnn_features], dim=1)neural_output = mx.sym.FullyConnected(data=neural_components, num_hidden=input_feature_shape[2])model_output = neural_output + ar_outputloss_grad = mx.sym.LinearRegressionOutput(data=model_output, label=Y)return loss_grad, [v.name for v in train_iter.provide_data], [v.name for v in train_iter.provide_label]def train(symbol, train_iter, val_iter, data_names, label_names):devs = mx.cpu() if args.gpus is None or args.gpus is '' else [mx.gpu(int(i)) for i in args.gpus.split(',')]module = mx.mod.Module(symbol, data_names=data_names, label_names=label_names, context=devs)module.bind(data_shapes=train_iter.provide_data, label_shapes=train_iter.provide_label)module.init_params(mx.initializer.Uniform(0.1))module.init_optimizer(optimizer=args.optimizer, optimizer_params={'learning_rate': args.lr})for epoch in tqdm(range(1, args.num_epochs + 1), desc="Epochs"):train_iter.reset()val_iter.reset()for batch in tqdm(train_iter, desc="Batches", leave=False):module.forward(batch, is_train=True)  # compute predictionsmodule.backward()  # compute gradientsmodule.update()  # update parameterstrain_pred = module.predict(train_iter).asnumpy()train_label = train_iter.label[0][1].asnumpy()print('\nMetrics: Epoch %d, Training %s' % (epoch, evaluate(train_pred, train_label)))val_pred = module.predict(val_iter).asnumpy()val_label = val_iter.label[0][1].asnumpy()print('Metrics: Epoch %d, Validation %s' % (epoch, evaluate(val_pred, val_label)))if epoch % args.save_period == 0 and epoch > 1:module.save_checkpoint(prefix=os.path.join("../models/", args.model_prefix), epoch=epoch,save_optimizer_states=False)if epoch == args.num_epochs:module.save_checkpoint(prefix=os.path.join("../models/", args.model_prefix), epoch=epoch,save_optimizer_states=False)return moduledef predict(symbol, train_iter, val_iter, test_iter, data_names, label_names):devs = mx.cpu() if args.gpus is None or args.gpus is '' else [mx.gpu(int(i)) for i in args.gpus.split(',')]module = mx.mod.Module(symbol, data_names=data_names, label_names=label_names, context=devs)module.bind(data_shapes=train_iter.provide_data, label_shapes=train_iter.provide_label)module.init_params(mx.initializer.Uniform(0.1))module.init_optimizer(optimizer=args.optimizer, optimizer_params={'learning_rate': args.lr})# 加载模型参数params_file = "../models/electricity_model-0100.params"  # 参数文件的路径module.load_params(params_file)# 将模型转换为评估模式test_iter.reset()test_pred = module.predict(test_iter).asnumpy()test_label = test_iter.label[0][1].asnumpy()pre_results = []real_results = []for i in range(len(test_pred)):# 这里你想看那个列的图形就画出那个列的即可pre_results.append(test_pred[i][4])real_results.append(test_label[i][4])print("预测值:", pre_results)print("真实值:", real_results)df = pd.DataFrame({'real': real_results, 'forecast': pre_results})df.to_csv('results.csv', index=False)# 创建一个新的图形plt.figure(figsize=(10, 6))# 绘制预测值曲线,使用蓝色实线plt.plot(pre_results, color='blue', linestyle='-', linewidth=2, label='Predicted')# 绘制真实值曲线,使用红色虚线plt.plot(real_results, color='red', linestyle='--', linewidth=2, label='True')# 添加标题和轴标签plt.title('Predicted vs True Values', fontsize=16)plt.xlabel('Time', fontsize=12)plt.ylabel('Value', fontsize=12)# 添加图例plt.legend(loc='upper left')# 显示网格线plt.grid(True, linestyle='--', alpha=0.5)# 保存图形plt.savefig('line_plot.png')# 显示图形plt.show()print(test_label, test_pred)if __name__ == '__main__':# parse argsargs = parser.parse_args()args.splits = list(map(float, args.splits.split(',')))args.filter_list = list(map(int, args.filter_list.split(',')))# Check valid argsif not max(args.filter_list) <= args.q:raise AssertionError("no filter can be larger than q")if not args.q >= math.ceil(args.seasonal_period / args.time_interval):raise AssertionError("size of skip connections cannot exceed q")# Build data iteratorstrain_iter, val_iter, test_iter = build_iters(args.data_dir, args.max_records, args.q, args.horizon, args.splits,args.batch_size)# Choose cells for recurrent layers: each cell will take the output of the previous cell in the listrcells = [mx.rnn.GRUCell(num_hidden=args.recurrent_state_size)]skiprcells = [mx.rnn.LSTMCell(num_hidden=args.recurrent_state_size)]# Define network symbolsymbol, data_names, label_names = sym_gen(train_iter, args.q, args.filter_list, args.num_filters,args.dropout, rcells, skiprcells, args.seasonal_period,args.time_interval)Train = True# train cnn modelif Train:module = train(symbol, train_iter, val_iter, data_names, label_names)predict(symbol, train_iter, val_iter, test_iter, data_names, label_names)

项目的目录结构如下-> 

全文总结

到此本文已经全部讲解完成了,希望能够帮助到大家,在这里也给大家推荐一些我其它的博客的时间序列实战案例讲解,其中有数据分析的讲解就是我前面提到的如何设置参数的分析博客,最后希望大家订阅我的专栏,本专栏均分文章均分98。

时间序列预测中的数据分析->周期性、相关性、滞后性、趋势性、离群值等特性的分析方法

时间序列预测模型实战案例(八)(Informer)个人数据集、详细参数、代码实战讲解

时间序列预测模型实战案例(七)(TPA-LSTM)结合TPA注意力机制的LSTM实现多元预测

时间序列预测模型实战案例(六)深入理解机器学习ARIMA包括差分和相关性分析

时间序列预测模型实战案例(五)基于双向LSTM横向搭配单向LSTM进行回归问题解决

时间序列预测模型实战案例(四)(Xgboost)(Python)(机器学习)图解机制原理实现时间序列预测和分类(附一键运行代码资源下载和代码讲解)

时间序列预测模型实战案例(三)(LSTM)(Python)(深度学习)时间序列预测(包括运行代码以及代码讲解)

【全网首发】(MTS-Mixers)(Python)(Pytorch)最新由华为发布的时间序列预测模型实战案例(一)(包括代码讲解)实现企业级预测精度包括官方代码BUG修复Transform模型

时间序列预测模型实战案例(二)(Holt-Winter)(Python)结合K-折交叉验证进行时间序列预测实现企业级预测精度(包括运行代码以及代码讲解)

如果大家有不懂的也可以评论区留言一些报错什么的大家可以讨论讨论看到我也会给大家解答如何解决!

最后希望大家工作顺利学业有成!

相关文章:

时间序列预测模型实战案例(十)(个人创新模型)通过堆叠CNN、GRU、LSTM实现多元预测和单元预测

本文介绍 本篇博客为大家讲解的是通过组堆叠CNN、GRU、LSTM个数&#xff0c;建立多元预测和单元预测的时间序列预测模型&#xff0c;其效果要比单用GRU、LSTM效果好的多&#xff0c;其结合了CNN的特征提取功能、GRU和LSTM用于处理数据中的时间依赖关系的功能。通过将它们组合在…...

【有源码】基于uniapp的农场管理小程序springboot基于微信小程序的农场检测系统(源码 调试 lw 开题报告ppt)

&#x1f495;&#x1f495;作者&#xff1a;计算机源码社 &#x1f495;&#x1f495;个人简介&#xff1a;本人七年开发经验&#xff0c;擅长Java、Python、PHP、.NET、微信小程序、爬虫、大数据等&#xff0c;大家有这一块的问题可以一起交流&#xff01; &#x1f495;&…...

商城系统分布式下单

一、锁定库存的sql select * from ware where id{id} and total-lock>0 update ware set locklock{num} where id{id} and total-lock>{num} 二、下单服务要用分布式事务&#xff0c;因为seat的二阶段提交要说很多资源&#xff0c;会造成处理变成串行化&#xff0c;高并发…...

Java自学第5课:Java web开发环境概述,更换Eclipse版本

1 Java web开发环境 前面我们讲了java基本开发环境&#xff0c;但最终还是要转到web来的&#xff0c;先看下怎么搭建开发环境。 这个图就是大概讲了下开发和应用环境&#xff0c;其实很简单&#xff0c;对于一台裸机&#xff0c;win7 系统的&#xff0c;首先第1步&#xff0c;…...

[网鼎杯 2020 青龙组]AreUSerialz

[网鼎杯 2020 青龙组]AreUSerialz <?phpinclude("flag.php");highlight_file(__FILE__);class FileHandler {protected $op;protected $filename;protected $content;function __construct() {$op "1";$filename "/tmp/tmpfile";$content…...

使用Kotlin与Unirest库抓取音频文件的技术实践

目录 摘要 一、Kotlin与Unirest库概述 二、使用Kotlin和Unirest抓取音频文件 1、添加Unirest依赖 2、发送HTTP请求获取音频文件 3、保存音频文件 三、完整代码示例 四、注意事项 结论 摘要 本文详细阐述了如何使用Kotlin编程语言与Unirest库抓取网络上的音频文件。首…...

gdb调试常用命令

基本命令 1&#xff09;进入GDB  #gdb test test是要调试的程序&#xff0c;由gcc test.c -g -o test生成。进入后提示符变为(gdb) 。 2&#xff09;查看源码  (gdb) l 源码会进行行号提示。 如果需要查看在其他文件中定义的函数&#xff0c;在l后加上函数名即可定位到这…...

CH11_重构API

将查询函数和修改函数分离&#xff08;Separate Query from Modifier&#xff09; function getTotalOutstandingAndSendBill() {const result customer.invoices.reduce((total, each) > each.amount total, 0);sendBill();return result; }function totalOutstanding() …...

UPLOAD-LABS1

less1 (js验证) 我们上传PHP的发现不可以&#xff0c;只能是jpg&#xff0c;png&#xff0c;gif&#xff08;白名单限制了&#xff09; 我们可以直接去修改限制 在查看器中看到使用了onsubmit这个函数&#xff0c;触发了鼠标的单击事件&#xff0c;在表单提交后马上调用了re…...

WordPress相关文章推荐

首先 WordPress 本身并没有相关文章的推荐功能&#xff0c;网站之所以需要这样的功能出于两个原因&#xff0c;一方面是推荐相关的内容越优质&#xff0c;访客的留存和继续阅读将会增强&#xff0c;同样从优化角度来说会更加有利于搜索引擎抓取时对页面质量的提升&#xff0c;毕…...

【QML】Qt和QML获取操作系统类型

1. Qt获取系统类型 //方法 QSysInfo::productType()//举例&#xff1a; if(QSysInfo::productType() "windows") {qDebug() << "windows system"; }官方说明&#xff1a; [static] QString QSysInfo::productType() Returns the product name of …...

CSS 显示、定位、布局、浮动

一、CSS 显示&#xff1a; CSS display属性设置元素应如何显示&#xff1b;CSS visibility属性指定元素应可见还是隐藏。隐藏元素可以通过display属性设置为“none”&#xff0c;也可以通过visibility属性设置为“hidden”。两者的区别&#xff1a;visibility:hidden可以隐藏某…...

Java 学习笔记

文章目录 一、集合1.1 List1.1.1 ArrayList1.1.2 Vector1.1.3 LinkedList 1.2 Deque1.3 Set1.4 Map1.4.1 HashMap1.4.2 LinkedHashMap 1.5 注意事项 二、函数式接口和 Lambda 表达式三、方法引用3.1 静态方法引用3.2 实例方法引用3.2 特定类型的方法引用3.4 构造器引用 四、Str…...

项目实战:优化Servlet,把所有围绕Fruit操作的Servlet封装成一个Servlet

1、FruitServlet 这些Servlet都是围绕着Fruit进行的把所有对水果增删改查的Servlet放到一个Servlet里面&#xff0c;让tomcat实例化一个Servlet对象 package com.csdn.fruit.servlet; import com.csdn.fruit.dto.PageInfo; import com.csdn.fruit.dto.PageQueryParam; import c…...

Go语言函数参数

文章目录 Go语言函数参数1. **函数参数的定义**&#xff1a;2. **参数的数量**&#xff1a;3. **参数的数据类型**&#xff1a;4. **参数的命名**&#xff1a;5. **参数的传递**&#xff1a;6. **参数的传递方式**&#xff1a;7. **空白标识符**&#xff1a; Go语言函数参数 在…...

【遍历二叉树的非递归算法,二叉树的层次遍历】

文章目录 遍历二叉树的非递归算法二叉树的层次遍历 遍历二叉树的非递归算法 先序遍历序列建立二叉树的二叉链表 中序遍历非递归算法 二叉树中序遍历的非递归算法的关键&#xff1a;在中序遍历过某个结点的整个左子树后&#xff0c;如何找到该结点的根以及右子树。 基本思想&a…...

数模之线性规划

线性规划 优化类问题&#xff1a;有限的资源&#xff0c;最大的收益 例子: 华强去水果摊找茬&#xff0c;水果摊上共3个瓜&#xff0c;华强总共有40点体力值,每劈一个瓜能带来40点挑衅值,每挑一个瓜问“你这瓜保熟吗”能带来30点挑衅值,劈瓜消耗20点体力值&#xff0c;问话消耗…...

【C++】AVL树的4中旋转调整

文章目录 前提一、AVL树的结构定义二、AVL的插入&#xff08;重点&#xff09;1. 插入的结点在较高左子树的左侧&#xff08;右单旋&#xff09;2. 新节点插入较高右子树的右侧&#xff08;左单旋&#xff09;3.新结点插入较高右子树的左侧&#xff08;先右单旋再左单旋&#x…...

【MATLAB源码-第69期】基于matlab的LDPC码,turbo码,卷积码误码率对比,码率均为1/3,BPSK调制。

操作环境&#xff1a; MATLAB 2022a 1、算法描述 本文章介绍了卷积码、Turbo码和LDPC码。以相同的码率仿真这三种编码&#xff0c;并对比其误码率性能 信源输出的数据符号&#xff08;二进制&#xff09;是相互独立和等概率的&#xff1b; 信道是加性白高斯噪声信道&#…...

Java获取时间戳、字符串和Date对象的相互转换、日期时间格式化、获取年月日

获取时间戳&#xff08;自1970年1月1日经历的毫秒数值&#xff09; package org.example;import java.util.Date;public class Main {public static void main(String[] args) {Date date1 new Date(1699540662210L);System.out.println(date1.getTime());Date date2 new Dat…...

浅谈 React Hooks

React Hooks 是 React 16.8 引入的一组 API&#xff0c;用于在函数组件中使用 state 和其他 React 特性&#xff08;例如生命周期方法、context 等&#xff09;。Hooks 通过简洁的函数接口&#xff0c;解决了状态与 UI 的高度解耦&#xff0c;通过函数式编程范式实现更灵活 Rea…...

零门槛NAS搭建:WinNAS如何让普通电脑秒变私有云?

一、核心优势&#xff1a;专为Windows用户设计的极简NAS WinNAS由深圳耘想存储科技开发&#xff0c;是一款收费低廉但功能全面的Windows NAS工具&#xff0c;主打“无学习成本部署” 。与其他NAS软件相比&#xff0c;其优势在于&#xff1a; 无需硬件改造&#xff1a;将任意W…...

DeepSeek 赋能智慧能源:微电网优化调度的智能革新路径

目录 一、智慧能源微电网优化调度概述1.1 智慧能源微电网概念1.2 优化调度的重要性1.3 目前面临的挑战 二、DeepSeek 技术探秘2.1 DeepSeek 技术原理2.2 DeepSeek 独特优势2.3 DeepSeek 在 AI 领域地位 三、DeepSeek 在微电网优化调度中的应用剖析3.1 数据处理与分析3.2 预测与…...

mongodb源码分析session执行handleRequest命令find过程

mongo/transport/service_state_machine.cpp已经分析startSession创建ASIOSession过程&#xff0c;并且验证connection是否超过限制ASIOSession和connection是循环接受客户端命令&#xff0c;把数据流转换成Message&#xff0c;状态转变流程是&#xff1a;State::Created 》 St…...

2.Vue编写一个app

1.src中重要的组成 1.1main.ts // 引入createApp用于创建应用 import { createApp } from "vue"; // 引用App根组件 import App from ./App.vue;createApp(App).mount(#app)1.2 App.vue 其中要写三种标签 <template> <!--html--> </template>…...

【CSS position 属性】static、relative、fixed、absolute 、sticky详细介绍,多层嵌套定位示例

文章目录 ★ position 的五种类型及基本用法 ★ 一、position 属性概述 二、position 的五种类型详解(初学者版) 1. static(默认值) 2. relative(相对定位) 3. absolute(绝对定位) 4. fixed(固定定位) 5. sticky(粘性定位) 三、定位元素的层级关系(z-i…...

RNN避坑指南:从数学推导到LSTM/GRU工业级部署实战流程

本文较长&#xff0c;建议点赞收藏&#xff0c;以免遗失。更多AI大模型应用开发学习视频及资料&#xff0c;尽在聚客AI学院。 本文全面剖析RNN核心原理&#xff0c;深入讲解梯度消失/爆炸问题&#xff0c;并通过LSTM/GRU结构实现解决方案&#xff0c;提供时间序列预测和文本生成…...

什么是Ansible Jinja2

理解 Ansible Jinja2 模板 Ansible 是一款功能强大的开源自动化工具&#xff0c;可让您无缝地管理和配置系统。Ansible 的一大亮点是它使用 Jinja2 模板&#xff0c;允许您根据变量数据动态生成文件、配置设置和脚本。本文将向您介绍 Ansible 中的 Jinja2 模板&#xff0c;并通…...

Python ROS2【机器人中间件框架】 简介

销量过万TEEIS德国护膝夏天用薄款 优惠券冠生园 百花蜂蜜428g 挤压瓶纯蜂蜜巨奇严选 鞋子除臭剂360ml 多芬身体磨砂膏280g健70%-75%酒精消毒棉片湿巾1418cm 80片/袋3袋大包清洁食品用消毒 优惠券AIMORNY52朵红玫瑰永生香皂花同城配送非鲜花七夕情人节生日礼物送女友 热卖妙洁棉…...

网站指纹识别

网站指纹识别 网站的最基本组成&#xff1a;服务器&#xff08;操作系统&#xff09;、中间件&#xff08;web容器&#xff09;、脚本语言、数据厍 为什么要了解这些&#xff1f;举个例子&#xff1a;发现了一个文件读取漏洞&#xff0c;我们需要读/etc/passwd&#xff0c;如…...