19
19
import logging
20
20
import queue
21
21
import threading
22
+ from dataclasses import dataclass
22
23
from datetime import timedelta
23
- from typing import TYPE_CHECKING , Any , Dict , List , Optional , Tuple , Type , Union
24
+ from typing import (
25
+ TYPE_CHECKING ,
26
+ Any ,
27
+ Callable ,
28
+ Dict ,
29
+ List ,
30
+ Optional ,
31
+ Tuple ,
32
+ Type ,
33
+ TypeVar ,
34
+ Union ,
35
+ cast ,
36
+ )
24
37
25
38
import torch
26
39
import torch .distributed as dist
29
42
# pyre-fixme[21]: no attribute ProcessGroupNCCL
30
43
# pyre-fixme[21]: no attribute ProcessGroupGloo
31
44
from torch .distributed import (
32
- BroadcastOptions ,
33
45
DeviceMesh ,
34
46
PrefixStore ,
35
47
ProcessGroup as BaseProcessGroup ,
40
52
get_rank ,
41
53
init_device_mesh ,
42
54
)
43
- from torch .distributed .distributed_c10d import Work , _world
55
+ from torch .distributed .distributed_c10d import (
56
+ AllgatherOptions ,
57
+ AllreduceOptions ,
58
+ BroadcastOptions ,
59
+ ReduceOp ,
60
+ Work ,
61
+ _world ,
62
+ )
44
63
from torch .futures import Future
45
64
46
65
if TYPE_CHECKING :
54
73
_FUTURE_EXCEPTION = "fut_exception"
55
74
56
75
76
+ T = TypeVar ("T" )
77
+
78
+
57
79
def _get (q : mp .Queue , timeout : Union [float , timedelta ]) -> object :
58
80
"""
59
81
Gets an item from a queue with a timeout. If the timeout is exceeded then
@@ -122,15 +144,17 @@ def configure(self, store_addr: str, rank: int, world_size: int) -> None:
122
144
raise NotImplementedError ("not implemented" )
123
145
124
146
# pyre-fixme[14]: inconsistent override
125
- def allreduce (self , tensors : List [torch .Tensor ], opts : object ) -> Work :
147
+ def allreduce (
148
+ self , tensors : List [torch .Tensor ], opts : Union [AllreduceOptions , ReduceOp ]
149
+ ) -> Work :
126
150
raise NotImplementedError ("not implemented" )
127
151
128
152
# pyre-fixme[14]: inconsistent override
129
153
def allgather (
130
154
self ,
131
155
output_tensors : List [List [torch .Tensor ]],
132
156
input_tensor : List [torch .Tensor ],
133
- opts : object ,
157
+ opts : AllgatherOptions ,
134
158
) -> Work :
135
159
"""
136
160
Gathers tensors from the whole group in a list.
@@ -140,7 +164,9 @@ def allgather(
140
164
raise NotImplementedError ("not implemented" )
141
165
142
166
# pyre-fixme[14]: inconsistent override
143
- def broadcast (self , tensor_list : List [torch .Tensor ], opts : object ) -> Work :
167
+ def broadcast (
168
+ self , tensor_list : List [torch .Tensor ], opts : BroadcastOptions
169
+ ) -> Work :
144
170
"""
145
171
Broadcasts the tensor to the whole group.
146
172
@@ -567,6 +593,9 @@ def wait(self, timeout: Optional[timedelta] = None) -> bool:
567
593
def get_future (self ) -> Future [object ]:
568
594
return self ._pg ._get_future (self ._op_id )
569
595
596
+ def __del__ (self ) -> None :
597
+ self ._tx .put (("del" , self ._op_id ), timeout = self ._timeout )
598
+
570
599
571
600
class _BabyWorkNCCL (_BabyWork ):
572
601
def wait (self , timeout : Optional [timedelta ] = None ) -> bool :
@@ -695,15 +724,18 @@ def _worker(
695
724
cmd = op [0 ]
696
725
if cmd == "func" :
697
726
func_name , args , kwargs = op [1 :]
727
+ args = _PickleSafeOptions .unsafe_args (args )
698
728
fn = getattr (pg , func_name )
699
729
work [next_op_id ] = fn (* args , ** kwargs )
700
730
tx .put (next_op_id )
701
731
next_op_id += 1
702
732
elif cmd == "wait" :
703
733
op_id : int = op [1 ]
704
734
work [op_id ].wait ()
705
- del work [op_id ]
706
735
tx .put (op_id )
736
+ elif cmd == "del" :
737
+ op_id : int = op [1 ]
738
+ del work [op_id ]
707
739
elif cmd == "future" :
708
740
op_id : int = op [1 ]
709
741
@@ -731,6 +763,8 @@ def callback(fut: Future[object]) -> None:
731
763
732
764
del work [op_id ]
733
765
tx .put ((op_id , event ))
766
+ elif cmd == "num_active_work" :
767
+ tx .put (len (work ))
734
768
else :
735
769
raise ValueError (f"unknown cmd: { cmd } " )
736
770
@@ -775,7 +809,10 @@ def _run_func(self, func: str, *args: object, **kwargs: object) -> Work:
775
809
assert rx is not None
776
810
assert tx is not None
777
811
778
- tx .put (("func" , func , args , kwargs ), timeout = self ._timeout )
812
+ tx .put (
813
+ ("func" , func , _PickleSafeOptions .safe_args (args ), kwargs ),
814
+ timeout = self ._timeout ,
815
+ )
779
816
780
817
op_id = _get (rx , self ._timeout )
781
818
assert isinstance (op_id , int ), f"invalid return { op_id } "
@@ -784,7 +821,11 @@ def _run_func(self, func: str, *args: object, **kwargs: object) -> Work:
784
821
pg = self , tx = tx , rx = rx , op_id = op_id , timeout = self ._timeout
785
822
)
786
823
787
- def allreduce (self , tensors : List [torch .Tensor ], opts : object ) -> Work :
824
+ def allreduce (
825
+ self ,
826
+ tensors : List [torch .Tensor ],
827
+ opts : Union [dist .AllreduceOptions , dist .ReduceOp ],
828
+ ) -> Work :
788
829
assert isinstance (tensors , list ), "input must be list"
789
830
790
831
for tensor in tensors :
@@ -793,9 +834,90 @@ def allreduce(self, tensors: List[torch.Tensor], opts: object) -> Work:
793
834
794
835
return self ._run_func ("allreduce" , tensors , opts )
795
836
837
+ def allgather (
838
+ self ,
839
+ output_tensors : List [List [torch .Tensor ]],
840
+ input_tensor : List [torch .Tensor ],
841
+ opts : AllgatherOptions ,
842
+ ) -> Work :
843
+ assert isinstance (output_tensors , list ), "input must be list"
844
+ assert isinstance (input_tensor , list ), "input must be list"
845
+
846
+ for tensor_list in output_tensors :
847
+ for tensor in tensor_list :
848
+ if not tensor .is_shared ():
849
+ tensor .share_memory_ ()
850
+
851
+ for tensor in input_tensor :
852
+ if not tensor .is_shared ():
853
+ tensor .share_memory_ ()
854
+
855
+ return self ._run_func ("allgather" , output_tensors , input_tensor , opts )
856
+
857
+ def broadcast (
858
+ self ,
859
+ tensor_list : List [torch .Tensor ],
860
+ opts : BroadcastOptions ,
861
+ ) -> Work :
862
+ assert isinstance (tensor_list , list ), "input must be list"
863
+
864
+ for tensor in tensor_list :
865
+ if not tensor .is_shared ():
866
+ tensor .share_memory_ ()
867
+
868
+ return self ._run_func ("broadcast" , tensor_list , opts )
869
+
796
870
def size (self ) -> int :
797
871
return self ._world_size
798
872
873
+ def num_active_work (self ) -> int :
874
+ assert self ._tx is not None
875
+ self ._tx .put (("num_active_work" ,), timeout = self ._timeout )
876
+
877
+ assert self ._rx is not None
878
+ return cast (int , _get (self ._rx , self ._timeout ))
879
+
880
+
881
+ @dataclass
882
+ class _PickleSafeOptions :
883
+ func : Callable [[], object ]
884
+ fields : Dict [str , object ]
885
+
886
+ @classmethod
887
+ def safe_args (cls , args : T ) -> T :
888
+ if isinstance (args , tuple ):
889
+ return tuple (cls .safe_args (arg ) for arg in args )
890
+ elif isinstance (args , list ):
891
+ return [cls .safe_args (arg ) for arg in args ]
892
+ elif isinstance (args , (AllreduceOptions , AllgatherOptions , BroadcastOptions )):
893
+ return cls .from_torch (args )
894
+ else :
895
+ return args
896
+
897
+ @classmethod
898
+ def unsafe_args (cls , args : T ) -> T :
899
+ if isinstance (args , tuple ):
900
+ return tuple (cls .unsafe_args (arg ) for arg in args )
901
+ elif isinstance (args , list ):
902
+ return [cls .unsafe_args (arg ) for arg in args ]
903
+ elif isinstance (args , cls ):
904
+ return args .to_torch ()
905
+ else :
906
+ return args
907
+
908
+ @classmethod
909
+ def from_torch (cls , opts : object ) -> "_PickleSafeOptions" :
910
+ return cls (
911
+ func = opts .__class__ ,
912
+ fields = {k : getattr (opts , k ) for k in dir (opts ) if not k .startswith ("_" )},
913
+ )
914
+
915
+ def to_torch (self ) -> object :
916
+ opts = self .func ()
917
+ for k , v in self .fields .items ():
918
+ setattr (opts , k , v )
919
+ return opts
920
+
799
921
800
922
class ProcessGroupBabyGloo (ProcessGroupBaby ):
801
923
"""
0 commit comments