Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixes bug when using historical forecast with discrete lags < -1 #2715

Merged
merged 17 commits into from
Mar 7, 2025
Merged
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ but cannot always guarantee backwards compatibility. Changes that may **break co
- 🔴 / 🟢 Fixed a bug which raised an error when loading torch models that were saved with Darts versions < 0.33.0. This is a breaking change and models saved with version 0.33.0 will not be loadable anymore. [#2692](https://github.com/unit8co/darts/pull/2692) by [Dennis Bader](https://github.com/dennisbader).
- Fixed a bug in `StaticCovariatesTransformer` which raised an error when trying to inverse transform one-hot encoded categorical static covariates with identical values across time-series. Each categorical static covariates is now referred to by `{covariate_name}_{category_name}`, regardless of the number of categories. [#2710](https://github.com/unit8co/darts/pull/2710) by [Antoine Madrona](https://github.com/madtoinou)
- Fixed a bug in `13-TFT-examples.ipynb` where two calls to `TimeSeries.from_series()` were not providing `series` but `pd.Index`. The method calls were changed to `TimeSeries.from_values()`. [#2719](https://github.com/unit8co/darts/pull/2719) by [Jules Authier](https://github.com/authierj)
- Fixed a bug in `RegressionModel` where performing optimized historical forecasts with `max(lags) < -1` resulted in forecasts that extended too far into the future. [#2715](https://github.com/unit8co/darts/pull/2715) by [Jules Authier](https://github.com/authierj)

**Dependencies**

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,8 @@
{},
(5, 3),
),
(LinearRegressionModel, {"lags": [-5]}, {}, (5, 1)),
(LinearRegressionModel, {"lags": [-5], "output_chunk_shift": 1}, {}, (5, 2)),
]
if not isinstance(CatBoostModel, NotImportedModule):
models_reg_no_cov_cls_kwargs.append((
Expand Down Expand Up @@ -665,7 +667,7 @@ def test_historical_forecasts_negative_rangeindex(self):
def test_historical_forecasts(self, config):
"""Tests historical forecasts with retraining for expected forecast lengths and times"""
forecast_horizon = 8
# if no fit and retrain=false, should fit at fist iteration
# if no fit and retrain=false, should fit at first iteration
model_cls, kwargs, model_kwarg, bounds = config
model = model_cls(**kwargs, **model_kwarg)
# set train length to be the minimum required training length
Expand Down Expand Up @@ -1248,11 +1250,12 @@ def test_regression_auto_start_multiple_no_cov(self, config):
[ts_univariate, ts_multivariate],
models_reg_no_cov_cls_kwargs + models_reg_cov_cls_kwargs,
[True, False],
[True, False],
[1, 5],
),
)
def test_optimized_historical_forecasts_regression(self, config):
ts, model_config, multi_models, forecast_horizon = config
ts, model_config, multi_models, overlap_end, forecast_horizon = config
# slightly longer to not affect the last predictable timestamp
ts_covs = self.ts_covs
start = 14
Expand Down Expand Up @@ -1301,6 +1304,7 @@ def test_optimized_historical_forecasts_regression(self, config):
last_points_only=last_points_only,
stride=stride,
forecast_horizon=forecast_horizon,
overlap_end=overlap_end,
enable_optimization=False,
)

Expand All @@ -1317,6 +1321,7 @@ def test_optimized_historical_forecasts_regression(self, config):
last_points_only=last_points_only,
stride=stride,
forecast_horizon=forecast_horizon,
overlap_end=overlap_end,
)

self.helper_compare_hf(hist_fct, opti_hist_fct)
Expand Down
13 changes: 9 additions & 4 deletions darts/utils/historical_forecasts/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -782,8 +782,7 @@ def _adjust_historical_forecasts_time_index(
show_warnings: bool,
) -> TimeIndex:
"""
Shrink the beginning and end of the historical forecasts time index based on the values of `start`,
`forecast_horizon` and `overlap_end`.
Shrink the beginning and end of the historical forecasts time index based on the value of `start`.
"""
# retrieve actual start
# when applicable, shift the start of the forecastable index based on `start`
Expand Down Expand Up @@ -992,7 +991,7 @@ def _get_historical_forecast_boundaries(
overlap_end,
)

# adjust boundaries based on start, forecast_horizon and overlap_end
# adjust boundaries based on start
historical_forecasts_time_index = _adjust_historical_forecasts_time_index(
series=series,
series_idx=series_idx,
Expand Down Expand Up @@ -1020,7 +1019,13 @@ def _get_historical_forecast_boundaries(
hist_fct_tgt_start, hist_fct_tgt_end = historical_forecasts_time_index
if min_target_lag is not None:
hist_fct_tgt_start += min_target_lag * freq
hist_fct_tgt_end -= 1 * freq

# target lag has a gap between the max lag and the present
if hasattr(model, "lags") and model._get_lags("target"):
hist_fct_tgt_end += 1 * freq * model._get_lags("target")[-1]
else:
hist_fct_tgt_end -= 1 * freq

# past lags are <= 0
hist_fct_pc_start, hist_fct_pc_end = historical_forecasts_time_index
if min_past_cov_lag is not None:
Expand Down
Loading