11
11
import unittest
12
12
from concurrent .futures import ProcessPoolExecutor , ThreadPoolExecutor
13
13
from datetime import timedelta
14
- from typing import Any , Dict , List , Tuple , cast
14
+ from typing import Any , Dict , List , Tuple , Union , cast
15
15
from unittest import TestCase , skipUnless
16
16
from unittest .mock import Mock
17
17
18
18
import torch
19
19
import torch .distributed as dist
20
- from parameterized import parameterized
21
20
from torch import nn
22
21
from torch ._C ._distributed_c10d import (
23
22
AllgatherOptions ,
@@ -75,39 +74,45 @@ def run_collectives(
75
74
[torch .empty_like (input_tensor ) for _ in range (get_world_size (pg ))]
76
75
]
77
76
tensor_list = [torch .empty_like (input_tensor )]
77
+
78
78
works = []
79
+ input_tensors = []
79
80
80
81
if "allreduce" in collectives :
81
82
works += [
82
83
pg .allreduce ([input_tensor ], AllreduceOptions ()),
83
84
pg .allreduce ([input_tensor ], ReduceOp .SUM ),
84
85
]
85
- input_tensors = input_tensor
86
+ input_tensors += [ input_tensor , input_tensor ]
86
87
elif "allgather" in collectives :
87
88
works += [pg .allgather (output_tensors , [input_tensor ], AllgatherOptions ())]
88
- input_tensors = (output_tensors , input_tensor )
89
+ input_tensors += [ (output_tensors , input_tensor )]
89
90
elif "broadcast" in collectives :
90
91
works += [pg .broadcast (tensor_list , BroadcastOptions ())]
91
- input_tensors = tensor_list
92
+ input_tensors += [ tensor_list ]
92
93
elif "broadcast_one" in collectives :
93
94
works += [pg .broadcast_one (input_tensor , 0 )]
94
- input_tensors = input_tensor
95
+ input_tensors += [ input_tensor ]
95
96
96
- def check_tensors (input_tensors : Any ) -> None : # pyre-ignore[2]
97
+ def check_tensors (input_tensors : Union [ torch . Tensor , List [ torch . Tensor ]] ) -> None :
97
98
"""Recursively check tensors for input_tensors shape and dtype."""
98
99
if isinstance (input_tensors , torch .Tensor ):
99
- assert input_tensors .dtype == dtype , f"Output dtype mismatch: { input_tensors .dtype } != { dtype } "
100
- assert input_tensors .shape == shape , f"Output shape mismatch: { input_tensors .shape } != { shape } "
100
+ assert (
101
+ input_tensors .dtype == dtype
102
+ ), f"Output dtype mismatch: { input_tensors .dtype } != { dtype } "
103
+ assert (
104
+ input_tensors .shape == shape
105
+ ), f"Output shape mismatch: { input_tensors .shape } != { shape } "
101
106
elif isinstance (input_tensors , (list , tuple )):
102
107
for item in input_tensors :
103
108
check_tensors (item )
104
109
105
- for work in works :
110
+ for work , input_tensor in zip ( works , input_tensors ) :
106
111
work .wait ()
107
112
fut = work .get_future ()
108
113
fut .wait ()
109
114
# Check that all tensor arguments have the input_tensors shapes and dtypes
110
- check_tensors (input_tensors )
115
+ check_tensors (input_tensor )
111
116
112
117
print (works )
113
118
return works
@@ -128,8 +133,7 @@ def setUp(self) -> None:
128
133
)
129
134
self .store_addr = f"localhost:{ self .store .port } /prefix"
130
135
131
- @parameterized .expand (collectives )
132
- def test_nccl (self , collective : str ) -> None :
136
+ def test_nccl (self ) -> None :
133
137
device = "cuda"
134
138
135
139
pg = ProcessGroupNCCL ()
@@ -139,7 +143,7 @@ def test_nccl(self, collective: str) -> None:
139
143
140
144
run_collectives (
141
145
pg = pg ,
142
- collectives = [ collective ] ,
146
+ collectives = self . collectives ,
143
147
example_tensor = torch .tensor ([2 ], device = device ),
144
148
)
145
149
@@ -153,7 +157,7 @@ def test_nccl(self, collective: str) -> None:
153
157
154
158
run_collectives (
155
159
pg = pg ,
156
- collectives = [ collective ] ,
160
+ collectives = self . collectives ,
157
161
example_tensor = torch .tensor ([2 ], device = device ),
158
162
)
159
163
@@ -233,23 +237,21 @@ def setUp(self) -> None:
233
237
)
234
238
self .store_addr = f"localhost:{ self .store .port } /prefix"
235
239
236
- @parameterized .expand (collectives )
237
- def test_gloo (self , collective : str ) -> None :
240
+ def test_gloo (self ) -> None :
238
241
pg = ProcessGroupGloo ()
239
242
pg .configure (self .store_addr , 0 , 1 )
240
243
241
244
self .assertEqual (pg .size (), 1 )
242
- run_collectives (pg = pg , collectives = [ collective ] )
245
+ run_collectives (pg = pg , collectives = self . collectives )
243
246
m = nn .Linear (3 , 4 )
244
247
m = torch .nn .parallel .DistributedDataParallel (m , process_group = pg )
245
248
m (torch .rand (2 , 3 ))
246
249
247
- @parameterized .expand (collectives )
248
- def test_baby_gloo_apis (self , collective : str ) -> None :
250
+ def test_baby_gloo_apis (self ) -> None :
249
251
pg = ProcessGroupBabyGloo (timeout = timedelta (seconds = 10 ))
250
252
pg .configure (self .store_addr , 0 , 1 )
251
253
252
- run_collectives (pg = pg , collectives = [ collective ] )
254
+ run_collectives (pg = pg , collectives = self . collectives )
253
255
254
256
# force collection to ensure no BabyWork objects remain
255
257
gc .collect ()
0 commit comments