20
20
import threading
21
21
from abc import ABC
22
22
from datetime import timedelta
23
- from typing import TYPE_CHECKING , Dict , List , Optional , Type
23
+ from typing import TYPE_CHECKING , Dict , List , Optional , Tuple , Type , Union
24
24
25
25
import torch
26
26
import torch .distributed as dist
38
38
Store ,
39
39
TCPStore ,
40
40
get_rank ,
41
+ init_device_mesh ,
41
42
)
42
43
from torch .distributed .distributed_c10d import Work , _world
43
44
from torch .futures import Future
@@ -130,17 +131,7 @@ def size(self) -> int:
130
131
def getBackendName (self ) -> str :
131
132
raise NotImplementedError ("not implemented" )
132
133
133
- def register (self , name : str ) -> "ProcessGroup" :
134
- """
135
- Registers the process group with the global registry. This enables usage
136
- with things like functional_collectives which are compilable.
137
-
138
- This should only be called once.
139
-
140
- Args:
141
- name: name must be a unique name for this process group
142
- """
143
-
134
+ def _register (self , name : str ) -> str :
144
135
group_name = f"{ self .getBackendName ()} :{ name } "
145
136
146
137
# This is needed for DeviceMesh and functional collectives to work.
@@ -158,6 +149,21 @@ def create_pg(
158
149
devices = ["cpu" ]
159
150
dist .Backend .register_backend (group_name , create_pg , devices = devices )
160
151
152
+ return group_name
153
+
154
+ def register (self , name : str ) -> "ProcessGroup" :
155
+ """
156
+ Registers the process group with the global registry. This enables usage
157
+ with things like functional_collectives which are compilable.
158
+
159
+ This should only be called once.
160
+
161
+ Args:
162
+ name: name must be a unique name for this process group
163
+ """
164
+
165
+ group_name = self ._register (name )
166
+
161
167
return dist .new_group (
162
168
ranks = [dist .get_rank ()],
163
169
backend = group_name ,
@@ -496,6 +502,9 @@ def allreduce(self, tensors: List[torch.Tensor], opts: object) -> Work:
496
502
def size (self ) -> int :
497
503
return self ._manager .num_participants ()
498
504
505
+ def getBackendName (self ) -> str :
506
+ return self ._manager ._pg .getBackendName ()
507
+
499
508
500
509
class _BabyWork (Work ):
501
510
def __init__ (
@@ -689,7 +698,6 @@ def _future_handler(self, future_queue: mp.Queue) -> None:
689
698
logger .exception (f"got unexpected error in future handler: { e } " )
690
699
691
700
def _get_future (self , op_id : int ) -> Future [object ]:
692
-
693
701
with self ._futures_lock :
694
702
fut = Future () # pyre-fixme[29]: is not a function
695
703
self ._futures [op_id ] = fut
@@ -797,3 +805,232 @@ def extend_device_mesh(
797
805
mesh = mesh .mesh .unsqueeze (dim ),
798
806
mesh_dim_names = tuple (mesh_dim_names ),
799
807
)
808
+
809
+
810
+ class ManagedDeviceMesh (DeviceMesh ):
811
+ def __init__ (
812
+ self ,
813
+ mesh : Optional [DeviceMesh ],
814
+ mesh_dim_names : Tuple [str , ...],
815
+ replicate_pg : ManagedProcessGroup ,
816
+ replicate_dim : int ,
817
+ parent : Optional ["ManagedDeviceMesh" ],
818
+ ) -> None :
819
+ if mesh is None and parent is not None :
820
+ raise ValueError (
821
+ "ManagedDeviceMesh doesn't support both mesh and parent are None."
822
+ )
823
+ self .mesh = mesh
824
+ self .mesh_dim_names = mesh_dim_names
825
+ self .replicate_pg = replicate_pg
826
+ self .replicate_dim = replicate_dim
827
+ self .replicate_dim_name : str = mesh_dim_names [replicate_dim ]
828
+ self .parent = parent
829
+ self .flatten_meshes : Dict [str , DeviceMesh ] = {}
830
+ self .device_type : str
831
+ if mesh is not None :
832
+ self .device_type = mesh .device_type
833
+ else :
834
+ assert parent is not None
835
+ self .device_type = parent .device_type
836
+ self ._flatten_mesh_list : Tuple [DeviceMesh , ...] = tuple ()
837
+ self ._thread_id : Optional [int ] = None
838
+
839
+ def __getitem__ (self , mesh_dim_names : Union [str , Tuple [str , ...]]) -> DeviceMesh :
840
+ if isinstance (mesh_dim_names , str ):
841
+ if mesh_dim_names == self .replicate_dim_name :
842
+ return ManagedDeviceMesh (
843
+ mesh = None ,
844
+ mesh_dim_names = (mesh_dim_names ,),
845
+ replicate_pg = self .replicate_pg ,
846
+ replicate_dim = 0 ,
847
+ parent = self ,
848
+ )
849
+ elif mesh_dim_names in self .flatten_meshes :
850
+ return self .flatten_meshes [mesh_dim_names ]
851
+ else :
852
+ assert self .mesh is not None
853
+ return self .mesh [mesh_dim_names ]
854
+ else :
855
+ assert isinstance (mesh_dim_names , tuple )
856
+ if self .replicate_dim_name in mesh_dim_names :
857
+ assert self .mesh is not None
858
+ return self .mesh [mesh_dim_names ]
859
+ else :
860
+ assert self .mesh is not None
861
+ return ManagedDeviceMesh (
862
+ self .mesh [mesh_dim_names ],
863
+ mesh_dim_names ,
864
+ self .replicate_pg ,
865
+ mesh_dim_names .index (self .replicate_dim_name ),
866
+ parent = self ,
867
+ )
868
+
869
+ def _real_mesh_dim (self , mesh_dim : int ) -> int :
870
+ return mesh_dim - 1 if mesh_dim > self .replicate_dim else mesh_dim
871
+
872
+ def get_group (self , mesh_dim : Optional [Union [int , str ]] = None ) -> BaseProcessGroup :
873
+ if isinstance (mesh_dim , str ):
874
+ dim = self .mesh_dim_names .index (mesh_dim )
875
+ else :
876
+ dim = 0 if mesh_dim is None else int (mesh_dim )
877
+
878
+ if mesh_dim is None :
879
+ assert self .mesh is not None
880
+ return self .replicate_pg
881
+ elif dim == self .replicate_dim :
882
+ return self .replicate_pg
883
+ else :
884
+ assert self .mesh is not None
885
+ return self .mesh .get_group (self ._real_mesh_dim (dim ))
886
+
887
+ def _flatten (self , mesh_dim_name : Optional [str ]) -> "DeviceMesh" :
888
+ flatten_mesh = _FlattenDeviceMesh (self )
889
+ if mesh_dim_name is None :
890
+ raise ValueError ("ManagedDeviceMesh._flatten requires `mesh_dim_name`" )
891
+ if self .parent is None :
892
+ self .flatten_meshes [mesh_dim_name ] = flatten_mesh
893
+ else :
894
+ self .parent .flatten_meshes [mesh_dim_name ] = flatten_mesh
895
+ return flatten_mesh
896
+
897
+ def size (self , mesh_dim : Optional [int ] = None ) -> int :
898
+ if mesh_dim is None :
899
+ if self .mesh is None :
900
+ return self .replicate_pg .size ()
901
+ else :
902
+ assert self .mesh is not None
903
+ return self .mesh .size () * self .replicate_pg .size ()
904
+ elif mesh_dim == self .replicate_dim :
905
+ return self .replicate_pg .size ()
906
+ else :
907
+ assert self .mesh is not None
908
+ return self .mesh .size (self ._real_mesh_dim (mesh_dim ))
909
+
910
+ @property
911
+ def ndim (self ) -> int :
912
+ assert self .mesh is not None
913
+ return self .mesh .ndim + 1
914
+
915
+ @property
916
+ def shape (self ) -> Tuple [int , ...]:
917
+ assert self .mesh is not None
918
+ ret : List [int ] = list (self .mesh .shape )
919
+ ret .insert (self .replicate_dim , self .replicate_pg .size ())
920
+ return tuple (ret )
921
+
922
+ def get_rank (self ) -> int :
923
+ assert self .mesh is not None
924
+ return self .mesh .get_rank ()
925
+
926
+ def get_local_rank (self , mesh_dim : Optional [Union [int , str ]] = None ) -> int :
927
+ if isinstance (mesh_dim , str ):
928
+ dim = self .mesh_dim_names .index (mesh_dim )
929
+ else :
930
+ dim = 0 if mesh_dim is None else int (mesh_dim )
931
+
932
+ if mesh_dim is None :
933
+ if self .mesh is None :
934
+ return get_rank (self .replicate_pg )
935
+
936
+ assert self .replicate_dim == 0 , "replicate_dim must be the first one"
937
+ assert self .mesh is not None
938
+ other_dim_size = self .mesh .size ()
939
+ assert self .mesh is not None
940
+ other_dim_rank = self .mesh .get_local_rank ()
941
+ replicate_pg_rank = get_rank (self .replicate_pg )
942
+ return other_dim_size * replicate_pg_rank + other_dim_rank
943
+ elif dim == self .replicate_dim :
944
+ return get_rank (self .replicate_pg )
945
+ else :
946
+ assert self .mesh is not None
947
+ return self .mesh .get_local_rank (self ._real_mesh_dim (dim ))
948
+
949
+ def get_coordinate (self ) -> Optional [List [int ]]:
950
+ """
951
+ Return the relative indices of this rank relative to all
952
+ dimensions of the mesh. If this rank is not part of the mesh, return None.
953
+ """
954
+ assert self .mesh is not None
955
+ return self .mesh ._coordinate_on_dim if self .mesh ._coordinate_on_dim else None
956
+
957
+ def get_all_groups (self ) -> List [BaseProcessGroup ]:
958
+ raise NotImplementedError
959
+
960
+
961
+ class _FlattenDeviceMesh (DeviceMesh ):
962
+ def __init__ (self , managed_mesh : ManagedDeviceMesh ) -> None :
963
+ self .managed_mesh = managed_mesh
964
+
965
+ def __getitem__ (self , mesh_dim_names : Union [str , Tuple [str , ...]]) -> DeviceMesh :
966
+ raise NotImplementedError
967
+
968
+ def get_group (self , mesh_dim : Optional [Union [int , str ]] = None ) -> BaseProcessGroup :
969
+ raise NotImplementedError
970
+
971
+ def _flatten (self , mesh_dim_name : Optional [str ]) -> "DeviceMesh" :
972
+ raise NotImplementedError
973
+
974
+ def size (self , mesh_dim : Optional [int ] = None ) -> int :
975
+ assert mesh_dim is None
976
+ return self .managed_mesh .size ()
977
+
978
+ @property
979
+ def ndim (self ) -> int :
980
+ raise NotImplementedError
981
+
982
+ @property
983
+ def shape (self ) -> Tuple [int , ...]:
984
+ raise NotImplementedError
985
+
986
+ def get_rank (self ) -> int :
987
+ raise NotImplementedError
988
+
989
+ def get_local_rank (self , mesh_dim : Optional [Union [int , str ]] = None ) -> int :
990
+ assert mesh_dim is None
991
+ return self .managed_mesh .get_local_rank ()
992
+
993
+ def get_all_groups (self ) -> List [BaseProcessGroup ]:
994
+ raise NotImplementedError
995
+
996
+
997
+ def ft_init_device_mesh (
998
+ * ,
999
+ device_type : str ,
1000
+ mesh_shape : Tuple [int , ...],
1001
+ mesh_dim_names : Tuple [str , ...],
1002
+ replicate_dim : int ,
1003
+ manager : "Manager" ,
1004
+ ) -> "ManagedDeviceMesh" :
1005
+ # We need to mislead DeviceMesh into thinking that replicate_dim has only
1006
+ # 1 rank.
1007
+ _mesh_shape = list (mesh_shape )
1008
+ _mesh_shape .pop (replicate_dim )
1009
+ _mesh_dim_names = list (mesh_dim_names )
1010
+ _mesh_dim_names .pop (replicate_dim )
1011
+ mesh = init_device_mesh (
1012
+ device_type ,
1013
+ mesh_shape = tuple (_mesh_shape ),
1014
+ mesh_dim_names = tuple (_mesh_dim_names ),
1015
+ )
1016
+
1017
+ if device_type == "cpu" :
1018
+ pg = ProcessGroupGloo ()
1019
+ elif device_type == "cuda" :
1020
+ pg = ProcessGroupNCCL ()
1021
+ else :
1022
+ raise ValueError ()
1023
+
1024
+ manager ._pg = pg
1025
+ replicate_pg = ManagedProcessGroup (manager )
1026
+ # We have to use MultiProcessTestCase, otherwise c10d will complain
1027
+ # the same backend has been registered.
1028
+ replicate_pg .register (mesh_dim_names [replicate_dim ])
1029
+
1030
+ return ManagedDeviceMesh (
1031
+ mesh = mesh ,
1032
+ mesh_dim_names = mesh_dim_names ,
1033
+ replicate_pg = replicate_pg ,
1034
+ replicate_dim = replicate_dim ,
1035
+ parent = None ,
1036
+ )
0 commit comments