Skip to content

Commit 36ab8b2

Browse files
authored
Merge pull request #661 from NVIDIA/am/scenario-reports-cfg
Configure reports via scenario config
2 parents 13b83f2 + 57adfaa commit 36ab8b2

File tree

8 files changed

+102
-3
lines changed

8 files changed

+102
-3
lines changed

doc/reporting.md

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,13 @@ Per-test reports are linked to a particular workload type (e.g. `NcclTest`). All
1010
To list all available reports, one can use `cloudai list-reports` command. Use verbose output to also print report configurations.
1111

1212

13+
## Notes and general flow
14+
1. All reports should be registered via `Registry()` (`.add_report()` or `.add_scenario_report()`).
15+
1. Scenario reports are configurable via system config (Slurm-only for now) and scenario config.
16+
1. Configuration in a scenario config has the highest priority. Next, system config is checked. Then it defaults to report config from the registry.
17+
1. Then report is generated (or not) according to this final config.
18+
19+
1320
## Enable, disable and configure reports
1421
**NOTE** Only scenario-level reports can be configured today.
1522

src/cloudai/_core/test_scenario.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from .test_template_strategy import TestTemplateStrategy
2727

2828
if TYPE_CHECKING:
29+
from ..models.scenario import ReportConfig
2930
from .report_generation_strategy import ReportGenerationStrategy
3031
from .test import Test
3132

@@ -184,18 +185,26 @@ class TestScenario:
184185

185186
__test__ = False
186187

187-
def __init__(self, name: str, test_runs: List[TestRun], job_status_check: bool = True) -> None:
188+
def __init__(
189+
self,
190+
name: str,
191+
test_runs: List[TestRun],
192+
job_status_check: bool = True,
193+
reports: dict[str, ReportConfig] | None = None,
194+
) -> None:
188195
"""
189196
Initialize a TestScenario instance.
190197
191198
Args:
192199
name (str): Name of the test scenario.
193200
test_runs (List[TestRun]): List of tests in the scenario with custom run options.
194201
job_status_check (bool): Flag indicating whether to check the job status or not.
202+
reports (Optional[dict[str, ReportConfig]]): Reports to be generated for the scenario.
195203
"""
196204
self.name = name
197205
self.test_runs = test_runs
198206
self.job_status_check = job_status_check
207+
self.reports = reports or {}
199208

200209
def __repr__(self) -> str:
201210
"""

src/cloudai/cli/handlers.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,9 @@ def generate_reports(system: System, test_scenario: TestScenario, result_dir: Pa
157157
logging.debug(f"Generating report '{name}' ({reporter_class.__name__})")
158158

159159
cfg = registry.report_configs.get(name, ReportConfig(enable=False))
160-
if isinstance(system, SlurmSystem) and system.reports and name in system.reports:
160+
if scenario_cfg := test_scenario.reports.get(name):
161+
cfg = scenario_cfg
162+
elif isinstance(system, SlurmSystem) and system.reports and name in system.reports:
161163
cfg = system.reports[name]
162164
logging.debug(f"Report '{name}' config is: {cfg.model_dump_json(indent=None)}")
163165

src/cloudai/models/scenario.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,7 @@ class TestScenarioModel(BaseModel):
155155
tests: list[TestRunModel] = Field(alias="Tests", min_length=1)
156156
pre_test: Optional[str] = None
157157
post_test: Optional[str] = None
158+
reports: dict[str, ReportConfig] = Field(default_factory=dict)
158159

159160
@model_validator(mode="after")
160161
def check_no_self_dependency(self):
@@ -188,6 +189,11 @@ def check_all_dependencies_are_known(self):
188189

189190
return self
190191

192+
@field_validator("reports", mode="before")
193+
@classmethod
194+
def parse_reports(cls, value: dict[str, Any] | None) -> dict[str, ReportConfig] | None:
195+
return parse_reports_spec(value)
196+
191197

192198
class TestRunDetails(BaseModel):
193199
"""

src/cloudai/test_scenario_parser.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,7 @@ def _parse_data(self, data: Dict[str, Any]) -> TestScenario:
167167
name=ts_model.name,
168168
test_runs=list(test_runs_by_id.values()),
169169
job_status_check=ts_model.job_status_check,
170+
reports=ts_model.reports,
170171
)
171172

