Skip to content

Commit 19d3c53

Browse files
authored
Merge pull request #636 from mdekstrand/feature/notebook-logging
Improve logging and progress support
2 parents f91c5d3 + 57aa2e7 commit 19d3c53

File tree

15 files changed

+1152
-168
lines changed

15 files changed

+1152
-168
lines changed

docs/releases/2025.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ New Features (incremental)
5050

5151
* Many LensKit components (batch running, model training, etc.) now report
5252
progress the progress API in :mod:`lenskit.logging.progress`, and can be
53-
connected to TQDM or Rich.
53+
connected to Jupyter or Rich.
5454
* Added RBP top-N metric (:pr:`334`).
5555
* Added command-line tool to fetch datasets (:pr:`347`).
5656

lenskit/lenskit/logging/_console.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,13 +22,12 @@ class ConsoleHandler(Handler):
2222

2323
@property
2424
def supports_color(self) -> bool:
25-
return console.is_terminal and not console.no_color
25+
return (console.is_terminal or console.is_jupyter) and not console.no_color
2626

2727
def emit(self, record: LogRecord) -> None:
2828
try:
2929
fmt = self.format(record)
30-
print(fmt)
31-
# console.print(*self._decoder.decode(fmt))
30+
console.print(self._decoder.decode_line(fmt))
3231
except Exception:
3332
self.handleError(record)
3433

lenskit/lenskit/logging/config.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
import structlog
1616

17-
from ._console import ConsoleHandler, setup_console
17+
from ._console import ConsoleHandler, console, setup_console
1818
from .processors import format_timestamp, log_warning, remove_internal
1919
from .progress import set_progress_impl
2020
from .tracing import lenskit_filtering_logger
@@ -62,7 +62,7 @@ def notebook_logging(level: int = logging.INFO):
6262
"""
6363
cfg = LoggingConfig()
6464
cfg.level = level
65-
cfg.set_stream_mode("simple")
65+
cfg.progress_backend = "notebook"
6666
cfg.apply()
6767

6868

@@ -81,6 +81,7 @@ class LoggingConfig: # pragma: nocover
8181

8282
level: int = logging.INFO
8383
stream: Literal["full", "simple", "json"] = "full"
84+
progress_backend: Literal["notebook", "rich"] | None = None
8485
file: Path | None = None
8586
file_level: int | None = None
8687
file_format: LogFormat = "json"
@@ -108,6 +109,8 @@ def set_stream_mode(self, mode: Literal["full", "simple", "json"]):
108109
Configure the standard error stream mode.
109110
"""
110111
self.stream = mode
112+
if mode == "full":
113+
self.force_console = True
111114

