@@ -1459,6 +1459,10 @@ def groupby_reduce(
1459
1459
by : tuple = tuple (np .asarray (b ) if not is_duck_array (b ) else b for b in by )
1460
1460
nby = len (by )
1461
1461
by_is_dask = any (is_duck_dask_array (b ) for b in by )
1462
+
1463
+ if method in ["split-reduce" , "cohorts" ] and by_is_dask :
1464
+ raise ValueError (f"method={ method !r} can only be used when grouping by numpy arrays." )
1465
+
1462
1466
if not is_duck_array (array ):
1463
1467
array = np .asarray (array )
1464
1468
if isinstance (isbin , bool ):
@@ -1477,9 +1481,11 @@ def groupby_reduce(
1477
1481
# (pd.IntervalIndex or not)
1478
1482
expected_groups = _convert_expected_groups_to_index (expected_groups , isbin , sort )
1479
1483
1480
- # when grouping by multiple variables, we factorize early.
1481
1484
# TODO: could restrict this to dask-only
1482
- if nby > 1 :
1485
+ factorize_early = (nby > 1 ) or (
1486
+ any (isbin ) and method in ["split-reduce" , "cohorts" ] and is_duck_dask_array (array )
1487
+ )
1488
+ if factorize_early :
1483
1489
by , final_groups , grp_shape = _factorize_multiple (
1484
1490
by , expected_groups , by_is_dask = by_is_dask
1485
1491
)
@@ -1497,6 +1503,7 @@ def groupby_reduce(
1497
1503
if method in ["blockwise" , "cohorts" , "split-reduce" ] and len (axis ) != by .ndim :
1498
1504
raise NotImplementedError (
1499
1505
"Must reduce along all dimensions of `by` when method != 'map-reduce'."
1506
+ f"Received method={ method !r} "
1500
1507
)
1501
1508
1502
1509
# TODO: make sure expected_groups is unique
@@ -1617,10 +1624,12 @@ def groupby_reduce(
1617
1624
result = result [..., sorted_idx ]
1618
1625
groups = (groups [0 ][sorted_idx ],)
1619
1626
1620
- if nby > 1 :
1627
+ if factorize_early :
1621
1628
# nan group labels are factorized to -1, and preserved
1622
- # now we get rid of them
1623
- nanmask = groups [0 ] == - 1
1629
+ # now we get rid of them by reindexing
1630
+ # This also handles bins with no data
1631
+ result = reindex_ (
1632
+ result , from_ = groups [0 ], to = expected_groups , fill_value = fill_value
1633
+ ).reshape (result .shape [:- 1 ] + grp_shape )
1624
1634
groups = final_groups
1625
- result = result [..., ~ nanmask ].reshape (result .shape [:- 1 ] + grp_shape )
1626
1635
return (result , * groups )
0 commit comments