Skip to content

Commit 9180536

Browse files
committed
Support "subsampled" seasons
1 parent 82f3c21 commit 9180536

File tree

2 files changed

+29
-7
lines changed

2 files changed

+29
-7
lines changed

xarray/groupers.py

+26-7
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,9 @@
77
from __future__ import annotations
88

99
import datetime
10+
import functools
1011
import itertools
12+
import operator
1113
from abc import ABC, abstractmethod
1214
from collections import defaultdict
1315
from collections.abc import Mapping, Sequence
@@ -670,7 +672,12 @@ class SeasonResampler(Resampler):
670672

671673
def __post_init__(self):
672674
self.season_inds = season_to_month_tuple(self.seasons)
673-
self.season_tuples = dict(zip(self.seasons, self.season_inds, strict=False))
675+
all_inds = functools.reduce(operator.add, self.season_inds)
676+
if len(all_inds) > len(set(all_inds)):
677+
raise ValueError(
678+
f"Overlapping seasons are not allowed. Received {self.seasons!r}"
679+
)
680+
self.season_tuples = dict(zip(self.seasons, self.season_inds, strict=True))
674681

675682
def factorize(self, group):
676683
if group.ndim != 1:
@@ -696,12 +703,22 @@ def factorize(self, group):
696703
season_label[month.isin(season_ind)] = season_str
697704
if "DJ" in season_str:
698705
after_dec = season_ind[season_str.index("D") + 1 :]
706+
# important this is assuming non-overlapping seasons
699707
year[month.isin(after_dec)] -= 1
700708

709+
# Allow users to skip one or more months?
710+
# present_seasons is a mask that is True for months that are requestsed in the output
711+
present_seasons = season_label != ""
712+
if present_seasons.all():
713+
present_seasons = slice(None)
701714
frame = pd.DataFrame(
702-
data={"index": np.arange(group.size), "month": month},
715+
data={
716+
"index": np.arange(group[present_seasons].size),
717+
"month": month[present_seasons],
718+
},
703719
index=pd.MultiIndex.from_arrays(
704-
[year.data, season_label], names=["year", "season"]
720+
[year.data[present_seasons], season_label[present_seasons]],
721+
names=["year", "season"],
705722
),
706723
)
707724

@@ -727,19 +744,19 @@ def factorize(self, group):
727744

728745
sbins = first_items.values.astype(int)
729746
group_indices = [
730-
slice(i, j) for i, j in zip(sbins[:-1], sbins[1:], strict=False)
747+
slice(i, j) for i, j in zip(sbins[:-1], sbins[1:], strict=True)
731748
]
732749
group_indices += [slice(sbins[-1], None)]
733750

734751
# Make sure the first and last timestamps
735752
# are for the correct months,if not we have incomplete seasons
736753
unique_codes = np.arange(len(unique_coord))
737754
if self.drop_incomplete:
738-
for idx, slicer in zip([0, -1], (slice(1, None), slice(-1)), strict=False):
755+
for idx, slicer in zip([0, -1], (slice(1, None), slice(-1)), strict=True):
739756
stamp_year, stamp_season = frame.index[idx]
740757
code = seasons.index(stamp_season)
741758
stamp_month = season_inds[code][idx]
742-
if stamp_month != month[idx].item():
759+
if stamp_month != month[present_seasons][idx].item():
743760
# we have an incomplete season!
744761
group_indices = group_indices[slicer]
745762
unique_coord = unique_coord[slicer]
@@ -769,7 +786,9 @@ def factorize(self, group):
769786
if not full_index.equals(unique_coord):
770787
raise ValueError("Are there seasons missing in the middle of the dataset?")
771788

772-
codes = group.copy(data=np.repeat(unique_codes, counts), deep=False)
789+
final_codes = np.full(group.data.size, -1)
790+
final_codes[present_seasons] = np.repeat(unique_codes, counts)
791+
codes = group.copy(data=final_codes, deep=False)
773792
unique_coord_var = Variable(group.name, unique_coord, group.attrs)
774793

775794
return EncodedGroups(

xarray/tests/test_groupby.py

+3
Original file line numberDiff line numberDiff line change
@@ -2958,6 +2958,9 @@ def test_season_resampler():
29582958
# skip september
29592959
da.groupby(time=SeasonResampler(["DJF", "MAM", "JJA", "ON"])).sum()
29602960

2961+
# "subsampling"
2962+
da.groupby(time=SeasonResampler(["JJAS"])).sum()
2963+
29612964
# overlapping
29622965
with pytest.raises(ValueError):
29632966
da.groupby(time=SeasonResampler(["DJFM", "MAMJ", "JJAS", "SOND"])).sum()

0 commit comments

Comments
 (0)