|
18 | 18 |
|
19 | 19 | import logging |
20 | 20 | import re |
| 21 | +from copy import copy |
21 | 22 | from pathlib import Path |
22 | | -from typing import Any, Dict, List, Optional, Tuple, Union |
| 23 | +from typing import Any, Dict, Iterable, List, Optional, Tuple, Union |
23 | 24 |
|
24 | 25 | from pydantic import BaseModel, ConfigDict, Field, field_serializer, field_validator |
25 | 26 |
|
26 | 27 | from cloudai.core import BaseJob, File, Installable, System |
27 | 28 | from cloudai.models.scenario import ReportConfig, parse_reports_spec |
28 | 29 | from cloudai.util import CommandShell |
29 | 30 |
|
| 31 | +from .slurm_job import SlurmJob |
30 | 32 | from .slurm_metadata import SlurmStepMetadata |
31 | 33 | from .slurm_node import SlurmNode, SlurmNodeState |
32 | 34 |
|
@@ -137,6 +139,8 @@ class SlurmSystem(BaseModel, System): |
137 | 139 | data_repository: Optional[DataRepositoryConfig] = None |
138 | 140 | reports: Optional[dict[str, ReportConfig]] = None |
139 | 141 |
|
| 142 | + group_allocated: set[SlurmNode] = Field(default_factory=set, exclude=True) |
| 143 | + |
140 | 144 | @field_validator("reports", mode="before") |
141 | 145 | @classmethod |
142 | 146 | def parse_reports(cls, value: dict[str, Any] | None) -> dict[str, ReportConfig] | None: |
@@ -199,6 +203,7 @@ def update(self) -> None: |
199 | 203 | all_nodes = self.nodes_from_sinfo() |
200 | 204 | self.update_nodes_state_and_user(all_nodes, insert_new=True) |
201 | 205 | self.update_nodes_state_and_user(self.nodes_from_squeue()) |
| 206 | + self.update_nodes_state_and_user(self.group_allocated) |
202 | 207 |
|
203 | 208 | def nodes_from_sinfo(self) -> list[SlurmNode]: |
204 | 209 | sinfo_output, _ = self.fetch_command_output("sinfo -o '%P|%t|%u|%N'") |
@@ -232,7 +237,7 @@ def nodes_from_squeue(self) -> list[SlurmNode]: |
232 | 237 | nodes.append(SlurmNode(name=node, partition=partition, state=SlurmNodeState.ALLOCATED, user=user)) |
233 | 238 | return nodes |
234 | 239 |
|
235 | | - def update_nodes_state_and_user(self, nodes: list[SlurmNode], insert_new: bool = False) -> None: |
| 240 | + def update_nodes_state_and_user(self, nodes: Iterable[SlurmNode], insert_new: bool = False) -> None: |
236 | 241 | for node in nodes: |
237 | 242 | for part in self.partitions: |
238 | 243 | if part.name != node.partition: |
@@ -595,13 +600,18 @@ def allocate_nodes( |
595 | 600 | f"and ensure there are enough resources to meet the requested node count. Additionally, " |
596 | 601 | f"verify that the system can accommodate the number of nodes required by the test scenario." |
597 | 602 | ) |
| 603 | + |
598 | 604 | else: |
599 | 605 | raise ValueError( |
600 | 606 | f"The 'number_of_nodes' argument must be either an integer specifying the number of nodes to allocate," |
601 | 607 | f" or 'max_avail' to allocate all available nodes. Received: '{number_of_nodes}'. " |
602 | 608 | "Please correct the input." |
603 | 609 | ) |
604 | 610 |
|
| 611 | + for node in allocated_nodes: |
| 612 | + node.state = SlurmNodeState.ALLOCATED |
| 613 | + self.group_allocated.update(copy(node) for node in allocated_nodes) |
| 614 | + |
605 | 615 | return allocated_nodes |
606 | 616 |
|
607 | 617 | def scancel(self, job_id: int) -> None: |
@@ -748,3 +758,10 @@ def get_nodes_by_spec(self, num_nodes: int, nodes: list[str]) -> Tuple[int, list |
748 | 758 |
|
749 | 759 | def system_installables(self) -> list[Installable]: |
750 | 760 | return [File(Path(__file__).parent.absolute() / "slurm-metadata.sh")] |
| 761 | + |
| 762 | + def complete_job(self, job: SlurmJob) -> None: |
| 763 | + out, _ = self.fetch_command_output(f"sacct -j {job.id} -p --noheader -X --format=NodeList") |
| 764 | + spec = out.splitlines()[0] if out.splitlines() else out |
| 765 | + nodelist = set(parse_node_list(spec.strip().replace("|", ""))) |
| 766 | + to_unlock = [node for node in self.group_allocated if node.name in nodelist] |
| 767 | + self.group_allocated.difference_update(to_unlock) |
0 commit comments