Skip to content

Commit 4c9adbd

Browse files
committed
Improve MCP babysitting diagnostics
1 parent 7dd69d7 commit 4c9adbd

2 files changed

Lines changed: 81 additions & 27 deletions

File tree

lib/marin/src/marin/mcp/babysitter.py

Lines changed: 54 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from pathlib import Path
1212
from typing import Any
1313

14+
from google.protobuf import json_format
1415
from iris.cli.bug_report import gather_bug_report
1516
from iris.cli.job import build_job_summary
1617
from iris.cli.token_store import cluster_name_from_url, load_any_token, load_token
@@ -30,6 +31,7 @@
3031
DEFAULT_PROFILE_SECONDS = 1
3132
MAX_LIST_JOBS_PAGE_SIZE = 500
3233
DEFAULT_LIST_JOBS_LIMIT = 100
34+
_PROTO_TO_DICT_OPTIONS = dict(preserving_proto_field_name=True)
3335

3436
_ZEPHYR_PROGRESS_RE = re.compile(
3537
r"\[(?P<stage>[^\]]+)\]\s+"
@@ -45,6 +47,9 @@
4547
_TPU_XLA_RE = re.compile(r"\b(tpu|xla|hlo).*\b(bad|fault|hardware|unavailable|failed)\b", re.IGNORECASE)
4648
_DEAD_WORKER_RE = re.compile(r"\b(heartbeat timeout|dead worker|worker.*lost|worker.*crashed)\b", re.IGNORECASE)
4749
_TERMINATED_BY_USER_RE = re.compile(r"terminated by user", re.IGNORECASE)
50+
_ZEPHYR_COORDINATOR_LOOP_FRAME = "_coordinator_loop"
51+
_ZEPHYR_WAIT_FOR_STAGE_FRAME = "_wait_for_stage"
52+
_ZEPHYR_WORKER_POOL_FRAME = "_worker"
4853

4954

5055
@dataclass(frozen=True)
@@ -183,16 +188,12 @@ def job_status_to_json(job: job_pb2.JobStatus, tasks: Iterable[job_pb2.TaskStatu
183188
}
184189

185190

186-
def _attribute_value_to_json(value) -> Any:
187-
kind = value.WhichOneof("value")
188-
if kind is None:
189-
return None
190-
return getattr(value, kind)
191+
def _worker_metadata_to_json(metadata: job_pb2.WorkerMetadata) -> dict[str, Any]:
192+
return json_format.MessageToDict(metadata, **_PROTO_TO_DICT_OPTIONS)
191193

192194

193195
def worker_status_to_json(worker: controller_pb2.Controller.WorkerHealthStatus) -> dict[str, Any]:
194196
"""Serialize Iris worker health into stable JSON."""
195-
metadata = worker.metadata
196197
return {
197198
"worker_id": worker.worker_id,
198199
"healthy": bool(worker.healthy),
@@ -201,23 +202,7 @@ def worker_status_to_json(worker: controller_pb2.Controller.WorkerHealthStatus)
201202
"running_job_ids": list(worker.running_job_ids),
202203
"address": worker.address,
203204
"status_message": worker.status_message,
204-
"metadata": {
205-
"hostname": metadata.hostname,
206-
"ip_address": metadata.ip_address,
207-
"cpu_count": int(metadata.cpu_count),
208-
"memory_bytes": int(metadata.memory_bytes),
209-
"disk_bytes": int(metadata.disk_bytes),
210-
"device": _device_config_to_json(metadata.device) if metadata.HasField("device") else _cpu_device_json(),
211-
"tpu_name": metadata.tpu_name,
212-
"tpu_worker_id": metadata.tpu_worker_id,
213-
"gpu_count": int(metadata.gpu_count),
214-
"gpu_name": metadata.gpu_name,
215-
"gpu_memory_mb": int(metadata.gpu_memory_mb),
216-
"gce_instance_name": metadata.gce_instance_name,
217-
"gce_zone": metadata.gce_zone,
218-
"git_hash": metadata.git_hash,
219-
"attributes": {key: _attribute_value_to_json(value) for key, value in metadata.attributes.items()},
220-
},
205+
"metadata": _worker_metadata_to_json(worker.metadata),
221206
}
222207

223208

@@ -274,6 +259,27 @@ def parse_zephyr_progress(lines: Iterable[str]) -> list[dict[str, Any]]:
274259
return list(snapshots_by_stage.values())
275260

276261

262+
def parse_zephyr_thread_state(thread_dump: str) -> dict[str, Any]:
263+
"""Classify a Zephyr coordinator thread dump into a compact liveness state."""
264+
if not thread_dump:
265+
return {"state": "unknown", "evidence": ["empty thread dump"]}
266+
267+
evidence: list[str] = []
268+
has_wait_for_stage = _ZEPHYR_WAIT_FOR_STAGE_FRAME in thread_dump
269+
has_coordinator_loop = _ZEPHYR_COORDINATOR_LOOP_FRAME in thread_dump
270+
has_worker_pool = _ZEPHYR_WORKER_POOL_FRAME in thread_dump
271+
272+
if has_wait_for_stage:
273+
evidence.append("waiting for stage completion")
274+
if has_coordinator_loop:
275+
evidence.append("coordinator loop thread present")
276+
if has_wait_for_stage or has_coordinator_loop:
277+
return {"state": "active", "evidence": evidence}
278+
if has_worker_pool:
279+
return {"state": "zombie_suspected", "evidence": ["worker pool frames without coordinator loop"]}
280+
return {"state": "unknown", "evidence": ["no Zephyr coordinator frames found"]}
281+
282+
277283
def classify_diagnosis(
278284
*,
279285
job: dict[str, Any],
@@ -368,11 +374,12 @@ def add(signal: str, severity: str, evidence: list[str], escalation_hint: str) -
368374
"Inspect involved workers and recent process logs.",
369375
)
370376

371-
if thread_dump and "_coordinator_loop" not in thread_dump and "_worker" in thread_dump:
377+
thread_state = parse_zephyr_thread_state(thread_dump)
378+
if thread_state["state"] == "zombie_suspected":
372379
add(
373380
"zombie_coordinator",
374381
"error",
375-
["thread dump lacks _coordinator_loop and shows worker pool frames"],
382+
thread_state["evidence"],
376383
"Restart only after confirming with the user.",
377384
)
378385

@@ -602,12 +609,33 @@ def zephyr_stage_progress(self, *, coord_job_id: str, max_lines: int = DEFAULT_Z
602609
def zephyr_coordinator_status(self, *, coord_job_id: str) -> dict[str, Any]:
603610
summary = self.job_summary(coord_job_id)["data"]
604611
progress_payload = self.zephyr_stage_progress(coord_job_id=coord_job_id)["data"]
612+
thread_target = f"{coord_job_id}/0"
613+
thread_profile = self.profile_task(
614+
target=thread_target,
615+
profile_type="threads",
616+
duration_seconds=DEFAULT_PROFILE_SECONDS,
617+
)
618+
thread_dump = str(thread_profile["data"].get("text", ""))
619+
thread_state = parse_zephyr_thread_state(thread_dump)
620+
thread_warnings = list(thread_profile["warnings"])
621+
if thread_warnings:
622+
thread_state = {
623+
"state": "unavailable",
624+
"evidence": thread_warnings,
625+
}
626+
diagnosis = classify_diagnosis(job=summary, logs=[], workers=[], thread_dump=thread_dump)
605627
return self.envelope(
606628
{
607629
"summary": summary,
608630
"progress": progress_payload["progress"],
609631
"cursor": progress_payload["cursor"],
610-
}
632+
"thread_liveness": {
633+
"target": thread_target,
634+
**thread_state,
635+
},
636+
"diagnosis": diagnosis,
637+
},
638+
warnings=thread_warnings,
611639
)
612640

613641
def diagnose(self, *, job_id: str, log_lines: int = DEFAULT_LOG_LINES) -> dict[str, Any]:

tests/mcp/test_babysitter.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
_token_provider,
1212
classify_diagnosis,
1313
parse_zephyr_progress,
14+
parse_zephyr_thread_state,
1415
task_status_to_json,
1516
)
1617

@@ -161,6 +162,28 @@ def test_parse_zephyr_progress_keeps_latest_stage_snapshot():
161162
assert progress[1]["stage"] == "stage1-Reduce"
162163

163164

165+
def test_parse_zephyr_thread_state_classifies_active_and_zombie_dumps():
166+
active = parse_zephyr_thread_state(
167+
"""
168+
Thread actor-method_0:
169+
File "zephyr/execution.py", line 873, in _wait_for_stage
170+
Thread zephyr-coordinator-loop:
171+
File "zephyr/execution.py", line 444, in _coordinator_loop
172+
"""
173+
)
174+
zombie = parse_zephyr_thread_state(
175+
"""
176+
Thread worker-pool-0:
177+
File "concurrent/futures/thread.py", line 58, in _worker
178+
"""
179+
)
180+
181+
assert active["state"] == "active"
182+
assert "waiting for stage completion" in active["evidence"]
183+
assert zombie["state"] == "zombie_suspected"
184+
assert "worker pool frames without coordinator loop" in zombie["evidence"]
185+
186+
164187
def test_classify_diagnosis_reports_common_babysitting_signals():
165188
job = {
166189
"state": "failed",
@@ -191,12 +214,15 @@ def test_classify_diagnosis_reports_common_babysitting_signals():
191214
}
192215
]
193216

194-
signals = classify_diagnosis(job=job, logs=logs, workers=workers, thread_dump="")
217+
thread_dump = 'File "concurrent/futures/thread.py", line 58, in _worker'
218+
219+
signals = classify_diagnosis(job=job, logs=logs, workers=workers, thread_dump=thread_dump)
195220
names = {signal["signal"] for signal in signals}
196221

197222
assert "oom_or_exit_137" in names
198223
assert "quota_or_backoff" in names
199224
assert "tpu_xla_bad_node" in names
200225
assert "dead_worker" in names
226+
assert "zombie_coordinator" in names
201227
assert "repeated_retries" in names
202228
assert "misleading_terminated_by_user" in names

0 commit comments

Comments
 (0)