In this notebook we show how to use
StatsForecast
and ray
to forecast thounsands of time series in less than 6 minutes
(M5 dataset). Also, we show that
StatsForecast
has better performance in time and accuracy compared to Prophet
running on a Spark
cluster
using DataBricks.
In this example, we used a ray cluster (AWS) of 11 instances of type
m5.2xlarge (8 cores, 32 GB RAM).
Installing StatsForecast Library
!pip install "statsforecast[ray]" neuralforecast s3fs pyarrow
from time import time
import pandas as pd
from neuralforecast.data.datasets.m5 import M5, M5Evaluation
from statsforecast import StatsForecast
from statsforecast.models import ETS
Download data
The example uses the M5
dataset.
It consists of 30,490
bottom time series.
Y_df = pd.read_parquet('s3://m5-benchmarks/data/train/target.parquet')
Y_df = Y_df.rename(columns={
'item_id': 'unique_id',
'timestamp': 'ds',
'demand': 'y'
})
Y_df['ds'] = pd.to_datetime(Y_df['ds'])
| unique_id | ds | y |
---|
0 | FOODS_1_001_CA_1 | 2011-01-29 | 3.0 |
1 | FOODS_1_001_CA_1 | 2011-01-30 | 0.0 |
2 | FOODS_1_001_CA_1 | 2011-01-31 | 0.0 |
3 | FOODS_1_001_CA_1 | 2011-02-01 | 1.0 |
4 | FOODS_1_001_CA_1 | 2011-02-02 | 4.0 |
Since the M5 dataset contains intermittent time series, we add a
constant to avoid problems during the training phase. Later, we will
substract the constant from the forecasts.
constant = 10
Y_df['y'] += constant
Train the model
StatsForecast
receives a list of models to fit each time series. Since we are dealing
with Daily data, it would be benefitial to use 7 as seasonality. Observe
that we need to pass the ray address to the ray_address
argument.
fcst = StatsForecast(
df=Y_df,
models=[ETS(season_length=7, model='ZNA')],
freq='D',
ray_address='ray://ADDRESS:10001'
)
init = time()
Y_hat = fcst.forecast(28)
end = time()
print(f'Minutes taken by StatsForecast using: {(end - init) / 60}')
/home/ubuntu/miniconda/envs/ray/lib/python3.7/site-packages/ray/util/client/worker.py:618: UserWarning: More than 10MB of messages have been created to schedule tasks on the server. This can be slow on Ray Client due to communication overhead over the network. If you're running many fine-grained tasks, consider running them inside a single remote function. See the section on "Too fine-grained tasks" in the Ray Design Patterns document for more details: https://docs.google.com/document/d/167rnnDFIVRhHhK4mznEIemOtj63IOhtIPvSYaPgI4Fg/edit#heading=h.f7ins22n6nyl. If your functions frequently use large objects, consider storing the objects remotely with ray.put. An example of this is shown in the "Closure capture of large / unserializable object" section of the Ray Design Patterns document, available here: https://docs.google.com/document/d/167rnnDFIVRhHhK4mznEIemOtj63IOhtIPvSYaPgI4Fg/edit#heading=h.1afmymq455wu
UserWarning,
Minutes taken by StatsForecast using: 5.4817593971888225
StatsForecast
and ray
took only 5.48 minutes to train 30,490
time series, compared
to 18.23 minutes for Prophet and Spark.
We remove the constant.
The M5 competition used the weighted root mean squared scaled error. You
can find details of the metric
here.
Y_hat = Y_hat.reset_index().set_index(['unique_id', 'ds']).unstack()
Y_hat = Y_hat.droplevel(0, 1).reset_index()
*_, S_df = M5.load('./data')
Y_hat = S_df.merge(Y_hat, how='left', on=['unique_id'])
100%|███████████████████████████████████████████████████████████| 50.2M/50.2M [00:00<00:00, 77.1MiB/s]
M5Evaluation.evaluate(y_hat=Y_hat, directory='./data')
| wrmsse |
---|
Total | 0.677233 |
Level1 | 0.435558 |
Level2 | 0.522863 |
Level3 | 0.582109 |
Level4 | 0.488484 |
Level5 | 0.567825 |
Level6 | 0.587605 |
Level7 | 0.662774 |
Level8 | 0.647712 |
Level9 | 0.732107 |
Level10 | 1.013124 |
Level11 | 0.970465 |
Level12 | 0.916175 |
Also,
StatsForecast
is more accurate than Prophet, since the overall WMRSSE is 0.68
,
against 0.77
obtained by prophet.