Skip to content

Commit 825f8a2

Browse files
committed
Address babysitter MCP review feedback
1 parent 9824376 commit 825f8a2

2 files changed

Lines changed: 95 additions & 15 deletions

File tree

lib/marin/src/marin/mcp/babysitter.py

Lines changed: 28 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,12 @@
88
import re
99
from collections.abc import Iterable
1010
from dataclasses import asdict, dataclass
11+
from pathlib import Path
1112
from typing import Any
1213

1314
from iris.cli.bug_report import gather_bug_report
1415
from iris.cli.job import build_job_summary
16+
from iris.cli.token_store import cluster_name_from_url, load_any_token, load_token
1517
from iris.cluster.log_store import build_log_source
1618
from iris.cluster.runtime.profile import SYSTEM_PROCESS_TARGET
1719
from iris.cluster.types import JobName
@@ -51,7 +53,6 @@ class IrisConnectionConfig:
5153

5254
controller_url: str
5355
cluster: str = "default"
54-
iris_token: str | None = None
5556
timeout_ms: int = 30_000
5657

5758

@@ -391,7 +392,7 @@ class IrisBabysitter:
391392

392393
def __init__(self, config: IrisConnectionConfig):
393394
self.config = config
394-
self.token_provider = _token_provider(config.iris_token)
395+
self.token_provider = _token_provider(config.cluster)
395396
interceptors = [AuthTokenInjector(self.token_provider)] if self.token_provider else []
396397
self.controller = ControllerServiceClientSync(
397398
config.controller_url,
@@ -423,6 +424,7 @@ def list_jobs(
423424
jobs: list[dict[str, Any]] = []
424425
offset = 0
425426
capped_limit = max(1, limit)
427+
prefix_job = JobName.from_wire(prefix) if prefix else None
426428
while len(jobs) < capped_limit:
427429
query = controller_pb2.Controller.JobQuery(
428430
state_filter=state_filter,
@@ -434,7 +436,7 @@ def list_jobs(
434436
)
435437
response = self.controller.list_jobs(controller_pb2.Controller.ListJobsRequest(query=query))
436438
for job in response.jobs:
437-
if prefix and not job.job_id.startswith(prefix):
439+
if prefix_job is not None and not _job_matches_prefix(job.job_id, prefix_job):
438440
continue
439441
jobs.append(job_status_to_json(job))
440442
if len(jobs) >= capped_limit:
@@ -447,9 +449,7 @@ def list_jobs(
447449
def job_summary(self, job_id: str) -> dict[str, Any]:
448450
job_response = self.controller.get_job_status(controller_pb2.Controller.GetJobStatusRequest(job_id=job_id))
449451
tasks_response = self.controller.list_tasks(controller_pb2.Controller.ListTasksRequest(job_id=job_id))
450-
summary = build_job_summary(job_response.job, list(tasks_response.tasks))
451-
summary.update(job_status_to_json(job_response.job, tasks_response.tasks))
452-
return self.envelope(summary)
452+
return self.envelope(_job_summary_payload(job_response.job, list(tasks_response.tasks)))
453453

454454
def job_tree(self, job_id: str) -> dict[str, Any]:
455455
root = JobName.from_wire(job_id)
@@ -620,6 +620,7 @@ def diagnose(self, *, job_id: str, log_lines: int = DEFAULT_LOG_LINES) -> dict[s
620620
def _jobs_with_prefix(self, prefix: str) -> list[job_pb2.JobStatus]:
621621
jobs: list[job_pb2.JobStatus] = []
622622
offset = 0
623+
root = JobName.from_wire(prefix)
623624
while True:
624625
query = controller_pb2.Controller.JobQuery(
625626
sort_field=controller_pb2.Controller.JOB_SORT_FIELD_DATE,
@@ -628,16 +629,30 @@ def _jobs_with_prefix(self, prefix: str) -> list[job_pb2.JobStatus]:
628629
limit=MAX_LIST_JOBS_PAGE_SIZE,
629630
)
630631
response = self.controller.list_jobs(controller_pb2.Controller.ListJobsRequest(query=query))
631-
jobs.extend(job for job in response.jobs if job.job_id.startswith(prefix))
632+
jobs.extend(job for job in response.jobs if _job_matches_prefix(job.job_id, root))
632633
if not response.has_more:
633634
return jobs
634635
offset += len(response.jobs)
635636

636637

637-
def _token_provider(token: str | None) -> TokenProvider | None:
638-
if not token:
638+
def _job_summary_payload(job: job_pb2.JobStatus, tasks: list[job_pb2.TaskStatus]) -> dict[str, Any]:
639+
summary = build_job_summary(job, tasks)
640+
for key, value in job_status_to_json(job).items():
641+
summary.setdefault(key, value)
642+
return summary
643+
644+
645+
def _job_matches_prefix(job_id: str, prefix: JobName) -> bool:
646+
return prefix.is_ancestor_of(JobName.from_wire(job_id), include_self=True)
647+
648+
649+
def _token_provider(cluster: str, *, store_path: Path | None = None) -> TokenProvider | None:
650+
credential = load_token(cluster, store_path=store_path)
651+
if credential is None:
652+
credential = load_any_token(store_path=store_path)
653+
if credential is None:
639654
return None
640-
return StaticTokenProvider(token)
655+
return StaticTokenProvider(credential.token)
641656

642657

643658
def _normalize_state_filter(state: str) -> str:
@@ -769,19 +784,18 @@ def diagnose(job_id: str, log_lines: int = DEFAULT_LOG_LINES) -> dict[str, Any]:
769784
def main(argv: list[str] | None = None) -> None:
770785
parser = argparse.ArgumentParser(description="Run the Marin Iris/Zephyr babysitting MCP server.")
771786
parser.add_argument("--controller-url", required=True, help="Iris controller URL.")
772-
parser.add_argument("--cluster", default="default", help="Cluster label returned in tool responses.")
773-
parser.add_argument("--iris-token", default=None, help="Bearer token for Iris controllers with auth enabled.")
787+
parser.add_argument("--cluster", default=None, help="Cluster label and Iris token-store key.")
774788
parser.add_argument("--timeout-ms", type=int, default=30_000, help="Controller RPC timeout in milliseconds.")
775789
parser.add_argument("--transport", choices=("stdio", "sse", "streamable-http"), default="stdio")
776790
parser.add_argument("--host", default="127.0.0.1", help="HTTP host for SSE/streamable-http transports.")
777791
parser.add_argument("--port", type=int, default=8000, help="HTTP port for SSE/streamable-http transports.")
778792
args = parser.parse_args(argv)
793+
cluster = args.cluster or cluster_name_from_url(args.controller_url)
779794

780795
service = IrisBabysitter(
781796
IrisConnectionConfig(
782797
controller_url=args.controller_url,
783-
cluster=args.cluster,
784-
iris_token=args.iris_token,
798+
cluster=cluster,
785799
timeout_ms=args.timeout_ms,
786800
)
787801
)

tests/mcp/test_babysitter.py

Lines changed: 67 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,14 @@
11
# Copyright The Marin Authors
22
# SPDX-License-Identifier: Apache-2.0
33

4-
from iris.rpc import job_pb2, time_pb2
4+
from iris.cli.token_store import store_token
5+
from iris.rpc import controller_pb2, job_pb2, time_pb2
56

67
from marin.mcp.babysitter import (
8+
IrisBabysitter,
9+
IrisConnectionConfig,
10+
_job_summary_payload,
11+
_token_provider,
712
classify_diagnosis,
813
parse_zephyr_progress,
914
task_status_to_json,
@@ -14,6 +19,21 @@ def _timestamp(epoch_ms: int):
1419
return time_pb2.Timestamp(epoch_ms=epoch_ms)
1520

1621

22+
class _ListJobsController:
23+
def __init__(self, jobs: list[job_pb2.JobStatus]):
24+
self.jobs = jobs
25+
26+
def list_jobs(self, _request):
27+
return controller_pb2.Controller.ListJobsResponse(jobs=self.jobs, has_more=False)
28+
29+
30+
def _service_with_controller(controller):
31+
service = object.__new__(IrisBabysitter)
32+
service.config = IrisConnectionConfig(controller_url="http://controller", cluster="test")
33+
service.controller = controller
34+
return service
35+
36+
1737
def test_task_status_json_includes_attempts_timestamps_and_usage():
1838
task = job_pb2.TaskStatus(
1939
task_id="/alice/train/0",
@@ -71,6 +91,52 @@ def test_task_status_json_includes_attempts_timestamps_and_usage():
7191
assert payload["attempts"][1]["exit_code"] == 137
7292

7393

94+
def test_job_summary_payload_preserves_summary_task_fields():
95+
job = job_pb2.JobStatus(
96+
job_id="/alice/train",
97+
name="train",
98+
state=job_pb2.JOB_STATE_RUNNING,
99+
task_count=1,
100+
)
101+
running_task = job_pb2.TaskStatus(
102+
task_id="/alice/train/0",
103+
state=job_pb2.TASK_STATE_RUNNING,
104+
exit_code=0,
105+
)
106+
107+
payload = _job_summary_payload(job, [running_task])
108+
109+
assert payload["tasks"][0]["index"] == "0"
110+
assert payload["tasks"][0]["exit_code"] is None
111+
assert "resource_requests" in payload
112+
113+
114+
def test_jobs_with_prefix_excludes_string_prefix_siblings():
115+
service = _service_with_controller(
116+
_ListJobsController(
117+
[
118+
job_pb2.JobStatus(job_id="/alice/train"),
119+
job_pb2.JobStatus(job_id="/alice/train/child"),
120+
job_pb2.JobStatus(job_id="/alice/train-v2"),
121+
]
122+
)
123+
)
124+
125+
jobs = service._jobs_with_prefix("/alice/train")
126+
127+
assert [job.job_id for job in jobs] == ["/alice/train", "/alice/train/child"]
128+
129+
130+
def test_token_provider_loads_iris_token_store(tmp_path):
131+
store_path = tmp_path / "tokens.json"
132+
store_token("iris-prod", "https://controller.example.com", "stored-token", store_path=store_path)
133+
134+
provider = _token_provider("iris-prod", store_path=store_path)
135+
136+
assert provider is not None
137+
assert provider.get_token() == "stored-token"
138+
139+
74140
def test_parse_zephyr_progress_keeps_latest_stage_snapshot():
75141
lines = [
76142
"noise: pull_task worker-7",

0 commit comments

Comments
 (0)