Skip to content

Commit 73a6e47

Browse files
authored
Merge pull request #167 from jeffnvidia/simplify_slurm_system
Simplify slurm system
2 parents 48c5b02 + 9c44cc1 commit 73a6e47

File tree

2 files changed

+20
-25
lines changed

2 files changed

+20
-25
lines changed

src/cloudai/systems/slurm/slurm_system.py

Lines changed: 11 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
# See the License for the specific language governing permissions and
1515
# limitations under the License.
1616

17-
import getpass
1817
import logging
1918
import re
2019
from pathlib import Path
@@ -468,10 +467,9 @@ def get_available_nodes_from_group(
468467
available nodes.
469468
"""
470469
self.validate_partition_and_group(partition_name, group_name)
471-
current_user = getpass.getuser()
472470
self.update_node_states()
473471

474-
grouped_nodes = self.group_nodes_by_state(partition_name, group_name, current_user)
472+
grouped_nodes = self.group_nodes_by_state(partition_name, group_name)
475473
allocated_nodes = self.allocate_nodes(grouped_nodes, number_of_nodes, group_name)
476474

477475
# Log allocation details
@@ -502,9 +500,7 @@ def validate_partition_and_group(self, partition_name: str, group_name: str) ->
502500
if group_name not in self.groups[partition_name]:
503501
raise ValueError(f"Group '{group_name}' not found in partition '{partition_name}'.")
504502

505-
def group_nodes_by_state(
506-
self, partition_name: str, group_name: str, current_user: str
507-
) -> Dict[SlurmNodeState, List[SlurmNode]]:
503+
def group_nodes_by_state(self, partition_name: str, group_name: str) -> Dict[SlurmNodeState, List[SlurmNode]]:
508504
"""
509505
Group nodes by their states, excluding nodes allocated to the current user.
510506
@@ -524,11 +520,7 @@ def group_nodes_by_state(
524520

525521
for node in self.groups[partition_name][group_name]:
526522
if node.state in grouped_nodes:
527-
# Exclude nodes allocated to the current user
528-
if node.state == SlurmNodeState.ALLOCATED and node.user == current_user:
529-
continue
530-
if node.state in grouped_nodes:
531-
grouped_nodes[node.state].append(node)
523+
grouped_nodes[node.state].append(node)
532524

533525
return grouped_nodes
534526

@@ -552,26 +544,27 @@ def allocate_nodes(
552544
"""
553545
# Allocate nodes based on priority: idle, then completing, then allocated
554546
allocated_nodes = []
555-
available_states = [SlurmNodeState.IDLE, SlurmNodeState.COMPLETING, SlurmNodeState.ALLOCATED]
556-
557547
if isinstance(number_of_nodes, str) and number_of_nodes == "max_avail":
558-
for state in available_states:
559-
allocated_nodes.extend(grouped_nodes[state])
560-
548+
allocated_nodes.extend(grouped_nodes[SlurmNodeState.IDLE])
549+
allocated_nodes.extend(grouped_nodes[SlurmNodeState.COMPLETING])
561550
if len(allocated_nodes) == 0:
562551
raise ValueError(f"No available nodes in group '{group_name}'.")
563552

564553
elif isinstance(number_of_nodes, int):
565-
for state in available_states:
554+
for state in grouped_nodes:
566555
while grouped_nodes[state] and len(allocated_nodes) < number_of_nodes:
567556
allocated_nodes.append(grouped_nodes[state].pop(0))
568557

569558
if len(allocated_nodes) < number_of_nodes:
570559
raise ValueError(
571-
"Requested number of nodes ({}) exceeds the number of " "available nodes in group '{}'.".format(
560+
"Requested number of nodes ({}) exceeds the number of nodes in group '{}'.".format(
572561
number_of_nodes, group_name
573562
)
574563
)
564+
else:
565+
raise ValueError(
566+
f"number of nodes should either be an int or 'max_avail', number of nodes : {number_of_nodes}"
567+
)
575568

576569
return allocated_nodes
577570

tests/test_slurm_system.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -160,8 +160,7 @@ def grouped_nodes() -> dict[SlurmNodeState, list[SlurmNode]]:
160160
SlurmNodeState.COMPLETING: [
161161
SlurmNode(name="node04", partition=partition_name, state=SlurmNodeState.COMPLETING)
162162
],
163-
SlurmNodeState.ALLOCATED: [],
164-
SlurmNodeState.DOWN: [SlurmNode(name="node05", partition=partition_name, state=SlurmNodeState.DOWN)],
163+
SlurmNodeState.ALLOCATED: [SlurmNode(name="node05", partition=partition_name, state=SlurmNodeState.ALLOCATED)],
165164
}
166165

167166
return grouped_nodes
@@ -178,9 +177,11 @@ def test_allocate_nodes_max_avail(slurm_system: SlurmSystem, grouped_nodes: dict
178177
]
179178
returned_node_names = [node.name for node in available_nodes]
180179

181-
assert set(returned_node_names) == set(expected_node_names), "Should return all available nodes except DOWN nodes"
182-
down_node_name = grouped_nodes[SlurmNodeState.DOWN][0].name
183-
assert down_node_name not in returned_node_names, "DOWN node should not be included"
180+
assert set(returned_node_names) == set(
181+
expected_node_names
182+
), "Should return all available nodes except ALLOCATED nodes"
183+
allocated_node_name = grouped_nodes[SlurmNodeState.ALLOCATED][0].name
184+
assert allocated_node_name not in returned_node_names, "ALLOCATED node should not be included"
184185

185186

186187
def test_allocate_nodes_num_nodes_integers(
@@ -200,11 +201,12 @@ def test_allocate_nodes_exceeding_limit(
200201
slurm_system: SlurmSystem, grouped_nodes: dict[SlurmNodeState, list[SlurmNode]]
201202
):
202203
group_name = "group_name"
204+
num_nodes = 5
203205

204206
with pytest.raises(
205207
ValueError,
206208
match=re.escape(
207-
f"Requested number of nodes (4) exceeds the number of available nodes in group '{group_name}'."
209+
f"Requested number of nodes ({num_nodes}) exceeds the number of nodes in group '{group_name}'."
208210
),
209211
):
210-
slurm_system.allocate_nodes(grouped_nodes, 4, group_name)
212+
slurm_system.allocate_nodes(grouped_nodes, num_nodes, group_name)

0 commit comments

Comments
 (0)