- Jan Golda, Krzysztof Kudrynski. “NVIDIA, Deep Learning Forecasting Examples”
- Bryan Lim, Sercan O. Arik, Nicolas Loeff, Tomas Pfister, “Temporal Fusion Transformers for interpretable multi-horizon time series forecasting”

1. Temporal Fusion Decoder
TFT
BaseModel
TFT
The Temporal Fusion Transformer architecture (TFT) is an Sequence-to-Sequence
model that combines static, historic and future available data to predict an
univariate target. The method combines gating layers, an LSTM recurrent encoder,
with and interpretable multi-head attention layer and a multi-step forecasting
strategy decoder.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
h | int | Forecast horizon. | required |
input_size | int | autorregresive inputs size, y=[1,2,3,4] input_size=2 -> y_[t-2:t]=[1,2]. | required |
tgt_size | int | target size. | 1 |
stat_exog_list | str list | static continuous columns. | None |
hist_exog_list | str list | historic continuous columns. | None |
futr_exog_list | str list | future continuous columns. | None |
hidden_size | int | units of embeddings and encoders. | 128 |
n_head | int | number of attention heads in temporal fusion decoder. | 4 |
attn_dropout | float | dropout of fusion decoder’s attention layer. | 0.0 |
grn_activation | str | activation for the GRN module from [‘ReLU’, ‘Softplus’, ‘Tanh’, ‘SELU’, ‘LeakyReLU’, ‘Sigmoid’, ‘ELU’, ‘GLU’]. | ’ELU’ |
n_rnn_layers | int | number of RNN layers. | 1 |
rnn_type | str | recurrent neural network (RNN) layer type from [“lstm”,“gru”]. | ‘lstm’ |
one_rnn_initial_state | str | Initialize all rnn layers with the same initial states computed from static covariates. | False |
dropout | float | dropout of inputs VSNs. | 0.1 |
loss | PyTorch module | instantiated train loss class from losses collection. | MAE() |
valid_loss | PyTorch module | instantiated valid loss class from losses collection. | None |
max_steps | int | maximum number of training steps. | 1000 |
learning_rate | float | Learning rate between (0, 1). | 0.001 |
num_lr_decays | int | Number of learning rate decays, evenly distributed across max_steps. | -1 |
early_stop_patience_steps | int | Number of validation iterations before early stopping. | -1 |
val_check_steps | int | Number of training steps between every validation loss check. | 100 |
batch_size | int | number of different series in each batch. | 32 |
valid_batch_size | int | number of different series in each validation and test batch. | None |
windows_batch_size | int | windows sampled from rolled data, default uses all. | 1024 |
inference_windows_batch_size | int | number of windows to sample in each inference batch, -1 uses all. | 1024 |
start_padding_enabled | bool | if True, the model will pad the time series with zeros at the beginning, by input size. | False |
training_data_availability_threshold | Union[float, List[float]] | minimum fraction of valid data points required for training windows. Single float applies to both insample and outsample; list of two floats specifies [insample_fraction, outsample_fraction]. Default 0.0 allows windows with only 1 valid data point (current behavior). | 0.0 |
step_size | int | step size between each window of temporal data. | 1 |
scaler_type | str | type of scaler for temporal inputs normalization see temporal scalers. | ‘robust’ |
random_seed | int | random seed initialization for replicability. | 1 |
drop_last_loader | bool | if True TimeSeriesDataLoader drops last non-full batch. | False |
alias | str | optional, Custom name of the model. | None |
optimizer | Subclass of ‘torch.optim.Optimizer’ | optional, user specified optimizer instead of the default choice (Adam). | None |
optimizer_kwargs | dict | optional, list of parameters used by the user specified optimizer. | None |
lr_scheduler | Subclass of ‘torch.optim.lr_scheduler.LRScheduler’ | optional, user specified lr_scheduler instead of the default choice (StepLR). | None |
lr_scheduler_kwargs | dict | optional, list of parameters used by the user specified lr_scheduler. | None |
dataloader_kwargs | dict | optional, list of parameters passed into the PyTorch Lightning dataloader by the TimeSeriesDataLoader. | None |
**trainer_kwargs | int | keyword trainer arguments inherited from PyTorch Lighning’s trainer. |
TFT.fit
fit method, optimizes the neural network’s weights using the
initialization parameters (learning_rate, windows_batch_size, …)
and the loss function as defined during the initialization.
Within fit we use a PyTorch Lightning Trainer that
inherits the initialization’s self.trainer_kwargs, to customize
its inputs, see PL’s trainer arguments.
The method is designed to be compatible with SKLearn-like classes
and in particular to be compatible with the StatsForecast library.
By default the model is not saving training checkpoints to protect
disk memory, to get them change enable_checkpointing=True in __init__.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
dataset | TimeSeriesDataset | NeuralForecast’s TimeSeriesDataset, see documentation. | required |
val_size | int | Validation size for temporal cross-validation. | 0 |
random_seed | int | Random seed for pytorch initializer and numpy generators, overwrites model.init’s. | None |
test_size | int | Test size for temporal cross-validation. | 0 |
| Type | Description |
|---|---|
| None |
TFT.predict
Trainer execution of predict_step.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
dataset | TimeSeriesDataset | NeuralForecast’s TimeSeriesDataset, see documentation. | required |
test_size | int | Test size for temporal cross-validation. | None |
step_size | int | Step size between each window. | 1 |
random_seed | int | Random seed for pytorch initializer and numpy generators, overwrites model.init’s. | None |
quantiles | list | Target quantiles to predict. | None |
h | int | Prediction horizon, if None, uses the model’s fitted horizon. Defaults to None. | None |
explainer_config | dict | configuration for explanations. | None |
**data_module_kwargs | dict | PL’s TimeSeriesDataModule args, see documentation. |
| Type | Description |
|---|---|
| None |
TFT.feature_importances
| Name | Type | Description |
|---|---|---|
dict | A dictionary containing the feature importances for each feature type. The keys are ‘hist_vsn’, ‘future_vsn’, and ‘static_vsn’, and the values are pandas DataFrames with the corresponding feature importances. |
TFT.attention_weights
TFT.feature_importance_correlations
Usage Example
2. TFT Architecture
The first TFT’s step is embed the original input into a high dimensional space , after which each embedding is gated by a variable selection network (VSN). The static embedding is used as context for variable selection and as initial condition to the LSTM. Finally the encoded variables are fed into the multi-head attention decoder.2.1 Static Covariate Encoder
The static embedding is transformed by the StaticCovariateEncoder into contexts . Where are temporal variable selection contexts, are TemporalFusionDecoder enriching contexts, and are LSTM’s hidden/contexts for the TemporalCovariateEncoder.2.2 Temporal Covariate Encoder
TemporalCovariateEncoder encodes the embeddings and contexts with an LSTM. An analogous process is repeated for the future data, with the main difference that contains the future available information.2.3 Temporal Fusion Decoder
The TemporalFusionDecoder enriches the LSTM’s outputs with and then uses an attention layer, and multi-step adapter.3. Interpretability
3.1 Attention Weights
3.1.1 Mean attention
3.1.2 Attention of all future time steps
3.1.3 Attention of a specific future time step
3.2 Feature Importance
3.2.1 Global feature importance
Static variable importances
Past variable importances
Future variable importances
3.2.2 Variable importances over time
Future variable importance over time
Importance of each future covariate at each future time stepPast variable importance over time
Past variable importance over time ponderated by attention
Decomposition of the importance of each time step based on importance of each variable at that time step3.2.3 Variable importance correlations over time
Variables which gain and lose importance at same moments4. Auxiliary Functions
4.1 Gating Mechanisms
The Gated Residual Network (GRN) provides adaptive depth and network complexity capable of accommodating different size datasets. As residual connections allow for the network to skip the non-linear transformation of input and context . The Gated Linear Unit (GLU) provides the flexibility of supressing unnecesary parts of the GRN. Consider GRN’s output then GLU transformation is defined by:
4.2 Variable Selection Networks
TFT includes automated variable selection capabilities, through its variable selection network (VSN) components. The VSN takes the original input and transforms it through embeddings or linear transformations into a high dimensional space . For the observed historic data, the embedding matrix at time is a concatenation of variable embeddings: The variable selection weights are given by: The VSN processed features are then:

