Skip to content

Commit aadf67c

Browse files
committed
[WIP] A test case to show how to use DeviceMesh API to create the customized PG
This is still not working yet because get_group will fail.
1 parent c7a7f1e commit aadf67c

File tree

2 files changed

+78
-12
lines changed

2 files changed

+78
-12
lines changed

torchft/process_group.py

+15-11
Original file line numberDiff line numberDiff line change
@@ -102,17 +102,7 @@ def size(self) -> int:
102102
def getBackendName(self) -> str:
103103
raise NotImplementedError("not implemented")
104104

105-
def register(self, name: str) -> BaseProcessGroup:
106-
"""
107-
Registers the process group with the global registry. This enables usage
108-
with things like functional_collectives which are compilable.
109-
110-
This should only be called once.
111-
112-
Args:
113-
name: name must be a unique name for this process group
114-
"""
115-
105+
def _register(self, name: str) -> str:
116106
group_name = f"{self.getBackendName()}:{name}"
117107

118108
# This is needed for DeviceMesh and functional collectives to work.
@@ -130,6 +120,20 @@ def create_pg(
130120
devices = ["cpu"]
131121
dist.Backend.register_backend(group_name, create_pg, devices=devices)
132122

123+
return group_name
124+
125+
def register(self, name: str) -> BaseProcessGroup:
126+
"""
127+
Registers the process group with the global registry. This enables usage
128+
with things like functional_collectives which are compilable.
129+
130+
This should only be called once.
131+
132+
Args:
133+
name: name must be a unique name for this process group
134+
"""
135+
group_name = self._register(name)
136+
133137
return dist.new_group(
134138
ranks=[dist.get_rank()],
135139
backend=group_name,

torchft/process_group_test.py

+63-1
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,12 @@
1313
from torch import nn
1414
from torch._C._distributed_c10d import _resolve_process_group
1515
from torch.distributed import _functional_collectives, ReduceOp, TCPStore
16-
from torch.distributed.device_mesh import init_device_mesh
16+
from torch.distributed.device_mesh import _mesh_resources, init_device_mesh
17+
18+
from torch.testing._internal.common_distributed import (
19+
MultiProcessTestCase,
20+
)
21+
from torch.testing._internal.common_utils import FILE_SCHEMA
1722

1823
from torchft.process_group import (
1924
extend_device_mesh,
@@ -194,3 +199,60 @@ def test_functional_collectives(self) -> None:
194199
_functional_collectives.all_reduce(t, "sum", pg).wait()
195200
finally:
196201
pg.unregister()
202+
203+
204+
class ProcessGroupMPTest(MultiProcessTestCase):
205+
@property
206+
def world_size(self):
207+
return 4
208+
209+
def setUp(self):
210+
super().setUp()
211+
# Set TORCH_NCCL_DESYNC_DEBUG=0 to disable the NCCL `workCleanupLoop()`,
212+
# which can cause unit test flakiness:
213+
# https://github.com/pytorch/pytorch/issues/90848
214+
os.environ["TORCH_NCCL_DESYNC_DEBUG"] = "0"
215+
self._spawn_processes()
216+
217+
def test_init_device_mesh(self) -> None:
218+
os.environ["MASTER_ADDR"] = "localhost"
219+
os.environ["MASTER_PORT"] = str(12345)
220+
os.environ["RANK"] = str(self.rank)
221+
os.environ["WORLD_SIZE"] = str(4)
222+
223+
def ft_init_device_mesh(device, mesh_shape, mesh_dim_names, replicate_dim):
224+
if device == "cpu":
225+
pg = ProcessGroupGloo()
226+
elif device == "cuda":
227+
pg = ProcessGroupNCCL()
228+
else:
229+
raise ValueError()
230+
231+
# We have to use MultiProcessTestCase, otherwise c10d will complain
232+
# the same backend has been registered.
233+
backend_name = pg._register(mesh_dim_names[replicate_dim])
234+
# This currently doesn't work with NCCL as DeviceMesh will ignore
235+
# `_set_mesh_dim_group_options()` and just use `split_group()`.
236+
# We will need to change DeviceMesh to use `new_group()` instead of
237+
# `split_group()` when backend is not None.
238+
_mesh_resources._set_mesh_dim_group_options(
239+
replicate_dim, backend_name, None
240+
)
241+
device_mesh = init_device_mesh(
242+
device, mesh_shape=mesh_shape, mesh_dim_names=mesh_dim_names
243+
)
244+
# We need an API to clear the mesh_dim_group_options because it will
245+
# affect the following flatten() API.
246+
return device_mesh
247+
248+
device_mesh = ft_init_device_mesh(
249+
"cpu", mesh_shape=(2, 2), mesh_dim_names=("dp", "tp"), replicate_dim=0
250+
)
251+
252+
store = TCPStore(
253+
host_name="localhost", port=0, is_master=True, wait_for_workers=False
254+
)
255+
store_addr = f"localhost:{store.port}/prefix"
256+
pg = device_mesh.get_group("dp")
257+
pg.configure(store_addr, 0, 1)
258+
pg.unregister()

0 commit comments

Comments
 (0)