
NeuralForecast
.
Outline: 1. Install NeuralForecast
- Load and plot the data
- Train multiple models using cross-validation
- Evaluate models and select the best for each series
- Plot cross-validation results
Prerequesites
This guide assumes basic familiarity with neuralforecast
. For a
minimal example visit the Quick
Start
1. Install NeuralForecast
2. Load and plot the data
We’ll use pandas to load the hourly dataset from the M4 Forecasting Competition, which has been stored in a parquet file for efficiency.unique_id | ds | y | |
---|---|---|---|
0 | H1 | 1 | 605.0 |
1 | H1 | 2 | 586.0 |
2 | H1 | 3 | 586.0 |
3 | H1 | 4 | 559.0 |
4 | H1 | 5 | 511.0 |
neuralforecast
should be a data frame in long format with
three columns: unique_id
, ds
, and y
.
-
unique_id
(string, int, or category): A unique identifier for each time series. -
ds
(int or timestamp): An integer indexing time or a timestamp in format YYYY-MM-DD or YYYY-MM-DD HH:MM:SS. -
y
(numeric): The target variable to forecast.
plot_series
method from
utilsforecast.plotting
. utilsforecast
is a dependency of
neuralforecast
so it should be already installed.

3. Train multiple models using cross-validation
We’ll train different models fromneuralforecast
using the
cross-validation
method to decide which one perfoms best on the
historical data. To do this, we need to import the
NeuralForecast
class and the models that we want to compare.
neuralforecast's
MPL,
NBEATS,
and
NHITS
models.
First, we need to create a list of models and then instantiate the
NeuralForecast
class. For each model, we’ll define the following hyperparameters:
-
h
: The forecast horizon. Here, we will use the same horizon as in the M4 competition, which was 48 steps ahead. -
input_size
: The number of historical observations (lags) that the model uses to make predictions. In this case, it will be twice the forecast horizon. -
loss
: The loss function to optimize. Here, we’ll use the Multi Quantile Loss (MQLoss) fromneuralforecast.losses.pytorch
.
Warning The Multi Quantile Loss (MQLoss) is the sum of the quantile losses for each target quantile. The quantile loss for a single quantile measures how well a model has predicted a specific quantile of the actual distribution, penalizing overestimations and underestimations asymmetrically based on the quantile’s value. For more details see here.While there are other hyperparameters that can be defined for each model, we’ll use the default values for the purposes of this tutorial. To learn more about the hyperparameters of each model, please check out the corresponding documentation.
cross_validation
method takes the following arguments:
-
df
: The data frame in the format described in section 2. -
n_windows
(int): The number of windows to evaluate. Default is 1 and here we’ll use 3. -
step_size
(int): The number of steps between consecutive windows to produce the forecasts. In this example, we’ll setstep_size=horizon
to produce non-overlapping forecasts. The following diagram shows how the forecasts are produced based on thestep_size
parameter and forecast horizonh
of a model. In this diagramstep_size=2
andh=4
.

