Skip to content

Commit 0065605

Browse files
committed
feat(callback): add progress_bar callback
1 parent a97c444 commit 0065605

File tree

3 files changed

+173
-2
lines changed

3 files changed

+173
-2
lines changed

python-package/lightgbm/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from pathlib import Path
77

88
from .basic import Booster, Dataset, Sequence, register_logger
9-
from .callback import early_stopping, log_evaluation, record_evaluation, reset_parameter
9+
from .callback import early_stopping, log_evaluation, record_evaluation, reset_parameter, progress_bar
1010
from .engine import CVBooster, cv, train
1111

1212
try:
@@ -32,5 +32,5 @@
3232
'train', 'cv',
3333
'LGBMModel', 'LGBMRegressor', 'LGBMClassifier', 'LGBMRanker',
3434
'DaskLGBMRegressor', 'DaskLGBMClassifier', 'DaskLGBMRanker',
35-
'log_evaluation', 'record_evaluation', 'reset_parameter', 'early_stopping',
35+
'log_evaluation', 'record_evaluation', 'reset_parameter', 'early_stopping', 'progress_bar',
3636
'plot_importance', 'plot_split_value_histogram', 'plot_metric', 'plot_tree', 'create_tree_digraph']

python-package/lightgbm/callback.py

Lines changed: 157 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,18 @@
11
# coding: utf-8
22
"""Callbacks library."""
3+
from __future__ import annotations
4+
35
import collections
6+
import importlib
7+
import warnings
8+
from collections import OrderedDict
49
from functools import partial
510
from typing import Any, Callable, Dict, List, Tuple, Union
11+
from typing import Any, Literal, Type
12+
try:
13+
import tqdm
14+
except ImportError:
15+
pass
616

717
from .basic import _ConfigAliases, _LGBM_BoosterEvalMethodResultType, _log_info, _log_warning
818

@@ -11,6 +21,7 @@
1121
'log_evaluation',
1222
'record_evaluation',
1323
'reset_parameter',
24+
'progress_bar',
1425
]
1526

