Skip to content

Commit edc1344

Browse files
authored
Expand groupby_reduce property tests (#385)
* Expand groupby_reduce property tests * Add back var, std * cast quantile result * Revert "Add back var, std" This reverts commit 805b8d3. * pin numpy in benchmark env * Add benchmarks as test
1 parent a9f51a1 commit edc1344

File tree

4 files changed

+38
-4
lines changed

4 files changed

+38
-4
lines changed

ci/benchmark.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ dependencies:
66
- build
77
- cachey
88
- dask-core
9-
- numpy>=1.22
9+
- numpy<2
1010
- mamba
1111
- pip
1212
- python=3.10

flox/aggregate_flox.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,7 @@ def _np_grouped_op(
160160

161161
def _nan_grouped_op(group_idx, array, func, fillna, *args, **kwargs):
162162
if fillna in [dtypes.INF, dtypes.NINF]:
163-
fillna = dtypes._get_fill_value(kwargs.get("dtype", array.dtype), fillna)
163+
fillna = dtypes._get_fill_value(kwargs.get("dtype", None) or array.dtype, fillna)
164164
result = func(group_idx, np.where(isnull(array), fillna, array), *args, **kwargs)
165165
# np.nanmax([np.nan, np.nan]) = np.nan
166166
# To recover this behaviour, we need to search for the fillna value

tests/test_asv.py

+24
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
# Run asv benchmarks as tests
2+
3+
import pytest
4+
5+
pytest.importorskip("dask")
6+
7+
from asv_bench.benchmarks import reduce
8+
9+
10+
@pytest.mark.parametrize(
11+
"problem", [reduce.ChunkReduce1D, reduce.ChunkReduce2D, reduce.ChunkReduce2DAllAxes]
12+
)
13+
def test_reduce(problem) -> None:
14+
testcase = problem()
15+
testcase.setup()
16+
for args in zip(*testcase.time_reduce.params):
17+
testcase.time_reduce(*args)
18+
19+
20+
def test_reduce_bare() -> None:
21+
testcase = reduce.ChunkReduce1D()
22+
testcase.setup()
23+
for args in zip(*testcase.time_reduce_bare.params):
24+
testcase.time_reduce_bare(*args)

tests/test_properties.py

+12-2
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,11 @@ def not_overflowing_array(array: np.ndarray[Any, Any]) -> bool:
6161
return result
6262

6363

64-
@given(data=st.data(), array=numeric_arrays, func=func_st)
64+
@given(
65+
data=st.data(),
66+
array=st.one_of(numeric_arrays, chunked_arrays(arrays=numeric_arrays)),
67+
func=func_st,
68+
)
6569
def test_groupby_reduce(data, array, func: str) -> None:
6670
# overflow behaviour differs between bincount and sum (for example)
6771
assume(not_overflowing_array(array))
@@ -93,7 +97,13 @@ def test_groupby_reduce(data, array, func: str) -> None:
9397

9498
# numpy-groupies always does the calculation in float64
9599
if (
96-
("var" in func or "std" in func or "sum" in func or "mean" in func)
100+
(
101+
"var" in func
102+
or "std" in func
103+
or "sum" in func
104+
or "mean" in func
105+
or "quantile" in func
106+
)
97107
and array.dtype.kind == "f"
98108
and array.dtype.itemsize != 8
99109
):

0 commit comments

Comments
 (0)