Skip to content

Commit 70161f8

Browse files
committed
Add tol support to SequentialFeatureSelector
1 parent 366f717 commit 70161f8

File tree

3 files changed

+239
-34
lines changed

3 files changed

+239
-34
lines changed

docs/sources/CHANGELOG.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,11 @@ The CHANGELOG for the current development version is available at
2323

2424
- Fixes an edge-case bug where decision regions plots didn't have unique colors ([#1157](https://github.com/rasbt/mlxtend/issues/1157) via [mariam851](https://github.com/mariam851))
2525

26+
- Added functional `tol` support to `SequentialFeatureSelector` for auto-like
27+
modes (`k_features="best"`/`"parsimonious"`) with forward-stop validation.
28+
([#1079](https://github.com/rasbt/mlxtend/issues/1079) via
29+
[@vaaven](https://github.com/vaaven))
30+
2631

2732
### Version 0.24.0 (13 Dec 2025)
2833

mlxtend/feature_selection/sequential_feature_selector.py

Lines changed: 72 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,13 @@ class SequentialFeatureSelector(_BaseXComposition, MetaEstimatorMixin):
7676
labels) stratified k-fold. Otherwise regular k-fold cross-validation
7777
is performed. No cross-validation if cv is None, False, or 0.
7878
79+
tol : float or None (default: None)
80+
Early stopping tolerance. This is only active when
81+
`k_features` is `"best"` or `"parsimonious"` and ignored for
82+
integer or tuple input.
83+
Forward selection requires `tol > 0`; backward selection allows
84+
non-positive values.
85+
7986
n_jobs : int (default: 1)
8087
The number of CPUs to use for evaluating different feature subsets
8188
in parallel. -1 means 'all CPUs'.
@@ -192,6 +199,7 @@ def __init__(
192199
verbose=0,
193200
scoring=None,
194201
cv=5,
202+
tol=None,
195203
n_jobs=1,
196204
pre_dispatch="2*n_jobs",
197205
clone_estimator=True,
@@ -215,6 +223,7 @@ def __init__(
215223
)
216224
raise TypeError(err_msg)
217225
self.cv = cv
226+
self.tol = tol
218227
self.n_jobs = n_jobs
219228
self.verbose = verbose
220229

@@ -444,6 +453,8 @@ def fit(self, X, y, groups=None, **fit_params):
444453

445454
self.k_lb = max(1, len(self.fixed_features_group_set))
446455
self.k_ub = len(self.feature_groups_)
456+
original_k_features = self.k_features
457+
k_features = self.k_features
447458

448459
if (
449460
not isinstance(self.k_features, int)
@@ -495,23 +506,38 @@ def fit(self, X, y, groups=None, **fit_params):
495506
# )
496507

497508
self.is_parsimonious = False
498-
if isinstance(self.k_features, str):
499-
if self.k_features not in {"best", "parsimonious"}:
509+
if isinstance(k_features, str):
510+
if k_features not in {"best", "parsimonious"}:
500511
raise AttributeError(
501512
"If a string argument is provided, "
502513
'it must be "best" or "parsimonious"'
503514
)
504-
if self.k_features == "parsimonious":
515+
if k_features == "parsimonious":
505516
self.is_parsimonious = True
506517

507-
if isinstance(self.k_features, str):
508-
self.k_features = (self.k_lb, self.k_ub)
509-
elif isinstance(self.k_features, int):
518+
if self.tol is not None:
519+
is_auto = original_k_features in {"best", "parsimonious"}
520+
if not is_auto:
521+
raise ValueError(
522+
"tol is only enabled when k_features is `best` or `parsimonious`."
523+
)
524+
if not isinstance(self.tol, (int, float, np.number)):
525+
raise TypeError("tol must be numeric.")
526+
if not np.isfinite(float(self.tol)):
527+
raise ValueError("tol must be finite.")
528+
if self.forward and self.tol <= 0:
529+
raise ValueError(
530+
"tol must be strictly positive when doing forward selection"
531+
)
532+
533+
if isinstance(k_features, str):
534+
k_features = (self.k_lb, self.k_ub)
535+
elif isinstance(k_features, int):
510536
# we treat k_features as k group of features
511-
self.k_features = (self.k_features, self.k_features)
537+
k_features = (k_features, k_features)
512538

513-
self.min_k = self.k_features[0]
514-
self.max_k = self.k_features[1]
539+
self.min_k = k_features[0]
540+
self.max_k = k_features[1]
515541

516542
if self.forward:
517543
k_idx = tuple(sorted(self.fixed_features_group_set))
@@ -536,8 +562,15 @@ def fit(self, X, y, groups=None, **fit_params):
536562
"cv_scores": k_score,
537563
"avg_score": np.nanmean(k_score),
538564
}
565+
k_score_current = np.nanmean(k_score)
566+
else:
567+
k_score_current = np.nan
539568

540569
orig_set = set(range(self.k_ub))
570+
auto_selection = self.tol is not None and original_k_features in {
571+
"best",
572+
"parsimonious",
573+
}
541574
try:
542575
while k != k_stop:
543576
prev_subset = set(k_idx)
@@ -548,7 +581,7 @@ def fit(self, X, y, groups=None, **fit_params):
548581
search_set = prev_subset
549582
must_include_set = self.fixed_features_group_set
550583

551-
k_idx, k_score, cv_scores = self._feature_selector(
584+
k_idx_next, k_score_next, cv_scores_next = self._feature_selector(
552585
search_set,
553586
must_include_set,
554587
X=X_,
@@ -558,39 +591,34 @@ def fit(self, X, y, groups=None, **fit_params):
558591
feature_groups=self.feature_groups_,
559592
**fit_params,
560593
)
594+
if k_idx_next is None:
595+
break
561596

562-
k = len(k_idx)
563-
# floating can lead to multiple same-sized subsets
564-
if k not in self.subsets_ or (k_score > self.subsets_[k]["avg_score"]):
565-
k_idx = tuple(sorted(k_idx))
566-
self.subsets_[k] = {
567-
"feature_idx": k_idx,
568-
"cv_scores": cv_scores,
569-
"avg_score": k_score,
570-
}
597+
k_idx_next = tuple(sorted(k_idx_next))
598+
k = len(k_idx_next)
571599

572600
if self.floating:
573601
# floating direction is opposite of self.forward, i.e. in
574602
# forward selection, we do floating in backward manner,
575603
# and in backward selection, we do floating in forward manner
576604
is_float_forward = not self.forward
577-
(new_feature_idx,) = set(k_idx) ^ prev_subset
605+
(new_feature_idx,) = set(k_idx_next) ^ prev_subset
578606
for _ in range(X_.shape[1]):
579607
if (
580608
self.forward
581-
and (len(k_idx) - len(self.fixed_features_group_set)) <= 2
609+
and (len(k_idx_next) - len(self.fixed_features_group_set)) <= 2
582610
):
583611
break
584-
if not self.forward and (len(orig_set) - len(k_idx) <= 2):
612+
if not self.forward and (len(orig_set) - len(k_idx_next) <= 2):
585613
break
586614

587615
if is_float_forward:
588616
# corresponding to self.forward=False
589617
search_set = orig_set - {new_feature_idx}
590-
must_include_set = set(k_idx)
618+
must_include_set = set(k_idx_next)
591619
else:
592620
# corresponding to self.forward=True
593-
search_set = set(k_idx)
621+
search_set = set(k_idx_next)
594622
must_include_set = self.fixed_features_group_set | {
595623
new_feature_idx
596624
}
@@ -610,7 +638,7 @@ def fit(self, X, y, groups=None, **fit_params):
610638
**fit_params,
611639
)
612640

613-
if k_score_c <= k_score:
641+
if k_score_c <= k_score_next:
614642
break
615643

616644
# In the floating process, we basically revisit our previous
@@ -619,14 +647,24 @@ def fit(self, X, y, groups=None, **fit_params):
619647
if k_score_c <= self.subsets_[len(k_idx_c)]["avg_score"]:
620648
break
621649
else:
622-
k_idx, k_score, cv_scores = k_idx_c, k_score_c, cv_scores_c
623-
k_idx = tuple(sorted(k_idx))
624-
k = len(k_idx)
625-
self.subsets_[k] = {
626-
"feature_idx": k_idx,
627-
"cv_scores": cv_scores,
628-
"avg_score": k_score,
629-
}
650+
k_idx_next = tuple(sorted(k_idx_c))
651+
k_score_next = k_score_c
652+
cv_scores_next = cv_scores_c
653+
k = len(k_idx_next)
654+
655+
if auto_selection and not np.isnan(k_score_current):
656+
if (k_score_next - k_score_current) < self.tol:
657+
break
658+
659+
if k not in self.subsets_ or (k_score_next > self.subsets_[k]["avg_score"]):
660+
self.subsets_[k] = {
661+
"feature_idx": k_idx_next,
662+
"cv_scores": cv_scores_next,
663+
"avg_score": k_score_next,
664+
}
665+
666+
k_idx = k_idx_next
667+
k_score_current = k_score_next
630668

631669
if self.verbose == 1:
632670
sys.stderr.write("\rFeatures: %d/%s" % (len(k_idx), k_stop))
@@ -638,7 +676,7 @@ def fit(self, X, y, groups=None, **fit_params):
638676
datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
639677
len(k_idx),
640678
k_stop,
641-
k_score,
679+
k_score_next,
642680
)
643681
)
644682

0 commit comments

Comments
 (0)