Skip to content

Commit

Permalink
Update on "Introduce ManagedDeviceMesh to integrate DeviceMesh with T…
Browse files Browse the repository at this point in the history
…orchFT"



ManagedDeviceMesh allow users to manipulate DeviceMesh with TorchFT ManagedProcessGroup. This currently work with a simple HSDP case but the actual integration and e2e tests are likely to expose more issues, e.g., checkpointing.

[ghstack-poisoned]
  • Loading branch information
fegin committed Jan 7, 2025
2 parents 6fb19cd + 50b7520 commit dd0bdb1
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 19 deletions.
6 changes: 4 additions & 2 deletions torchft/fsdp_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,17 +37,19 @@

class FSDPTest(MultiProcessTestCase):
@property
def world_size(self):
def world_size(self) -> int:
return 4

def setUp(self):
def setUp(self) -> None:
super().setUp()
os.environ["TORCH_NCCL_DESYNC_DEBUG"] = "0"
self._spawn_processes()

def test_fsdp(self) -> None:
group_size = self.world_size // 2
# pyre-ignore[16]
group = self.rank // group_size
# pyre-ignore[16]
group_rank = self.rank % group_size

os.environ["MASTER_ADDR"] = "127.0.0.1"
Expand Down
35 changes: 21 additions & 14 deletions torchft/process_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -807,18 +807,18 @@ def extend_device_mesh(
)


