Skip to content

Commit b6491d3

Browse files
authored
Merge pull request #596 from NVIDIA/am/per-tr-cmd-gen
Create CmdGenStrategy per usage
2 parents 872f0a5 + 997c10d commit b6491d3

34 files changed

+286
-277
lines changed

src/cloudai/_core/base_runner.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,10 @@
2222
from typing import Dict, List
2323

2424
from .base_job import BaseJob
25+
from .command_gen_strategy import CommandGenStrategy
2526
from .exceptions import JobFailureError, JobSubmissionError
2627
from .job_status_result import JobStatusResult
28+
from .registry import Registry
2729
from .system import System
2830
from .test_scenario import TestRun, TestScenario
2931

@@ -113,9 +115,7 @@ async def submit_test(self, tr: TestRun):
113115
exit(1)
114116

115117
def on_job_submit(self, tr: TestRun) -> None:
116-
if tr.test.test_template._command_gen_strategy is not None:
117-
cmd_gen = tr.test.test_template.command_gen_strategy
118-
cmd_gen.store_test_run(tr)
118+
return
119119

120120
async def delayed_submit_test(self, tr: TestRun, delay: int = 5):
121121
"""
@@ -372,3 +372,7 @@ async def delayed_kill_job(self, job: BaseJob, delay: int = 0):
372372
await asyncio.sleep(delay)
373373
job.terminated_by_dependency = True
374374
self.system.kill(job)
375+
376+
def get_cmd_gen_strategy(self, system: System, test_run: TestRun) -> CommandGenStrategy:
377+
strategy_cls = Registry().get_command_gen_strategy(type(system), type(test_run.test.test_definition))
378+
return strategy_cls(system, test_run)

src/cloudai/_core/command_gen_strategy.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
from abc import ABC, abstractmethod
1818

19+
from .system import System
1920
from .test_scenario import TestRun
2021
from .test_template_strategy import TestTemplateStrategy
2122

@@ -25,6 +26,10 @@ class CommandGenStrategy(TestTemplateStrategy, ABC):
2526

2627
TEST_RUN_DUMP_FILE_NAME: str = "test-run.toml"
2728

29+
def __init__(self, system: System, test_run: TestRun) -> None:
30+
super().__init__(system)
31+
self.test_run = test_run
32+
2833
@abstractmethod
2934
def gen_exec_command(self, tr: TestRun) -> str:
3035
"""

src/cloudai/_core/registry.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from ..reporter import Reporter
2626
from .base_installer import BaseInstaller
2727
from .base_runner import BaseRunner
28+
from .command_gen_strategy import CommandGenStrategy
2829
from .grading_strategy import GradingStrategy
2930
from .report_generation_strategy import ReportGenerationStrategy
3031
from .system import System
@@ -66,6 +67,7 @@ class Registry(metaclass=Singleton):
6667
scenario_reports: ClassVar[dict[str, type[Reporter]]] = {}
6768
report_configs: ClassVar[dict[str, ReportConfig]] = {}
6869
reward_functions_map: ClassVar[dict[str, RewardFunction]] = {}
70+
command_gen_strategies_map: ClassVar[dict[tuple[Type[System], Type[TestDefinition]], Type[CommandGenStrategy]]] = {}
6971

7072
def add_runner(self, name: str, value: Type[BaseRunner]) -> None:
7173
"""
@@ -250,3 +252,25 @@ def get_reward_function(self, name: str) -> RewardFunction:
250252
f"Reward function '{name}' not found. Available functions: {list(self.reward_functions_map.keys())}"
251253
)
252254
return self.reward_functions_map[name]
255+
256+
def add_command_gen_strategy(
257+
self, system_type: Type[System], tdef_type: Type[TestDefinition], value: Type[CommandGenStrategy]
258+
) -> None:
259+
if (system_type, tdef_type) in self.command_gen_strategies_map:
260+
raise ValueError(
261+
f"Duplicating implementation for '{system_type.__name__}, {tdef_type.__name__}', use 'update()' "
262+
"for replacement."
263+
)
264+
self.update_command_gen_strategy(system_type, tdef_type, value)
265+
266+
def update_command_gen_strategy(
267+
self, system_type: Type[System], tdef_type: Type[TestDefinition], value: Type[CommandGenStrategy]
268+
) -> None:
269+
self.command_gen_strategies_map[(system_type, tdef_type)] = value
270+
271+
def get_command_gen_strategy(
272+
self, system_type: Type[System], tdef_type: Type[TestDefinition]
273+
) -> Type[CommandGenStrategy]:
274+
if (system_type, tdef_type) not in self.command_gen_strategies_map:
275+
raise KeyError(f"Command gen strategy for '{system_type.__name__}, {tdef_type.__name__}' not found.")
276+
return self.command_gen_strategies_map[(system_type, tdef_type)]

src/cloudai/_core/test_template.py

Lines changed: 0 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
from pathlib import Path
1818
from typing import Any, Dict, Optional
1919

20-
from .command_gen_strategy import CommandGenStrategy
2120
from .grading_strategy import GradingStrategy
2221
from .json_gen_strategy import JsonGenStrategy
2322
from .system import System
@@ -30,13 +29,6 @@ class TestTemplate:
3029
3130
Providing a framework for test execution, including installation, uninstallation, and execution command generation
3231
based on system configurations and test parameters.
33-
34-
Attributes
35-
cmd_args (Dict[str, Any]): Default command-line arguments.
36-
logger (logging.Logger): Logger for the test template.
37-
command_gen_strategy (CommandGenStrategy): Strategy for generating execution commands.
38-
json_gen_strategy (JsonGenStrategy): Strategy for generating json string.
39-
grading_strategy (GradingStrategy): Strategy for grading performance based on test outcomes.
4032
"""
4133

