|
22 | 22 | from cloudai import Test, TestRun |
23 | 23 | from cloudai.systems.slurm.slurm_system import SlurmSystem |
24 | 24 | from cloudai.workloads.nemo_run import NeMoRunCmdArgs, NeMoRunReportGenerationStrategy, NeMoRunTestDefinition |
| 25 | +from cloudai.workloads.nemo_run.report_generation_strategy import extract_timings |
25 | 26 |
|
26 | 27 |
|
27 | 28 | @pytest.fixture |
@@ -154,3 +155,79 @@ def test_metrics(nemo_tr: TestRun, slurm_system: SlurmSystem, metric: str): |
154 | 155 | nemo_tr.test.test_definition.agent_metric = metric |
155 | 156 | value = nemo_tr.get_metric_value(slurm_system) |
156 | 157 | assert value == 12.72090909090909 |
| 158 | + |
| 159 | + |
| 160 | +def test_extract_timings_valid_file(tmp_path: Path) -> None: |
| 161 | + stdout_file = tmp_path / "stdout.txt" |
| 162 | + stdout_file.write_text( |
| 163 | + "Training epoch 0, iteration 17/99 | train_step_timing in s: 12.64 | global_step: 17\n" |
| 164 | + "Training epoch 0, iteration 18/99 | train_step_timing in s: 12.65 | global_step: 18\n" |
| 165 | + "Training epoch 0, iteration 19/99 | train_step_timing in s: 12.66 | global_step: 19\n" |
| 166 | + ) |
| 167 | + |
| 168 | + timings = extract_timings(stdout_file) |
| 169 | + assert timings == [12.65, 12.66], "Timings extraction failed for valid file." |
| 170 | + |
| 171 | + |
| 172 | +def test_extract_timings_missing_file(tmp_path: Path) -> None: |
| 173 | + stdout_file = tmp_path / "missing_stdout.txt" |
| 174 | + |
| 175 | + timings = extract_timings(stdout_file) |
| 176 | + assert timings == [], "Timings extraction should return an empty list for missing file." |
| 177 | + |
| 178 | + |
| 179 | +def test_extract_timings_invalid_content(tmp_path: Path) -> None: |
| 180 | + stdout_file = tmp_path / "stdout.txt" |
| 181 | + stdout_file.write_text("Invalid content without timing information\n") |
| 182 | + |
| 183 | + timings = extract_timings(stdout_file) |
| 184 | + assert timings == [], "Timings extraction should return an empty list for invalid content." |
| 185 | + |
| 186 | + |
| 187 | +def test_extract_timings_file_not_found(tmp_path: Path) -> None: |
| 188 | + stdout_file = tmp_path / "nonexistent_stdout.txt" |
| 189 | + |
| 190 | + timings = extract_timings(stdout_file) |
| 191 | + assert timings == [], "Timings extraction should return an empty list when the file does not exist." |
| 192 | + |
| 193 | + |
| 194 | +def test_generate_report_no_timings(slurm_system: SlurmSystem, nemo_tr: TestRun, tmp_path: Path) -> None: |
| 195 | + nemo_tr.output_path = tmp_path |
| 196 | + stdout_file = nemo_tr.output_path / "stdout.txt" |
| 197 | + stdout_file.write_text("No valid timing information\n") |
| 198 | + |
| 199 | + strategy = NeMoRunReportGenerationStrategy(slurm_system, nemo_tr) |
| 200 | + strategy.generate_report() |
| 201 | + |
| 202 | + summary_file = nemo_tr.output_path / "report.txt" |
| 203 | + assert not summary_file.exists(), "Report should not be generated if no valid timings are found." |
| 204 | + |
| 205 | + |
| 206 | +def test_generate_report_partial_timings(slurm_system: SlurmSystem, nemo_tr: TestRun, tmp_path: Path) -> None: |
| 207 | + nemo_tr.output_path = tmp_path |
| 208 | + stdout_file = nemo_tr.output_path / "stdout.txt" |
| 209 | + stdout_file.write_text( |
| 210 | + "Training epoch 0, iteration 17/99 | train_step_timing in s: 12.64 | global_step: 17\n" |
| 211 | + "Invalid line without timing\n" |
| 212 | + "Training epoch 0, iteration 18/99 | train_step_timing in s: 12.65 | global_step: 18\n" |
| 213 | + ) |
| 214 | + |
| 215 | + strategy = NeMoRunReportGenerationStrategy(slurm_system, nemo_tr) |
| 216 | + strategy.generate_report() |
| 217 | + |
| 218 | + summary_file = nemo_tr.output_path / "report.txt" |
| 219 | + assert summary_file.is_file(), "Report should be generated even with partial valid timings." |
| 220 | + |
| 221 | + summary_content = summary_file.read_text().strip().split("\n") |
| 222 | + assert len(summary_content) == 4, "Summary file should contain four lines (avg, median, min, max)." |
| 223 | + |
| 224 | + expected_values = { |
| 225 | + "Average": 12.645, |
| 226 | + "Median": 12.645, |
| 227 | + "Min": 12.64, |
| 228 | + "Max": 12.65, |
| 229 | + } |
| 230 | + |
| 231 | + for line in summary_content: |
| 232 | + key, value = line.split(": ") |
| 233 | + assert pytest.approx(float(value), 0.01) == expected_values[key], f"{key} value mismatch." |
0 commit comments