@@ -735,14 +735,24 @@ def _expand_dims(results: IntermediateDict) -> IntermediateDict:
735
735
def _simple_combine (
736
736
x_chunk , agg : Aggregation , axis : Sequence , keepdims : bool , is_aggregate : bool = False
737
737
) -> IntermediateDict :
738
+ """
739
+ 'Simple' combination of blockwise results.
740
+
741
+ 1. After the blockwise groupby-reduce, all blocks contain a value for all possible groups,
742
+ and are of the same shape; i.e. reindex must have been True
743
+ 2. _expand_dims was used to insert an extra axis DUMMY_AXIS
744
+ 3. Here we concatenate along DUMMY_AXIS, and then call the combine function along
745
+ DUMMY_AXIS
746
+ 4. At the final agggregate step, we squeeze out DUMMY_AXIS
747
+ """
738
748
from dask .array .core import deepfirst
739
749
740
750
results = {"groups" : deepfirst (x_chunk )["groups" ]}
741
751
results ["intermediates" ] = []
742
752
for idx , combine in enumerate (agg .combine ):
743
- array = _conc2 (x_chunk , key1 = "intermediates" , key2 = idx , axis = axis )
753
+ array = _conc2 (x_chunk , key1 = "intermediates" , key2 = idx , axis = axis [: - 1 ] + ( DUMMY_AXIS ,) )
744
754
assert array .ndim >= 2
745
- result = getattr (np , combine )(array , axis = axis , keepdims = True )
755
+ result = getattr (np , combine )(array , axis = axis [: - 1 ] + ( DUMMY_AXIS ,) , keepdims = True )
746
756
if is_aggregate :
747
757
# squeeze out DUMMY_AXIS if this is the last step i.e. called from _aggregate
748
758
result = result .squeeze (axis = DUMMY_AXIS )
0 commit comments