Skip to content

Commit 49d2aec

Browse files
authored
Add _test_pg helper (#45)
* Add _test_pg helper * update CONTRIBUTING.md
1 parent 6d6e9a4 commit 49d2aec

File tree

2 files changed

+90
-32
lines changed

2 files changed

+90
-32
lines changed

CONTRIBUTING.md

+14
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,20 @@ lintrunner -a
6767

6868
### Tests
6969

70+
We use `pytest` as our testing framework. To execute a specific test, use the following command:
71+
72+
```sh
73+
pytest torchft/process_group_test.py -k test_device_mesh
74+
```
75+
76+
To run the Rust tests run:
77+
78+
```sh
79+
cargo test
80+
```
81+
82+
To run the entire suite of tests:
83+
7084
```sh
7185
$ scripts/test.sh
7286
```

torchft/process_group_test.py

+76-32
Original file line numberDiff line numberDiff line change
@@ -6,21 +6,34 @@
66

77
import os
88
from concurrent.futures import ThreadPoolExecutor
9-
from typing import Tuple
9+
from typing import Any, Dict, Tuple
1010
from unittest import TestCase, skipUnless
1111
from unittest.mock import Mock
1212

1313
import torch
1414
import torch.distributed as dist
1515
from torch import nn
16-
from torch._C._distributed_c10d import _resolve_process_group
17-
from torch.distributed import ReduceOp, TCPStore, Work, _functional_collectives
16+
from torch._C._distributed_c10d import (
17+
AllgatherOptions,
18+
AllreduceOptions,
19+
BroadcastOptions,
20+
ReduceOp,
21+
_resolve_process_group,
22+
)
23+
from torch.distributed import (
24+
ReduceOp,
25+
TCPStore,
26+
Work,
27+
_functional_collectives,
28+
get_world_size,
29+
)
1830
from torch.distributed.device_mesh import init_device_mesh
1931

2032
from torchft.manager import Manager
2133
from torchft.process_group import (
2234
ErrorSwallowingProcessGroupWrapper,
2335
ManagedProcessGroup,
36+
ProcessGroup,
2437
ProcessGroupBabyGloo,
2538
ProcessGroupBabyNCCL,
2639
ProcessGroupDummy,
@@ -41,6 +54,56 @@ def dummy_init_pg() -> None:
4154
)
4255

4356

57+
def _test_pg(
58+
pg: ProcessGroup,
59+
example_tensor: torch.Tensor = torch.randn((2, 3), dtype=torch.float32),
60+
) -> Dict[str, dist._Work]:
61+
"""
62+
Helper function to test a set of collective operations on a given process group.
63+
"""
64+
65+
shape: torch.Size = example_tensor.shape
66+
dtype: torch.dtype = example_tensor.dtype
67+
68+
# Create some dummy tensors for testing
69+
input_tensor = example_tensor.clone()
70+
output_tensors = [
71+
[torch.empty_like(input_tensor) for _ in range(get_world_size(pg))]
72+
]
73+
tensor_list = [torch.empty_like(input_tensor)]
74+
75+
def check_tensors(arg: Any) -> None: # pyre-ignore[2]
76+
"""Recursively check tensors for expected shape and dtype."""
77+
if isinstance(arg, torch.Tensor):
78+
assert arg.dtype == dtype, f"Output dtype mismatch: {arg.dtype} != {dtype}"
79+
assert arg.shape == shape, f"Output shape mismatch: {arg.shape} != {shape}"
80+
elif isinstance(arg, (list, tuple)):
81+
for item in arg:
82+
check_tensors(item)
83+
84+
# Test collectives
85+
collectives = {
86+
"allreduce": ([input_tensor], AllreduceOptions()),
87+
"allgather": (output_tensors, [input_tensor], AllgatherOptions()),
88+
"broadcast": (tensor_list, BroadcastOptions()),
89+
"broadcast_one": (input_tensor, 0),
90+
}
91+
works: Dict[str, dist._Work] = {}
92+
for coll_str, args in collectives.items():
93+
coll = getattr(pg, coll_str)
94+
work = coll(*args)
95+
works[coll_str] = work
96+
work.wait()
97+
fut = work.get_future()
98+
fut.wait()
99+
100+
# Check that all tensor arguments have the expected shapes and dtypes
101+
check_tensors(args)
102+
103+
print(works)
104+
return works
105+
106+
44107
class ProcessGroupTest(TestCase):
45108
def test_gloo(self) -> None:
46109
store = TCPStore(
@@ -53,11 +116,7 @@ def test_gloo(self) -> None:
53116

54117
self.assertEqual(pg.size(), 1)
55118

56-
at = torch.tensor([2])
57-
58-
a_work = pg.allreduce([at], ReduceOp.SUM)
59-
a_work.wait()
60-
a_work.get_future().wait()
119+
_test_pg(pg)
61120

62121
m = nn.Linear(3, 4)
63122
m = torch.nn.parallel.DistributedDataParallel(m, process_group=pg)
@@ -77,10 +136,7 @@ def test_nccl(self) -> None:
77136

78137
self.assertEqual(pg.size(), 1)
79138

80-
at = torch.tensor([2], device=device)
81-
a_work = pg.allreduce([at], ReduceOp.SUM)
82-
a_work.wait()
83-
a_work.get_future().wait()
139+
_test_pg(pg, torch.tensor([2], device=device))
84140

85141
m = nn.Linear(3, 4).to(device)
86142
m = torch.nn.parallel.DistributedDataParallel(m, process_group=pg)
@@ -90,9 +146,7 @@ def test_nccl(self) -> None:
90146
store_addr = f"localhost:{store.port}/prefix2"
91147
pg.configure(store_addr, 0, 1)
92148

93-
at = torch.tensor([2], device=device)
94-
a_work = pg.allreduce([at], ReduceOp.SUM)
95-
a_work.wait()
149+
_test_pg(pg, torch.tensor([2], device=device))
96150

97151
torch.cuda.synchronize()
98152

@@ -220,22 +274,16 @@ def test_error_swallowing_process_group_wrapper(self) -> None:
220274
wrapper = ErrorSwallowingProcessGroupWrapper(pg)
221275
self.assertIs(wrapper.parent, pg)
222276

223-
t = torch.zeros(10)
224-
work = wrapper.allreduce([t], ReduceOp.SUM)
225-
self.assertIsInstance(work, _ErrorSwallowingWork)
226-
work.wait()
227-
fut = work.get_future()
228-
fut.wait()
277+
works = _test_pg(wrapper)
278+
self.assertIsInstance(list(works.values())[0], _ErrorSwallowingWork)
229279

230280
err = RuntimeError("test")
231281
wrapper.report_error(err)
232282
self.assertEqual(wrapper.error(), err)
233283

234-
work = wrapper.allreduce([t], ReduceOp.SUM)
235-
self.assertIsInstance(work, _DummyWork)
236-
work.wait()
237-
fut = work.get_future()
238-
fut.wait()
284+
works = _test_pg(wrapper)
285+
for work in works.values():
286+
self.assertIsInstance(work, _DummyWork)
239287

240288
def test_managed_process_group(self) -> None:
241289
manager = Mock(spec=Manager)
@@ -246,12 +294,8 @@ def test_managed_process_group(self) -> None:
246294

247295
self.assertEqual(pg.size(), 123)
248296

249-
t = torch.zeros(10)
250-
work = pg.allreduce([t], ReduceOp.SUM)
251-
self.assertIsInstance(work, _ManagedWork)
252-
work.wait()
253-
fut = work.get_future()
254-
fut.wait()
297+
works = _test_pg(pg)
298+
self.assertIsInstance(list(works.values())[0], _ManagedWork)
255299

256300
self.assertEqual(manager.report_error.call_count, 0)
257301
self.assertEqual(manager.wrap_future.call_count, 1)

0 commit comments

Comments
 (0)