From ce0242ed9969267ef204a9053866a61ed74ea718 Mon Sep 17 00:00:00 2001 From: madtoinou Date: Wed, 5 Mar 2025 11:59:17 +0100 Subject: [PATCH 1/7] feat: adding support for getting estimator based on quantile and associated tests --- darts/models/forecasting/regression_model.py | 73 +++++++--- .../forecasting/test_regression_models.py | 133 +++++++++++++++++- 2 files changed, 186 insertions(+), 20 deletions(-) diff --git a/darts/models/forecasting/regression_model.py b/darts/models/forecasting/regression_model.py index 6c69c48c3e..c6d0bda33a 100644 --- a/darts/models/forecasting/regression_model.py +++ b/darts/models/forecasting/regression_model.py @@ -505,43 +505,77 @@ def output_chunk_length(self) -> int: def output_chunk_shift(self) -> int: return self._output_chunk_shift - def get_multioutput_estimator(self, horizon: int, target_dim: int): + def get_multioutput_estimator( + self, horizon: int, target_dim: int, quantile: Optional[float] = None + ): """Returns the estimator that forecasts the `horizon`th step of the `target_dim`th target component. Internally, estimators are grouped by `output_chunk_length` position, then by component. + Note: for probabilistic models fitting quantiles, there is an additional abstraction layer, + grouping the estimators by `quantile`. + Parameters ---------- horizon The index of the forecasting point within `output_chunk_length`. target_dim The index of the target component. + quantile + Optionaly, for probabilistic model fitting quantiles, the quantile value """ - raise_if_not( - isinstance(self.model, MultiOutputRegressor), - "The sklearn model is not a MultiOutputRegressor object.", - logger, - ) - raise_if_not( - 0 <= horizon < self.output_chunk_length, - f"`horizon` must be `>= 0` and `< output_chunk_length={self.output_chunk_length}`.", - logger, - ) - raise_if_not( - 0 <= target_dim < self.input_dim["target"], - f"`target_dim` must be `>= 0`, and `< n_target_components={self.input_dim['target']}`.", - logger, - ) + if not isinstance(self.model, MultiOutputRegressor): + raise_log( + ValueError("The sklearn model is not a MultiOutputRegressor object."), + logger, + ) + if not 0 <= horizon < self.output_chunk_length: + raise_log( + ValueError( + f"`horizon` must be `>= 0` and `< output_chunk_length={self.output_chunk_length}`." + ), + logger, + ) + if not 0 <= target_dim < self.input_dim["target"]: + raise_log( + ValueError( + f"`target_dim` must be `>= 0`, and `< n_target_components={self.input_dim['target']}`." + ), + logger, + ) # when multi_models=True, one model per horizon and target component idx_estimator = ( self.multi_models * self.input_dim["target"] * horizon + target_dim ) - return self.model.estimators_[idx_estimator] + if quantile is None: + return self.model.estimators_[idx_estimator] + # for quantile-models, the estimators are also grouped by quantiles + else: + if self.likelihood != "quantile": + raise_log( + ValueError( + "`quantile` is supported only when the `RegressionModel` is probabilistic and using the " + "'quantile' likelihood." + ), + logger, + ) + if quantile not in self._model_container.keys(): + raise_log( + ValueError( + f"The fitted quantiles are {list(self._model_container.keys())}, received quantile={quantile}." + ), + logger, + ) + return self._model_container[quantile].estimators_[idx_estimator] - def get_estimator(self, horizon: int, target_dim: int): + def get_estimator( + self, horizon: int, target_dim: int, quantile: Optional[float] = None + ): """Returns the estimator that forecasts the `horizon`th step of the `target_dim`th target component. + For probabilistic models fitting quantiles, it is possible to also specify the quantile. + The model is returned directly if it supports multi-output natively. Parameters @@ -550,8 +584,9 @@ def get_estimator(self, horizon: int, target_dim: int): The index of the forecasting point within `output_chunk_length`. target_dim The index of the target component. + quantile + Optionaly, for probabilistic model fitting quantiles, the quantile value """ - if isinstance(self.model, MultiOutputRegressor): return self.get_multioutput_estimator( horizon=horizon, target_dim=target_dim diff --git a/darts/tests/models/forecasting/test_regression_models.py b/darts/tests/models/forecasting/test_regression_models.py index cc9a514b51..76ef5d7af3 100644 --- a/darts/tests/models/forecasting/test_regression_models.py +++ b/darts/tests/models/forecasting/test_regression_models.py @@ -1337,7 +1337,7 @@ def test_opti_historical_forecast_predict_checks(self): ], ) def test_multioutput_wrapper(self, config): - """Check that with input_chunk_length=1, wrapping in MultiOutputRegressor is not happening""" + """Check that with input_chunk_length=1, wrapping in MultiOutputRegressor occurs only when necessary""" model, supports_multioutput_natively = config model.fit(series=self.sine_multivariate1) if supports_multioutput_natively: @@ -1469,6 +1469,137 @@ def test_get_multioutput_estimator_single_model(self): # sub-model forecast only depend on the target_dim assert np.abs(j + 1 - pred) < 1e-2 + @pytest.mark.parametrize("multi_models", [True, False]) + def test_get_multioutput_estimator_quantile(self, multi_models): + """Check estimator getter when using quantile value""" + ocl = 3 + lags = 3 + quantiles = [0.01, 0.5, 0.99] + ts = tg.gaussian_timeseries( + mean=0, std=1, length=100, column_name="normal" + ).stack( + tg.gaussian_timeseries(mean=10, std=1, length=100, column_name="gaussian"), + ) + + m = XGBModel( + lags=lags, + output_chunk_length=ocl, + multi_models=multi_models, + likelihood="quantile", + quantiles=quantiles, + random_state=1, + ) + m.fit(ts) + + assert len(m._model_container) == len(quantiles) + for quantile_container in m._model_container.values(): + # one sub-model per quantile, per component, per horizon + if multi_models: + assert len(quantile_container.estimators_) == ocl * ts.width + # one sub-model per quantile, per component + else: + assert len(quantile_container.estimators_) == ts.width + + # check that retrieve sub-models prediction match the "wrapper" model predictions + pred = m.predict( + n=ocl, + series=ts[-lags:] if multi_models else ts[-lags - ocl + 1 :], + num_samples=1, + predict_likelihood_parameters=True, + ) + for j in range(ts.width): + dummy_feats = np.array([[0, 0.1, -0.1] * ts.width]) + 10 * j + for i in range(ocl): + for q in quantiles: + sub_model = m.get_multioutput_estimator( + horizon=i, target_dim=j, quantile=q + ) + pred_sub_model = sub_model.predict(dummy_feats)[0] + # due to the difference in inputs, the predictions are not exactly identical + assert ( + np.abs( + pred[f"{ts.components[j]}_q{q:.2f}"].values()[i][0] + - pred_sub_model + ) + < 3 + ) + + def test_get_multioutput_estimator_exceptions(self): + """Check that all the corner-cases are properly covered by the method""" + ts = TimeSeries.from_values( + values=np.array([ + [0, 0, 0, 0, 1], + [0, 0, 0, 0, 2], + ]).T, + columns=["a", "b"], + ) + m = LinearRegressionModel( + lags=2, + output_chunk_length=2, + random_state=1, + ) + m.fit(ts["a"]) + # not wrapped in MultiOutputRegressor because of native multi-output support + with pytest.raises(ValueError) as err: + m.get_multioutput_estimator(horizon=0, target_dim=0) + assert str(err.value).startswith( + "The sklearn model is not a MultiOutputRegressor object." + ) + + # univariate, deterministic, ocl > 2 + m = RegressionModel( + model=HistGradientBoostingRegressor(), + lags=2, + output_chunk_length=2, + ) + m.fit(ts["a"]) + # horizon > ocl + with pytest.raises(ValueError) as err: + m.get_multioutput_estimator(horizon=3, target_dim=0) + assert str(err.value).startswith( + "`horizon` must be `>= 0` and `< output_chunk_length" + ) + # target dim > training series width + with pytest.raises(ValueError) as err: + m.get_multioutput_estimator(horizon=0, target_dim=1) + assert str(err.value).startswith( + "`target_dim` must be `>= 0`, and `< n_target_components=" + ) + + # univariate, probabilistic + # using the quantiles argument to force wrapping in MultiOutputRegressor + m = XGBModel( + lags=2, + output_chunk_length=2, + random_state=1, + likelihood="poisson", + quantiles=[0.5], + ) + m.fit(ts["a"]) + # incorrect likelihood + with pytest.raises(ValueError) as err: + m.get_multioutput_estimator(horizon=0, target_dim=0, quantile=0.1) + assert str(err.value).startswith( + "`quantile` is supported only when the `RegressionModel` is probabilistic " + "and using the 'quantile' likelihood." + ) + + # univariate, probabilistic + m = XGBModel( + lags=2, + output_chunk_length=2, + random_state=1, + likelihood="quantile", + quantiles=[0.01, 0.5, 0.99], + ) + m.fit(ts["a"]) + # retrieving a non-defined quantile + with pytest.raises(ValueError) as err: + m.get_multioutput_estimator(horizon=0, target_dim=0, quantile=0.1) + assert str(err.value).startswith( + "The fitted quantiles are [0.01, 0.5, 0.99], received quantile=0.1" + ) + @pytest.mark.parametrize("mode", [True, False]) def test_regression_model(self, mode): lags = 4 From 41ca60891430c8da6fe51755f71f288052abf725 Mon Sep 17 00:00:00 2001 From: madtoinou Date: Wed, 5 Mar 2025 12:04:09 +0100 Subject: [PATCH 2/7] updated changelog --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index a592e63e5d..e9ed0218e0 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -16,6 +16,7 @@ but cannot always guarantee backwards compatibility. Changes that may **break co - Made method `ForecastingModel.untrained_model()` public. Use this method to get a new (untrained) model instance created with the same parameters. [#2684](https://github.com/unit8co/darts/pull/2684) by [Timon Erhart](https://github.com/turbotimon) - `TimeSeries.plot()` now supports setting the color for each component in the series. Simply pass a list / sequence of colors with length matching the number of components as parameters "c" or "colors". [#2680](https://github.com/unit8co/darts/pull/2680) by [Jules Authier](https://github.com/authierj) - Made it possible to run the quickstart notebook `00-quickstart.ipynb` locally. [#2691](https://github.com/unit8co/darts/pull/2691) by [Jules Authier](https://github.com/authierj) +- `get_multioutput_estimator` now support using `quantile` in addition of `horizon` and `target_dim` to get the estimator for probabilistic regression models using the `quantile` likelihood. [#2716](https://github.com/unit8co/darts/pull/2716) by [Antoine Madrona](https://github.com/madtoinou) **Fixed** From 873a1120dce9bc2e662e3f4056057ca43ac66ef0 Mon Sep 17 00:00:00 2001 From: dennisbader Date: Fri, 7 Mar 2025 13:23:45 +0100 Subject: [PATCH 3/7] minor updates --- CHANGELOG.md | 2 +- darts/models/forecasting/regression_model.py | 39 ++++++++++--------- .../forecasting/test_regression_models.py | 31 ++++++++------- 3 files changed, 37 insertions(+), 35 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 2e8a5e7ed6..4df55d062a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -17,7 +17,7 @@ but cannot always guarantee backwards compatibility. Changes that may **break co - Made method `ForecastingModel.untrained_model()` public. Use this method to get a new (untrained) model instance created with the same parameters. [#2684](https://github.com/unit8co/darts/pull/2684) by [Timon Erhart](https://github.com/turbotimon) - `TimeSeries.plot()` now supports setting the color for each component in the series. Simply pass a list / sequence of colors with length matching the number of components as parameters "c" or "colors". [#2680](https://github.com/unit8co/darts/pull/2680) by [Jules Authier](https://github.com/authierj) - Made it possible to run the quickstart notebook `00-quickstart.ipynb` locally. [#2691](https://github.com/unit8co/darts/pull/2691) by [Jules Authier](https://github.com/authierj) -- `get_multioutput_estimator` now support using `quantile` in addition of `horizon` and `target_dim` to get the estimator for probabilistic regression models using the `quantile` likelihood. [#2716](https://github.com/unit8co/darts/pull/2716) by [Antoine Madrona](https://github.com/madtoinou) +- Added `quantile` parameter to `RegressionModel.get_estimator()` to get the specific quantile estimator for probabilistic regression models using the `quantile` likelihood. [#2716](https://github.com/unit8co/darts/pull/2716) by [Antoine Madrona](https://github.com/madtoinou) **Fixed** diff --git a/darts/models/forecasting/regression_model.py b/darts/models/forecasting/regression_model.py index 49bef66a28..b2eeab16a7 100644 --- a/darts/models/forecasting/regression_model.py +++ b/darts/models/forecasting/regression_model.py @@ -522,7 +522,7 @@ def get_multioutput_estimator( target_dim The index of the target component. quantile - Optionaly, for probabilistic model fitting quantiles, the quantile value + Optionally, for probabilistic model with `likelihood="quantile"`, a quantile value. """ if not isinstance(self.model, MultiOutputRegressor): raise_log( @@ -550,24 +550,25 @@ def get_multioutput_estimator( ) if quantile is None: return self.model.estimators_[idx_estimator] + # for quantile-models, the estimators are also grouped by quantiles - else: - if self.likelihood != "quantile": - raise_log( - ValueError( - "`quantile` is supported only when the `RegressionModel` is probabilistic and using the " - "'quantile' likelihood." - ), - logger, - ) - if quantile not in self._model_container.keys(): - raise_log( - ValueError( - f"The fitted quantiles are {list(self._model_container.keys())}, received quantile={quantile}." - ), - logger, - ) - return self._model_container[quantile].estimators_[idx_estimator] + if self.likelihood != "quantile": + raise_log( + ValueError( + "`quantile` is only supported for probabilistic models that " + "use `likelihood='quantile'`." + ), + logger, + ) + if quantile not in self._model_container: + raise_log( + ValueError( + f"Invalid `quantile={quantile}`. Must be one of the fitted quantiles " + f"`{list(self._model_container.keys())}`." + ), + logger, + ) + return self._model_container[quantile].estimators_[idx_estimator] def get_estimator( self, horizon: int, target_dim: int, quantile: Optional[float] = None @@ -585,7 +586,7 @@ def get_estimator( target_dim The index of the target component. quantile - Optionaly, for probabilistic model fitting quantiles, the quantile value + Optionally, for probabilistic model with `likelihood="quantile"`, a quantile value. """ if isinstance(self.model, MultiOutputRegressor): return self.get_multioutput_estimator( diff --git a/darts/tests/models/forecasting/test_regression_models.py b/darts/tests/models/forecasting/test_regression_models.py index 90f514a063..ed83edd18b 100644 --- a/darts/tests/models/forecasting/test_regression_models.py +++ b/darts/tests/models/forecasting/test_regression_models.py @@ -1496,10 +1496,8 @@ def test_get_multioutput_estimator_quantile(self, multi_models): ocl = 3 lags = 3 quantiles = [0.01, 0.5, 0.99] - ts = tg.gaussian_timeseries( - mean=0, std=1, length=100, column_name="normal" - ).stack( - tg.gaussian_timeseries(mean=10, std=1, length=100, column_name="gaussian"), + ts = tg.sine_timeseries(length=100, column_name="sine").stack( + tg.linear_timeseries(length=100, column_name="linear"), ) m = XGBModel( @@ -1513,6 +1511,7 @@ def test_get_multioutput_estimator_quantile(self, multi_models): m.fit(ts) assert len(m._model_container) == len(quantiles) + assert sorted(list(m._model_container.keys())) == sorted(quantiles) for quantile_container in m._model_container.values(): # one sub-model per quantile, per component, per horizon if multi_models: @@ -1522,27 +1521,28 @@ def test_get_multioutput_estimator_quantile(self, multi_models): assert len(quantile_container.estimators_) == ts.width # check that retrieve sub-models prediction match the "wrapper" model predictions + pred_input = ts[-lags:] if multi_models else ts[-lags - ocl + 1 :] pred = m.predict( n=ocl, - series=ts[-lags:] if multi_models else ts[-lags - ocl + 1 :], + series=pred_input, num_samples=1, predict_likelihood_parameters=True, ) for j in range(ts.width): - dummy_feats = np.array([[0, 0.1, -0.1] * ts.width]) + 10 * j for i in range(ocl): + if multi_models: + dummy_feats = pred_input.values()[:lags] + else: + dummy_feats = pred_input.values()[i : +i + lags] + dummy_feats = np.expand_dims(dummy_feats.flatten(), 0) for q in quantiles: sub_model = m.get_multioutput_estimator( horizon=i, target_dim=j, quantile=q ) pred_sub_model = sub_model.predict(dummy_feats)[0] - # due to the difference in inputs, the predictions are not exactly identical assert ( - np.abs( - pred[f"{ts.components[j]}_q{q:.2f}"].values()[i][0] - - pred_sub_model - ) - < 3 + pred_sub_model + == pred[f"{ts.components[j]}_q{q:.2f}"].values()[i][0] ) def test_get_multioutput_estimator_exceptions(self): @@ -1601,8 +1601,8 @@ def test_get_multioutput_estimator_exceptions(self): with pytest.raises(ValueError) as err: m.get_multioutput_estimator(horizon=0, target_dim=0, quantile=0.1) assert str(err.value).startswith( - "`quantile` is supported only when the `RegressionModel` is probabilistic " - "and using the 'quantile' likelihood." + "`quantile` is only supported for probabilistic models that " + "use `likelihood='quantile'`." ) # univariate, probabilistic @@ -1618,7 +1618,8 @@ def test_get_multioutput_estimator_exceptions(self): with pytest.raises(ValueError) as err: m.get_multioutput_estimator(horizon=0, target_dim=0, quantile=0.1) assert str(err.value).startswith( - "The fitted quantiles are [0.01, 0.5, 0.99], received quantile=0.1" + "Invalid `quantile=0.1`. Must be one of the fitted quantiles " + "`[0.01, 0.5, 0.99]`." ) @pytest.mark.parametrize("mode", [True, False]) From d494707e190308d1dea5bc5da4742215af92403b Mon Sep 17 00:00:00 2001 From: madtoinou Date: Fri, 7 Mar 2025 13:58:09 +0100 Subject: [PATCH 4/7] fix: combined get_multioutput_estimator and get_estimator into a single method --- darts/explainability/shap_explainer.py | 2 +- darts/models/forecasting/regression_model.py | 47 +++++-------------- .../forecasting/test_regression_models.py | 26 +++++----- 3 files changed, 24 insertions(+), 51 deletions(-) diff --git a/darts/explainability/shap_explainer.py b/darts/explainability/shap_explainer.py index b931bbdc74..d80091acb0 100644 --- a/darts/explainability/shap_explainer.py +++ b/darts/explainability/shap_explainer.py @@ -611,7 +611,7 @@ def __init__( self.explainers[i] = {} for j in range(self.target_dim): self.explainers[i][j] = self._build_explainer_sklearn( - self.model.get_multioutput_estimator(horizon=i, target_dim=j), + self.model.get_estimator(horizon=i, target_dim=j), self.background_X, self.shap_method, **kwargs, diff --git a/darts/models/forecasting/regression_model.py b/darts/models/forecasting/regression_model.py index b2eeab16a7..df0d48dd54 100644 --- a/darts/models/forecasting/regression_model.py +++ b/darts/models/forecasting/regression_model.py @@ -505,15 +505,17 @@ def output_chunk_length(self) -> int: def output_chunk_shift(self) -> int: return self._output_chunk_shift - def get_multioutput_estimator( + def get_estimator( self, horizon: int, target_dim: int, quantile: Optional[float] = None ): """Returns the estimator that forecasts the `horizon`th step of the `target_dim`th target component. - Internally, estimators are grouped by `output_chunk_length` position, then by component. + For probabilistic models fitting quantiles, it is possible to also specify the quantile. - Note: for probabilistic models fitting quantiles, there is an additional abstraction layer, - grouping the estimators by `quantile`. + The model is returned directly if it supports multi-output natively. + + Note: Internally, estimators are grouped by `output_chunk_length` position, then by component. For probabilistic + models fitting quantiles, there is an additional abstraction layer, grouping the estimators by `quantile`. Parameters ---------- @@ -525,10 +527,11 @@ def get_multioutput_estimator( Optionally, for probabilistic model with `likelihood="quantile"`, a quantile value. """ if not isinstance(self.model, MultiOutputRegressor): - raise_log( - ValueError("The sklearn model is not a MultiOutputRegressor object."), - logger, + logger.warning( + "Model supports multi-output; a single estimator forecasts all the horizons and components." ) + return self.model + if not 0 <= horizon < self.output_chunk_length: raise_log( ValueError( @@ -568,35 +571,7 @@ def get_multioutput_estimator( ), logger, ) - return self._model_container[quantile].estimators_[idx_estimator] - - def get_estimator( - self, horizon: int, target_dim: int, quantile: Optional[float] = None - ): - """Returns the estimator that forecasts the `horizon`th step of the `target_dim`th target component. - - For probabilistic models fitting quantiles, it is possible to also specify the quantile. - - The model is returned directly if it supports multi-output natively. - - Parameters - ---------- - horizon - The index of the forecasting point within `output_chunk_length`. - target_dim - The index of the target component. - quantile - Optionally, for probabilistic model with `likelihood="quantile"`, a quantile value. - """ - if isinstance(self.model, MultiOutputRegressor): - return self.get_multioutput_estimator( - horizon=horizon, target_dim=target_dim - ) - else: - logger.info( - "Model supports multi-output; a single estimator forecasts all the horizons and components." - ) - return self.model + return self._model_container[quantile].estimators_[idx_estimator] def _add_val_set_to_kwargs( self, diff --git a/darts/tests/models/forecasting/test_regression_models.py b/darts/tests/models/forecasting/test_regression_models.py index ed83edd18b..2ef1fbe09d 100644 --- a/darts/tests/models/forecasting/test_regression_models.py +++ b/darts/tests/models/forecasting/test_regression_models.py @@ -1396,7 +1396,7 @@ def test_multioutput_validation(self, config): else: assert not isinstance(model.model, MultiOutputRegressor) - def test_get_multioutput_estimator_multi_models(self): + def test_get_estimator_multi_models(self): """Craft training data so that estimator_[i].predict(X) == i + 1""" def helper_check_overfitted_estimators(ts: TimeSeries, ocl: int): @@ -1417,7 +1417,7 @@ def helper_check_overfitted_estimators(ts: TimeSeries, ocl: int): estimator_counter = 0 for i in range(ocl): for j in range(ts.width): - sub_model = m.get_multioutput_estimator(horizon=i, target_dim=j) + sub_model = m.get_estimator(horizon=i, target_dim=j) pred = sub_model.predict(dummy_feats)[0] # sub-model is overfitted on the training series assert np.abs(estimator_counter - pred) < 1e-2 @@ -1455,7 +1455,7 @@ def helper_check_overfitted_estimators(ts: TimeSeries, ocl: int): # estimators_[3] labels : [3] helper_check_overfitted_estimators(ts, ocl) - def test_get_multioutput_estimator_single_model(self): + def test_get_estimator_single_model(self): """Check estimator getter when multi_models=False""" # multivariate, one sub-model per component ocl = 2 @@ -1485,13 +1485,13 @@ def test_get_multioutput_estimator_single_model(self): dummy_feats = np.array([[0, 0, 0] * ts.width]) for i in range(ocl): for j in range(ts.width): - sub_model = m.get_multioutput_estimator(horizon=i, target_dim=j) + sub_model = m.get_estimator(horizon=i, target_dim=j) pred = sub_model.predict(dummy_feats)[0] # sub-model forecast only depend on the target_dim assert np.abs(j + 1 - pred) < 1e-2 @pytest.mark.parametrize("multi_models", [True, False]) - def test_get_multioutput_estimator_quantile(self, multi_models): + def test_get_estimator_quantile(self, multi_models): """Check estimator getter when using quantile value""" ocl = 3 lags = 3 @@ -1536,16 +1536,14 @@ def test_get_multioutput_estimator_quantile(self, multi_models): dummy_feats = pred_input.values()[i : +i + lags] dummy_feats = np.expand_dims(dummy_feats.flatten(), 0) for q in quantiles: - sub_model = m.get_multioutput_estimator( - horizon=i, target_dim=j, quantile=q - ) + sub_model = m.get_estimator(horizon=i, target_dim=j, quantile=q) pred_sub_model = sub_model.predict(dummy_feats)[0] assert ( pred_sub_model == pred[f"{ts.components[j]}_q{q:.2f}"].values()[i][0] ) - def test_get_multioutput_estimator_exceptions(self): + def test_get_estimator_exceptions(self): """Check that all the corner-cases are properly covered by the method""" ts = TimeSeries.from_values( values=np.array([ @@ -1562,7 +1560,7 @@ def test_get_multioutput_estimator_exceptions(self): m.fit(ts["a"]) # not wrapped in MultiOutputRegressor because of native multi-output support with pytest.raises(ValueError) as err: - m.get_multioutput_estimator(horizon=0, target_dim=0) + m.get_estimator(horizon=0, target_dim=0) assert str(err.value).startswith( "The sklearn model is not a MultiOutputRegressor object." ) @@ -1576,13 +1574,13 @@ def test_get_multioutput_estimator_exceptions(self): m.fit(ts["a"]) # horizon > ocl with pytest.raises(ValueError) as err: - m.get_multioutput_estimator(horizon=3, target_dim=0) + m.get_estimator(horizon=3, target_dim=0) assert str(err.value).startswith( "`horizon` must be `>= 0` and `< output_chunk_length" ) # target dim > training series width with pytest.raises(ValueError) as err: - m.get_multioutput_estimator(horizon=0, target_dim=1) + m.get_estimator(horizon=0, target_dim=1) assert str(err.value).startswith( "`target_dim` must be `>= 0`, and `< n_target_components=" ) @@ -1599,7 +1597,7 @@ def test_get_multioutput_estimator_exceptions(self): m.fit(ts["a"]) # incorrect likelihood with pytest.raises(ValueError) as err: - m.get_multioutput_estimator(horizon=0, target_dim=0, quantile=0.1) + m.get_estimator(horizon=0, target_dim=0, quantile=0.1) assert str(err.value).startswith( "`quantile` is only supported for probabilistic models that " "use `likelihood='quantile'`." @@ -1616,7 +1614,7 @@ def test_get_multioutput_estimator_exceptions(self): m.fit(ts["a"]) # retrieving a non-defined quantile with pytest.raises(ValueError) as err: - m.get_multioutput_estimator(horizon=0, target_dim=0, quantile=0.1) + m.get_estimator(horizon=0, target_dim=0, quantile=0.1) assert str(err.value).startswith( "Invalid `quantile=0.1`. Must be one of the fitted quantiles " "`[0.01, 0.5, 0.99]`." From e0ac18076afef1000f61657944a191fb99ec5418 Mon Sep 17 00:00:00 2001 From: madtoinou Date: Fri, 7 Mar 2025 15:24:45 +0100 Subject: [PATCH 5/7] fix: typo --- darts/models/forecasting/regression_model.py | 2 +- .../models/forecasting/test_regression_models.py | 12 ++++++++---- 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/darts/models/forecasting/regression_model.py b/darts/models/forecasting/regression_model.py index df0d48dd54..2c8444a89b 100644 --- a/darts/models/forecasting/regression_model.py +++ b/darts/models/forecasting/regression_model.py @@ -571,7 +571,7 @@ def get_estimator( ), logger, ) - return self._model_container[quantile].estimators_[idx_estimator] + return self._model_container[quantile].estimators_[idx_estimator] def _add_val_set_to_kwargs( self, diff --git a/darts/tests/models/forecasting/test_regression_models.py b/darts/tests/models/forecasting/test_regression_models.py index 2ef1fbe09d..26ed894cf5 100644 --- a/darts/tests/models/forecasting/test_regression_models.py +++ b/darts/tests/models/forecasting/test_regression_models.py @@ -1,6 +1,7 @@ import functools import importlib import inspect +import logging import math from copy import deepcopy from itertools import product @@ -1543,7 +1544,7 @@ def test_get_estimator_quantile(self, multi_models): == pred[f"{ts.components[j]}_q{q:.2f}"].values()[i][0] ) - def test_get_estimator_exceptions(self): + def test_get_estimator_exceptions(self, caplog): """Check that all the corner-cases are properly covered by the method""" ts = TimeSeries.from_values( values=np.array([ @@ -1559,10 +1560,13 @@ def test_get_estimator_exceptions(self): ) m.fit(ts["a"]) # not wrapped in MultiOutputRegressor because of native multi-output support - with pytest.raises(ValueError) as err: + with caplog.at_level(logging.WARNING): m.get_estimator(horizon=0, target_dim=0) - assert str(err.value).startswith( - "The sklearn model is not a MultiOutputRegressor object." + assert len(caplog.records) == 1 + assert caplog.records[0].levelname == "WARNING" + assert caplog.records[0].message == ( + "Model supports multi-output; a single estimator " + "forecasts all the horizons and components." ) # univariate, deterministic, ocl > 2 From 0e01826e8179219f7db428fb3857aa605e90b4e3 Mon Sep 17 00:00:00 2001 From: madtoinou Date: Fri, 7 Mar 2025 15:28:44 +0100 Subject: [PATCH 6/7] update changelog --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 4df55d062a..abb33873ad 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -35,6 +35,7 @@ but cannot always guarantee backwards compatibility. Changes that may **break co **Improved** - Refactored and improved the multi-output support handling for `RegressionModel`. [#2659](https://github.com/unit8co/darts/pull/2659) by [Jonas Blanc](https://github.com/jonasblanc) +- 🔴 The `RegressionModel.get_multioutput_estimator()` method was removed, all the logic is now contained in `RegressionModel.get_estimator()`. [#2716](https://github.com/unit8co/darts/pull/2716) by [Antoine Madrona](https://github.com/madtoinou) ## [0.33.0](https://github.com/unit8co/darts/tree/0.33.0) (2025-02-14) From b5ce31e01421f6bd23c8fb12de0cb1292201671f Mon Sep 17 00:00:00 2001 From: dennisbader Date: Fri, 7 Mar 2025 16:26:59 +0100 Subject: [PATCH 7/7] update changelog --- CHANGELOG.md | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index abb33873ad..7a2eba14b3 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -19,6 +19,10 @@ but cannot always guarantee backwards compatibility. Changes that may **break co - Made it possible to run the quickstart notebook `00-quickstart.ipynb` locally. [#2691](https://github.com/unit8co/darts/pull/2691) by [Jules Authier](https://github.com/authierj) - Added `quantile` parameter to `RegressionModel.get_estimator()` to get the specific quantile estimator for probabilistic regression models using the `quantile` likelihood. [#2716](https://github.com/unit8co/darts/pull/2716) by [Antoine Madrona](https://github.com/madtoinou) +**Removed** + +- 🔴 Removed method `RegressionModel.get_multioutput_estimator()`. Use `RegressionModel.get_estimator()` instead. [#2716](https://github.com/unit8co/darts/pull/2716) by [Antoine Madrona](https://github.com/madtoinou) + **Fixed** - 🔴 / 🟢 Fixed a bug which raised an error when loading torch models that were saved with Darts versions < 0.33.0. This is a breaking change and models saved with version 0.33.0 will not be loadable anymore. [#2692](https://github.com/unit8co/darts/pull/2692) by [Dennis Bader](https://github.com/dennisbader). @@ -35,7 +39,6 @@ but cannot always guarantee backwards compatibility. Changes that may **break co **Improved** - Refactored and improved the multi-output support handling for `RegressionModel`. [#2659](https://github.com/unit8co/darts/pull/2659) by [Jonas Blanc](https://github.com/jonasblanc) -- 🔴 The `RegressionModel.get_multioutput_estimator()` method was removed, all the logic is now contained in `RegressionModel.get_estimator()`. [#2716](https://github.com/unit8co/darts/pull/2716) by [Antoine Madrona](https://github.com/madtoinou) ## [0.33.0](https://github.com/unit8co/darts/tree/0.33.0) (2025-02-14)