Forecasting with TFT: Temporal Fusion Transformer
Temporal Fusion Transformer (TFT) proposed by Lim et al. [1] is one of the most popular transformer-based model for time-series forecasting. In summary, TFT combines gating layers, an LSTM recurrent encoder, with multi-head attention layers for a multi-step forecasting strategy decoder. For more details on the Nixtla’s TFT implementation visit this link.
In this notebook we show how to train the TFT model on the Texas electricity market load data (ERCOT). Accurately forecasting electricity markets is of great interest, as it is useful for planning distribution and consumption.
We will show you how to load the data, train the TFT performing automatic hyperparameter tuning, and produce forecasts. Then, we will show you how to perform multiple historical forecasts for cross validation.
You can run these experiments using GPU with Google Colab.
1. Libraries
2. Load ERCOT Data
The input to NeuralForecast is always a data frame in long
format with
three columns: unique_id
, ds
and y
:
-
The
unique_id
(string, int or category) represents an identifier for the series. -
The
ds
(datestamp or int) column should be either an integer indexing time or a datestamp ideally like YYYY-MM-DD for a date or YYYY-MM-DD HH:MM:SS for a timestamp. -
The
y
(numeric) represents the measurement we wish to forecast. We will rename the
First, read the 2022 historic total demand of the ERCOT market. We processed the original data (available here), by adding the missing hour due to daylight saving time, parsing the date to datetime format, and filtering columns of interest.
unique_id | ds | y | |
---|---|---|---|
0 | ERCOT | 2021-01-01 00:00:00 | 43719.849616 |
1 | ERCOT | 2021-01-01 01:00:00 | 43321.050347 |
2 | ERCOT | 2021-01-01 02:00:00 | 43063.067063 |
3 | ERCOT | 2021-01-01 03:00:00 | 43090.059203 |
4 | ERCOT | 2021-01-01 04:00:00 | 43486.590073 |
3. Model training and forecast
First, instantiate the
AutoTFT
model. The
AutoTFT
class will automatically perform hyperparamter tunning using Tune
library, exploring a
user-defined or default search space. Models are selected based on the
error on a validation set and the best model is then stored and used
during inference.
To instantiate
AutoTFT
you need to define:
h
: forecasting horizonloss
: training lossconfig
: hyperparameter search space. IfNone
, theAutoTFT
class will use a pre-defined suggested hyperparameter space.num_samples
: number of configurations explored.
Tip
Increase the
num_samples
parameter to explore a wider set of configurations for the selected models. As a rule of thumb choose it to be bigger than15
.With
num_samples=3
this example should run in around 20 minutes.
Tip
All our models can be used for both point and probabilistic forecasting. For producing probabilistic outputs, simply modify the loss to one of our
DistributionLoss
. The complete list of losses is available in this link
Important
TFT is a very large model and can require a lot of memory! If you are running out of GPU memory, try declaring your config search space and decrease the
hidden_size
,n_heads
, andwindows_batch_size
parameters.This are all the parameters of the config:
The
NeuralForecast
class has built-in methods to simplify the forecasting pipelines, such
as fit
, predit
, and cross_validation
. Instantiate a
NeuralForecast
object with the following required parameters:
-
models
: a list of models. -
freq
: a string indicating the frequency of the data. (See panda’s available frequencies.)
Then, use the fit
method to train the
AutoTFT
model on the ERCOT data. The total training time will depend on the
hardware and the explored configurations, it should take between 10 and
30 minutes.
Finally, use the predict
method to forecast the next 24 hours after
the training data and plot the forecasts.
ds | AutoTFT | |
---|---|---|
unique_id | ||
ERCOT | 2022-10-01 00:00:00 | 38644.019531 |
ERCOT | 2022-10-01 01:00:00 | 36833.121094 |
ERCOT | 2022-10-01 02:00:00 | 35698.265625 |
ERCOT | 2022-10-01 03:00:00 | 35065.148438 |
ERCOT | 2022-10-01 04:00:00 | 34788.566406 |
Plot the results with matplot lib
4. Cross validation for multiple historic forecasts
The cross_validation
method allows you to simulate multiple historic
forecasts, greatly simplifying pipelines by replacing for loops with
fit
and predict
methods. See this
tutorial
for an animation of how the windows are defined.
With time series data, cross validation is done by defining a sliding
window across the historical data and predicting the period following
it. This form of cross validation allows us to arrive at a better
estimation of our model’s predictive abilities across a wider range of
temporal instances while also keeping the data in the training set
contiguous as is required by our models. The cross_validation
method
will use the validation set for hyperparameter selection, and will then
produce the forecasts for the test set.
Use the cross_validation
method to produce all the daily forecasts for
September. Set the validation and test sizes. To produce daily forecasts
set the forecasting set the step size between windows as 24, to only
produce one forecast per day.
Finally, we merge the forecasts with the Y_df
dataset and plot the
forecasts.
Next Steps
In Challu et al [2] we demonstrate that the N-HiTS model outperforms the latest transformers by more than 20% with 50 times less computation.
Learn how to use the N-HiTS and the NeuralForecast library in this tutorial.