Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,7 @@ locals {
"mig_flex.py",
"resume_wrapper.sh",
"resume.py",
"repair.py",
"setup_network_storage.py",
"setup.py",
"slurmsync.py",
Expand All @@ -218,6 +219,7 @@ locals {
"setup.py",
"slurmsync.py",
"sort_nodes.py",
"repair.py",
"suspend.py",
"tpu.py",
"util.py",
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
#!/usr/bin/env python3

# Copyright 2026 "Google LLC"
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import fcntl
import json
import logging
from datetime import datetime, timezone
from pathlib import Path
import subprocess
import util
from util import run, to_hostlist, lookup
log = logging.getLogger()
REPAIR_FILE = Path("/slurm/repair_operations.json")
REPAIR_REASONS = frozenset(["PERFORMANCE", "SDC", "XID", "unspecified"])

def is_node_being_repaired(node):
"""Check if a node is currently being repaired."""
operations = _get_operations()
return node in operations and operations[node]["status"] == "REPAIR_IN_PROGRESS"

def _get_operations():
"""Get all repair operations from the file."""
if not REPAIR_FILE.exists():
return {}
with open(REPAIR_FILE, 'r', encoding='utf-8') as f:
try:
return json.load(f)
except json.JSONDecodeError:
log.error(f"Failed to decode JSON from {REPAIR_FILE}, returning empty operations list.")
return {}

def _write_all_operations(operations):
"""Store the operations to the file safely."""
try:
with open(REPAIR_FILE, 'a', encoding='utf-8') as f:
try:
fcntl.lockf(f, fcntl.LOCK_EX | fcntl.LOCK_NB)
except (IOError, BlockingIOError):
log.warning(f"Could not acquire lock on {REPAIR_FILE}. Another process may be running.")
return False

try:
f.seek(0)
f.truncate()
json.dump(operations, f, indent=4)
f.flush()
return True
finally:
fcntl.lockf(f, fcntl.LOCK_UN)
except (IOError, TypeError) as e:
log.error(f"Failed to store repair operations to {REPAIR_FILE}: {e}")
return False

def store_operation(node, operation_id, reason):
"""Store a single repair operation."""
operations = _get_operations()
operations[node] = {
"operation_id": operation_id,
"reason": reason,
"status": "REPAIR_IN_PROGRESS",
"timestamp": datetime.now(timezone.utc).isoformat(),
}
if not _write_all_operations(operations):
log.error(f"Failed to persist repair operation for node {node}.")

def call_rr_api(node, reason):
"""Call the R&R API for a given node."""
log.info(f"Calling R&R API for node {node} with reason {reason}")
inst = lookup().instance(node)
if not inst:
log.error(f"Instance {node} not found, cannot report fault.")
return None
cmd = f"gcloud compute instances report-host-as-faulty {node} --async --disruption-schedule=IMMEDIATE --fault-reasons=behavior={reason},description='VM is managed by Slurm' --zone={inst.zone} --format=json"
try:
result = run(cmd)
log.info(f"gcloud compute instances report-host-as-faulty stdout: {result.stdout.strip()}")
op = json.loads(result.stdout)
if isinstance(op, list):
op = op[0]
return op["name"]
except (subprocess.CalledProcessError, json.JSONDecodeError) as e:
log.error(f"Failed to call or parse R&R API response for {node}: {e}")
return None
except Exception as e:
log.error(f"An unexpected error occurred while calling R&R API for {node}: {e}")
return None

def _get_operation_status(operation_id):
"""Get the status of a GCP operation."""
cmd = f'gcloud compute operations list --filter="name={operation_id}" --format=json'
try:
result = run(cmd)
operations_list = json.loads(result.stdout)
if operations_list:
return operations_list[0]

return None
except (subprocess.CalledProcessError, json.JSONDecodeError) as e:
log.error(f"Failed to get or parse operation status for {operation_id}: {e}")
return None
except Exception as e:
log.error(f"An unexpected error occurred while getting operation status for {operation_id}: {e}")
return None

def poll_operations():
"""Poll the status of ongoing repair operations."""
operations = _get_operations()
if not operations:
return

log.info("Polling repair operations")
for node, op_details in operations.items():
if op_details["status"] == "REPAIR_IN_PROGRESS":
gcp_op_status = _get_operation_status(op_details["operation_id"])
if not gcp_op_status:
continue

if gcp_op_status.get("status") == "DONE":
if gcp_op_status.get("error"):
log.error(f"Repair operation for {node} failed: {gcp_op_status['error']}")
op_details["status"] = "FAILURE"
run(f"{lookup().scontrol} update nodename={node} state=down reason='Repair failed'")
else:
log.info(f"Repair operation for {node} succeeded. Powering down the VM")
run(f"{lookup().scontrol} update nodename={node} state=power_down reason='Repair succeeded'")
op_details["status"] = "SUCCESS"

if not _write_all_operations(operations):
log.error("Failed to persist updated repair operations state after polling.")
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import re
import sys
import shlex
import subprocess
from datetime import datetime, timedelta
from itertools import chain
from pathlib import Path
Expand Down Expand Up @@ -48,6 +49,7 @@
import conf
import conf_v2411
import watch_delete_vm_op
import repair

log = logging.getLogger()

Expand Down Expand Up @@ -120,6 +122,17 @@ def apply(self, nodes: List[str]) -> None:
log.info(f"{len(nodes)} nodes set down ({hostlist}) with reason={self.reason}")
run(f"{lookup().scontrol} update nodename={hostlist} state=down reason={shlex.quote(self.reason)}")

@dataclass(frozen=True)
class NodeActionRepair():
reason: str
def apply(self, nodes: List[str]) -> None:
hostlist = util.to_hostlist(nodes)
log.info(f"{len(nodes)} nodes to repair ({hostlist}) with reason={self.reason}")
for node in nodes:
op_id = repair.call_rr_api(node, self.reason)
if op_id:
repair.store_operation(node, op_id, self.reason)

@dataclass(frozen=True)
class NodeActionUnknown():
slurm_state: Optional[NodeState]
Expand Down Expand Up @@ -238,11 +251,32 @@ def _find_tpu_node_action(nodename, state) -> NodeAction:

return NodeActionUnchanged()

def get_node_reason(nodename: str) -> Optional[str]:
"""Get the reason for a node's state using JSON output."""
try:
# Use --json to get structured data
result = run(f"{lookup().scontrol} show node {nodename} --json")
data = json.loads(result.stdout)

# Access the reason field directly from the JSON structure
nodes = data.get('nodes', [])
if nodes:
reason = nodes[0].get('reason')
# Handle the specific formatting logic for brackets if needed
if reason and "[" in reason:
return reason.split("[")[0].strip()
return reason
except (subprocess.CalledProcessError, json.JSONDecodeError) as e:
log.error(f"Failed to execute scontrol or parse its JSON output for node {nodename}: {e}")
except Exception as e:
log.error(f"An unexpected error occurred while getting reason for node {nodename}: {e}")
return None


def get_node_action(nodename: str) -> NodeAction:
"""Determine node/instance status that requires action"""
lkp = lookup()
state = lkp.node_state(nodename)

if lkp.node_is_gke(nodename):
return NodeActionUnchanged()

Expand All @@ -264,6 +298,14 @@ def get_node_action(nodename: str) -> NodeAction:
("POWER_DOWN", "POWERING_UP", "POWERING_DOWN", "POWERED_DOWN")
) & (state.flags if state is not None else set())

if state is not None and "DRAIN" in state.flags:
reason = get_node_reason(nodename)
if reason in repair.REPAIR_REASONS:
if repair.is_node_being_repaired(nodename):
return NodeActionUnchanged()
if inst:
return NodeActionRepair(reason=reason)

if (state is None) and (inst is None):
# Should never happen
return NodeActionUnknown(None, None)
Expand Down Expand Up @@ -643,6 +685,11 @@ def main():
except Exception:
log.exception("failed to sync instances")

try:
repair.poll_operations()
except Exception:
log.exception("failed to poll repair operations")

try:
sync_flex_migs(lkp)
except Exception:
Expand Down
Loading
Loading