Skip to content

Commit 266e3b2

Browse files
authored
Merge pull request #385 from NVIDIA/am/slurm-system
Remove node list definition from slurm partition
2 parents fa7c36a + f9276bf commit 266e3b2

File tree

12 files changed

+95
-163
lines changed

12 files changed

+95
-163
lines changed

conf/common/system/example_slurm_cluster.toml

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@ ntasks_per_node = 8
2727

2828
[[partitions]]
2929
name = "partition_1"
30-
nodes = ["node-[001-100]"]
3130

3231
[[partitions.groups]]
3332
name = "group_1"
@@ -47,7 +46,6 @@ nodes = ["node-[001-100]"]
4746

4847
[[partitions]]
4948
name = "partition_2"
50-
nodes = ["node-[101-200]"]
5149

5250
[global_env_vars]
5351
# NCCL Specific Configurations

src/cloudai/systems/slurm/slurm_system.py

Lines changed: 46 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -76,25 +76,8 @@ class SlurmPartition(BaseModel):
7676

7777
model_config = ConfigDict(extra="forbid")
7878
name: str
79-
nodes: List[str]
8079
groups: List[SlurmGroup] = []
81-
82-
_slurm_nodes: List[SlurmNode] = []
83-
84-
@property
85-
def slurm_nodes(self) -> List[SlurmNode]:
86-
if self._slurm_nodes:
87-
return self._slurm_nodes
88-
89-
node_names = set()
90-
for nodes_list in self.nodes:
91-
node_names.update(set(parse_node_list(nodes_list)))
92-
93-
self._slurm_nodes = [
94-
SlurmNode(name=node_name, partition=self.name, state=SlurmNodeState.UNKNOWN_STATE)
95-
for node_name in node_names
96-
]
97-
return self._slurm_nodes
80+
slurm_nodes: list[SlurmNode] = Field(default_factory=list[SlurmNode], exclude=True)
9881

9982

10083
class SlurmSystem(BaseModel, System):
@@ -147,7 +130,10 @@ def groups(self) -> Dict[str, Dict[str, List[SlurmNode]]]:
147130
node_names = set()
148131
for group_nodes in group.nodes:
149132
node_names.update(set(parse_node_list(group_nodes)))
150-
groups[part.name][group.name] = [node for node in part.slurm_nodes if node.name in node_names]
133+
groups[part.name][group.name] = [
134+
SlurmNode(name=node_name, partition=self.name, state=SlurmNodeState.UNKNOWN_STATE)
135+
for node_name in node_names
136+
]
151137

152138
return groups
153139

