88import re
99from collections .abc import Iterable
1010from dataclasses import asdict , dataclass
11+ from pathlib import Path
1112from typing import Any
1213
1314from iris .cli .bug_report import gather_bug_report
1415from iris .cli .job import build_job_summary
16+ from iris .cli .token_store import cluster_name_from_url , load_any_token , load_token
1517from iris .cluster .log_store import build_log_source
1618from iris .cluster .runtime .profile import SYSTEM_PROCESS_TARGET
1719from iris .cluster .types import JobName
@@ -51,7 +53,6 @@ class IrisConnectionConfig:
5153
5254 controller_url : str
5355 cluster : str = "default"
54- iris_token : str | None = None
5556 timeout_ms : int = 30_000
5657
5758
@@ -391,7 +392,7 @@ class IrisBabysitter:
391392
392393 def __init__ (self , config : IrisConnectionConfig ):
393394 self .config = config
394- self .token_provider = _token_provider (config .iris_token )
395+ self .token_provider = _token_provider (config .cluster )
395396 interceptors = [AuthTokenInjector (self .token_provider )] if self .token_provider else []
396397 self .controller = ControllerServiceClientSync (
397398 config .controller_url ,
@@ -423,6 +424,7 @@ def list_jobs(
423424 jobs : list [dict [str , Any ]] = []
424425 offset = 0
425426 capped_limit = max (1 , limit )
427+ prefix_job = JobName .from_wire (prefix ) if prefix else None
426428 while len (jobs ) < capped_limit :
427429 query = controller_pb2 .Controller .JobQuery (
428430 state_filter = state_filter ,
@@ -434,7 +436,7 @@ def list_jobs(
434436 )
435437 response = self .controller .list_jobs (controller_pb2 .Controller .ListJobsRequest (query = query ))
436438 for job in response .jobs :
437- if prefix and not job .job_id . startswith ( prefix ):
439+ if prefix_job is not None and not _job_matches_prefix ( job .job_id , prefix_job ):
438440 continue
439441 jobs .append (job_status_to_json (job ))
440442 if len (jobs ) >= capped_limit :
@@ -447,9 +449,7 @@ def list_jobs(
447449 def job_summary (self , job_id : str ) -> dict [str , Any ]:
448450 job_response = self .controller .get_job_status (controller_pb2 .Controller .GetJobStatusRequest (job_id = job_id ))
449451 tasks_response = self .controller .list_tasks (controller_pb2 .Controller .ListTasksRequest (job_id = job_id ))
450- summary = build_job_summary (job_response .job , list (tasks_response .tasks ))
451- summary .update (job_status_to_json (job_response .job , tasks_response .tasks ))
452- return self .envelope (summary )
452+ return self .envelope (_job_summary_payload (job_response .job , list (tasks_response .tasks )))
453453
454454 def job_tree (self , job_id : str ) -> dict [str , Any ]:
455455 root = JobName .from_wire (job_id )
@@ -620,6 +620,7 @@ def diagnose(self, *, job_id: str, log_lines: int = DEFAULT_LOG_LINES) -> dict[s
620620 def _jobs_with_prefix (self , prefix : str ) -> list [job_pb2 .JobStatus ]:
621621 jobs : list [job_pb2 .JobStatus ] = []
622622 offset = 0
623+ root = JobName .from_wire (prefix )
623624 while True :
624625 query = controller_pb2 .Controller .JobQuery (
625626 sort_field = controller_pb2 .Controller .JOB_SORT_FIELD_DATE ,
@@ -628,16 +629,30 @@ def _jobs_with_prefix(self, prefix: str) -> list[job_pb2.JobStatus]:
628629 limit = MAX_LIST_JOBS_PAGE_SIZE ,
629630 )
630631 response = self .controller .list_jobs (controller_pb2 .Controller .ListJobsRequest (query = query ))
631- jobs .extend (job for job in response .jobs if job .job_id . startswith ( prefix ))
632+ jobs .extend (job for job in response .jobs if _job_matches_prefix ( job .job_id , root ))
632633 if not response .has_more :
633634 return jobs
634635 offset += len (response .jobs )
635636
636637
637- def _token_provider (token : str | None ) -> TokenProvider | None :
638- if not token :
638+ def _job_summary_payload (job : job_pb2 .JobStatus , tasks : list [job_pb2 .TaskStatus ]) -> dict [str , Any ]:
639+ summary = build_job_summary (job , tasks )
640+ for key , value in job_status_to_json (job ).items ():
641+ summary .setdefault (key , value )
642+ return summary
643+
644+
645+ def _job_matches_prefix (job_id : str , prefix : JobName ) -> bool :
646+ return prefix .is_ancestor_of (JobName .from_wire (job_id ), include_self = True )
647+
648+
649+ def _token_provider (cluster : str , * , store_path : Path | None = None ) -> TokenProvider | None :
650+ credential = load_token (cluster , store_path = store_path )
651+ if credential is None :
652+ credential = load_any_token (store_path = store_path )
653+ if credential is None :
639654 return None
640- return StaticTokenProvider (token )
655+ return StaticTokenProvider (credential . token )
641656
642657
643658def _normalize_state_filter (state : str ) -> str :
@@ -769,19 +784,18 @@ def diagnose(job_id: str, log_lines: int = DEFAULT_LOG_LINES) -> dict[str, Any]:
769784def main (argv : list [str ] | None = None ) -> None :
770785 parser = argparse .ArgumentParser (description = "Run the Marin Iris/Zephyr babysitting MCP server." )
771786 parser .add_argument ("--controller-url" , required = True , help = "Iris controller URL." )
772- parser .add_argument ("--cluster" , default = "default" , help = "Cluster label returned in tool responses." )
773- parser .add_argument ("--iris-token" , default = None , help = "Bearer token for Iris controllers with auth enabled." )
787+ parser .add_argument ("--cluster" , default = None , help = "Cluster label and Iris token-store key." )
774788 parser .add_argument ("--timeout-ms" , type = int , default = 30_000 , help = "Controller RPC timeout in milliseconds." )
775789 parser .add_argument ("--transport" , choices = ("stdio" , "sse" , "streamable-http" ), default = "stdio" )
776790 parser .add_argument ("--host" , default = "127.0.0.1" , help = "HTTP host for SSE/streamable-http transports." )
777791 parser .add_argument ("--port" , type = int , default = 8000 , help = "HTTP port for SSE/streamable-http transports." )
778792 args = parser .parse_args (argv )
793+ cluster = args .cluster or cluster_name_from_url (args .controller_url )
779794
780795 service = IrisBabysitter (
781796 IrisConnectionConfig (
782797 controller_url = args .controller_url ,
783- cluster = args .cluster ,
784- iris_token = args .iris_token ,
798+ cluster = cluster ,
785799 timeout_ms = args .timeout_ms ,
786800 )
787801 )
0 commit comments