Skip to content

Commit b598794

Browse files
JustinPan-googOrbax Authors
authored andcommitted
Add overall timeout in AsyncCheckpointer.
PiperOrigin-RevId: 878104670
1 parent 36f6735 commit b598794

File tree

2 files changed

+61
-29
lines changed

2 files changed

+61
-29
lines changed

checkpoint/orbax/checkpoint/_src/checkpointers/async_checkpointer.py

Lines changed: 39 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
"""AsyncCheckpointer."""
1616

17+
import datetime
1718
import sys
1819
import threading
1920
import time
@@ -69,23 +70,27 @@ def _background_wait_for_commit_futures(
6970
on_commit_callback: Callable[[], None],
7071
*,
7172
barrier_sync_key_prefix: str,
72-
sync_fn: Callable[[str], None],
73+
sync_fn: Callable[[str, int], None],
74+
timeout_secs: int,
7375
primary_host: int | None,
7476
):
7577
"""A function to be run in a background thread that waits for futures."""
7678
current_process = multihost.process_index()
7779
current_thread_id = threading.current_thread().name
7880
process_count = jax.process_count()
7981
logging.info(
80-
'[process=%s][thread=%s] Background save thread started.',
82+
'[process=%s][thread=%s] Background save thread started. Deadline for'
83+
' this save operation is %s',
8184
current_process,
8285
current_thread_id,
86+
datetime.datetime.now() + datetime.timedelta(seconds=timeout_secs),
8387
)
8488
thread_start_time = time.time()
8589

