Skip to content

Commit f429966

Browse files
committed
adjust the printing, so it doesn't print inf/nans to save space.
also correct the doc for the printing function
1 parent 0417a53 commit f429966

2 files changed

Lines changed: 75 additions & 74 deletions

File tree

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
77
[Unreleased]
88
### Added
99
### Changed
10+
- Show the ETA when sampling
1011
### Fixed
1112

1213
[3.0.0 - 2025-10-04]

py/dynesty/utils.py

Lines changed: 74 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -391,7 +391,8 @@ def _update_tqdm_eta_from_dlogz(pbar,
391391
state = {}
392392
setattr(pbar, "_dynesty_eta_state", state)
393393

394-
# Dynamic batches: estimate completion from progress within [logl_min,logl_max]
394+
# Dynamic batches: estimate completion from progress
395+
# within [logl_min,logl_max]
395396
# only when both bounds are finite; otherwise fall back to dlogz trend.
396397
if (nbatch is not None and np.isfinite(loglstar) and np.isfinite(logl_min)
397398
and np.isfinite(logl_max) and (logl_max > logl_min)):
@@ -462,7 +463,7 @@ def _update_tqdm_eta_from_dlogz(pbar,
462463
pbar.total = max(niter + int(np.ceil(rem_iters)), pbar.n + 1)
463464

464465

465-
def print_fn(results,
466+
def print_fn(itresult,
466467
niter,
467468
ncall,
468469
add_live_it=None,
@@ -478,24 +479,9 @@ def print_fn(results,
478479
Parameters
479480
----------
480481
481-
results : tuple
482-
Collection of variables output from the current state of the sampler.
483-
Currently includes:
484-
(1) particle index,
485-
(2) unit cube position,
486-
(3) parameter position,
487-
(4) ln(likelihood),
488-
(5) ln(volume),
489-
(6) ln(weight),
490-
(7) ln(evidence),
491-
(8) Var[ln(evidence)],
492-
(9) information,
493-
(10) number of (current) function calls,
494-
(11) iteration when the point was originally proposed,
495-
(12) index of the bounding object originally proposed from,
496-
(13) index of the bounding object active at a given iteration,
497-
(14) cumulative efficiency, and
498-
(15) estimated remaining ln(evidence).
482+
itresult : IteratorResult or IteratorResultShort
483+
Single iterator output record from the sampler loop that is used
484+
for progress/status printing.
499485
500486
niter : int
501487
The current iteration of the sampler.
@@ -528,7 +514,7 @@ def print_fn(results,
528514
529515
"""
530516
if pbar is None:
531-
print_fn_fallback(results,
517+
print_fn_fallback(itresult,
532518
niter,
533519
ncall,
534520
add_live_it=add_live_it,
@@ -539,7 +525,7 @@ def print_fn(results,
539525
logl_max=logl_max)
540526
else:
541527
print_fn_tqdm(pbar,
542-
results,
528+
itresult,
543529
niter,
544530
ncall,
545531
add_live_it=add_live_it,
@@ -550,7 +536,7 @@ def print_fn(results,
550536
logl_max=logl_max)
551537

552538

553-
def get_print_fn_args(results,
539+
def get_print_fn_args(itresult,
554540
niter,
555541
ncall,
556542
add_live_it=None,
@@ -559,55 +545,69 @@ def get_print_fn_args(results,
559545
nbatch=None,
560546
logl_min=-np.inf,
561547
logl_max=np.inf):
562-
# Extract results at the current iteration.
563-
loglstar = results.loglstar
564-
logz = results.logz
565-
logzvar = results.logzvar
566-
delta_logz = results.delta_logz
567-
bounditer = results.bounditer
568-
nc = results.nc
569-
eff = results.eff
570-
571-
# Adjusting outputs for printing.
572-
if delta_logz > 1e6:
573-
delta_logz = np.inf
574-
if logzvar >= 0. and logzvar <= 1e6:
575-
logzerr = np.sqrt(logzvar)
576-
else:
577-
logzerr = np.nan
578-
if logz <= -1e6:
579-
logz = -np.inf
580-
if loglstar <= -1e6:
581-
loglstar = -np.inf
582-
583-
# Constructing output.
584-
long_str = []
585-
# long_str.append("iter: {:d}".format(niter))
586-
if add_live_it is not None:
587-
long_str.append("+{:d}".format(add_live_it))
548+
"""
549+
Build preformatted status strings for progress printing.
550+
551+
Parameters
552+
----------
553+
itresult : IteratorResult or IteratorResultShort
554+
Single iterator output record from the sampler loop.
555+
"""
556+
loglstar_val = itresult.loglstar
557+
logz_val = itresult.logz
558+
delta_logz_val = itresult.delta_logz
559+
logzvar = itresult.logzvar
560+
561+
loglstar = -np.inf if loglstar_val <= -1e6 else loglstar_val
562+
logz = -np.inf if logz_val <= -1e6 else logz_val
563+
delta_logz = np.inf if delta_logz_val > 1e6 else delta_logz_val
564+
logzerr = np.sqrt(logzvar) if 0. <= logzvar <= 1e6 else np.nan
565+
566+
long_str = [f"+{add_live_it:d}"] if add_live_it is not None else []
588567
short_str = list(long_str)
589568
if nbatch is not None:
590-
long_str.append("batch: {:d}".format(nbatch))
591-
long_str.append("bound: {:d}".format(bounditer))
592-
long_str.append("nc: {:d}".format(nc))
593-
long_str.append("ncall: {:d}".format(ncall))
594-
long_str.append("eff(%): {:6.3f}".format(eff))
595-
short_str.append(long_str[-1])
596-
long_str.append("loglstar: {:6.3f} < {:6.3f} < {:6.3f}".format(
597-
logl_min, loglstar, logl_max))
598-
short_str.append("logl*: {:6.1f}<{:6.1f}<{:6.1f}".format(
599-
logl_min, loglstar, logl_max))
600-
long_str.append("logz: {:6.3f} +/- {:6.3f}".format(logz, logzerr))
601-
short_str.append("logz: {:6.1f}+/-{:.1f}".format(logz, logzerr))
602-
mid_str = list(short_str)
603-
show_dlogz = (dlogz is not None and
604-
(nbatch is None or nbatch == 0 or stop_val is None))
569+
long_str.append(f"batch: {nbatch:d}")
570+
long_str.extend([
571+
f"bound: {itresult.bounditer:d}",
572+
f"nc: {itresult.nc:d}",
573+
f"ncall: {ncall:d}",
574+
])
575+
eff_str = f"eff(%): {itresult.eff:6.3f}"
576+
long_str.append(eff_str)
577+
short_str.append(eff_str)
578+
579+
finite_logl_min = np.isfinite(logl_min)
580+
finite_logl_max = np.isfinite(logl_max)
581+
if finite_logl_min:
582+
long_logl = f"loglstar: {logl_min:6.3f} < {loglstar:6.3f}"
583+
short_logl = f"logl*: {logl_min:6.1f}<{loglstar:6.1f}"
584+
else:
585+
long_logl = f"loglstar: {loglstar:6.3f}"
586+
short_logl = f"logl*: {loglstar:6.1f}"
587+
if finite_logl_max:
588+
long_logl += f" < {logl_max:6.3f}"
589+
short_logl += f"<{logl_max:6.1f}"
590+
long_str.append(long_logl)
591+
short_str.append(short_logl)
592+
593+
long_logz = f"logz: {logz:6.3f}"
594+
short_logz = f"logz: {logz:6.1f}"
595+
if not np.isnan(logzerr):
596+
long_logz += f" +/- {logzerr:6.3f}"
597+
short_logz += f"+/-{logzerr:.1f}"
598+
long_str.append(long_logz)
599+
short_str.append(short_logz)
600+
601+
show_dlogz = (dlogz is not None
602+
and (nbatch is None or nbatch == 0 or stop_val is None))
605603
if show_dlogz:
606-
long_str.append("dlogz: {:6.3f} > {:6.3f}".format(delta_logz, dlogz))
607-
mid_str.append("dlogz: {:6.1f}>{:6.1f}".format(delta_logz, dlogz))
604+
long_tail = f"dlogz: {delta_logz:6.3f} > {dlogz:6.3f}"
605+
mid_tail = f"dlogz: {delta_logz:6.1f}>{dlogz:6.1f}"
608606
else:
609-
long_str.append("stop: {:6.3f}".format(stop_val))
610-
mid_str.append("stop: {:6.3f}".format(stop_val))
607+
long_tail = f"stop: {stop_val:6.3f}"
608+
mid_tail = f"stop: {stop_val:6.3f}"
609+
long_str.append(long_tail)
610+
mid_str = short_str + [mid_tail]
611611

612612
return PrintFnArgs(niter=niter,
613613
short_str=short_str,
@@ -616,7 +616,7 @@ def get_print_fn_args(results,
616616

617617

618618
def print_fn_tqdm(pbar,
619-
results,
619+
itresult,
620620
niter,
621621
ncall,
622622
add_live_it=None,
@@ -628,7 +628,7 @@ def print_fn_tqdm(pbar,
628628
"""
629629
This is a function that does the status printing using tqdm module
630630
"""
631-
fn_args = get_print_fn_args(results,
631+
fn_args = get_print_fn_args(itresult,
632632
niter,
633633
ncall,
634634
add_live_it=add_live_it,
@@ -640,17 +640,17 @@ def print_fn_tqdm(pbar,
640640

641641
_update_tqdm_eta_from_dlogz(pbar,
642642
fn_args.niter,
643-
results.delta_logz,
643+
itresult.delta_logz,
644644
dlogz,
645645
nbatch=nbatch,
646-
loglstar=results.loglstar,
646+
loglstar=itresult.loglstar,
647647
logl_min=logl_min,
648648
logl_max=logl_max)
649649
pbar.set_postfix_str(" | ".join(fn_args.long_str), refresh=False)
650650
pbar.update(fn_args.niter - pbar.n)
651651

652652

653-
def print_fn_fallback(results,
653+
def print_fn_fallback(itresult,
654654
niter,
655655
ncall,
656656
add_live_it=None,
@@ -663,7 +663,7 @@ def print_fn_fallback(results,
663663
This is a function that does the status printing using just
664664
standard printing into the console
665665
"""
666-
fn_args = get_print_fn_args(results,
666+
fn_args = get_print_fn_args(itresult,
667667
niter,
668668
ncall,
669669
add_live_it=add_live_it,

0 commit comments

Comments
 (0)