Skip to content

Commit 6360978

Browse files
committed
Introduce ManagedDeviceMesh to integrate DeviceMesh with TorchFT
Summary: ManagedDeviceMesh allow users to manipulate DeviceMesh with TorchFT ManagedProcessGroup. ghstack-source-id: e7afbf22208758883e0cc3abacab4f6deb60c41b Pull Request resolved: #56
1 parent f31d3b1 commit 6360978

File tree

4 files changed

+375
-14
lines changed

4 files changed

+375
-14
lines changed

pyproject.toml

+3-1
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,9 @@ dev = [
2525
"pytest",
2626
"black",
2727
"pyre-check",
28-
"parameterized"
28+
"parameterized",
29+
"expecttest",
30+
"numpy"
2931
]
3032

3133
[tool.maturin]

torchft/fsdp_test.py

+70
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import os
8+
from concurrent.futures import ThreadPoolExecutor
9+
from typing import Any, Dict, Tuple
10+
from unittest import TestCase, skipUnless
11+
from unittest.mock import Mock
12+
13+
import torch
14+
import torch.distributed as dist
15+
from torch import nn
16+
from torch._C._distributed_c10d import (
17+
AllgatherOptions,
18+
AllreduceOptions,
19+
BroadcastOptions,
20+
ReduceOp,
21+
_resolve_process_group,
22+
)
23+
from torch.distributed import (
24+
ReduceOp,
25+
TCPStore,
26+
Work,
27+
_functional_collectives,
28+
get_world_size,
29+
)
30+
from torch.distributed._composable.fsdp import fully_shard
31+
from torch.distributed.device_mesh import init_device_mesh
32+
from torch.testing._internal.common_distributed import MultiProcessTestCase
33+
34+
from torchft.manager import Manager
35+
from torchft.process_group import ManagedProcessGroup, ft_init_device_mesh
36+
37+
38+
class FSDPTest(MultiProcessTestCase):
39+
@property
40+
def world_size(self) -> int:
41+
return 4
42+
43+
def setUp(self) -> None:
44+
super().setUp()
45+
os.environ["TORCH_NCCL_DESYNC_DEBUG"] = "0"
46+
self._spawn_processes()
47+
48+
def test_fsdp(self) -> None:
49+
group_size = self.world_size // 2
50+
group = self.rank // group_size
51+
group_rank = self.rank % group_size
52+
53+
os.environ["MASTER_ADDR"] = "127.0.0.1"
54+
os.environ["MASTER_PORT"] = str(12346 + group)
55+
os.environ["RANK"] = str(group_rank)
56+
os.environ["WORLD_SIZE"] = str(group_size)
57+
58+
manager = Mock(spec=Manager)
59+
device_mesh = ft_init_device_mesh(
60+
device_type="cuda",
61+
mesh_shape=(2, 2),
62+
mesh_dim_names=("dp_replicate", "dp_shard"),
63+
replicate_dim=0,
64+
manager=manager,
65+
)
66+
manager.num_participants.return_value = 1
67+
model = nn.Linear(128, 128).cuda()
68+
batch = torch.randn(4, 128).cuda()
69+
shard_model = fully_shard(model, mesh=device_mesh)
70+
shard_model(batch).mean().backward()

torchft/process_group.py

+250-13
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
import threading
2121
from abc import ABC
2222
from datetime import timedelta
23-
from typing import TYPE_CHECKING, Dict, List, Optional, Type
23+
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Type, Union
2424

2525
import torch
2626
import torch.distributed as dist
@@ -38,6 +38,7 @@
3838
Store,
3939
TCPStore,
4040
get_rank,
41+
init_device_mesh,
4142
)
4243
from torch.distributed.distributed_c10d import Work, _world
4344
from torch.futures import Future
@@ -130,17 +131,7 @@ def size(self) -> int:
130131
def getBackendName(self) -> str:
131132
raise NotImplementedError("not implemented")
132133

