The most important train signal is the forecast error, which is the difference between the observed value yτy_{\tau} and the prediction y^τ\hat{y}_{\tau}, at time yτy_{\tau}:

eτ=yτy^ττ{t+1,,t+H}e_{\tau} = y_{\tau}-\hat{y}_{\tau} \qquad \qquad \tau \in \{t+1,\dots,t+H \}

The train loss summarizes the forecast errors in different train optimization objectives.

All the losses are torch.nn.modules which helps to automatically moved them across CPU/GPU/TPU devices with Pytorch Lightning.


source

BasePointLoss

 BasePointLoss (horizon_weight, outputsize_multiplier, output_names)

*Base class for point loss functions.

Parameters:
horizon_weight: Tensor of size h, weight for each timestamp of the forecasting window.
outputsize_multiplier: Multiplier for the output size.
output_names: Names of the outputs.
*

1. Scale-dependent Errors

These metrics are on the same scale as the data.

Mean Absolute Error (MAE)


source

MAE.__init__

 MAE.__init__ (horizon_weight=None)

*Mean Absolute Error

Calculates Mean Absolute Error between y and y_hat. MAE measures the relative prediction accuracy of a forecasting method by calculating the deviation of the prediction and the true value at a given time and averages these devations over the length of the series.

MAE(yτ,y^τ)=1Hτ=t+1t+Hyτy^τ\mathrm{MAE}(\mathbf{y}_{\tau}, \mathbf{\hat{y}}_{\tau}) = \frac{1}{H} \sum^{t+H}_{\tau=t+1} |y_{\tau} - \hat{y}_{\tau}|

Parameters:
horizon_weight: Tensor of size h, weight for each timestamp of the forecasting window.
*


source

MAE.__call__

 MAE.__call__ (y:torch.Tensor, y_hat:torch.Tensor,
               mask:Optional[torch.Tensor]=None)

*Parameters:
y: tensor, Actual values.
y_hat: tensor, Predicted values.
mask: tensor, Specifies datapoints to consider in loss.

Returns:
mae: tensor (single value).*

Mean Squared Error (MSE)


source

MSE.__init__

 MSE.__init__ (horizon_weight=None)

*Mean Squared Error

Calculates Mean Squared Error between y and y_hat. MSE measures the relative prediction accuracy of a forecasting method by calculating the squared deviation of the prediction and the true value at a given time, and averages these devations over the length of the series.

MSE(yτ,y^τ)=1Hτ=t+1t+H(yτy^τ)2\mathrm{MSE}(\mathbf{y}_{\tau}, \mathbf{\hat{y}}_{\tau}) = \frac{1}{H} \sum^{t+H}_{\tau=t+1} (y_{\tau} - \hat{y}_{\tau})^{2}

Parameters:
horizon_weight: Tensor of size h, weight for each timestamp of the forecasting window.
*


source

MSE.__call__

 MSE.__call__ (y:torch.Tensor, y_hat:torch.Tensor,
               mask:Optional[torch.Tensor]=None)

*Parameters:
y: tensor, Actual values.
y_hat: tensor, Predicted values.
mask: tensor, Specifies datapoints to consider in loss.

Returns:
mse: tensor (single value).*

Root Mean Squared Error (RMSE)


source

RMSE.__init__

 RMSE.__init__ (horizon_weight=None)

*Root Mean Squared Error

Calculates Root Mean Squared Error between y and y_hat. RMSE measures the relative prediction accuracy of a forecasting method by calculating the squared deviation of the prediction and the observed value at a given time and averages these devations over the length of the series. Finally the RMSE will be in the same scale as the original time series so its comparison with other series is possible only if they share a common scale. RMSE has a direct connection to the L2 norm.

RMSE(yτ,y^τ)=1Hτ=t+1t+H(yτy^τ)2\mathrm{RMSE}(\mathbf{y}_{\tau}, \mathbf{\hat{y}}_{\tau}) = \sqrt{\frac{1}{H} \sum^{t+H}_{\tau=t+1} (y_{\tau} - \hat{y}_{\tau})^{2}}

Parameters:
horizon_weight: Tensor of size h, weight for each timestamp of the forecasting window.
*


source

RMSE.__call__

 RMSE.__call__ (y:torch.Tensor, y_hat:torch.Tensor,
                mask:Optional[torch.Tensor]=None)

*Parameters:
y: tensor, Actual values.
y_hat: tensor, Predicted values.
mask: tensor, Specifies datapoints to consider in loss.

Returns:
rmse: tensor (single value).*

2. Percentage errors

These metrics are unit-free, suitable for comparisons across series.

Mean Absolute Percentage Error (MAPE)


source

MAPE.__init__

 MAPE.__init__ (horizon_weight=None)

*Mean Absolute Percentage Error

Calculates Mean Absolute Percentage Error between y and y_hat. MAPE measures the relative prediction accuracy of a forecasting method by calculating the percentual deviation of the prediction and the observed value at a given time and averages these devations over the length of the series. The closer to zero an observed value is, the higher penalty MAPE loss assigns to the corresponding error.

