Skip to content

Commit 4d976fd

Browse files
committed
Introduce ManagedDeviceMesh to integrate DeviceMesh with TorchFT
Summary: ManagedDeviceMesh allow users to manipulate DeviceMesh with TorchFT ManagedProcessGroup. ghstack-source-id: 35ed91c90e242aa9f114a7b25f58097d312e274d Pull Request resolved: #56
1 parent f31d3b1 commit 4d976fd

File tree

3 files changed

+347
-13
lines changed

3 files changed

+347
-13
lines changed

torchft/fsdp_test.py

+72
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import os
8+
from concurrent.futures import ThreadPoolExecutor
9+
from typing import Any, Dict, Tuple
10+
from unittest import TestCase, skipUnless
11+
from unittest.mock import Mock
12+
13+
import torch
14+
import torch.distributed as dist
15+
from torch import nn
16+
from torch._C._distributed_c10d import (
17+
AllgatherOptions,
18+
AllreduceOptions,
19+
BroadcastOptions,
20+
ReduceOp,
21+
_resolve_process_group,
22+
)
23+
from torch.distributed import (
24+
ReduceOp,
25+
TCPStore,
26+
Work,
27+
_functional_collectives,
28+
get_world_size,
29+
)
30+
from torch.distributed._composable.fsdp import fully_shard
31+
from torch.distributed.device_mesh import init_device_mesh
32+
from torch.testing._internal.common_distributed import MultiProcessTestCase
33+
34+
from torchft.manager import Manager
35+
from torchft.process_group import ManagedProcessGroup, ft_init_device_mesh
36+
37+
38+
class FSDPTest(MultiProcessTestCase):
39+
@property
40+
def world_size(self) -> int:
41+
return 4
42+
43+
def setUp(self) -> None:
44+
super().setUp()
45+
os.environ["TORCH_NCCL_DESYNC_DEBUG"] = "0"
46+
self._spawn_processes()
47+
48+
def test_fsdp(self) -> None:
49+
group_size = self.world_size // 2
50+
# pyre-ignore[16]
51+
group = self.rank // group_size
52+
# pyre-ignore[16]
53+
group_rank = self.rank % group_size
54+
55+
os.environ["MASTER_ADDR"] = "127.0.0.1"
56+
os.environ["MASTER_PORT"] = str(12346 + group)
57+
os.environ["RANK"] = str(group_rank)
58+
os.environ["WORLD_SIZE"] = str(group_size)
59+
60+
manager = Mock(spec=Manager)
61+
device_mesh = ft_init_device_mesh(
62+
device_type="cuda",
63+
mesh_shape=(2, 2),
64+
mesh_dim_names=("dp_replicate", "dp_shard"),
65+
replicate_dim=0,
66+
manager=manager,
67+
)
68+
manager.num_participants.return_value = 1
69+
model = nn.Linear(128, 128).cuda()
70+
batch = torch.randn(4, 128).cuda()
71+
shard_model = fully_shard(model, mesh=device_mesh)
72+
shard_model(batch).mean().backward()

torchft/process_group.py

+226-13
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
import threading
2121
from abc import ABC
2222
from datetime import timedelta
23-
from typing import TYPE_CHECKING, Dict, List, Optional, Type
23+
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Type, Union
2424

2525
import torch
2626
import torch.distributed as dist
@@ -38,6 +38,7 @@
3838
Store,
3939
TCPStore,
4040
get_rank,
41+
init_device_mesh,
4142
)
4243
from torch.distributed.distributed_c10d import Work, _world
4344
from torch.futures import Future
@@ -130,17 +131,7 @@ def size(self) -> int:
130131
def getBackendName(self) -> str:
131132
raise NotImplementedError("not implemented")
132133

133-
def register(self, name: str) -> "ProcessGroup":
134-
"""
135-
Registers the process group with the global registry. This enables usage
136-
with things like functional_collectives which are compilable.
137-
138-
This should only be called once.
139-
140-
Args:
141-
name: name must be a unique name for this process group
142-
"""
143-
134+
def _register(self, name: str) -> str:
144135
group_name = f"{self.getBackendName()}:{name}"
145136

