Skip to content

Commit 8fb173c

Browse files
committed
Add _test_pg helper
1 parent 8a22dc8 commit 8fb173c

File tree

2 files changed

+83
-37
lines changed

2 files changed

+83
-37
lines changed

CONTRIBUTING.md

+8
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,14 @@ lintrunner -a
6767

6868
### Tests
6969

70+
We use `pytest` as our testing framework. To execute a specific test, use the following command:
71+
72+
```sh
73+
pytest torchft/process_group_test.py -k test_device_mesh
74+
```
75+
76+
To run the entire suite of tests:
77+
7078
```sh
7179
$ scripts/test.sh
7280
```

torchft/process_group_test.py

+75-37
Original file line numberDiff line numberDiff line change
@@ -6,31 +6,44 @@
66

77
import os
88
from concurrent.futures import ThreadPoolExecutor
9-
from typing import Tuple
10-
from unittest import TestCase, skipUnless
9+
from typing import Any, Tuple
10+
from unittest import skipUnless, TestCase
1111
from unittest.mock import Mock
1212

1313
import torch
1414
import torch.distributed as dist
1515
from torch import nn
16-
from torch._C._distributed_c10d import _resolve_process_group
17-
from torch.distributed import ReduceOp, TCPStore, Work, _functional_collectives
16+
from torch._C._distributed_c10d import (
17+
_resolve_process_group,
18+
AllgatherOptions,
19+
AllreduceOptions,
20+
BroadcastOptions,
21+
ReduceOp,
22+
)
23+
from torch.distributed import (
24+
_functional_collectives,
25+
get_world_size,
26+
ReduceOp,
27+
TCPStore,
28+
Work,
29+
)
1830
from torch.distributed.device_mesh import init_device_mesh
1931

2032
from torchft.manager import Manager
2133
from torchft.process_group import (
34+
_DummyWork,
35+
_ErrorSwallowingWork,
36+
_ManagedWork,
2237
ErrorSwallowingProcessGroupWrapper,
38+
extend_device_mesh,
2339
ManagedProcessGroup,
40+
ProcessGroup,
2441
ProcessGroupBabyGloo,
2542
ProcessGroupBabyNCCL,
2643
ProcessGroupDummy,
2744
ProcessGroupGloo,
2845
ProcessGroupNCCL,
2946
ProcessGroupWrapper,
30-
_DummyWork,
31-
_ErrorSwallowingWork,
32-
_ManagedWork,
33-
extend_device_mesh,
3447
)
3548

3649

@@ -41,6 +54,54 @@ def dummy_init_pg() -> None:
4154
)
4255

4356

