diff --git a/torchft/process_group.py b/torchft/process_group.py index b38d291..c1b67f0 100644 --- a/torchft/process_group.py +++ b/torchft/process_group.py @@ -54,7 +54,10 @@ ) from torch.distributed.distributed_c10d import ( AllgatherOptions, + AllreduceCoalescedOptions, AllreduceOptions, + AllToAllOptions, + BarrierOptions, BroadcastOptions, ReduceOp, ReduceScatterOptions, @@ -107,40 +110,66 @@ def __init__(self, *args: object, **kwargs: object) -> None: self._group_name: Optional[str] = None - def configure(self, store_addr: str, rank: int, world_size: int) -> None: + # pyre-fixme[14]: inconsistent override + def allgather( + self, + output_tensors: List[List[torch.Tensor]], + input_tensor: List[torch.Tensor], + opts: AllgatherOptions, + ) -> Work: """ - This reconfigures the ProcessGroup to use a new store, rank and world size. - - Every time this is called it must be provided with a unique prefixed - store address. I.e. localhost:1234/my/prefix/1 - - This function will block until the underlying ProcessGroup is created. - If an error occurs this will throw. + Gathers tensors from the whole group in a list. - Args: - store_addr: address of the store to use - rank: rank of this process - world_size: world size of this process group + See torch.distributed.all_gather for more details. """ raise NotImplementedError("not implemented") # pyre-fixme[14]: inconsistent override def allreduce( - self, tensors: List[torch.Tensor], opts: Union[AllreduceOptions, ReduceOp] + self, + tensors: List[torch.Tensor], + opts: Union[AllreduceOptions, ReduceOp], ) -> Work: + """ + Reduces the tensor data across all machines in such a way that all get the final result. + + See torch.distributed.all_reduce for more details. + """ + raise NotImplementedError("not implemented") + + def allreduce_coalesced( + self, + tensors: List[torch.Tensor], + opts: AllreduceCoalescedOptions, + ) -> Work: + """ + Performs an all_reduce operation in a coalesced manner. + + See torch.distributed.all_reduce_coalesced for more details. + """ raise NotImplementedError("not implemented") # pyre-fixme[14]: inconsistent override - def allgather( + def alltoall_base( self, - output_tensors: List[List[torch.Tensor]], - input_tensor: List[torch.Tensor], - opts: AllgatherOptions, + output_buffer: torch.Tensor, + input_buffer: torch.Tensor, + output_split_sizes: List[int], + input_split_sizes: List[int], + opts: AllToAllOptions, ) -> Work: """ - Gathers tensors from the whole group in a list. + Performs an all_to_all operation. - See torch.distributed.all_gather for more details. + See torch.distributed.all_to_all_single for more details. + """ + raise NotImplementedError("not implemented") + + def barrier(self, opts: BarrierOptions) -> Work: + """ + Synchronizes all processes. + + See torch.distributed.barrier for more details. """ raise NotImplementedError("not implemented") @@ -160,6 +189,15 @@ def broadcast_one(self, tensor: torch.Tensor, root: int) -> Work: opts.rootRank = root return self.broadcast([tensor], opts) + # pyre-fixme[14]: inconsistent override + def recv(self, tensors: List[torch.Tensor], src_rank: int, tag: int) -> Work: + """ + Receives a list of tensors from the process with rank `rank`. + + See torch.distributed.recv for more details. + """ + raise NotImplementedError("not implemented") + # pyre-fixme[14]: inconsistent override def reduce_scatter( self, @@ -174,6 +212,32 @@ def reduce_scatter( """ raise NotImplementedError("not implemented") + # pyre-fixme[14]: inconsistent override + def send(self, tensors: List[torch.Tensor], dst_rank: int, tag: int) -> Work: + """ + Sends a list of tensors to the process with rank `dst_rank`. + + See torch.distributed.send for more details. + """ + raise NotImplementedError("not implemented") + + def configure(self, store_addr: str, rank: int, world_size: int) -> None: + """ + This reconfigures the ProcessGroup to use a new store, rank and world size. + + Every time this is called it must be provided with a unique prefixed + store address. I.e. localhost:1234/my/prefix/1 + + This function will block until the underlying ProcessGroup is created. + If an error occurs this will throw. + + Args: + store_addr: address of the store to use + rank: rank of this process + world_size: world size of this process group + """ + raise NotImplementedError("not implemented") + def size(self) -> int: raise NotImplementedError("not implemented") @@ -268,9 +332,6 @@ def configure(self, store_addr: str, rank: int, world_size: int) -> None: def _create_pg(self, store: Store, rank: int, world_size: int) -> BaseProcessGroup: raise NotImplementedError("not implemented") - def allreduce(self, tensors: List[torch.Tensor], opts: object) -> Work: - return self.parent.allreduce(tensors, opts) - def allgather( self, output_tensors: List[List[torch.Tensor]], @@ -279,9 +340,35 @@ def allgather( ) -> Work: return self.parent.allgather(output_tensors, input_tensor, opts) + def allreduce(self, tensors: List[torch.Tensor], opts: object) -> Work: + return self.parent.allreduce(tensors, opts) + + def allreduce_coalesced( + self, tensors: List[torch.Tensor], opts: Union[AllreduceOptions, ReduceOp] + ) -> Work: + return self.parent.allreduce_coalesced(tensors, opts) + + def alltoall_base( + self, + output_buffer: torch.Tensor, + input_buffer: torch.Tensor, + output_split_sizes: List[int], + input_split_sizes: List[int], + opts: AllToAllOptions, + ) -> Work: + return self.parent.alltoall_base( + output_buffer, input_buffer, output_split_sizes, input_split_sizes, opts + ) + + def barrier(self, opts: BarrierOptions) -> Work: + return self.parent.barrier(opts) + def broadcast(self, tensor_list: List[torch.Tensor], opts: object) -> Work: return self.parent.broadcast(tensor_list, opts) + def recv(self, tensors: List[torch.Tensor], src_rank: int, tag: int) -> Work: + return self.parent.recv(tensors, src_rank, tag) + def reduce_scatter( self, output_tensors: List[torch.Tensor], @@ -290,6 +377,9 @@ def reduce_scatter( ) -> Work: return self.parent.reduce_scatter(output_tensors, input_tensors, opts) + def send(self, tensors: List[torch.Tensor], dst_rank: int, tag: int) -> Work: + return self.parent.send(tensors, dst_rank, tag) + def size(self) -> int: return self.parent.size() @@ -414,11 +504,37 @@ def allreduce(self, tensors: List[torch.Tensor], opts: object) -> Work: self._work.append(res) return res + def allreduce_coalesced( + self, tensors: List[torch.Tensor], opts: Union[AllreduceOptions, ReduceOp] + ) -> Work: + res = _DummyWork(tensors) + self._work.append(res) + return res + + def alltoall_base( + self, + output_buffer: torch.Tensor, + input_buffer: torch.Tensor, + output_split_sizes: List[int], + input_split_sizes: List[int], + opts: AllToAllOptions, + ) -> Work: + output_buffer.copy_(input_buffer) + res = _DummyWork([output_buffer]) + self._work.append(res) + return res + + def barrier(self, opts: BarrierOptions) -> Work: + return _DummyWork(None) + def broadcast(self, tensor_list: List[torch.Tensor], opts: object) -> Work: res = _DummyWork(tensor_list) self._work.append(res) return res + def recv(self, tensors: List[torch.Tensor], src_rank: int, tag: int) -> Work: + return _DummyWork(None) + def reduce_scatter( self, output_tensors: List[torch.Tensor], @@ -432,6 +548,9 @@ def reduce_scatter( self._work.append(res) return res + def send(self, tensors: List[torch.Tensor], dst_rank: int, tag: int) -> Work: + return _DummyWork(None) + def size(self) -> int: return self._world @@ -653,6 +772,26 @@ def set_stream(self) -> Generator[None, None, None]: yield +def _maybe_share_tensors( + tensor: Union[List[List[torch.Tensor]], List[torch.Tensor], torch.Tensor] +) -> None: + """Move a tensor / list of tensors to shared memory if not already in shared memory.""" + if isinstance(tensor, list): + for t in tensor: + _maybe_share_tensors(t) + elif isinstance(tensor, torch.Tensor): + if not tensor.is_shared(): + tensor.share_memory_() + else: + raise TypeError(f"expected tensor or list but got {type(tensor)}") + + +def _assert_list(tensors: Union[List[torch.Tensor], List[List[torch.Tensor]]]) -> None: + """Assert that the input is a list of tensors or a nested list of tensors.""" + if not isinstance(tensors, list): + raise TypeError(f"expected list but got {type(tensors)}") + + class ProcessGroupBaby(ProcessGroup): """ This is a process group that runs the underlying process group in a @@ -979,71 +1118,89 @@ def _assert_alive(self) -> None: if not p.is_alive(): raise RuntimeError(f"child process {p.pid=} is dead {p.exitcode=}") + def allgather( + self, + output_tensors: List[List[torch.Tensor]], + input_tensor: List[torch.Tensor], + opts: AllgatherOptions, + ) -> Work: + _assert_list(output_tensors) + _assert_list(input_tensor) + _maybe_share_tensors(output_tensors) + _maybe_share_tensors(input_tensor) + return self._run_func("allgather", output_tensors, input_tensor, opts) + def allreduce( self, tensors: List[torch.Tensor], opts: Union[dist.AllreduceOptions, dist.ReduceOp], ) -> Work: - assert isinstance(tensors, list), "input must be list" - - for tensor in tensors: - if not tensor.is_shared(): - tensor.share_memory_() - + _assert_list(tensors) + _maybe_share_tensors(tensors) return self._run_func("allreduce", tensors, opts) - def allgather( + def allreduce_coalesced( self, - output_tensors: List[List[torch.Tensor]], - input_tensor: List[torch.Tensor], - opts: AllgatherOptions, + tensors: List[torch.Tensor], + opts: Union[dist.AllreduceCoalescedOptions, dist.ReduceOp], ) -> Work: - assert isinstance(output_tensors, list), "input must be list" - assert isinstance(input_tensor, list), "input must be list" - - for tensor_list in output_tensors: - for tensor in tensor_list: - if not tensor.is_shared(): - tensor.share_memory_() + _assert_list(tensors) + _maybe_share_tensors(tensors) + return self._run_func("allreduce_coalesced", tensors, opts) - for tensor in input_tensor: - if not tensor.is_shared(): - tensor.share_memory_() + def alltoall_base( + self, + output_buffer: torch.Tensor, + input_buffer: torch.Tensor, + output_split_sizes: List[int], + input_split_sizes: List[int], + opts: AllToAllOptions, + ) -> Work: + _maybe_share_tensors(output_buffer) + _maybe_share_tensors(input_buffer) + return self._run_func( + "alltoall_base", + output_buffer, + input_buffer, + output_split_sizes, + input_split_sizes, + opts, + ) - return self._run_func("allgather", output_tensors, input_tensor, opts) + def barrier(self, opts: BarrierOptions) -> Work: + return self._run_func("barrier", opts) def broadcast( self, tensor_list: List[torch.Tensor], opts: BroadcastOptions, ) -> Work: - assert isinstance(tensor_list, list), "input must be list" - - for tensor in tensor_list: - if not tensor.is_shared(): - tensor.share_memory_() - + _assert_list(tensor_list) + _maybe_share_tensors(tensor_list) return self._run_func("broadcast", tensor_list, opts) + def recv(self, tensors: List[torch.Tensor], src_rank: int, tag: int) -> Work: + _assert_list(tensors) + _maybe_share_tensors(tensors) + return self._run_func("recv", tensors, src_rank, tag) + def reduce_scatter( self, output_tensors: List[torch.Tensor], input_tensors: List[List[torch.Tensor]], opts: ReduceScatterOptions, ) -> Work: - assert isinstance(output_tensors, list), "input must be list" - assert isinstance(input_tensors, list), "input must be list" - - for tensor in output_tensors: - if not tensor.is_shared(): - tensor.share_memory_() - - for tensor_list in input_tensors: - for tensor in tensor_list: - if not tensor.is_shared(): - tensor.share_memory_() + _assert_list(output_tensors) + _assert_list(input_tensors) + _maybe_share_tensors(output_tensors) + _maybe_share_tensors(input_tensors) return self._run_func("reduce_scatter", output_tensors, input_tensors, opts) + def send(self, tensors: List[torch.Tensor], dst_rank: int, tag: int) -> Work: + _assert_list(tensors) + _maybe_share_tensors(tensors) + return self._run_func("send", tensors, dst_rank, tag) + def size(self) -> int: return self._world_size @@ -1069,8 +1226,11 @@ def safe_args(cls, args: T) -> T: elif isinstance( args, ( - AllreduceOptions, AllgatherOptions, + AllreduceOptions, + AllreduceCoalescedOptions, + AllToAllOptions, + BarrierOptions, BroadcastOptions, ReduceScatterOptions, ), diff --git a/torchft/process_group_test.py b/torchft/process_group_test.py index fb67457..6afc825 100644 --- a/torchft/process_group_test.py +++ b/torchft/process_group_test.py @@ -20,7 +20,10 @@ from torch import nn from torch._C._distributed_c10d import ( AllgatherOptions, + AllreduceCoalescedOptions, AllreduceOptions, + AllToAllOptions, + BarrierOptions, BroadcastOptions, ReduceOp, ReduceScatterOptions, @@ -88,11 +91,23 @@ def check_tensors(arg: Any) -> None: # pyre-ignore[2] for item in arg: check_tensors(item) - # Test collectives + # Test collectives. send/recv require multiple processes to test, so we skip them here collectives = [ ("allreduce", ([input_tensor], AllreduceOptions())), ("allreduce", ([input_tensor], ReduceOp.SUM)), + ("allreduce_coalesced", ([input_tensor], AllreduceCoalescedOptions())), ("allgather", (output_tensors, [input_tensor], AllgatherOptions())), + ( + "alltoall_base", + ( + output_tensors[0][0], + input_tensor, + [input_tensor.shape[0]], + [input_tensor.shape[0]], + AllToAllOptions(), + ), + ), + ("barrier", (BarrierOptions(),)), ("broadcast", (tensor_list, BroadcastOptions())), ("broadcast_one", (input_tensor, 0)), ( @@ -122,6 +137,120 @@ def check_tensors(arg: Any) -> None: # pyre-ignore[2] return works +def _test_multi_pg(pg: ProcessGroup, rank: int, tensor: torch.Tensor) -> None: + """ + Helper function to test a set of collective operations in settings with multiple + process groups. + """ + # Test allgather + tensor_list = [ + torch.zeros(2, dtype=torch.int64, device=tensor.device) for _ in range(2) + ] + allgather_tensor = ( + torch.arange(2, dtype=torch.int64, device=tensor.device) + 1 + 2 * rank + ) + allgather_work = pg.allgather([tensor_list], [allgather_tensor], AllgatherOptions()) + allgather_work.wait() + torch.testing.assert_close( + tensor_list[0], torch.tensor([1, 2], device=tensor.device) + ) + torch.testing.assert_close( + tensor_list[1], torch.tensor([3, 4], device=tensor.device) + ) + + # Test allreduce + tc = tensor.clone() + allreduce_work = pg.allreduce([tc], ReduceOp.SUM) + allreduce_work.wait() + expected_tensor = torch.tensor([3], device=tc.device) + torch.testing.assert_close(tc, expected_tensor) + + # Test allreduce_coalesced + tensors = [tensor.clone(), tensor.clone() + 1] + allreduce_coalesced_work = pg.allreduce_coalesced( + tensors, AllreduceCoalescedOptions() + ) + allreduce_coalesced_work.wait() + torch.testing.assert_close(tensors[0], torch.tensor([3], device=tensor.device)) + torch.testing.assert_close(tensors[1], torch.tensor([5], device=tensor.device)) + + # Test all-to-all + input_tensor = torch.tensor([rank + 1, rank + 5], device=tensor.device) + output_tensor = torch.empty_like(input_tensor) + alltoall_work = pg.alltoall_base( + output_tensor, input_tensor, [1, 1], [1, 1], AllToAllOptions() + ) + alltoall_work.wait() + if rank == 0: + expected_alltoall = torch.tensor([1, 2], device=tensor.device) + else: + expected_alltoall = torch.tensor([5, 6], device=tensor.device) + torch.testing.assert_close(output_tensor, expected_alltoall) + + # Test broadcast + broadcast_tensor = tensor.clone() if rank == 0 else torch.zeros_like(tensor) + broadcast_work = pg.broadcast([broadcast_tensor], BroadcastOptions()) + broadcast_work.wait() + expected_broadcast = torch.tensor([1], device=tensor.device) + torch.testing.assert_close(broadcast_tensor, expected_broadcast) + + # Test broadcast_one + broadcast_one_tensor = tensor.clone() if rank == 0 else torch.zeros_like(tensor) + broadcast_one_work = pg.broadcast_one(broadcast_one_tensor, 0) + broadcast_one_work.wait() + torch.testing.assert_close( + broadcast_one_tensor, torch.tensor([1], device=tensor.device) + ) + + # Test barrier + barrier_work = pg.barrier(BarrierOptions()) + barrier_work.wait() + + # Test send/recv + if rank == 0: + send_tensor = tensor.clone() + send_work = pg.send([send_tensor], 1, 0) + send_work.wait() + else: + recv_tensor = torch.zeros_like(tensor) + recv_work = pg.recv([recv_tensor], 0, 0) + recv_work.wait() + expected = torch.tensor([1], device=tensor.device) + torch.testing.assert_close(recv_tensor, expected) + + # Test reduce_scatter + if tensor.device.type == "cuda": + # reduce scatter not supported on GLOO + input_tensors = [ + torch.tensor( + [rank + 1, rank + 3], device=tensor.device, dtype=torch.float32 + ), + torch.tensor( + [rank + 5, rank + 7], device=tensor.device, dtype=torch.float32 + ), + ] + output_tensor = torch.empty(2, device=tensor.device) + reduce_scatter_work = pg.reduce_scatter( + [output_tensor], [input_tensors], ReduceScatterOptions() + ) + reduce_scatter_work.wait() + # Input tensors become: + # rank 0: [[1, 3], [5, 7]] + # rank 1: [[2, 4], [6, 8]] + # Therefore expected outputs are: + # rank 0: [1 + 2 = 3, 3 + 4 = 7] + # rank 1: [5 + 6 = 11, 7 + 8 = 15] + if rank == 0: + expected_reduce_scatter = torch.tensor( + [3, 7], device=tensor.device, dtype=torch.float32 + ) + else: + expected_reduce_scatter = torch.tensor( + [11, 15], device=tensor.device, dtype=torch.float32 + ) + torch.testing.assert_close(output_tensor, expected_reduce_scatter) + + class ProcessGroupTest(TestCase): def test_gloo(self) -> None: store = TCPStore( @@ -187,31 +316,21 @@ def test_baby_gloo(self) -> None: store_addr: str = f"localhost:{store.port}/prefix" - def run(rank: int) -> Tuple[torch.Tensor, Work]: - a = ProcessGroupBabyGloo() - a.configure(store_addr, rank, 2) - - self.assertEqual(a.size(), 2) + def run(rank: int, store_addr: str = store_addr) -> None: + pg = ProcessGroupBabyGloo() + pg.configure(store_addr, rank, 2) - at = torch.tensor([rank + 1]) + self.assertEqual(pg.size(), 2) - a_work = a.allreduce([at], ReduceOp.SUM) - return at, a_work + tensor = torch.tensor([rank + 1]) + _test_multi_pg(pg, rank, tensor) with ThreadPoolExecutor(max_workers=2) as executor: a_fut = executor.submit(run, 0) b_fut = executor.submit(run, 1) - at, a_work = a_fut.result() - bt, b_work = b_fut.result() - - a_work.wait() - fut = b_work.get_future() - - fut.wait() - - torch.testing.assert_close(at, torch.tensor([3])) - torch.testing.assert_close(bt, torch.tensor([3])) + a_fut.result() + b_fut.result() def test_baby_gloo_timeout(self) -> None: store = TCPStore( @@ -291,16 +410,21 @@ def test_baby_nccl_apis(self) -> None: store_addr = f"localhost:{store.port}/prefix" a = ProcessGroupBabyNCCL(timeout=timedelta(seconds=10)) - a.configure(store_addr, 0, 1) + try: + a.configure(store_addr, 0, 1) - _test_pg(a, torch.randn((2, 3), device="cuda")) + _test_pg(a, torch.randn((2, 3), device="cuda")) - torch.cuda.synchronize() + torch.cuda.synchronize() - # force collection to ensure no BabyWork objects remain - gc.collect() + # force collection to ensure no BabyWork objects remain + gc.collect() - self.assertEqual(a.num_active_work(), 0) + self.assertEqual(a.num_active_work(), 0) + finally: + a.shutdown() + torch.cuda.synchronize() + torch.cuda.empty_cache() def test_dummy(self) -> None: pg = ProcessGroupDummy(0, 1) @@ -317,7 +441,7 @@ def test_baby_nccl_2gpu(self) -> None: store_addr: str = f"localhost:{store.port}/prefix" - def run(rank: int) -> Tuple[ProcessGroupBabyNCCL, torch.Tensor, Work]: + def run(rank: int) -> ProcessGroupBabyNCCL: a = ProcessGroupBabyNCCL( timeout=timedelta(seconds=10.0), ) @@ -327,31 +451,22 @@ def run(rank: int) -> Tuple[ProcessGroupBabyNCCL, torch.Tensor, Work]: # We test using set_device to ensure stream device is correct. torch.cuda.set_device(rank) at = torch.tensor([rank + 1], device="cuda") - - a_work = a.allreduce([at], ReduceOp.SUM) - return a, at, a_work + try: + _test_multi_pg(a, rank, at) + finally: + a.shutdown() + return a with ThreadPoolExecutor(max_workers=2) as executor: a_fut = executor.submit(run, 0) b_fut = executor.submit(run, 1) - a, at, a_work = a_fut.result() - b, bt, b_work = b_fut.result() + a = a_fut.result() + b = b_fut.result() - try: - a_work.wait() - b_work.get_future().wait() - torch.testing.assert_close(at.cpu(), bt.cpu()) - finally: - # cleanup - first ensure that babywork is deleted before shutting down PGs - # note futures must be deleted as they hold references to babywork - del a_fut - del b_fut - del a_work - del b_work - gc.collect() - b.shutdown() - a.shutdown() + # cleanup + torch.cuda.synchronize() + torch.cuda.empty_cache() def test_device_mesh(self) -> None: os.environ["MASTER_ADDR"] = "localhost"