2626from typing import Any
2727
2828from absl import logging
29+ from clu import metric_writers
30+ from etils import epath
2931import numpy as np
3032from orbax .checkpoint ._src .multihost import multihost
3133import psutil
@@ -437,24 +439,146 @@ def report(self):
437439 logging .info ("\n " .join (report_lines ))
438440
439441
442+ class _MetricsCollector :
443+ """Internal context manager to collect specified metrics."""
444+
445+ def __init__ (
446+ self , metrics_obj : Metrics , operation_name : str , metric_keys : list [str ]
447+ ):
448+ self .metrics_obj = metrics_obj
449+ self .operation_name = operation_name
450+ self ._metrics : dict [str , BaseMetric ] = {}
451+
452+ for key in metric_keys :
453+ if key in METRIC_REGISTRY :
454+ metric_class = METRIC_REGISTRY [key ]
455+ self ._metrics [key ] = metric_class (operation_name )
456+ else :
457+ logging .warning ("Unknown metric key: %s" , key )
458+
459+ def __enter__ (self ):
460+ for metric in self ._metrics .values ():
461+ metric .start ()
462+ return self
463+
464+ def __exit__ (self , * exc ):
465+ for key , metric in self ._metrics .items ():
466+ try :
467+ metric_results = metric .stop ()
468+ self .metrics_obj ._add_results (metric .name , key , metric_results )
469+ except Exception as e : # pylint: disable=broad-exception-caught
470+ logging .exception ("Error stopping metric %s: %s" , metric .name , e )
471+
472+
473+ ################################################################################
474+ # Aggregation and Reporting
475+ ################################################################################
476+
477+
478+ @dataclasses .dataclass
479+ class AggregatedStats :
480+ """Statistics aggregated over multiple benchmark repetitions.
481+
482+ Attributes:
483+ mean: Mean value.
484+ std: Standard deviation.
485+ min: Minimum value.
486+ max: Maximum value.
487+ count: Number of values aggregated.
488+ """
489+
490+ mean : float
491+ std : float
492+ min : float
493+ max : float
494+ count : int
495+
496+
440497class MetricsManager :
441- """Manages metrics aggregation across multiple benchmark runs."""
498+ """Manages metrics aggregation and reporting for a test suite.
499+
500+ This class collects metrics from multiple benchmark runs and repetitions,
501+ computes aggregate statistics (mean, std, min, max), generates a
502+ human-readable report for logging, and exports metrics to TensorBoard
503+ if configured.
504+ """
505+
506+ def __init__ (
507+ self ,
508+ name : str ,
509+ num_repeats : int ,
510+ ):
511+ """Initializes the MetricsManager.
442512
443- def __init__ (self , name : str , num_repeats : int ):
513+ Args:
514+ name: The name of the test suite.
515+ num_repeats: The number of repetitions for each benchmark configuration.
516+ """
444517 self ._name = name
445518 self ._num_repeats = num_repeats
446519 self ._runs : dict [str , list [tuple [Metrics , Exception | None ]]] = (
447520 collections .defaultdict (list )
448521 )
522+ self ._benchmark_options : dict [str , Any ] = {}
449523
450524 def add_result (
451- self , benchmark_name : str , metrics : Metrics , error : Exception | None
525+ self ,
526+ benchmark_name : str ,
527+ options : Any ,
528+ metrics : Metrics ,
529+ error : Exception | None ,
452530 ):
453- """Adds a result from a single benchmark run."""
531+ """Adds metrics from a single benchmark run/repetition.
532+
533+ Args:
534+ benchmark_name: The name of the benchmark configuration.
535+ options: The BenchmarkOptions used for this run.
536+ metrics: The Metrics object containing results for this run.
537+ error: An exception if the run failed, otherwise None.
538+ """
454539 self ._runs [benchmark_name ].append ((metrics , error ))
540+ if benchmark_name not in self ._benchmark_options :
541+ self ._benchmark_options [benchmark_name ] = options
542+
543+ def _aggregate_metrics (
544+ self , results : list [tuple [Metrics , Exception | None ]]
545+ ) -> tuple [dict [str , AggregatedStats ], dict [str , str ]]:
546+ """Computes aggregate stats (mean, std, etc.) for successful runs.
547+
548+ Args:
549+ results: A list of (Metrics, error) tuples for a benchmark configuration.
550+
551+ Returns:
552+ A tuple containing:
553+ - A dict mapping metric keys to AggregatedStats.
554+ - A dict mapping metric keys to their units.
555+ """
556+ metrics_collector = collections .defaultdict (list )
557+ metric_units = {}
558+ for metrics , error in results :
559+ if error is None :
560+ for key , (value , unit ) in metrics .results .items ():
561+ if isinstance (value , (int , float )):
562+ metrics_collector [key ].append (value )
563+ metric_units [key ] = unit
564+
565+ aggregated_stats_dict = {}
566+ for key , values in metrics_collector .items ():
567+ aggregated_stats_dict [key ] = AggregatedStats (
568+ mean = np .mean (values ),
569+ std = np .std (values ),
570+ min = np .min (values ),
571+ max = np .max (values ),
572+ count = len (values ),
573+ )
574+ return aggregated_stats_dict , metric_units
455575
456576 def generate_report (self ) -> str :
457- """Generates a report with statistics from the test results."""
577+ """Generates a final string report containing aggregated metrics.
578+
579+ Returns:
580+ A formatted string containing the full benchmark report.
581+ """
458582 report_lines = []
459583 title = f" Test Suite Report: { self ._name } "
460584 report_lines .append (f"\n { title := ^ 80 } " )
@@ -476,35 +600,29 @@ def generate_report(self) -> str:
476600 f" { passed_runs } , Failed: { failed_runs } "
477601 )
478602
603+ # Aggregate metrics, add to report, and write aggregates to TensorBoard
479604 if self ._num_repeats > 1 :
480605 report_lines .append ("\n " + "-" * 80 )
481606 report_lines .append ("--- Aggregated Metrics per Benchmark ---" )
482607 for benchmark_name , results in self ._runs .items ():
483608 if not results :
484609 continue
485610 report_lines .append (f"\n Benchmark: { benchmark_name } " )
486- metrics_collector = collections .defaultdict (list )
487- metric_units = {}
488- for metrics , error in results :
489- if error is None :
490- for key , (value , unit ) in metrics .results .items ():
491- if isinstance (value , (int , float )):
492- metrics_collector [key ].append (value )
493- metric_units [key ] = unit
494- if not metrics_collector :
611+
612+ aggregated_stats_dict , metric_units = self ._aggregate_metrics (results )
613+
614+ if not aggregated_stats_dict :
495615 report_lines .append (" No successful runs to aggregate." )
496616 continue
497- for key , values in metrics_collector .items ():
617+
618+ for key , stats in aggregated_stats_dict .items ():
498619 unit = metric_units [key ]
499- mean = np .mean (values )
500- stdev = np .std (values )
501- min_val = np .min (values )
502- max_val = np .max (values )
503620 report_lines .append (
504- f" { key } : { mean :.4f} +/- { stdev :.4f} { unit } (min:"
505- f" { min_val :.4f} , max: { max_val :.4f} , n={ len ( values ) } )"
621+ f" { key } : { stats . mean :.4f} +/- { stats . std :.4f} { unit } (min:"
622+ f" { stats . min :.4f} , max: { stats . max :.4f} , n={ stats . count } )"
506623 )
507624
625+ # Report failed runs
508626 if failed_runs > 0 :
509627 report_lines .append ("\n " + "-" * 80 )
510628 report_lines .append ("--- Failed Runs ---" )
@@ -516,36 +634,42 @@ def generate_report(self) -> str:
516634 if len (error_repr ) > 1000 :
517635 error_repr = error_repr [:1000 ] + "..."
518636 report_lines .append (f"Test: { metrics .name } , Error: { error_repr } " )
637+
519638 report_lines .append ("\n " + "=" * 80 )
520639 return "\n " .join (report_lines )
521640
522-
523- class _MetricsCollector :
524- """Internal context manager to collect specified metrics."""
525-
526- def __init__ (
527- self , metrics_obj : Metrics , operation_name : str , metric_keys : list [str ]
528- ):
529- self .metrics_obj = metrics_obj
530- self .operation_name = operation_name
531- self ._metrics : dict [str , BaseMetric ] = {}
532-
533- for key in metric_keys :
534- if key in METRIC_REGISTRY :
535- metric_class = METRIC_REGISTRY [key ]
536- self ._metrics [key ] = metric_class (operation_name )
537- else :
538- logging .warning ("Unknown metric key: %s" , key )
539-
540- def __enter__ (self ):
541- for metric in self ._metrics .values ():
542- metric .start ()
543- return self
544-
545- def __exit__ (self , * exc ):
546- for key , metric in self ._metrics .items ():
547- try :
548- metric_results = metric .stop ()
549- self .metrics_obj ._add_results (metric .name , key , metric_results )
550- except Exception as e : # pylint: disable=broad-exception-caught
551- logging .exception ("Error stopping metric %s: %s" , metric .name , e )
641+ def export_to_tensorboard (self , tensorboard_dir : epath .Path ):
642+ """Exports metrics to TensorBoard."""
643+ logging .info ("Writing per-repetition metrics to TensorBoard..." )
644+ for benchmark_name , results in self ._runs .items ():
645+ is_primary_host = multihost .process_index () == 0
646+ writer = metric_writers .create_default_writer (
647+ tensorboard_dir ,
648+ just_logging = not is_primary_host ,
649+ collection = benchmark_name ,
650+ )
651+ # Write metrics for each repetition
652+ for i , (metrics , error ) in enumerate (results ):
653+ if error is None :
654+ for key , (value , unit ) in metrics .results .items ():
655+ tag = f'{ key } _{ unit .replace ("/" , "_" )} '
656+ if isinstance (value , (int , float )):
657+ writer .write_scalars (step = i , scalars = {tag : value })
658+ else :
659+ writer .write_texts (step = i , texts = {tag : str (value )})
660+ else :
661+ tag = "error"
662+ writer .write_texts (step = i , texts = {tag : f"<pre>{ repr (error )} </pre>" })
663+ # Write benchmark options as text
664+ if self ._benchmark_options [benchmark_name ]:
665+ writer .write_texts (
666+ step = 0 ,
667+ texts = {
668+ "options" : (
669+ f"<pre>{ repr (self ._benchmark_options [benchmark_name ])} </pre>"
670+ )
671+ },
672+ )
673+ writer .flush ()
674+ writer .close ()
675+ logging .info ("Finished writing metrics to TensorBoard." )
0 commit comments