diff --git a/community/modules/scheduler/schedmd-slurm-gcp-v6-controller/modules/slurm_files/main.tf b/community/modules/scheduler/schedmd-slurm-gcp-v6-controller/modules/slurm_files/main.tf index 7b380cb5f3..c6a967c1a3 100644 --- a/community/modules/scheduler/schedmd-slurm-gcp-v6-controller/modules/slurm_files/main.tf +++ b/community/modules/scheduler/schedmd-slurm-gcp-v6-controller/modules/slurm_files/main.tf @@ -192,6 +192,7 @@ locals { "mig_flex.py", "resume_wrapper.sh", "resume.py", + "repair.py", "setup_network_storage.py", "setup.py", "slurmsync.py", @@ -218,6 +219,7 @@ locals { "setup.py", "slurmsync.py", "sort_nodes.py", + "repair.py", "suspend.py", "tpu.py", "util.py", diff --git a/community/modules/scheduler/schedmd-slurm-gcp-v6-controller/modules/slurm_files/scripts/repair.py b/community/modules/scheduler/schedmd-slurm-gcp-v6-controller/modules/slurm_files/scripts/repair.py new file mode 100644 index 0000000000..a550200c94 --- /dev/null +++ b/community/modules/scheduler/schedmd-slurm-gcp-v6-controller/modules/slurm_files/scripts/repair.py @@ -0,0 +1,142 @@ +#!/usr/bin/env python3 + +# Copyright 2026 "Google LLC" +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import fcntl +import json +import logging +from datetime import datetime, timezone +from pathlib import Path +import subprocess +import util +from util import run, to_hostlist, lookup +log = logging.getLogger() +REPAIR_FILE = Path("/slurm/repair_operations.json") +REPAIR_REASONS = frozenset(["PERFORMANCE", "SDC", "XID", "unspecified"]) + +def is_node_being_repaired(node): + """Check if a node is currently being repaired.""" + operations = _get_operations() + return node in operations and operations[node]["status"] == "REPAIR_IN_PROGRESS" + +def _get_operations(): + """Get all repair operations from the file.""" + if not REPAIR_FILE.exists(): + return {} + with open(REPAIR_FILE, 'r', encoding='utf-8') as f: + try: + return json.load(f) + except json.JSONDecodeError: + log.error(f"Failed to decode JSON from {REPAIR_FILE}, returning empty operations list.") + return {} + +def _write_all_operations(operations): + """Store the operations to the file safely.""" + try: + with open(REPAIR_FILE, 'a', encoding='utf-8') as f: + try: + fcntl.lockf(f, fcntl.LOCK_EX | fcntl.LOCK_NB) + except (IOError, BlockingIOError): + log.warning(f"Could not acquire lock on {REPAIR_FILE}. Another process may be running.") + return False + + try: + f.seek(0) + f.truncate() + json.dump(operations, f, indent=4) + f.flush() + return True + finally: + fcntl.lockf(f, fcntl.LOCK_UN) + except (IOError, TypeError) as e: + log.error(f"Failed to store repair operations to {REPAIR_FILE}: {e}") + return False + +def store_operation(node, operation_id, reason): + """Store a single repair operation.""" + operations = _get_operations() + operations[node] = { + "operation_id": operation_id, + "reason": reason, + "status": "REPAIR_IN_PROGRESS", + "timestamp": datetime.now(timezone.utc).isoformat(), + } + if not _write_all_operations(operations): + log.error(f"Failed to persist repair operation for node {node}.") + +def call_rr_api(node, reason): + """Call the R&R API for a given node.""" + log.info(f"Calling R&R API for node {node} with reason {reason}") + inst = lookup().instance(node) + if not inst: + log.error(f"Instance {node} not found, cannot report fault.") + return None + cmd = f"gcloud compute instances report-host-as-faulty {node} --async --disruption-schedule=IMMEDIATE --fault-reasons=behavior={reason},description='VM is managed by Slurm' --zone={inst.zone} --format=json" + try: + result = run(cmd) + log.info(f"gcloud compute instances report-host-as-faulty stdout: {result.stdout.strip()}") + op = json.loads(result.stdout) + if isinstance(op, list): + op = op[0] + return op["name"] + except (subprocess.CalledProcessError, json.JSONDecodeError) as e: + log.error(f"Failed to call or parse R&R API response for {node}: {e}") + return None + except Exception as e: + log.error(f"An unexpected error occurred while calling R&R API for {node}: {e}") + return None + +def _get_operation_status(operation_id): + """Get the status of a GCP operation.""" + cmd = f'gcloud compute operations list --filter="name={operation_id}" --format=json' + try: + result = run(cmd) + operations_list = json.loads(result.stdout) + if operations_list: + return operations_list[0] + + return None + except (subprocess.CalledProcessError, json.JSONDecodeError) as e: + log.error(f"Failed to get or parse operation status for {operation_id}: {e}") + return None + except Exception as e: + log.error(f"An unexpected error occurred while getting operation status for {operation_id}: {e}") + return None + +def poll_operations(): + """Poll the status of ongoing repair operations.""" + operations = _get_operations() + if not operations: + return + + log.info("Polling repair operations") + for node, op_details in operations.items(): + if op_details["status"] == "REPAIR_IN_PROGRESS": + gcp_op_status = _get_operation_status(op_details["operation_id"]) + if not gcp_op_status: + continue + + if gcp_op_status.get("status") == "DONE": + if gcp_op_status.get("error"): + log.error(f"Repair operation for {node} failed: {gcp_op_status['error']}") + op_details["status"] = "FAILURE" + run(f"{lookup().scontrol} update nodename={node} state=down reason='Repair failed'") + else: + log.info(f"Repair operation for {node} succeeded. Powering down the VM") + run(f"{lookup().scontrol} update nodename={node} state=power_down reason='Repair succeeded'") + op_details["status"] = "SUCCESS" + + if not _write_all_operations(operations): + log.error("Failed to persist updated repair operations state after polling.") diff --git a/community/modules/scheduler/schedmd-slurm-gcp-v6-controller/modules/slurm_files/scripts/slurmsync.py b/community/modules/scheduler/schedmd-slurm-gcp-v6-controller/modules/slurm_files/scripts/slurmsync.py old mode 100755 new mode 100644 index 9e5d2519a3..fc6c64fde8 --- a/community/modules/scheduler/schedmd-slurm-gcp-v6-controller/modules/slurm_files/scripts/slurmsync.py +++ b/community/modules/scheduler/schedmd-slurm-gcp-v6-controller/modules/slurm_files/scripts/slurmsync.py @@ -21,6 +21,7 @@ import re import sys import shlex +import subprocess from datetime import datetime, timedelta from itertools import chain from pathlib import Path @@ -48,6 +49,7 @@ import conf import conf_v2411 import watch_delete_vm_op +import repair log = logging.getLogger() @@ -120,6 +122,17 @@ def apply(self, nodes: List[str]) -> None: log.info(f"{len(nodes)} nodes set down ({hostlist}) with reason={self.reason}") run(f"{lookup().scontrol} update nodename={hostlist} state=down reason={shlex.quote(self.reason)}") +@dataclass(frozen=True) +class NodeActionRepair(): + reason: str + def apply(self, nodes: List[str]) -> None: + hostlist = util.to_hostlist(nodes) + log.info(f"{len(nodes)} nodes to repair ({hostlist}) with reason={self.reason}") + for node in nodes: + op_id = repair.call_rr_api(node, self.reason) + if op_id: + repair.store_operation(node, op_id, self.reason) + @dataclass(frozen=True) class NodeActionUnknown(): slurm_state: Optional[NodeState] @@ -238,11 +251,32 @@ def _find_tpu_node_action(nodename, state) -> NodeAction: return NodeActionUnchanged() +def get_node_reason(nodename: str) -> Optional[str]: + """Get the reason for a node's state using JSON output.""" + try: + # Use --json to get structured data + result = run(f"{lookup().scontrol} show node {nodename} --json") + data = json.loads(result.stdout) + + # Access the reason field directly from the JSON structure + nodes = data.get('nodes', []) + if nodes: + reason = nodes[0].get('reason') + # Handle the specific formatting logic for brackets if needed + if reason and "[" in reason: + return reason.split("[")[0].strip() + return reason + except (subprocess.CalledProcessError, json.JSONDecodeError) as e: + log.error(f"Failed to execute scontrol or parse its JSON output for node {nodename}: {e}") + except Exception as e: + log.error(f"An unexpected error occurred while getting reason for node {nodename}: {e}") + return None + + def get_node_action(nodename: str) -> NodeAction: """Determine node/instance status that requires action""" lkp = lookup() state = lkp.node_state(nodename) - if lkp.node_is_gke(nodename): return NodeActionUnchanged() @@ -264,6 +298,14 @@ def get_node_action(nodename: str) -> NodeAction: ("POWER_DOWN", "POWERING_UP", "POWERING_DOWN", "POWERED_DOWN") ) & (state.flags if state is not None else set()) + if state is not None and "DRAIN" in state.flags: + reason = get_node_reason(nodename) + if reason in repair.REPAIR_REASONS: + if repair.is_node_being_repaired(nodename): + return NodeActionUnchanged() + if inst: + return NodeActionRepair(reason=reason) + if (state is None) and (inst is None): # Should never happen return NodeActionUnknown(None, None) @@ -643,6 +685,11 @@ def main(): except Exception: log.exception("failed to sync instances") + try: + repair.poll_operations() + except Exception: + log.exception("failed to poll repair operations") + try: sync_flex_migs(lkp) except Exception: diff --git a/community/modules/scheduler/schedmd-slurm-gcp-v6-controller/modules/slurm_files/scripts/tests/test_repair.py b/community/modules/scheduler/schedmd-slurm-gcp-v6-controller/modules/slurm_files/scripts/tests/test_repair.py new file mode 100644 index 0000000000..9e84bfa23c --- /dev/null +++ b/community/modules/scheduler/schedmd-slurm-gcp-v6-controller/modules/slurm_files/scripts/tests/test_repair.py @@ -0,0 +1,178 @@ +#!/usr/bin/env python3 + +# Copyright 2026 "Google LLC" +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +from unittest.mock import patch, mock_open, MagicMock +import json +import fcntl +from datetime import datetime, timezone +from pathlib import Path +import subprocess + +import repair + +class RepairScriptTest(unittest.TestCase): + + def setUp(self): + # Reset the REPAIR_FILE path for each test + repair.REPAIR_FILE = Path("/slurm/repair_operations.json") + + @patch('repair._get_operations') + def test_is_node_being_repaired(self, mock_get_operations): + mock_get_operations.return_value = { + "node-1": {"status": "REPAIR_IN_PROGRESS"}, + "node-2": {"status": "SUCCESS"} + } + self.assertTrue(repair.is_node_being_repaired("node-1")) + self.assertFalse(repair.is_node_being_repaired("node-2")) + self.assertFalse(repair.is_node_being_repaired("node-3")) + + @patch('builtins.open', new_callable=mock_open, read_data='{"node-1": {"status": "REPAIR_IN_PROGRESS"}}') + @patch('pathlib.Path.exists', return_value=True) + def test_get_operations_success(self, mock_exists, mock_open_file): + ops = repair._get_operations() + self.assertEqual(ops, {"node-1": {"status": "REPAIR_IN_PROGRESS"}}) + mock_open_file.assert_called_with(repair.REPAIR_FILE, 'r', encoding='utf-8') + + @patch('builtins.open', new_callable=mock_open, read_data='invalid json') + @patch('pathlib.Path.exists', return_value=True) + def test_get_operations_json_decode_error(self, mock_exists, mock_open_file): + ops = repair._get_operations() + self.assertEqual(ops, {}) + + @patch('pathlib.Path.exists', return_value=False) + def test_get_operations_file_not_found(self, mock_exists): + ops = repair._get_operations() + self.assertEqual(ops, {}) + + @patch('builtins.open', new_callable=mock_open) + @patch('fcntl.lockf') + def test_write_all_operations(self, mock_lockf, mock_open_file): + operations = {"node-1": {"status": "SUCCESS"}} + repair._write_all_operations(operations) + mock_open_file.assert_called_with(repair.REPAIR_FILE, 'a', encoding='utf-8') + handle = mock_open_file() + expected_json_string = json.dumps(operations, indent=4) + written_data_parts = [call_args[0][0] for call_args in handle.write.call_args_list] + written_data = ''.join(written_data_parts) + self.assertEqual(written_data, expected_json_string) + mock_lockf.assert_any_call(handle, fcntl.LOCK_EX | fcntl.LOCK_NB) + mock_lockf.assert_any_call(handle, fcntl.LOCK_UN) + + @patch('repair._get_operations', return_value={}) + @patch('repair._write_all_operations') + @patch('repair.datetime') + def test_store_operation(self, mock_datetime, mock_write_all_operations, mock_get_operations): + mock_now = datetime(2025, 1, 1, tzinfo=timezone.utc) + mock_datetime.now.return_value = mock_now + + repair.store_operation("node-1", "op-123", "PERFORMANCE") + + expected_operations = { + "node-1": { + "operation_id": "op-123", + "reason": "PERFORMANCE", + "status": "REPAIR_IN_PROGRESS", + "timestamp": mock_now.isoformat(), + } + } + mock_write_all_operations.assert_called_with(expected_operations) + + @patch('repair.lookup') + @patch('repair.run') + def test_call_rr_api_success(self, mock_run, mock_lookup): + mock_instance = MagicMock() + mock_instance.zone = "us-central1-a" + mock_lookup.return_value.instance.return_value = mock_instance + + mock_run.return_value = MagicMock( + stdout='[{"name": "op-123"}]', + stderr='', + returncode=0 + ) + + op_id = repair.call_rr_api("node-1", "XID") + self.assertEqual(op_id, "op-123") + mock_run.assert_called_once() + + @patch('repair.lookup') + def test_call_rr_api_instance_not_found(self, mock_lookup): + mock_lookup.return_value.instance.return_value = None + op_id = repair.call_rr_api("node-1", "XID") + self.assertIsNone(op_id) + + @patch('repair.lookup') + @patch('repair.run', side_effect=subprocess.CalledProcessError(1, 'cmd')) + def test_call_rr_api_run_error(self, mock_run, mock_lookup): + mock_instance = MagicMock() + mock_instance.zone = "us-central1-a" + mock_lookup.return_value.instance.return_value = mock_instance + op_id = repair.call_rr_api("node-1", "XID") + self.assertIsNone(op_id) + + @patch('repair.run') + def test_get_operation_status_success(self, mock_run): + mock_run.return_value.stdout = '[{"status": "DONE"}]' + status = repair._get_operation_status("op-123") + self.assertEqual(status, {"status": "DONE"}) + + @patch('repair.run') + def test_get_operation_status_empty_list(self, mock_run): + mock_run.return_value.stdout = '[]' + status = repair._get_operation_status("op-123") + self.assertIsNone(status) + + @patch('repair._get_operations') + @patch('repair._write_all_operations') + @patch('repair._get_operation_status') + @patch('repair.lookup') + @patch('repair.run') + def test_poll_operations(self, mock_run, mock_lookup, mock_get_op_status, mock_store_ops, mock_get_ops): + # Setup initial operations data + mock_get_ops.return_value = { + "node-1": {"operation_id": "op-1", "status": "REPAIR_IN_PROGRESS"}, + "node-2": {"operation_id": "op-2", "status": "REPAIR_IN_PROGRESS"}, + "node-3": {"status": "SUCCESS"}, + "node-4": {"status": "RECOVERED"} + } + + # Mock responses for get_operation_status + mock_get_op_status.side_effect = [ + {"status": "DONE"}, + {"status": "DONE", "error": "Something went wrong"}, + ] + + # Mock instance status for the SUCCESS case + mock_instance = MagicMock() + mock_instance.status = "RUNNING" + mock_lookup.return_value.instance.return_value = mock_instance + + # Run the poll + repair.poll_operations() + + # Check the stored operations + final_ops = mock_store_ops.call_args[0][0] + self.assertEqual(final_ops["node-1"]["status"], "SUCCESS") + self.assertEqual(final_ops["node-2"]["status"], "FAILURE") + self.assertEqual(final_ops["node-3"]["status"], "SUCCESS") + self.assertEqual(final_ops["node-4"]["status"], "RECOVERED") + + # Check that scontrol was called correctly + self.assertIn("update nodename=node-1 state=power_down", mock_run.call_args_list[0][0][0]) + self.assertIn("update nodename=node-2 state=down", mock_run.call_args_list[1][0][0]) + +if __name__ == '__main__': + unittest.main()