diff --git a/mlxtend/feature_selection/sequential_feature_selector.py b/mlxtend/feature_selection/sequential_feature_selector.py index 835d7a3f9..80e08f343 100644 --- a/mlxtend/feature_selection/sequential_feature_selector.py +++ b/mlxtend/feature_selection/sequential_feature_selector.py @@ -197,12 +197,14 @@ def __init__( clone_estimator=True, fixed_features=None, feature_groups=None, + tol=None, ): self.estimator = estimator self.k_features = k_features self.forward = forward self.floating = floating self.pre_dispatch = pre_dispatch + self.tol = tol # Want to raise meaningful error message if a # cross-validation generator is inputted if isinstance(cv, types.GeneratorType): @@ -569,6 +571,13 @@ def fit(self, X, y, groups=None, **fit_params): "avg_score": k_score, } + if self.tol is not None and k > 1: + prev_k = k - 1 if self.forward else k + 1 + if prev_k in self.subsets_: + diff = k_score - self.subsets_[prev_k]["avg_score"] + if diff < self.tol: + k_stop = k + if self.floating: # floating direction is opposite of self.forward, i.e. in # forward selection, we do floating in backward manner,