Skip to content

Commit ad43e17

Browse files
Enable Auto Early Stopping
1 parent 5792b5b commit ad43e17

File tree

6 files changed

+226
-66
lines changed

6 files changed

+226
-66
lines changed

docs/Python-Intro.rst

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -241,6 +241,11 @@ This works with both metrics to minimize (L2, log loss, etc.) and to maximize (N
241241
Note that if you specify more than one evaluation metric, all of them will be used for early stopping.
242242
However, you can change this behavior and make LightGBM check only the first metric for early stopping by passing ``first_metric_only=True`` in ``early_stopping`` callback constructor.
243243

244+
In the scikit-learn API of lightgbm, early stopping is enabled by default if the training set passed during ``fit()`` has more than 10000 rows.
245+
This behavior can be controlled by explicitly setting the parameter ``early_stopping`` to ``True`` or ``False`` in the class constructor.
246+
When auto early stopping is enabled, a portion of the training data will be used as validation set. The amount of data to use for validation
247+
is controlled by the parameter ``validation_fraction`` and defaults to 0.1.
248+
244249
Prediction
245250
----------
246251

python-package/lightgbm/basic.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2539,14 +2539,12 @@ def set_categorical_feature(
25392539
self : Dataset
25402540
Dataset with set categorical features.
25412541
"""
2542-
if self.categorical_feature == categorical_feature:
2542+
if self.categorical_feature == categorical_feature or categorical_feature == 'auto':
25432543
return self
25442544
if self.data is not None:
25452545
if self.categorical_feature is None:
25462546
self.categorical_feature = categorical_feature
25472547
return self._free_handle()
2548-
elif categorical_feature == 'auto':
2549-
return self
25502548
else:
25512549
if self.categorical_feature != 'auto':
25522550
_log_warning('categorical_feature in Dataset is overridden.\n'

python-package/lightgbm/dask.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1145,6 +1145,7 @@ def __init__(
11451145
random_state: Optional[Union[int, np.random.RandomState]] = None,
11461146
n_jobs: Optional[int] = None,
11471147
importance_type: str = 'split',
1148+
validation_fraction: Optional[float] = 0.1,
11481149
client: Optional[Client] = None,
11491150
**kwargs: Any
11501151
):
@@ -1350,6 +1351,7 @@ def __init__(
13501351
random_state: Optional[Union[int, np.random.RandomState]] = None,
13511352
n_jobs: Optional[int] = None,
13521353
importance_type: str = 'split',
1354+
validation_fraction: Optional[float] = 0.1,
13531355
client: Optional[Client] = None,
13541356
**kwargs: Any
13551357
):
@@ -1520,6 +1522,7 @@ def __init__(
15201522
random_state: Optional[Union[int, np.random.RandomState]] = None,
15211523
n_jobs: Optional[int] = None,
15221524
importance_type: str = 'split',
1525+
validation_fraction: Optional[float] = 0.1,
15231526
client: Optional[Client] = None,
15241527
**kwargs: Any
15251528
):

python-package/lightgbm/engine.py

Lines changed: 23 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -455,11 +455,9 @@ def _make_n_folds(
455455
nfold: int,
456456
params: Dict[str, Any],
457457
seed: int,
458-
fpreproc: Optional[_LGBM_PreprocFunction],
459458
stratified: bool,
460459
shuffle: bool,
461-
eval_train_metric: bool
462-
) -> CVBooster:
460+
) -> Iterable[Tuple[np.ndarray, np.ndarray]]:
463461
"""Make a n-fold list of Booster from random indices."""
464462
full_data = full_data.construct()
465463
num_data = full_data.num_data()
@@ -500,7 +498,16 @@ def _make_n_folds(
500498
test_id = [randidx[i: i + kstep] for i in range(0, num_data, kstep)]
501499
train_id = [np.concatenate([test_id[i] for i in range(nfold) if k != i]) for k in range(nfold)]
502500
folds = zip(train_id, test_id)
501+
return folds
503502

503+
504+
def _make_cvbooster(
505+
full_data: Dataset,
506+
params: Dict[str, Any],
507+
folds: Iterable[Tuple[np.ndarray, np.ndarray]],
508+
fpreproc: Optional[_LGBM_PreprocFunction],
509+
eval_train_metric: bool,
510+
) -> CVBooster:
504511
ret = CVBooster()
505512
for train_idx, test_idx in folds:
506513
train_set = full_data.subset(sorted(train_idx))
@@ -720,8 +727,10 @@ def cv(
720727

721728
results = defaultdict(list)
722729
cvfolds = _make_n_folds(full_data=train_set, folds=folds, nfold=nfold,
723-
params=params, seed=seed, fpreproc=fpreproc,
724-
stratified=stratified, shuffle=shuffle,
730+
params=params, seed=seed,
731+
stratified=stratified, shuffle=shuffle)
732+
cvbooster = _make_cvbooster(full_data=train_set, folds=cvfolds,
733+
params=params, fpreproc=fpreproc,
725734
eval_train_metric=eval_train_metric)
726735

727736
# setup callbacks
@@ -752,34 +761,34 @@ def cv(
752761

753762
for i in range(num_boost_round):
754763
for cb in callbacks_before_iter:
755-
cb(callback.CallbackEnv(model=cvfolds,
764+
cb(callback.CallbackEnv(model=cvbooster,
756765
params=params,
757766
iteration=i,
758767
begin_iteration=0,
759768
end_iteration=num_boost_round,
760769
evaluation_result_list=None))
761-
cvfolds.update(fobj=fobj) # type: ignore[call-arg]
762-
res = _agg_cv_result(cvfolds.eval_valid(feval)) # type: ignore[call-arg]
770+
cvbooster.update(fobj=fobj) # type: ignore[call-arg]
771+
res = _agg_cv_result(cvbooster.eval_valid(feval)) # type: ignore[call-arg]
763772
for _, key, mean, _, std in res:
764773
results[f'{key}-mean'].append(mean)
765774
results[f'{key}-stdv'].append(std)
766775
try:
767776
for cb in callbacks_after_iter:
768-
cb(callback.CallbackEnv(model=cvfolds,
777+
cb(callback.CallbackEnv(model=cvbooster,
769778
params=params,
770779
iteration=i,
771780
begin_iteration=0,
772781
end_iteration=num_boost_round,
773782
evaluation_result_list=res))
774783
except callback.EarlyStopException as earlyStopException:
775-
cvfolds.best_iteration = earlyStopException.best_iteration + 1
776-
for bst in cvfolds.boosters:
777-
bst.best_iteration = cvfolds.best_iteration
784+
cvbooster.best_iteration = earlyStopException.best_iteration + 1
785+
for bst in cvbooster.boosters:
786+
bst.best_iteration = cvbooster.best_iteration
778787
for k in results:
779-
results[k] = results[k][:cvfolds.best_iteration]
788+
results[k] = results[k][:cvbooster.best_iteration]
780789
break
781790

782791
if return_cvbooster:
783-
results['cvbooster'] = cvfolds # type: ignore[assignment]
792+
results['cvbooster'] = cvbooster # type: ignore[assignment]
784793

785794
return dict(results)

python-package/lightgbm/sklearn.py

Lines changed: 89 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
_LGBMCheckClassificationTargets, _LGBMCheckSampleWeight, _LGBMCheckXY, _LGBMClassifierBase,
1717
_LGBMComputeSampleWeight, _LGBMCpuCount, _LGBMLabelEncoder, _LGBMModelBase, _LGBMRegressorBase,
1818
dt_DataTable, pd_DataFrame)
19-
from .engine import train
19+
from .engine import _make_n_folds, train
2020

2121
__all__ = [
2222
'LGBMClassifier',
@@ -412,6 +412,7 @@ def __init__(
412412
random_state: Optional[Union[int, np.random.RandomState]] = None,
413413
n_jobs: Optional[int] = None,
414414
importance_type: str = 'split',
415+
validation_fraction: Optional[float] = 0.1,
415416
**kwargs
416417
):
417418
r"""Construct a gradient boosting model.
@@ -491,6 +492,10 @@ def __init__(
491492
The type of feature importance to be filled into ``feature_importances_``.
492493
If 'split', result contains numbers of times the feature is used in a model.
493494
If 'gain', result contains total gains of splits which use the feature.
495+
validation_fraction : float or None, optional (default=0.1)
496+
Proportion of training data to set aside as
497+
validation data for early stopping. If None, early stopping is done on
498+
the training data. Only used if early stopping is performed.
494499
**kwargs
495500
Other parameters for the model.
496501
Check http://lightgbm.readthedocs.io/en/latest/Parameters.html for more parameters.
@@ -553,6 +558,7 @@ def __init__(
553558
self.random_state = random_state
554559
self.n_jobs = n_jobs
555560
self.importance_type = importance_type
561+
self.validation_fraction = validation_fraction
556562
self._Booster: Optional[Booster] = None
557563
self._evals_result: _EvalResultDict = {}
558564
self._best_score: _LGBM_BoosterBestScoreType = {}
@@ -668,9 +674,27 @@ def _process_params(self, stage: str) -> Dict[str, Any]:
668674
params.pop('importance_type', None)
669675
params.pop('n_estimators', None)
670676
params.pop('class_weight', None)
677+
params.pop("validation_fraction", None)
671678

672679
if isinstance(params['random_state'], np.random.RandomState):
673680
params['random_state'] = params['random_state'].randint(np.iinfo(np.int32).max)
681+
682+
params = _choose_param_value(
683+
main_param_name="early_stopping_round",
684+
params=params,
685+
default_value="auto",
686+
)
687+
if params["early_stopping_round"] == "auto":
688+
if hasattr(self, "_n_rows_train") and self._n_rows_train > 10_000:
689+
params["early_stopping_round"] = 10
690+
else:
691+
params["early_stopping_round"] = None
692+
693+
if params["early_stopping_round"] is True:
694+
params["early_stopping_round"] = 10
695+
elif params["early_stopping_round"] is False:
696+
params["early_stopping_round"] = None
697+
674698
if self._n_classes > 2:
675699
for alias in _ConfigAliases.get('num_class'):
676700
params.pop(alias, None)
@@ -745,6 +769,19 @@ def fit(
745769
init_model: Optional[Union[str, Path, Booster, "LGBMModel"]] = None
746770
) -> "LGBMModel":
747771
"""Docstring is set after definition, using a template."""
772+
if not isinstance(X, (pd_DataFrame, dt_DataTable)):
773+
_X, _y = _LGBMCheckXY(X, y, accept_sparse=True, force_all_finite=False, ensure_min_samples=2)
774+
if sample_weight is not None:
775+
sample_weight = _LGBMCheckSampleWeight(sample_weight, _X)
776+
else:
777+
_X, _y = X, y
778+
779+
self._n_features = _X.shape[1]
780+
# copy for consistency
781+
self._n_features_in = self._n_features
782+
783+
self._n_rows_train = _X.shape[0]
784+
748785
params = self._process_params(stage="fit")
749786

750787
# Do not modify original args in fit function
@@ -766,13 +803,6 @@ def fit(
766803
params['metric'] = [e for e in eval_metrics_builtin if e not in params['metric']] + params['metric']
767804
params['metric'] = [metric for metric in params['metric'] if metric is not None]
768805

769-
if not isinstance(X, (pd_DataFrame, dt_DataTable)):
770-
_X, _y = _LGBMCheckXY(X, y, accept_sparse=True, force_all_finite=False, ensure_min_samples=2)
771-
if sample_weight is not None:
772-
sample_weight = _LGBMCheckSampleWeight(sample_weight, _X)
773-
else:
774-
_X, _y = X, y
775-
776806
if self._class_weight is None:
777807
self._class_weight = self.class_weight
778808
if self._class_weight is not None:
@@ -782,51 +812,61 @@ def fit(
782812
else:
783813
sample_weight = np.multiply(sample_weight, class_sample_weight)
784814

785-
self._n_features = _X.shape[1]
786-
# copy for consistency
787-
self._n_features_in = self._n_features
788-
789815
train_set = Dataset(data=_X, label=_y, weight=sample_weight, group=group,
790816
init_score=init_score, categorical_feature=categorical_feature,
791817
params=params)
818+
if params["early_stopping_round"] is not None and eval_set is None:
819+
if self.validation_fraction is not None:
820+
n_splits = max(int(np.ceil(1 / self.validation_fraction)), 2)
821+
stratified = isinstance(self, LGBMClassifier)
822+
cvfolds = _make_n_folds(full_data=train_set, folds=None, nfold=n_splits,
823+
params=params, seed=self.random_state,
824+
stratified=stratified, shuffle=True)
825+
train_idx, val_idx = next(cvfolds)
826+
valid_set = train_set.subset(sorted(val_idx))
827+
train_set = train_set.subset(sorted(train_idx))
828+
else:
829+
valid_set = train_set
830+
valid_set = valid_set.construct()
831+
valid_sets = [valid_set]
792832

793-
valid_sets: List[Dataset] = []
794-
if eval_set is not None:
795-
796-
def _get_meta_data(collection, name, i):
797-
if collection is None:
798-
return None
799-
elif isinstance(collection, list):
800-
return collection[i] if len(collection) > i else None
801-
elif isinstance(collection, dict):
802-
return collection.get(i, None)
803-
else:
804-
raise TypeError(f"{name} should be dict or list")
805-
806-
if isinstance(eval_set, tuple):
807-
eval_set = [eval_set]
808-
for i, valid_data in enumerate(eval_set):
809-
# reduce cost for prediction training data
810-
if valid_data[0] is X and valid_data[1] is y:
811-
valid_set = train_set
812-
else:
813-
valid_weight = _get_meta_data(eval_sample_weight, 'eval_sample_weight', i)
814-
valid_class_weight = _get_meta_data(eval_class_weight, 'eval_class_weight', i)
815-
if valid_class_weight is not None:
816-
if isinstance(valid_class_weight, dict) and self._class_map is not None:
817-
valid_class_weight = {self._class_map[k]: v for k, v in valid_class_weight.items()}
818-
valid_class_sample_weight = _LGBMComputeSampleWeight(valid_class_weight, valid_data[1])
819-
if valid_weight is None or len(valid_weight) == 0:
820-
valid_weight = valid_class_sample_weight
821-
else:
822-
valid_weight = np.multiply(valid_weight, valid_class_sample_weight)
823-
valid_init_score = _get_meta_data(eval_init_score, 'eval_init_score', i)
824-
valid_group = _get_meta_data(eval_group, 'eval_group', i)
825-
valid_set = Dataset(data=valid_data[0], label=valid_data[1], weight=valid_weight,
826-
group=valid_group, init_score=valid_init_score,
827-
categorical_feature='auto', params=params)
828-
829-
valid_sets.append(valid_set)
833+
else:
834+
valid_sets: List[Dataset] = []
835+
if eval_set is not None:
836+
def _get_meta_data(collection, name, i):
837+
if collection is None:
838+
return None
839+
elif isinstance(collection, list):
840+
return collection[i] if len(collection) > i else None
841+
elif isinstance(collection, dict):
842+
return collection.get(i, None)
843+
else:
844+
raise TypeError(f"{name} should be dict or list")
845+
846+
if isinstance(eval_set, tuple):
847+
eval_set = [eval_set]
848+
for i, valid_data in enumerate(eval_set):
849+
# reduce cost for prediction training data
850+
if valid_data[0] is X and valid_data[1] is y:
851+
valid_set = train_set
852+
else:
853+
valid_weight = _get_meta_data(eval_sample_weight, 'eval_sample_weight', i)
854+
valid_class_weight = _get_meta_data(eval_class_weight, 'eval_class_weight', i)
855+
if valid_class_weight is not None:
856+
if isinstance(valid_class_weight, dict) and self._class_map is not None:
857+
valid_class_weight = {self._class_map[k]: v for k, v in valid_class_weight.items()}
858+
valid_class_sample_weight = _LGBMComputeSampleWeight(valid_class_weight, valid_data[1])
859+
if valid_weight is None or len(valid_weight) == 0:
860+
valid_weight = valid_class_sample_weight
861+
else:
862+
valid_weight = np.multiply(valid_weight, valid_class_sample_weight)
863+
valid_init_score = _get_meta_data(eval_init_score, 'eval_init_score', i)
864+
valid_group = _get_meta_data(eval_group, 'eval_group', i)
865+
valid_set = Dataset(data=valid_data[0], label=valid_data[1], weight=valid_weight,
866+
group=valid_group, init_score=valid_init_score,
867+
categorical_feature='auto', params=params)
868+
869+
valid_sets.append(valid_set)
830870

831871
if isinstance(init_model, LGBMModel):
832872
init_model = init_model.booster_

0 commit comments

Comments
 (0)