Skip to content

Commit 5b05d2b

Browse files
Fix/RegressionEnsemble with single model regressor and coef access in LinearRegressionModel (#2205)
* fix: overwrite the self.model attribute with the model container * fix: prevent creation of RegressionEnsemble with a regression model created with multi_models=False * update changelog * Update darts/models/forecasting/regression_ensemble_model.py Co-authored-by: Dennis Bader <[email protected]> * rephrasing changelog * fix: enforce multi_models=True when ocl=1 --------- Co-authored-by: Dennis Bader <[email protected]>
1 parent 1d7d854 commit 5b05d2b

File tree

5 files changed

+37
-1
lines changed

5 files changed

+37
-1
lines changed

CHANGELOG.md

+2
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@ but cannot always guarantee backwards compatibility. Changes that may **break co
1818
- Added option to exclude some `group_cols` from being added as static covariates when using `TimeSeries.from_group_dataframe()` with parameter `drop_group_cols`.
1919

2020
**Fixed**
21+
- Fixed a bug in probabilistic `LinearRegressionModel.fit()`, where the `model` attribute was not pointing to all underlying estimators. [#2205](https://github.com/unit8co/darts/pull/2205) by [Antoine Madrona](https://github.com/madtoinou).
22+
- Raise an error in `RegressionEsembleModel` when the `regression_model` was created with `multi_models=False` (not supported). [#2205](https://github.com/unit8co/darts/pull/2205) by [Antoine Madrona](https://github.com/madtoinou).
2123
- Fixed a bug in `coefficient_of_variaton()` with `intersect=True`, where the coefficient was not computed on the intersection. [#2202](https://github.com/unit8co/darts/pull/2202) by [Antoine Madrona](https://github.com/madtoinou).
2224

2325
### For developers of the library:

darts/models/forecasting/linear_regression_model.py

+4
Original file line numberDiff line numberDiff line change
@@ -245,6 +245,7 @@ def fit(
245245

246246
for quantile in self.quantiles:
247247
self.kwargs["quantile"] = quantile
248+
# assign the Quantile regressor to self.model to leverage existing logic
248249
self.model = QuantileRegressor(**self.kwargs)
249250
super().fit(
250251
series=series,
@@ -256,6 +257,9 @@ def fit(
256257

257258
self._model_container[quantile] = self.model
258259

260+
# replace the last trained QuantileRegressor with the dictionnary of Regressors.
261+
self.model = self._model_container
262+
259263
return self
260264

261265
else:

darts/models/forecasting/regression_ensemble_model.py

+5
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,11 @@ def __init__(
123123
lags=None, lags_future_covariates=[0], fit_intercept=False
124124
)
125125
elif isinstance(regression_model, RegressionModel):
126+
raise_if_not(
127+
regression_model.multi_models,
128+
"Cannot use `regression_model` that was created with `multi_models = False`.",
129+
logger,
130+
)
126131
regression_model = regression_model
127132
else:
128133
# scikit-learn like model

darts/models/forecasting/regression_model.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -194,7 +194,7 @@ def encode_year(idx):
194194
self.lags: Dict[str, List[int]] = {}
195195
self.component_lags: Dict[str, Dict[str, List[int]]] = {}
196196
self.input_dim = None
197-
self.multi_models = multi_models
197+
self.multi_models = True if multi_models or output_chunk_length == 1 else False
198198
self._considers_static_covariates = use_static_covariates
199199
self._static_covariates_shape: Optional[Tuple[int, int]] = None
200200
self._lagged_feature_names: Optional[List[str]] = None

darts/tests/models/forecasting/test_regression_ensemble_model.py

+25
Original file line numberDiff line numberDiff line change
@@ -884,6 +884,31 @@ def test_predict_likelihood_parameters_multivariate_regression_ensemble(self):
884884
pred_ens["linear_q0.05"].values() < pred_ens["linear_q0.50"].values()
885885
) and all(pred_ens["linear_q0.50"].values() < pred_ens["linear_q0.95"].values())
886886

887+
def test_wrong_model_creation_params(self):
888+
"""Since `multi_models=False` requires to shift the regression model lags in the past (outside of the forecasting
889+
model predictions), it is not supported."""
890+
forcasting_models = [
891+
self.get_deterministic_global_model(2),
892+
self.get_deterministic_global_model([-5, -7]),
893+
]
894+
RegressionEnsembleModel(
895+
forecasting_models=forcasting_models,
896+
regression_train_n_points=10,
897+
regression_model=LinearRegressionModel(
898+
lags_future_covariates=[0], output_chunk_length=2, multi_models=True
899+
),
900+
)
901+
with pytest.raises(ValueError):
902+
RegressionEnsembleModel(
903+
forecasting_models=forcasting_models,
904+
regression_train_n_points=10,
905+
regression_model=LinearRegressionModel(
906+
lags_future_covariates=[0],
907+
output_chunk_length=2,
908+
multi_models=False,
909+
),
910+
)
911+
887912
@staticmethod
888913
def get_probabilistic_global_model(
889914
lags: Union[int, List[int]],

0 commit comments

Comments
 (0)