@@ -163,7 +149,10 @@ def update(self) -> None:
163149
commands, and correlating this information to determine the state of each node and the user running jobs on
164150
each node.
165151
"""
166-
self.update_node_states()
152+
squeue_output, _ = self.fetch_command_output("squeue -o '%N|%u' --noheader")
153+
sinfo_output, _ = self.fetch_command_output("sinfo")
154+
node_user_map = self.parse_squeue_output(squeue_output)
155+
self.parse_sinfo_output(sinfo_output, node_user_map)
167156

168157
def is_job_running(self, job: BaseJob, retry_threshold: int = 3) -> bool:
169158
"""
@@ -373,7 +362,7 @@ def get_available_nodes_from_group(
373362
"""
374363
self.validate_partition_and_group(partition_name, group_name)
375364

376-
self.update_node_states()
365+
self.update()
377366

378367
grouped_nodes = self.group_nodes_by_state(partition_name, group_name)
379368

@@ -490,18 +479,6 @@ def allocate_nodes(
490479

491480
return allocated_nodes
492481

493-
def is_node_in_system(self, node_name: str) -> bool:
494-
"""
495-
Check if a given node is part of the Slurm system.
496-
497-
Args:
498-
node_name (str): The name of the node to check.
499-
500-
Returns:
501-
True if the node is part of the system, otherwise False.
502-
"""
503-
return any(any(node.name == node_name for node in part.slurm_nodes) for part in self.partitions)
504-
505482
def scancel(self, job_id: int) -> None:
506483
"""
507484
Terminates a specified Slurm job by sending a cancellation command.
@@ -511,39 +488,6 @@ def scancel(self, job_id: int) -> None:
511488
"""
512489
self.cmd_shell.execute(f"scancel {job_id}")
513490

514-
def update_node_states(self) -> None:
515-
"""
516-
Update the states of nodes in the Slurm system.
517-
518-
By querying the current state of each node using the 'sinfo' command, and correlates this with 'squeue' to
519-
determine which user is running jobs on each node. This method parses the output of these commands, identifies
520-
the state of nodes and the users, and updates the corresponding SlurmNode instances in the system.
521-
"""
522-
squeue_output = self.get_squeue()
523-
sinfo_output = self.get_sinfo()
524-
node_user_map = self.parse_squeue_output(squeue_output)
525-
self.parse_sinfo_output(sinfo_output, node_user_map)
526-
527-
def get_squeue(self) -> str:
528-
"""
529-
Fetch the output from the 'squeue' command.
530-
531-
Returns
532-
str: The stdout from the 'squeue' command execution.
533-
"""
534-
squeue_output, _ = self.fetch_command_output("squeue -o '%N|%u' --noheader")
535-
return squeue_output
536-
537-
def get_sinfo(self) -> str:
538-
"""
539-
Fetch the output from the 'sinfo' command.
540-
541-
Returns
542-
str: The stdout from the 'sinfo' command execution.
543-
"""
544-
sinfo_output, _ = self.fetch_command_output("sinfo")
545-
return sinfo_output
546-
547491
def fetch_command_output(self, command: str) -> Tuple[str, str]:
548492
"""
549493
Execute a system command and return its output.
@@ -614,12 +558,25 @@ def parse_sinfo_output(self, sinfo_output: str, node_user_map: Dict[str, str]) -
614558
for part in self.partitions:
615559
if part.name != partition:
616560
continue
561+
562+
found = False
617563
for node in part.slurm_nodes:
618564
if node.name == node_name:
565+
found = True
619566
node.state = state_enum
620567
node.user = node_user_map.get(node_name, "N/A")
621568
break
622569

570+
if not found:
571+
part.slurm_nodes.append(
572+
SlurmNode(
573+
name=node_name,
574+
partition=partition,
575+
state=state_enum,
576+
user=node_user_map.get(node_name, "N/A"),
577+
)
578+
)
579+
623580
def convert_state_to_enum(self, state_str: str) -> SlurmNodeState:
624581
"""
625582
Convert a Slurm node state string to its corresponding enum member.
@@ -709,13 +666,30 @@ def parse_nodes(self, nodes: List[str]) -> List[str]:
709666
group_nodes = self.get_available_nodes_from_group(partition_name, group_name, num_nodes)
710667
parsed_nodes += [node.name for node in group_nodes]
711668
else:
712-
# Handle both individual node names and ranges
713-
if self.is_node_in_system(node_spec) or "[" in node_spec:
714-
expanded_nodes = parse_node_list(node_spec)
715-
parsed_nodes += expanded_nodes
716-
else:
717-
raise ValueError(f"Node '{node_spec}' not found.")
669+
expanded_nodes = parse_node_list(node_spec)
670+
parsed_nodes += expanded_nodes
718671

719672
# Remove duplicates while preserving order
720673
parsed_nodes = list(dict.fromkeys(parsed_nodes))
721674
return parsed_nodes
675+
676+
def get_nodes_by_spec(self, num_nodes: int, nodes: list[str]) -> Tuple[int, list[str]]:
677+
"""
678+
Retrieve a list of node names based on specifications.
679+
680+
When nodes is empty, returns `(num_nodes, [])`, otherwise parses the node specifications and returns the number
681+
of nodes and a list of node names.
682+
683+
Args:
684+
num_nodes (int): The number of nodes, can't be `0`.
685+
nodes (list[str]): A list of node names specifications, slurm format or `PARTITION:GROUP:NUM_NODES`.
686+
687+
Returns:
688+
Tuple[int, list[str]]: The number of nodes and a list of node names.
689+
"""
690+
num_nodes, node_list = num_nodes, []
691+
parsed_nodes = self.parse_nodes(nodes)
692+
if parsed_nodes:
693+
num_nodes = len(parsed_nodes)
694+
node_list = parsed_nodes
695+
return num_nodes, node_list

src/cloudai/systems/slurm/strategy/slurm_command_gen_strategy.py

Lines changed: 2 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -42,14 +42,6 @@ def __init__(self, system: SlurmSystem, cmd_args: Dict[str, Any]) -> None:
4242
"""
4343
super().__init__(system, cmd_args)
4444
self.system = system
45-
if not self.system.default_partition:
46-
raise ValueError(
47-
"Default partition not set in the Slurm system object. "
48-
"The 'default_partition' attribute should be properly defined in the Slurm system configuration. "
49-
"Please ensure that 'default_partition' is set correctly in the corresponding system configuration "
50-
"(e.g., system.toml)."
51-
)
52-
5345
self.docker_image_url = self.cmd_args.get("docker_image_url", "")
5446

5547
@abstractmethod
@@ -130,15 +122,12 @@ def _parse_slurm_args(
130122
KeyError: If partition or essential node settings are missing.
131123
"""
132124
job_name = self.job_name(job_name_prefix)
133-
134-
parsed_nodes = self.system.parse_nodes(tr.nodes)
135-
num_nodes = len(parsed_nodes) if parsed_nodes else tr.num_nodes
136-
node_list_str = ",".join(parsed_nodes) if parsed_nodes else ""
125+
num_nodes, node_list = self.system.get_nodes_by_spec(tr.num_nodes, tr.nodes)
137126

138127
slurm_args = {
139128
"job_name": job_name,
140129
"num_nodes": num_nodes,
141-
"node_list_str": node_list_str,
130+
"node_list_str": ",".join(node_list),
142131
}
143132
if tr.time_limit:
144133
slurm_args["time_limit"] = tr.time_limit

src/cloudai/workloads/nemo_launcher/slurm_command_gen_strategy.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ def _container_mounts(self, tr: TestRun) -> list[str]:
3636
def gen_exec_command(self, tr: TestRun) -> str:
3737
self._prepare_environment(tr.test.cmd_args, tr.test.extra_env_vars, tr.output_path)
3838

39-
nodes = self.system.parse_nodes(tr.nodes)
39+
_, nodes = self.system.get_nodes_by_spec(tr.num_nodes, tr.nodes)
4040
self._set_node_config(nodes, tr.num_nodes)
4141

4242
tdef: NeMoLauncherTestDefinition = cast(NeMoLauncherTestDefinition, tr.test.test_definition)

src/cloudai/workloads/nemo_run/slurm_command_gen_strategy.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ def generate_test_command(
8484
"-y",
8585
]
8686

87-
num_nodes = len(self.system.parse_nodes(tr.nodes)) if tr.nodes else tr.num_nodes
87+
num_nodes, _ = self.system.get_nodes_by_spec(tr.num_nodes, tr.nodes)
8888

8989
if cmd_args_dict["trainer"]["num_nodes"] and cmd_args_dict["trainer"]["num_nodes"] > num_nodes:
9090
err = (

tests/conftest.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,15 +39,13 @@ def slurm_system(tmp_path: Path) -> SlurmSystem:
3939
partitions=[
4040
SlurmPartition(
4141
name="main",
42-
nodes=["node-[033-064]"],
4342
groups=[
4443
SlurmGroup(name="group1", nodes=["node-[033-048]"]),
4544
SlurmGroup(name="group2", nodes=["node-[049-064]"]),
4645
],
4746
),
4847
SlurmPartition(
4948
name="backup",
50-
nodes=["node0[1-8]"],
5149
groups=[
5250
SlurmGroup(name="group1", nodes=["node0[1-4]"]),
5351
SlurmGroup(name="group2", nodes=["node0[5-8]"]),

tests/slurm_command_gen_strategy/conftest.py

Lines changed: 0 additions & 42 deletions
This file was deleted.

tests/slurm_command_gen_strategy/test_common_slurm_command_gen_strategy.py

Lines changed: 1 addition & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ def test_filename_generation(strategy_fixture: SlurmCommandGenStrategy, testrun_
7575
assert "test_job" in file_contents
7676
assert "node1,node2" in file_contents
7777
assert "srun" in file_contents
78-
assert "--mpi=fake-mpi" in file_contents
78+
assert f"--mpi={strategy_fixture.system.mpi}" in file_contents
7979

8080

8181
def test_num_nodes_and_nodes(strategy_fixture: SlurmCommandGenStrategy):
@@ -131,18 +131,6 @@ def test_time_limit(time_limit: Optional[str], strategy_fixture: SlurmCommandGen
131131
assert "time_limit" not in slurm_args
132132

133133

134-
def test_raises_if_no_default_partition(slurm_system: SlurmSystem):
135-
slurm_system.default_partition = ""
136-
with pytest.raises(ValueError) as exc_info:
137-
MySlurmCommandGenStrategy(slurm_system, {})
138-
assert (
139-
"Default partition not set in the Slurm system object. "
140-
"The 'default_partition' attribute should be properly defined in the Slurm "
141-
"system configuration. Please ensure that 'default_partition' is set correctly "
142-
"in the corresponding system configuration (e.g., system.toml)."
143-
) in str(exc_info.value)
144-
145-
146134
@pytest.mark.parametrize(
147135
"pre_test,post_test,expected_script_lines",
148136
[

tests/slurm_command_gen_strategy/test_jax_toolbox_slurm_command_gen_strategy.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,13 +59,14 @@ def grok_test(self) -> GrokTestDefinition:
5959
@pytest.mark.parametrize("test_fixture", ["gpt_test", "grok_test"])
6060
def test_gen_exec_command(
6161
self,
62-
slurm_system,
62+
slurm_system: SlurmSystem,
6363
cmd_gen_strategy: JaxToolboxSlurmCommandGenStrategy,
6464
tmp_path: Path,
6565
request,
6666
test_fixture,
6767
) -> None:
6868
test_def = request.getfixturevalue(test_fixture)
69+
slurm_system.output_path.mkdir(parents=True, exist_ok=True)
6970

7071
test = Test(test_definition=test_def, test_template=TestTemplate(slurm_system, "name"))
7172
test_run = TestRun(

tests/test_job_submission_error.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ def slurm_system(tmp_path: Path):
5656
install_path=tmp_path,
5757
output_path=tmp_path,
5858
default_partition="main",
59-
partitions=[SlurmPartition(name="main", nodes=["nodeA001", "nodeB001"])],
59+
partitions=[SlurmPartition(name="main")],
6060
)
6161
return system
6262

0 commit comments

Comments
 (0)