StemGNN
)
is a Graph-based multivariate time-series forecasting model.
StemGNN
jointly learns temporal dependencies and inter-series correlations in
the spectral domain, by combining Graph Fourier Transform (GFT) and
Discrete Fourier Transform (DFT).
This method proved state-of-the-art performance on geo-temporal datasets
such as Solar
, METR-LA
, and PEMS-BAY
, and
References-Defu Cao, Yujing Wang, Juanyong Duan, Ce Zhang, Xia Zhu, Congrui Huang, Yunhai Tong, Bixiong Xu, Jing Bai, Jie Tong, Qi Zhang (2020). “Spectral Temporal Graph Neural Network for Multivariate Time-series Forecasting”.
source
GLU
GLU
source
StockBlockLayer
StockBlockLayer
source
StemGNN
*StemGNN The Spectral Temporal Graph Neural Network (
StemGNN
)
is a Graph-based multivariate time-series forecasting model.
StemGNN
jointly learns temporal dependencies and inter-series correlations in
the spectral domain, by combining Graph Fourier Transform (GFT) and
Discrete Fourier Transform (DFT).
Parameters:h
: int, Forecast horizon. input_size
: int,
autorregresive inputs size, y=[1,2,3,4] input_size=2 ->
y_[t-2:t]=[1,2].n_series
: int, number of time-series.futr_exog_list
: str list, future exogenous columns.hist_exog_list
: str list, historic exogenous columns.stat_exog_list
: str list, static exogenous columns.n_stacks
:
int=2, number of stacks in the model.multi_layer
: int=5,
multiplier for FC hidden size on StemGNN blocks.dropout_rate
:
float=0.5, dropout rate.leaky_rate
: float=0.2, alpha for
LeakyReLU layer on Latent Correlation layer.loss
: PyTorch module,
instantiated train loss class from losses
collection.valid_loss
: PyTorch module=loss
, instantiated valid loss class from
losses
collection.max_steps
: int=1000, maximum number of training steps.learning_rate
: float=1e-3, Learning rate between (0, 1).num_lr_decays
: int=-1, Number of learning rate decays, evenly
distributed across max_steps.early_stop_patience_steps
: int=-1,
Number of validation iterations before early stopping.val_check_steps
: int=100, Number of training steps between every
validation loss check.batch_size
: int, number of windows in each
batch.valid_batch_size
: int=None, number of different series in
each validation and test batch, if None uses batch_size.windows_batch_size
: int=32, number of windows to sample in each
training batch, default uses all.inference_windows_batch_size
:
int=32, number of windows to sample in each inference batch, -1 uses
all.start_padding_enabled
: bool=False, if True, the model will
pad the time series with zeros at the beginning, by input size.step_size
: int=1, step size between each window of temporal data.scaler_type
: str=‘robust’, type of scaler for temporal inputs
normalization see temporal
scalers.random_seed
: int, random_seed for pytorch initializer and numpy
generators.drop_last_loader
: bool=False, if True
TimeSeriesDataLoader
drops last non-full batch.alias
: str,
optional, Custom name of the model.optimizer
: Subclass of
‘torch.optim.Optimizer’, optional, user specified optimizer instead of
the default choice (Adam).optimizer_kwargs
: dict, optional, list
of parameters used by the user specified optimizer
.lr_scheduler
: Subclass of ‘torch.optim.lr_scheduler.LRScheduler’,
optional, user specified lr_scheduler instead of the default choice
(StepLR).lr_scheduler_kwargs
: dict, optional, list of parameters
used by the user specified lr_scheduler
.dataloader_kwargs
:
dict, optional, list of parameters passed into the PyTorch Lightning
dataloader by the TimeSeriesDataLoader
. **trainer_kwargs
: int,
keyword trainer arguments inherited from PyTorch Lighning’s
trainer.*
StemGNN.fit
*Fit. The
fit
method, optimizes the neural network’s weights using the
initialization parameters (learning_rate
, windows_batch_size
, …) and
the loss
function as defined during the initialization. Within fit
we use a PyTorch Lightning Trainer
that inherits the initialization’s
self.trainer_kwargs
, to customize its inputs, see PL’s trainer
arguments.
The method is designed to be compatible with SKLearn-like classes and in
particular to be compatible with the StatsForecast library.
By default the model
is not saving training checkpoints to protect
disk memory, to get them change enable_checkpointing=True
in
__init__
.
Parameters:dataset
: NeuralForecast’s
TimeSeriesDataset
,
see
documentation.val_size
: int, validation size for temporal cross-validation.random_seed
: int=None, random_seed for pytorch initializer and numpy
generators, overwrites model.__init__’s.test_size
: int, test
size for temporal cross-validation.*
StemGNN.predict
*Predict. Neural network prediction with PL’s
Trainer
execution of
predict_step
.
Parameters:dataset
: NeuralForecast’s
TimeSeriesDataset
,
see
documentation.test_size
: int=None, test size for temporal cross-validation.step_size
: int=1, Step size between each window.random_seed
:
int=None, random_seed for pytorch initializer and numpy generators,
overwrites model.__init__’s.quantiles
: list of floats,
optional (default=None), target quantiles to predict. **data_module_kwargs
: PL’s TimeSeriesDataModule args, see
documentation.*
Usage Examples
Train model and forecast future values withpredict
method.
cross_validation
to forecast multiple historic values.