-
Notifications
You must be signed in to change notification settings - Fork 7.2k
[Core] Add register_collective_backend API for customized collective backends #60701
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from all commits
629452e
1761527
4c6c39b
5b5f82d
2e368ee
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,48 @@ | ||
| from typing import Dict, Type | ||
|
|
||
| from .collective_group.base_collective_group import BaseGroup | ||
|
|
||
|
|
||
| class BackendRegistry: | ||
| _instance = None | ||
| _map: Dict[str, Type[BaseGroup]] | ||
|
|
||
| def __new__(cls): | ||
| if cls._instance is None: | ||
| cls._instance = super(BackendRegistry, cls).__new__(cls) | ||
| cls._instance._map = {} | ||
| return cls._instance | ||
|
|
||
| def put(self, name: str, group_cls: Type[BaseGroup]) -> None: | ||
| if not issubclass(group_cls, BaseGroup): | ||
| raise TypeError(f"{group_cls} is not a subclass of BaseGroup") | ||
| if name.upper() in self._map: | ||
| raise ValueError(f"Backend {name} already registered") | ||
| self._map[name.upper()] = group_cls | ||
|
|
||
| def get(self, name: str) -> Type[BaseGroup]: | ||
| name = name.upper() | ||
| if name not in self._map: | ||
| raise ValueError(f"Backend {name} not registered") | ||
| return self._map[name] | ||
|
|
||
| def check(self, name: str) -> bool: | ||
| try: | ||
| cls = self.get(name) | ||
| return cls.check_backend_availability() | ||
| except (ValueError, AttributeError): | ||
| return False | ||
|
|
||
| def list_backends(self) -> list: | ||
| return list(self._map.keys()) | ||
|
|
||
|
|
||
| _global_registry = BackendRegistry() | ||
|
|
||
|
|
||
| def register_collective_backend(name: str, group_cls: Type[BaseGroup]) -> None: | ||
| _global_registry.put(name, group_cls) | ||
|
|
||
|
|
||
| def get_backend_registry() -> BackendRegistry: | ||
| return _global_registry |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -13,6 +13,10 @@ | |
| import ray.experimental.internal_kv as _internal_kv | ||
| from . import types | ||
| from ray._common.network_utils import find_free_port, is_ipv6 | ||
| from ray.util.collective.backend_registry import ( | ||
| get_backend_registry, | ||
| register_collective_backend, | ||
| ) | ||
| from ray.util.collective.collective_group.torch_gloo_collective_group import ( | ||
| get_master_address_metadata_key as _get_master_addr_key, | ||
| ) | ||
|
|
@@ -38,6 +42,11 @@ | |
| except ImportError: | ||
| _TORCH_DISTRIBUTED_AVAILABLE = False | ||
|
|
||
| if _NCCL_AVAILABLE: | ||
| register_collective_backend("NCCL", NCCLGroup) | ||
| if _TORCH_DISTRIBUTED_AVAILABLE: | ||
| register_collective_backend("GLOO", TorchGLOOGroup) | ||
|
|
||
|
|
||
| def nccl_available(): | ||
| global _LOG_NCCL_WARNING | ||
|
|
@@ -57,10 +66,6 @@ def gloo_available(): | |
| return _TORCH_DISTRIBUTED_AVAILABLE | ||
|
|
||
|
|
||
| def torch_distributed_available(): | ||
| return _TORCH_DISTRIBUTED_AVAILABLE | ||
|
|
||
|
|
||
| def get_address_and_port() -> Tuple[str, int]: | ||
| """Returns the IP address and a free port on this node.""" | ||
| addr = ray.util.get_node_ip_address() | ||
|
|
@@ -78,18 +83,25 @@ class GroupManager(object): | |
|
|
||
| def __init__(self): | ||
| self._name_group_map = {} | ||
| self._registry = get_backend_registry() | ||
|
|
||
| def create_collective_group( | ||
| self, backend, world_size, rank, group_name, gloo_timeout | ||
| self, backend, world_size, rank, group_name, gloo_timeout=None | ||
| ): | ||
| """The entry to create new collective groups in the manager. | ||
|
|
||
| Put the registration and the group information into the manager | ||
| metadata as well. | ||
| """ | ||
| backend = types.Backend(backend) | ||
| if backend == types.Backend.GLOO: | ||
| # Rendezvous: ensure a MASTER_ADDR:MASTER_PORT is published in internal_kv. | ||
| backend = backend.upper() | ||
| backend_cls = self._registry.get(backend) | ||
|
|
||
| if not backend_cls.check_backend_availability(): | ||
| raise RuntimeError( | ||
| f"Backend {backend} is not available. Please check the installation." | ||
| ) | ||
|
|
||
| if backend == "GLOO": | ||
| metadata_key = _get_master_addr_key(group_name) | ||
| if rank == 0: | ||
| addr, port = get_address_and_port() | ||
|
|
@@ -112,13 +124,9 @@ def create_collective_group( | |
| logger.debug( | ||
| "Creating torch.distributed GLOO group: '{}'...".format(group_name) | ||
| ) | ||
| g = TorchGLOOGroup(world_size, rank, group_name, gloo_timeout) | ||
| elif backend == types.Backend.NCCL: | ||
| _check_backend_availability(backend) | ||
| logger.debug("Creating NCCL group: '{}'...".format(group_name)) | ||
| g = NCCLGroup(world_size, rank, group_name) | ||
| g = backend_cls(world_size, rank, group_name, gloo_timeout) | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hardcoded GLOO check causes crashes with custom backendsMedium Severity The hardcoded string check |
||
| else: | ||
| raise RuntimeError(f"Unexpected backend: {backend}") | ||
| g = backend_cls(world_size, rank, group_name) | ||
|
|
||
| self._name_group_map[group_name] = g | ||
| return self._name_group_map[group_name] | ||
|
|
@@ -171,7 +179,7 @@ def is_group_initialized(group_name): | |
| def init_collective_group( | ||
| world_size: int, | ||
| rank: int, | ||
| backend=types.Backend.NCCL, | ||
| backend: str = "NCCL", | ||
| group_name: str = "default", | ||
| gloo_timeout: int = 30000, | ||
| ): | ||
|
|
@@ -187,11 +195,13 @@ def init_collective_group( | |
| None | ||
| """ | ||
| _check_inside_actor() | ||
| backend = types.Backend(backend) | ||
| _check_backend_availability(backend) | ||
|
|
||
cursor[bot] marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| global _group_mgr | ||
| global _group_mgr_lock | ||
|
|
||
| backend_cls = _group_mgr._registry.get(backend) | ||
| if not backend_cls.check_backend_availability(): | ||
| raise RuntimeError("Backend '{}' is not available.".format(backend)) | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Redundant backend availability checksLow Severity Backend availability is checked three times for the same operation. Both Additional Locations (2) |
||
| # TODO(Hao): implement a group auto-counter. | ||
| if not group_name: | ||
| raise ValueError("group_name '{}' needs to be a string.".format(group_name)) | ||
|
|
@@ -212,7 +222,7 @@ def create_collective_group( | |
| actors, | ||
| world_size: int, | ||
| ranks: List[int], | ||
| backend=types.Backend.NCCL, | ||
| backend: str = "NCCL", | ||
| group_name: str = "default", | ||
| gloo_timeout: int = 30000, | ||
| ): | ||
|
|
@@ -230,8 +240,9 @@ def create_collective_group( | |
| Returns: | ||
| None | ||
| """ | ||
| backend = types.Backend(backend) | ||
| _check_backend_availability(backend) | ||
| backend_cls = _group_mgr._registry.get(backend) | ||
| if not backend_cls.check_backend_availability(): | ||
| raise RuntimeError("Backend '{}' is not available.".format(backend)) | ||
|
|
||
| name = "info_" + group_name | ||
| try: | ||
|
|
@@ -805,17 +816,6 @@ def _check_single_tensor_input(tensor): | |
| ) | ||
|
|
||
|
|
||
| def _check_backend_availability(backend: types.Backend): | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Type mismatch in gloo_timeout environment variableHigh Severity The |
||
| """Check whether the backend is available.""" | ||
| if backend == types.Backend.GLOO: | ||
| # Now we have deprecated pygloo, and use torch_gloo in all cases. | ||
| if not torch_distributed_available(): | ||
| raise RuntimeError("torch.distributed is not available.") | ||
| elif backend == types.Backend.NCCL: | ||
| if not nccl_available(): | ||
| raise RuntimeError("NCCL is not available.") | ||
|
|
||
|
|
||
| def _check_inside_actor(): | ||
| """Check if currently it is inside a Ray actor/task.""" | ||
| worker = ray._private.worker.global_worker | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -2,9 +2,6 @@ | |
| import logging | ||
| import time | ||
|
|
||
| import cupy | ||
| import torch | ||
|
|
||
| import ray | ||
| from ray.util.collective.collective_group import nccl_util | ||
| from ray.util.collective.collective_group.base_collective_group import BaseGroup | ||
|
|
@@ -13,7 +10,6 @@ | |
| from ray.util.collective.types import ( | ||
| AllGatherOptions, | ||
| AllReduceOptions, | ||
| Backend, | ||
| BarrierOptions, | ||
| BroadcastOptions, | ||
| RecvOptions, | ||
|
|
@@ -25,6 +21,18 @@ | |
|
|
||
| logger = logging.getLogger(__name__) | ||
|
|
||
| global _LOG_NCCL_WARNING, _NCCL_AVAILABLE | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Unnecessary global statement at module levelLow Severity The |
||
|
|
||
| try: | ||
| import cupy | ||
| import torch | ||
|
|
||
| _NCCL_AVAILABLE = True | ||
| _LOG_NCCL_WARNING = False | ||
| except ImportError: | ||
| _NCCL_AVAILABLE = False | ||
| _LOG_NCCL_WARNING = True | ||
cursor[bot] marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
|
|
||
| class Rendezvous: | ||
| """A rendezvous class for different actor/task processes to meet. | ||
|
|
@@ -163,7 +171,19 @@ def destroy_group(self): | |
|
|
||
| @classmethod | ||
| def backend(cls): | ||
| return Backend.NCCL | ||
| return "NCCL" | ||
|
|
||
| @classmethod | ||
| def check_backend_availability(cls) -> bool: | ||
| global _LOG_NCCL_WARNING, _NCCL_AVAILABLE | ||
| if ray.get_gpu_ids() and _LOG_NCCL_WARNING: | ||
| logger.warning( | ||
| "NCCL seems unavailable. Please install Cupy " | ||
| "following the guide at: " | ||
| "https://docs.cupy.dev/en/stable/install.html." | ||
| ) | ||
| _LOG_NCCL_WARNING = False | ||
| return _NCCL_AVAILABLE | ||
|
|
||
| def allreduce(self, tensors, allreduce_options=AllReduceOptions()): | ||
| """AllReduce tensors across the collective group following options. | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,55 @@ | ||
| import torch | ||
|
|
||
| import ray | ||
| from ray.util.collective import ( | ||
| allreduce, | ||
| create_collective_group, | ||
| init_collective_group, | ||
| ) | ||
| from ray.util.collective.backend_registry import get_backend_registry | ||
| from ray.util.collective.types import ReduceOp | ||
|
|
||
|
|
||
| def test_gloo_via_registry(): | ||
| ray.init() | ||
|
|
||
| registry = get_backend_registry() | ||
| assert "GLOO" in registry.list_backends() | ||
| assert registry.check("GLOO") | ||
|
|
||
| @ray.remote | ||
| class Worker: | ||
| def __init__(self, rank): | ||
| self.rank = rank | ||
| self.tensor = None | ||
|
|
||
| def setup(self, world_size): | ||
| init_collective_group( | ||
| world_size=world_size, | ||
| rank=self.rank, | ||
| backend="GLOO", | ||
| group_name="default", | ||
| gloo_timeout=30000, | ||
| ) | ||
|
|
||
| def compute(self): | ||
| self.tensor = torch.tensor([self.rank + 1], dtype=torch.float32) | ||
| allreduce(self.tensor, op=ReduceOp.SUM) | ||
| return self.tensor.item() | ||
|
|
||
| actors = [Worker.remote(rank=i) for i in range(2)] | ||
| create_collective_group( | ||
| actors=actors, | ||
| world_size=2, | ||
| ranks=[0, 1], | ||
| backend="GLOO", | ||
| group_name="default", | ||
| gloo_timeout=30000, | ||
| ) | ||
|
|
||
| ray.get([a.setup.remote(2) for a in actors]) | ||
| results = ray.get([a.compute.remote() for a in actors]) | ||
|
|
||
| assert results == [3.0, 3.0], f"Expected [3.0, 3.0], got {results}" | ||
|
|
||
| ray.shutdown() |


There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The conditional registration of backends is based on
_NCCL_AVAILABLEand_TORCH_DISTRIBUTED_AVAILABLEflags. Due to the changes in this PR (guarded imports within group classes), these flags are no longer reliable here and will likely always beTrue. The registration should be unconditional, as the actual availability of the backend is checked dynamically when a collective group is created.