@@ -1195,7 +1195,6 @@ def dask_groupby_agg(
1195
1195
1196
1196
import dask .array
1197
1197
from dask .array .core import slices_from_chunks
1198
- from dask .highlevelgraph import HighLevelGraph
1199
1198
1200
1199
# I think _tree_reduce expects this
1201
1200
assert isinstance (axis , Sequence )
@@ -1268,6 +1267,9 @@ def dask_groupby_agg(
1268
1267
engine = engine ,
1269
1268
sort = sort ,
1270
1269
),
1270
+ # output indices are the same as input indices
1271
+ # Unlike xhistogram, we don't always know what the size of the group
1272
+ # dimension will be unless reindex=True
1271
1273
inds ,
1272
1274
array ,
1273
1275
inds ,
@@ -1277,7 +1279,7 @@ def dask_groupby_agg(
1277
1279
dtype = array .dtype , # this is purely for show
1278
1280
meta = array ._meta ,
1279
1281
align_arrays = False ,
1280
- token = f"{ name } -chunk-{ token } " ,
1282
+ name = f"{ name } -chunk-{ token } " ,
1281
1283
)
1282
1284
1283
1285
if expected_groups is None :
@@ -1364,35 +1366,63 @@ def dask_groupby_agg(
1364
1366
groups = (np .concatenate (groups_in_block ),)
1365
1367
ngroups_per_block = tuple (len (grp ) for grp in groups_in_block )
1366
1368
group_chunks = (ngroups_per_block ,)
1367
-
1368
1369
else :
1369
1370
raise ValueError (f"Unknown method={ method } ." )
1370
1371
1371
- # extract results from the dict
1372
+ out_inds = inds [: - len ( axis )] + ( inds [ - 1 ],)
1372
1373
output_chunks = reduced .chunks [: - len (axis )] + group_chunks
1374
+ if method == "blockwise" and len (axis ) > 1 :
1375
+ # The final results are available but the blocks along axes
1376
+ # need to be reshaped to axis=-1
1377
+ # I don't know that this is possible with blockwise
1378
+ # All other code paths benefit from an unmaterialized Blockwise layer
1379
+ reduced = _collapse_blocks_along_axes (reduced , axis , group_chunks )
1380
+
1381
+ # Can't use map_blocks because it forces concatenate=True along drop_axes,
1382
+ result = dask .array .blockwise (
1383
+ _extract_result ,
1384
+ out_inds ,
1385
+ reduced ,
1386
+ inds ,
1387
+ adjust_chunks = dict (zip (out_inds , output_chunks )),
1388
+ dtype = agg .dtype [agg .name ],
1389
+ key = agg .name ,
1390
+ name = f"{ name } -{ token } " ,
1391
+ concatenate = False ,
1392
+ )
1393
+
1394
+ return (result , groups )
1395
+
1396
+
1397
+ def _collapse_blocks_along_axes (reduced , axis , group_chunks ):
1398
+ import dask .array
1399
+ from dask .highlevelgraph import HighLevelGraph
1400
+
1401
+ nblocks = tuple (reduced .numblocks [ax ] for ax in axis )
1402
+ output_chunks = reduced .chunks [: - len (axis )] + ((1 ,) * (len (axis ) - 1 ),) + group_chunks
1403
+
1404
+ # extract results from the dict
1373
1405
ochunks = tuple (range (len (chunks_v )) for chunks_v in output_chunks )
1374
1406
layer2 : dict [tuple , tuple ] = {}
1375
- agg_name = f"{ name } -{ token } "
1376
- for ochunk in itertools .product (* ochunks ):
1377
- if method == "blockwise" :
1378
- if len (axis ) == 1 :
1379
- inchunk = ochunk
1380
- else :
1381
- nblocks = tuple (len (array .chunks [ax ]) for ax in axis )
1382
- inchunk = ochunk [:- 1 ] + np .unravel_index (ochunk [- 1 ], nblocks )
1383
- else :
1384
- inchunk = ochunk [:- 1 ] + (0 ,) * (len (axis ) - 1 ) + (ochunk [- 1 ],)
1407
+ name = f"reshape-{ reduced .name } "
1385
1408
1386
- layer2 [(agg_name , * ochunk )] = (operator .getitem , (reduced .name , * inchunk ), agg .name )
1409
+ for ochunk in itertools .product (* ochunks ):
1410
+ inchunk = ochunk [: - len (axis )] + np .unravel_index (ochunk [- 1 ], nblocks )
1411
+ layer2 [(name , * ochunk )] = (reduced .name , * inchunk )
1387
1412
1388
- result = dask .array .Array (
1389
- HighLevelGraph .from_collections (agg_name , layer2 , dependencies = [reduced ]),
1390
- agg_name ,
1413
+ return dask .array .Array (
1414
+ HighLevelGraph .from_collections (name , layer2 , dependencies = [reduced ]),
1415
+ name ,
1391
1416
chunks = output_chunks ,
1392
- dtype = agg .dtype [ agg . name ] ,
1417
+ dtype = reduced .dtype ,
1393
1418
)
1394
1419
1395
- return (result , groups )
1420
+
1421
+ def _extract_result (result_dict : FinalResultsDict , key ) -> np .ndarray :
1422
+ from dask .array .core import deepfirst
1423
+
1424
+ # deepfirst should be not be needed here but sometimes we receive a list of dict?
1425
+ return deepfirst (result_dict )[key ]
1396
1426
1397
1427
1398
1428
def _validate_reindex (
0 commit comments