Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introduce ManagedDeviceMesh to integrate DeviceMesh with TorchFT #56

Merged
merged 27 commits into from
Jan 10, 2025
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
239 changes: 214 additions & 25 deletions torchft/process_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import threading
from abc import ABC
from datetime import timedelta
from typing import TYPE_CHECKING, Dict, List, Optional, Type
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Type, Union

import torch
import torch.distributed as dist
Expand All @@ -38,6 +38,7 @@
Store,
TCPStore,
get_rank,
init_device_mesh,
)
from torch.distributed.distributed_c10d import Work, _world
from torch.futures import Future
Expand Down Expand Up @@ -130,17 +131,7 @@ def size(self) -> int:
def getBackendName(self) -> str:
raise NotImplementedError("not implemented")

def register(self, name: str) -> "ProcessGroup":
"""
Registers the process group with the global registry. This enables usage
with things like functional_collectives which are compilable.

This should only be called once.

Args:
name: name must be a unique name for this process group
"""

def _register(self, name: str) -> str:
group_name = f"{self.getBackendName()}:{name}"

# This is needed for DeviceMesh and functional collectives to work.
Expand All @@ -158,6 +149,21 @@ def create_pg(
devices = ["cpu"]
dist.Backend.register_backend(group_name, create_pg, devices=devices)

return group_name

def register(self, name: str) -> "ProcessGroup":
"""
Registers the process group with the global registry. This enables usage
with things like functional_collectives which are compilable.

This should only be called once.

Args:
name: name must be a unique name for this process group
"""

group_name = self._register(name)

return dist.new_group(
ranks=[dist.get_rank()],
backend=group_name,
Expand Down Expand Up @@ -244,9 +250,9 @@ class ProcessGroupGloo(ProcessGroupWrapper):
This is a reconfigurable version of ProcessGroupGloo.
"""

PG_CLASS: Type[BaseProcessGroup] = (
BaseProcessGroupGloo # pyre-fixme[16]: no attribute ProcessGroupGloo
)
PG_CLASS: Type[
BaseProcessGroup
] = BaseProcessGroupGloo # pyre-fixme[16]: no attribute ProcessGroupGloo

def getBackendName(self) -> str:
return "torchft-gloo"
Expand All @@ -263,9 +269,9 @@ class ProcessGroupNCCL(ProcessGroupWrapper):
abort when reconfiguring, we need to ensure this is safe.
"""

PG_CLASS: Type[BaseProcessGroup] = (
BaseProcessGroupNCCL # pyre-fixme[16]: no attribute ProcessGroupNCCL
)
PG_CLASS: Type[
BaseProcessGroup
] = BaseProcessGroupNCCL # pyre-fixme[16]: no attribute ProcessGroupNCCL

def getBackendName(self) -> str:
return "torchft-nccl"
Expand Down Expand Up @@ -496,6 +502,9 @@ def allreduce(self, tensors: List[torch.Tensor], opts: object) -> Work:
def size(self) -> int:
return self._manager.num_participants()

def getBackendName(self) -> str:
return self._manager._pg.getBackendName()


class _BabyWork(Work):
def __init__(
Expand Down Expand Up @@ -689,7 +698,6 @@ def _future_handler(self, future_queue: mp.Queue) -> None:
logger.exception(f"got unexpected error in future handler: {e}")

def _get_future(self, op_id: int) -> Future[object]:

with self._futures_lock:
fut = Future() # pyre-fixme[29]: is not a function
self._futures[op_id] = fut
Expand Down Expand Up @@ -737,9 +745,9 @@ class ProcessGroupBabyGloo(ProcessGroupBaby):
ProcessGroupBabyNCCL.
"""

PG_CLASS: Type[BaseProcessGroup] = (
BaseProcessGroupGloo # pyre-fixme[16]: no attribute ProcessGroupGloo
)
PG_CLASS: Type[
BaseProcessGroup
] = BaseProcessGroupGloo # pyre-fixme[16]: no attribute ProcessGroupGloo

def getBackendName(self) -> str:
return "torchft-baby-gloo"
Expand All @@ -761,9 +769,9 @@ class ProcessGroupBabyNCCL(ProcessGroupBaby):
tensors may leak in the current PyTorch implementation. TODO fix
"""

PG_CLASS: Type[BaseProcessGroup] = (
BaseProcessGroupNCCL # pyre-fixme[16]: no attribute ProcessGroupNCCL
)
PG_CLASS: Type[
BaseProcessGroup
] = BaseProcessGroupNCCL # pyre-fixme[16]: no attribute ProcessGroupNCCL
WORK_CLASS = _BabyWorkNCCL

def getBackendName(self) -> str:
Expand Down Expand Up @@ -797,3 +805,184 @@ def extend_device_mesh(
mesh=mesh.mesh.unsqueeze(dim),
mesh_dim_names=tuple(mesh_dim_names),
)


class ManagedDeviceMesh(DeviceMesh):
def __init__(
self,
mesh: Optional[DeviceMesh],
mesh_dim_names: Tuple[str],
replicate_pg: ManagedProcessGroup,
replicate_dim: int,
parent: Optional["ManagedDeviceMesh"],
):
self.mesh = mesh
self.mesh_dim_names = mesh_dim_names
self.replicate_pg = replicate_pg
self.replicate_dim = replicate_dim
self.replicate_dim_name = mesh_dim_names[replicate_dim]
self.parent = parent
self.flatten_meshes = {}

def __getitem__(self, mesh_dim_names: Union[str, Tuple[str, ...]]) -> DeviceMesh:
if isinstance(mesh_dim_names, str):
if mesh_dim_names == self.replicate_dim_name:
return ManagedDeviceMesh(
mesh=None,
mesh_dim_names=(mesh_dim_names,),
replicate_pg=self.replicate_pg,
replicate_dim=0,
parent=self,
)
elif mesh_dim_names in self.flatten_meshes:
return self.flatten_meshes[mesh_dim_names]
else:
return self.mesh[mesh_dim_names]
else:
assert isinstance(mesh_dim_names, tuple)
if self.replicate_dim_name in mesh_dim_names:
return self.mesh[mesh_dim_names]
else:
return ManagedDeviceMesh(
self.mesh[mesh_dim_names],
mesh_dim_names,
self.replicate_pg,
mesh_dim_name.index(self.replicate_dim_name),
parent=self,
)

def get_group(self, mesh_dim: Optional[str] = None) -> BaseProcessGroup:
if mesh_dim is None:
assert self.mesh is None
return self.replicate_pg
elif mesh_dim == self.replicate_dim_name:
return self.replicate_pg
else:
return self.mesh.get_group(mesh_dim)

def _flatten(self, mesh_dim_name: str) -> "DeviceMesh":
flatten_mesh = _FlattenDeviceMesh(self)
if self.parent is None:
self.flatten_meshes[mesh_dim_name] = flatten_mesh
else:
self.parent.flatten_meshes[mesh_dim_name] = flatten_mesh
return flatten_mesh

def size(self, mesh_dim: Optional[int] = None) -> int:
if mesh_dim is None:
if self.mesh is None:
return self.replicate_pg.size()
else:
return self.mesh.size() * self.replicate_pg.size()
elif mesh_dim == self.replicate_dim:
return self.replicate_pg.size()
else:
return self.mesh.size(mesh_dim)

@property
def ndim(self) -> int:
return self.mesh.ndim + 1

@property
def shape(self) -> Tuple[int, ...]:
ret = list(self.mesh.shape)
ret.insert(self.replicate_dim, self.replicate_pg.size())

def get_rank(self) -> int:
return self.mesh.get_rank()

def get_local_rank(self, mesh_dim: Optional[Union[int, str]] = None) -> int:
if mesh_dim is None:
if self.mesh is None:
return get_rank(self.replicate_pg)

assert self.replicate_dim == 0, "replicate_dim must be the first one"
other_dim_size = self.mesh.size()
other_dim_rank = self.mesh.get_local_rank()
replicate_pg_rank = get_rank(self.replicate_pg)
return other_dim_size * replicate_pg_rank + other_dim_rank
elif mesh_dim in (self.replicate_dim, self.replicate_dim_name):
return get_rank(self.replicate_pg)
else:
return self.mesh.get_local_rank(mesh_dim)

def get_all_groups(self) -> List[ProcessGroup]:
raise NotImplementedError


class _FlattenDeviceMesh(DeviceMesh):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh nice! Does this solve the issue with flattening in FSDP or just throws an error for now?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It should work for the case where we flatten the device mesh to compute the global loss average but not work for data loader. I think we need to customize dataloader anyway.

def __init__(self, managed_mesh: ManagedDeviceMesh):
self.managed_mesh = managed_mesh

def __getitem__(self, mesh_dim_names: Union[str, Tuple[str, ...]]) -> DeviceMesh:
raise NotImplementedError

def get_group(self, mesh_dim: Optional[str] = None) -> BaseProcessGroup:
raise NotImplementedError

def _flatten(self, mesh_dim_name: str) -> "DeviceMesh":
raise NotImplementedError

def size(self, mesh_dim: Optional[int] = None) -> int:
assert mesh_dim is None
return self.managed_mesh.size()

@property
def ndim(self) -> int:
raise NotImplementedError

@property
def shape(self) -> Tuple[int, ...]:
raise NotImplementedError

def get_rank(self) -> int:
raise NotImplementedError

def get_local_rank(self, mesh_dim: Optional[Union[int, str]] = None) -> int:
assert mesh_dim is None
return self.managed_mesh.get_local_rank()

def get_all_groups(self) -> List[ProcessGroup]:
raise NotImplementedError


def ft_init_device_mesh(
*,
device_type: str,
mesh_shape: Tuple[int, ...],
mesh_dim_names: Tuple[str, ...],
replicate_dim: int,
manager: "Manager",
):
# We have to lie DeviceMesh that the replicate_dim has only
# 1 rank.
_mesh_shape = list(mesh_shape)
_mesh_shape.pop(replicate_dim)
_mesh_dim_names = list(mesh_dim_names)
_mesh_dim_names.pop(replicate_dim)
mesh = init_device_mesh(
device_type,
mesh_shape=tuple(_mesh_shape),
mesh_dim_names=tuple(_mesh_dim_names),
)

if device_type == "cpu":
pg = ProcessGroupGloo()
elif device_type == "cuda":
pg = ProcessGroupNCCL()
else:
raise ValueError()

manager._pg = pg
replicate_pg = ManagedProcessGroup(manager)
# We have to use MultiProcessTestCase, otherwise c10d will complain
# the same backend has been registered.
replicate_pg.register(mesh_dim_names[replicate_dim])

return ManagedDeviceMesh(
mesh=mesh,
mesh_dim_names=mesh_dim_names,
replicate_pg=replicate_pg,
replicate_dim=replicate_dim,
parent=None,
)
47 changes: 47 additions & 0 deletions torchft/process_group_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,12 @@
get_world_size,
)
from torch.distributed.device_mesh import init_device_mesh
from torch.testing._internal.common_distributed import MultiProcessTestCase

from torchft.manager import Manager
from torchft.process_group import (
ErrorSwallowingProcessGroupWrapper,
ManagedDeviceMesh,
ManagedProcessGroup,
ProcessGroup,
ProcessGroupBabyGloo,
Expand All @@ -44,6 +46,7 @@
_ErrorSwallowingWork,
_ManagedWork,
extend_device_mesh,
ft_init_device_mesh,
)


Expand Down Expand Up @@ -234,6 +237,7 @@ def test_device_mesh(self) -> None:
pg.configure(store_addr, 0, 1)

mesh_2d = extend_device_mesh(mesh_1d, pg)
mesh_2d.get_group("dp")
assert mesh_2d.ndim == 2

pg.unregister()
Expand Down Expand Up @@ -299,3 +303,46 @@ def test_managed_process_group(self) -> None:

self.assertEqual(manager.report_error.call_count, 0)
self.assertEqual(manager.wrap_future.call_count, 1)


class DevideMeshTest(MultiProcessTestCase):
@property
def world_size(self):
return 4

def setUp(self):
super().setUp()
os.environ["TORCH_NCCL_DESYNC_DEBUG"] = "0"
self._spawn_processes()

def test_init_device_mesh(self) -> None:
os.environ["MASTER_PORT"] = str(12346)
os.environ["RANK"] = str(self.rank)
os.environ["WORLD_SIZE"] = str(4)

manager = Mock(spec=Manager)
# Even though we only have 4 workers, we can still initialize (2, 4) mesh.
# That's because the replicate group is NOT phystically created in the
# real mesh but is virtually added to the mesh via ManagedDeviceMesh.
device_mesh = ft_init_device_mesh(
device_type="cpu",
mesh_shape=(2, 4),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this value used at all? I assume it doesn't really matter what it's set to?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The replicate part is not going to be valid but other parts are valid and will be used.

mesh_dim_names=("dp_replicate", "dp_shard"),
replicate_dim=0,
manager=manager,
)

self.assertTrue(
isinstance(device_mesh.get_group("dp_replicate"), ManagedProcessGroup)
)
self.assertTrue(
not isinstance(device_mesh.get_group("dp_shard"), ManagedProcessGroup)
)
replicate_group = device_mesh.get_group("dp_replicate")
self.assertEqual(replicate_group._manager, manager)
replicate_mesh = device_mesh["dp_replicate"]
self.assertEqual(replicate_mesh.get_group(), replicate_group)
flatten_mesh = device_mesh._flatten("dp")
manager.num_participants.return_value = 1
self.assertEqual(flatten_mesh.size(), 4)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should this be equal to world_size?

self.assertEqual(flatten_mesh.get_local_rank(), dist.get_rank())
Loading