19
19
import logging
20
20
import queue
21
21
import threading
22
+ from dataclasses import dataclass
22
23
from datetime import timedelta
23
- from typing import TYPE_CHECKING , Dict , List , Optional , Tuple , Type , Union
24
+ from typing import (
25
+ TYPE_CHECKING ,
26
+ Any ,
27
+ Callable ,
28
+ Dict ,
29
+ List ,
30
+ Optional ,
31
+ Tuple ,
32
+ Type ,
33
+ TypeVar ,
34
+ Union ,
35
+ )
24
36
25
37
import torch
26
38
import torch .distributed as dist
29
41
# pyre-fixme[21]: no attribute ProcessGroupNCCL
30
42
# pyre-fixme[21]: no attribute ProcessGroupGloo
31
43
from torch .distributed import (
32
- BroadcastOptions ,
33
44
DeviceMesh ,
34
45
PrefixStore ,
35
46
ProcessGroup as BaseProcessGroup ,
40
51
get_rank ,
41
52
init_device_mesh ,
42
53
)
43
- from torch .distributed .distributed_c10d import Work , _world
54
+ from torch .distributed .distributed_c10d import (
55
+ AllgatherOptions ,
56
+ AllreduceOptions ,
57
+ BroadcastOptions ,
58
+ ReduceOp ,
59
+ Work ,
60
+ _world ,
61
+ )
44
62
from torch .futures import Future
45
63
46
64
if TYPE_CHECKING :
54
72
_FUTURE_EXCEPTION = "fut_exception"
55
73
56
74
75
+ T = TypeVar ("T" )
76
+
77
+
57
78
def _get (q : mp .Queue , timeout : Union [float , timedelta ]) -> object :
58
79
"""
59
80
Gets an item from a queue with a timeout. If the timeout is exceeded then
@@ -122,15 +143,17 @@ def configure(self, store_addr: str, rank: int, world_size: int) -> None:
122
143
raise NotImplementedError ("not implemented" )
123
144
124
145
# pyre-fixme[14]: inconsistent override
125
- def allreduce (self , tensors : List [torch .Tensor ], opts : object ) -> Work :
146
+ def allreduce (
147
+ self , tensors : List [torch .Tensor ], opts : Union [AllreduceOptions , ReduceOp ]
148
+ ) -> Work :
126
149
raise NotImplementedError ("not implemented" )
127
150
128
151
# pyre-fixme[14]: inconsistent override
129
152
def allgather (
130
153
self ,
131
154
output_tensors : List [List [torch .Tensor ]],
132
155
input_tensor : List [torch .Tensor ],
133
- opts : object ,
156
+ opts : AllgatherOptions ,
134
157
) -> Work :
135
158
"""
136
159
Gathers tensors from the whole group in a list.
@@ -140,7 +163,9 @@ def allgather(
140
163
raise NotImplementedError ("not implemented" )
141
164
142
165
# pyre-fixme[14]: inconsistent override
143
- def broadcast (self , tensor_list : List [torch .Tensor ], opts : object ) -> Work :
166
+ def broadcast (
167
+ self , tensor_list : List [torch .Tensor ], opts : BroadcastOptions
168
+ ) -> Work :
144
169
"""
145
170
Broadcasts the tensor to the whole group.
146
171
@@ -567,6 +592,9 @@ def wait(self, timeout: Optional[timedelta] = None) -> bool:
567
592
def get_future (self ) -> Future [object ]:
568
593
return self ._pg ._get_future (self ._op_id )
569
594
595
+ def __del__ (self ) -> None :
596
+ self ._tx .put (("del" , self ._op_id ), timeout = self ._timeout )
597
+
570
598
571
599
class _BabyWorkNCCL (_BabyWork ):
572
600
def wait (self , timeout : Optional [timedelta ] = None ) -> bool :
@@ -695,15 +723,18 @@ def _worker(
695
723
cmd = op [0 ]
696
724
if cmd == "func" :
697
725
func_name , args , kwargs = op [1 :]
726
+ args = _PickleSafeOptions .unsafe_args (args )
698
727
fn = getattr (pg , func_name )
699
728
work [next_op_id ] = fn (* args , ** kwargs )
700
729
tx .put (next_op_id )
701
730
next_op_id += 1
702
731
elif cmd == "wait" :
703
732
op_id : int = op [1 ]
704
733
work [op_id ].wait ()
705
- del work [op_id ]
706
734
tx .put (op_id )
735
+ elif cmd == "del" :
736
+ op_id : int = op [1 ]
737
+ del work [op_id ]
707
738
elif cmd == "future" :
708
739
op_id : int = op [1 ]
709
740
@@ -775,7 +806,10 @@ def _run_func(self, func: str, *args: object, **kwargs: object) -> Work:
775
806
assert rx is not None
776
807
assert tx is not None
777
808
778
- tx .put (("func" , func , args , kwargs ), timeout = self ._timeout )
809
+ tx .put (
810
+ ("func" , func , _PickleSafeOptions .safe_args (args ), kwargs ),
811
+ timeout = self ._timeout ,
812
+ )
779
813
780
814
op_id = _get (rx , self ._timeout )
781
815
assert isinstance (op_id , int ), f"invalid return { op_id } "
@@ -784,7 +818,11 @@ def _run_func(self, func: str, *args: object, **kwargs: object) -> Work:
784
818
pg = self , tx = tx , rx = rx , op_id = op_id , timeout = self ._timeout
785
819
)
786
820
787
- def allreduce (self , tensors : List [torch .Tensor ], opts : object ) -> Work :
821
+ def allreduce (
822
+ self ,
823
+ tensors : List [torch .Tensor ],
824
+ opts : Union [dist .AllreduceOptions , dist .ReduceOp ],
825
+ ) -> Work :
788
826
assert isinstance (tensors , list ), "input must be list"
789
827
790
828
for tensor in tensors :
@@ -793,10 +831,84 @@ def allreduce(self, tensors: List[torch.Tensor], opts: object) -> Work:
793
831
794
832
return self ._run_func ("allreduce" , tensors , opts )
795
833
834
+ def allgather (
835
+ self ,
836
+ output_tensors : List [List [torch .Tensor ]],
837
+ input_tensor : List [torch .Tensor ],
838
+ opts : AllgatherOptions ,
839
+ ) -> Work :
840
+ assert isinstance (output_tensors , list ), "input must be list"
841
+ assert isinstance (input_tensor , list ), "input must be list"
842
+
843
+ for tensor_list in output_tensors :
844
+ for tensor in tensor_list :
845
+ if not tensor .is_shared ():
846
+ tensor .share_memory_ ()
847
+
848
+ for tensor in input_tensor :
849
+ if not tensor .is_shared ():
850
+ tensor .share_memory_ ()
851
+
852
+ return self ._run_func ("allgather" , output_tensors , input_tensor , opts )
853
+
854
+ def broadcast (
855
+ self ,
856
+ tensor_list : List [torch .Tensor ],
857
+ opts : BroadcastOptions ,
858
+ ) -> Work :
859
+ assert isinstance (tensor_list , list ), "input must be list"
860
+
861
+ for tensor in tensor_list :
862
+ if not tensor .is_shared ():
863
+ tensor .share_memory_ ()
864
+
865
+ return self ._run_func ("broadcast" , tensor_list , opts )
866
+
796
867
def size (self ) -> int :
797
868
return self ._world_size
798
869
799
870
871
+ @dataclass
872
+ class _PickleSafeOptions :
873
+ func : Callable [[], object ]
874
+ fields : Dict [str , object ]
875
+
876
+ @classmethod
877
+ def safe_args (cls , args : T ) -> T :
878
+ if isinstance (args , tuple ):
879
+ return tuple (cls .safe_args (arg ) for arg in args )
880
+ elif isinstance (args , list ):
881
+ return [cls .safe_args (arg ) for arg in args ]
882
+ elif isinstance (args , (AllreduceOptions , AllgatherOptions , BroadcastOptions )):
883
+ return cls .from_torch (args )
884
+ else :
885
+ return args
886
+
887
+ @classmethod
888
+ def unsafe_args (cls , args : T ) -> T :
889
+ if isinstance (args , tuple ):
890
+ return tuple (cls .unsafe_args (arg ) for arg in args )
891
+ elif isinstance (args , list ):
892
+ return [cls .unsafe_args (arg ) for arg in args ]
893
+ elif isinstance (args , cls ):
894
+ return args .to_torch ()
895
+ else :
896
+ return args
897
+
898
+ @classmethod
899
+ def from_torch (cls , opts : object ) -> "_PickleSafeOptions" :
900
+ return cls (
901
+ func = opts .__class__ ,
902
+ fields = {k : getattr (opts , k ) for k in dir (opts ) if not k .startswith ("_" )},
903
+ )
904
+
905
+ def to_torch (self ) -> object :
906
+ opts = self .func ()
907
+ for k , v in self .fields .items ():
908
+ setattr (opts , k , v )
909
+ return opts
910
+
911
+
800
912
class ProcessGroupBabyGloo (ProcessGroupBaby ):
801
913
"""
802
914
This is a ProcessGroup that runs Gloo in a subprocess.
0 commit comments