We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 94ca830 commit 39b6a2fCopy full SHA for 39b6a2f
torchft/process_group.py
@@ -1221,11 +1221,13 @@ def __init__(self, manager: "Manager") -> None:
1221
self._manager = manager
1222
1223
def allreduce(self, tensors: List[torch.Tensor], opts: object) -> Work:
1224
+ assert len(tensors) == 1
1225
+
1226
if isinstance(opts, ReduceOp):
- return self._manager.allreduce(tensors, reduce_op=opts)
1227
+ return self._manager.allreduce(tensors[0], reduce_op=opts)
1228
1229
if isinstance(opts, AllreduceOptions):
- return self._manager.allreduce(tensors, reduce_op=opts.reduceOp)
1230
+ return self._manager.allreduce(tensors[0], reduce_op=opts.reduceOp)
1231
1232
assert False, "unreachable"
1233
0 commit comments