MAPE(yτ,y^τ)=1Hτ=t+1t+Hyτy^τyτ\mathrm{MAPE}(\mathbf{y}_{\tau}, \mathbf{\hat{y}}_{\tau}) = \frac{1}{H} \sum^{t+H}_{\tau=t+1} \frac{|y_{\tau}-\hat{y}_{\tau}|}{|y_{\tau}|}

Parameters:
horizon_weight: Tensor of size h, weight for each timestamp of the forecasting window.

References:
Makridakis S., “Accuracy measures: theoretical and practical concerns”.*


source

MAPE.__call__

 MAPE.__call__ (y:torch.Tensor, y_hat:torch.Tensor,
                mask:Optional[torch.Tensor]=None)

*Parameters:
y: tensor, Actual values.
y_hat: tensor, Predicted values.
mask: tensor, Specifies date stamps per serie to consider in loss.

Returns:
mape: tensor (single value).*

Symmetric MAPE (sMAPE)


source

SMAPE.__init__

 SMAPE.__init__ (horizon_weight=None)

*Symmetric Mean Absolute Percentage Error

Calculates Symmetric Mean Absolute Percentage Error between y and y_hat. SMAPE measures the relative prediction accuracy of a forecasting method by calculating the relative deviation of the prediction and the observed value scaled by the sum of the absolute values for the prediction and observed value at a given time, then averages these devations over the length of the series. This allows the SMAPE to have bounds between 0% and 200% which is desireble compared to normal MAPE that may be undetermined when the target is zero.

sMAPE2(yτ,y^τ)=1Hτ=t+1t+Hyτy^τyτ+y^τ\mathrm{sMAPE}_{2}(\mathbf{y}_{\tau}, \mathbf{\hat{y}}_{\tau}) = \frac{1}{H} \sum^{t+H}_{\tau=t+1} \frac{|y_{\tau}-\hat{y}_{\tau}|}{|y_{\tau}|+|\hat{y}_{\tau}|}

Parameters:
horizon_weight: Tensor of size h, weight for each timestamp of the forecasting window.

References:
Makridakis S., “Accuracy measures: theoretical and practical concerns”.*


source

SMAPE.__call__

 SMAPE.__call__ (y:torch.Tensor, y_hat:torch.Tensor,
                 mask:Optional[torch.Tensor]=None)

*Parameters:
y: tensor, Actual values.
y_hat: tensor, Predicted values.
mask: tensor, Specifies date stamps per serie to consider in loss.

Returns:
smape: tensor (single value).*

3. Scale-independent Errors

These metrics measure the relative improvements versus baselines.

Mean Absolute Scaled Error (MASE)


source

MASE.__init__

 MASE.__init__ (seasonality:int, horizon_weight=None)

*Mean Absolute Scaled Error Calculates the Mean Absolute Scaled Error between y and y_hat. MASE measures the relative prediction accuracy of a forecasting method by comparinng the mean absolute errors of the prediction and the observed value against the mean absolute errors of the seasonal naive model. The MASE partially composed the Overall Weighted Average (OWA), used in the M4 Competition.

MASE(yτ,y^τ,y^τseason)=1Hτ=t+1t+Hyτy^τMAE(yτ,y^τseason)\mathrm{MASE}(\mathbf{y}_{\tau}, \mathbf{\hat{y}}_{\tau}, \mathbf{\hat{y}}^{season}_{\tau}) = \frac{1}{H} \sum^{t+H}_{\tau=t+1} \frac{|y_{\tau}-\hat{y}_{\tau}|}{\mathrm{MAE}(\mathbf{y}_{\tau}, \mathbf{\hat{y}}^{season}_{\tau})}

Parameters:
seasonality: int. Main frequency of the time series; Hourly 24, Daily 7, Weekly 52, Monthly 12, Quarterly 4, Yearly 1. horizon_weight: Tensor of size h, weight for each timestamp of the forecasting window.

References:
Rob J. Hyndman, & Koehler, A. B. “Another look at measures of forecast accuracy”.
Spyros Makridakis, Evangelos Spiliotis, Vassilios Assimakopoulos, “The M4 Competition: 100,000 time series and 61 forecasting methods”.*


source

MASE.__call__

 MASE.__call__ (y:torch.Tensor, y_hat:torch.Tensor,
                y_insample:torch.Tensor, mask:Optional[torch.Tensor]=None)

*Parameters:
y: tensor (batch_size, output_size), Actual values.
y_hat: tensor (batch_size, output_size)), Predicted values.
y_insample: tensor (batch_size, input_size), Actual insample Seasonal Naive predictions.
mask: tensor, Specifies date stamps per serie to consider in loss.

Returns:
mase: tensor (single value).*

Relative Mean Squared Error (relMSE)


source

relMSE.__init__

 relMSE.__init__ (y_train, horizon_weight=None)

*Relative Mean Squared Error Computes Relative Mean Squared Error (relMSE), as proposed by Hyndman & Koehler (2006) as an alternative to percentage errors, to avoid measure unstability.

relMSE(y,y^,y^naive1)=MSE(y,y^)MSE(y,y^naive1) \mathrm{relMSE}(\mathbf{y}, \mathbf{\hat{y}}, \mathbf{\hat{y}}^{naive1}) = \frac{\mathrm{MSE}(\mathbf{y}, \mathbf{\hat{y}})}{\mathrm{MSE}(\mathbf{y}, \mathbf{\hat{y}}^{naive1})}

