Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 33 additions & 20 deletions lib/iris/src/iris/cluster/client/remote_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,16 @@
from iris.cluster.types import Entrypoint, EnvironmentSpec, JobName, TaskAttempt, adjust_tpu_replicas, is_job_finished
from iris.rpc import cluster_pb2
from iris.rpc.cluster_connect import ControllerServiceClientSync
from iris.rpc.errors import call_with_retry, format_connect_error
from iris.rpc.errors import call_with_retry, format_connect_error, poll_with_retries
from iris.time_utils import Deadline, Duration, ExponentialBackoff

logger = logging.getLogger(__name__)

# How long to tolerate controller unavailability before giving up on monitoring.
# The job itself keeps running server-side; this only affects the client's ability
# to poll status. One hour gives ample time for controller restarts/upgrades.
CONTROLLER_UNAVAILABLE_TOLERANCE = 3600.0


class RemoteClusterClient:
"""Cluster client via RPC to controller.
Expand Down Expand Up @@ -133,6 +138,11 @@ def wait_for_job(
) -> cluster_pb2.JobStatus:
"""Wait for job to complete with exponential backoff polling.

If the controller becomes unavailable, retries with backoff for up to
``CONTROLLER_UNAVAILABLE_TOLERANCE`` seconds or until the caller's
*timeout* expires — whichever comes first. The unavailability timer
resets each time a status check succeeds.

