Skip to main content

LightGBMCV

LightGBMCV(
    freq,
    lags=None,
    lag_transforms=None,
    date_features=None,
    num_threads=1,
    target_transforms=None,
)
Create LightGBM CV object. Parameters:
NameTypeDescriptionDefault
freqstr or intPandas offset alias, e.g. ‘D’, ‘W-THU’ or integer denoting the frequency of the series.required
lagslist of intLags of the target to use as features. Defaults to None.None
lag_transformsdict of int to list of functionsMapping of target lags to their transformations. Defaults to None.None
date_featureslist of str or callableFeatures computed from the dates. Can be pandas date attributes or functions that will take the dates as input. Defaults to None.None
num_threadsintNumber of threads to use when computing the features. Defaults to 1.1
target_transformslist of transformersTransformations that will be applied to the target before computing the features and restored after the forecasting step. Defaults to None.None

LightGBMCV.fit

fit(
    df,
    n_windows,
    h,
    id_col="unique_id",
    time_col="ds",
    target_col="y",
    step_size=None,
    num_iterations=100,
    params=None,
    static_features=None,
    dropna=True,
    keep_last_n=None,
    eval_every=10,
    weights=None,
    metric="mape",
    verbose_eval=True,
    early_stopping_evals=2,
    early_stopping_pct=0.01,
    compute_cv_preds=False,
    before_predict_callback=None,
    after_predict_callback=None,
    input_size=None,
)
Train boosters simultaneously and assess their performance on the complete forecasting window. Parameters:
NameTypeDescriptionDefault
dfpandas DataFrameSeries data in long format.required
n_windowsintNumber of windows to evaluate.required
hintForecast horizon.required
id_colstrColumn that identifies each serie. Defaults to ‘unique_id’.‘unique_id’
time_colstrColumn that identifies each timestep, its values can be timestamps or integers. Defaults to ‘ds’.‘ds’
target_colstrColumn that contains the target. Defaults to ‘y’.‘y’
step_sizeintStep size between each cross validation window. If None it will be equal to h. Defaults to None.None
num_iterationsintMaximum number of boosting iterations to run. Defaults to 100.100
paramsdictParameters to be passed to the LightGBM Boosters. Defaults to None.None
static_featureslist of strNames of the features that are static and will be repeated when forecasting. Defaults to None.None
dropnaboolDrop rows with missing values produced by the transformations. Defaults to True.True
keep_last_nintKeep only these many records from each serie for the forecasting step. Can save time and memory if your features allow it. Defaults to None.None
eval_everyintNumber of boosting iterations to train before evaluating on the whole forecast window. Defaults to 10.10
weightssequence of floatWeights to multiply the metric of each window. If None, all windows have the same weight. Defaults to None.None
metricstr or callableMetric used to assess the performance of the models and perform early stopping. Defaults to ‘mape’.‘mape’
verbose_evalboolPrint the metrics of each evaluation.True
early_stopping_evalsintMaximum number of evaluations to run without improvement. Defaults to 2.2
early_stopping_pctfloatMinimum percentage improvement in metric value in early_stopping_evals evaluations. Defaults to 0.01.0.01
compute_cv_predsboolCompute predictions for each window after finding the best iteration. Defaults to False.False
before_predict_callbackcallableFunction to call on the features before computing the predictions. This function will take the input dataframe that will be passed to the model for predicting and should return a dataframe with the same structure. The series identifier is on the index. Defaults to None.None
after_predict_callbackcallableFunction to call on the predictions before updating the targets. This function will take a pandas Series with the predictions and should return another one with the same structure. The series identifier is on the index. Defaults to None.None
input_sizeintMaximum training samples per serie in each window. If None, will use an expanding window. Defaults to None.None
Returns:
TypeDescription
list of tupleList of (boosting rounds, metric value) tuples.

LightGBMCV.predict

