Skip to content

Commit 39b6a2f

Browse files
tushar00jainfacebook-github-bot
authored andcommitted
fix call to manager allreduce (#273)
Summary: Pull Request resolved: #273 manager allreduce expects a tensor but we're passing a list Reviewed By: amirafzali, d4l3k Differential Revision: D83353505 fbshipit-source-id: ea5bdee40f9de22720a3135198df478237f1405d
1 parent 94ca830 commit 39b6a2f

File tree

1 file changed

+4
-2
lines changed

1 file changed

+4
-2
lines changed

torchft/process_group.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1221,11 +1221,13 @@ def __init__(self, manager: "Manager") -> None:
12211221
self._manager = manager
12221222

12231223
def allreduce(self, tensors: List[torch.Tensor], opts: object) -> Work:
1224+
assert len(tensors) == 1
1225+
12241226
if isinstance(opts, ReduceOp):
1225-
return self._manager.allreduce(tensors, reduce_op=opts)
1227+
return self._manager.allreduce(tensors[0], reduce_op=opts)
12261228

12271229
if isinstance(opts, AllreduceOptions):
1228-
return self._manager.allreduce(tensors, reduce_op=opts.reduceOp)
1230+
return self._manager.allreduce(tensors[0], reduce_op=opts.reduceOp)
12291231

12301232
assert False, "unreachable"
12311233

0 commit comments

Comments
 (0)