Skip to content

Commit 5f44126

Browse files
github-actions[bot]rjpowerclaude
committed
Deduplicate unavailability handling into poll_with_retries in errors.py
Move the retry-with-backoff-on-unavailable logic from wait_for_job and wait_for_job_with_streaming into a shared poll_with_retries() function in the errors.py retry library. The new function respects the caller's deadline — if timeout expires during controller unavailability, it raises TimeoutError instead of continuing to retry for the full tolerance window. Co-authored-by: Russell Power <rjpower@users.noreply.github.com> Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 260706c commit 5f44126

3 files changed

Lines changed: 197 additions & 84 deletions

File tree

lib/iris/src/iris/cluster/client/remote_client.py

Lines changed: 19 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from iris.cluster.types import Entrypoint, EnvironmentSpec, JobName, TaskAttempt, adjust_tpu_replicas, is_job_finished
1818
from iris.rpc import cluster_pb2
1919
from iris.rpc.cluster_connect import ControllerServiceClientSync
20-
from iris.rpc.errors import call_with_retry, format_connect_error, is_retryable_error
20+
from iris.rpc.errors import call_with_retry, format_connect_error, poll_with_retries
2121
from iris.time_utils import Deadline, Duration, ExponentialBackoff
2222

2323
logger = logging.getLogger(__name__)
@@ -139,8 +139,9 @@ def wait_for_job(
139139
"""Wait for job to complete with exponential backoff polling.
140140
141141
If the controller becomes unavailable, retries with backoff for up to
142-
``CONTROLLER_UNAVAILABLE_TOLERANCE`` seconds before giving up. The
143-
unavailable timer resets each time a status check succeeds.
142+
``CONTROLLER_UNAVAILABLE_TOLERANCE`` seconds or until the caller's
143+
*timeout* expires — whichever comes first. The unavailability timer
144+
resets each time a status check succeeds.
144145
145146
Args:
146147
job_id: Full job ID
@@ -155,46 +156,14 @@ def wait_for_job(
155156
"""
156157
deadline = Deadline.from_seconds(timeout)
157158
backoff = ExponentialBackoff(initial=0.1, maximum=poll_interval)
158-
unavailable_backoff = ExponentialBackoff(initial=1.0, maximum=60.0, factor=2.0)
159-
unavailable_since: float | None = None
160159

161160
while True:
162-
try:
163-
job_info = self.get_job_status(job_id)
164-
except Exception as e:
165-
if not is_retryable_error(e):
166-
raise
167-
now = time.monotonic()
168-
if unavailable_since is None:
169-
unavailable_since = now
170-
elapsed_unavailable = now - unavailable_since
171-
if elapsed_unavailable >= CONTROLLER_UNAVAILABLE_TOLERANCE:
172-
logger.error(
173-
"Controller unavailable for %.0fs, giving up on %s",
174-
elapsed_unavailable,
175-
job_id,
176-
)
177-
raise
178-
logger.warning(
179-
"Controller unavailable for %s (%.0fs), job is still running server-side: %s",
180-
job_id,
181-
elapsed_unavailable,
182-
e,
183-
)
184-
interval = unavailable_backoff.next_interval()
185-
time.sleep(min(interval, deadline.remaining_seconds()))
186-
continue
187-
188-
# Controller responded — reset unavailability tracking.
189-
if unavailable_since is not None:
190-
elapsed_unavailable = time.monotonic() - unavailable_since
191-
logger.info(
192-
"Controller back online for %s after %.0fs of unavailability",
193-
job_id,
194-
elapsed_unavailable,
195-
)
196-
unavailable_since = None
197-
unavailable_backoff.reset()
161+
job_info = poll_with_retries(
162+
str(job_id),
163+
lambda: self.get_job_status(job_id),
164+
deadline=deadline,
165+
unavailable_tolerance=CONTROLLER_UNAVAILABLE_TOLERANCE,
166+
)
198167

199168
if is_job_finished(job_info.state):
200169
return job_info
@@ -222,9 +191,9 @@ def wait_for_job_with_streaming(
222191
credentials and endpoint configuration), avoiding client-side S3 access.
223192
224193
If the controller becomes unavailable, retries with backoff for up to
225-
``CONTROLLER_UNAVAILABLE_TOLERANCE`` seconds before giving up. Log fetch
226-
failures are treated the same way — they do not count toward a hard
227-
failure limit while the controller is unreachable.
194+
``CONTROLLER_UNAVAILABLE_TOLERANCE`` seconds or until the caller's
195+
*timeout* expires — whichever comes first. Log fetch failures are
196+
non-fatal — they log a warning but never abort monitoring.
228197
229198
Child job statuses are delivered inline in ``GetTaskLogsResponse`` (when
230199
*include_children* is True), so detecting state transitions requires no
@@ -236,52 +205,20 @@ def wait_for_job_with_streaming(
236205
"""
237206
deadline = Deadline.from_seconds(timeout)
238207
terminal_status: cluster_pb2.JobStatus | None = None
239-
unavailable_backoff = ExponentialBackoff(initial=1.0, maximum=60.0, factor=2.0)
240-
unavailable_since: float | None = None
241208
# Track child job states so we fire callbacks once per transition.
242209
child_job_states: dict[str, int] = {}
243210
cursor: int = 0
244211

245212
while True:
246-
try:
247-
status = self.get_job_status(job_id)
248-
except Exception as e:
249-
if not is_retryable_error(e):
250-
raise
251-
now = time.monotonic()
252-
if unavailable_since is None:
253-
unavailable_since = now
254-
elapsed_unavailable = now - unavailable_since
255-
if elapsed_unavailable >= CONTROLLER_UNAVAILABLE_TOLERANCE:
256-
logger.error(
257-
"Controller unavailable for %.0fs, giving up on %s",
258-
elapsed_unavailable,
259-
job_id,
260-
)
261-
raise
262-
logger.warning(
263-
"Controller unavailable for %s (%.0fs), job is still running server-side: %s",
264-
job_id,
265-
elapsed_unavailable,
266-
e,
267-
)
268-
interval = unavailable_backoff.next_interval()
269-
time.sleep(min(interval, deadline.remaining_seconds()))
270-
continue
213+
status = poll_with_retries(
214+
str(job_id),
215+
lambda: self.get_job_status(job_id),
216+
deadline=deadline,
217+
unavailable_tolerance=CONTROLLER_UNAVAILABLE_TOLERANCE,
218+
)
271219

272220
state_name = cluster_pb2.JobState.Name(status.state)
273221

274-
# Controller responded — reset unavailability tracking.
275-
if unavailable_since is not None:
276-
elapsed_unavailable = time.monotonic() - unavailable_since
277-
logger.info(
278-
"Controller back online for %s after %.0fs of unavailability",
279-
job_id,
280-
elapsed_unavailable,
281-
)
282-
unavailable_since = None
283-
unavailable_backoff.reset()
284-
285222
try:
286223
log_response = self.fetch_task_logs(
287224
job_id,

lib/iris/src/iris/rpc/errors.py

Lines changed: 93 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from google.protobuf.any_pb2 import Any as AnyProto
1616

1717
from iris.rpc import errors_pb2
18-
from iris.time_utils import ExponentialBackoff, Timestamp
18+
from iris.time_utils import Deadline, ExponentialBackoff, Timestamp
1919

2020
logger = logging.getLogger(__name__)
2121

@@ -220,3 +220,95 @@ def call_with_retry(
220220

221221
assert last_exception is not None
222222
raise last_exception
223+
224+
225+
def poll_with_retries(
226+
operation: str,
227+
poll_fn: Callable[[], T],
228+
*,
229+
deadline: Deadline,
230+
unavailable_tolerance: float = 3600.0,
231+
backoff: ExponentialBackoff | None = None,
232+
) -> T:
233+
"""Poll an RPC endpoint, tolerating transient unavailability.
234+
235+
Calls ``poll_fn`` in a loop. On retryable errors the function backs off
236+
and keeps trying for up to ``unavailable_tolerance`` seconds **or** until
237+
``deadline`` expires — whichever comes first. When the call succeeds the
238+
unavailability timer resets.
239+
240+
This is designed for monitoring loops (e.g. ``wait_for_job``) where the
241+
server-side work continues regardless of client polling failures.
242+
243+
Args:
244+
operation: Human-readable description for log messages.
245+
poll_fn: Callable that performs the RPC. Should raise on failure.
246+
deadline: Caller-supplied deadline — polling stops with ``TimeoutError``
247+
if the deadline expires, even during unavailability.
248+
unavailable_tolerance: Maximum seconds to tolerate continuous
249+
controller unavailability before re-raising the RPC error.
250+
backoff: Backoff for unavailability retries. Defaults to 1 s → 60 s.
251+
252+
Returns:
253+
The successful result of ``poll_fn``.
254+
255+
Raises:
256+
TimeoutError: If *deadline* expires while the controller is unavailable.
257+
Exception: The last RPC error if unavailability exceeds the tolerance,
258+
or any non-retryable error from ``poll_fn``.
259+
"""
260+
261+
if backoff is None:
262+
backoff = ExponentialBackoff(initial=1.0, maximum=60.0, factor=2.0)
263+
else:
264+
backoff = backoff.copy()
265+
266+
unavailable_since: float | None = None
267+
268+
while True:
269+
try:
270+
result = poll_fn()
271+
except Exception as e:
272+
if not is_retryable_error(e):
273+
raise
274+
275+
now = time.monotonic()
276+
if unavailable_since is None:
277+
unavailable_since = now
278+
elapsed_unavailable = now - unavailable_since
279+
280+
if elapsed_unavailable >= unavailable_tolerance:
281+
logger.error(
282+
"Controller unavailable for %.0fs, giving up on %s",
283+
elapsed_unavailable,
284+
operation,
285+
)
286+
raise
287+
288+
if deadline.expired():
289+
raise TimeoutError(
290+
f"{operation}: deadline expired after {elapsed_unavailable:.0f}s of controller unavailability"
291+
) from e
292+
293+
logger.warning(
294+
"Controller unavailable for %s (%.0fs), job is still running server-side: %s",
295+
operation,
296+
elapsed_unavailable,
297+
e,
298+
)
299+
interval = backoff.next_interval()
300+
time.sleep(min(interval, deadline.remaining_seconds()))
301+
continue
302+
303+
# Success — reset unavailability tracking.
304+
if unavailable_since is not None:
305+
elapsed_unavailable = time.monotonic() - unavailable_since
306+
logger.info(
307+
"Controller back online for %s after %.0fs of unavailability",
308+
operation,
309+
elapsed_unavailable,
310+
)
311+
unavailable_since = None
312+
backoff.reset()
313+
314+
return result

lib/iris/tests/rpc/test_errors.py

Lines changed: 85 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,9 @@
1212
connect_error_sanitized,
1313
connect_error_with_traceback,
1414
extract_error_details,
15+
poll_with_retries,
1516
)
16-
from iris.time_utils import ExponentialBackoff
17+
from iris.time_utils import Deadline, ExponentialBackoff
1718

1819

1920
def test_connect_error_with_traceback_populates_timestamp() -> None:
@@ -172,3 +173,86 @@ def fail_then_succeed():
172173
)
173174
assert result == "recovered"
174175
assert call_count == 4
176+
177+
178+
# -- poll_with_retries tests --
179+
180+
181+
def test_poll_with_retries_succeeds_immediately() -> None:
182+
result = poll_with_retries(
183+
"test",
184+
lambda: "ok",
185+
deadline=Deadline.from_seconds(5.0),
186+
)
187+
assert result == "ok"
188+
189+
190+
def test_poll_with_retries_retries_then_succeeds() -> None:
191+
call_count = 0
192+
193+
def flaky():
194+
nonlocal call_count
195+
call_count += 1
196+
if call_count <= 2:
197+
raise ConnectError(Code.UNAVAILABLE, "down")
198+
return "recovered"
199+
200+
result = poll_with_retries(
201+
"test",
202+
flaky,
203+
deadline=Deadline.from_seconds(5.0),
204+
backoff=ExponentialBackoff(initial=0.01, maximum=0.05),
205+
)
206+
assert result == "recovered"
207+
assert call_count == 3
208+
209+
210+
def test_poll_with_retries_respects_deadline() -> None:
211+
"""Deadline expiry during unavailability raises TimeoutError, not the RPC error."""
212+
213+
def always_fail():
214+
raise ConnectError(Code.UNAVAILABLE, "down")
215+
216+
with pytest.raises(TimeoutError, match="deadline expired"):
217+
poll_with_retries(
218+
"test",
219+
always_fail,
220+
deadline=Deadline.from_seconds(0.3),
221+
unavailable_tolerance=3600.0,
222+
backoff=ExponentialBackoff(initial=0.01, maximum=0.05),
223+
)
224+
225+
226+
def test_poll_with_retries_respects_unavailable_tolerance() -> None:
227+
"""Unavailability tolerance expiry re-raises the RPC error."""
228+
229+
def always_fail():
230+
raise ConnectError(Code.UNAVAILABLE, "down")
231+
232+
with pytest.raises(ConnectError) as exc_info:
233+
poll_with_retries(
234+
"test",
235+
always_fail,
236+
deadline=Deadline.from_seconds(10.0),
237+
unavailable_tolerance=0.3,
238+
backoff=ExponentialBackoff(initial=0.01, maximum=0.05),
239+
)
240+
assert exc_info.value.code == Code.UNAVAILABLE
241+
242+
243+
def test_poll_with_retries_raises_non_retryable_immediately() -> None:
244+
call_count = 0
245+
246+
def not_found():
247+
nonlocal call_count
248+
call_count += 1
249+
raise ConnectError(Code.NOT_FOUND, "gone")
250+
251+
with pytest.raises(ConnectError) as exc_info:
252+
poll_with_retries(
253+
"test",
254+
not_found,
255+
deadline=Deadline.from_seconds(5.0),
256+
)
257+
assert exc_info.value.code == Code.NOT_FOUND
258+
assert call_count == 1

0 commit comments

Comments
 (0)