Skip to content

Commit 13fe9fd

Browse files
committed
R&R Slurm integration
1 parent b47dd3a commit 13fe9fd

File tree

3 files changed

+165
-0
lines changed

3 files changed

+165
-0
lines changed

community/modules/scheduler/schedmd-slurm-gcp-v6-controller/modules/slurm_files/main.tf

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -191,6 +191,7 @@ locals {
191191
"mig_flex.py",
192192
"resume_wrapper.sh",
193193
"resume.py",
194+
"repair.py",
194195
"setup_network_storage.py",
195196
"setup.py",
196197
"slurmsync.py",
@@ -215,6 +216,7 @@ locals {
215216
"setup.py",
216217
"slurmsync.py",
217218
"sort_nodes.py",
219+
"repair.py",
218220
"suspend.py",
219221
"tpu.py",
220222
"util.py",
Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
#!/usr/bin/env python3
2+
3+
# Copyright (C) SchedMD LLC.
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
17+
import fcntl
18+
import json
19+
import logging
20+
from datetime import datetime, timezone
21+
from pathlib import Path
22+
import util
23+
from util import run, to_hostlist, lookup
24+
log = logging.getLogger()
25+
REPAIR_FILE = Path("/slurm/repair_operations.json")
26+
REPAIR_REASONS = frozenset(["PERFORMANCE", "SDC", "XID", "unspecified"])
27+
28+
def is_node_being_repaired(node):
29+
"""Check if a node is currently being repaired."""
30+
operations = get_operations()
31+
return node in operations and operations[node]["status"] == "REPAIR_IN_PROGRESS"
32+
33+
def get_operations():
34+
"""Get all repair operations from the file."""
35+
if not REPAIR_FILE.exists():
36+
return {}
37+
with open(REPAIR_FILE, 'r') as f:
38+
try:
39+
return json.load(f)
40+
except json.JSONDecodeError:
41+
return {}
42+
43+
def store_operations(operations):
44+
"""Store the operations to the file."""
45+
with open(REPAIR_FILE, 'w') as f:
46+
fcntl.lockf(f, fcntl.LOCK_EX)
47+
try:
48+
json.dump(operations, f, indent=4)
49+
finally:
50+
fcntl.lockf(f, fcntl.LOCK_UN)
51+
52+
def store_operation(node, operation_id, reason):
53+
"""Store a single repair operation."""
54+
operations = get_operations()
55+
operations[node] = {
56+
"operation_id": operation_id,
57+
"reason": reason,
58+
"status": "REPAIR_IN_PROGRESS",
59+
"timestamp": datetime.now(timezone.utc).isoformat(),
60+
}
61+
store_operations(operations)
62+
63+
def call_rr_api(node, reason):
64+
"""Call the R&R API for a given node."""
65+
log.info(f"Calling R&R API for node {node} with reason {reason}")
66+
inst = lookup().instance(node)
67+
if not inst:
68+
log.error(f"Instance {node} not found, cannot report fault.")
69+
return None
70+
cmd = f"gcloud compute instances report-host-as-faulty {node} --async --disruption-schedule=IMMEDIATE --fault-reasons=behavior={reason},description='VM is managed by Slurm' --zone={inst.zone} --format=json"
71+
try:
72+
result = run(cmd)
73+
op = json.loads(result.stdout)
74+
if isinstance(op, list):
75+
op = op[0]
76+
return op["name"]
77+
except Exception as e:
78+
log.error(f"Failed to call R&R API for {node}: {e}")
79+
return None
80+
81+
def get_operation_status(operation_id):
82+
"""Get the status of a GCP operation."""
83+
cmd = f'gcloud compute operations list --filter="name={operation_id}" --format=json'
84+
try:
85+
result = run(cmd)
86+
operations_list = json.loads(result.stdout)
87+
if operations_list and len(operations_list) > 0:
88+
return operations_list[0]
89+
90+
return None
91+
except Exception as e:
92+
log.error(f"Failed to get operation status for {operation_id}: {e}")
93+
return None
94+
95+
def poll_operations():
96+
"""Poll the status of ongoing repair operations."""
97+
operations = get_operations()
98+
if not operations:
99+
return
100+
101+
log.info("Polling repair operations")
102+
for node, op_details in operations.items():
103+
if op_details["status"] == "REPAIR_IN_PROGRESS":
104+
op_status = get_operation_status(op_details["operation_id"])
105+
if not op_status:
106+
continue
107+
108+
if op_status.get("status") == "DONE":
109+
if op_status.get("error"):
110+
log.error(f"Repair operation for {node} failed: {op_status['error']}")
111+
op_details["status"] = "FAILURE"
112+
run(f"{lookup().scontrol} update nodename={node} state=down reason='Repair failed'")
113+
else:
114+
log.info(f"Repair operation for {node} succeeded.")
115+
op_details["status"] = "SUCCESS"
116+
if op_details["status"] == "SUCCESS":
117+
inst = lookup().instance(node)
118+
if inst and inst.status == "RUNNING":
119+
log.info(f"Node {node} is back online, undraining.")
120+
op_details["status"] = "RECOVERED"
121+
122+
store_operations(operations)
123+

community/modules/scheduler/schedmd-slurm-gcp-v6-controller/modules/slurm_files/scripts/slurmsync.py

100755100644
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747
import tpu
4848
import conf
4949
import watch_delete_vm_op
50+
import repair
5051

5152
log = logging.getLogger()
5253

@@ -119,6 +120,17 @@ def apply(self, nodes: List[str]) -> None:
119120
log.info(f"{len(nodes)} nodes set down ({hostlist}) with reason={self.reason}")
120121
run(f"{lookup().scontrol} update nodename={hostlist} state=down reason={shlex.quote(self.reason)}")
121122

123+
@dataclass(frozen=True)
124+
class NodeActionRepair():
125+
reason: str
126+
def apply(self, nodes: List[str]) -> None:
127+
hostlist = util.to_hostlist(nodes)
128+
log.info(f"{len(nodes)} nodes to repair ({hostlist}) with reason={self.reason}")
129+
for node in nodes:
130+
op_id = repair.call_rr_api(node, self.reason)
131+
if op_id:
132+
repair.store_operation(node, op_id, self.reason)
133+
122134
@dataclass(frozen=True)
123135
class NodeActionUnknown():
124136
slurm_state: Optional[NodeState]
@@ -237,10 +249,33 @@ def _find_tpu_node_action(nodename, state) -> NodeAction:
237249

238250
return NodeActionUnchanged()
239251

252+
def get_node_reason(nodename: str) -> Optional[str]:
253+
"""Get the reason for a node's state."""
254+
try:
255+
result = run(f"{lookup().scontrol} show node {nodename}")
256+
for line in result.stdout.splitlines():
257+
if "Reason=" in line:
258+
reason = line.split("Reason=")[1].strip()
259+
if "[" in reason:
260+
return reason.split("[")[0].strip()
261+
return reason
262+
except Exception as e:
263+
log.error(f"Failed to get reason for node {nodename}: {e}")
264+
return None
265+
266+
240267
def get_node_action(nodename: str) -> NodeAction:
241268
"""Determine node/instance status that requires action"""
242269
lkp = lookup()
243270
state = lkp.node_state(nodename)
271+
if state is not None and "DRAIN" in state.flags:
272+
reason = get_node_reason(nodename)
273+
if reason in repair.REPAIR_REASONS:
274+
if repair.is_node_being_repaired(nodename):
275+
return NodeActionUnchanged()
276+
inst = lkp.instance(nodename.split(".")[0])
277+
if inst:
278+
return NodeActionRepair(reason=reason)
244279

245280
if lkp.node_is_gke(nodename):
246281
return NodeActionUnchanged()
@@ -631,6 +666,11 @@ def main():
631666
except Exception:
632667
log.exception("failed to sync instances")
633668

669+
try:
670+
repair.poll_operations()
671+
except Exception:
672+
log.exception("failed to poll repair operations")
673+
634674
try:
635675
sync_flex_migs(lkp)
636676
except Exception:

0 commit comments

Comments
 (0)