Skip to content

Commit 19b17d2

Browse files
Fixes bug when using historical forecast with discrete lags < -1 (#2715)
* notes * dennis note * trying a few things * seems to fix it * fix and tests * note removal * overlap_end taken into account * pr corrections * changelog updated * comment adpated * make optimized hfc work identically as non optimized * update changelog * update changelog --------- Co-authored-by: dennisbader <[email protected]>
1 parent f76aedf commit 19b17d2

File tree

3 files changed

+17
-6
lines changed

3 files changed

+17
-6
lines changed

CHANGELOG.md

+1
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ but cannot always guarantee backwards compatibility. Changes that may **break co
2828
- 🔴 / 🟢 Fixed a bug which raised an error when loading torch models that were saved with Darts versions < 0.33.0. This is a breaking change and models saved with version 0.33.0 will not be loadable anymore. [#2692](https://github.com/unit8co/darts/pull/2692) by [Dennis Bader](https://github.com/dennisbader).
2929
- Fixed a bug in `StaticCovariatesTransformer` which raised an error when trying to inverse transform one-hot encoded categorical static covariates with identical values across time-series. Each categorical static covariates is now referred to by `{covariate_name}_{category_name}`, regardless of the number of categories. [#2710](https://github.com/unit8co/darts/pull/2710) by [Antoine Madrona](https://github.com/madtoinou)
3030
- Fixed a bug in `13-TFT-examples.ipynb` where two calls to `TimeSeries.from_series()` were not providing `series` but `pd.Index`. The method calls were changed to `TimeSeries.from_values()`. [#2719](https://github.com/unit8co/darts/pull/2719) by [Jules Authier](https://github.com/authierj)
31+
- Fixed a bug in `RegressionModel` where performing optimized historical forecasts with `max(lags) < -1` resulted in forecasts that extended too far into the future. [#2715](https://github.com/unit8co/darts/pull/2715) by [Jules Authier](https://github.com/authierj)
3132

3233
**Dependencies**
3334

darts/tests/utils/historical_forecasts/test_historical_forecasts.py

+7-2
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,8 @@
7272
{},
7373
(5, 3),
7474
),
75+
(LinearRegressionModel, {"lags": [-5]}, {}, (5, 1)),
76+
(LinearRegressionModel, {"lags": [-5], "output_chunk_shift": 1}, {}, (5, 2)),
7577
]
7678
if not isinstance(CatBoostModel, NotImportedModule):
7779
models_reg_no_cov_cls_kwargs.append((
@@ -665,7 +667,7 @@ def test_historical_forecasts_negative_rangeindex(self):
665667
def test_historical_forecasts(self, config):
666668
"""Tests historical forecasts with retraining for expected forecast lengths and times"""
667669
forecast_horizon = 8
668-
# if no fit and retrain=false, should fit at fist iteration
670+
# if no fit and retrain=false, should fit at first iteration
669671
model_cls, kwargs, model_kwarg, bounds = config
670672
model = model_cls(**kwargs, **model_kwarg)
671673
# set train length to be the minimum required training length
@@ -1248,11 +1250,12 @@ def test_regression_auto_start_multiple_no_cov(self, config):
12481250
[ts_univariate, ts_multivariate],
12491251
models_reg_no_cov_cls_kwargs + models_reg_cov_cls_kwargs,
12501252
[True, False],
1253+
[True, False],
12511254
[1, 5],
12521255
),
12531256
)
12541257
def test_optimized_historical_forecasts_regression(self, config):
1255-
ts, model_config, multi_models, forecast_horizon = config
1258+
ts, model_config, multi_models, overlap_end, forecast_horizon = config
12561259
# slightly longer to not affect the last predictable timestamp
12571260
ts_covs = self.ts_covs
12581261
start = 14
@@ -1301,6 +1304,7 @@ def test_optimized_historical_forecasts_regression(self, config):
13011304
last_points_only=last_points_only,
13021305
stride=stride,
13031306
forecast_horizon=forecast_horizon,
1307+
overlap_end=overlap_end,
13041308
enable_optimization=False,
13051309
)
13061310

@@ -1317,6 +1321,7 @@ def test_optimized_historical_forecasts_regression(self, config):
13171321
last_points_only=last_points_only,
13181322
stride=stride,
13191323
forecast_horizon=forecast_horizon,
1324+
overlap_end=overlap_end,
13201325
)
13211326

13221327
self.helper_compare_hf(hist_fct, opti_hist_fct)

darts/utils/historical_forecasts/utils.py

+9-4
Original file line numberDiff line numberDiff line change
@@ -782,8 +782,7 @@ def _adjust_historical_forecasts_time_index(
782782
show_warnings: bool,
783783
) -> TimeIndex:
784784
"""
785-
Shrink the beginning and end of the historical forecasts time index based on the values of `start`,
786-
`forecast_horizon` and `overlap_end`.
785+
Shrink the beginning and end of the historical forecasts time index based on the value of `start`.
787786
"""
788787
# retrieve actual start
789788
# when applicable, shift the start of the forecastable index based on `start`
@@ -992,7 +991,7 @@ def _get_historical_forecast_boundaries(
992991
overlap_end,
993992
)
994993

995-
# adjust boundaries based on start, forecast_horizon and overlap_end
994+
# adjust boundaries based on start
996995
historical_forecasts_time_index = _adjust_historical_forecasts_time_index(
997996
series=series,
998997
series_idx=series_idx,
@@ -1020,7 +1019,13 @@ def _get_historical_forecast_boundaries(
10201019
hist_fct_tgt_start, hist_fct_tgt_end = historical_forecasts_time_index
10211020
if min_target_lag is not None:
10221021
hist_fct_tgt_start += min_target_lag * freq
1023-
hist_fct_tgt_end -= 1 * freq
1022+
1023+
# target lag has a gap between the max lag and the present
1024+
if hasattr(model, "lags") and model._get_lags("target"):
1025+
hist_fct_tgt_end += 1 * freq * model._get_lags("target")[-1]
1026+
else:
1027+
hist_fct_tgt_end -= 1 * freq
1028+
10241029
# past lags are <= 0
10251030
hist_fct_pc_start, hist_fct_pc_end = historical_forecasts_time_index
10261031
if min_past_cov_lag is not None:

0 commit comments

Comments
 (0)