Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix/code maintenance after dependencies update #2722

Merged
merged 6 commits into from
Mar 8, 2025
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ but cannot always guarantee backwards compatibility. Changes that may **break co

- Bumped minimum scikit-learn version from `1.0.1` to `1.6.0`. This was required due to sklearn deprecating `_get_tags` in favor of `BaseEstimator.__sklearn_tags__` in version 1.7. This leads to increasing both sklearn and XGBoost minimum supported version to 1.6 and 2.1.4 respectively. [#2659](https://github.com/unit8co/darts/pull/2659) by [Jonas Blanc](https://github.com/jonasblanc)
- Bumped minimum xgboost version from `1.6.0` to `2.1.4` for the same reason as bumping the minimum sklearn version. [#2659](https://github.com/unit8co/darts/pull/2659) by [Jonas Blanc](https://github.com/jonasblanc)
- Various code changes to reflect the updated dependencies versions (sklearn, statsmodel) et eliminate various warnings. [#2722](https://github.com/unit8co/darts/pull/2722) by [Antoine Madrona](https://github.com/madtoinou)

### For developers of the library:

Expand Down
18 changes: 5 additions & 13 deletions darts/models/forecasting/arima.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
from typing_extensions import TypeAlias

import numpy as np
from statsmodels import __version_tuple__ as statsmodels_version
from statsmodels.tsa.arima.model import ARIMA as staARIMA

from darts.logging import get_logger
Expand All @@ -31,9 +30,6 @@

logger = get_logger(__name__)

# Check whether we are running statsmodels >= 0.13.5 or not:
statsmodels_above_0135 = statsmodels_version > (0, 13, 5)


IntOrIntSequence: TypeAlias = Union[int, Sequence[int]]

Expand Down Expand Up @@ -137,15 +133,11 @@ def encode_year(idx):
self.seasonal_order = seasonal_order
self.trend = trend
self.model = None
if statsmodels_above_0135:
self._random_state = (
random_state
if random_state is None
else np.random.RandomState(random_state)
)
else:
self._random_state = None
np.random.seed(random_state if random_state is not None else 0)
self._random_state = (
random_state
if random_state is None
else np.random.RandomState(random_state)
)

@property
def supports_multivariate(self) -> bool:
Expand Down
24 changes: 7 additions & 17 deletions darts/utils/multioutput.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,18 @@
from typing import Optional

from sklearn import __version__ as sklearn_version
from sklearn.base import is_classifier
from sklearn.multioutput import MultiOutputRegressor as sk_MultiOutputRegressor
from sklearn.multioutput import _fit_estimator
from sklearn.utils.multiclass import check_classification_targets
from sklearn.utils.validation import has_fit_parameter
from sklearn.utils.parallel import Parallel, delayed
from sklearn.utils.validation import (
_check_method_params,
has_fit_parameter,
validate_data,
)

from darts.logging import get_logger, raise_log

if sklearn_version >= "1.4":
# sklearn renamed `_check_fit_params` to `_check_method_params` in v1.4
from sklearn.utils.validation import _check_method_params
else:
from sklearn.utils.validation import _check_fit_params as _check_method_params

if sklearn_version >= "1.3":
# delayed was moved from sklearn.utils.fixes to sklearn.utils.parallel in v1.3
from sklearn.utils.parallel import Parallel, delayed
else:
from joblib import Parallel
from sklearn.utils.fixes import delayed

logger = get_logger(__name__)


Expand Down Expand Up @@ -78,8 +69,7 @@ def fit(self, X, y, sample_weight=None, **fit_params):
ValueError("The base estimator should implement a fit method"),
logger=logger,
)

y = self._validate_data(X="no_validation", y=y, multi_output=True)
y = validate_data(self.estimator, X="no_validation", y=y, multi_output=True)

if is_classifier(self):
check_classification_targets(y)
Expand Down
Loading