Parameters:
y_train: numpy array, Training values.
horizon_weight: Tensor of size h, weight for each timestamp of the forecasting window.

References:
- Hyndman, R. J and Koehler, A. B. (2006). “Another look at measures of forecast accuracy”, International Journal of Forecasting, Volume 22, Issue 4.
- Kin G. Olivares, O. Nganba Meetei, Ruijun Ma, Rohan Reddy, Mengfei Cao, Lee Dicker. “Probabilistic Hierarchical Forecasting with Deep Poisson Mixtures. Submitted to the International Journal Forecasting, Working paper available at arxiv.*


source

relMSE.__call__

 relMSE.__call__ (y:torch.Tensor, y_hat:torch.Tensor,
                  mask:Optional[torch.Tensor]=None)

*Parameters:
y: tensor (batch_size, output_size), Actual values.
y_hat: tensor (batch_size, output_size)), Predicted values.
y_insample: tensor (batch_size, input_size), Actual insample Seasonal Naive predictions.
mask: tensor, Specifies date stamps per serie to consider in loss.

Returns:
relMSE: tensor (single value).*

4. Probabilistic Errors

These methods use statistical approaches for estimating unknown probability distributions using observed data.

Maximum likelihood estimation involves finding the parameter values that maximize the likelihood function, which measures the probability of obtaining the observed data given the parameter values. MLE has good theoretical properties and efficiency under certain satisfied assumptions.

On the non-parametric approach, quantile regression measures non-symmetrically deviation, producing under/over estimation.

Quantile Loss


source

QuantileLoss.__init__

 QuantileLoss.__init__ (q, horizon_weight=None)

*Quantile Loss

Computes the quantile loss between y and y_hat. QL measures the deviation of a quantile forecast. By weighting the absolute deviation in a non symmetric way, the loss pays more attention to under or over estimation. A common value for q is 0.5 for the deviation from the median (Pinball loss).

QL(yτ,y^τ(q))=1Hτ=t+1t+H((1q)(y^τ(q)yτ)++q(yτy^τ(q))+)\mathrm{QL}(\mathbf{y}_{\tau}, \mathbf{\hat{y}}^{(q)}_{\tau}) = \frac{1}{H} \sum^{t+H}_{\tau=t+1} \Big( (1-q)\,( \hat{y}^{(q)}_{\tau} - y_{\tau} )_{+} + q\,( y_{\tau} - \hat{y}^{(q)}_{\tau} )_{+} \Big)

Parameters:
q: float, between 0 and 1. The slope of the quantile loss, in the context of quantile regression, the q determines the conditional quantile level.
horizon_weight: Tensor of size h, weight for each timestamp of the forecasting window.

References:
Roger Koenker and Gilbert Bassett, Jr., “Regression Quantiles”.*


source

QuantileLoss.__call__

 QuantileLoss.__call__ (y:torch.Tensor, y_hat:torch.Tensor,
                        mask:Optional[torch.Tensor]=None)

*Parameters:
y: tensor, Actual values.
y_hat: tensor, Predicted values.
mask: tensor, Specifies datapoints to consider in loss.

Returns:
quantile_loss: tensor (single value).*

Multi Quantile Loss (MQLoss)


source

MQLoss.__init__

 MQLoss.__init__ (level=[80, 90], quantiles=None, horizon_weight=None)

*Multi-Quantile loss

Calculates the Multi-Quantile loss (MQL) between y and y_hat. MQL calculates the average multi-quantile Loss for a given set of quantiles, based on the absolute difference between predicted quantiles and observed values.

MQL(yτ,[y^τ(q1),...,y^τ(qn)])=1nqiQL(yτ,y^τ(qi))\mathrm{MQL}(\mathbf{y}_{\tau},[\mathbf{\hat{y}}^{(q_{1})}_{\tau}, ... ,\hat{y}^{(q_{n})}_{\tau}]) = \frac{1}{n} \sum_{q_{i}} \mathrm{QL}(\mathbf{y}_{\tau}, \mathbf{\hat{y}}^{(q_{i})}_{\tau})

The limit behavior of MQL allows to measure the accuracy of a full predictive distribution F^τ\mathbf{\hat{F}}_{\tau} with the continuous ranked probability score (CRPS). This can be achieved through a numerical integration technique, that discretizes the quantiles and treats the CRPS integral with a left Riemann approximation, averaging over uniformly distanced quantiles.

CRPS(yτ,F^τ)=01QL(yτ,y^τ(q))dq\mathrm{CRPS}(y_{\tau}, \mathbf{\hat{F}}_{\tau}) = \int^{1}_{0} \mathrm{QL}(y_{\tau}, \hat{y}^{(q)}_{\tau}) dq

Parameters:
level: int list [0,100]. Probability levels for prediction intervals (Defaults median). quantiles: float list [0., 1.]. Alternative to level, quantiles to estimate from y distribution. horizon_weight: Tensor of size h, weight for each timestamp of the forecasting window.

References:
Roger Koenker and Gilbert Bassett, Jr., “Regression Quantiles”.
James E. Matheson and Robert L. Winkler, “Scoring Rules for Continuous Probability Distributions”.*


source

