20
20
import threading
21
21
from abc import ABC
22
22
from datetime import timedelta
23
- from typing import TYPE_CHECKING , Dict , List , Optional , Type
23
+ from typing import Dict , List , Optional , Tuple , Type , TYPE_CHECKING , Union
24
24
25
25
import torch
26
26
import torch .distributed as dist
31
31
from torch .distributed import (
32
32
BroadcastOptions ,
33
33
DeviceMesh ,
34
+ get_rank ,
35
+ init_device_mesh ,
34
36
PrefixStore ,
35
37
ProcessGroup as BaseProcessGroup ,
36
38
ProcessGroupGloo as BaseProcessGroupGloo ,
37
39
ProcessGroupNCCL as BaseProcessGroupNCCL ,
38
40
Store ,
39
41
TCPStore ,
40
- get_rank ,
41
42
)
42
- from torch .distributed .distributed_c10d import Work , _world
43
+ from torch .distributed .distributed_c10d import _world , Work
43
44
from torch .futures import Future
44
45
45
46
if TYPE_CHECKING :
@@ -130,17 +131,7 @@ def size(self) -> int:
130
131
def getBackendName (self ) -> str :
131
132
raise NotImplementedError ("not implemented" )
132
133
133
- def register (self , name : str ) -> "ProcessGroup" :
134
- """
135
- Registers the process group with the global registry. This enables usage
136
- with things like functional_collectives which are compilable.
137
-
138
- This should only be called once.
139
-
140
- Args:
141
- name: name must be a unique name for this process group
142
- """
143
-
134
+ def _register (self , name : str ) -> str :
144
135
group_name = f"{ self .getBackendName ()} :{ name } "
145
136
146
137
# This is needed for DeviceMesh and functional collectives to work.
@@ -158,6 +149,21 @@ def create_pg(
158
149
devices = ["cpu" ]
159
150
dist .Backend .register_backend (group_name , create_pg , devices = devices )
160
151
152
+ return group_name
153
+
154
+ def register (self , name : str ) -> "ProcessGroup" :
155
+ """
156
+ Registers the process group with the global registry. This enables usage
157
+ with things like functional_collectives which are compilable.
158
+
159
+ This should only be called once.
160
+
161
+ Args:
162
+ name: name must be a unique name for this process group
163
+ """
164
+
165
+ group_name = self ._register (name )
166
+
161
167
return dist .new_group (
162
168
ranks = [dist .get_rank ()],
163
169
backend = group_name ,
@@ -244,9 +250,9 @@ class ProcessGroupGloo(ProcessGroupWrapper):
244
250
This is a reconfigurable version of ProcessGroupGloo.
245
251
"""
246
252
247
- PG_CLASS : Type [BaseProcessGroup ] = (
248
- BaseProcessGroupGloo # pyre-fixme[16]: no attribute ProcessGroupGloo
249
- )
253
+ PG_CLASS : Type [
254
+ BaseProcessGroup
255
+ ] = BaseProcessGroupGloo # pyre-fixme[16]: no attribute ProcessGroupGloo
250
256
251
257
def getBackendName (self ) -> str :
252
258
return "torchft-gloo"
@@ -263,9 +269,9 @@ class ProcessGroupNCCL(ProcessGroupWrapper):
263
269
abort when reconfiguring, we need to ensure this is safe.
264
270
"""
265
271
266
- PG_CLASS : Type [BaseProcessGroup ] = (
267
- BaseProcessGroupNCCL # pyre-fixme[16]: no attribute ProcessGroupNCCL
268
- )
272
+ PG_CLASS : Type [
273
+ BaseProcessGroup
274
+ ] = BaseProcessGroupNCCL # pyre-fixme[16]: no attribute ProcessGroupNCCL
269
275
270
276
def getBackendName (self ) -> str :
271
277
return "torchft-nccl"
@@ -496,6 +502,9 @@ def allreduce(self, tensors: List[torch.Tensor], opts: object) -> Work:
496
502
def size (self ) -> int :
497
503
return self ._manager .num_participants ()
498
504
505
+ def getBackendName (self ) -> str :
506
+ return self ._manager ._pg .getBackendName ()
507
+
499
508
500
509
class _BabyWork (Work ):
501
510
def __init__ (
@@ -689,7 +698,6 @@ def _future_handler(self, future_queue: mp.Queue) -> None:
689
698
logger .exception (f"got unexpected error in future handler: { e } " )
690
699
691
700
def _get_future (self , op_id : int ) -> Future [object ]:
692
-
693
701
with self ._futures_lock :
694
702
fut = Future () # pyre-fixme[29]: is not a function
695
703
self ._futures [op_id ] = fut
@@ -737,9 +745,9 @@ class ProcessGroupBabyGloo(ProcessGroupBaby):
737
745
ProcessGroupBabyNCCL.
738
746
"""
739
747
740
- PG_CLASS : Type [BaseProcessGroup ] = (
741
- BaseProcessGroupGloo # pyre-fixme[16]: no attribute ProcessGroupGloo
742
- )
748
+ PG_CLASS : Type [
749
+ BaseProcessGroup
750
+ ] = BaseProcessGroupGloo # pyre-fixme[16]: no attribute ProcessGroupGloo
743
751
744
752
def getBackendName (self ) -> str :
745
753
return "torchft-baby-gloo"
@@ -761,9 +769,9 @@ class ProcessGroupBabyNCCL(ProcessGroupBaby):
761
769
tensors may leak in the current PyTorch implementation. TODO fix
762
770
"""
763
771
764
- PG_CLASS : Type [BaseProcessGroup ] = (
765
- BaseProcessGroupNCCL # pyre-fixme[16]: no attribute ProcessGroupNCCL
766
- )
772
+ PG_CLASS : Type [
773
+ BaseProcessGroup
774
+ ] = BaseProcessGroupNCCL # pyre-fixme[16]: no attribute ProcessGroupNCCL
767
775
WORK_CLASS = _BabyWorkNCCL
768
776
769
777
def getBackendName (self ) -> str :
@@ -797,3 +805,184 @@ def extend_device_mesh(
797
805
mesh = mesh .mesh .unsqueeze (dim ),
798
806
mesh_dim_names = tuple (mesh_dim_names ),
799
807
)
808
+
809
+
810
+ class ManagedDeviceMesh (DeviceMesh ):
811
+ def __init__ (
812
+ self ,
813
+ mesh : Optional [DeviceMesh ],
814
+ mesh_dim_names : Tuple [str ],
815
+ replicate_pg : ManagedProcessGroup ,
816
+ replicate_dim : int ,
817
+ parent : Optional ["ManagedDeviceMesh" ],
818
+ ):
819
+ self .mesh = mesh
820
+ self .mesh_dim_names = mesh_dim_names
821
+ self .replicate_pg = replicate_pg
822
+ self .replicate_dim = replicate_dim
823
+ self .replicate_dim_name = mesh_dim_names [replicate_dim ]
824
+ self .parent = parent
825
+ self .flatten_meshes = {}
826
+
827
+ def __getitem__ (self , mesh_dim_names : Union [str , Tuple [str , ...]]) -> DeviceMesh :
828
+ if isinstance (mesh_dim_names , str ):
829
+ if mesh_dim_names == self .replicate_dim_name :
830
+ return ManagedDeviceMesh (
831
+ mesh = None ,
832
+ mesh_dim_names = (mesh_dim_names ,),
833
+ replicate_pg = self .replicate_pg ,
834
+ replicate_dim = 0 ,
835
+ parent = self ,
836
+ )
837
+ elif mesh_dim_names in self .flatten_meshes :
838
+ return self .flatten_meshes [mesh_dim_names ]
839
+ else :
840
+ return self .mesh [mesh_dim_names ]
841
+ else :
842
+ assert isinstance (mesh_dim_names , tuple )
843
+ if self .replicate_dim_name in mesh_dim_names :
844
+ return self .mesh [mesh_dim_names ]
845
+ else :
846
+ return ManagedDeviceMesh (
847
+ self .mesh [mesh_dim_names ],
848
+ mesh_dim_names ,
849
+ self .replicate_pg ,
850
+ mesh_dim_name .index (self .replicate_dim_name ),
851
+ parent = self ,
852
+ )
853
+
854
+ def get_group (self , mesh_dim : Optional [str ] = None ) -> BaseProcessGroup :
855
+ if mesh_dim is None :
856
+ assert self .mesh is None
857
+ return self .replicate_pg
858
+ elif mesh_dim == self .replicate_dim_name :
859
+ return self .replicate_pg
860
+ else :
861
+ return self .mesh .get_group (mesh_dim )
862
+
863
+ def _flatten (self , mesh_dim_name : str ) -> "DeviceMesh" :
864
+ flatten_mesh = _FlattenDeviceMesh (self )
865
+ if self .parent is None :
866
+ self .flatten_meshes [mesh_dim_name ] = flatten_mesh
867
+ else :
868
+ self .parent .flatten_meshes [mesh_dim_name ] = flatten_mesh
869
+ return flatten_mesh
870
+
871
+ def size (self , mesh_dim : Optional [int ] = None ) -> int :
872
+ if mesh_dim is None :
873
+ if self .mesh is None :
874
+ return self .replicate_pg .size ()
875
+ else :
876
+ return self .mesh .size () * self .replicate_pg .size ()
877
+ elif mesh_dim == self .replicate_dim :
878
+ return self .replicate_pg .size ()
879
+ else :
880
+ return self .mesh .size (mesh_dim )
881
+
882
+ @property
883
+ def ndim (self ) -> int :
884
+ return self .mesh .ndim + 1
885
+
886
+ @property
887
+ def shape (self ) -> Tuple [int , ...]:
888
+ ret = list (self .mesh .shape )
889
+ ret .insert (self .replicate_dim , self .replicate_pg .size ())
890
+
891
+ def get_rank (self ) -> int :
892
+ return self .mesh .get_rank ()
893
+
894
+ def get_local_rank (self , mesh_dim : Optional [Union [int , str ]] = None ) -> int :
895
+ if mesh_dim is None :
896
+ if self .mesh is None :
897
+ return get_rank (self .replicate_pg )
898
+
899
+ assert self .replicate_dim == 0 , "replicate_dim must be the first one"
900
+ other_dim_size = self .mesh .size ()
901
+ other_dim_rank = self .mesh .get_local_rank ()
902
+ replicate_pg_rank = get_rank (self .replicate_pg )
903
+ return other_dim_size * replicate_pg_rank + other_dim_rank
904
+ elif mesh_dim in (self .replicate_dim , self .replicate_dim_name ):
905
+ return get_rank (self .replicate_pg )
906
+ else :
907
+ return self .mesh .get_local_rank (mesh_dim )
908
+
909
+ def get_all_groups (self ) -> List [ProcessGroup ]:
910
+ raise NotImplementedError
911
+
912
+
913
+ class _FlattenDeviceMesh (DeviceMesh ):
914
+ def __init__ (self , managed_mesh : ManagedDeviceMesh ):
915
+ self .managed_mesh = managed_mesh
916
+
917
+ def __getitem__ (self , mesh_dim_names : Union [str , Tuple [str , ...]]) -> DeviceMesh :
918
+ raise NotImplementedError
919
+
920
+ def get_group (self , mesh_dim : Optional [str ] = None ) -> BaseProcessGroup :
921
+ raise NotImplementedError
922
+
923
+ def _flatten (self , mesh_dim_name : str ) -> "DeviceMesh" :
924
+ raise NotImplementedError
925
+
926
+ def size (self , mesh_dim : Optional [int ] = None ) -> int :
927
+ assert mesh_dim is None
928
+ return self .managed_mesh .size ()
929
+
930
+ @property
931
+ def ndim (self ) -> int :
932
+ raise NotImplementedError
933
+
934
+ @property
935
+ def shape (self ) -> Tuple [int , ...]:
936
+ raise NotImplementedError
937
+
938
+ def get_rank (self ) -> int :
939
+ raise NotImplementedError
940
+
941
+ def get_local_rank (self , mesh_dim : Optional [Union [int , str ]] = None ) -> int :
942
+ assert mesh_dim is None
943
+ return self .managed_mesh .get_local_rank ()
944
+
945
+ def get_all_groups (self ) -> List [ProcessGroup ]:
946
+ raise NotImplementedError
947
+
948
+
949
+ def ft_init_device_mesh (
950
+ * ,
951
+ device_type : str ,
952
+ mesh_shape : Tuple [int , ...],
953
+ mesh_dim_names : Tuple [str , ...],
954
+ replicate_dim : int ,
955
+ manager : "Manager" ,
956
+ ):
957
+ # We have to lie DeviceMesh that the replicate_dim has only
958
+ # 1 rank.
959
+ _mesh_shape = list (mesh_shape )
960
+ _mesh_shape .pop (replicate_dim )
961
+ _mesh_dim_names = list (mesh_dim_names )
962
+ _mesh_dim_names .pop (replicate_dim )
963
+ mesh = init_device_mesh (
964
+ device_type ,
965
+ mesh_shape = tuple (_mesh_shape ),
966
+ mesh_dim_names = tuple (_mesh_dim_names ),
967
+ )
968
+
969
+ if device_type == "cpu" :
970
+ pg = ProcessGroupGloo ()
971
+ elif device_type == "cuda" :
972
+ pg = ProcessGroupNCCL ()
973
+ else :
974
+ raise ValueError ()
975
+
976
+ manager ._pg = pg
977
+ replicate_pg = ManagedProcessGroup (manager )
978
+ # We have to use MultiProcessTestCase, otherwise c10d will complain
979
+ # the same backend has been registered.
980
+ replicate_pg .register (mesh_dim_names [replicate_dim ])
981
+
982
+ return ManagedDeviceMesh (
983
+ mesh = mesh ,
984
+ mesh_dim_names = mesh_dim_names ,
985
+ replicate_pg = replicate_pg ,
986
+ replicate_dim = replicate_dim ,
987
+ parent = None ,
988
+ )
0 commit comments