import datetime

import tempfile
from nbdev import show_doc
from fastcore.test import test_eq, test_fail, test_warns
from window_ops.expanding import expanding_mean
from window_ops.rolling import rolling_mean
from window_ops.shift import shift_array

from mlforecast.callbacks import SaveFeatures
from mlforecast.lag_transforms import ExpandingMean, RollingMean
from mlforecast.target_transforms import Differences, LocalStandardScaler
from mlforecast.utils import generate_daily_series, generate_prices_for_series

Data format

The required input format is a dataframe with at least the following columns: * unique_id with a unique identifier for each time serie * ds with the datestamp and a column * y with the values of the serie.

Every other column is considered a static feature unless stated otherwise in TimeSeries.fit

series = generate_daily_series(20, n_static_features=2)
series
unique_iddsystatic_0static_1
0id_002000-01-017.4045292753
1id_002000-01-0235.9526242753
2id_002000-01-0368.9583532753
3id_002000-01-0484.9945052753
4id_002000-01-05113.2198102753
4869id_192000-03-25400.6068079745
4870id_192000-03-26538.7948249745
4871id_192000-03-27620.2021049745
4872id_192000-03-2820.6254269745
4873id_192000-03-29141.5131699745

For simplicity we’ll just take one time serie here.

uids = series['unique_id'].unique()
serie = series[series['unique_id'].eq(uids[0])]
serie
unique_iddsystatic_0static_1
0id_002000-01-017.4045292753
1id_002000-01-0235.9526242753
2id_002000-01-0368.9583532753
3id_002000-01-0484.9945052753
4id_002000-01-05113.2198102753
217id_002000-08-0513.2631882753
218id_002000-08-0638.2319812753
219id_002000-08-0759.5551832753
220id_002000-08-0886.9863682753
221id_002000-08-09119.2548102753

source

TimeSeries

 TimeSeries (freq:Union[int,str], lags:Optional[Iterable[int]]=None, lag_t
             ransforms:Optional[Dict[int,List[Union[Callable,Tuple[Callabl
             e,Any]]]]]=None,
             date_features:Optional[Iterable[Union[str,Callable]]]=None,
             num_threads:int=1, target_transforms:Optional[List[Union[mlfo
             recast.target_transforms.BaseTargetTransform,mlforecast.targe
             t_transforms._BaseGroupedArrayTargetTransform]]]=None,
             lag_transforms_namer:Optional[Callable]=None)

Utility class for storing and transforming time series data.

The TimeSeries class takes care of defining the transformations to be performed (lags, lag_transforms and date_features). The transformations can be computed using multithreading if num_threads > 1.

def month_start_or_end(dates):
    return dates.is_month_start | dates.is_month_end

flow_config = dict(
    freq='W-THU',
    lags=[7],
    lag_transforms={
        1: [expanding_mean, (rolling_mean, 7)]
    },
    date_features=['dayofweek', 'week', month_start_or_end]
)

ts = TimeSeries(**flow_config)
ts
TimeSeries(freq=W-THU, transforms=['lag7', 'expanding_mean_lag1', 'rolling_mean_lag1_window_size7'], date_features=['dayofweek', 'week', 'month_start_or_end'], num_threads=1)

The frequency is converted to an offset.

test_eq(ts.freq, pd.tseries.frequencies.to_offset(flow_config['freq']))

The date features are stored as they were passed to the constructor.

test_eq(ts.date_features, flow_config['date_features'])

The transformations are stored as a dictionary where the key is the name of the transformation (name of the column in the dataframe with the computed features), which is built using build_transform_name and the value is a tuple where the first element is the lag it is applied to, then the function and then the function arguments.

test_eq(
    ts.transforms, 
    {
        'lag7': Lag(7),
        'expanding_mean_lag1': (1, expanding_mean), 
        'rolling_mean_lag1_window_size7': (1, rolling_mean, 7)
        
    }
)

Note that for lags we define the transformation as the identity function applied to its corresponding lag. This is because _transform_series takes the lag as an argument and shifts the array before computing the transformation.


source

TimeSeries.fit_transform

 TimeSeries.fit_transform (data:Union[pandas.core.frame.DataFrame,polars.d
                           ataframe.frame.DataFrame], id_col:str,
                           time_col:str, target_col:str,
                           static_features:Optional[List[str]]=None,
                           dropna:bool=True,
                           keep_last_n:Optional[int]=None,
                           max_horizon:Optional[int]=None,
                           return_X_y:bool=False, as_numpy:bool=False)

*Add the features to data and save the required information for the predictions step.

If not all features are static, specify which ones are in static_features. If you don’t want to drop rows with null values after the transformations set dropna=False If keep_last_n is not None then that number of observations is kept across all series for updates.*

flow_config = dict(
    freq='D',
    lags=[7, 14],
    lag_transforms={
        2: [
            (rolling_mean, 7),
            (rolling_mean, 14),
        ]
    },
    date_features=['dayofweek', 'month', 'year'],
    num_threads=2
)

ts = TimeSeries(**flow_config)
_ = ts.fit_transform(series, id_col='unique_id', time_col='ds', target_col='y')

The series values are stored as a GroupedArray in an attribute ga. If the data type of the series values is an int then it is converted to np.float32, this is because lags generate np.nans so we need a float data type for them.

np.testing.assert_equal(ts.ga.data, series.y.values)

The series ids are stored in an uids attribute.

test_eq(ts.uids, series['unique_id'].unique())

