Skip to content

Commit b74554c

Browse files
JustinPan-googOrbax Authors
authored andcommitted
Internal change
PiperOrigin-RevId: 878104670
1 parent 960ca54 commit b74554c

File tree

1 file changed

+24
-7
lines changed

1 file changed

+24
-7
lines changed

checkpoint/orbax/checkpoint/_src/checkpointers/async_checkpointer.py

Lines changed: 24 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -63,13 +63,25 @@ def _on_commit_callback(
6363
)
6464

6565

66+
def _check_for_timeout(
67+
thread_start_time: float,
68+
timeout_secs: int,
69+
):
70+
"""Checks if the timeout has been exceeded."""
71+
time_remaining_secs = timeout_secs - (time.time() - thread_start_time)
72+
if time_remaining_secs <= 0:
73+
raise TimeoutError()
74+
return time_remaining_secs
75+
76+
6677
def _background_wait_for_commit_futures(
6778
directory: epath.Path,
6879
commit_futures: Sequence[future.Future],
6980
on_commit_callback: Callable[[], None],
7081
*,
7182
barrier_sync_key_prefix: str,
72-
sync_fn: Callable[[str], None],
83+
sync_fn: Callable[[str, int], None],
84+
timeout_secs: int,
7385
primary_host: int | None,
7486
):
7587
"""A function to be run in a background thread that waits for futures."""
@@ -85,7 +97,8 @@ def _background_wait_for_commit_futures(
8597

8698
# Wait for commit operations to complete.
8799
for commit_future in commit_futures:
88-
commit_future.result()
100+
time_remaining_secs = _check_for_timeout(thread_start_time, timeout_secs)
101+
commit_future.result(timeout=int(time_remaining_secs))
89102
commit_duration_secs = time.time() - thread_start_time
90103
logging.info(
91104
'[process=%s][thread=%s] %d Handler Commit operations completed. Time'
@@ -111,12 +124,14 @@ def _background_wait_for_commit_futures(
111124
# All processes will wait at the barrier. When all processes are at the
112125
# barrier, the barrier will be satisfied. If not, then it will timeout.
113126
try:
127+
time_remaining_secs = _check_for_timeout(thread_start_time, timeout_secs)
114128
sync_fn(
115129
multihost.unique_barrier_key(
116130
'async_write_complete',
117131
prefix=barrier_sync_key_prefix,
118132
suffix=f'{directory.name}',
119-
)
133+
),
134+
int(time_remaining_secs * 1000),
120135
)
121136
except jax.errors.JaxRuntimeError as e:
122137
if sys.version_info >= (3, 11):
@@ -128,12 +143,14 @@ def _background_wait_for_commit_futures(
128143
on_commit_callback()
129144
if process_count > 1:
130145
# Block until process 0 completes on_commit_callback.
146+
time_remaining_secs = _check_for_timeout(thread_start_time, timeout_secs)
131147
sync_fn(
132148
multihost.unique_barrier_key(
133149
'async_commit_complete',
134150
prefix=barrier_sync_key_prefix,
135151
suffix=f'{directory.name}',
136-
)
152+
),
153+
int(time_remaining_secs * 1000),
137154
)
138155

139156
thread_duration_secs = time.time() - thread_start_time
@@ -190,9 +207,8 @@ def __init__(
190207
self._thread = None
191208
self._exception = None
192209

193-
timeout_in_ms = self._timeout_secs * 1000
194-
self._sync_fn: Callable[[str], None] = lambda key: barrier_sync_fn(
195-
key=key, timeout_ms=timeout_in_ms
210+
self._sync_fn: Callable[[str, int], None] = (
211+
lambda key, timeout_ms: barrier_sync_fn(key=key, timeout_ms=timeout_ms)
196212
)
197213

198214
def __del__(self):
@@ -218,6 +234,7 @@ def _thread_func(
218234
on_commit_callback,
219235
barrier_sync_key_prefix=self._barrier_sync_key_prefix,
220236
sync_fn=self._sync_fn,
237+
timeout_secs=self._timeout_secs,
221238
primary_host=self._primary_host,
222239
)
223240
except Exception as e: # pylint: disable=broad-exception-caught

0 commit comments

Comments
 (0)