Skip to content

Commit 77dc5e0

Browse files
committed
small edits
1 parent 96ae241 commit 77dc5e0

File tree

3 files changed

+190
-88
lines changed

3 files changed

+190
-88
lines changed

properties/test_properties.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import itertools
2+
13
import pytest
24

35
pytest.importorskip("hypothesis")
@@ -37,12 +39,10 @@ def test_property_season_month_tuple(roll, breaks):
3739
if breaks[-1] != 12:
3840
breaks = breaks + [12]
3941
seasons = tuple(
40-
"".join(rolled_chars[start:stop])
41-
for start, stop in zip(breaks[:-1], breaks[1:], strict=False)
42+
"".join(rolled_chars[start:stop]) for start, stop in itertools.pairwise(breaks)
4243
)
4344
actual = season_to_month_tuple(seasons)
4445
expected = tuple(
45-
rolled_months[start:stop]
46-
for start, stop in zip(breaks[:-1], breaks[1:], strict=False)
46+
rolled_months[start:stop] for start, stop in itertools.pairwise(breaks)
4747
)
4848
assert expected == actual

xarray/groupers.py

+72-51
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,7 @@
1414
from collections import defaultdict
1515
from collections.abc import Mapping, Sequence
1616
from dataclasses import dataclass, field
17-
from itertools import pairwise
18-
from itertools import chain
17+
from itertools import chain, pairwise
1918
from typing import TYPE_CHECKING, Any, Literal, cast
2019

2120
import numpy as np
@@ -25,16 +24,12 @@
2524
from xarray.coding.cftime_offsets import BaseCFTimeOffset, _new_to_legacy_freq
2625
from xarray.coding.cftimeindex import CFTimeIndex
2726
from xarray.core import duck_array_ops
28-
from xarray.core.computation import apply_ufunc
29-
from xarray.core.coordinates import Coordinates, _coordinates_from_variable
30-
from xarray.core.coordinates import Coordinates
31-
from xarray.core.common import _contains_datetime_like_objects
32-
from xarray.core.common import _contains_datetime_like_objects
3327
from xarray.core.common import (
3428
_contains_cftime_datetimes,
3529
_contains_datetime_like_objects,
3630
)
37-
from xarray.core.coordinates import Coordinates
31+
from xarray.core.computation import apply_ufunc
32+
from xarray.core.coordinates import Coordinates, _coordinates_from_variable
3833
from xarray.core.dataarray import DataArray
3934
from xarray.core.duck_array_ops import isnull
4035
from xarray.core.formatting import first_n_items
@@ -751,14 +746,16 @@ def __post_init__(self):
751746
)
752747
self.season_tuples = dict(zip(self.seasons, self.season_inds, strict=True))
753748

754-
def factorize(self, group):
749+
def factorize(self, group: T_Group) -> EncodedGroups:
755750
if group.ndim != 1:
756751
raise ValueError(
757752
"SeasonResampler can only be used to resample by 1D arrays."
758753
)
759-
if not _contains_datetime_like_objects(group.variable):
754+
if not isinstance(group, DataArray) or not _contains_datetime_like_objects(
755+
group.variable
756+
):
760757
raise ValueError(
761-
"SeasonResampler can only be used to group by datetime-like arrays."
758+
"SeasonResampler can only be used to group by datetime-like DataArrays."
762759
)
763760

764761
seasons = self.seasons
@@ -775,13 +772,14 @@ def factorize(self, group):
775772
season_label[month.isin(season_ind)] = season_str
776773
if "DJ" in season_str:
777774
after_dec = season_ind[season_str.index("D") + 1 :]
778-
# important this is assuming non-overlapping seasons
775+
# important: this is assuming non-overlapping seasons
779776
year[month.isin(after_dec)] -= 1
780777

781778
# Allow users to skip one or more months?
782-
# present_seasons is a mask that is True for months that are requestsed in the output
779+
# present_seasons is a mask that is True for months that are requested in the output
783780
present_seasons = season_label != ""
784781
if present_seasons.all():
782+
# avoid copies if we can.
785783
present_seasons = slice(None)
786784
frame = pd.DataFrame(
787785
data={
@@ -794,10 +792,13 @@ def factorize(self, group):
794792
),
795793
)
796794

797-
series = frame["index"]
798-
g = series.groupby(["year", "season"], sort=False)
799-
first_items = g.first()
800-
counts = g.count()
795+
agged = (
796+
frame["index"]
797+
.groupby(["year", "season"], sort=False)
798+
.agg(["first", "count"])
799+
)
800+
first_items = agged["first"]
801+
counts = agged["count"]
801802

