Skip to content

Commit 7d48c52

Browse files
committed
ProcessGroupBabyNCCL: support multiple streams and use event on start
1 parent 68e1d28 commit 7d48c52

File tree

2 files changed

+180
-47
lines changed

2 files changed

+180
-47
lines changed

torchft/process_group.py

+149-44
Original file line numberDiff line numberDiff line change
@@ -19,13 +19,16 @@
1919
import logging
2020
import queue
2121
import threading
22+
from collections import defaultdict
23+
from contextlib import contextmanager, nullcontext
2224
from dataclasses import dataclass
2325
from datetime import timedelta
2426
from typing import (
2527
TYPE_CHECKING,
2628
Any,
2729
Callable,
2830
Dict,
31+
Generator,
2932
List,
3033
Optional,
3134
Tuple,
@@ -586,29 +589,59 @@ def __init__(
586589
self._timeout = timeout
587590

588591
def wait(self, timeout: Optional[timedelta] = None) -> bool:
592+
self._pg._assert_alive()
593+
589594
self._tx.put(("wait", self._op_id), timeout=self._timeout)
590-
assert _get(self._rx, self._timeout) == self._op_id
595+
op_id, event = cast(
596+
Tuple[int, Optional[torch.cuda.Event]],
597+
_get(self._rx, timeout or self._timeout),
598+
)
599+
assert op_id == self._op_id
600+
if event is not None:
601+
event.wait()
591602
return True
592603

604+
def synchronize(self) -> None:
605+
# TODO: No one seems to use this and NCCL wait already only waits the
606+
# stream and is non-blocking on the CPU side so no real need for a
607+
# separate call.
608+
raise NotImplementedError("not implemented")
609+
593610
def get_future(self) -> Future[object]:
594611
return self._pg._get_future(self._op_id)
595612

596613
def __del__(self) -> None:
597614
self._tx.put(("del", self._op_id), timeout=self._timeout)
598615

599616

600-
class _BabyWorkNCCL(_BabyWork):
601-
def wait(self, timeout: Optional[timedelta] = None) -> bool:
602-
self._tx.put(("synchronize", self._op_id), timeout=self._timeout)
603-
# pyre-fixme[23]: unable to unpack into 2 values
604-
op_id, event = _get(self._rx, self._timeout)
605-
assert op_id == self._op_id
606-
assert isinstance(event, torch.cuda.Event)
617+
def _is_any_cuda(obj: object) -> bool:
618+
"""
619+
Returns true if any of the tensors in the object are CUDA tensors.
607620
608-
# Wait on Event makes the stream wait but not the CPU thread.
609-
event.wait()
621+
Supports lists, tuples, dicts, and tensors.
622+
"""
623+
if isinstance(obj, torch.Tensor):
624+
return obj.is_cuda
625+
elif isinstance(obj, (list, tuple)):
626+
return any(_is_any_cuda(o) for o in obj)
627+
elif isinstance(obj, dict):
628+
return any(_is_any_cuda(o) for o in obj.values())
629+
else:
630+
return False
610631

611-
return True
632+
633+
@dataclass
634+
class _OpMetadata:
635+
work: Work
636+
stream: Optional[torch.cuda.Stream]
637+
638+
@contextmanager
639+
def set_stream(self) -> Generator[None, None, None]:
640+
if self.stream is not None:
641+
with torch.cuda.stream(self.stream):
642+
yield
643+
else:
644+
yield
612645

613646

614647
class ProcessGroupBaby(ProcessGroup):
@@ -617,11 +650,8 @@ class ProcessGroupBaby(ProcessGroup):
617650
subprocess. Since it's running in a subprocess all tensors need to be in
618651
shared memory or will be moved to shared memory. CUDA tensors are implicitly
619652
share able and don't need any changes.
620-
621653
"""
622654

623-
WORK_CLASS: Type[_BabyWork] = _BabyWork
624-
625655
def __init__(self, timeout: Union[float, timedelta] = 60.0) -> None:
626656
super().__init__(0, 1)
627657

@@ -679,7 +709,14 @@ def configure(self, store_addr: str, rank: int, world_size: int) -> None:
679709

680710
self._p = ctx.Process(
681711
target=self._worker,
682-
args=(store_addr, rank, world_size, self._tx, self._rx, self._future_queue),
712+
args=(
713+
store_addr,
714+
rank,
715+
world_size,
716+
self._tx,
717+
self._rx,
718+
self._future_queue,
719+
),
683720
daemon=True,
684721
)
685722
self._p.start()
@@ -716,23 +753,76 @@ def _worker(
716753
return
717754
tx.put(None)
718755

719-
work = {}
756+
streams: Dict[str, torch.cuda.Stream] = {}
757+
work: Dict[int, _OpMetadata] = {}
720758
next_op_id: int = 0
721759

722760
while True:
723761
op = rx.get()
724762
cmd = op[0]
725763
if cmd == "func":
726-
func_name, args, kwargs = op[1:]
727-
args = _PickleSafeOptions.unsafe_args(args)
728-
fn = getattr(pg, func_name)
729-
work[next_op_id] = fn(*args, **kwargs)
764+
func_name, args, kwargs, stream_device, stream_id, event = op[1:]
765+
766+
print(f"func {func_name=}")
767+
768+
# To avoid potential deadlocks we need to preserve the
769+
# stream/synchronization behavior of the parent process.
770+
# We allocate one Stream per stream_id to make sure that we
771+
# don't accidentally introduce cross stream synchronization
772+
# points.
773+
if stream_id is not None:
774+
stream_key = f"{stream_device}/{stream_id}"
775+
if stream_key not in streams:
776+
streams[stream_key] = torch.cuda.Stream(
777+
device=stream_device
778+
)
779+
stream = streams[stream_key]
780+
else:
781+
stream = None
782+
783+
with (
784+
torch.cuda.stream(stream)
785+
if stream is not None
786+
else nullcontext()
787+
):
788+
print("stream created")
789+
790+
# Make the stream wait on the cuda event to make sure we
791+
# don't start the operation until the tensor is ready.
792+
if event is not None:
793+
event.wait()
794+
795+
print("waited")
796+
797+
args = _PickleSafeOptions.unsafe_args(args)
798+
fn = getattr(pg, func_name)
799+
work[next_op_id] = _OpMetadata(
800+
work=fn(*args, **kwargs),
801+
stream=stream,
802+
)
730803
tx.put(next_op_id)
731804
next_op_id += 1
732805
elif cmd == "wait":
733806
op_id: int = op[1]
734-
work[op_id].wait()
735-
tx.put(op_id)
807+
808+
metadata = work[op_id]
809+
810+
with metadata.set_stream():
811+
# With WorkNCCL this makes the stream wait not the CPU when
812+
# no timeout is passed.
813+
metadata.work.wait()
814+
815+
# Register event on the stream that we can pass to the main
816+
# process.
817+
event = (
818+
torch.cuda.current_stream().record_event(
819+
torch.cuda.Event(interprocess=True)
820+
)
821+
if metadata.stream is not None
822+
else None
823+
)
824+
825+
tx.put((op_id, event))
736826
elif cmd == "del":
737827
op_id: int = op[1]
738828
del work[op_id]
@@ -746,23 +836,8 @@ def callback(fut: Future[object]) -> None:
746836
except Exception as e:
747837
future_queue.put((op_id, _FUTURE_EXCEPTION, e))
748838

749-
work[op_id].get_future().add_done_callback(callback)
839+
work[op_id].work.get_future().add_done_callback(callback)
750840
tx.put(op_id)
751-
elif cmd == "synchronize":
752-
# CUDA only, use events instead of waiting on CPU
753-
op_id = op[1]
754-
755-
# With WorkNCCL this makes the stream wait not the CPU when
756-
# no timeout is passed.
757-
work[op_id].wait()
758-
759-
# Register event on the stream that we can pass to the main
760-
# process.
761-
event = torch.cuda.Event(interprocess=True)
762-
event.record()
763-
764-
del work[op_id]
765-
tx.put((op_id, event))
766841
elif cmd == "num_active_work":
767842
tx.put(len(work))
768843
else:
@@ -792,6 +867,8 @@ def _future_handler(self, future_queue: mp.Queue) -> None:
792867
logger.exception(f"got unexpected error in future handler: {e}")
793868

794869
def _get_future(self, op_id: int) -> Future[object]:
870+
self._assert_alive()
871+
795872
with self._futures_lock:
796873
fut = Future() # pyre-fixme[29]: is not a function
797874
self._futures[op_id] = fut
@@ -804,22 +881,52 @@ def _get_future(self, op_id: int) -> Future[object]:
804881
return fut
805882

806883
def _run_func(self, func: str, *args: object, **kwargs: object) -> Work:
884+
self._assert_alive()
885+
807886
rx = self._rx
808887
tx = self._tx
809888
assert rx is not None
810889
assert tx is not None
811890

891+
is_cuda = _is_any_cuda(args)
892+
893+
stream_device = torch.cuda.current_stream().device if is_cuda else None
894+
stream_id = torch.cuda.current_stream().stream_id if is_cuda else None
895+
event = (
896+
torch.cuda.current_stream().record_event(
897+
torch.cuda.Event(interprocess=True)
898+
)
899+
if is_cuda
900+
else None
901+
)
902+
812903
tx.put(
813-
("func", func, _PickleSafeOptions.safe_args(args), kwargs),
904+
(
905+
"func",
906+
func,
907+
_PickleSafeOptions.safe_args(args),
908+
kwargs,
909+
stream_device,
910+
stream_id,
911+
event,
912+
),
814913
timeout=self._timeout,
815914
)
816915

817916
op_id = _get(rx, self._timeout)
818917
assert isinstance(op_id, int), f"invalid return {op_id}"
819918

820-
return self.WORK_CLASS(
821-
pg=self, tx=tx, rx=rx, op_id=op_id, timeout=self._timeout
822-
)
919+
return _BabyWork(pg=self, tx=tx, rx=rx, op_id=op_id, timeout=self._timeout)
920+
921+
def _assert_alive(self) -> None:
922+
"""
923+
Assert that the process group is alive. This is used to ensure that
924+
operations are not performed on a dead process group and any errors are surfaced.
925+
"""
926+
p = self._p
927+
assert p is not None
928+
if not p.is_alive():
929+
raise RuntimeError(f"child process {p.pid=} is dead {p.exitcode=}")
823930

824931
def allreduce(
825932
self,
@@ -952,8 +1059,6 @@ class ProcessGroupBabyNCCL(ProcessGroupBaby):
9521059
tensors may leak in the current PyTorch implementation. TODO fix
9531060
"""
9541061

955-
WORK_CLASS = _BabyWorkNCCL
956-
9571062
@classmethod
9581063
def _create_pg(cls, store: Store, rank: int, world_size: int) -> BaseProcessGroup:
9591064
# pyre-fixme[16]: no attribute ProcessGroupNCCL

torchft/process_group_test.py

+31-3
Original file line numberDiff line numberDiff line change
@@ -266,6 +266,31 @@ def test_baby_gloo_apis(self) -> None:
266266

267267
self.assertEqual(a.num_active_work(), 0)
268268

269+
# pyre-fixme[56]: Pyre was not able to infer the type of argument
270+
@skipUnless(torch.cuda.is_available(), "needs CUDA")
271+
def test_baby_nccl_apis(self) -> None:
272+
# set to 1 if more than >=2 gpus
273+
device_id = 1 % torch.cuda.device_count()
274+
torch.cuda.set_device(device_id)
275+
276+
store = TCPStore(
277+
host_name="localhost", port=0, is_master=True, wait_for_workers=False
278+
)
279+
280+
store_addr = f"localhost:{store.port}/prefix"
281+
282+
a = ProcessGroupBabyNCCL(timeout=timedelta(seconds=10))
283+
a.configure(store_addr, 0, 1)
284+
285+
_test_pg(a, torch.randn((2, 3), device="cuda"))
286+
287+
torch.cuda.synchronize()
288+
289+
# force collection to ensure no BabyWork objects remain
290+
gc.collect()
291+
292+
self.assertEqual(a.num_active_work(), 0)
293+
269294
def test_dummy(self) -> None:
270295
pg = ProcessGroupDummy(0, 1)
271296
m = nn.Linear(3, 4)
@@ -282,12 +307,15 @@ def test_baby_nccl_2gpu(self) -> None:
282307
store_addr: str = f"localhost:{store.port}/prefix"
283308

284309
def run(rank: int) -> Tuple[torch.Tensor, Work]:
285-
a = ProcessGroupBabyNCCL()
310+
a = ProcessGroupBabyNCCL(
311+
timeout=timedelta(seconds=10.0),
312+
)
286313
a.configure(store_addr, rank, 2)
287-
288314
self.assertEqual(a.size(), 2)
289315

290-
at = torch.tensor([rank + 1], device=f"cuda:{rank}")
316+
# We test using set_device to ensure stream device is correct.
317+
torch.cuda.set_device(rank)
318+
at = torch.tensor([rank + 1], device="cuda")
291319

292320
a_work = a.allreduce([at], ReduceOp.SUM)
293321
return at, a_work

0 commit comments

Comments
 (0)