Skip to content

Commit 3079414

Browse files
authored
Update for scikit-learn 1.8.0 and pandas 3.0.0 compatibility (#106)
1 parent 331ca47 commit 3079414

4 files changed

Lines changed: 22 additions & 13 deletions

File tree

environment.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ dependencies:
1010
- matplotlib
1111
- numba
1212
- pandas
13-
- scikit-learn
13+
- scikit-learn>=1.6
1414
- PyWavelets
1515
- tqdm
1616
- download

mne_features/feature_extraction.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import joblib
1818
from sklearn.pipeline import FeatureUnion
1919
from sklearn.preprocessing import FunctionTransformer
20+
from sklearn.utils.validation import validate_data
2021

2122
from .bivariate import get_bivariate_funcs, get_bivariate_func_names
2223
from .univariate import get_univariate_funcs, get_univariate_func_names
@@ -103,7 +104,13 @@ def fit(self, X, y=None):
103104
-------
104105
self
105106
"""
106-
self._check_input(X, reset=True)
107+
validate_data(
108+
self,
109+
X,
110+
reset=True,
111+
accept_sparse=self.accept_sparse,
112+
skip_check_array=not self.validate,
113+
)
107114
_feature_func = _get_python_func(self.func)
108115
if hasattr(_feature_func, 'get_feature_names'):
109116
_params = self.get_params()

mne_features/tests/test_univariate.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -306,7 +306,7 @@ def test_feature_names_quantile():
306306
df = extract_features(
307307
_data, sfreq, selected_funcs, funcs_params={'quantile__q': q},
308308
return_as_df=True)
309-
assert_equal(df.columns.get_level_values(1).values, col_names)
309+
assert_equal(df.columns.get_level_values(1).to_list(), col_names)
310310

311311

312312
def test_feature_names_spect_edge_freq():
@@ -325,7 +325,7 @@ def test_feature_names_spect_edge_freq():
325325
_data, sfreq, selected_funcs,
326326
funcs_params={'spect_edge_freq__edge': edge},
327327
return_as_df=True)
328-
assert_equal(df.columns.get_level_values(1).values, col_names)
328+
assert_equal(df.columns.get_level_values(1).to_list(), col_names)
329329

330330

331331
def test_feature_names_spect_slope():
@@ -338,7 +338,7 @@ def test_feature_names_spect_slope():
338338
col_names = ['ch%s_%s' % (ch, stat) for ch in range(n_chans) for
339339
stat in stats]
340340
df = extract_features(_data, sfreq, selected_funcs, return_as_df=True)
341-
assert_equal(df.columns.get_level_values(1).values, col_names)
341+
assert_equal(df.columns.get_level_values(1).to_list(), col_names)
342342

343343

344344
def test_feature_names_wavelet_coef_energy(wavelet_name='db4'):
@@ -357,7 +357,7 @@ def test_feature_names_wavelet_coef_energy(wavelet_name='db4'):
357357
_data, sfreq, selected_funcs,
358358
funcs_params={'wavelet_coef_energy__wavelet_name': wavelet_name},
359359
return_as_df=True)
360-
assert_equal(df.columns.get_level_values(1).values, col_names)
360+
assert_equal(df.columns.get_level_values(1).to_list(), col_names)
361361

362362

363363
def test_feature_names_teager_kaiser_energy(wavelet_name='db4'):
@@ -376,7 +376,7 @@ def test_feature_names_teager_kaiser_energy(wavelet_name='db4'):
376376
_data, sfreq, selected_funcs,
377377
funcs_params={'teager_kaiser_energy__wavelet_name': wavelet_name},
378378
return_as_df=True)
379-
assert_equal(df.columns.get_level_values(1).values, col_names)
379+
assert_equal(df.columns.get_level_values(1).to_list(), col_names)
380380

381381

382382
def test_feature_names_pow_freq_bands():
@@ -402,15 +402,17 @@ def test_feature_names_pow_freq_bands():
402402
funcs_params={'pow_freq_bands__ratios': 'only',
403403
'pow_freq_bands__freq_bands': fb},
404404
return_as_df=True)
405-
assert_equal(df_only.columns.get_level_values(1).values, ratios_names)
405+
assert_equal(
406+
df_only.columns.get_level_values(1).to_list(), ratios_names
407+
)
406408

407409
# With `ratios = 'all'`:
408410
df_all = extract_features(
409411
_data, sfreq, selected_funcs,
410412
funcs_params={'pow_freq_bands__ratios': 'all',
411413
'pow_freq_bands__freq_bands': fb},
412414
return_as_df=True)
413-
assert_equal(df_all.columns.get_level_values(1).values,
415+
assert_equal(df_all.columns.get_level_values(1).to_list(),
414416
pow_names + ratios_names)
415417

416418
# With `ratios = None`:
@@ -419,7 +421,7 @@ def test_feature_names_pow_freq_bands():
419421
funcs_params={'pow_freq_bands__ratios': None,
420422
'pow_freq_bands__freq_bands': fb},
421423
return_as_df=True)
422-
assert_equal(df.columns.get_level_values(1).values, pow_names)
424+
assert_equal(df.columns.get_level_values(1).to_list(), pow_names)
423425

424426
# With `ratios = 'only'` and `ratios_triu = True`:
425427
df_only = extract_features(
@@ -428,7 +430,7 @@ def test_feature_names_pow_freq_bands():
428430
'pow_freq_bands__ratios_triu': True,
429431
'pow_freq_bands__freq_bands': fb},
430432
return_as_df=True)
431-
assert_equal(df_only.columns.get_level_values(1).values,
433+
assert_equal(df_only.columns.get_level_values(1).to_list(),
432434
ratios_names[::2])
433435

434436

@@ -531,7 +533,7 @@ def test_feature_names_energy_freq_bands():
531533
_data, sfreq, selected_funcs,
532534
funcs_params={'energy_freq_bands__freq_bands': fb},
533535
return_as_df=True)
534-
assert_equal(df.columns.get_level_values(1).values, feat_names)
536+
assert_equal(df.columns.get_level_values(1).to_list(), feat_names)
535537

536538

537539
def test_spect_slope():

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ def package_tree(pkgroot):
6060
platforms='any',
6161
python_requires='>=3.10',
6262
packages=package_tree('mne_features'),
63-
install_requires=['numpy', 'scipy', 'numba', 'scikit-learn', 'mne',
63+
install_requires=['numpy', 'scipy', 'numba', 'scikit-learn>=1.6', 'mne',
6464
'PyWavelets', 'pandas'],
6565
project_urls={
6666
'Documentation': 'https://mne-tools.github.io/mne-features/',

0 commit comments

Comments
 (0)