33
33
generic_aggregate ,
34
34
)
35
35
from .cache import memoize
36
- from .xrutils import is_duck_array , is_duck_dask_array , isnull
36
+ from .xrutils import is_duck_array , is_duck_dask_array , isnull , module_available
37
+
38
+ HAS_NUMBAGG = module_available ("numbagg" , minversion = "0.3.0" )
37
39
38
40
if TYPE_CHECKING :
39
41
try :
69
71
T_Dtypes = Union [np .typing .DTypeLike , Sequence [np .typing .DTypeLike ], None ]
70
72
T_FillValues = Union [np .typing .ArrayLike , Sequence [np .typing .ArrayLike ], None ]
71
73
T_Engine = Literal ["flox" , "numpy" , "numba" , "numbagg" ]
74
+ T_EngineOpt = None | T_Engine
72
75
T_Method = Literal ["map-reduce" , "blockwise" , "cohorts" ]
73
76
T_IsBins = Union [bool | Sequence [bool ]]
74
77
83
86
DUMMY_AXIS = - 2
84
87
85
88
89
+ def _issorted (arr : np .ndarray ) -> bool :
90
+ return bool ((arr [:- 1 ] <= arr [1 :]).all ())
91
+
92
+
86
93
def _is_arg_reduction (func : T_Agg ) -> bool :
87
94
if isinstance (func , str ) and func in ["argmin" , "argmax" , "nanargmax" , "nanargmin" ]:
88
95
return True
@@ -632,6 +639,7 @@ def chunk_argreduce(
632
639
reindex : bool = False ,
633
640
engine : T_Engine = "numpy" ,
634
641
sort : bool = True ,
642
+ user_dtype = None ,
635
643
) -> IntermediateDict :
636
644
"""
637
645
Per-chunk arg reduction.
@@ -652,6 +660,7 @@ def chunk_argreduce(
652
660
dtype = dtype ,
653
661
engine = engine ,
654
662
sort = sort ,
663
+ user_dtype = user_dtype ,
655
664
)
656
665
if not isnull (results ["groups" ]).all ():
657
666
idx = np .broadcast_to (idx , array .shape )
@@ -685,6 +694,7 @@ def chunk_reduce(
685
694
engine : T_Engine = "numpy" ,
686
695
kwargs : Sequence [dict ] | None = None ,
687
696
sort : bool = True ,
697
+ user_dtype = None ,
688
698
) -> IntermediateDict :
689
699
"""
690
700
Wrapper for numpy_groupies aggregate that supports nD ``array`` and
@@ -785,6 +795,7 @@ def chunk_reduce(
785
795
group_idx = group_idx .reshape (- 1 )
786
796
787
797
assert group_idx .ndim == 1
798
+
788
799
empty = np .all (props .nanmask )
789
800
790
801
results : IntermediateDict = {"groups" : [], "intermediates" : []}
@@ -1100,6 +1111,7 @@ def _grouped_combine(
1100
1111
dtype = (np .intp ,),
1101
1112
engine = engine ,
1102
1113
sort = sort ,
1114
+ user_dtype = agg .dtype ["user" ],
1103
1115
)["intermediates" ][0 ]
1104
1116
)
1105
1117
@@ -1129,6 +1141,7 @@ def _grouped_combine(
1129
1141
dtype = (dtype ,),
1130
1142
engine = engine ,
1131
1143
sort = sort ,
1144
+ user_dtype = agg .dtype ["user" ],
1132
1145
)
1133
1146
results ["intermediates" ].append (* _results ["intermediates" ])
1134
1147
results ["groups" ] = _results ["groups" ]
@@ -1174,6 +1187,7 @@ def _reduce_blockwise(
1174
1187
engine = engine ,
1175
1188
sort = sort ,
1176
1189
reindex = reindex ,
1190
+ user_dtype = agg .dtype ["user" ],
1177
1191
)
1178
1192
1179
1193
if _is_arg_reduction (agg ):
@@ -1366,6 +1380,7 @@ def dask_groupby_agg(
1366
1380
fill_value = agg .fill_value ["intermediate" ],
1367
1381
dtype = agg .dtype ["intermediate" ],
1368
1382
reindex = reindex ,
1383
+ user_dtype = agg .dtype ["user" ],
1369
1384
)
1370
1385
if do_simple_combine :
1371
1386
# Add a dummy dimension that then gets reduced over
@@ -1757,6 +1772,23 @@ def _validate_expected_groups(nby: int, expected_groups: T_ExpectedGroupsOpt) ->
1757
1772
return expected_groups
1758
1773
1759
1774
1775
+ def _choose_engine (by , agg : Aggregation ):
1776
+ dtype = agg .dtype ["user" ]
1777
+
1778
+ not_arg_reduce = not _is_arg_reduction (agg )
1779
+
1780
+ # numbagg only supports nan-skipping reductions
1781
+ # without dtype specified
1782
+ if HAS_NUMBAGG and "nan" in agg .name :
1783
+ if not_arg_reduce and dtype is None :
1784
+ return "numbagg"
1785
+
1786
+ if not_arg_reduce and (not is_duck_dask_array (by ) and _issorted (by )):
1787
+ return "flox"
1788
+ else :
1789
+ return "numpy"
1790
+
1791
+
1760
1792
def groupby_reduce (
1761
1793
array : np .ndarray | DaskArray ,
1762
1794
* by : T_By ,
@@ -1769,7 +1801,7 @@ def groupby_reduce(
1769
1801
dtype : np .typing .DTypeLike = None ,
1770
1802
min_count : int | None = None ,
1771
1803
method : T_Method = "map-reduce" ,
1772
- engine : T_Engine = "numpy" ,
1804
+ engine : T_EngineOpt = None ,
1773
1805
reindex : bool | None = None ,
1774
1806
finalize_kwargs : dict [Any , Any ] | None = None ,
1775
1807
) -> tuple [DaskArray , Unpack [tuple [np .ndarray | DaskArray , ...]]]: # type: ignore[misc] # Unpack not in mypy yet
@@ -2027,9 +2059,14 @@ def groupby_reduce(
2027
2059
# overwrite than when min_count is set
2028
2060
fill_value = np .nan
2029
2061
2030
- kwargs = dict (axis = axis_ , fill_value = fill_value , engine = engine )
2062
+ kwargs = dict (axis = axis_ , fill_value = fill_value )
2031
2063
agg = _initialize_aggregation (func , dtype , array .dtype , fill_value , min_count_ , finalize_kwargs )
2032
2064
2065
+ # Need to set this early using `agg`
2066
+ # It cannot be done in the core loop of chunk_reduce
2067
+ # since we "prepare" the data for flox.
2068
+ kwargs ["engine" ] = _choose_engine (by_ , agg ) if engine is None else engine
2069
+
2033
2070
groups : tuple [np .ndarray | DaskArray , ...]
2034
2071
if not has_dask :
2035
2072
results = _reduce_blockwise (
@@ -2080,7 +2117,7 @@ def groupby_reduce(
2080
2117
assert len (groups ) == 1
2081
2118
sorted_idx = np .argsort (groups [0 ])
2082
2119
# This optimization helps specifically with resampling
2083
- if not (sorted_idx [: - 1 ] <= sorted_idx [ 1 :]). all ( ):
2120
+ if not _issorted (sorted_idx ):
2084
2121
result = result [..., sorted_idx ]
2085
2122
groups = (groups [0 ][sorted_idx ],)
2086
2123
0 commit comments