Skip to content
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion src/cloudai/registration.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# SPDX-FileCopyrightText: NVIDIA CORPORATION & AFFILIATES
# Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
Expand Down Expand Up @@ -98,6 +98,7 @@ def register_all():
)
from cloudai.workloads.megatron_run import (
CheckpointTimingReportGenerationStrategy,
MegatronRunReportGenerationStrategy,
MegatronRunSlurmCommandGenStrategy,
MegatronRunTestDefinition,
)
Expand Down Expand Up @@ -259,6 +260,7 @@ def register_all():
Registry().add_report(GPTTestDefinition, JaxToolboxReportGenerationStrategy)
Registry().add_report(GrokTestDefinition, JaxToolboxReportGenerationStrategy)
Registry().add_report(MegatronRunTestDefinition, CheckpointTimingReportGenerationStrategy)
Registry().add_report(MegatronRunTestDefinition, MegatronRunReportGenerationStrategy)
Registry().add_report(MegatronBridgeTestDefinition, MegatronBridgeReportGenerationStrategy)
Registry().add_report(NCCLTestDefinition, NcclTestPerformanceReportGenerationStrategy)
Registry().add_report(NeMoLauncherTestDefinition, NeMoLauncherReportGenerationStrategy)
Expand Down
5 changes: 3 additions & 2 deletions src/cloudai/workloads/megatron_run/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# SPDX-FileCopyrightText: NVIDIA CORPORATION & AFFILIATES
# Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
Expand All @@ -15,12 +15,13 @@
# limitations under the License.

from .megatron_run import MegatronRunCmdArgs, MegatronRunTestDefinition
from .report_generation_strategy import CheckpointTimingReportGenerationStrategy
from .report_generation_strategy import CheckpointTimingReportGenerationStrategy, MegatronRunReportGenerationStrategy
from .slurm_command_gen_strategy import MegatronRunSlurmCommandGenStrategy

__all__ = [
"CheckpointTimingReportGenerationStrategy",
"MegatronRunCmdArgs",
"MegatronRunReportGenerationStrategy",
"MegatronRunSlurmCommandGenStrategy",
"MegatronRunTestDefinition",
]
142 changes: 140 additions & 2 deletions src/cloudai/workloads/megatron_run/report_generation_strategy.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# SPDX-FileCopyrightText: NVIDIA CORPORATION & AFFILIATES
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
Expand All @@ -14,13 +14,28 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

import logging
import re
from pathlib import Path
from statistics import mean, median, pstdev
from typing import ClassVar

from cloudai.core import ReportGenerationStrategy
from cloudai.core import METRIC_ERROR, ReportGenerationStrategy

CHECKPOINT_REGEX = re.compile(r"(save|load)-checkpoint\s.*:\s\((\d+\.\d+),\s(\d+\.\d+)\)")

# Pattern to match lines like:
# [2026-01-16 07:32:39] iteration 6/100 | ... |
# elapsed time per iteration (ms): 15639.0 | throughput per GPU (TFLOP/s/GPU): 494.6 | ...
ITERATION_REGEX = re.compile(
r"elapsed time per iteration \(ms\):\s*([0-9]+(?:\.[0-9]+)?)"
r".*?"
r"throughput per GPU \(TFLOP/s/GPU\):\s*([0-9]+(?:\.[0-9]+)?)",
re.IGNORECASE,
)


class CheckpointTimingReportGenerationStrategy(ReportGenerationStrategy):
"""Strategy for generating reports from Checkpoint Timing test outputs."""
Expand Down Expand Up @@ -59,3 +74,126 @@ def generate_report(self) -> None:
for checkpoint_type, timings in [("save", save_timings), ("load", load_timings)]:
for t in timings:
file.write(f"{checkpoint_type},{t[0]},{t[1]}\n")


class MegatronRunReportGenerationStrategy(ReportGenerationStrategy):
"""Parse Megatron-Run stdout.txt for iteration time and GPU TFLOP/s per GPU."""

metrics: ClassVar[list[str]] = ["default", "iteration-time", "tflops-per-gpu"]

def get_log_file(self) -> Path | None:
log = self.test_run.output_path / "stdout.txt"
return log if log.is_file() else None

@property
def results_file(self) -> Path:
return self.get_log_file() or (self.test_run.output_path / "stdout.txt")

def can_handle_directory(self) -> bool:
log_file = self.get_log_file()
if not log_file:
return False
with log_file.open("r", encoding="utf-8", errors="ignore") as f:
for line in f:
if ITERATION_REGEX.search(line):
return True
return False

def _extract(self, log_path: Path) -> tuple[list[float], list[float]]:
"""Extract iteration times (ms) and GPU TFLOPS from the log file."""
iter_times_ms: list[float] = []
gpu_tflops: list[float] = []
with log_path.open("r", encoding="utf-8", errors="ignore") as f:
for line in f:
m = ITERATION_REGEX.search(line)
if m:
try:
iter_times_ms.append(float(m.group(1)))
gpu_tflops.append(float(m.group(2)))
except (ValueError, TypeError):
logging.debug("Failed to parse iteration metrics line: %s", line.rstrip("\n"))

# Keep only the last 10 iterations for statistics (to exclude warmup)
if len(iter_times_ms) > 10:
iter_times_ms = iter_times_ms[-10:]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm wondering if taking the last 10 iterations is the most relevant. What if the training has some ups and downs (as I already saw). Maybe just skipping the warmup, so say the 20 first iterations is enough ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, skipping the warmup stage makes more sense, have updated to skipping the first 20 iterations. Originally was following the format in Megatron-Bridge report. Maybe later need to unify the formats for computing statistics.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The last 10 iteration is what the GPU perf team uses. I have seen those runs on IB clusters and mostly towards the end it remains stable.

gpu_tflops = gpu_tflops[-10:]
return iter_times_ms, gpu_tflops

def _get_extracted_data(self) -> tuple[Path | None, list[float], list[float]]:
log_file = self.get_log_file()
if not log_file:
return None, [], []
iter_times_ms, gpu_tflops = self._extract(log_file)
return log_file, iter_times_ms, gpu_tflops

def generate_report(self) -> None:
log_file, iter_times_ms, gpu_tflops = self._get_extracted_data()
if not log_file:
logging.error(
"No stdout.txt file found in: %s",
self.test_run.output_path,
)
return

summary_file = self.test_run.output_path / "megatron_run_report.txt"
if not iter_times_ms:
with summary_file.open("w") as f:
f.write("MegatronRun report\n")
f.write("No iteration timing lines were found.\n\n")
f.write("Searched file:\n")
f.write(f" - {log_file}\n")
logging.warning("No iteration metrics found under %s (wrote %s)", self.test_run.output_path, summary_file)
return

iter_stats = {
"avg": mean(iter_times_ms),
"median": median(iter_times_ms),
"min": min(iter_times_ms),
"max": max(iter_times_ms),
"std": pstdev(iter_times_ms) if len(iter_times_ms) > 1 else 0.0,
}
if gpu_tflops:
tflops_stats = {
"avg": mean(gpu_tflops),
"median": median(gpu_tflops),
"min": min(gpu_tflops),
"max": max(gpu_tflops),
"std": pstdev(gpu_tflops) if len(gpu_tflops) > 1 else 0.0,
}
else:
tflops_stats = {"avg": 0.0, "median": 0.0, "min": 0.0, "max": 0.0, "std": 0.0}

with summary_file.open("w") as f:
f.write(f"Source log: {log_file}\n\n")
f.write("Iteration Time (ms)\n")
f.write(f" avg: {iter_stats['avg']}\n")
f.write(f" median: {iter_stats['median']}\n")
f.write(f" min: {iter_stats['min']}\n")
f.write(f" max: {iter_stats['max']}\n")
f.write(f" std: {iter_stats['std']}\n")
f.write("\n")
f.write("TFLOP/s per GPU\n")
f.write(f" avg: {tflops_stats['avg']}\n")
f.write(f" median: {tflops_stats['median']}\n")
f.write(f" min: {tflops_stats['min']}\n")
f.write(f" max: {tflops_stats['max']}\n")
f.write(f" std: {tflops_stats['std']}\n")

def get_metric(self, metric: str) -> float:
if metric not in {"default", "iteration-time", "tflops-per-gpu"}:
return METRIC_ERROR
log_file, iter_times_ms, gpu_tflops = self._get_extracted_data()
if not log_file:
logging.error(
"No stdout.txt file found in: %s",
self.test_run.output_path,
)
return METRIC_ERROR
if not iter_times_ms:
return METRIC_ERROR

if metric in {"default", "iteration-time"}:
return float(mean(iter_times_ms))
if metric == "tflops-per-gpu":
return float(mean(gpu_tflops)) if gpu_tflops else METRIC_ERROR
return METRIC_ERROR
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
# SPDX-FileCopyrightText: NVIDIA CORPORATION & AFFILIATES
# Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from pathlib import Path

import pytest

from cloudai import TestRun
from cloudai.core import METRIC_ERROR
from cloudai.systems.slurm.slurm_system import SlurmSystem
from cloudai.workloads.megatron_run import (
MegatronRunCmdArgs,
MegatronRunReportGenerationStrategy,
MegatronRunTestDefinition,
)


@pytest.fixture
def megatron_run_tr(tmp_path: Path) -> TestRun:
test = MegatronRunTestDefinition(
name="megatron_run",
description="desc",
test_template_name="t",
cmd_args=MegatronRunCmdArgs(docker_image_url="http://url", run_script=Path(__file__)),
)
tr = TestRun(name="megatron_run_test", test=test, num_nodes=1, nodes=[], output_path=tmp_path)

stdout_content = (
"[2026-01-16 07:32:24] iteration 5/ 100 | consumed samples: 10240 | "
"elapsed time per iteration (ms): 15800.0 | throughput per GPU (TFLOP/s/GPU): 490.0 | "
"learning rate: 4.134000E-07 | global batch size: 2048 | lm loss: 1.344240E+01 | "
"seq_load_balancing_loss: 1.000203E+00 | loss scale: 1.0 | grad norm: 2.870 | "
"num zeros: 1174412544.0 | params norm: 8660.607 | "
"number of skipped iterations: 0 | number of nan iterations: 0 |\n"
"[2026-01-16 07:32:39] iteration 6/ 100 | consumed samples: 12288 | "
"elapsed time per iteration (ms): 15639.0 | throughput per GPU (TFLOP/s/GPU): 494.6 | "
"learning rate: 4.180800E-07 | global batch size: 2048 | lm loss: 1.342407E+01 | "
"seq_load_balancing_loss: 1.000202E+00 | loss scale: 1.0 | grad norm: 2.867 | "
"num zeros: 1174412672.0 | params norm: 8660.606 | "
"number of skipped iterations: 0 | number of nan iterations: 0 |\n"
"[2026-01-16 07:32:54] iteration 7/ 100 | consumed samples: 14336 | "
"elapsed time per iteration (ms): 15448.5 | throughput per GPU (TFLOP/s/GPU): 500.6 | "
"learning rate: 4.227600E-07 | global batch size: 2048 | lm loss: 1.340574E+01 | "
"seq_load_balancing_loss: 1.000201E+00 | loss scale: 1.0 | grad norm: 2.864 | "
"num zeros: 1174412800.0 | params norm: 8660.605 | "
"number of skipped iterations: 0 | number of nan iterations: 0 |\n"
)
(tr.output_path / "stdout.txt").write_text(stdout_content)

return tr


@pytest.fixture
def megatron_run_tr_no_data(tmp_path: Path) -> TestRun:
test = MegatronRunTestDefinition(
name="megatron_run",
description="desc",
test_template_name="t",
cmd_args=MegatronRunCmdArgs(docker_image_url="http://url", run_script=Path(__file__)),
)
tr = TestRun(name="megatron_run_test", test=test, num_nodes=1, nodes=[], output_path=tmp_path)

stdout_content = """
Some random log output without iteration metrics
Starting training...
"""
(tr.output_path / "stdout.txt").write_text(stdout_content)

return tr


def test_megatron_run_can_handle_directory(slurm_system: SlurmSystem, megatron_run_tr: TestRun) -> None:
strategy = MegatronRunReportGenerationStrategy(slurm_system, megatron_run_tr)
assert strategy.can_handle_directory()


def test_megatron_run_cannot_handle_directory_without_iteration_data(
slurm_system: SlurmSystem, megatron_run_tr_no_data: TestRun
) -> None:
strategy = MegatronRunReportGenerationStrategy(slurm_system, megatron_run_tr_no_data)
assert not strategy.can_handle_directory()


