Skip to content

Commit d24d1b5

Browse files
Fix/code maintenance after dependencies update (#2722)
* fix: simplified logic after increase in minimym version for statsmodels * fix: simplified logic after increase of minimum version for sklearn * update changelog * Update CHANGELOG.md --------- Co-authored-by: Dennis Bader <[email protected]>
1 parent f1850d6 commit d24d1b5

File tree

3 files changed

+13
-30
lines changed

3 files changed

+13
-30
lines changed

CHANGELOG.md

+1
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ but cannot always guarantee backwards compatibility. Changes that may **break co
3434

3535
- Bumped minimum scikit-learn version from `1.0.1` to `1.6.0`. This was required due to sklearn deprecating `_get_tags` in favor of `BaseEstimator.__sklearn_tags__` in version 1.7. This leads to increasing both sklearn and XGBoost minimum supported version to 1.6 and 2.1.4 respectively. [#2659](https://github.com/unit8co/darts/pull/2659) by [Jonas Blanc](https://github.com/jonasblanc)
3636
- Bumped minimum xgboost version from `1.6.0` to `2.1.4` for the same reason as bumping the minimum sklearn version. [#2659](https://github.com/unit8co/darts/pull/2659) by [Jonas Blanc](https://github.com/jonasblanc)
37+
- Various code changes to reflect the updated dependencies versions (sklearn, statsmodel) and eliminate various warnings. [#2722](https://github.com/unit8co/darts/pull/2722) by [Antoine Madrona](https://github.com/madtoinou)
3738

3839
### For developers of the library:
3940

darts/models/forecasting/arima.py

+5-13
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020
from typing_extensions import TypeAlias
2121

2222
import numpy as np
23-
from statsmodels import __version_tuple__ as statsmodels_version
2423
from statsmodels.tsa.arima.model import ARIMA as staARIMA
2524

2625
from darts.logging import get_logger
@@ -31,9 +30,6 @@
3130

3231
logger = get_logger(__name__)
3332

34-
# Check whether we are running statsmodels >= 0.13.5 or not:
35-
statsmodels_above_0135 = statsmodels_version > (0, 13, 5)
36-
3733

3834
IntOrIntSequence: TypeAlias = Union[int, Sequence[int]]
3935

@@ -137,15 +133,11 @@ def encode_year(idx):
137133
self.seasonal_order = seasonal_order
138134
self.trend = trend
139135
self.model = None
140-
if statsmodels_above_0135:
141-
self._random_state = (
142-
random_state
143-
if random_state is None
144-
else np.random.RandomState(random_state)
145-
)
146-
else:
147-
self._random_state = None
148-
np.random.seed(random_state if random_state is not None else 0)
136+
self._random_state = (
137+
random_state
138+
if random_state is None
139+
else np.random.RandomState(random_state)
140+
)
149141

150142
@property
151143
def supports_multivariate(self) -> bool:

darts/utils/multioutput.py

+7-17
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,18 @@
11
from typing import Optional
22

3-
from sklearn import __version__ as sklearn_version
43
from sklearn.base import is_classifier
54
from sklearn.multioutput import MultiOutputRegressor as sk_MultiOutputRegressor
65
from sklearn.multioutput import _fit_estimator
76
from sklearn.utils.multiclass import check_classification_targets
8-
from sklearn.utils.validation import has_fit_parameter
7+
from sklearn.utils.parallel import Parallel, delayed
8+
from sklearn.utils.validation import (
9+
_check_method_params,
10+
has_fit_parameter,
11+
validate_data,
12+
)
913

1014
from darts.logging import get_logger, raise_log
1115

12-
if sklearn_version >= "1.4":
13-
# sklearn renamed `_check_fit_params` to `_check_method_params` in v1.4
14-
from sklearn.utils.validation import _check_method_params
15-
else:
16-
from sklearn.utils.validation import _check_fit_params as _check_method_params
17-
18-
if sklearn_version >= "1.3":
19-
# delayed was moved from sklearn.utils.fixes to sklearn.utils.parallel in v1.3
20-
from sklearn.utils.parallel import Parallel, delayed
21-
else:
22-
from joblib import Parallel
23-
from sklearn.utils.fixes import delayed
24-
2516
logger = get_logger(__name__)
2617

2718

@@ -78,8 +69,7 @@ def fit(self, X, y, sample_weight=None, **fit_params):
7869
ValueError("The base estimator should implement a fit method"),
7970
logger=logger,
8071
)
81-
82-
y = self._validate_data(X="no_validation", y=y, multi_output=True)
72+
y = validate_data(self.estimator, X="no_validation", y=y, multi_output=True)
8373

8474
if is_classifier(self):
8575
check_classification_targets(y)

0 commit comments

Comments
 (0)