1414
1515"""AsyncCheckpointer."""
1616
17+ import datetime
1718import sys
1819import threading
1920import time
@@ -69,23 +70,27 @@ def _background_wait_for_commit_futures(
6970 on_commit_callback : Callable [[], None ],
7071 * ,
7172 barrier_sync_key_prefix : str ,
72- sync_fn : Callable [[str ], None ],
73+ sync_fn : Callable [[str , int ], None ],
74+ timeout_secs : int ,
7375 primary_host : int | None ,
7476):
7577 """A function to be run in a background thread that waits for futures."""
7678 current_process = multihost .process_index ()
7779 current_thread_id = threading .current_thread ().name
7880 process_count = jax .process_count ()
7981 logging .info (
80- '[process=%s][thread=%s] Background save thread started.' ,
82+ '[process=%s][thread=%s] Background save thread started. Deadline for'
83+ ' this save operation is %s' ,
8184 current_process ,
8285 current_thread_id ,
86+ datetime .datetime .now () + datetime .timedelta (seconds = timeout_secs ),
8387 )
8488 thread_start_time = time .time ()
8589
8690 # Wait for commit operations to complete.
87- for commit_future in commit_futures :
88- commit_future .result ()
91+ future .ChainedFuture (commit_futures , cb = lambda : None ).result (
92+ timeout = timeout_secs
93+ )
8994 commit_duration_secs = time .time () - thread_start_time
9095 logging .info (
9196 '[process=%s][thread=%s] %d Handler Commit operations completed. Time'
@@ -111,30 +116,48 @@ def _background_wait_for_commit_futures(
111116 # All processes will wait at the barrier. When all processes are at the
112117 # barrier, the barrier will be satisfied. If not, then it will timeout.
113118 try :
119+ time_remaining_secs = future .get_remaining_time (
120+ thread_start_time , timeout_secs
121+ )
114122 sync_fn (
115123 multihost .unique_barrier_key (
116124 'async_write_complete' ,
117125 prefix = barrier_sync_key_prefix ,
118126 suffix = f'{ directory .name } ' ,
119- )
127+ ),
128+ int (time_remaining_secs * 1000 ),
120129 )
121130 except jax .errors .JaxRuntimeError as e :
122131 if sys .version_info >= (3 , 11 ):
123132 if 'DEADLINE_EXCEEDED' in str (e ):
124133 _add_deadline_exceeded_notes (e )
125- raise
134+ raise TimeoutError (
135+ 'Timed out while waiting for async_write_complete barrier.'
136+ ) from e
126137
127138 if utils .is_primary_host (primary_host ):
128139 on_commit_callback ()
129140 if process_count > 1 :
130141 # Block until process 0 completes on_commit_callback.
131- sync_fn (
132- multihost .unique_barrier_key (
133- 'async_commit_complete' ,
134- prefix = barrier_sync_key_prefix ,
135- suffix = f'{ directory .name } ' ,
136- )
137- )
142+ try :
143+ time_remaining_secs = future .get_remaining_time (
144+ thread_start_time , timeout_secs
145+ )
146+ sync_fn (
147+ multihost .unique_barrier_key (
148+ 'async_commit_complete' ,
149+ prefix = barrier_sync_key_prefix ,
150+ suffix = f'{ directory .name } ' ,
151+ ),
152+ int (time_remaining_secs * 1000 ),
153+ )
154+ except jax .errors .JaxRuntimeError as e :
155+ if sys .version_info >= (3 , 11 ):
156+ if 'DEADLINE_EXCEEDED' in str (e ):
157+ _add_deadline_exceeded_notes (e )
158+ raise TimeoutError (
159+ 'Timed out while waiting for async_commit_complete barrier.'
160+ ) from e
138161
139162 thread_duration_secs = time .time () - thread_start_time
140163 jax .monitoring .record_event_duration_secs (
@@ -190,9 +213,8 @@ def __init__(
190213 self ._thread = None
191214 self ._exception = None
192215
193- timeout_in_ms = self ._timeout_secs * 1000
194- self ._sync_fn : Callable [[str ], None ] = lambda key : barrier_sync_fn (
195- key = key , timeout_ms = timeout_in_ms
216+ self ._sync_fn : Callable [[str , int ], None ] = (
217+ lambda key , timeout_ms : barrier_sync_fn (key = key , timeout_ms = timeout_ms )
196218 )
197219
198220 def __del__ (self ):
@@ -218,6 +240,7 @@ def _thread_func(
218240 on_commit_callback ,
219241 barrier_sync_key_prefix = self ._barrier_sync_key_prefix ,
220242 sync_fn = self ._sync_fn ,
243+ timeout_secs = self ._timeout_secs ,
221244 primary_host = self ._primary_host ,
222245 )
223246 except Exception as e : # pylint: disable=broad-exception-caught
0 commit comments