Skip to content

Commit 82ee1dd

Browse files
authored
Merge pull request #577 from NVIDIA/am/group-nodes-alloc
Re-work slurm node status update
2 parents 9079e9c + f296db0 commit 82ee1dd

File tree

4 files changed

+255
-192
lines changed

4 files changed

+255
-192
lines changed

src/cloudai/systems/slurm/slurm_node.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,10 @@ def allocatable(self, free_only: bool = True) -> bool:
140140
SlurmNodeState.RESERVED,
141141
]
142142

143+
def __hash__(self) -> int:
144+
"""Provide a hash of the Slurm node, including its name, state, and partition."""
145+
return hash((self.name, self.partition, self.state, self.user))
146+
143147
def __repr__(self) -> str:
144148
"""
145149
Provide a structured string representation of the Slurm node, including its name, state, and partition.

src/cloudai/systems/slurm/slurm_system.py

Lines changed: 54 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -196,10 +196,59 @@ def update(self) -> None:
196196
commands, and correlating this information to determine the state of each node and the user running jobs on
197197
each node.
198198
"""
199-
squeue_output, _ = self.fetch_command_output("squeue -o '%N|%u' --noheader")
200-
sinfo_output, _ = self.fetch_command_output("sinfo")
201-
node_user_map = self.parse_squeue_output(squeue_output)
202-
self.parse_sinfo_output(sinfo_output, node_user_map)
199+
all_nodes = self.nodes_from_sinfo()
200+
self.update_nodes_state_and_user(all_nodes, insert_new=True)
201+
self.update_nodes_state_and_user(self.nodes_from_squeue())
202+
203+
def nodes_from_sinfo(self) -> list[SlurmNode]:
204+
sinfo_output, _ = self.fetch_command_output("sinfo -o '%P|%t|%u|%N'")
205+
nodes: list[SlurmNode] = []
206+
for line in sinfo_output.split("\n"):
207+
if not line.strip():
208+
continue
209+
parts = line.split("|")
210+
if len(parts) < 4:
211+
continue
212+
partition, state, user, nodelist = parts[:4]
213+
partition = partition.rstrip("*").strip()
214+
node_names = parse_node_list(nodelist)
215+
logging.debug(f"{partition=}, {state=}, {nodelist=}, {node_names=}")
216+
for node_name in node_names:
217+
nodes.append(
218+
SlurmNode(name=node_name, partition=partition, state=self.convert_state_to_enum(state), user=user)
219+
)
220+
return nodes
221+
222+
def nodes_from_squeue(self) -> list[SlurmNode]:
223+
squeue_output, _ = self.fetch_command_output("squeue --states=running,pending --noheader -o '%P|%T|%N|%u'")
224+
nodes: list[SlurmNode] = []
225+
for line in squeue_output.split("\n"):
226+
parts = line.split("|")
227+
if len(parts) < 4:
228+
continue
229+
partition, _, nodelist, user = parts[:4]
230+
node_names = parse_node_list(nodelist)
231+
for node in node_names:
232+
nodes.append(SlurmNode(name=node, partition=partition, state=SlurmNodeState.ALLOCATED, user=user))
233+
return nodes
234+
235+
def update_nodes_state_and_user(self, nodes: list[SlurmNode], insert_new: bool = False) -> None:
236+
for node in nodes:
237+
for part in self.partitions:
238+
if part.name != node.partition:
239+
continue
240+
241+
found = False
242+
for pnode in part.slurm_nodes:
243+
if pnode.name != node.name:
244+
continue
245+
pnode.state = node.state
246+
pnode.user = node.user
247+
found = True
248+
break
249+
250+
if not found and insert_new:
251+
part.slurm_nodes.append(node)
203252

