Skip to content

Commit c975dd6

Browse files
authored
MAINT fix compatibility with scikit-learn 1.7 (#1137)
1 parent 2370cf9 commit c975dd6

File tree

4 files changed

+12
-5
lines changed

4 files changed

+12
-5
lines changed

imblearn/ensemble/_easy_ensemble.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
# License: MIT
66

77
import copy
8+
import inspect
89
import numbers
910

1011
import numpy as np
@@ -13,7 +14,6 @@
1314
from sklearn.ensemble._bagging import _parallel_decision_function
1415
from sklearn.ensemble._base import _partition_estimators
1516
from sklearn.utils._param_validation import Interval, StrOptions
16-
from sklearn.utils._tags import _safe_tags
1717
from sklearn.utils.fixes import parse_version
1818
from sklearn.utils.metaestimators import available_if
1919
from sklearn.utils.parallel import Parallel, delayed
@@ -312,11 +312,16 @@ def decision_function(self, X):
312312
# Parallel loop
313313
n_jobs, _, starts = _partition_estimators(self.n_estimators, self.n_jobs)
314314

315+
kwargs = {}
316+
if "params" in inspect.signature(_parallel_decision_function).parameters:
317+
kwargs["params"] = {}
318+
315319
all_decisions = Parallel(n_jobs=n_jobs, verbose=self.verbose)(
316320
delayed(_parallel_decision_function)(
317321
self.estimators_[starts[i] : starts[i + 1]],
318322
self.estimators_features_[starts[i] : starts[i + 1]],
319323
X,
324+
**kwargs,
320325
)
321326
for i in range(n_jobs)
322327
)
@@ -343,7 +348,7 @@ def _get_estimator(self):
343348
return self.estimator
344349

345350
def _more_tags(self):
346-
return {"allow_nan": _safe_tags(self._get_estimator(), "allow_nan")}
351+
return {"allow_nan": get_tags(self._get_estimator()).input_tags.allow_nan}
347352

348353
def __sklearn_tags__(self):
349354
tags = super().__sklearn_tags__()

imblearn/metrics/tests/test_classification.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -454,7 +454,7 @@ def test_iba_error_y_score_prob_error(score_loss):
454454
y_true, y_pred, _ = make_prediction(binary=True)
455455

456456
aps = make_index_balanced_accuracy(alpha=0.5, squared=True)(score_loss)
457-
with pytest.raises(AttributeError):
457+
with pytest.raises((AttributeError, TypeError)):
458458
aps(y_true, y_pred)
459459

460460

imblearn/over_sampling/_smote/base.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -981,6 +981,5 @@ def _more_tags(self):
981981

982982
def __sklearn_tags__(self):
983983
tags = super().__sklearn_tags__()
984-
tags.input_tags.sparse = False
985984
tags.input_tags.string = True
986985
return tags

imblearn/tests/test_pipeline.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@
3939
from imblearn.pipeline import Pipeline, make_pipeline
4040
from imblearn.under_sampling import EditedNearestNeighbours as ENN
4141
from imblearn.under_sampling import RandomUnderSampler
42-
from imblearn.utils._sklearn_compat import sklearn_version
42+
from imblearn.utils._sklearn_compat import Tags, sklearn_version
4343
from imblearn.utils.estimator_checks import check_param_validation
4444

4545
JUNK_FOOD_DOCS = (
@@ -61,6 +61,9 @@ def __init__(self, a=None, b=None):
6161
self.a = a
6262
self.b = b
6363

64+
def __sklearn_tags__(self):
65+
return Tags()
66+
6467

6568
class NoTrans(NoFit):
6669
def fit(self, X, y):

0 commit comments

Comments
 (0)