For each time serie, the last observed date is stored so that predictions start from the last date + the frequency.

test_eq(ts.last_dates, series.groupby('unique_id', observed=True)['ds'].max().values)

The last row of every serie without the y and ds columns are taken as static features.

pd.testing.assert_frame_equal(
    ts.static_features_,
    series.groupby('unique_id', observed=True).tail(1).drop(columns=['ds', 'y']).reset_index(drop=True),
)

If you pass static_features to TimeSeries.fit_transform then only these are kept.

ts.fit_transform(series, id_col='unique_id', time_col='ds', target_col='y', static_features=['static_0'])

pd.testing.assert_frame_equal(
    ts.static_features_,
    series.groupby('unique_id', observed=True).tail(1)[['unique_id', 'static_0']].reset_index(drop=True),
)

You can also specify keep_last_n in TimeSeries.fit_transform, which means that after computing the features for training we want to keep only the last n samples of each time serie for computing the updates. This saves both memory and time, since the updates are performed by running the transformation functions on all time series again and keeping only the last value (the update).

If you have very long time series and your updates only require a small sample it’s recommended that you set keep_last_n to the minimum number of samples required to compute the updates, which in this case is 15 since we have a rolling mean of size 14 over the lag 2 and in the first update the lag 2 becomes the lag 1. This is because in the first update the lag 1 is the last value of the series (or the lag 0), the lag 2 is the lag 1 and so on.

keep_last_n = 15

ts = TimeSeries(**flow_config)
df = ts.fit_transform(series, id_col='unique_id', time_col='ds', target_col='y', keep_last_n=keep_last_n)
ts._predict_setup()

expected_lags = ['lag7', 'lag14']
expected_transforms = ['rolling_mean_lag2_window_size7', 
                       'rolling_mean_lag2_window_size14']
expected_date_features = ['dayofweek', 'month', 'year']

test_eq(ts.features, expected_lags + expected_transforms + expected_date_features)
test_eq(ts.static_features_.columns.tolist() + ts.features, df.columns.drop(['ds', 'y']).tolist())
# we dropped 2 rows because of the lag 2 and 13 more to have the window of size 14
test_eq(df.shape[0], series.shape[0] - (2 + 13) * ts.ga.n_groups)
test_eq(ts.ga.data.size, ts.ga.n_groups * keep_last_n)

TimeSeries.fit_transform requires that the y column doesn’t have any null values. This is because the transformations could propagate them forward, so if you have null values in the y column you’ll get an error.

series_with_nulls = series.copy()
series_with_nulls.loc[1, 'y'] = np.nan
test_fail(
    lambda: ts.fit_transform(series_with_nulls, id_col='unique_id', time_col='ds', target_col='y'),
    contains='y column contains null values'
)

source

TimeSeries.predict

 TimeSeries.predict (models:Dict[str,Union[sklearn.base.BaseEstimator,List
                     [sklearn.base.BaseEstimator]]], horizon:int,
                     before_predict_callback:Optional[Callable]=None,
                     after_predict_callback:Optional[Callable]=None, X_df:
                     Union[pandas.core.frame.DataFrame,polars.dataframe.fr
                     ame.DataFrame,NoneType]=None,
                     ids:Optional[List[str]]=None)

Once we have a trained model we can use TimeSeries.predict passing the model and the horizon to get the predictions back.

class DummyModel:
    def predict(self, X: pd.DataFrame) -> np.ndarray:
        return X['lag7'].values

horizon = 7
model = DummyModel()
ts = TimeSeries(**flow_config)
ts.fit_transform(series, id_col='unique_id', time_col='ds', target_col='y')
predictions = ts.predict({'DummyModel': model}, horizon)

grouped_series = series.groupby('unique_id', observed=True)
expected_preds = grouped_series['y'].tail(7)  # the model predicts the lag-7
last_dates = grouped_series['ds'].max()
expected_dsmin = last_dates + pd.offsets.Day()
expected_dsmax = last_dates + horizon * pd.offsets.Day()
grouped_preds = predictions.groupby('unique_id', observed=True)

np.testing.assert_allclose(predictions['DummyModel'], expected_preds)
pd.testing.assert_series_equal(grouped_preds['ds'].min(), expected_dsmin)
pd.testing.assert_series_equal(grouped_preds['ds'].max(), expected_dsmax)

If we have dynamic features we can pass them to X_df.

class PredictPrice:
    def predict(self, X):
        return X['price']

series = generate_daily_series(20, n_static_features=2, equal_ends=True)
dynamic_series = series.rename(columns={'static_1': 'product_id'})
prices_catalog = generate_prices_for_series(dynamic_series)
series_with_prices = dynamic_series.merge(prices_catalog, how='left')

model = PredictPrice()
ts = TimeSeries(**flow_config)
ts.fit_transform(
    series_with_prices,
    id_col='unique_id',
    time_col='ds',
    target_col='y',
    static_features=['static_0', 'product_id'],
)
predictions = ts.predict({'PredictPrice': model}, horizon=1, X_df=prices_catalog)
pd.testing.assert_frame_equal(
    predictions.rename(columns={'PredictPrice': 'price'}),
    prices_catalog.merge(predictions[['unique_id', 'ds']])[['unique_id', 'ds', 'price']]
)

source

TimeSeries.update

 TimeSeries.update
                    (df:Union[pandas.core.frame.DataFrame,polars.dataframe
                    .frame.DataFrame])

Update the values of the stored series.