Skip to content

Commit 471aa94

Browse files
authored
Use blockwise to extract final result (#182)
* Use blockwise to extract final result for method="blockwise" * FOr all methods * bugfix * Try return_array from _finalize_results * Revert "Try return_array from _finalize_results" This reverts commit cb25e38. * Fixes.
1 parent df0da40 commit 471aa94

File tree

1 file changed

+50
-20
lines changed

1 file changed

+50
-20
lines changed

flox/core.py

+50-20
Original file line numberDiff line numberDiff line change
@@ -1195,7 +1195,6 @@ def dask_groupby_agg(
11951195

11961196
import dask.array
11971197
from dask.array.core import slices_from_chunks
1198-
from dask.highlevelgraph import HighLevelGraph
11991198

12001199
# I think _tree_reduce expects this
12011200
assert isinstance(axis, Sequence)
@@ -1268,6 +1267,9 @@ def dask_groupby_agg(
12681267
engine=engine,
12691268
sort=sort,
12701269
),
1270+
# output indices are the same as input indices
1271+
# Unlike xhistogram, we don't always know what the size of the group
1272+
# dimension will be unless reindex=True
12711273
inds,
12721274
array,
12731275
inds,
@@ -1277,7 +1279,7 @@ def dask_groupby_agg(
12771279
dtype=array.dtype, # this is purely for show
12781280
meta=array._meta,
12791281
align_arrays=False,
1280-
token=f"{name}-chunk-{token}",
1282+
name=f"{name}-chunk-{token}",
12811283
)
12821284

12831285
if expected_groups is None:
@@ -1364,35 +1366,63 @@ def dask_groupby_agg(
13641366
groups = (np.concatenate(groups_in_block),)
13651367
ngroups_per_block = tuple(len(grp) for grp in groups_in_block)
13661368
group_chunks = (ngroups_per_block,)
1367-
13681369
else:
13691370
raise ValueError(f"Unknown method={method}.")
13701371

1371-
# extract results from the dict
1372+
out_inds = inds[: -len(axis)] + (inds[-1],)
13721373
output_chunks = reduced.chunks[: -len(axis)] + group_chunks
1374+
if method == "blockwise" and len(axis) > 1:
1375+
# The final results are available but the blocks along axes
1376+
# need to be reshaped to axis=-1
1377+
# I don't know that this is possible with blockwise
1378+
# All other code paths benefit from an unmaterialized Blockwise layer
1379+
reduced = _collapse_blocks_along_axes(reduced, axis, group_chunks)
1380+
1381+
# Can't use map_blocks because it forces concatenate=True along drop_axes,
1382+
result = dask.array.blockwise(
1383+
_extract_result,
1384+
out_inds,
1385+
reduced,
1386+
inds,
1387+
adjust_chunks=dict(zip(out_inds, output_chunks)),
1388+
dtype=agg.dtype[agg.name],
1389+
key=agg.name,
1390+
name=f"{name}-{token}",
1391+
concatenate=False,
1392+
)
1393+
1394+
return (result, groups)
1395+
1396+
1397+
def _collapse_blocks_along_axes(reduced, axis, group_chunks):
1398+
import dask.array
1399+
from dask.highlevelgraph import HighLevelGraph
1400+
1401+
nblocks = tuple(reduced.numblocks[ax] for ax in axis)
1402+
output_chunks = reduced.chunks[: -len(axis)] + ((1,) * (len(axis) - 1),) + group_chunks
1403+
1404+
# extract results from the dict
13731405
ochunks = tuple(range(len(chunks_v)) for chunks_v in output_chunks)
13741406
layer2: dict[tuple, tuple] = {}
1375-
agg_name = f"{name}-{token}"
1376-
for ochunk in itertools.product(*ochunks):
1377-
if method == "blockwise":
1378-
if len(axis) == 1:
1379-
inchunk = ochunk
1380-
else:
1381-
nblocks = tuple(len(array.chunks[ax]) for ax in axis)
1382-
inchunk = ochunk[:-1] + np.unravel_index(ochunk[-1], nblocks)
1383-
else:
1384-
inchunk = ochunk[:-1] + (0,) * (len(axis) - 1) + (ochunk[-1],)
1407+
name = f"reshape-{reduced.name}"
13851408

1386-
layer2[(agg_name, *ochunk)] = (operator.getitem, (reduced.name, *inchunk), agg.name)
1409+
for ochunk in itertools.product(*ochunks):
1410+
inchunk = ochunk[: -len(axis)] + np.unravel_index(ochunk[-1], nblocks)
1411+
layer2[(name, *ochunk)] = (reduced.name, *inchunk)
13871412

1388-
result = dask.array.Array(
1389-
HighLevelGraph.from_collections(agg_name, layer2, dependencies=[reduced]),
1390-
agg_name,
1413+
return dask.array.Array(
1414+
HighLevelGraph.from_collections(name, layer2, dependencies=[reduced]),
1415+
name,
13911416
chunks=output_chunks,
1392-
dtype=agg.dtype[agg.name],
1417+
dtype=reduced.dtype,
13931418
)
13941419

1395-
return (result, groups)
1420+
1421+
def _extract_result(result_dict: FinalResultsDict, key) -> np.ndarray:
1422+
from dask.array.core import deepfirst
1423+
1424+
# deepfirst should be not be needed here but sometimes we receive a list of dict?
1425+
return deepfirst(result_dict)[key]
13961426

13971427

13981428
def _validate_reindex(

0 commit comments

Comments
 (0)