diff --git a/lib/marin/src/marin/mcp/babysitter.py b/lib/marin/src/marin/mcp/babysitter.py index d851698c58..16649e1ec1 100644 --- a/lib/marin/src/marin/mcp/babysitter.py +++ b/lib/marin/src/marin/mcp/babysitter.py @@ -11,6 +11,7 @@ from pathlib import Path from typing import Any +from google.protobuf import json_format from iris.cli.bug_report import gather_bug_report from iris.cli.job import build_job_summary from iris.cli.token_store import cluster_name_from_url, load_any_token, load_token @@ -30,6 +31,7 @@ DEFAULT_PROFILE_SECONDS = 1 MAX_LIST_JOBS_PAGE_SIZE = 500 DEFAULT_LIST_JOBS_LIMIT = 100 +_PROTO_TO_DICT_OPTIONS = dict(preserving_proto_field_name=True) _ZEPHYR_PROGRESS_RE = re.compile( r"\[(?P[^\]]+)\]\s+" @@ -45,6 +47,9 @@ _TPU_XLA_RE = re.compile(r"\b(tpu|xla|hlo).*\b(bad|fault|hardware|unavailable|failed)\b", re.IGNORECASE) _DEAD_WORKER_RE = re.compile(r"\b(heartbeat timeout|dead worker|worker.*lost|worker.*crashed)\b", re.IGNORECASE) _TERMINATED_BY_USER_RE = re.compile(r"terminated by user", re.IGNORECASE) +_ZEPHYR_COORDINATOR_LOOP_FRAME = "_coordinator_loop" +_ZEPHYR_WAIT_FOR_STAGE_FRAME = "_wait_for_stage" +_ZEPHYR_WORKER_POOL_FRAME = "_worker" @dataclass(frozen=True) @@ -183,16 +188,12 @@ def job_status_to_json(job: job_pb2.JobStatus, tasks: Iterable[job_pb2.TaskStatu } -def _attribute_value_to_json(value) -> Any: - kind = value.WhichOneof("value") - if kind is None: - return None - return getattr(value, kind) +def _worker_metadata_to_json(metadata: job_pb2.WorkerMetadata) -> dict[str, Any]: + return json_format.MessageToDict(metadata, **_PROTO_TO_DICT_OPTIONS) def worker_status_to_json(worker: controller_pb2.Controller.WorkerHealthStatus) -> dict[str, Any]: """Serialize Iris worker health into stable JSON.""" - metadata = worker.metadata return { "worker_id": worker.worker_id, "healthy": bool(worker.healthy), @@ -201,23 +202,7 @@ def worker_status_to_json(worker: controller_pb2.Controller.WorkerHealthStatus) "running_job_ids": list(worker.running_job_ids), "address": worker.address, "status_message": worker.status_message, - "metadata": { - "hostname": metadata.hostname, - "ip_address": metadata.ip_address, - "cpu_count": int(metadata.cpu_count), - "memory_bytes": int(metadata.memory_bytes), - "disk_bytes": int(metadata.disk_bytes), - "device": _device_config_to_json(metadata.device) if metadata.HasField("device") else _cpu_device_json(), - "tpu_name": metadata.tpu_name, - "tpu_worker_id": metadata.tpu_worker_id, - "gpu_count": int(metadata.gpu_count), - "gpu_name": metadata.gpu_name, - "gpu_memory_mb": int(metadata.gpu_memory_mb), - "gce_instance_name": metadata.gce_instance_name, - "gce_zone": metadata.gce_zone, - "git_hash": metadata.git_hash, - "attributes": {key: _attribute_value_to_json(value) for key, value in metadata.attributes.items()}, - }, + "metadata": _worker_metadata_to_json(worker.metadata), } @@ -274,6 +259,27 @@ def parse_zephyr_progress(lines: Iterable[str]) -> list[dict[str, Any]]: return list(snapshots_by_stage.values()) +def parse_zephyr_thread_state(thread_dump: str) -> dict[str, Any]: + """Classify a Zephyr coordinator thread dump into a compact liveness state.""" + if not thread_dump: + return {"state": "unknown", "evidence": ["empty thread dump"]} + + evidence: list[str] = [] + has_wait_for_stage = _ZEPHYR_WAIT_FOR_STAGE_FRAME in thread_dump + has_coordinator_loop = _ZEPHYR_COORDINATOR_LOOP_FRAME in thread_dump + has_worker_pool = _ZEPHYR_WORKER_POOL_FRAME in thread_dump + + if has_wait_for_stage: + evidence.append("waiting for stage completion") + if has_coordinator_loop: + evidence.append("coordinator loop thread present") + if has_wait_for_stage or has_coordinator_loop: + return {"state": "active", "evidence": evidence} + if has_worker_pool: + return {"state": "zombie_suspected", "evidence": ["worker pool frames without coordinator loop"]} + return {"state": "unknown", "evidence": ["no Zephyr coordinator frames found"]} + + def classify_diagnosis( *, job: dict[str, Any], @@ -368,11 +374,12 @@ def add(signal: str, severity: str, evidence: list[str], escalation_hint: str) - "Inspect involved workers and recent process logs.", ) - if thread_dump and "_coordinator_loop" not in thread_dump and "_worker" in thread_dump: + thread_state = parse_zephyr_thread_state(thread_dump) + if thread_state["state"] == "zombie_suspected": add( "zombie_coordinator", "error", - ["thread dump lacks _coordinator_loop and shows worker pool frames"], + thread_state["evidence"], "Restart only after confirming with the user.", ) @@ -602,12 +609,33 @@ def zephyr_stage_progress(self, *, coord_job_id: str, max_lines: int = DEFAULT_Z def zephyr_coordinator_status(self, *, coord_job_id: str) -> dict[str, Any]: summary = self.job_summary(coord_job_id)["data"] progress_payload = self.zephyr_stage_progress(coord_job_id=coord_job_id)["data"] + thread_target = f"{coord_job_id}/0" + thread_profile = self.profile_task( + target=thread_target, + profile_type="threads", + duration_seconds=DEFAULT_PROFILE_SECONDS, + ) + thread_dump = str(thread_profile["data"].get("text", "")) + thread_state = parse_zephyr_thread_state(thread_dump) + thread_warnings = list(thread_profile["warnings"]) + if thread_warnings: + thread_state = { + "state": "unavailable", + "evidence": thread_warnings, + } + diagnosis = classify_diagnosis(job=summary, logs=[], workers=[], thread_dump=thread_dump) return self.envelope( { "summary": summary, "progress": progress_payload["progress"], "cursor": progress_payload["cursor"], - } + "thread_liveness": { + "target": thread_target, + **thread_state, + }, + "diagnosis": diagnosis, + }, + warnings=thread_warnings, ) def diagnose(self, *, job_id: str, log_lines: int = DEFAULT_LOG_LINES) -> dict[str, Any]: diff --git a/tests/mcp/test_babysitter.py b/tests/mcp/test_babysitter.py index abbfcd5962..3526eb500c 100644 --- a/tests/mcp/test_babysitter.py +++ b/tests/mcp/test_babysitter.py @@ -11,6 +11,7 @@ _token_provider, classify_diagnosis, parse_zephyr_progress, + parse_zephyr_thread_state, task_status_to_json, ) @@ -161,6 +162,28 @@ def test_parse_zephyr_progress_keeps_latest_stage_snapshot(): assert progress[1]["stage"] == "stage1-Reduce" +def test_parse_zephyr_thread_state_classifies_active_and_zombie_dumps(): + active = parse_zephyr_thread_state( + """ + Thread actor-method_0: + File "zephyr/execution.py", line 873, in _wait_for_stage + Thread zephyr-coordinator-loop: + File "zephyr/execution.py", line 444, in _coordinator_loop + """ + ) + zombie = parse_zephyr_thread_state( + """ + Thread worker-pool-0: + File "concurrent/futures/thread.py", line 58, in _worker + """ + ) + + assert active["state"] == "active" + assert "waiting for stage completion" in active["evidence"] + assert zombie["state"] == "zombie_suspected" + assert "worker pool frames without coordinator loop" in zombie["evidence"] + + def test_classify_diagnosis_reports_common_babysitting_signals(): job = { "state": "failed", @@ -191,12 +214,15 @@ def test_classify_diagnosis_reports_common_babysitting_signals(): } ] - signals = classify_diagnosis(job=job, logs=logs, workers=workers, thread_dump="") + thread_dump = 'File "concurrent/futures/thread.py", line 58, in _worker' + + signals = classify_diagnosis(job=job, logs=logs, workers=workers, thread_dump=thread_dump) names = {signal["signal"] for signal in signals} assert "oom_or_exit_137" in names assert "quota_or_backoff" in names assert "tpu_xla_bad_node" in names assert "dead_worker" in names + assert "zombie_coordinator" in names assert "repeated_retries" in names assert "misleading_terminated_by_user" in names