Skip to content

Commit 7421cb1

Browse files
authored
Preserve dtype better when specified. (#389)
* Preserve dtype better when specified. * Add one more test * tweak test * more test * [revert] test with Xarray PR branch * tweak * show versions * Drop python 3.9, use ruff * switch to Ruff * fix mypy * remove toctrees * fix * one more
1 parent 0438a7e commit 7421cb1

File tree

9 files changed

+85
-15
lines changed

9 files changed

+85
-15
lines changed

.github/workflows/ci.yaml

+3-1
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@ jobs:
7070
- name: Run Tests
7171
id: status
7272
run: |
73+
python -c "import xarray; xarray.show_versions()"
7374
pytest --durations=20 --durations-min=0.5 -n auto --cov=./ --cov-report=xml --hypothesis-profile ci
7475
- name: Upload code coverage to Codecov
7576
uses: codecov/[email protected]
@@ -98,7 +99,7 @@ jobs:
9899
steps:
99100
- uses: actions/checkout@v4
100101
with:
101-
repository: "pydata/xarray"
102+
repository: "dcherian/xarray"
102103
fetch-depth: 0 # Fetch all history for all branches and tags.
103104
- name: Set up conda environment
104105
uses: mamba-org/setup-micromamba@v1
@@ -112,6 +113,7 @@ jobs:
112113
pint>=0.22
113114
- name: Install xarray
114115
run: |
116+
git checkout flox-preserve-dtype
115117
python -m pip install --no-deps .
116118
- name: Install upstream flox
117119
run: |

ci/environment.yml

+2-1
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,12 @@ dependencies:
1919
- pytest-pretty
2020
- pytest-xdist
2121
- syrupy
22-
- xarray
2322
- pre-commit
2423
- numpy_groupies>=0.9.19
2524
- pooch
2625
- toolz
2726
- numba
2827
- numbagg>=0.3
2928
- hypothesis
29+
- pip:
30+
- git+https://github.com/dcherian/xarray.git@flox-preserve-dtype

ci/no-dask.yml

+2-1
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,12 @@ dependencies:
1414
- pytest-pretty
1515
- pytest-xdist
1616
- syrupy
17-
- xarray
1817
- numpydoc
1918
- pre-commit
2019
- numpy_groupies>=0.9.19
2120
- pooch
2221
- toolz
2322
- numba
2423
- numbagg>=0.3
24+
- pip:
25+
- git+https://github.com/dcherian/xarray.git@flox-preserve-dtype

flox/aggregations.py

+8-5
Original file line numberDiff line numberDiff line change
@@ -549,20 +549,23 @@ def quantile_new_dims_func(q) -> tuple[Dim]:
549549
return (Dim(name="quantile", values=q),)
550550

551551

552+
# if the input contains integers or floats smaller than float64,
553+
# the output data-type is float64. Otherwise, the output data-type is the same as that
554+
# of the input.
552555
quantile = Aggregation(
553556
name="quantile",
554557
fill_value=dtypes.NA,
555558
chunk=None,
556559
combine=None,
557-
final_dtype=np.floating,
560+
final_dtype=np.float64,
558561
new_dims_func=quantile_new_dims_func,
559562
)
560563
nanquantile = Aggregation(
561564
name="nanquantile",
562565
fill_value=dtypes.NA,
563566
chunk=None,
564567
combine=None,
565-
final_dtype=np.floating,
568+
final_dtype=np.float64,
566569
new_dims_func=quantile_new_dims_func,
567570
)
568571
mode = Aggregation(name="mode", fill_value=dtypes.NA, chunk=None, combine=None, preserves_dtype=True)
@@ -801,9 +804,9 @@ def _initialize_aggregation(
801804
dtype_: np.dtype | None = (
802805
np.dtype(dtype) if dtype is not None and not isinstance(dtype, np.dtype) else dtype
803806
)
804-
final_dtype = dtypes._normalize_dtype(dtype_ or agg.dtype_init["final"], array_dtype, fill_value)
805-
if not agg.preserves_dtype:
806-
final_dtype = dtypes._maybe_promote_int(final_dtype)
807+
final_dtype = dtypes._normalize_dtype(
808+
dtype_ or agg.dtype_init["final"], array_dtype, agg.preserves_dtype, fill_value
809+
)
807810
agg.dtype = {
808811
"user": dtype, # Save to automatically choose an engine
809812
"final": final_dtype,

flox/xrdtypes.py

+7-2
Original file line numberDiff line numberDiff line change
@@ -150,9 +150,14 @@ def is_datetime_like(dtype):
150150
return np.issubdtype(dtype, np.datetime64) or np.issubdtype(dtype, np.timedelta64)
151151

152152

153-
def _normalize_dtype(dtype: DTypeLike, array_dtype: np.dtype, fill_value=None) -> np.dtype:
153+
def _normalize_dtype(
154+
dtype: DTypeLike, array_dtype: np.dtype, preserves_dtype: bool, fill_value=None
155+
) -> np.dtype:
154156
if dtype is None:
155-
dtype = array_dtype
157+
if not preserves_dtype:
158+
dtype = _maybe_promote_int(array_dtype)
159+
else:
160+
dtype = array_dtype
156161
if dtype is np.floating:
157162
# mean, std, var always result in floating
158163
# but we preserve the array's dtype if it is floating

tests/strategies.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ def supported_dtypes() -> st.SearchStrategy[np.dtype]:
2727

2828

2929
# TODO: stop excluding everything but U
30-
array_dtype_st = supported_dtypes().filter(lambda x: x.kind not in "cmMU")
30+
array_dtypes = supported_dtypes().filter(lambda x: x.kind not in "cmMU")
3131
by_dtype_st = supported_dtypes()
3232

3333
NON_NUMPY_FUNCS = [
@@ -43,7 +43,7 @@ def supported_dtypes() -> st.SearchStrategy[np.dtype]:
4343

4444
func_st = st.sampled_from([f for f in ALL_FUNCS if f not in NON_NUMPY_FUNCS and f not in SKIPPED_FUNCS])
4545
numeric_arrays = npst.arrays(
46-
elements={"allow_subnormal": False}, shape=npst.array_shapes(), dtype=array_dtype_st
46+
elements={"allow_subnormal": False}, shape=npst.array_shapes(), dtype=array_dtypes
4747
)
4848
all_arrays = npst.arrays(
4949
elements={"allow_subnormal": False},

tests/test_core.py

+14-1
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ def _get_array_func(func: str) -> Callable:
8181

8282
def npfunc(x, **kwargs):
8383
x = np.asarray(x)
84-
return (~np.isnan(x)).sum()
84+
return (~xrutils.isnull(x)).sum(**kwargs)
8585

8686
elif func in ["nanfirst", "nanlast"]:
8787
npfunc = getattr(xrutils, func)
@@ -1984,3 +1984,16 @@ def test_blockwise_nans():
19841984
)
19851985
assert_equal(expected_groups, actual_groups)
19861986
assert_equal(expected, actual)
1987+
1988+
1989+
@pytest.mark.parametrize("func", ["sum", "prod", "count", "nansum"])
1990+
@pytest.mark.parametrize("engine", ["flox", "numpy"])
1991+
def test_agg_dtypes(func, engine):
1992+
# regression test for GH388
1993+
counts = np.array([0, 2, 1, 0, 1])
1994+
group = np.array([1, 1, 1, 2, 2])
1995+
actual, _ = groupby_reduce(
1996+
counts, group, expected_groups=(np.array([1, 2]),), func=func, dtype="uint8", engine=engine
1997+
)
1998+
expected = _get_array_func(func)(counts, dtype="uint8")
1999+
assert actual.dtype == np.uint8 == expected.dtype

tests/test_properties.py

+23-1
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from flox.xrutils import notnull
2121

2222
from . import assert_equal
23-
from .strategies import by_arrays, chunked_arrays, func_st, numeric_arrays
23+
from .strategies import array_dtypes, by_arrays, chunked_arrays, func_st, numeric_arrays
2424
from .strategies import chunks as chunks_strategy
2525

2626
dask.config.set(scheduler="sync")
@@ -244,3 +244,25 @@ def test_first_last_useless(data, func):
244244
actual, groups = groupby_reduce(array, by, axis=-1, func=func, engine="numpy")
245245
expected = np.zeros(shape[:-1] + (len(groups),), dtype=array.dtype)
246246
assert_equal(actual, expected)
247+
248+
249+
@given(
250+
func=st.sampled_from(["sum", "prod", "nansum", "nanprod"]),
251+
engine=st.sampled_from(["numpy", "flox"]),
252+
array_dtype=st.none() | array_dtypes,
253+
dtype=st.none() | array_dtypes,
254+
)
255+
def test_agg_dtype_specified(func, array_dtype, dtype, engine):
256+
# regression test for GH388
257+
counts = np.array([0, 2, 1, 0, 1], dtype=array_dtype)
258+
group = np.array([1, 1, 1, 2, 2])
259+
actual, _ = groupby_reduce(
260+
counts,
261+
group,
262+
expected_groups=(np.array([1, 2]),),
263+
func=func,
264+
dtype=dtype,
265+
engine=engine,
266+
)
267+
expected = getattr(np, func)(counts, keepdims=True, dtype=dtype)
268+
assert actual.dtype == expected.dtype

tests/test_xarray.py

+24-1
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424

2525
# test against legacy xarray implementation
2626
# avoid some compilation overhead
27-
xr.set_options(use_flox=False, use_numbagg=False)
27+
xr.set_options(use_flox=False, use_numbagg=False, use_bottleneck=False)
2828
tolerance64 = {"rtol": 1e-15, "atol": 1e-18}
2929
np.random.seed(123)
3030

@@ -760,3 +760,26 @@ def test_direct_reduction(func):
760760
with xr.set_options(use_flox=False):
761761
expected = getattr(data.groupby("x", squeeze=False), func)(**kwargs)
762762
xr.testing.assert_identical(expected, actual)
763+
764+
765+
@pytest.mark.parametrize("reduction", ["max", "min", "nanmax", "nanmin", "sum", "nansum", "prod", "nanprod"])
766+
def test_groupby_preserve_dtype(reduction):
767+
# all groups are present, we should follow numpy exactly
768+
ds = xr.Dataset(
769+
{
770+
"test": (
771+
["x", "y"],
772+
np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype="int16"),
773+
)
774+
},
775+
coords={"idx": ("x", [1, 2, 1])},
776+
)
777+
778+
kwargs = {"engine": "numpy"}
779+
if "nan" in reduction:
780+
kwargs["skipna"] = True
781+
with xr.set_options(use_flox=True):
782+
actual = getattr(ds.groupby("idx"), reduction.removeprefix("nan"))(**kwargs).test.dtype
783+
expected = getattr(np, reduction)(ds.test.data, axis=0).dtype
784+
785+
assert actual == expected

0 commit comments

Comments
 (0)