Skip to content

Commit bd189af

Browse files
committed
Introduce ManagedDeviceMesh to integrate DeviceMesh with TorchFT
Summary: ManagedDeviceMesh allow users to manipulate DeviceMesh with TorchFT ManagedProcessGroup. ghstack-source-id: 9b2bdf3aa301a643726c8d8fb43f385bb022ba96 Pull Request resolved: #56
1 parent f31d3b1 commit bd189af

File tree

2 files changed

+271
-35
lines changed

2 files changed

+271
-35
lines changed

torchft/process_group.py

+216-27
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
import threading
2121
from abc import ABC
2222
from datetime import timedelta
23-
from typing import TYPE_CHECKING, Dict, List, Optional, Type
23+
from typing import Dict, List, Optional, Tuple, Type, TYPE_CHECKING, Union
2424

2525
import torch
2626
import torch.distributed as dist
@@ -31,15 +31,16 @@
3131
from torch.distributed import (
3232
BroadcastOptions,
3333
DeviceMesh,
34+
get_rank,
35+
init_device_mesh,
3436
PrefixStore,
3537
ProcessGroup as BaseProcessGroup,
3638
ProcessGroupGloo as BaseProcessGroupGloo,
3739
ProcessGroupNCCL as BaseProcessGroupNCCL,
3840
Store,
3941
TCPStore,
40-
get_rank,
4142
)
42-
from torch.distributed.distributed_c10d import Work, _world
43+
from torch.distributed.distributed_c10d import _world, Work
4344
from torch.futures import Future
4445

4546
if TYPE_CHECKING:
@@ -130,17 +131,7 @@ def size(self) -> int:
130131
def getBackendName(self) -> str:
131132
raise NotImplementedError("not implemented")
132133

133-
def register(self, name: str) -> "ProcessGroup":
134-
"""
135-
Registers the process group with the global registry. This enables usage
136-
with things like functional_collectives which are compilable.
137-
138-
This should only be called once.
139-
140-
Args:
141-
name: name must be a unique name for this process group
142-
"""
143-
134+
def _register(self, name: str) -> str:
144135
group_name = f"{self.getBackendName()}:{name}"
145136

