Skip to content

Commit 8cb9ca7

Browse files
mridul-sahuOrbax Authors
authored andcommitted
No public description
PiperOrigin-RevId: 833407698
1 parent 4b5d0b4 commit 8cb9ca7

File tree

4 files changed

+167
-40
lines changed

4 files changed

+167
-40
lines changed

checkpoint/orbax/checkpoint/_src/testing/benchmarks/core/core.py

Lines changed: 14 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -381,44 +381,17 @@ def __init__(
381381
self._benchmarks_generators = benchmarks_generators
382382
self._skip_incompatible_mesh_configs = skip_incompatible_mesh_configs
383383
self._num_repeats = num_repeats
384-
385-
def _generate_report(self, results: Sequence[TestResult]) -> str:
386-
"""Generates a report from the test results."""
387-
passed_count = 0
388-
failed_tests = []
389-
for result in results:
390-
if result.is_successful():
391-
passed_count += 1
392-
else:
393-
failed_tests.append(result)
394-
395-
failed_count = len(failed_tests)
396-
report_lines = []
397-
title = f" Test Suite Report: {self._name} "
398-
report_lines.append(f"\n{title:=^80}")
399-
report_lines.append(f"Total tests run: {len(results)}")
400-
report_lines.append(f"Passed: {passed_count}")
401-
report_lines.append(f"Failed: {failed_count}")
402-
403-
if failed_count > 0:
404-
report_lines.append("-" * 80)
405-
report_lines.append("--- Failed Tests ---")
406-
for result in failed_tests:
407-
error_repr = repr(result.error)
408-
# Limit error length to avoid flooding logs.
409-
if len(error_repr) > 1000:
410-
error_repr = error_repr[:1000] + "..."
411-
report_lines.append(f"Test: {result.metrics.name}, Error: {error_repr}")
412-
report_lines.append("=" * 80)
413-
return "\n".join(report_lines)
384+
self._suite_metrics = metric_lib.MetricsManager(
385+
name=name, num_repeats=num_repeats
386+
)
414387

415388
def run(self) -> Sequence[TestResult]:
416389
"""Runs all benchmarks in the suite sequentially."""
417390
logging.info(
418391
"\n%s Running Test Suite: %s %s", "=" * 25, self._name, "=" * 25
419392
)
420393

421-
results = []
394+
all_results = []
422395
for i, generator in enumerate(self._benchmarks_generators):
423396
logging.info(
424397
"\n%s Running Generator %d: %s %s",
@@ -432,7 +405,8 @@ def run(self) -> Sequence[TestResult]:
432405
)
433406
if not generated_benchmarks:
434407
logging.warning(
435-
"Generator %s produced no benchmarks.", generator.__class__.__name__
408+
"Generator %s produced no benchmarks.",
409+
generator.__class__.__name__,
436410
)
437411
continue
438412

@@ -445,11 +419,15 @@ def run(self) -> Sequence[TestResult]:
445419
i + 1,
446420
self._num_repeats,
447421
)
448-
results.append(benchmark.run(repeat_index=repeat_index))
422+
result = benchmark.run(repeat_index=repeat_index)
423+
all_results.append(result)
424+
self._suite_metrics.add_result(
425+
benchmark.name, result.metrics, result.error
426+
)
449427

450-
if not results:
428+
if not all_results:
451429
logging.warning("No benchmarks were run for this suite.")
452430

453-
logging.info(self._generate_report(results))
431+
logging.info(self._suite_metrics.generate_report())
454432
multihost.sync_global_processes("test_suite:run_end")
455-
return results
433+
return all_results

checkpoint/orbax/checkpoint/_src/testing/benchmarks/core/core_test.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -452,10 +452,9 @@ def test_fn(self, test_context: core.TestContext) -> core.TestResult:
452452
report_log = report_log_call[0][0]
453453

