@@ -250,9 +250,9 @@ class ProcessGroupGloo(ProcessGroupWrapper):
250
250
This is a reconfigurable version of ProcessGroupGloo.
251
251
"""
252
252
253
- PG_CLASS : Type [
254
- BaseProcessGroup
255
- ] = BaseProcessGroupGloo # pyre-fixme[16]: no attribute ProcessGroupGloo
253
+ PG_CLASS : Type [BaseProcessGroup ] = (
254
+ BaseProcessGroupGloo # pyre-fixme[16]: no attribute ProcessGroupGloo
255
+ )
256
256
257
257
def getBackendName (self ) -> str :
258
258
return "torchft-gloo"
@@ -269,9 +269,9 @@ class ProcessGroupNCCL(ProcessGroupWrapper):
269
269
abort when reconfiguring, we need to ensure this is safe.
270
270
"""
271
271
272
- PG_CLASS : Type [
273
- BaseProcessGroup
274
- ] = BaseProcessGroupNCCL # pyre-fixme[16]: no attribute ProcessGroupNCCL
272
+ PG_CLASS : Type [BaseProcessGroup ] = (
273
+ BaseProcessGroupNCCL # pyre-fixme[16]: no attribute ProcessGroupNCCL
274
+ )
275
275
276
276
def getBackendName (self ) -> str :
277
277
return "torchft-nccl"
@@ -745,9 +745,9 @@ class ProcessGroupBabyGloo(ProcessGroupBaby):
745
745
ProcessGroupBabyNCCL.
746
746
"""
747
747
748
- PG_CLASS : Type [
749
- BaseProcessGroup
750
- ] = BaseProcessGroupGloo # pyre-fixme[16]: no attribute ProcessGroupGloo
748
+ PG_CLASS : Type [BaseProcessGroup ] = (
749
+ BaseProcessGroupGloo # pyre-fixme[16]: no attribute ProcessGroupGloo
750
+ )
751
751
752
752
def getBackendName (self ) -> str :
753
753
return "torchft-baby-gloo"
@@ -769,9 +769,9 @@ class ProcessGroupBabyNCCL(ProcessGroupBaby):
769
769
tensors may leak in the current PyTorch implementation. TODO fix
770
770
"""
771
771
772
- PG_CLASS : Type [
773
- BaseProcessGroup
774
- ] = BaseProcessGroupNCCL # pyre-fixme[16]: no attribute ProcessGroupNCCL
772
+ PG_CLASS : Type [BaseProcessGroup ] = (
773
+ BaseProcessGroupNCCL # pyre-fixme[16]: no attribute ProcessGroupNCCL
774
+ )
775
775
WORK_CLASS = _BabyWorkNCCL
776
776
777
777
def getBackendName (self ) -> str :
@@ -807,27 +807,34 @@ def extend_device_mesh(
807
807
)
808
808
809
809
810
- class ManagedDeviceMesh (DeviceMesh ):
810
+ class _ManagedDeviceMesh (DeviceMesh ):
811
811
def __init__ (
812
812
self ,
813
813
mesh : Optional [DeviceMesh ],
814
814
mesh_dim_names : Tuple [str ],
815
815
replicate_pg : ManagedProcessGroup ,
816
816
replicate_dim : int ,
817
- parent : Optional ["ManagedDeviceMesh " ],
817
+ parent : Optional ["_ManagedDeviceMesh " ],
818
818
):
819
+ if mesh is None and parent is not None :
820
+ raise ValueError (
821
+ "_ManagedDeviceMesh doesn't support both mesh and parent are None."
822
+ )
819
823
self .mesh = mesh
820
824
self .mesh_dim_names = mesh_dim_names
821
825
self .replicate_pg = replicate_pg
822
826
self .replicate_dim = replicate_dim
823
827
self .replicate_dim_name = mesh_dim_names [replicate_dim ]
824
828
self .parent = parent
825
829
self .flatten_meshes = {}
830
+ self .device_type = mesh .device_type if mesh is not None else parent .device_type
831
+ self ._flatten_mesh_list = tuple ()
832
+ self ._thread_id = None
826
833
827
834
def __getitem__ (self , mesh_dim_names : Union [str , Tuple [str , ...]]) -> DeviceMesh :
828
835
if isinstance (mesh_dim_names , str ):
829
836
if mesh_dim_names == self .replicate_dim_name :
830
- return ManagedDeviceMesh (
837
+ return _ManagedDeviceMesh (
831
838
mesh = None ,
832
839
mesh_dim_names = (mesh_dim_names ,),
833
840
replicate_pg = self .replicate_pg ,
@@ -843,22 +850,25 @@ def __getitem__(self, mesh_dim_names: Union[str, Tuple[str, ...]]) -> DeviceMesh
843
850
if self .replicate_dim_name in mesh_dim_names :
844
851
return self .mesh [mesh_dim_names ]
845
852
else :
846
- return ManagedDeviceMesh (
853
+ return _ManagedDeviceMesh (
847
854
self .mesh [mesh_dim_names ],
848
855
mesh_dim_names ,
849
856
self .replicate_pg ,
850
857
mesh_dim_name .index (self .replicate_dim_name ),
851
858
parent = self ,
852
859
)
853
860
861
+ def _real_mesh_dim (self , mesh_dim : int ) -> int :
862
+ return mesh_dim - 1 if mesh_dim > self .replicate_dim else mesh_dim
863
+
854
864
def get_group (self , mesh_dim : Optional [str ] = None ) -> BaseProcessGroup :
855
865
if mesh_dim is None :
856
866
assert self .mesh is None
857
867
return self .replicate_pg
858
868
elif mesh_dim == self .replicate_dim_name :
859
869
return self .replicate_pg
860
870
else :
861
- return self .mesh .get_group (mesh_dim )
871
+ return self .mesh .get_group (self . _real_mesh_dim ( mesh_dim ) )
862
872
863
873
def _flatten (self , mesh_dim_name : str ) -> "DeviceMesh" :
864
874
flatten_mesh = _FlattenDeviceMesh (self )
@@ -877,7 +887,7 @@ def size(self, mesh_dim: Optional[int] = None) -> int:
877
887
elif mesh_dim == self .replicate_dim :
878
888
return self .replicate_pg .size ()
879
889
else :
880
- return self .mesh .size (mesh_dim )
890
+ return self .mesh .size (self . _real_mesh_dim ( mesh_dim ) )
881
891
882
892
@property
883
893
def ndim (self ) -> int :
@@ -904,14 +914,21 @@ def get_local_rank(self, mesh_dim: Optional[Union[int, str]] = None) -> int:
904
914
elif mesh_dim in (self .replicate_dim , self .replicate_dim_name ):
905
915
return get_rank (self .replicate_pg )
906
916
else :
907
- return self .mesh .get_local_rank (mesh_dim )
917
+ return self .mesh .get_local_rank (self ._real_mesh_dim (mesh_dim ))
918
+
919
+ def get_coordinate (self ) -> Optional [List [int ]]:
920
+ """
921
+ Return the relative indices of this rank relative to all
922
+ dimensions of the mesh. If this rank is not part of the mesh, return None.
923
+ """
924
+ return self .mesh ._coordinate_on_dim if self .mesh ._coordinate_on_dim else None
908
925
909
926
def get_all_groups (self ) -> List [ProcessGroup ]:
910
927
raise NotImplementedError
911
928
912
929
913
930
class _FlattenDeviceMesh (DeviceMesh ):
914
- def __init__ (self , managed_mesh : ManagedDeviceMesh ):
931
+ def __init__ (self , managed_mesh : _ManagedDeviceMesh ):
915
932
self .managed_mesh = managed_mesh
916
933
917
934
def __getitem__ (self , mesh_dim_names : Union [str , Tuple [str , ...]]) -> DeviceMesh :
@@ -954,7 +971,7 @@ def ft_init_device_mesh(
954
971
replicate_dim : int ,
955
972
manager : "Manager" ,
956
973
):
957
- # We have to lie DeviceMesh that the replicate_dim has only
974
+ # We need to mislead DeviceMesh into thinking that replicate_dim has only
958
975
# 1 rank.
959
976
_mesh_shape = list (mesh_shape )
960
977
_mesh_shape .pop (replicate_dim )
@@ -979,7 +996,7 @@ def ft_init_device_mesh(
979
996
# the same backend has been registered.
980
997
replicate_pg .register (mesh_dim_names [replicate_dim ])
981
998
982
- return ManagedDeviceMesh (
999
+ return _ManagedDeviceMesh (
983
1000
mesh = mesh ,
984
1001
mesh_dim_names = mesh_dim_names ,
985
1002
replicate_pg = replicate_pg ,
0 commit comments