146137
# This is needed for DeviceMesh and functional collectives to work.
@@ -158,6 +149,21 @@ def create_pg(
158149
devices = ["cpu"]
159150
dist.Backend.register_backend(group_name, create_pg, devices=devices)
160151

152+
return group_name
153+
154+
def register(self, name: str) -> "ProcessGroup":
155+
"""
156+
Registers the process group with the global registry. This enables usage
157+
with things like functional_collectives which are compilable.
158+
159+
This should only be called once.
160+
161+
Args:
162+
name: name must be a unique name for this process group
163+
"""
164+
165+
group_name = self._register(name)
166+
161167
return dist.new_group(
162168
ranks=[dist.get_rank()],
163169
backend=group_name,
@@ -244,9 +250,9 @@ class ProcessGroupGloo(ProcessGroupWrapper):
244250
This is a reconfigurable version of ProcessGroupGloo.
245251
"""
246252

247-
PG_CLASS: Type[BaseProcessGroup] = (
248-
BaseProcessGroupGloo # pyre-fixme[16]: no attribute ProcessGroupGloo
249-
)
253+
PG_CLASS: Type[
254+
BaseProcessGroup
255+
] = BaseProcessGroupGloo # pyre-fixme[16]: no attribute ProcessGroupGloo
250256

251257
def getBackendName(self) -> str:
252258
return "torchft-gloo"
@@ -263,9 +269,9 @@ class ProcessGroupNCCL(ProcessGroupWrapper):
263269
abort when reconfiguring, we need to ensure this is safe.
264270
"""
265271

266-
PG_CLASS: Type[BaseProcessGroup] = (
267-
BaseProcessGroupNCCL # pyre-fixme[16]: no attribute ProcessGroupNCCL
268-
)
272+
PG_CLASS: Type[
273+
BaseProcessGroup
274+
] = BaseProcessGroupNCCL # pyre-fixme[16]: no attribute ProcessGroupNCCL
269275

270276
def getBackendName(self) -> str:
271277
return "torchft-nccl"
@@ -496,6 +502,9 @@ def allreduce(self, tensors: List[torch.Tensor], opts: object) -> Work:
496502
def size(self) -> int:
497503
return self._manager.num_participants()
498504

505+
def getBackendName(self) -> str:
506+
return self._manager._pg.getBackendName()
507+
499508

500509
class _BabyWork(Work):
501510
def __init__(
@@ -689,7 +698,6 @@ def _future_handler(self, future_queue: mp.Queue) -> None:
689698
logger.exception(f"got unexpected error in future handler: {e}")
690699

691700
def _get_future(self, op_id: int) -> Future[object]:
692-
693701
with self._futures_lock:
694702
fut = Future() # pyre-fixme[29]: is not a function
695703
self._futures[op_id] = fut
@@ -737,9 +745,9 @@ class ProcessGroupBabyGloo(ProcessGroupBaby):
737745
ProcessGroupBabyNCCL.
738746
"""
739747

740-
PG_CLASS: Type[BaseProcessGroup] = (
741-
BaseProcessGroupGloo # pyre-fixme[16]: no attribute ProcessGroupGloo
742-
)
748+
PG_CLASS: Type[
749+
BaseProcessGroup
750+
] = BaseProcessGroupGloo # pyre-fixme[16]: no attribute ProcessGroupGloo
743751

744752
def getBackendName(self) -> str:
745753
return "torchft-baby-gloo"
@@ -761,9 +769,9 @@ class ProcessGroupBabyNCCL(ProcessGroupBaby):
761769
tensors may leak in the current PyTorch implementation. TODO fix
762770
"""
763771

764-
PG_CLASS: Type[BaseProcessGroup] = (
765-
BaseProcessGroupNCCL # pyre-fixme[16]: no attribute ProcessGroupNCCL
766-
)
772+
PG_CLASS: Type[
773+
BaseProcessGroup
774+
] = BaseProcessGroupNCCL # pyre-fixme[16]: no attribute ProcessGroupNCCL
767775
WORK_CLASS = _BabyWorkNCCL
768776

769777
def getBackendName(self) -> str:
@@ -797,3 +805,184 @@ def extend_device_mesh(
797805
mesh=mesh.mesh.unsqueeze(dim),
798806
mesh_dim_names=tuple(mesh_dim_names),
799807
)
808+
809+
810+
class ManagedDeviceMesh(DeviceMesh):
811+
def __init__(
812+
self,
813+
mesh: Optional[DeviceMesh],
814+
mesh_dim_names: Tuple[str],
815+
replicate_pg: ManagedProcessGroup,
816+
replicate_dim: int,
817+
parent: Optional["ManagedDeviceMesh"],
818+
):
819+
self.mesh = mesh
820+
self.mesh_dim_names = mesh_dim_names
821+
self.replicate_pg = replicate_pg
822+
self.replicate_dim = replicate_dim
823+
self.replicate_dim_name = mesh_dim_names[replicate_dim]
824+
self.parent = parent
825+
self.flatten_meshes = {}
826+
827+
def __getitem__(self, mesh_dim_names: Union[str, Tuple[str, ...]]) -> DeviceMesh:
828+
if isinstance(mesh_dim_names, str):
829+
if mesh_dim_names == self.replicate_dim_name:
830+
return ManagedDeviceMesh(
831+
mesh=None,
832+
mesh_dim_names=(mesh_dim_names,),
833+
replicate_pg=self.replicate_pg,
834+
replicate_dim=0,
835+
parent=self,
836+
)
837+
elif mesh_dim_names in self.flatten_meshes:
838+
return self.flatten_meshes[mesh_dim_names]
839+
else:
840+
return self.mesh[mesh_dim_names]
841+
else:
842+
assert isinstance(mesh_dim_names, tuple)
843+
if self.replicate_dim_name in mesh_dim_names:
844+
return self.mesh[mesh_dim_names]
845+
else:
846+
return ManagedDeviceMesh(
847+
self.mesh[mesh_dim_names],
848+
mesh_dim_names,
849+
self.replicate_pg,
850+
mesh_dim_name.index(self.replicate_dim_name),
851+
parent=self,
852+
)
853+
854+
def get_group(self, mesh_dim: Optional[str] = None) -> BaseProcessGroup:
855+
if mesh_dim is None:
856+
assert self.mesh is None
857+
return self.replicate_pg
858+
elif mesh_dim == self.replicate_dim_name:
859+
return self.replicate_pg
860+
else:
861+
return self.mesh.get_group(mesh_dim)
862+
863+
def _flatten(self, mesh_dim_name: str) -> "DeviceMesh":
864+
flatten_mesh = _FlattenDeviceMesh(self)
865+
if self.parent is None:
866+
self.flatten_meshes[mesh_dim_name] = flatten_mesh
867+
else:
868+
self.parent.flatten_meshes[mesh_dim_name] = flatten_mesh
869+
return flatten_mesh
870+
871+
def size(self, mesh_dim: Optional[int] = None) -> int:
872+
if mesh_dim is None:
873+
if self.mesh is None:
874+
return self.replicate_pg.size()
875+
else:
876+
return self.mesh.size() * self.replicate_pg.size()
877+
elif mesh_dim == self.replicate_dim:
878+
return self.replicate_pg.size()
879+
else:
880+
return self.mesh.size(mesh_dim)
881+
882+
@property
883+
def ndim(self) -> int:
884+
return self.mesh.ndim + 1
885+
886+
@property
887+
def shape(self) -> Tuple[int, ...]:
888+
ret = list(self.mesh.shape)
889+
ret.insert(self.replicate_dim, self.replicate_pg.size())
890+
891+
def get_rank(self) -> int:
892+
return self.mesh.get_rank()
893+
894+
def get_local_rank(self, mesh_dim: Optional[Union[int, str]] = None) -> int:
895+
if mesh_dim is None:
896+
if self.mesh is None:
897+
return get_rank(self.replicate_pg)
898+
899+
assert self.replicate_dim == 0, "replicate_dim must be the first one"
900+
other_dim_size = self.mesh.size()
901+
other_dim_rank = self.mesh.get_local_rank()
902+
replicate_pg_rank = get_rank(self.replicate_pg)
903+
return other_dim_size * replicate_pg_rank + other_dim_rank
904+
elif mesh_dim in (self.replicate_dim, self.replicate_dim_name):
905+
return get_rank(self.replicate_pg)
906+
else:
907+
return self.mesh.get_local_rank(mesh_dim)
908+
909+
def get_all_groups(self) -> List[ProcessGroup]:
910+
raise NotImplementedError
911+
912+
913+
class _FlattenDeviceMesh(DeviceMesh):
914+
def __init__(self, managed_mesh: ManagedDeviceMesh):
915+
self.managed_mesh = managed_mesh
916+
917+
def __getitem__(self, mesh_dim_names: Union[str, Tuple[str, ...]]) -> DeviceMesh:
918+
raise NotImplementedError
919+
920+
def get_group(self, mesh_dim: Optional[str] = None) -> BaseProcessGroup:
921+
raise NotImplementedError
922+
923+
def _flatten(self, mesh_dim_name: str) -> "DeviceMesh":
924+
raise NotImplementedError
925+
926+
def size(self, mesh_dim: Optional[int] = None) -> int:
927+
assert mesh_dim is None
928+
return self.managed_mesh.size()
929+
930+
@property
931+
def ndim(self) -> int:
932+
raise NotImplementedError
933+
934+
@property
935+
def shape(self) -> Tuple[int, ...]:
936+
raise NotImplementedError
937+
938+
def get_rank(self) -> int:
939+
raise NotImplementedError
940+
941+
def get_local_rank(self, mesh_dim: Optional[Union[int, str]] = None) -> int:
942+
assert mesh_dim is None
943+
return self.managed_mesh.get_local_rank()
944+
945+
def get_all_groups(self) -> List[ProcessGroup]:
946+
raise NotImplementedError
947+
948+
949+
def ft_init_device_mesh(
950+
*,
951+
device_type: str,
952+
mesh_shape: Tuple[int, ...],
953+
mesh_dim_names: Tuple[str, ...],
954+
replicate_dim: int,
955+
manager: "Manager",
956+
):
957+
# We have to lie DeviceMesh that the replicate_dim has only
958+
# 1 rank.
959+
_mesh_shape = list(mesh_shape)
960+
_mesh_shape.pop(replicate_dim)
961+
_mesh_dim_names = list(mesh_dim_names)
962+
_mesh_dim_names.pop(replicate_dim)
963+
mesh = init_device_mesh(
964+
device_type,
965+
mesh_shape=tuple(_mesh_shape),
966+
mesh_dim_names=tuple(_mesh_dim_names),
967+
)
968+
969+
if device_type == "cpu":
970+
pg = ProcessGroupGloo()
971+
elif device_type == "cuda":
972+
pg = ProcessGroupNCCL()
973+
else:
974+
raise ValueError()
975+
976+
manager._pg = pg
977+
replicate_pg = ManagedProcessGroup(manager)
978+
# We have to use MultiProcessTestCase, otherwise c10d will complain
979+
# the same backend has been registered.
980+
replicate_pg.register(mesh_dim_names[replicate_dim])
981+
982+
return ManagedDeviceMesh(
983+
mesh=mesh,
984+
mesh_dim_names=mesh_dim_names,
985+
replicate_pg=replicate_pg,
986+
replicate_dim=replicate_dim,
987+
parent=None,
988+
)

0 commit comments

Comments
 (0)