Skip to content

Commit bce5e6f

Browse files
authored
Merge pull request #598 from NVIDIA/am/cmd-gen-use-tr-member
Rely on member test run object instead of args
2 parents 996645f + a99e0d9 commit bce5e6f

File tree

37 files changed

+576
-655
lines changed

37 files changed

+576
-655
lines changed

src/cloudai/_core/command_gen_strategy.py

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -29,28 +29,35 @@ class CommandGenStrategy(TestTemplateStrategy, ABC):
2929
def __init__(self, system: System, test_run: TestRun) -> None:
3030
super().__init__(system)
3131
self.test_run = test_run
32+
self._final_env_vars: dict[str, str | list[str]] = {}
3233

3334
@abstractmethod
34-
def gen_exec_command(self, tr: TestRun) -> str:
35+
def gen_exec_command(self) -> str:
3536
"""
3637
Generate the execution command for a test based on the given parameters.
3738
38-
Args:
39-
tr (TestRun): Contains the test and its run-specific configurations.
40-
4139
Returns:
4240
str: The generated execution command.
4341
"""
4442
pass
4543

4644
@abstractmethod
47-
def store_test_run(self, tr: TestRun) -> None:
45+
def store_test_run(self) -> None:
4846
"""
4947
Store the test run information in output folder.
5048
5149
Only at command generation time, CloudAI has all the information to store the test run.
52-
53-
Args:
54-
tr (TestRun): The test run object to be stored.
5550
"""
5651
pass
52+
53+
@property
54+
def final_env_vars(self) -> dict[str, str | list[str]]:
55+
if not self._final_env_vars:
56+
final_env_vars = self.system.global_env_vars.copy()
57+
final_env_vars.update(self.test_run.test.extra_env_vars)
58+
self._final_env_vars = final_env_vars
59+
return self._final_env_vars
60+
61+
@final_env_vars.setter
62+
def final_env_vars(self, value: dict[str, str | list[str]]) -> None:
63+
self._final_env_vars = value

src/cloudai/systems/lsf/lsf_command_gen_strategy.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ def __init__(self, system: System, test_run: TestRun) -> None:
4242
super().__init__(system, test_run)
4343
self.system = cast(LSFSystem, system)
4444

45-
def gen_exec_command(self, tr: TestRun) -> str:
45+
def gen_exec_command(self) -> str:
4646
"""
4747
Generate the execution command for the test run.
4848
@@ -52,11 +52,13 @@ def gen_exec_command(self, tr: TestRun) -> str:
5252
Returns:
5353
str: The generated LSF command.
5454
"""
55-
env_vars = self._override_env_vars(self.system.global_env_vars, tr.test.extra_env_vars)
56-
cmd_args = self._flatten_dict(tr.test.cmd_args)
57-
lsf_args = self._parse_lsf_args(tr.test.test_template.__class__.__name__, env_vars, cmd_args, tr)
55+
env_vars = self.final_env_vars
56+
cmd_args = self._flatten_dict(self.test_run.test.cmd_args)
57+
lsf_args = self._parse_lsf_args(
58+
self.test_run.test.test_template.__class__.__name__, env_vars, cmd_args, self.test_run
59+
)
5860

59-
bsub_command = self._gen_bsub_command(lsf_args, env_vars, cmd_args, tr)
61+
bsub_command = self._gen_bsub_command(lsf_args, env_vars, cmd_args, self.test_run)
6062

6163
return bsub_command.strip()
6264

src/cloudai/systems/lsf/lsf_runner.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ def get_job_id(self, stdout: str, stderr: str) -> int | None:
5454
def _submit_test(self, tr: TestRun) -> LSFJob:
5555
logging.info(f"Running test: {tr.name}")
5656
tr.output_path = self.get_job_output_path(tr)
57-
exec_cmd = self.get_cmd_gen_strategy(self.system, tr).gen_exec_command(tr)
57+
exec_cmd = self.get_cmd_gen_strategy(self.system, tr).gen_exec_command()
5858
logging.debug(f"Executing command for test {tr.name}: {exec_cmd}")
5959
job_id = 0
6060
if self.mode == "run":

src/cloudai/systems/slurm/single_sbatch_runner.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -109,11 +109,11 @@ def aux_commands(self) -> list[str]:
109109
max_nodes, _ = self.extract_sbatch_nodes_spec()
110110
tr.num_nodes = max_nodes
111111
cmd_gen = cast(SlurmCommandGenStrategy, self.get_cmd_gen_strategy(self.system, tr))
112-
return [cmd_gen._metadata_cmd(tr), cmd_gen._ranks_mapping_cmd(tr)]
112+
return [cmd_gen._metadata_cmd(), cmd_gen._ranks_mapping_cmd()]
113113

114114
def get_single_tr_block(self, tr: TestRun) -> str:
115115
cmd_gen = cast(SlurmCommandGenStrategy, self.get_cmd_gen_strategy(self.system, tr))
116-
srun_cmd = cmd_gen.gen_srun_command(tr)
116+
srun_cmd = cmd_gen.gen_srun_command()
117117
nnodes, node_list = self.system.get_nodes_by_spec(tr.nnodes, tr.nodes)
118118
node_arg = f"--nodelist={','.join(node_list)}" if node_list else f"-N{nnodes}"
119119
extra_args = (

0 commit comments

Comments
 (0)