Skip to content

Commit a3e6db4

Browse files
Fix warnings from scikit-learn 1.8+ and introduce FixedPipeline (#850)
* p * fixing number
1 parent 1d8a638 commit a3e6db4

File tree

4 files changed

+71
-16
lines changed

4 files changed

+71
-16
lines changed

docs/source/whats_new.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ Version 1.5 (Source - GitHub)
3131

3232
⚕️ Code health
3333
^^^^^^^^^^^^^^
34-
34+
- Fixing warnings from the latest scikit-learn version within the Preprocessing logic (:gh:`850` by `Bruno Aristimunha`_)
3535

3636

3737
Version 1.4.2 (Stable - PyPi)

moabb/datasets/base.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,9 @@
1515
import mne_bids
1616
import numpy as np
1717
import pandas as pd
18-
from sklearn.pipeline import Pipeline
1918

2019
from moabb.datasets.bids_interface import StepType, _interface_map
21-
from moabb.datasets.preprocessing import SetRawAnnotations
20+
from moabb.datasets.preprocessing import FixedPipeline, SetRawAnnotations
2221

2322

2423
log = logging.getLogger(__name__)
@@ -393,7 +392,7 @@ def __init__(
393392
self.unit_factor = unit_factor
394393

395394
def _create_process_pipeline(self):
396-
return Pipeline(
395+
return FixedPipeline(
397396
[
398397
(
399398
StepType.RAW,
@@ -620,7 +619,7 @@ def _get_single_subject_data_using_cache(
620619
self,
621620
subject,
622621
path=cache_config.path,
623-
process_pipeline=Pipeline(cached_steps),
622+
process_pipeline=FixedPipeline(cached_steps),
624623
verbose=cache_config.verbose,
625624
)
626625

@@ -667,7 +666,7 @@ def _get_single_subject_data_using_cache(
667666
self,
668667
subject,
669668
path=cache_config.path,
670-
process_pipeline=Pipeline(
669+
process_pipeline=FixedPipeline(
671670
cached_steps + remaining_steps[: step_idx + 1]
672671
),
673672
verbose=cache_config.verbose,

moabb/datasets/preprocessing.py

Lines changed: 56 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,50 @@
66
import mne
77
import numpy as np
88
from sklearn.base import BaseEstimator, TransformerMixin
9-
from sklearn.pipeline import FunctionTransformer, Pipeline
9+
from sklearn.pipeline import FunctionTransformer, Pipeline, _name_estimators
1010
from sklearn.utils._repr_html.estimator import _VisualBlock
1111

1212

1313
log = logging.getLogger(__name__)
1414

1515

16+
class FixedPipeline(Pipeline):
17+
"""A Pipeline that is always considered fitted.
18+
19+
This is useful for pre-processing pipelines that don't require fitting,
20+
as they only apply fixed transformations (e.g., filtering, epoching).
21+
This avoids the FutureWarning from sklearn 1.8+ about unfitted pipelines.
22+
"""
23+
24+
def __sklearn_is_fitted__(self):
25+
"""Return True to indicate this pipeline is always considered fitted."""
26+
return True
27+
28+
29+
def make_fixed_pipeline(*steps, memory=None, verbose=False):
30+
"""Create a FixedPipeline that is always considered fitted.
31+
32+
This is a drop-in replacement for sklearn's make_pipeline that creates
33+
a pipeline marked as fitted, suitable for fixed transformations.
34+
35+
Parameters
36+
----------
37+
*steps : list of estimators
38+
List of (name, transform) tuples that are chained in the pipeline.
39+
memory : str or object with the joblib.Memory interface, default=None
40+
Used to cache the fitted transformers of the pipeline.
41+
verbose : bool, default=False
42+
If True, the time elapsed while fitting each step will be printed.
43+
44+
Returns
45+
-------
46+
p : FixedPipeline
47+
A FixedPipeline object.
48+
"""
49+
50+
return FixedPipeline(_name_estimators(steps), memory=memory, verbose=verbose)
51+
52+
1653
def _is_none_pipeline(pipeline):
1754
"""Check if a pipeline is the result of make_pipeline(None)"""
1855
return (
@@ -44,6 +81,11 @@ def transform(self, X, y=None):
4481
def fit(self, X, y=None):
4582
for _, t in self.transformers:
4683
t.fit(X)
84+
return self
85+
86+
def __sklearn_is_fitted__(self):
87+
"""Return True to indicate this transformer is always considered fitted."""
88+
return True
4789

4890
def _sk_visual_block_(self):
4991
"""Tell sklearn’s diagrammer to lay us out in parallel."""
@@ -65,7 +107,11 @@ def __init__(self):
65107
# when using the pipeline
66108

67109
def fit(self, X, y=None):
68-
pass
110+
return self
111+
112+
def __sklearn_is_fitted__(self):
113+
"""Return True to indicate this transformer is always considered fitted."""
114+
return True
69115

70116
def _sk_visual_block_(self):
71117
"""Tell sklearn’s diagrammer to lay us out in parallel."""
@@ -103,6 +149,7 @@ class SetRawAnnotations(FixedTransformer):
103149
"""
104150

105151
def __init__(self, event_id, interval: Tuple[float, float]):
152+
super().__init__()
106153
assert isinstance(event_id, dict) # not None
107154
self.event_id = event_id
108155
values = _get_event_id_values(self.event_id)
@@ -153,6 +200,7 @@ class RawToEvents(FixedTransformer):
153200
"""
154201

155202
def __init__(self, event_id: dict[str, int], interval: Tuple[float, float]):
203+
super().__init__()
156204
assert isinstance(event_id, dict) # not None
157205
self.event_id = event_id
158206
self.interval = interval
@@ -212,6 +260,7 @@ def __init__(
212260
stop_offset,
213261
marker=1,
214262
):
263+
super().__init__()
215264
self.length = length
216265
self.stride = stride
217266
self.start_offset = start_offset
@@ -245,12 +294,16 @@ def transform(self, raw: mne.io.BaseRaw, y=None):
245294

246295

247296
class EpochsToEvents(FixedTransformer):
297+
def __init__(self):
298+
super().__init__()
299+
248300
def transform(self, epochs, y=None):
249301
return epochs.events
250302

251303

252304
class EventsToLabels(FixedTransformer):
253305
def __init__(self, event_id):
306+
super().__init__()
254307
self.event_id = event_id
255308

256309
def transform(self, events, y=None):
@@ -269,6 +322,7 @@ def __init__(
269322
channels: List[str] = None,
270323
interpolate_missing_channels: bool = False,
271324
):
325+
super().__init__()
272326
assert isinstance(event_id, dict) # not None
273327
self.event_id = event_id
274328
self.tmin = tmin

moabb/paradigms/base.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,21 +6,23 @@
66
import mne
77
import numpy as np
88
import pandas as pd
9-
from sklearn.pipeline import Pipeline, make_pipeline
9+
from sklearn.pipeline import Pipeline
1010
from sklearn.preprocessing import FunctionTransformer
1111

1212
from moabb.datasets.base import BaseDataset
1313
from moabb.datasets.bids_interface import StepType
1414
from moabb.datasets.preprocessing import (
1515
EpochsToEvents,
1616
EventsToLabels,
17+
FixedPipeline,
1718
ForkPipelines,
1819
RawToEpochs,
1920
RawToEvents,
2021
SetRawAnnotations,
2122
get_crop_pipeline,
2223
get_filter_pipeline,
2324
get_resample_pipeline,
25+
make_fixed_pipeline,
2426
)
2527

2628

@@ -203,20 +205,20 @@ def make_process_pipelines(
203205
]
204206
)
205207
steps.append((StepType.ARRAY, array_events_pipeline))
206-
process_pipelines.append(Pipeline(steps))
208+
process_pipelines.append(FixedPipeline(steps))
207209
return process_pipelines
208210

209211
def make_labels_pipeline(self, dataset, return_epochs=False, return_raws=False):
210212
"""Returns the pipeline that extracts the labels from the
211213
output of the postprocess_pipeline.
212214
Refer to the arguments of :func:`get_data` for more information."""
213215
if return_epochs:
214-
labels_pipeline = make_pipeline(
216+
labels_pipeline = make_fixed_pipeline(
215217
EpochsToEvents(),
216218
EventsToLabels(event_id=self.used_events(dataset)),
217219
)
218220
elif return_raws:
219-
labels_pipeline = make_pipeline(
221+
labels_pipeline = make_fixed_pipeline(
220222
self._get_events_pipeline(dataset),
221223
EventsToLabels(event_id=self.used_events(dataset)),
222224
)
@@ -424,10 +426,10 @@ def _get_epochs_pipeline(self, return_epochs, return_raws, dataset):
424426
steps.append(
425427
(
426428
"epoching",
427-
make_pipeline(
429+
make_fixed_pipeline(
428430
ForkPipelines(
429431
[
430-
("raw", make_pipeline(None)),
432+
("raw", make_fixed_pipeline(None)),
431433
("events", self._get_events_pipeline(dataset)),
432434
]
433435
),
@@ -448,7 +450,7 @@ def _get_epochs_pipeline(self, return_epochs, return_raws, dataset):
448450
steps.append(("resample", get_resample_pipeline(self.resample)))
449451
if return_epochs: # needed to concatenate epochs
450452
steps.append(("load_data", FunctionTransformer(methodcaller("load_data"))))
451-
return Pipeline(steps)
453+
return FixedPipeline(steps)
452454

453455
def _get_array_pipeline(
454456
self, return_epochs, return_raws, dataset, processing_pipeline
@@ -466,7 +468,7 @@ def _get_array_pipeline(
466468
steps.append(("postprocess_pipeline", processing_pipeline))
467469
if len(steps) == 0:
468470
return None
469-
return Pipeline(steps)
471+
return FixedPipeline(steps)
470472

471473
def match_all(
472474
self,

0 commit comments

Comments
 (0)