Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[python-package] Add progress_bar callback using tqdm #5867

Closed
wants to merge 14 commits into from
3 changes: 2 additions & 1 deletion .ci/test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,8 @@ conda create -q -y -n $CONDA_ENV \
${CONDA_PYTHON_REQUIREMENT} \
python-graphviz \
scikit-learn \
scipy || exit -1
scipy \
tqdm || exit -1

source activate $CONDA_ENV

Expand Down
3 changes: 2 additions & 1 deletion .ci/test_windows.ps1
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,8 @@ conda create -q -y -n $env:CONDA_ENV `
"python=$env:PYTHON_VERSION[build=*cpython]" `
python-graphviz `
scikit-learn `
scipy ; Check-Output $?
scipy `
tqdm ; Check-Output $?

if ($env:TASK -ne "bdist") {
conda activate $env:CONDA_ENV
Expand Down
4 changes: 2 additions & 2 deletions python-package/lightgbm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from pathlib import Path

from .basic import Booster, Dataset, Sequence, register_logger
from .callback import early_stopping, log_evaluation, record_evaluation, reset_parameter
from .callback import early_stopping, log_evaluation, record_evaluation, reset_parameter, progress_bar
from .engine import CVBooster, cv, train

try:
Expand All @@ -32,5 +32,5 @@
'train', 'cv',
'LGBMModel', 'LGBMRegressor', 'LGBMClassifier', 'LGBMRanker',
'DaskLGBMRegressor', 'DaskLGBMClassifier', 'DaskLGBMRanker',
'log_evaluation', 'record_evaluation', 'reset_parameter', 'early_stopping',
'log_evaluation', 'record_evaluation', 'reset_parameter', 'early_stopping', 'progress_bar',
'plot_importance', 'plot_split_value_histogram', 'plot_metric', 'plot_tree', 'create_tree_digraph']
162 changes: 161 additions & 1 deletion python-package/lightgbm/callback.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,28 @@
# coding: utf-8
"""Callbacks library."""
from __future__ import annotations

import collections
import importlib
import warnings
from collections import OrderedDict
from functools import partial
from typing import Any, Callable, Dict, List, Tuple, Union
from typing import Any, Callable, Dict, List, Tuple, Type, Union
import sys

from .basic import _ConfigAliases, _LGBM_BoosterEvalMethodResultType, _log_info, _log_warning

try:
import tqdm
except ImportError:
pass

__all__ = [
'early_stopping',
'log_evaluation',
'record_evaluation',
'reset_parameter',
'progress_bar',
]

_EvalResultDict = Dict[str, Dict[str, List[Any]]]
Expand Down Expand Up @@ -413,3 +425,151 @@ def early_stopping(stopping_rounds: int, first_metric_only: bool = False, verbos
The callback that activates early stopping.
"""
return _EarlyStoppingCallback(stopping_rounds=stopping_rounds, first_metric_only=first_metric_only, verbose=verbose, min_delta=min_delta)


class _ProgressBarCallback:
"""Internal class to handle progress bar."""
tqdm_cls: "Type[tqdm.std.tqdm]"
pbar: "tqdm.std.tqdm" | None

def __init__(
self,
tqdm_cls: str | Type["tqdm.std.tqdm"] = "auto",
early_stopping_callback: Any | None = None,
**tqdm_kwargs: Any,
) -> None:
"""Progress bar callback for LightGBM training.

Parameters
----------
tqdm_cls : Literal[ "auto", "autonotebook", "std", "notebook", "asyncio", "keras", "dask", "tk", "gui", "rich", "contrib.slack", "contrib.discord", "contrib.telegram", "contrib.bells", ] | Type[tqdm.std.tqdm], optional
The tqdm class or module name, by default "auto"
early_stopping_callback : _EarlyStoppingCallback | None, optional
The early stopping callback, by default None

.. rubric:: Example

.. code-block:: python
early_stopping_callback = early_stopping(stopping_rounds=50)
callbacks = [early_stopping_callback, ProgressBarCallback(early_stopping_callback=early_stopping_callback)]
estimator.fit(X_train, y_train, eval_set=[(X_test, y_test)], callbacks=callbacks)
"""
self.order = 40
self.before_iteration = False
if isinstance(tqdm_cls, str):
try:
tqdm_module = importlib.import_module(f"tqdm.{tqdm_cls}")
except ImportError as e:
raise ImportError(
f"tqdm needs to be installed to use tqdm.{tqdm_cls}") from e
self.tqdm_cls = getattr(tqdm_module, "tqdm")
else:
self.tqdm_cls = tqdm_cls
self.early_stopping_callback = early_stopping_callback
self.tqdm_kwargs = tqdm_kwargs
if "total" in tqdm_kwargs:
warnings.warn("'total' in tqdm_kwargs is ignored.", UserWarning)
self.pbar = None

def _init(self, env: CallbackEnv) -> None:
# create pbar on first call
tqdm_kwargs = self.tqdm_kwargs.copy()
tqdm_kwargs["total"] = env.end_iteration - env.begin_iteration
self.pbar = self.tqdm_cls(**tqdm_kwargs)

def __call__(self, env: CallbackEnv) -> None:
if env.iteration == env.begin_iteration:
self._init(env)
assert self.pbar is not None

# update postfix
if len(env.evaluation_result_list) > 0:
# If OrderedDict is not used, the order of display is disjointed and slightly difficult to see.
# https://github.com/microsoft/LightGBM/blob/a97c444b4cf9d2755bd888911ce65ace1fe13e4b/python-package/lightgbm/callback.py#L56-66
if self.early_stopping_callback is not None:
postfix = OrderedDict(
[
(
f"{entry[0]}'s {entry[1]}",
f"{entry[2]:g}{'=' if entry[2] == best_score else ('>' if cmp_op else '<')}{best_score:g}@{best_iter}it",
)
for entry, cmp_op, best_score, best_iter in zip(
env.evaluation_result_list,
self.early_stopping_callback.cmp_op,
self.early_stopping_callback.best_score,
self.early_stopping_callback.best_iter,
)
]
)
else:
postfix = OrderedDict(
[
(f"{entry[0]}'s {entry[1]}", f"{entry[2]:g}")
for entry in env.evaluation_result_list
]
)
self.pbar.set_postfix(ordered_dict=postfix, refresh=False)

# update pbar
self.pbar.update()
self.pbar.refresh()

# do not pickle tqdm instance
def __getstate__(self) -> dict[str, Any]:
state = self.__dict__.copy()
state["pbar"] = None
# class should be picklable
return state

def __setstate__(self, state: dict[str, Any]) -> None:
self.__dict__.update(state)
self.pbar = None

if sys.version_info >= (3, 8):
from typing import Literal, overload

@overload
def progress_bar(tqdm_cls: Literal[
"auto",
"autonotebook",
"std",
"notebook",
"asyncio",
"keras",
"dask",
"tk",
"gui",
"rich",
"contrib.slack",
"contrib.discord",
"contrib.telegram",
"contrib.bells",
], early_stopping_callback: _EarlyStoppingCallback | None = None, **tqdm_kwargs: Any) -> _ProgressBarCallback:
...

def progress_bar(tqdm_cls: str | Type["tqdm.std.tqdm"] = "auto",
early_stopping_callback: _EarlyStoppingCallback | None = None,
**tqdm_kwargs: Any,
) -> _ProgressBarCallback:
"""Progress bar callback for LightGBM training.

Parameters
----------
tqdm_cls : Literal[ &quot;auto&quot;, &quot;autonotebook&quot;, &quot;std&quot;, &quot;notebook&quot;, &quot;asyncio&quot;, &quot;keras&quot;, &quot;dask&quot;, &quot;tk&quot;, &quot;gui&quot;, &quot;rich&quot;, &quot;contrib.slack&quot;, &quot;contrib.discord&quot;, &quot;contrib.telegram&quot;, &quot;contrib.bells&quot;, ] | Type[tqdm.std.tqdm], optional
The tqdm class or module name, by default &quot;auto&quot;
early_stopping_callback : Any | None, optional
The early stopping callback, by default None

.. rubric:: Example

.. code-block:: python
early_stopping_callback = early_stopping(stopping_rounds=50)
callbacks = [early_stopping_callback, progress_bar(early_stopping_callback=early_stopping_callback)]
estimator.fit(X_train, y_train, eval_set=[(X_test, y_test)], callbacks=callbacks)

Returns
-------
callback : _ProgressBarCallback
The callback that displays the progress bar.
"""
return _ProgressBarCallback(tqdm_cls=tqdm_cls, early_stopping_callback=early_stopping_callback, **tqdm_kwargs)
51 changes: 50 additions & 1 deletion tests/python_package_test/test_callback.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
# coding: utf-8
import pytest
from sklearn.model_selection import train_test_split
import tqdm

import lightgbm as lgb
import lightgbm.callback

from .utils import SERIALIZERS, pickle_and_unpickle_object
from .utils import SERIALIZERS, load_breast_cancer, pickle_and_unpickle_object


def reset_feature_fraction(boosting_round):
Expand Down Expand Up @@ -55,3 +58,49 @@ def test_reset_parameter_callback_is_picklable(serializer):
assert callback_from_disk.before_iteration is True
assert callback.kwargs == callback_from_disk.kwargs
assert callback.kwargs == params

@pytest.mark.parametrize('serializer', SERIALIZERS)
def test_progress_bar_callback_is_picklable(serializer):
callback = lgb.progress_bar()
callback_from_disk = pickle_and_unpickle_object(obj=callback, serializer=serializer)
callback.CallbackEnv(model=None,
params={},
iteration=1,
begin_iteration=0,
end_iteration=100,
evaluation_result_list=[])
callback.pbar = tqdm.tqdm(total=100)
assert callback_from_disk.order == 40
assert callback_from_disk.before_iteration is False

def test_progress_bar_warn_override() -> None:
with pytest.warns(UserWarning):
lgb.progress_bar(total=100)

def test_progress_bar_binary():
X, y = load_breast_cancer(return_X_y=True)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.1, random_state=42)
gbm = lgb.LGBMClassifier(n_estimators=50, verbose=-1)
callback = lgb.progress_bar()
gbm.fit(X_train, y_train, eval_set=[(X_test, y_test)], callbacks=[lgb.early_stopping(5), callback])

assert issubclass(callback.tqdm_cls, tqdm.std.tqdm)
assert isinstance(callback.pbar, tqdm.std.tqdm)
assert callback.pbar is not None
assert callback.pbar.total == gbm.n_estimators
assert callback.pbar.n == gbm.n_estimators

def test_progress_bar_early_stopping_binary():
X, y = load_breast_cancer(return_X_y=True)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.1, random_state=42)
gbm = lgb.LGBMClassifier(n_estimators=50, verbose=-1)
early_stopping_callback = lgb.early_stopping(5)
callback = lgb.progress_bar(early_stopping_callback=early_stopping_callback)
gbm.fit(X_train, y_train, eval_set=[(X_test, y_test)], callbacks=[early_stopping_callback, callback])

assert issubclass(callback.tqdm_cls, tqdm.std.tqdm)
assert isinstance(callback.pbar, tqdm.std.tqdm)
assert callback.pbar is not None
assert callback.pbar.total == gbm.n_estimators
assert callback.pbar.n >= 0
assert callback.pbar.n <= gbm.n_estimators