Skip to content

Commit 4631d6f

Browse files
authored
process_group: added registration to support DeviceMesh and functional_collectives (#13)
* process_group: added registration support * process_group: added initial DeviceMesh support
1 parent 936016c commit 4631d6f

File tree

2 files changed

+139
-12
lines changed

2 files changed

+139
-12
lines changed

torchft/process_group.py

+91-12
Original file line numberDiff line numberDiff line change
@@ -4,26 +4,33 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
from abc import ABC
87
import logging
9-
from typing import Type, List, Optional, Callable, Tuple
10-
from datetime import timedelta
118
import threading
9+
from abc import ABC
10+
from datetime import timedelta
11+
from typing import Callable, List, Optional, Tuple, Type
1212

13-
from torch.futures import Future
13+
import torch
14+
import torch.distributed as dist
15+
import torch.multiprocessing as mp
16+
from torch._C._distributed_c10d import (
17+
_register_process_group,
18+
_unregister_process_group,
19+
)
1420
from torch.distributed import (
15-
ProcessGroup as BaseProcessGroup,
16-
Store,
17-
TCPStore,
18-
PrefixStore,
1921
BroadcastOptions,
22+
DeviceMesh,
23+
get_rank,
24+
PrefixStore,
25+
ProcessGroup as BaseProcessGroup,
2026
ProcessGroupGloo as BaseProcessGroupGloo,
2127
ProcessGroupNCCL as BaseProcessGroupNCCL,
28+
Store,
29+
TCPStore,
2230
)
23-
import torch.distributed as dist
24-
from torch.distributed.distributed_c10d import Work
25-
import torch
26-
import torch.multiprocessing as mp
31+
from torch.distributed.distributed_c10d import _world, Work
32+
33+
from torch.futures import Future
2734

2835
logger = logging.getLogger(__name__)
2936

@@ -62,6 +69,11 @@ def create_store(store_addr: str) -> Store:
6269

6370

6471
class ProcessGroup(BaseProcessGroup):
72+
def __init__(self, *args, **kwargs) -> None:
73+
super().__init__(*args, **kwargs)
74+
75+
self._group_name = None
76+
6577
def configure(self, store_addr: str, rank: int, world_size: int) -> None:
6678
raise NotImplementedError("not implemented")
6779

@@ -90,6 +102,44 @@ def size(self) -> int:
90102
def getBackendName(self) -> str:
91103
raise NotImplementedError("not implemented")
92104

105+
def register(self, name: str) -> None:
106+
"""
107+
Registers the process group with the global registry. This enables usage
108+
with things like functional_collectives which are compilable.
109+
110+
This should only be called once.
111+
112+
Args:
113+
name: name must be a unique name for this process group
114+
"""
115+
116+
self._group_name = f"{self.getBackendName()}:{name}"
117+
_register_process_group(self.group_name, self)
118+
119+
# This is needed for DeviceMesh to work
120+
# Resizable worlds don't fit well into DeviceMesh so we register a world
121+
# size 1 PG.
122+
_world.pg_map[self] = (None, None)
123+
_world.pg_names[self] = self._group_name
124+
_world.pg_to_tag[self] = self._group_name
125+
_world.tags_to_pg.setdefault(self._group_name, []).append(self)
126+
# these PGs can be resized so we lie about the rank mapping
127+
_world.pg_group_ranks[self] = {get_rank(): 0}
128+
129+
@property
130+
def group_name(self) -> str:
131+
if self._group_name is None:
132+
raise ValueError("ProcessGroup name not set")
133+
return self._group_name
134+
135+
def unregister(self) -> None:
136+
"""
137+
Unregisters the process group with the global registry.
138+
139+
Must be registered first.
140+
"""
141+
_unregister_process_group(self.group_name)
142+
93143

94144
class ProcessGroupWrapper(ProcessGroup):
95145
PG_CLASS: Type[BaseProcessGroup]
@@ -458,3 +508,32 @@ class ProcessGroupBabyNCCL(ProcessGroupBaby):
458508

459509
def getBackendName(self):
460510
return "torchft-baby-nccl"
511+
512+
513+
def extend_device_mesh(
514+
mesh: DeviceMesh, pg: ProcessGroup, name: str = "dp", dim: int = 0
515+
) -> DeviceMesh:
516+
"""
517+
This is a helper method to extend a traditional DeviceMesh with a torchft ProcessGroup for usage with DeviceMesh based APIs such as FSDPv2 with hybrid sharding.
518+
519+
Resizable PGs aren't natively supported by DeviceMesh so we lie to
520+
DeviceMesh and say the PG is world size 1. This is fine as long as any
521+
numeric scaling is handled at the PG level.
522+
523+
Args:
524+
mesh: The DeviceMesh to extend
525+
pg: The ProcessGroup to add to the mesh
526+
name: The name of the new dimension
527+
dim: The dimension to add the ProcessGroup to
528+
"""
529+
groups = mesh.get_all_groups()
530+
groups.insert(dim, pg)
531+
mesh_dim_names = list(mesh.mesh_dim_names)
532+
mesh_dim_names.insert(dim, name)
533+
534+
return DeviceMesh.from_group(
535+
group=groups,
536+
device_type=mesh.device_type,
537+
mesh=mesh.mesh.unsqueeze(dim),
538+
mesh_dim_names=mesh_dim_names,
539+
)

torchft/process_group_test.py

+48
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,17 @@
66

77
from unittest import TestCase, skipUnless
88
from concurrent.futures import ThreadPoolExecutor
9+
import os
910

1011
import torch
1112
from torch.distributed import TCPStore, ReduceOp
1213
import torch.distributed as dist
1314
from torch import nn
15+
from torch._C._distributed_c10d import (
16+
_resolve_process_group,
17+
)
18+
from torch.distributed import _functional_collectives
19+
from torch.distributed.device_mesh import init_device_mesh
1420

1521
from torchft.process_group import (
1622
ProcessGroupBabyGloo,
@@ -19,6 +25,7 @@
1925
ProcessGroupNCCL,
2026
ProcessGroupDummy,
2127
ProcessGroup,
28+
extend_device_mesh,
2229
)
2330

2431

@@ -140,3 +147,44 @@ def run(rank: int) -> None:
140147
b_work.get_future().wait()
141148

142149
torch.testing.assert_close(at.cpu(), bt.cpu())
150+
151+
def test_device_mesh(self) -> None:
152+
os.environ["MASTER_ADDR"] = "localhost"
153+
os.environ["MASTER_PORT"] = str(0)
154+
os.environ["RANK"] = str(0)
155+
os.environ["WORLD_SIZE"] = str(1)
156+
157+
mesh_1d = init_device_mesh("cpu", mesh_shape=(1,), mesh_dim_names=("tp",))
158+
159+
store = TCPStore(
160+
host_name="localhost", port=0, is_master=True, wait_for_workers=False
161+
)
162+
store_addr = f"localhost:{store.port}/prefix"
163+
164+
pg = ProcessGroupGloo()
165+
pg.register("test_device_mesh")
166+
pg.configure(store_addr, 0, 1)
167+
168+
mesh_2d = extend_device_mesh(mesh_1d, pg)
169+
assert mesh_2d.ndim == 2
170+
171+
def test_functional_collectives(self) -> None:
172+
store = TCPStore(
173+
host_name="localhost", port=0, is_master=True, wait_for_workers=False
174+
)
175+
store_addr = f"localhost:{store.port}/prefix"
176+
177+
pg = ProcessGroupGloo()
178+
pg.configure(store_addr, 0, 1)
179+
180+
pg.register("test_func_col")
181+
182+
self.assertEqual(pg.group_name, "torchft-gloo:test_func_col")
183+
184+
self.assertIs(_resolve_process_group(pg.group_name), pg)
185+
186+
try:
187+
t = torch.zeros(10)
188+
_functional_collectives.all_reduce(t, "sum", pg).wait()
189+
finally:
190+
pg.unregister()

0 commit comments

Comments
 (0)