Skip to content

Commit ec8ae32

Browse files
committed
linters
1 parent 435449c commit ec8ae32

File tree

1 file changed

+23
-21
lines changed

1 file changed

+23
-21
lines changed

torchft/process_group_test.py

+23-21
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,12 @@
1111
import unittest
1212
from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor
1313
from datetime import timedelta
14-
from typing import Any, Dict, List, Tuple, cast
14+
from typing import Any, Dict, List, Tuple, Union, cast
1515
from unittest import TestCase, skipUnless
1616
from unittest.mock import Mock
1717

1818
import torch
1919
import torch.distributed as dist
20-
from parameterized import parameterized
2120
from torch import nn
2221
from torch._C._distributed_c10d import (
2322
AllgatherOptions,
@@ -75,39 +74,45 @@ def run_collectives(
7574
[torch.empty_like(input_tensor) for _ in range(get_world_size(pg))]
7675
]
7776
tensor_list = [torch.empty_like(input_tensor)]
77+
7878
works = []
79+
input_tensors = []
7980

8081
if "allreduce" in collectives:
8182
works += [
8283
pg.allreduce([input_tensor], AllreduceOptions()),
8384
pg.allreduce([input_tensor], ReduceOp.SUM),
8485
]
85-
input_tensors = input_tensor
86+
input_tensors += [input_tensor, input_tensor]
8687
elif "allgather" in collectives:
8788
works += [pg.allgather(output_tensors, [input_tensor], AllgatherOptions())]
88-
input_tensors = (output_tensors, input_tensor)
89+
input_tensors += [(output_tensors, input_tensor)]
8990
elif "broadcast" in collectives:
9091
works += [pg.broadcast(tensor_list, BroadcastOptions())]
91-
input_tensors = tensor_list
92+
input_tensors += [tensor_list]
9293
elif "broadcast_one" in collectives:
9394
works += [pg.broadcast_one(input_tensor, 0)]
94-
input_tensors = input_tensor
95+
input_tensors += [input_tensor]
9596

96-
def check_tensors(input_tensors: Any) -> None: # pyre-ignore[2]
97+
def check_tensors(input_tensors: Union[torch.Tensor, List[torch.Tensor]]) -> None:
9798
"""Recursively check tensors for input_tensors shape and dtype."""
9899
if isinstance(input_tensors, torch.Tensor):
99-
assert input_tensors.dtype == dtype, f"Output dtype mismatch: {input_tensors.dtype} != {dtype}"
100-
assert input_tensors.shape == shape, f"Output shape mismatch: {input_tensors.shape} != {shape}"
100+
assert (
101+
input_tensors.dtype == dtype
102+
), f"Output dtype mismatch: {input_tensors.dtype} != {dtype}"
103+
assert (
104+
input_tensors.shape == shape
105+
), f"Output shape mismatch: {input_tensors.shape} != {shape}"
101106
elif isinstance(input_tensors, (list, tuple)):
102107
for item in input_tensors:
103108
check_tensors(item)
104109

105-
for work in works:
110+
for work, input_tensor in zip(works, input_tensors):
106111
work.wait()
107112
fut = work.get_future()
108113
fut.wait()
109114
# Check that all tensor arguments have the input_tensors shapes and dtypes
110-
check_tensors(input_tensors)
115+
check_tensors(input_tensor)
111116

112117
print(works)
113118
return works
@@ -128,8 +133,7 @@ def setUp(self) -> None:
128133
)
129134
self.store_addr = f"localhost:{self.store.port}/prefix"
130135

131-
@parameterized.expand(collectives)
132-
def test_nccl(self, collective: str) -> None:
136+
def test_nccl(self) -> None:
133137
device = "cuda"
134138

135139
pg = ProcessGroupNCCL()
@@ -139,7 +143,7 @@ def test_nccl(self, collective: str) -> None:
139143

140144
run_collectives(
141145
pg=pg,
142-
collectives=[collective],
146+
collectives=self.collectives,
143147
example_tensor=torch.tensor([2], device=device),
144148
)
145149

@@ -153,7 +157,7 @@ def test_nccl(self, collective: str) -> None:
153157

154158
run_collectives(
155159
pg=pg,
156-
collectives=[collective],
160+
collectives=self.collectives,
157161
example_tensor=torch.tensor([2], device=device),
158162
)
159163

@@ -233,23 +237,21 @@ def setUp(self) -> None:
233237
)
234238
self.store_addr = f"localhost:{self.store.port}/prefix"
235239

236-
@parameterized.expand(collectives)
237-
def test_gloo(self, collective: str) -> None:
240+
def test_gloo(self) -> None:
238241
pg = ProcessGroupGloo()
239242
pg.configure(self.store_addr, 0, 1)
240243

241244
self.assertEqual(pg.size(), 1)
242-
run_collectives(pg=pg, collectives=[collective])
245+
run_collectives(pg=pg, collectives=self.collectives)
243246
m = nn.Linear(3, 4)
244247
m = torch.nn.parallel.DistributedDataParallel(m, process_group=pg)
245248
m(torch.rand(2, 3))
246249

247-
@parameterized.expand(collectives)
248-
def test_baby_gloo_apis(self, collective: str) -> None:
250+
def test_baby_gloo_apis(self) -> None:
249251
pg = ProcessGroupBabyGloo(timeout=timedelta(seconds=10))
250252
pg.configure(self.store_addr, 0, 1)
251253

252-
run_collectives(pg=pg, collectives=[collective])
254+
run_collectives(pg=pg, collectives=self.collectives)
253255

254256
# force collection to ensure no BabyWork objects remain
255257
gc.collect()

0 commit comments

Comments
 (0)