204253
def is_job_running(self, job: BaseJob, retry_threshold: int = 3) -> bool:
205254
"""
@@ -580,79 +629,6 @@ def fetch_command_output(self, command: str) -> Tuple[str, str]:
580629
logging.error(f"Error executing command '{command}': {stderr}")
581630
return stdout, stderr
582631

583-
def parse_squeue_output(self, squeue_output: str) -> Dict[str, str]:
584-
"""
585-
Parse the output from the 'squeue' command to map nodes to users.
586-
587-
The expected format of squeue_output is lines of 'node_spec|user', where node_spec can include comma-separated
588-
node names or ranges.
589-
590-
Args:
591-
squeue_output (str): The raw output from the squeue command.
592-
593-
Returns:
594-
Dict[str, str]: A dictionary mapping node names to usernames.
595-
"""
596-
node_user_map = {}
597-
for line in squeue_output.split("\n"):
598-
if line.strip():
599-
# Split the line into node list and user, handling only the first '|'
600-
parts = line.split("|")
601-
if len(parts) < 2:
602-
continue # Skip malformed lines
603-
604-
node_list_part, user = parts[0], "|".join(parts[1:])
605-
# Handle cases where multiple node groups or ranges are specified
606-
for node in parse_node_list(node_list_part):
607-
node_user_map[node] = user.strip()
608-
609-
return node_user_map
610-
611-
def parse_sinfo_output(self, sinfo_output: str, node_user_map: Dict[str, str]) -> None:
612-
"""
613-
Parse the output from the 'sinfo' command to update node states.
614-
615-
Args:
616-
sinfo_output (str): The output from the sinfo command.
617-
node_user_map (dict): A dictionary mapping node names to users.
618-
"""
619-
for line in sinfo_output.split("\n")[1:]: # Skip the header line
620-
if not line.strip():
621-
continue
622-
parts = line.split()
623-
if len(parts) < 6:
624-
continue
625-
partition, _, _, _, state, nodelist = parts[:6]
626-
partition = partition.rstrip("*")
627-
node_names = parse_node_list(nodelist)
628-
629-
# Convert state to enum, handling states with suffixes
630-
state_enum = self.convert_state_to_enum(state)
631-
632-
for node_name in node_names:
633-
# Find the partition and node to update the state
634-
for part in self.partitions:
635-
if part.name != partition:
636-
continue
637-
638-
found = False
639-
for node in part.slurm_nodes:
640-
if node.name == node_name:
641-
found = True
642-
node.state = state_enum
643-
node.user = node_user_map.get(node_name, "N/A")
644-
break
645-
646-
if not found:
647-
part.slurm_nodes.append(
648-
SlurmNode(
649-
name=node_name,
650-
partition=partition,
651-
state=state_enum,
652-
user=node_user_map.get(node_name, "N/A"),
653-
)
654-
)
655-
656632
def convert_state_to_enum(self, state_str: str) -> SlurmNodeState:
657633
"""
658634
Convert a Slurm node state string to its corresponding enum member.
@@ -768,7 +744,7 @@ def get_nodes_by_spec(self, num_nodes: int, nodes: list[str]) -> Tuple[int, list
768744
if parsed_nodes:
769745
num_nodes = len(parsed_nodes)
770746
node_list = parsed_nodes
771-
return num_nodes, node_list
747+
return num_nodes, sorted(node_list)
772748

773749
def system_installables(self) -> list[Installable]:
774750
return [File(Path(__file__).parent.absolute() / "slurm-metadata.sh")]

tests/test_slurm_allocation.py

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
# SPDX-FileCopyrightText: NVIDIA CORPORATION & AFFILIATES
2+
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3+
# SPDX-License-Identifier: Apache-2.0
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
17+
import pytest
18+
19+
from cloudai.systems.slurm import SlurmGroup, SlurmNode, SlurmNodeState, SlurmPartition, SlurmSystem, parse_node_list
20+
21+
22+
class TestGroupAllocation:
23+
def prepare(
24+
self, slurm_system: SlurmSystem, taken_node_names: list[str], monkeypatch: pytest.MonkeyPatch
25+
) -> tuple[SlurmSystem, list[SlurmNode], list[SlurmNode]]:
26+
slurm_system.partitions = [
27+
SlurmPartition(name="main", groups=[SlurmGroup(name="group1", nodes=["node0[1-5]"])])
28+
]
29+
all_nodes = [
30+
SlurmNode(name=name, partition="main", state=SlurmNodeState.IDLE)
31+
for name in parse_node_list(slurm_system.partitions[0].groups[0].nodes[0])
32+
]
33+
taken_nodes = [
34+
SlurmNode(name=node.name, partition="main", state=SlurmNodeState.ALLOCATED)
35+
for node in all_nodes
36+
if node.name in taken_node_names
37+
]
38+
39+
mod_path = "cloudai.systems.slurm.slurm_system.SlurmSystem"
40+
monkeypatch.setattr(f"{mod_path}.nodes_from_sinfo", lambda *args, **kwargs: all_nodes)
41+
monkeypatch.setattr(f"{mod_path}.nodes_from_squeue", lambda *args, **kwargs: taken_nodes)
42+
return slurm_system, all_nodes, taken_nodes
43+
44+
def test_all_nodes_in_group_are_idle(self, slurm_system: SlurmSystem, monkeypatch: pytest.MonkeyPatch):
45+
system, *_ = self.prepare(slurm_system, [], monkeypatch)
46+
nnodes, nodes_list = system.get_nodes_by_spec(1, ["main:group1:5"])
47+
assert nodes_list == parse_node_list(slurm_system.partitions[0].groups[0].nodes[0])
48+
assert nnodes == len(nodes_list)
49+
50+
def test_enough_free_nodes_for_allocation(self, slurm_system: SlurmSystem, monkeypatch: pytest.MonkeyPatch):
51+
system, all_nodes, taken_nodes = self.prepare(slurm_system, ["node01", "node02"], monkeypatch)
52+
nnodes, nodes_list = system.get_nodes_by_spec(1, ["main:group1:3"])
53+
assert nnodes == 3
54+
assert nodes_list == sorted([n.name for n in set(all_nodes) - set(taken_nodes)])
55+
56+
def test_not_enough_nodes_for_allocation(self, slurm_system: SlurmSystem, monkeypatch: pytest.MonkeyPatch):
57+
"""In this scenario we still return required number of nodes to put job into the queue"""
58+
system, all_nodes, _ = self.prepare(slurm_system, ["node01", "node02"], monkeypatch)
59+
nnodes, nodes_list = system.get_nodes_by_spec(1, ["main:group1:5"])
60+
assert nnodes == 5
61+
assert nodes_list == sorted([n.name for n in all_nodes])
62+
63+
@pytest.mark.xfail(reason="This is a bug in the code, RM4471870")
64+
def test_two_cases_one_group(self, slurm_system: SlurmSystem, monkeypatch: pytest.MonkeyPatch):
65+
# system has 5 nodes in the group
66+
system, *_ = self.prepare(slurm_system, [], monkeypatch)
67+
68+
# first case asks for 2 nodes
69+
nnodes, nodes_list1 = system.get_nodes_by_spec(1, ["main:group1:2"])
70+
assert nnodes == 2
71+
72+
# second case asks for another 2 nodes
73+
nnodes, nodes_list2 = system.get_nodes_by_spec(1, ["main:group1:2"])
74+
assert nnodes == 2
75+
76+
assert nodes_list1 != nodes_list2, "Same nodes we allocated for two different requests"

0 commit comments

Comments
 (0)