BaseAuto
*Class for Automatic Hyperparameter Optimization, it builds on top of
ray
to give access to a wide variety of hyperparameter optimization
tools ranging from classic grid search, to Bayesian optimization and
HyperBand algorithm.
The validation loss to be optimized is defined by the config['loss']
dictionary value, the config also contains the rest of the
hyperparameter search space.
It is important to note that the success of this hyperparameter
optimization heavily relies on a strong correlation between the
validation and test periods.*
Type | Default | Details | |
---|---|---|---|
cls_model | PyTorch/PyTorchLightning model | See neuralforecast.models collection here. | |
h | int | Forecast horizon | |
loss | PyTorch module | Instantiated train loss class from losses collection. | |
valid_loss | PyTorch module | Instantiated valid loss class from losses collection. | |
config | dict or callable | Dictionary with ray.tune defined search space or function that takes an optuna trial and returns a configuration dict. | |
search_alg | BasicVariantGenerator | <ray.tune.search.basic_variant.BasicVariantGenerator object at 0x7f820028a2f0> | For ray see https://docs.ray.io/en/latest/tune/api_docs/suggestion.html For optuna see https://optuna.readthedocs.io/en/stable/reference/samplers/index.html. |
num_samples | int | 10 | Number of hyperparameter optimization steps/samples. |
cpus | int | 4 | Number of cpus to use during optimization. Only used with ray tune. |
gpus | int | 0 | Number of gpus to use during optimization, default all available. Only used with ray tune. |
refit_with_val | bool | False | Refit of best model should preserve val_size. |
verbose | bool | False | Track progress. |
alias | NoneType | None | Custom name of the model. |
backend | str | ray | Backend to use for searching the hyperparameter space, can be either ‘ray’ or ‘optuna’. |
callbacks | NoneType | None | List of functions to call during the optimization process. ray reference: https://docs.ray.io/en/latest/tune/tutorials/tune-metrics.html optuna reference: https://optuna.readthedocs.io/en/stable/tutorial/20_recipes/007_optuna_callback.html |
BaseAuto.fit
*BaseAuto.fit Perform the hyperparameter optimization as specified by the BaseAuto configuration dictionary
config
.
The optimization is performed on the
TimeSeriesDataset
using temporal cross validation with the validation set that
sequentially precedes the test set.
Parameters:dataset
: NeuralForecast’s
TimeSeriesDataset
see details
hereval_size
: int, size of temporal validation set (needs to be bigger
than 0).test_size
: int, size of temporal test set (default
0).random_seed
: int=None, random_seed for hyperparameter
exploration algorithms, not yet implemented.Returns:
self
: fitted instance of BaseAuto
with best hyperparameters and
results.*
BaseAuto.predict
*BaseAuto.predict Predictions of the best performing model on validation. Parameters:
dataset
: NeuralForecast’s
TimeSeriesDataset
see details
herestep_size
: int, steps between sequential predictions, (default 1).**data_kwarg
: additional parameters for the dataset module.random_seed
: int=None, random_seed for hyperparameter exploration
algorithms (not implemented).Returns:
y_hat
: numpy
predictions of the
NeuralForecast
model.*
References
- James Bergstra, Remi Bardenet, Yoshua Bengio, and Balazs Kegl (2011). “Algorithms for Hyper-Parameter Optimization”. In: Advances in Neural Information Processing Systems. url: https://proceedings.neurips.cc/paper/2011/file/86e8f7ab32cfd12577bc2619bc635690-Paper.pdf
- Kirthevasan Kandasamy, Karun Raju Vysyaraju, Willie Neiswanger, Biswajit Paria, Christopher R. Collins, Jeff Schneider, Barnabas Poczos, Eric P. Xing (2019). “Tuning Hyperparameters without Grad Students: Scalable and Robust Bayesian Optimisation with Dragonfly”. Journal of Machine Learning Research. url: https://arxiv.org/abs/1903.06694
- Lisha Li, Kevin Jamieson, Giulia DeSalvo, Afshin Rostamizadeh, Ameet Talwalkar (2016). “Hyperband: A Novel Bandit-Based Approach to Hyperparameter Optimization”. Journal of Machine Learning Research. url: https://arxiv.org/abs/1603.06560