Skip to content

Commit 07a15c4

Browse files
authored
Faster subsetting for cohorts (#397)
* Faster subsetting for cohorts Closes #396 * tpying
1 parent 8556811 commit 07a15c4

File tree

1 file changed

+8
-3
lines changed

1 file changed

+8
-3
lines changed

flox/core.py

+8-3
Original file line numberDiff line numberDiff line change
@@ -1494,8 +1494,9 @@ def _normalize_indexes(array: DaskArray, flatblocks, blkshape) -> tuple:
14941494
def subset_to_blocks(
14951495
array: DaskArray,
14961496
flatblocks: Sequence[int],
1497-
blkshape: tuple[int] | None = None,
1497+
blkshape: tuple[int, ...] | None = None,
14981498
reindexer=identity,
1499+
chunks_as_array: tuple[np.ndarray, ...] | None = None,
14991500
) -> DaskArray:
15001501
"""
15011502
Advanced indexing of .blocks such that we always get a regular array back.
@@ -1518,6 +1519,9 @@ def subset_to_blocks(
15181519
if blkshape is None:
15191520
blkshape = array.blocks.shape
15201521

1522+
if chunks_as_array is None:
1523+
chunks_as_array = tuple(np.array(c) for c in array.chunks)
1524+
15211525
index = _normalize_indexes(array, flatblocks, blkshape)
15221526

15231527
if all(not isinstance(i, np.ndarray) and i == slice(None) for i in index):
@@ -1531,7 +1535,7 @@ def subset_to_blocks(
15311535
new_keys = array._key_array[index]
15321536

15331537
squeezed = tuple(np.squeeze(i) if isinstance(i, np.ndarray) else i for i in index)
1534-
chunks = tuple(tuple(np.array(c)[i].tolist()) for c, i in zip(array.chunks, squeezed))
1538+
chunks = tuple(tuple(c[i].tolist()) for c, i in zip(chunks_as_array, squeezed))
15351539

15361540
keys = itertools.product(*(range(len(c)) for c in chunks))
15371541
layer: Graph = {(name,) + key: (reindexer, tuple(new_keys[key].tolist())) for key in keys}
@@ -1726,14 +1730,15 @@ def dask_groupby_agg(
17261730

17271731
reduced_ = []
17281732
groups_ = []
1733+
chunks_as_array = tuple(np.array(c) for c in array.chunks)
17291734
for blks, cohort in chunks_cohorts.items():
17301735
cohort_index = pd.Index(cohort)
17311736
reindexer = (
17321737
partial(reindex_intermediates, agg=agg, unique_groups=cohort_index)
17331738
if do_simple_combine
17341739
else identity
17351740
)
1736-
reindexed = subset_to_blocks(intermediate, blks, block_shape, reindexer)
1741+
reindexed = subset_to_blocks(intermediate, blks, block_shape, reindexer, chunks_as_array)
17371742
# now that we have reindexed, we can set reindex=True explicitlly
17381743
reduced_.append(
17391744
tree_reduce(

0 commit comments

Comments
 (0)