802803
if _contains_cftime_datetimes(group.data):
803804
index_class = CFTimeIndex
@@ -814,32 +815,18 @@ def factorize(self, group):
814815
]
815816
)
816817

817-
sbins = first_items.values.astype(int)
818-
group_indices = [
819-
slice(i, j) for i, j in zip(sbins[:-1], sbins[1:], strict=True)
820-
]
821-
group_indices += [slice(sbins[-1], None)]
822-
823-
# Make sure the first and last timestamps
824-
# are for the correct months,if not we have incomplete seasons
825-
unique_codes = np.arange(len(unique_coord))
826-
if self.drop_incomplete:
827-
for idx, slicer in zip([0, -1], (slice(1, None), slice(-1)), strict=True):
828-
stamp_year, stamp_season = frame.index[idx]
829-
code = seasons.index(stamp_season)
830-
stamp_month = season_inds[code][idx]
831-
if stamp_month != month[present_seasons][idx].item():
832-
# we have an incomplete season!
833-
group_indices = group_indices[slicer]
834-
unique_coord = unique_coord[slicer]
835-
if idx == 0:
836-
unique_codes -= 1
837-
unique_codes[idx] = -1
838-
839-
# all years and seasons
818+
# sbins = first_items.values.astype(int)
819+
# group_indices = [
820+
# slice(i, j) for i, j in zip(sbins[:-1], sbins[1:], strict=True)
821+
# ]
822+
# group_indices += [slice(sbins[-1], None)]
823+
824+
# This sorted call is a hack. It's hard to figure out how
825+
# to start the iteration for arbitrary season ordering
826+
# for example "DJF" as first entry or last entry
827+
# So we construct the largest possible index and slice it to the
828+
# range present in the data.
840829
complete_index = index_class(
841-
# This sorted call is a hack. It's hard to figure out how
842-
# to start the iteration
843830
sorted(
844831
[
845832
datetime_class(year=y, month=m, day=1)
@@ -850,22 +837,56 @@ def factorize(self, group):
850837
]
851838
)
852839
)
853-
# only keep that included in data
854-
range_ = complete_index.get_indexer(unique_coord[[0, -1]])
855-
full_index = complete_index[slice(range_[0], range_[-1] + 1)]
840+
841+
# all years and seasons
842+
def get_label(year, season):
843+
month = season_tuples[season][0]
844+
return f"{year}-{month}-01"
845+
846+
unique_codes = np.arange(len(unique_coord))
847+
first_valid_season = season_label[0]
848+
last_valid_season = season_label[-1]
849+
first_year, last_year = year.data[[0, -1]]
850+
if self.drop_incomplete:
851+
if month.data[0] != season_tuples[first_valid_season][0]:
852+
if "DJ" in first_valid_season:
853+
first_year += 1
854+
first_valid_season = seasons[
855+
(seasons.index(first_valid_season) + 1) % len(seasons)
856+
]
857+
# group_indices = group_indices[slice(1, None)]
858+
unique_codes -= 1
859+
860+
if month.data[-1] != season_tuples[last_valid_season][-1]:
861+
last_valid_season = seasons[seasons.index(last_valid_season) - 1]
862+
if "DJ" in last_valid_season:
863+
last_year -= 1
864+
# group_indices = group_indices[slice(-1)]
865+
unique_codes[-1] = -1
866+
867+
first_label = get_label(first_year, first_valid_season)
868+
last_label = get_label(last_year, last_valid_season)
869+
870+
slicer = complete_index.slice_indexer(first_label, last_label)
871+
full_index = complete_index[slicer]
872+
# TODO: group must be sorted
873+
# codes = np.searchsorted(edges, group.data, side="left")
874+
# codes -= 1
875+
# codes[~present_seasons | group.data >= edges[-1]] = -1
876+
# codes[isnull(group.data)] = -1
877+
# import ipdb; ipdb.set_trace()
856878
# check that there are no "missing" seasons in the middle
857-
# print(full_index, unique_coord)
858-
if not full_index.equals(unique_coord):
859-
raise ValueError("Are there seasons missing in the middle of the dataset?")
879+
# if not full_index.equals(unique_coord):
880+
# raise ValueError("Are there seasons missing in the middle of the dataset?")
860881

861882
final_codes = np.full(group.data.size, -1)
862883
final_codes[present_seasons] = np.repeat(unique_codes, counts)
863884
codes = group.copy(data=final_codes, deep=False)
864-
unique_coord_var = Variable(group.name, unique_coord, group.attrs)
885+
# unique_coord_var = Variable(group.name, unique_coord, group.attrs)
865886

866887
return EncodedGroups(
867888
codes=codes,
868-
group_indices=group_indices,
869-
unique_coord=unique_coord_var,
889+
# group_indices=group_indices,
890+
# unique_coord=unique_coord_var,
870891
full_index=full_index,
871892
)

