Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 22 additions & 3 deletions moabb/datasets/preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,20 @@
import numpy as np
from sklearn.base import BaseEstimator, TransformerMixin
from sklearn.pipeline import FunctionTransformer, Pipeline, _name_estimators
from sklearn.utils._repr_html.estimator import _VisualBlock


# Handle different scikit-learn versions for _VisualBlock import
# sklearn >= 1.6 moved _VisualBlock to sklearn.utils._repr_html.estimator
# sklearn < 1.6 has it in sklearn.utils._estimator_html_repr
try:
from sklearn.utils._repr_html.estimator import _VisualBlock
except (ImportError, ModuleNotFoundError):
try:
from sklearn.utils._estimator_html_repr import _VisualBlock
except (ImportError, ModuleNotFoundError):
# Fallback: create a dummy _VisualBlock for older sklearn versions
# that don't have HTML representation support
_VisualBlock = None


log = logging.getLogger(__name__)
Expand Down Expand Up @@ -88,7 +101,9 @@ def __sklearn_is_fitted__(self):
return True

def _sk_visual_block_(self):
"""Tell sklearn’s diagrammer to lay us out in parallel."""
"""Tell sklearn's diagrammer to lay us out in parallel."""
if _VisualBlock is None:
return NotImplemented
names, estimators = zip(*self.transformers)
return _VisualBlock(
kind="parallel",
Expand All @@ -114,7 +129,9 @@ def __sklearn_is_fitted__(self):
return True

def _sk_visual_block_(self):
"""Tell sklearn’s diagrammer to lay us out in parallel."""
"""Tell sklearn's diagrammer to lay us out in parallel."""
if _VisualBlock is None:
return NotImplemented
return _VisualBlock(
kind="parallel",
name_caption=str(self.__class__.__name__),
Expand Down Expand Up @@ -405,6 +422,8 @@ def __repr__(self):
return self._display_name

def _sk_visual_block_(self):
if _VisualBlock is None:
return NotImplemented
return _VisualBlock(
kind="single",
estimators=self,
Expand Down
5 changes: 5 additions & 0 deletions moabb/evaluations/evaluations.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
)
from moabb.evaluations.utils import (
_create_save_path,
_ensure_fitted,
_save_model_cv,
)

Expand Down Expand Up @@ -226,6 +227,7 @@ def _evaluate(
cvclf = clone(grid_clf)

cvclf.fit(X_[train], y_[train])
_ensure_fitted(cvclf)

score = scorer(cvclf, X_[test], y_[test])

Expand Down Expand Up @@ -308,6 +310,7 @@ def score_explicit(self, clf, X_train, y_train, X_test, y_test):
t_start = time()
try:
model = clf.fit(X_train, y_train)
_ensure_fitted(model)
score = _score(
estimator=model,
X_test=X_test,
Expand Down Expand Up @@ -540,6 +543,7 @@ def evaluate(
cvclf = clone(grid_clf)

cvclf.fit(X[train], y[train])
_ensure_fitted(cvclf)

model_list.append(cvclf)
score = scorer(cvclf, X[test], y[test])
Expand Down Expand Up @@ -724,6 +728,7 @@ def evaluate(
)

model = deepcopy(clf).fit(X[train], y[train])
_ensure_fitted(model)

if _carbonfootprint:
emissions = tracker.stop()
Expand Down
61 changes: 61 additions & 0 deletions moabb/evaluations/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,67 @@
optuna_available = False


def _ensure_fitted(estimator):
"""Ensure an estimator is properly marked as fitted for sklearn 1.8+.

In sklearn 1.8+, Pipeline.predict() calls check_is_fitted(self) which
may fail for some estimators (especially deep learning wrappers) that
don't properly set fitted attributes. This function adds the necessary
attributes to ensure the estimator passes sklearn's fitted check.

Parameters
----------
estimator : sklearn-compatible estimator
The fitted estimator to mark as fitted. This should be called
after fit() has been called on the estimator.

Returns
-------
estimator : sklearn-compatible estimator
The same estimator with fitted attributes set.

Notes
-----
This function modifies the estimator in-place and returns it for
convenience. sklearn's check_is_fitted looks for:
1. __sklearn_is_fitted__() method returning True
2. Or any attribute ending with '_' (like classes_, coef_, etc.)

We add a __sklearn_is_fitted__ method that returns True.
"""

# Define a method that returns True to indicate fitted state
def _sklearn_is_fitted_true(self):
return True

# Add __sklearn_is_fitted__ method if not present or if it returns False
if not hasattr(estimator, "__sklearn_is_fitted__"):
import types

estimator.__sklearn_is_fitted__ = types.MethodType(
_sklearn_is_fitted_true, estimator
)
else:
# Check if existing method returns False (unfitted)
try:
if not estimator.__sklearn_is_fitted__():
import types

estimator.__sklearn_is_fitted__ = types.MethodType(
_sklearn_is_fitted_true, estimator
)
except Exception:
pass

# For Pipeline objects, also ensure all steps are marked
if isinstance(estimator, Pipeline):
for name, step in estimator.steps:
if step is not None:
_ensure_fitted(step)

return estimator


def _check_if_is_pytorch_model(model):
"""Check if the model is a skorch model.

Expand Down
Loading