66import mne
77import numpy as np
88from sklearn .base import BaseEstimator , TransformerMixin
9- from sklearn .pipeline import FunctionTransformer , Pipeline
9+ from sklearn .pipeline import FunctionTransformer , Pipeline , _name_estimators
1010from sklearn .utils ._repr_html .estimator import _VisualBlock
1111
1212
1313log = logging .getLogger (__name__ )
1414
1515
16+ class FixedPipeline (Pipeline ):
17+ """A Pipeline that is always considered fitted.
18+
19+ This is useful for pre-processing pipelines that don't require fitting,
20+ as they only apply fixed transformations (e.g., filtering, epoching).
21+ This avoids the FutureWarning from sklearn 1.8+ about unfitted pipelines.
22+ """
23+
24+ def __sklearn_is_fitted__ (self ):
25+ """Return True to indicate this pipeline is always considered fitted."""
26+ return True
27+
28+
29+ def make_fixed_pipeline (* steps , memory = None , verbose = False ):
30+ """Create a FixedPipeline that is always considered fitted.
31+
32+ This is a drop-in replacement for sklearn's make_pipeline that creates
33+ a pipeline marked as fitted, suitable for fixed transformations.
34+
35+ Parameters
36+ ----------
37+ *steps : list of estimators
38+ List of (name, transform) tuples that are chained in the pipeline.
39+ memory : str or object with the joblib.Memory interface, default=None
40+ Used to cache the fitted transformers of the pipeline.
41+ verbose : bool, default=False
42+ If True, the time elapsed while fitting each step will be printed.
43+
44+ Returns
45+ -------
46+ p : FixedPipeline
47+ A FixedPipeline object.
48+ """
49+
50+ return FixedPipeline (_name_estimators (steps ), memory = memory , verbose = verbose )
51+
52+
1653def _is_none_pipeline (pipeline ):
1754 """Check if a pipeline is the result of make_pipeline(None)"""
1855 return (
@@ -44,6 +81,11 @@ def transform(self, X, y=None):
4481 def fit (self , X , y = None ):
4582 for _ , t in self .transformers :
4683 t .fit (X )
84+ return self
85+
86+ def __sklearn_is_fitted__ (self ):
87+ """Return True to indicate this transformer is always considered fitted."""
88+ return True
4789
4890 def _sk_visual_block_ (self ):
4991 """Tell sklearn’s diagrammer to lay us out in parallel."""
@@ -65,7 +107,11 @@ def __init__(self):
65107 # when using the pipeline
66108
67109 def fit (self , X , y = None ):
68- pass
110+ return self
111+
112+ def __sklearn_is_fitted__ (self ):
113+ """Return True to indicate this transformer is always considered fitted."""
114+ return True
69115
70116 def _sk_visual_block_ (self ):
71117 """Tell sklearn’s diagrammer to lay us out in parallel."""
@@ -103,6 +149,7 @@ class SetRawAnnotations(FixedTransformer):
103149 """
104150
105151 def __init__ (self , event_id , interval : Tuple [float , float ]):
152+ super ().__init__ ()
106153 assert isinstance (event_id , dict ) # not None
107154 self .event_id = event_id
108155 values = _get_event_id_values (self .event_id )
@@ -153,6 +200,7 @@ class RawToEvents(FixedTransformer):
153200 """
154201
155202 def __init__ (self , event_id : dict [str , int ], interval : Tuple [float , float ]):
203+ super ().__init__ ()
156204 assert isinstance (event_id , dict ) # not None
157205 self .event_id = event_id
158206 self .interval = interval
@@ -212,6 +260,7 @@ def __init__(
212260 stop_offset ,
213261 marker = 1 ,
214262 ):
263+ super ().__init__ ()
215264 self .length = length
216265 self .stride = stride
217266 self .start_offset = start_offset
@@ -245,12 +294,16 @@ def transform(self, raw: mne.io.BaseRaw, y=None):
245294
246295
247296class EpochsToEvents (FixedTransformer ):
297+ def __init__ (self ):
298+ super ().__init__ ()
299+
248300 def transform (self , epochs , y = None ):
249301 return epochs .events
250302
251303
252304class EventsToLabels (FixedTransformer ):
253305 def __init__ (self , event_id ):
306+ super ().__init__ ()
254307 self .event_id = event_id
255308
256309 def transform (self , events , y = None ):
@@ -269,6 +322,7 @@ def __init__(
269322 channels : List [str ] = None ,
270323 interpolate_missing_channels : bool = False ,
271324 ):
325+ super ().__init__ ()
272326 assert isinstance (event_id , dict ) # not None
273327 self .event_id = event_id
274328 self .tmin = tmin
0 commit comments