@@ -76,25 +76,8 @@ class SlurmPartition(BaseModel):
7676
7777 model_config = ConfigDict (extra = "forbid" )
7878 name : str
79- nodes : List [str ]
8079 groups : List [SlurmGroup ] = []
81-
82- _slurm_nodes : List [SlurmNode ] = []
83-
84- @property
85- def slurm_nodes (self ) -> List [SlurmNode ]:
86- if self ._slurm_nodes :
87- return self ._slurm_nodes
88-
89- node_names = set ()
90- for nodes_list in self .nodes :
91- node_names .update (set (parse_node_list (nodes_list )))
92-
93- self ._slurm_nodes = [
94- SlurmNode (name = node_name , partition = self .name , state = SlurmNodeState .UNKNOWN_STATE )
95- for node_name in node_names
96- ]
97- return self ._slurm_nodes
80+ slurm_nodes : list [SlurmNode ] = Field (default_factory = list [SlurmNode ], exclude = True )
9881
9982
10083class SlurmSystem (BaseModel , System ):
@@ -147,7 +130,10 @@ def groups(self) -> Dict[str, Dict[str, List[SlurmNode]]]:
147130 node_names = set ()
148131 for group_nodes in group .nodes :
149132 node_names .update (set (parse_node_list (group_nodes )))
150- groups [part .name ][group .name ] = [node for node in part .slurm_nodes if node .name in node_names ]
133+ groups [part .name ][group .name ] = [
134+ SlurmNode (name = node_name , partition = self .name , state = SlurmNodeState .UNKNOWN_STATE )
135+ for node_name in node_names
136+ ]
151137
152138 return groups
153139
@@ -163,7 +149,10 @@ def update(self) -> None:
163149 commands, and correlating this information to determine the state of each node and the user running jobs on
164150 each node.
165151 """
166- self .update_node_states ()
152+ squeue_output , _ = self .fetch_command_output ("squeue -o '%N|%u' --noheader" )
153+ sinfo_output , _ = self .fetch_command_output ("sinfo" )
154+ node_user_map = self .parse_squeue_output (squeue_output )
155+ self .parse_sinfo_output (sinfo_output , node_user_map )
167156
168157 def is_job_running (self , job : BaseJob , retry_threshold : int = 3 ) -> bool :
169158 """
@@ -373,7 +362,7 @@ def get_available_nodes_from_group(
373362 """
374363 self .validate_partition_and_group (partition_name , group_name )
375364
376- self .update_node_states ()
365+ self .update ()
377366
378367 grouped_nodes = self .group_nodes_by_state (partition_name , group_name )
379368
@@ -490,18 +479,6 @@ def allocate_nodes(
490479
491480 return allocated_nodes
492481
493- def is_node_in_system (self , node_name : str ) -> bool :
494- """
495- Check if a given node is part of the Slurm system.
496-
497- Args:
498- node_name (str): The name of the node to check.
499-
500- Returns:
501- True if the node is part of the system, otherwise False.
502- """
503- return any (any (node .name == node_name for node in part .slurm_nodes ) for part in self .partitions )
504-
505482 def scancel (self , job_id : int ) -> None :
506483 """
507484 Terminates a specified Slurm job by sending a cancellation command.
@@ -511,39 +488,6 @@ def scancel(self, job_id: int) -> None:
511488 """
512489 self .cmd_shell .execute (f"scancel { job_id } " )
513490
514- def update_node_states (self ) -> None :
515- """
516- Update the states of nodes in the Slurm system.
517-
518- By querying the current state of each node using the 'sinfo' command, and correlates this with 'squeue' to
519- determine which user is running jobs on each node. This method parses the output of these commands, identifies
520- the state of nodes and the users, and updates the corresponding SlurmNode instances in the system.
521- """
522- squeue_output = self .get_squeue ()
523- sinfo_output = self .get_sinfo ()
524- node_user_map = self .parse_squeue_output (squeue_output )
525- self .parse_sinfo_output (sinfo_output , node_user_map )
526-
527- def get_squeue (self ) -> str :
528- """
529- Fetch the output from the 'squeue' command.
530-
531- Returns
532- str: The stdout from the 'squeue' command execution.
533- """
534- squeue_output , _ = self .fetch_command_output ("squeue -o '%N|%u' --noheader" )
535- return squeue_output
536-
537- def get_sinfo (self ) -> str :
538- """
539- Fetch the output from the 'sinfo' command.
540-
541- Returns
542- str: The stdout from the 'sinfo' command execution.
543- """
544- sinfo_output , _ = self .fetch_command_output ("sinfo" )
545- return sinfo_output
546-
547491 def fetch_command_output (self , command : str ) -> Tuple [str , str ]:
548492 """
549493 Execute a system command and return its output.
@@ -614,12 +558,25 @@ def parse_sinfo_output(self, sinfo_output: str, node_user_map: Dict[str, str]) -
614558 for part in self .partitions :
615559 if part .name != partition :
616560 continue
561+
562+ found = False
617563 for node in part .slurm_nodes :
618564 if node .name == node_name :
565+ found = True
619566 node .state = state_enum
620567 node .user = node_user_map .get (node_name , "N/A" )
621568 break
622569
570+ if not found :
571+ part .slurm_nodes .append (
572+ SlurmNode (
573+ name = node_name ,
574+ partition = partition ,
575+ state = state_enum ,
576+ user = node_user_map .get (node_name , "N/A" ),
577+ )
578+ )
579+
623580 def convert_state_to_enum (self , state_str : str ) -> SlurmNodeState :
624581 """
625582 Convert a Slurm node state string to its corresponding enum member.
@@ -709,13 +666,30 @@ def parse_nodes(self, nodes: List[str]) -> List[str]:
709666 group_nodes = self .get_available_nodes_from_group (partition_name , group_name , num_nodes )
710667 parsed_nodes += [node .name for node in group_nodes ]
711668 else :
712- # Handle both individual node names and ranges
713- if self .is_node_in_system (node_spec ) or "[" in node_spec :
714- expanded_nodes = parse_node_list (node_spec )
715- parsed_nodes += expanded_nodes
716- else :
717- raise ValueError (f"Node '{ node_spec } ' not found." )
669+ expanded_nodes = parse_node_list (node_spec )
670+ parsed_nodes += expanded_nodes
718671
719672 # Remove duplicates while preserving order
720673 parsed_nodes = list (dict .fromkeys (parsed_nodes ))
721674 return parsed_nodes
675+
676+ def get_nodes_by_spec (self , num_nodes : int , nodes : list [str ]) -> Tuple [int , list [str ]]:
677+ """
678+ Retrieve a list of node names based on specifications.
679+
680+ When nodes is empty, returns `(num_nodes, [])`, otherwise parses the node specifications and returns the number
681+ of nodes and a list of node names.
682+
683+ Args:
684+ num_nodes (int): The number of nodes, can't be `0`.
685+ nodes (list[str]): A list of node names specifications, slurm format or `PARTITION:GROUP:NUM_NODES`.
686+
687+ Returns:
688+ Tuple[int, list[str]]: The number of nodes and a list of node names.
689+ """
690+ num_nodes , node_list = num_nodes , []
691+ parsed_nodes = self .parse_nodes (nodes )
692+ if parsed_nodes :
693+ num_nodes = len (parsed_nodes )
694+ node_list = parsed_nodes
695+ return num_nodes , node_list
0 commit comments