Skip to content

Commit 0c7ac68

Browse files
committed
linter
1 parent 5f73001 commit 0c7ac68

File tree

1 file changed

+3
-6
lines changed

1 file changed

+3
-6
lines changed

torchft/process_group_test.py

+3-6
Original file line numberDiff line numberDiff line change
@@ -100,15 +100,12 @@ def run_collective(
100100
coll = getattr(pg, collective)
101101
args_list = _build_args(pg=pg, collective=collective, example_tensor=example_tensor)
102102
works: Dict[str, dist._Work] = {}
103+
103104
def check_tensors(arg: Any) -> None: # pyre-ignore[2]
104105
"""Recursively check tensors for expected shape and dtype."""
105106
if isinstance(arg, torch.Tensor):
106-
assert (
107-
arg.dtype == dtype
108-
), f"Output dtype mismatch: {arg.dtype} != {dtype}"
109-
assert (
110-
arg.shape == shape
111-
), f"Output shape mismatch: {arg.shape} != {shape}"
107+
assert arg.dtype == dtype, f"Output dtype mismatch: {arg.dtype} != {dtype}"
108+
assert arg.shape == shape, f"Output shape mismatch: {arg.shape} != {shape}"
112109
elif isinstance(arg, (list, tuple)):
113110
for item in arg:
114111
check_tensors(item)

0 commit comments

Comments
 (0)