Skip to content

Commit ee864cf

Browse files
authored
process_group: register via public API (#14)
1 parent 4631d6f commit ee864cf

File tree

2 files changed

+41
-25
lines changed

2 files changed

+41
-25
lines changed

torchft/process_group.py

+21-11
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ def size(self) -> int:
102102
def getBackendName(self) -> str:
103103
raise NotImplementedError("not implemented")
104104

105-
def register(self, name: str) -> None:
105+
def register(self, name: str) -> BaseProcessGroup:
106106
"""
107107
Registers the process group with the global registry. This enables usage
108108
with things like functional_collectives which are compilable.
@@ -113,32 +113,42 @@ def register(self, name: str) -> None:
113113
name: name must be a unique name for this process group
114114
"""
115115

116-
self._group_name = f"{self.getBackendName()}:{name}"
117-
_register_process_group(self.group_name, self)
116+
group_name = f"{self.getBackendName()}:{name}"
118117

119-
# This is needed for DeviceMesh to work
118+
# This is needed for DeviceMesh and functional collectives to work.
120119
# Resizable worlds don't fit well into DeviceMesh so we register a world
121120
# size 1 PG.
122-
_world.pg_map[self] = (None, None)
123-
_world.pg_names[self] = self._group_name
124-
_world.pg_to_tag[self] = self._group_name
125-
_world.tags_to_pg.setdefault(self._group_name, []).append(self)
126-
# these PGs can be resized so we lie about the rank mapping
127-
_world.pg_group_ranks[self] = {get_rank(): 0}
121+
122+
def create_pg(
123+
prefix_store: PrefixStore, rank: int, world_size: int, timeout: float
124+
) -> ProcessGroup:
125+
return self
126+
127+
dist.Backend.register_backend(group_name, create_pg)
128+
129+
return dist.new_group(
130+
ranks=[dist.get_rank()],
131+
backend=group_name,
132+
group_desc=group_name,
133+
timeout=timedelta(seconds=60.0), # this timeout isn't used
134+
)
128135

129136
@property
130137
def group_name(self) -> str:
131138
if self._group_name is None:
132139
raise ValueError("ProcessGroup name not set")
133140
return self._group_name
134141

142+
def _set_group_name(self, name: str) -> None:
143+
self._group_name = name
144+
135145
def unregister(self) -> None:
136146
"""
137147
Unregisters the process group with the global registry.
138148
139149
Must be registered first.
140150
"""
141-
_unregister_process_group(self.group_name)
151+
dist.destroy_process_group(self)
142152

143153

144154
class ProcessGroupWrapper(ProcessGroup):

torchft/process_group_test.py

+20-14
Original file line numberDiff line numberDiff line change
@@ -4,31 +4,35 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
from unittest import TestCase, skipUnless
8-
from concurrent.futures import ThreadPoolExecutor
97
import os
8+
from concurrent.futures import ThreadPoolExecutor
9+
from unittest import skipUnless, TestCase
1010

1111
import torch
12-
from torch.distributed import TCPStore, ReduceOp
1312
import torch.distributed as dist
1413
from torch import nn
15-
from torch._C._distributed_c10d import (
16-
_resolve_process_group,
17-
)
18-
from torch.distributed import _functional_collectives
14+
from torch._C._distributed_c10d import _resolve_process_group
15+
from torch.distributed import _functional_collectives, ReduceOp, TCPStore
1916
from torch.distributed.device_mesh import init_device_mesh
2017

2118
from torchft.process_group import (
19+
extend_device_mesh,
20+
ProcessGroup,
2221
ProcessGroupBabyGloo,
2322
ProcessGroupBabyNCCL,
23+
ProcessGroupDummy,
2424
ProcessGroupGloo,
2525
ProcessGroupNCCL,
26-
ProcessGroupDummy,
27-
ProcessGroup,
28-
extend_device_mesh,
2926
)
3027

3128

29+
def dummy_init_pg() -> None:
30+
if not dist.is_initialized():
31+
dist.init_process_group(
32+
backend="gloo", rank=0, world_size=1, store=dist.HashStore()
33+
)
34+
35+
3236
class ProcessGroupTest(TestCase):
3337
def test_gloo(self) -> None:
3438
store = TCPStore(
@@ -168,18 +172,20 @@ def test_device_mesh(self) -> None:
168172
mesh_2d = extend_device_mesh(mesh_1d, pg)
169173
assert mesh_2d.ndim == 2
170174

175+
pg.unregister()
176+
171177
def test_functional_collectives(self) -> None:
178+
dummy_init_pg()
179+
172180
store = TCPStore(
173181
host_name="localhost", port=0, is_master=True, wait_for_workers=False
174182
)
175183
store_addr = f"localhost:{store.port}/prefix"
176184

177-
pg = ProcessGroupGloo()
185+
pg = ProcessGroupGloo().register("test_func_col")
178186
pg.configure(store_addr, 0, 1)
179187

180-
pg.register("test_func_col")
181-
182-
self.assertEqual(pg.group_name, "torchft-gloo:test_func_col")
188+
self.assertEqual(pg.group_name, str(dist.get_pg_count() - 1))
183189

184190
self.assertIs(_resolve_process_group(pg.group_name), pg)
185191

0 commit comments

Comments
 (0)