|
20 | 20 | from flox.xrutils import notnull
|
21 | 21 |
|
22 | 22 | from . import assert_equal
|
23 |
| -from .strategies import by_arrays, chunked_arrays, func_st, numeric_arrays |
| 23 | +from .strategies import array_dtypes, by_arrays, chunked_arrays, func_st, numeric_arrays |
24 | 24 | from .strategies import chunks as chunks_strategy
|
25 | 25 |
|
26 | 26 | dask.config.set(scheduler="sync")
|
@@ -244,3 +244,25 @@ def test_first_last_useless(data, func):
|
244 | 244 | actual, groups = groupby_reduce(array, by, axis=-1, func=func, engine="numpy")
|
245 | 245 | expected = np.zeros(shape[:-1] + (len(groups),), dtype=array.dtype)
|
246 | 246 | assert_equal(actual, expected)
|
| 247 | + |
| 248 | + |
| 249 | +@given( |
| 250 | + func=st.sampled_from(["sum", "prod", "nansum", "nanprod"]), |
| 251 | + engine=st.sampled_from(["numpy", "flox"]), |
| 252 | + array_dtype=st.none() | array_dtypes, |
| 253 | + dtype=st.none() | array_dtypes, |
| 254 | +) |
| 255 | +def test_agg_dtype_specified(func, array_dtype, dtype, engine): |
| 256 | + # regression test for GH388 |
| 257 | + counts = np.array([0, 2, 1, 0, 1], dtype=array_dtype) |
| 258 | + group = np.array([1, 1, 1, 2, 2]) |
| 259 | + actual, _ = groupby_reduce( |
| 260 | + counts, |
| 261 | + group, |
| 262 | + expected_groups=(np.array([1, 2]),), |
| 263 | + func=func, |
| 264 | + dtype=dtype, |
| 265 | + engine=engine, |
| 266 | + ) |
| 267 | + expected = getattr(np, func)(counts, keepdims=True, dtype=dtype) |
| 268 | + assert actual.dtype == expected.dtype |
0 commit comments