File tree 1 file changed +3
-6
lines changed
1 file changed +3
-6
lines changed Original file line number Diff line number Diff line change @@ -100,15 +100,12 @@ def run_collective(
100
100
coll = getattr (pg , collective )
101
101
args_list = _build_args (pg = pg , collective = collective , example_tensor = example_tensor )
102
102
works : Dict [str , dist ._Work ] = {}
103
+
103
104
def check_tensors (arg : Any ) -> None : # pyre-ignore[2]
104
105
"""Recursively check tensors for expected shape and dtype."""
105
106
if isinstance (arg , torch .Tensor ):
106
- assert (
107
- arg .dtype == dtype
108
- ), f"Output dtype mismatch: { arg .dtype } != { dtype } "
109
- assert (
110
- arg .shape == shape
111
- ), f"Output shape mismatch: { arg .shape } != { shape } "
107
+ assert arg .dtype == dtype , f"Output dtype mismatch: { arg .dtype } != { dtype } "
108
+ assert arg .shape == shape , f"Output shape mismatch: { arg .shape } != { shape } "
112
109
elif isinstance (arg , (list , tuple )):
113
110
for item in arg :
114
111
check_tensors (item )
You can’t perform that action at this time.
0 commit comments