6
6
7
7
import os
8
8
from concurrent .futures import ThreadPoolExecutor
9
- from typing import Tuple
9
+ from typing import Any , Dict , Tuple
10
10
from unittest import TestCase , skipUnless
11
11
from unittest .mock import Mock
12
12
13
13
import torch
14
14
import torch .distributed as dist
15
15
from torch import nn
16
- from torch ._C ._distributed_c10d import _resolve_process_group
17
- from torch .distributed import ReduceOp , TCPStore , Work , _functional_collectives
16
+ from torch ._C ._distributed_c10d import (
17
+ AllgatherOptions ,
18
+ AllreduceOptions ,
19
+ BroadcastOptions ,
20
+ ReduceOp ,
21
+ _resolve_process_group ,
22
+ )
23
+ from torch .distributed import (
24
+ ReduceOp ,
25
+ TCPStore ,
26
+ Work ,
27
+ _functional_collectives ,
28
+ get_world_size ,
29
+ )
18
30
from torch .distributed .device_mesh import init_device_mesh
19
31
20
32
from torchft .manager import Manager
21
33
from torchft .process_group import (
22
34
ErrorSwallowingProcessGroupWrapper ,
23
35
ManagedProcessGroup ,
36
+ ProcessGroup ,
24
37
ProcessGroupBabyGloo ,
25
38
ProcessGroupBabyNCCL ,
26
39
ProcessGroupDummy ,
@@ -41,6 +54,56 @@ def dummy_init_pg() -> None:
41
54
)
42
55
43
56
57
+ def _test_pg (
58
+ pg : ProcessGroup ,
59
+ example_tensor : torch .Tensor = torch .randn ((2 , 3 ), dtype = torch .float32 ),
60
+ ) -> Dict [str , dist ._Work ]:
61
+ """
62
+ Helper function to test a set of collective operations on a given process group.
63
+ """
64
+
65
+ shape : torch .Size = example_tensor .shape
66
+ dtype : torch .dtype = example_tensor .dtype
67
+
68
+ # Create some dummy tensors for testing
69
+ input_tensor = example_tensor .clone ()
70
+ output_tensors = [
71
+ [torch .empty_like (input_tensor ) for _ in range (get_world_size (pg ))]
72
+ ]
73
+ tensor_list = [torch .empty_like (input_tensor )]
74
+
75
+ def check_tensors (arg : Any ) -> None : # pyre-ignore[2]
76
+ """Recursively check tensors for expected shape and dtype."""
77
+ if isinstance (arg , torch .Tensor ):
78
+ assert arg .dtype == dtype , f"Output dtype mismatch: { arg .dtype } != { dtype } "
79
+ assert arg .shape == shape , f"Output shape mismatch: { arg .shape } != { shape } "
80
+ elif isinstance (arg , (list , tuple )):
81
+ for item in arg :
82
+ check_tensors (item )
83
+
84
+ # Test collectives
85
+ collectives = {
86
+ "allreduce" : ([input_tensor ], AllreduceOptions ()),
87
+ "allgather" : (output_tensors , [input_tensor ], AllgatherOptions ()),
88
+ "broadcast" : (tensor_list , BroadcastOptions ()),
89
+ "broadcast_one" : (input_tensor , 0 ),
90
+ }
91
+ works : Dict [str , dist ._Work ] = {}
92
+ for coll_str , args in collectives .items ():
93
+ coll = getattr (pg , coll_str )
94
+ work = coll (* args )
95
+ works [coll_str ] = work
96
+ work .wait ()
97
+ fut = work .get_future ()
98
+ fut .wait ()
99
+
100
+ # Check that all tensor arguments have the expected shapes and dtypes
101
+ check_tensors (args )
102
+
103
+ print (works )
104
+ return works
105
+
106
+
44
107
class ProcessGroupTest (TestCase ):
45
108
def test_gloo (self ) -> None :
46
109
store = TCPStore (
@@ -53,11 +116,7 @@ def test_gloo(self) -> None:
53
116
54
117
self .assertEqual (pg .size (), 1 )
55
118
56
- at = torch .tensor ([2 ])
57
-
58
- a_work = pg .allreduce ([at ], ReduceOp .SUM )
59
- a_work .wait ()
60
- a_work .get_future ().wait ()
119
+ _test_pg (pg )
61
120
62
121
m = nn .Linear (3 , 4 )
63
122
m = torch .nn .parallel .DistributedDataParallel (m , process_group = pg )
@@ -77,10 +136,7 @@ def test_nccl(self) -> None:
77
136
78
137
self .assertEqual (pg .size (), 1 )
79
138
80
- at = torch .tensor ([2 ], device = device )
81
- a_work = pg .allreduce ([at ], ReduceOp .SUM )
82
- a_work .wait ()
83
- a_work .get_future ().wait ()
139
+ _test_pg (pg , torch .tensor ([2 ], device = device ))
84
140
85
141
m = nn .Linear (3 , 4 ).to (device )
86
142
m = torch .nn .parallel .DistributedDataParallel (m , process_group = pg )
@@ -90,9 +146,7 @@ def test_nccl(self) -> None:
90
146
store_addr = f"localhost:{ store .port } /prefix2"
91
147
pg .configure (store_addr , 0 , 1 )
92
148
93
- at = torch .tensor ([2 ], device = device )
94
- a_work = pg .allreduce ([at ], ReduceOp .SUM )
95
- a_work .wait ()
149
+ _test_pg (pg , torch .tensor ([2 ], device = device ))
96
150
97
151
torch .cuda .synchronize ()
98
152
@@ -220,22 +274,16 @@ def test_error_swallowing_process_group_wrapper(self) -> None:
220
274
wrapper = ErrorSwallowingProcessGroupWrapper (pg )
221
275
self .assertIs (wrapper .parent , pg )
222
276
223
- t = torch .zeros (10 )
224
- work = wrapper .allreduce ([t ], ReduceOp .SUM )
225
- self .assertIsInstance (work , _ErrorSwallowingWork )
226
- work .wait ()
227
- fut = work .get_future ()
228
- fut .wait ()
277
+ works = _test_pg (wrapper )
278
+ self .assertIsInstance (list (works .values ())[0 ], _ErrorSwallowingWork )
229
279
230
280
err = RuntimeError ("test" )
231
281
wrapper .report_error (err )
232
282
self .assertEqual (wrapper .error (), err )
233
283
234
- work = wrapper .allreduce ([t ], ReduceOp .SUM )
235
- self .assertIsInstance (work , _DummyWork )
236
- work .wait ()
237
- fut = work .get_future ()
238
- fut .wait ()
284
+ works = _test_pg (wrapper )
285
+ for work in works .values ():
286
+ self .assertIsInstance (work , _DummyWork )
239
287
240
288
def test_managed_process_group (self ) -> None :
241
289
manager = Mock (spec = Manager )
@@ -246,12 +294,8 @@ def test_managed_process_group(self) -> None:
246
294
247
295
self .assertEqual (pg .size (), 123 )
248
296
249
- t = torch .zeros (10 )
250
- work = pg .allreduce ([t ], ReduceOp .SUM )
251
- self .assertIsInstance (work , _ManagedWork )
252
- work .wait ()
253
- fut = work .get_future ()
254
- fut .wait ()
297
+ works = _test_pg (pg )
298
+ self .assertIsInstance (list (works .values ())[0 ], _ManagedWork )
255
299
256
300
self .assertEqual (manager .report_error .call_count , 0 )
257
301
self .assertEqual (manager .wrap_future .call_count , 1 )
0 commit comments