15
15
Any ,
16
16
Callable ,
17
17
Literal ,
18
+ TypedDict ,
18
19
Union ,
19
20
overload ,
20
21
)
87
88
DUMMY_AXIS = - 2
88
89
89
90
91
+ class FactorizeKwargs (TypedDict , total = False ):
92
+ """Used in _factorize_multiple"""
93
+
94
+ by : T_Bys
95
+ axes : T_Axes
96
+ fastpath : bool
97
+ expected_groups : T_ExpectIndexOptTuple | None
98
+ reindex : bool
99
+ sort : bool
100
+
101
+
90
102
def _postprocess_numbagg (result , * , func , fill_value , size , seen_groups ):
91
103
"""Account for numbagg not providing a fill_value kwarg."""
92
104
from .aggregate_numbagg import DEFAULT_FILL_VALUE
@@ -1434,7 +1446,7 @@ def dask_groupby_agg(
1434
1446
_ , (array , by ) = dask .array .unify_chunks (array , inds , by , inds [- by .ndim :])
1435
1447
1436
1448
# tokenize here since by has already been hashed if its numpy
1437
- token = dask .base .tokenize (array , by , agg , expected_groups , axis )
1449
+ token = dask .base .tokenize (array , by , agg , expected_groups , axis , method )
1438
1450
1439
1451
# preprocess the array:
1440
1452
# - for argreductions, this zips the index together with the array block
@@ -1454,7 +1466,8 @@ def dask_groupby_agg(
1454
1466
# b. "_grouped_combine": A more general solution where we tree-reduce the groupby reduction.
1455
1467
# This allows us to discover groups at compute time, support argreductions, lower intermediate
1456
1468
# memory usage (but method="cohorts" would also work to reduce memory in some cases)
1457
- do_simple_combine = not _is_arg_reduction (agg )
1469
+ labels_are_unknown = is_duck_dask_array (by_input ) and expected_groups is None
1470
+ do_simple_combine = not _is_arg_reduction (agg ) and not labels_are_unknown
1458
1471
1459
1472
if method == "blockwise" :
1460
1473
# use the "non dask" code path, but applied blockwise
@@ -1510,7 +1523,7 @@ def dask_groupby_agg(
1510
1523
1511
1524
tree_reduce = partial (
1512
1525
dask .array .reductions ._tree_reduce ,
1513
- name = f"{ name } -reduce- { method } " ,
1526
+ name = f"{ name } -reduce" ,
1514
1527
dtype = array .dtype ,
1515
1528
axis = axis ,
1516
1529
keepdims = True ,
@@ -1529,7 +1542,7 @@ def dask_groupby_agg(
1529
1542
combine = partial (combine , agg = agg ),
1530
1543
aggregate = partial (aggregate , expected_groups = expected_groups , reindex = reindex ),
1531
1544
)
1532
- if is_duck_dask_array ( by_input ) and expected_groups is None :
1545
+ if labels_are_unknown :
1533
1546
groups = _extract_unknown_groups (reduced , dtype = by .dtype )
1534
1547
group_chunks = ((np .nan ,),)
1535
1548
else :
@@ -1747,17 +1760,26 @@ def _convert_expected_groups_to_index(
1747
1760
1748
1761
1749
1762
def _lazy_factorize_wrapper (* by : T_By , ** kwargs ) -> np .ndarray :
1750
- group_idx , * rest = factorize_ (by , ** kwargs )
1763
+ group_idx , * _ = factorize_ (by , ** kwargs )
1751
1764
return group_idx
1752
1765
1753
1766
1754
1767
def _factorize_multiple (
1755
1768
by : T_Bys ,
1756
1769
expected_groups : T_ExpectIndexOptTuple ,
1757
1770
any_by_dask : bool ,
1758
- reindex : bool ,
1759
1771
sort : bool = True ,
1760
1772
) -> tuple [tuple [np .ndarray ], tuple [np .ndarray , ...], tuple [int , ...]]:
1773
+ kwargs : FactorizeKwargs = dict (
1774
+ axes = (), # always (), we offset later if necessary.
1775
+ expected_groups = expected_groups ,
1776
+ fastpath = True ,
1777
+ # This is the only way it makes sense I think.
1778
+ # reindex controls what's actually allocated in chunk_reduce
1779
+ # At this point, we care about an accurate conversion to codes.
1780
+ reindex = True ,
1781
+ sort = sort ,
1782
+ )
1761
1783
if any_by_dask :
1762
1784
import dask .array
1763
1785
@@ -1771,11 +1793,7 @@ def _factorize_multiple(
1771
1793
* by_ ,
1772
1794
chunks = tuple (chunks .values ()),
1773
1795
meta = np .array ((), dtype = np .int64 ),
1774
- axes = (), # always (), we offset later if necessary.
1775
- expected_groups = expected_groups ,
1776
- fastpath = True ,
1777
- reindex = reindex ,
1778
- sort = sort ,
1796
+ ** kwargs ,
1779
1797
)
1780
1798
1781
1799
fg , gs = [], []
@@ -1796,14 +1814,8 @@ def _factorize_multiple(
1796
1814
found_groups = tuple (fg )
1797
1815
grp_shape = tuple (gs )
1798
1816
else :
1799
- group_idx , found_groups , grp_shape , ngroups , size , props = factorize_ (
1800
- by ,
1801
- axes = (), # always (), we offset later if necessary.
1802
- expected_groups = expected_groups ,
1803
- fastpath = True ,
1804
- reindex = reindex ,
1805
- sort = sort ,
1806
- )
1817
+ kwargs ["by" ] = by
1818
+ group_idx , found_groups , grp_shape , * _ = factorize_ (** kwargs )
1807
1819
1808
1820
return (group_idx ,), found_groups , grp_shape
1809
1821
@@ -2058,7 +2070,7 @@ def groupby_reduce(
2058
2070
# (pd.IntervalIndex or not)
2059
2071
expected_groups = _convert_expected_groups_to_index (expected_groups , isbins , sort )
2060
2072
2061
- # Don't factorize " early only when
2073
+ # Don't factorize early only when
2062
2074
# grouping by dask arrays, and not having expected_groups
2063
2075
factorize_early = not (
2064
2076
# can't do it if we are grouping by dask array but don't have expected_groups
@@ -2069,10 +2081,6 @@ def groupby_reduce(
2069
2081
bys ,
2070
2082
expected_groups ,
2071
2083
any_by_dask = any_by_dask ,
2072
- # This is the only way it makes sense I think.
2073
- # reindex controls what's actually allocated in chunk_reduce
2074
- # At this point, we care about an accurate conversion to codes.
2075
- reindex = True ,
2076
2084
sort = sort ,
2077
2085
)
2078
2086
expected_groups = (pd .RangeIndex (math .prod (grp_shape )),)
@@ -2103,21 +2111,17 @@ def groupby_reduce(
2103
2111
"along a single axis or when reducing across all dimensions of `by`."
2104
2112
)
2105
2113
2106
- # TODO: make sure expected_groups is unique
2107
2114
if nax == 1 and by_ .ndim > 1 and expected_groups is None :
2108
- if not any_by_dask :
2109
- expected_groups = _get_expected_groups (by_ , sort )
2110
- else :
2111
- # When we reduce along all axes, we are guaranteed to see all
2112
- # groups in the final combine stage, so everything works.
2113
- # This is not necessarily true when reducing along a subset of axes
2114
- # (of by)
2115
- # TODO: Does this depend on chunking of by?
2116
- # For e.g., we could relax this if there is only one chunk along all
2117
- # by dim != axis?
2118
- raise NotImplementedError (
2119
- "Please provide ``expected_groups`` when not reducing along all axes."
2120
- )
2115
+ # When we reduce along all axes, we are guaranteed to see all
2116
+ # groups in the final combine stage, so everything works.
2117
+ # This is not necessarily true when reducing along a subset of axes
2118
+ # (of by)
2119
+ # TODO: Does this depend on chunking of by?
2120
+ # For e.g., we could relax this if there is only one chunk along all
2121
+ # by dim != axis?
2122
+ raise NotImplementedError (
2123
+ "Please provide ``expected_groups`` when not reducing along all axes."
2124
+ )
2121
2125
2122
2126
assert nax <= by_ .ndim
2123
2127
if nax < by_ .ndim :
0 commit comments