172173
def _create_test_run(

src/cloudai/workloads/nccl_test/nccl_comparisson_report.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747
class NcclComparissonReportConfig(ReportConfig):
4848
"""Configuration for NCCL comparisson report."""
4949

50+
enable: bool = True
5051
group_by: list[str] = Field(default_factory=lambda: ["subtest_name"])
5152

5253

tests/test_reporter.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -208,3 +208,49 @@ def test_disabled_on_system_level(self, slurm_system: SlurmSystem) -> None:
208208
slurm_system.reports = {"sr1": ReportConfig(enable=False)}
209209
generate_reports(slurm_system, TestScenario(name="ts", test_runs=[]), slurm_system.output_path)
210210
assert MY_REPORT_CALLED == 0
211+
212+
213+
class TestGenerateReportPriority:
214+
@pytest.fixture(autouse=True)
215+
def setup(self):
216+
reg = Registry()
217+
orig_reports = copy.deepcopy(reg.scenario_reports)
218+
reg.scenario_reports.clear()
219+
220+
global MY_REPORT_CALLED
221+
MY_REPORT_CALLED = 0
222+
223+
yield
224+
225+
reg.scenario_reports.clear()
226+
reg.scenario_reports.update(orig_reports)
227+
228+
def test_non_registered_report_is_ignored(self, slurm_system: SlurmSystem) -> None:
229+
generate_reports(slurm_system, TestScenario(name="ts", test_runs=[]), slurm_system.output_path)
230+
assert MY_REPORT_CALLED == 0
231+
232+
def test_report_is_enabled_on_system_level(self, slurm_system: SlurmSystem) -> None:
233+
Registry().add_scenario_report("sr1", MyReporter, ReportConfig(enable=True))
234+
slurm_system.reports = {"sr1": ReportConfig(enable=True)}
235+
generate_reports(slurm_system, TestScenario(name="ts", test_runs=[]), slurm_system.output_path)
236+
assert MY_REPORT_CALLED == 1
237+
238+
def test_report_is_enabled_on_scenario_level(self, slurm_system: SlurmSystem) -> None:
239+
Registry().add_scenario_report("sr1", MyReporter, ReportConfig(enable=True))
240+
slurm_system.reports = {}
241+
generate_reports(
242+
slurm_system,
243+
TestScenario(name="ts", test_runs=[], reports={"sr1": ReportConfig(enable=True)}),
244+
slurm_system.output_path,
245+
)
246+
assert MY_REPORT_CALLED == 1
247+
248+
def test_report_scenario_has_highest_priority(self, slurm_system: SlurmSystem) -> None:
249+
Registry().add_scenario_report("sr1", MyReporter, ReportConfig(enable=True))
250+
slurm_system.reports = {"sr1": ReportConfig(enable=False)}
251+
generate_reports(
252+
slurm_system,
253+
TestScenario(name="ts", test_runs=[], reports={"sr1": ReportConfig(enable=True)}),
254+
slurm_system.output_path,
255+
)
256+
assert MY_REPORT_CALLED == 1

tests/test_test_scenario_parser.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,16 @@
1515
# limitations under the License.
1616

1717
from pathlib import Path
18-
from typing import List, Optional
18+
from typing import List, Optional, cast
1919
from unittest.mock import create_autospec
2020

2121
import pytest
22+
import toml
2223

2324
from cloudai.core import Test, TestDefinition, TestRun, TestScenario
25+
from cloudai.models.scenario import TestScenarioModel
2426
from cloudai.test_scenario_parser import calculate_total_time_limit
27+
from cloudai.workloads.nccl_test.nccl_comparisson_report import NcclComparissonReportConfig
2528

2629

2730
class DummyTestRun(TestRun):
@@ -71,3 +74,27 @@ def test_calculate_total_time_limit(
7174
test_hooks: List[TestScenario], time_limit: Optional[str], expected: Optional[str]
7275
) -> None:
7376
assert calculate_total_time_limit(test_hooks, time_limit) == expected
77+
78+
79+
def test_report_spec_is_parsed() -> None:
80+
model = TestScenarioModel.model_validate(
81+
toml.loads("""
82+
name = "scenario"
83+
84+
[reports]
85+
nccl_comparisson = { enable = false, group_by = ["my_field"] }
86+
87+
[[Tests]]
88+
id = "1"
89+
num_nodes = 2
90+
91+
name = "name"
92+
description = "desc"
93+
test_template_name = "NcclTest"
94+
""")
95+
)
96+
97+
assert len(model.reports) == 1
98+
cfg = cast(NcclComparissonReportConfig, model.reports["nccl_comparisson"])
99+
assert cfg.enable is False
100+
assert cfg.group_by == ["my_field"]

0 commit comments

Comments
 (0)