From c782f4e1d089382c4444b3a29bf2fab72d077e43 Mon Sep 17 00:00:00 2001 From: Allen Wang <9057208+allenwang28@users.noreply.github.com> Date: Fri, 21 Feb 2025 10:54:29 -0600 Subject: [PATCH] Enhances `process_group_test` (#113) * initial commit exploring multi process group tests * modify all collective tests to be robust to world size * add resiliency tests * play w/ timeouts to speed things up * finalize * increase baby gloo timeout * COLLECTIVES => SKIP --------- Co-authored-by: Allen Chenjim Wang --- torchft/process_group_test.py | 512 +++++++++++++++++++++++++--------- 1 file changed, 377 insertions(+), 135 deletions(-) diff --git a/torchft/process_group_test.py b/torchft/process_group_test.py index 6afc825..b3ca93e 100644 --- a/torchft/process_group_test.py +++ b/torchft/process_group_test.py @@ -11,12 +11,13 @@ import unittest from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor from datetime import timedelta -from typing import Any, Dict, Tuple, cast +from typing import Any, Callable, Dict, List, cast from unittest import TestCase, skipUnless from unittest.mock import Mock import torch import torch.distributed as dist +from parameterized import parameterized from torch import nn from torch._C._distributed_c10d import ( AllgatherOptions, @@ -137,64 +138,132 @@ def check_tensors(arg: Any) -> None: # pyre-ignore[2] return works -def _test_multi_pg(pg: ProcessGroup, rank: int, tensor: torch.Tensor) -> None: - """ - Helper function to test a set of collective operations in settings with multiple - process groups. +def run_allgather_test(pg: ProcessGroup, rank: int, tensor: torch.Tensor) -> None: + """Test allgather collective operation. + + Suppose each rank's local tensor = [rank+1, rank+2], + we allgather => gather onto a list of length world_sz. """ - # Test allgather - tensor_list = [ - torch.zeros(2, dtype=torch.int64, device=tensor.device) for _ in range(2) + world_sz = pg.size() + to_gather = torch.stack([tensor, tensor + 1], dim=0) + # shape: (2,) + to_gather = to_gather.reshape(-1) + + # Gathers as follows: [ [ recv0 ], [ recv1 ], ... [ recv_{sz-1} ] ] + # Each recv is shape (2,) + output_list = [ + torch.zeros(2, device=tensor.device, dtype=tensor.dtype) + for _ in range(world_sz) ] - allgather_tensor = ( - torch.arange(2, dtype=torch.int64, device=tensor.device) + 1 + 2 * rank - ) - allgather_work = pg.allgather([tensor_list], [allgather_tensor], AllgatherOptions()) - allgather_work.wait() - torch.testing.assert_close( - tensor_list[0], torch.tensor([1, 2], device=tensor.device) - ) - torch.testing.assert_close( - tensor_list[1], torch.tensor([3, 4], device=tensor.device) - ) - # Test allreduce + work = pg.allgather([output_list], [to_gather], AllgatherOptions()) + work.wait() + + for r in range(world_sz): + expected = torch.tensor( + [r + 1, r + 2], device=tensor.device, dtype=tensor.dtype + ) + torch.testing.assert_close(output_list[r], expected) + + +def run_allreduce_test(pg: ProcessGroup, rank: int, tensor: torch.Tensor) -> None: + """Test allreduce collective operation. + + Assume each rank's tensor has value = rank + 1. + The final result after allreduce(SUM) should be sum(r=1,...,world_sz-1). + """ tc = tensor.clone() - allreduce_work = pg.allreduce([tc], ReduceOp.SUM) - allreduce_work.wait() - expected_tensor = torch.tensor([3], device=tc.device) - torch.testing.assert_close(tc, expected_tensor) - - # Test allreduce_coalesced - tensors = [tensor.clone(), tensor.clone() + 1] - allreduce_coalesced_work = pg.allreduce_coalesced( - tensors, AllreduceCoalescedOptions() + world_sz = pg.size() + work = pg.allreduce([tc], ReduceOp.SUM) + work.wait() + expected_val = sum(r + 1 for r in range(world_sz)) + torch.testing.assert_close(tc, torch.tensor([expected_val], device=tensor.device)) + + +def run_allreduce_coalesced_test( + pg: ProcessGroup, rank: int, tensor: torch.Tensor +) -> None: + """Test allreduce_coalesced collective operation. + + Assume each rank's tensor has value = rank + 1. + We coalesce 1 tensors: + - t0 = [rank + 1] + - t1 = [rank + 2] + + Our final sum should be sum(r=1,...,world_sz-1) + sum(r=2,...,world_sz-1). + """ + world_sz = pg.size() + t0 = tensor.clone() + t1 = tensor.clone() + 1 + work = pg.allreduce_coalesced([t0, t1], AllreduceCoalescedOptions()) + work.wait() + sum_t0 = sum(r + 1 for r in range(world_sz)) + sum_t1 = sum(r + 2 for r in range(world_sz)) + torch.testing.assert_close(t0, torch.tensor([sum_t0], device=t0.device)) + torch.testing.assert_close(t1, torch.tensor([sum_t1], device=t1.device)) + + +def run_alltoall_test(pg: ProcessGroup, rank: int, tensor: torch.Tensor) -> None: + """Test all-to-all collective operation. + + Suppose each rank's local tensor = [rank*ws+1, rank*ws+2, ..., rank*ws + n] + + e.g.: + rank=0 => [1,2] + rank=1 => [3,4] + + After all-to-all, rank r's output[k] = the element from rank k that is destined for rank r, + e.g.: (k*n) + (r+1): + + rank=0 => [1,3] + rank=1 => [2,4] + + """ + world_sz = pg.size() + if world_sz < 2: + return + + input_tensor = torch.arange( + start=rank * world_sz + 1, + end=rank * world_sz + 1 + world_sz, + device=tensor.device, + dtype=tensor.dtype, ) - allreduce_coalesced_work.wait() - torch.testing.assert_close(tensors[0], torch.tensor([3], device=tensor.device)) - torch.testing.assert_close(tensors[1], torch.tensor([5], device=tensor.device)) + output_tensor = torch.empty(world_sz, device=tensor.device, dtype=tensor.dtype) + + send_sz = [1] * world_sz + recv_sz = [1] * world_sz - # Test all-to-all - input_tensor = torch.tensor([rank + 1, rank + 5], device=tensor.device) - output_tensor = torch.empty_like(input_tensor) alltoall_work = pg.alltoall_base( - output_tensor, input_tensor, [1, 1], [1, 1], AllToAllOptions() + output_tensor, input_tensor, send_sz, recv_sz, AllToAllOptions() ) alltoall_work.wait() - if rank == 0: - expected_alltoall = torch.tensor([1, 2], device=tensor.device) - else: - expected_alltoall = torch.tensor([5, 6], device=tensor.device) - torch.testing.assert_close(output_tensor, expected_alltoall) - # Test broadcast + expected = torch.empty(world_sz, device=tensor.device, dtype=tensor.dtype) + for k in range(world_sz): + val = k * world_sz + (rank + 1) + expected[k] = val + + torch.testing.assert_close(output_tensor, expected) + + +def run_broadcast_test(pg: ProcessGroup, rank: int, tensor: torch.Tensor) -> None: + """Test broadcast collective operation. + + rank0 will broadcast a known value and all other ranks should get it. + """ broadcast_tensor = tensor.clone() if rank == 0 else torch.zeros_like(tensor) broadcast_work = pg.broadcast([broadcast_tensor], BroadcastOptions()) broadcast_work.wait() expected_broadcast = torch.tensor([1], device=tensor.device) torch.testing.assert_close(broadcast_tensor, expected_broadcast) - # Test broadcast_one + +def run_broadcast_one_test(pg: ProcessGroup, rank: int, tensor: torch.Tensor) -> None: + """Test broadcast_one collective operation. + + rank0 will broadcast a known value and all other ranks should get it. + """ broadcast_one_tensor = tensor.clone() if rank == 0 else torch.zeros_like(tensor) broadcast_one_work = pg.broadcast_one(broadcast_one_tensor, 0) broadcast_one_work.wait() @@ -202,53 +271,98 @@ def _test_multi_pg(pg: ProcessGroup, rank: int, tensor: torch.Tensor) -> None: broadcast_one_tensor, torch.tensor([1], device=tensor.device) ) - # Test barrier + +# pyre-fixme[2]: Parameter must be annotated. +def run_barrier_test(pg: ProcessGroup, *args) -> None: + """Test barrier collective operation.""" barrier_work = pg.barrier(BarrierOptions()) barrier_work.wait() - # Test send/recv + +def run_send_recv_test(pg: ProcessGroup, rank: int, tensor: torch.Tensor) -> None: + """Test send/recv point-to-point operations. + + Simple point-to-point between ranks 0 and 1, ignored for other ranks. + """ + if pg.size() < 2: + return if rank == 0: send_tensor = tensor.clone() send_work = pg.send([send_tensor], 1, 0) send_work.wait() - else: + elif rank == 1: recv_tensor = torch.zeros_like(tensor) recv_work = pg.recv([recv_tensor], 0, 0) recv_work.wait() expected = torch.tensor([1], device=tensor.device) torch.testing.assert_close(recv_tensor, expected) - # Test reduce_scatter - if tensor.device.type == "cuda": - # reduce scatter not supported on GLOO - input_tensors = [ - torch.tensor( - [rank + 1, rank + 3], device=tensor.device, dtype=torch.float32 - ), - torch.tensor( - [rank + 5, rank + 7], device=tensor.device, dtype=torch.float32 - ), - ] - output_tensor = torch.empty(2, device=tensor.device) - reduce_scatter_work = pg.reduce_scatter( - [output_tensor], [input_tensors], ReduceScatterOptions() + +def run_reduce_scatter_test(pg: ProcessGroup, rank: int, tensor: torch.Tensor) -> None: + """Test reduce_scatter collective operation. + + Assume each rank creates a matrix where each row r contains values: + [r * world_sz + 1, ..., r * world_sz + world_sz] + + For example, with world_size=2: + [[1, 2], + [3, 4]] + + The reduce_scatter operation then: + - Reduces (sums) corresponding rows across all ranks + - Scatters the results so each rank gets one row of the final sum + - Since all ranks had the same initial data, the expected result for each rank r is: + rank r receives: [rworld_sz + 1, ..., rworld_sz + world_sz] * world_sz + + For example, with 2 ranks: + rank 0 gets: [1, 2] * 2 = [2, 4] (first row) + rank 1 gets: [3, 4] * 2 = [6, 8] (second row) + """ + if tensor.device.type == "cpu": + return + # reduce scatter not supported on GLOO + world_sz = pg.size() + if world_sz < 2: + return + + local_data = [] + for r in range(world_sz): + row_vals = torch.arange( + start=r * world_sz + 1, + end=r * world_sz + world_sz + 1, + device=tensor.device, + dtype=torch.float32, ) - reduce_scatter_work.wait() - # Input tensors become: - # rank 0: [[1, 3], [5, 7]] - # rank 1: [[2, 4], [6, 8]] - # Therefore expected outputs are: - # rank 0: [1 + 2 = 3, 3 + 4 = 7] - # rank 1: [5 + 6 = 11, 7 + 8 = 15] - if rank == 0: - expected_reduce_scatter = torch.tensor( - [3, 7], device=tensor.device, dtype=torch.float32 - ) - else: - expected_reduce_scatter = torch.tensor( - [11, 15], device=tensor.device, dtype=torch.float32 - ) - torch.testing.assert_close(output_tensor, expected_reduce_scatter) + local_data.append(row_vals) + + out = torch.zeros(world_sz, device=tensor.device, dtype=torch.float32) + opts = ReduceScatterOptions() + opts.reduceOp = ReduceOp.SUM + work = pg.reduce_scatter([out], [local_data], opts) + work.wait() + + expected_row = torch.arange( + start=rank * world_sz + 1, + end=rank * world_sz + world_sz + 1, + device=tensor.device, + dtype=torch.float32, + ) + expected_sum = expected_row * world_sz + torch.testing.assert_close(out, expected_sum) + + +_COLLECTIVE_TO_FUNC: Dict[str, Callable[[ProcessGroup, int, torch.Tensor], None]] = { + "allgather": run_allgather_test, + "allreduce": run_allreduce_test, + "allreduce_coalesced": run_allreduce_coalesced_test, + "alltoall_base": run_alltoall_test, + "barrier": run_barrier_test, + "broadcast": run_broadcast_test, + "broadcast_one": run_broadcast_one_test, + "reduce_scatter": run_reduce_scatter_test, + "send/recv": run_send_recv_test, +} +_ALL_COLLECTIVES: List[str] = list(_COLLECTIVE_TO_FUNC.keys()) class ProcessGroupTest(TestCase): @@ -309,29 +423,6 @@ def test_nccl(self) -> None: torch.cuda.synchronize() - def test_baby_gloo(self) -> None: - store = TCPStore( - host_name="localhost", port=0, is_master=True, wait_for_workers=False - ) - - store_addr: str = f"localhost:{store.port}/prefix" - - def run(rank: int, store_addr: str = store_addr) -> None: - pg = ProcessGroupBabyGloo() - pg.configure(store_addr, rank, 2) - - self.assertEqual(pg.size(), 2) - - tensor = torch.tensor([rank + 1]) - _test_multi_pg(pg, rank, tensor) - - with ThreadPoolExecutor(max_workers=2) as executor: - a_fut = executor.submit(run, 0) - b_fut = executor.submit(run, 1) - - a_fut.result() - b_fut.result() - def test_baby_gloo_timeout(self) -> None: store = TCPStore( host_name="localhost", port=0, is_master=True, wait_for_workers=False @@ -432,42 +523,6 @@ def test_dummy(self) -> None: m = torch.nn.parallel.DistributedDataParallel(m, process_group=pg) m(torch.rand(2, 3)) - # pyre-fixme[56]: Pyre was not able to infer the type of argument - @skipUnless(torch.cuda.device_count() >= 2, "need two CUDA devices") - def test_baby_nccl_2gpu(self) -> None: - store = TCPStore( - host_name="localhost", port=0, is_master=True, wait_for_workers=False - ) - - store_addr: str = f"localhost:{store.port}/prefix" - - def run(rank: int) -> ProcessGroupBabyNCCL: - a = ProcessGroupBabyNCCL( - timeout=timedelta(seconds=10.0), - ) - a.configure(store_addr, rank, 2) - self.assertEqual(a.size(), 2) - - # We test using set_device to ensure stream device is correct. - torch.cuda.set_device(rank) - at = torch.tensor([rank + 1], device="cuda") - try: - _test_multi_pg(a, rank, at) - finally: - a.shutdown() - return a - - with ThreadPoolExecutor(max_workers=2) as executor: - a_fut = executor.submit(run, 0) - b_fut = executor.submit(run, 1) - - a = a_fut.result() - b = b_fut.result() - - # cleanup - torch.cuda.synchronize() - torch.cuda.empty_cache() - def test_device_mesh(self) -> None: os.environ["MASTER_ADDR"] = "localhost" os.environ["MASTER_PORT"] = str(0) @@ -600,3 +655,190 @@ def test_init_device_mesh(self) -> None: for i in range(4): future = executor.submit(self._test_init_device_mesh, 4, i) futures.append(future) + + +class MultiPgBaseTest(TestCase): + """ + A base test that creates N processes (via ThreadPoolExecutor) sharing + a single ProcessGroup. Each test_* method will reuse the same PG. + + Subclasses can specify: + - BACKEND: the backend to use for the ProcessGroup ("gloo" or "nccl") + - WORLD_SIZE: how many ranks to simulate + - Additional config for the PG, i.e. timeouts. + """ + + BACKEND = "gloo" + WORLD_SIZE = 2 + + @classmethod + def setUpClass(cls) -> None: + super().setUpClass() + + cls.store = TCPStore( + host_name="localhost", port=0, is_master=True, wait_for_workers=False + ) + cls.store_addr = f"localhost:{cls.store.port}/prefix" + + cls.pg_pool: List[ProcessGroup] = [] + + cls.executor = ThreadPoolExecutor(max_workers=cls.WORLD_SIZE) + + def init_pg(rank: int) -> ProcessGroup: + pg = cls._create_pg(cls.BACKEND) + pg.configure(cls.store_addr, rank, cls.WORLD_SIZE) + return pg + + futures = [cls.executor.submit(init_pg, rank) for rank in range(cls.WORLD_SIZE)] + cls.pg_pool = [future.result() for future in futures] + + @classmethod + def tearDownClass(cls) -> None: + # Cleanup + for pg in cls.pg_pool: + shutdown = getattr(pg, "shutdown", None) + if shutdown is not None: + shutdown() + cls.executor.shutdown(wait=True) + super().tearDownClass() + + @classmethod + def _create_pg(cls, backend: str) -> ProcessGroup: + """ + Helper that creates a new ProcessGroup of the specified type. + + NCCL groups aren't currently supported - we prefer to test + BabyNCCLGroups as they spin up their own subprocesses. + """ + if backend == "gloo": + return ProcessGroupGloo(timeout=timedelta(seconds=1)) + elif backend == "baby_gloo": + return ProcessGroupBabyGloo(timeout=timedelta(seconds=5)) + elif backend == "baby_nccl": + return ProcessGroupBabyNCCL(timeout=timedelta(seconds=10)) + else: + # fallback / dummy + return ProcessGroupDummy(0, 1) + + # pyre-fixme[3]: Return type must be annotated. + def _run_parallel(self, collective: str, device: str = "cpu") -> List[Any]: + """ + Helper to run on all ranks in parallel, returning a list + of results or raising an exception if any fail. + """ + func = _COLLECTIVE_TO_FUNC[collective] + futures = [] + for rank in range(self.WORLD_SIZE): + pg = self.pg_pool[rank] + # Each worker calls `func(pg=pg, rank=rank, tensor=tensor, *args, **kwargs)` + if "cuda" in device: + device = f"cuda:{rank}" + tensor = torch.tensor([rank + 1], device=device) + fut = self.executor.submit(func, pg, rank, tensor) + futures.append(fut) + return [f.result() for f in futures] + + def _run_with_resiliency(self, collective: str, device: str = "cpu") -> List[str]: + """ + Run a collective with resiliency: + - fault_rank (last rank) simulates a crash. + - surviving ranks detect the error, then reconfigure PG to exclude fault_rank. + - surviving ranks run the same collective again successfully. + """ + + def worker(pg: ProcessGroup, rank: int, dev: str) -> str: + fault_rank = self.WORLD_SIZE - 1 + test = _COLLECTIVE_TO_FUNC[collective] + + t1 = torch.tensor([rank + 1], device=dev, dtype=torch.float32) + # Simulate failure on the fault rank, but other ranks should still succeed. + if rank == fault_rank: + return f"Rank{rank} crashed" + + try: + test(pg, rank, t1.clone()) + except RuntimeError as e: + assert f"Simulated rank{rank} failure" in str(e) + + # Re-configure the PG to exclude the fault rank + new_world_size = self.WORLD_SIZE - 1 + new_store_addr = f"localhost:{self.store.port}/reconfig_{collective}" + pg.configure(new_store_addr, rank, new_world_size) + + # run the same collective again successfully + t2 = torch.tensor([rank + 1], device=dev, dtype=torch.float32) + test(pg, rank, t2) + return f"Rank{rank} final success." + + # run in parallel + futs = [ + self.executor.submit(worker, self.pg_pool[r], r, device) + for r in range(self.WORLD_SIZE) + ] + results = [] + for f in futs: + try: + results.append(f.result(timeout=20)) + except Exception as e: + results.append(e) + return results + + +class GlooMultiPgTest(MultiPgBaseTest): + BACKEND = "gloo" + WORLD_SIZE = 3 + SKIP = [ + "alltoall_base", + "reduce_scatter", + ] + COLLECTIVES: List[str] = list(set(_ALL_COLLECTIVES) - set(SKIP)) + + @parameterized.expand(COLLECTIVES) + def test_collective(self, collective: str) -> None: + self._run_parallel(collective, device="cpu") + + @parameterized.expand(COLLECTIVES) + def test_collective_with_resiliency(self, collective: str) -> None: + self._run_with_resiliency(collective, device="cpu") + + +class BabyGlooMultiPgTest(MultiPgBaseTest): + BACKEND = "baby_gloo" + WORLD_SIZE = 3 + SKIP = [ + "alltoall_base", + "reduce_scatter", + ] + COLLECTIVES: List[str] = list(set(_ALL_COLLECTIVES) - set(SKIP)) + + @parameterized.expand(COLLECTIVES) + def test_collective(self, collective: str) -> None: + self._run_parallel(collective, device="cpu") + + @parameterized.expand(COLLECTIVES) + def test_collective_with_resiliency(self, collective: str) -> None: + self._run_with_resiliency(collective, device="cpu") + + +@skipUnless( + torch.cuda.is_available() and torch.cuda.device_count() >= 2, "needs 2 CUDA devices" +) +class BabyNcclMultiPgTest(MultiPgBaseTest): + BACKEND = "baby_nccl" + WORLD_SIZE = 2 + + @parameterized.expand(_ALL_COLLECTIVES) + def test_collective(self, collective: str) -> None: + self._run_parallel(collective, device="cuda") + + +@skipUnless( + torch.cuda.is_available() and torch.cuda.device_count() >= 3, "needs 3 CUDA devices" +) +class BabyNcclResiliencyTest(MultiPgBaseTest): + BACKEND = "baby_nccl" + WORLD_SIZE = 3 + + @parameterized.expand(_ALL_COLLECTIVES) + def test_collective_with_resiliency(self, collective: str) -> None: + self._run_parallel(collective, device="cuda")