|
13 | 13 | from torch import nn
|
14 | 14 | from torch._C._distributed_c10d import _resolve_process_group
|
15 | 15 | from torch.distributed import _functional_collectives, ReduceOp, TCPStore
|
16 |
| -from torch.distributed.device_mesh import init_device_mesh |
| 16 | +from torch.distributed.device_mesh import _mesh_resources, init_device_mesh |
| 17 | + |
| 18 | +from torch.testing._internal.common_distributed import ( |
| 19 | + MultiProcessTestCase, |
| 20 | +) |
| 21 | +from torch.testing._internal.common_utils import FILE_SCHEMA |
17 | 22 |
|
18 | 23 | from torchft.process_group import (
|
19 | 24 | extend_device_mesh,
|
@@ -194,3 +199,60 @@ def test_functional_collectives(self) -> None:
|
194 | 199 | _functional_collectives.all_reduce(t, "sum", pg).wait()
|
195 | 200 | finally:
|
196 | 201 | pg.unregister()
|
| 202 | + |
| 203 | + |
| 204 | +class ProcessGroupMPTest(MultiProcessTestCase): |
| 205 | + @property |
| 206 | + def world_size(self): |
| 207 | + return 4 |
| 208 | + |
| 209 | + def setUp(self): |
| 210 | + super().setUp() |
| 211 | + # Set TORCH_NCCL_DESYNC_DEBUG=0 to disable the NCCL `workCleanupLoop()`, |
| 212 | + # which can cause unit test flakiness: |
| 213 | + # https://github.com/pytorch/pytorch/issues/90848 |
| 214 | + os.environ["TORCH_NCCL_DESYNC_DEBUG"] = "0" |
| 215 | + self._spawn_processes() |
| 216 | + |
| 217 | + def test_init_device_mesh(self) -> None: |
| 218 | + os.environ["MASTER_ADDR"] = "localhost" |
| 219 | + os.environ["MASTER_PORT"] = str(12345) |
| 220 | + os.environ["RANK"] = str(self.rank) |
| 221 | + os.environ["WORLD_SIZE"] = str(4) |
| 222 | + |
| 223 | + def ft_init_device_mesh(device, mesh_shape, mesh_dim_names, replicate_dim): |
| 224 | + if device == "cpu": |
| 225 | + pg = ProcessGroupGloo() |
| 226 | + elif device == "cuda": |
| 227 | + pg = ProcessGroupNCCL() |
| 228 | + else: |
| 229 | + raise ValueError() |
| 230 | + |
| 231 | + # We have to use MultiProcessTestCase, otherwise c10d will complain |
| 232 | + # the same backend has been registered. |
| 233 | + backend_name = pg._register(mesh_dim_names[replicate_dim]) |
| 234 | + # This currently doesn't work with NCCL as DeviceMesh will ignore |
| 235 | + # `_set_mesh_dim_group_options()` and just use `split_group()`. |
| 236 | + # We will need to change DeviceMesh to use `new_group()` instead of |
| 237 | + # `split_group()` when backend is not None. |
| 238 | + _mesh_resources._set_mesh_dim_group_options( |
| 239 | + replicate_dim, backend_name, None |
| 240 | + ) |
| 241 | + device_mesh = init_device_mesh( |
| 242 | + device, mesh_shape=mesh_shape, mesh_dim_names=mesh_dim_names |
| 243 | + ) |
| 244 | + # We need an API to clear the mesh_dim_group_options because it will |
| 245 | + # affect the following flatten() API. |
| 246 | + return device_mesh |
| 247 | + |
| 248 | + device_mesh = ft_init_device_mesh( |
| 249 | + "cpu", mesh_shape=(2, 2), mesh_dim_names=("dp", "tp"), replicate_dim=0 |
| 250 | + ) |
| 251 | + |
| 252 | + store = TCPStore( |
| 253 | + host_name="localhost", port=0, is_master=True, wait_for_workers=False |
| 254 | + ) |
| 255 | + store_addr = f"localhost:{store.port}/prefix" |
| 256 | + pg = device_mesh.get_group("dp") |
| 257 | + pg.configure(store_addr, 0, 1) |
| 258 | + pg.unregister() |
0 commit comments