@@ -136,7 +136,7 @@ def _get_optimal_chunks_for_groups(chunks, labels):
136
136
return tuple (newchunks )
137
137
138
138
139
- def _unique (a : np .ndarray ):
139
+ def _unique (a : np .ndarray ) -> np . ndarray :
140
140
"""Much faster to use pandas unique and sort the results.
141
141
np.unique sorts before uniquifying and is slow."""
142
142
return np .sort (pd .unique (a .reshape (- 1 )))
@@ -816,8 +816,25 @@ def _expand_dims(results: IntermediateDict) -> IntermediateDict:
816
816
return results
817
817
818
818
819
+ def _find_unique_groups (x_chunk ) -> np .ndarray :
820
+ from dask .base import flatten
821
+ from dask .utils import deepmap
822
+
823
+ unique_groups = _unique (np .asarray (tuple (flatten (deepmap (listify_groups , x_chunk )))))
824
+ unique_groups = unique_groups [~ isnull (unique_groups )]
825
+
826
+ if len (unique_groups ) == 0 :
827
+ unique_groups = np .array ([np .nan ])
828
+ return unique_groups
829
+
830
+
819
831
def _simple_combine (
820
- x_chunk , agg : Aggregation , axis : T_Axes , keepdims : bool , is_aggregate : bool = False
832
+ x_chunk ,
833
+ agg : Aggregation ,
834
+ axis : T_Axes ,
835
+ keepdims : bool ,
836
+ reindex : bool ,
837
+ is_aggregate : bool = False ,
821
838
) -> IntermediateDict :
822
839
"""
823
840
'Simple' combination of blockwise results.
@@ -830,8 +847,19 @@ def _simple_combine(
830
847
4. At the final agggregate step, we squeeze out DUMMY_AXIS
831
848
"""
832
849
from dask .array .core import deepfirst
850
+ from dask .utils import deepmap
851
+
852
+ if not reindex :
853
+ # We didn't reindex at the blockwise step
854
+ # So now reindex before combining by reducing along DUMMY_AXIS
855
+ unique_groups = _find_unique_groups (x_chunk )
856
+ x_chunk = deepmap (
857
+ partial (reindex_intermediates , agg = agg , unique_groups = unique_groups ), x_chunk
858
+ )
859
+ else :
860
+ unique_groups = deepfirst (x_chunk )["groups" ]
833
861
834
- results : IntermediateDict = {"groups" : deepfirst ( x_chunk )[ "groups" ] }
862
+ results : IntermediateDict = {"groups" : unique_groups }
835
863
results ["intermediates" ] = []
836
864
axis_ = axis [:- 1 ] + (DUMMY_AXIS ,)
837
865
for idx , combine in enumerate (agg .combine ):
@@ -886,7 +914,6 @@ def _grouped_combine(
886
914
sort : bool = True ,
887
915
) -> IntermediateDict :
888
916
"""Combine intermediates step of tree reduction."""
889
- from dask .base import flatten
890
917
from dask .utils import deepmap
891
918
892
919
if isinstance (x_chunk , dict ):
@@ -897,11 +924,7 @@ def _grouped_combine(
897
924
# when there's only a single axis of reduction, we can just concatenate later,
898
925
# reindexing is unnecessary
899
926
# I bet we can minimize the amount of reindexing for mD reductions too, but it's complicated
900
- unique_groups = _unique (np .array (tuple (flatten (deepmap (listify_groups , x_chunk )))))
901
- unique_groups = unique_groups [~ isnull (unique_groups )]
902
- if len (unique_groups ) == 0 :
903
- unique_groups = [np .nan ]
904
-
927
+ unique_groups = _find_unique_groups (x_chunk )
905
928
x_chunk = deepmap (
906
929
partial (reindex_intermediates , agg = agg , unique_groups = unique_groups ), x_chunk
907
930
)
@@ -1216,7 +1239,8 @@ def dask_groupby_agg(
1216
1239
# This allows us to discover groups at compute time, support argreductions, lower intermediate
1217
1240
# memory usage (but method="cohorts" would also work to reduce memory in some cases)
1218
1241
1219
- do_simple_combine = method != "blockwise" and reindex and not _is_arg_reduction (agg )
1242
+ do_simple_combine = not _is_arg_reduction (agg )
1243
+
1220
1244
if method == "blockwise" :
1221
1245
# use the "non dask" code path, but applied blockwise
1222
1246
blockwise_method = partial (
@@ -1268,31 +1292,32 @@ def dask_groupby_agg(
1268
1292
if method in ["map-reduce" , "cohorts" ]:
1269
1293
combine : Callable [..., IntermediateDict ]
1270
1294
if do_simple_combine :
1271
- combine = _simple_combine
1295
+ combine = partial (_simple_combine , reindex = reindex )
1296
+ combine_name = "simple-combine"
1272
1297
else :
1273
1298
combine = partial (_grouped_combine , engine = engine , sort = sort )
1299
+ combine_name = "grouped-combine"
1274
1300
1275
- # Each chunk of `reduced`` is really a dict mapping
1276
- # 1. reduction name to array
1277
- # 2. "groups" to an array of group labels
1278
- # Note: it does not make sense to interpret axis relative to
1279
- # shape of intermediate results after the blockwise call
1280
1301
tree_reduce = partial (
1281
1302
dask .array .reductions ._tree_reduce ,
1282
- combine = partial (combine , agg = agg ),
1283
- name = f"{ name } -reduce-{ method } " ,
1303
+ name = f"{ name } -reduce-{ method } -{ combine_name } " ,
1284
1304
dtype = array .dtype ,
1285
1305
axis = axis ,
1286
1306
keepdims = True ,
1287
1307
concatenate = False ,
1288
1308
)
1289
- aggregate = partial (
1290
- _aggregate , combine = combine , agg = agg , fill_value = fill_value , reindex = reindex
1291
- )
1309
+ aggregate = partial (_aggregate , combine = combine , agg = agg , fill_value = fill_value )
1310
+
1311
+ # Each chunk of `reduced`` is really a dict mapping
1312
+ # 1. reduction name to array
1313
+ # 2. "groups" to an array of group labels
1314
+ # Note: it does not make sense to interpret axis relative to
1315
+ # shape of intermediate results after the blockwise call
1292
1316
if method == "map-reduce" :
1293
1317
reduced = tree_reduce (
1294
1318
intermediate ,
1295
- aggregate = partial (aggregate , expected_groups = expected_groups ),
1319
+ combine = partial (combine , agg = agg ),
1320
+ aggregate = partial (aggregate , expected_groups = expected_groups , reindex = reindex ),
1296
1321
)
1297
1322
if is_duck_dask_array (by_input ) and expected_groups is None :
1298
1323
groups = _extract_unknown_groups (reduced , group_chunks = group_chunks , dtype = by .dtype )
@@ -1310,23 +1335,17 @@ def dask_groupby_agg(
1310
1335
reduced_ = []
1311
1336
groups_ = []
1312
1337
for blks , cohort in chunks_cohorts .items ():
1338
+ index = pd .Index (cohort )
1313
1339
subset = subset_to_blocks (intermediate , blks , array .blocks .shape [- len (axis ) :])
1314
- if do_simple_combine :
1315
- # reindex so that reindex can be set to True later
1316
- reindexed = dask .array .map_blocks (
1317
- reindex_intermediates ,
1318
- subset ,
1319
- agg = agg ,
1320
- unique_groups = cohort ,
1321
- meta = subset ._meta ,
1322
- )
1323
- else :
1324
- reindexed = subset
1325
-
1340
+ reindexed = dask .array .map_blocks (
1341
+ reindex_intermediates , subset , agg = agg , unique_groups = index , meta = subset ._meta
1342
+ )
1343
+ # now that we have reindexed, we can set reindex=True explicitlly
1326
1344
reduced_ .append (
1327
1345
tree_reduce (
1328
1346
reindexed ,
1329
- aggregate = partial (aggregate , expected_groups = cohort , reindex = reindex ),
1347
+ combine = partial (combine , agg = agg , reindex = True ),
1348
+ aggregate = partial (aggregate , expected_groups = index , reindex = True ),
1330
1349
)
1331
1350
)
1332
1351
groups_ .append (cohort )
@@ -1382,28 +1401,24 @@ def _validate_reindex(
1382
1401
if reindex is True :
1383
1402
if _is_arg_reduction (func ):
1384
1403
raise NotImplementedError
1385
- if method == "blockwise" :
1386
- raise NotImplementedError
1404
+ if method in ["blockwise" , "cohorts" ]:
1405
+ raise ValueError (
1406
+ "reindex=True is not a valid choice for method='blockwise' or method='cohorts'."
1407
+ )
1387
1408
1388
1409
if reindex is None :
1389
1410
if method == "blockwise" or _is_arg_reduction (func ):
1390
1411
reindex = False
1391
1412
1392
- elif expected_groups is not None :
1393
- reindex = True
1394
-
1395
- elif method in ["split-reduce" , "cohorts" ]:
1396
- reindex = True
1413
+ elif method == "cohorts" :
1414
+ reindex = False
1397
1415
1398
1416
elif method == "map-reduce" :
1399
1417
if expected_groups is None and by_is_dask :
1400
1418
reindex = False
1401
1419
else :
1402
1420
reindex = True
1403
1421
1404
- if method in ["split-reduce" , "cohorts" ] and reindex is False :
1405
- raise NotImplementedError
1406
-
1407
1422
assert isinstance (reindex , bool )
1408
1423
return reindex
1409
1424
0 commit comments