1627
_EvalResultDict = Dict[str, Dict[str, List[Any]]]
@@ -413,3 +424,149 @@ def early_stopping(stopping_rounds: int, first_metric_only: bool = False, verbos
413424
The callback that activates early stopping.
414425
"""
415426
return _EarlyStoppingCallback(stopping_rounds=stopping_rounds, first_metric_only=first_metric_only, verbose=verbose, min_delta=min_delta)
427+
428+
429+
class _ProgressBarCallback:
430+
"""Internal class to handle progress bar."""
431+
tqdm_cls: "Type[tqdm.std.tqdm]"
432+
pbar: "tqdm.std.tqdm" | None
433+
434+
def __init__(
435+
self,
436+
tqdm_cls: Literal[
437+
"auto",
438+
"autonotebook",
439+
"std",
440+
"notebook",
441+
"asyncio",
442+
"keras",
443+
"dask",
444+
"tk",
445+
"gui",
446+
"rich",
447+
"contrib.slack",
448+
"contrib.discord",
449+
"contrib.telegram",
450+
"contrib.bells",
451+
]
452+
| "Type[tqdm.std.tqdm]" = "auto",
453+
early_stopping_callback: Any | None = None,
454+
**tqdm_kwargs: Any,
455+
) -> None:
456+
"""Progress bar callback for LightGBM training.
457+
458+
Parameters
459+
----------
460+
tqdm_cls : Literal[ "auto", "autonotebook", "std", "notebook", "asyncio", "keras", "dask", "tk", "gui", "rich", "contrib.slack", "contrib.discord", "contrib.telegram", "contrib.bells", ] | Type[tqdm.std.tqdm], optional
461+
The tqdm class or module name, by default "auto"
462+
early_stopping_callback : _EarlyStoppingCallback | None, optional
463+
The early stopping callback, by default None
464+
465+
.. rubric:: Example
466+
467+
.. code-block:: python
468+
early_stopping_callback = early_stopping(stopping_rounds=50)
469+
callbacks = [early_stopping_callback, ProgressBarCallback(early_stopping_callback=early_stopping_callback)]
470+
estimator.fit(X_train, y_train, eval_set=[(X_test, y_test)], callbacks=callbacks)
471+
"""
472+
if isinstance(tqdm_cls, str):
473+
try:
474+
tqdm_module = importlib.import_module(f"tqdm.{tqdm_cls}")
475+
except ImportError as e:
476+
raise ImportError(
477+
f"tqdm needs to be installed to use tqdm.{tqdm_cls}") from e
478+
self.tqdm_cls = getattr(tqdm_module, "tqdm")
479+
else:
480+
self.tqdm_cls = tqdm_cls
481+
self.early_stopping_callback = early_stopping_callback
482+
self.tqdm_kwargs = tqdm_kwargs
483+
if "total" in tqdm_kwargs:
484+
warnings.warn("'total' in tqdm_kwargs is ignored.", UserWarning)
485+
self.pbar = None
486+
487+
def _init(self, env: CallbackEnv) -> None:
488+
# create pbar on first call
489+
tqdm_kwargs = self.tqdm_kwargs.copy()
490+
tqdm_kwargs["total"] = env.end_iteration - env.begin_iteration
491+
self.pbar = self.tqdm_cls(**tqdm_kwargs)
492+
493+
def __call__(self, env: CallbackEnv) -> None:
494+
if env.iteration == env.begin_iteration:
495+
self._init(env)
496+
assert self.pbar is not None
497+
498+
# update postfix
499+
if len(env.evaluation_result_list) > 0:
500+
# If OrderedDict is not used, the order of display is disjointed and slightly difficult to see.
501+
# https://github.com/microsoft/LightGBM/blob/a97c444b4cf9d2755bd888911ce65ace1fe13e4b/python-package/lightgbm/callback.py#L56-66
502+
if self.early_stopping_callback is not None:
503+
postfix = OrderedDict(
504+
[
505+
(
506+
f"{entry[0]}'s {entry[1]}",
507+
f"{entry[2]:g}{'=' if entry[2] == best_score else ('>' if cmp_op else '<')}{best_score:g}@{best_iter}it",
508+
)
509+
for entry, cmp_op, best_score, best_iter in zip(
510+
env.evaluation_result_list,
511+
self.early_stopping_callback.cmp_op,
512+
self.early_stopping_callback.best_score,
513+
self.early_stopping_callback.best_iter,
514+
)
515+
]
516+
)
517+
else:
518+
postfix = OrderedDict(
519+
[
520+
(f"{entry[0]}'s {entry[1]}", f"{entry[2]:g}")
521+
for entry in env.evaluation_result_list
522+
]
523+
)
524+
self.pbar.set_postfix(ordered_dict=postfix, refresh=False)
525+
526+
# update pbar
527+
self.pbar.update()
528+
self.pbar.refresh()
529+
530+
531+
def progress_bar(tqdm_cls: Literal[
532+
"auto",
533+
"autonotebook",
534+
"std",
535+
"notebook",
536+
"asyncio",
537+
"keras",
538+
"dask",
539+
"tk",
540+
"gui",
541+
"rich",
542+
"contrib.slack",
543+
"contrib.discord",
544+
"contrib.telegram",
545+
"contrib.bells",
546+
]
547+
| "Type[tqdm.std.tqdm]" = "auto",
548+
early_stopping_callback: _EarlyStoppingCallback | None = None,
549+
**tqdm_kwargs: Any,
550+
) -> _ProgressBarCallback:
551+
"""Progress bar callback for LightGBM training.
552+
553+
Parameters
554+
----------
555+
tqdm_cls : Literal[ &quot;auto&quot;, &quot;autonotebook&quot;, &quot;std&quot;, &quot;notebook&quot;, &quot;asyncio&quot;, &quot;keras&quot;, &quot;dask&quot;, &quot;tk&quot;, &quot;gui&quot;, &quot;rich&quot;, &quot;contrib.slack&quot;, &quot;contrib.discord&quot;, &quot;contrib.telegram&quot;, &quot;contrib.bells&quot;, ] | Type[tqdm.std.tqdm], optional
556+
The tqdm class or module name, by default &quot;auto&quot;
557+
early_stopping_callback : Any | None, optional
558+
The early stopping callback, by default None
559+
560+
.. rubric:: Example
561+
562+
.. code-block:: python
563+
early_stopping_callback = early_stopping(stopping_rounds=50)
564+
callbacks = [early_stopping_callback, progress_bar(early_stopping_callback=early_stopping_callback)]
565+
estimator.fit(X_train, y_train, eval_set=[(X_test, y_test)], callbacks=callbacks)
566+
567+
Returns
568+
-------
569+
callback : _ProgressBarCallback
570+
The callback that displays the progress bar.
571+
"""
572+
return _ProgressBarCallback(tqdm_cls=tqdm_cls, early_stopping_callback=early_stopping_callback, **tqdm_kwargs)

tests/python_package_test/test_callback.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,3 +55,17 @@ def test_reset_parameter_callback_is_picklable(serializer):
5555
assert callback_from_disk.before_iteration is True
5656
assert callback.kwargs == callback_from_disk.kwargs
5757
assert callback.kwargs == params
58+
59+
@pytest.mark.parametrize('serializer', SERIALIZERS)
60+
def test_progress_bar_callback_is_picklable(serializer):
61+
rounds = 5
62+
callback = lgb.progress_bar()
63+
callback_from_disk = pickle_and_unpickle_object(obj=callback, serializer=serializer)
64+
assert callback_from_disk.order == 30
65+
assert callback_from_disk.before_iteration is False
66+
assert callback.stopping_rounds == callback_from_disk.stopping_rounds
67+
assert callback.stopping_rounds == rounds
68+
69+
def test_progress_bar_warn_override(self) -> None:
70+
with pytest.warns(UserWarning):
71+
lgb.progress_bar(self.tqdm_cls, total=100, **self.tqdm_kwargs)

0 commit comments

Comments
 (0)