diff --git a/torchft/process_group.py b/torchft/process_group.py index 0110c6e..4790352 100644 --- a/torchft/process_group.py +++ b/torchft/process_group.py @@ -19,6 +19,7 @@ import logging import queue import threading +from contextlib import contextmanager, nullcontext from dataclasses import dataclass from datetime import timedelta from typing import ( @@ -26,10 +27,10 @@ Any, Callable, Dict, + Generator, List, Optional, Tuple, - Type, TypeVar, Union, cast, @@ -58,9 +59,9 @@ BroadcastOptions, ReduceOp, Work, - _world, ) from torch.futures import Future +from torch.utils._pytree import tree_any if TYPE_CHECKING: from torchft.manager import Manager @@ -586,10 +587,24 @@ def __init__( self._timeout = timeout def wait(self, timeout: Optional[timedelta] = None) -> bool: + self._pg._assert_alive() + self._tx.put(("wait", self._op_id), timeout=self._timeout) - assert _get(self._rx, self._timeout) == self._op_id + op_id, event = cast( + Tuple[int, Optional[torch.cuda.Event]], + _get(self._rx, timeout or self._timeout), + ) + assert op_id == self._op_id + if event is not None: + event.wait() return True + def synchronize(self) -> None: + # TODO: No one seems to use this and NCCL wait already only waits the + # stream and is non-blocking on the CPU side so no real need for a + # separate call. + raise NotImplementedError("not implemented") + def get_future(self) -> Future[object]: return self._pg._get_future(self._op_id) @@ -597,18 +612,27 @@ def __del__(self) -> None: self._tx.put(("del", self._op_id), timeout=self._timeout) -class _BabyWorkNCCL(_BabyWork): - def wait(self, timeout: Optional[timedelta] = None) -> bool: - self._tx.put(("synchronize", self._op_id), timeout=self._timeout) - # pyre-fixme[23]: unable to unpack into 2 values - op_id, event = _get(self._rx, self._timeout) - assert op_id == self._op_id - assert isinstance(event, torch.cuda.Event) +def _is_any_cuda(obj: object) -> bool: + """ + Returns true if any of the tensors in the object are CUDA tensors. - # Wait on Event makes the stream wait but not the CPU thread. - event.wait() + Supports lists, tuples, dicts, and tensors. + """ + return tree_any(lambda obj: isinstance(obj, torch.Tensor) and obj.is_cuda, obj) - return True + +@dataclass +class _OpMetadata: + work: Work + stream: Optional[torch.cuda.Stream] + + @contextmanager + def set_stream(self) -> Generator[None, None, None]: + if self.stream is not None: + with torch.cuda.stream(self.stream): + yield + else: + yield class ProcessGroupBaby(ProcessGroup): @@ -617,11 +641,8 @@ class ProcessGroupBaby(ProcessGroup): subprocess. Since it's running in a subprocess all tensors need to be in shared memory or will be moved to shared memory. CUDA tensors are implicitly share able and don't need any changes. - """ - WORK_CLASS: Type[_BabyWork] = _BabyWork - def __init__(self, timeout: Union[float, timedelta] = 60.0) -> None: super().__init__(0, 1) @@ -679,7 +700,14 @@ def configure(self, store_addr: str, rank: int, world_size: int) -> None: self._p = ctx.Process( target=self._worker, - args=(store_addr, rank, world_size, self._tx, self._rx, self._future_queue), + args=( + store_addr, + rank, + world_size, + self._tx, + self._rx, + self._future_queue, + ), daemon=True, ) self._p.start() @@ -716,23 +744,70 @@ def _worker( return tx.put(None) - work = {} + streams: Dict[str, torch.cuda.Stream] = {} + work: Dict[int, _OpMetadata] = {} next_op_id: int = 0 while True: op = rx.get() cmd = op[0] if cmd == "func": - func_name, args, kwargs = op[1:] - args = _PickleSafeOptions.unsafe_args(args) - fn = getattr(pg, func_name) - work[next_op_id] = fn(*args, **kwargs) + func_name, args, kwargs, stream_device, stream_id, event = op[1:] + + # To avoid potential deadlocks we need to preserve the + # stream/synchronization behavior of the parent process. + # We allocate one Stream per stream_id to make sure that we + # don't accidentally introduce cross stream synchronization + # points. + if stream_id is not None: + stream_key = f"{stream_device}/{stream_id}" + if stream_key not in streams: + streams[stream_key] = torch.cuda.Stream( + device=stream_device + ) + stream = streams[stream_key] + else: + stream = None + + with ( + torch.cuda.stream(stream) + if stream is not None + else nullcontext() + ): + # Make the stream wait on the cuda event to make sure we + # don't start the operation until the tensor is ready. + if event is not None: + event.wait() + + args = _PickleSafeOptions.unsafe_args(args) + fn = getattr(pg, func_name) + work[next_op_id] = _OpMetadata( + work=fn(*args, **kwargs), + stream=stream, + ) tx.put(next_op_id) next_op_id += 1 elif cmd == "wait": op_id: int = op[1] - work[op_id].wait() - tx.put(op_id) + + metadata = work[op_id] + + with metadata.set_stream(): + # With WorkNCCL this makes the stream wait not the CPU when + # no timeout is passed. + metadata.work.wait() + + # Register event on the stream that we can pass to the main + # process. + event = ( + torch.cuda.current_stream().record_event( + torch.cuda.Event(interprocess=True) + ) + if metadata.stream is not None + else None + ) + + tx.put((op_id, event)) elif cmd == "del": op_id: int = op[1] del work[op_id] @@ -746,23 +821,8 @@ def callback(fut: Future[object]) -> None: except Exception as e: future_queue.put((op_id, _FUTURE_EXCEPTION, e)) - work[op_id].get_future().add_done_callback(callback) + work[op_id].work.get_future().add_done_callback(callback) tx.put(op_id) - elif cmd == "synchronize": - # CUDA only, use events instead of waiting on CPU - op_id = op[1] - - # With WorkNCCL this makes the stream wait not the CPU when - # no timeout is passed. - work[op_id].wait() - - # Register event on the stream that we can pass to the main - # process. - event = torch.cuda.Event(interprocess=True) - event.record() - - del work[op_id] - tx.put((op_id, event)) elif cmd == "num_active_work": tx.put(len(work)) else: @@ -771,6 +831,7 @@ def callback(fut: Future[object]) -> None: except Exception as e: logger.exception("worker errored") tx.put(e) + raise def _future_handler(self, future_queue: mp.Queue) -> None: try: @@ -792,6 +853,8 @@ def _future_handler(self, future_queue: mp.Queue) -> None: logger.exception(f"got unexpected error in future handler: {e}") def _get_future(self, op_id: int) -> Future[object]: + self._assert_alive() + with self._futures_lock: fut = Future() # pyre-fixme[29]: is not a function self._futures[op_id] = fut @@ -804,22 +867,52 @@ def _get_future(self, op_id: int) -> Future[object]: return fut def _run_func(self, func: str, *args: object, **kwargs: object) -> Work: + self._assert_alive() + rx = self._rx tx = self._tx assert rx is not None assert tx is not None + is_cuda = _is_any_cuda(args) + + stream_device = torch.cuda.current_stream().device if is_cuda else None + stream_id = torch.cuda.current_stream().stream_id if is_cuda else None + event = ( + torch.cuda.current_stream().record_event( + torch.cuda.Event(interprocess=True) + ) + if is_cuda + else None + ) + tx.put( - ("func", func, _PickleSafeOptions.safe_args(args), kwargs), + ( + "func", + func, + _PickleSafeOptions.safe_args(args), + kwargs, + stream_device, + stream_id, + event, + ), timeout=self._timeout, ) op_id = _get(rx, self._timeout) assert isinstance(op_id, int), f"invalid return {op_id}" - return self.WORK_CLASS( - pg=self, tx=tx, rx=rx, op_id=op_id, timeout=self._timeout - ) + return _BabyWork(pg=self, tx=tx, rx=rx, op_id=op_id, timeout=self._timeout) + + def _assert_alive(self) -> None: + """ + Assert that the process group is alive. This is used to ensure that + operations are not performed on a dead process group and any errors are surfaced. + """ + p = self._p + assert p is not None + if not p.is_alive(): + raise RuntimeError(f"child process {p.pid=} is dead {p.exitcode=}") def allreduce( self, @@ -952,8 +1045,6 @@ class ProcessGroupBabyNCCL(ProcessGroupBaby): tensors may leak in the current PyTorch implementation. TODO fix """ - WORK_CLASS = _BabyWorkNCCL - @classmethod def _create_pg(cls, store: Store, rank: int, world_size: int) -> BaseProcessGroup: # pyre-fixme[16]: no attribute ProcessGroupNCCL diff --git a/torchft/process_group_test.py b/torchft/process_group_test.py index 6abeb39..f765625 100644 --- a/torchft/process_group_test.py +++ b/torchft/process_group_test.py @@ -266,6 +266,31 @@ def test_baby_gloo_apis(self) -> None: self.assertEqual(a.num_active_work(), 0) + # pyre-fixme[56]: Pyre was not able to infer the type of argument + @skipUnless(torch.cuda.is_available(), "needs CUDA") + def test_baby_nccl_apis(self) -> None: + # set to 1 if more than >=2 gpus + device_id = 1 % torch.cuda.device_count() + torch.cuda.set_device(device_id) + + store = TCPStore( + host_name="localhost", port=0, is_master=True, wait_for_workers=False + ) + + store_addr = f"localhost:{store.port}/prefix" + + a = ProcessGroupBabyNCCL(timeout=timedelta(seconds=10)) + a.configure(store_addr, 0, 1) + + _test_pg(a, torch.randn((2, 3), device="cuda")) + + torch.cuda.synchronize() + + # force collection to ensure no BabyWork objects remain + gc.collect() + + self.assertEqual(a.num_active_work(), 0) + def test_dummy(self) -> None: pg = ProcessGroupDummy(0, 1) m = nn.Linear(3, 4) @@ -282,12 +307,15 @@ def test_baby_nccl_2gpu(self) -> None: store_addr: str = f"localhost:{store.port}/prefix" def run(rank: int) -> Tuple[torch.Tensor, Work]: - a = ProcessGroupBabyNCCL() + a = ProcessGroupBabyNCCL( + timeout=timedelta(seconds=10.0), + ) a.configure(store_addr, rank, 2) - self.assertEqual(a.size(), 2) - at = torch.tensor([rank + 1], device=f"cuda:{rank}") + # We test using set_device to ensure stream device is correct. + torch.cuda.set_device(rank) + at = torch.tensor([rank + 1], device="cuda") a_work = a.allreduce([at], ReduceOp.SUM) return at, a_work