Skip to content

Commit 20b5f0c

Browse files
committed
Introduce ManagedDeviceMesh to integrate DeviceMesh with TorchFT
Summary: ManagedDeviceMesh allow users to manipulate DeviceMesh with TorchFT ManagedProcessGroup. ghstack-source-id: b1ed52b20adff13f2389aa554f20e150e6e375b8 Pull Request resolved: #56
1 parent f31d3b1 commit 20b5f0c

File tree

2 files changed

+261
-25
lines changed

2 files changed

+261
-25
lines changed

torchft/process_group.py

+214-25
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
import threading
2121
from abc import ABC
2222
from datetime import timedelta
23-
from typing import TYPE_CHECKING, Dict, List, Optional, Type
23+
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Type, Union
2424

2525
import torch
2626
import torch.distributed as dist
@@ -38,6 +38,7 @@
3838
Store,
3939
TCPStore,
4040
get_rank,
41+
init_device_mesh,
4142
)
4243
from torch.distributed.distributed_c10d import Work, _world
4344
from torch.futures import Future
@@ -130,17 +131,7 @@ def size(self) -> int:
130131
def getBackendName(self) -> str:
131132
raise NotImplementedError("not implemented")
132133

133-
def register(self, name: str) -> "ProcessGroup":
134-
"""
135-
Registers the process group with the global registry. This enables usage
136-
with things like functional_collectives which are compilable.
137-
138-
This should only be called once.
139-
140-
Args:
141-
name: name must be a unique name for this process group
142-
"""
143-
134+
def _register(self, name: str) -> str:
144135
group_name = f"{self.getBackendName()}:{name}"
145136