4234
__test__ = False
@@ -49,23 +41,9 @@ def __init__(self, system: System) -> None:
4941
system (System): System configuration for the test template.
5042
"""
5143
self.system = system
52-
self._command_gen_strategy: Optional[CommandGenStrategy] = None
5344
self._json_gen_strategy: Optional[JsonGenStrategy] = None
5445
self.grading_strategy: Optional[GradingStrategy] = None
5546

56-
@property
57-
def command_gen_strategy(self) -> CommandGenStrategy:
58-
if self._command_gen_strategy is None:
59-
raise ValueError(
60-
"command_gen_strategy is missing. Ensure the strategy is registered in the Registry "
61-
"by calling the appropriate registration function for the system type."
62-
)
63-
return self._command_gen_strategy
64-
65-
@command_gen_strategy.setter
66-
def command_gen_strategy(self, value: CommandGenStrategy) -> None:
67-
self._command_gen_strategy = value
68-
6947
@property
7048
def json_gen_strategy(self) -> JsonGenStrategy:
7149
if self._json_gen_strategy is None:
@@ -79,18 +57,6 @@ def json_gen_strategy(self) -> JsonGenStrategy:
7957
def json_gen_strategy(self, value: JsonGenStrategy) -> None:
8058
self._json_gen_strategy = value
8159

82-
def gen_exec_command(self, tr: TestRun) -> str:
83-
"""
84-
Generate an execution command for a test using this template.
85-
86-
Args:
87-
tr (TestRun): Contains the test and its run-specific configurations.
88-
89-
Returns:
90-
str: The generated execution command.
91-
"""
92-
return self.command_gen_strategy.gen_exec_command(tr)
93-
9460
def gen_json(self, tr: TestRun) -> Dict[Any, Any]:
9561
"""
9662
Generate a JSON string representing the Kubernetes job specification for this test using this template.

src/cloudai/registration.py

Lines changed: 18 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ def register_all():
2323
inverse_reward,
2424
negative_reward,
2525
)
26-
from cloudai.core import CommandGenStrategy, GradingStrategy, JsonGenStrategy, Registry
26+
from cloudai.core import GradingStrategy, JsonGenStrategy, Registry
2727
from cloudai.models.scenario import ReportConfig
2828
from cloudai.reporter import PerTestReporter, StatusReporter, TarballReporter
2929

@@ -117,31 +117,23 @@ def register_all():
117117
Registry().add_runner("lsf", LSFRunner)
118118
Registry().add_runner("runai", RunAIRunner)
119119

