Skip to content

Commit b305b86

Browse files
authored
Merge pull request #210 from TaekyungHeo/nemo-bug-fix-slurm-args
Add Support for Cluster Account and gpus_per_node in Command Generation
2 parents 16e8314 + 98274a5 commit b305b86

File tree

2 files changed

+109
-0
lines changed

2 files changed

+109
-0
lines changed

src/cloudai/schema/test_template/nemo_launcher/slurm_command_gen_strategy.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,20 @@ def gen_exec_command(
7878
del self.final_cmd_args["repository_commit_hash"]
7979
del self.final_cmd_args["docker_image_url"]
8080

81+
if self.slurm_system.account is not None:
82+
self.final_cmd_args["cluster.account"] = self.slurm_system.account
83+
self.final_cmd_args["cluster.job_name_prefix"] = f"{self.slurm_system.account}-cloudai.nemo:"
84+
self.final_cmd_args["cluster.gpus_per_node"] = (
85+
self.slurm_system.gpus_per_node if self.slurm_system.gpus_per_node is not None else "null"
86+
)
87+
88+
if ("data_dir" in self.final_cmd_args) and (self.final_cmd_args["data_dir"] == "DATA_DIR"):
89+
raise ValueError(
90+
"The 'data_dir' field of the NeMo launcher test contains the placeholder 'DATA_DIR'. "
91+
"Please update the test schema TOML file with a valid path to the dataset. "
92+
"The 'data_dir' field must point to an actual dataset location, not a placeholder."
93+
)
94+
8195
cmd_args_str = self._generate_cmd_args_str(self.final_cmd_args, nodes)
8296

8397
full_cmd = f"python {launcher_path}/launcher_scripts/main.py {cmd_args_str}"

tests/test_slurm_command_gen_strategy.py

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -298,6 +298,101 @@ def test_invalid_tokenizer_path(self, nemo_cmd_gen: NeMoLauncherSlurmCommandGenS
298298
nodes=[],
299299
)
300300

301+
def test_account_in_command(self, nemo_cmd_gen: NeMoLauncherSlurmCommandGenStrategy):
302+
extra_env_vars = {"TEST_VAR_1": "value1"}
303+
cmd_args = {
304+
"docker_image_url": "fake",
305+
"repository_url": "fake",
306+
"repository_commit_hash": "fake",
307+
}
308+
309+
nemo_cmd_gen.slurm_system.account = "test_account"
310+
cmd = nemo_cmd_gen.gen_exec_command(
311+
cmd_args=cmd_args,
312+
extra_env_vars=extra_env_vars,
313+
extra_cmd_args="",
314+
output_path=Path(""),
315+
num_nodes=1,
316+
nodes=[],
317+
)
318+
319+
assert "cluster.account=test_account" in cmd
320+
assert "cluster.job_name_prefix=test_account-cloudai.nemo:" in cmd
321+
322+
def test_no_account_in_command(self, nemo_cmd_gen: NeMoLauncherSlurmCommandGenStrategy):
323+
extra_env_vars = {"TEST_VAR_1": "value1"}
324+
cmd_args = {
325+
"docker_image_url": "fake",
326+
"repository_url": "fake",
327+
"repository_commit_hash": "fake",
328+
}
329+
330+
nemo_cmd_gen.slurm_system.account = None
331+
cmd = nemo_cmd_gen.gen_exec_command(
332+
cmd_args=cmd_args,
333+
extra_env_vars=extra_env_vars,
334+
extra_cmd_args="",
335+
output_path=Path(""),
336+
num_nodes=1,
337+
nodes=[],
338+
)
339+
340+
assert "cluster.account" not in cmd
341+
assert "cluster.job_name_prefix" not in cmd
342+
343+
def test_gpus_per_node_value(self, nemo_cmd_gen: NeMoLauncherSlurmCommandGenStrategy):
344+
extra_env_vars = {"TEST_VAR_1": "value1"}
345+
cmd_args = {
346+
"docker_image_url": "fake",
347+
"repository_url": "fake",
348+
"repository_commit_hash": "fake",
349+
}
350+
351+
nemo_cmd_gen.slurm_system.gpus_per_node = 4
352+
cmd = nemo_cmd_gen.gen_exec_command(
353+
cmd_args=cmd_args,
354+
extra_env_vars=extra_env_vars,
355+
extra_cmd_args="",
356+
output_path=Path(""),
357+
num_nodes=1,
358+
nodes=[],
359+
)
360+
361+
assert "cluster.gpus_per_node=4" in cmd
362+
363+
nemo_cmd_gen.slurm_system.gpus_per_node = None
364+
cmd = nemo_cmd_gen.gen_exec_command(
365+
cmd_args=cmd_args,
366+
extra_env_vars=extra_env_vars,
367+
extra_cmd_args="",
368+
output_path=Path(""),
369+
num_nodes=1,
370+
nodes=[],
371+
)
372+
373+
assert "cluster.gpus_per_node=null" in cmd
374+
375+
def test_data_dir_validation(self, nemo_cmd_gen: NeMoLauncherSlurmCommandGenStrategy):
376+
extra_env_vars = {"TEST_VAR_1": "value1"}
377+
cmd_args = {
378+
"docker_image_url": "fake",
379+
"repository_url": "fake",
380+
"repository_commit_hash": "fake",
381+
"data_dir": "DATA_DIR", # Invalid placeholder
382+
}
383+
384+
with pytest.raises(
385+
ValueError, match="The 'data_dir' field of the NeMo launcher test contains the placeholder 'DATA_DIR'."
386+
):
387+
nemo_cmd_gen.gen_exec_command(
388+
cmd_args=cmd_args,
389+
extra_env_vars=extra_env_vars,
390+
extra_cmd_args="",
391+
output_path=Path(""),
392+
num_nodes=1,
393+
nodes=[],
394+
)
395+
301396

302397
class TestWriteSbatchScript:
303398
MANDATORY_ARGS = {

0 commit comments

Comments
 (0)