@@ -63,13 +63,25 @@ def _on_commit_callback(
6363 )
6464
6565
66+ def _check_for_timeout (
67+ thread_start_time : float ,
68+ timeout_secs : int ,
69+ ):
70+ """Checks if the timeout has been exceeded."""
71+ time_remaining_secs = timeout_secs - (time .time () - thread_start_time )
72+ if time_remaining_secs <= 0 :
73+ raise TimeoutError ()
74+ return time_remaining_secs
75+
76+
6677def _background_wait_for_commit_futures (
6778 directory : epath .Path ,
6879 commit_futures : Sequence [future .Future ],
6980 on_commit_callback : Callable [[], None ],
7081 * ,
7182 barrier_sync_key_prefix : str ,
72- sync_fn : Callable [[str ], None ],
83+ sync_fn : Callable [[str , int ], None ],
84+ timeout_secs : int ,
7385 primary_host : int | None ,
7486):
7587 """A function to be run in a background thread that waits for futures."""
@@ -85,7 +97,8 @@ def _background_wait_for_commit_futures(
8597
8698 # Wait for commit operations to complete.
8799 for commit_future in commit_futures :
88- commit_future .result ()
100+ time_remaining_secs = _check_for_timeout (thread_start_time , timeout_secs )
101+ commit_future .result (timeout = int (time_remaining_secs ))
89102 commit_duration_secs = time .time () - thread_start_time
90103 logging .info (
91104 '[process=%s][thread=%s] %d Handler Commit operations completed. Time'
@@ -111,12 +124,14 @@ def _background_wait_for_commit_futures(
111124 # All processes will wait at the barrier. When all processes are at the
112125 # barrier, the barrier will be satisfied. If not, then it will timeout.
113126 try :
127+ time_remaining_secs = _check_for_timeout (thread_start_time , timeout_secs )
114128 sync_fn (
115129 multihost .unique_barrier_key (
116130 'async_write_complete' ,
117131 prefix = barrier_sync_key_prefix ,
118132 suffix = f'{ directory .name } ' ,
119- )
133+ ),
134+ int (time_remaining_secs * 1000 ),
120135 )
121136 except jax .errors .JaxRuntimeError as e :
122137 if sys .version_info >= (3 , 11 ):
@@ -128,12 +143,14 @@ def _background_wait_for_commit_futures(
128143 on_commit_callback ()
129144 if process_count > 1 :
130145 # Block until process 0 completes on_commit_callback.
146+ time_remaining_secs = _check_for_timeout (thread_start_time , timeout_secs )
131147 sync_fn (
132148 multihost .unique_barrier_key (
133149 'async_commit_complete' ,
134150 prefix = barrier_sync_key_prefix ,
135151 suffix = f'{ directory .name } ' ,
136- )
152+ ),
153+ int (time_remaining_secs * 1000 ),
137154 )
138155
139156 thread_duration_secs = time .time () - thread_start_time
@@ -190,9 +207,8 @@ def __init__(
190207 self ._thread = None
191208 self ._exception = None
192209
193- timeout_in_ms = self ._timeout_secs * 1000
194- self ._sync_fn : Callable [[str ], None ] = lambda key : barrier_sync_fn (
195- key = key , timeout_ms = timeout_in_ms
210+ self ._sync_fn : Callable [[str , int ], None ] = (
211+ lambda key , timeout_ms : barrier_sync_fn (key = key , timeout_ms = timeout_ms )
196212 )
197213
198214 def __del__ (self ):
@@ -218,6 +234,7 @@ def _thread_func(
218234 on_commit_callback ,
219235 barrier_sync_key_prefix = self ._barrier_sync_key_prefix ,
220236 sync_fn = self ._sync_fn ,
237+ timeout_secs = self ._timeout_secs ,
221238 primary_host = self ._primary_host ,
222239 )
223240 except Exception as e : # pylint: disable=broad-exception-caught
0 commit comments