Skip to content

Commit 62dca8f

Browse files
Catboost native multi-output with RMSE (#2659)
* Replace MultioutputRegressor by native multioutput from CastBoost, wip * Add _native_support_multioutput() to regression models, wip * Fix logic, replace RMSE with MultiRMSE when multioutput is required * Add entry to change log * Adapt shapexplainer test to support all native multioutput * Let user decide on native multioutput usage * Update CHANGELOG.md * Update CHANGELOG.md * Add recursive call for native multioutput * Add test for native multioutput support * Fix logic, only sklearn model should be passed to RegressionModel * Set loss_function in kwargs following model default * Fix SNM logic to be independant of MOR wrapper * Use __sklearn_tags__ instead of _get_tags as required in sklearn 1.7 * Update CHANGELOG.md * Minor rewriting * Remove logic for xboost<2.0.0 * minor updates --------- Co-authored-by: dennisbader <[email protected]>
1 parent b1f7327 commit 62dca8f

File tree

8 files changed

+87
-59
lines changed

8 files changed

+87
-59
lines changed

CHANGELOG.md

+8
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ but cannot always guarantee backwards compatibility. Changes that may **break co
1111

1212
**Improved**
1313

14+
- Improved `CatBoostModel` documentation by describing how to use native multi-output regression. [#2659](https://github.com/unit8co/darts/pull/2659) by [Jonas Blanc](https://github.com/jonasblanc)
1415
- `TimeSeries.from_dataframe()` and `from_series()` now support creating `TimeSeries` from additional backends (Polars, PyArrow, ...). We leverage `narwhals` as the compatibility layer between dataframe libraries. See the `narwhals` [documentation](https://narwhals-dev.github.io/narwhals/) for all supported backends. [#2661](https://github.com/unit8co/darts/pull/2661) by [Jules Authier](https://github.com/authierj)
1516
- Added ONNX support for torch-based models with method `TorchForecastingModel.to_onnx()`. Check out [this example](https://unit8co.github.io/darts/userguide/gpu_and_tpu_usage.html#exporting-model-to-onnx-format-for-inference) from the user guide on how to export and load a model for inference. [#2620](https://github.com/unit8co/darts/pull/2620) by [Antoine Madrona](https://github.com/madtoinou)
1617
- Made method `ForecastingModel.untrained_model()` public. Use this method to get a new (untrained) model instance created with the same parameters. [#2684](https://github.com/unit8co/darts/pull/2684) by [Timon Erhart](https://github.com/turbotimon)
@@ -24,8 +25,15 @@ but cannot always guarantee backwards compatibility. Changes that may **break co
2425

2526
**Dependencies**
2627

28+
- Bumped minimum scikit-learn version from `1.0.1` to `1.6.0`. This was required due to sklearn deprecating `_get_tags` in favor of `BaseEstimator.__sklearn_tags__` in version 1.7. This leads to increasing both sklearn and XGBoost minimum supported version to 1.6 and 2.1.4 respectively. [#2659](https://github.com/unit8co/darts/pull/2659) by [Jonas Blanc](https://github.com/jonasblanc)
29+
- Bumped minimum xgboost version from `1.6.0` to `2.1.4` for the same reason as bumping the minimum sklearn version. [#2659](https://github.com/unit8co/darts/pull/2659) by [Jonas Blanc](https://github.com/jonasblanc)
30+
2731
### For developers of the library:
2832

33+
**Improved**
34+
35+
- Refactored and improved the multi-output support handling for `RegressionModel`. [#2659](https://github.com/unit8co/darts/pull/2659) by [Jonas Blanc](https://github.com/jonasblanc)
36+
2937
## [0.33.0](https://github.com/unit8co/darts/tree/0.33.0) (2025-02-14)
3038

3139
### For users of the library:

darts/models/forecasting/catboost_model.py

+14
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,9 @@ def encode_year(idx):
132132
that all target `series` have the same static covariate dimensionality in ``fit()`` and ``predict()``.
133133
**kwargs
134134
Additional keyword arguments passed to `catboost.CatBoostRegressor`.
135+
Native multi-output support can be achieved by using an appropriate `loss_function` ('MultiRMSE',
136+
'MultiRMSEWithMissingValues'). Otherwise, Darts uses its `MultiOutputRegressor` wrapper to add multi-output
137+
support.
135138
136139
Examples
137140
--------
@@ -208,6 +211,9 @@ def encode_year(idx):
208211
use_static_covariates=use_static_covariates,
209212
)
210213

214+
# if no loss provided, get the default loss from the model
215+
self.kwargs["loss_function"] = self.model.get_params().get("loss_function")
216+
211217
def fit(
212218
self,
213219
series: Union[TimeSeries, Sequence[TimeSeries]],
@@ -404,3 +410,11 @@ def min_train_series_length(self) -> int:
404410
else self.output_chunk_length
405411
),
406412
)
413+
414+
@property
415+
def _supports_native_multioutput(self):
416+
# CatBoostRegressor supports multi-output natively, but only with selected loss functions
417+
# ("MultiRMSE", "MultiRMSEWithMissingValues", ...)
418+
return CatBoostRegressor._is_multiregression_objective(
419+
self.kwargs.get("loss_function")
420+
)

darts/models/forecasting/regression_model.py

+25-29
Original file line numberDiff line numberDiff line change
@@ -846,36 +846,21 @@ def fit(
846846
"future": future_covariates[0].width if future_covariates else None,
847847
}
848848

849-
# if multi-output regression
850-
use_mor = False
851-
if not series[0].is_univariate or (
852-
self.output_chunk_length > 1
853-
and self.multi_models
849+
# Check if multi-output regression is required
850+
requires_multioutput = not series[0].is_univariate or (
851+
self.output_chunk_length > 1 and self.multi_models
852+
)
853+
854+
# If multi-output required and model doesn't support it natively, wrap it in a MultiOutputRegressor
855+
if (
856+
requires_multioutput
854857
and not isinstance(self.model, MultiOutputRegressor)
858+
and (
859+
not self._supports_native_multioutput
860+
or sample_weight
861+
is not None # we have 2D sample (and time) weights, only supported in Darts
862+
)
855863
):
856-
if sample_weight is not None:
857-
# we have 2D sample (and time) weights, only supported in Darts
858-
use_mor = True
859-
elif not (
860-
callable(getattr(self.model, "_get_tags", None))
861-
and isinstance(self.model._get_tags(), dict)
862-
and self.model._get_tags().get("multioutput")
863-
):
864-
# model does not support multi-output regression natively
865-
use_mor = True
866-
elif (
867-
self.model.__class__.__name__ == "CatBoostRegressor"
868-
and self.model.get_params()["loss_function"] == "RMSEWithUncertainty"
869-
):
870-
use_mor = True
871-
elif (
872-
self.model.__class__.__name__ == "XGBRegressor"
873-
and self.likelihood is not None
874-
):
875-
# since xgboost==2.1.0, likelihoods do not support native multi output regression
876-
use_mor = True
877-
878-
if use_mor:
879864
val_set_name, val_weight_name = self.val_set_params
880865
mor_kwargs = {
881866
"eval_set_name": val_set_name,
@@ -884,7 +869,6 @@ def fit(
884869
}
885870
self.model = MultiOutputRegressor(self.model, **mor_kwargs)
886871

887-
# warn if n_jobs_multioutput_wrapper was provided but not used
888872
if (
889873
not isinstance(self.model, MultiOutputRegressor)
890874
and n_jobs_multioutput_wrapper is not None
@@ -1391,6 +1375,18 @@ def _optimized_historical_forecasts(
13911375
)
13921376
return series2seq(hfc, seq_type_out=series_seq_type)
13931377

1378+
@property
1379+
def _supports_native_multioutput(self) -> bool:
1380+
"""
1381+
Returns True if the model supports multi-output regression natively.
1382+
"""
1383+
model = (
1384+
self.model.estimator
1385+
if isinstance(self.model, MultiOutputRegressor)
1386+
else self.model
1387+
)
1388+
return model.__sklearn_tags__().target_tags.multi_output
1389+
13941390

13951391
class _LikelihoodMixin:
13961392
"""

darts/models/forecasting/xgboost.py

+7-13
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
"""
99

1010
from collections.abc import Sequence
11-
from functools import partial
1211
from typing import Optional, Union
1312

1413
import numpy as np
@@ -25,10 +24,6 @@
2524

2625
logger = get_logger(__name__)
2726

28-
# Check whether we are running xgboost >= 2.0.0 for quantile regression
29-
tokens = xgb.__version__.split(".")
30-
xgb_200_or_above = int(tokens[0]) >= 2
31-
3227

3328
def xgb_quantile_loss(labels: np.ndarray, preds: np.ndarray, quantile: float):
3429
"""Custom loss function for XGBoost to compute quantile loss gradient.
@@ -205,9 +200,7 @@ def encode_year(idx):
205200
if likelihood in {"poisson"}:
206201
self.kwargs["objective"] = f"count:{likelihood}"
207202
elif likelihood == "quantile":
208-
if xgb_200_or_above:
209-
# leverage built-in Quantile Regression
210-
self.kwargs["objective"] = "reg:quantileerror"
203+
self.kwargs["objective"] = "reg:quantileerror"
211204
self.quantiles, self._median_idx = self._prepare_quantiles(quantiles)
212205
self._model_container = self._get_model_container()
213206

@@ -289,11 +282,7 @@ def fit(
289282
# empty model container in case of multiple calls to fit, e.g. when backtesting
290283
self._model_container.clear()
291284
for quantile in self.quantiles:
292-
if xgb_200_or_above:
293-
self.kwargs["quantile_alpha"] = quantile
294-
else:
295-
objective = partial(xgb_quantile_loss, quantile=quantile)
296-
self.kwargs["objective"] = objective
285+
self.kwargs["quantile_alpha"] = quantile
297286
self.model = xgb.XGBRegressor(**self.kwargs)
298287
super().fit(
299288
series=series,
@@ -368,3 +357,8 @@ def min_train_series_length(self) -> int:
368357
else self.output_chunk_length
369358
),
370359
)
360+
361+
@property
362+
def _supports_native_multioutput(self):
363+
# since xgboost==2.1.0, likelihoods do not support native multi output regression
364+
return super()._supports_native_multioutput and self.likelihood is None

darts/tests/ad/test_scorers.py

+1-7
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33

44
import numpy as np
55
import pytest
6-
import sklearn
76
from pyod.models.knn import KNN
87
from scipy.stats import cauchy, expon, gamma, laplace, norm, poisson
98

@@ -1221,12 +1220,7 @@ def test_multivariate_componentwise_kmeans(self):
12211220

12221221
assert np.abs(1.0 - auc_roc_cwtrue[0]) < delta
12231222
assert np.abs(0.97666 - auc_roc_cwtrue[1]) < delta
1224-
# sklearn changed the centroid initialization in version 1.3.0
1225-
# so the results are slightly different for older versions
1226-
if sklearn.__version__ < "1.3.0":
1227-
assert np.abs(0.9851 - auc_roc_cwfalse) < delta
1228-
else:
1229-
assert np.abs(0.99007 - auc_roc_cwfalse) < delta
1223+
assert np.abs(0.99007 - auc_roc_cwfalse) < delta
12301224

12311225
def test_PyODScorer(self):
12321226
# Check parameters and inputs

darts/tests/explainability/test_shap_explainer.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -205,8 +205,9 @@ def test_creation(self):
205205

206206
# Good type of explainers
207207
shap_explain = ShapExplainer(m)
208-
if isinstance(m, XGBModel):
209-
# since xgboost > 2.1.0, model supports native multi output regression
208+
if m._supports_native_multioutput:
209+
# since xgboost > 2.1.0, model supports native multi-output regression
210+
# CatBoostModel supports multi-output for certain loss functions
210211
assert isinstance(shap_explain.explainers.explainers, shap.explainers.Tree)
211212
else:
212213
assert isinstance(
@@ -270,7 +271,7 @@ def test_creation(self):
270271
future_covariates=self.fut_cov_ts,
271272
)
272273
shap_explain = ShapExplainer(m)
273-
if isinstance(m, XGBModel):
274+
if m._supports_native_multioutput:
274275
assert isinstance(shap_explain.explainers.explainers, shap.explainers.Tree)
275276
else:
276277
assert isinstance(

darts/tests/models/forecasting/test_regression_models.py

+26-5
Original file line numberDiff line numberDiff line change
@@ -182,11 +182,7 @@ class NewCls(cls):
182182
class TestRegressionModels:
183183
np.random.seed(42)
184184
# default regression models
185-
models = [
186-
RandomForest,
187-
LinearRegressionModel,
188-
RegressionModel,
189-
]
185+
models = [RandomForest, LinearRegressionModel, RegressionModel]
190186

191187
# register likelihood regression models
192188
QuantileLinearRegressionModel = partialclass(
@@ -1355,6 +1351,31 @@ def test_multioutput_wrapper(self, config):
13551351
horizon=0, target_dim=1
13561352
)
13571353

1354+
model_configs_multioutput = [
1355+
(
1356+
RegressionModel,
1357+
{"lags": 4, "model": LinearRegression()},
1358+
True,
1359+
),
1360+
(LinearRegressionModel, {"lags": 4}, True),
1361+
(XGBModel, {"lags": 4}, True),
1362+
(XGBModel, {"lags": 4, "likelihood": "poisson"}, False),
1363+
]
1364+
if lgbm_available:
1365+
model_configs_multioutput += [(LightGBMModel, {"lags": 4}, False)]
1366+
if cb_available:
1367+
model_configs_multioutput += [
1368+
(CatBoostModel, {"lags": 4, "loss_function": "RMSE"}, False),
1369+
(CatBoostModel, {"lags": 4, "loss_function": "MultiRMSE"}, True),
1370+
(CatBoostModel, {"lags": 4, "loss_function": "RMSEWithUncertainty"}, False),
1371+
]
1372+
1373+
@pytest.mark.parametrize("config", model_configs_multioutput)
1374+
def test_supports_native_multioutput(self, config):
1375+
model_cls, model_config, supports_native_multioutput = config
1376+
model = model_cls(**model_config)
1377+
assert model._supports_native_multioutput == supports_native_multioutput
1378+
13581379
model_configs = [(XGBModel, dict({"likelihood": "poisson"}, **xgb_test_params))]
13591380
if lgbm_available:
13601381
model_configs += [(LightGBMModel, lgbm_test_params)]

requirements/core.txt

+2-2
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ pandas>=1.0.5
88
pmdarima>=1.8.0
99
pyod>=0.9.5
1010
requests>=2.22.0
11-
scikit-learn>=1.0.1
11+
scikit-learn>=1.6.0
1212
scipy>=1.3.2
1313
shap>=0.40.0
1414
statsforecast>=1.4
@@ -17,4 +17,4 @@ tbats>=1.1.0
1717
tqdm>=4.60.0
1818
typing-extensions
1919
xarray>=0.17.0
20-
xgboost>=1.6.0
20+
xgboost>=2.1.4

0 commit comments

Comments
 (0)