Skip to content

Commit be63d53

Browse files
rjpowerHelw150
authored andcommitted
Extract retry_with_backoff to rigging for reuse outside RPC (#4499)
Add retry_with_backoff() to rigging.timing alongside ExponentialBackoff. It accepts a retryable predicate and on_retry(exc, attempt) callback, making it usable for any exception type—not just gRPC ConnectErrors. Refactor iris.call_with_retry and marin.call_with_hf_backoff to delegate to it. Delete fn_utils.with_retries (fixed delay, no backoff) and update its two callers in hf_upload.py.
1 parent 7eed5bc commit be63d53

6 files changed

Lines changed: 212 additions & 136 deletions

File tree

lib/iris/src/iris/rpc/errors.py

Lines changed: 17 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717
from iris.rpc import errors_pb2
1818
from iris.time_proto import timestamp_to_proto
19-
from rigging.timing import Deadline, ExponentialBackoff, Timestamp
19+
from rigging.timing import Deadline, ExponentialBackoff, Timestamp, retry_with_backoff
2020

2121
logger = logging.getLogger(__name__)
2222

@@ -171,56 +171,22 @@ def call_with_retry(
171171
Raises:
172172
Exception from call_fn if all retries exhausted or error is not retryable
173173
"""
174-
if backoff is None:
175-
backoff = ExponentialBackoff(initial=0.5, maximum=10.0, factor=2.0)
176-
else:
177-
backoff = backoff.copy()
178-
last_exception = None
179-
start_time = time.monotonic()
180-
181-
for attempt in range(max_attempts):
182-
try:
183-
return call_fn()
184-
except Exception as e:
185-
last_exception = e
186-
if not is_retryable_error(e):
187-
raise
188-
189-
if on_retry is not None:
190-
on_retry(e)
191-
192-
elapsed = time.monotonic() - start_time
193-
attempts_exhausted = attempt + 1 >= max_attempts
194-
time_exhausted = max_elapsed is not None and elapsed >= max_elapsed
195-
196-
if attempts_exhausted or time_exhausted:
197-
logger.exception(
198-
"Operation %s failed after %d attempts (%.1fs elapsed): %s",
199-
operation,
200-
attempt + 1,
201-
elapsed,
202-
e,
203-
)
204-
raise
205-
206-
delay = backoff.next_interval()
207-
if max_elapsed is not None:
208-
remaining = max_elapsed - elapsed
209-
delay = min(delay, max(0, remaining))
210-
211-
logger.exception(
212-
"Operation %s failed (attempt %d/%d, %.1fs elapsed), retrying in %.2fs: %s",
213-
operation,
214-
attempt + 1,
215-
max_attempts,
216-
elapsed,
217-
delay,
218-
e,
219-
)
220-
time.sleep(delay)
221-
222-
assert last_exception is not None
223-
raise last_exception
174+
wrapped_on_retry: Callable[[Exception, int], None] | None = None
175+
if on_retry is not None:
176+
177+
def wrapped_on_retry(exc: Exception, _attempt: int) -> None:
178+
assert on_retry is not None
179+
on_retry(exc)
180+
181+
return retry_with_backoff(
182+
call_fn,
183+
retryable=is_retryable_error,
184+
max_attempts=max_attempts,
185+
max_elapsed=max_elapsed,
186+
backoff=backoff,
187+
on_retry=wrapped_on_retry,
188+
operation=operation,
189+
)
224190

225191

226192
def poll_with_retries(

lib/marin/src/marin/export/hf_upload.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
# SPDX-License-Identifier: Apache-2.0
33

44
import dataclasses
5-
import functools
65
import io
76
import logging
87
import os
@@ -14,11 +13,11 @@
1413
import humanfriendly
1514
from fsspec.implementations.local import LocalFileSystem
1615
from rigging.filesystem import open_url
16+
from rigging.timing import ExponentialBackoff, retry_with_backoff
1717
from huggingface_hub import create_commit, upload_folder
1818
from tqdm_loggable.auto import tqdm
1919

2020
from marin.execution import ExecutorStep, InputName
21-
from marin.utilities.fn_utils import with_retries
2221
from marin.utils import fsspec_glob
2322

2423
logger = logging.getLogger(__name__)
@@ -222,16 +221,22 @@ def _actually_upload_to_hf(config: UploadToHfConfig):
222221
)
223222

224223

225-
@functools.wraps(upload_folder)
226-
@with_retries()
227224
def retrying_upload_folder(*args, **kwargs):
228-
return upload_folder(*args, **kwargs)
225+
return retry_with_backoff(
226+
lambda: upload_folder(*args, **kwargs),
227+
max_attempts=3,
228+
backoff=ExponentialBackoff(initial=2.0, maximum=30.0, factor=2.0),
229+
operation="upload_folder",
230+
)
229231

230232

231-
@functools.wraps(create_commit)
232-
@with_retries()
233233
def retrying_create_commit(*args, **kwargs):
234-
return create_commit(*args, **kwargs)
234+
return retry_with_backoff(
235+
lambda: create_commit(*args, **kwargs),
236+
max_attempts=3,
237+
backoff=ExponentialBackoff(initial=2.0, maximum=30.0, factor=2.0),
238+
operation="create_commit",
239+
)
235240

236241

237242
def _wrap_in_buffered_base(fileobj):

lib/marin/src/marin/utilities/fn_utils.py

Lines changed: 0 additions & 38 deletions
This file was deleted.

lib/marin/src/marin/utils.py

Lines changed: 8 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,6 @@
44
import functools
55
import logging
66
import os
7-
import random
8-
import time
97
from collections.abc import Callable
108
from contextlib import contextmanager
119
from dataclasses import fields, is_dataclass
@@ -17,6 +15,7 @@
1715
import fsspec
1816
import requests
1917
from rigging.filesystem import url_to_fs
18+
from rigging.timing import ExponentialBackoff, retry_with_backoff
2019
from huggingface_hub.utils import HfHubHTTPError
2120

2221
logger = logging.getLogger(__name__)
@@ -153,47 +152,22 @@ def _hf_should_retry(exc: Exception) -> bool:
153152
return any(keyword in message for keyword in _HF_RETRY_KEYWORDS)
154153

155154

156-
def _hf_sleep_with_jitter(delay: float, max_delay: float) -> tuple[float, float]:
157-
jitter = random.uniform(0.5, 1.5)
158-
sleep_seconds = min(delay * jitter, max_delay)
159-
time.sleep(sleep_seconds)
160-
next_delay = min(delay * 2, max_delay)
161-
return sleep_seconds, next_delay
162-
163-
164155
def call_with_hf_backoff(
165156
fn: Callable[[], T],
166157
*,
167158
context: str,
168159
max_attempts: int = 6,
169160
initial_delay: float = 2.0,
170161
max_delay: float = 60.0,
171-
logger: logging.Logger | None = None,
172162
) -> T:
173163
"""Call ``fn`` with exponential backoff tuned for HF rate limits."""
174-
175-
log_obj = logger or logging.getLogger(__name__)
176-
delay = initial_delay
177-
178-
for attempt in range(1, max_attempts + 1):
179-
try:
180-
return fn()
181-
except Exception as exc: # pragma: no cover - network failure
182-
retryable = _hf_should_retry(exc)
183-
if not retryable or attempt == max_attempts:
184-
raise
185-
186-
sleep_seconds, delay = _hf_sleep_with_jitter(delay, max_delay)
187-
log_obj.warning(
188-
"HF request failed for %s (attempt %s/%s): %s. Retrying in %.1fs",
189-
context,
190-
attempt,
191-
max_attempts,
192-
exc,
193-
sleep_seconds,
194-
)
195-
196-
raise RuntimeError(f"Exceeded max attempts ({max_attempts}) for HF request: {context}")
164+
return retry_with_backoff(
165+
fn,
166+
retryable=_hf_should_retry,
167+
max_attempts=max_attempts,
168+
backoff=ExponentialBackoff(initial=initial_delay, maximum=max_delay, factor=2.0, jitter=0.25),
169+
operation=context,
170+
)
197171

198172

199173
def load_dataset_with_backoff(
@@ -202,7 +176,6 @@ def load_dataset_with_backoff(
202176
max_attempts: int = 6,
203177
initial_delay: float = 2.0,
204178
max_delay: float = 120.0,
205-
logger: logging.Logger | None = None,
206179
**dataset_kwargs: Any,
207180
):
208181
return call_with_hf_backoff(
@@ -211,7 +184,6 @@ def load_dataset_with_backoff(
211184
max_attempts=max_attempts,
212185
initial_delay=initial_delay,
213186
max_delay=max_delay,
214-
logger=logger,
215187
)
216188

217189

@@ -222,7 +194,6 @@ def load_tokenizer_with_backoff(
222194
max_attempts: int = 6,
223195
initial_delay: float = 2.0,
224196
max_delay: float = 60.0,
225-
logger: logging.Logger | None = None,
226197
):
227198
from levanter.tokenizers import load_tokenizer
228199

@@ -233,7 +204,6 @@ def load_tokenizer_with_backoff(
233204
max_attempts=max_attempts,
234205
initial_delay=initial_delay,
235206
max_delay=max_delay,
236-
logger=logger,
237207
)
238208

239209

lib/rigging/src/rigging/timing.py

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import time
1010
from collections.abc import Callable, Iterator
1111
from datetime import datetime, timedelta, timezone
12+
from typing import TypeVar
1213

1314
logger = logging.getLogger(__name__)
1415

@@ -554,3 +555,73 @@ def wait_until_or_raise(
554555
) -> None:
555556
if not self.wait_until(condition, timeout):
556557
raise TimeoutError(error_message)
558+
559+
560+
T = TypeVar("T")
561+
562+
563+
def retry_with_backoff(
564+
call_fn: Callable[[], T],
565+
*,
566+
retryable: Callable[[Exception], bool] = lambda _: True,
567+
max_attempts: int = 10,
568+
max_elapsed: float | None = None,
569+
backoff: ExponentialBackoff | None = None,
570+
on_retry: Callable[[Exception, int], None] | None = None,
571+
operation: str = "",
572+
) -> T:
573+
"""Execute call_fn with exponential backoff retry.
574+
575+
Args:
576+
call_fn: Function to call and potentially retry.
577+
retryable: Returns True to retry this exception, False to re-raise immediately.
578+
Defaults to retrying all exceptions.
579+
max_attempts: Maximum total attempts (default 10).
580+
max_elapsed: Wall-clock budget in seconds; None means no time limit.
581+
backoff: Backoff schedule; a fresh copy is made internally so the caller's
582+
instance is not mutated. Defaults to
583+
ExponentialBackoff(initial=0.5, maximum=10.0, factor=2.0).
584+
on_retry: Called with (exception, attempt_index) before each sleep.
585+
operation: Description used in log messages.
586+
"""
587+
if backoff is None:
588+
backoff = ExponentialBackoff(initial=0.5, maximum=10.0, factor=2.0)
589+
else:
590+
backoff = backoff.copy()
591+
592+
start_time = time.monotonic()
593+
594+
for attempt in range(max_attempts):
595+
try:
596+
return call_fn()
597+
except Exception as e:
598+
if not retryable(e):
599+
raise
600+
601+
if on_retry is not None:
602+
on_retry(e, attempt)
603+
604+
elapsed = time.monotonic() - start_time
605+
attempts_exhausted = attempt + 1 >= max_attempts
606+
time_exhausted = max_elapsed is not None and elapsed >= max_elapsed
607+
608+
if attempts_exhausted or time_exhausted:
609+
raise
610+
611+
delay = backoff.next_interval()
612+
if max_elapsed is not None:
613+
remaining = max_elapsed - elapsed
614+
delay = min(delay, max(0, remaining))
615+
616+
logger.warning(
617+
"Operation %s failed (attempt %d/%d, %.1fs elapsed), retrying in %.2fs: %s",
618+
operation,
619+
attempt + 1,
620+
max_attempts,
621+
elapsed,
622+
delay,
623+
e,
624+
)
625+
time.sleep(delay)
626+
627+
raise AssertionError("unreachable")

0 commit comments

Comments
 (0)