Tutorial on how to train and forecast Transformer models.
Informer
(Zhou, H. et al. 2021),
Autoformer
(Wu et al. 2021),
FEDformer
(Zhou, T. et al. 2022), and
PatchTST
(Nie et al. 2023).
Our implementation of all these models are univariate, meaning that only
autoregressive values of each feature are used for forecasting. We
observed that these unvivariate models are more accurate and faster than
their multivariate couterpart.
In this notebook we will show how to: * Load the
ETTm2 benchmark dataset, used
in the academic literature. * Train models * Forecast the test set
The results achieved in this notebook outperform the original
self-reported results in the respective original paper, with a fraction
of the computational cost. Additionally, all models are trained with the
default recommended parameters, results can be further improved using
our auto
models with automatic hyperparameter selection.
You can run these experiments using GPU with Google Colab.
LongHorizon
class will automatically download the complete ETTm2
dataset and process it.
It return three Dataframes: Y_df
contains the values for the target
variables, X_df
contains exogenous calendar features and S_df
contains static features for each time-series (none for ETTm2). For this
example we will only use Y_df
.
If you want to use your own data just replace Y_df
. Be sure to use a
long format and have a simmilar structure than our data set.
unique_id | ds | y | |
---|---|---|---|
0 | HUFL | 2016-07-01 00:00:00 | -0.041413 |
1 | HUFL | 2016-07-01 00:15:00 | -0.185467 |
57600 | HULL | 2016-07-01 00:00:00 | 0.040104 |
57601 | HULL | 2016-07-01 00:15:00 | -0.214450 |
115200 | LUFL | 2016-07-01 00:00:00 | 0.695804 |
115201 | LUFL | 2016-07-01 00:15:00 | 0.434685 |
172800 | LULL | 2016-07-01 00:00:00 | 0.434430 |
172801 | LULL | 2016-07-01 00:15:00 | 0.428168 |
230400 | MUFL | 2016-07-01 00:00:00 | -0.599211 |
230401 | MUFL | 2016-07-01 00:15:00 | -0.658068 |
288000 | MULL | 2016-07-01 00:00:00 | -0.393536 |
288001 | MULL | 2016-07-01 00:15:00 | -0.659338 |
345600 | OT | 2016-07-01 00:00:00 | 1.018032 |
345601 | OT | 2016-07-01 00:15:00 | 0.980124 |
cross_validation
method, which allows
users to automatically simulate multiple historic forecasts (in the test
set).
The cross_validation
method will use the validation set for
hyperparameter selection and early stopping, and will then produce the
forecasts for the test set.
First, instantiate each model in the models
list, specifying the
horizon
, input_size
, and training iterations.
(NOTE: The
FEDformer
model was excluded due to extremely long training times.)
Tip
Check our auto
models for automatic hyperparameter optimization.
Instantiate a
NeuralForecast
object with the following required parameters:
models
: a list of models.
freq
: a string indicating the frequency of the data. (See panda’s
available
frequencies.)
cross_validation
method, specifying the dataset
(Y_df
), validation size and test size.
cross_validation
method will return the forecasts for each model
on the test set.
unique_id | ds | cutoff | Informer | Autoformer | PatchTST | y | |
---|---|---|---|---|---|---|---|
0 | HUFL | 2017-10-24 00:00:00 | 2017-10-23 23:45:00 | -1.055062 | -0.861487 | -0.860189 | -0.977673 |
1 | HUFL | 2017-10-24 00:15:00 | 2017-10-23 23:45:00 | -1.021247 | -0.873399 | -0.865730 | -0.865620 |
2 | HUFL | 2017-10-24 00:30:00 | 2017-10-23 23:45:00 | -1.057297 | -0.900345 | -0.944296 | -0.961624 |
3 | HUFL | 2017-10-24 00:45:00 | 2017-10-23 23:45:00 | -0.886652 | -0.867466 | -0.974849 | -1.049700 |
4 | HUFL | 2017-10-24 01:00:00 | 2017-10-23 23:45:00 | -1.000431 | -0.887454 | -1.008530 | -0.953600 |
OT
variable for
all models.
Horizon | PatchTST | AutoFormer | Informer | ARIMA |
---|---|---|---|---|
96 | 0.256 | 0.339 | 0.453 | 0.301 |
192 | 0.296 | 0.340 | 0.563 | 0.345 |
336 | 0.329 | 0.372 | 0.887 | 0.386 |
720 | 0.385 | 0.419 | 1.388 | 0.445 |
NHITS
,
based on feed-forward networks in (Challu et al. 2023). It achieves on
par performance with
PatchTST
,
with a fraction of the computational cost. The
NHITS
tutorial is available
here.