17
17
"""
18
18
19
19
import logging
20
+ import queue
20
21
import threading
21
22
from abc import ABC
22
23
from datetime import timedelta
23
- from typing import TYPE_CHECKING , Dict , List , Optional , Type
24
+ from typing import TYPE_CHECKING , Dict , List , Optional , Type , Union
24
25
25
26
import torch
26
27
import torch .distributed as dist
53
54
_FUTURE_EXCEPTION = "fut_exception"
54
55
55
56
56
- def _get (queue : mp .Queue , timeout : float ) -> object :
57
- v = queue .get (timeout = timeout )
57
+ def _get (q : mp .Queue , timeout : Union [float , timedelta ]) -> object :
58
+ """
59
+ Gets an item from a queue with a timeout. If the timeout is exceeded then
60
+ a TimeoutError is raised.
61
+
62
+ If an exception is returned from the queue then it is raised.
63
+
64
+ Args:
65
+ q: queue to get from
66
+ timeout: timeout in seconds
67
+ """
68
+ if isinstance (timeout , timedelta ):
69
+ timeout = timeout .total_seconds ()
70
+ try :
71
+ v = q .get (timeout = timeout )
72
+ except queue .Empty as e :
73
+ raise TimeoutError (f"queue.get() timed out after { timeout } seconds" ) from e
58
74
if isinstance (v , Exception ):
59
75
raise v
60
76
return v
@@ -95,6 +111,9 @@ def configure(self, store_addr: str, rank: int, world_size: int) -> None:
95
111
Every time this is called it must be provided with a unique prefixed
96
112
store address. I.e. localhost:1234/my/prefix/1
97
113
114
+ This function will block until the underlying ProcessGroup is created.
115
+ If an error occurs this will throw.
116
+
98
117
Args:
99
118
store_addr: address of the store to use
100
119
rank: rank of this process
@@ -187,7 +206,6 @@ def __repr__(self) -> str:
187
206
188
207
189
208
class ProcessGroupWrapper (ProcessGroup ):
190
- PG_CLASS : Type [BaseProcessGroup ] # pyre-fixme[13]: never initialized
191
209
"""
192
210
This is a wrapper around any ProcessGroup with a reconfiguration method.
193
211
"""
@@ -209,9 +227,10 @@ def configure(self, store_addr: str, rank: int, world_size: int) -> None:
209
227
210
228
store = create_store_client (store_addr )
211
229
212
- # TODO: set global timeout
213
- # pyre-fixme[20]: expects argument options
214
- self ._pg = self .PG_CLASS (store , rank , world_size )
230
+ self ._pg = self ._create_pg (store , rank , world_size )
231
+
232
+ def _create_pg (self , store : Store , rank : int , world_size : int ) -> BaseProcessGroup :
233
+ raise NotImplementedError ("not implemented" )
215
234
216
235
def allreduce (self , tensors : List [torch .Tensor ], opts : object ) -> Work :
217
236
return self .parent .allreduce (tensors , opts )
@@ -244,9 +263,13 @@ class ProcessGroupGloo(ProcessGroupWrapper):
244
263
This is a reconfigurable version of ProcessGroupGloo.
245
264
"""
246
265
247
- PG_CLASS : Type [BaseProcessGroup ] = (
248
- BaseProcessGroupGloo # pyre-fixme[16]: no attribute ProcessGroupGloo
249
- )
266
+ def __init__ (self , timeout : timedelta = timedelta (seconds = 60.0 )) -> None :
267
+ super ().__init__ ()
268
+ self ._timeout = timeout
269
+
270
+ def _create_pg (self , store : Store , rank : int , world_size : int ) -> BaseProcessGroup :
271
+ # pyre-fixme[16]: no attribute ProcessGroupGloo
272
+ return BaseProcessGroupGloo (store , rank , world_size , self ._timeout )
250
273
251
274
def getBackendName (self ) -> str :
252
275
return "torchft-gloo"
@@ -263,9 +286,9 @@ class ProcessGroupNCCL(ProcessGroupWrapper):
263
286
abort when reconfiguring, we need to ensure this is safe.
264
287
"""
265
288
266
- PG_CLASS : Type [ BaseProcessGroup ] = (
267
- BaseProcessGroupNCCL # pyre-fixme[16]: no attribute ProcessGroupNCCL
268
- )
289
+ def _create_pg ( self , store : Store , rank : int , world_size : int ) -> BaseProcessGroup :
290
+ # pyre-fixme[16]: no attribute ProcessGroupNCCL
291
+ return BaseProcessGroupNCCL ( store , rank , world_size )
269
292
270
293
def getBackendName (self ) -> str :
271
294
return "torchft-nccl"
@@ -549,7 +572,7 @@ class ProcessGroupBaby(ProcessGroup):
549
572
PG_CLASS : Type [BaseProcessGroup ] # pyre-fixme[13]: never initialized
550
573
WORK_CLASS : Type [_BabyWork ] = _BabyWork
551
574
552
- def __init__ (self , timeout : float = 60.0 ) -> None :
575
+ def __init__ (self , timeout : Union [ float , timedelta ] = 60.0 ) -> None :
553
576
super ().__init__ (0 , 1 )
554
577
555
578
self ._world_size = - 1
@@ -562,7 +585,10 @@ def __init__(self, timeout: float = 60.0) -> None:
562
585
self ._futures : Dict [int , Future [object ]] = {}
563
586
self ._futures_lock = threading .Lock ()
564
587
565
- self ._timeout = timeout
588
+ if isinstance (timeout , timedelta ):
589
+ timeout = timeout .total_seconds ()
590
+
591
+ self ._timeout : float = timeout
566
592
567
593
def configure (self , store_addr : str , rank : int , world_size : int ) -> None :
568
594
if self ._p is not None :
@@ -581,7 +607,7 @@ def configure(self, store_addr: str, rank: int, world_size: int) -> None:
581
607
582
608
ctx = mp .get_context ("spawn" )
583
609
self ._tx = ctx .Queue ()
584
- self ._rx = ctx .Queue ()
610
+ self ._rx = rx = ctx .Queue ()
585
611
586
612
# futures need thread to fire callbacks
587
613
self ._future_queue = ctx .Queue ()
@@ -602,6 +628,10 @@ def configure(self, store_addr: str, rank: int, world_size: int) -> None:
602
628
)
603
629
self ._p .start ()
604
630
631
+ # fetch the status of the PG init
632
+ # if an exception was returned _get will throw
633
+ assert _get (rx , self ._timeout ) is None
634
+
605
635
@classmethod
606
636
def _worker (
607
637
cls ,
@@ -615,8 +645,14 @@ def _worker(
615
645
try :
616
646
store = create_store_client (store_addr )
617
647
618
- # pyre-fixme[20]: expects argument options
619
- pg = cls .PG_CLASS (store , rank , world_size )
648
+ try :
649
+ # pyre-fixme[20]: expects argument options
650
+ pg = cls .PG_CLASS (store , rank , world_size )
651
+ except Exception as e :
652
+ logger .exception (f"got exception in worker: { e } " )
653
+ tx .put (e )
654
+ return
655
+ tx .put (None )
620
656
621
657
work = {}
622
658
next_op_id : int = 0
0 commit comments