class _ManagedDeviceMesh(DeviceMesh):
class ManagedDeviceMesh(DeviceMesh):
def __init__(
self,
mesh: Optional[DeviceMesh],
mesh_dim_names: Tuple[str],
mesh_dim_names: Tuple[str, ...],
replicate_pg: ManagedProcessGroup,
replicate_dim: int,
parent: Optional["_ManagedDeviceMesh"],
):
parent: Optional["ManagedDeviceMesh"],
) -> None:
if mesh is None and parent is not None:
raise ValueError(
"_ManagedDeviceMesh doesn't support both mesh and parent are None."
"ManagedDeviceMesh doesn't support both mesh and parent are None."
)
self.mesh = mesh
self.mesh_dim_names = mesh_dim_names
Expand All @@ -834,7 +834,7 @@ def __init__(
def __getitem__(self, mesh_dim_names: Union[str, Tuple[str, ...]]) -> DeviceMesh:
if isinstance(mesh_dim_names, str):
if mesh_dim_names == self.replicate_dim_name:
return _ManagedDeviceMesh(
return ManagedDeviceMesh(
mesh=None,
mesh_dim_names=(mesh_dim_names,),
replicate_pg=self.replicate_pg,
Expand All @@ -850,11 +850,11 @@ def __getitem__(self, mesh_dim_names: Union[str, Tuple[str, ...]]) -> DeviceMesh
if self.replicate_dim_name in mesh_dim_names:
return self.mesh[mesh_dim_names]
else:
return _ManagedDeviceMesh(
return ManagedDeviceMesh(
self.mesh[mesh_dim_names],
mesh_dim_names,
self.replicate_pg,
mesh_dim_name.index(self.replicate_dim_name),
mesh_dim_names.index(self.replicate_dim_name),
parent=self,
)

Expand All @@ -868,7 +868,8 @@ def get_group(self, mesh_dim: Optional[str] = None) -> BaseProcessGroup:
elif mesh_dim == self.replicate_dim_name:
return self.replicate_pg
else:
return self.mesh.get_group(self._real_mesh_dim(mesh_dim))
dim = self.mesh_dim_names.index(mesh_dim)
return self.mesh.get_group(self._real_mesh_dim(dim))

def _flatten(self, mesh_dim_name: str) -> "DeviceMesh":
flatten_mesh = _FlattenDeviceMesh(self)
Expand Down Expand Up @@ -897,11 +898,17 @@ def ndim(self) -> int:
def shape(self) -> Tuple[int, ...]:
ret = list(self.mesh.shape)
ret.insert(self.replicate_dim, self.replicate_pg.size())
return ret

def get_rank(self) -> int:
return self.mesh.get_rank()

def get_local_rank(self, mesh_dim: Optional[Union[int, str]] = None) -> int:
if isinstance(mesh_dim, str):
dim = self.mesh_dim_names.index(mesh_dim)
else:
dim = 0 if mesh_dim is None else int(mesh_dim)

if mesh_dim is None:
if self.mesh is None:
return get_rank(self.replicate_pg)
Expand All @@ -911,10 +918,10 @@ def get_local_rank(self, mesh_dim: Optional[Union[int, str]] = None) -> int:
other_dim_rank = self.mesh.get_local_rank()
replicate_pg_rank = get_rank(self.replicate_pg)
return other_dim_size * replicate_pg_rank + other_dim_rank
elif mesh_dim in (self.replicate_dim, self.replicate_dim_name):
elif dim == self.replicate_dim:
return get_rank(self.replicate_pg)
else:
return self.mesh.get_local_rank(self._real_mesh_dim(mesh_dim))
return self.mesh.get_local_rank(self._real_mesh_dim(dim))

def get_coordinate(self) -> Optional[List[int]]:
"""
Expand All @@ -928,7 +935,7 @@ def get_all_groups(self) -> List[ProcessGroup]:


class _FlattenDeviceMesh(DeviceMesh):
def __init__(self, managed_mesh: _ManagedDeviceMesh):
def __init__(self, managed_mesh: ManagedDeviceMesh) -> None:
self.managed_mesh = managed_mesh

def __getitem__(self, mesh_dim_names: Union[str, Tuple[str, ...]]) -> DeviceMesh:
Expand Down Expand Up @@ -970,7 +977,7 @@ def ft_init_device_mesh(
mesh_dim_names: Tuple[str, ...],
replicate_dim: int,
manager: "Manager",
):
) -> "ManagedDeviceMesh":
# We need to mislead DeviceMesh into thinking that replicate_dim has only
# 1 rank.
_mesh_shape = list(mesh_shape)
Expand All @@ -996,7 +1003,7 @@ def ft_init_device_mesh(
# the same backend has been registered.
replicate_pg.register(mesh_dim_names[replicate_dim])

return _ManagedDeviceMesh(
return ManagedDeviceMesh(
mesh=mesh,
mesh_dim_names=mesh_dim_names,
replicate_pg=replicate_pg,
Expand Down
9 changes: 6 additions & 3 deletions torchft/process_group_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,18 +304,19 @@ def test_managed_process_group(self) -> None:
self.assertEqual(manager.wrap_future.call_count, 1)


class DevideMeshTest(MultiProcessTestCase):
class DeviceMeshTest(MultiProcessTestCase):
@property
def world_size(self):
def world_size(self) -> int:
return 4

def setUp(self):
def setUp(self) -> None:
super().setUp()
os.environ["TORCH_NCCL_DESYNC_DEBUG"] = "0"
self._spawn_processes()

def test_init_device_mesh(self) -> None:
os.environ["MASTER_PORT"] = str(12346)
# pyre-ignore[16]
os.environ["RANK"] = str(self.rank)
os.environ["WORLD_SIZE"] = str(4)

Expand All @@ -331,13 +332,15 @@ def test_init_device_mesh(self) -> None:
manager=manager,
)

# pyre-ignore[16]
self.assertTrue(
isinstance(device_mesh.get_group("dp_replicate"), ManagedProcessGroup)
)
self.assertTrue(
not isinstance(device_mesh.get_group("dp_shard"), ManagedProcessGroup)
)
replicate_group = device_mesh.get_group("dp_replicate")
# pyre-ignore[16]
self.assertEqual(replicate_group._manager, manager)
replicate_mesh = device_mesh["dp_replicate"]
self.assertEqual(replicate_mesh.get_group(), replicate_group)
Expand Down

0 comments on commit dd0bdb1

Please sign in to comment.