Skip to content

Commit 8ff36c2

Browse files
authored
[DOC] improve and add tide model to docs (#1762)
### Description Fixes #1758
1 parent 097403e commit 8ff36c2

File tree

5 files changed

+45
-23
lines changed

5 files changed

+45
-23
lines changed

docs/source/models.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ and you should take into account. Here is an overview over the pros and cons of
3030
:py:class:`~pytorch_forecasting.models.nhits.NHiTS`, "x", "x", "x", "", "", "", "", "", "", 1
3131
:py:class:`~pytorch_forecasting.models.deepar.DeepAR`, "x", "x", "x", "", "x", "x", "x [#deepvar]_ ", "x", "", 3
3232
:py:class:`~pytorch_forecasting.models.temporal_fusion_transformer.TemporalFusionTransformer`, "x", "x", "x", "x", "", "x", "", "x", "x", 4
33-
33+
:py:class:`~pytorch_forecasting.model.tide.TiDEModel`, "x", "x", "x", "", "", "", "", "x", "", 3
3434

3535
.. [#deepvar] Accounting for correlations using a multivariate loss function which converts the network into a DeepVAR model.
3636

pytorch_forecasting/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646
NHiTS,
4747
RecurrentNetwork,
4848
TemporalFusionTransformer,
49+
TiDEModel,
4950
get_rnn,
5051
)
5152
from pytorch_forecasting.utils import (
@@ -70,6 +71,7 @@
7071
"NaNLabelEncoder",
7172
"MultiNormalizer",
7273
"TemporalFusionTransformer",
74+
"TiDEModel",
7375
"NBeats",
7476
"NHiTS",
7577
"Baseline",

pytorch_forecasting/models/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from pytorch_forecasting.models.temporal_fusion_transformer import (
1919
TemporalFusionTransformer,
2020
)
21+
from pytorch_forecasting.models.tide import TiDEModel
2122

2223
__all__ = [
2324
"NBeats",
@@ -35,4 +36,5 @@
3536
"GRU",
3637
"MultiEmbedding",
3738
"DecoderMLP",
39+
"TiDEModel",
3840
]

pytorch_forecasting/models/tide/_tide.py

Lines changed: 38 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,8 @@
1+
"""
2+
Implements the TiDE (Time-series Dense Encoder-decoder) model, which is designed for
3+
long-term time-series forecasting.
4+
"""
5+
16
from copy import copy
27
from typing import Dict, List, Optional, Tuple, Union
38

@@ -44,30 +49,39 @@ def __init__(
4449
):
4550
"""An implementation of the TiDE model.
4651
47-
TiDE shares similarities with Transformers (implemented in :class:TransformerModel), but aims to deliver
48-
better performance with reduced computational requirements by utilizing MLP-based encoder-decoder architectures
49-
without attention mechanisms.
52+
TiDE shares similarities with Transformers
53+
(implemented in :class:TransformerModel), but aims to deliver better performance
54+
with reduced computational requirements by utilizing MLP-based encoder-decoder
55+
architectures without attention mechanisms.
5056
51-
This model supports future covariates (known for output_chunk_length points after the prediction time) and
52-
static covariates.
57+
This model supports future covariates (known for output_chunk_length points
58+
after the prediction time) andstatic covariates.
5359
54-
The encoder and decoder are constructed using residual blocks. The number of residual blocks in the encoder and
55-
decoder can be specified with `num_encoder_layers` and `num_decoder_layers` respectively. The layer width in the
56-
residual blocks can be adjusted using `hidden_size`, while the layer width in the temporal decoder can be
57-
controlled via `temporal_decoder_hidden`.
60+
The encoder and decoder are constructed using residual blocks. The number of
61+
residual blocks in the encoder and decoder can be specified with
62+
`num_encoder_layers` and `num_decoder_layers` respectively. The layer width in
63+
the residual blocks can be adjusted using `hidden_size`, while the layer width
64+
in the temporal decoder can be controlled via `temporal_decoder_hidden`.
5865
5966
Parameters
6067
----------
61-
input_chunk_length (int): Number of past time steps to use as input for the model (per chunk).
62-
This applies to the target series and future covariates (if supported by the model).
63-
output_chunk_length (int): Number of time steps the internal model predicts simultaneously (per chunk).
64-
This also determines how many future values from future covariates are used as input
68+
input_chunk_length :int
69+
Number of past time steps to use as input for themodel (per chunk).
70+
This applies to the target series and future covariates
6571
(if supported by the model).
66-
num_encoder_layers (int): Number of residual blocks in the encoder. Defaults to 2.
67-
num_decoder_layers (int): Number of residual blocks in the decoder. Defaults to 2.
68-
decoder_output_dim (int): Dimensionality of the decoder's output. Defaults to 16.
69-
hidden_size (int): Size of hidden layers in the encoder and decoder. Typically ranges from 32 to 128 when
70-
no covariates are used. Defaults to 128.
72+
output_chunk_length : int
73+
Number of time steps the internal model predicts simultaneously (per chunk).
74+
This also determines how many future values from future covariates
75+
are used as input (if supported by the model).
76+
num_encoder_layers : int, default=2
77+
Number of residual blocks in the encoder
78+
num_decoder_layers : int, default=2
79+
Number of residual blocks in the decoder
80+
decoder_output_dim : int, default=16
81+
Dimensionality of the decoder's output
82+
hidden_size : int, default=128
83+
Size of hidden layers in the encoder and decoder.
84+
Typically ranges from 32 to 128 when no covariates are used.
7185
temporal_width_future (int): Width of the output layer in the residual block for future covariate projections.
7286
If set to 0, bypasses feature projection and uses raw feature data. Defaults to 4.
7387
temporal_hidden_size_future (int): Width of the hidden layer in the residual block for future covariate
@@ -98,8 +112,10 @@ def __init__(
98112
**kwargs
99113
Allows optional arguments to configure pytorch_lightning.Module, pytorch_lightning.Trainer, and
100114
pytorch-forecasting's :class:BaseModelWithCovariates.
101-
""" # noqa: E501
102115
116+
Note:
117+
The model supports future covariates and static covariates.
118+
""" # noqa: E501
103119
if static_categoricals is None:
104120
static_categoricals = []
105121
if static_reals is None:
@@ -200,15 +216,16 @@ def static_size(self) -> int:
200216
@classmethod
201217
def from_dataset(cls, dataset: TimeSeriesDataSet, **kwargs):
202218
"""
203-
Convenience function to create network from :py:class`~pytorch_forecasting.data.timeseries.TimeSeriesDataSet`.
219+
Convenience function to create network from
220+
:py:class`~pytorch_forecasting.data.timeseries.TimeSeriesDataSet`.
204221
205222
Args:
206223
dataset (TimeSeriesDataSet): dataset where sole predictor is the target.
207224
**kwargs: additional arguments to be passed to `__init__` method.
208225
209226
Returns:
210227
TiDE
211-
""" # noqa: E501
228+
"""
212229

213230
# validate arguments
214231
assert not isinstance(

pytorch_forecasting/models/tide/sub_modules.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
"""
22
Time-series Dense Encoder (TiDE)
3-
------
3+
--------------------------------
44
"""
55

66
from typing import Optional, Tuple
@@ -226,6 +226,7 @@ def forward(
226226
self, x_in: Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]
227227
) -> torch.Tensor:
228228
"""TiDE model forward pass.
229+
229230
Parameters
230231
----------
231232
x_in

0 commit comments

Comments
 (0)