@@ -102,7 +102,10 @@ def __init__(
102
102
min_replica_size: minimum number of replicas on each step
103
103
port: if rank==0, the port to run the manager server on
104
104
use_async_quorum: whether to run the quorum asynchronously during the forward pass
105
- timeout: timeout for all operations
105
+ timeout:
106
+ the default timeout for all operation, if you're using per
107
+ request timeouts this should be longer than the longest request
108
+ timeout.
106
109
rank: the replica group local rank
107
110
world_size: the replica group local world size
108
111
store_addr: TCPStore address for this replica group
@@ -279,7 +282,10 @@ def errored(self) -> Optional[Exception]:
279
282
return self ._errored
280
283
281
284
def wrap_future (
282
- self , fut : torch .futures .Future [T ], default : T
285
+ self ,
286
+ fut : torch .futures .Future [T ],
287
+ default : T ,
288
+ timeout : Optional [timedelta ] = None ,
283
289
) -> torch .futures .Future [T ]:
284
290
"""
285
291
Wrap a Future and swallow any errors that occur and report them to the manager.
@@ -289,10 +295,11 @@ def wrap_future(
289
295
Args:
290
296
fut: the Future to wrap
291
297
default: the default value to complete the Future with if an error occurs
298
+ timeout: the timeout for the Future, if None, the manager's timeout will be used
292
299
"""
293
300
294
301
# add a timeout to the future
295
- fut = future_timeout (fut , self ._timeout )
302
+ fut = future_timeout (fut , timeout or self ._timeout )
296
303
297
304
# schedule error handling as a continuation on the Future
298
305
def callback (
@@ -313,7 +320,12 @@ def callback(
313
320
self ._pending_work .append (cast (torch .futures .Future [object ], fut ))
314
321
return fut
315
322
316
- def start_quorum (self , room_id : str = "default" , allow_heal : bool = True ) -> None :
323
+ def start_quorum (
324
+ self ,
325
+ room_id : str = "default" ,
326
+ allow_heal : bool = True ,
327
+ timeout : Optional [timedelta ] = None ,
328
+ ) -> None :
317
329
"""
318
330
.. note::
319
331
We recommend using the :py:class:`torchft.optim.OptimizerWrapper` instead of calling this directly.
@@ -331,6 +343,7 @@ def start_quorum(self, room_id: str = "default", allow_heal: bool = True) -> Non
331
343
calls. All replicas must pass the same value to allow_heal.
332
344
room_id: (experimental) the room id to use for quorum, this allows
333
345
for multiple quorums to be used within the same job.
346
+ timeout: the timeout for quorum and recovery operations, if None, the manager's timeout will be used
334
347
"""
335
348
336
349
# wait for previous quorum to complete
@@ -345,7 +358,10 @@ def start_quorum(self, room_id: str = "default", allow_heal: bool = True) -> Non
345
358
# block to allow gracefully recovering from issues in PG setup and quorum.
346
359
347
360
self ._quorum_future = self ._executor .submit (
348
- self ._async_quorum , room_id = room_id , allow_heal = allow_heal
361
+ self ._async_quorum ,
362
+ room_id = room_id ,
363
+ allow_heal = allow_heal ,
364
+ timeout = timeout or self ._timeout ,
349
365
)
350
366
if not self ._use_async_quorum :
351
367
self .wait_quorum ()
@@ -369,7 +385,7 @@ def wait_quorum(self) -> None:
369
385
), "must call start_quorum before wait_quorum"
370
386
self ._quorum_future .result ()
371
387
372
- def _async_quorum (self , room_id : str , allow_heal : bool ) -> None :
388
+ def _async_quorum (self , room_id : str , allow_heal : bool , timeout : timedelta ) -> None :
373
389
(
374
390
quorum_id ,
375
391
replica_rank ,
@@ -385,6 +401,7 @@ def _async_quorum(self, room_id: str, allow_heal: bool) -> None:
385
401
rank = self ._rank ,
386
402
step = self ._step ,
387
403
checkpoint_server_addr = self ._ckpt_server .address (),
404
+ timeout = timeout ,
388
405
)
389
406
390
407
# When using async quorum we need to take the recovered workers.
@@ -422,8 +439,10 @@ def _async_quorum(self, room_id: str, allow_heal: bool) -> None:
422
439
self ._logger .info (
423
440
f"healing required, fetching checkpoint server address from { address = } { max_step = } "
424
441
)
425
- primary_client = ManagerClient (address , timeout = self ._timeout )
426
- checkpoint_server_address = primary_client .checkpoint_address (self ._rank )
442
+ primary_client = ManagerClient (address , timeout = timeout )
443
+ checkpoint_server_address = primary_client .checkpoint_address (
444
+ self ._rank , timeout = timeout
445
+ )
427
446
428
447
self ._logger .info (f"fetching checkpoint from { checkpoint_server_address = } " )
429
448
self ._pending_state_dict = CheckpointServer .load_from_address (
@@ -449,7 +468,7 @@ def _apply_pending_state_dict(self) -> None:
449
468
self ._load_state_dict (self ._pending_state_dict ["user" ])
450
469
self ._pending_state_dict = None
451
470
452
- def should_commit (self ) -> bool :
471
+ def should_commit (self , timeout : Optional [ timedelta ] = None ) -> bool :
453
472
"""
454
473
.. note::
455
474
We recommend using the :py:class:`torchft.optim.OptimizerWrapper` instead of calling this directly.
@@ -486,7 +505,10 @@ def should_commit(self) -> bool:
486
505
enough_replicas = self .num_participants () >= self ._min_replica_size
487
506
local_should_commit = enough_replicas and self ._errored is None
488
507
should_commit = self ._client .should_commit (
489
- self ._rank , self ._step , local_should_commit
508
+ self ._rank ,
509
+ self ._step ,
510
+ local_should_commit ,
511
+ timeout = timeout or self ._timeout ,
490
512
)
491
513
self ._logger .info (
492
514
f"should_commit={ should_commit } enough_replicas={ enough_replicas } , errored={ self ._errored } "
0 commit comments