@@ -853,17 +853,23 @@ def concat_arrays(arrays, reindexers, axis=0, index=None, fill_value=None):
853853 )
854854
855855
856- def inner_concat_aligned_mapping (mappings , * , reindexers = None , index = None , axis = 0 ):
856+ def inner_concat_aligned_mapping (
857+ mappings , * , reindexers = None , index = None , axis = 0 , concat_axis = None
858+ ):
859+ if concat_axis is None :
860+ concat_axis = axis
857861 result = {}
858862
859863 for k in intersect_keys (mappings ):
860864 els = [m [k ] for m in mappings ]
861865 if reindexers is None :
862- cur_reindexers = gen_inner_reindexers (els , new_index = index , axis = axis )
866+ cur_reindexers = gen_inner_reindexers (
867+ els , new_index = index , axis = concat_axis
868+ )
863869 else :
864870 cur_reindexers = reindexers
865871
866- result [k ] = concat_arrays (els , cur_reindexers , index = index , axis = axis )
872+ result [k ] = concat_arrays (els , cur_reindexers , index = index , axis = concat_axis )
867873 return result
868874
869875
@@ -959,15 +965,19 @@ def missing_element(
959965
960966
961967def outer_concat_aligned_mapping (
962- mappings , * , reindexers = None , index = None , axis = 0 , fill_value = None
968+ mappings , * , reindexers = None , index = None , axis = 0 , concat_axis = None , fill_value = None
963969):
970+ if concat_axis is None :
971+ concat_axis = axis
964972 result = {}
965973 ns = [m .parent .shape [axis ] for m in mappings ]
966974
967975 for k in union_keys (mappings ):
968976 els = [m .get (k , MissingVal ) for m in mappings ]
969977 if reindexers is None :
970- cur_reindexers = gen_outer_reindexers (els , ns , new_index = index , axis = axis )
978+ cur_reindexers = gen_outer_reindexers (
979+ els , ns , new_index = index , axis = concat_axis
980+ )
971981 else :
972982 cur_reindexers = reindexers
973983
@@ -986,15 +996,15 @@ def outer_concat_aligned_mapping(
986996 if not_missing (el )
987997 else missing_element (
988998 n ,
989- axis = axis ,
999+ axis = concat_axis ,
9901000 els = els ,
9911001 fill_value = fill_value ,
9921002 off_axis_size = off_axis_size ,
9931003 )
9941004 for el , n in zip (els , ns )
9951005 ],
9961006 cur_reindexers ,
997- axis = axis ,
1007+ axis = concat_axis ,
9981008 index = index ,
9991009 fill_value = fill_value ,
10001010 )
@@ -1368,7 +1378,10 @@ def concat(
13681378 [a .layers for a in adatas ], axis = axis , reindexers = reindexers
13691379 )
13701380 concat_mapping = concat_aligned_mapping (
1371- [getattr (a , f"{ axis_name } m" ) for a in adatas ], index = concat_indices
1381+ [getattr (a , f"{ axis_name } m" ) for a in adatas ],
1382+ axis = axis ,
1383+ concat_axis = 0 ,
1384+ index = concat_indices ,
13721385 )
13731386 if pairwise :
13741387 concat_pairwise = concat_pairwise_mapping (
0 commit comments