Skip to content

Commit 8e08fe4

Browse files
Merge pull request #444 from srivatsankrishnan/reporter_bug
Nemo Dry-Run/Run Fix
2 parents 3075fac + f0ac0f2 commit 8e08fe4

File tree

2 files changed

+88
-1
lines changed

2 files changed

+88
-1
lines changed

src/cloudai/workloads/nemo_run/report_generation_strategy.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,10 @@
2828

2929
@cache
3030
def extract_timings(stdout_file: Path) -> list[float]:
31+
if not stdout_file.exists():
32+
logging.error(f"{stdout_file} not found")
33+
return []
34+
3135
train_step_timings: list[float] = []
3236
step_timings: list[float] = []
3337

@@ -76,6 +80,9 @@ def generate_report(self) -> None:
7680
return
7781

7882
step_timings = extract_timings(self.results_file)
83+
if not step_timings:
84+
logging.error("No valid step timings found. Report generation aborted.")
85+
return
7986

8087
stats = {
8188
"avg": np.mean(step_timings),
@@ -93,7 +100,10 @@ def generate_report(self) -> None:
93100

94101
def get_metric(self, metric: str) -> float:
95102
step_timings = extract_timings(self.results_file)
96-
if metric not in {"default", "step-time"} or not step_timings:
103+
if not step_timings:
104+
return METRIC_ERROR
105+
106+
if metric not in {"default", "step-time"}:
97107
return METRIC_ERROR
98108

99109
return float(np.mean(step_timings))

tests/report_generation_strategy/test_nemo_run_report_generation_strategy.py

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from cloudai import Test, TestRun
2323
from cloudai.systems.slurm.slurm_system import SlurmSystem
2424
from cloudai.workloads.nemo_run import NeMoRunCmdArgs, NeMoRunReportGenerationStrategy, NeMoRunTestDefinition
25+
from cloudai.workloads.nemo_run.report_generation_strategy import extract_timings
2526

2627

2728
@pytest.fixture
@@ -154,3 +155,79 @@ def test_metrics(nemo_tr: TestRun, slurm_system: SlurmSystem, metric: str):
154155
nemo_tr.test.test_definition.agent_metric = metric
155156
value = nemo_tr.get_metric_value(slurm_system)
156157
assert value == 12.72090909090909
158+
159+
160+
def test_extract_timings_valid_file(tmp_path: Path) -> None:
161+
stdout_file = tmp_path / "stdout.txt"
162+
stdout_file.write_text(
163+
"Training epoch 0, iteration 17/99 | train_step_timing in s: 12.64 | global_step: 17\n"
164+
"Training epoch 0, iteration 18/99 | train_step_timing in s: 12.65 | global_step: 18\n"
165+
"Training epoch 0, iteration 19/99 | train_step_timing in s: 12.66 | global_step: 19\n"
166+
)
167+
168+
timings = extract_timings(stdout_file)
169+
assert timings == [12.65, 12.66], "Timings extraction failed for valid file."
170+
171+
172+
def test_extract_timings_missing_file(tmp_path: Path) -> None:
173+
stdout_file = tmp_path / "missing_stdout.txt"
174+
175+
timings = extract_timings(stdout_file)
176+
assert timings == [], "Timings extraction should return an empty list for missing file."
177+
178+
179+
def test_extract_timings_invalid_content(tmp_path: Path) -> None:
180+
stdout_file = tmp_path / "stdout.txt"
181+
stdout_file.write_text("Invalid content without timing information\n")
182+
183+
timings = extract_timings(stdout_file)
184+
assert timings == [], "Timings extraction should return an empty list for invalid content."
185+
186+
187+
def test_extract_timings_file_not_found(tmp_path: Path) -> None:
188+
stdout_file = tmp_path / "nonexistent_stdout.txt"
189+
190+
timings = extract_timings(stdout_file)
191+
assert timings == [], "Timings extraction should return an empty list when the file does not exist."
192+
193+
194+
def test_generate_report_no_timings(slurm_system: SlurmSystem, nemo_tr: TestRun, tmp_path: Path) -> None:
195+
nemo_tr.output_path = tmp_path
196+
stdout_file = nemo_tr.output_path / "stdout.txt"
197+
stdout_file.write_text("No valid timing information\n")
198+
199+
strategy = NeMoRunReportGenerationStrategy(slurm_system, nemo_tr)
200+
strategy.generate_report()
201+
202+
summary_file = nemo_tr.output_path / "report.txt"
203+
assert not summary_file.exists(), "Report should not be generated if no valid timings are found."
204+
205+
206+
def test_generate_report_partial_timings(slurm_system: SlurmSystem, nemo_tr: TestRun, tmp_path: Path) -> None:
207+
nemo_tr.output_path = tmp_path
208+
stdout_file = nemo_tr.output_path / "stdout.txt"
209+
stdout_file.write_text(
210+
"Training epoch 0, iteration 17/99 | train_step_timing in s: 12.64 | global_step: 17\n"
211+
"Invalid line without timing\n"
212+
"Training epoch 0, iteration 18/99 | train_step_timing in s: 12.65 | global_step: 18\n"
213+
)
214+
215+
strategy = NeMoRunReportGenerationStrategy(slurm_system, nemo_tr)
216+
strategy.generate_report()
217+
218+
summary_file = nemo_tr.output_path / "report.txt"
219+
assert summary_file.is_file(), "Report should be generated even with partial valid timings."
220+
221+
summary_content = summary_file.read_text().strip().split("\n")
222+
assert len(summary_content) == 4, "Summary file should contain four lines (avg, median, min, max)."
223+
224+
expected_values = {
225+
"Average": 12.645,
226+
"Median": 12.645,
227+
"Min": 12.64,
228+
"Max": 12.65,
229+
}
230+
231+
for line in summary_content:
232+
key, value = line.split(": ")
233+
assert pytest.approx(float(value), 0.01) == expected_values[key], f"{key} value mismatch."

0 commit comments

Comments
 (0)