Skip to content

Commit aa66fc4

Browse files
committed
ProcessGroupBabyNCCL: support multiple streams and use event on start
1 parent c3d5d54 commit aa66fc4

File tree

2 files changed

+130
-42
lines changed

2 files changed

+130
-42
lines changed

torchft/process_group.py

+109-42
Original file line numberDiff line numberDiff line change
@@ -19,13 +19,16 @@
1919
import logging
2020
import queue
2121
import threading
22+
from collections import defaultdict
23+
from contextlib import contextmanager, nullcontext
2224
from dataclasses import dataclass
2325
from datetime import timedelta
2426
from typing import (
2527
TYPE_CHECKING,
2628
Any,
2729
Callable,
2830
Dict,
31+
Generator,
2932
List,
3033
Optional,
3134
Tuple,
@@ -587,28 +590,56 @@ def __init__(
587590

588591
def wait(self, timeout: Optional[timedelta] = None) -> bool:
589592
self._tx.put(("wait", self._op_id), timeout=self._timeout)
590-
assert _get(self._rx, self._timeout) == self._op_id
593+
op_id, event = cast(
594+
Tuple[int, Optional[torch.cuda.Event]],
595+
_get(self._rx, timeout or self._timeout),
596+
)
597+
assert op_id == self._op_id
598+
if event is not None:
599+
event.wait()
591600
return True
592601

602+
def synchronize(self) -> None:
603+
# TODO: No one seems to use this and NCCL wait already only waits the
604+
# stream and is non-blocking on the CPU side so no real need for a
605+
# separate call.
606+
raise NotImplementedError("not implemented")
607+
593608
def get_future(self) -> Future[object]:
594609
return self._pg._get_future(self._op_id)
595610

596611
def __del__(self) -> None:
597612
self._tx.put(("del", self._op_id), timeout=self._timeout)
598613

599614

600-
class _BabyWorkNCCL(_BabyWork):
601-
def wait(self, timeout: Optional[timedelta] = None) -> bool:
602-
self._tx.put(("synchronize", self._op_id), timeout=self._timeout)
603-
# pyre-fixme[23]: unable to unpack into 2 values
604-
op_id, event = _get(self._rx, self._timeout)
605-
assert op_id == self._op_id
606-
assert isinstance(event, torch.cuda.Event)
615+
def _is_any_cuda(obj: object) -> bool:
616+
"""
617+
Returns true if any of the tensors in the object are CUDA tensors.
607618
608-
# Wait on Event makes the stream wait but not the CPU thread.
609-
event.wait()
619+
Supports lists, tuples, dicts, and tensors.
620+
"""
621+
if isinstance(obj, torch.Tensor):
622+
return obj.is_cuda
623+
elif isinstance(obj, (list, tuple)):
624+
return any(_is_any_cuda(o) for o in obj)
625+
elif isinstance(obj, dict):
626+
return any(_is_any_cuda(o) for o in obj.values())
627+
else:
628+
return False
610629

611-
return True
630+
631+
@dataclass
632+
class _OpMetadata:
633+
work: Work
634+
stream: Optional[torch.cuda.Stream]
635+
636+
@contextmanager
637+
def set_stream(self) -> Generator[None, None, None]:
638+
if self.stream is not None:
639+
with torch.cuda.stream(self.stream):
640+
yield
641+
else:
642+
yield
612643

613644

614645
class ProcessGroupBaby(ProcessGroup):
@@ -620,8 +651,6 @@ class ProcessGroupBaby(ProcessGroup):
620651
621652
"""
622653

623-
WORK_CLASS: Type[_BabyWork] = _BabyWork
624-
625654
def __init__(self, timeout: Union[float, timedelta] = 60.0) -> None:
626655
super().__init__(0, 1)
627656

@@ -716,23 +745,62 @@ def _worker(
716745
return
717746
tx.put(None)
718747

719-
work = {}
748+
streams = defaultdict(lambda: torch.cuda.Stream())
749+
work: Dict[int, _OpMetadata] = {}
720750
next_op_id: int = 0
721751

722752
while True:
723753
op = rx.get()
724754
cmd = op[0]
725755
if cmd == "func":
726-
func_name, args, kwargs = op[1:]
727-
args = _PickleSafeOptions.unsafe_args(args)
728-
fn = getattr(pg, func_name)
729-
work[next_op_id] = fn(*args, **kwargs)
756+
func_name, args, kwargs, stream_id, event = op[1:]
757+
758+
# To avoid potential deadlocks we need to preserve the
759+
# stream/synchronization behavior of the parent process.
760+
# We allocate one Stream per stream_id to make sure that we
761+
# don't accidentally introduce cross stream synchronization
762+
# points.
763+
stream = streams[stream_id] if stream_id is not None else None
764+
with (
765+
torch.cuda.stream(stream)
766+
if stream is not None
767+
else nullcontext()
768+
):
769+
770+
# Make the stream wait on the cuda event to make sure we
771+
# don't start the operation until the tensor is ready.
772+
if event is not None:
773+
event.wait()
774+
775+
args = _PickleSafeOptions.unsafe_args(args)
776+
fn = getattr(pg, func_name)
777+
work[next_op_id] = _OpMetadata(
778+
work=fn(*args, **kwargs),
779+
stream=stream,
780+
)
730781
tx.put(next_op_id)
731782
next_op_id += 1
732783
elif cmd == "wait":
733784
op_id: int = op[1]
734-
work[op_id].wait()
735-
tx.put(op_id)
785+
786+
metadata = work[op_id]
787+
788+
with metadata.set_stream():
789+
# With WorkNCCL this makes the stream wait not the CPU when
790+
# no timeout is passed.
791+
metadata.work.wait()
792+
793+
# Register event on the stream that we can pass to the main
794+
# process.
795+
event = (
796+
torch.cuda.current_stream().record_event(
797+
torch.cuda.Event(interprocess=True)
798+
)
799+
if metadata.stream is not None
800+
else None
801+
)
802+
803+
tx.put((op_id, event))
736804
elif cmd == "del":
737805
op_id: int = op[1]
738806
del work[op_id]
@@ -746,23 +814,8 @@ def callback(fut: Future[object]) -> None:
746814
except Exception as e:
747815
future_queue.put((op_id, _FUTURE_EXCEPTION, e))
748816

749-
work[op_id].get_future().add_done_callback(callback)
817+
work[op_id].work.get_future().add_done_callback(callback)
750818
tx.put(op_id)
751-
elif cmd == "synchronize":
752-
# CUDA only, use events instead of waiting on CPU
753-
op_id = op[1]
754-
755-
# With WorkNCCL this makes the stream wait not the CPU when
756-
# no timeout is passed.
757-
work[op_id].wait()
758-
759-
# Register event on the stream that we can pass to the main
760-
# process.
761-
event = torch.cuda.Event(interprocess=True)
762-
event.record()
763-
764-
del work[op_id]
765-
tx.put((op_id, event))
766819
elif cmd == "num_active_work":
767820
tx.put(len(work))
768821
else:
@@ -809,17 +862,33 @@ def _run_func(self, func: str, *args: object, **kwargs: object) -> Work:
809862
assert rx is not None
810863
assert tx is not None
811864

865+
is_cuda = _is_any_cuda(args)
866+
867+
stream_id = torch.cuda.current_stream().stream_id if is_cuda else None
868+
event = (
869+
torch.cuda.current_stream().record_event(
870+
torch.cuda.Event(interprocess=True)
871+
)
872+
if is_cuda
873+
else None
874+
)
875+
812876
tx.put(
813-
("func", func, _PickleSafeOptions.safe_args(args), kwargs),
877+
(
878+
"func",
879+
func,
880+
_PickleSafeOptions.safe_args(args),
881+
kwargs,
882+
stream_id,
883+
event,
884+
),
814885
timeout=self._timeout,
815886
)
816887

817888
op_id = _get(rx, self._timeout)
818889
assert isinstance(op_id, int), f"invalid return {op_id}"
819890

820-
return self.WORK_CLASS(
821-
pg=self, tx=tx, rx=rx, op_id=op_id, timeout=self._timeout
822-
)
891+
return _BabyWork(pg=self, tx=tx, rx=rx, op_id=op_id, timeout=self._timeout)
823892

824893
def allreduce(
825894
self,
@@ -952,8 +1021,6 @@ class ProcessGroupBabyNCCL(ProcessGroupBaby):
9521021
tensors may leak in the current PyTorch implementation. TODO fix
9531022
"""
9541023

955-
WORK_CLASS = _BabyWorkNCCL
956-
9571024
@classmethod
9581025
def _create_pg(cls, store: Store, rank: int, world_size: int) -> BaseProcessGroup:
9591026
# pyre-fixme[16]: no attribute ProcessGroupNCCL

torchft/process_group_test.py

+21
Original file line numberDiff line numberDiff line change
@@ -266,6 +266,27 @@ def test_baby_gloo_apis(self) -> None:
266266

267267
self.assertEqual(a.num_active_work(), 0)
268268

269+
# pyre-fixme[56]: Pyre was not able to infer the type of argument
270+
@skipUnless(torch.cuda.is_available(), "needs CUDA")
271+
def test_baby_nccl_apis(self) -> None:
272+
store = TCPStore(
273+
host_name="localhost", port=0, is_master=True, wait_for_workers=False
274+
)
275+
276+
store_addr = f"localhost:{store.port}/prefix"
277+
278+
a = ProcessGroupBabyNCCL(timeout=timedelta(seconds=10))
279+
a.configure(store_addr, 0, 1)
280+
281+
_test_pg(a, torch.randn((2, 3), device="cuda"))
282+
283+
torch.cuda.synchronize()
284+
285+
# force collection to ensure no BabyWork objects remain
286+
gc.collect()
287+
288+
self.assertEqual(a.num_active_work(), 0)
289+
269290
def test_dummy(self) -> None:
270291
pg = ProcessGroupDummy(0, 1)
271292
m = nn.Linear(3, 4)

0 commit comments

Comments
 (0)