6
6
7
7
import os
8
8
from concurrent .futures import ThreadPoolExecutor
9
- from typing import Tuple
10
- from unittest import TestCase , skipUnless
9
+ from typing import Any , Tuple
10
+ from unittest import skipUnless , TestCase
11
11
from unittest .mock import Mock
12
12
13
13
import torch
14
14
import torch .distributed as dist
15
15
from torch import nn
16
- from torch ._C ._distributed_c10d import _resolve_process_group
17
- from torch .distributed import ReduceOp , TCPStore , Work , _functional_collectives
16
+ from torch ._C ._distributed_c10d import (
17
+ _resolve_process_group ,
18
+ AllgatherOptions ,
19
+ AllreduceOptions ,
20
+ BroadcastOptions ,
21
+ ReduceOp ,
22
+ )
23
+ from torch .distributed import (
24
+ _functional_collectives ,
25
+ get_world_size ,
26
+ ReduceOp ,
27
+ TCPStore ,
28
+ Work ,
29
+ )
18
30
from torch .distributed .device_mesh import init_device_mesh
19
31
20
32
from torchft .manager import Manager
21
33
from torchft .process_group import (
34
+ _DummyWork ,
35
+ _ErrorSwallowingWork ,
36
+ _ManagedWork ,
22
37
ErrorSwallowingProcessGroupWrapper ,
38
+ extend_device_mesh ,
23
39
ManagedProcessGroup ,
40
+ ProcessGroup ,
24
41
ProcessGroupBabyGloo ,
25
42
ProcessGroupBabyNCCL ,
26
43
ProcessGroupDummy ,
27
44
ProcessGroupGloo ,
28
45
ProcessGroupNCCL ,
29
46
ProcessGroupWrapper ,
30
- _DummyWork ,
31
- _ErrorSwallowingWork ,
32
- _ManagedWork ,
33
- extend_device_mesh ,
34
47
)
35
48
36
49
@@ -41,6 +54,54 @@ def dummy_init_pg() -> None:
41
54
)
42
55
43
56
57
+ def _test_pg (
58
+ pg : ProcessGroup ,
59
+ example_tensor : torch .Tensor = torch .randn ((2 , 3 ), dtype = torch .float32 ),
60
+ ) -> None :
61
+ """
62
+ Helper function to test a set of collective operations on a given process group.
63
+ """
64
+
65
+ shape : torch .Size = example_tensor .shape
66
+ dtype : torch .dtype = example_tensor .dtype
67
+
68
+ # Create some dummy tensors for testing
69
+ input_tensor = example_tensor .clone ()
70
+ output_tensors = [
71
+ [torch .empty_like (input_tensor ) for _ in range (get_world_size (pg ))]
72
+ ]
73
+ tensor_list = [torch .empty_like (input_tensor )]
74
+
75
+ def check_tensors (arg : Any ) -> None : # pyre-ignore[2]
76
+ """Recursively check tensors for expected shape and dtype."""
77
+ if isinstance (arg , torch .Tensor ):
78
+ assert arg .dtype == dtype , f"Output dtype mismatch: { arg .dtype } != { dtype } "
79
+ assert arg .shape == shape , f"Output shape mismatch: { arg .shape } != { shape } "
80
+ elif isinstance (arg , (list , tuple )):
81
+ for item in arg :
82
+ check_tensors (item )
83
+
84
+ # Test collectives
85
+ collectives = {
86
+ "allreduce" : ([input_tensor ], AllreduceOptions ()),
87
+ "allgather" : (output_tensors , [input_tensor ], AllgatherOptions ()),
88
+ "broadcast" : (tensor_list , BroadcastOptions ()),
89
+ "broadcast_one" : (input_tensor , 0 ),
90
+ }
91
+ for coll_str , args in collectives .items ():
92
+ coll = getattr (pg , coll_str )
93
+ work = coll (* args )
94
+ work .wait ()
95
+
96
+ # Check that all tensor arguments have the expected shapes and dtypes
97
+ check_tensors (args )
98
+
99
+ # Check that get_future works
100
+ work = coll (* args )
101
+ fut = work .get_future ()
102
+ fut .wait ()
103
+
104
+
44
105
class ProcessGroupTest (TestCase ):
45
106
def test_gloo (self ) -> None :
46
107
store = TCPStore (
@@ -53,11 +114,7 @@ def test_gloo(self) -> None:
53
114
54
115
self .assertEqual (pg .size (), 1 )
55
116
56
- at = torch .tensor ([2 ])
57
-
58
- a_work = pg .allreduce ([at ], ReduceOp .SUM )
59
- a_work .wait ()
60
- a_work .get_future ().wait ()
117
+ _test_pg (pg )
61
118
62
119
m = nn .Linear (3 , 4 )
63
120
m = torch .nn .parallel .DistributedDataParallel (m , process_group = pg )
@@ -77,10 +134,7 @@ def test_nccl(self) -> None:
77
134
78
135
self .assertEqual (pg .size (), 1 )
79
136
80
- at = torch .tensor ([2 ], device = device )
81
- a_work = pg .allreduce ([at ], ReduceOp .SUM )
82
- a_work .wait ()
83
- a_work .get_future ().wait ()
137
+ _test_pg (pg , torch .tensor ([2 ], device = device ))
84
138
85
139
m = nn .Linear (3 , 4 ).to (device )
86
140
m = torch .nn .parallel .DistributedDataParallel (m , process_group = pg )
@@ -90,9 +144,7 @@ def test_nccl(self) -> None:
90
144
store_addr = f"localhost:{ store .port } /prefix2"
91
145
pg .configure (store_addr , 0 , 1 )
92
146
93
- at = torch .tensor ([2 ], device = device )
94
- a_work = pg .allreduce ([at ], ReduceOp .SUM )
95
- a_work .wait ()
147
+ _test_pg (pg , torch .tensor ([2 ], device = device ))
96
148
97
149
torch .cuda .synchronize ()
98
150
@@ -220,22 +272,13 @@ def test_error_swallowing_process_group_wrapper(self) -> None:
220
272
wrapper = ErrorSwallowingProcessGroupWrapper (pg )
221
273
self .assertIs (wrapper .parent , pg )
222
274
223
- t = torch .zeros (10 )
224
- work = wrapper .allreduce ([t ], ReduceOp .SUM )
225
- self .assertIsInstance (work , _ErrorSwallowingWork )
226
- work .wait ()
227
- fut = work .get_future ()
228
- fut .wait ()
275
+ _test_pg (wrapper )
229
276
230
277
err = RuntimeError ("test" )
231
278
wrapper .report_error (err )
232
279
self .assertEqual (wrapper .error (), err )
233
280
234
- work = wrapper .allreduce ([t ], ReduceOp .SUM )
235
- self .assertIsInstance (work , _DummyWork )
236
- work .wait ()
237
- fut = work .get_future ()
238
- fut .wait ()
281
+ _test_pg (wrapper )
239
282
240
283
def test_managed_process_group (self ) -> None :
241
284
manager = Mock (spec = Manager )
@@ -246,12 +289,7 @@ def test_managed_process_group(self) -> None:
246
289
247
290
self .assertEqual (pg .size (), 123 )
248
291
249
- t = torch .zeros (10 )
250
- work = pg .allreduce ([t ], ReduceOp .SUM )
251
- self .assertIsInstance (work , _ManagedWork )
252
- work .wait ()
253
- fut = work .get_future ()
254
- fut .wait ()
292
+ _test_pg (pg )
255
293
256
294
self .assertEqual (manager .report_error .call_count , 0 )
257
295
self .assertEqual (manager .wrap_future .call_count , 1 )
0 commit comments