@@ -1494,8 +1494,9 @@ def _normalize_indexes(array: DaskArray, flatblocks, blkshape) -> tuple:
1494
1494
def subset_to_blocks (
1495
1495
array : DaskArray ,
1496
1496
flatblocks : Sequence [int ],
1497
- blkshape : tuple [int ] | None = None ,
1497
+ blkshape : tuple [int , ... ] | None = None ,
1498
1498
reindexer = identity ,
1499
+ chunks_as_array : tuple [np .ndarray , ...] | None = None ,
1499
1500
) -> DaskArray :
1500
1501
"""
1501
1502
Advanced indexing of .blocks such that we always get a regular array back.
@@ -1518,6 +1519,9 @@ def subset_to_blocks(
1518
1519
if blkshape is None :
1519
1520
blkshape = array .blocks .shape
1520
1521
1522
+ if chunks_as_array is None :
1523
+ chunks_as_array = tuple (np .array (c ) for c in array .chunks )
1524
+
1521
1525
index = _normalize_indexes (array , flatblocks , blkshape )
1522
1526
1523
1527
if all (not isinstance (i , np .ndarray ) and i == slice (None ) for i in index ):
@@ -1531,7 +1535,7 @@ def subset_to_blocks(
1531
1535
new_keys = array ._key_array [index ]
1532
1536
1533
1537
squeezed = tuple (np .squeeze (i ) if isinstance (i , np .ndarray ) else i for i in index )
1534
- chunks = tuple (tuple (np . array ( c ) [i ].tolist ()) for c , i in zip (array . chunks , squeezed ))
1538
+ chunks = tuple (tuple (c [i ].tolist ()) for c , i in zip (chunks_as_array , squeezed ))
1535
1539
1536
1540
keys = itertools .product (* (range (len (c )) for c in chunks ))
1537
1541
layer : Graph = {(name ,) + key : (reindexer , tuple (new_keys [key ].tolist ())) for key in keys }
@@ -1726,14 +1730,15 @@ def dask_groupby_agg(
1726
1730
1727
1731
reduced_ = []
1728
1732
groups_ = []
1733
+ chunks_as_array = tuple (np .array (c ) for c in array .chunks )
1729
1734
for blks , cohort in chunks_cohorts .items ():
1730
1735
cohort_index = pd .Index (cohort )
1731
1736
reindexer = (
1732
1737
partial (reindex_intermediates , agg = agg , unique_groups = cohort_index )
1733
1738
if do_simple_combine
1734
1739
else identity
1735
1740
)
1736
- reindexed = subset_to_blocks (intermediate , blks , block_shape , reindexer )
1741
+ reindexed = subset_to_blocks (intermediate , blks , block_shape , reindexer , chunks_as_array )
1737
1742
# now that we have reindexed, we can set reindex=True explicitlly
1738
1743
reduced_ .append (
1739
1744
tree_reduce (
0 commit comments