Skip to content

Commit c2c4e1d

Browse files
authored
Support reindexing in simple_combine (#177)
* Support reindexing in simple_combine For 1D combine, great improvement for cohorts-type reductions More memory but similar time for map-reduce. Note that the map-reduce intermediates are a worst case where there are no shared groups between the chunks being combined. This case is actually optimized in _group_combine where reindexing is skipped for reducing along a single axis. [ 68.75%] ··· =========== ========= ========= -- combine ----------- ------------------- kind grouped combine =========== ========= ========= cohorts 760M 631M mapreduce 981M 1.81G =========== ========= ========= [ 75.00%] ··· =========== ========== =========== -- combine ----------- ---------------------- kind grouped combine =========== ========== =========== cohorts 393±10ms 137±10ms mapreduce 652±10ms 611±400ms =========== ========== =========== Fix bug in unique * Fix bug with all NaN blocks
1 parent 0db264a commit c2c4e1d

File tree

3 files changed

+129
-90
lines changed

3 files changed

+129
-90
lines changed

asv_bench/benchmarks/combine.py

+19-10
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from functools import partial
2+
13
import numpy as np
24

35
import flox
@@ -7,26 +9,31 @@
79
N = 1000
810

911

12+
def _get_combine(combine):
13+
if combine == "grouped":
14+
return partial(flox.core._grouped_combine, engine="numpy")
15+
else:
16+
return partial(flox.core._simple_combine, reindex=False)
17+
18+
1019
class Combine:
1120
def setup(self, *args, **kwargs):
1221
raise NotImplementedError
1322

14-
@parameterized("kind", ("cohorts", "mapreduce"))
15-
def time_combine(self, kind):
16-
flox.core._grouped_combine(
23+
@parameterized(("kind", "combine"), (("reindexed", "not_reindexed"), ("grouped", "simple")))
24+
def time_combine(self, kind, combine):
25+
_get_combine(combine)(
1726
getattr(self, f"x_chunk_{kind}"),
1827
**self.kwargs,
1928
keepdims=True,
20-
engine="numpy",
2129
)
2230

23-
@parameterized("kind", ("cohorts", "mapreduce"))
24-
def peakmem_combine(self, kind):
25-
flox.core._grouped_combine(
31+
@parameterized(("kind", "combine"), (("reindexed", "not_reindexed"), ("grouped", "simple")))
32+
def peakmem_combine(self, kind, combine):
33+
_get_combine(combine)(
2634
getattr(self, f"x_chunk_{kind}"),
2735
**self.kwargs,
2836
keepdims=True,
29-
engine="numpy",
3037
)
3138

3239

@@ -47,7 +54,7 @@ def construct_member(groups):
4754
}
4855

4956
# motivated by
50-
self.x_chunk_mapreduce = [
57+
self.x_chunk_not_reindexed = [
5158
construct_member(groups)
5259
for groups in [
5360
np.array((1, 2, 3, 4)),
@@ -57,5 +64,7 @@ def construct_member(groups):
5764
* 2
5865
]
5966

60-
self.x_chunk_cohorts = [construct_member(groups) for groups in [np.array((1, 2, 3, 4))] * 4]
67+
self.x_chunk_reindexed = [
68+
construct_member(groups) for groups in [np.array((1, 2, 3, 4))] * 4
69+
]
6170
self.kwargs = {"agg": flox.aggregations.mean, "axis": (3,)}

flox/core.py

+60-45
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,7 @@ def _get_optimal_chunks_for_groups(chunks, labels):
136136
return tuple(newchunks)
137137

138138

139-
def _unique(a: np.ndarray):
139+
def _unique(a: np.ndarray) -> np.ndarray:
140140
"""Much faster to use pandas unique and sort the results.
141141
np.unique sorts before uniquifying and is slow."""
142142
return np.sort(pd.unique(a.reshape(-1)))
@@ -816,8 +816,25 @@ def _expand_dims(results: IntermediateDict) -> IntermediateDict:
816816
return results
817817

818818

819+
def _find_unique_groups(x_chunk) -> np.ndarray:
820+
from dask.base import flatten
821+
from dask.utils import deepmap
822+
823+
unique_groups = _unique(np.asarray(tuple(flatten(deepmap(listify_groups, x_chunk)))))
824+
unique_groups = unique_groups[~isnull(unique_groups)]
825+
826+
if len(unique_groups) == 0:
827+
unique_groups = np.array([np.nan])
828+
return unique_groups
829+
830+
819831
def _simple_combine(
820-
x_chunk, agg: Aggregation, axis: T_Axes, keepdims: bool, is_aggregate: bool = False
832+
x_chunk,
833+
agg: Aggregation,
834+
axis: T_Axes,
835+
keepdims: bool,
836+
reindex: bool,
837+
is_aggregate: bool = False,
821838
) -> IntermediateDict:
822839
"""
823840
'Simple' combination of blockwise results.
@@ -830,8 +847,19 @@ def _simple_combine(
830847
4. At the final agggregate step, we squeeze out DUMMY_AXIS
831848
"""
832849
from dask.array.core import deepfirst
850+
from dask.utils import deepmap
851+
852+
if not reindex:
853+
# We didn't reindex at the blockwise step
854+
# So now reindex before combining by reducing along DUMMY_AXIS
855+
unique_groups = _find_unique_groups(x_chunk)
856+
x_chunk = deepmap(
857+
partial(reindex_intermediates, agg=agg, unique_groups=unique_groups), x_chunk
858+
)
859+
else:
860+
unique_groups = deepfirst(x_chunk)["groups"]
833861

834-
results: IntermediateDict = {"groups": deepfirst(x_chunk)["groups"]}
862+
results: IntermediateDict = {"groups": unique_groups}
835863
results["intermediates"] = []
836864
axis_ = axis[:-1] + (DUMMY_AXIS,)
837865
for idx, combine in enumerate(agg.combine):
@@ -886,7 +914,6 @@ def _grouped_combine(
886914
sort: bool = True,
887915
) -> IntermediateDict:
888916
"""Combine intermediates step of tree reduction."""
889-
from dask.base import flatten
890917
from dask.utils import deepmap
891918

892919
if isinstance(x_chunk, dict):
@@ -897,11 +924,7 @@ def _grouped_combine(
897924
# when there's only a single axis of reduction, we can just concatenate later,
898925
# reindexing is unnecessary
899926
# I bet we can minimize the amount of reindexing for mD reductions too, but it's complicated
900-
unique_groups = _unique(np.array(tuple(flatten(deepmap(listify_groups, x_chunk)))))
901-
unique_groups = unique_groups[~isnull(unique_groups)]
902-
if len(unique_groups) == 0:
903-
unique_groups = [np.nan]
904-
927+
unique_groups = _find_unique_groups(x_chunk)
905928
x_chunk = deepmap(
906929
partial(reindex_intermediates, agg=agg, unique_groups=unique_groups), x_chunk
907930
)
@@ -1216,7 +1239,8 @@ def dask_groupby_agg(
12161239
# This allows us to discover groups at compute time, support argreductions, lower intermediate
12171240
# memory usage (but method="cohorts" would also work to reduce memory in some cases)
12181241

1219-
do_simple_combine = method != "blockwise" and reindex and not _is_arg_reduction(agg)
1242+
do_simple_combine = not _is_arg_reduction(agg)
1243+
12201244
if method == "blockwise":
12211245
# use the "non dask" code path, but applied blockwise
12221246
blockwise_method = partial(
@@ -1268,31 +1292,32 @@ def dask_groupby_agg(
12681292
if method in ["map-reduce", "cohorts"]:
12691293
combine: Callable[..., IntermediateDict]
12701294
if do_simple_combine:
1271-
combine = _simple_combine
1295+
combine = partial(_simple_combine, reindex=reindex)
1296+
combine_name = "simple-combine"
12721297
else:
12731298
combine = partial(_grouped_combine, engine=engine, sort=sort)
1299+
combine_name = "grouped-combine"
12741300

1275-
# Each chunk of `reduced`` is really a dict mapping
1276-
# 1. reduction name to array
1277-
# 2. "groups" to an array of group labels
1278-
# Note: it does not make sense to interpret axis relative to
1279-
# shape of intermediate results after the blockwise call
12801301
tree_reduce = partial(
12811302
dask.array.reductions._tree_reduce,
1282-
combine=partial(combine, agg=agg),
1283-
name=f"{name}-reduce-{method}",
1303+
name=f"{name}-reduce-{method}-{combine_name}",
12841304
dtype=array.dtype,
12851305
axis=axis,
12861306
keepdims=True,
12871307
concatenate=False,
12881308
)
1289-
aggregate = partial(
1290-
_aggregate, combine=combine, agg=agg, fill_value=fill_value, reindex=reindex
1291-
)
1309+
aggregate = partial(_aggregate, combine=combine, agg=agg, fill_value=fill_value)
1310+
1311+
# Each chunk of `reduced`` is really a dict mapping
1312+
# 1. reduction name to array
1313+
# 2. "groups" to an array of group labels
1314+
# Note: it does not make sense to interpret axis relative to
1315+
# shape of intermediate results after the blockwise call
12921316
if method == "map-reduce":
12931317
reduced = tree_reduce(
12941318
intermediate,
1295-
aggregate=partial(aggregate, expected_groups=expected_groups),
1319+
combine=partial(combine, agg=agg),
1320+
aggregate=partial(aggregate, expected_groups=expected_groups, reindex=reindex),
12961321
)
12971322
if is_duck_dask_array(by_input) and expected_groups is None:
12981323
groups = _extract_unknown_groups(reduced, group_chunks=group_chunks, dtype=by.dtype)
@@ -1310,23 +1335,17 @@ def dask_groupby_agg(
13101335
reduced_ = []
13111336
groups_ = []
13121337
for blks, cohort in chunks_cohorts.items():
1338+
index = pd.Index(cohort)
13131339
subset = subset_to_blocks(intermediate, blks, array.blocks.shape[-len(axis) :])
1314-
if do_simple_combine:
1315-
# reindex so that reindex can be set to True later
1316-
reindexed = dask.array.map_blocks(
1317-
reindex_intermediates,
1318-
subset,
1319-
agg=agg,
1320-
unique_groups=cohort,
1321-
meta=subset._meta,
1322-
)
1323-
else:
1324-
reindexed = subset
1325-
1340+
reindexed = dask.array.map_blocks(
1341+
reindex_intermediates, subset, agg=agg, unique_groups=index, meta=subset._meta
1342+
)
1343+
# now that we have reindexed, we can set reindex=True explicitlly
13261344
reduced_.append(
13271345
tree_reduce(
13281346
reindexed,
1329-
aggregate=partial(aggregate, expected_groups=cohort, reindex=reindex),
1347+
combine=partial(combine, agg=agg, reindex=True),
1348+
aggregate=partial(aggregate, expected_groups=index, reindex=True),
13301349
)
13311350
)
13321351
groups_.append(cohort)
@@ -1382,28 +1401,24 @@ def _validate_reindex(
13821401
if reindex is True:
13831402
if _is_arg_reduction(func):
13841403
raise NotImplementedError
1385-
if method == "blockwise":
1386-
raise NotImplementedError
1404+
if method in ["blockwise", "cohorts"]:
1405+
raise ValueError(
1406+
"reindex=True is not a valid choice for method='blockwise' or method='cohorts'."
1407+
)
13871408

13881409
if reindex is None:
13891410
if method == "blockwise" or _is_arg_reduction(func):
13901411
reindex = False
13911412

1392-
elif expected_groups is not None:
1393-
reindex = True
1394-
1395-
elif method in ["split-reduce", "cohorts"]:
1396-
reindex = True
1413+
elif method == "cohorts":
1414+
reindex = False
13971415

13981416
elif method == "map-reduce":
13991417
if expected_groups is None and by_is_dask:
14001418
reindex = False
14011419
else:
14021420
reindex = True
14031421

1404-
if method in ["split-reduce", "cohorts"] and reindex is False:
1405-
raise NotImplementedError
1406-
14071422
assert isinstance(reindex, bool)
14081423
return reindex
14091424

0 commit comments

Comments
 (0)