Skip to content

Commit 8e80778

Browse files
authored
Fixed _calc_score for *scikit-learn* version compatibility (#1109)
* added sklearn version compability for fit_params * added change to changelog * fixed format * fixed order of imports * formatted with isort and black
1 parent a78bd0b commit 8e80778

File tree

2 files changed

+7
-3
lines changed

2 files changed

+7
-3
lines changed

docs/sources/CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@ Files updated:
2323
- ['mlxtend.frequent_patterns.fpcommon']
2424
- ['mlxtend.frequent_patterns.fpgrowth'](https://rasbt.github.io/mlxtend/user_guide/frequent_patterns/fpgrowth/)
2525
- ['mlxtend.frequent_patterns.fpmax'](https://rasbt.github.io/mlxtend/user_guide/frequent_patterns/fpmax/)
26+
- ['mlxtend/feature_selection/utilities.py'](https://github.com/rasbt/mlxtend/blob/master/mlxtend/feature_selection/utilities.py)
27+
- Modified `_calc_score` function to ensure compatibility with *scikit-learn* versions 1.4 and above by dynamically selecting between `fit_params` and `params` in `cross_val_score`.
2628
- [`mlxtend.feature_selection.SequentialFeatureSelector`](https://github.com/rasbt/mlxtend/blob/master/mlxtend/feature_selection/sequential_feature_selector.py)
2729
- Updated negative infinity constant to be compatible with old and new (>=2.0) `numpy` versions
2830
- [`mlxtend.frequent_patterns.association_rules`](https://rasbt.github.io/mlxtend/user_guide/frequent_patterns/association_rules/)

mlxtend/feature_selection/utilities.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
1-
from copy import deepcopy
2-
31
import numpy as np
2+
from sklearn import __version__ as sklearn_version
43
from sklearn.model_selection import cross_val_score
54

65

@@ -94,6 +93,9 @@ def _calc_score(
9493
feature_groups = [[i] for i in range(X.shape[1])]
9594

9695
IDX = _merge_lists(feature_groups, indices)
96+
97+
param_name = "fit_params" if sklearn_version < "1.4" else "params"
98+
9799
if selector.cv:
98100
scores = cross_val_score(
99101
selector.est_,
@@ -104,7 +106,7 @@ def _calc_score(
104106
scoring=selector.scorer,
105107
n_jobs=1,
106108
pre_dispatch=selector.pre_dispatch,
107-
fit_params=fit_params,
109+
**{param_name: fit_params},
108110
)
109111
else:
110112
selector.est_.fit(X[:, IDX], y, **fit_params)

0 commit comments

Comments
 (0)