17
17
"""
18
18
19
19
import logging
20
+ import queue
20
21
import threading
21
22
from abc import ABC
22
23
from datetime import timedelta
23
- from typing import TYPE_CHECKING , Dict , List , Optional , Type
24
+ from typing import TYPE_CHECKING , Dict , List , Optional , Type , Union
24
25
25
26
import torch
26
27
import torch .distributed as dist
53
54
_FUTURE_EXCEPTION = "fut_exception"
54
55
55
56
56
- def _get (queue : mp .Queue , timeout : float ) -> object :
57
- v = queue .get (timeout = timeout )
57
+ def _get (q : mp .Queue , timeout : Union [float , timedelta ]) -> object :
58
+ """
59
+ Gets an item from a queue with a timeout. If the timeout is exceeded then
60
+ a TimeoutError is raised.
61
+
62
+ If an exception is returned from the queue then it is raised.
63
+
64
+ Args:
65
+ q: queue to get from
66
+ timeout: timeout in seconds
67
+ """
68
+ if isinstance (timeout , timedelta ):
69
+ timeout = timeout .total_seconds ()
70
+ try :
71
+ v = q .get (timeout = timeout )
72
+ except queue .Empty as e :
73
+ raise TimeoutError (f"queue.get() timed out after { timeout } seconds" ) from e
58
74
if isinstance (v , Exception ):
59
75
raise v
60
76
return v
@@ -95,6 +111,9 @@ def configure(self, store_addr: str, rank: int, world_size: int) -> None:
95
111
Every time this is called it must be provided with a unique prefixed
96
112
store address. I.e. localhost:1234/my/prefix/1
97
113
114
+ This function will block until the underlying ProcessGroup is created.
115
+ If an error occurs this will throw.
116
+
98
117
Args:
99
118
store_addr: address of the store to use
100
119
rank: rank of this process
@@ -187,7 +206,6 @@ def __repr__(self) -> str:
187
206
188
207
189
208
class ProcessGroupWrapper (ProcessGroup ):
190
- PG_CLASS : Type [BaseProcessGroup ] # pyre-fixme[13]: never initialized
191
209
"""
192
210
This is a wrapper around any ProcessGroup with a reconfiguration method.
193
211
"""
@@ -209,9 +227,10 @@ def configure(self, store_addr: str, rank: int, world_size: int) -> None:
209
227
210
228
store = create_store_client (store_addr )
211
229
212
- # TODO: set global timeout
213
- # pyre-fixme[20]: expects argument options
214
- self ._pg = self .PG_CLASS (store , rank , world_size )
230
+ self ._pg = self ._create_pg (store , rank , world_size )
231
+
232
+ def _create_pg (self , store : Store , rank : int , world_size : int ) -> BaseProcessGroup :
233
+ raise NotImplementedError ("not implemented" )
215
234
216
235
def allreduce (self , tensors : List [torch .Tensor ], opts : object ) -> Work :
217
236
return self .parent .allreduce (tensors , opts )
@@ -244,9 +263,13 @@ class ProcessGroupGloo(ProcessGroupWrapper):
244
263
This is a reconfigurable version of ProcessGroupGloo.
245
264
"""
246
265
247
- PG_CLASS : Type [BaseProcessGroup ] = (
248
- BaseProcessGroupGloo # pyre-fixme[16]: no attribute ProcessGroupGloo
249
- )
266
+ def __init__ (self , timeout : timedelta = timedelta (seconds = 60.0 )) -> None :
267
+ super ().__init__ ()
268
+ self ._timeout = timeout
269
+
270
+ def _create_pg (self , store : Store , rank : int , world_size : int ) -> BaseProcessGroup :
271
+ # pyre-fixme[16]: no attribute ProcessGroupGloo
272
+ return BaseProcessGroupGloo (store , rank , world_size , self ._timeout )
250
273
251
274
def getBackendName (self ) -> str :
252
275
return "torchft-gloo"
@@ -263,9 +286,9 @@ class ProcessGroupNCCL(ProcessGroupWrapper):
263
286
abort when reconfiguring, we need to ensure this is safe.
264
287
"""
265
288
266
- PG_CLASS : Type [ BaseProcessGroup ] = (
267
- BaseProcessGroupNCCL # pyre-fixme[16]: no attribute ProcessGroupNCCL
268
- )
289
+ def _create_pg ( self , store : Store , rank : int , world_size : int ) -> BaseProcessGroup :
290
+ # pyre-fixme[16]: no attribute ProcessGroupNCCL
291
+ return BaseProcessGroupNCCL ( store , rank , world_size )
269
292
270
293
def getBackendName (self ) -> str :
271
294
return "torchft-nccl"
@@ -546,10 +569,9 @@ class ProcessGroupBaby(ProcessGroup):
546
569
547
570
"""
548
571
549
- PG_CLASS : Type [BaseProcessGroup ] # pyre-fixme[13]: never initialized
550
572
WORK_CLASS : Type [_BabyWork ] = _BabyWork
551
573
552
- def __init__ (self , timeout : float = 60.0 ) -> None :
574
+ def __init__ (self , timeout : Union [ float , timedelta ] = 60.0 ) -> None :
553
575
super ().__init__ (0 , 1 )
554
576
555
577
self ._world_size = - 1
@@ -562,7 +584,10 @@ def __init__(self, timeout: float = 60.0) -> None:
562
584
self ._futures : Dict [int , Future [object ]] = {}
563
585
self ._futures_lock = threading .Lock ()
564
586
565
- self ._timeout = timeout
587
+ if isinstance (timeout , timedelta ):
588
+ timeout = timeout .total_seconds ()
589
+
590
+ self ._timeout : float = timeout
566
591
567
592
def configure (self , store_addr : str , rank : int , world_size : int ) -> None :
568
593
if self ._p is not None :
@@ -581,7 +606,7 @@ def configure(self, store_addr: str, rank: int, world_size: int) -> None:
581
606
582
607
ctx = mp .get_context ("spawn" )
583
608
self ._tx = ctx .Queue ()
584
- self ._rx = ctx .Queue ()
609
+ self ._rx = rx = ctx .Queue ()
585
610
586
611
# futures need thread to fire callbacks
587
612
self ._future_queue = ctx .Queue ()
@@ -602,6 +627,17 @@ def configure(self, store_addr: str, rank: int, world_size: int) -> None:
602
627
)
603
628
self ._p .start ()
604
629
630
+ # fetch the status of the PG init
631
+ # if an exception was returned _get will throw
632
+ assert _get (rx , self ._timeout ) is None
633
+
634
+ @classmethod
635
+ def _create_pg (cls , store : Store , rank : int , world_size : int ) -> BaseProcessGroup :
636
+ """
637
+ This is a class method to avoid pickling the class.
638
+ """
639
+ raise NotImplementedError ("not implemented" )
640
+
605
641
@classmethod
606
642
def _worker (
607
643
cls ,
@@ -615,8 +651,13 @@ def _worker(
615
651
try :
616
652
store = create_store_client (store_addr )
617
653
618
- # pyre-fixme[20]: expects argument options
619
- pg = cls .PG_CLASS (store , rank , world_size )
654
+ try :
655
+ pg = cls ._create_pg (store , rank , world_size )
656
+ except Exception as e :
657
+ logger .exception (f"got exception in worker: { e } " )
658
+ tx .put (e )
659
+ return
660
+ tx .put (None )
620
661
621
662
work = {}
622
663
next_op_id : int = 0
@@ -737,9 +778,10 @@ class ProcessGroupBabyGloo(ProcessGroupBaby):
737
778
ProcessGroupBabyNCCL.
738
779
"""
739
780
740
- PG_CLASS : Type [BaseProcessGroup ] = (
741
- BaseProcessGroupGloo # pyre-fixme[16]: no attribute ProcessGroupGloo
742
- )
781
+ @classmethod
782
+ def _create_pg (cls , store : Store , rank : int , world_size : int ) -> BaseProcessGroup :
783
+ # pyre-fixme[16]: no attribute ProcessGroupGloo
784
+ return BaseProcessGroupGloo (store , rank , world_size )
743
785
744
786
def getBackendName (self ) -> str :
745
787
return "torchft-baby-gloo"
@@ -761,11 +803,13 @@ class ProcessGroupBabyNCCL(ProcessGroupBaby):
761
803
tensors may leak in the current PyTorch implementation. TODO fix
762
804
"""
763
805
764
- PG_CLASS : Type [BaseProcessGroup ] = (
765
- BaseProcessGroupNCCL # pyre-fixme[16]: no attribute ProcessGroupNCCL
766
- )
767
806
WORK_CLASS = _BabyWorkNCCL
768
807
808
+ @classmethod
809
+ def _create_pg (cls , store : Store , rank : int , world_size : int ) -> BaseProcessGroup :
810
+ # pyre-fixme[16]: no attribute ProcessGroupNCCL
811
+ return BaseProcessGroupNCCL (store , rank , world_size )
812
+
769
813
def getBackendName (self ) -> str :
770
814
return "torchft-baby-nccl"
771
815
0 commit comments