predict(
    h, before_predict_callback=None, after_predict_callback=None, X_df=None
)
Compute predictions with each of the trained boosters. Parameters:
NameTypeDescriptionDefault
hintForecast horizon.required
before_predict_callbackcallableFunction to call on the features before computing the predictions. This function will take the input dataframe that will be passed to the model for predicting and should return a dataframe with the same structure. The series identifier is on the index. Defaults to None.None
after_predict_callbackcallableFunction to call on the predictions before updating the targets. This function will take a pandas Series with the predictions and should return another one with the same structure. The series identifier is on the index. Defaults to None.None
X_dfDataFrameDataframe with the future exogenous features. Should have the id column and the time column. Defaults to None.None
Returns:
TypeDescription
DataFramePredictions for each serie and timestep, with one column per window.

LightGBMCV.setup

setup(
    df,
    n_windows,
    h,
    id_col="unique_id",
    time_col="ds",
    target_col="y",
    step_size=None,
    params=None,
    static_features=None,
    dropna=True,
    keep_last_n=None,
    weights=None,
    metric="mape",
    input_size=None,
)
Initialize internal data structures to iteratively train the boosters. Use this before calling partial_fit. Parameters:
NameTypeDescriptionDefault
dfpandas DataFrameSeries data in long format.required
n_windowsintNumber of windows to evaluate.required
hintForecast horizon.required
id_colstrColumn that identifies each serie. Defaults to ‘unique_id’.‘unique_id’
time_colstrColumn that identifies each timestep, its values can be timestamps or integers. Defaults to ‘ds’.‘ds’
target_colstrColumn that contains the target. Defaults to ‘y’.‘y’
step_sizeintStep size between each cross validation window. If None it will be equal to h. Defaults to None.None
paramsdictParameters to be passed to the LightGBM Boosters. Defaults to None.None
static_featureslist of strNames of the features that are static and will be repeated when forecasting. Defaults to None.None
dropnaboolDrop rows with missing values produced by the transformations. Defaults to True.True
keep_last_nintKeep only these many records from each serie for the forecasting step. Can save time and memory if your features allow it. Defaults to None.None
weightssequence of floatWeights to multiply the metric of each window. If None, all windows have the same weight. Defaults to None.None
metricstr or callableMetric used to assess the performance of the models and perform early stopping. Defaults to ‘mape’.‘mape’
input_sizeintMaximum training samples per serie in each window. If None, will use an expanding window. Defaults to None.None
Returns:
TypeDescription
LightGBMCVCV object with internal data structures for partial_fit.

LightGBMCV.partial_fit

partial_fit(
    num_iterations, before_predict_callback=None, after_predict_callback=None
)
Train the boosters for some iterations. Parameters:
NameTypeDescriptionDefault
num_iterationsintNumber of boosting iterations to runrequired
before_predict_callbackcallableFunction to call on the features before computing the predictions. This function will take the input dataframe that will be passed to the model for predicting and should return a dataframe with the same structure. The series identifier is on the index. Defaults to None.None
after_predict_callbackcallableFunction to call on the predictions before updating the targets. This function will take a pandas Series with the predictions and should return another one with the same structure. The series identifier is on the index. Defaults to None.None
Returns:
TypeDescription
floatWeighted metric after training for num_iterations.

Example

This shows an example with just 4 series of the M4 dataset. If you want to run it yourself on all of them, you can refer to this notebook.
import random

from datasetsforecast.m4 import M4, M4Info
from fastcore.test import test_eq, test_fail
from mlforecast.target_transforms import Differences
from nbdev import show_doc

