Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Catboost native multi-output with RMSE #2659

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
e692116
Replace MultioutputRegressor by native multioutput from CastBoost, wip
jonasblanc Jan 24, 2025
a3c812d
Add _native_support_multioutput() to regression models, wip
jonasblanc Jan 29, 2025
06e0521
Fix logic, replace RMSE with MultiRMSE when multioutput is required
jonasblanc Jan 30, 2025
110be86
Merge branch 'master' into feat/native-multioutput-catboost
jonasblanc Jan 30, 2025
5c1eeef
Add entry to change log
jonasblanc Jan 31, 2025
849a7e5
Merge branch 'master' into feat/native-multioutput-catboost
jonasblanc Feb 3, 2025
7eef26f
Adapt shapexplainer test to support all native multioutput
jonasblanc Feb 3, 2025
e140258
Merge branch 'feat/native-multioutput-catboost' of github.com:jonasbl…
jonasblanc Feb 3, 2025
554deff
Merge branch 'master' into feat/native-multioutput-catboost
jonasblanc Feb 3, 2025
5fa5a51
Let user decide on native multioutput usage
jonasblanc Feb 17, 2025
fac56c0
Merge branch 'feat/native-multioutput-catboost' of github.com:jonasbl…
jonasblanc Feb 17, 2025
02c484b
Update CHANGELOG.md
jonasblanc Feb 18, 2025
60f9ed1
Merge branch 'master' into feat/native-multioutput-catboost
jonasblanc Feb 18, 2025
1ad3c48
Update CHANGELOG.md
jonasblanc Feb 18, 2025
fe202d2
Merge branch 'master' into feat/native-multioutput-catboost
jonasblanc Feb 18, 2025
c2f7724
Add recursive call for native multioutput
jonasblanc Feb 18, 2025
d811189
Add test for native multioutput support
jonasblanc Feb 18, 2025
36573b3
Merge branch 'feat/native-multioutput-catboost' of github.com:jonasbl…
jonasblanc Feb 18, 2025
bd3f745
Fix logic, only sklearn model should be passed to RegressionModel
jonasblanc Feb 18, 2025
806c869
Set loss_function in kwargs following model default
jonasblanc Feb 20, 2025
5973a4b
Merge branch 'master' into feat/native-multioutput-catboost
jonasblanc Feb 20, 2025
1595477
Fix SNM logic to be independant of MOR wrapper
jonasblanc Mar 3, 2025
c4dab0a
Merge branch 'feat/native-multioutput-catboost' of github.com:jonasbl…
jonasblanc Mar 3, 2025
4963330
Merge branch 'master' into feat/native-multioutput-catboost
jonasblanc Mar 3, 2025
dd9075d
Use __sklearn_tags__ instead of _get_tags as required in sklearn 1.7
jonasblanc Mar 3, 2025
8b5bdd9
Merge branch 'feat/native-multioutput-catboost' of github.com:jonasbl…
jonasblanc Mar 3, 2025
62c7187
Update CHANGELOG.md
jonasblanc Mar 3, 2025
55a17a6
Minor rewriting
jonasblanc Mar 6, 2025
ef78e6a
Remove logic for xboost<2.0.0
jonasblanc Mar 6, 2025
bbcc7b4
Merge branch 'feat/native-multioutput-catboost' of github.com:jonasbl…
jonasblanc Mar 6, 2025
357ab51
Merge branch 'master' into feat/native-multioutput-catboost
jonasblanc Mar 6, 2025
76f29d1
minor updates
dennisbader Mar 7, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ but cannot always guarantee backwards compatibility. Changes that may **break co

**Improved**

