@@ -971,14 +971,16 @@ def reduce(self, data, op="sum", root=0):
971971 )
972972 for deeper_key in first_value .keys ()
973973 }
974+ else :
975+ # Otherwise, we assume our values can be concatenated
976+ reduced_values = reduction_op ([reduction_op (d ) for d in data .values ()])
974977
975- # Otherwise, we assume our values can be concatenated
976- reduced_values = reduction_op ([reduction_op (d ) for d in data .values ()])
977-
978- if self ._distributed :
979- reduced_values = self .comm .allreduce (reduced_values , op = mpi_reduction_op )
978+ if self ._distributed :
979+ reduced_values = self .comm .allreduce (
980+ reduced_values , op = mpi_reduction_op
981+ )
980982
981- return reduced_values
983+ return reduced_values
982984
983985 def all_concat (self , data : dict [Any , np .ndarray | dict [str , np .ndarray ]]):
984986 """
@@ -999,14 +1001,16 @@ def all_concat(self, data: dict[Any, np.ndarray | dict[str, np.ndarray]]):
9991001 )
10001002 for deeper_key in first_value .keys ()
10011003 }
1004+ else :
1005+ # Otherwise, we assume our values can be concatenated
1006+ all_values = np .concatenate (
1007+ [np .atleast_1d (value ) for value in data .values ()]
1008+ )
10021009
1003- # Otherwise, we assume our values can be concatenated
1004- all_values = np .concatenate ([np .atleast_1d (value ) for value in data .values ()])
1005-
1006- if self ._distributed :
1007- all_values = np .concatenate (self .comm .allgather (all_values ))
1010+ if self ._distributed :
1011+ all_values = np .concatenate (self .comm .allgather (all_values ))
10081012
1009- return all_values
1013+ return all_values
10101014
10111015 def concat (
10121016 self , data : dict [Any , np .ndarray | dict [str , np .ndarray ]], root : int = 0
@@ -1030,16 +1034,18 @@ def concat(
10301034 )
10311035 for deeper_key in first_value .keys ()
10321036 }
1037+ else :
1038+ # Otherwise, we assume our values can be concatenated
1039+ all_values = np .concatenate (
1040+ [np .atleast_1d (value ) for value in data .values ()]
1041+ )
10331042
1034- # Otherwise, we assume our values can be concatenated
1035- all_values = np .concatenate ([np .atleast_1d (value ) for value in data .values ()])
1036-
1037- if self ._distributed :
1038- tmp = self .comm .gather (all_values , root = root )
1039- if self .comm .rank == root :
1040- return np .concatenate (tmp )
1041- else :
1042- return None
1043+ if self ._distributed :
1044+ tmp = self .comm .gather (all_values , root = root )
1045+ if self .comm .rank == root :
1046+ return np .concatenate (tmp )
1047+ else :
1048+ return None
10431049
10441050 ###
10451051 # Non-blocking stuff.
0 commit comments