Skip to content

Commit 594d4a7

Browse files
committed
Add SeasonGrouper, SeasonResampler
These two groupers allow defining custom seasons, and dropping incomplete seasons from the output. Both cases are treated by adjusting the factorization -- conversion from group labels to integer codes -- appropriately.
1 parent 3c74509 commit 594d4a7

File tree

5 files changed

+329
-6
lines changed

5 files changed

+329
-6
lines changed

doc/api.rst

+2
Original file line numberDiff line numberDiff line change
@@ -1126,6 +1126,8 @@ Grouper Objects
11261126
groupers.BinGrouper
11271127
groupers.UniqueGrouper
11281128
groupers.TimeResampler
1129+
groupers.SeasonGrouper
1130+
groupers.SeasonResampler
11291131

11301132

11311133
Rolling objects

properties/test_properties.py

+31
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,12 @@
22

33
pytest.importorskip("hypothesis")
44

5+
import hypothesis.strategies as st
56
from hypothesis import given
67

78
import xarray as xr
89
import xarray.testing.strategies as xrst
10+
from xarray.groupers import season_to_month_tuple
911

1012

1113
@given(attrs=xrst.simple_attrs)
@@ -15,3 +17,32 @@ def test_assert_identical(attrs):
1517

1618
ds = xr.Dataset(attrs=attrs)
1719
xr.testing.assert_identical(ds, ds.copy(deep=True))
20+
21+
22+
@given(
23+
roll=st.integers(min_value=0, max_value=12),
24+
breaks=st.lists(
25+
st.integers(min_value=0, max_value=11), min_size=1, max_size=12, unique=True
26+
),
27+
)
28+
def test_property_season_month_tuple(roll, breaks):
29+
chars = list("JFMAMJJASOND")
30+
months = tuple(range(1, 13))
31+
32+
rolled_chars = chars[roll:] + chars[:roll]
33+
rolled_months = months[roll:] + months[:roll]
34+
breaks = sorted(breaks)
35+
if breaks[0] != 0:
36+
breaks = [0] + breaks
37+
if breaks[-1] != 12:
38+
breaks = breaks + [12]
39+
seasons = tuple(
40+
"".join(rolled_chars[start:stop])
41+
for start, stop in zip(breaks[:-1], breaks[1:], strict=False)
42+
)
43+
actual = season_to_month_tuple(seasons)
44+
expected = tuple(
45+
rolled_months[start:stop]
46+
for start, stop in zip(breaks[:-1], breaks[1:], strict=False)
47+
)
48+
assert expected == actual

xarray/core/toolzcompat.py

+56
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
# This file contains functions copied from the toolz library in accordance
2+
# with its license. The original copyright notice is duplicated below.
3+
4+
# Copyright (c) 2013 Matthew Rocklin
5+
6+
# All rights reserved.
7+
8+
# Redistribution and use in source and binary forms, with or without
9+
# modification, are permitted provided that the following conditions are met:
10+
11+
# a. Redistributions of source code must retain the above copyright notice,
12+
# this list of conditions and the following disclaimer.
13+
# b. Redistributions in binary form must reproduce the above copyright
14+
# notice, this list of conditions and the following disclaimer in the
15+
# documentation and/or other materials provided with the distribution.
16+
# c. Neither the name of toolz nor the names of its contributors
17+
# may be used to endorse or promote products derived from this software
18+
# without specific prior written permission.
19+
20+
21+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
22+
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
23+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
24+
# ARE DISCLAIMED. IN NO EVENT SHALL THE REGENTS OR CONTRIBUTORS BE LIABLE FOR
25+
# ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
26+
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
27+
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
28+
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
29+
# LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY
30+
# OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH
31+
# DAMAGE.
32+
33+
34+
def sliding_window(n, seq):
35+
"""A sequence of overlapping subsequences
36+
37+
>>> list(sliding_window(2, [1, 2, 3, 4]))
38+
[(1, 2), (2, 3), (3, 4)]
39+
40+
This function creates a sliding window suitable for transformations like
41+
sliding means / smoothing
42+
43+
>>> mean = lambda seq: float(sum(seq)) / len(seq)
44+
>>> list(map(mean, sliding_window(2, [1, 2, 3, 4])))
45+
[1.5, 2.5, 3.5]
46+
"""
47+
import collections
48+
import itertools
49+
50+
return zip(
51+
*(
52+
collections.deque(itertools.islice(it, i), 0) or it
53+
for i, it in enumerate(itertools.tee(seq, n))
54+
),
55+
strict=False,
56+
)

