|
4 | 4 | # This source code is licensed under the BSD-style license found in the
|
5 | 5 | # LICENSE file in the root directory of this source tree.
|
6 | 6 |
|
7 |
| -from abc import ABC |
8 | 7 | import logging
|
9 |
| -from typing import Type, List, Optional, Callable, Tuple |
10 |
| -from datetime import timedelta |
11 | 8 | import threading
|
| 9 | +from abc import ABC |
| 10 | +from datetime import timedelta |
| 11 | +from typing import Callable, List, Optional, Tuple, Type |
12 | 12 |
|
13 |
| -from torch.futures import Future |
| 13 | +import torch |
| 14 | +import torch.distributed as dist |
| 15 | +import torch.multiprocessing as mp |
| 16 | +from torch._C._distributed_c10d import ( |
| 17 | + _register_process_group, |
| 18 | + _unregister_process_group, |
| 19 | +) |
14 | 20 | from torch.distributed import (
|
15 |
| - ProcessGroup as BaseProcessGroup, |
16 |
| - Store, |
17 |
| - TCPStore, |
18 |
| - PrefixStore, |
19 | 21 | BroadcastOptions,
|
| 22 | + DeviceMesh, |
| 23 | + get_rank, |
| 24 | + PrefixStore, |
| 25 | + ProcessGroup as BaseProcessGroup, |
20 | 26 | ProcessGroupGloo as BaseProcessGroupGloo,
|
21 | 27 | ProcessGroupNCCL as BaseProcessGroupNCCL,
|
| 28 | + Store, |
| 29 | + TCPStore, |
22 | 30 | )
|
23 |
| -import torch.distributed as dist |
24 |
| -from torch.distributed.distributed_c10d import Work |
25 |
| -import torch |
26 |
| -import torch.multiprocessing as mp |
| 31 | +from torch.distributed.distributed_c10d import _world, Work |
| 32 | + |
| 33 | +from torch.futures import Future |
27 | 34 |
|
28 | 35 | logger = logging.getLogger(__name__)
|
29 | 36 |
|
@@ -62,6 +69,11 @@ def create_store(store_addr: str) -> Store:
|
62 | 69 |
|
63 | 70 |
|
64 | 71 | class ProcessGroup(BaseProcessGroup):
|
| 72 | + def __init__(self, *args, **kwargs) -> None: |
| 73 | + super().__init__(*args, **kwargs) |
| 74 | + |
| 75 | + self._group_name = None |
| 76 | + |
65 | 77 | def configure(self, store_addr: str, rank: int, world_size: int) -> None:
|
66 | 78 | raise NotImplementedError("not implemented")
|
67 | 79 |
|
@@ -90,6 +102,44 @@ def size(self) -> int:
|
90 | 102 | def getBackendName(self) -> str:
|
91 | 103 | raise NotImplementedError("not implemented")
|
92 | 104 |
|
| 105 | + def register(self, name: str) -> None: |
| 106 | + """ |
| 107 | + Registers the process group with the global registry. This enables usage |
| 108 | + with things like functional_collectives which are compilable. |
| 109 | +
|
| 110 | + This should only be called once. |
| 111 | +
|
| 112 | + Args: |
| 113 | + name: name must be a unique name for this process group |
| 114 | + """ |
| 115 | + |
| 116 | + self._group_name = f"{self.getBackendName()}:{name}" |
| 117 | + _register_process_group(self.group_name, self) |
| 118 | + |
| 119 | + # This is needed for DeviceMesh to work |
| 120 | + # Resizable worlds don't fit well into DeviceMesh so we register a world |
| 121 | + # size 1 PG. |
| 122 | + _world.pg_map[self] = (None, None) |
| 123 | + _world.pg_names[self] = self._group_name |
| 124 | + _world.pg_to_tag[self] = self._group_name |
| 125 | + _world.tags_to_pg.setdefault(self._group_name, []).append(self) |
| 126 | + # these PGs can be resized so we lie about the rank mapping |
| 127 | + _world.pg_group_ranks[self] = {get_rank(): 0} |
| 128 | + |
| 129 | + @property |
| 130 | + def group_name(self) -> str: |
| 131 | + if self._group_name is None: |
| 132 | + raise ValueError("ProcessGroup name not set") |
| 133 | + return self._group_name |
| 134 | + |
| 135 | + def unregister(self) -> None: |
| 136 | + """ |
| 137 | + Unregisters the process group with the global registry. |
| 138 | +
|
| 139 | + Must be registered first. |
| 140 | + """ |
| 141 | + _unregister_process_group(self.group_name) |
| 142 | + |
93 | 143 |
|
94 | 144 | class ProcessGroupWrapper(ProcessGroup):
|
95 | 145 | PG_CLASS: Type[BaseProcessGroup]
|
@@ -458,3 +508,32 @@ class ProcessGroupBabyNCCL(ProcessGroupBaby):
|
458 | 508 |
|
459 | 509 | def getBackendName(self):
|
460 | 510 | return "torchft-baby-nccl"
|
| 511 | + |
| 512 | + |
| 513 | +def extend_device_mesh( |
| 514 | + mesh: DeviceMesh, pg: ProcessGroup, name: str = "dp", dim: int = 0 |
| 515 | +) -> DeviceMesh: |
| 516 | + """ |
| 517 | + This is a helper method to extend a traditional DeviceMesh with a torchft ProcessGroup for usage with DeviceMesh based APIs such as FSDPv2 with hybrid sharding. |
| 518 | +
|
| 519 | + Resizable PGs aren't natively supported by DeviceMesh so we lie to |
| 520 | + DeviceMesh and say the PG is world size 1. This is fine as long as any |
| 521 | + numeric scaling is handled at the PG level. |
| 522 | +
|
| 523 | + Args: |
| 524 | + mesh: The DeviceMesh to extend |
| 525 | + pg: The ProcessGroup to add to the mesh |
| 526 | + name: The name of the new dimension |
| 527 | + dim: The dimension to add the ProcessGroup to |
| 528 | + """ |
| 529 | + groups = mesh.get_all_groups() |
| 530 | + groups.insert(dim, pg) |
| 531 | + mesh_dim_names = list(mesh.mesh_dim_names) |
| 532 | + mesh_dim_names.insert(dim, name) |
| 533 | + |
| 534 | + return DeviceMesh.from_group( |
| 535 | + group=groups, |
| 536 | + device_type=mesh.device_type, |
| 537 | + mesh=mesh.mesh.unsqueeze(dim), |
| 538 | + mesh_dim_names=mesh_dim_names, |
| 539 | + ) |
0 commit comments