Args:
job_id: Full job ID
timeout: Maximum time to wait in seconds
Expand All @@ -148,7 +158,13 @@ def wait_for_job(
backoff = ExponentialBackoff(initial=0.1, maximum=poll_interval)

while True:
job_info = self.get_job_status(job_id)
job_info = poll_with_retries(
str(job_id),
lambda: self.get_job_status(job_id),
deadline=deadline,
unavailable_tolerance=CONTROLLER_UNAVAILABLE_TOLERANCE,
)

if is_job_finished(job_info.state):
return job_info

Expand All @@ -174,6 +190,11 @@ def wait_for_job_with_streaming(
Delegates log reading to the controller (which has the correct storage
credentials and endpoint configuration), avoiding client-side S3 access.

If the controller becomes unavailable, retries with backoff for up to
``CONTROLLER_UNAVAILABLE_TOLERANCE`` seconds or until the caller's
*timeout* expires — whichever comes first. Log fetch failures are
non-fatal — they log a warning but never abort monitoring.

Child job statuses are delivered inline in ``GetTaskLogsResponse`` (when
*include_children* is True), so detecting state transitions requires no
additional RPC calls.
Expand All @@ -184,15 +205,18 @@ def wait_for_job_with_streaming(
"""
deadline = Deadline.from_seconds(timeout)
terminal_status: cluster_pb2.JobStatus | None = None
log_fetch_backoff = ExponentialBackoff(initial=1.0, maximum=30.0)
consecutive_log_failures = 0
max_log_failures = 5
# Track child job states so we fire callbacks once per transition.
child_job_states: dict[str, int] = {}
cursor: int = 0

while True:
status = self.get_job_status(job_id)
status = poll_with_retries(
str(job_id),
lambda: self.get_job_status(job_id),
deadline=deadline,
unavailable_tolerance=CONTROLLER_UNAVAILABLE_TOLERANCE,
)

state_name = cluster_pb2.JobState.Name(status.state)

try:
Expand All @@ -203,20 +227,10 @@ def wait_for_job_with_streaming(
cursor=cursor,
min_level=min_level,
)
consecutive_log_failures = 0
log_fetch_backoff.reset()
except Exception as e:
consecutive_log_failures += 1
# Log fetch failures are non-fatal — we still have the job status.
msg = format_connect_error(e) if isinstance(e, ConnectError) else str(e)
logger.warning(
"Failed to fetch logs for %s (%d/%d), will retry:\n%s",
job_id,
consecutive_log_failures,
max_log_failures,
msg,
)
if consecutive_log_failures >= max_log_failures:
raise
logger.warning("Failed to fetch logs for %s, will retry: %s", job_id, msg)
log_response = None

if log_response is not None:
Expand Down Expand Up @@ -252,8 +266,7 @@ def wait_for_job_with_streaming(
continue

deadline.raise_if_expired(f"Job {job_id} did not complete in {timeout}s")
sleep_time = log_fetch_backoff.next_interval() if consecutive_log_failures > 0 else poll_interval
time.sleep(sleep_time)
time.sleep(poll_interval)

def terminate_job(self, job_id: JobName) -> None:
request = cluster_pb2.Controller.TerminateJobRequest(job_id=job_id.to_wire())
Expand Down
125 changes: 114 additions & 11 deletions lib/iris/src/iris/rpc/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from google.protobuf.any_pb2 import Any as AnyProto

from iris.rpc import errors_pb2
from iris.time_utils import ExponentialBackoff, Timestamp
from iris.time_utils import Deadline, ExponentialBackoff, Timestamp

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -143,17 +143,23 @@ def call_with_retry(
*,
on_retry: Callable[[Exception], None] | None = None,
max_attempts: int = 20,
max_elapsed: float | None = None,
backoff: ExponentialBackoff | None = None,
) -> T:
"""Execute an RPC call with exponential backoff retry.

Retries stop when either ``max_attempts`` is exhausted **or**
``max_elapsed`` seconds have passed, whichever comes first.

Args:
operation: Description of the operation for logging
call_fn: Callable that performs the RPC
on_retry: Optional callback invoked with the exception on every retryable
failure, including the final attempt. Useful for clearing cached
connections so subsequent calls can re-resolve endpoints.
max_attempts: Maximum number of attempts (default: 20)
max_elapsed: Maximum wall-clock seconds to keep retrying. ``None``
means no time limit (only ``max_attempts`` is used).
backoff: Backoff configuration. A fresh copy is made internally so the
caller's instance is not mutated. Defaults to
ExponentialBackoff(initial=0.5, maximum=10.0, factor=2.0).
Expand All @@ -169,43 +175,140 @@ def call_with_retry(
else:
backoff = backoff.copy()
last_exception = None
start_time = time.monotonic()

for attempt in range(max_attempts):
try:
return call_fn()
except Exception as e:
last_exception = e
if not is_retryable_error(e):
# Non-retryable error, fail immediately
raise

# Always clear stale state on retryable errors, even on the final
# attempt, so the next call from the caller can re-resolve.
if on_retry is not None:
on_retry(e)

if attempt + 1 >= max_attempts:
# Final attempt failed, raise
elapsed = time.monotonic() - start_time
attempts_exhausted = attempt + 1 >= max_attempts
time_exhausted = max_elapsed is not None and elapsed >= max_elapsed

if attempts_exhausted or time_exhausted:
logger.exception(
"Operation %s failed after %d attempts: %s",
"Operation %s failed after %d attempts (%.1fs elapsed): %s",
operation,
max_attempts,
attempt + 1,
elapsed,
e,
)
raise

# Log and retry
delay = backoff.next_interval()
if max_elapsed is not None:
remaining = max_elapsed - elapsed
delay = min(delay, max(0, remaining))

logger.exception(
"Operation %s failed (attempt %d/%d), retrying in %.2fs: %s",
"Operation %s failed (attempt %d/%d, %.1fs elapsed), retrying in %.2fs: %s",
operation,
attempt + 1,
max_attempts,
elapsed,
delay,
e,
)
time.sleep(delay)

# Should not reach here due to raise in loop, but satisfy type checker
assert last_exception is not None
raise last_exception


def poll_with_retries(
operation: str,
poll_fn: Callable[[], T],
*,
deadline: Deadline,
unavailable_tolerance: float = 3600.0,
backoff: ExponentialBackoff | None = None,
) -> T:
"""Poll an RPC endpoint, tolerating transient unavailability.

Calls ``poll_fn`` in a loop. On retryable errors the function backs off
and keeps trying for up to ``unavailable_tolerance`` seconds **or** until
``deadline`` expires — whichever comes first. When the call succeeds the
unavailability timer resets.

This is designed for monitoring loops (e.g. ``wait_for_job``) where the
server-side work continues regardless of client polling failures.

Args:
operation: Human-readable description for log messages.
poll_fn: Callable that performs the RPC. Should raise on failure.
deadline: Caller-supplied deadline — polling stops with ``TimeoutError``
if the deadline expires, even during unavailability.
unavailable_tolerance: Maximum seconds to tolerate continuous
controller unavailability before re-raising the RPC error.
backoff: Backoff for unavailability retries. Defaults to 1 s → 60 s.

Returns:
The successful result of ``poll_fn``.

Raises:
TimeoutError: If *deadline* expires while the controller is unavailable.
Exception: The last RPC error if unavailability exceeds the tolerance,
or any non-retryable error from ``poll_fn``.
"""

if backoff is None:
backoff = ExponentialBackoff(initial=1.0, maximum=60.0, factor=2.0)
else:
backoff = backoff.copy()

unavailable_since: float | None = None

while True:
try:
result = poll_fn()
except Exception as e:
if not is_retryable_error(e):
raise

now = time.monotonic()
if unavailable_since is None:
unavailable_since = now
elapsed_unavailable = now - unavailable_since

if elapsed_unavailable >= unavailable_tolerance:
logger.error(
"Controller unavailable for %.0fs, giving up on %s",
elapsed_unavailable,
operation,
)
raise

if deadline.expired():
raise TimeoutError(
f"{operation}: deadline expired after {elapsed_unavailable:.0f}s of controller unavailability"
) from e

logger.warning(
"Controller unavailable for %s (%.0fs), job is still running server-side: %s",
operation,
elapsed_unavailable,
e,
)
interval = backoff.next_interval()
time.sleep(min(interval, deadline.remaining_seconds()))
continue

# Success — reset unavailability tracking.
if unavailable_since is not None:
elapsed_unavailable = time.monotonic() - unavailable_since
logger.info(
"Controller back online for %s after %.0fs of unavailability",
operation,
elapsed_unavailable,
)
unavailable_since = None
backoff.reset()

return result
Loading
Loading