How-to guides
MLflow
How-to guides
- Exogenous features
- Lag transformations
- Hyperparameter optimization
- Using scikit-learn pipelines
- Sample weights
- Cross validation
- Probabilistic forecasting
- Target transformations
- Analyzing the trained models
- MLflow
- Transforming exogenous features
- Custom training
- Training with numpy arrays
- One model per step
- Custom date features
- Predict callbacks
- Predicting a subset of ids
- Transfer Learning
API Reference
- Local
- Distributed
How-to guides
MLflow
Log your metrics and models
Libraries
import copy
import subprocess
import time
import lightgbm as lgb
import mlflow
import pandas as pd
import requests
from sklearn.linear_model import LinearRegression
from utilsforecast.data import generate_series
from utilsforecast.losses import rmse, smape
from utilsforecast.evaluation import evaluate
from utilsforecast.feature_engineering import fourier
import mlforecast.flavor
from mlforecast import MLForecast
from mlforecast.lag_transforms import ExponentiallyWeightedMean
from mlforecast.utils import PredictionIntervals
Data setup
freq = 'h'
h = 10
series = generate_series(5, freq=freq)
valid = series.groupby('unique_id', observed=True).tail(h)
train = series.drop(valid.index)
train, X_df = fourier(train, freq=freq, season_length=24, k=2, h=h)
Parameters
params = {
'init': {
'models': {
'lgb': lgb.LGBMRegressor(
n_estimators=50, num_leaves=16, verbosity=-1
),
'lr': LinearRegression(),
},
'freq': freq,
'lags': [24],
'lag_transforms': {
1: [ExponentiallyWeightedMean(0.9)],
},
'num_threads': 2,
},
'fit': {
'static_features': ['unique_id'],
'prediction_intervals': PredictionIntervals(n_windows=2, h=h),
}
}
Logging
If you have a tracking server, you can run
mlflow.set_tracking_uri(your_server_uri)
to connect to it.
mlflow.set_experiment("mlforecast")
with mlflow.start_run() as run:
train_ds = mlflow.data.from_pandas(train)
valid_ds = mlflow.data.from_pandas(valid)
mlflow.log_input(train_ds, context="training")
mlflow.log_input(valid_ds, context="validation")
logged_params = copy.deepcopy(params)
logged_params['init']['models'] = {
k: (v.__class__.__name__, v.get_params())
for k, v in params['init']['models'].items()
}
mlflow.log_params(logged_params)
mlf = MLForecast(**params['init'])
mlf.fit(train, **params['fit'])
preds = mlf.predict(h, X_df=X_df)
eval_result = evaluate(
valid.merge(preds, on=['unique_id', 'ds']),
metrics=[rmse, smape],
agg_fn='mean',
)
models = mlf.models_.keys()
logged_metrics = {}
for _, row in eval_result.iterrows():
metric = row['metric']
for model in models:
logged_metrics[f'{metric}_{model}'] = row[model]
mlflow.log_metrics(logged_metrics)
mlforecast.flavor.log_model(model=mlf, artifact_path="model")
model_uri = mlflow.get_artifact_uri("model")
run_id = run.info.run_id
/home/ubuntu/repos/mlforecast/.venv/lib/python3.10/site-packages/mlflow/types/utils.py:406: UserWarning: Hint: Inferred schema contains integer column(s). Integer columns in Python cannot represent missing values. If your input data contains missing values at inference time, it will be encoded as floats and will cause a schema enforcement error. The best way to avoid this problem is to infer the model schema based on a realistic data sample (training dataset) that includes missing values. Alternatively, you can declare integer columns as doubles (float64) whenever these columns may have missing values. See `Handling Integers With Missing Values <https://www.mlflow.org/docs/latest/models.html#handling-integers-with-missing-values>`_ for more details.
warnings.warn(
2024/08/23 02:57:14 WARNING mlflow.models.model: Input example should be provided to infer model signature if the model signature is not provided when logging the model.
Load model
loaded_model = mlforecast.flavor.load_model(model_uri=model_uri)
results = loaded_model.predict(h=h, X_df=X_df, ids=[3])
results.head(2)
unique_id | ds | lgb | lr | |
---|---|---|---|---|
0 | 3 | 2000-01-10 16:00:00 | 0.333308 | 0.243017 |
1 | 3 | 2000-01-10 17:00:00 | 0.127424 | 0.249742 |
PyFunc
loaded_pyfunc = mlforecast.flavor.pyfunc.load_model(model_uri=model_uri)
# single row dataframe
predict_conf = pd.DataFrame(
[
{
"h": h,
"ids": [0, 2],
"X_df": X_df,
"level": [80]
}
]
)
pyfunc_result = loaded_pyfunc.predict(predict_conf)
pyfunc_result.head(2)
unique_id | ds | lgb | lr | lgb-lo-80 | lgb-hi-80 | lr-lo-80 | lr-hi-80 | |
---|---|---|---|---|---|---|---|---|
0 | 0 | 2000-01-09 20:00:00 | 0.260544 | 0.244128 | 0.140168 | 0.380921 | 0.114001 | 0.374254 |
1 | 0 | 2000-01-09 21:00:00 | 0.250096 | 0.247742 | 0.072820 | 0.427372 | 0.047584 | 0.447900 |
Model serving
host = 'localhost'
port = '5000'
cmd = f'mlflow models serve -m runs:/{run_id}/model -h {host} -p {port} --env-manager local'
# initialize server
process = subprocess.Popen(cmd.split())
time.sleep(5)
# single row dataframe. must be JSON serializable
predict_conf = pd.DataFrame(
[
{
"h": h,
"ids": [3, 4],
"X_df": X_df.astype({'ds': 'str'}).to_dict(orient='list'),
"level": [95]
}
]
)
payload = {'dataframe_split': predict_conf.to_dict(orient='split', index=False)}
resp = requests.post(f'http://{host}:{port}/invocations', json=payload)
print(pd.DataFrame(resp.json()['predictions']).head(2))
process.terminate()
process.wait(timeout=10)
Downloading artifacts: 100%|██████████| 7/7 [00:00<00:00, 18430.71it/s]
2024/08/23 02:57:16 INFO mlflow.models.flavor_backend_registry: Selected backend for flavor 'python_function'
2024/08/23 02:57:16 INFO mlflow.pyfunc.backend: === Running command 'exec gunicorn --timeout=60 -b localhost:5000 -w 1 ${GUNICORN_CMD_ARGS} -- mlflow.pyfunc.scoring_server.wsgi:app'
[2024-08-23 02:57:16 +0000] [23054] [INFO] Starting gunicorn 22.0.0
[2024-08-23 02:57:16 +0000] [23054] [INFO] Listening at: http://127.0.0.1:5000 (23054)
[2024-08-23 02:57:16 +0000] [23054] [INFO] Using worker: sync
[2024-08-23 02:57:16 +0000] [23055] [INFO] Booting worker with pid: 23055
unique_id ds lgb lr lgb-lo-95 lgb-hi-95 \
0 3 2000-01-10T16:00:00 0.333308 0.243017 0.174073 0.492544
1 3 2000-01-10T17:00:00 0.127424 0.249742 -0.009993 0.264842
lr-lo-95 lr-hi-95
0 0.032451 0.453583
1 0.045525 0.453959
[2024-08-23 02:57:20 +0000] [23054] [INFO] Handling signal: term
[2024-08-23 02:57:20 +0000] [23055] [INFO] Worker exiting (pid: 23055)
[2024-08-23 02:57:21 +0000] [23054] [INFO] Shutting down: Master