|
1 | 1 | # coding: utf-8
|
2 | 2 | """Callbacks library."""
|
| 3 | +from __future__ import annotations |
| 4 | + |
3 | 5 | import collections
|
| 6 | +import importlib |
| 7 | +import warnings |
| 8 | +from collections import OrderedDict |
4 | 9 | from functools import partial
|
5 | 10 | from typing import Any, Callable, Dict, List, Tuple, Union
|
| 11 | +from typing import Any, Literal, Type |
| 12 | +try: |
| 13 | + import tqdm |
| 14 | +except ImportError: |
| 15 | + pass |
6 | 16 |
|
7 | 17 | from .basic import _ConfigAliases, _LGBM_BoosterEvalMethodResultType, _log_info, _log_warning
|
8 | 18 |
|
|
11 | 21 | 'log_evaluation',
|
12 | 22 | 'record_evaluation',
|
13 | 23 | 'reset_parameter',
|
| 24 | + 'progress_bar', |
14 | 25 | ]
|
15 | 26 |
|
16 | 27 | _EvalResultDict = Dict[str, Dict[str, List[Any]]]
|
@@ -413,3 +424,149 @@ def early_stopping(stopping_rounds: int, first_metric_only: bool = False, verbos
|
413 | 424 | The callback that activates early stopping.
|
414 | 425 | """
|
415 | 426 | return _EarlyStoppingCallback(stopping_rounds=stopping_rounds, first_metric_only=first_metric_only, verbose=verbose, min_delta=min_delta)
|
| 427 | + |
| 428 | + |
| 429 | +class _ProgressBarCallback: |
| 430 | + """Internal class to handle progress bar.""" |
| 431 | + tqdm_cls: "Type[tqdm.std.tqdm]" |
| 432 | + pbar: "tqdm.std.tqdm" | None |
| 433 | + |
| 434 | + def __init__( |
| 435 | + self, |
| 436 | + tqdm_cls: Literal[ |
| 437 | + "auto", |
| 438 | + "autonotebook", |
| 439 | + "std", |
| 440 | + "notebook", |
| 441 | + "asyncio", |
| 442 | + "keras", |
| 443 | + "dask", |
| 444 | + "tk", |
| 445 | + "gui", |
| 446 | + "rich", |
| 447 | + "contrib.slack", |
| 448 | + "contrib.discord", |
| 449 | + "contrib.telegram", |
| 450 | + "contrib.bells", |
| 451 | + ] |
| 452 | + | "Type[tqdm.std.tqdm]" = "auto", |
| 453 | + early_stopping_callback: Any | None = None, |
| 454 | + **tqdm_kwargs: Any, |
| 455 | + ) -> None: |
| 456 | + """Progress bar callback for LightGBM training. |
| 457 | +
|
| 458 | + Parameters |
| 459 | + ---------- |
| 460 | + tqdm_cls : Literal[ "auto", "autonotebook", "std", "notebook", "asyncio", "keras", "dask", "tk", "gui", "rich", "contrib.slack", "contrib.discord", "contrib.telegram", "contrib.bells", ] | Type[tqdm.std.tqdm], optional |
| 461 | + The tqdm class or module name, by default "auto" |
| 462 | + early_stopping_callback : _EarlyStoppingCallback | None, optional |
| 463 | + The early stopping callback, by default None |
| 464 | +
|
| 465 | + .. rubric:: Example |
| 466 | +
|
| 467 | + .. code-block:: python |
| 468 | + early_stopping_callback = early_stopping(stopping_rounds=50) |
| 469 | + callbacks = [early_stopping_callback, ProgressBarCallback(early_stopping_callback=early_stopping_callback)] |
| 470 | + estimator.fit(X_train, y_train, eval_set=[(X_test, y_test)], callbacks=callbacks) |
| 471 | + """ |
| 472 | + if isinstance(tqdm_cls, str): |
| 473 | + try: |
| 474 | + tqdm_module = importlib.import_module(f"tqdm.{tqdm_cls}") |
| 475 | + except ImportError as e: |
| 476 | + raise ImportError( |
| 477 | + f"tqdm needs to be installed to use tqdm.{tqdm_cls}") from e |
| 478 | + self.tqdm_cls = getattr(tqdm_module, "tqdm") |
| 479 | + else: |
| 480 | + self.tqdm_cls = tqdm_cls |
| 481 | + self.early_stopping_callback = early_stopping_callback |
| 482 | + self.tqdm_kwargs = tqdm_kwargs |
| 483 | + if "total" in tqdm_kwargs: |
| 484 | + warnings.warn("'total' in tqdm_kwargs is ignored.", UserWarning) |
| 485 | + self.pbar = None |
| 486 | + |
| 487 | + def _init(self, env: CallbackEnv) -> None: |
| 488 | + # create pbar on first call |
| 489 | + tqdm_kwargs = self.tqdm_kwargs.copy() |
| 490 | + tqdm_kwargs["total"] = env.end_iteration - env.begin_iteration |
| 491 | + self.pbar = self.tqdm_cls(**tqdm_kwargs) |
| 492 | + |
| 493 | + def __call__(self, env: CallbackEnv) -> None: |
| 494 | + if env.iteration == env.begin_iteration: |
| 495 | + self._init(env) |
| 496 | + assert self.pbar is not None |
| 497 | + |
| 498 | + # update postfix |
| 499 | + if len(env.evaluation_result_list) > 0: |
| 500 | + # If OrderedDict is not used, the order of display is disjointed and slightly difficult to see. |
| 501 | + # https://github.com/microsoft/LightGBM/blob/a97c444b4cf9d2755bd888911ce65ace1fe13e4b/python-package/lightgbm/callback.py#L56-66 |
| 502 | + if self.early_stopping_callback is not None: |
| 503 | + postfix = OrderedDict( |
| 504 | + [ |
| 505 | + ( |
| 506 | + f"{entry[0]}'s {entry[1]}", |
| 507 | + f"{entry[2]:g}{'=' if entry[2] == best_score else ('>' if cmp_op else '<')}{best_score:g}@{best_iter}it", |
| 508 | + ) |
| 509 | + for entry, cmp_op, best_score, best_iter in zip( |
| 510 | + env.evaluation_result_list, |
| 511 | + self.early_stopping_callback.cmp_op, |
| 512 | + self.early_stopping_callback.best_score, |
| 513 | + self.early_stopping_callback.best_iter, |
| 514 | + ) |
| 515 | + ] |
| 516 | + ) |
| 517 | + else: |
| 518 | + postfix = OrderedDict( |
| 519 | + [ |
| 520 | + (f"{entry[0]}'s {entry[1]}", f"{entry[2]:g}") |
| 521 | + for entry in env.evaluation_result_list |
| 522 | + ] |
| 523 | + ) |
| 524 | + self.pbar.set_postfix(ordered_dict=postfix, refresh=False) |
| 525 | + |
| 526 | + # update pbar |
| 527 | + self.pbar.update() |
| 528 | + self.pbar.refresh() |
| 529 | + |
| 530 | + |
| 531 | +def progress_bar(tqdm_cls: Literal[ |
| 532 | + "auto", |
| 533 | + "autonotebook", |
| 534 | + "std", |
| 535 | + "notebook", |
| 536 | + "asyncio", |
| 537 | + "keras", |
| 538 | + "dask", |
| 539 | + "tk", |
| 540 | + "gui", |
| 541 | + "rich", |
| 542 | + "contrib.slack", |
| 543 | + "contrib.discord", |
| 544 | + "contrib.telegram", |
| 545 | + "contrib.bells", |
| 546 | +] |
| 547 | + | "Type[tqdm.std.tqdm]" = "auto", |
| 548 | + early_stopping_callback: _EarlyStoppingCallback | None = None, |
| 549 | + **tqdm_kwargs: Any, |
| 550 | +) -> _ProgressBarCallback: |
| 551 | + """Progress bar callback for LightGBM training. |
| 552 | +
|
| 553 | + Parameters |
| 554 | + ---------- |
| 555 | + tqdm_cls : Literal[ "auto", "autonotebook", "std", "notebook", "asyncio", "keras", "dask", "tk", "gui", "rich", "contrib.slack", "contrib.discord", "contrib.telegram", "contrib.bells", ] | Type[tqdm.std.tqdm], optional |
| 556 | + The tqdm class or module name, by default "auto" |
| 557 | + early_stopping_callback : Any | None, optional |
| 558 | + The early stopping callback, by default None |
| 559 | +
|
| 560 | + .. rubric:: Example |
| 561 | +
|
| 562 | + .. code-block:: python |
| 563 | + early_stopping_callback = early_stopping(stopping_rounds=50) |
| 564 | + callbacks = [early_stopping_callback, progress_bar(early_stopping_callback=early_stopping_callback)] |
| 565 | + estimator.fit(X_train, y_train, eval_set=[(X_test, y_test)], callbacks=callbacks) |
| 566 | +
|
| 567 | + Returns |
| 568 | + ------- |
| 569 | + callback : _ProgressBarCallback |
| 570 | + The callback that displays the progress bar. |
| 571 | + """ |
| 572 | + return _ProgressBarCallback(tqdm_cls=tqdm_cls, early_stopping_callback=early_stopping_callback, **tqdm_kwargs) |
0 commit comments