MQLoss.__call__

 MQLoss.__call__ (y:torch.Tensor, y_hat:torch.Tensor,
                  mask:Optional[torch.Tensor]=None)

*Parameters:
y: tensor, Actual values.
y_hat: tensor, Predicted values.
mask: tensor, Specifies date stamps per serie to consider in loss.

Returns:
mqloss: tensor (single value).*

Implicit Quantile Loss (IQLoss)


source

QuantileLayer

 QuantileLayer (num_output:int, cos_embedding_dim:int=128)

*Implicit Quantile Layer from the paper IQN for Distributional Reinforcement Learning (https://arxiv.org/abs/1806.06923) by Dabney et al. 2018.

Code from GluonTS: https://github.com/awslabs/gluonts/blob/61133ef6e2d88177b32ace4afc6843ab9a7bc8cd/src/gluonts/torch/distributions/implicit_quantile_network.py\*


source

IQLoss.__init__

 IQLoss.__init__ (cos_embedding_dim=64, concentration0=1.0,
                  concentration1=1.0, horizon_weight=None)

*Implicit Quantile Loss

Computes the quantile loss between y and y_hat, with the quantile q provided as an input to the network. IQL measures the deviation of a quantile forecast. By weighting the absolute deviation in a non symmetric way, the loss pays more attention to under or over estimation.

QL(yτ,y^τ(q))=1Hτ=t+1t+H((1q)(y^τ(q)yτ)++q(yτy^τ(q))+)\mathrm{QL}(\mathbf{y}_{\tau}, \mathbf{\hat{y}}^{(q)}_{\tau}) = \frac{1}{H} \sum^{t+H}_{\tau=t+1} \Big( (1-q)\,( \hat{y}^{(q)}_{\tau} - y_{\tau} )_{+} + q\,( y_{\tau} - \hat{y}^{(q)}_{\tau} )_{+} \Big)

Parameters:
quantile_sampling: str, default=‘uniform’, sampling distribution used to sample the quantiles during training. Choose from [‘uniform’, ‘beta’].
horizon_weight: Tensor of size h, weight for each timestamp of the forecasting window.

References:
Gouttes, Adèle, Kashif Rasul, Mateusz Koren, Johannes Stephan, and Tofigh Naghibi, “Probabilistic Time Series Forecasting with Implicit Quantile Networks”.*


source

IQLoss.__call__

 IQLoss.__call__ (y:torch.Tensor, y_hat:torch.Tensor,
                  mask:Optional[torch.Tensor]=None)

*Parameters:
y: tensor, Actual values.
y_hat: tensor, Predicted values.
mask: tensor, Specifies datapoints to consider in loss.

Returns:
quantile_loss: tensor (single value).*

DistributionLoss


source

DistributionLoss.__init__

 DistributionLoss.__init__ (distribution, level=[80, 90], quantiles=None,
                            num_samples=1000, return_params=False,
                            **distribution_kwargs)

*DistributionLoss

This PyTorch module wraps the torch.distribution classes allowing it to interact with NeuralForecast models modularly. It shares the negative log-likelihood as the optimization objective and a sample method to generate empirically the quantiles defined by the level list.

Additionally, it implements a distribution transformation that factorizes the scale-dependent likelihood parameters into a base scale and a multiplier efficiently learnable within the network’s non-linearities operating ranges.

Available distributions:
- Poisson
- Normal
- StudentT
- NegativeBinomial
- Tweedie
- Bernoulli (Temporal Classifiers)
- ISQF (Incremental Spline Quantile Function)

Parameters:
distribution: str, identifier of a torch.distributions.Distribution class.
level: float list [0,100], confidence levels for prediction intervals.
quantiles: float list [0,1], alternative to level list, target quantiles.
num_samples: int=500, number of samples for the empirical quantiles.
return_params: bool=False, wether or not return the Distribution parameters.

References:
- PyTorch Probability Distributions Package: StudentT.
- David Salinas, Valentin Flunkert, Jan Gasthaus, Tim Januschowski (2020). “DeepAR: Probabilistic forecasting with autoregressive recurrent networks”. International Journal of Forecasting.
- Park, Youngsuk, Danielle Maddix, François-Xavier Aubet, Kelvin Kan, Jan Gasthaus, and Yuyang Wang (2022). “Learning Quantile Functions without Quantile Crossing for Distribution-free Time Series Forecasting”.*


source

DistributionLoss.sample

 DistributionLoss.sample (distr_args:torch.Tensor,
                          num_samples:Optional[int]=None)

*Construct the empirical quantiles from the estimated Distribution, sampling from it num_samples independently.

Parameters
distr_args: Constructor arguments for the underlying Distribution type.
num_samples: int=500, overwrite number of samples for the empirical quantiles.

Returns
samples: tensor, shape [B,H,num_samples].
quantiles: tensor, empirical quantiles defined by levels.
*


source

DistributionLoss.__call__

 DistributionLoss.__call__ (y:torch.Tensor, distr_args:torch.Tensor,
                            mask:Optional[torch.Tensor]=None)

*Computes the negative log-likelihood objective function. To estimate the following predictive distribution:

P(yτθ)andlog(P(yτθ))\mathrm{P}(\mathbf{y}_{\tau}\,|\,\theta) \quad \mathrm{and} \quad -\log(\mathrm{P}(\mathbf{y}_{\tau}\,|\,\theta))

where θ\theta represents the distributions parameters. It aditionally summarizes the objective signal using a weighted average using the mask tensor.

Parameters
y: tensor, Actual values.
distr_args: Constructor arguments for the underlying Distribution type.
loc: Optional tensor, of the same shape as the batch_shape + event_shape of the resulting distribution.
scale: Optional tensor, of the same shape as the batch_shape+event_shape of the resulting distribution.
mask: tensor, Specifies date stamps per serie to consider in loss.

Returns
loss: scalar, weighted loss function against which backpropagation will be performed.
*

Poisson Mixture Mesh (PMM)


source

PMM.__init__

 PMM.__init__ (n_components=10, level=[80, 90], quantiles=None,
               num_samples=1000, return_params=False,
               batch_correlation=False, horizon_correlation=False)

*Poisson Mixture Mesh

This Poisson Mixture statistical model assumes independence across groups of data G={[gi]}\mathcal{G}=\{[g_{i}]\}, and estimates relationships within the group.

P(y[b][t+1:t+H])=[gi]GP(y[gi][τ])=β[gi](k=1Kwk(β,τ)[gi][t+1:t+H]Poisson(yβ,τ,λ^β,τ,k)) \mathrm{P}\left(\mathbf{y}_{[b][t+1:t+H]}\right) = \prod_{ [g_{i}] \in \mathcal{G}} \mathrm{P} \left(\mathbf{y}_{[g_{i}][\tau]} \right) = \prod_{\beta\in[g_{i}]} \left(\sum_{k=1}^{K} w_k \prod_{(\beta,\tau) \in [g_i][t+1:t+H]} \mathrm{Poisson}(y_{\beta,\tau}, \hat{\lambda}_{\beta,\tau,k}) \right)

Parameters:
n_components: int=10, the number of mixture components.
level: float list [0,100], confidence levels for prediction intervals.
quantiles: float list [0,1], alternative to level list, target quantiles.
return_params: bool=False, wether or not return the Distribution parameters.
batch_correlation: bool=False, wether or not model batch correlations.
horizon_correlation: bool=False, wether or not model horizon correlations.

References:
Kin G. Olivares, O. Nganba Meetei, Ruijun Ma, Rohan Reddy, Mengfei Cao, Lee Dicker. Probabilistic Hierarchical Forecasting with Deep Poisson Mixtures. Submitted to the International Journal Forecasting, Working paper available at arxiv.*


source

PMM.sample

 PMM.sample (distr_args, num_samples=None)

*Construct the empirical quantiles from the estimated Distribution, sampling from it num_samples independently.

Parameters
distr_args: Constructor arguments for the underlying Distribution type.
loc: Optional tensor, of the same shape as the batch_shape + event_shape of the resulting distribution.
scale: Optional tensor, of the same shape as the batch_shape+event_shape of the resulting distribution.
num_samples: int=500, overwrites number of samples for the empirical quantiles.

Returns
samples: tensor, shape [B,H,num_samples].
quantiles: tensor, empirical quantiles defined by levels.
*


source

PMM.__call__

 PMM.__call__ (y:torch.Tensor, distr_args:Tuple[torch.Tensor],
               mask:Optional[torch.Tensor]=None)

Call self as a function.

Gaussian Mixture Mesh (GMM)


source

GMM.__init__

 GMM.__init__ (n_components=1, level=[80, 90], quantiles=None,
               num_samples=1000, return_params=False,
               batch_correlation=False, horizon_correlation=False)

*Gaussian Mixture Mesh

This Gaussian Mixture statistical model assumes independence across groups of data G={[gi]}\mathcal{G}=\{[g_{i}]\}, and estimates relationships within the group.

P(y[b][t+1:t+H])=[gi]GP(y[gi][τ])=β[gi](k=1Kwk(β,τ)[gi][t+1:t+H]Gaussian(yβ,τ,μ^β,τ,k,σβ,τ,k)) \mathrm{P}\left(\mathbf{y}_{[b][t+1:t+H]}\right) = \prod_{ [g_{i}] \in \mathcal{G}} \mathrm{P}\left(\mathbf{y}_{[g_{i}][\tau]}\right)= \prod_{\beta\in[g_{i}]} \left(\sum_{k=1}^{K} w_k \prod_{(\beta,\tau) \in [g_i][t+1:t+H]} \mathrm{Gaussian}(y_{\beta,\tau}, \hat{\mu}_{\beta,\tau,k}, \sigma_{\beta,\tau,k})\right)

Parameters:
n_components: int=10, the number of mixture components.
level: float list [0,100], confidence levels for prediction intervals.
quantiles: float list [0,1], alternative to level list, target quantiles.
return_params: bool=False, wether or not return the Distribution parameters.
batch_correlation: bool=False, wether or not model batch correlations.
horizon_correlation: bool=False, wether or not model horizon correlations.

References:
Kin G. Olivares, O. Nganba Meetei, Ruijun Ma, Rohan Reddy, Mengfei Cao, Lee Dicker. Probabilistic Hierarchical Forecasting with Deep Poisson Mixtures. Submitted to the International Journal Forecasting, Working paper available at arxiv.*


source

GMM.sample

 GMM.sample (distr_args, num_samples=None)

*Construct the empirical quantiles from the estimated Distribution, sampling from it num_samples independently.

Parameters
distr_args: Constructor arguments for the underlying Distribution type.
loc: Optional tensor, of the same shape as the batch_shape + event_shape of the resulting distribution.
scale: Optional tensor, of the same shape as the batch_shape+event_shape of the resulting distribution.
num_samples: int=500, number of samples for the empirical quantiles.

Returns
samples: tensor, shape [B,H,num_samples].
quantiles: tensor, empirical quantiles defined by levels.
*


source

GMM.__call__

 GMM.__call__ (y:torch.Tensor,
               distr_args:Tuple[torch.Tensor,torch.Tensor],
               mask:Optional[torch.Tensor]=None)

Call self as a function.

Negative Binomial Mixture Mesh (NBMM)


source

NBMM.__init__

 NBMM.__init__ (n_components=1, level=[80, 90], quantiles=None,
                num_samples=1000, return_params=False)

*Negative Binomial Mixture Mesh

This N. Binomial Mixture statistical model assumes independence across groups of data G={[gi]}\mathcal{G}=\{[g_{i}]\}, and estimates relationships within the group.

P(y[b][t+1:t+H])=[gi]GP(y[gi][τ])=β[gi](k=1Kwk(β,τ)[gi][t+1:t+H]NBinomial(yβ,τ,r^β,τ,k,p^β,τ,k)) \mathrm{P}\left(\mathbf{y}_{[b][t+1:t+H]}\right) = \prod_{ [g_{i}] \in \mathcal{G}} \mathrm{P}\left(\mathbf{y}_{[g_{i}][\tau]}\right)= \prod_{\beta\in[g_{i}]} \left(\sum_{k=1}^{K} w_k \prod_{(\beta,\tau) \in [g_i][t+1:t+H]} \mathrm{NBinomial}(y_{\beta,\tau}, \hat{r}_{\beta,\tau,k}, \hat{p}_{\beta,\tau,k})\right)

Parameters:
n_components: int=10, the number of mixture components.
level: float list [0,100], confidence levels for prediction intervals.
quantiles: float list [0,1], alternative to level list, target quantiles.
return_params: bool=False, wether or not return the Distribution parameters.

References:
Kin G. Olivares, O. Nganba Meetei, Ruijun Ma, Rohan Reddy, Mengfei Cao, Lee Dicker. Probabilistic Hierarchical Forecasting with Deep Poisson Mixtures. Submitted to the International Journal Forecasting, Working paper available at arxiv.*


source

NBMM.sample

 NBMM.sample (distr_args, num_samples=None)

*Construct the empirical quantiles from the estimated Distribution, sampling from it num_samples independently.

Parameters
distr_args: Constructor arguments for the underlying Distribution type.
loc: Optional tensor, of the same shape as the batch_shape + event_shape of the resulting distribution.
scale: Optional tensor, of the same shape as the batch_shape+event_shape of the resulting distribution.
num_samples: int=500, number of samples for the empirical quantiles.

Returns
samples: tensor, shape [B,H,num_samples].
quantiles: tensor, empirical quantiles defined by levels.
*


source

NBMM.__call__

 NBMM.__call__ (y:torch.Tensor,
                distr_args:Tuple[torch.Tensor,torch.Tensor],
                mask:Optional[torch.Tensor]=None)

Call self as a function.

5. Robustified Errors

This type of errors from robust statistic focus on methods resistant to outliers and violations of assumptions, providing reliable estimates and inferences. Robust estimators are used to reduce the impact of outliers, offering more stable results.

Huber Loss


source

HuberLoss.__init__

 HuberLoss.__init__ (delta:float=1.0, horizon_weight=None)

*Huber Loss

The Huber loss, employed in robust regression, is a loss function that exhibits reduced sensitivity to outliers in data when compared to the squared error loss. This function is also refered as SmoothL1.

The Huber loss function is quadratic for small errors and linear for large errors, with equal values and slopes of the different sections at the two points where (yτy^τ)2(y_{\tau}-\hat{y}_{\tau})^{2}=yτy^τ|y_{\tau}-\hat{y}_{\tau}|.

Lδ(yτ,  y^τ)={12(yτy^τ)2  for yτy^τδδ (yτy^τ12δ),  otherwise. L_{\delta}(y_{\tau},\; \hat{y}_{\tau}) =\begin{cases}{\frac{1}{2}}(y_{\tau}-\hat{y}_{\tau})^{2}\;{\text{for }}|y_{\tau}-\hat{y}_{\tau}|\leq \delta \\ \delta \ \cdot \left(|y_{\tau}-\hat{y}_{\tau}|-{\frac {1}{2}}\delta \right),\;{\text{otherwise.}}\end{cases}

where δ\delta is a threshold parameter that determines the point at which the loss transitions from quadratic to linear, and can be tuned to control the trade-off between robustness and accuracy in the predictions.

Parameters:
delta: float=1.0, Specifies the threshold at which to change between delta-scaled L1 and L2 loss. horizon_weight: Tensor of size h, weight for each timestamp of the forecasting window.

References:
Huber Peter, J (1964). “Robust Estimation of a Location Parameter”. Annals of Statistics*


source

HuberLoss.__call__

 HuberLoss.__call__ (y:torch.Tensor, y_hat:torch.Tensor,
                     mask:Optional[torch.Tensor]=None)

*Parameters:
y: tensor, Actual values.
y_hat: tensor, Predicted values.
mask: tensor, Specifies date stamps per serie to consider in loss.

Returns:
huber_loss: tensor (single value).*

Tukey Loss


source

TukeyLoss.__init__

 TukeyLoss.__init__ (c:float=4.685, normalize:bool=True)

*Tukey Loss

The Tukey loss function, also known as Tukey’s biweight function, is a robust statistical loss function used in robust statistics. Tukey’s loss exhibits quadratic behavior near the origin, like the Huber loss; however, it is even more robust to outliers as the loss for large residuals remains constant instead of scaling linearly.

The parameter cc in Tukey’s loss determines the ‘’saturation’’ point of the function: Higher values of cc enhance sensitivity, while lower values increase resistance to outliers.

Lc(yτ,  y^τ)={c26[1(yτy^τc)2]3  for yτy^τcc26otherwise. L_{c}(y_{\tau},\; \hat{y}_{\tau}) =\begin{cases}{ \frac{c^{2}}{6}} \left[1-(\frac{y_{\tau}-\hat{y}_{\tau}}{c})^{2} \right]^{3} \;\text{for } |y_{\tau}-\hat{y}_{\tau}|\leq c \\ \frac{c^{2}}{6} \qquad \text{otherwise.} \end{cases}

Please note that the Tukey loss function assumes the data to be stationary or normalized beforehand. If the error values are excessively large, the algorithm may need help to converge during optimization. It is advisable to employ small learning rates.

Parameters:
c: float=4.685, Specifies the Tukey loss’ threshold on which residuals are no longer considered.
normalize: bool=True, Wether normalization is performed within Tukey loss’ computation.

References:
Beaton, A. E., and Tukey, J. W. (1974). “The Fitting of Power Series, Meaning Polynomials, Illustrated on Band-Spectroscopic Data.”*


source

TukeyLoss.__call__

 TukeyLoss.__call__ (y:torch.Tensor, y_hat:torch.Tensor,
                     mask:Optional[torch.Tensor]=None)

*Parameters:
y: tensor, Actual values.
y_hat: tensor, Predicted values.
mask: tensor, Specifies date stamps per serie to consider in loss.

Returns:
tukey_loss: tensor (single value).*

Huberized Quantile Loss


source

HuberQLoss.__init__

 HuberQLoss.__init__ (q, delta:float=1.0, horizon_weight=None)

*Huberized Quantile Loss

The Huberized quantile loss is a modified version of the quantile loss function that combines the advantages of the quantile loss and the Huber loss. It is commonly used in regression tasks, especially when dealing with data that contains outliers or heavy tails.

The Huberized quantile loss between y and y_hat measure the Huber Loss in a non-symmetric way. The loss pays more attention to under/over-estimation depending on the quantile parameter qq; and controls the trade-off between robustness and accuracy in the predictions with the parameter deltadelta.

HuberQL(yτ,y^τ(q))=(1q)Lδ(yτ,  y^τ(q))1{y^τ(q)yτ}+qLδ(yτ,  y^τ(q))1{y^τ(q)<yτ} \mathrm{HuberQL}(\mathbf{y}_{\tau}, \mathbf{\hat{y}}^{(q)}_{\tau}) = (1-q)\, L_{\delta}(y_{\tau},\; \hat{y}^{(q)}_{\tau}) \mathbb{1}\{ \hat{y}^{(q)}_{\tau} \geq y_{\tau} \} + q\, L_{\delta}(y_{\tau},\; \hat{y}^{(q)}_{\tau}) \mathbb{1}\{ \hat{y}^{(q)}_{\tau} < y_{\tau} \}

Parameters:
delta: float=1.0, Specifies the threshold at which to change between delta-scaled L1 and L2 loss.
q: float, between 0 and 1. The slope of the quantile loss, in the context of quantile regression, the q determines the conditional quantile level.
horizon_weight: Tensor of size h, weight for each timestamp of the forecasting window.

References:
Huber Peter, J (1964). “Robust Estimation of a Location Parameter”. Annals of Statistics
Roger Koenker and Gilbert Bassett, Jr., “Regression Quantiles”.*


source

HuberQLoss.__call__

 HuberQLoss.__call__ (y:torch.Tensor, y_hat:torch.Tensor,
                      mask:Optional[torch.Tensor]=None)

*Parameters:
y: tensor, Actual values.
y_hat: tensor, Predicted values.
mask: tensor, Specifies datapoints to consider in loss.

Returns:
huber_qloss: tensor (single value).*

Huberized MQLoss


source

HuberMQLoss.__init__

 HuberMQLoss.__init__ (level=[80, 90], quantiles=None, delta:float=1.0,
                       horizon_weight=None)

*Huberized Multi-Quantile loss

The Huberized Multi-Quantile loss (HuberMQL) is a modified version of the multi-quantile loss function that combines the advantages of the quantile loss and the Huber loss. HuberMQL is commonly used in regression tasks, especially when dealing with data that contains outliers or heavy tails. The loss function pays more attention to under/over-estimation depending on the quantile list [q1,q2,][q_{1},q_{2},\dots] parameter. It controls the trade-off between robustness and prediction accuracy with the parameter δ\delta.

HuberMQLδ(yτ,[y^τ(q1),...,y^τ(qn)])=1nqiHuberQLδ(yτ,y^τ(qi)) \mathrm{HuberMQL}_{\delta}(\mathbf{y}_{\tau},[\mathbf{\hat{y}}^{(q_{1})}_{\tau}, ... ,\hat{y}^{(q_{n})}_{\tau}]) = \frac{1}{n} \sum_{q_{i}} \mathrm{HuberQL}_{\delta}(\mathbf{y}_{\tau}, \mathbf{\hat{y}}^{(q_{i})}_{\tau})

Parameters:
level: int list [0,100]. Probability levels for prediction intervals (Defaults median). quantiles: float list [0., 1.]. Alternative to level, quantiles to estimate from y distribution. delta: float=1.0, Specifies the threshold at which to change between delta-scaled L1 and L2 loss.

horizon_weight: Tensor of size h, weight for each timestamp of the forecasting window.

References:
Huber Peter, J (1964). “Robust Estimation of a Location Parameter”. Annals of Statistics
Roger Koenker and Gilbert Bassett, Jr., “Regression Quantiles”.*


source

HuberMQLoss.__call__

 HuberMQLoss.__call__ (y:torch.Tensor, y_hat:torch.Tensor,
                       mask:Optional[torch.Tensor]=None)

*Parameters:
y: tensor, Actual values.
y_hat: tensor, Predicted values.
mask: tensor, Specifies date stamps per serie to consider in loss.

Returns:
hmqloss: tensor (single value).*

6. Others

Accuracy


source

Accuracy.__init__

 Accuracy.__init__ ()

*Accuracy

Computes the accuracy between categorical y and y_hat. This evaluation metric is only meant for evalution, as it is not differentiable.

Accuracy(yτ,y^τ)=1Hτ=t+1t+H1{yτ==y^τ}\mathrm{Accuracy}(\mathbf{y}_{\tau}, \mathbf{\hat{y}}_{\tau}) = \frac{1}{H} \sum^{t+H}_{\tau=t+1} \mathrm{1}\{\mathbf{y}_{\tau}==\mathbf{\hat{y}}_{\tau}\}*


source

Accuracy.__call__

 Accuracy.__call__ (y:torch.Tensor, y_hat:torch.Tensor,
                    mask:Optional[torch.Tensor]=None)

*Parameters:
y: tensor, Actual values.
y_hat: tensor, Predicted values.
mask: tensor, Specifies date stamps per serie to consider in loss.

Returns:
accuracy: tensor (single value).*

Scaled Continuous Ranked Probability Score (sCRPS)


source

sCRPS.__init__

 sCRPS.__init__ (level=[80, 90], quantiles=None)

*Scaled Continues Ranked Probability Score

Calculates a scaled variation of the CRPS, as proposed by Rangapuram (2021), to measure the accuracy of predicted quantiles y_hat compared to the observation y.

This metric averages percentual weighted absolute deviations as defined by the quantile losses.

sCRPS(y^τ(q),yτ)=2Ni01QL(y^τ(qyi,τ)qiyi,τdq \mathrm{sCRPS}(\mathbf{\hat{y}}^{(q)}_{\tau}, \mathbf{y}_{\tau}) = \frac{2}{N} \sum_{i} \int^{1}_{0} \frac{\mathrm{QL}(\mathbf{\hat{y}}^{(q}_{\tau} y_{i,\tau})_{q}}{\sum_{i} | y_{i,\tau} |} dq

where y^τ(q\mathbf{\hat{y}}^{(q}_{\tau} is the estimated quantile, and yi,τy_{i,\tau} are the target variable realizations.

Parameters:
level: int list [0,100]. Probability levels for prediction intervals (Defaults median). quantiles: float list [0., 1.]. Alternative to level, quantiles to estimate from y distribution.

References:
- Gneiting, Tilmann. (2011). “Quantiles as optimal point forecasts”. International Journal of Forecasting.
- Spyros Makridakis, Evangelos Spiliotis, Vassilios Assimakopoulos, Zhi Chen, Anil Gaba, Ilia Tsetlin, Robert L. Winkler. (2022). “The M5 uncertainty competition: Results, findings and conclusions”. International Journal of Forecasting.
- Syama Sundar Rangapuram, Lucien D Werner, Konstantinos Benidis, Pedro Mercado, Jan Gasthaus, Tim Januschowski. (2021). “End-to-End Learning of Coherent Probabilistic Forecasts for Hierarchical Time Series”. Proceedings of the 38th International Conference on Machine Learning (ICML).*


source

sCRPS.__call__

 sCRPS.__call__ (y:torch.Tensor, y_hat:torch.Tensor,
                 mask:Optional[torch.Tensor]=None)

*Parameters:
y: tensor, Actual values.
y_hat: tensor, Predicted values.
mask: tensor, Specifies date stamps per series to consider in loss.

Returns:
scrps: tensor (single value).*