Skip to content

Commit 797b6de

Browse files
dennisbaderhrzn
andauthored
Feature/tft darts (#513)
* added tff.py for TFFModel * savepoint switching to master * TFT with MixedCovariates * savepoint software update * checked tft_submodels * original model performs well * all submodels checked and ready * new tft submodels work as original TFT * github test * adapted forward for fitting with darts * predict returns output but issue with target scaling * fixed bug in _produce_train_output(), now predictions look much better * train and prediction work * TFT now supports variable prediction lengths * added multivariate TFT forecast support * added model docstrings * general clean up of unnecessary variables and code * removed files used for building the model * added TFTModel to README.md * added probabilistic forecasting support for TFT with likelihood models * minor docstring fixes * fixed error from testing with MixedCovariatesInferenceDataset for regression models * removed unused changes in darts/utils/data/inference_dataset.py * Revert "removed unused changes in darts/utils/data/inference_dataset.py" This reverts commit a8521cc. * remove unused changes in darts/utils/data/inference_dataset.py * added TFT unit tests * cleaned up unused features in TFT unit test * fixed *args to predict() method for model.historical_forecastings() * solve an issue with Cuda * made compatible with pandas==1.3.0 and statsmodels==0.13.0 * added package pytorch-forecasting>=0.9.1 * added submodels from pytorch-forecasting * removed pytorch-forecasting dependency and applied part of PR review suggestions * Feat/quantile loss as likelihood (#526) * applied changes from PR review Co-authored-by: Julien Herzen <[email protected]> Co-authored-by: Julien Herzen <[email protected]>
1 parent 7ec330c commit 797b6de

13 files changed

+1791
-19
lines changed

README.md

+1
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,7 @@ Model | Univariate | Multivariate | Probabilistic | Multiple-series training | P
138138
`NBEATSModel` | ✅ | ✅ | | ✅ | ✅ | | [N-BEATS paper](https://arxiv.org/abs/1905.10437)
139139
`TCNModel` | ✅ | ✅ | ✅ | ✅ | ✅ | | [TCN paper](https://arxiv.org/abs/1803.01271), [DeepTCN paper](https://arxiv.org/abs/1906.04397), [blog post](https://medium.com/unit8-machine-learning-publication/temporal-convolutional-networks-and-forecasting-5ce1b6e97ce4)
140140
`TransformerModel` | ✅ | ✅ | ✅ | ✅ | ✅ | |
141+
`TFTModel` | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | [TFT paper](https://arxiv.org/pdf/1912.09363.pdf), [PyTorch Forecasting](https://pytorch-forecasting.readthedocs.io/en/latest/models.html)
141142
Naive Baselines | ✅ | | | | | |
142143

143144

darts/models/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
from darts.models.forecasting.tcn_model import TCNModel
3333
from darts.models.forecasting.nbeats import NBEATSModel
3434
from darts.models.forecasting.transformer_model import TransformerModel
35+
from darts.models.forecasting.tft_model import TFTModel
3536

3637
except ModuleNotFoundError:
3738
logger.warning("Support Torch based models not available. To enable it, install u8darts[torch] or u8darts[all].")

darts/models/forecasting/fft.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -171,7 +171,7 @@ def _crop_to_match_seasons(series: TimeSeries, required_matches: Optional[set])
171171
return series
172172

173173
first_ts = series.time_index[0]
174-
freq = first_ts.freq
174+
freq = series.freq
175175
pred_ts = series.time_index[-1] + freq
176176

177177
# start at first timestamp of given series and move forward until a matching timestamp is found

0 commit comments

Comments
 (0)