120-
Registry().add_strategy(
121-
CommandGenStrategy, [StandaloneSystem], [SleepTestDefinition], SleepStandaloneCommandGenStrategy
122-
)
123-
Registry().add_strategy(CommandGenStrategy, [LSFSystem], [SleepTestDefinition], SleepLSFCommandGenStrategy)
124-
Registry().add_strategy(CommandGenStrategy, [SlurmSystem], [SleepTestDefinition], SleepSlurmCommandGenStrategy)
120+
Registry().add_command_gen_strategy(StandaloneSystem, SleepTestDefinition, SleepStandaloneCommandGenStrategy)
121+
Registry().add_command_gen_strategy(LSFSystem, SleepTestDefinition, SleepLSFCommandGenStrategy)
122+
Registry().add_command_gen_strategy(SlurmSystem, SleepTestDefinition, SleepSlurmCommandGenStrategy)
125123
Registry().add_strategy(JsonGenStrategy, [KubernetesSystem], [SleepTestDefinition], SleepKubernetesJsonGenStrategy)
126124
Registry().add_strategy(
127125
JsonGenStrategy, [KubernetesSystem], [NCCLTestDefinition], NcclTestKubernetesJsonGenStrategy
128126
)
129127
Registry().add_strategy(JsonGenStrategy, [RunAISystem], [NCCLTestDefinition], NcclTestRunAIJsonGenStrategy)
130128
Registry().add_strategy(GradingStrategy, [SlurmSystem], [NCCLTestDefinition], NcclTestGradingStrategy)
131129

132-
Registry().add_strategy(
133-
CommandGenStrategy, [SlurmSystem], [MegatronRunTestDefinition], MegatronRunSlurmCommandGenStrategy
134-
)
135-
Registry().add_strategy(CommandGenStrategy, [SlurmSystem], [NCCLTestDefinition], NcclTestSlurmCommandGenStrategy)
130+
Registry().add_command_gen_strategy(SlurmSystem, MegatronRunTestDefinition, MegatronRunSlurmCommandGenStrategy)
131+
Registry().add_command_gen_strategy(SlurmSystem, NCCLTestDefinition, NcclTestSlurmCommandGenStrategy)
136132
Registry().add_strategy(GradingStrategy, [SlurmSystem], [SleepTestDefinition], SleepGradingStrategy)
137133

138-
Registry().add_strategy(
139-
CommandGenStrategy, [SlurmSystem], [NeMoLauncherTestDefinition], NeMoLauncherSlurmCommandGenStrategy
140-
)
141-
Registry().add_strategy(CommandGenStrategy, [SlurmSystem], [NeMoRunTestDefinition], NeMoRunSlurmCommandGenStrategy)
142-
Registry().add_strategy(
143-
CommandGenStrategy, [SlurmSystem], [NIXLBenchTestDefinition], NIXLBenchSlurmCommandGenStrategy
144-
)
134+
Registry().add_command_gen_strategy(SlurmSystem, NeMoLauncherTestDefinition, NeMoLauncherSlurmCommandGenStrategy)
135+
Registry().add_command_gen_strategy(SlurmSystem, NeMoRunTestDefinition, NeMoRunSlurmCommandGenStrategy)
136+
Registry().add_command_gen_strategy(SlurmSystem, NIXLBenchTestDefinition, NIXLBenchSlurmCommandGenStrategy)
145137

146138
Registry().add_strategy(GradingStrategy, [SlurmSystem], [NeMoLauncherTestDefinition], NeMoLauncherGradingStrategy)
147139
Registry().add_strategy(
@@ -151,29 +143,20 @@ def register_all():
151143
JaxToolboxGradingStrategy,
152144
)
153145
Registry().add_strategy(GradingStrategy, [SlurmSystem], [UCCTestDefinition], UCCTestGradingStrategy)
154-
Registry().add_strategy(
155-
CommandGenStrategy,
156-
[SlurmSystem],
157-
[GPTTestDefinition, GrokTestDefinition, NemotronTestDefinition],
158-
JaxToolboxSlurmCommandGenStrategy,
159-
)
146+
Registry().add_command_gen_strategy(SlurmSystem, GPTTestDefinition, JaxToolboxSlurmCommandGenStrategy)
147+
Registry().add_command_gen_strategy(SlurmSystem, GrokTestDefinition, JaxToolboxSlurmCommandGenStrategy)
148+
Registry().add_command_gen_strategy(SlurmSystem, NemotronTestDefinition, JaxToolboxSlurmCommandGenStrategy)
160149

161-
Registry().add_strategy(CommandGenStrategy, [SlurmSystem], [UCCTestDefinition], UCCTestSlurmCommandGenStrategy)
150+
Registry().add_command_gen_strategy(SlurmSystem, UCCTestDefinition, UCCTestSlurmCommandGenStrategy)
162151

