Skip to content

Commit 71cc79c

Browse files
committed
ProcessGroupBabyNCCL: support multiple streams and use event on start
1 parent c3d5d54 commit 71cc79c

File tree

2 files changed

+183
-45
lines changed

2 files changed

+183
-45
lines changed

torchft/process_group.py

+154-44
Original file line numberDiff line numberDiff line change
@@ -19,13 +19,16 @@
1919
import logging
2020
import queue
2121
import threading
22+
from collections import defaultdict
23+
from contextlib import contextmanager, nullcontext
2224
from dataclasses import dataclass
2325
from datetime import timedelta
2426
from typing import (
2527
TYPE_CHECKING,
2628
Any,
2729
Callable,
2830
Dict,
31+
Generator,
2932
List,
3033
Optional,
3134
Tuple,
@@ -586,29 +589,59 @@ def __init__(
586589
self._timeout = timeout
587590

588591
def wait(self, timeout: Optional[timedelta] = None) -> bool:
592+
self._pg._assert_alive()
593+
589594
self._tx.put(("wait", self._op_id), timeout=self._timeout)
590-
assert _get(self._rx, self._timeout) == self._op_id
595+
op_id, event = cast(
596+
Tuple[int, Optional[torch.cuda.Event]],
597+
_get(self._rx, timeout or self._timeout),
598+
)
599+
assert op_id == self._op_id
600+
if event is not None:
601+
event.wait()
591602
return True
592603

604+
def synchronize(self) -> None:
605+
# TODO: No one seems to use this and NCCL wait already only waits the
606+
# stream and is non-blocking on the CPU side so no real need for a
607+
# separate call.
608+
raise NotImplementedError("not implemented")
609+
593610
def get_future(self) -> Future[object]:
594611
return self._pg._get_future(self._op_id)
595612

596613
def __del__(self) -> None:
597614
self._tx.put(("del", self._op_id), timeout=self._timeout)
598615

599616

600-
class _BabyWorkNCCL(_BabyWork):
601-
def wait(self, timeout: Optional[timedelta] = None) -> bool:
602-
self._tx.put(("synchronize", self._op_id), timeout=self._timeout)
603-
# pyre-fixme[23]: unable to unpack into 2 values
604-
op_id, event = _get(self._rx, self._timeout)
605-
assert op_id == self._op_id
606-
assert isinstance(event, torch.cuda.Event)
617+
def _is_any_cuda(obj: object) -> bool:
618+
"""
619+
Returns true if any of the tensors in the object are CUDA tensors.
607620
608-
# Wait on Event makes the stream wait but not the CPU thread.
609-
event.wait()
621+
Supports lists, tuples, dicts, and tensors.
622+
"""
623+
if isinstance(obj, torch.Tensor):
624+
return obj.is_cuda
625+
elif isinstance(obj, (list, tuple)):
626+
return any(_is_any_cuda(o) for o in obj)
627+
elif isinstance(obj, dict):
628+
return any(_is_any_cuda(o) for o in obj.values())
629+
else:
630+
return False
610631

611-
return True
632+
633+
@dataclass
634+
class _OpMetadata:
635+
work: Work
636+
stream: Optional[torch.cuda.Stream]
637+
638+
@contextmanager
639+
def set_stream(self) -> Generator[None, None, None]:
640+
if self.stream is not None:
641+
with torch.cuda.stream(self.stream):
642+
yield
643+
else:
644+
yield
612645

613646

614647
class ProcessGroupBaby(ProcessGroup):
@@ -617,11 +650,8 @@ class ProcessGroupBaby(ProcessGroup):
617650
subprocess. Since it's running in a subprocess all tensors need to be in
618651
shared memory or will be moved to shared memory. CUDA tensors are implicitly
619652
share able and don't need any changes.
620-
621653
"""
622654

623-
WORK_CLASS: Type[_BabyWork] = _BabyWork
624-
625655
def __init__(self, timeout: Union[float, timedelta] = 60.0) -> None:
626656
super().__init__(0, 1)
627657

@@ -640,6 +670,10 @@ def __init__(self, timeout: Union[float, timedelta] = 60.0) -> None:
640670

641671
self._timeout: float = timeout
642672

673+
self._cuda_device_id: Optional[int] = (
674+
torch.cuda.current_device() if torch.cuda.is_available() else None
675+
)
676+
643677
def configure(self, store_addr: str, rank: int, world_size: int) -> None:
644678
if self._p is not None:
645679
self._p.kill()
@@ -679,7 +713,15 @@ def configure(self, store_addr: str, rank: int, world_size: int) -> None:
679713

680714
self._p = ctx.Process(
681715
target=self._worker,
682-
args=(store_addr, rank, world_size, self._tx, self._rx, self._future_queue),
716+
args=(
717+
store_addr,
718+
rank,
719+
world_size,
720+
self._tx,
721+
self._rx,
722+
self._future_queue,
723+
self._cuda_device_id,
724+
),
683725
daemon=True,
684726
)
685727
self._p.start()
@@ -704,8 +746,12 @@ def _worker(
704746
rx: mp.Queue,
705747
tx: mp.Queue,
706748
future_queue: mp.Queue,
749+
cuda_device_id: Optional[int],
707750
) -> None:
708751
try:
752+
if cuda_device_id is not None:
753+
torch.cuda.set_device(cuda_device_id)
754+
709755
store = create_store_client(store_addr)
710756

711757
try:
@@ -716,23 +762,62 @@ def _worker(
716762
return
717763
tx.put(None)
718764

719-
work = {}
765+
streams = defaultdict(lambda: torch.cuda.Stream())
766+
work: Dict[int, _OpMetadata] = {}
720767
next_op_id: int = 0
721768

722769
while True:
723770
op = rx.get()
724771
cmd = op[0]
725772
if cmd == "func":
726-
func_name, args, kwargs = op[1:]
727-
args = _PickleSafeOptions.unsafe_args(args)
728-
fn = getattr(pg, func_name)
729-
work[next_op_id] = fn(*args, **kwargs)
773+
func_name, args, kwargs, stream_id, event = op[1:]
774+
775+
# To avoid potential deadlocks we need to preserve the
776+
# stream/synchronization behavior of the parent process.
777+
# We allocate one Stream per stream_id to make sure that we
778+
# don't accidentally introduce cross stream synchronization
779+
# points.
780+
stream = streams[stream_id] if stream_id is not None else None
781+
with (
782+
torch.cuda.stream(stream)
783+
if stream is not None
784+
else nullcontext()
785+
):
786+
787+
# Make the stream wait on the cuda event to make sure we
788+
# don't start the operation until the tensor is ready.
789+
if event is not None:
790+
event.wait()
791+
792+
args = _PickleSafeOptions.unsafe_args(args)
793+
fn = getattr(pg, func_name)
794+
work[next_op_id] = _OpMetadata(
795+
work=fn(*args, **kwargs),
796+
stream=stream,
797+
)
730798
tx.put(next_op_id)
731799
next_op_id += 1
732800
elif cmd == "wait":
733801
op_id: int = op[1]
734-
work[op_id].wait()
735-
tx.put(op_id)
802+
803+
metadata = work[op_id]
804+
805+
with metadata.set_stream():
806+
# With WorkNCCL this makes the stream wait not the CPU when
807+
# no timeout is passed.
808+
metadata.work.wait()
809+
810+
# Register event on the stream that we can pass to the main
811+
# process.
812+
event = (
813+
torch.cuda.current_stream().record_event(
814+
torch.cuda.Event(interprocess=True)
815+
)
816+
if metadata.stream is not None
817+
else None
818+
)
819+
820+
tx.put((op_id, event))
736821
elif cmd == "del":
737822
op_id: int = op[1]
738823
del work[op_id]
@@ -746,25 +831,12 @@ def callback(fut: Future[object]) -> None:
746831
except Exception as e:
747832
future_queue.put((op_id, _FUTURE_EXCEPTION, e))
748833

749-
work[op_id].get_future().add_done_callback(callback)
834+
work[op_id].work.get_future().add_done_callback(callback)
750835
tx.put(op_id)
751-
elif cmd == "synchronize":
752-
# CUDA only, use events instead of waiting on CPU
753-
op_id = op[1]
754-
755-
# With WorkNCCL this makes the stream wait not the CPU when
756-
# no timeout is passed.
757-
work[op_id].wait()
758-
759-
# Register event on the stream that we can pass to the main
760-
# process.
761-
event = torch.cuda.Event(interprocess=True)
762-
event.record()
763-
764-
del work[op_id]
765-
tx.put((op_id, event))
766836
elif cmd == "num_active_work":
767837
tx.put(len(work))
838+
elif cmd == "cuda_device_id":
839+
tx.put(torch.cuda.current_device())
768840
else:
769841
raise ValueError(f"unknown cmd: {cmd}")
770842

@@ -792,6 +864,8 @@ def _future_handler(self, future_queue: mp.Queue) -> None:
792864
logger.exception(f"got unexpected error in future handler: {e}")
793865

794866
def _get_future(self, op_id: int) -> Future[object]:
867+
self._assert_alive()
868+
795869
with self._futures_lock:
796870
fut = Future() # pyre-fixme[29]: is not a function
797871
self._futures[op_id] = fut
@@ -804,22 +878,50 @@ def _get_future(self, op_id: int) -> Future[object]:
804878
return fut
805879

806880
def _run_func(self, func: str, *args: object, **kwargs: object) -> Work:
881+
self._assert_alive()
882+
807883
rx = self._rx
808884
tx = self._tx
809885
assert rx is not None
810886
assert tx is not None
811887

888+
is_cuda = _is_any_cuda(args)
889+
890+
stream_id = torch.cuda.current_stream().stream_id if is_cuda else None
891+
event = (
892+
torch.cuda.current_stream().record_event(
893+
torch.cuda.Event(interprocess=True)
894+
)
895+
if is_cuda
896+
else None
897+
)
898+
812899
tx.put(
813-
("func", func, _PickleSafeOptions.safe_args(args), kwargs),
900+
(
901+
"func",
902+
func,
903+
_PickleSafeOptions.safe_args(args),
904+
kwargs,
905+
stream_id,
906+
event,
907+
),
814908
timeout=self._timeout,
815909
)
816910

817911
op_id = _get(rx, self._timeout)
818912
assert isinstance(op_id, int), f"invalid return {op_id}"
819913

820-
return self.WORK_CLASS(
821-
pg=self, tx=tx, rx=rx, op_id=op_id, timeout=self._timeout
822-
)
914+
return _BabyWork(pg=self, tx=tx, rx=rx, op_id=op_id, timeout=self._timeout)
915+
916+
def _assert_alive(self) -> None:
917+
"""
918+
Assert that the process group is alive. This is used to ensure that
919+
operations are not performed on a dead process group and any errors are surfaced.
920+
"""
921+
p = self._p
922+
assert p is not None
923+
if not p.is_alive():
924+
raise RuntimeError(f"child process {p.pid=} is dead {p.exitcode=}")
823925

824926
def allreduce(
825927
self,
@@ -877,6 +979,13 @@ def num_active_work(self) -> int:
877979
assert self._rx is not None
878980
return cast(int, _get(self._rx, self._timeout))
879981

982+
def cuda_device_id(self) -> int:
983+
assert self._tx is not None
984+
self._tx.put(("cuda_device_id",), timeout=self._timeout)
985+
986+
assert self._rx is not None
987+
return cast(int, _get(self._rx, self._timeout))
988+
880989

881990
@dataclass
882991
class _PickleSafeOptions:
@@ -950,9 +1059,10 @@ class ProcessGroupBabyNCCL(ProcessGroupBaby):
9501059
9511060
WARNING: If the child process is killed while an operation is running, CUDA
9521061
tensors may leak in the current PyTorch implementation. TODO fix
953-
"""
9541062
955-
WORK_CLASS = _BabyWorkNCCL
1063+
If CUDA tensors are being used on a non-default device you must call
1064+
``torch.cuda.set_device()`` prior to instantiating this ProcessGroup.
1065+
"""
9561066

9571067
@classmethod
9581068
def _create_pg(cls, store: Store, rank: int, world_size: int) -> BaseProcessGroup:

torchft/process_group_test.py

+29-1
Original file line numberDiff line numberDiff line change
@@ -266,6 +266,32 @@ def test_baby_gloo_apis(self) -> None:
266266

267267
self.assertEqual(a.num_active_work(), 0)
268268

269+
# pyre-fixme[56]: Pyre was not able to infer the type of argument
270+
@skipUnless(torch.cuda.is_available(), "needs CUDA")
271+
def test_baby_nccl_apis(self) -> None:
272+
# set to 1 if more than >=2 gpus
273+
device_id = 1 % torch.cuda.device_count()
274+
torch.cuda.set_device(device_id)
275+
276+
store = TCPStore(
277+
host_name="localhost", port=0, is_master=True, wait_for_workers=False
278+
)
279+
280+
store_addr = f"localhost:{store.port}/prefix"
281+
282+
a = ProcessGroupBabyNCCL(timeout=timedelta(seconds=10))
283+
a.configure(store_addr, 0, 1)
284+
285+
_test_pg(a, torch.randn((2, 3), device="cuda"))
286+
287+
torch.cuda.synchronize()
288+
289+
# force collection to ensure no BabyWork objects remain
290+
gc.collect()
291+
292+
self.assertEqual(a.num_active_work(), 0)
293+
self.assertEqual(a.cuda_device_id(), device_id)
294+
269295
def test_dummy(self) -> None:
270296
pg = ProcessGroupDummy(0, 1)
271297
m = nn.Linear(3, 4)
@@ -282,12 +308,14 @@ def test_baby_nccl_2gpu(self) -> None:
282308
store_addr: str = f"localhost:{store.port}/prefix"
283309

284310
def run(rank: int) -> Tuple[torch.Tensor, Work]:
311+
torch.cuda.set_device(rank)
312+
285313
a = ProcessGroupBabyNCCL()
286314
a.configure(store_addr, rank, 2)
287315

288316
self.assertEqual(a.size(), 2)
289317

290-
at = torch.tensor([rank + 1], device=f"cuda:{rank}")
318+
at = torch.tensor([rank + 1], device="cuda")
291319

292320
a_work = a.allreduce([at], ReduceOp.SUM)
293321
return at, a_work

0 commit comments

Comments
 (0)