from mlforecast.lag_transforms import SeasonalRollingMean
group = 'Hourly'
await M4.async_download('data', group=group)
df, *_ = M4.load(directory='data', group=group)
df['ds'] = df['ds'].astype('int')
ids = df['unique_id'].unique()
random.seed(0)
sample_ids = random.choices(ids, k=4)
sample_df = df[df['unique_id'].isin(sample_ids)]
sample_df
unique_iddsy
86796H196111.8
86797H196211.4
86798H196311.1
86799H196410.8
86800H196510.6
325235H413100499.0
325236H413100588.0
325237H413100647.0
325238H413100741.0
325239H413100834.0
info = M4Info[group]
horizon = info.horizon
valid = sample_df.groupby('unique_id').tail(horizon)
train = sample_df.drop(valid.index)
train.shape, valid.shape
((3840, 3), (192, 3))
What LightGBMCV does is emulate LightGBM’s cv function where several Boosters are trained simultaneously on different partitions of the data, that is, one boosting iteration is performed on all of them at a time. This allows to have an estimate of the error by iteration, so if we combine this with early stopping we can find the best iteration to train a final model using all the data or even use these individual models’ predictions to compute an ensemble. In order to have a good estimate of the forecasting performance of our model we compute predictions for the whole test period and compute a metric on that. Since this step can slow down training, there’s an eval_every parameter that can be used to control this, that is, if eval_every=10 (the default) every 10 boosting iterations we’re going to compute forecasts for the complete window and report the error. We also have early stopping parameters:
  • early_stopping_evals: how many evaluations of the full window should we go without improving to stop training?
  • early_stopping_pct: what’s the minimum percentage improvement we want in these early_stopping_evals in order to keep training?
This makes the LightGBMCV class a good tool to quickly test different configurations of the model. Consider the following example, where we’re going to try to find out which features can improve the performance of our model. We start just using lags.
static_fit_config = dict(
    n_windows=2,
    h=horizon,
    params={'verbose': -1},
    compute_cv_preds=True,
)
cv = LightGBMCV(
    freq=1,
    lags=[24 * (i+1) for i in range(7)],  # one week of lags
)
hist = cv.fit(train, **static_fit_config)
[LightGBM] [Info] Start training from score 51.745632
[10] mape: 0.590690
[20] mape: 0.251093
[30] mape: 0.143643
[40] mape: 0.109723
[50] mape: 0.102099
[60] mape: 0.099448
[70] mape: 0.098349
[80] mape: 0.098006
[90] mape: 0.098718
Early stopping at round 90
Using best iteration: 80
By setting compute_cv_preds we get the predictions from each model on their corresponding validation fold.
cv.cv_preds_
unique_iddsyBoosterwindow
0H19686515.515.5229240
1H19686615.114.9858320
2H19686714.814.6679010
3H19686814.414.5145920
4H19686914.214.0357930
187H41395659.077.2279051
188H41395758.080.5896411
189H41395853.053.9868341
190H41395938.036.7497861
191H41396046.036.2812251
The individual models we trained are saved, so calling predict returns the predictions from every model trained.
source
preds = cv.predict(horizon)
preds
unique_iddsBooster0Booster1
0H19696115.67025215.848888
1H19696215.52292415.697399
2H19696314.98583215.166213
3H19696414.98583214.723238
4H19696514.56215214.451092
187H413100470.69524265.917620
188H413100566.21658062.615788
189H413100663.89657367.848598
190H413100746.92279750.981950
191H413100845.00654142.752819
We can average these predictions and evaluate them.
def evaluate_on_valid(preds):
    preds = preds.copy()
    preds['final_prediction'] = preds.drop(columns=['unique_id', 'ds']).mean(1)
    merged = preds.merge(valid, on=['unique_id', 'ds'])
    merged['abs_err'] = abs(merged['final_prediction'] - merged['y']) / merged['y']
    return merged.groupby('unique_id')['abs_err'].mean().mean()