57+
def _test_pg(
58+
pg: ProcessGroup,
59+
example_tensor: torch.Tensor = torch.randn((2, 3), dtype=torch.float32),
60+
) -> None:
61+
"""
62+
Helper function to test a set of collective operations on a given process group.
63+
"""
64+
65+
shape: torch.Size = example_tensor.shape
66+
dtype: torch.dtype = example_tensor.dtype
67+
68+
# Create some dummy tensors for testing
69+
input_tensor = example_tensor.clone()
70+
output_tensors = [
71+
[torch.empty_like(input_tensor) for _ in range(get_world_size(pg))]
72+
]
73+
tensor_list = [torch.empty_like(input_tensor)]
74+
75+
def check_tensors(arg: Any) -> None: # pyre-ignore[2]
76+
"""Recursively check tensors for expected shape and dtype."""
77+
if isinstance(arg, torch.Tensor):
78+
assert arg.dtype == dtype, f"Output dtype mismatch: {arg.dtype} != {dtype}"
79+
assert arg.shape == shape, f"Output shape mismatch: {arg.shape} != {shape}"
80+
elif isinstance(arg, (list, tuple)):
81+
for item in arg:
82+
check_tensors(item)
83+
84+
# Test collectives
85+
collectives = {
86+
"allreduce": ([input_tensor], AllreduceOptions()),
87+
"allgather": (output_tensors, [input_tensor], AllgatherOptions()),
88+
"broadcast": (tensor_list, BroadcastOptions()),
89+
"broadcast_one": (input_tensor, 0),
90+
}
91+
for coll_str, args in collectives.items():
92+
coll = getattr(pg, coll_str)
93+
work = coll(*args)
94+
work.wait()
95+
96+
# Check that all tensor arguments have the expected shapes and dtypes
97+
check_tensors(args)
98+
99+
# Check that get_future works
100+
work = coll(*args)
101+
fut = work.get_future()
102+
fut.wait()
103+
104+
44105
class ProcessGroupTest(TestCase):
45106
def test_gloo(self) -> None:
46107
store = TCPStore(
@@ -53,11 +114,7 @@ def test_gloo(self) -> None:
53114

54115
self.assertEqual(pg.size(), 1)
55116

56-
at = torch.tensor([2])
57-
58-
a_work = pg.allreduce([at], ReduceOp.SUM)
59-
a_work.wait()
60-
a_work.get_future().wait()
117+
_test_pg(pg)
61118

62119
m = nn.Linear(3, 4)
63120
m = torch.nn.parallel.DistributedDataParallel(m, process_group=pg)
@@ -77,10 +134,7 @@ def test_nccl(self) -> None:
77134

78135
self.assertEqual(pg.size(), 1)
79136

80-
at = torch.tensor([2], device=device)
81-
a_work = pg.allreduce([at], ReduceOp.SUM)
82-
a_work.wait()
83-
a_work.get_future().wait()
137+
_test_pg(pg, torch.tensor([2], device=device))
84138

85139
m = nn.Linear(3, 4).to(device)
86140
m = torch.nn.parallel.DistributedDataParallel(m, process_group=pg)
@@ -90,9 +144,7 @@ def test_nccl(self) -> None:
90144
store_addr = f"localhost:{store.port}/prefix2"
91145
pg.configure(store_addr, 0, 1)
92146

93-
at = torch.tensor([2], device=device)
94-
a_work = pg.allreduce([at], ReduceOp.SUM)
95-
a_work.wait()
147+
_test_pg(pg, torch.tensor([2], device=device))
96148

97149
torch.cuda.synchronize()
98150

@@ -220,22 +272,13 @@ def test_error_swallowing_process_group_wrapper(self) -> None:
220272
wrapper = ErrorSwallowingProcessGroupWrapper(pg)
221273
self.assertIs(wrapper.parent, pg)
222274

223-
t = torch.zeros(10)
224-
work = wrapper.allreduce([t], ReduceOp.SUM)
225-
self.assertIsInstance(work, _ErrorSwallowingWork)
226-
work.wait()
227-
fut = work.get_future()
228-
fut.wait()
275+
_test_pg(wrapper)
229276

230277
err = RuntimeError("test")
231278
wrapper.report_error(err)
232279
self.assertEqual(wrapper.error(), err)
233280

234-
work = wrapper.allreduce([t], ReduceOp.SUM)
235-
self.assertIsInstance(work, _DummyWork)
236-
work.wait()
237-
fut = work.get_future()
238-
fut.wait()
281+
_test_pg(wrapper)
239282

240283
def test_managed_process_group(self) -> None:
241284
manager = Mock(spec=Manager)
@@ -246,12 +289,7 @@ def test_managed_process_group(self) -> None:
246289

247290
self.assertEqual(pg.size(), 123)
248291

249-
t = torch.zeros(10)
250-
work = pg.allreduce([t], ReduceOp.SUM)
251-
self.assertIsInstance(work, _ManagedWork)
252-
work.wait()
253-
fut = work.get_future()
254-
fut.wait()
292+
_test_pg(pg)
255293

256294
self.assertEqual(manager.report_error.call_count, 0)
257295
self.assertEqual(manager.wrap_future.call_count, 1)

0 commit comments

Comments
 (0)