Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 54 additions & 26 deletions lib/marin/src/marin/mcp/babysitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from pathlib import Path
from typing import Any

from google.protobuf import json_format
from iris.cli.bug_report import gather_bug_report
from iris.cli.job import build_job_summary
from iris.cli.token_store import cluster_name_from_url, load_any_token, load_token
Expand All @@ -30,6 +31,7 @@
DEFAULT_PROFILE_SECONDS = 1
MAX_LIST_JOBS_PAGE_SIZE = 500
DEFAULT_LIST_JOBS_LIMIT = 100
_PROTO_TO_DICT_OPTIONS = dict(preserving_proto_field_name=True)

_ZEPHYR_PROGRESS_RE = re.compile(
r"\[(?P<stage>[^\]]+)\]\s+"
Expand All @@ -45,6 +47,9 @@
_TPU_XLA_RE = re.compile(r"\b(tpu|xla|hlo).*\b(bad|fault|hardware|unavailable|failed)\b", re.IGNORECASE)
_DEAD_WORKER_RE = re.compile(r"\b(heartbeat timeout|dead worker|worker.*lost|worker.*crashed)\b", re.IGNORECASE)
_TERMINATED_BY_USER_RE = re.compile(r"terminated by user", re.IGNORECASE)
_ZEPHYR_COORDINATOR_LOOP_FRAME = "_coordinator_loop"
_ZEPHYR_WAIT_FOR_STAGE_FRAME = "_wait_for_stage"
_ZEPHYR_WORKER_POOL_FRAME = "_worker"


@dataclass(frozen=True)
Expand Down Expand Up @@ -183,16 +188,12 @@ def job_status_to_json(job: job_pb2.JobStatus, tasks: Iterable[job_pb2.TaskStatu
}


def _attribute_value_to_json(value) -> Any:
kind = value.WhichOneof("value")
if kind is None:
return None
return getattr(value, kind)
def _worker_metadata_to_json(metadata: job_pb2.WorkerMetadata) -> dict[str, Any]:
return json_format.MessageToDict(metadata, **_PROTO_TO_DICT_OPTIONS)


def worker_status_to_json(worker: controller_pb2.Controller.WorkerHealthStatus) -> dict[str, Any]:
"""Serialize Iris worker health into stable JSON."""
metadata = worker.metadata
return {
"worker_id": worker.worker_id,
"healthy": bool(worker.healthy),
Expand All @@ -201,23 +202,7 @@ def worker_status_to_json(worker: controller_pb2.Controller.WorkerHealthStatus)
"running_job_ids": list(worker.running_job_ids),
"address": worker.address,
"status_message": worker.status_message,
"metadata": {
"hostname": metadata.hostname,
"ip_address": metadata.ip_address,
"cpu_count": int(metadata.cpu_count),
"memory_bytes": int(metadata.memory_bytes),
"disk_bytes": int(metadata.disk_bytes),
"device": _device_config_to_json(metadata.device) if metadata.HasField("device") else _cpu_device_json(),
"tpu_name": metadata.tpu_name,
"tpu_worker_id": metadata.tpu_worker_id,
"gpu_count": int(metadata.gpu_count),
"gpu_name": metadata.gpu_name,
"gpu_memory_mb": int(metadata.gpu_memory_mb),
"gce_instance_name": metadata.gce_instance_name,
"gce_zone": metadata.gce_zone,
"git_hash": metadata.git_hash,
"attributes": {key: _attribute_value_to_json(value) for key, value in metadata.attributes.items()},
},
"metadata": _worker_metadata_to_json(worker.metadata),
}


Expand Down Expand Up @@ -274,6 +259,27 @@ def parse_zephyr_progress(lines: Iterable[str]) -> list[dict[str, Any]]:
return list(snapshots_by_stage.values())


def parse_zephyr_thread_state(thread_dump: str) -> dict[str, Any]:
"""Classify a Zephyr coordinator thread dump into a compact liveness state."""
if not thread_dump:
return {"state": "unknown", "evidence": ["empty thread dump"]}

evidence: list[str] = []
has_wait_for_stage = _ZEPHYR_WAIT_FOR_STAGE_FRAME in thread_dump
has_coordinator_loop = _ZEPHYR_COORDINATOR_LOOP_FRAME in thread_dump
has_worker_pool = _ZEPHYR_WORKER_POOL_FRAME in thread_dump

if has_wait_for_stage:
evidence.append("waiting for stage completion")
if has_coordinator_loop:
evidence.append("coordinator loop thread present")
if has_wait_for_stage or has_coordinator_loop:
return {"state": "active", "evidence": evidence}
if has_worker_pool:
return {"state": "zombie_suspected", "evidence": ["worker pool frames without coordinator loop"]}
return {"state": "unknown", "evidence": ["no Zephyr coordinator frames found"]}


def classify_diagnosis(
*,
job: dict[str, Any],
Expand Down Expand Up @@ -368,11 +374,12 @@ def add(signal: str, severity: str, evidence: list[str], escalation_hint: str) -
"Inspect involved workers and recent process logs.",
)

if thread_dump and "_coordinator_loop" not in thread_dump and "_worker" in thread_dump:
thread_state = parse_zephyr_thread_state(thread_dump)
if thread_state["state"] == "zombie_suspected":
add(
"zombie_coordinator",
"error",
["thread dump lacks _coordinator_loop and shows worker pool frames"],
thread_state["evidence"],
"Restart only after confirming with the user.",
)

Expand Down Expand Up @@ -602,12 +609,33 @@ def zephyr_stage_progress(self, *, coord_job_id: str, max_lines: int = DEFAULT_Z
def zephyr_coordinator_status(self, *, coord_job_id: str) -> dict[str, Any]:
summary = self.job_summary(coord_job_id)["data"]
progress_payload = self.zephyr_stage_progress(coord_job_id=coord_job_id)["data"]
thread_target = f"{coord_job_id}/0"
thread_profile = self.profile_task(
target=thread_target,
profile_type="threads",
duration_seconds=DEFAULT_PROFILE_SECONDS,
)
thread_dump = str(thread_profile["data"].get("text", ""))
thread_state = parse_zephyr_thread_state(thread_dump)
thread_warnings = list(thread_profile["warnings"])
if thread_warnings:
thread_state = {
"state": "unavailable",
"evidence": thread_warnings,
}
diagnosis = classify_diagnosis(job=summary, logs=[], workers=[], thread_dump=thread_dump)
return self.envelope(
{
"summary": summary,
"progress": progress_payload["progress"],
"cursor": progress_payload["cursor"],
}
"thread_liveness": {
"target": thread_target,
**thread_state,
},
"diagnosis": diagnosis,
},
warnings=thread_warnings,
)

def diagnose(self, *, job_id: str, log_lines: int = DEFAULT_LOG_LINES) -> dict[str, Any]:
Expand Down
28 changes: 27 additions & 1 deletion tests/mcp/test_babysitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
_token_provider,
classify_diagnosis,
parse_zephyr_progress,
parse_zephyr_thread_state,
task_status_to_json,
)

Expand Down Expand Up @@ -161,6 +162,28 @@ def test_parse_zephyr_progress_keeps_latest_stage_snapshot():
assert progress[1]["stage"] == "stage1-Reduce"


def test_parse_zephyr_thread_state_classifies_active_and_zombie_dumps():
active = parse_zephyr_thread_state(
"""
Thread actor-method_0:
File "zephyr/execution.py", line 873, in _wait_for_stage
Thread zephyr-coordinator-loop:
File "zephyr/execution.py", line 444, in _coordinator_loop
"""
)
zombie = parse_zephyr_thread_state(
"""
Thread worker-pool-0:
File "concurrent/futures/thread.py", line 58, in _worker
"""
)

assert active["state"] == "active"
assert "waiting for stage completion" in active["evidence"]
assert zombie["state"] == "zombie_suspected"
assert "worker pool frames without coordinator loop" in zombie["evidence"]


def test_classify_diagnosis_reports_common_babysitting_signals():
job = {
"state": "failed",
Expand Down Expand Up @@ -191,12 +214,15 @@ def test_classify_diagnosis_reports_common_babysitting_signals():
}
]

signals = classify_diagnosis(job=job, logs=logs, workers=workers, thread_dump="")
thread_dump = 'File "concurrent/futures/thread.py", line 58, in _worker'

signals = classify_diagnosis(job=job, logs=logs, workers=workers, thread_dump=thread_dump)
names = {signal["signal"] for signal in signals}

assert "oom_or_exit_137" in names
assert "quota_or_backoff" in names
assert "tpu_xla_bad_node" in names
assert "dead_worker" in names
assert "zombie_coordinator" in names
assert "repeated_retries" in names
assert "misleading_terminated_by_user" in names
Loading