Skip to content

Commit 0f8b44a

Browse files
committed
ProcessGroupBabyNCCL: support multiple streams and use event on start
1 parent 68e1d28 commit 0f8b44a

File tree

2 files changed

+179
-49
lines changed

2 files changed

+179
-49
lines changed

torchft/process_group.py

+148-46
Original file line numberDiff line numberDiff line change
@@ -19,17 +19,18 @@
1919
import logging
2020
import queue
2121
import threading
22+
from contextlib import contextmanager, nullcontext
2223
from dataclasses import dataclass
2324
from datetime import timedelta
2425
from typing import (
2526
TYPE_CHECKING,
2627
Any,
2728
Callable,
2829
Dict,
30+
Generator,
2931
List,
3032
Optional,
3133
Tuple,
32-
Type,
3334
TypeVar,
3435
Union,
3536
cast,
@@ -58,7 +59,6 @@
5859
BroadcastOptions,
5960
ReduceOp,
6061
Work,
61-
_world,
6262
)
6363
from torch.futures import Future
6464

@@ -586,29 +586,59 @@ def __init__(
586586
self._timeout = timeout
587587

588588
def wait(self, timeout: Optional[timedelta] = None) -> bool:
589+
self._pg._assert_alive()
590+
589591
self._tx.put(("wait", self._op_id), timeout=self._timeout)
590-
assert _get(self._rx, self._timeout) == self._op_id
592+
op_id, event = cast(
593+
Tuple[int, Optional[torch.cuda.Event]],
594+
_get(self._rx, timeout or self._timeout),
595+
)
596+
assert op_id == self._op_id
597+
if event is not None:
598+
event.wait()
591599
return True
592600

601+
def synchronize(self) -> None:
602+
# TODO: No one seems to use this and NCCL wait already only waits the
603+
# stream and is non-blocking on the CPU side so no real need for a
604+
# separate call.
605+
raise NotImplementedError("not implemented")
606+
593607
def get_future(self) -> Future[object]:
594608
return self._pg._get_future(self._op_id)
595609

596610
def __del__(self) -> None:
597611
self._tx.put(("del", self._op_id), timeout=self._timeout)
598612

599613

600-
class _BabyWorkNCCL(_BabyWork):
601-
def wait(self, timeout: Optional[timedelta] = None) -> bool:
602-
self._tx.put(("synchronize", self._op_id), timeout=self._timeout)
603-
# pyre-fixme[23]: unable to unpack into 2 values
604-
op_id, event = _get(self._rx, self._timeout)
605-
assert op_id == self._op_id
606-
assert isinstance(event, torch.cuda.Event)
614+
def _is_any_cuda(obj: object) -> bool:
615+
"""
616+
Returns true if any of the tensors in the object are CUDA tensors.
607617
608-
# Wait on Event makes the stream wait but not the CPU thread.
609-
event.wait()
618+
Supports lists, tuples, dicts, and tensors.
619+
"""
620+
if isinstance(obj, torch.Tensor):
621+
return obj.is_cuda
622+
elif isinstance(obj, (list, tuple)):
623+
return any(_is_any_cuda(o) for o in obj)
624+
elif isinstance(obj, dict):
625+
return any(_is_any_cuda(o) for o in obj.values())
626+
else:
627+
return False
610628

611-
return True
629+
630+
@dataclass
631+
class _OpMetadata:
632+
work: Work
633+
stream: Optional[torch.cuda.Stream]
634+
635+
@contextmanager
636+
def set_stream(self) -> Generator[None, None, None]:
637+
if self.stream is not None:
638+
with torch.cuda.stream(self.stream):
639+
yield
640+
else:
641+
yield
612642

613643

614644
class ProcessGroupBaby(ProcessGroup):
@@ -617,11 +647,8 @@ class ProcessGroupBaby(ProcessGroup):
617647
subprocess. Since it's running in a subprocess all tensors need to be in
618648
shared memory or will be moved to shared memory. CUDA tensors are implicitly
619649
share able and don't need any changes.
620-
621650
"""
622651

623-
WORK_CLASS: Type[_BabyWork] = _BabyWork
624-
625652
def __init__(self, timeout: Union[float, timedelta] = 60.0) -> None:
626653
super().__init__(0, 1)
627654

@@ -679,7 +706,14 @@ def configure(self, store_addr: str, rank: int, world_size: int) -> None:
679706

680707
self._p = ctx.Process(
681708
target=self._worker,
682-
args=(store_addr, rank, world_size, self._tx, self._rx, self._future_queue),
709+
args=(
710+
store_addr,
711+
rank,
712+
world_size,
713+
self._tx,
714+
self._rx,
715+
self._future_queue,
716+
),
683717
daemon=True,
684718
)
685719
self._p.start()
@@ -716,23 +750,76 @@ def _worker(
716750
return
717751
tx.put(None)
718752

719-
work = {}
753+
streams: Dict[str, torch.cuda.Stream] = {}
754+
work: Dict[int, _OpMetadata] = {}
720755
next_op_id: int = 0
721756

722757
while True:
723758
op = rx.get()
724759
cmd = op[0]
725760
if cmd == "func":
726-
func_name, args, kwargs = op[1:]
727-
args = _PickleSafeOptions.unsafe_args(args)
728-
fn = getattr(pg, func_name)
729-
work[next_op_id] = fn(*args, **kwargs)
761+
func_name, args, kwargs, stream_device, stream_id, event = op[1:]
762+
763+
print(f"func {func_name=}")
764+
765+
# To avoid potential deadlocks we need to preserve the
766+
# stream/synchronization behavior of the parent process.
767+
# We allocate one Stream per stream_id to make sure that we
768+
# don't accidentally introduce cross stream synchronization
769+
# points.
770+
if stream_id is not None:
771+
stream_key = f"{stream_device}/{stream_id}"
772+
if stream_key not in streams:
773+
streams[stream_key] = torch.cuda.Stream(
774+
device=stream_device
775+
)
776+
stream = streams[stream_key]
777+
else:
778+
stream = None
779+
780+
with (
781+
torch.cuda.stream(stream)
782+
if stream is not None
783+
else nullcontext()
784+
):
785+
print("stream created")
786+
787+
# Make the stream wait on the cuda event to make sure we
788+
# don't start the operation until the tensor is ready.
789+
if event is not None:
790+
event.wait()
791+
792+
print("waited")
793+
794+
args = _PickleSafeOptions.unsafe_args(args)
795+
fn = getattr(pg, func_name)
796+
work[next_op_id] = _OpMetadata(
797+
work=fn(*args, **kwargs),
798+
stream=stream,
799+
)
730800
tx.put(next_op_id)
731801
next_op_id += 1
732802
elif cmd == "wait":
733803
op_id: int = op[1]
734-
work[op_id].wait()
735-
tx.put(op_id)
804+
805+
metadata = work[op_id]
806+
807+
with metadata.set_stream():
808+
# With WorkNCCL this makes the stream wait not the CPU when
809+
# no timeout is passed.
810+
metadata.work.wait()
811+
812+
# Register event on the stream that we can pass to the main
813+
# process.
814+
event = (
815+
torch.cuda.current_stream().record_event(
816+
torch.cuda.Event(interprocess=True)
817+
)
818+
if metadata.stream is not None
819+
else None
820+
)
821+
822+
tx.put((op_id, event))
736823
elif cmd == "del":
737824
op_id: int = op[1]
738825
del work[op_id]
@@ -746,23 +833,8 @@ def callback(fut: Future[object]) -> None:
746833
except Exception as e:
747834
future_queue.put((op_id, _FUTURE_EXCEPTION, e))
748835

749-
work[op_id].get_future().add_done_callback(callback)
836+
work[op_id].work.get_future().add_done_callback(callback)
750837
tx.put(op_id)
751-
elif cmd == "synchronize":
752-
# CUDA only, use events instead of waiting on CPU
753-
op_id = op[1]
754-
755-
# With WorkNCCL this makes the stream wait not the CPU when
756-
# no timeout is passed.
757-
work[op_id].wait()
758-
759-
# Register event on the stream that we can pass to the main
760-
# process.
761-
event = torch.cuda.Event(interprocess=True)
762-
event.record()
763-
764-
del work[op_id]
765-
tx.put((op_id, event))
766838
elif cmd == "num_active_work":
767839
tx.put(len(work))
768840
else:
@@ -792,6 +864,8 @@ def _future_handler(self, future_queue: mp.Queue) -> None:
792864
logger.exception(f"got unexpected error in future handler: {e}")
793865

794866
def _get_future(self, op_id: int) -> Future[object]:
867+
self._assert_alive()
868+
795869
with self._futures_lock:
796870
fut = Future() # pyre-fixme[29]: is not a function
797871
self._futures[op_id] = fut
@@ -804,22 +878,52 @@ def _get_future(self, op_id: int) -> Future[object]:
804878
return fut
805879

806880
def _run_func(self, func: str, *args: object, **kwargs: object) -> Work:
881+
self._assert_alive()
882+
807883
rx = self._rx
808884
tx = self._tx
809885
assert rx is not None
810886
assert tx is not None
811887

888+
is_cuda = _is_any_cuda(args)
889+
890+
stream_device = torch.cuda.current_stream().device if is_cuda else None
891+
stream_id = torch.cuda.current_stream().stream_id if is_cuda else None
892+
event = (
893+
torch.cuda.current_stream().record_event(
894+
torch.cuda.Event(interprocess=True)
895+
)
896+
if is_cuda
897+
else None
898+
)
899+
812900
tx.put(
813-
("func", func, _PickleSafeOptions.safe_args(args), kwargs),
901+
(
902+
"func",
903+
func,
904+
_PickleSafeOptions.safe_args(args),
905+
kwargs,
906+
stream_device,
907+
stream_id,
908+
event,
909+
),
814910
timeout=self._timeout,
815911
)
816912

817913
op_id = _get(rx, self._timeout)
818914
assert isinstance(op_id, int), f"invalid return {op_id}"
819915

820-
return self.WORK_CLASS(
821-
pg=self, tx=tx, rx=rx, op_id=op_id, timeout=self._timeout
822-
)
916+
return _BabyWork(pg=self, tx=tx, rx=rx, op_id=op_id, timeout=self._timeout)
917+
918+
def _assert_alive(self) -> None:
919+
"""
920+
Assert that the process group is alive. This is used to ensure that
921+
operations are not performed on a dead process group and any errors are surfaced.
922+
"""
923+
p = self._p
924+
assert p is not None
925+
if not p.is_alive():
926+
raise RuntimeError(f"child process {p.pid=} is dead {p.exitcode=}")
823927

824928
def allreduce(
825929
self,
@@ -952,8 +1056,6 @@ class ProcessGroupBabyNCCL(ProcessGroupBaby):
9521056
tensors may leak in the current PyTorch implementation. TODO fix
9531057
"""
9541058

955-
WORK_CLASS = _BabyWorkNCCL
956-
9571059
@classmethod
9581060
def _create_pg(cls, store: Store, rank: int, world_size: int) -> BaseProcessGroup:
9591061
# pyre-fixme[16]: no attribute ProcessGroupNCCL

torchft/process_group_test.py

+31-3
Original file line numberDiff line numberDiff line change
@@ -266,6 +266,31 @@ def test_baby_gloo_apis(self) -> None:
266266

267267
self.assertEqual(a.num_active_work(), 0)
268268

269+
# pyre-fixme[56]: Pyre was not able to infer the type of argument
270+
@skipUnless(torch.cuda.is_available(), "needs CUDA")
271+
def test_baby_nccl_apis(self) -> None:
272+
# set to 1 if more than >=2 gpus
273+
device_id = 1 % torch.cuda.device_count()
274+
torch.cuda.set_device(device_id)
275+
276+
store = TCPStore(
277+
host_name="localhost", port=0, is_master=True, wait_for_workers=False
278+
)
279+
280+
store_addr = f"localhost:{store.port}/prefix"
281+
282+
a = ProcessGroupBabyNCCL(timeout=timedelta(seconds=10))
283+
a.configure(store_addr, 0, 1)
284+
285+
_test_pg(a, torch.randn((2, 3), device="cuda"))
286+
287+
torch.cuda.synchronize()
288+
289+
# force collection to ensure no BabyWork objects remain
290+
gc.collect()
291+
292+
self.assertEqual(a.num_active_work(), 0)
293+
269294
def test_dummy(self) -> None:
270295
pg = ProcessGroupDummy(0, 1)
271296
m = nn.Linear(3, 4)
@@ -282,12 +307,15 @@ def test_baby_nccl_2gpu(self) -> None:
282307
store_addr: str = f"localhost:{store.port}/prefix"
283308

284309
def run(rank: int) -> Tuple[torch.Tensor, Work]:
285-
a = ProcessGroupBabyNCCL()
310+
a = ProcessGroupBabyNCCL(
311+
timeout=timedelta(seconds=10.0),
312+
)
286313
a.configure(store_addr, rank, 2)
287-
288314
self.assertEqual(a.size(), 2)
289315

290-
at = torch.tensor([rank + 1], device=f"cuda:{rank}")
316+
# We test using set_device to ensure stream device is correct.
317+
torch.cuda.set_device(rank)
318+
at = torch.tensor([rank + 1], device="cuda")
291319

292320
a_work = a.allreduce([at], ReduceOp.SUM)
293321
return at, a_work

0 commit comments

Comments
 (0)