454454
self.assertIn(' Test Suite Report: report_suite ', report_log)
455-
self.assertIn('Total tests run: 3', report_log)
456-
self.assertIn('Passed: 2', report_log)
457-
self.assertIn('Failed: 1', report_log)
458-
self.assertIn('--- Failed Tests ---', report_log)
455+
self.assertIn('Total benchmark configurations: 3', report_log)
456+
self.assertIn('Total runs (1 repeats): 3, Passed: 2, Failed: 1', report_log)
457+
self.assertIn('--- Failed Runs ---', report_log)
459458
self.assertIn("Error: ValueError('opt1=2, opt2=b failed')", report_log)
460459

461460

checkpoint/orbax/checkpoint/_src/testing/benchmarks/core/metric.py

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
"""Metric classes for benchmarking."""
1616

17+
import collections
1718
from collections.abc import MutableMapping
1819
import contextlib
1920
import dataclasses
@@ -25,6 +26,7 @@
2526
from typing import Any
2627

2728
from absl import logging
29+
import numpy as np
2830
from orbax.checkpoint._src.multihost import multihost
2931
import psutil
3032
import tensorstore as ts
@@ -435,6 +437,89 @@ def report(self):
435437
logging.info("\n".join(report_lines))
436438

437439

440+
class MetricsManager:
441+
"""Manages metrics aggregation across multiple benchmark runs."""
442+
443+
def __init__(self, name: str, num_repeats: int):
444+
self._name = name
445+
self._num_repeats = num_repeats
446+
self._runs: dict[str, list[tuple[Metrics, Exception | None]]] = (
447+
collections.defaultdict(list)
448+
)
449+
450+
def add_result(
451+
self, benchmark_name: str, metrics: Metrics, error: Exception | None
452+
):
453+
"""Adds a result from a single benchmark run."""
454+
self._runs[benchmark_name].append((metrics, error))
455+
456+
def generate_report(self) -> str:
457+
"""Generates a report with statistics from the test results."""
458+
report_lines = []
459+
title = f" Test Suite Report: {self._name} "
460+
report_lines.append(f"\n{title:=^80}")
461+
462+
total_runs = 0
463+
passed_runs = 0
464+
failed_runs = 0
465+
for _, results in self._runs.items():
466+
total_runs += len(results)
467+
for _, error in results:
468+
if error is None:
469+
passed_runs += 1
470+
else:
471+
failed_runs += 1
472+
473+
report_lines.append(f"Total benchmark configurations: {len(self._runs)}")
474+
report_lines.append(
475+
f"Total runs ({self._num_repeats} repeats): {total_runs}, Passed:"
476+
f" {passed_runs}, Failed: {failed_runs}"
477+
)
478+
479+
if self._num_repeats > 1:
480+
report_lines.append("\n" + "-" * 80)
481+
report_lines.append("--- Aggregated Metrics per Benchmark ---")
482+
for benchmark_name, results in self._runs.items():
483+
if not results:
484+
continue
485+
report_lines.append(f"\nBenchmark: {benchmark_name}")
486+
metrics_collector = collections.defaultdict(list)
487+
metric_units = {}
488+
for metrics, error in results:
489+
if error is None:
490+
for key, (value, unit) in metrics.results.items():
491+
if isinstance(value, (int, float)):
492+
metrics_collector[key].append(value)
493+
metric_units[key] = unit
494+
if not metrics_collector:
495+
report_lines.append(" No successful runs to aggregate.")
496+
continue
497+
for key, values in metrics_collector.items():
498+
unit = metric_units[key]
499+
mean = np.mean(values)
500+
stdev = np.std(values)
501+
min_val = np.min(values)
502+
max_val = np.max(values)
503+
report_lines.append(
504+
f" {key}: {mean:.4f} +/- {stdev:.4f} {unit} (min:"
505+
f" {min_val:.4f}, max: {max_val:.4f}, n={len(values)})"
506+
)
507+
508+
if failed_runs > 0:
509+
report_lines.append("\n" + "-" * 80)
510+
report_lines.append("--- Failed Runs ---")
511+
for _, results in self._runs.items():
512+
for metrics, error in results:
513+
if error is not None:
514+
error_repr = repr(error)
515+
# Limit error length to avoid flooding logs.
516+
if len(error_repr) > 1000:
517+
error_repr = error_repr[:1000] + "..."
518+
report_lines.append(f"Test: {metrics.name}, Error: {error_repr}")
519+
report_lines.append("\n" + "=" * 80)
520+
return "\n".join(report_lines)
521+
522+
438523
class _MetricsCollector:
439524
"""Internal context manager to collect specified metrics."""
440525

