Skip to content

Commit 07411f6

Browse files
committed
cleanup
1 parent 4da732f commit 07411f6

File tree

3 files changed

+67
-35
lines changed

3 files changed

+67
-35
lines changed

properties/test_properties.py

+17-2
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,11 @@
55
pytest.importorskip("hypothesis")
66

77
import hypothesis.strategies as st
8-
from hypothesis import given
8+
from hypothesis import given, note
99

1010
import xarray as xr
1111
import xarray.testing.strategies as xrst
12-
from xarray.groupers import season_to_month_tuple
12+
from xarray.groupers import find_independent_seasons, season_to_month_tuple
1313

1414

1515
@given(attrs=xrst.simple_attrs)
@@ -46,3 +46,18 @@ def test_property_season_month_tuple(roll, breaks):
4646
rolled_months[start:stop] for start, stop in itertools.pairwise(breaks)
4747
)
4848
assert expected == actual
49+
50+
51+
@given(data=st.data(), nmonths=st.integers(min_value=1, max_value=11))
52+
def test_property_find_independent_seasons(data, nmonths):
53+
chars = "JFMAMJJASOND"
54+
# if stride > nmonths, then we can't infer season order
55+
stride = data.draw(st.integers(min_value=1, max_value=nmonths))
56+
chars = chars + chars[:nmonths]
57+
seasons = [list(chars[i : i + nmonths]) for i in range(0, 12, stride)]
58+
note(seasons)
59+
groups = find_independent_seasons(seasons)
60+
for group in groups:
61+
inds = tuple(itertools.chain(*group.inds))
62+
assert len(inds) == len(set(inds))
63+
assert len(group.codes) == len(set(group.codes))

xarray/groupers.py