eval1 = evaluate_on_valid(preds)
eval1
0.11036194712311806
Now, since these series are hourly, maybe we can try to remove the daily seasonality by taking the 168th (24 * 7) difference, that is, substract the value at the same hour from one week ago, thus our target will be zt=ytyt168z_t = y_{t} - y_{t-168}. The features will be computed from this target and when we predict they will be automatically re-applied.
cv2 = LightGBMCV(
    freq=1,
    target_transforms=[Differences([24 * 7])],
    lags=[24 * (i+1) for i in range(7)],
)
hist2 = cv2.fit(train, **static_fit_config)
[LightGBM] [Info] Start training from score 0.519010
[10] mape: 0.089024
[20] mape: 0.090683
[30] mape: 0.092316
Early stopping at round 30
Using best iteration: 10
assert hist2[-1][1] < hist[-1][1]
Nice! We achieve a better score in less iterations. Let’s see if this improvement translates to the validation set as well.
preds2 = cv2.predict(horizon)
eval2 = evaluate_on_valid(preds2)
eval2
0.08956665504570135
assert eval2 < eval1
Great! Maybe we can try some lag transforms now. We’ll try the seasonal rolling mean that averages the values “every season”, that is, if we set season_length=24 and window_size=7 then we’ll average the value at the same hour for every day of the week.
cv3 = LightGBMCV(
    freq=1,
    target_transforms=[Differences([24 * 7])],
    lags=[24 * (i+1) for i in range(7)],
    lag_transforms={
        48: [SeasonalRollingMean(season_length=24, window_size=7)],
    },
)
hist3 = cv3.fit(train, **static_fit_config)
[LightGBM] [Info] Start training from score 0.273641
[10] mape: 0.086724
[20] mape: 0.088466
[30] mape: 0.090536
Early stopping at round 30
Using best iteration: 10
Seems like this is helping as well!
assert hist3[-1][1] < hist2[-1][1]
Does this reflect on the validation set?
preds3 = cv3.predict(horizon)
eval3 = evaluate_on_valid(preds3)
eval3
0.08961279023129345
Nice! mlforecast also supports date features, but in this case our time column is made from integers so there aren’t many possibilites here. As you can see this allows you to iterate faster and get better estimates of the forecasting performance you can expect from your model. If you’re doing hyperparameter tuning it’s useful to be able to run a couple of iterations, assess the performance, and determine if this particular configuration isn’t promising and should be discarded. For example, optuna has pruners that you can call with your current score and it decides if the trial should be discarded. We’ll now show how to do that. Since the CV requires a bit of setup, like the LightGBM datasets and the internal features, we have this setup method.
cv4 = LightGBMCV(
    freq=1,
    lags=[24 * (i+1) for i in range(7)],
)
cv4.setup(
    train,
    n_windows=2,
    h=horizon,
    params={'verbose': -1},
)
LightGBMCV(freq=1, lag_features=['lag24', 'lag48', 'lag72', 'lag96', 'lag120', 'lag144', 'lag168'], date_features=[], num_threads=1, bst_threads=8)
Once we have this we can call partial_fit to only train for some iterations and return the score of the forecast window.
score = cv4.partial_fit(10)
score
[LightGBM] [Info] Start training from score 51.745632
0.5906900462828166
This is equal to the first evaluation from our first example.
assert hist[0][1] == score
We can now use this score to decide if this configuration is promising. If we want to we can train some more iterations.
score2 = cv4.partial_fit(20)
This is now equal to our third metric from the first example, since this time we trained for 20 iterations.
assert hist[2][1] == score2

Using a custom metric

The built-in metrics are MAPE and RMSE, which are computed by serie and then averaged across all series. If you want to do something different or use a different metric entirely, you can define your own metric like the following:
def weighted_mape(
    y_true: pd.Series,
    y_pred: pd.Series,
    ids: pd.Series,
    dates: pd.Series,
):
    """Weighs the MAPE by the magnitude of the series values"""
    abs_pct_err = abs(y_true - y_pred) / abs(y_true)
    mape_by_serie = abs_pct_err.groupby(ids).mean()
    totals_per_serie = y_pred.groupby(ids).sum()
    series_weights = totals_per_serie / totals_per_serie.sum()
    return (mape_by_serie * series_weights).sum()
_ = LightGBMCV(
    freq=1,
    lags=[24 * (i+1) for i in range(7)],
).fit(
    train,
    n_windows=2,
    h=horizon,
    params={'verbose': -1},
    metric=weighted_mape,
)
[LightGBM] [Info] Start training from score 51.745632
[10] weighted_mape: 0.480353
[20] weighted_mape: 0.218670
[30] weighted_mape: 0.161706
[40] weighted_mape: 0.149992
[50] weighted_mape: 0.149024
[60] weighted_mape: 0.148496
Early stopping at round 60
Using best iteration: 60