xarray/tests/test_groupby.py

+114-33
Original file line numberDiff line numberDiff line change
@@ -12,14 +12,15 @@
1212
from packaging.version import Version
1313

1414
import xarray as xr
15-
from xarray import DataArray, Dataset, Variable, cftime_range
15+
from xarray import DataArray, Dataset, Variable, cftime_range, date_range
1616
from xarray.core.alignment import broadcast
1717
from xarray.core.groupby import _consolidate_slices
1818
from xarray.core.types import InterpOptions, ResampleCompatible
1919
from xarray.groupers import (
2020
BinGrouper,
2121
EncodedGroups,
2222
Grouper,
23+
SeasonGrouper,
2324
SeasonResampler,
2425
TimeResampler,
2526
UniqueGrouper,
@@ -44,6 +45,7 @@
4445
requires_pandas_ge_2_2,
4546
requires_scipy,
4647
)
48+
from xarray.tests.test_coding_times import _ALL_CALENDARS
4749

4850

4951
@pytest.fixture
@@ -3144,48 +3146,127 @@ def test_groupby_dask_eager_load_warnings():
31443146
ds.groupby_bins("x", bins=[1, 2, 3], eagerly_compute_group=False)
31453147

31463148

3147-
# TODO: Possible property tests to add to this module
3148-
# 1. lambda x: x
3149-
# 2. grouped-reduce on unique coords is identical to array
3150-
# 3. group_over == groupby-reduce along other dimensions
3151-
# 4. result is equivalent for transposed input
3152-
def test_season_to_month_tuple():
3153-
assert season_to_month_tuple(["JF", "MAM", "JJAS", "OND"]) == (
3154-
(1, 2),
3155-
(3, 4, 5),
3156-
(6, 7, 8, 9),
3157-
(10, 11, 12),
3158-
)
3159-
assert season_to_month_tuple(["DJFM", "AM", "JJAS", "ON"]) == (
3160-
(12, 1, 2, 3),
3161-
(4, 5),
3162-
(6, 7, 8, 9),
3163-
(10, 11),
3149+
class TestSeasonGrouperAndResampler:
3150+
def test_season_to_month_tuple(self):
3151+
assert season_to_month_tuple(["JF", "MAM", "JJAS", "OND"]) == (
3152+
(1, 2),
3153+
(3, 4, 5),
3154+
(6, 7, 8, 9),
3155+
(10, 11, 12),
3156+
)
3157+
assert season_to_month_tuple(["DJFM", "AM", "JJAS", "ON"]) == (
3158+
(12, 1, 2, 3),
3159+
(4, 5),
3160+
(6, 7, 8, 9),
3161+
(10, 11),
3162+
)
3163+
3164+
@pytest.mark.parametrize("calendar", _ALL_CALENDARS)
3165+
def test_season_grouper_simple(self, calendar) -> None:
3166+
time = cftime_range("2001-01-01", "2002-12-30", freq="D", calendar=calendar)
3167+
da = DataArray(np.ones(time.size), dims="time", coords={"time": time})
3168+
expected = da.groupby("time.season").mean()
3169+
# note season order matches expected
3170+
actual = da.groupby(
3171+
time=SeasonGrouper(
3172+
["DJF", "JJA", "MAM", "SON"], # drop_incomplete=False
3173+
)
3174+
).mean()
3175+
assert_identical(expected, actual)
3176+
3177+
# TODO: drop_incomplete
3178+
@requires_cftime
3179+
@pytest.mark.parametrize("drop_incomplete", [True, False])
3180+
@pytest.mark.parametrize(
3181+
"seasons",
3182+
[
3183+
pytest.param(["DJF", "MAM", "JJA", "SON"], id="standard"),
3184+
pytest.param(["MAM", "JJA", "SON", "DJF"], id="standard-diff-order"),
3185+
pytest.param(["JFM", "AMJ", "JAS", "OND"], id="december-same-year"),
3186+
pytest.param(["DJF", "MAM", "JJA", "ON"], id="skip-september"),
3187+
pytest.param(["JJAS"], id="jjas-only"),
3188+
pytest.param(["MAM", "JJA", "SON", "DJF"], id="different-order"),
3189+
pytest.param(["JJA", "MAM", "SON", "DJF"], id="out-of-order"),
3190+
],
31643191
)
3192+
def test_season_resampler(self, seasons: list[str], drop_incomplete: bool) -> None:
3193+
calendar = "standard"
3194+
time = date_range("2001-01-01", "2002-12-30", freq="D", calendar=calendar)
3195+
da = DataArray(np.ones(time.size), dims="time", coords={"time": time})
3196+
counts = da.resample(time="ME").count()
3197+
3198+
seasons_as_ints = season_to_month_tuple(seasons)
3199+
month = counts.time.dt.month.data
3200+
year = counts.time.dt.year.data
3201+
for season, as_ints in zip(seasons, seasons_as_ints, strict=True):
3202+
if "DJ" in season:
3203+
for imonth in as_ints[season.index("D") + 1 :]:
3204+
year[month == imonth] -= 1
3205+
counts["time"] = (
3206+
"time",
3207+
[pd.Timestamp(f"{y}-{m}-01") for y, m in zip(year, month, strict=True)],
3208+
)
3209+
counts = counts.convert_calendar(calendar, "time", align_on="date")
3210+
3211+
expected_vals = []
3212+
expected_time = []
3213+
for year in [2001, 2002]:
3214+
for season, as_ints in zip(seasons, seasons_as_ints, strict=True):
3215+
out_year = year
3216+
if "DJ" in season:
3217+
out_year = year - 1
3218+
available = [
3219+
counts.sel(time=f"{out_year}-{month:02d}").data for month in as_ints
3220+
]
3221+
if any(len(a) == 0 for a in available) and drop_incomplete:
3222+
continue
3223+
output_label = pd.Timestamp(f"{out_year}-{as_ints[0]:02d}-01")
3224+
expected_time.append(output_label)
3225+
# use concatenate to handle empty array when dec value does not exist
3226+
expected_vals.append(np.concatenate(available).sum())
31653227

3228+
expected = xr.DataArray(
3229+
expected_vals, dims="time", coords={"time": expected_time}
3230+
).convert_calendar(calendar, align_on="date")
3231+
rs = SeasonResampler(seasons, drop_incomplete=drop_incomplete)
3232+
# through resample
3233+
actual = da.resample(time=rs).sum()
3234+
assert_identical(actual, expected)
31663235

3167-
def test_season_resampler():
3168-
time = cftime_range("2001-01-01", "2002-12-30", freq="D", calendar="360_day")
3169-
da = DataArray(np.ones(time.size), dims="time", coords={"time": time})
3236+
def test_season_resampler_errors(self):
3237+
time = cftime_range("2001-01-01", "2002-12-30", freq="D", calendar="360_day")
3238+
da = DataArray(np.ones(time.size), dims="time", coords={"time": time})
31703239

3171-
# through resample
3172-
da.resample(time=SeasonResampler(["DJF", "MAM", "JJA", "SON"])).sum()
3240+
# non-datetime array
3241+
with pytest.raises(ValueError):
3242+
DataArray(np.ones(5), dims="time").groupby(time=SeasonResampler(["DJF"]))
31733243

3174-
# through groupby
3175-
da.groupby(time=SeasonResampler(["DJF", "MAM", "JJA", "SON"])).sum()
3244+
# ndim > 1 array
3245+
with pytest.raises(ValueError):
3246+
DataArray(
3247+
np.ones((5, 5)), dims=("t", "x"), coords={"x": np.arange(5)}
3248+
).groupby(x=SeasonResampler(["DJF"]))
31763249

3177-
# skip september
3178-
da.groupby(time=SeasonResampler(["DJF", "MAM", "JJA", "ON"])).sum()
3250+
# overlapping seasons
3251+
with pytest.raises(ValueError):
3252+
da.groupby(time=SeasonResampler(["DJFM", "MAMJ", "JJAS", "SOND"])).sum()
31793253

3180-
# "subsampling"
3181-
da.groupby(time=SeasonResampler(["JJAS"])).sum()
3254+
@requires_cftime
3255+
def test_season_resampler_groupby_identical(self):
3256+
time = date_range("2001-01-01", "2002-12-30", freq="D")
3257+
da = DataArray(np.ones(time.size), dims="time", coords={"time": time})
31823258

3183-
# overlapping
3184-
with pytest.raises(ValueError):
3185-
da.groupby(time=SeasonResampler(["DJFM", "MAMJ", "JJAS", "SOND"])).sum()
3259+
# through resample
3260+
resampler = SeasonResampler(["DJF", "MAM", "JJA", "SON"])
3261+
rs = da.resample(time=resampler).sum()
31863262

3263+
# through groupby
3264+
gb = da.groupby(time=resampler).sum()
3265+
assert_identical(rs, gb)
31873266

3188-
# Possible property tests
3267+
3268+
# TODO: Possible property tests to add to this module
31893269
# 1. lambda x: x
31903270
# 2. grouped-reduce on unique coords is identical to array
31913271
# 3. group_over == groupby-reduce along other dimensions
3272+
# 4. result is equivalent for transposed input

0 commit comments

Comments
 (0)