+24-4
Original file line numberDiff line numberDiff line change
@@ -604,6 +604,14 @@ def unique_value_groups(
604604

605605

606606
def season_to_month_tuple(seasons: Sequence[str]) -> tuple[tuple[int, ...], ...]:
607+
"""
608+
>>> season_to_month_tuple(["DJF", "MAM", "JJA", "SON"])
609+
((12, 1, 2), (3, 4, 5), (6, 7, 8), (9, 10, 11))
610+
>>> season_to_month_tuple(["DJFM", "MAMJ", "JJAS", "SOND"])
611+
((12, 1, 2, 3), (3, 4, 5, 6), (6, 7, 8, 9), (9, 10, 11, 12))
612+
>>> season_to_month_tuple(["DJFM", "SOND"])
613+
((12, 1, 2, 3), (9, 10, 11, 12))
614+
"""
607615
initials = "JFMAMJJASOND"
608616
starts = dict(
609617
("".join(s), i + 1)
@@ -629,7 +637,7 @@ def season_to_month_tuple(seasons: Sequence[str]) -> tuple[tuple[int, ...], ...]
629637
return tuple(result)
630638

631639

632-
def inds_to_string(asints: tuple[tuple[int, ...], ...]) -> tuple[str, ...]:
640+
def inds_to_season_string(asints: tuple[tuple[int, ...], ...]) -> tuple[str, ...]:
633641
inits = "JFMAMJJASOND"
634642
return tuple("".join([inits[i_ - 1] for i_ in t]) for t in asints)
635643

@@ -660,24 +668,36 @@ def is_sorted_periodic(lst):
660668
return lst[-1] <= lst[0]
661669

662670

663-
@dataclass
671+
@dataclass(kw_only=True, frozen=True)
664672
class SeasonsGroup:
665673
seasons: tuple[str, ...]
674+
# tuple[integer months] corresponding to each season
666675
inds: tuple[tuple[int, ...], ...]
676+
# integer code for each season, this is not simply range(len(seasons))
677+
# when the seasons have overlaps
667678
codes: Sequence[int]
668679

669680

670681
def find_independent_seasons(seasons: Sequence[str]) -> Sequence[SeasonsGroup]:
671682
"""
672683
Iterates though a list of seasons e.g. ["DJF", "FMA", ...],
673684
and splits that into multiple sequences of non-overlapping seasons.
685+
686+
>>> find_independent_seasons(
687+
... ["DJF", "FMA", "AMJ", "JJA", "ASO", "OND"]
688+
... ) # doctest: +NORMALIZE_WHITESPACE
689+
[SeasonsGroup(seasons=('DJF', 'AMJ', 'ASO'), inds=((12, 1, 2), (4, 5, 6), (8, 9, 10)), codes=[0, 2, 4]),
690+
SeasonsGroup(seasons=('FMA', 'JJA', 'OND'), inds=((2, 3, 4), (6, 7, 8), (10, 11, 12)), codes=[1, 3, 5])]
691+
692+
>>> find_independent_seasons(["DJF", "MAM", "JJA", "SON"])
693+
[SeasonsGroup(seasons=('DJF', 'MAM', 'JJA', 'SON'), inds=((12, 1, 2), (3, 4, 5), (6, 7, 8), (9, 10, 11)), codes=[0, 1, 2, 3])]
674694
"""
675695
season_inds = season_to_month_tuple(seasons)
676696
grouped = defaultdict(list)
677697
codes = defaultdict(list)
678698
seen: set[tuple[int, ...]] = set()
679699
idx = 0
680-
# This is quadratic, but the length of seasons is at most 12
700+
# This is quadratic, but the number of seasons is at most 12
681701
for i, current in enumerate(season_inds):
682702
# Start with a group
683703
if current not in seen:
@@ -699,7 +719,7 @@ def find_independent_seasons(seasons: Sequence[str]) -> Sequence[SeasonsGroup]:
699719

700720
grouped_ints = tuple(tuple(idx) for idx in grouped.values() if idx)
701721
return [
702-
SeasonsGroup(seasons=inds_to_string(inds), inds=inds, codes=codes)
722+
SeasonsGroup(seasons=inds_to_season_string(inds), inds=inds, codes=codes)
703723
for inds, codes in zip(grouped_ints, codes.values(), strict=False)
704724
]
705725

xarray/tests/test_groupby.py

+26-29
Original file line numberDiff line numberDiff line change
@@ -3378,7 +3378,7 @@ def test_season_grouper_with_partial_years(self, calendar):
33783378

33793379
assert_allclose(expected, actual)
33803380

3381-
@pytest.mark.parametrize("calendar", _ALL_CALENDARS)
3381+
@pytest.mark.parametrize("calendar", ["standard"])
33823382
def test_season_grouper_with_single_month_seasons(self, calendar):
33833383
time = date_range("2001-01-01", "2002-12-30", freq="MS", calendar=calendar)
33843384
# fmt: off
@@ -3392,29 +3392,32 @@ def test_season_grouper_with_single_month_seasons(self, calendar):
33923392
da = DataArray(data, dims="time", coords={"time": time})
33933393
da["year"] = da.time.dt.year
33943394

3395-
actual = da.groupby(
3396-
year=UniqueGrouper(),
3397-
time=SeasonGrouper(
3398-
["J", "F", "M", "A", "M", "J", "J", "A", "S", "O", "N", "D"]
3399-
),
3400-
).mean()
3395+
# TODO: Consider supporting this if needed
3396+
# It does not work without flox, because the group labels are not unique,
3397+
# and so the stack/unstack approach does not work.
3398+
with pytest.raises(ValueError):
3399+
da.groupby(
3400+
year=UniqueGrouper(),
3401+
time=SeasonGrouper(
3402+
["J", "F", "M", "A", "M", "J", "J", "A", "S", "O", "N", "D"]
3403+
),
3404+
).mean()
34013405

34023406
# Expected if single month seasons are handled correctly
3403-
expected = xr.DataArray(
3404-
data=np.array(
3405-
[
3406-
[1.0, 1.25, 1.5, 1.75, 2.0, 1.1, 1.35, 1.6, 1.85, 1.2, 1.45, 1.7],
3407-
[1.95, 1.05, 1.3, 1.55, 1.8, 1.15, 1.4, 1.65, 1.9, 1.25, 1.5, 1.75],
3408-
]
3409-
),
3410-
dims=["year", "season"],
3411-
coords={
3412-
"year": [2001, 2002],
3413-
"season": ["J", "F", "M", "A", "M", "J", "J", "A", "S", "O", "N", "D"],
3414-
},
3415-
)
3416-
3417-
assert_allclose(expected, actual)
3407+
# expected = xr.DataArray(
3408+
# data=np.array(
3409+
# [
3410+
# [1.0, 1.25, 1.5, 1.75, 2.0, 1.1, 1.35, 1.6, 1.85, 1.2, 1.45, 1.7],
3411+
# [1.95, 1.05, 1.3, 1.55, 1.8, 1.15, 1.4, 1.65, 1.9, 1.25, 1.5, 1.75],
3412+
# ]
3413+
# ),
3414+
# dims=["year", "season"],
3415+
# coords={
3416+
# "year": [2001, 2002],
3417+
# "season": ["J", "F", "M", "A", "M", "J", "J", "A", "S", "O", "N", "D"],
3418+
# },
3419+
# )
3420+
# assert_allclose(expected, actual)
34183421

34193422
@pytest.mark.parametrize("calendar", _ALL_CALENDARS)
34203423
def test_season_grouper_with_months_spanning_calendar_year_using_previous_year(
@@ -3481,13 +3484,7 @@ def test_season_resampling_raises_unsorted_seasons(self, seasons):
34813484
da.resample(time=SeasonResampler(seasons))
34823485

34833486
@pytest.mark.parametrize(
3484-
"use_cftime",
3485-
[
3486-
pytest.param(
3487-
True, marks=pytest.mark.skipif(not has_cftime, reason="no cftime")
3488-
),
3489-
False,
3490-
],
3487+
"use_cftime", [pytest.param(True, marks=requires_cftime), False]
34913488
)
34923489
@pytest.mark.parametrize("drop_incomplete", [True, False])
34933490
@pytest.mark.parametrize(

0 commit comments

Comments
 (0)