112115
def set_verbose(self, verbose: bool | int = True):
113116
"""
@@ -157,6 +160,12 @@ def apply(self):
157160
term = logging.StreamHandler(sys.stderr)
158161
term.setLevel(self.level)
159162
proc_fmt = structlog.processors.JSONRenderer()
163+
elif console.is_jupyter:
164+
term = logging.StreamHandler(sys.stdout)
165+
term.setLevel(self.level)
166+
proc_fmt = structlog.dev.ConsoleRenderer(
167+
colors=self.stream == "full" and not console.no_color
168+
)
160169
else:
161170
term = ConsoleHandler()
162171
term.setLevel(self.level)
@@ -207,7 +216,9 @@ def apply(self):
207216

208217
root.setLevel(self.effective_level)
209218

210-
if self.stream == "full":
219+
if self.progress_backend is not None:
220+
set_progress_impl(self.progress_backend)
221+
elif self.stream == "full":
211222
set_progress_impl("rich")
212223

213224
warnings.showwarning = log_warning

lenskit/lenskit/logging/progress/_dispatch.py

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,29 +1,30 @@
1-
from functools import partial
1+
import warnings
22
from typing import Any, Callable, Literal, overload
33

44
from ._base import Progress
55

66
_backend: Callable[..., Progress] = Progress
77

88

9-
@overload
10-
def set_progress_impl(name: Literal["tqdm"], impl: Callable[..., Any] | None = None, /): ...
119
@overload
1210
def set_progress_impl(name: Literal["rich"]): ...
11+
@overload
12+
def set_progress_impl(name: Literal["notebook"]): ...
1313
def set_progress_impl(name: str | None, *options: Any):
14+
"""
15+
Set the progress bar implementation.
16+
"""
1417
global _backend
1518

1619
match name:
17-
case "tqdm":
18-
from tqdm.autonotebook import tqdm
19-
20-
from ._tqdm import TQDMProgress
21-
22-
impl = tqdm
23-
if options and options[0]:
24-
impl = options[0]
25-
26-
_backend = partial(TQDMProgress, impl)
20+
case "notebook":
21+
try:
22+
from ._notebook import JupyterProgress
23+
24+
_backend = JupyterProgress
25+
except ImportError:
26+
warnings.warn("notebook progress backend needs ipywidgets")
27+
_backend = Progress
2728

2829
case "rich":
2930
from ._rich import RichProgress
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
def field_format(name: str, fs: str | None):
2+
if fs:
3+
return "{%s:%s}" % (name, fs)
4+
else:
5+
return "{%s}" % (name,)
Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
"""
2+
Jupyter notebook progress support.
3+
"""
4+
5+
from __future__ import annotations
6+
7+
from time import perf_counter
8+
9+
import ipywidgets as widgets
10+
from humanize import metric
11+
from IPython.display import display
12+
13+
from ._base import Progress
14+
from ._formats import field_format
15+
16+
__all__ = ["JupyterProgress"]
17+
18+
19+
class JupyterProgress(Progress):
20+
"""
21+
Progress logging to Jupyter notebook widgets.
22+
"""
23+
24+
widget: widgets.IntProgress
25+
text: widgets.Label
26+
box: widgets.HBox
27+
total: int | None
28+
current: int
29+
_last_update: float = 0
30+
_field_format: str | None = None
31+
32+
def __init__(
33+
self,
34+
label: str | None,
35+
total: int | None,
36+
fields: dict[str, str | None],
37+
):
38+
self.current = 0
39+
self.total = total
40+
if total:
41+
self.widget = widgets.IntProgress(value=0, min=0, max=total, step=1)
42+
else:
43+
self.widget = widgets.IntProgress(value=1, min=0, max=1, step=1)
44+
self.widget.bar_style = "info"
45+
46+
pieces = []
47+
if label:
48+
pieces.append(widgets.Label(value=label))
49+
pieces.append(self.widget)
50+
51+
self.text = widgets.Label()
52+
if total:
53+
self.text.value = "0 / {}".format(metric(total))
54+
pieces.append(self.text)
55+
56+
self.box = widgets.HBox(pieces)
57+
display(self.box)
58+
59+
if fields:
60+
self._field_format = ", ".join(
61+
[f"{name}: {field_format(name, fs)}" for (name, fs) in fields.items()]
62+
)
63+
64+
def update(self, advance: int = 1, **kwargs: float | int | str):
65+
"""
66+
Update the progress bar.
67+
"""
68+
self.current += advance
69+
now = perf_counter()
70+
if now - self._last_update >= 0.1 or (self.total and self.current >= self.total):
71+
self.widget.value = self.current
72+
if self.total:
73+
if self.total >= 1000:
74+
txt = "{} / {}".format(metric(self.current), metric(self.total))
75+
else:
76+
txt = "{} / {}".format(self.current, self.total)
77+
else:
78+
txt = "{} / ?".format(metric(self.current))
79+
if self._field_format:
80+
txt += " ({})".format(self._field_format.format(**kwargs))
81+
self.text.value = txt
82+
83+
self._last_update = now
84+
# if self._field_format:
85+
# self.tqdm.set_postfix_str(self._field_format.format(kwargs))
86+
87+
def finish(self):
88+
"""
89+
Finish and clean up this progress bar. If the progresss bar is used as
90+
a context manager, this is automatically called on context exit.
91+
"""
92+
self.box.close()
93+
94+
def __enter__(self):
95+
return self
96+
97+
def __exit__(self, *args):
98+
self.finish()

lenskit/lenskit/logging/progress/_rich.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from .._console import console, get_live
2424
from .._proxy import get_logger
2525
from ._base import Progress
26+
from ._formats import field_format
2627

2728
_log = get_logger("lenskit.logging.progress")
2829
_pb_lock = Lock()
@@ -38,24 +39,34 @@ class RichProgress(Progress):
3839
uuid: UUID
3940
label: str
4041
total: int | None
41-
fields: dict[str, str | None]
4242
logger: structlog.stdlib.BoundLogger
43+
_field_format: str | None = None
4344
_task: TaskID | None = None
4445

4546
def __init__(self, label: str, total: int | None, fields: dict[str, str | None]):
4647
super().__init__()
4748
self.uuid = uuid4()
4849
self.label = label
4950
self.total = total
50-
self.fields = fields
5151

5252
self.logger = _log.bind(label=label, uuid=str(self.uuid))
5353

5454
self._task = _install_bar(self)
5555

56+
if fields:
57+
self._field_format = ", ".join(
58+
[
59+
f"[json.key]{name}[/json.key]: {field_format(name, fs)}"
60+
for (name, fs) in fields.items()
61+
]
62+
)
63+
5664
def update(self, advance: int = 1, **kwargs: float | int | str):
65+
extra = ""
66+
if self._field_format:
67+
extra = self._field_format.format(**kwargs)
5768
if _progress is not None:
58-
_progress.update(self._task, advance=advance, **kwargs) # type: ignore
69+
_progress.update(self._task, advance=advance, extra=extra) # type: ignore
5970

6071
def finish(self):
6172
_remove_bar(self)
@@ -78,12 +89,13 @@ def _install_bar(bar: RichProgress) -> TaskID | None:
7889
RateColumn(),
7990
TaskProgressColumn(),
8091
TimeRemainingColumn(),
92+
TextColumn("{task.fields[extra]}"),
8193
console=console,
8294
)
8395
live.update(_progress)
8496

8597
_active_bars[bar.uuid] = bar
86-
return _progress.add_task(bar.label, total=bar.total)
98+
return _progress.add_task(bar.label, total=bar.total, extra="")
8799

88100

89101
def _remove_bar(bar: RichProgress):

lenskit/lenskit/logging/progress/_tqdm.py

Lines changed: 0 additions & 52 deletions
This file was deleted.

lenskit/lenskit/metrics/bulk.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -229,7 +229,7 @@ def compute(
229229

230230
n = len(outputs)
231231
_log.info("computing %d listwise metrics for %d output lists", len(lms), n)
232-
with item_progress("lists", n) as pb:
232+
with item_progress("Measuring", n) as pb:
233233
for i, (key, out) in enumerate(outputs):
234234
list_test = test.lookup_projected(key)
235235
if out is None:

lenskit/lenskit/parallel/pool.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ def __init__(
6565
log.debug("persisting function")
6666
job = worker.WorkerData(func, model)
6767
job = shm_serialize(job, self.manager)
68-
log.info("setting up process pool")
68+
log.debug("setting up process pool")
6969
self.pool = ProcessPoolExecutor(n_jobs, ctx, worker.initalize, (job,))
7070
except Exception as e:
7171
self.manager.shutdown()

0 commit comments

Comments
 (0)