Skip to content

Commit fa18d5f

Browse files
mridul-sahuOrbax Authors
authored andcommitted
Export Aggregated Metrics report to tensorboard.
PiperOrigin-RevId: 834280995
1 parent 008c77a commit fa18d5f

File tree

2 files changed

+23
-1
lines changed

2 files changed

+23
-1
lines changed

checkpoint/orbax/checkpoint/_src/multihost/multislice.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,7 @@ def get_device_memory() -> int:
156156
'NVIDIA H100 80GB HBM3': int(80e9),
157157
'NVIDIA H200': int(144e9),
158158
'NVIDIA B200': int(183e9),
159-
'NVIDIA B300 SXM6 AC': int(275e9)
159+
'NVIDIA B300 SXM6 AC': int(275e9),
160160
}
161161
memory = hbm_memory.get(device.device_kind, None)
162162
if memory is None:

checkpoint/orbax/checkpoint/_src/testing/benchmarks/core/metric.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -670,6 +670,28 @@ def export_to_tensorboard(self, tensorboard_dir: epath.Path):
670670
)
671671
},
672672
)
673+
# Write Aggreagated metrics as text
674+
aggregated_metrics = []
675+
aggregated_stats_dict, metric_units = self._aggregate_metrics(results)
676+
677+
if not aggregated_stats_dict:
678+
aggregated_metrics.append("No successful runs to aggregate.")
679+
continue
680+
681+
for key, stats in aggregated_stats_dict.items():
682+
unit = metric_units[key]
683+
aggregated_metrics.append(
684+
f"{key}: {stats.mean:.4f} +/- {stats.std:.4f} {unit} (min:"
685+
f" {stats.min:.4f}, max: {stats.max:.4f}, n={stats.count})"
686+
)
687+
aggregated_metrics_str = "\n".join(aggregated_metrics)
688+
writer.write_texts(
689+
step=0,
690+
texts={
691+
"aggregated_metrics": f"<pre>{aggregated_metrics_str}</pre>"
692+
},
693+
)
694+
673695
writer.flush()
674696
writer.close()
675697
logging.info("Finished writing metrics to TensorBoard.")

0 commit comments

Comments
 (0)