21
21
import threading
22
22
from abc import ABC
23
23
from datetime import timedelta
24
- from typing import TYPE_CHECKING , Dict , List , Optional , Type , Union
24
+ from typing import TYPE_CHECKING , Dict , List , Optional , Tuple , Type , Union
25
25
26
26
import torch
27
27
import torch .distributed as dist
39
39
Store ,
40
40
TCPStore ,
41
41
get_rank ,
42
+ init_device_mesh ,
42
43
)
43
44
from torch .distributed .distributed_c10d import Work , _world
44
45
from torch .futures import Future
@@ -149,17 +150,7 @@ def size(self) -> int:
149
150
def getBackendName (self ) -> str :
150
151
raise NotImplementedError ("not implemented" )
151
152
152
- def register (self , name : str ) -> "ProcessGroup" :
153
- """
154
- Registers the process group with the global registry. This enables usage
155
- with things like functional_collectives which are compilable.
156
-
157
- This should only be called once.
158
-
159
- Args:
160
- name: name must be a unique name for this process group
161
- """
162
-
153
+ def _register (self , name : str ) -> str :
163
154
group_name = f"{ self .getBackendName ()} :{ name } "
164
155
165
156
# This is needed for DeviceMesh and functional collectives to work.
@@ -177,6 +168,21 @@ def create_pg(
177
168
devices = ["cpu" ]
178
169
dist .Backend .register_backend (group_name , create_pg , devices = devices )
179
170
171
+ return group_name
172
+
173
+ def register (self , name : str ) -> "ProcessGroup" :
174
+ """
175
+ Registers the process group with the global registry. This enables usage
176
+ with things like functional_collectives which are compilable.
177
+
178
+ This should only be called once.
179
+
180
+ Args:
181
+ name: name must be a unique name for this process group
182
+ """
183
+
184
+ group_name = self ._register (name )
185
+
180
186
return dist .new_group (
181
187
ranks = [dist .get_rank ()],
182
188
backend = group_name ,
@@ -519,6 +525,9 @@ def allreduce(self, tensors: List[torch.Tensor], opts: object) -> Work:
519
525
def size (self ) -> int :
520
526
return self ._manager .num_participants ()
521
527
528
+ def getBackendName (self ) -> str :
529
+ return self ._manager ._pg .getBackendName ()
530
+
522
531
523
532
class _BabyWork (Work ):
524
533
def __init__ (
@@ -730,7 +739,6 @@ def _future_handler(self, future_queue: mp.Queue) -> None:
730
739
logger .exception (f"got unexpected error in future handler: { e } " )
731
740
732
741
def _get_future (self , op_id : int ) -> Future [object ]:
733
-
734
742
with self ._futures_lock :
735
743
fut = Future () # pyre-fixme[29]: is not a function
736
744
self ._futures [op_id ] = fut
@@ -841,3 +849,231 @@ def extend_device_mesh(
841
849
mesh = mesh .mesh .unsqueeze (dim ),
842
850
mesh_dim_names = tuple (mesh_dim_names ),
843
851
)
852
+
853
+
854
+ class ManagedDeviceMesh (DeviceMesh ):
855
+ def __init__ (
856
+ self ,
857
+ mesh : Optional [DeviceMesh ],
858
+ mesh_dim_names : Tuple [str , ...],
859
+ replicate_pg : ManagedProcessGroup ,
860
+ replicate_dim : int ,
861
+ parent : Optional ["ManagedDeviceMesh" ],
862
+ ) -> None :
863
+ if mesh is None and parent is None :
864
+ raise ValueError (
865
+ "ManagedDeviceMesh doesn't support both mesh and parent are None."
866
+ )
867
+ self .mesh = mesh
868
+ self .mesh_dim_names = mesh_dim_names
869
+ self .replicate_pg = replicate_pg
870
+ self .replicate_dim = replicate_dim
871
+ self .replicate_dim_name : str = mesh_dim_names [replicate_dim ]
872
+ self .parent = parent
873
+ self .flatten_meshes : Dict [str , DeviceMesh ] = {}
874
+ self .device_type : str
875
+ if mesh is not None :
876
+ self .device_type = mesh .device_type
877
+ else :
878
+ assert parent is not None
879
+ self .device_type = parent .device_type
880
+ self ._flatten_mesh_list : Tuple [DeviceMesh , ...] = tuple ()
881
+ self ._thread_id : Optional [int ] = None
882
+
883
+ def __getitem__ (self , mesh_dim_names : Union [str , Tuple [str , ...]]) -> DeviceMesh :
884
+ if isinstance (mesh_dim_names , str ):
885
+ if mesh_dim_names == self .replicate_dim_name :
886
+ return ManagedDeviceMesh (
887
+ mesh = None ,
888
+ mesh_dim_names = (mesh_dim_names ,),
889
+ replicate_pg = self .replicate_pg ,
890
+ replicate_dim = 0 ,
891
+ parent = self ,
892
+ )
893
+ elif mesh_dim_names in self .flatten_meshes :
894
+ return self .flatten_meshes [mesh_dim_names ]
895
+ else :
896
+ assert self .mesh is not None
897
+ return self .mesh [mesh_dim_names ]
898
+ else :
899
+ assert isinstance (mesh_dim_names , tuple )
900
+ if self .replicate_dim_name in mesh_dim_names :
901
+ assert self .mesh is not None
902
+ return self .mesh [mesh_dim_names ]
903
+ else :
904
+ assert self .mesh is not None
905
+ return ManagedDeviceMesh (
906
+ self .mesh [mesh_dim_names ],
907
+ mesh_dim_names ,
908
+ self .replicate_pg ,
909
+ mesh_dim_names .index (self .replicate_dim_name ),
910
+ parent = self ,
911
+ )
912
+
913
+ def _real_mesh_dim (self , mesh_dim : int ) -> int :
914
+ return mesh_dim - 1 if mesh_dim > self .replicate_dim else mesh_dim
915
+
916
+ def get_group (self , mesh_dim : Optional [Union [int , str ]] = None ) -> BaseProcessGroup :
917
+ if isinstance (mesh_dim , str ):
918
+ dim = self .mesh_dim_names .index (mesh_dim )
919
+ else :
920
+ dim = 0 if mesh_dim is None else int (mesh_dim )
921
+
922
+ if mesh_dim is None :
923
+ return self .replicate_pg
924
+ elif dim == self .replicate_dim :
925
+ return self .replicate_pg
926
+ else :
927
+ assert self .mesh is not None
928
+ return self .mesh .get_group (self ._real_mesh_dim (dim ))
929
+
930
+ def _flatten (self , mesh_dim_name : Optional [str ]) -> "DeviceMesh" :
931
+ flatten_mesh = _FlattenDeviceMesh (self )
932
+ if mesh_dim_name is None :
933
+ raise ValueError ("ManagedDeviceMesh._flatten requires `mesh_dim_name`" )
934
+ if self .parent is None :
935
+ self .flatten_meshes [mesh_dim_name ] = flatten_mesh
936
+ else :
937
+ self .parent .flatten_meshes [mesh_dim_name ] = flatten_mesh
938
+ return flatten_mesh
939
+
940
+ def size (self , mesh_dim : Optional [int ] = None ) -> int :
941
+ if mesh_dim is None :
942
+ if self .mesh is None :
943
+ return self .replicate_pg .size ()
944
+ else :
945
+ assert self .mesh is not None
946
+ return self .mesh .size () * self .replicate_pg .size ()
947
+ elif mesh_dim == self .replicate_dim :
948
+ return self .replicate_pg .size ()
949
+ else :
950
+ assert self .mesh is not None
951
+ return self .mesh .size (self ._real_mesh_dim (mesh_dim ))
952
+
953
+ @property
954
+ def ndim (self ) -> int :
955
+ assert self .mesh is not None
956
+ return self .mesh .ndim + 1
957
+
958
+ @property
959
+ def shape (self ) -> Tuple [int , ...]:
960
+ assert self .mesh is not None
961
+ ret : List [int ] = list (self .mesh .shape )
962
+ ret .insert (self .replicate_dim , self .replicate_pg .size ())
963
+ return tuple (ret )
964
+
965
+ def get_rank (self ) -> int :
966
+ assert self .mesh is not None
967
+ return self .mesh .get_rank ()
968
+
969
+ def get_local_rank (self , mesh_dim : Optional [Union [int , str ]] = None ) -> int :
970
+ if isinstance (mesh_dim , str ):
971
+ dim = self .mesh_dim_names .index (mesh_dim )
972
+ else :
973
+ dim = 0 if mesh_dim is None else int (mesh_dim )
974
+
975
+ if mesh_dim is None :
976
+ if self .mesh is None :
977
+ return get_rank (self .replicate_pg )
978
+
979
+ assert self .replicate_dim == 0 , "replicate_dim must be the first one"
980
+ assert self .mesh is not None
981
+ other_dim_size = self .mesh .size ()
982
+ assert self .mesh is not None
983
+ other_dim_rank = self .mesh .get_local_rank ()
984
+ replicate_pg_rank = get_rank (self .replicate_pg )
985
+ return other_dim_size * replicate_pg_rank + other_dim_rank
986
+ elif dim == self .replicate_dim :
987
+ return get_rank (self .replicate_pg )
988
+ else :
989
+ assert self .mesh is not None
990
+ return self .mesh .get_local_rank (self ._real_mesh_dim (dim ))
991
+
992
+ def get_coordinate (self ) -> Optional [List [int ]]:
993
+ """
994
+ Return the relative indices of this rank relative to all
995
+ dimensions of the mesh. If this rank is not part of the mesh, return None.
996
+ """
997
+ assert self .mesh is not None
998
+ return self .mesh ._coordinate_on_dim if self .mesh ._coordinate_on_dim else None
999
+
1000
+ def get_all_groups (self ) -> List [BaseProcessGroup ]:
1001
+ raise NotImplementedError
1002
+
1003
+
1004
+ class _FlattenDeviceMesh (DeviceMesh ):
1005
+ def __init__ (self , managed_mesh : ManagedDeviceMesh ) -> None :
1006
+ self .managed_mesh = managed_mesh
1007
+
1008
+ def __getitem__ (self , mesh_dim_names : Union [str , Tuple [str , ...]]) -> DeviceMesh :
1009
+ raise NotImplementedError
1010
+
1011
+ def get_group (self , mesh_dim : Optional [Union [int , str ]] = None ) -> BaseProcessGroup :
1012
+ raise NotImplementedError
1013
+
1014
+ def _flatten (self , mesh_dim_name : Optional [str ]) -> "DeviceMesh" :
1015
+ raise NotImplementedError
1016
+
1017
+ def size (self , mesh_dim : Optional [int ] = None ) -> int :
1018
+ assert mesh_dim is None
1019
+ return self .managed_mesh .size ()
1020
+
1021
+ @property
1022
+ def ndim (self ) -> int :
1023
+ raise NotImplementedError
1024
+
1025
+ @property
1026
+ def shape (self ) -> Tuple [int , ...]:
1027
+ raise NotImplementedError
1028
+
1029
+ def get_rank (self ) -> int :
1030
+ raise NotImplementedError
1031
+
1032
+ def get_local_rank (self , mesh_dim : Optional [Union [int , str ]] = None ) -> int :
1033
+ assert mesh_dim is None
1034
+ return self .managed_mesh .get_local_rank ()
1035
+
1036
+ def get_all_groups (self ) -> List [BaseProcessGroup ]:
1037
+ raise NotImplementedError
1038
+
1039
+
1040
+ def ft_init_device_mesh (
1041
+ * ,
1042
+ device_type : str ,
1043
+ mesh_shape : Tuple [int , ...],
1044
+ mesh_dim_names : Tuple [str , ...],
1045
+ replicate_dim : int ,
1046
+ manager : "Manager" ,
1047
+ ) -> "ManagedDeviceMesh" :
1048
+ # We need to mislead DeviceMesh into thinking that replicate_dim has only
1049
+ # 1 rank.
1050
+ _mesh_shape = list (mesh_shape )
1051
+ _mesh_shape .pop (replicate_dim )
1052
+ _mesh_dim_names = list (mesh_dim_names )
1053
+ _mesh_dim_names .pop (replicate_dim )
1054
+ mesh = init_device_mesh (
1055
+ device_type ,
1056
+ mesh_shape = tuple (_mesh_shape ),
1057
+ mesh_dim_names = tuple (_mesh_dim_names ),
1058
+ )
1059
+
1060
+ if device_type == "cpu" :
1061
+ pg = ProcessGroupGloo ()
1062
+ elif device_type == "cuda" :
1063
+ pg = ProcessGroupNCCL ()
1064
+ else :
1065
+ raise ValueError ()
1066
+
1067
+ manager ._pg = pg
1068
+ replicate_pg = ManagedProcessGroup (manager )
1069
+ # We have to use MultiProcessTestCase, otherwise c10d will complain
1070
+ # the same backend has been registered.
1071
+ replicate_pg .register (mesh_dim_names [replicate_dim ])
1072
+
1073
+ return ManagedDeviceMesh (
1074
+ mesh = mesh ,
1075
+ mesh_dim_names = mesh_dim_names ,
1076
+ replicate_pg = replicate_pg ,
1077
+ replicate_dim = replicate_dim ,
1078
+ parent = None ,
1079
+ )
0 commit comments