xarray/groupers.py

+218
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,9 @@
77
from __future__ import annotations
88

99
import datetime
10+
import itertools
1011
from abc import ABC, abstractmethod
12+
from collections.abc import Mapping, Sequence
1113
from dataclasses import dataclass, field
1214
from typing import TYPE_CHECKING, Any, Literal, cast
1315

@@ -16,11 +18,13 @@
1618

1719
from xarray.coding.cftime_offsets import BaseCFTimeOffset, _new_to_legacy_freq
1820
from xarray.core import duck_array_ops
21+
from xarray.core.common import _contains_datetime_like_objects
1922
from xarray.core.coordinates import Coordinates
2023
from xarray.core.dataarray import DataArray
2124
from xarray.core.groupby import T_Group, _DummyGroup
2225
from xarray.core.indexes import safe_cast_to_index
2326
from xarray.core.resample_cftime import CFTimeGrouper
27+
from xarray.core.toolzcompat import sliding_window
2428
from xarray.core.types import (
2529
Bins,
2630
DatetimeLike,
@@ -485,3 +489,217 @@ def unique_value_groups(
485489
if isinstance(values, pd.MultiIndex):
486490
values.names = ar.names
487491
return values, inverse
492+
493+
494+
def season_to_month_tuple(seasons: Sequence[str]) -> tuple[tuple[int, ...], ...]:
495+
initials = "JFMAMJJASOND"
496+
starts = dict(
497+
("".join(s), i + 1)
498+
for s, i in zip(sliding_window(2, initials + "J"), range(12), strict=False)
499+
)
500+
result: list[tuple[int, ...]] = []
501+
for i, season in enumerate(seasons):
502+
if len(season) == 1:
503+
if i < len(seasons) - 1:
504+
suffix = seasons[i + 1][0]
505+
else:
506+
suffix = seasons[0][0]
507+
else:
508+
suffix = season[1]
509+
510+
start = starts[season[0] + suffix]
511+
512+
month_append = []
513+
for i in range(len(season[1:])):
514+
elem = start + i + 1
515+
month_append.append(elem - 12 * (elem > 12))
516+
result.append((start,) + tuple(month_append))
517+
return tuple(result)
518+
519+
520+
@dataclass
521+
class SeasonGrouper(Grouper):
522+
"""Allows grouping using a custom definition of seasons.
523+
524+
Parameters
525+
----------
526+
seasons: sequence of str
527+
List of strings representing seasons. E.g. ``"JF"`` or ``"JJA"`` etc.
528+
529+
Examples
530+
--------
531+
>>> SeasonGrouper(["JF", "MAM", "JJAS", "OND"])
532+
>>> SeasonGrouper(["DJFM", "AM", "JJA", "SON"])
533+
"""
534+
535+
seasons: Sequence[str]
536+
season_inds: Sequence[Sequence[int]] = field(init=False, repr=False)
537+
# drop_incomplete: bool = field(default=True) # TODO
538+
539+
def __post_init__(self) -> None:
540+
self.season_inds = season_to_month_tuple(self.seasons)
541+
542+
def factorize(self, group: T_Group) -> EncodedGroups:
543+
if TYPE_CHECKING:
544+
assert not isinstance(group, _DummyGroup)
545+
if not _contains_datetime_like_objects(group.variable):
546+
raise ValueError(
547+
"SeasonGrouper can only be used to group by datetime-like arrays."
548+
)
549+
550+
seasons = self.seasons
551+
season_inds = self.season_inds
552+
553+
months = group.dt.month
554+
codes_ = np.full(group.shape, -1)
555+
group_indices: list[list[int]] = [[]] * len(seasons)
556+
557+
index = np.arange(group.size)
558+
for idx, season_tuple in enumerate(season_inds):
559+
mask = months.isin(season_tuple)
560+
codes_[mask] = idx
561+
group_indices[idx] = index[mask]
562+
563+
if np.all(codes_ == -1):
564+
raise ValueError(
565+
"Failed to group data. Are you grouping by a variable that is all NaN?"
566+
)
567+
codes = group.copy(data=codes_, deep=False).rename("season")
568+
unique_coord = Variable("season", seasons, attrs=group.attrs)
569+
full_index = pd.Index(seasons)
570+
return EncodedGroups(
571+
codes=codes,
572+
group_indices=tuple(group_indices),
573+
unique_coord=unique_coord,
574+
full_index=full_index,
575+
)
576+
577+
578+
@dataclass
579+
class SeasonResampler(Resampler):
580+
"""Allows grouping using a custom definition of seasons.
581+
582+
Parameters
583+
----------
584+
seasons: Sequence[str]
585+
An ordered list of seasons.
586+
drop_incomplete: bool
587+
Whether to drop seasons that are not completely included in the data.
588+
For example, if a time series starts in Jan-2001, and seasons includes `"DJF"`
589+
then observations from Jan-2001, and Feb-2001 are ignored in the grouping
590+
since Dec-2000 isn't present.
591+
592+
Examples
593+
--------
594+
>>> SeasonResampler(["JF", "MAM", "JJAS", "OND"])
595+
>>> SeasonResampler(["DJFM", "AM", "JJA", "SON"])
596+
"""
597+
598+
seasons: Sequence[str]
599+
drop_incomplete: bool = field(default=True)
600+
season_inds: Sequence[Sequence[int]] = field(init=False, repr=False)
601+
season_tuples: Mapping[str, Sequence[int]] = field(init=False, repr=False)
602+
603+
def __post_init__(self):
604+
self.season_inds = season_to_month_tuple(self.seasons)
605+
self.season_tuples = dict(zip(self.seasons, self.season_inds, strict=False))
606+
607+
def factorize(self, group):
608+
if group.ndim != 1:
609+
raise ValueError(
610+
"SeasonResampler can only be used to resample by 1D arrays."
611+
)
612+
if not _contains_datetime_like_objects(group.variable):
613+
raise ValueError(
614+
"SeasonResampler can only be used to group by datetime-like arrays."
615+
)
616+
617+
seasons = self.seasons
618+
season_inds = self.season_inds
619+
season_tuples = self.season_tuples
620+
621+
nstr = max(len(s) for s in seasons)
622+
year = group.dt.year.astype(int)
623+
month = group.dt.month.astype(int)
624+
season_label = np.full(group.shape, "", dtype=f"U{nstr}")
625+
626+
# offset years for seasons with December and January
627+
for season_str, season_ind in zip(seasons, season_inds, strict=False):
628+
season_label[month.isin(season_ind)] = season_str
629+
if "DJ" in season_str:
630+
after_dec = season_ind[season_str.index("D") + 1 :]
631+
year[month.isin(after_dec)] -= 1
632+
633+
frame = pd.DataFrame(
634+
data={"index": np.arange(group.size), "month": month},
635+
index=pd.MultiIndex.from_arrays(
636+
[year.data, season_label], names=["year", "season"]
637+
),
638+
)
639+
640+
series = frame["index"]
641+
g = series.groupby(["year", "season"], sort=False)
642+
first_items = g.first()
643+
counts = g.count()
644+
645+
# these are the seasons that are present
646+
unique_coord = pd.DatetimeIndex(
647+
[
648+
pd.Timestamp(year=year, month=season_tuples[season][0], day=1)
649+
for year, season in first_items.index
650+
]
651+
)
652+
653+
sbins = first_items.values.astype(int)
654+
group_indices = [
655+
slice(i, j) for i, j in zip(sbins[:-1], sbins[1:], strict=False)
656+
]
657+
group_indices += [slice(sbins[-1], None)]
658+
659+
# Make sure the first and last timestamps
660+
# are for the correct months,if not we have incomplete seasons
661+
unique_codes = np.arange(len(unique_coord))
662+
if self.drop_incomplete:
663+
for idx, slicer in zip([0, -1], (slice(1, None), slice(-1)), strict=False):
664+
stamp_year, stamp_season = frame.index[idx]
665+
code = seasons.index(stamp_season)
666+
stamp_month = season_inds[code][idx]
667+
if stamp_month != month[idx].item():
668+
# we have an incomplete season!
669+
group_indices = group_indices[slicer]
670+
unique_coord = unique_coord[slicer]
671+
if idx == 0:
672+
unique_codes -= 1
673+
unique_codes[idx] = -1
674+
675+
# all years and seasons
676+
complete_index = pd.DatetimeIndex(
677+
# This sorted call is a hack. It's hard to figure out how
678+
# to start the iteration
679+
sorted(
680+
[
681+
pd.Timestamp(f"{y}-{m}-01")
682+
for y, m in itertools.product(
683+
range(year[0].item(), year[-1].item() + 1),
684+
[s[0] for s in season_inds],
685+
)
686+
]
687+
)
688+
)
689+
# only keep that included in data
690+
range_ = complete_index.get_indexer(unique_coord[[0, -1]])
691+
full_index = complete_index[slice(range_[0], range_[-1] + 1)]
692+
# check that there are no "missing" seasons in the middle
693+
# print(full_index, unique_coord)
694+
if not full_index.equals(unique_coord):
695+
raise ValueError("Are there seasons missing in the middle of the dataset?")
696+
697+
codes = group.copy(data=np.repeat(unique_codes, counts), deep=False)
698+
unique_coord_var = Variable(group.name, unique_coord, group.attrs)
699+
700+
return EncodedGroups(
701+
codes=codes,
702+
group_indices=group_indices,
703+
unique_coord=unique_coord_var,
704+
full_index=full_index,
705+
)

xarray/tests/test_groupby.py

+22-6
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
Grouper,
2222
TimeResampler,
2323
UniqueGrouper,
24+
season_to_month_tuple,
2425
)
2526
from xarray.tests import (
2627
InaccessibleArray,
@@ -2915,12 +2916,6 @@ def test_gappy_resample_reductions(reduction):
29152916
assert_identical(expected, actual)
29162917

29172918

2918-
# Possible property tests
2919-
# 1. lambda x: x
2920-
# 2. grouped-reduce on unique coords is identical to array
2921-
# 3. group_over == groupby-reduce along other dimensions
2922-
2923-
29242919
def test_groupby_transpose():
29252920
# GH5361
29262921
data = xr.DataArray(
@@ -2932,3 +2927,24 @@ def test_groupby_transpose():
29322927
second = data.groupby("x").sum()
29332928

29342929
assert_identical(first, second.transpose(*first.dims))
2930+
2931+
2932+
def test_season_to_month_tuple():
2933+
assert season_to_month_tuple(["JF", "MAM", "JJAS", "OND"]) == (
2934+
(1, 2),
2935+
(3, 4, 5),
2936+
(6, 7, 8, 9),
2937+
(10, 11, 12),
2938+
)
2939+
assert season_to_month_tuple(["DJFM", "AM", "JJAS", "ON"]) == (
2940+
(12, 1, 2, 3),
2941+
(4, 5),
2942+
(6, 7, 8, 9),
2943+
(10, 11),
2944+
)
2945+
2946+
2947+
# Possible property tests
2948+
# 1. lambda x: x
2949+
# 2. grouped-reduce on unique coords is identical to array
2950+
# 3. group_over == groupby-reduce along other dimensions

0 commit comments

Comments
 (0)