8690
# Wait for commit operations to complete.
87-
for commit_future in commit_futures:
88-
commit_future.result()
91+
future.ChainedFuture(commit_futures, cb=lambda: None).result(
92+
timeout=timeout_secs
93+
)
8994
commit_duration_secs = time.time() - thread_start_time
9095
logging.info(
9196
'[process=%s][thread=%s] %d Handler Commit operations completed. Time'
@@ -111,30 +116,48 @@ def _background_wait_for_commit_futures(
111116
# All processes will wait at the barrier. When all processes are at the
112117
# barrier, the barrier will be satisfied. If not, then it will timeout.
113118
try:
119+
time_remaining_secs = future.get_remaining_time(
120+
thread_start_time, timeout_secs
121+
)
114122
sync_fn(
115123
multihost.unique_barrier_key(
116124
'async_write_complete',
117125
prefix=barrier_sync_key_prefix,
118126
suffix=f'{directory.name}',
119-
)
127+
),
128+
int(time_remaining_secs * 1000),
120129
)
121130
except jax.errors.JaxRuntimeError as e:
122131
if sys.version_info >= (3, 11):
123132
if 'DEADLINE_EXCEEDED' in str(e):
124133
_add_deadline_exceeded_notes(e)
125-
raise
134+
raise TimeoutError(
135+
'Timed out while waiting for async_write_complete barrier.'
136+
) from e
126137

127138
if utils.is_primary_host(primary_host):
128139
on_commit_callback()
129140
if process_count > 1:
130141
# Block until process 0 completes on_commit_callback.
131-
sync_fn(
132-
multihost.unique_barrier_key(
133-
'async_commit_complete',
134-
prefix=barrier_sync_key_prefix,
135-
suffix=f'{directory.name}',
136-
)
137-
)
142+
try:
143+
time_remaining_secs = future.get_remaining_time(
144+
thread_start_time, timeout_secs
145+
)
146+
sync_fn(
147+
multihost.unique_barrier_key(
148+
'async_commit_complete',
149+
prefix=barrier_sync_key_prefix,
150+
suffix=f'{directory.name}',
151+
),
152+
int(time_remaining_secs * 1000),
153+
)
154+
except jax.errors.JaxRuntimeError as e:
155+
if sys.version_info >= (3, 11):
156+
if 'DEADLINE_EXCEEDED' in str(e):
157+
_add_deadline_exceeded_notes(e)
158+
raise TimeoutError(
159+
'Timed out while waiting for async_commit_complete barrier.'
160+
) from e
138161

139162
thread_duration_secs = time.time() - thread_start_time
140163
jax.monitoring.record_event_duration_secs(
@@ -190,9 +213,8 @@ def __init__(
190213
self._thread = None
191214
self._exception = None
192215

193-
timeout_in_ms = self._timeout_secs * 1000
194-
self._sync_fn: Callable[[str], None] = lambda key: barrier_sync_fn(
195-
key=key, timeout_ms=timeout_in_ms
216+
self._sync_fn: Callable[[str, int], None] = (
217+
lambda key, timeout_ms: barrier_sync_fn(key=key, timeout_ms=timeout_ms)
196218
)
197219

198220
def __del__(self):
@@ -218,6 +240,7 @@ def _thread_func(
218240
on_commit_callback,
219241
barrier_sync_key_prefix=self._barrier_sync_key_prefix,
220242
sync_fn=self._sync_fn,
243+
timeout_secs=self._timeout_secs,
221244
primary_host=self._primary_host,
222245
)
223246
except Exception as e: # pylint: disable=broad-exception-caught

checkpoint/orbax/checkpoint/_src/futures/future.py

Lines changed: 22 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -170,14 +170,14 @@ class Future(Protocol):
170170
future, but merely wait for an ongoing operation to complete.
171171
"""
172172

173-
def result(self, timeout: Optional[int] = None) -> Any:
173+
def result(self, timeout: Optional[float] = None) -> Any:
174174
"""Waits for the future to complete its operation."""
175175
...
176176

177177

178178
class NoopFuture:
179179

180-
def result(self, timeout: Optional[int] = None) -> Any:
180+
def result(self, timeout: Optional[float] = None) -> Any:
181181
del timeout
182182
return None
183183

@@ -189,21 +189,18 @@ def __init__(self, futures: Sequence[Future], cb: Callable[[], None]):
189189
self._futures = futures
190190
self._cb = cb
191191

192-
def result(self, timeout: Optional[int] = None) -> Any:
192+
def result(self, timeout: Optional[float] = None) -> Any:
193193
"""Waits for all futures to complete."""
194194
n = len(self._futures)
195195
start = time.time()
196-
time_remaining = timeout
197196
for k, f in enumerate(self._futures):
198-
f.result(timeout=time_remaining)
199-
if time_remaining is not None:
200-
time_elapsed = time.time() - start
201-
time_remaining -= time_elapsed
202-
if time_remaining <= 0:
203-
raise TimeoutError(
204-
'ChainedFuture completed {:d}/{:d} futures but timed out after'
205-
' {:.2f} seconds.'.format(k, n, time_elapsed)
206-
)
197+
try:
198+
f.result(timeout=get_remaining_time(start, timeout))
199+
except TimeoutError as e:
200+
raise TimeoutError(
201+
f'ChainedFuture completed {k}/{n} futures but timed out after'
202+
f' {time.time() - start:.2f} seconds.'
203+
) from e
207204
time_elapsed = time.time() - start
208205
logging.vlog(
209206
1,
@@ -215,6 +212,18 @@ def result(self, timeout: Optional[int] = None) -> Any:
215212
self._cb()
216213

217214

215+
def get_remaining_time(
216+
start_time: float, timeout_secs: Optional[float]
217+
) -> Optional[float]:
218+
"""Returns remaining time in secs, or None if timeout_secs is None."""
219+
if timeout_secs is None:
220+
return None
221+
elapsed = time.time() - start_time
222+
if elapsed >= timeout_secs:
223+
raise TimeoutError(f'Timed out after {elapsed} seconds.')
224+
return timeout_secs - elapsed
225+
226+
218227
def wait_for_signals(
219228
receive_signals: Sequence[synchronization.HandlerAwaitableSignal],
220229
*,

0 commit comments

Comments
 (0)