146137
# This is needed for DeviceMesh and functional collectives to work.
@@ -158,6 +149,21 @@ def create_pg(
158149
devices = ["cpu"]
159150
dist.Backend.register_backend(group_name, create_pg, devices=devices)
160151

152+
return group_name
153+
154+
def register(self, name: str) -> "ProcessGroup":
155+
"""
156+
Registers the process group with the global registry. This enables usage
157+
with things like functional_collectives which are compilable.
158+
159+
This should only be called once.
160+
161+
Args:
162+
name: name must be a unique name for this process group
163+
"""
164+
165+
group_name = self._register(name)
166+
161167
return dist.new_group(
162168
ranks=[dist.get_rank()],
163169
backend=group_name,
@@ -496,6 +502,9 @@ def allreduce(self, tensors: List[torch.Tensor], opts: object) -> Work:
496502
def size(self) -> int:
497503
return self._manager.num_participants()
498504

505+
def getBackendName(self) -> str:
506+
return self._manager._pg.getBackendName()
507+
499508

500509
class _BabyWork(Work):
501510
def __init__(
@@ -689,7 +698,6 @@ def _future_handler(self, future_queue: mp.Queue) -> None:
689698
logger.exception(f"got unexpected error in future handler: {e}")
690699

691700
def _get_future(self, op_id: int) -> Future[object]:
692-
693701
with self._futures_lock:
694702
fut = Future() # pyre-fixme[29]: is not a function
695703
self._futures[op_id] = fut
@@ -797,3 +805,208 @@ def extend_device_mesh(
797805
mesh=mesh.mesh.unsqueeze(dim),
798806
mesh_dim_names=tuple(mesh_dim_names),
799807
)
808+
809+
810+
class ManagedDeviceMesh(DeviceMesh):
811+
def __init__(
812+
self,
813+
mesh: Optional[DeviceMesh],
814+
mesh_dim_names: Tuple[str, ...],
815+
replicate_pg: ManagedProcessGroup,
816+
replicate_dim: int,
817+
parent: Optional["ManagedDeviceMesh"],
818+
) -> None:
819+
if mesh is None and parent is not None:
820+
raise ValueError(
821+
"ManagedDeviceMesh doesn't support both mesh and parent are None."
822+
)
823+
self.mesh = mesh
824+
self.mesh_dim_names = mesh_dim_names
825+
self.replicate_pg = replicate_pg
826+
self.replicate_dim = replicate_dim
827+
self.replicate_dim_name = mesh_dim_names[replicate_dim]
828+
self.parent = parent
829+
self.flatten_meshes = {}
830+
self.device_type = mesh.device_type if mesh is not None else parent.device_type
831+
self._flatten_mesh_list = tuple()
832+
self._thread_id = None
833+
834+
def __getitem__(self, mesh_dim_names: Union[str, Tuple[str, ...]]) -> DeviceMesh:
835+
if isinstance(mesh_dim_names, str):
836+
if mesh_dim_names == self.replicate_dim_name:
837+
return ManagedDeviceMesh(
838+
mesh=None,
839+
mesh_dim_names=(mesh_dim_names,),
840+
replicate_pg=self.replicate_pg,
841+
replicate_dim=0,
842+
parent=self,
843+
)
844+
elif mesh_dim_names in self.flatten_meshes:
845+
return self.flatten_meshes[mesh_dim_names]
846+
else:
847+
return self.mesh[mesh_dim_names]
848+
else:
849+
assert isinstance(mesh_dim_names, tuple)
850+
if self.replicate_dim_name in mesh_dim_names:
851+
return self.mesh[mesh_dim_names]
852+
else:
853+
return ManagedDeviceMesh(
854+
self.mesh[mesh_dim_names],
855+
mesh_dim_names,
856+
self.replicate_pg,
857+
mesh_dim_names.index(self.replicate_dim_name),
858+
parent=self,
859+
)
860+
861+
def _real_mesh_dim(self, mesh_dim: int) -> int:
862+
return mesh_dim - 1 if mesh_dim > self.replicate_dim else mesh_dim
863+
864+
def get_group(self, mesh_dim: Optional[str] = None) -> BaseProcessGroup:
865+
if mesh_dim is None:
866+
assert self.mesh is None
867+
return self.replicate_pg
868+
elif mesh_dim == self.replicate_dim_name:
869+
return self.replicate_pg
870+
else:
871+
dim = self.mesh_dim_names.index(mesh_dim)
872+
return self.mesh.get_group(self._real_mesh_dim(dim))
873+
874+
def _flatten(self, mesh_dim_name: str) -> "DeviceMesh":
875+
flatten_mesh = _FlattenDeviceMesh(self)
876+
if self.parent is None:
877+
self.flatten_meshes[mesh_dim_name] = flatten_mesh
878+
else:
879+
self.parent.flatten_meshes[mesh_dim_name] = flatten_mesh
880+
return flatten_mesh
881+
882+
def size(self, mesh_dim: Optional[int] = None) -> int:
883+
if mesh_dim is None:
884+
if self.mesh is None:
885+
return self.replicate_pg.size()
886+
else:
887+
return self.mesh.size() * self.replicate_pg.size()
888+
elif mesh_dim == self.replicate_dim:
889+
return self.replicate_pg.size()
890+
else:
891+
return self.mesh.size(self._real_mesh_dim(mesh_dim))
892+
893+
@property
894+
def ndim(self) -> int:
895+
return self.mesh.ndim + 1
896+
897+
@property
898+
def shape(self) -> Tuple[int, ...]:
899+
ret = list(self.mesh.shape)
900+
ret.insert(self.replicate_dim, self.replicate_pg.size())
901+
return ret
902+
903+
def get_rank(self) -> int:
904+
return self.mesh.get_rank()
905+
906+
def get_local_rank(self, mesh_dim: Optional[Union[int, str]] = None) -> int:
907+
if isinstance(mesh_dim, str):
908+
dim = self.mesh_dim_names.index(mesh_dim)
909+
else:
910+
dim = 0 if mesh_dim is None else int(mesh_dim)
911+
912+
if mesh_dim is None:
913+
if self.mesh is None:
914+
return get_rank(self.replicate_pg)
915+
916+
assert self.replicate_dim == 0, "replicate_dim must be the first one"
917+
other_dim_size = self.mesh.size()
918+
other_dim_rank = self.mesh.get_local_rank()
919+
replicate_pg_rank = get_rank(self.replicate_pg)
920+
return other_dim_size * replicate_pg_rank + other_dim_rank
921+
elif dim == self.replicate_dim:
922+
return get_rank(self.replicate_pg)
923+
else:
924+
return self.mesh.get_local_rank(self._real_mesh_dim(dim))
925+
926+
def get_coordinate(self) -> Optional[List[int]]:
927+
"""
928+
Return the relative indices of this rank relative to all
929+
dimensions of the mesh. If this rank is not part of the mesh, return None.
930+
"""
931+
return self.mesh._coordinate_on_dim if self.mesh._coordinate_on_dim else None
932+
933+
def get_all_groups(self) -> List[ProcessGroup]:
934+
raise NotImplementedError
935+
936+
937+
class _FlattenDeviceMesh(DeviceMesh):
938+
def __init__(self, managed_mesh: ManagedDeviceMesh) -> None:
939+
self.managed_mesh = managed_mesh
940+
941+
def __getitem__(self, mesh_dim_names: Union[str, Tuple[str, ...]]) -> DeviceMesh:
942+
raise NotImplementedError
943+
944+
def get_group(self, mesh_dim: Optional[str] = None) -> BaseProcessGroup:
945+
raise NotImplementedError
946+
947+
def _flatten(self, mesh_dim_name: str) -> "DeviceMesh":
948+
raise NotImplementedError
949+
950+
def size(self, mesh_dim: Optional[int] = None) -> int:
951+
assert mesh_dim is None
952+
return self.managed_mesh.size()
953+
954+
@property
955+
def ndim(self) -> int:
956+
raise NotImplementedError
957+
958+
@property
959+
def shape(self) -> Tuple[int, ...]:
960+
raise NotImplementedError
961+
962+
def get_rank(self) -> int:
963+
raise NotImplementedError
964+
965+
def get_local_rank(self, mesh_dim: Optional[Union[int, str]] = None) -> int:
966+
assert mesh_dim is None
967+
return self.managed_mesh.get_local_rank()
968+
969+
def get_all_groups(self) -> List[ProcessGroup]:
970+
raise NotImplementedError
971+
972+
973+
def ft_init_device_mesh(
974+
*,
975+
device_type: str,
976+
mesh_shape: Tuple[int, ...],
977+
mesh_dim_names: Tuple[str, ...],
978+
replicate_dim: int,
979+
manager: "Manager",
980+
) -> "ManagedDeviceMesh":
981+
# We need to mislead DeviceMesh into thinking that replicate_dim has only
982+
# 1 rank.
983+
_mesh_shape = list(mesh_shape)
984+
_mesh_shape.pop(replicate_dim)
985+
_mesh_dim_names = list(mesh_dim_names)
986+
_mesh_dim_names.pop(replicate_dim)
987+
mesh = init_device_mesh(
988+
device_type,
989+
mesh_shape=tuple(_mesh_shape),
990+
mesh_dim_names=tuple(_mesh_dim_names),
991+
)
992+
993+
if device_type == "cpu":
994+
pg = ProcessGroupGloo()
995+
elif device_type == "cuda":
996+
pg = ProcessGroupNCCL()
997+
else:
998+
raise ValueError()
999+
1000+
manager._pg = pg
1001+
replicate_pg = ManagedProcessGroup(manager)
1002+
# We have to use MultiProcessTestCase, otherwise c10d will complain
1003+
# the same backend has been registered.
1004+
replicate_pg.register(mesh_dim_names[replicate_dim])
1005+
1006+
return ManagedDeviceMesh(
1007+
mesh=mesh,
1008+
mesh_dim_names=mesh_dim_names,
1009+
replicate_pg=replicate_pg,
1010+
replicate_dim=replicate_dim,
1011+
parent=None,
1012+
)

0 commit comments

Comments
 (0)