Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add _test_pg helper #45

Merged
merged 2 commits into from
Dec 18, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,14 @@ lintrunner -a

### Tests

We use `pytest` as our testing framework. To execute a specific test, use the following command:

```sh
pytest torchft/process_group_test.py -k test_device_mesh
```

To run the entire suite of tests:

```sh
$ scripts/test.sh
```
Expand Down
108 changes: 76 additions & 32 deletions torchft/process_group_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,21 +6,34 @@

import os
from concurrent.futures import ThreadPoolExecutor
from typing import Tuple
from typing import Any, Dict, Tuple
from unittest import TestCase, skipUnless
from unittest.mock import Mock

import torch
import torch.distributed as dist
from torch import nn
from torch._C._distributed_c10d import _resolve_process_group
from torch.distributed import ReduceOp, TCPStore, Work, _functional_collectives
from torch._C._distributed_c10d import (
AllgatherOptions,
AllreduceOptions,
BroadcastOptions,
ReduceOp,
_resolve_process_group,
)
from torch.distributed import (
ReduceOp,
TCPStore,
Work,
_functional_collectives,
get_world_size,
)
from torch.distributed.device_mesh import init_device_mesh

from torchft.manager import Manager
from torchft.process_group import (
ErrorSwallowingProcessGroupWrapper,
ManagedProcessGroup,
ProcessGroup,
ProcessGroupBabyGloo,
ProcessGroupBabyNCCL,
ProcessGroupDummy,
Expand All @@ -41,6 +54,56 @@ def dummy_init_pg() -> None:
)


def _test_pg(
pg: ProcessGroup,
example_tensor: torch.Tensor = torch.randn((2, 3), dtype=torch.float32),
) -> Dict[str, dist._Work]:
"""
Helper function to test a set of collective operations on a given process group.
"""

shape: torch.Size = example_tensor.shape
dtype: torch.dtype = example_tensor.dtype

# Create some dummy tensors for testing
input_tensor = example_tensor.clone()
output_tensors = [
[torch.empty_like(input_tensor) for _ in range(get_world_size(pg))]
]
tensor_list = [torch.empty_like(input_tensor)]

def check_tensors(arg: Any) -> None: # pyre-ignore[2]
"""Recursively check tensors for expected shape and dtype."""
if isinstance(arg, torch.Tensor):
assert arg.dtype == dtype, f"Output dtype mismatch: {arg.dtype} != {dtype}"
assert arg.shape == shape, f"Output shape mismatch: {arg.shape} != {shape}"
elif isinstance(arg, (list, tuple)):
for item in arg:
check_tensors(item)

# Test collectives
collectives = {
"allreduce": ([input_tensor], AllreduceOptions()),
"allgather": (output_tensors, [input_tensor], AllgatherOptions()),
"broadcast": (tensor_list, BroadcastOptions()),
"broadcast_one": (input_tensor, 0),
}
works: Dict[str, dist._Work] = {}
for coll_str, args in collectives.items():
coll = getattr(pg, coll_str)
work = coll(*args)
works[coll_str] = work
work.wait()
fut = work.get_future()
fut.wait()

# Check that all tensor arguments have the expected shapes and dtypes
check_tensors(args)

print(works)
return works


class ProcessGroupTest(TestCase):
def test_gloo(self) -> None:
store = TCPStore(
Expand All @@ -53,11 +116,7 @@ def test_gloo(self) -> None:

self.assertEqual(pg.size(), 1)

at = torch.tensor([2])

a_work = pg.allreduce([at], ReduceOp.SUM)
a_work.wait()
a_work.get_future().wait()
_test_pg(pg)

m = nn.Linear(3, 4)
m = torch.nn.parallel.DistributedDataParallel(m, process_group=pg)
Expand All @@ -77,10 +136,7 @@ def test_nccl(self) -> None:

self.assertEqual(pg.size(), 1)

at = torch.tensor([2], device=device)
a_work = pg.allreduce([at], ReduceOp.SUM)
a_work.wait()
a_work.get_future().wait()
_test_pg(pg, torch.tensor([2], device=device))

m = nn.Linear(3, 4).to(device)
m = torch.nn.parallel.DistributedDataParallel(m, process_group=pg)
Expand All @@ -90,9 +146,7 @@ def test_nccl(self) -> None:
store_addr = f"localhost:{store.port}/prefix2"
pg.configure(store_addr, 0, 1)

at = torch.tensor([2], device=device)
a_work = pg.allreduce([at], ReduceOp.SUM)
a_work.wait()
_test_pg(pg, torch.tensor([2], device=device))

torch.cuda.synchronize()

Expand Down Expand Up @@ -220,22 +274,16 @@ def test_error_swallowing_process_group_wrapper(self) -> None:
wrapper = ErrorSwallowingProcessGroupWrapper(pg)
self.assertIs(wrapper.parent, pg)

t = torch.zeros(10)
work = wrapper.allreduce([t], ReduceOp.SUM)
self.assertIsInstance(work, _ErrorSwallowingWork)
work.wait()
fut = work.get_future()
fut.wait()
works = _test_pg(wrapper)
self.assertIsInstance(list(works.values())[0], _ErrorSwallowingWork)

err = RuntimeError("test")
wrapper.report_error(err)
self.assertEqual(wrapper.error(), err)

work = wrapper.allreduce([t], ReduceOp.SUM)
self.assertIsInstance(work, _DummyWork)
work.wait()
fut = work.get_future()
fut.wait()
works = _test_pg(wrapper)
for work in works.values():
self.assertIsInstance(work, _DummyWork)

def test_managed_process_group(self) -> None:
manager = Mock(spec=Manager)
Expand All @@ -246,12 +294,8 @@ def test_managed_process_group(self) -> None:

self.assertEqual(pg.size(), 123)

t = torch.zeros(10)
work = pg.allreduce([t], ReduceOp.SUM)
self.assertIsInstance(work, _ManagedWork)
work.wait()
fut = work.get_future()
fut.wait()
works = _test_pg(pg)
self.assertIsInstance(list(works.values())[0], _ManagedWork)

self.assertEqual(manager.report_error.call_count, 0)
self.assertEqual(manager.wrap_future.call_count, 1)
Loading