File tree 1 file changed +12
-14
lines changed
1 file changed +12
-14
lines changed Original file line number Diff line number Diff line change @@ -100,27 +100,25 @@ def run_collective(
100
100
coll = getattr (pg , collective )
101
101
args_list = _build_args (pg = pg , collective = collective , example_tensor = example_tensor )
102
102
works : Dict [str , dist ._Work ] = {}
103
+ def check_tensors (arg : Any ) -> None : # pyre-ignore[2]
104
+ """Recursively check tensors for expected shape and dtype."""
105
+ if isinstance (arg , torch .Tensor ):
106
+ assert (
107
+ arg .dtype == dtype
108
+ ), f"Output dtype mismatch: { arg .dtype } != { dtype } "
109
+ assert (
110
+ arg .shape == shape
111
+ ), f"Output shape mismatch: { arg .shape } != { shape } "
112
+ elif isinstance (arg , (list , tuple )):
113
+ for item in arg :
114
+ check_tensors (item )
103
115
104
116
for args in args_list :
105
117
work = coll (* args )
106
118
works [collective ] = work
107
119
work .wait ()
108
120
fut = work .get_future ()
109
121
fut .wait ()
110
-
111
- def check_tensors (arg : Any ) -> None : # pyre-ignore[2]
112
- """Recursively check tensors for expected shape and dtype."""
113
- if isinstance (arg , torch .Tensor ):
114
- assert (
115
- arg .dtype == dtype
116
- ), f"Output dtype mismatch: { arg .dtype } != { dtype } "
117
- assert (
118
- arg .shape == shape
119
- ), f"Output shape mismatch: { arg .shape } != { shape } "
120
- elif isinstance (arg , (list , tuple )):
121
- for item in arg :
122
- check_tensors (item )
123
-
124
122
check_tensors (args )
125
123
print (works )
126
124
return works
You can’t perform that action at this time.
0 commit comments