Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Catboost native multi-output with RMSE #2659

Merged
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
e692116
Replace MultioutputRegressor by native multioutput from CastBoost, wip
jonasblanc Jan 24, 2025
a3c812d
Add _native_support_multioutput() to regression models, wip
jonasblanc Jan 29, 2025
06e0521
Fix logic, replace RMSE with MultiRMSE when multioutput is required
jonasblanc Jan 30, 2025
110be86
Merge branch 'master' into feat/native-multioutput-catboost
jonasblanc Jan 30, 2025
5c1eeef
Add entry to change log
jonasblanc Jan 31, 2025
849a7e5
Merge branch 'master' into feat/native-multioutput-catboost
jonasblanc Feb 3, 2025
7eef26f
Adapt shapexplainer test to support all native multioutput
jonasblanc Feb 3, 2025
e140258
Merge branch 'feat/native-multioutput-catboost' of github.com:jonasbl…
jonasblanc Feb 3, 2025
554deff
Merge branch 'master' into feat/native-multioutput-catboost
jonasblanc Feb 3, 2025
5fa5a51
Let user decide on native multioutput usage
jonasblanc Feb 17, 2025
fac56c0
Merge branch 'feat/native-multioutput-catboost' of github.com:jonasbl…
jonasblanc Feb 17, 2025
02c484b
Update CHANGELOG.md
jonasblanc Feb 18, 2025
60f9ed1
Merge branch 'master' into feat/native-multioutput-catboost
jonasblanc Feb 18, 2025
1ad3c48
Update CHANGELOG.md
jonasblanc Feb 18, 2025
fe202d2
Merge branch 'master' into feat/native-multioutput-catboost
jonasblanc Feb 18, 2025
c2f7724
Add recursive call for native multioutput
jonasblanc Feb 18, 2025
d811189
Add test for native multioutput support
jonasblanc Feb 18, 2025
36573b3
Merge branch 'feat/native-multioutput-catboost' of github.com:jonasbl…
jonasblanc Feb 18, 2025
bd3f745
Fix logic, only sklearn model should be passed to RegressionModel
jonasblanc Feb 18, 2025
806c869
Set loss_function in kwargs following model default
jonasblanc Feb 20, 2025
5973a4b
Merge branch 'master' into feat/native-multioutput-catboost
jonasblanc Feb 20, 2025
1595477
Fix SNM logic to be independant of MOR wrapper
jonasblanc Mar 3, 2025
c4dab0a
Merge branch 'feat/native-multioutput-catboost' of github.com:jonasbl…
jonasblanc Mar 3, 2025
4963330
Merge branch 'master' into feat/native-multioutput-catboost
jonasblanc Mar 3, 2025
dd9075d
Use __sklearn_tags__ instead of _get_tags as required in sklearn 1.7
jonasblanc Mar 3, 2025
8b5bdd9
Merge branch 'feat/native-multioutput-catboost' of github.com:jonasbl…
jonasblanc Mar 3, 2025
62c7187
Update CHANGELOG.md
jonasblanc Mar 3, 2025
55a17a6
Minor rewriting
jonasblanc Mar 6, 2025
ef78e6a
Remove logic for xboost<2.0.0
jonasblanc Mar 6, 2025
bbcc7b4
Merge branch 'feat/native-multioutput-catboost' of github.com:jonasbl…
jonasblanc Mar 6, 2025
357ab51
Merge branch 'master' into feat/native-multioutput-catboost
jonasblanc Mar 6, 2025
76f29d1
minor updates
dennisbader Mar 7, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ but cannot always guarantee backwards compatibility. Changes that may **break co
- Added more resampling methods to `TimeSeries.resample()`. This allows to aggregate values when down-sampling and to fill or keep the holes when up-sampling. [#2654](https://github.com/unit8co/darts/pull/2654) by [Jonas Blanc](https://github.com/jonasblanc)
- Added general function `darts.slice_intersect()` to intersect a sequence of `TimeSeries` along the time index. [#2592](https://github.com/unit8co/darts/pull/2592) by [Yoav Matzkevich](https://github.com/ymatzkevich).
- Added new time aggregated metric `wmape()` (Weighted Mean Absolute Percentage Error). [#2544](https://github.com/unit8co/darts/pull/2648) by [He Weilin](https://github.com/cnhwl).
- Added support for native `CatBoostModel` multioutput (by converting loss function from `RMSE` to `MultiRMSE` when required) instead of relying on Darts `MultiOutputRegressor`. [#2659](https://github.com/unit8co/darts/pull/2659) by [Jonas Blanc](https://github.com/jonasblanc)
- Improvements to `ForecastingModel`:
- Added parameter `clean: bool` to `ForecastingModel.save()` to store a cleaned version of the model (removes training data from global models, and Lightning Trainer-related parameters from torch models). [#2649](https://github.com/unit8co/darts/pull/2649) by [Jonas Blanc](https://github.com/jonasblanc).
- Added parameter `pl_trainer_kwargs` to `TorchForecastingModel.load()` to setup a new Lightning Trainer used to configure the model for downstream tasks (e.g. prediction). [#2649](https://github.com/unit8co/darts/pull/2649) by [Jonas Blanc](https://github.com/jonasblanc).
Expand Down
28 changes: 25 additions & 3 deletions darts/models/forecasting/catboost_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from darts.logging import get_logger
from darts.models.forecasting.regression_model import RegressionModel, _LikelihoodMixin
from darts.timeseries import TimeSeries
from darts.utils.multioutput import MultiOutputRegressor

logger = get_logger(__name__)

Expand Down Expand Up @@ -208,6 +209,10 @@ def encode_year(idx):
use_static_covariates=use_static_covariates,
)

def _native_support_multioutput(self):
# CatBoostRegressor supports multioutput natively, but only with the "MultiRMSE" loss function
return self.kwargs["loss_function"] == "MultiRMSE"

def fit(
self,
series: Union[TimeSeries, Sequence[TimeSeries]],
Expand Down Expand Up @@ -249,9 +254,9 @@ def fit(
creation) to know their sizes, which might be expensive on big datasets.
If some series turn out to have a length that would allow more than `max_samples_per_ts`, only the
most recent `max_samples_per_ts` samples will be considered.
n_jobs_multioutput_wrapper
Number of jobs of the MultiOutputRegressor wrapper to run in parallel. Only used if the model doesn't
support multi-output regression natively.
# n_jobs_multioutput_wrapper
# Number of jobs of the MultiOutputRegressor wrapper to run in parallel. Only used if the model doesn't
# support multi-output regression natively.
sample_weight
Optionally, some sample weights to apply to the target `series` labels. They are applied per observation,
per label (each step in `output_chunk_length`), and per component.
Expand Down Expand Up @@ -294,6 +299,23 @@ def fit(
self._model_container[quantile] = self.model
return self

# If multioutput, and notprobabilistic, use MultiRMSE loss for CatBoost native multioutput support
require_multioutput = not series[0].is_univariate or (
self.output_chunk_length > 1
and self.multi_models
and not isinstance(self.model, MultiOutputRegressor)
)

if require_multioutput and (
self.kwargs.get("loss_function") is None
or self.kwargs["loss_function"] == "RMSE"
):
self.kwargs["loss_function"] = "MultiRMSE"
self.model = CatBoostRegressor(**self.kwargs)
logger.warning(
"Changed loss function to 'MultiRMSE' for multioutput support"
)

super().fit(
series=series,
past_covariates=past_covariates,
Expand Down
44 changes: 17 additions & 27 deletions darts/models/forecasting/regression_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -736,6 +736,16 @@ def _fit_model(
)
)

def _native_support_multioutput(self) -> bool:
"""
Returns True if the model supports multi-output regression natively.
"""
return (
callable(getattr(self.model, "_get_tags", None))
and isinstance(self.model._get_tags(), dict)
and self.model._get_tags().get("multioutput")
)

def fit(
self,
series: Union[TimeSeries, Sequence[TimeSeries]],
Expand Down Expand Up @@ -846,36 +856,17 @@ def fit(
"future": future_covariates[0].width if future_covariates else None,
}

# if multi-output regression
use_mor = False
if not series[0].is_univariate or (
# Check if multi-output regression is required
require_multioutput = not series[0].is_univariate or (
self.output_chunk_length > 1
and self.multi_models
and not isinstance(self.model, MultiOutputRegressor)
):
if sample_weight is not None:
# we have 2D sample (and time) weights, only supported in Darts
use_mor = True
elif not (
callable(getattr(self.model, "_get_tags", None))
and isinstance(self.model._get_tags(), dict)
and self.model._get_tags().get("multioutput")
):
# model does not support multi-output regression natively
use_mor = True
elif (
self.model.__class__.__name__ == "CatBoostRegressor"
and self.model.get_params()["loss_function"] == "RMSEWithUncertainty"
):
use_mor = True
elif (
self.model.__class__.__name__ == "XGBRegressor"
and self.likelihood is not None
):
# since xgboost==2.1.0, likelihoods do not support native multi output regression
use_mor = True
)

if use_mor:
# If multi-output required and model doesn't support it natively, wrap it in a MultiOutputRegressor
if require_multioutput and (
not self._native_support_multioutput() or sample_weight is not None
):
val_set_name, val_weight_name = self.val_set_params
mor_kwargs = {
"eval_set_name": val_set_name,
Expand All @@ -884,7 +875,6 @@ def fit(
}
self.model = MultiOutputRegressor(self.model, **mor_kwargs)

# warn if n_jobs_multioutput_wrapper was provided but not used
if (
not isinstance(self.model, MultiOutputRegressor)
and n_jobs_multioutput_wrapper is not None
Expand Down
4 changes: 4 additions & 0 deletions darts/models/forecasting/xgboost.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,10 @@ def encode_year(idx):
use_static_covariates=use_static_covariates,
)

def _native_support_multioutput(self):
# since xgboost==2.1.0, likelihoods do not support native multi output regression
return super()._native_support_multioutput() and self.likelihood is None

def fit(
self,
series: Union[TimeSeries, Sequence[TimeSeries]],
Expand Down
5 changes: 3 additions & 2 deletions darts/tests/explainability/test_shap_explainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,8 +205,9 @@ def test_creation(self):

# Good type of explainers
shap_explain = ShapExplainer(m)
if isinstance(m, XGBModel):
if m._native_support_multioutput():
# since xgboost > 2.1.0, model supports native multi output regression
# CatBoostModel supports multioutput for certain loss functions
assert isinstance(shap_explain.explainers.explainers, shap.explainers.Tree)
else:
assert isinstance(
Expand Down Expand Up @@ -270,7 +271,7 @@ def test_creation(self):
future_covariates=self.fut_cov_ts,
)
shap_explain = ShapExplainer(m)
if isinstance(m, XGBModel):
if m._native_support_multioutput():
assert isinstance(shap_explain.explainers.explainers, shap.explainers.Tree)
else:
assert isinstance(
Expand Down
10 changes: 4 additions & 6 deletions darts/tests/models/forecasting/test_regression_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,11 +182,7 @@ class NewCls(cls):
class TestRegressionModels:
np.random.seed(42)
# default regression models
models = [
RandomForest,
LinearRegressionModel,
RegressionModel,
]
models = [RandomForest, LinearRegressionModel, RegressionModel]

# register likelihood regression models
QuantileLinearRegressionModel = partialclass(
Expand Down Expand Up @@ -1362,7 +1358,9 @@ def test_multioutput_wrapper(self, config):
if lgbm_available:
model_configs += [(LightGBMModel, lgbm_test_params)]
if cb_available:
model_configs += [(CatBoostModel, cb_test_params)]
model_configs += [
(CatBoostModel, dict({"likelihood": "poisson"}, **cb_test_params))
]

@pytest.mark.parametrize("config", product(model_configs, [1, 2], [True, False]))
def test_multioutput_validation(self, config):
Expand Down
Loading