Skip to content

Commit fecd9a6

Browse files
mathausedcherian
andauthored
use engine flox for ordered groups (#266)
* use engine flox for ordered groups * Add issorted helper func * Some fixes * In xarray too * formatting * simplify * retry * flox * minversion numabgg * cleanup * fix type * update gitignore * add types * Fix env? * fix * fix merge * cleanup * [skip-ci] bench * temporarily disable numbagg * don't cache env * Finally! * bugfix * Fix doctest * more fixes * Fix CI * readd numbagg * Fix. --------- Co-authored-by: Deepak Cherian <[email protected]> Co-authored-by: Deepak Cherian <[email protected]>
1 parent 789cf73 commit fecd9a6

File tree

9 files changed

+116
-20
lines changed

9 files changed

+116
-20
lines changed

.github/workflows/ci-additional.yaml

+4-1
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,10 @@ jobs:
7272
conda list
7373
- name: Run doctests
7474
run: |
75-
python -m pytest --doctest-modules flox --ignore flox/tests --cov=./ --cov-report=xml
75+
python -m pytest --doctest-modules \
76+
flox/aggregations.py flox/core.py flox/xarray.py \
77+
--ignore flox/tests \
78+
--cov=./ --cov-report=xml
7679
- name: Upload code coverage to Codecov
7780
uses: codecov/[email protected]
7881
with:

.gitignore

+1
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
asv_bench/pkgs/
12
docs/source/generated/
23
html/
34
.asv/

asv_bench/benchmarks/reduce.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
N = 3000
99
funcs = ["sum", "nansum", "mean", "nanmean", "max", "nanmax", "count"]
10-
engines = ["flox", "numpy", "numbagg"]
10+
engines = [None, "flox", "numpy", "numbagg"]
1111
expected_groups = {
1212
"None": None,
1313
"bins": pd.IntervalIndex.from_breaks([1, 2, 4]),

ci/environment.yml

+1
Original file line numberDiff line numberDiff line change
@@ -23,4 +23,5 @@ dependencies:
2323
- pooch
2424
- toolz
2525
- numba
26+
- numbagg>=0.3
2627
- scipy

flox/aggregations.py

+2
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ class AggDtypeInit(TypedDict):
2929

3030

3131
class AggDtype(TypedDict):
32+
user: DTypeLike | None
3233
final: np.dtype
3334
numpy: tuple[np.dtype | type[np.intp], ...]
3435
intermediate: tuple[np.dtype | type[np.intp], ...]
@@ -569,6 +570,7 @@ def _initialize_aggregation(
569570

570571
final_dtype = _normalize_dtype(dtype_ or agg.dtype_init["final"], array_dtype, fill_value)
571572
agg.dtype = {
573+
"user": dtype, # Save to automatically choose an engine
572574
"final": final_dtype,
573575
"numpy": (final_dtype,),
574576
"intermediate": tuple(

flox/core.py

+41-4
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,9 @@
3333
generic_aggregate,
3434
)
3535
from .cache import memoize
36-
from .xrutils import is_duck_array, is_duck_dask_array, isnull
36+
from .xrutils import is_duck_array, is_duck_dask_array, isnull, module_available
37+
38+
HAS_NUMBAGG = module_available("numbagg", minversion="0.3.0")
3739

3840
if TYPE_CHECKING:
3941
try:
@@ -69,6 +71,7 @@
6971
T_Dtypes = Union[np.typing.DTypeLike, Sequence[np.typing.DTypeLike], None]
7072
T_FillValues = Union[np.typing.ArrayLike, Sequence[np.typing.ArrayLike], None]
7173
T_Engine = Literal["flox", "numpy", "numba", "numbagg"]
74+
T_EngineOpt = None | T_Engine
7275
T_Method = Literal["map-reduce", "blockwise", "cohorts"]
7376
T_IsBins = Union[bool | Sequence[bool]]
7477

@@ -83,6 +86,10 @@
8386
DUMMY_AXIS = -2
8487

8588

89+
def _issorted(arr: np.ndarray) -> bool:
90+
return bool((arr[:-1] <= arr[1:]).all())
91+
92+
8693
def _is_arg_reduction(func: T_Agg) -> bool:
8794
if isinstance(func, str) and func in ["argmin", "argmax", "nanargmax", "nanargmin"]:
8895
return True
@@ -632,6 +639,7 @@ def chunk_argreduce(
632639
reindex: bool = False,
633640
engine: T_Engine = "numpy",
634641
sort: bool = True,
642+
user_dtype=None,
635643
) -> IntermediateDict:
636644
"""
637645
Per-chunk arg reduction.
@@ -652,6 +660,7 @@ def chunk_argreduce(
652660
dtype=dtype,
653661
engine=engine,
654662
sort=sort,
663+
user_dtype=user_dtype,
655664
)
656665
if not isnull(results["groups"]).all():
657666
idx = np.broadcast_to(idx, array.shape)
@@ -685,6 +694,7 @@ def chunk_reduce(
685694
engine: T_Engine = "numpy",
686695
kwargs: Sequence[dict] | None = None,
687696
sort: bool = True,
697+
user_dtype=None,
688698
) -> IntermediateDict:
689699
"""
690700
Wrapper for numpy_groupies aggregate that supports nD ``array`` and
@@ -785,6 +795,7 @@ def chunk_reduce(
785795
group_idx = group_idx.reshape(-1)
786796

787797
assert group_idx.ndim == 1
798+
788799
empty = np.all(props.nanmask)
789800

790801
results: IntermediateDict = {"groups": [], "intermediates": []}
@@ -1100,6 +1111,7 @@ def _grouped_combine(
11001111
dtype=(np.intp,),
11011112
engine=engine,
11021113
sort=sort,
1114+
user_dtype=agg.dtype["user"],
11031115
)["intermediates"][0]
11041116
)
11051117

@@ -1129,6 +1141,7 @@ def _grouped_combine(
11291141
dtype=(dtype,),
11301142
engine=engine,
11311143
sort=sort,
1144+
user_dtype=agg.dtype["user"],
11321145
)
11331146
results["intermediates"].append(*_results["intermediates"])
11341147
results["groups"] = _results["groups"]
@@ -1174,6 +1187,7 @@ def _reduce_blockwise(
11741187
engine=engine,
11751188
sort=sort,
11761189
reindex=reindex,
1190+
user_dtype=agg.dtype["user"],
11771191
)
11781192

11791193
if _is_arg_reduction(agg):
@@ -1366,6 +1380,7 @@ def dask_groupby_agg(
13661380
fill_value=agg.fill_value["intermediate"],
13671381
dtype=agg.dtype["intermediate"],
13681382
reindex=reindex,
1383+
user_dtype=agg.dtype["user"],
13691384
)
13701385
if do_simple_combine:
13711386
# Add a dummy dimension that then gets reduced over
@@ -1757,6 +1772,23 @@ def _validate_expected_groups(nby: int, expected_groups: T_ExpectedGroupsOpt) ->
17571772
return expected_groups
17581773

17591774

1775+
def _choose_engine(by, agg: Aggregation):
1776+
dtype = agg.dtype["user"]
1777+
1778+
not_arg_reduce = not _is_arg_reduction(agg)
1779+
1780+
# numbagg only supports nan-skipping reductions
1781+
# without dtype specified
1782+
if HAS_NUMBAGG and "nan" in agg.name:
1783+
if not_arg_reduce and dtype is None:
1784+
return "numbagg"
1785+
1786+
if not_arg_reduce and (not is_duck_dask_array(by) and _issorted(by)):
1787+
return "flox"
1788+
else:
1789+
return "numpy"
1790+
1791+
17601792
def groupby_reduce(
17611793
array: np.ndarray | DaskArray,
17621794
*by: T_By,
@@ -1769,7 +1801,7 @@ def groupby_reduce(
17691801
dtype: np.typing.DTypeLike = None,
17701802
min_count: int | None = None,
17711803
method: T_Method = "map-reduce",
1772-
engine: T_Engine = "numpy",
1804+
engine: T_EngineOpt = None,
17731805
reindex: bool | None = None,
17741806
finalize_kwargs: dict[Any, Any] | None = None,
17751807
) -> tuple[DaskArray, Unpack[tuple[np.ndarray | DaskArray, ...]]]: # type: ignore[misc] # Unpack not in mypy yet
@@ -2027,9 +2059,14 @@ def groupby_reduce(
20272059
# overwrite than when min_count is set
20282060
fill_value = np.nan
20292061

2030-
kwargs = dict(axis=axis_, fill_value=fill_value, engine=engine)
2062+
kwargs = dict(axis=axis_, fill_value=fill_value)
20312063
agg = _initialize_aggregation(func, dtype, array.dtype, fill_value, min_count_, finalize_kwargs)
20322064

2065+
# Need to set this early using `agg`
2066+
# It cannot be done in the core loop of chunk_reduce
2067+
# since we "prepare" the data for flox.
2068+
kwargs["engine"] = _choose_engine(by_, agg) if engine is None else engine
2069+
20332070
groups: tuple[np.ndarray | DaskArray, ...]
20342071
if not has_dask:
20352072
results = _reduce_blockwise(
@@ -2080,7 +2117,7 @@ def groupby_reduce(
20802117
assert len(groups) == 1
20812118
sorted_idx = np.argsort(groups[0])
20822119
# This optimization helps specifically with resampling
2083-
if not (sorted_idx[:-1] <= sorted_idx[1:]).all():
2120+
if not _issorted(sorted_idx):
20842121
result = result[..., sorted_idx]
20852122
groups = (groups[0][sorted_idx],)
20862123

flox/xarray.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ def xarray_reduce(
7373
fill_value=None,
7474
dtype: np.typing.DTypeLike = None,
7575
method: str = "map-reduce",
76-
engine: str = "numpy",
76+
engine: str | None = None,
7777
keep_attrs: bool | None = True,
7878
skipna: bool | None = None,
7979
min_count: int | None = None,
@@ -369,7 +369,7 @@ def wrapper(array, *by, func, skipna, core_dims, **kwargs):
369369

370370
# Flox's count works with non-numeric and its faster than converting.
371371
requires_numeric = func not in ["count", "any", "all"] or (
372-
func == "count" and engine != "flox"
372+
func == "count" and kwargs["engine"] != "flox"
373373
)
374374
if requires_numeric:
375375
is_npdatetime = array.dtype.kind in "Mm"

flox/xrutils.py

+26-1
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,14 @@
22
# defined in xarray
33

44
import datetime
5+
import importlib
56
from collections.abc import Iterable
6-
from typing import Any
7+
from typing import Any, Optional
78

89
import numpy as np
910
import pandas as pd
1011
from numpy.core.multiarray import normalize_axis_index # type: ignore[attr-defined]
12+
from packaging.version import Version
1113

1214
try:
1315
import cftime
@@ -317,3 +319,26 @@ def nanlast(values, axis, keepdims=False):
317319
return np.expand_dims(result, axis=axis)
318320
else:
319321
return result
322+
323+
324+
def module_available(module: str, minversion: Optional[str] = None) -> bool:
325+
"""Checks whether a module is installed without importing it.
326+
327+
Use this for a lightweight check and lazy imports.
328+
329+
Parameters
330+
----------
331+
module : str
332+
Name of the module.
333+
334+
Returns
335+
-------
336+
available : bool
337+
Whether the module is installed.
338+
"""
339+
has = importlib.util.find_spec(module) is not None
340+
if has:
341+
mod = importlib.import_module(module)
342+
return Version(mod.__version__) < Version(minversion) if minversion is not None else True
343+
else:
344+
return False

tests/test_core.py

+38-11
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,10 @@
1111
from numpy_groupies.aggregate_numpy import aggregate
1212

1313
from flox import xrutils
14-
from flox.aggregations import Aggregation
14+
from flox.aggregations import Aggregation, _initialize_aggregation
1515
from flox.core import (
16+
HAS_NUMBAGG,
17+
_choose_engine,
1618
_convert_expected_groups_to_index,
1719
_get_optimal_chunks_for_groups,
1820
_normalize_indexes,
@@ -600,12 +602,9 @@ def test_groupby_reduce_axis_subset_against_numpy(func, axis, engine):
600602
by = np.broadcast_to(labels2d, (3, *labels2d.shape))
601603
rng = np.random.default_rng(12345)
602604
array = rng.random(by.shape)
603-
kwargs = dict(
604-
func=func, axis=axis, expected_groups=[0, 2], fill_value=fill_value, engine=engine
605-
)
606-
expected, _ = groupby_reduce(array, by, **kwargs)
605+
kwargs = dict(func=func, axis=axis, expected_groups=[0, 2], fill_value=fill_value)
606+
expected, _ = groupby_reduce(array, by, engine=engine, **kwargs)
607607
if engine == "flox":
608-
kwargs.pop("engine")
609608
expected_npg, _ = groupby_reduce(array, by, **kwargs, engine="numpy")
610609
assert_equal(expected_npg, expected)
611610

@@ -622,12 +621,9 @@ def test_groupby_reduce_axis_subset_against_numpy(func, axis, engine):
622621
by = np.broadcast_to(labels2d, (3, *labels2d.shape))
623622
rng = np.random.default_rng(12345)
624623
array = rng.random(by.shape)
625-
kwargs = dict(
626-
func=func, axis=axis, expected_groups=[0, 2], fill_value=fill_value, engine=engine
627-
)
628-
expected, _ = groupby_reduce(array, by, **kwargs)
624+
kwargs = dict(func=func, axis=axis, expected_groups=[0, 2], fill_value=fill_value)
625+
expected, _ = groupby_reduce(array, by, engine=engine, **kwargs)
629626
if engine == "flox":
630-
kwargs.pop("engine")
631627
expected_npg, _ = groupby_reduce(array, by, **kwargs, engine="numpy")
632628
assert_equal(expected_npg, expected)
633629

@@ -640,6 +636,7 @@ def test_groupby_reduce_axis_subset_against_numpy(func, axis, engine):
640636
actual, _ = groupby_reduce(
641637
da.from_array(array, chunks=(-1, 2, 3)),
642638
da.from_array(by, chunks=(-1, 2, 2)),
639+
engine=engine,
643640
**kwargs,
644641
)
645642
assert_equal(actual, expected, tolerance)
@@ -1546,3 +1543,33 @@ def test_method_check_numpy():
15461543
]
15471544
)
15481545
assert_equal(actual, expected)
1546+
1547+
1548+
@pytest.mark.parametrize("dtype", [None, np.float64])
1549+
def test_choose_engine(dtype):
1550+
numbagg_possible = HAS_NUMBAGG and dtype is None
1551+
default = "numbagg" if numbagg_possible else "numpy"
1552+
mean = _initialize_aggregation(
1553+
"mean",
1554+
dtype=dtype,
1555+
array_dtype=np.dtype("int64"),
1556+
fill_value=0,
1557+
min_count=0,
1558+
finalize_kwargs=None,
1559+
)
1560+
argmax = _initialize_aggregation(
1561+
"argmax",
1562+
dtype=dtype,
1563+
array_dtype=np.dtype("int64"),
1564+
fill_value=0,
1565+
min_count=0,
1566+
finalize_kwargs=None,
1567+
)
1568+
1569+
# sorted by -> flox
1570+
sorted_engine = _choose_engine(np.array([1, 1, 2, 2]), agg=mean)
1571+
assert sorted_engine == ("numbagg" if numbagg_possible else "flox")
1572+
# unsorted by -> numpy
1573+
assert _choose_engine(np.array([3, 1, 1]), agg=mean) == default
1574+
# argmax does not give engine="flox"
1575+
assert _choose_engine(np.array([1, 1, 2, 2]), agg=argmax) == "numpy"

0 commit comments

Comments
 (0)