Skip to content

Commit d90e7eb

Browse files
Fix compatibility with scikit-learn 1.8 (#852)
* fixing more colab issues * updating the utils
1 parent a3e6db4 commit d90e7eb

File tree

3 files changed

+88
-3
lines changed

3 files changed

+88
-3
lines changed

moabb/datasets/preprocessing.py

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,20 @@
77
import numpy as np
88
from sklearn.base import BaseEstimator, TransformerMixin
99
from sklearn.pipeline import FunctionTransformer, Pipeline, _name_estimators
10-
from sklearn.utils._repr_html.estimator import _VisualBlock
10+
11+
12+
# Handle different scikit-learn versions for _VisualBlock import
13+
# sklearn >= 1.6 moved _VisualBlock to sklearn.utils._repr_html.estimator
14+
# sklearn < 1.6 has it in sklearn.utils._estimator_html_repr
15+
try:
16+
from sklearn.utils._repr_html.estimator import _VisualBlock
17+
except (ImportError, ModuleNotFoundError):
18+
try:
19+
from sklearn.utils._estimator_html_repr import _VisualBlock
20+
except (ImportError, ModuleNotFoundError):
21+
# Fallback: create a dummy _VisualBlock for older sklearn versions
22+
# that don't have HTML representation support
23+
_VisualBlock = None
1124

1225

1326
log = logging.getLogger(__name__)
@@ -88,7 +101,9 @@ def __sklearn_is_fitted__(self):
88101
return True
89102

90103
def _sk_visual_block_(self):
91-
"""Tell sklearn’s diagrammer to lay us out in parallel."""
104+
"""Tell sklearn's diagrammer to lay us out in parallel."""
105+
if _VisualBlock is None:
106+
return NotImplemented
92107
names, estimators = zip(*self.transformers)
93108
return _VisualBlock(
94109
kind="parallel",
@@ -114,7 +129,9 @@ def __sklearn_is_fitted__(self):
114129
return True
115130

116131
def _sk_visual_block_(self):
117-
"""Tell sklearn’s diagrammer to lay us out in parallel."""
132+
"""Tell sklearn's diagrammer to lay us out in parallel."""
133+
if _VisualBlock is None:
134+
return NotImplemented
118135
return _VisualBlock(
119136
kind="parallel",
120137
name_caption=str(self.__class__.__name__),
@@ -405,6 +422,8 @@ def __repr__(self):
405422
return self._display_name
406423

407424
def _sk_visual_block_(self):
425+
if _VisualBlock is None:
426+
return NotImplemented
408427
return _VisualBlock(
409428
kind="single",
410429
estimators=self,

moabb/evaluations/evaluations.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
)
2626
from moabb.evaluations.utils import (
2727
_create_save_path,
28+
_ensure_fitted,
2829
_save_model_cv,
2930
)
3031

@@ -226,6 +227,7 @@ def _evaluate(
226227
cvclf = clone(grid_clf)
227228

228229
cvclf.fit(X_[train], y_[train])
230+
_ensure_fitted(cvclf)
229231

230232
score = scorer(cvclf, X_[test], y_[test])
231233

@@ -308,6 +310,7 @@ def score_explicit(self, clf, X_train, y_train, X_test, y_test):
308310
t_start = time()
309311
try:
310312
model = clf.fit(X_train, y_train)
313+
_ensure_fitted(model)
311314
score = _score(
312315
estimator=model,
313316
X_test=X_test,
@@ -540,6 +543,7 @@ def evaluate(
540543
cvclf = clone(grid_clf)
541544

542545
cvclf.fit(X[train], y[train])
546+
_ensure_fitted(cvclf)
543547

544548
model_list.append(cvclf)
545549
score = scorer(cvclf, X[test], y[test])
@@ -724,6 +728,7 @@ def evaluate(
724728
)
725729

726730
model = deepcopy(clf).fit(X[train], y[train])
731+
_ensure_fitted(model)
727732

728733
if _carbonfootprint:
729734
emissions = tracker.stop()

moabb/evaluations/utils.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,67 @@
2323
optuna_available = False
2424

2525

26+
def _ensure_fitted(estimator):
27+
"""Ensure an estimator is properly marked as fitted for sklearn 1.8+.
28+
29+
In sklearn 1.8+, Pipeline.predict() calls check_is_fitted(self) which
30+
may fail for some estimators (especially deep learning wrappers) that
31+
don't properly set fitted attributes. This function adds the necessary
32+
attributes to ensure the estimator passes sklearn's fitted check.
33+
34+
Parameters
35+
----------
36+
estimator : sklearn-compatible estimator
37+
The fitted estimator to mark as fitted. This should be called
38+
after fit() has been called on the estimator.
39+
40+
Returns
41+
-------
42+
estimator : sklearn-compatible estimator
43+
The same estimator with fitted attributes set.
44+
45+
Notes
46+
-----
47+
This function modifies the estimator in-place and returns it for
48+
convenience. sklearn's check_is_fitted looks for:
49+
1. __sklearn_is_fitted__() method returning True
50+
2. Or any attribute ending with '_' (like classes_, coef_, etc.)
51+
52+
We add a __sklearn_is_fitted__ method that returns True.
53+
"""
54+
55+
# Define a method that returns True to indicate fitted state
56+
def _sklearn_is_fitted_true(self):
57+
return True
58+
59+
# Add __sklearn_is_fitted__ method if not present or if it returns False
60+
if not hasattr(estimator, "__sklearn_is_fitted__"):
61+
import types
62+
63+
estimator.__sklearn_is_fitted__ = types.MethodType(
64+
_sklearn_is_fitted_true, estimator
65+
)
66+
else:
67+
# Check if existing method returns False (unfitted)
68+
try:
69+
if not estimator.__sklearn_is_fitted__():
70+
import types
71+
72+
estimator.__sklearn_is_fitted__ = types.MethodType(
73+
_sklearn_is_fitted_true, estimator
74+
)
75+
except Exception:
76+
pass
77+
78+
# For Pipeline objects, also ensure all steps are marked
79+
if isinstance(estimator, Pipeline):
80+
for name, step in estimator.steps:
81+
if step is not None:
82+
_ensure_fitted(step)
83+
84+
return estimator
85+
86+
2687
def _check_if_is_pytorch_model(model):
2788
"""Check if the model is a skorch model.
2889

0 commit comments

Comments
 (0)