checkpoint/orbax/checkpoint/_src/testing/benchmarks/core/metric_test.py

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -151,5 +151,70 @@ def test_all_metrics(self):
151151
self.assertIn('test_metric_tensorstore_diff_count', metrics.results)
152152

153153

154+
class MetricsManagerTest(parameterized.TestCase):
155+
156+
def test_add_result_and_generate_report_no_repeats(self):
157+
manager = metric_lib.MetricsManager(name='Suite', num_repeats=1)
158+
metrics1 = metric_lib.Metrics()
159+
metrics1.results['op1_time_duration'] = (1.0, 's')
160+
manager.add_result('bench1', metrics1, None)
161+
162+
metrics2 = metric_lib.Metrics()
163+
metrics2.results['op1_time_duration'] = (2.0, 's')
164+
manager.add_result('bench2', metrics2, ValueError('failure'))
165+
166+
report = manager.generate_report()
167+
self.assertIn('Suite', report)
168+
self.assertIn('Total benchmark configurations: 2', report)
169+
self.assertIn('Total runs (1 repeats): 2, Passed: 1, Failed: 1', report)
170+
self.assertNotIn('Aggregated Metrics', report)
171+
self.assertIn('Failed Runs', report)
172+
self.assertIn("Error: ValueError('failure')", report)
173+
174+
def test_generate_report_with_repeats_and_aggregation(self):
175+
manager = metric_lib.MetricsManager(name='Suite', num_repeats=3)
176+
177+
# Benchmark 1, Run 1
178+
m1r1 = metric_lib.Metrics()
179+
m1r1.results['op_time_duration'] = (1.0, 's')
180+
m1r1.results['op_rss_diff'] = (10.0, 'MB')
181+
manager.add_result('bench1', m1r1, None)
182+
# Benchmark 1, Run 2
183+
m1r2 = metric_lib.Metrics()
184+
m1r2.results['op_time_duration'] = (1.2, 's')
185+
m1r2.results['op_rss_diff'] = (12.0, 'MB')
186+
manager.add_result('bench1', m1r2, None)
187+
# Benchmark 1, Run 3 (Failed)
188+
m1r3 = metric_lib.Metrics()
189+
manager.add_result('bench1', m1r3, RuntimeError('Run 3 failed'))
190+
191+
report = manager.generate_report()
192+
193+
self.assertIn('Suite', report)
194+
self.assertIn('Total benchmark configurations: 1', report)
195+
self.assertIn('Total runs (3 repeats): 3, Passed: 2, Failed: 1', report)
196+
self.assertIn('Aggregated Metrics', report)
197+
self.assertIn('Benchmark: bench1', report)
198+
# mean=1.1, std=0.1, min=1.0, max=1.2
199+
self.assertIn(
200+
'op_time_duration: 1.1000 +/- 0.1000 s (min: 1.0000, max: 1.2000, n=2)',
201+
report,
202+
)
203+
# mean=11.0, std=1.0, min=10.0, max=12.0
204+
self.assertIn(
205+
'op_rss_diff: 11.0000 +/- 1.0000 MB (min: 10.0000, max: 12.0000, n=2)',
206+
report,
207+
)
208+
self.assertIn('Failed Runs', report)
209+
self.assertIn("Error: RuntimeError('Run 3 failed')", report)
210+
211+
def test_generate_report_no_successful_runs_for_aggregation(self):
212+
manager = metric_lib.MetricsManager(name='Suite', num_repeats=2)
213+
manager.add_result('bench1', metric_lib.Metrics(), ValueError('1'))
214+
manager.add_result('bench1', metric_lib.Metrics(), ValueError('2'))
215+
report = manager.generate_report()
216+
self.assertIn('No successful runs to aggregate', report)
217+
218+
154219
if __name__ == '__main__':
155220
absltest.main()

0 commit comments

Comments
 (0)