From d1f39ea1e15966882cffaf9ee9ef2f75ce9cd0bc Mon Sep 17 00:00:00 2001 From: 34j <55338215+34j@users.noreply.github.com> Date: Fri, 5 May 2023 15:35:38 +0900 Subject: [PATCH 01/14] feat(callback): add `progress_bar` callback --- python-package/lightgbm/__init__.py | 4 +- python-package/lightgbm/callback.py | 159 ++++++++++++++++++++- tests/python_package_test/test_callback.py | 14 ++ 3 files changed, 174 insertions(+), 3 deletions(-) diff --git a/python-package/lightgbm/__init__.py b/python-package/lightgbm/__init__.py index 5815bc602bde..d311ba934073 100644 --- a/python-package/lightgbm/__init__.py +++ b/python-package/lightgbm/__init__.py @@ -6,7 +6,7 @@ from pathlib import Path from .basic import Booster, Dataset, Sequence, register_logger -from .callback import early_stopping, log_evaluation, record_evaluation, reset_parameter +from .callback import early_stopping, log_evaluation, record_evaluation, reset_parameter, progress_bar from .engine import CVBooster, cv, train try: @@ -32,5 +32,5 @@ 'train', 'cv', 'LGBMModel', 'LGBMRegressor', 'LGBMClassifier', 'LGBMRanker', 'DaskLGBMRegressor', 'DaskLGBMClassifier', 'DaskLGBMRanker', - 'log_evaluation', 'record_evaluation', 'reset_parameter', 'early_stopping', + 'log_evaluation', 'record_evaluation', 'reset_parameter', 'early_stopping', 'progress_bar', 'plot_importance', 'plot_split_value_histogram', 'plot_metric', 'plot_tree', 'create_tree_digraph'] diff --git a/python-package/lightgbm/callback.py b/python-package/lightgbm/callback.py index 868a6fc15534..5b057529c29b 100644 --- a/python-package/lightgbm/callback.py +++ b/python-package/lightgbm/callback.py @@ -1,16 +1,27 @@ # coding: utf-8 """Callbacks library.""" +from __future__ import annotations + import collections +import importlib +import warnings +from collections import OrderedDict from functools import partial -from typing import Any, Callable, Dict, List, Tuple, Union +from typing import Any, Callable, Dict, List, Literal, Tuple, Type, Union from .basic import _ConfigAliases, _LGBM_BoosterEvalMethodResultType, _log_info, _log_warning +try: + import tqdm +except ImportError: + pass + __all__ = [ 'early_stopping', 'log_evaluation', 'record_evaluation', 'reset_parameter', + 'progress_bar', ] _EvalResultDict = Dict[str, Dict[str, List[Any]]] @@ -413,3 +424,149 @@ def early_stopping(stopping_rounds: int, first_metric_only: bool = False, verbos The callback that activates early stopping. """ return _EarlyStoppingCallback(stopping_rounds=stopping_rounds, first_metric_only=first_metric_only, verbose=verbose, min_delta=min_delta) + + +class _ProgressBarCallback: + """Internal class to handle progress bar.""" + tqdm_cls: "Type[tqdm.std.tqdm]" + pbar: "tqdm.std.tqdm" | None + + def __init__( + self, + tqdm_cls: Literal[ + "auto", + "autonotebook", + "std", + "notebook", + "asyncio", + "keras", + "dask", + "tk", + "gui", + "rich", + "contrib.slack", + "contrib.discord", + "contrib.telegram", + "contrib.bells", + ] + | "Type[tqdm.std.tqdm]" = "auto", + early_stopping_callback: Any | None = None, + **tqdm_kwargs: Any, + ) -> None: + """Progress bar callback for LightGBM training. + + Parameters + ---------- + tqdm_cls : Literal[ "auto", "autonotebook", "std", "notebook", "asyncio", "keras", "dask", "tk", "gui", "rich", "contrib.slack", "contrib.discord", "contrib.telegram", "contrib.bells", ] | Type[tqdm.std.tqdm], optional + The tqdm class or module name, by default "auto" + early_stopping_callback : _EarlyStoppingCallback | None, optional + The early stopping callback, by default None + + .. rubric:: Example + + .. code-block:: python + early_stopping_callback = early_stopping(stopping_rounds=50) + callbacks = [early_stopping_callback, ProgressBarCallback(early_stopping_callback=early_stopping_callback)] + estimator.fit(X_train, y_train, eval_set=[(X_test, y_test)], callbacks=callbacks) + """ + if isinstance(tqdm_cls, str): + try: + tqdm_module = importlib.import_module(f"tqdm.{tqdm_cls}") + except ImportError as e: + raise ImportError( + f"tqdm needs to be installed to use tqdm.{tqdm_cls}") from e + self.tqdm_cls = getattr(tqdm_module, "tqdm") + else: + self.tqdm_cls = tqdm_cls + self.early_stopping_callback = early_stopping_callback + self.tqdm_kwargs = tqdm_kwargs + if "total" in tqdm_kwargs: + warnings.warn("'total' in tqdm_kwargs is ignored.", UserWarning) + self.pbar = None + + def _init(self, env: CallbackEnv) -> None: + # create pbar on first call + tqdm_kwargs = self.tqdm_kwargs.copy() + tqdm_kwargs["total"] = env.end_iteration - env.begin_iteration + self.pbar = self.tqdm_cls(**tqdm_kwargs) + + def __call__(self, env: CallbackEnv) -> None: + if env.iteration == env.begin_iteration: + self._init(env) + assert self.pbar is not None + + # update postfix + if len(env.evaluation_result_list) > 0: + # If OrderedDict is not used, the order of display is disjointed and slightly difficult to see. + # https://github.com/microsoft/LightGBM/blob/a97c444b4cf9d2755bd888911ce65ace1fe13e4b/python-package/lightgbm/callback.py#L56-66 + if self.early_stopping_callback is not None: + postfix = OrderedDict( + [ + ( + f"{entry[0]}'s {entry[1]}", + f"{entry[2]:g}{'=' if entry[2] == best_score else ('>' if cmp_op else '<')}{best_score:g}@{best_iter}it", + ) + for entry, cmp_op, best_score, best_iter in zip( + env.evaluation_result_list, + self.early_stopping_callback.cmp_op, + self.early_stopping_callback.best_score, + self.early_stopping_callback.best_iter, + ) + ] + ) + else: + postfix = OrderedDict( + [ + (f"{entry[0]}'s {entry[1]}", f"{entry[2]:g}") + for entry in env.evaluation_result_list + ] + ) + self.pbar.set_postfix(ordered_dict=postfix, refresh=False) + + # update pbar + self.pbar.update() + self.pbar.refresh() + + +def progress_bar(tqdm_cls: Literal[ + "auto", + "autonotebook", + "std", + "notebook", + "asyncio", + "keras", + "dask", + "tk", + "gui", + "rich", + "contrib.slack", + "contrib.discord", + "contrib.telegram", + "contrib.bells", +] + | "Type[tqdm.std.tqdm]" = "auto", + early_stopping_callback: _EarlyStoppingCallback | None = None, + **tqdm_kwargs: Any, +) -> _ProgressBarCallback: + """Progress bar callback for LightGBM training. + + Parameters + ---------- + tqdm_cls : Literal[ "auto", "autonotebook", "std", "notebook", "asyncio", "keras", "dask", "tk", "gui", "rich", "contrib.slack", "contrib.discord", "contrib.telegram", "contrib.bells", ] | Type[tqdm.std.tqdm], optional + The tqdm class or module name, by default "auto" + early_stopping_callback : Any | None, optional + The early stopping callback, by default None + + .. rubric:: Example + + .. code-block:: python + early_stopping_callback = early_stopping(stopping_rounds=50) + callbacks = [early_stopping_callback, progress_bar(early_stopping_callback=early_stopping_callback)] + estimator.fit(X_train, y_train, eval_set=[(X_test, y_test)], callbacks=callbacks) + + Returns + ------- + callback : _ProgressBarCallback + The callback that displays the progress bar. + """ + return _ProgressBarCallback(tqdm_cls=tqdm_cls, early_stopping_callback=early_stopping_callback, **tqdm_kwargs) diff --git a/tests/python_package_test/test_callback.py b/tests/python_package_test/test_callback.py index cb5dc707bf43..cd340098d1a8 100644 --- a/tests/python_package_test/test_callback.py +++ b/tests/python_package_test/test_callback.py @@ -55,3 +55,17 @@ def test_reset_parameter_callback_is_picklable(serializer): assert callback_from_disk.before_iteration is True assert callback.kwargs == callback_from_disk.kwargs assert callback.kwargs == params + +@pytest.mark.parametrize('serializer', SERIALIZERS) +def test_progress_bar_callback_is_picklable(serializer): + rounds = 5 + callback = lgb.progress_bar() + callback_from_disk = pickle_and_unpickle_object(obj=callback, serializer=serializer) + assert callback_from_disk.order == 30 + assert callback_from_disk.before_iteration is False + assert callback.stopping_rounds == callback_from_disk.stopping_rounds + assert callback.stopping_rounds == rounds + +def test_progress_bar_warn_override() -> None: + with pytest.warns(UserWarning): + lgb.progress_bar(total=100) From d27ebed88c9d76c0e70a37ccee9332e530abc44b Mon Sep 17 00:00:00 2001 From: 34j <55338215+34j@users.noreply.github.com> Date: Fri, 5 May 2023 16:45:01 +0900 Subject: [PATCH 02/14] fix(callback): do not use Literal if Python version is < 3.8 --- python-package/lightgbm/callback.py | 42 ++++++++++++----------------- 1 file changed, 17 insertions(+), 25 deletions(-) diff --git a/python-package/lightgbm/callback.py b/python-package/lightgbm/callback.py index 5b057529c29b..7bb9f93f0744 100644 --- a/python-package/lightgbm/callback.py +++ b/python-package/lightgbm/callback.py @@ -7,7 +7,8 @@ import warnings from collections import OrderedDict from functools import partial -from typing import Any, Callable, Dict, List, Literal, Tuple, Type, Union +from typing import Any, Callable, Dict, List, Tuple, Type, Union +import sys from .basic import _ConfigAliases, _LGBM_BoosterEvalMethodResultType, _log_info, _log_warning @@ -426,14 +427,9 @@ def early_stopping(stopping_rounds: int, first_metric_only: bool = False, verbos return _EarlyStoppingCallback(stopping_rounds=stopping_rounds, first_metric_only=first_metric_only, verbose=verbose, min_delta=min_delta) -class _ProgressBarCallback: - """Internal class to handle progress bar.""" - tqdm_cls: "Type[tqdm.std.tqdm]" - pbar: "tqdm.std.tqdm" | None - - def __init__( - self, - tqdm_cls: Literal[ +if sys.version_info >= (3, 8): + from typing import Literal + _tqdm_module_name_str = Literal[ "auto", "autonotebook", "std", @@ -449,6 +445,17 @@ def __init__( "contrib.telegram", "contrib.bells", ] +else: + _tqdm_module_name_str = str + +class _ProgressBarCallback: + """Internal class to handle progress bar.""" + tqdm_cls: "Type[tqdm.std.tqdm]" + pbar: "tqdm.std.tqdm" | None + + def __init__( + self, + tqdm_cls: _tqdm_module_name_str | "Type[tqdm.std.tqdm]" = "auto", early_stopping_callback: Any | None = None, **tqdm_kwargs: Any, @@ -528,22 +535,7 @@ def __call__(self, env: CallbackEnv) -> None: self.pbar.refresh() -def progress_bar(tqdm_cls: Literal[ - "auto", - "autonotebook", - "std", - "notebook", - "asyncio", - "keras", - "dask", - "tk", - "gui", - "rich", - "contrib.slack", - "contrib.discord", - "contrib.telegram", - "contrib.bells", -] +def progress_bar(tqdm_cls: _tqdm_module_name_str | "Type[tqdm.std.tqdm]" = "auto", early_stopping_callback: _EarlyStoppingCallback | None = None, **tqdm_kwargs: Any, From 2aa94330d390888471eaf29c6b9be41441c22153 Mon Sep 17 00:00:00 2001 From: 34j <55338215+34j@users.noreply.github.com> Date: Fri, 5 May 2023 17:54:58 +0900 Subject: [PATCH 03/14] fix(callback): use overload instead --- python-package/lightgbm/callback.py | 48 ++++++++++++++--------------- 1 file changed, 23 insertions(+), 25 deletions(-) diff --git a/python-package/lightgbm/callback.py b/python-package/lightgbm/callback.py index 7bb9f93f0744..35b599dd29ce 100644 --- a/python-package/lightgbm/callback.py +++ b/python-package/lightgbm/callback.py @@ -427,27 +427,6 @@ def early_stopping(stopping_rounds: int, first_metric_only: bool = False, verbos return _EarlyStoppingCallback(stopping_rounds=stopping_rounds, first_metric_only=first_metric_only, verbose=verbose, min_delta=min_delta) -if sys.version_info >= (3, 8): - from typing import Literal - _tqdm_module_name_str = Literal[ - "auto", - "autonotebook", - "std", - "notebook", - "asyncio", - "keras", - "dask", - "tk", - "gui", - "rich", - "contrib.slack", - "contrib.discord", - "contrib.telegram", - "contrib.bells", - ] -else: - _tqdm_module_name_str = str - class _ProgressBarCallback: """Internal class to handle progress bar.""" tqdm_cls: "Type[tqdm.std.tqdm]" @@ -455,8 +434,7 @@ class _ProgressBarCallback: def __init__( self, - tqdm_cls: _tqdm_module_name_str - | "Type[tqdm.std.tqdm]" = "auto", + tqdm_cls: str | Type["tqdm.std.tqdm"] = "auto", early_stopping_callback: Any | None = None, **tqdm_kwargs: Any, ) -> None: @@ -534,9 +512,29 @@ def __call__(self, env: CallbackEnv) -> None: self.pbar.update() self.pbar.refresh() +if sys.version_info >= (3, 8): + from typing import Literal, overload + + @overload + def progress_bar(tqdm_cls: Literal[ + "auto", + "autonotebook", + "std", + "notebook", + "asyncio", + "keras", + "dask", + "tk", + "gui", + "rich", + "contrib.slack", + "contrib.discord", + "contrib.telegram", + "contrib.bells", + ], early_stopping_callback: _EarlyStoppingCallback | None = None, **tqdm_kwargs: Any) -> _ProgressBarCallback: + ... -def progress_bar(tqdm_cls: _tqdm_module_name_str - | "Type[tqdm.std.tqdm]" = "auto", +def progress_bar(tqdm_cls: str | Type["tqdm.std.tqdm"] = "auto", early_stopping_callback: _EarlyStoppingCallback | None = None, **tqdm_kwargs: Any, ) -> _ProgressBarCallback: From ab1ed89f4ea247ced2575fe2d5b106797780eec0 Mon Sep 17 00:00:00 2001 From: 34j <55338215+34j@users.noreply.github.com> Date: Fri, 5 May 2023 17:56:04 +0900 Subject: [PATCH 04/14] ci: add tqdm to test deps --- .ci/test.sh | 3 ++- .ci/test_windows.ps1 | 3 ++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/.ci/test.sh b/.ci/test.sh index e4b6eb7fbae0..895c38ca5370 100755 --- a/.ci/test.sh +++ b/.ci/test.sh @@ -134,7 +134,8 @@ conda create -q -y -n $CONDA_ENV \ ${CONDA_PYTHON_REQUIREMENT} \ python-graphviz \ scikit-learn \ - scipy || exit -1 + scipy \ + tqdm || exit -1 source activate $CONDA_ENV diff --git a/.ci/test_windows.ps1 b/.ci/test_windows.ps1 index 3d07496d855b..b3c93efd7851 100644 --- a/.ci/test_windows.ps1 +++ b/.ci/test_windows.ps1 @@ -56,7 +56,8 @@ conda create -q -y -n $env:CONDA_ENV ` "python=$env:PYTHON_VERSION[build=*cpython]" ` python-graphviz ` scikit-learn ` - scipy ; Check-Output $? + scipy ` + tqdm ; Check-Output $? if ($env:TASK -ne "bdist") { conda activate $env:CONDA_ENV From 90194011e16b7cc3a1e4d5b416113313ca264f81 Mon Sep 17 00:00:00 2001 From: 34j <55338215+34j@users.noreply.github.com> Date: Fri, 5 May 2023 18:14:44 +0900 Subject: [PATCH 05/14] test(callback): add basic tests --- tests/python_package_test/test_callback.py | 32 +++++++++++++++++++++- 1 file changed, 31 insertions(+), 1 deletion(-) diff --git a/tests/python_package_test/test_callback.py b/tests/python_package_test/test_callback.py index cd340098d1a8..e9dfe8b797e6 100644 --- a/tests/python_package_test/test_callback.py +++ b/tests/python_package_test/test_callback.py @@ -1,9 +1,11 @@ # coding: utf-8 import pytest +from sklearn.model_selection import train_test_split +import tqdm import lightgbm as lgb -from .utils import SERIALIZERS, pickle_and_unpickle_object +from .utils import SERIALIZERS, load_breast_cancer, pickle_and_unpickle_object def reset_feature_fraction(boosting_round): @@ -69,3 +71,31 @@ def test_progress_bar_callback_is_picklable(serializer): def test_progress_bar_warn_override() -> None: with pytest.warns(UserWarning): lgb.progress_bar(total=100) + +def test_progress_bar_binary(): + X, y = load_breast_cancer(return_X_y=True) + X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.1, random_state=42) + gbm = lgb.LGBMClassifier(n_estimators=50, verbose=-1) + callback = lgb.progress_bar() + gbm.fit(X_train, y_train, eval_set=[(X_test, y_test)], callbacks=[lgb.early_stopping(5), callback]) + + assert issubclass(callback.tqdm_cls, tqdm.std.tqdm) + assert isinstance(callback.pbar, tqdm.std.tqdm) + assert callback.pbar is not None + assert callback.pbar.total == gbm.n_estimators + assert callback.pbar.n == gbm.n_estimators + +def test_progress_bar_early_stopping_binary(): + X, y = load_breast_cancer(return_X_y=True) + X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.1, random_state=42) + gbm = lgb.LGBMClassifier(n_estimators=50, verbose=-1) + early_stopping = lgb.early_stopping(5) + callback = lgb.progress_bar(early_stopping=early_stopping) + gbm.fit(X_train, y_train, eval_set=[(X_test, y_test)], callbacks=[early_stopping, callback]) + + assert issubclass(callback.tqdm_cls, tqdm.std.tqdm) + assert isinstance(callback.pbar, tqdm.std.tqdm) + assert callback.pbar is not None + assert callback.pbar.total == gbm.n_estimators + assert callback.pbar.n >= 0 + assert callback.pbar.n <= gbm.n_estimators From d178623d5c7d8eb8c5cb09af9de58ae6bce97bce Mon Sep 17 00:00:00 2001 From: 34j <55338215+34j@users.noreply.github.com> Date: Fri, 5 May 2023 20:45:20 +0900 Subject: [PATCH 06/14] fix(callback): set order and before_iterationi --- python-package/lightgbm/callback.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/python-package/lightgbm/callback.py b/python-package/lightgbm/callback.py index 35b599dd29ce..dee9d4ca2543 100644 --- a/python-package/lightgbm/callback.py +++ b/python-package/lightgbm/callback.py @@ -454,6 +454,8 @@ def __init__( callbacks = [early_stopping_callback, ProgressBarCallback(early_stopping_callback=early_stopping_callback)] estimator.fit(X_train, y_train, eval_set=[(X_test, y_test)], callbacks=callbacks) """ + self.order = 40 + self.before_iteration = False if isinstance(tqdm_cls, str): try: tqdm_module = importlib.import_module(f"tqdm.{tqdm_cls}") From b55ba929b8d329ff5d3a2c93d3f0d07971350c0e Mon Sep 17 00:00:00 2001 From: 34j <55338215+34j@users.noreply.github.com> Date: Fri, 5 May 2023 20:46:05 +0900 Subject: [PATCH 07/14] fix(callback): make the class picklable --- python-package/lightgbm/callback.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/python-package/lightgbm/callback.py b/python-package/lightgbm/callback.py index dee9d4ca2543..5223a622d3d0 100644 --- a/python-package/lightgbm/callback.py +++ b/python-package/lightgbm/callback.py @@ -513,6 +513,17 @@ def __call__(self, env: CallbackEnv) -> None: # update pbar self.pbar.update() self.pbar.refresh() + + # do not pickle tqdm instance + def __getstate__(self) -> dict[str, Any]: + state = self.__dict__.copy() + state["pbar"] = None + # class should be picklable + return state + + def __setstate__(self, state: dict[str, Any]) -> None: + self.__dict__.update(state) + self.pbar = None if sys.version_info >= (3, 8): from typing import Literal, overload From 10a025de2a44e5feeffaf65325d3f7cbaff994c3 Mon Sep 17 00:00:00 2001 From: 34j <55338215+34j@users.noreply.github.com> Date: Fri, 5 May 2023 20:52:21 +0900 Subject: [PATCH 08/14] test: fix pickle test --- tests/python_package_test/test_callback.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/tests/python_package_test/test_callback.py b/tests/python_package_test/test_callback.py index e9dfe8b797e6..fd5d15f8efb6 100644 --- a/tests/python_package_test/test_callback.py +++ b/tests/python_package_test/test_callback.py @@ -60,13 +60,17 @@ def test_reset_parameter_callback_is_picklable(serializer): @pytest.mark.parametrize('serializer', SERIALIZERS) def test_progress_bar_callback_is_picklable(serializer): - rounds = 5 callback = lgb.progress_bar() callback_from_disk = pickle_and_unpickle_object(obj=callback, serializer=serializer) - assert callback_from_disk.order == 30 + callback.CallbackEnv(model=None, + params={}, + iteration=1, + begin_iteration=0, + end_iteration=100, + evaluation_result_list=[]) + callback.pbar = tqdm.tqdm(total=100) + assert callback_from_disk.order == 40 assert callback_from_disk.before_iteration is False - assert callback.stopping_rounds == callback_from_disk.stopping_rounds - assert callback.stopping_rounds == rounds def test_progress_bar_warn_override() -> None: with pytest.warns(UserWarning): From 44c9897dbbc6ef7df96316b110ae78448e29de52 Mon Sep 17 00:00:00 2001 From: 34j <55338215+34j@users.noreply.github.com> Date: Fri, 5 May 2023 20:53:06 +0900 Subject: [PATCH 09/14] test: fix keywords for the callback --- tests/python_package_test/test_callback.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/tests/python_package_test/test_callback.py b/tests/python_package_test/test_callback.py index fd5d15f8efb6..ada0599cbbea 100644 --- a/tests/python_package_test/test_callback.py +++ b/tests/python_package_test/test_callback.py @@ -4,6 +4,7 @@ import tqdm import lightgbm as lgb +import lightgbm.callback from .utils import SERIALIZERS, load_breast_cancer, pickle_and_unpickle_object @@ -93,9 +94,9 @@ def test_progress_bar_early_stopping_binary(): X, y = load_breast_cancer(return_X_y=True) X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.1, random_state=42) gbm = lgb.LGBMClassifier(n_estimators=50, verbose=-1) - early_stopping = lgb.early_stopping(5) - callback = lgb.progress_bar(early_stopping=early_stopping) - gbm.fit(X_train, y_train, eval_set=[(X_test, y_test)], callbacks=[early_stopping, callback]) + early_stopping_callback = lgb.early_stopping(5) + callback = lgb.progress_bar(early_stopping_callback=early_stopping_callback) + gbm.fit(X_train, y_train, eval_set=[(X_test, y_test)], callbacks=[early_stopping_callback, callback]) assert issubclass(callback.tqdm_cls, tqdm.std.tqdm) assert isinstance(callback.pbar, tqdm.std.tqdm) From c137fca39e66c68a228396a84a0783724a0fef4f Mon Sep 17 00:00:00 2001 From: 34j <55338215+34j@users.noreply.github.com> Date: Fri, 5 May 2023 22:00:55 +0900 Subject: [PATCH 10/14] test: fix callback test --- tests/python_package_test/test_callback.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/tests/python_package_test/test_callback.py b/tests/python_package_test/test_callback.py index ada0599cbbea..1c58bedc5288 100644 --- a/tests/python_package_test/test_callback.py +++ b/tests/python_package_test/test_callback.py @@ -63,13 +63,12 @@ def test_reset_parameter_callback_is_picklable(serializer): def test_progress_bar_callback_is_picklable(serializer): callback = lgb.progress_bar() callback_from_disk = pickle_and_unpickle_object(obj=callback, serializer=serializer) - callback.CallbackEnv(model=None, + callback(lightgbm.callback.CallbackEnv(model=None, params={}, iteration=1, begin_iteration=0, end_iteration=100, - evaluation_result_list=[]) - callback.pbar = tqdm.tqdm(total=100) + evaluation_result_list=[])) assert callback_from_disk.order == 40 assert callback_from_disk.before_iteration is False From 8e52051741b644fd929261690379e15e76010566 Mon Sep 17 00:00:00 2001 From: 34j <55338215+34j@users.noreply.github.com> Date: Fri, 5 May 2023 22:23:21 +0900 Subject: [PATCH 11/14] test(callback): fix CallbackEnv kwargs in `test_progress_bar_callback_is_picklable` --- tests/python_package_test/test_callback.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/python_package_test/test_callback.py b/tests/python_package_test/test_callback.py index 1c58bedc5288..11050fcc4303 100644 --- a/tests/python_package_test/test_callback.py +++ b/tests/python_package_test/test_callback.py @@ -65,7 +65,7 @@ def test_progress_bar_callback_is_picklable(serializer): callback_from_disk = pickle_and_unpickle_object(obj=callback, serializer=serializer) callback(lightgbm.callback.CallbackEnv(model=None, params={}, - iteration=1, + iteration=0, begin_iteration=0, end_iteration=100, evaluation_result_list=[])) From 8f07b5c4a64e6ddc269a1c075e76c2b519c6216b Mon Sep 17 00:00:00 2001 From: 34j <55338215+34j@users.noreply.github.com> Date: Fri, 5 May 2023 22:26:58 +0900 Subject: [PATCH 12/14] test(callback): do not use early stopping in `test_progress_bar_binary()` --- tests/python_package_test/test_callback.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/python_package_test/test_callback.py b/tests/python_package_test/test_callback.py index 11050fcc4303..8d4c718a0ef2 100644 --- a/tests/python_package_test/test_callback.py +++ b/tests/python_package_test/test_callback.py @@ -81,7 +81,7 @@ def test_progress_bar_binary(): X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.1, random_state=42) gbm = lgb.LGBMClassifier(n_estimators=50, verbose=-1) callback = lgb.progress_bar() - gbm.fit(X_train, y_train, eval_set=[(X_test, y_test)], callbacks=[lgb.early_stopping(5), callback]) + gbm.fit(X_train, y_train, eval_set=[(X_test, y_test)], callbacks=[callback]) assert issubclass(callback.tqdm_cls, tqdm.std.tqdm) assert isinstance(callback.pbar, tqdm.std.tqdm) From 6fbef47672348b4a0e958ee607b0ec746e9bd7b7 Mon Sep 17 00:00:00 2001 From: 34j <55338215+34j@users.noreply.github.com> Date: Sat, 6 May 2023 00:26:08 +0900 Subject: [PATCH 13/14] ci: add types-tqdm for mypy --- .ci/test.sh | 3 ++- .ci/test_windows.ps1 | 3 ++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/.ci/test.sh b/.ci/test.sh index 895c38ca5370..72469e597cf6 100755 --- a/.ci/test.sh +++ b/.ci/test.sh @@ -135,7 +135,8 @@ conda create -q -y -n $CONDA_ENV \ python-graphviz \ scikit-learn \ scipy \ - tqdm || exit -1 + tqdm \ + types-tqdm || exit -1 source activate $CONDA_ENV diff --git a/.ci/test_windows.ps1 b/.ci/test_windows.ps1 index b3c93efd7851..c62e8fd041d9 100644 --- a/.ci/test_windows.ps1 +++ b/.ci/test_windows.ps1 @@ -57,7 +57,8 @@ conda create -q -y -n $env:CONDA_ENV ` python-graphviz ` scikit-learn ` scipy ` - tqdm ; Check-Output $? + tqdm ` + types-tqdm ; Check-Output $? if ($env:TASK -ne "bdist") { conda activate $env:CONDA_ENV From 8dddfd3d0d1a267641e1a8a9afc2d77953440de5 Mon Sep 17 00:00:00 2001 From: 34j <55338215+34j@users.noreply.github.com> Date: Sat, 6 May 2023 00:27:21 +0900 Subject: [PATCH 14/14] style: lint files --- docs/conf.py | 1 + python-package/lightgbm/__init__.py | 2 +- python-package/lightgbm/callback.py | 45 ++++++++++++++++++++++++----- 3 files changed, 39 insertions(+), 9 deletions(-) diff --git a/docs/conf.py b/docs/conf.py index 38ec99e75a36..125026cf99e5 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -103,6 +103,7 @@ def run(self) -> List: 'pandas', 'scipy', 'scipy.sparse', + 'tqdm' ] try: import sklearn # noqa: F401 diff --git a/python-package/lightgbm/__init__.py b/python-package/lightgbm/__init__.py index d311ba934073..dc36a4bf7c03 100644 --- a/python-package/lightgbm/__init__.py +++ b/python-package/lightgbm/__init__.py @@ -6,7 +6,7 @@ from pathlib import Path from .basic import Booster, Dataset, Sequence, register_logger -from .callback import early_stopping, log_evaluation, record_evaluation, reset_parameter, progress_bar +from .callback import early_stopping, log_evaluation, progress_bar, record_evaluation, reset_parameter from .engine import CVBooster, cv, train try: diff --git a/python-package/lightgbm/callback.py b/python-package/lightgbm/callback.py index 5223a622d3d0..2206b8d148dc 100644 --- a/python-package/lightgbm/callback.py +++ b/python-package/lightgbm/callback.py @@ -4,11 +4,11 @@ import collections import importlib +import sys import warnings from collections import OrderedDict from functools import partial from typing import Any, Callable, Dict, List, Tuple, Type, Union -import sys from .basic import _ConfigAliases, _LGBM_BoosterEvalMethodResultType, _log_info, _log_warning @@ -429,6 +429,7 @@ def early_stopping(stopping_rounds: int, first_metric_only: bool = False, verbos class _ProgressBarCallback: """Internal class to handle progress bar.""" + tqdm_cls: "Type[tqdm.std.tqdm]" pbar: "tqdm.std.tqdm" | None @@ -513,23 +514,25 @@ def __call__(self, env: CallbackEnv) -> None: # update pbar self.pbar.update() self.pbar.refresh() - + # do not pickle tqdm instance def __getstate__(self) -> dict[str, Any]: state = self.__dict__.copy() state["pbar"] = None # class should be picklable return state - + def __setstate__(self, state: dict[str, Any]) -> None: self.__dict__.update(state) self.pbar = None + if sys.version_info >= (3, 8): from typing import Literal, overload - - @overload - def progress_bar(tqdm_cls: Literal[ + + @overload # type: ignore + def progress_bar( # type: ignore + tqdm_cls: Literal[ "auto", "autonotebook", "std", @@ -544,10 +547,36 @@ def progress_bar(tqdm_cls: Literal[ "contrib.discord", "contrib.telegram", "contrib.bells", - ], early_stopping_callback: _EarlyStoppingCallback | None = None, **tqdm_kwargs: Any) -> _ProgressBarCallback: + ], + early_stopping_callback: _EarlyStoppingCallback | None = None, + **tqdm_kwargs: Any + ) -> _ProgressBarCallback: + """Progress bar callback for LightGBM training. + + Parameters + ---------- + tqdm_cls : Literal[ "auto", "autonotebook", "std", "notebook", "asyncio", "keras", "dask", "tk", "gui", "rich", "contrib.slack", "contrib.discord", "contrib.telegram", "contrib.bells", ] | Type[tqdm.std.tqdm], optional + The tqdm class or module name, by default "auto" + early_stopping_callback : Any | None, optional + The early stopping callback, by default None + + .. rubric:: Example + + .. code-block:: python + early_stopping_callback = early_stopping(stopping_rounds=50) + callbacks = [early_stopping_callback, progress_bar(early_stopping_callback=early_stopping_callback)] + estimator.fit(X_train, y_train, eval_set=[(X_test, y_test)], callbacks=callbacks) + + Returns + ------- + callback : _ProgressBarCallback + The callback that displays the progress bar. + """ ... -def progress_bar(tqdm_cls: str | Type["tqdm.std.tqdm"] = "auto", + +def progress_bar( # type: ignore + tqdm_cls: str | Type["tqdm.std.tqdm"] = "auto", early_stopping_callback: _EarlyStoppingCallback | None = None, **tqdm_kwargs: Any, ) -> _ProgressBarCallback: