Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/cloudai/models/scenario.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# SPDX-FileCopyrightText: NVIDIA CORPORATION & AFFILIATES
# Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
Expand Down Expand Up @@ -102,7 +102,7 @@ def tdef_model_dump(self, by_alias: bool) -> dict:
"test_template_name": self.test_template_name,
"agent": self.agent,
"agent_steps": self.agent_steps,
"agent_metrics": self.agent_metrics,
"agent_metrics": self.agent_metrics if "agent_metrics" in self.model_fields_set else None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That would change default values to None. Why is it needed?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If it is None here then the definition of agent_metrics in test config can propagate, otherwise if agent_metrics is not defined in the scenario config, the final merged config would always be [default] even though agent_metrics is set in the test config.

"agent_reward_function": self.agent_reward_function,
"extra_container_mounts": self.extra_container_mounts,
"extra_env_vars": self.extra_env_vars if self.extra_env_vars else None,
Expand Down
4 changes: 3 additions & 1 deletion src/cloudai/registration.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# SPDX-FileCopyrightText: NVIDIA CORPORATION & AFFILIATES
# Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
Expand Down Expand Up @@ -98,6 +98,7 @@ def register_all():
)
from cloudai.workloads.megatron_run import (
CheckpointTimingReportGenerationStrategy,
MegatronRunReportGenerationStrategy,
MegatronRunSlurmCommandGenStrategy,
MegatronRunTestDefinition,
)
Expand Down Expand Up @@ -259,6 +260,7 @@ def register_all():
Registry().add_report(GPTTestDefinition, JaxToolboxReportGenerationStrategy)
Registry().add_report(GrokTestDefinition, JaxToolboxReportGenerationStrategy)
Registry().add_report(MegatronRunTestDefinition, CheckpointTimingReportGenerationStrategy)
Registry().add_report(MegatronRunTestDefinition, MegatronRunReportGenerationStrategy)
Registry().add_report(MegatronBridgeTestDefinition, MegatronBridgeReportGenerationStrategy)
Registry().add_report(NCCLTestDefinition, NcclTestPerformanceReportGenerationStrategy)
Registry().add_report(NeMoLauncherTestDefinition, NeMoLauncherReportGenerationStrategy)
Expand Down
5 changes: 3 additions & 2 deletions src/cloudai/workloads/megatron_run/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# SPDX-FileCopyrightText: NVIDIA CORPORATION & AFFILIATES
# Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
Expand All @@ -15,12 +15,13 @@
# limitations under the License.

from .megatron_run import MegatronRunCmdArgs, MegatronRunTestDefinition
from .report_generation_strategy import CheckpointTimingReportGenerationStrategy
from .report_generation_strategy import CheckpointTimingReportGenerationStrategy, MegatronRunReportGenerationStrategy
from .slurm_command_gen_strategy import MegatronRunSlurmCommandGenStrategy

__all__ = [
"CheckpointTimingReportGenerationStrategy",
"MegatronRunCmdArgs",
"MegatronRunReportGenerationStrategy",
"MegatronRunSlurmCommandGenStrategy",
"MegatronRunTestDefinition",
]
129 changes: 127 additions & 2 deletions src/cloudai/workloads/megatron_run/report_generation_strategy.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# SPDX-FileCopyrightText: NVIDIA CORPORATION & AFFILIATES
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
Expand All @@ -14,13 +14,29 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

import csv
import logging
import re
from pathlib import Path
from statistics import mean, median, pstdev
from typing import ClassVar

from cloudai.core import ReportGenerationStrategy
from cloudai.core import METRIC_ERROR, ReportGenerationStrategy

CHECKPOINT_REGEX = re.compile(r"(save|load)-checkpoint\s.*:\s\((\d+\.\d+),\s(\d+\.\d+)\)")

# Pattern to match lines like:
# [2026-01-16 07:32:39] iteration 6/100 | ... |
# elapsed time per iteration (ms): 15639.0 | throughput per GPU (TFLOP/s/GPU): 494.6 | ...
ITERATION_REGEX = re.compile(
r"elapsed time per iteration \(ms\):\s*([0-9]+(?:\.[0-9]+)?)"
r".*?"
r"throughput per GPU \(TFLOP/s/GPU\):\s*([0-9]+(?:\.[0-9]+)?)",
re.IGNORECASE,
)


class CheckpointTimingReportGenerationStrategy(ReportGenerationStrategy):
"""Strategy for generating reports from Checkpoint Timing test outputs."""
Expand Down Expand Up @@ -59,3 +75,112 @@ def generate_report(self) -> None:
for checkpoint_type, timings in [("save", save_timings), ("load", load_timings)]:
for t in timings:
file.write(f"{checkpoint_type},{t[0]},{t[1]}\n")


class MegatronRunReportGenerationStrategy(ReportGenerationStrategy):
"""Parse Megatron-Run stdout.txt for iteration time and GPU TFLOP/s per GPU."""

metrics: ClassVar[list[str]] = ["default", "iteration-time", "tflops-per-gpu"]

def get_log_file(self) -> Path | None:
log = self.test_run.output_path / "stdout.txt"
return log if log.is_file() else None

@property
def results_file(self) -> Path:
return self.get_log_file() or (self.test_run.output_path / "stdout.txt")

def can_handle_directory(self) -> bool:
log_file = self.get_log_file()
if not log_file:
return False
with log_file.open("r", encoding="utf-8", errors="ignore") as f:
for line in f:
if ITERATION_REGEX.search(line):
return True
return False

def _extract(self, log_path: Path) -> tuple[list[float], list[float]]:
"""Extract iteration times (ms) and GPU TFLOPS from the log file."""
iter_times_ms: list[float] = []
gpu_tflops: list[float] = []
with log_path.open("r", encoding="utf-8", errors="ignore") as f:
for line in f:
m = ITERATION_REGEX.search(line)
if m:
try:
iter_times_ms.append(float(m.group(1)))
gpu_tflops.append(float(m.group(2)))
except (ValueError, TypeError):
logging.debug("Failed to parse iteration metrics line: %s", line.rstrip("\n"))

# Skip the first 20 iterations for statistics (to exclude warmup)
if len(iter_times_ms) > 20:
iter_times_ms = iter_times_ms[20:]
gpu_tflops = gpu_tflops[20:]
return iter_times_ms, gpu_tflops

def _get_extracted_data(self) -> tuple[Path | None, list[float], list[float]]:
log_file = self.get_log_file()
if not log_file:
return None, [], []
iter_times_ms, gpu_tflops = self._extract(log_file)
return log_file, iter_times_ms, gpu_tflops

def generate_report(self) -> None:
log_file, iter_times_ms, gpu_tflops = self._get_extracted_data()
if not log_file:
logging.error(
"No stdout.txt file found in: %s",
self.test_run.output_path,
)
return

report_file = self.test_run.output_path / "megatron_run_report.csv"
if not iter_times_ms:
with report_file.open("w", newline="") as f:
writer = csv.writer(f)
writer.writerow(["metric_type", "avg", "median", "min", "max", "std"])
writer.writerow(["error: No iteration timing lines were found.", "", "", "", "", ""])
logging.warning("No iteration metrics found under %s (wrote %s)", self.test_run.output_path, report_file)
return

iter_avg = mean(iter_times_ms)
iter_median = median(iter_times_ms)
iter_min = min(iter_times_ms)
iter_max = max(iter_times_ms)
iter_std = pstdev(iter_times_ms) if len(iter_times_ms) > 1 else 0.0

if gpu_tflops:
tflops_avg = mean(gpu_tflops)
tflops_median = median(gpu_tflops)
tflops_min = min(gpu_tflops)
tflops_max = max(gpu_tflops)
tflops_std = pstdev(gpu_tflops) if len(gpu_tflops) > 1 else 0.0
else:
tflops_avg = tflops_median = tflops_min = tflops_max = tflops_std = 0.0

with report_file.open("w", newline="") as f:
writer = csv.writer(f)
writer.writerow(["metric_type", "avg", "median", "min", "max", "std"])
writer.writerow(["iteration_time_ms", iter_avg, iter_median, iter_min, iter_max, iter_std])
writer.writerow(["tflops_per_gpu", tflops_avg, tflops_median, tflops_min, tflops_max, tflops_std])

def get_metric(self, metric: str) -> float:
if metric not in {"default", "iteration-time", "tflops-per-gpu"}:
return METRIC_ERROR
log_file, iter_times_ms, gpu_tflops = self._get_extracted_data()
if not log_file:
logging.error(
"No stdout.txt file found in: %s",
self.test_run.output_path,
)
return METRIC_ERROR
if not iter_times_ms:
return METRIC_ERROR

if metric in {"default", "iteration-time"}:
return float(mean(iter_times_ms))
if metric == "tflops-per-gpu":
return float(mean(gpu_tflops)) if gpu_tflops else METRIC_ERROR
return METRIC_ERROR
Original file line number Diff line number Diff line change
@@ -0,0 +1,172 @@
# SPDX-FileCopyrightText: NVIDIA CORPORATION & AFFILIATES
# Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import csv
from pathlib import Path

import pytest

from cloudai import TestRun
from cloudai.core import METRIC_ERROR
from cloudai.systems.slurm.slurm_system import SlurmSystem
from cloudai.workloads.megatron_run import (
MegatronRunCmdArgs,
MegatronRunReportGenerationStrategy,
MegatronRunTestDefinition,
)


@pytest.fixture
def megatron_run_tr(tmp_path: Path) -> TestRun:
test = MegatronRunTestDefinition(
name="megatron_run",
description="desc",
test_template_name="t",
cmd_args=MegatronRunCmdArgs(docker_image_url="http://url", run_script=Path(__file__)),
)
tr = TestRun(name="megatron_run_test", test=test, num_nodes=1, nodes=[], output_path=tmp_path)

stdout_content = (
"[2026-01-16 07:32:24] iteration 5/ 100 | consumed samples: 10240 | "
"elapsed time per iteration (ms): 15800.0 | throughput per GPU (TFLOP/s/GPU): 490.0 | "
"learning rate: 4.134000E-07 | global batch size: 2048 | lm loss: 1.344240E+01 | "
"seq_load_balancing_loss: 1.000203E+00 | loss scale: 1.0 | grad norm: 2.870 | "
"num zeros: 1174412544.0 | params norm: 8660.607 | "
"number of skipped iterations: 0 | number of nan iterations: 0 |\n"
"[2026-01-16 07:32:39] iteration 6/ 100 | consumed samples: 12288 | "
"elapsed time per iteration (ms): 15639.0 | throughput per GPU (TFLOP/s/GPU): 494.6 | "
"learning rate: 4.180800E-07 | global batch size: 2048 | lm loss: 1.342407E+01 | "
"seq_load_balancing_loss: 1.000202E+00 | loss scale: 1.0 | grad norm: 2.867 | "
"num zeros: 1174412672.0 | params norm: 8660.606 | "
"number of skipped iterations: 0 | number of nan iterations: 0 |\n"
"[2026-01-16 07:32:54] iteration 7/ 100 | consumed samples: 14336 | "
"elapsed time per iteration (ms): 15448.5 | throughput per GPU (TFLOP/s/GPU): 500.6 | "
"learning rate: 4.227600E-07 | global batch size: 2048 | lm loss: 1.340574E+01 | "
"seq_load_balancing_loss: 1.000201E+00 | loss scale: 1.0 | grad norm: 2.864 | "
"num zeros: 1174412800.0 | params norm: 8660.605 | "
"number of skipped iterations: 0 | number of nan iterations: 0 |\n"
)
(tr.output_path / "stdout.txt").write_text(stdout_content)

return tr


@pytest.fixture
def megatron_run_tr_no_data(tmp_path: Path) -> TestRun:
test = MegatronRunTestDefinition(
name="megatron_run",
description="desc",
test_template_name="t",
cmd_args=MegatronRunCmdArgs(docker_image_url="http://url", run_script=Path(__file__)),
)
tr = TestRun(name="megatron_run_test", test=test, num_nodes=1, nodes=[], output_path=tmp_path)

stdout_content = """
Some random log output without iteration metrics
Starting training...
"""
(tr.output_path / "stdout.txt").write_text(stdout_content)

return tr


def test_megatron_run_can_handle_directory(slurm_system: SlurmSystem, megatron_run_tr: TestRun) -> None:
strategy = MegatronRunReportGenerationStrategy(slurm_system, megatron_run_tr)
assert strategy.can_handle_directory()


def test_megatron_run_cannot_handle_directory_without_iteration_data(
slurm_system: SlurmSystem, megatron_run_tr_no_data: TestRun
) -> None:
strategy = MegatronRunReportGenerationStrategy(slurm_system, megatron_run_tr_no_data)
assert not strategy.can_handle_directory()


def test_megatron_run_extract_and_generate_report(slurm_system: SlurmSystem, megatron_run_tr: TestRun) -> None:
strategy = MegatronRunReportGenerationStrategy(slurm_system, megatron_run_tr)
strategy.generate_report()
report_path = megatron_run_tr.output_path / "megatron_run_report.csv"
assert report_path.is_file()

with report_path.open() as f:
reader = csv.DictReader(f)
rows = list(reader)

# Should have 2 rows: iteration_time_ms and tflops_per_gpu
assert len(rows) == 2

expected_headers = {"metric_type", "avg", "median", "min", "max", "std"}
assert set(rows[0].keys()) == expected_headers

data = {row["metric_type"]: row for row in rows}

# Verify iteration_time_ms stats
assert "iteration_time_ms" in data
iter_stats = data["iteration_time_ms"]
expected_iter_avg = (15800.0 + 15639.0 + 15448.5) / 3
assert abs(float(iter_stats["avg"]) - expected_iter_avg) < 0.1
assert abs(float(iter_stats["median"]) - 15639.0) < 0.1
assert abs(float(iter_stats["min"]) - 15448.5) < 0.1
assert abs(float(iter_stats["max"]) - 15800.0) < 0.1

# Verify tflops_per_gpu stats
assert "tflops_per_gpu" in data
tflops_stats = data["tflops_per_gpu"]
expected_tflops_avg = (490.0 + 494.6 + 500.6) / 3
assert abs(float(tflops_stats["avg"]) - expected_tflops_avg) < 0.1
assert abs(float(tflops_stats["median"]) - 494.6) < 0.1
assert abs(float(tflops_stats["min"]) - 490.0) < 0.1
assert abs(float(tflops_stats["max"]) - 500.6) < 0.1


def test_megatron_run_get_metric_iteration_time(slurm_system: SlurmSystem, megatron_run_tr: TestRun) -> None:
strategy = MegatronRunReportGenerationStrategy(slurm_system, megatron_run_tr)
# Expected: avg of [15800.0, 15639.0, 15448.5]
expected_avg = (15800.0 + 15639.0 + 15448.5) / 3
metric = strategy.get_metric("iteration-time")
assert abs(metric - expected_avg) < 0.1


def test_megatron_run_get_metric_default(slurm_system: SlurmSystem, megatron_run_tr: TestRun) -> None:
strategy = MegatronRunReportGenerationStrategy(slurm_system, megatron_run_tr)
# Default should return iteration-time
expected_avg = (15800.0 + 15639.0 + 15448.5) / 3
metric = strategy.get_metric("default")
assert abs(metric - expected_avg) < 0.1


def test_megatron_run_get_metric_tflops(slurm_system: SlurmSystem, megatron_run_tr: TestRun) -> None:
strategy = MegatronRunReportGenerationStrategy(slurm_system, megatron_run_tr)
# Expected: avg of [490.0, 494.6, 500.6]
expected_avg = (490.0 + 494.6 + 500.6) / 3
metric = strategy.get_metric("tflops-per-gpu")
assert abs(metric - expected_avg) < 0.1


def test_megatron_run_get_metric_invalid(slurm_system: SlurmSystem, megatron_run_tr: TestRun) -> None:
strategy = MegatronRunReportGenerationStrategy(slurm_system, megatron_run_tr)
metric = strategy.get_metric("invalid-metric")
assert metric == METRIC_ERROR


def test_megatron_run_get_metric_no_data(slurm_system: SlurmSystem, megatron_run_tr_no_data: TestRun) -> None:
strategy = MegatronRunReportGenerationStrategy(slurm_system, megatron_run_tr_no_data)
metric = strategy.get_metric("iteration-time")
assert metric == METRIC_ERROR


def test_megatron_run_metrics_class_var() -> None:
assert MegatronRunReportGenerationStrategy.metrics == ["default", "iteration-time", "tflops-per-gpu"]
Loading