Skip to content

Commit 5190414

Browse files
committed
fixes test
1 parent a425493 commit 5190414

File tree

1 file changed

+9
-1
lines changed

1 file changed

+9
-1
lines changed

torchft/process_group.py

+9-1
Original file line numberDiff line numberDiff line change
@@ -1037,7 +1037,15 @@ def safe_args(cls, args: T) -> T:
10371037
return tuple(cls.safe_args(arg) for arg in args)
10381038
elif isinstance(args, list):
10391039
return [cls.safe_args(arg) for arg in args]
1040-
elif isinstance(args, (AllreduceOptions, AllgatherOptions, BroadcastOptions)):
1040+
elif isinstance(
1041+
args,
1042+
(
1043+
AllreduceOptions,
1044+
AllgatherOptions,
1045+
BroadcastOptions,
1046+
ReduceScatterOptions,
1047+
),
1048+
):
10411049
return cls.from_torch(args)
10421050
else:
10431051
return args

0 commit comments

Comments
 (0)