133-
def register(self, name: str) -> "ProcessGroup":
134-
"""
135-
Registers the process group with the global registry. This enables usage
136-
with things like functional_collectives which are compilable.
137-
138-
This should only be called once.
139-
140-
Args:
141-
name: name must be a unique name for this process group
142-
"""
143-
134+
def _register(self, name: str) -> str:
144135
group_name = f"{self.getBackendName()}:{name}"
145136

146137
# This is needed for DeviceMesh and functional collectives to work.
@@ -158,6 +149,21 @@ def create_pg(
158149
devices = ["cpu"]
159150
dist.Backend.register_backend(group_name, create_pg, devices=devices)
160151

152+
return group_name
153+
154+
def register(self, name: str) -> "ProcessGroup":
155+
"""
156+
Registers the process group with the global registry. This enables usage
157+
with things like functional_collectives which are compilable.
158+
159+
This should only be called once.
160+
161+
Args:
162+
name: name must be a unique name for this process group
163+
"""
164+
165+
group_name = self._register(name)
166+
161167
return dist.new_group(
162168
ranks=[dist.get_rank()],
163169
backend=group_name,
@@ -496,6 +502,9 @@ def allreduce(self, tensors: List[torch.Tensor], opts: object) -> Work:
496502
def size(self) -> int:
497503
return self._manager.num_participants()
498504

505+
def getBackendName(self) -> str:
506+
return self._manager._pg.getBackendName()
507+
499508

500509
class _BabyWork(Work):
501510
def __init__(
@@ -689,7 +698,6 @@ def _future_handler(self, future_queue: mp.Queue) -> None:
689698
logger.exception(f"got unexpected error in future handler: {e}")
690699

691700
def _get_future(self, op_id: int) -> Future[object]:
692-
693701
with self._futures_lock:
694702
fut = Future() # pyre-fixme[29]: is not a function
695703
self._futures[op_id] = fut
@@ -797,3 +805,232 @@ def extend_device_mesh(
797805
mesh=mesh.mesh.unsqueeze(dim),
798806
mesh_dim_names=tuple(mesh_dim_names),
799807
)
808+
809+
810+
class ManagedDeviceMesh(DeviceMesh):
811+
def __init__(
812+
self,
813+
mesh: Optional[DeviceMesh],
814+
mesh_dim_names: Tuple[str, ...],
815+
replicate_pg: ManagedProcessGroup,
816+
replicate_dim: int,
817+
parent: Optional["ManagedDeviceMesh"],
818+
) -> None:
819+
if mesh is None and parent is not None:
820+
raise ValueError(
821+
"ManagedDeviceMesh doesn't support both mesh and parent are None."
822+
)
823+
self.mesh = mesh
824+
self.mesh_dim_names = mesh_dim_names
825+
self.replicate_pg = replicate_pg
826+
self.replicate_dim = replicate_dim
827+
self.replicate_dim_name: str = mesh_dim_names[replicate_dim]
828+
self.parent = parent
829+
self.flatten_meshes: Dict[str, DeviceMesh] = {}
830+
self.device_type: str
831+
if mesh is not None:
832+
self.device_type = mesh.device_type
833+
else:
834+
assert parent is not None
835+
self.device_type = parent.device_type
836+
self._flatten_mesh_list: Tuple[DeviceMesh, ...] = tuple()
837+
self._thread_id: Optional[int] = None
838+
839+
def __getitem__(self, mesh_dim_names: Union[str, Tuple[str, ...]]) -> DeviceMesh:
840+
if isinstance(mesh_dim_names, str):
841+
if mesh_dim_names == self.replicate_dim_name:
842+
return ManagedDeviceMesh(
843+
mesh=None,
844+
mesh_dim_names=(mesh_dim_names,),
845+
replicate_pg=self.replicate_pg,
846+
replicate_dim=0,
847+
parent=self,
848+
)
849+
elif mesh_dim_names in self.flatten_meshes:
850+
return self.flatten_meshes[mesh_dim_names]
851+
else:
852+
assert self.mesh is not None
853+
return self.mesh[mesh_dim_names]
854+
else:
855+
assert isinstance(mesh_dim_names, tuple)
856+
if self.replicate_dim_name in mesh_dim_names:
857+
assert self.mesh is not None
858+
return self.mesh[mesh_dim_names]
859+
else:
860+
assert self.mesh is not None
861+
return ManagedDeviceMesh(
862+
self.mesh[mesh_dim_names],
863+
mesh_dim_names,
864+
self.replicate_pg,
865+
mesh_dim_names.index(self.replicate_dim_name),
866+
parent=self,
867+
)
868+
869+
def _real_mesh_dim(self, mesh_dim: int) -> int:
870+
return mesh_dim - 1 if mesh_dim > self.replicate_dim else mesh_dim
871+
872+
def get_group(self, mesh_dim: Optional[Union[int, str]] = None) -> BaseProcessGroup:
873+
if isinstance(mesh_dim, str):
874+
dim = self.mesh_dim_names.index(mesh_dim)
875+
else:
876+
dim = 0 if mesh_dim is None else int(mesh_dim)
877+
878+
if mesh_dim is None:
879+
assert self.mesh is not None
880+
return self.replicate_pg
881+
elif dim == self.replicate_dim:
882+
return self.replicate_pg
883+
else:
884+
assert self.mesh is not None
885+
return self.mesh.get_group(self._real_mesh_dim(dim))
886+
887+
def _flatten(self, mesh_dim_name: Optional[str]) -> "DeviceMesh":
888+
flatten_mesh = _FlattenDeviceMesh(self)
889+
if mesh_dim_name is None:
890+
raise ValueError("ManagedDeviceMesh._flatten requires `mesh_dim_name`")
891+
if self.parent is None:
892+
self.flatten_meshes[mesh_dim_name] = flatten_mesh
893+
else:
894+
self.parent.flatten_meshes[mesh_dim_name] = flatten_mesh
895+
return flatten_mesh
896+
897+
def size(self, mesh_dim: Optional[int] = None) -> int:
898+
if mesh_dim is None:
899+
if self.mesh is None:
900+
return self.replicate_pg.size()
901+
else:
902+
assert self.mesh is not None
903+
return self.mesh.size() * self.replicate_pg.size()
904+
elif mesh_dim == self.replicate_dim:
905+
return self.replicate_pg.size()
906+
else:
907+
assert self.mesh is not None
908+
return self.mesh.size(self._real_mesh_dim(mesh_dim))
909+
910+
@property
911+
def ndim(self) -> int:
912+
assert self.mesh is not None
913+
return self.mesh.ndim + 1
914+
915+
@property
916+
def shape(self) -> Tuple[int, ...]:
917+
assert self.mesh is not None
918+
ret: List[int] = list(self.mesh.shape)
919+
ret.insert(self.replicate_dim, self.replicate_pg.size())
920+
return tuple(ret)
921+
922+
def get_rank(self) -> int:
923+
assert self.mesh is not None
924+
return self.mesh.get_rank()
925+
926+
def get_local_rank(self, mesh_dim: Optional[Union[int, str]] = None) -> int:
927+
if isinstance(mesh_dim, str):
928+
dim = self.mesh_dim_names.index(mesh_dim)
929+
else:
930+
dim = 0 if mesh_dim is None else int(mesh_dim)
931+
932+
if mesh_dim is None:
933+
if self.mesh is None:
934+
return get_rank(self.replicate_pg)
935+
936+
assert self.replicate_dim == 0, "replicate_dim must be the first one"
937+
assert self.mesh is not None
938+
other_dim_size = self.mesh.size()
939+
assert self.mesh is not None
940+
other_dim_rank = self.mesh.get_local_rank()
941+
replicate_pg_rank = get_rank(self.replicate_pg)
942+
return other_dim_size * replicate_pg_rank + other_dim_rank
943+
elif dim == self.replicate_dim:
944+
return get_rank(self.replicate_pg)
945+
else:
946+
assert self.mesh is not None
947+
return self.mesh.get_local_rank(self._real_mesh_dim(dim))
948+
949+
def get_coordinate(self) -> Optional[List[int]]:
950+
"""
951+
Return the relative indices of this rank relative to all
952+
dimensions of the mesh. If this rank is not part of the mesh, return None.
953+
"""
954+
assert self.mesh is not None
955+
return self.mesh._coordinate_on_dim if self.mesh._coordinate_on_dim else None
956+
957+
def get_all_groups(self) -> List[BaseProcessGroup]:
958+
raise NotImplementedError
959+
960+
961+
class _FlattenDeviceMesh(DeviceMesh):
962+
def __init__(self, managed_mesh: ManagedDeviceMesh) -> None:
963+
self.managed_mesh = managed_mesh
964+
965+
def __getitem__(self, mesh_dim_names: Union[str, Tuple[str, ...]]) -> DeviceMesh:
966+
raise NotImplementedError
967+
968+
def get_group(self, mesh_dim: Optional[Union[int, str]] = None) -> BaseProcessGroup:
969+
raise NotImplementedError
970+
971+
def _flatten(self, mesh_dim_name: Optional[str]) -> "DeviceMesh":
972+
raise NotImplementedError
973+
974+
def size(self, mesh_dim: Optional[int] = None) -> int:
975+
assert mesh_dim is None
976+
return self.managed_mesh.size()
977+
978+
@property
979+
def ndim(self) -> int:
980+
raise NotImplementedError
981+
982+
@property
983+
def shape(self) -> Tuple[int, ...]:
984+
raise NotImplementedError
985+
986+
def get_rank(self) -> int:
987+
raise NotImplementedError
988+
989+
def get_local_rank(self, mesh_dim: Optional[Union[int, str]] = None) -> int:
990+
assert mesh_dim is None
991+
return self.managed_mesh.get_local_rank()
992+
993+
def get_all_groups(self) -> List[BaseProcessGroup]:
994+
raise NotImplementedError
995+
996+
997+
def ft_init_device_mesh(
998+
*,
999+
device_type: str,
1000+
mesh_shape: Tuple[int, ...],
1001+
mesh_dim_names: Tuple[str, ...],
1002+
replicate_dim: int,
1003+
manager: "Manager",
1004+
) -> "ManagedDeviceMesh":
1005+
# We need to mislead DeviceMesh into thinking that replicate_dim has only
1006+
# 1 rank.
1007+
_mesh_shape = list(mesh_shape)
1008+
_mesh_shape.pop(replicate_dim)
1009+
_mesh_dim_names = list(mesh_dim_names)
1010+
_mesh_dim_names.pop(replicate_dim)
1011+
mesh = init_device_mesh(
1012+
device_type,
1013+
mesh_shape=tuple(_mesh_shape),
1014+
mesh_dim_names=tuple(_mesh_dim_names),
1015+
)
1016+
1017+
if device_type == "cpu":
1018+
pg = ProcessGroupGloo()
1019+
elif device_type == "cuda":
1020+
pg = ProcessGroupNCCL()
1021+
else:
1022+
raise ValueError()
1023+
1024+
manager._pg = pg
1025+
replicate_pg = ManagedProcessGroup(manager)
1026+
# We have to use MultiProcessTestCase, otherwise c10d will complain
1027+
# the same backend has been registered.
1028+
replicate_pg.register(mesh_dim_names[replicate_dim])
1029+
1030+
return ManagedDeviceMesh(
1031+
mesh=mesh,
1032+
mesh_dim_names=mesh_dim_names,
1033+
replicate_pg=replicate_pg,
1034+
replicate_dim=replicate_dim,
1035+
parent=None,
1036+
)

0 commit comments

Comments
 (0)