|
1 | 1 | from typing import Optional
|
2 | 2 |
|
3 |
| -from sklearn import __version__ as sklearn_version |
4 | 3 | from sklearn.base import is_classifier
|
5 | 4 | from sklearn.multioutput import MultiOutputRegressor as sk_MultiOutputRegressor
|
6 | 5 | from sklearn.multioutput import _fit_estimator
|
7 | 6 | from sklearn.utils.multiclass import check_classification_targets
|
8 |
| -from sklearn.utils.validation import has_fit_parameter |
| 7 | +from sklearn.utils.parallel import Parallel, delayed |
| 8 | +from sklearn.utils.validation import ( |
| 9 | + _check_method_params, |
| 10 | + has_fit_parameter, |
| 11 | + validate_data, |
| 12 | +) |
9 | 13 |
|
10 | 14 | from darts.logging import get_logger, raise_log
|
11 | 15 |
|
12 |
| -if sklearn_version >= "1.4": |
13 |
| - # sklearn renamed `_check_fit_params` to `_check_method_params` in v1.4 |
14 |
| - from sklearn.utils.validation import _check_method_params |
15 |
| -else: |
16 |
| - from sklearn.utils.validation import _check_fit_params as _check_method_params |
17 |
| - |
18 |
| -if sklearn_version >= "1.3": |
19 |
| - # delayed was moved from sklearn.utils.fixes to sklearn.utils.parallel in v1.3 |
20 |
| - from sklearn.utils.parallel import Parallel, delayed |
21 |
| -else: |
22 |
| - from joblib import Parallel |
23 |
| - from sklearn.utils.fixes import delayed |
24 |
| - |
25 | 16 | logger = get_logger(__name__)
|
26 | 17 |
|
27 | 18 |
|
@@ -78,8 +69,7 @@ def fit(self, X, y, sample_weight=None, **fit_params):
|
78 | 69 | ValueError("The base estimator should implement a fit method"),
|
79 | 70 | logger=logger,
|
80 | 71 | )
|
81 |
| - |
82 |
| - y = self._validate_data(X="no_validation", y=y, multi_output=True) |
| 72 | + y = validate_data(self.estimator, X="no_validation", y=y, multi_output=True) |
83 | 73 |
|
84 | 74 | if is_classifier(self):
|
85 | 75 | check_classification_targets(y)
|
|
0 commit comments