1111from pathlib import Path
1212from typing import Any
1313
14+ from google .protobuf import json_format
1415from iris .cli .bug_report import gather_bug_report
1516from iris .cli .job import build_job_summary
1617from iris .cli .token_store import cluster_name_from_url , load_any_token , load_token
3031DEFAULT_PROFILE_SECONDS = 1
3132MAX_LIST_JOBS_PAGE_SIZE = 500
3233DEFAULT_LIST_JOBS_LIMIT = 100
34+ _PROTO_TO_DICT_OPTIONS = dict (preserving_proto_field_name = True )
3335
3436_ZEPHYR_PROGRESS_RE = re .compile (
3537 r"\[(?P<stage>[^\]]+)\]\s+"
4547_TPU_XLA_RE = re .compile (r"\b(tpu|xla|hlo).*\b(bad|fault|hardware|unavailable|failed)\b" , re .IGNORECASE )
4648_DEAD_WORKER_RE = re .compile (r"\b(heartbeat timeout|dead worker|worker.*lost|worker.*crashed)\b" , re .IGNORECASE )
4749_TERMINATED_BY_USER_RE = re .compile (r"terminated by user" , re .IGNORECASE )
50+ _ZEPHYR_COORDINATOR_LOOP_FRAME = "_coordinator_loop"
51+ _ZEPHYR_WAIT_FOR_STAGE_FRAME = "_wait_for_stage"
52+ _ZEPHYR_WORKER_POOL_FRAME = "_worker"
4853
4954
5055@dataclass (frozen = True )
@@ -183,16 +188,12 @@ def job_status_to_json(job: job_pb2.JobStatus, tasks: Iterable[job_pb2.TaskStatu
183188 }
184189
185190
186- def _attribute_value_to_json (value ) -> Any :
187- kind = value .WhichOneof ("value" )
188- if kind is None :
189- return None
190- return getattr (value , kind )
191+ def _worker_metadata_to_json (metadata : job_pb2 .WorkerMetadata ) -> dict [str , Any ]:
192+ return json_format .MessageToDict (metadata , ** _PROTO_TO_DICT_OPTIONS )
191193
192194
193195def worker_status_to_json (worker : controller_pb2 .Controller .WorkerHealthStatus ) -> dict [str , Any ]:
194196 """Serialize Iris worker health into stable JSON."""
195- metadata = worker .metadata
196197 return {
197198 "worker_id" : worker .worker_id ,
198199 "healthy" : bool (worker .healthy ),
@@ -201,23 +202,7 @@ def worker_status_to_json(worker: controller_pb2.Controller.WorkerHealthStatus)
201202 "running_job_ids" : list (worker .running_job_ids ),
202203 "address" : worker .address ,
203204 "status_message" : worker .status_message ,
204- "metadata" : {
205- "hostname" : metadata .hostname ,
206- "ip_address" : metadata .ip_address ,
207- "cpu_count" : int (metadata .cpu_count ),
208- "memory_bytes" : int (metadata .memory_bytes ),
209- "disk_bytes" : int (metadata .disk_bytes ),
210- "device" : _device_config_to_json (metadata .device ) if metadata .HasField ("device" ) else _cpu_device_json (),
211- "tpu_name" : metadata .tpu_name ,
212- "tpu_worker_id" : metadata .tpu_worker_id ,
213- "gpu_count" : int (metadata .gpu_count ),
214- "gpu_name" : metadata .gpu_name ,
215- "gpu_memory_mb" : int (metadata .gpu_memory_mb ),
216- "gce_instance_name" : metadata .gce_instance_name ,
217- "gce_zone" : metadata .gce_zone ,
218- "git_hash" : metadata .git_hash ,
219- "attributes" : {key : _attribute_value_to_json (value ) for key , value in metadata .attributes .items ()},
220- },
205+ "metadata" : _worker_metadata_to_json (worker .metadata ),
221206 }
222207
223208
@@ -274,6 +259,27 @@ def parse_zephyr_progress(lines: Iterable[str]) -> list[dict[str, Any]]:
274259 return list (snapshots_by_stage .values ())
275260
276261
262+ def parse_zephyr_thread_state (thread_dump : str ) -> dict [str , Any ]:
263+ """Classify a Zephyr coordinator thread dump into a compact liveness state."""
264+ if not thread_dump :
265+ return {"state" : "unknown" , "evidence" : ["empty thread dump" ]}
266+
267+ evidence : list [str ] = []
268+ has_wait_for_stage = _ZEPHYR_WAIT_FOR_STAGE_FRAME in thread_dump
269+ has_coordinator_loop = _ZEPHYR_COORDINATOR_LOOP_FRAME in thread_dump
270+ has_worker_pool = _ZEPHYR_WORKER_POOL_FRAME in thread_dump
271+
272+ if has_wait_for_stage :
273+ evidence .append ("waiting for stage completion" )
274+ if has_coordinator_loop :
275+ evidence .append ("coordinator loop thread present" )
276+ if has_wait_for_stage or has_coordinator_loop :
277+ return {"state" : "active" , "evidence" : evidence }
278+ if has_worker_pool :
279+ return {"state" : "zombie_suspected" , "evidence" : ["worker pool frames without coordinator loop" ]}
280+ return {"state" : "unknown" , "evidence" : ["no Zephyr coordinator frames found" ]}
281+
282+
277283def classify_diagnosis (
278284 * ,
279285 job : dict [str , Any ],
@@ -368,11 +374,12 @@ def add(signal: str, severity: str, evidence: list[str], escalation_hint: str) -
368374 "Inspect involved workers and recent process logs." ,
369375 )
370376
371- if thread_dump and "_coordinator_loop" not in thread_dump and "_worker" in thread_dump :
377+ thread_state = parse_zephyr_thread_state (thread_dump )
378+ if thread_state ["state" ] == "zombie_suspected" :
372379 add (
373380 "zombie_coordinator" ,
374381 "error" ,
375- [ "thread dump lacks _coordinator_loop and shows worker pool frames " ],
382+ thread_state [ "evidence " ],
376383 "Restart only after confirming with the user." ,
377384 )
378385
@@ -602,12 +609,33 @@ def zephyr_stage_progress(self, *, coord_job_id: str, max_lines: int = DEFAULT_Z
602609 def zephyr_coordinator_status (self , * , coord_job_id : str ) -> dict [str , Any ]:
603610 summary = self .job_summary (coord_job_id )["data" ]
604611 progress_payload = self .zephyr_stage_progress (coord_job_id = coord_job_id )["data" ]
612+ thread_target = f"{ coord_job_id } /0"
613+ thread_profile = self .profile_task (
614+ target = thread_target ,
615+ profile_type = "threads" ,
616+ duration_seconds = DEFAULT_PROFILE_SECONDS ,
617+ )
618+ thread_dump = str (thread_profile ["data" ].get ("text" , "" ))
619+ thread_state = parse_zephyr_thread_state (thread_dump )
620+ thread_warnings = list (thread_profile ["warnings" ])
621+ if thread_warnings :
622+ thread_state = {
623+ "state" : "unavailable" ,
624+ "evidence" : thread_warnings ,
625+ }
626+ diagnosis = classify_diagnosis (job = summary , logs = [], workers = [], thread_dump = thread_dump )
605627 return self .envelope (
606628 {
607629 "summary" : summary ,
608630 "progress" : progress_payload ["progress" ],
609631 "cursor" : progress_payload ["cursor" ],
610- }
632+ "thread_liveness" : {
633+ "target" : thread_target ,
634+ ** thread_state ,
635+ },
636+ "diagnosis" : diagnosis ,
637+ },
638+ warnings = thread_warnings ,
611639 )
612640
613641 def diagnose (self , * , job_id : str , log_lines : int = DEFAULT_LOG_LINES ) -> dict [str , Any ]:
0 commit comments