Skip to content

Commit c42b039

Browse files
mridul-sahuOrbax Authors
authored andcommitted
No public description
PiperOrigin-RevId: 834114429
1 parent 79b3896 commit c42b039

File tree

9 files changed

+305
-99
lines changed

9 files changed

+305
-99
lines changed

checkpoint/orbax/checkpoint/_src/testing/benchmarks/configs/pytree_checkpoint_benchmark_pathways.yaml

Lines changed: 17 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -22,22 +22,22 @@ checkpoint_config:
2222
# h: {dtype: "float32", shape: [256, 1048576], sharding: ["tensor", "data"]}
2323
# i: {dtype: "float32", shape: [128, 2097152], sharding: ["tensor", "data"]}
2424
# j: {dtype: "float32", shape: [64, 4194304], sharding: ["tensor", "data"]}
25-
# # k: {dtype: "float32", shape: [32, 8388608], sharding: ["tensor", "data"]}
26-
# # l: {dtype: "float32", shape: [16, 16777216], sharding: ["tensor", "data"]}
27-
# # m: {dtype: "float32", shape: [8, 33554432], sharding: ["tensor", "data"]}
28-
# # n: {dtype: "float32", shape: [4, 67108864], sharding: ["tensor", "data"]}
29-
# # o: {dtype: "float32", shape: [4, 4, 16777216], sharding: [null, "tensor", "data"]}
30-
# # p: {dtype: "float32", shape: [4, 8, 8388608], sharding: [null, "tensor", "data"]}
31-
# # q: {dtype: "float32", shape: [4, 16, 4194304], sharding: [null, "tensor", "data"]}
32-
# # r: {dtype: "float32", shape: [4, 32, 2097152], sharding: [null, "tensor", "data"]}
33-
# # s: {dtype: "float32", shape: [4, 64, 1048576], sharding: [null, "tensor", "data"]}
34-
# # t: {dtype: "float32", shape: [4, 128, 524288], sharding: [null, "tensor", "data"]}
35-
# # u: {dtype: "float32", shape: [4, 256, 262144], sharding: [null, "tensor", "data"]}
36-
# # v: {dtype: "float32", shape: [4, 512, 131072], sharding: [null, "tensor", "data"]}
37-
# # w: {dtype: "float32", shape: [4, 1024, 65536], sharding: [null, "tensor", "data"]}
38-
# # x: {dtype: "float32", shape: [4, 2048, 32768], sharding: [null, "tensor", "data"]}
39-
# # y: {dtype: "float32", shape: [4, 4096, 16384], sharding: [null, "tensor", "data"]}
40-
# # z: {dtype: "float32", shape: [4, 8192, 8192], sharding: [null, "tensor", "data"]}
25+
# k: {dtype: "float32", shape: [32, 8388608], sharding: ["tensor", "data"]}
26+
# l: {dtype: "float32", shape: [16, 16777216], sharding: ["tensor", "data"]}
27+
# m: {dtype: "float32", shape: [8, 33554432], sharding: ["tensor", "data"]}
28+
# n: {dtype: "float32", shape: [4, 67108864], sharding: ["tensor", "data"]}
29+
# o: {dtype: "float32", shape: [4, 4, 16777216], sharding: [null, "tensor", "data"]}
30+
# p: {dtype: "float32", shape: [4, 8, 8388608], sharding: [null, "tensor", "data"]}
31+
# q: {dtype: "float32", shape: [4, 16, 4194304], sharding: [null, "tensor", "data"]}
32+
# r: {dtype: "float32", shape: [4, 32, 2097152], sharding: [null, "tensor", "data"]}
33+
# s: {dtype: "float32", shape: [4, 64, 1048576], sharding: [null, "tensor", "data"]}
34+
# t: {dtype: "float32", shape: [4, 128, 524288], sharding: [null, "tensor", "data"]}
35+
# u: {dtype: "float32", shape: [4, 256, 262144], sharding: [null, "tensor", "data"]}
36+
# v: {dtype: "float32", shape: [4, 512, 131072], sharding: [null, "tensor", "data"]}
37+
# w: {dtype: "float32", shape: [4, 1024, 65536], sharding: [null, "tensor", "data"]}
38+
# x: {dtype: "float32", shape: [4, 2048, 32768], sharding: [null, "tensor", "data"]}
39+
# y: {dtype: "float32", shape: [4, 4096, 16384], sharding: [null, "tensor", "data"]}
40+
# z: {dtype: "float32", shape: [4, 8192, 8192], sharding: [null, "tensor", "data"]}
4141

