Skip to content

Commit e98c67c

Browse files
committed
Update
[ghstack-poisoned]
1 parent 884c4bd commit e98c67c

File tree

3 files changed

+109
-23
lines changed

3 files changed

+109
-23
lines changed

torchft/fsdp_test.py

+70
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import os
8+
from concurrent.futures import ThreadPoolExecutor
9+
from typing import Any, Dict, Tuple
10+
from unittest import skipUnless, TestCase
11+
from unittest.mock import Mock
12+
13+
import torch
14+
import torch.distributed as dist
15+
from torch import nn
16+
from torch._C._distributed_c10d import (
17+
_resolve_process_group,
18+
AllgatherOptions,
19+
AllreduceOptions,
20+
BroadcastOptions,
21+
ReduceOp,
22+
)
23+
from torch.distributed import (
24+
_functional_collectives,
25+
get_world_size,
26+
ReduceOp,
27+
TCPStore,
28+
Work,
29+
)
30+
from torch.distributed._composable.fsdp import fully_shard
31+
from torch.distributed.device_mesh import init_device_mesh
32+
from torch.testing._internal.common_distributed import MultiProcessTestCase
33+
34+
from torchft.manager import Manager
35+
from torchft.process_group import ft_init_device_mesh, ManagedProcessGroup
36+
37+
38+
class FSDPTest(MultiProcessTestCase):
39+
@property
40+
def world_size(self):
41+
return 4
42+
43+
def setUp(self):
44+
super().setUp()
45+
os.environ["TORCH_NCCL_DESYNC_DEBUG"] = "0"
46+
self._spawn_processes()
47+
48+
def test_fsdp(self) -> None:
49+
group_size = self.world_size // 2
50+
group = self.rank // group_size
51+
group_rank = self.rank % group_size
52+
53+
os.environ["MASTER_ADDR"] = "127.0.0.1"
54+
os.environ["MASTER_PORT"] = str(12346 + group)
55+
os.environ["RANK"] = str(group_rank)
56+
os.environ["WORLD_SIZE"] = str(group_size)
57+
58+
manager = Mock(spec=Manager)
59+
device_mesh = ft_init_device_mesh(
60+
device_type="cuda",
61+
mesh_shape=(2, 2),
62+
mesh_dim_names=("dp_replicate", "dp_shard"),
63+
replicate_dim=0,
64+
manager=manager,
65+
)
66+
manager.num_participants.return_value = 1
67+
model = nn.Linear(128, 128).cuda()
68+
batch = torch.randn(4, 128).cuda()
69+
shard_model = fully_shard(model, mesh=device_mesh)
70+
shard_model(batch).mean().backward()

torchft/process_group.py

+39-22
Original file line numberDiff line numberDiff line change
@@ -250,9 +250,9 @@ class ProcessGroupGloo(ProcessGroupWrapper):
250250
This is a reconfigurable version of ProcessGroupGloo.
251251
"""
252252

253-
PG_CLASS: Type[
254-
BaseProcessGroup
255-
] = BaseProcessGroupGloo # pyre-fixme[16]: no attribute ProcessGroupGloo
253+
PG_CLASS: Type[BaseProcessGroup] = (
254+
BaseProcessGroupGloo # pyre-fixme[16]: no attribute ProcessGroupGloo
255+
)
256256

257257
def getBackendName(self) -> str:
258258
return "torchft-gloo"
@@ -269,9 +269,9 @@ class ProcessGroupNCCL(ProcessGroupWrapper):
269269
abort when reconfiguring, we need to ensure this is safe.
270270
"""
271271

272-
PG_CLASS: Type[
273-
BaseProcessGroup
274-
] = BaseProcessGroupNCCL # pyre-fixme[16]: no attribute ProcessGroupNCCL
272+
PG_CLASS: Type[BaseProcessGroup] = (
273+
BaseProcessGroupNCCL # pyre-fixme[16]: no attribute ProcessGroupNCCL
274+
)
275275

276276
def getBackendName(self) -> str:
277277
return "torchft-nccl"
@@ -745,9 +745,9 @@ class ProcessGroupBabyGloo(ProcessGroupBaby):
745745
ProcessGroupBabyNCCL.
746746
"""
747747

748-
PG_CLASS: Type[
749-
BaseProcessGroup
750-
] = BaseProcessGroupGloo # pyre-fixme[16]: no attribute ProcessGroupGloo
748+
PG_CLASS: Type[BaseProcessGroup] = (
749+
BaseProcessGroupGloo # pyre-fixme[16]: no attribute ProcessGroupGloo
750+
)
751751

752752
def getBackendName(self) -> str:
753753
return "torchft-baby-gloo"
@@ -769,9 +769,9 @@ class ProcessGroupBabyNCCL(ProcessGroupBaby):
769769
tensors may leak in the current PyTorch implementation. TODO fix
770770
"""
771771

772-
PG_CLASS: Type[
773-
BaseProcessGroup
774-
] = BaseProcessGroupNCCL # pyre-fixme[16]: no attribute ProcessGroupNCCL
772+
PG_CLASS: Type[BaseProcessGroup] = (
773+
BaseProcessGroupNCCL # pyre-fixme[16]: no attribute ProcessGroupNCCL
774+
)
775775
WORK_CLASS = _BabyWorkNCCL
776776

777777
def getBackendName(self) -> str:
@@ -807,27 +807,34 @@ def extend_device_mesh(
807807
)
808808

809809

