Skip to content

Commit 28f2170

Browse files
authored
[iris] Use lightweight GetJobState for Fray actor polling (#5021)
FrayActorJob.wait_ready and is_done were calling the heavy GetJobStatus RPC on every 0.5s tick just to read .state, pairing 1:1 with the GetJobState polls from wait_for_job. Switch both to the lightweight state RPC, back off from 0.1s to 5s between polls, and add IrisClient.job_state(job_id) plus wire Job.state through it. The full JobStatus is now fetched only on the terminal-error path where the error message is actually needed.
1 parent c2df1e7 commit 28f2170

File tree

3 files changed

+47
-54
lines changed

3 files changed

+47
-54
lines changed

lib/fray/src/fray/v2/iris_backend.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
from iris.cluster.types import CoschedulingConfig, EnvironmentSpec, ResourceSpec, is_job_finished
3737
from iris.cluster.types import Entrypoint as IrisEntrypoint
3838
from iris.rpc import job_pb2
39+
from rigging.timing import ExponentialBackoff
3940

4041
from fray.v2.actor import (
4142
ActorContext,
@@ -477,7 +478,7 @@ def wait_ready(self, count: int | None = None, timeout: float = 900.0) -> list[A
477478
"""
478479
target = count if count is not None else self._count
479480
start = time.monotonic()
480-
sleep_secs = 0.5
481+
backoff = ExponentialBackoff(initial=0.1, maximum=5.0)
481482

482483
while True:
483484
self.discover_new(target=target)
@@ -487,28 +488,29 @@ def wait_ready(self, count: int | None = None, timeout: float = 900.0) -> list[A
487488

488489
# Fail fast if the underlying job has terminated (e.g. crash, OOM,
489490
# missing interpreter). Without this check we'd spin for the full
490-
# timeout waiting for endpoints that will never appear.
491+
# timeout waiting for endpoints that will never appear. Use the
492+
# lightweight state-only RPC and fetch the full status only when we
493+
# actually need the error message.
491494
client = self._get_client()
492-
job_status = client.status(self._job_id)
493-
if is_job_finished(job_status.state):
494-
error = job_status.error or "unknown error"
495+
state = client.job_state(self._job_id)
496+
if is_job_finished(state):
497+
error = client.status(self._job_id).error or "unknown error"
495498
raise RuntimeError(
496499
f"Actor job {self._job_id} finished before all actors registered "
497500
f"({len(self._discovered_names)}/{target} ready). "
498-
f"Job state={job_status.state}, error={error}"
501+
f"Job state={state}, error={error}"
499502
)
500503

501504
elapsed = time.monotonic() - start
502505
if elapsed >= timeout:
503506
raise TimeoutError(f"Only {len(self._discovered_names)}/{target} actors ready after {timeout}s")
504507

505-
time.sleep(sleep_secs)
508+
time.sleep(backoff.next_interval())
506509

507510
def is_done(self) -> bool:
508511
"""Return True if the Iris worker job has permanently terminated."""
509512
client = self._get_client()
510-
job_status = client.status(self._job_id)
511-
return is_job_finished(job_status.state)
513+
return is_job_finished(client.job_state(self._job_id))
512514

513515
def shutdown(self) -> None:
514516
"""Terminate the actor job."""

lib/iris/src/iris/client/client.py

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
from contextvars import ContextVar
2424
from dataclasses import dataclass
2525
from pathlib import Path
26-
from typing import Protocol
26+
from typing import Protocol, cast
2727

2828
from connectrpc.code import Code
2929
from connectrpc.errors import ConnectError
@@ -224,18 +224,14 @@ def status(self) -> job_pb2.JobStatus:
224224
"""
225225
return self._client._cluster_client.get_job_status(self._job_id)
226226

227-
def state_only(self) -> int:
227+
def state_only(self) -> job_pb2.JobState:
228228
"""Lightweight state query that avoids loading tasks/attempts/workers."""
229-
states = self._client._cluster_client.get_job_states([self._job_id])
230-
wire_id = self._job_id.to_wire()
231-
if wire_id not in states:
232-
raise KeyError(f"Job {wire_id} not found")
233-
return states[wire_id]
229+
return self._client.job_state(self._job_id)
234230

235231
@property
236232
def state(self) -> job_pb2.JobState:
237-
"""Get current job state (shortcut for status().state)."""
238-
return self.status().state
233+
"""Get current job state via the lightweight state-only RPC."""
234+
return self.state_only()
239235

240236
def tasks(self) -> list[Task]:
241237
"""Get all tasks for this job.
@@ -706,6 +702,17 @@ def status(self, job_id: JobName) -> job_pb2.JobStatus:
706702
"""
707703
return self._cluster_client.get_job_status(job_id)
708704

705+
def job_state(self, job_id: JobName) -> job_pb2.JobState:
706+
"""Lightweight state query that avoids loading tasks/attempts/workers.
707+
708+
Prefer this over ``status(job_id).state`` for polling loops.
709+
"""
710+
states = self._cluster_client.get_job_states([job_id])
711+
wire_id = job_id.to_wire()
712+
if wire_id not in states:
713+
raise KeyError(f"Job {wire_id} not found")
714+
return cast(job_pb2.JobState, states[wire_id])
715+
709716
def terminate(self, job_id: JobName) -> None:
710717
"""Terminate a running job.
711718

0 commit comments

Comments
 (0)