Skip to content

Commit c051a27

Browse files
authored
Add GPU directive support check to SlurmSystem and use it in command gen (#541)
1 parent 4600e19 commit c051a27

File tree

4 files changed

+80
-1
lines changed

4 files changed

+80
-1
lines changed

src/cloudai/systems/slurm/slurm_system.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,7 @@ class SlurmSystem(BaseModel, System):
139139
cmd_shell: CommandShell = Field(default=CommandShell(), exclude=True)
140140
extra_srun_args: Optional[str] = None
141141
extra_sbatch_args: list[str] = []
142+
supports_gpu_directives_cache: Optional[bool] = Field(default=None, exclude=True)
142143

143144
data_repository: Optional[DataRepositoryConfig] = None
144145

@@ -165,6 +166,29 @@ def groups(self) -> Dict[str, Dict[str, List[SlurmNode]]]:
165166

166167
return groups
167168

169+
@property
170+
def supports_gpu_directives(self) -> bool:
171+
if self.supports_gpu_directives_cache is not None:
172+
return self.supports_gpu_directives_cache
173+
174+
stdout, stderr = self.fetch_command_output("scontrol show config")
175+
if stderr:
176+
logging.warning(f"Error checking GPU support: {stderr}")
177+
self.supports_gpu_directives_cache = True
178+
return True
179+
180+
has_gres_gpu = False
181+
has_gpu_gres_type = False
182+
183+
for line in stdout.splitlines():
184+
if "AccountingStorageTRES" in line and "gres/gpu" in line:
185+
has_gres_gpu = True
186+
if "GresTypes" in line and "gpu" in line and "(null)" not in line:
187+
has_gpu_gres_type = True
188+
189+
self.supports_gpu_directives_cache = has_gres_gpu and has_gpu_gres_type
190+
return self.supports_gpu_directives_cache
191+
168192
@field_serializer("install_path", "output_path")
169193
def _path_serializer(self, v: Path) -> str:
170194
return str(v)

src/cloudai/systems/slurm/strategy/slurm_command_gen_strategy.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -378,9 +378,10 @@ def _append_sbatch_directives(self, batch_script_content: List[str], tr: TestRun
378378

379379
hostfile = self._append_nodes_related_directives(batch_script_content, tr)
380380

381-
if self.system.gpus_per_node:
381+
if self.system.gpus_per_node and self.system.supports_gpu_directives:
382382
batch_script_content.append(f"#SBATCH --gpus-per-node={self.system.gpus_per_node}")
383383
batch_script_content.append(f"#SBATCH --gres=gpu:{self.system.gpus_per_node}")
384+
384385
if self.system.ntasks_per_node:
385386
batch_script_content.append(f"#SBATCH --ntasks-per-node={self.system.ntasks_per_node}")
386387
if tr.time_limit:

src/cloudai/workloads/nemo_run/slurm_command_gen_strategy.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,19 @@ def generate_test_command(
134134

135135
cmd_args_dict["trainer"]["num_nodes"] = num_nodes
136136

137+
if self.system.gpus_per_node:
138+
trainer_config = cmd_args_dict.get("trainer", {})
139+
if "devices" in trainer_config:
140+
user_devices = trainer_config["devices"]
141+
if user_devices != self.system.gpus_per_node:
142+
logging.warning(
143+
f"User-specified trainer.devices ({user_devices}) differs from "
144+
f"system gpus_per_node ({self.system.gpus_per_node})"
145+
)
146+
cmd_args_dict["trainer"]["devices"] = self.system.gpus_per_node
147+
else:
148+
logging.debug("SlurmSystem.gpus_per_node is not set. Skipping trainer.devices injection.")
149+
137150
self.append_flattened_dict("", cmd_args_dict, command)
138151

139152
if tr.test.extra_cmd_args:

tests/test_slurm_system.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -530,3 +530,44 @@ def test_per_step_isolation(self, mock_get_nodes: Mock, slurm_system: SlurmSyste
530530
res = strategy.get_cached_nodes_spec(test_run)
531531
assert mock_get_nodes.call_count == 2
532532
assert res == (2, ["node03", "node04"])
533+
534+
535+
@pytest.mark.parametrize(
536+
"scontrol_output,expected_support",
537+
[
538+
# Case 1: gres/gpu in AccountingStorageTRES and gpu in GresTypes - should be supported
539+
(
540+
"""Configuration data as of 2023-06-14T16:28:09
541+
AccountingStorageTRES = cpu,mem,energy,node,billing,fs/disk,vmem,pages,gres/gpu,gres/gpumem,gres/gpuutil
542+
GresTypes = gpu""",
543+
True,
544+
),
545+
# Case 2: gres/gpu in AccountingStorageTRES but GresTypes is (null) - should NOT be supported
546+
(
547+
"""Configuration data as of 2023-06-14T16:28:09
548+
AccountingStorageTRES = cpu,mem,energy,node,billing,fs/disk,vmem,pages,gres/gpu,gres/gpumem,gres/gpuutil
549+
GresTypes = (null)""",
550+
False,
551+
),
552+
# Case 3: No gres/gpu in AccountingStorageTRES - should NOT be supported
553+
(
554+
"""Configuration data as of 2023-06-14T16:28:09
555+
AccountingStorageTRES = cpu,mem,energy,node,billing,fs/disk,vmem,pages
556+
GresTypes = gpu""",
557+
False,
558+
),
559+
# Case 4: No gres/gpu in AccountingStorageTRES and GresTypes is (null) - should NOT be supported
560+
(
561+
"""Configuration data as of 2023-06-14T16:28:09
562+
AccountingStorageTRES = cpu,mem,energy,node,billing,fs/disk,vmem,pages
563+
GresTypes = (null)""",
564+
False,
565+
),
566+
],
567+
)
568+
@patch("cloudai.systems.slurm.slurm_system.SlurmSystem.fetch_command_output")
569+
def test_supports_gpu_directives(
570+
mock_fetch_command_output, scontrol_output: str, expected_support: bool, slurm_system: SlurmSystem
571+
):
572+
mock_fetch_command_output.return_value = (scontrol_output, "")
573+
assert slurm_system.supports_gpu_directives == expected_support

0 commit comments

Comments
 (0)