810-
class ManagedDeviceMesh(DeviceMesh):
810+
class _ManagedDeviceMesh(DeviceMesh):
811811
def __init__(
812812
self,
813813
mesh: Optional[DeviceMesh],
814814
mesh_dim_names: Tuple[str],
815815
replicate_pg: ManagedProcessGroup,
816816
replicate_dim: int,
817-
parent: Optional["ManagedDeviceMesh"],
817+
parent: Optional["_ManagedDeviceMesh"],
818818
):
819+
if mesh is None and parent is not None:
820+
raise ValueError(
821+
"_ManagedDeviceMesh doesn't support both mesh and parent are None."
822+
)
819823
self.mesh = mesh
820824
self.mesh_dim_names = mesh_dim_names
821825
self.replicate_pg = replicate_pg
822826
self.replicate_dim = replicate_dim
823827
self.replicate_dim_name = mesh_dim_names[replicate_dim]
824828
self.parent = parent
825829
self.flatten_meshes = {}
830+
self.device_type = mesh.device_type if mesh is not None else parent.device_type
831+
self._flatten_mesh_list = tuple()
832+
self._thread_id = None
826833

827834
def __getitem__(self, mesh_dim_names: Union[str, Tuple[str, ...]]) -> DeviceMesh:
828835
if isinstance(mesh_dim_names, str):
829836
if mesh_dim_names == self.replicate_dim_name:
830-
return ManagedDeviceMesh(
837+
return _ManagedDeviceMesh(
831838
mesh=None,
832839
mesh_dim_names=(mesh_dim_names,),
833840
replicate_pg=self.replicate_pg,
@@ -843,22 +850,25 @@ def __getitem__(self, mesh_dim_names: Union[str, Tuple[str, ...]]) -> DeviceMesh
843850
if self.replicate_dim_name in mesh_dim_names:
844851
return self.mesh[mesh_dim_names]
845852
else:
846-
return ManagedDeviceMesh(
853+
return _ManagedDeviceMesh(
847854
self.mesh[mesh_dim_names],
848855
mesh_dim_names,
849856
self.replicate_pg,
850857
mesh_dim_name.index(self.replicate_dim_name),
851858
parent=self,
852859
)
853860

861+
def _real_mesh_dim(self, mesh_dim: int) -> int:
862+
return mesh_dim - 1 if mesh_dim > self.replicate_dim else mesh_dim
863+
854864
def get_group(self, mesh_dim: Optional[str] = None) -> BaseProcessGroup:
855865
if mesh_dim is None:
856866
assert self.mesh is None
857867
return self.replicate_pg
858868
elif mesh_dim == self.replicate_dim_name:
859869
return self.replicate_pg
860870
else:
861-
return self.mesh.get_group(mesh_dim)
871+
return self.mesh.get_group(self._real_mesh_dim(mesh_dim))
862872

863873
def _flatten(self, mesh_dim_name: str) -> "DeviceMesh":
864874
flatten_mesh = _FlattenDeviceMesh(self)
@@ -877,7 +887,7 @@ def size(self, mesh_dim: Optional[int] = None) -> int:
877887
elif mesh_dim == self.replicate_dim:
878888
return self.replicate_pg.size()
879889
else:
880-
return self.mesh.size(mesh_dim)
890+
return self.mesh.size(self._real_mesh_dim(mesh_dim))
881891

882892
@property
883893
def ndim(self) -> int:
@@ -904,14 +914,21 @@ def get_local_rank(self, mesh_dim: Optional[Union[int, str]] = None) -> int:
904914
elif mesh_dim in (self.replicate_dim, self.replicate_dim_name):
905915
return get_rank(self.replicate_pg)
906916
else:
907-
return self.mesh.get_local_rank(mesh_dim)
917+
return self.mesh.get_local_rank(self._real_mesh_dim(mesh_dim))
918+
919+
def get_coordinate(self) -> Optional[List[int]]:
920+
"""
921+
Return the relative indices of this rank relative to all
922+
dimensions of the mesh. If this rank is not part of the mesh, return None.
923+
"""
924+
return self.mesh._coordinate_on_dim if self.mesh._coordinate_on_dim else None
908925

909926
def get_all_groups(self) -> List[ProcessGroup]:
910927
raise NotImplementedError
911928

912929

913930
class _FlattenDeviceMesh(DeviceMesh):
914-
def __init__(self, managed_mesh: ManagedDeviceMesh):
931+
def __init__(self, managed_mesh: _ManagedDeviceMesh):
915932
self.managed_mesh = managed_mesh
916933

917934
def __getitem__(self, mesh_dim_names: Union[str, Tuple[str, ...]]) -> DeviceMesh:
@@ -954,7 +971,7 @@ def ft_init_device_mesh(
954971
replicate_dim: int,
955972
manager: "Manager",
956973
):
957-
# We have to lie DeviceMesh that the replicate_dim has only
974+
# We need to mislead DeviceMesh into thinking that replicate_dim has only
958975
# 1 rank.
959976
_mesh_shape = list(mesh_shape)
960977
_mesh_shape.pop(replicate_dim)
@@ -979,7 +996,7 @@ def ft_init_device_mesh(
979996
# the same backend has been registered.
980997
replicate_pg.register(mesh_dim_names[replicate_dim])
981998

982-
return ManagedDeviceMesh(
999+
return _ManagedDeviceMesh(
9831000
mesh=mesh,
9841001
mesh_dim_names=mesh_dim_names,
9851002
replicate_pg=replicate_pg,

torchft/process_group_test.py

-1
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,6 @@
3333
from torchft.manager import Manager
3434
from torchft.process_group import (
3535
ErrorSwallowingProcessGroupWrapper,
36-
ManagedDeviceMesh,
3736
ManagedProcessGroup,
3837
ProcessGroup,
3938
ProcessGroupBabyGloo,

0 commit comments

Comments
 (0)