Skip to content

Commit f76aedf

Browse files
feat: adding support for getting estimator based on quantile (#2716)
* feat: adding support for getting estimator based on quantile and associated tests * updated changelog * minor updates * fix: combined get_multioutput_estimator and get_estimator into a single method * fix: typo * update changelog * update changelog --------- Co-authored-by: dennisbader <[email protected]>
1 parent fabba0a commit f76aedf

File tree

4 files changed

+194
-44
lines changed

4 files changed

+194
-44
lines changed

CHANGELOG.md

+5
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,11 @@ but cannot always guarantee backwards compatibility. Changes that may **break co
1717
- Made method `ForecastingModel.untrained_model()` public. Use this method to get a new (untrained) model instance created with the same parameters. [#2684](https://github.com/unit8co/darts/pull/2684) by [Timon Erhart](https://github.com/turbotimon)
1818
- `TimeSeries.plot()` now supports setting the color for each component in the series. Simply pass a list / sequence of colors with length matching the number of components as parameters "c" or "colors". [#2680](https://github.com/unit8co/darts/pull/2680) by [Jules Authier](https://github.com/authierj)
1919
- Made it possible to run the quickstart notebook `00-quickstart.ipynb` locally. [#2691](https://github.com/unit8co/darts/pull/2691) by [Jules Authier](https://github.com/authierj)
20+
- Added `quantile` parameter to `RegressionModel.get_estimator()` to get the specific quantile estimator for probabilistic regression models using the `quantile` likelihood. [#2716](https://github.com/unit8co/darts/pull/2716) by [Antoine Madrona](https://github.com/madtoinou)
21+
22+
**Removed**
23+
24+
- 🔴 Removed method `RegressionModel.get_multioutput_estimator()`. Use `RegressionModel.get_estimator()` instead. [#2716](https://github.com/unit8co/darts/pull/2716) by [Antoine Madrona](https://github.com/madtoinou)
2025

2126
**Fixed**
2227

darts/explainability/shap_explainer.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -611,7 +611,7 @@ def __init__(
611611
self.explainers[i] = {}
612612
for j in range(self.target_dim):
613613
self.explainers[i][j] = self._build_explainer_sklearn(
614-
self.model.get_multioutput_estimator(horizon=i, target_dim=j),
614+
self.model.get_estimator(horizon=i, target_dim=j),
615615
self.background_X,
616616
self.shap_method,
617617
**kwargs,

darts/models/forecasting/regression_model.py

+49-38
Original file line numberDiff line numberDiff line change
@@ -505,62 +505,73 @@ def output_chunk_length(self) -> int:
505505
def output_chunk_shift(self) -> int:
506506
return self._output_chunk_shift
507507

508-
def get_multioutput_estimator(self, horizon: int, target_dim: int):
508+
def get_estimator(
509+
self, horizon: int, target_dim: int, quantile: Optional[float] = None
510+
):
509511
"""Returns the estimator that forecasts the `horizon`th step of the `target_dim`th target component.
510512
511-
Internally, estimators are grouped by `output_chunk_length` position, then by component.
513+
For probabilistic models fitting quantiles, it is possible to also specify the quantile.
514+
515+
The model is returned directly if it supports multi-output natively.
516+
517+
Note: Internally, estimators are grouped by `output_chunk_length` position, then by component. For probabilistic
518+
models fitting quantiles, there is an additional abstraction layer, grouping the estimators by `quantile`.
512519
513520
Parameters
514521
----------
515522
horizon
516523
The index of the forecasting point within `output_chunk_length`.
517524
target_dim
518525
The index of the target component.
526+
quantile
527+
Optionally, for probabilistic model with `likelihood="quantile"`, a quantile value.
519528
"""
520-
raise_if_not(
521-
isinstance(self.model, MultiOutputRegressor),
522-
"The sklearn model is not a MultiOutputRegressor object.",
523-
logger,
524-
)
525-
raise_if_not(
526-
0 <= horizon < self.output_chunk_length,
527-
f"`horizon` must be `>= 0` and `< output_chunk_length={self.output_chunk_length}`.",
528-
logger,
529-
)
530-
raise_if_not(
531-
0 <= target_dim < self.input_dim["target"],
532-
f"`target_dim` must be `>= 0`, and `< n_target_components={self.input_dim['target']}`.",
533-
logger,
534-
)
529+
if not isinstance(self.model, MultiOutputRegressor):
530+
logger.warning(
531+
"Model supports multi-output; a single estimator forecasts all the horizons and components."
532+
)
533+
return self.model
534+
535+
if not 0 <= horizon < self.output_chunk_length:
536+
raise_log(
537+
ValueError(
538+
f"`horizon` must be `>= 0` and `< output_chunk_length={self.output_chunk_length}`."
539+
),
540+
logger,
541+
)
542+
if not 0 <= target_dim < self.input_dim["target"]:
543+
raise_log(
544+
ValueError(
545+
f"`target_dim` must be `>= 0`, and `< n_target_components={self.input_dim['target']}`."
546+
),
547+
logger,
548+
)
535549

536550
# when multi_models=True, one model per horizon and target component
537551
idx_estimator = (
538552
self.multi_models * self.input_dim["target"] * horizon + target_dim
539553
)
540-
return self.model.estimators_[idx_estimator]
554+
if quantile is None:
555+
return self.model.estimators_[idx_estimator]
541556

542-
def get_estimator(self, horizon: int, target_dim: int):
543-
"""Returns the estimator that forecasts the `horizon`th step of the `target_dim`th target component.
544-
545-
The model is returned directly if it supports multi-output natively.
546-
547-
Parameters
548-
----------
549-
horizon
550-
The index of the forecasting point within `output_chunk_length`.
551-
target_dim
552-
The index of the target component.
553-
"""
554-
555-
if isinstance(self.model, MultiOutputRegressor):
556-
return self.get_multioutput_estimator(
557-
horizon=horizon, target_dim=target_dim
557+
# for quantile-models, the estimators are also grouped by quantiles
558+
if self.likelihood != "quantile":
559+
raise_log(
560+
ValueError(
561+
"`quantile` is only supported for probabilistic models that "
562+
"use `likelihood='quantile'`."
563+
),
564+
logger,
558565
)
559-
else:
560-
logger.info(
561-
"Model supports multi-output; a single estimator forecasts all the horizons and components."
566+
if quantile not in self._model_container:
567+
raise_log(
568+
ValueError(
569+
f"Invalid `quantile={quantile}`. Must be one of the fitted quantiles "
570+
f"`{list(self._model_container.keys())}`."
571+
),
572+
logger,
562573
)
563-
return self.model
574+
return self._model_container[quantile].estimators_[idx_estimator]
564575

565576
def _add_val_set_to_kwargs(
566577
self,

darts/tests/models/forecasting/test_regression_models.py

+139-5
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import functools
22
import importlib
33
import inspect
4+
import logging
45
import math
56
from copy import deepcopy
67
from itertools import product
@@ -1333,7 +1334,7 @@ def test_opti_historical_forecast_predict_checks(self):
13331334
],
13341335
)
13351336
def test_multioutput_wrapper(self, config):
1336-
"""Check that with input_chunk_length=1, wrapping in MultiOutputRegressor is not happening"""
1337+
"""Check that with input_chunk_length=1, wrapping in MultiOutputRegressor occurs only when necessary"""
13371338
model, supports_multioutput_natively = config
13381339
model.fit(series=self.sine_multivariate1)
13391340
if supports_multioutput_natively:
@@ -1396,7 +1397,7 @@ def test_multioutput_validation(self, config):
13961397
else:
13971398
assert not isinstance(model.model, MultiOutputRegressor)
13981399

1399-
def test_get_multioutput_estimator_multi_models(self):
1400+
def test_get_estimator_multi_models(self):
14001401
"""Craft training data so that estimator_[i].predict(X) == i + 1"""
14011402

14021403
def helper_check_overfitted_estimators(ts: TimeSeries, ocl: int):
@@ -1417,7 +1418,7 @@ def helper_check_overfitted_estimators(ts: TimeSeries, ocl: int):
14171418
estimator_counter = 0
14181419
for i in range(ocl):
14191420
for j in range(ts.width):
1420-
sub_model = m.get_multioutput_estimator(horizon=i, target_dim=j)
1421+
sub_model = m.get_estimator(horizon=i, target_dim=j)
14211422
pred = sub_model.predict(dummy_feats)[0]
14221423
# sub-model is overfitted on the training series
14231424
assert np.abs(estimator_counter - pred) < 1e-2
@@ -1455,7 +1456,7 @@ def helper_check_overfitted_estimators(ts: TimeSeries, ocl: int):
14551456
# estimators_[3] labels : [3]
14561457
helper_check_overfitted_estimators(ts, ocl)
14571458

1458-
def test_get_multioutput_estimator_single_model(self):
1459+
def test_get_estimator_single_model(self):
14591460
"""Check estimator getter when multi_models=False"""
14601461
# multivariate, one sub-model per component
14611462
ocl = 2
@@ -1485,11 +1486,144 @@ def test_get_multioutput_estimator_single_model(self):
14851486
dummy_feats = np.array([[0, 0, 0] * ts.width])
14861487
for i in range(ocl):
14871488
for j in range(ts.width):
1488-
sub_model = m.get_multioutput_estimator(horizon=i, target_dim=j)
1489+
sub_model = m.get_estimator(horizon=i, target_dim=j)
14891490
pred = sub_model.predict(dummy_feats)[0]
14901491
# sub-model forecast only depend on the target_dim
14911492
assert np.abs(j + 1 - pred) < 1e-2
14921493

1494+
@pytest.mark.parametrize("multi_models", [True, False])
1495+
def test_get_estimator_quantile(self, multi_models):
1496+
"""Check estimator getter when using quantile value"""
1497+
ocl = 3
1498+
lags = 3
1499+
quantiles = [0.01, 0.5, 0.99]
1500+
ts = tg.sine_timeseries(length=100, column_name="sine").stack(
1501+
tg.linear_timeseries(length=100, column_name="linear"),
1502+
)
1503+
1504+
m = XGBModel(
1505+
lags=lags,
1506+
output_chunk_length=ocl,
1507+
multi_models=multi_models,
1508+
likelihood="quantile",
1509+
quantiles=quantiles,
1510+
random_state=1,
1511+
)
1512+
m.fit(ts)
1513+
1514+
assert len(m._model_container) == len(quantiles)
1515+
assert sorted(list(m._model_container.keys())) == sorted(quantiles)
1516+
for quantile_container in m._model_container.values():
1517+
# one sub-model per quantile, per component, per horizon
1518+
if multi_models:
1519+
assert len(quantile_container.estimators_) == ocl * ts.width
1520+
# one sub-model per quantile, per component
1521+
else:
1522+
assert len(quantile_container.estimators_) == ts.width
1523+
1524+
# check that retrieve sub-models prediction match the "wrapper" model predictions
1525+
pred_input = ts[-lags:] if multi_models else ts[-lags - ocl + 1 :]
1526+
pred = m.predict(
1527+
n=ocl,
1528+
series=pred_input,
1529+
num_samples=1,
1530+
predict_likelihood_parameters=True,
1531+
)
1532+
for j in range(ts.width):
1533+
for i in range(ocl):
1534+
if multi_models:
1535+
dummy_feats = pred_input.values()[:lags]
1536+
else:
1537+
dummy_feats = pred_input.values()[i : +i + lags]
1538+
dummy_feats = np.expand_dims(dummy_feats.flatten(), 0)
1539+
for q in quantiles:
1540+
sub_model = m.get_estimator(horizon=i, target_dim=j, quantile=q)
1541+
pred_sub_model = sub_model.predict(dummy_feats)[0]
1542+
assert (
1543+
pred_sub_model
1544+
== pred[f"{ts.components[j]}_q{q:.2f}"].values()[i][0]
1545+
)
1546+
1547+
def test_get_estimator_exceptions(self, caplog):
1548+
"""Check that all the corner-cases are properly covered by the method"""
1549+
ts = TimeSeries.from_values(
1550+
values=np.array([
1551+
[0, 0, 0, 0, 1],
1552+
[0, 0, 0, 0, 2],
1553+
]).T,
1554+
columns=["a", "b"],
1555+
)
1556+
m = LinearRegressionModel(
1557+
lags=2,
1558+
output_chunk_length=2,
1559+
random_state=1,
1560+
)
1561+
m.fit(ts["a"])
1562+
# not wrapped in MultiOutputRegressor because of native multi-output support
1563+
with caplog.at_level(logging.WARNING):
1564+
m.get_estimator(horizon=0, target_dim=0)
1565+
assert len(caplog.records) == 1
1566+
assert caplog.records[0].levelname == "WARNING"
1567+
assert caplog.records[0].message == (
1568+
"Model supports multi-output; a single estimator "
1569+
"forecasts all the horizons and components."
1570+
)
1571+
1572+
# univariate, deterministic, ocl > 2
1573+
m = RegressionModel(
1574+
model=HistGradientBoostingRegressor(),
1575+
lags=2,
1576+
output_chunk_length=2,
1577+
)
1578+
m.fit(ts["a"])
1579+
# horizon > ocl
1580+
with pytest.raises(ValueError) as err:
1581+
m.get_estimator(horizon=3, target_dim=0)
1582+
assert str(err.value).startswith(
1583+
"`horizon` must be `>= 0` and `< output_chunk_length"
1584+
)
1585+
# target dim > training series width
1586+
with pytest.raises(ValueError) as err:
1587+
m.get_estimator(horizon=0, target_dim=1)
1588+
assert str(err.value).startswith(
1589+
"`target_dim` must be `>= 0`, and `< n_target_components="
1590+
)
1591+
1592+
# univariate, probabilistic
1593+
# using the quantiles argument to force wrapping in MultiOutputRegressor
1594+
m = XGBModel(
1595+
lags=2,
1596+
output_chunk_length=2,
1597+
random_state=1,
1598+
likelihood="poisson",
1599+
quantiles=[0.5],
1600+
)
1601+
m.fit(ts["a"])
1602+
# incorrect likelihood
1603+
with pytest.raises(ValueError) as err:
1604+
m.get_estimator(horizon=0, target_dim=0, quantile=0.1)
1605+
assert str(err.value).startswith(
1606+
"`quantile` is only supported for probabilistic models that "
1607+
"use `likelihood='quantile'`."
1608+
)
1609+
1610+
# univariate, probabilistic
1611+
m = XGBModel(
1612+
lags=2,
1613+
output_chunk_length=2,
1614+
random_state=1,
1615+
likelihood="quantile",
1616+
quantiles=[0.01, 0.5, 0.99],
1617+
)
1618+
m.fit(ts["a"])
1619+
# retrieving a non-defined quantile
1620+
with pytest.raises(ValueError) as err:
1621+
m.get_estimator(horizon=0, target_dim=0, quantile=0.1)
1622+
assert str(err.value).startswith(
1623+
"Invalid `quantile=0.1`. Must be one of the fitted quantiles "
1624+
"`[0.01, 0.5, 0.99]`."
1625+
)
1626+
14931627
@pytest.mark.parametrize("mode", [True, False])
14941628
def test_regression_model(self, mode):
14951629
lags = 4

0 commit comments

Comments
 (0)