Skip to content

Commit 731fa05

Browse files
committed
Fix simple_combine
1 parent ef21d21 commit 731fa05

File tree

1 file changed

+12
-2
lines changed

1 file changed

+12
-2
lines changed

flox/core.py

+12-2
Original file line numberDiff line numberDiff line change
@@ -735,14 +735,24 @@ def _expand_dims(results: IntermediateDict) -> IntermediateDict:
735735
def _simple_combine(
736736
x_chunk, agg: Aggregation, axis: Sequence, keepdims: bool, is_aggregate: bool = False
737737
) -> IntermediateDict:
738+
"""
739+
'Simple' combination of blockwise results.
740+
741+
1. After the blockwise groupby-reduce, all blocks contain a value for all possible groups,
742+
and are of the same shape; i.e. reindex must have been True
743+
2. _expand_dims was used to insert an extra axis DUMMY_AXIS
744+
3. Here we concatenate along DUMMY_AXIS, and then call the combine function along
745+
DUMMY_AXIS
746+
4. At the final agggregate step, we squeeze out DUMMY_AXIS
747+
"""
738748
from dask.array.core import deepfirst
739749

740750
results = {"groups": deepfirst(x_chunk)["groups"]}
741751
results["intermediates"] = []
742752
for idx, combine in enumerate(agg.combine):
743-
array = _conc2(x_chunk, key1="intermediates", key2=idx, axis=axis)
753+
array = _conc2(x_chunk, key1="intermediates", key2=idx, axis=axis[:-1] + (DUMMY_AXIS,))
744754
assert array.ndim >= 2
745-
result = getattr(np, combine)(array, axis=axis, keepdims=True)
755+
result = getattr(np, combine)(array, axis=axis[:-1] + (DUMMY_AXIS,), keepdims=True)
746756
if is_aggregate:
747757
# squeeze out DUMMY_AXIS if this is the last step i.e. called from _aggregate
748758
result = result.squeeze(axis=DUMMY_AXIS)

0 commit comments

Comments
 (0)