- Improved `CatBoostModel` documentation by describing how to use native multi-output regression. [#2659](https://github.com/unit8co/darts/pull/2659) by [Jonas Blanc](https://github.com/jonasblanc)
- `TimeSeries.from_dataframe()` and `from_series()` now support creating `TimeSeries` from additional backends (Polars, PyArrow, ...). We leverage `narwhals` as the compatibility layer between dataframe libraries. See the `narwhals` [documentation](https://narwhals-dev.github.io/narwhals/) for all supported backends. [#2661](https://github.com/unit8co/darts/pull/2661) by [Jules Authier](https://github.com/authierj)
- Added ONNX support for torch-based models with method `TorchForecastingModel.to_onnx()`. Check out [this example](https://unit8co.github.io/darts/userguide/gpu_and_tpu_usage.html#exporting-model-to-onnx-format-for-inference) from the user guide on how to export and load a model for inference. [#2620](https://github.com/unit8co/darts/pull/2620) by [Antoine Madrona](https://github.com/madtoinou)
- Made method `ForecastingModel.untrained_model()` public. Use this method to get a new (untrained) model instance created with the same parameters. [#2684](https://github.com/unit8co/darts/pull/2684) by [Timon Erhart](https://github.com/turbotimon)
Expand All @@ -24,8 +25,15 @@ but cannot always guarantee backwards compatibility. Changes that may **break co

**Dependencies**

- Bumped minimum scikit-learn version from `1.0.1` to `1.6.0`. This was required due to sklearn deprecating `_get_tags` in favor of `BaseEstimator.__sklearn_tags__` in version 1.7. This leads to increasing both sklearn and XGBoost minimum supported version to 1.6 and 2.1.4 respectively. [#2659](https://github.com/unit8co/darts/pull/2659) by [Jonas Blanc](https://github.com/jonasblanc)
- Bumped minimum xgboost version from `1.6.0` to `2.1.4` for the same reason as bumping the minimum sklearn version. [#2659](https://github.com/unit8co/darts/pull/2659) by [Jonas Blanc](https://github.com/jonasblanc)

### For developers of the library:

**Improved**

- Refactored and improved the multi-output support handling for `RegressionModel`. [#2659](https://github.com/unit8co/darts/pull/2659) by [Jonas Blanc](https://github.com/jonasblanc)

## [0.33.0](https://github.com/unit8co/darts/tree/0.33.0) (2025-02-14)

### For users of the library:
Expand Down
14 changes: 14 additions & 0 deletions darts/models/forecasting/catboost_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,9 @@ def encode_year(idx):
that all target `series` have the same static covariate dimensionality in ``fit()`` and ``predict()``.
**kwargs
Additional keyword arguments passed to `catboost.CatBoostRegressor`.
Native multi-output support can be achieved by using an appropriate `loss_function` ('MultiRMSE',
'MultiRMSEWithMissingValues'). Otherwise, Darts uses its `MultiOutputRegressor` wrapper to add multi-output
support.

Examples
--------
Expand Down Expand Up @@ -208,6 +211,9 @@ def encode_year(idx):
use_static_covariates=use_static_covariates,
)

# if no loss provided, get the default loss from the model
self.kwargs["loss_function"] = self.model.get_params().get("loss_function")

def fit(
self,
series: Union[TimeSeries, Sequence[TimeSeries]],
Expand Down Expand Up @@ -404,3 +410,11 @@ def min_train_series_length(self) -> int:
else self.output_chunk_length
),
)

@property
def _supports_native_multioutput(self):
# CatBoostRegressor supports multi-output natively, but only with selected loss functions
# ("MultiRMSE", "MultiRMSEWithMissingValues", ...)
return CatBoostRegressor._is_multiregression_objective(
self.kwargs.get("loss_function")
)
54 changes: 25 additions & 29 deletions darts/models/forecasting/regression_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -846,36 +846,21 @@ def fit(
"future": future_covariates[0].width if future_covariates else None,
}

# if multi-output regression
use_mor = False
if not series[0].is_univariate or (
self.output_chunk_length > 1
and self.multi_models
# Check if multi-output regression is required
requires_multioutput = not series[0].is_univariate or (
self.output_chunk_length > 1 and self.multi_models
)

# If multi-output required and model doesn't support it natively, wrap it in a MultiOutputRegressor
if (
requires_multioutput
and not isinstance(self.model, MultiOutputRegressor)
and (
not self._supports_native_multioutput
or sample_weight
is not None # we have 2D sample (and time) weights, only supported in Darts
)
):
if sample_weight is not None:
# we have 2D sample (and time) weights, only supported in Darts
use_mor = True
elif not (
callable(getattr(self.model, "_get_tags", None))
and isinstance(self.model._get_tags(), dict)
and self.model._get_tags().get("multioutput")
):
# model does not support multi-output regression natively
use_mor = True
elif (
self.model.__class__.__name__ == "CatBoostRegressor"
and self.model.get_params()["loss_function"] == "RMSEWithUncertainty"
):
use_mor = True
elif (
self.model.__class__.__name__ == "XGBRegressor"
and self.likelihood is not None
):
# since xgboost==2.1.0, likelihoods do not support native multi output regression
use_mor = True

if use_mor:
val_set_name, val_weight_name = self.val_set_params
mor_kwargs = {
"eval_set_name": val_set_name,
Expand All @@ -884,7 +869,6 @@ def fit(
}
self.model = MultiOutputRegressor(self.model, **mor_kwargs)

# warn if n_jobs_multioutput_wrapper was provided but not used
if (
not isinstance(self.model, MultiOutputRegressor)
and n_jobs_multioutput_wrapper is not None
Expand Down Expand Up @@ -1391,6 +1375,18 @@ def _optimized_historical_forecasts(
)
return series2seq(hfc, seq_type_out=series_seq_type)

@property
def _supports_native_multioutput(self) -> bool:
"""
Returns True if the model supports multi-output regression natively.
"""
model = (
self.model.estimator
if isinstance(self.model, MultiOutputRegressor)
else self.model
)
return model.__sklearn_tags__().target_tags.multi_output


class _LikelihoodMixin:
"""
Expand Down
20 changes: 7 additions & 13 deletions darts/models/forecasting/xgboost.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
"""

from collections.abc import Sequence
from functools import partial
from typing import Optional, Union

import numpy as np
Expand All @@ -25,10 +24,6 @@

logger = get_logger(__name__)

# Check whether we are running xgboost >= 2.0.0 for quantile regression
tokens = xgb.__version__.split(".")
xgb_200_or_above = int(tokens[0]) >= 2


def xgb_quantile_loss(labels: np.ndarray, preds: np.ndarray, quantile: float):
"""Custom loss function for XGBoost to compute quantile loss gradient.
Expand Down Expand Up @@ -205,9 +200,7 @@ def encode_year(idx):
if likelihood in {"poisson"}:
self.kwargs["objective"] = f"count:{likelihood}"
elif likelihood == "quantile":
if xgb_200_or_above:
# leverage built-in Quantile Regression
self.kwargs["objective"] = "reg:quantileerror"
self.kwargs["objective"] = "reg:quantileerror"
self.quantiles, self._median_idx = self._prepare_quantiles(quantiles)
self._model_container = self._get_model_container()

Expand Down Expand Up @@ -289,11 +282,7 @@ def fit(
# empty model container in case of multiple calls to fit, e.g. when backtesting
self._model_container.clear()
for quantile in self.quantiles:
if xgb_200_or_above:
self.kwargs["quantile_alpha"] = quantile
else:
objective = partial(xgb_quantile_loss, quantile=quantile)
self.kwargs["objective"] = objective
self.kwargs["quantile_alpha"] = quantile
self.model = xgb.XGBRegressor(**self.kwargs)
super().fit(
series=series,
Expand Down Expand Up @@ -368,3 +357,8 @@ def min_train_series_length(self) -> int:
else self.output_chunk_length
),
)

@property
def _supports_native_multioutput(self):
# since xgboost==2.1.0, likelihoods do not support native multi output regression
return super()._supports_native_multioutput and self.likelihood is None
8 changes: 1 addition & 7 deletions darts/tests/ad/test_scorers.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@

import numpy as np
import pytest
import sklearn
from pyod.models.knn import KNN
from scipy.stats import cauchy, expon, gamma, laplace, norm, poisson

Expand Down Expand Up @@ -1221,12 +1220,7 @@ def test_multivariate_componentwise_kmeans(self):

assert np.abs(1.0 - auc_roc_cwtrue[0]) < delta
assert np.abs(0.97666 - auc_roc_cwtrue[1]) < delta
# sklearn changed the centroid initialization in version 1.3.0
# so the results are slightly different for older versions
if sklearn.__version__ < "1.3.0":
assert np.abs(0.9851 - auc_roc_cwfalse) < delta
else:
assert np.abs(0.99007 - auc_roc_cwfalse) < delta
assert np.abs(0.99007 - auc_roc_cwfalse) < delta

def test_PyODScorer(self):
# Check parameters and inputs
Expand Down
7 changes: 4 additions & 3 deletions darts/tests/explainability/test_shap_explainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,8 +205,9 @@ def test_creation(self):

# Good type of explainers
shap_explain = ShapExplainer(m)
if isinstance(m, XGBModel):
# since xgboost > 2.1.0, model supports native multi output regression
if m._supports_native_multioutput:
# since xgboost > 2.1.0, model supports native multi-output regression
# CatBoostModel supports multi-output for certain loss functions
assert isinstance(shap_explain.explainers.explainers, shap.explainers.Tree)
else:
assert isinstance(
Expand Down Expand Up @@ -270,7 +271,7 @@ def test_creation(self):
future_covariates=self.fut_cov_ts,
)
shap_explain = ShapExplainer(m)
if isinstance(m, XGBModel):
if m._supports_native_multioutput:
assert isinstance(shap_explain.explainers.explainers, shap.explainers.Tree)
else:
assert isinstance(
Expand Down
31 changes: 26 additions & 5 deletions darts/tests/models/forecasting/test_regression_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,11 +182,7 @@ class NewCls(cls):
class TestRegressionModels:
np.random.seed(42)
# default regression models
models = [
RandomForest,
LinearRegressionModel,
RegressionModel,
]
models = [RandomForest, LinearRegressionModel, RegressionModel]

# register likelihood regression models
QuantileLinearRegressionModel = partialclass(
Expand Down Expand Up @@ -1355,6 +1351,31 @@ def test_multioutput_wrapper(self, config):
horizon=0, target_dim=1
)

model_configs_multioutput = [
(
RegressionModel,
{"lags": 4, "model": LinearRegression()},
True,
),
(LinearRegressionModel, {"lags": 4}, True),
(XGBModel, {"lags": 4}, True),
(XGBModel, {"lags": 4, "likelihood": "poisson"}, False),
]
if lgbm_available:
model_configs_multioutput += [(LightGBMModel, {"lags": 4}, False)]
if cb_available:
model_configs_multioutput += [
(CatBoostModel, {"lags": 4, "loss_function": "RMSE"}, False),
(CatBoostModel, {"lags": 4, "loss_function": "MultiRMSE"}, True),
(CatBoostModel, {"lags": 4, "loss_function": "RMSEWithUncertainty"}, False),
]

@pytest.mark.parametrize("config", model_configs_multioutput)
def test_supports_native_multioutput(self, config):
model_cls, model_config, supports_native_multioutput = config
model = model_cls(**model_config)
assert model._supports_native_multioutput == supports_native_multioutput

model_configs = [(XGBModel, dict({"likelihood": "poisson"}, **xgb_test_params))]
if lgbm_available:
model_configs += [(LightGBMModel, lgbm_test_params)]
Expand Down
4 changes: 2 additions & 2 deletions requirements/core.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ pandas>=1.0.5
pmdarima>=1.8.0
pyod>=0.9.5
requests>=2.22.0
scikit-learn>=1.0.1
scikit-learn>=1.6.0
scipy>=1.3.2
shap>=0.40.0
statsforecast>=1.4
Expand All @@ -17,4 +17,4 @@ tbats>=1.1.0
tqdm>=4.60.0
typing-extensions
xarray>=0.17.0
xgboost>=1.6.0
xgboost>=2.1.4
Loading