Neural/MLForecast
This example notebook demonstrates the compatibility of HierarchicalForecast’s reconciliation methods with popular machine-learning libraries, specifically NeuralForecast and MLForecast.
The notebook utilizes NBEATS and XGBRegressor models to create base forecasts for the TourismLarge Hierarchical Dataset. After that, we use HierarchicalForecast to reconcile the base predictions.
References
- Boris N. Oreshkin, Dmitri Carpov, Nicolas
Chapados, Yoshua Bengio (2019). “N-BEATS: Neural basis expansion
analysis for interpretable time series forecasting”. url:
https://arxiv.org/abs/1905.10437
-
Tianqi Chen and Carlos Guestrin. “XGBoost: A Scalable Tree Boosting
System”. In: Proceedings of the 22nd ACM SIGKDD International Conference
on Knowledge Discovery and Data Mining. KDD ’16. San Francisco,
California, USA: Association for Computing Machinery, 2016, pp. 785–794.
isbn: 9781450342322. doi: 10.1145/2939672.2939785. url:
https://doi.org/10.1145/2939672.2939785 (cit. on
p. 26).
You can run these experiments using CPU or GPU with Google Colab.
1. Installing packages
# %pip install datasetsforecast hierarchicalforecast mlforecast neuralforecast
import numpy as np
import pandas as pd
from datasetsforecast.hierarchical import HierarchicalData
from neuralforecast import NeuralForecast
from neuralforecast.models import NBEATS
from neuralforecast.losses.pytorch import GMM
from mlforecast import MLForecast
from mlforecast.utils import PredictionIntervals
import xgboost as xgb
#obtain hierarchical reconciliation methods and evaluation
from hierarchicalforecast.methods import BottomUp, ERM, MinTrace
from hierarchicalforecast.utils import HierarchicalPlot
from hierarchicalforecast.core import HierarchicalReconciliation
from hierarchicalforecast.evaluation import scaled_crps
2. Load hierarchical dataset
This detailed Australian Tourism Dataset comes from the National Visitor Survey, managed by the Tourism Research Australia, it is composed of 555 monthly series from 1998 to 2016, it is organized geographically, and purpose of travel. The natural geographical hierarchy comprises seven states, divided further in 27 zones and 76 regions. The purpose of travel categories are holiday, visiting friends and relatives (VFR), business and other. The MinT (Wickramasuriya et al., 2019), among other hierarchical forecasting studies has used the dataset it in the past. The dataset can be accessed in the MinT reconciliation webpage, although other sources are available.
Geographical Division | Number of series per division | Number of series per purpose | Total |
---|---|---|---|
Australia | 1 | 4 | 5 |
States | 7 | 28 | 35 |
Zones | 27 | 108 | 135 |
Regions | 76 | 304 | 380 |
Total | 111 | 444 | 555 |
Y_df, S_df, tags = HierarchicalData.load('./data', 'TourismLarge')
Y_df['ds'] = pd.to_datetime(Y_df['ds'])
Y_df.head()
unique_id | ds | y | |
---|---|---|---|
0 | TotalAll | 1998-01-01 | 45151.071280 |
1 | TotalAll | 1998-02-01 | 17294.699551 |
2 | TotalAll | 1998-03-01 | 20725.114184 |
3 | TotalAll | 1998-04-01 | 25388.612353 |
4 | TotalAll | 1998-05-01 | 20330.035211 |
Visualize the aggregation matrix.
hplot = HierarchicalPlot(S=S_df, tags=tags)
hplot.plot_summing_matrix()
Split the dataframe in train/test splits.
def sort_hier_df(Y_df, S_df):
# sorts unique_id lexicographically
Y_df.unique_id = Y_df.unique_id.astype('category')
Y_df.unique_id = Y_df.unique_id.cat.set_categories(S_df.index)
Y_df = Y_df.sort_values(by=['unique_id', 'ds'])
return Y_df
Y_df = sort_hier_df(Y_df, S_df)
horizon = 12
Y_test_df = Y_df.groupby('unique_id').tail(horizon)
Y_train_df = Y_df.drop(Y_test_df.index)
3. Fit and Predict Models
HierarchicalForecast is compatible with many different ML models. Here,
we show two examples:
1. NBEATS, a MLP-based deep neural
architecture.
2. XGBRegressor, a tree-based architecture.
level = np.arange(0, 100, 2)
qs = [[50-lv/2, 50+lv/2] for lv in level]
quantiles = np.sort(np.concatenate(qs)/100)
#fit/predict NBEATS from NeuralForecast
nbeats = NBEATS(h=horizon,
input_size=2*horizon,
loss=GMM(n_components=10, quantiles=quantiles),
scaler_type='robust',
max_steps=2000)
nf = NeuralForecast(models=[nbeats], freq='MS')
nf.fit(df=Y_train_df)
Y_hat_nf = nf.predict()
insample_nf = nf.predict_insample(step_size=horizon)
#fit/predict XGBRegressor from MLForecast
mf = MLForecast(models=[xgb.XGBRegressor()],
freq='MS',
lags=[1,2,12,24],
date_features=['month'],
)
mf.fit(Y_train_df, fitted=True, prediction_intervals=PredictionIntervals(n_windows=10, h=horizon))
Y_hat_mf = mf.predict(horizon, level=level).set_index('unique_id')
insample_mf = mf.forecast_fitted_values()
Y_hat_nf
ds | NBEATS | NBEATS-lo-98.0 | NBEATS-lo-96.0 | NBEATS-lo-94.0 | NBEATS-lo-92.0 | NBEATS-lo-90.0 | NBEATS-lo-88.0 | NBEATS-lo-86.0 | NBEATS-lo-84.0 | … | NBEATS-hi-80.0 | NBEATS-hi-82.0 | NBEATS-hi-84.0 | NBEATS-hi-86.0 | NBEATS-hi-88.0 | NBEATS-hi-90.0 | NBEATS-hi-92.0 | NBEATS-hi-94.0 | NBEATS-hi-96.0 | NBEATS-hi-98.0 | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
unique_id | |||||||||||||||||||||
TotalAll | 2016-01-01 | 44525.652344 | 21232.554688 | 26024.839844 | 27435.285156 | 28136.705078 | 28766.150391 | 29569.240234 | 30344.240234 | 31163.099609 | … | 51812.953125 | 52171.792969 | 52628.562500 | 52890.750000 | 53160.312500 | 54025.210938 | 54451.109375 | 55651.007812 | 57686.027344 | 61461.066406 |
TotalAll | 2016-02-01 | 20819.431641 | 18020.289062 | 18314.943359 | 18480.269531 | 18612.464844 | 18695.382812 | 18807.242188 | 18912.910156 | 19027.187500 | … | 22719.998047 | 22802.921875 | 22887.734375 | 23031.005859 | 23133.865234 | 23230.322266 | 23406.496094 | 23622.166016 | 23887.796875 | 24165.496094 |
TotalAll | 2016-03-01 | 23676.291016 | 19303.222656 | 19684.693359 | 19928.400391 | 20150.691406 | 20319.113281 | 20499.980469 | 20632.185547 | 20748.207031 | … | 26215.312500 | 26291.195312 | 26402.853516 | 26578.257812 | 26848.179688 | 27054.107422 | 27310.746094 | 27723.867188 | 28211.294922 | 29011.082031 |
TotalAll | 2016-04-01 | 27978.587891 | 23936.988281 | 24329.892578 | 24532.740234 | 24735.703125 | 24902.812500 | 25165.074219 | 25256.669922 | 25489.455078 | … | 30192.365234 | 30278.451172 | 30339.017578 | 30381.443359 | 30465.722656 | 30574.056641 | 30682.609375 | 30860.427734 | 31032.648438 | 31199.992188 |
TotalAll | 2016-05-01 | 22810.310547 | 20037.218750 | 20194.531250 | 20387.541016 | 20510.244141 | 20594.226562 | 20675.720703 | 20767.025391 | 20876.550781 | … | 24975.916016 | 25149.097656 | 25240.177734 | 25401.996094 | 25577.400391 | 25800.574219 | 26132.904297 | 26559.906250 | 27273.566406 | 28567.857422 |
… | … | … | … | … | … | … | … | … | … | … | … | … | … | … | … | … | … | … | … | … | … |
GBDOth | 2016-08-01 | 3.384338 | -31.891897 | -15.230768 | -1.954657 | -1.143704 | -0.994592 | -0.947800 | -0.884839 | -0.824748 | … | 9.635074 | 10.517044 | 11.374988 | 12.784556 | 14.568413 | 22.581669 | 37.880905 | 51.512486 | 62.645977 | 81.495415 |
GBDOth | 2016-09-01 | 4.842800 | -41.682514 | -23.578377 | -6.487054 | -1.238661 | -1.024779 | -0.927368 | -0.856639 | -0.758568 | … | 11.743630 | 12.755230 | 14.384780 | 16.579344 | 19.425726 | 36.155537 | 44.394543 | 60.144749 | 78.533859 | 101.363129 |
GBDOth | 2016-10-01 | 4.466261 | -21.124041 | -1.662255 | -1.157058 | -0.949211 | -0.857361 | -0.755605 | -0.699540 | -0.659419 | … | 10.405193 | 11.605769 | 12.686687 | 14.218900 | 19.963741 | 26.705273 | 34.361160 | 51.898552 | 68.361931 | 89.458908 |
GBDOth | 2016-11-01 | 3.689114 | -22.615982 | -11.813770 | -1.530864 | -1.049960 | -0.922807 | -0.868391 | -0.802971 | -0.723462 | … | 8.213260 | 8.837670 | 10.219457 | 12.300932 | 13.135829 | 23.325760 | 37.628525 | 43.993382 | 63.594315 | 84.825226 |
GBDOth | 2016-12-01 | 3.994789 | -38.856083 | -24.361221 | -7.503808 | -1.199999 | -1.003695 | -0.880594 | -0.788414 | -0.737489 | … | 9.881157 | 11.406334 | 12.636977 | 15.831536 | 26.059269 | 32.270000 | 37.316460 | 51.765774 | 68.933304 | 91.916100 |
Y_hat_mf
ds | XGBRegressor | XGBRegressor-lo-98 | XGBRegressor-lo-96 | XGBRegressor-lo-94 | XGBRegressor-lo-92 | XGBRegressor-lo-90 | XGBRegressor-lo-88 | XGBRegressor-lo-86 | XGBRegressor-lo-84 | … | XGBRegressor-hi-80 | XGBRegressor-hi-82 | XGBRegressor-hi-84 | XGBRegressor-hi-86 | XGBRegressor-hi-88 | XGBRegressor-hi-90 | XGBRegressor-hi-92 | XGBRegressor-hi-94 | XGBRegressor-hi-96 | XGBRegressor-hi-98 | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
unique_id | |||||||||||||||||||||
TotalAll | 2016-01-01 | 43060.226562 | 38276.974483 | 38677.670530 | 39078.366577 | 39479.062624 | 39879.758671 | 40009.218877 | 40041.809140 | 40074.399403 | … | 45980.873195 | 46013.463459 | 46046.053722 | 46078.643985 | 46111.234248 | 46240.694454 | 46641.390501 | 47042.086548 | 47442.782595 | 47843.478642 |
TotalAll | 2016-02-01 | 18008.296875 | 14687.962868 | 14813.816467 | 14939.670066 | 15065.523666 | 15191.377265 | 15247.400539 | 15278.484410 | 15309.568281 | … | 20644.857726 | 20675.941597 | 20707.025469 | 20738.109340 | 20769.193211 | 20825.216485 | 20951.070084 | 21076.923684 | 21202.777283 | 21328.630882 |
TotalAll | 2016-03-01 | 20694.080078 | 16407.351099 | 16594.149043 | 16780.946987 | 16967.744931 | 17154.542875 | 17209.434677 | 17217.217141 | 17224.999606 | … | 24147.595620 | 24155.378085 | 24163.160550 | 24170.943015 | 24178.725480 | 24233.617281 | 24420.415225 | 24607.213169 | 24794.011113 | 24980.809057 |
TotalAll | 2016-04-01 | 24474.349609 | 20859.120558 | 20978.737726 | 21098.354893 | 21217.972060 | 21337.589227 | 21380.287167 | 21395.513953 | 21410.740739 | … | 27507.504906 | 27522.731693 | 27537.958479 | 27553.185266 | 27568.412052 | 27611.109991 | 27730.727159 | 27850.344326 | 27969.961493 | 28089.578660 |
TotalAll | 2016-05-01 | 19281.087891 | 15045.235849 | 15460.108990 | 15874.982131 | 16289.855271 | 16704.728412 | 16861.927796 | 16927.100837 | 16992.273878 | … | 21439.555822 | 21504.728863 | 21569.901904 | 21635.074945 | 21700.247986 | 21857.447369 | 22272.320510 | 22687.193651 | 23102.066792 | 23516.939933 |
… | … | … | … | … | … | … | … | … | … | … | … | … | … | … | … | … | … | … | … | … | … |
GBDOth | 2016-08-01 | 11.040442 | -0.720264 | 0.934877 | 2.590017 | 4.245157 | 5.900298 | 6.396993 | 6.479957 | 6.562921 | … | 15.352035 | 15.435000 | 15.517964 | 15.600928 | 15.683892 | 16.180587 | 17.835727 | 19.490868 | 21.146008 | 22.801149 |
GBDOth | 2016-09-01 | 6.440751 | -0.275863 | -0.182214 | -0.088566 | 0.005083 | 0.098732 | 0.123376 | 0.123376 | 0.123376 | … | 12.758126 | 12.758126 | 12.758126 | 12.758126 | 12.758126 | 12.782771 | 12.876419 | 12.970068 | 13.063716 | 13.157365 |
GBDOth | 2016-10-01 | 9.995112 | 2.407870 | 2.407870 | 2.407870 | 2.407870 | 2.407870 | 2.407870 | 2.407870 | 2.407870 | … | 17.582355 | 17.582355 | 17.582355 | 17.582355 | 17.582355 | 17.582355 | 17.582355 | 17.582355 | 17.582355 | 17.582355 |
GBDOth | 2016-11-01 | 6.747566 | 2.791389 | 2.791389 | 2.791389 | 2.791389 | 2.791389 | 2.791389 | 2.791389 | 2.791389 | … | 10.703742 | 10.703742 | 10.703742 | 10.703742 | 10.703742 | 10.703742 | 10.703742 | 10.703742 | 10.703742 | 10.703742 |
GBDOth | 2016-12-01 | 7.367904 | 2.349200 | 2.349200 | 2.349200 | 2.349200 | 2.349200 | 2.349200 | 2.349200 | 2.349200 | … | 12.386609 | 12.386609 | 12.386609 | 12.386609 | 12.386609 | 12.386609 | 12.386609 | 12.386609 | 12.386609 | 12.386609 |
4. Reconcile Predictions
With minimal parsing, we can reconcile the raw output predictions with different HierarchicalForecast reconciliation methods.
reconcilers = [
ERM(method='closed'),
BottomUp(),
MinTrace('ols'),
]
hrec = HierarchicalReconciliation(reconcilers=reconcilers)
Y_rec_nf = hrec.reconcile(Y_hat_df=Y_hat_nf, Y_df=insample_nf, S=S_df, tags=tags, level=level)
Y_rec_mf = hrec.reconcile(Y_hat_df=Y_hat_mf, Y_df=insample_mf, S=S_df, tags=tags, level=level)
5. Evaluation
To evaluate we use a scaled variation of the CRPS, as proposed by
Rangapuram (2021), to measure the accuracy of predicted quantiles
y_hat
compared to the observation y
.
rec_model_names_nf = ['NBEATS/BottomUp', 'NBEATS/MinTrace_method-ols', 'NBEATS/ERM_method-closed_lambda_reg-0.01']
rec_model_names_mf = ['XGBRegressor/BottomUp', 'XGBRegressor/MinTrace_method-ols', 'XGBRegressor/ERM_method-closed_lambda_reg-0.01']
n_quantiles = len(quantiles)
n_series = len(S_df)
for name in rec_model_names_nf:
quantile_columns = [col for col in Y_rec_nf.columns if (name+'-lo') in col or (name+'-hi') in col]
y_rec = Y_rec_nf[quantile_columns].values
y_test = Y_test_df['y'].values
y_rec = y_rec.reshape(n_series, horizon, n_quantiles)
y_test = y_test.reshape(n_series, horizon)
scrps = scaled_crps(y=y_test, y_hat=y_rec, quantiles=quantiles)
print("{:<50} {:.3f}".format(name+":", scrps))
for name in rec_model_names_mf:
quantile_columns = [col for col in Y_rec_mf.columns if (name+'-lo') in col or (name+'-hi') in col]
y_rec = Y_rec_mf[quantile_columns].values
y_test = Y_test_df['y'].values
y_rec = y_rec.reshape(n_series, horizon, n_quantiles)
y_test = y_test.reshape(n_series, horizon)
scrps = scaled_crps(y=y_test, y_hat=y_rec, quantiles=quantiles)
print("{:<50} {:.3f}".format(name+":", scrps))
NBEATS/BottomUp: 0.129
NBEATS/MinTrace_method-ols: 0.129
NBEATS/ERM_method-closed_lambda_reg-0.01: 0.179
XGBRegressor/BottomUp: 0.134
XGBRegressor/MinTrace_method-ols: 0.178
XGBRegressor/ERM_method-closed_lambda_reg-0.01: 0.177
6. Visualizations
plot_nf = pd.concat([Y_df.set_index(['unique_id', 'ds']),
Y_rec_nf.set_index('ds', append=True)], axis=1)
plot_nf = plot_nf.reset_index('ds')
plot_mf = pd.concat([Y_df.set_index(['unique_id', 'ds']),
Y_rec_mf.set_index('ds', append=True)], axis=1)
plot_mf = plot_mf.reset_index('ds')
hplot.plot_series(
series='TotalVis',
Y_df=plot_nf,
models=['y', 'NBEATS', 'NBEATS/BottomUp', 'NBEATS/MinTrace_method-ols', 'NBEATS/ERM_method-closed_lambda_reg-0.01'],
level=[80]
)
hplot.plot_series(
series='TotalVis',
Y_df=plot_mf,
models=['y', 'XGBRegressor', 'XGBRegressor/BottomUp', 'XGBRegressor/MinTrace_method-ols', 'XGBRegressor/ERM_method-closed_lambda_reg-0.01'],
level=[80]
)