Skip to content

Commit 5f73001

Browse files
committed
minor adjustment, move check_tensors out of the loop
1 parent 4c3b78e commit 5f73001

File tree

1 file changed

+12
-14
lines changed

1 file changed

+12
-14
lines changed

torchft/process_group_test.py

+12-14
Original file line numberDiff line numberDiff line change
@@ -100,27 +100,25 @@ def run_collective(
100100
coll = getattr(pg, collective)
101101
args_list = _build_args(pg=pg, collective=collective, example_tensor=example_tensor)
102102
works: Dict[str, dist._Work] = {}
103+
def check_tensors(arg: Any) -> None: # pyre-ignore[2]
104+
"""Recursively check tensors for expected shape and dtype."""
105+
if isinstance(arg, torch.Tensor):
106+
assert (
107+
arg.dtype == dtype
108+
), f"Output dtype mismatch: {arg.dtype} != {dtype}"
109+
assert (
110+
arg.shape == shape
111+
), f"Output shape mismatch: {arg.shape} != {shape}"
112+
elif isinstance(arg, (list, tuple)):
113+
for item in arg:
114+
check_tensors(item)
103115

104116
for args in args_list:
105117
work = coll(*args)
106118
works[collective] = work
107119
work.wait()
108120
fut = work.get_future()
109121
fut.wait()
110-
111-
def check_tensors(arg: Any) -> None: # pyre-ignore[2]
112-
"""Recursively check tensors for expected shape and dtype."""
113-
if isinstance(arg, torch.Tensor):
114-
assert (
115-
arg.dtype == dtype
116-
), f"Output dtype mismatch: {arg.dtype} != {dtype}"
117-
assert (
118-
arg.shape == shape
119-
), f"Output shape mismatch: {arg.shape} != {shape}"
120-
elif isinstance(arg, (list, tuple)):
121-
for item in arg:
122-
check_tensors(item)
123-
124122
check_tensors(args)
125123
print(works)
126124
return works

0 commit comments

Comments
 (0)