refit
(bool or int): Whether to retrain models for each cross-validation window. IfFalse
, the models are trained at the beginning and then used to predict each window. If a positive integer, the models are retrained everyrefit
windows. Default isFalse
, but here we’ll userefit=1
so that the models are retrained after each window using the data with timestamps up to and including the cutoff.
cross_validation
method in neuralforecast
diverges from other libraries, where models
are typically retrained at the start of each window. By default, it
trains the models once and then uses them to generate predictions over
all the windows, thus reducing the total execution time. For scenarios
where the models need to be retrained, you can use the refit
parameter
to specify the number of windows after which the models should be
retrained.
unique_id | ds | cutoff | MLP-median | MLP-lo-90 | MLP-lo-80 | MLP-hi-80 | MLP-hi-90 | NBEATS-median | NBEATS-lo-90 | NBEATS-lo-80 | NBEATS-hi-80 | NBEATS-hi-90 | NHITS-median | NHITS-lo-90 | NHITS-lo-80 | NHITS-hi-80 | NHITS-hi-90 | y | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | H1 | 605 | 604 | 638.964111 | 528.127747 | 546.731812 | 714.415466 | 750.265259 | 623.230896 | 580.549744 | 587.317688 | 647.942505 | 654.148682 | 625.377930 | 556.786926 | 577.746765 | 657.901611 | 670.458069 | 622.0 |
1 | H1 | 606 | 604 | 588.216370 | 445.395081 | 483.736542 | 684.394592 | 670.042358 | 552.829407 | 501.618988 | 529.007507 | 593.528564 | 603.152527 | 555.956177 | 511.696350 | 526.399597 | 604.318970 | 622.839722 | 558.0 |
2 | H1 | 607 | 604 | 542.242737 | 419.206757 | 439.244476 | 617.775269 | 638.583923 | 495.155548 | 451.871613 | 467.183533 | 550.048950 | 574.697021 | 502.860077 | 462.284668 | 460.950287 | 555.336731 | 571.852722 | 513.0 |
3 | H1 | 608 | 604 | 494.055573 | 414.775085 | 427.531647 | 583.965759 | 602.303772 | 465.182556 | 403.593140 | 410.033203 | 500.744019 | 518.277954 | 460.588684 | 406.762390 | 418.040710 | 501.833740 | 515.022095 | 476.0 |
4 | H1 | 609 | 604 | 469.330688 | 361.437927 | 378.501373 | 557.875244 | 569.767273 | 441.072388 | 371.541504 | 401.923584 | 483.667877 | 485.047729 | 441.463043 | 393.917725 | 394.483337 | 475.985229 | 499.001373 | 449.0 |
cross-validation
method is a data frame that
includes the following columns:
-
unique_id
: The unique identifier for each time series. -
ds
: The timestamp or temporal index. -
cutoff
: The last timestamp or temporal index used in that cross-validation window. -
"model"
: Columns with the model’s point forecasts (median) and prediction intervals. By default, the 80 and 90% prediction intervals are included when using the MQLoss. -
y
: The actual value.
4. Evaluate models and select the best for each series
To evaluate the point forecasts of the models, we’ll use the Root Mean Squared Error (RMSE), defined as the square root of the mean of the squared differences between the actual and the predicted values. For convenience, we’ll use theevaluate
and the
rmse
functions from utilsforecast
.
evaluate
function takes the following arguments:
-
df
: The data frame with the forecasts to evaluate. -
metrics
(list): The metrics to compute. -
models
(list): Names of the models to evaluate. Default isNone
, which uses all columns after removingid_col
,time_col
, andtarget_col
. -
id_col
(str): Column that identifies unique ids of the series. Default isunique_id
. -
time_col
(str): Column with the timestamps or the temporal index. Default isds
. -
target_col
(str): Column with the target variable. Default isy
.
models
, then we need to
exclude the cutoff
column from the cross-validation data frame.
unique_id | metric | MLP-median | NBEATS-median | NHITS-median | best_model | |
---|---|---|---|---|---|---|
0 | H1 | rmse | 46.654390 | 49.595304 | 47.651201 | MLP-median |
1 | H10 | rmse | 24.192081 | 21.580142 | 16.887989 | NHITS-median |
2 | H100 | rmse | 171.958998 | 178.820952 | 170.452623 | NHITS-median |
3 | H101 | rmse | 331.270162 | 260.021871 | 169.453119 | NHITS-median |
4 | H102 | rmse | 440.470939 | 362.602167 | 326.571391 | NHITS-median |
5 | H103 | rmse | 9069.937603 | 9267.925257 | 8578.535681 | NHITS-median |
6 | H104 | rmse | 189.534415 | 169.017976 | 226.442403 | NBEATS-median |
7 | H105 | rmse | 341.029706 | 284.038751 | 262.140145 | NHITS-median |
8 | H106 | rmse | 203.723728 | 328.128422 | 298.377068 | MLP-median |
9 | H107 | rmse | 212.384943 | 161.445838 | 231.303421 | NBEATS-median |
metric | model | num. of unique_ids | |
---|---|---|---|
0 | rmse | MLP-median | 2 |
1 | rmse | NBEATS-median | 2 |
2 | rmse | NHITS-median | 6 |
5. Plot cross-validation results
To visualize the cross-validation results, we will use theplot_series
method again. We’ll need to rename the y
column in the
cross-validation output to avoid duplicates with the original data
frame. We’ll also exclude the cutoff
column and use the
max_insample_length argument
to plot only the last 300 observations
for better visualization.

unique_id='H1'
.
There are three cutoffs because we set n_windows=3
. In this example,
we used refit=1
, so each model is retrained for each window using data
with timestamps up to and including the respective cutoff. Additionally,
since step_size
is equal to the forecast horizon, the resulting
forecasts are non-overlapping


