env: PYTORCH_ENABLE_MPS_FALLBACK=1
In summary Temporal Fusion Transformer (TFT) combines gating layers, an
LSTM recurrent encoder, with multi-head attention layers for a
multi-step forecasting strategy decoder.
TFT’s inputs are static
exogenous x(s), historic exogenous
x[:t](h), exogenous available at the time of the
prediction x[:t+H](f) and autorregresive features
y[:t], each of these inputs is further decomposed into
categorical and continuous. The network uses a multi-quantile regression
to model the following conditional
probability:P(y[t+1:t+H]∣y[:t],x[:t](h),x[:t+H](f),x(s))
References
- Jan Golda, Krzysztof Kudrynski. “NVIDIA, Deep
Learning Forecasting
Examples”
-
Bryan Lim, Sercan O. Arik, Nicolas Loeff, Tomas Pfister, “Temporal
Fusion Transformers for interpretable multi-horizon time series
forecasting”
1. Auxiliary Functions
1.1 Gating Mechanisms
The Gated Residual Network (GRN) provides adaptive depth and network
complexity capable of accommodating different size datasets. As residual
connections allow for the network to skip the non-linear transformation
of input a and context c.
The Gated Linear Unit (GLU) provides the flexibility of supressing
unnecesary parts of the GRN. Consider GRN’s output γ then GLU
transformation is defined by:
GLU(γ)=σ(W4γ+b4)⊙(W5γ+b5)
1.2 Variable Selection Networks
TFT includes automated variable selection capabilities, through its
variable selection network (VSN) components. The VSN takes the original
input
{x(s),x[:t](h),x[:t](f)}
and transforms it through embeddings or linear transformations into a
high dimensional space
{E(s),E[:t](h),E[:t+H](f)}.
For the observed historic data, the embedding matrix
Et(h) at time t is a concatenation of j variable
et,j(h) embeddings:
The variable selection weights are given by:
st(h)=SoftMax(GRN(Et(h),E(s)))
The VSN processed features are then:
E~t(h)=∑jsj(h)e~t,j(h)
1.3. Multi-Head Attention
To avoid information bottlenecks from the classic Seq2Seq architecture,
TFT incorporates a decoder-encoder attention mechanism inherited
transformer architectures (Li et. al
2019, Vaswani et. al
2017). It transform the the outputs
of the LSTM encoded temporal features, and helps the decoder better
capture long-term relationships.
The original multihead attention for each component Hm and its
query, key, and value representations are denoted by
Qm,Km,Vm, its transformation is given by:
TFT modifies the original multihead attention to improve its
interpretability. To do it it uses shared values V~ across
heads and employs additive aggregation,
InterpretableMultiHead(Q,K,V)=H~WM. The
mechanism has a great resemblence to a single attention layer, but it
allows for M multiple attention weights, and can be therefore be
interpreted as the average ensemble of M single attention layers.
2. TFT Architecture
The first TFT’s step is embed the original input
{x(s),x(h),x(f)} into a high
dimensional space
{E(s),E(h),E(f)}, after which
each embedding is gated by a variable selection network (VSN). The
static embedding E(s) is used as context for variable
selection and as initial condition to the LSTM. Finally the encoded
variables are fed into the multi-head attention decoder.
2.1 Static Covariate Encoder
The static embedding E(s) is transformed by the
StaticCovariateEncoder into contexts cs,ce,ch,cc. Where
cs are temporal variable selection contexts, ce are
TemporalFusionDecoder enriching contexts, and ch,cc are LSTM’s
hidden/contexts for the TemporalCovariateEncoder.
2.2 Temporal Covariate Encoder
TemporalCovariateEncoder encodes the embeddings
E(h),E(f) and contexts (ch,cc) with
an LSTM.
An analogous process is repeated for the future data, with the main
difference that E(f) contains the future available
information.
2.3 Temporal Fusion Decoder
The TemporalFusionDecoder enriches the LSTM’s outputs with ce and
then uses an attention layer, and multi-step adapter.
source
TFT
TFT (h, input_size, tgt_size:int=1, stat_exog_list=None,
hist_exog_list=None, futr_exog_list=None, hidden_size:int=128,
n_head:int=4, attn_dropout:float=0.0, grn_activation:str='ELU',
dropout:float=0.1, loss=MAE(), valid_loss=None, max_steps:int=1000,
learning_rate:float=0.001, num_lr_decays:int=-1,
early_stop_patience_steps:int=-1, val_check_steps:int=100,
batch_size:int=32, valid_batch_size:Optional[int]=None,
windows_batch_size:int=1024, inference_windows_batch_size:int=1024,
start_padding_enabled=False, step_size:int=1,
scaler_type:str='robust', num_workers_loader=0,
drop_last_loader=False, random_seed:int=1, optimizer=None,
optimizer_kwargs=None, lr_scheduler=None, lr_scheduler_kwargs=None,
**trainer_kwargs)
*TFT
The Temporal Fusion Transformer architecture (TFT) is an
Sequence-to-Sequence model that combines static, historic and future
available data to predict an univariate target. The method combines
gating layers, an LSTM recurrent encoder, with and interpretable
multi-head attention layer and a multi-step forecasting strategy
decoder.
Parameters:
h
: int, Forecast horizon.
input_size
: int,
autorregresive inputs size, y=[1,2,3,4] input_size=2 ->
y_[t-2:t]=[1,2].
stat_exog_list
: str list, static continuous
columns.
hist_exog_list
: str list, historic continuous
columns.
futr_exog_list
: str list, future continuous columns.
hidden_size
: int, units of embeddings and encoders.
dropout
:
float (0, 1), dropout of inputs VSNs.
n_head
: int=4, number of
attention heads in temporal fusion decoder.
attn_dropout
: float
(0, 1), dropout of fusion decoder’s attention layer.
grn_activation
: str, activation for the GRN module from [‘ReLU’,
‘Softplus’, ‘Tanh’, ‘SELU’, ‘LeakyReLU’, ‘Sigmoid’, ‘ELU’, ‘GLU’].
loss
: PyTorch module, instantiated train loss class from losses
collection.
valid_loss
: PyTorch module=loss
, instantiated valid loss class from
losses
collection.
max_steps
: int=1000, maximum number of training steps.
learning_rate
: float=1e-3, Learning rate between (0, 1).
num_lr_decays
: int=-1, Number of learning rate decays, evenly
distributed across max_steps.
early_stop_patience_steps
: int=-1,
Number of validation iterations before early stopping.
val_check_steps
: int=100, Number of training steps between every
validation loss check.
batch_size
: int, number of different series
in each batch.
windows_batch_size
: int=None, windows sampled from
rolled data, default uses all.
inference_windows_batch_size
:
int=-1, number of windows to sample in each inference batch, -1 uses
all.
start_padding_enabled
: bool=False, if True, the model will
pad the time series with zeros at the beginning, by input size.
valid_batch_size
: int=None, number of different series in each
validation and test batch.
step_size
: int=1, step size between
each window of temporal data.
scaler_type
: str=‘robust’, type of
scaler for temporal inputs normalization see temporal
scalers.
random_seed
: int, random seed initialization for replicability.
num_workers_loader
: int=os.cpu_count(), workers to be used by
TimeSeriesDataLoader
.
drop_last_loader
: bool=False, if True
TimeSeriesDataLoader
drops last non-full batch.
alias
: str,
optional, Custom name of the model.
optimizer
: Subclass of
‘torch.optim.Optimizer’, optional, user specified optimizer instead of
the default choice (Adam).
optimizer_kwargs
: dict, optional, list
of parameters used by the user specified optimizer
.
lr_scheduler
: Subclass of ‘torch.optim.lr_scheduler.LRScheduler’,
optional, user specified lr_scheduler instead of the default choice
(StepLR).
lr_scheduler_kwargs
: dict, optional, list of parameters
used by the user specified lr_scheduler
.
**trainer_kwargs
: int,
keyword trainer arguments inherited from PyTorch Lighning’s
trainer.
References:
- Bryan Lim, Sercan O. Arik, Nicolas Loeff, Tomas
Pfister, “Temporal Fusion Transformers for interpretable multi-horizon
time series
forecasting”*
3. TFT methods
TFT.fit
TFT.fit (dataset, val_size=0, test_size=0, random_seed=None,
distributed_config=None)
*Fit.
The fit
method, optimizes the neural network’s weights using the
initialization parameters (learning_rate
, windows_batch_size
, …) and
the loss
function as defined during the initialization. Within fit
we use a PyTorch Lightning Trainer
that inherits the initialization’s
self.trainer_kwargs
, to customize its inputs, see PL’s trainer
arguments.
The method is designed to be compatible with SKLearn-like classes and in
particular to be compatible with the StatsForecast library.
By default the model
is not saving training checkpoints to protect
disk memory, to get them change enable_checkpointing=True
in
__init__
.
Parameters:
dataset
: NeuralForecast’s
TimeSeriesDataset
,
see
documentation.
val_size
: int, validation size for temporal cross-validation.
random_seed
: int=None, random_seed for pytorch initializer and numpy
generators, overwrites model.__init__’s.
test_size
: int, test
size for temporal cross-validation.
*
TFT.predict
TFT.predict (dataset, test_size=None, step_size=1, random_seed=None,
**data_module_kwargs)
*Predict.
Neural network prediction with PL’s Trainer
execution of
predict_step
.
Parameters:
dataset
: NeuralForecast’s
TimeSeriesDataset
,
see
documentation.
test_size
: int=None, test size for temporal cross-validation.
step_size
: int=1, Step size between each window.
random_seed
:
int=None, random_seed for pytorch initializer and numpy generators,
overwrites model.__init__’s.
**data_module_kwargs
: PL’s
TimeSeriesDataModule args, see
documentation.*
source
TFT.feature_importances,
TFT.feature_importances, ()
*Compute the feature importances for historical, future, and static
features.
Returns: dict: A dictionary containing the feature importances for each
feature type. The keys are ‘hist_vsn’, ‘future_vsn’, and ‘static_vsn’,
and the values are pandas DataFrames with the corresponding feature
importances.*
source
TFT.attention_weights
*Batch average attention weights
Returns: np.ndarray: A 1D array containing the attention weights for
each time step.*
source
TFT.attention_weights
*Batch average attention weights
Returns: np.ndarray: A 1D array containing the attention weights for
each time step.*
source
TFT.feature_importance_correlations
TFT.feature_importance_correlations ()
*Compute the correlation between the past and future feature
importances and the mean attention weights.
Returns: pd.DataFrame: A DataFrame containing the correlation
coefficients between the past feature importances and the mean attention
weights.*
Usage Example
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
from neuralforecast import NeuralForecast
from neuralforecast.models import TFT
from neuralforecast.losses.pytorch import DistributionLoss
from neuralforecast.utils import AirPassengersPanel, AirPassengersStatic
AirPassengersPanel['month']=AirPassengersPanel.ds.dt.month
Y_train_df = AirPassengersPanel[AirPassengersPanel.ds<AirPassengersPanel['ds'].values[-12]]
Y_test_df = AirPassengersPanel[AirPassengersPanel.ds>=AirPassengersPanel['ds'].values[-12]].reset_index(drop=True)
nf = NeuralForecast(
models=[TFT(h=12, input_size=48,
hidden_size=20,
grn_activation='ELU',
loss=DistributionLoss(distribution='StudentT', level=[80, 90]),
learning_rate=0.005,
stat_exog_list=['airline1'],
futr_exog_list=['y_[lag12]','month'],
hist_exog_list=['trend'],
max_steps=300,
val_check_steps=10,
early_stop_patience_steps=10,
scaler_type='robust',
windows_batch_size=None,
enable_progress_bar=True),
],
freq='M'
)
nf.fit(df=Y_train_df, static_df=AirPassengersStatic, val_size=12)
Y_hat_df = nf.predict(futr_df=Y_test_df)
Y_hat_df = Y_hat_df.reset_index(drop=False).drop(columns=['unique_id','ds'])
plot_df = pd.concat([Y_test_df, Y_hat_df], axis=1)
plot_df = pd.concat([Y_train_df, plot_df])
plot_df = plot_df[plot_df.unique_id=='Airline1'].drop('unique_id', axis=1)
plt.plot(plot_df['ds'], plot_df['y'], c='black', label='True')
plt.plot(plot_df['ds'], plot_df['TFT'], c='purple', label='mean')
plt.plot(plot_df['ds'], plot_df['TFT-median'], c='blue', label='median')
plt.fill_between(x=plot_df['ds'][-12:],
y1=plot_df['TFT-lo-90'][-12:].values,
y2=plot_df['TFT-hi-90'][-12:].values,
alpha=0.4, label='level 90')
plt.legend()
plt.grid()
plt.plot()
Seed set to 1
GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
Sanity Checking: | | 0/? [00:00<?, ?it/s]
Training: | | 0/? [00:00<?, ?it/s]
Validation: | | 0/? [00:00<?, ?it/s]
Validation: | | 0/? [00:00<?, ?it/s]
Validation: | | 0/? [00:00<?, ?it/s]
Validation: | | 0/? [00:00<?, ?it/s]
Validation: | | 0/? [00:00<?, ?it/s]
Validation: | | 0/? [00:00<?, ?it/s]
Validation: | | 0/? [00:00<?, ?it/s]
Validation: | | 0/? [00:00<?, ?it/s]
Validation: | | 0/? [00:00<?, ?it/s]
Validation: | | 0/? [00:00<?, ?it/s]
Validation: | | 0/? [00:00<?, ?it/s]
Validation: | | 0/? [00:00<?, ?it/s]
Validation: | | 0/? [00:00<?, ?it/s]
Validation: | | 0/? [00:00<?, ?it/s]
Validation: | | 0/? [00:00<?, ?it/s]
Validation: | | 0/? [00:00<?, ?it/s]
Validation: | | 0/? [00:00<?, ?it/s]
Validation: | | 0/? [00:00<?, ?it/s]
Trainer already configured with model summary callbacks: [<class 'pytorch_lightning.callbacks.model_summary.ModelSummary'>]. Skipping setting a default `ModelSummary` callback.
GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
Predicting: | | 0/? [00:00<?, ?it/s]
Interpretability
1. Attention Weights
attention = nf.models[0].attention_weights()
def plot_attention(self, plot:str="time", output:str='plot', width:int=800, height:int=400):
"""
Plot the attention weights.
Args:
plot (str, optional): The type of plot to generate. Can be one of the following:
- 'time': Display the mean attention weights over time.
- 'all': Display the attention weights for each horizon.
- 'heatmap': Display the attention weights as a heatmap.
- An integer in the range [1, model.h) to display the attention weights for a specific horizon.
output (str, optional): The type of output to generate. Can be one of the following:
- 'plot': Display the plot directly.
- 'figure': Return the plot as a figure object.
width (int, optional): Width of the plot in pixels. Default is 800.
height (int, optional): Height of the plot in pixels. Default is 400.
Returns:
matplotlib.figure.Figure: If `output` is 'figure', the function returns the plot as a figure object.
"""
attention = (
self.mean_on_batch(self.interpretability_params["attn_wts"])
.mean(dim=0)
.cpu()
.numpy()
)
fig, ax = plt.subplots(figsize=(width / 100, height / 100))
if plot == "time":
attention = attention[self.input_size:, :].mean(axis=0)
ax.plot(np.arange(-self.input_size, self.h), attention)
ax.axvline(x=0, color='black', linewidth=3, linestyle='--', label="prediction start")
ax.set_title("Mean Attention")
ax.set_xlabel("time")
ax.set_ylabel("Attention")
ax.legend()
elif plot == "all":
for i in range(self.input_size, attention.shape[0]):
ax.plot(np.arange(-self.input_size, self.h), attention[i, :], label=f"horizon {i-self.input_size+1}")
ax.axvline(x=0, color='black', linewidth=3, linestyle='--', label="prediction start")
ax.set_title("Attention per horizon")
ax.set_xlabel("time")
ax.set_ylabel("Attention")
ax.legend()
elif plot == "heatmap":
cax = ax.imshow(attention, aspect='auto', cmap='viridis',
extent=[-self.input_size, self.h, -self.input_size, self.h])
fig.colorbar(cax)
ax.set_title("Attention Heatmap")
ax.set_xlabel("Attention (current time step)")
ax.set_ylabel("Attention (previous time step)")
elif isinstance(plot, int) and (plot in np.arange(1, self.h + 1)):
i = self.input_size + plot - 1
ax.plot(np.arange(-self.input_size, self.h), attention[i, :], label=f"horizon {plot}")
ax.axvline(x=0, color='black', linewidth=3, linestyle='--', label="prediction start")
ax.set_title(f"Attention weight for horizon {plot}")
ax.set_xlabel("time")
ax.set_ylabel("Attention")
ax.legend()
else:
raise ValueError('plot has to be in ["time","all","heatmap"] or integer in range(1,model.h)')
plt.tight_layout()
if output == 'plot':
plt.show()
elif output == 'figure':
return fig
else:
raise ValueError(f"Invalid output: {output}. Expected 'plot' or 'figure'.")
1.1 Mean attention
plot_attention(nf.models[0], plot="time")
1.2 Attention of all future time steps
plot_attention(nf.models[0], plot="all")
1.3 Attention of a specific future time step
plot_attention(nf.models[0], plot=8)
2. Feature Importance
2.1 Global feature importance
feature_importances = nf.models[0].feature_importances()
feature_importances.keys()
dict_keys(['Past variable importance over time', 'Future variable importance over time', 'Static covariates'])
Static variable importances
feature_importances['Static covariates'].sort_values(by='importance').plot(kind='barh')
Past variable importances
feature_importances['Past variable importance over time'].mean().sort_values().plot(kind='barh')
Future variable importances
feature_importances['Future variable importance over time'].mean().sort_values().plot(kind='barh')
2.2 Variable importances over time
Future variable importance over time
Importance of each future covariate at each future time step
df=feature_importances['Future variable importance over time']
fig, ax = plt.subplots(figsize=(20, 10))
bottom = np.zeros(len(df.index))
for col in df.columns:
p = ax.bar(np.arange(-len(df),0), df[col].values, 0.6, label=col, bottom=bottom)
bottom += df[col]
ax.set_title('Future variable importance over time ponderated by attention')
ax.set_ylabel("Importance")
ax.set_xlabel("Time")
ax.grid(True)
ax.legend()
plt.show()
2.3
Past variable importance over time
df= feature_importances['Past variable importance over time']
fig, ax = plt.subplots(figsize=(20, 10))
bottom = np.zeros(len(df.index))
for col in df.columns:
p = ax.bar(np.arange(-len(df),0), df[col].values, 0.6, label=col, bottom=bottom)
bottom += df[col]
ax.set_title('Past variable importance over time')
ax.set_ylabel("Importance")
ax.set_xlabel("Time")
ax.legend()
ax.grid(True)
plt.show()
Past variable importance over time ponderated by attention
Decomposition of the importance of each time step based on importance of
each variable at that time step
df= feature_importances['Past variable importance over time']
mean_attention = nf.models[0].attention_weights()[nf.models[0].input_size:,:].mean(axis=0)[:nf.models[0].input_size]
df = df.multiply(mean_attention, axis=0)
fig, ax = plt.subplots(figsize=(20, 10))
bottom = np.zeros(len(df.index))
for col in df.columns:
p = ax.bar(np.arange(-len(df),0), df[col].values, 0.6, label=col, bottom=bottom)
bottom += df[col]
ax.set_title('Past variable importance over time ponderated by attention')
ax.set_ylabel("Importance")
ax.set_xlabel("Time")
ax.legend()
ax.grid(True)
plt.plot(np.arange(-len(df),0), mean_attention, color='black', marker='o', linestyle='-', linewidth=2, label='mean_attention')
plt.legend()
plt.show()
3. Variable importance correlations over time
Variables which gain and lose importance at same moments
nf.models[0].feature_importance_correlations()
| trend | y_[lag12] | month | observed_target | Correlation with Mean Attention |
---|
trend | 1.00 | 0.69 | -0.77 | 0.44 | 0.40 |
y_[lag12] | 0.69 | 1.00 | -0.79 | 0.41 | 0.22 |
month | -0.77 | -0.79 | 1.00 | -0.72 | -0.48 |
observed_target | 0.44 | 0.41 | -0.72 | 1.00 | 0.74 |
Correlation with Mean Attention | 0.40 | 0.22 | -0.48 | 0.74 | 1.00 |