163152
Registry().add_strategy(GradingStrategy, [SlurmSystem], [ChakraReplayTestDefinition], ChakraReplayGradingStrategy)
164-
Registry().add_strategy(
165-
CommandGenStrategy, [SlurmSystem], [ChakraReplayTestDefinition], ChakraReplaySlurmCommandGenStrategy
166-
)
167-
Registry().add_strategy(
168-
CommandGenStrategy, [SlurmSystem], [SlurmContainerTestDefinition], SlurmContainerCommandGenStrategy
169-
)
170-
Registry().add_strategy(
171-
CommandGenStrategy, [SlurmSystem], [TritonInferenceTestDefinition], TritonInferenceSlurmCommandGenStrategy
153+
Registry().add_command_gen_strategy(SlurmSystem, ChakraReplayTestDefinition, ChakraReplaySlurmCommandGenStrategy)
154+
Registry().add_command_gen_strategy(SlurmSystem, SlurmContainerTestDefinition, SlurmContainerCommandGenStrategy)
155+
Registry().add_command_gen_strategy(
156+
SlurmSystem, TritonInferenceTestDefinition, TritonInferenceSlurmCommandGenStrategy
172157
)
173158

174-
Registry().add_strategy(
175-
CommandGenStrategy, [SlurmSystem], [AIDynamoTestDefinition], AIDynamoSlurmCommandGenStrategy
176-
)
159+
Registry().add_command_gen_strategy(SlurmSystem, AIDynamoTestDefinition, AIDynamoSlurmCommandGenStrategy)
177160

178161
Registry().add_installer("slurm", SlurmInstaller)
179162
Registry().add_installer("standalone", StandaloneInstaller)

src/cloudai/systems/lsf/lsf_command_gen_strategy.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,9 @@
1515
# limitations under the License.
1616

1717
from datetime import datetime
18-
from typing import Any, Dict, List, Union
18+
from typing import Any, Dict, List, Union, cast
1919

20-
from cloudai.core import CommandGenStrategy, TestRun
20+
from cloudai.core import CommandGenStrategy, System, TestRun
2121

2222
from .lsf_system import LSFSystem
2323

@@ -31,15 +31,16 @@ class LSFCommandGenStrategy(CommandGenStrategy):
3131
properties and methods.
3232
"""
3333

34-
def __init__(self, system: LSFSystem) -> None:
34+
def __init__(self, system: System, test_run: TestRun) -> None:
3535
"""
3636
Initialize a new LSFCommandGenStrategy instance.
3737
3838
Args:
3939
system (LSFSystem): The system schema object.
40+
test_run (TestRun): The test run object.
4041
"""
41-
super().__init__(system)
42-
self.system = system
42+
super().__init__(system, test_run)
43+
self.system = cast(LSFSystem, system)
4344

4445
def gen_exec_command(self, tr: TestRun) -> str:
4546
"""

src/cloudai/systems/lsf/lsf_runner.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ def get_job_id(self, stdout: str, stderr: str) -> int | None:
5454
def _submit_test(self, tr: TestRun) -> LSFJob:
5555
logging.info(f"Running test: {tr.name}")
5656
tr.output_path = self.get_job_output_path(tr)
57-
exec_cmd = tr.test.test_template.gen_exec_command(tr)
57+
exec_cmd = self.get_cmd_gen_strategy(self.system, tr).gen_exec_command(tr)
5858
logging.debug(f"Executing command for test {tr.name}: {exec_cmd}")
5959
job_id = 0
6060
if self.mode == "run":

src/cloudai/systems/slurm/single_sbatch_runner.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -108,11 +108,11 @@ def aux_commands(self) -> list[str]:
108108
tr.output_path = self.scenario_root
109109
max_nodes, _ = self.extract_sbatch_nodes_spec()
110110
tr.num_nodes = max_nodes
111-
cmd_gen = cast(SlurmCommandGenStrategy, tr.test.test_template.command_gen_strategy)
111+
cmd_gen = cast(SlurmCommandGenStrategy, self.get_cmd_gen_strategy(self.system, tr))
112112
return [cmd_gen._metadata_cmd(tr), cmd_gen._ranks_mapping_cmd(tr)]
113113

114114
def get_single_tr_block(self, tr: TestRun) -> str:
115-
cmd_gen = cast(SlurmCommandGenStrategy, tr.test.test_template.command_gen_strategy)
115+
cmd_gen = cast(SlurmCommandGenStrategy, self.get_cmd_gen_strategy(self.system, tr))
116116
srun_cmd = cmd_gen.gen_srun_command(tr)
117117
nnodes, node_list = self.system.get_nodes_by_spec(tr.nnodes, tr.nodes)
118118
node_arg = f"--nodelist={','.join(node_list)}" if node_list else f"-N{nnodes}"
@@ -161,7 +161,7 @@ def gen_sbatch_content(self) -> str:
161161

162162
def add_pre_tests(self, pre_tc: TestScenario, base_tr: TestRun) -> str:
163163
content = []
164-
cmd_gen = cast(SlurmCommandGenStrategy, base_tr.test.test_template.command_gen_strategy)
164+
cmd_gen = cast(SlurmCommandGenStrategy, self.get_cmd_gen_strategy(self.system, base_tr))
165165
content.append(cmd_gen.gen_pre_test(pre_tc, self.scenario_root))
166166
content.append("if [ $PRE_TEST_SUCCESS -ne 1 ]; then")
167167
content.append(" exit 1")

src/cloudai/systems/slurm/slurm_command_gen_strategy.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323

2424
import toml
2525

26-
from cloudai.core import CommandGenStrategy, Registry, TestRun, TestScenario
26+
from cloudai.core import CommandGenStrategy, Registry, System, TestRun, TestScenario
2727
from cloudai.models.scenario import TestRunDetails
2828

2929
from .slurm_system import SlurmSystem
@@ -38,15 +38,16 @@ class SlurmCommandGenStrategy(CommandGenStrategy):
3838
properties and methods.
3939
"""
4040

41-
def __init__(self, system: SlurmSystem) -> None:
41+
def __init__(self, system: System, test_run: TestRun) -> None:
4242
"""
4343
Initialize a new SlurmCommandGenStrategy instance.
4444
4545
Args:
4646
system (SlurmSystem): The system schema object.
47+
test_run (TestRun): The test run object.
4748
"""
48-
super().__init__(system)
49-
self.system = system
49+
super().__init__(system, test_run)
50+
self.system = cast(SlurmSystem, system)
5051

5152
self._node_spec_cache: dict[str, tuple[int, list[str]]] = {}
5253

@@ -143,11 +144,8 @@ def _get_cmd_gen_strategy(self, tr: TestRun) -> "SlurmCommandGenStrategy":
143144
Returns:
144145
CommandGenStrategy: The strategy instance.
145146
"""
146-
registry = Registry()
147-
key = (CommandGenStrategy, type(self.system), type(tr.test.test_definition))
148-
strategy_cls = registry.strategies_map[key]
149-
strategy_cls_typed = cast(type[SlurmCommandGenStrategy], strategy_cls)
150-
strategy = strategy_cls_typed(self.system)
147+
strategy_cls = Registry().get_command_gen_strategy(type(self.system), type(tr.test.test_definition))
148+
strategy = cast(SlurmCommandGenStrategy, strategy_cls(self.system, tr))
151149
return strategy
152150

153151
def _set_hook_output_path(self, tr: TestRun, base_output_path: Path) -> None:

src/cloudai/systems/slurm/slurm_runner.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ def get_job_id(self, stdout: str, stderr: str) -> int | None:
5656

5757
def _submit_test(self, tr: TestRun) -> SlurmJob:
5858
logging.info(f"Running test: {tr.name}")
59-
exec_cmd = tr.test.test_template.gen_exec_command(tr)
59+
exec_cmd = self.get_cmd_gen_strategy(self.system, tr).gen_exec_command(tr)
6060
logging.debug(f"Executing command for test {tr.name}: {exec_cmd}")
6161
job_id = 0
6262
if self.mode == "run":
@@ -73,6 +73,10 @@ def _submit_test(self, tr: TestRun) -> SlurmJob:
7373
logging.info(f"Submitted slurm job: {job_id}")
7474
return SlurmJob(tr, id=job_id)
7575

76+
def on_job_submit(self, tr: TestRun) -> None:
77+
cmd_gen = self.get_cmd_gen_strategy(self.system, tr)
78+
cmd_gen.store_test_run(tr)
79+
7680
def on_job_completion(self, job: BaseJob) -> None:
7781
logging.debug(f"Job completion callback for job {job.id}")
7882
self.system.complete_job(cast(SlurmJob, job))
@@ -94,7 +98,7 @@ def _mock_job_metadata(self) -> SlurmStepMetadata:
9498
def _get_job_metadata(
9599
self, job: SlurmJob, steps_metadata: list[SlurmStepMetadata]
96100
) -> tuple[Path, SlurmJobMetadata]:
97-
cmd_gen = cast(SlurmCommandGenStrategy, job.test_run.test.test_template.command_gen_strategy)
101+
cmd_gen = cast(SlurmCommandGenStrategy, self.get_cmd_gen_strategy(self.system, job.test_run))
98102
return job.test_run.output_path / "slurm-job.toml", SlurmJobMetadata(
99103
job_id=int(job.id),
100104
name=steps_metadata[0].name,

0 commit comments

Comments
 (0)