Skip to content

Commit 3853101

Browse files
authored
Revert "Support first, last with datetime, timedelta (#402)" (#404)
This reverts commit 4f6164f.
1 parent 4f6164f commit 3853101

File tree

5 files changed

+27
-73
lines changed

5 files changed

+27
-73
lines changed

flox/aggregate_numbagg.py

+1-3
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,6 @@
3030
"nanmean": {np.int_: np.float64},
3131
"nanvar": {np.int_: np.float64},
3232
"nanstd": {np.int_: np.float64},
33-
"nanfirst": {np.datetime64: np.int64, np.timedelta64: np.int64},
34-
"nanlast": {np.datetime64: np.int64, np.timedelta64: np.int64},
3533
}
3634

3735

@@ -53,7 +51,7 @@ def _numbagg_wrapper(
5351
if cast_to:
5452
for from_, to_ in cast_to.items():
5553
if np.issubdtype(array.dtype, from_):
56-
array = array.astype(to_, copy=False)
54+
array = array.astype(to_)
5755

5856
func_ = getattr(numbagg.grouped, f"group_{func}")
5957

flox/core.py

+1-32
Original file line numberDiff line numberDiff line change
@@ -45,10 +45,6 @@
4545
)
4646
from .cache import memoize
4747
from .xrutils import (
48-
_contains_cftime_datetimes,
49-
_datetime_nanmin,
50-
_to_pytimedelta,
51-
datetime_to_numeric,
5248
is_chunked_array,
5349
is_duck_array,
5450
is_duck_cubed_array,
@@ -2477,8 +2473,7 @@ def groupby_reduce(
24772473
has_dask = is_duck_dask_array(array) or is_duck_dask_array(by_)
24782474
has_cubed = is_duck_cubed_array(array) or is_duck_cubed_array(by_)
24792475

2480-
is_first_last = _is_first_last_reduction(func)
2481-
if is_first_last:
2476+
if _is_first_last_reduction(func):
24822477
if has_dask and nax != 1:
24832478
raise ValueError(
24842479
"For dask arrays: first, last, nanfirst, nanlast reductions are "
@@ -2491,24 +2486,6 @@ def groupby_reduce(
24912486
"along a single axis or when reducing across all dimensions of `by`."
24922487
)
24932488

2494-
# Flox's count works with non-numeric and its faster than converting.
2495-
is_npdatetime = array.dtype.kind in "Mm"
2496-
is_cftime = _contains_cftime_datetimes(array)
2497-
requires_numeric = (
2498-
(func not in ["count", "any", "all"] and not is_first_last)
2499-
or (func == "count" and engine != "flox")
2500-
or (is_first_last and is_cftime)
2501-
)
2502-
if requires_numeric:
2503-
if is_npdatetime:
2504-
offset = _datetime_nanmin(array)
2505-
# xarray always uses np.datetime64[ns] for np.datetime64 data
2506-
dtype = "timedelta64[ns]"
2507-
array = datetime_to_numeric(array, offset)
2508-
elif is_cftime:
2509-
offset = array.min()
2510-
array = datetime_to_numeric(array, offset, datetime_unit="us")
2511-
25122489
if nax == 1 and by_.ndim > 1 and expected_ is None:
25132490
# When we reduce along all axes, we are guaranteed to see all
25142491
# groups in the final combine stage, so everything works.
@@ -2694,14 +2671,6 @@ def groupby_reduce(
26942671

26952672
if is_bool_array and (_is_minmax_reduction(func) or _is_first_last_reduction(func)):
26962673
result = result.astype(bool)
2697-
2698-
# Output of count has an int dtype.
2699-
if requires_numeric and func != "count":
2700-
if is_npdatetime:
2701-
return result.astype(dtype) + offset
2702-
elif is_cftime:
2703-
return _to_pytimedelta(result, unit="us") + offset
2704-
27052674
return (result, *groups)
27062675

27072676

flox/xarray.py

+25
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import pandas as pd
88
import xarray as xr
99
from packaging.version import Version
10+
from xarray.core.duck_array_ops import _datetime_nanmin
1011

1112
from .aggregations import Aggregation, Dim, _atleast_1d, quantile_new_dims_func
1213
from .core import (
@@ -17,6 +18,7 @@
1718
)
1819
from .core import rechunk_for_blockwise as rechunk_array_for_blockwise
1920
from .core import rechunk_for_cohorts as rechunk_array_for_cohorts
21+
from .xrutils import _contains_cftime_datetimes, _to_pytimedelta, datetime_to_numeric
2022

2123
if TYPE_CHECKING:
2224
from xarray.core.types import T_DataArray, T_Dataset
@@ -364,6 +366,22 @@ def wrapper(array, *by, func, skipna, core_dims, **kwargs):
364366
if "nan" not in func and func not in ["all", "any", "count"]:
365367
func = f"nan{func}"
366368

369+
# Flox's count works with non-numeric and its faster than converting.
370+
requires_numeric = func not in ["count", "any", "all"] or (
371+
func == "count" and kwargs["engine"] != "flox"
372+
)
373+
if requires_numeric:
374+
is_npdatetime = array.dtype.kind in "Mm"
375+
is_cftime = _contains_cftime_datetimes(array)
376+
if is_npdatetime:
377+
offset = _datetime_nanmin(array)
378+
# xarray always uses np.datetime64[ns] for np.datetime64 data
379+
dtype = "timedelta64[ns]"
380+
array = datetime_to_numeric(array, offset)
381+
elif is_cftime:
382+
offset = array.min()
383+
array = datetime_to_numeric(array, offset, datetime_unit="us")
384+
367385
result, *groups = groupby_reduce(array, *by, func=func, **kwargs)
368386

369387
# Transpose the new quantile dimension to the end. This is ugly.
@@ -377,6 +395,13 @@ def wrapper(array, *by, func, skipna, core_dims, **kwargs):
377395
# output dim order: (*broadcast_dims, *group_dims, quantile_dim)
378396
result = np.moveaxis(result, 0, -1)
379397

398+
# Output of count has an int dtype.
399+
if requires_numeric and func != "count":
400+
if is_npdatetime:
401+
return result.astype(dtype) + offset
402+
elif is_cftime:
403+
return _to_pytimedelta(result, unit="us") + offset
404+
380405
return result
381406

382407
# These data variables do not have any of the core dimension,

flox/xrutils.py

-22
Original file line numberDiff line numberDiff line change
@@ -345,28 +345,6 @@ def _contains_cftime_datetimes(array) -> bool:
345345
return False
346346

347347

348-
def _datetime_nanmin(array):
349-
"""nanmin() function for datetime64.
350-
351-
Caveats that this function deals with:
352-
353-
- In numpy < 1.18, min() on datetime64 incorrectly ignores NaT
354-
- numpy nanmin() don't work on datetime64 (all versions at the moment of writing)
355-
- dask min() does not work on datetime64 (all versions at the moment of writing)
356-
"""
357-
from .xrdtypes import is_datetime_like
358-
359-
dtype = array.dtype
360-
assert is_datetime_like(dtype)
361-
# (NaT).astype(float) does not produce NaN...
362-
array = np.where(pd.isnull(array), np.nan, array.astype(float))
363-
array = min(array, skipna=True)
364-
if isinstance(array, float):
365-
array = np.array(array)
366-
# ...but (NaN).astype("M8") does produce NaT
367-
return array.astype(dtype)
368-
369-
370348
def _select_along_axis(values, idx, axis):
371349
other_ind = np.ix_(*[np.arange(s) for s in idx.shape])
372350
sl = other_ind[:axis] + (idx,) + other_ind[axis:]

tests/test_core.py

-16
Original file line numberDiff line numberDiff line change
@@ -2006,19 +2006,3 @@ def test_blockwise_avoid_rechunk():
20062006
actual, groups = groupby_reduce(array, by, func="first")
20072007
assert_equal(groups, ["", "0", "1"])
20082008
assert_equal(actual, np.array([0, 0, 0], dtype=np.int64))
2009-
2010-
2011-
@pytest.mark.parametrize("func", ["first", "last", "nanfirst", "nanlast"])
2012-
def test_datetime_timedelta_first_last(engine, func):
2013-
import flox
2014-
2015-
idx = 0 if "first" in func else -1
2016-
2017-
dt = pd.date_range("2001-01-01", freq="d", periods=5).values
2018-
by = np.ones(dt.shape, dtype=int)
2019-
actual, _ = flox.groupby_reduce(dt, by, func=func, engine=engine)
2020-
assert_equal(actual, dt[[idx]])
2021-
2022-
dt = dt - dt[0]
2023-
actual, _ = flox.groupby_reduce(dt, by, func=func, engine=engine)
2024-
assert_equal(actual, dt[[idx]])

0 commit comments

Comments
 (0)