def test_megatron_run_extract_and_generate_report(slurm_system: SlurmSystem, megatron_run_tr: TestRun) -> None:
strategy = MegatronRunReportGenerationStrategy(slurm_system, megatron_run_tr)
strategy.generate_report()
report_path = megatron_run_tr.output_path / "megatron_run_report.txt"
assert report_path.is_file()
content = report_path.read_text()
assert "Iteration Time (ms)" in content
assert "TFLOP/s per GPU" in content
assert "avg:" in content
assert "median:" in content
assert "min:" in content
assert "max:" in content
assert "std:" in content


def test_megatron_run_get_metric_iteration_time(slurm_system: SlurmSystem, megatron_run_tr: TestRun) -> None:
strategy = MegatronRunReportGenerationStrategy(slurm_system, megatron_run_tr)
# Expected: avg of [15800.0, 15639.0, 15448.5]
expected_avg = (15800.0 + 15639.0 + 15448.5) / 3
metric = strategy.get_metric("iteration-time")
assert abs(metric - expected_avg) < 0.1


def test_megatron_run_get_metric_default(slurm_system: SlurmSystem, megatron_run_tr: TestRun) -> None:
strategy = MegatronRunReportGenerationStrategy(slurm_system, megatron_run_tr)
# Default should return iteration-time
expected_avg = (15800.0 + 15639.0 + 15448.5) / 3
metric = strategy.get_metric("default")
assert abs(metric - expected_avg) < 0.1


def test_megatron_run_get_metric_tflops(slurm_system: SlurmSystem, megatron_run_tr: TestRun) -> None:
strategy = MegatronRunReportGenerationStrategy(slurm_system, megatron_run_tr)
# Expected: avg of [490.0, 494.6, 500.6]
expected_avg = (490.0 + 494.6 + 500.6) / 3
metric = strategy.get_metric("tflops-per-gpu")
assert abs(metric - expected_avg) < 0.1


def test_megatron_run_get_metric_invalid(slurm_system: SlurmSystem, megatron_run_tr: TestRun) -> None:
strategy = MegatronRunReportGenerationStrategy(slurm_system, megatron_run_tr)
metric = strategy.get_metric("invalid-metric")
assert metric == METRIC_ERROR


def test_megatron_run_get_metric_no_data(slurm_system: SlurmSystem, megatron_run_tr_no_data: TestRun) -> None:
strategy = MegatronRunReportGenerationStrategy(slurm_system, megatron_run_tr_no_data)
metric = strategy.get_metric("iteration-time")
assert metric == METRIC_ERROR


def test_megatron_run_metrics_class_var() -> None:
assert MegatronRunReportGenerationStrategy.metrics == ["default", "iteration-time", "tflops-per-gpu"]
8 changes: 6 additions & 2 deletions tests/test_test_scenario.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# SPDX-FileCopyrightText: NVIDIA CORPORATION & AFFILIATES
# Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
Expand Down Expand Up @@ -53,6 +53,7 @@
from cloudai.workloads.megatron_run import (
CheckpointTimingReportGenerationStrategy,
MegatronRunCmdArgs,
MegatronRunReportGenerationStrategy,
MegatronRunTestDefinition,
)
from cloudai.workloads.nccl_test import (
Expand Down Expand Up @@ -481,7 +482,10 @@ def test_default_reporters_size(self):
(DeepEPTestDefinition, {DeepEPReportGenerationStrategy}),
(GPTTestDefinition, {JaxToolboxReportGenerationStrategy}),
(GrokTestDefinition, {JaxToolboxReportGenerationStrategy}),
(MegatronRunTestDefinition, {CheckpointTimingReportGenerationStrategy}),
(
MegatronRunTestDefinition,
{CheckpointTimingReportGenerationStrategy, MegatronRunReportGenerationStrategy},
),
(MegatronBridgeTestDefinition, {MegatronBridgeReportGenerationStrategy}),
(NCCLTestDefinition, {NcclTestPerformanceReportGenerationStrategy}),
(NeMoLauncherTestDefinition, {NeMoLauncherReportGenerationStrategy}),
Expand Down