146137
# This is needed for DeviceMesh and functional collectives to work.
@@ -158,6 +149,21 @@ def create_pg(
158149
devices = ["cpu"]
159150
dist.Backend.register_backend(group_name, create_pg, devices=devices)
160151

152+
return group_name
153+
154+
def register(self, name: str) -> "ProcessGroup":
155+
"""
156+
Registers the process group with the global registry. This enables usage
157+
with things like functional_collectives which are compilable.
158+
159+
This should only be called once.
160+
161+
Args:
162+
name: name must be a unique name for this process group
163+
"""
164+
165+
group_name = self._register(name)
166+
161167
return dist.new_group(
162168
ranks=[dist.get_rank()],
163169
backend=group_name,
@@ -244,9 +250,9 @@ class ProcessGroupGloo(ProcessGroupWrapper):
244250
This is a reconfigurable version of ProcessGroupGloo.
245251
"""
246252

247-
PG_CLASS: Type[BaseProcessGroup] = (
248-
BaseProcessGroupGloo # pyre-fixme[16]: no attribute ProcessGroupGloo
249-
)
253+
PG_CLASS: Type[
254+
BaseProcessGroup
255+
] = BaseProcessGroupGloo # pyre-fixme[16]: no attribute ProcessGroupGloo
250256

251257
def getBackendName(self) -> str:
252258
return "torchft-gloo"
@@ -263,9 +269,9 @@ class ProcessGroupNCCL(ProcessGroupWrapper):
263269
abort when reconfiguring, we need to ensure this is safe.
264270
"""
265271

266-
PG_CLASS: Type[BaseProcessGroup] = (
267-
BaseProcessGroupNCCL # pyre-fixme[16]: no attribute ProcessGroupNCCL
268-
)
272+
PG_CLASS: Type[
273+
BaseProcessGroup
274+
] = BaseProcessGroupNCCL # pyre-fixme[16]: no attribute ProcessGroupNCCL
269275

270276
def getBackendName(self) -> str:
271277
return "torchft-nccl"
@@ -496,6 +502,9 @@ def allreduce(self, tensors: List[torch.Tensor], opts: object) -> Work:
496502
def size(self) -> int:
497503
return self._manager.num_participants()
498504

505+
def getBackendName(self) -> str:
506+
return self._manager._pg.getBackendName()
507+
499508

500509
class _BabyWork(Work):
501510
def __init__(
@@ -689,7 +698,6 @@ def _future_handler(self, future_queue: mp.Queue) -> None:
689698
logger.exception(f"got unexpected error in future handler: {e}")
690699

691700
def _get_future(self, op_id: int) -> Future[object]:
692-
693701
with self._futures_lock:
694702
fut = Future() # pyre-fixme[29]: is not a function
695703
self._futures[op_id] = fut
@@ -737,9 +745,9 @@ class ProcessGroupBabyGloo(ProcessGroupBaby):
737745
ProcessGroupBabyNCCL.
738746
"""
739747

740-
PG_CLASS: Type[BaseProcessGroup] = (
741-
BaseProcessGroupGloo # pyre-fixme[16]: no attribute ProcessGroupGloo
742-
)
748+
PG_CLASS: Type[
749+
BaseProcessGroup
750+
] = BaseProcessGroupGloo # pyre-fixme[16]: no attribute ProcessGroupGloo
743751

744752
def getBackendName(self) -> str:
745753
return "torchft-baby-gloo"
@@ -761,9 +769,9 @@ class ProcessGroupBabyNCCL(ProcessGroupBaby):
761769
tensors may leak in the current PyTorch implementation. TODO fix
762770
"""
763771

764-
PG_CLASS: Type[BaseProcessGroup] = (
765-
BaseProcessGroupNCCL # pyre-fixme[16]: no attribute ProcessGroupNCCL
766-
)
772+
PG_CLASS: Type[
773+
BaseProcessGroup
774+
] = BaseProcessGroupNCCL # pyre-fixme[16]: no attribute ProcessGroupNCCL
767775
WORK_CLASS = _BabyWorkNCCL
768776

769777
def getBackendName(self) -> str:
@@ -797,3 +805,184 @@ def extend_device_mesh(
797805
mesh=mesh.mesh.unsqueeze(dim),
798806
mesh_dim_names=tuple(mesh_dim_names),
799807
)
808+
809+
810+
class ManagedDeviceMesh(DeviceMesh):
811+
def __init__(
812+
self,
813+
mesh: Optional[DeviceMesh],
814+
mesh_dim_names: Tuple[str],
815+
replicate_pg: ManagedProcessGroup,
816+
replicate_dim: int,
817+
parent: Optional["ManagedDeviceMesh"],
818+
):
819+
self.mesh = mesh
820+
self.mesh_dim_names = mesh_dim_names
821+
self.replicate_pg = replicate_pg
822+
self.replicate_dim = replicate_dim
823+
self.replicate_dim_name = mesh_dim_names[replicate_dim]
824+
self.parent = parent
825+
self.flatten_meshes = {}
826+
827+
def __getitem__(self, mesh_dim_names: Union[str, Tuple[str, ...]]) -> DeviceMesh:
828+
if isinstance(mesh_dim_names, str):
829+
if mesh_dim_names == self.replicate_dim_name:
830+
return ManagedDeviceMesh(
831+
mesh=None,
832+
mesh_dim_names=(mesh_dim_names,),
833+
replicate_pg=self.replicate_pg,
834+
replicate_dim=0,
835+
parent=self,
836+
)
837+
elif mesh_dim_names in self.flatten_meshes:
838+
return self.flatten_meshes[mesh_dim_names]
839+
else:
840+
return self.mesh[mesh_dim_names]
841+
else:
842+
assert isinstance(mesh_dim_names, tuple)
843+
if self.replicate_dim_name in mesh_dim_names:
844+
return self.mesh[mesh_dim_names]
845+
else:
846+
return ManagedDeviceMesh(
847+
self.mesh[mesh_dim_names],
848+
mesh_dim_names,
849+
self.replicate_pg,
850+
mesh_dim_name.index(self.replicate_dim_name),
851+
parent=self,
852+
)
853+
854+
def get_group(self, mesh_dim: Optional[str] = None) -> BaseProcessGroup:
855+
if mesh_dim is None:
856+
assert self.mesh is None
857+
return self.replicate_pg
858+
elif mesh_dim == self.replicate_dim_name:
859+
return self.replicate_pg
860+
else:
861+
return self.mesh.get_group(mesh_dim)
862+
863+
def _flatten(self, mesh_dim_name: str) -> "DeviceMesh":
864+
flatten_mesh = _FlattenDeviceMesh(self)
865+
if self.parent is None:
866+
self.flatten_meshes[mesh_dim_name] = flatten_mesh
867+
else:
868+
self.parent.flatten_meshes[mesh_dim_name] = flatten_mesh
869+
return flatten_mesh
870+
871+
def size(self, mesh_dim: Optional[int] = None) -> int:
872+
if mesh_dim is None:
873+
if self.mesh is None:
874+
return self.replicate_pg.size()
875+
else:
876+
return self.mesh.size() * self.replicate_pg.size()
877+
elif mesh_dim == self.replicate_dim:
878+
return self.replicate_pg.size()
879+
else:
880+
return self.mesh.size(mesh_dim)
881+
882+
@property
883+
def ndim(self) -> int:
884+
return self.mesh.ndim + 1
885+
886+
@property
887+
def shape(self) -> Tuple[int, ...]:
888+
ret = list(self.mesh.shape)
889+
ret.insert(self.replicate_dim, self.replicate_pg.size())
890+
891+
def get_rank(self) -> int:
892+
return self.mesh.get_rank()
893+
894+
def get_local_rank(self, mesh_dim: Optional[Union[int, str]] = None) -> int:
895+
if mesh_dim is None:
896+
if self.mesh is None:
897+
return get_rank(self.replicate_pg)
898+
899+
assert self.replicate_dim == 0, "replicate_dim must be the first one"
900+
other_dim_size = self.mesh.size()
901+
other_dim_rank = self.mesh.get_local_rank()
902+
replicate_pg_rank = get_rank(self.replicate_pg)
903+
return other_dim_size * replicate_pg_rank + other_dim_rank
904+
elif mesh_dim in (self.replicate_dim, self.replicate_dim_name):
905+
return get_rank(self.replicate_pg)
906+
else:
907+
return self.mesh.get_local_rank(mesh_dim)
908+
909+
def get_all_groups(self) -> List[ProcessGroup]:
910+
raise NotImplementedError
911+
912+
913+
class _FlattenDeviceMesh(DeviceMesh):
914+
def __init__(self, managed_mesh: ManagedDeviceMesh):
915+
self.managed_mesh = managed_mesh
916+
917+
def __getitem__(self, mesh_dim_names: Union[str, Tuple[str, ...]]) -> DeviceMesh:
918+
raise NotImplementedError
919+
920+
def get_group(self, mesh_dim: Optional[str] = None) -> BaseProcessGroup:
921+
raise NotImplementedError
922+
923+
def _flatten(self, mesh_dim_name: str) -> "DeviceMesh":
924+
raise NotImplementedError
925+
926+
def size(self, mesh_dim: Optional[int] = None) -> int:
927+
assert mesh_dim is None
928+
return self.managed_mesh.size()
929+
930+
@property
931+
def ndim(self) -> int:
932+
raise NotImplementedError
933+
934+
@property
935+
def shape(self) -> Tuple[int, ...]:
936+
raise NotImplementedError
937+
938+
def get_rank(self) -> int:
939+
raise NotImplementedError
940+
941+
def get_local_rank(self, mesh_dim: Optional[Union[int, str]] = None) -> int:
942+
assert mesh_dim is None
943+
return self.managed_mesh.get_local_rank()
944+
945+
def get_all_groups(self) -> List[ProcessGroup]:
946+
raise NotImplementedError
947+
948+
949+
def ft_init_device_mesh(
950+
*,
951+
device_type: str,
952+
mesh_shape: Tuple[int, ...],
953+
mesh_dim_names: Tuple[str, ...],
954+
replicate_dim: int,
955+
manager: "Manager",
956+
):
957+
# We have to lie DeviceMesh that the replicate_dim has only
958+
# 1 rank.
959+
_mesh_shape = list(mesh_shape)
960+
_mesh_shape.pop(replicate_dim)
961+
_mesh_dim_names = list(mesh_dim_names)
962+
_mesh_dim_names.pop(replicate_dim)
963+
mesh = init_device_mesh(
964+
device_type,
965+
mesh_shape=tuple(_mesh_shape),
966+
mesh_dim_names=tuple(_mesh_dim_names),
967+
)
968+
969+
if device_type == "cpu":
970+
pg = ProcessGroupGloo()
971+
elif device_type == "cuda":
972+
pg = ProcessGroupNCCL()
973+
else:
974+
raise ValueError()
975+
976+
manager._pg = pg
977+
replicate_pg = ManagedProcessGroup(manager)
978+
# We have to use MultiProcessTestCase, otherwise c10d will complain
979+
# the same backend has been registered.
980+
replicate_pg.register(mesh_dim_names[replicate_dim])
981+
982+
return ManagedDeviceMesh(
983+
mesh=mesh,
984+
mesh_dim_names=mesh_dim_names,
985+
replicate_pg=replicate_pg,
986+
replicate_dim=replicate_dim,
987+
parent=None,
988+
)

torchft/process_group_test.py

+47
Original file line numberDiff line numberDiff line change
@@ -28,10 +28,12 @@
2828
get_world_size,
2929
)
3030
from torch.distributed.device_mesh import init_device_mesh
31+
from torch.testing._internal.common_distributed import MultiProcessTestCase
3132

3233
from torchft.manager import Manager
3334
from torchft.process_group import (
3435
ErrorSwallowingProcessGroupWrapper,
36+
ManagedDeviceMesh,
3537
ManagedProcessGroup,
3638
ProcessGroup,
3739
ProcessGroupBabyGloo,
@@ -44,6 +46,7 @@
4446
_ErrorSwallowingWork,
4547
_ManagedWork,
4648
extend_device_mesh,
49+
ft_init_device_mesh,
4750
)
4851

4952

@@ -234,6 +237,7 @@ def test_device_mesh(self) -> None:
234237
pg.configure(store_addr, 0, 1)
235238

236239
mesh_2d = extend_device_mesh(mesh_1d, pg)
240+
mesh_2d.get_group("dp")
237241
assert mesh_2d.ndim == 2
238242

239243
pg.unregister()
@@ -299,3 +303,46 @@ def test_managed_process_group(self) -> None:
299303

300304
self.assertEqual(manager.report_error.call_count, 0)
301305
self.assertEqual(manager.wrap_future.call_count, 1)
306+
307+
308+
class DevideMeshTest(MultiProcessTestCase):
309+
@property
310+
def world_size(self):
311+
return 4
312+
313+
def setUp(self):
314+
super().setUp()
315+
os.environ["TORCH_NCCL_DESYNC_DEBUG"] = "0"
316+
self._spawn_processes()
317+
318+
def test_init_device_mesh(self) -> None:
319+
os.environ["MASTER_PORT"] = str(12346)
320+
os.environ["RANK"] = str(self.rank)
321+
os.environ["WORLD_SIZE"] = str(4)
322+
323+
manager = Mock(spec=Manager)
324+
# Even though we only have 4 workers, we can still initialize (2, 4) mesh.
325+
# That's because the replicate group is NOT phystically created in the
326+
# real mesh but is virtually added to the mesh via ManagedDeviceMesh.
327+
device_mesh = ft_init_device_mesh(
328+
device_type="cpu",
329+
mesh_shape=(2, 4),
330+
mesh_dim_names=("dp_replicate", "dp_shard"),
331+
replicate_dim=0,
332+
manager=manager,
333+
)
334+
335+
self.assertTrue(
336+
isinstance(device_mesh.get_group("dp_replicate"), ManagedProcessGroup)
337+
)
338+
self.assertTrue(
339+
not isinstance(device_mesh.get_group("dp_shard"), ManagedProcessGroup)
340+
)
341+
replicate_group = device_mesh.get_group("dp_replicate")
342+
self.assertEqual(replicate_group._manager, manager)
343+
replicate_mesh = device_mesh["dp_replicate"]
344+
self.assertEqual(replicate_mesh.get_group(), replicate_group)
345+
flatten_mesh = device_mesh._flatten("dp")
346+
manager.num_participants.return_value = 1
347+
self.assertEqual(flatten_mesh.size(), 4)
348+
self.assertEqual(flatten_mesh.get_local_rank(), dist.get_rank())

0 commit comments

Comments
 (0)