Skip to content

Commit

Permalink
Enable Auto Early Stopping
Browse files Browse the repository at this point in the history
  • Loading branch information
ClaudioSalvatoreArcidiacono committed Sep 15, 2023
1 parent 5792b5b commit ad43e17
Show file tree
Hide file tree
Showing 6 changed files with 226 additions and 66 deletions.
5 changes: 5 additions & 0 deletions docs/Python-Intro.rst
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,11 @@ This works with both metrics to minimize (L2, log loss, etc.) and to maximize (N
Note that if you specify more than one evaluation metric, all of them will be used for early stopping.
However, you can change this behavior and make LightGBM check only the first metric for early stopping by passing ``first_metric_only=True`` in ``early_stopping`` callback constructor.

In the scikit-learn API of lightgbm, early stopping is enabled by default if the training set passed during ``fit()`` has more than 10000 rows.
This behavior can be controlled by explicitly setting the parameter ``early_stopping`` to ``True`` or ``False`` in the class constructor.
When auto early stopping is enabled, a portion of the training data will be used as validation set. The amount of data to use for validation
is controlled by the parameter ``validation_fraction`` and defaults to 0.1.

Prediction
----------

Expand Down
4 changes: 1 addition & 3 deletions python-package/lightgbm/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -2539,14 +2539,12 @@ def set_categorical_feature(
self : Dataset
Dataset with set categorical features.
"""
if self.categorical_feature == categorical_feature:
if self.categorical_feature == categorical_feature or categorical_feature == 'auto':
return self
if self.data is not None:
if self.categorical_feature is None:
self.categorical_feature = categorical_feature
return self._free_handle()
elif categorical_feature == 'auto':
return self
else:
if self.categorical_feature != 'auto':
_log_warning('categorical_feature in Dataset is overridden.\n'
Expand Down
3 changes: 3 additions & 0 deletions python-package/lightgbm/dask.py
Original file line number Diff line number Diff line change
Expand Up @@ -1145,6 +1145,7 @@ def __init__(
random_state: Optional[Union[int, np.random.RandomState]] = None,
n_jobs: Optional[int] = None,
importance_type: str = 'split',
validation_fraction: Optional[float] = 0.1,
client: Optional[Client] = None,
**kwargs: Any
):
Expand Down Expand Up @@ -1350,6 +1351,7 @@ def __init__(
random_state: Optional[Union[int, np.random.RandomState]] = None,
n_jobs: Optional[int] = None,
importance_type: str = 'split',
validation_fraction: Optional[float] = 0.1,
client: Optional[Client] = None,
**kwargs: Any
):
Expand Down Expand Up @@ -1520,6 +1522,7 @@ def __init__(
random_state: Optional[Union[int, np.random.RandomState]] = None,
n_jobs: Optional[int] = None,
importance_type: str = 'split',
validation_fraction: Optional[float] = 0.1,
client: Optional[Client] = None,
**kwargs: Any
):
Expand Down
37 changes: 23 additions & 14 deletions python-package/lightgbm/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -455,11 +455,9 @@ def _make_n_folds(
nfold: int,
params: Dict[str, Any],
seed: int,
fpreproc: Optional[_LGBM_PreprocFunction],
stratified: bool,
shuffle: bool,
eval_train_metric: bool
) -> CVBooster:
) -> Iterable[Tuple[np.ndarray, np.ndarray]]:
"""Make a n-fold list of Booster from random indices."""
full_data = full_data.construct()
num_data = full_data.num_data()
Expand Down Expand Up @@ -500,7 +498,16 @@ def _make_n_folds(
test_id = [randidx[i: i + kstep] for i in range(0, num_data, kstep)]
train_id = [np.concatenate([test_id[i] for i in range(nfold) if k != i]) for k in range(nfold)]
folds = zip(train_id, test_id)
return folds


def _make_cvbooster(
full_data: Dataset,
params: Dict[str, Any],
folds: Iterable[Tuple[np.ndarray, np.ndarray]],
fpreproc: Optional[_LGBM_PreprocFunction],
eval_train_metric: bool,
) -> CVBooster:
ret = CVBooster()
for train_idx, test_idx in folds:
train_set = full_data.subset(sorted(train_idx))
Expand Down Expand Up @@ -720,8 +727,10 @@ def cv(

results = defaultdict(list)
cvfolds = _make_n_folds(full_data=train_set, folds=folds, nfold=nfold,
params=params, seed=seed, fpreproc=fpreproc,
stratified=stratified, shuffle=shuffle,
params=params, seed=seed,
stratified=stratified, shuffle=shuffle)
cvbooster = _make_cvbooster(full_data=train_set, folds=cvfolds,
params=params, fpreproc=fpreproc,
eval_train_metric=eval_train_metric)

# setup callbacks
Expand Down Expand Up @@ -752,34 +761,34 @@ def cv(

for i in range(num_boost_round):
for cb in callbacks_before_iter:
cb(callback.CallbackEnv(model=cvfolds,
cb(callback.CallbackEnv(model=cvbooster,
params=params,
iteration=i,
begin_iteration=0,
end_iteration=num_boost_round,
evaluation_result_list=None))
cvfolds.update(fobj=fobj) # type: ignore[call-arg]
res = _agg_cv_result(cvfolds.eval_valid(feval)) # type: ignore[call-arg]
cvbooster.update(fobj=fobj) # type: ignore[call-arg]
res = _agg_cv_result(cvbooster.eval_valid(feval)) # type: ignore[call-arg]
for _, key, mean, _, std in res:
results[f'{key}-mean'].append(mean)
results[f'{key}-stdv'].append(std)
try:
for cb in callbacks_after_iter:
cb(callback.CallbackEnv(model=cvfolds,
cb(callback.CallbackEnv(model=cvbooster,
params=params,
iteration=i,
begin_iteration=0,
end_iteration=num_boost_round,
evaluation_result_list=res))
except callback.EarlyStopException as earlyStopException:
cvfolds.best_iteration = earlyStopException.best_iteration + 1
for bst in cvfolds.boosters:
bst.best_iteration = cvfolds.best_iteration
cvbooster.best_iteration = earlyStopException.best_iteration + 1
for bst in cvbooster.boosters:
bst.best_iteration = cvbooster.best_iteration
for k in results:
results[k] = results[k][:cvfolds.best_iteration]
results[k] = results[k][:cvbooster.best_iteration]
break

if return_cvbooster:
results['cvbooster'] = cvfolds # type: ignore[assignment]
results['cvbooster'] = cvbooster # type: ignore[assignment]

return dict(results)
138 changes: 89 additions & 49 deletions python-package/lightgbm/sklearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
_LGBMCheckClassificationTargets, _LGBMCheckSampleWeight, _LGBMCheckXY, _LGBMClassifierBase,
_LGBMComputeSampleWeight, _LGBMCpuCount, _LGBMLabelEncoder, _LGBMModelBase, _LGBMRegressorBase,
dt_DataTable, pd_DataFrame)
from .engine import train
from .engine import _make_n_folds, train

__all__ = [
'LGBMClassifier',
Expand Down Expand Up @@ -412,6 +412,7 @@ def __init__(
random_state: Optional[Union[int, np.random.RandomState]] = None,
n_jobs: Optional[int] = None,
importance_type: str = 'split',
validation_fraction: Optional[float] = 0.1,
**kwargs
):
r"""Construct a gradient boosting model.
Expand Down Expand Up @@ -491,6 +492,10 @@ def __init__(
The type of feature importance to be filled into ``feature_importances_``.
If 'split', result contains numbers of times the feature is used in a model.
If 'gain', result contains total gains of splits which use the feature.
validation_fraction : float or None, optional (default=0.1)
Proportion of training data to set aside as
validation data for early stopping. If None, early stopping is done on
the training data. Only used if early stopping is performed.
**kwargs
Other parameters for the model.
Check http://lightgbm.readthedocs.io/en/latest/Parameters.html for more parameters.
Expand Down Expand Up @@ -553,6 +558,7 @@ def __init__(
self.random_state = random_state
self.n_jobs = n_jobs
self.importance_type = importance_type
self.validation_fraction = validation_fraction
self._Booster: Optional[Booster] = None
self._evals_result: _EvalResultDict = {}
self._best_score: _LGBM_BoosterBestScoreType = {}
Expand Down Expand Up @@ -668,9 +674,27 @@ def _process_params(self, stage: str) -> Dict[str, Any]:
params.pop('importance_type', None)
params.pop('n_estimators', None)
params.pop('class_weight', None)
params.pop("validation_fraction", None)

if isinstance(params['random_state'], np.random.RandomState):
params['random_state'] = params['random_state'].randint(np.iinfo(np.int32).max)

params = _choose_param_value(
main_param_name="early_stopping_round",
params=params,
default_value="auto",
)
if params["early_stopping_round"] == "auto":
if hasattr(self, "_n_rows_train") and self._n_rows_train > 10_000:
params["early_stopping_round"] = 10
else:
params["early_stopping_round"] = None

if params["early_stopping_round"] is True:
params["early_stopping_round"] = 10
elif params["early_stopping_round"] is False:
params["early_stopping_round"] = None

if self._n_classes > 2:
for alias in _ConfigAliases.get('num_class'):
params.pop(alias, None)
Expand Down Expand Up @@ -745,6 +769,19 @@ def fit(
init_model: Optional[Union[str, Path, Booster, "LGBMModel"]] = None
) -> "LGBMModel":
"""Docstring is set after definition, using a template."""
if not isinstance(X, (pd_DataFrame, dt_DataTable)):
_X, _y = _LGBMCheckXY(X, y, accept_sparse=True, force_all_finite=False, ensure_min_samples=2)
if sample_weight is not None:
sample_weight = _LGBMCheckSampleWeight(sample_weight, _X)
else:
_X, _y = X, y

self._n_features = _X.shape[1]
# copy for consistency
self._n_features_in = self._n_features

self._n_rows_train = _X.shape[0]

params = self._process_params(stage="fit")

# Do not modify original args in fit function
Expand All @@ -766,13 +803,6 @@ def fit(
params['metric'] = [e for e in eval_metrics_builtin if e not in params['metric']] + params['metric']
params['metric'] = [metric for metric in params['metric'] if metric is not None]

if not isinstance(X, (pd_DataFrame, dt_DataTable)):
_X, _y = _LGBMCheckXY(X, y, accept_sparse=True, force_all_finite=False, ensure_min_samples=2)
if sample_weight is not None:
sample_weight = _LGBMCheckSampleWeight(sample_weight, _X)
else:
_X, _y = X, y

if self._class_weight is None:
self._class_weight = self.class_weight
if self._class_weight is not None:
Expand All @@ -782,51 +812,61 @@ def fit(
else:
sample_weight = np.multiply(sample_weight, class_sample_weight)

self._n_features = _X.shape[1]
# copy for consistency
self._n_features_in = self._n_features

train_set = Dataset(data=_X, label=_y, weight=sample_weight, group=group,
init_score=init_score, categorical_feature=categorical_feature,
params=params)
if params["early_stopping_round"] is not None and eval_set is None:
if self.validation_fraction is not None:
n_splits = max(int(np.ceil(1 / self.validation_fraction)), 2)
stratified = isinstance(self, LGBMClassifier)
cvfolds = _make_n_folds(full_data=train_set, folds=None, nfold=n_splits,
params=params, seed=self.random_state,
stratified=stratified, shuffle=True)
train_idx, val_idx = next(cvfolds)
valid_set = train_set.subset(sorted(val_idx))
train_set = train_set.subset(sorted(train_idx))
else:
valid_set = train_set
valid_set = valid_set.construct()
valid_sets = [valid_set]

valid_sets: List[Dataset] = []
if eval_set is not None:

def _get_meta_data(collection, name, i):
if collection is None:
return None
elif isinstance(collection, list):
return collection[i] if len(collection) > i else None
elif isinstance(collection, dict):
return collection.get(i, None)
else:
raise TypeError(f"{name} should be dict or list")

if isinstance(eval_set, tuple):
eval_set = [eval_set]
for i, valid_data in enumerate(eval_set):
# reduce cost for prediction training data
if valid_data[0] is X and valid_data[1] is y:
valid_set = train_set
else:
valid_weight = _get_meta_data(eval_sample_weight, 'eval_sample_weight', i)
valid_class_weight = _get_meta_data(eval_class_weight, 'eval_class_weight', i)
if valid_class_weight is not None:
if isinstance(valid_class_weight, dict) and self._class_map is not None:
valid_class_weight = {self._class_map[k]: v for k, v in valid_class_weight.items()}
valid_class_sample_weight = _LGBMComputeSampleWeight(valid_class_weight, valid_data[1])
if valid_weight is None or len(valid_weight) == 0:
valid_weight = valid_class_sample_weight
else:
valid_weight = np.multiply(valid_weight, valid_class_sample_weight)
valid_init_score = _get_meta_data(eval_init_score, 'eval_init_score', i)
valid_group = _get_meta_data(eval_group, 'eval_group', i)
valid_set = Dataset(data=valid_data[0], label=valid_data[1], weight=valid_weight,
group=valid_group, init_score=valid_init_score,
categorical_feature='auto', params=params)

valid_sets.append(valid_set)
else:
valid_sets: List[Dataset] = []
if eval_set is not None:
def _get_meta_data(collection, name, i):
if collection is None:
return None
elif isinstance(collection, list):
return collection[i] if len(collection) > i else None
elif isinstance(collection, dict):
return collection.get(i, None)
else:
raise TypeError(f"{name} should be dict or list")

if isinstance(eval_set, tuple):
eval_set = [eval_set]
for i, valid_data in enumerate(eval_set):
# reduce cost for prediction training data
if valid_data[0] is X and valid_data[1] is y:
valid_set = train_set
else:
valid_weight = _get_meta_data(eval_sample_weight, 'eval_sample_weight', i)
valid_class_weight = _get_meta_data(eval_class_weight, 'eval_class_weight', i)
if valid_class_weight is not None:
if isinstance(valid_class_weight, dict) and self._class_map is not None:
valid_class_weight = {self._class_map[k]: v for k, v in valid_class_weight.items()}
valid_class_sample_weight = _LGBMComputeSampleWeight(valid_class_weight, valid_data[1])
if valid_weight is None or len(valid_weight) == 0:
valid_weight = valid_class_sample_weight
else:
valid_weight = np.multiply(valid_weight, valid_class_sample_weight)
valid_init_score = _get_meta_data(eval_init_score, 'eval_init_score', i)
valid_group = _get_meta_data(eval_group, 'eval_group', i)
valid_set = Dataset(data=valid_data[0], label=valid_data[1], weight=valid_weight,
group=valid_group, init_score=valid_init_score,
categorical_feature='auto', params=params)

valid_sets.append(valid_set)

if isinstance(init_model, LGBMModel):
init_model = init_model.booster_
Expand Down
Loading

0 comments on commit ad43e17

Please sign in to comment.