4242
benchmarks:
4343
- generator: "orbax.checkpoint._src.testing.benchmarks.pytree_checkpoint_benchmark.PyTreeCheckpointBenchmark"
@@ -62,5 +62,4 @@ benchmarks:
6262
# - 4
6363
metric_tracemalloc_enabled: True
6464
metric_tensorstore_enabled: True
65-
use_colocated_python: [False]
66-
use_jax_array_handler: [False, True]
65+
use_colocated_python: [True, False]

checkpoint/orbax/checkpoint/_src/testing/benchmarks/core/config_parsing.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -193,5 +193,8 @@ def create_test_suite_from_config(
193193
generators.append(generator)
194194

195195
return core.TestSuite(
196-
name=suite_name, benchmarks_generators=generators, num_repeats=num_repeats
196+
name=suite_name,
197+
benchmarks_generators=generators,
198+
num_repeats=num_repeats,
199+
output_dir=output_dir,
197200
)

checkpoint/orbax/checkpoint/_src/testing/benchmarks/core/core.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -374,13 +374,15 @@ def __init__(
374374
self,
375375
name: str,
376376
benchmarks_generators: Sequence[BenchmarksGenerator],
377+
output_dir: str | None = None,
377378
skip_incompatible_mesh_configs: bool = True,
378379
num_repeats: int = 1,
379380
):
380381
self._name = name
381382
self._benchmarks_generators = benchmarks_generators
382383
self._skip_incompatible_mesh_configs = skip_incompatible_mesh_configs
383384
self._num_repeats = num_repeats
385+
self._output_dir = output_dir
384386
self._suite_metrics = metric_lib.MetricsManager(
385387
name=name, num_repeats=num_repeats
386388
)
@@ -422,12 +424,17 @@ def run(self) -> Sequence[TestResult]:
422424
result = benchmark.run(repeat_index=repeat_index)
423425
all_results.append(result)
424426
self._suite_metrics.add_result(
425-
benchmark.name, result.metrics, result.error
427+
benchmark.name, benchmark.options, result.metrics, result.error
426428
)
427429

428430
if not all_results:
429431
logging.warning("No benchmarks were run for this suite.")
430432

433+
if self._output_dir is not None:
434+
self._suite_metrics.export_to_tensorboard(
435+
epath.Path(self._output_dir) / "tensorboard"
436+
)
437+
431438
logging.info(self._suite_metrics.generate_report())
432439
multihost.sync_global_processes("test_suite:run_end")
433440
return all_results

checkpoint/orbax/checkpoint/_src/testing/benchmarks/core/core_test.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -344,6 +344,18 @@ def test_generator_mismatched_options(self):
344344

345345
class TestSuiteTest(parameterized.TestCase):
346346

347+
@mock.patch.object(metric_lib, 'MetricsManager')
348+
def test_init_with_output_dir(self, mock_metrics_manager):
349+
gen = MyGenerator(
350+
checkpoint_configs=[configs.CheckpointConfig()],
351+
options=MyBenchmarkOptions(opt1=1),
352+
)
353+
output_dir = '/tmp/foo'
354+
core.TestSuite(
355+
name='my_suite', benchmarks_generators=[gen], output_dir=output_dir
356+
)
357+
mock_metrics_manager.assert_called_once_with(name='my_suite', num_repeats=1)
358+
347359
@mock.patch.object(core.Benchmark, 'run')
348360
def test_run(self, mock_benchmark_run):
349361
gen = MyGenerator(

checkpoint/orbax/checkpoint/_src/testing/benchmarks/core/metric.py

Lines changed: 175 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@
2626
from typing import Any
2727

2828
from absl import logging
29+
from clu import metric_writers
30+
from etils import epath
2931
import numpy as np
3032
from orbax.checkpoint._src.multihost import multihost
3133
import psutil
@@ -437,24 +439,146 @@ def report(self):
437439
logging.info("\n".join(report_lines))
438440

439441

442+
class _MetricsCollector:
443+
"""Internal context manager to collect specified metrics."""
444+
445+
def __init__(
446+
self, metrics_obj: Metrics, operation_name: str, metric_keys: list[str]
447+
):
448+
self.metrics_obj = metrics_obj
449+
self.operation_name = operation_name
450+
self._metrics: dict[str, BaseMetric] = {}
451+
452+
for key in metric_keys:
453+
if key in METRIC_REGISTRY:
454+
metric_class = METRIC_REGISTRY[key]
455+
self._metrics[key] = metric_class(operation_name)
456+
else:
457+
logging.warning("Unknown metric key: %s", key)
458+
459+
def __enter__(self):
460+
for metric in self._metrics.values():
461+
metric.start()
462+
return self
463+
464+
def __exit__(self, *exc):
465+
for key, metric in self._metrics.items():
466+
try:
467+
metric_results = metric.stop()
468+
self.metrics_obj._add_results(metric.name, key, metric_results)
469+
except Exception as e: # pylint: disable=broad-exception-caught
470+
logging.exception("Error stopping metric %s: %s", metric.name, e)
471+
472+
473+
################################################################################
474+
# Aggregation and Reporting
475+
################################################################################
476+
477+
478+
@dataclasses.dataclass
479+
class AggregatedStats:
480+
"""Statistics aggregated over multiple benchmark repetitions.
481+
482+
Attributes:
483+
mean: Mean value.
484+
std: Standard deviation.
485+
min: Minimum value.
486+
max: Maximum value.
487+
count: Number of values aggregated.
488+
"""
489+
490+
mean: float
491+
std: float
492+
min: float
493+
max: float
494+
count: int
495+
496+
440497
class MetricsManager:
441-
"""Manages metrics aggregation across multiple benchmark runs."""
498+
"""Manages metrics aggregation and reporting for a test suite.
499+
500+
This class collects metrics from multiple benchmark runs and repetitions,
501+
computes aggregate statistics (mean, std, min, max), generates a
502+
human-readable report for logging, and exports metrics to TensorBoard
503+
if configured.
504+
"""
505+
506+
def __init__(
507+
self,
508+
name: str,
509+
num_repeats: int,
510+
):
511+
"""Initializes the MetricsManager.
442512
443-
def __init__(self, name: str, num_repeats: int):
513+
Args:
514+
name: The name of the test suite.
515+
num_repeats: The number of repetitions for each benchmark configuration.
516+
"""
444517
self._name = name
445518
self._num_repeats = num_repeats
446519
self._runs: dict[str, list[tuple[Metrics, Exception | None]]] = (
447520
collections.defaultdict(list)
448521
)
522+
self._benchmark_options: dict[str, Any] = {}
449523

450524
def add_result(
451-
self, benchmark_name: str, metrics: Metrics, error: Exception | None
525+
self,
526+
benchmark_name: str,
527+
options: Any,
528+
metrics: Metrics,
529+
error: Exception | None,
452530
):
453-
"""Adds a result from a single benchmark run."""
531+
"""Adds metrics from a single benchmark run/repetition.
532+
533+
Args:
534+
benchmark_name: The name of the benchmark configuration.
535+
options: The BenchmarkOptions used for this run.
536+
metrics: The Metrics object containing results for this run.
537+
error: An exception if the run failed, otherwise None.
538+
"""
454539
self._runs[benchmark_name].append((metrics, error))
540+
if benchmark_name not in self._benchmark_options:
541+
self._benchmark_options[benchmark_name] = options
542+
543+
def _aggregate_metrics(
544+
self, results: list[tuple[Metrics, Exception | None]]
545+
) -> tuple[dict[str, AggregatedStats], dict[str, str]]:
546+
"""Computes aggregate stats (mean, std, etc.) for successful runs.
547+
548+
Args:
549+
results: A list of (Metrics, error) tuples for a benchmark configuration.
550+
551+
Returns:
552+
A tuple containing:
553+
- A dict mapping metric keys to AggregatedStats.
554+
- A dict mapping metric keys to their units.
555+
"""
556+
metrics_collector = collections.defaultdict(list)
557+
metric_units = {}
558+
for metrics, error in results:
559+
if error is None:
560+
for key, (value, unit) in metrics.results.items():
561+
if isinstance(value, (int, float)):
562+
metrics_collector[key].append(value)
563+
metric_units[key] = unit
564+
565+
aggregated_stats_dict = {}
566+
for key, values in metrics_collector.items():
567+
aggregated_stats_dict[key] = AggregatedStats(
568+
mean=np.mean(values),
569+
std=np.std(values),
570+
min=np.min(values),
571+
max=np.max(values),
572+
count=len(values),
573+
)
574+
return aggregated_stats_dict, metric_units
455575

456576
def generate_report(self) -> str:
457-
"""Generates a report with statistics from the test results."""
577+
"""Generates a final string report containing aggregated metrics.
578+
579+
Returns:
580+
A formatted string containing the full benchmark report.
581+
"""
458582
report_lines = []
459583
title = f" Test Suite Report: {self._name} "
460584
report_lines.append(f"\n{title:=^80}")
@@ -476,35 +600,29 @@ def generate_report(self) -> str:
476600
f" {passed_runs}, Failed: {failed_runs}"
477601
)
478602

603+
# Aggregate metrics, add to report, and write aggregates to TensorBoard
479604
if self._num_repeats > 1:
480605
report_lines.append("\n" + "-" * 80)
481606
report_lines.append("--- Aggregated Metrics per Benchmark ---")
482607
for benchmark_name, results in self._runs.items():
483608
if not results:
484609
continue
485610
report_lines.append(f"\nBenchmark: {benchmark_name}")
486-
metrics_collector = collections.defaultdict(list)
487-
metric_units = {}
488-
for metrics, error in results:
489-
if error is None:
490-
for key, (value, unit) in metrics.results.items():
491-
if isinstance(value, (int, float)):
492-
metrics_collector[key].append(value)
493-
metric_units[key] = unit
494-
if not metrics_collector:
611+
612+
aggregated_stats_dict, metric_units = self._aggregate_metrics(results)
613+
614+
if not aggregated_stats_dict:
495615
report_lines.append(" No successful runs to aggregate.")
496616
continue
497-
for key, values in metrics_collector.items():
617+
618+
for key, stats in aggregated_stats_dict.items():
498619
unit = metric_units[key]
499-
mean = np.mean(values)
500-
stdev = np.std(values)
501-
min_val = np.min(values)
502-
max_val = np.max(values)
503620
report_lines.append(
504-
f" {key}: {mean:.4f} +/- {stdev:.4f} {unit} (min:"
505-
f" {min_val:.4f}, max: {max_val:.4f}, n={len(values)})"
621+
f" {key}: {stats.mean:.4f} +/- {stats.std:.4f} {unit} (min:"
622+
f" {stats.min:.4f}, max: {stats.max:.4f}, n={stats.count})"
506623
)
507624

625+
# Report failed runs
508626
if failed_runs > 0:
509627
report_lines.append("\n" + "-" * 80)
510628
report_lines.append("--- Failed Runs ---")
@@ -516,36 +634,42 @@ def generate_report(self) -> str:
516634
if len(error_repr) > 1000:
517635
error_repr = error_repr[:1000] + "..."
518636
report_lines.append(f"Test: {metrics.name}, Error: {error_repr}")
637+
519638
report_lines.append("\n" + "=" * 80)
520639
return "\n".join(report_lines)
521640

522-
523-
class _MetricsCollector:
524-
"""Internal context manager to collect specified metrics."""
525-
526-
def __init__(
527-
self, metrics_obj: Metrics, operation_name: str, metric_keys: list[str]
528-
):
529-
self.metrics_obj = metrics_obj
530-
self.operation_name = operation_name
531-
self._metrics: dict[str, BaseMetric] = {}
532-
533-
for key in metric_keys:
534-
if key in METRIC_REGISTRY:
535-
metric_class = METRIC_REGISTRY[key]
536-
self._metrics[key] = metric_class(operation_name)
537-
else:
538-
logging.warning("Unknown metric key: %s", key)
539-
540-
def __enter__(self):
541-
for metric in self._metrics.values():
542-
metric.start()
543-
return self
544-
545-
def __exit__(self, *exc):
546-
for key, metric in self._metrics.items():
547-
try:
548-
metric_results = metric.stop()
549-
self.metrics_obj._add_results(metric.name, key, metric_results)
550-
except Exception as e: # pylint: disable=broad-exception-caught
551-
logging.exception("Error stopping metric %s: %s", metric.name, e)
641+
def export_to_tensorboard(self, tensorboard_dir: epath.Path):
642+
"""Exports metrics to TensorBoard."""
643+
logging.info("Writing per-repetition metrics to TensorBoard...")
644+
for benchmark_name, results in self._runs.items():
645+
is_primary_host = multihost.process_index() == 0
646+
writer = metric_writers.create_default_writer(
647+
tensorboard_dir,
648+
just_logging=not is_primary_host,
649+
collection=benchmark_name,
650+
)
651+
# Write metrics for each repetition
652+
for i, (metrics, error) in enumerate(results):
653+
if error is None:
654+
for key, (value, unit) in metrics.results.items():
655+
tag = f'{key}_{unit.replace("/", "_")}'
656+
if isinstance(value, (int, float)):
657+
writer.write_scalars(step=i, scalars={tag: value})
658+
else:
659+
writer.write_texts(step=i, texts={tag: str(value)})
660+
else:
661+
tag = "error"
662+
writer.write_texts(step=i, texts={tag: f"<pre>{repr(error)}</pre>"})
663+
# Write benchmark options as text
664+
if self._benchmark_options[benchmark_name]:
665+
writer.write_texts(
666+
step=0,
667+
texts={
668+
"options": (
669+
f"<pre>{repr(self._benchmark_options[benchmark_name])}</pre>"
670+
)
671+
},
672+
)
673+
writer.flush()
674+
writer.close()
675+
logging.info("Finished writing metrics to TensorBoard.")

0 commit comments

Comments
 (0)