Skip to content

Commit b617bd2

Browse files
authored
Introduce ManagedDeviceMesh to integrate DeviceMesh with TorchFT (#56)
* Update [ghstack-poisoned] * Update (base update) [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update base for Update on "Introduce ManagedDeviceMesh to integrate DeviceMesh with TorchFT" ManagedDeviceMesh allow users to manipulate DeviceMesh with TorchFT ManagedProcessGroup. This currently work with a simple HSDP case but the actual integration and e2e tests are likely to expose more issues, e.g., checkpointing. [ghstack-poisoned] * Update base for Update on "Introduce ManagedDeviceMesh to integrate DeviceMesh with TorchFT" ManagedDeviceMesh allow users to manipulate DeviceMesh with TorchFT ManagedProcessGroup. This currently work with a simple HSDP case but the actual integration and e2e tests are likely to expose more issues, e.g., checkpointing. [ghstack-poisoned] * Update base for Update on "Introduce ManagedDeviceMesh to integrate DeviceMesh with TorchFT" ManagedDeviceMesh allow users to manipulate DeviceMesh with TorchFT ManagedProcessGroup. This currently work with a simple HSDP case but the actual integration and e2e tests are likely to expose more issues, e.g., checkpointing. [ghstack-poisoned] * Update base for Update on "Introduce ManagedDeviceMesh to integrate DeviceMesh with TorchFT" ManagedDeviceMesh allow users to manipulate DeviceMesh with TorchFT ManagedProcessGroup. This currently work with a simple HSDP case but the actual integration and e2e tests are likely to expose more issues, e.g., checkpointing. [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned]
1 parent bc99344 commit b617bd2

File tree

4 files changed

+379
-16
lines changed

4 files changed

+379
-16
lines changed

pyproject.toml

+3-1
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,9 @@ dev = [
2525
"pytest",
2626
"black",
2727
"pyre-check",
28-
"parameterized"
28+
"parameterized",
29+
"expecttest",
30+
"numpy"
2931
]
3032

3133
[tool.maturin]

torchft/fsdp_test.py

+74
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import multiprocessing
8+
import os
9+
import unittest
10+
from concurrent.futures import ProcessPoolExecutor
11+
from typing import Any, Dict, Tuple
12+
from unittest.mock import Mock
13+
14+
import torch
15+
import torch.distributed as dist
16+
from torch import nn
17+
from torch._C._distributed_c10d import (
18+
AllgatherOptions,
19+
AllreduceOptions,
20+
BroadcastOptions,
21+
ReduceOp,
22+
_resolve_process_group,
23+
)
24+
from torch.distributed import (
25+
ReduceOp,
26+
TCPStore,
27+
Work,
28+
_functional_collectives,
29+
get_world_size,
30+
)
31+
from torch.distributed._composable.fsdp import fully_shard
32+
from torch.distributed.device_mesh import init_device_mesh
33+
34+
from torchft.manager import Manager
35+
from torchft.process_group import ManagedProcessGroup, ft_init_device_mesh
36+
37+
38+
class FSDPTest(unittest.TestCase):
39+
@staticmethod
40+
def _test_fsdp(world_size: int, rank: int) -> None:
41+
torch.cuda.set_device(rank)
42+
43+
group_size = world_size // 2
44+
group = rank // group_size
45+
group_rank = rank % group_size
46+
47+
os.environ["MASTER_ADDR"] = "127.0.0.1"
48+
os.environ["MASTER_PORT"] = str(12346 + group)
49+
os.environ["RANK"] = str(group_rank)
50+
os.environ["WORLD_SIZE"] = str(group_size)
51+
52+
manager = Mock(spec=Manager)
53+
device_mesh = ft_init_device_mesh(
54+
device_type="cuda",
55+
mesh_shape=(2, 2),
56+
mesh_dim_names=("dp_replicate", "dp_shard"),
57+
replicate_dim=0,
58+
manager=manager,
59+
)
60+
manager.num_participants.return_value = 1
61+
model = nn.Linear(128, 128).cuda()
62+
batch = torch.randn(4, 128).cuda()
63+
shard_model = fully_shard(model, mesh=device_mesh)
64+
shard_model(batch).mean().backward()
65+
66+
# pyre-ignore[56]: Pyre was not able to infer the type of argument
67+
@unittest.skipIf(torch.cuda.device_count() < 4, "Not enough GPUs")
68+
def test_fsdp(self) -> None:
69+
multiprocessing.set_start_method("spawn")
70+
with ProcessPoolExecutor(max_workers=4) as executor:
71+
futures = []
72+
for i in range(4):
73+
future = executor.submit(self._test_fsdp, 4, i)
74+
futures.append(future)

torchft/process_group.py

+249-13
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
import threading
2222
from abc import ABC
2323
from datetime import timedelta
24-
from typing import TYPE_CHECKING, Dict, List, Optional, Type, Union
24+
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Type, Union
2525

2626
import torch
2727
import torch.distributed as dist
@@ -39,6 +39,7 @@
3939
Store,
4040
TCPStore,
4141
get_rank,
42+
init_device_mesh,
4243
)
4344
from torch.distributed.distributed_c10d import Work, _world
4445
from torch.futures import Future
@@ -149,17 +150,7 @@ def size(self) -> int:
149150
def getBackendName(self) -> str:
150151
raise NotImplementedError("not implemented")
151152

152-
def register(self, name: str) -> "ProcessGroup":
153-
"""
154-
Registers the process group with the global registry. This enables usage
155-
with things like functional_collectives which are compilable.
156-
157-
This should only be called once.
158-
159-
Args:
160-
name: name must be a unique name for this process group
161-
"""
162-
153+
def _register(self, name: str) -> str:
163154
group_name = f"{self.getBackendName()}:{name}"
164155

165156
# This is needed for DeviceMesh and functional collectives to work.
@@ -177,6 +168,21 @@ def create_pg(
177168
devices = ["cpu"]
178169
dist.Backend.register_backend(group_name, create_pg, devices=devices)
179170

171+
return group_name
172+
173+
def register(self, name: str) -> "ProcessGroup":
174+
"""
175+
Registers the process group with the global registry. This enables usage
176+
with things like functional_collectives which are compilable.
177+
178+
This should only be called once.
179+
180+
Args:
181+
name: name must be a unique name for this process group
182+
"""
183+
184+
group_name = self._register(name)
185+
180186
return dist.new_group(
181187
ranks=[dist.get_rank()],
182188
backend=group_name,
@@ -519,6 +525,9 @@ def allreduce(self, tensors: List[torch.Tensor], opts: object) -> Work:
519525
def size(self) -> int:
520526
return self._manager.num_participants()
521527

528+
def getBackendName(self) -> str:
529+
return self._manager._pg.getBackendName()
530+
522531

523532
class _BabyWork(Work):
524533
def __init__(
@@ -730,7 +739,6 @@ def _future_handler(self, future_queue: mp.Queue) -> None:
730739
logger.exception(f"got unexpected error in future handler: {e}")
731740

732741
def _get_future(self, op_id: int) -> Future[object]:
733-
734742
with self._futures_lock:
735743
fut = Future() # pyre-fixme[29]: is not a function
736744
self._futures[op_id] = fut
@@ -841,3 +849,231 @@ def extend_device_mesh(
841849
mesh=mesh.mesh.unsqueeze(dim),
842850
mesh_dim_names=tuple(mesh_dim_names),
843851
)
852+
853+
854+
class ManagedDeviceMesh(DeviceMesh):
855+
def __init__(
856+
self,
857+
mesh: Optional[DeviceMesh],
858+
mesh_dim_names: Tuple[str, ...],
859+
replicate_pg: ManagedProcessGroup,
860+
replicate_dim: int,
861+
parent: Optional["ManagedDeviceMesh"],
862+
) -> None:
863+
if mesh is None and parent is None:
864+
raise ValueError(
865+
"ManagedDeviceMesh doesn't support both mesh and parent are None."
866+
)
867+
self.mesh = mesh
868+
self.mesh_dim_names = mesh_dim_names
869+
self.replicate_pg = replicate_pg
870+
self.replicate_dim = replicate_dim
871+
self.replicate_dim_name: str = mesh_dim_names[replicate_dim]
872+
self.parent = parent
873+
self.flatten_meshes: Dict[str, DeviceMesh] = {}
874+
self.device_type: str
875+
if mesh is not None:
876+
self.device_type = mesh.device_type
877+
else:
878+
assert parent is not None
879+
self.device_type = parent.device_type
880+
self._flatten_mesh_list: Tuple[DeviceMesh, ...] = tuple()
881+
self._thread_id: Optional[int] = None
882+
883+
def __getitem__(self, mesh_dim_names: Union[str, Tuple[str, ...]]) -> DeviceMesh:
884+
if isinstance(mesh_dim_names, str):
885+
if mesh_dim_names == self.replicate_dim_name:
886+
return ManagedDeviceMesh(
887+
mesh=None,
888+
mesh_dim_names=(mesh_dim_names,),
889+
replicate_pg=self.replicate_pg,
890+
replicate_dim=0,
891+
parent=self,
892+
)
893+
elif mesh_dim_names in self.flatten_meshes:
894+
return self.flatten_meshes[mesh_dim_names]
895+
else:
896+
assert self.mesh is not None
897+
return self.mesh[mesh_dim_names]
898+
else:
899+
assert isinstance(mesh_dim_names, tuple)
900+
if self.replicate_dim_name in mesh_dim_names:
901+
assert self.mesh is not None
902+
return self.mesh[mesh_dim_names]
903+
else:
904+
assert self.mesh is not None
905+
return ManagedDeviceMesh(
906+
self.mesh[mesh_dim_names],
907+
mesh_dim_names,
908+
self.replicate_pg,
909+
mesh_dim_names.index(self.replicate_dim_name),
910+
parent=self,
911+
)
912+
913+
def _real_mesh_dim(self, mesh_dim: int) -> int:
914+
return mesh_dim - 1 if mesh_dim > self.replicate_dim else mesh_dim
915+
916+
def get_group(self, mesh_dim: Optional[Union[int, str]] = None) -> BaseProcessGroup:
917+
if isinstance(mesh_dim, str):
918+
dim = self.mesh_dim_names.index(mesh_dim)
919+
else:
920+
dim = 0 if mesh_dim is None else int(mesh_dim)
921+
922+
if mesh_dim is None:
923+
return self.replicate_pg
924+
elif dim == self.replicate_dim:
925+
return self.replicate_pg
926+
else:
927+
assert self.mesh is not None
928+
return self.mesh.get_group(self._real_mesh_dim(dim))
929+
930+
def _flatten(self, mesh_dim_name: Optional[str]) -> "DeviceMesh":
931+
flatten_mesh = _FlattenDeviceMesh(self)
932+
if mesh_dim_name is None:
933+
raise ValueError("ManagedDeviceMesh._flatten requires `mesh_dim_name`")
934+
if self.parent is None:
935+
self.flatten_meshes[mesh_dim_name] = flatten_mesh
936+
else:
937+
self.parent.flatten_meshes[mesh_dim_name] = flatten_mesh
938+
return flatten_mesh
939+
940+
def size(self, mesh_dim: Optional[int] = None) -> int:
941+
if mesh_dim is None:
942+
if self.mesh is None:
943+
return self.replicate_pg.size()
944+
else:
945+
assert self.mesh is not None
946+
return self.mesh.size() * self.replicate_pg.size()
947+
elif mesh_dim == self.replicate_dim:
948+
return self.replicate_pg.size()
949+
else:
950+
assert self.mesh is not None
951+
return self.mesh.size(self._real_mesh_dim(mesh_dim))
952+
953+
@property
954+
def ndim(self) -> int:
955+
assert self.mesh is not None
956+
return self.mesh.ndim + 1
957+
958+
@property
959+
def shape(self) -> Tuple[int, ...]:
960+
assert self.mesh is not None
961+
ret: List[int] = list(self.mesh.shape)
962+
ret.insert(self.replicate_dim, self.replicate_pg.size())
963+
return tuple(ret)
964+
965+
def get_rank(self) -> int:
966+
assert self.mesh is not None
967+
return self.mesh.get_rank()
968+
969+
def get_local_rank(self, mesh_dim: Optional[Union[int, str]] = None) -> int:
970+
if isinstance(mesh_dim, str):
971+
dim = self.mesh_dim_names.index(mesh_dim)
972+
else:
973+
dim = 0 if mesh_dim is None else int(mesh_dim)
974+
975+
if mesh_dim is None:
976+
if self.mesh is None:
977+
return get_rank(self.replicate_pg)
978+
979+
assert self.replicate_dim == 0, "replicate_dim must be the first one"
980+
assert self.mesh is not None
981+
other_dim_size = self.mesh.size()
982+
assert self.mesh is not None
983+
other_dim_rank = self.mesh.get_local_rank()
984+
replicate_pg_rank = get_rank(self.replicate_pg)
985+
return other_dim_size * replicate_pg_rank + other_dim_rank
986+
elif dim == self.replicate_dim:
987+
return get_rank(self.replicate_pg)
988+
else:
989+
assert self.mesh is not None
990+
return self.mesh.get_local_rank(self._real_mesh_dim(dim))
991+
992+
def get_coordinate(self) -> Optional[List[int]]:
993+
"""
994+
Return the relative indices of this rank relative to all
995+
dimensions of the mesh. If this rank is not part of the mesh, return None.
996+
"""
997+
assert self.mesh is not None
998+
return self.mesh._coordinate_on_dim if self.mesh._coordinate_on_dim else None
999+
1000+
def get_all_groups(self) -> List[BaseProcessGroup]:
1001+
raise NotImplementedError
1002+
1003+
1004+
class _FlattenDeviceMesh(DeviceMesh):
1005+
def __init__(self, managed_mesh: ManagedDeviceMesh) -> None:
1006+
self.managed_mesh = managed_mesh
1007+
1008+
def __getitem__(self, mesh_dim_names: Union[str, Tuple[str, ...]]) -> DeviceMesh:
1009+
raise NotImplementedError
1010+
1011+
def get_group(self, mesh_dim: Optional[Union[int, str]] = None) -> BaseProcessGroup:
1012+
raise NotImplementedError
1013+
1014+
def _flatten(self, mesh_dim_name: Optional[str]) -> "DeviceMesh":
1015+
raise NotImplementedError
1016+
1017+
def size(self, mesh_dim: Optional[int] = None) -> int:
1018+
assert mesh_dim is None
1019+
return self.managed_mesh.size()
1020+
1021+
@property
1022+
def ndim(self) -> int:
1023+
raise NotImplementedError
1024+
1025+
@property
1026+
def shape(self) -> Tuple[int, ...]:
1027+
raise NotImplementedError
1028+
1029+
def get_rank(self) -> int:
1030+
raise NotImplementedError
1031+
1032+
def get_local_rank(self, mesh_dim: Optional[Union[int, str]] = None) -> int:
1033+
assert mesh_dim is None
1034+
return self.managed_mesh.get_local_rank()
1035+
1036+
def get_all_groups(self) -> List[BaseProcessGroup]:
1037+
raise NotImplementedError
1038+
1039+
1040+
def ft_init_device_mesh(
1041+
*,
1042+
device_type: str,
1043+
mesh_shape: Tuple[int, ...],
1044+
mesh_dim_names: Tuple[str, ...],
1045+
replicate_dim: int,
1046+
manager: "Manager",
1047+
) -> "ManagedDeviceMesh":
1048+
# We need to mislead DeviceMesh into thinking that replicate_dim has only
1049+
# 1 rank.
1050+
_mesh_shape = list(mesh_shape)
1051+
_mesh_shape.pop(replicate_dim)
1052+
_mesh_dim_names = list(mesh_dim_names)
1053+
_mesh_dim_names.pop(replicate_dim)
1054+
mesh = init_device_mesh(
1055+
device_type,
1056+
mesh_shape=tuple(_mesh_shape),
1057+
mesh_dim_names=tuple(_mesh_dim_names),
1058+
)
1059+
1060+
if device_type == "cpu":
1061+
pg = ProcessGroupGloo()
1062+
elif device_type == "cuda":
1063+
pg = ProcessGroupNCCL()
1064+
else:
1065+
raise ValueError()
1066+
1067+
manager._pg = pg
1068+
replicate_pg = ManagedProcessGroup(manager)
1069+
# We have to use MultiProcessTestCase, otherwise c10d will complain
1070+
# the same backend has been registered.
1071+
replicate_pg.register(mesh_dim_names[replicate_dim])
1072+
1073+
return ManagedDeviceMesh(
1074+
mesh=mesh,
1075+
mesh_dim_names=mesh_dim_names,
1076+
replicate_pg=replicate_pg,
1077+
replicate_dim=replicate_dim,
1078+
parent=None,
1079+
)

0 commit comments

Comments
 (0)