|
7 | 7 | from __future__ import annotations
|
8 | 8 |
|
9 | 9 | import datetime
|
| 10 | +import itertools |
10 | 11 | from abc import ABC, abstractmethod
|
| 12 | +from collections.abc import Mapping, Sequence |
11 | 13 | from dataclasses import dataclass, field
|
12 | 14 | from typing import TYPE_CHECKING, Any, Literal, cast
|
13 | 15 |
|
|
16 | 18 |
|
17 | 19 | from xarray.coding.cftime_offsets import BaseCFTimeOffset, _new_to_legacy_freq
|
18 | 20 | from xarray.core import duck_array_ops
|
| 21 | +from xarray.core.common import _contains_datetime_like_objects |
19 | 22 | from xarray.core.coordinates import Coordinates
|
20 | 23 | from xarray.core.dataarray import DataArray
|
21 | 24 | from xarray.core.groupby import T_Group, _DummyGroup
|
22 | 25 | from xarray.core.indexes import safe_cast_to_index
|
23 | 26 | from xarray.core.resample_cftime import CFTimeGrouper
|
| 27 | +from xarray.core.toolzcompat import sliding_window |
24 | 28 | from xarray.core.types import (
|
25 | 29 | Bins,
|
26 | 30 | DatetimeLike,
|
@@ -485,3 +489,217 @@ def unique_value_groups(
|
485 | 489 | if isinstance(values, pd.MultiIndex):
|
486 | 490 | values.names = ar.names
|
487 | 491 | return values, inverse
|
| 492 | + |
| 493 | + |
| 494 | +def season_to_month_tuple(seasons: Sequence[str]) -> tuple[tuple[int, ...], ...]: |
| 495 | + initials = "JFMAMJJASOND" |
| 496 | + starts = dict( |
| 497 | + ("".join(s), i + 1) |
| 498 | + for s, i in zip(sliding_window(2, initials + "J"), range(12), strict=False) |
| 499 | + ) |
| 500 | + result: list[tuple[int, ...]] = [] |
| 501 | + for i, season in enumerate(seasons): |
| 502 | + if len(season) == 1: |
| 503 | + if i < len(seasons) - 1: |
| 504 | + suffix = seasons[i + 1][0] |
| 505 | + else: |
| 506 | + suffix = seasons[0][0] |
| 507 | + else: |
| 508 | + suffix = season[1] |
| 509 | + |
| 510 | + start = starts[season[0] + suffix] |
| 511 | + |
| 512 | + month_append = [] |
| 513 | + for i in range(len(season[1:])): |
| 514 | + elem = start + i + 1 |
| 515 | + month_append.append(elem - 12 * (elem > 12)) |
| 516 | + result.append((start,) + tuple(month_append)) |
| 517 | + return tuple(result) |
| 518 | + |
| 519 | + |
| 520 | +@dataclass |
| 521 | +class SeasonGrouper(Grouper): |
| 522 | + """Allows grouping using a custom definition of seasons. |
| 523 | +
|
| 524 | + Parameters |
| 525 | + ---------- |
| 526 | + seasons: sequence of str |
| 527 | + List of strings representing seasons. E.g. ``"JF"`` or ``"JJA"`` etc. |
| 528 | +
|
| 529 | + Examples |
| 530 | + -------- |
| 531 | + >>> SeasonGrouper(["JF", "MAM", "JJAS", "OND"]) |
| 532 | + >>> SeasonGrouper(["DJFM", "AM", "JJA", "SON"]) |
| 533 | + """ |
| 534 | + |
| 535 | + seasons: Sequence[str] |
| 536 | + season_inds: Sequence[Sequence[int]] = field(init=False, repr=False) |
| 537 | + # drop_incomplete: bool = field(default=True) # TODO |
| 538 | + |
| 539 | + def __post_init__(self) -> None: |
| 540 | + self.season_inds = season_to_month_tuple(self.seasons) |
| 541 | + |
| 542 | + def factorize(self, group: T_Group) -> EncodedGroups: |
| 543 | + if TYPE_CHECKING: |
| 544 | + assert not isinstance(group, _DummyGroup) |
| 545 | + if not _contains_datetime_like_objects(group.variable): |
| 546 | + raise ValueError( |
| 547 | + "SeasonGrouper can only be used to group by datetime-like arrays." |
| 548 | + ) |
| 549 | + |
| 550 | + seasons = self.seasons |
| 551 | + season_inds = self.season_inds |
| 552 | + |
| 553 | + months = group.dt.month |
| 554 | + codes_ = np.full(group.shape, -1) |
| 555 | + group_indices: list[list[int]] = [[]] * len(seasons) |
| 556 | + |
| 557 | + index = np.arange(group.size) |
| 558 | + for idx, season_tuple in enumerate(season_inds): |
| 559 | + mask = months.isin(season_tuple) |
| 560 | + codes_[mask] = idx |
| 561 | + group_indices[idx] = index[mask] |
| 562 | + |
| 563 | + if np.all(codes_ == -1): |
| 564 | + raise ValueError( |
| 565 | + "Failed to group data. Are you grouping by a variable that is all NaN?" |
| 566 | + ) |
| 567 | + codes = group.copy(data=codes_, deep=False).rename("season") |
| 568 | + unique_coord = Variable("season", seasons, attrs=group.attrs) |
| 569 | + full_index = pd.Index(seasons) |
| 570 | + return EncodedGroups( |
| 571 | + codes=codes, |
| 572 | + group_indices=tuple(group_indices), |
| 573 | + unique_coord=unique_coord, |
| 574 | + full_index=full_index, |
| 575 | + ) |
| 576 | + |
| 577 | + |
| 578 | +@dataclass |
| 579 | +class SeasonResampler(Resampler): |
| 580 | + """Allows grouping using a custom definition of seasons. |
| 581 | +
|
| 582 | + Parameters |
| 583 | + ---------- |
| 584 | + seasons: Sequence[str] |
| 585 | + An ordered list of seasons. |
| 586 | + drop_incomplete: bool |
| 587 | + Whether to drop seasons that are not completely included in the data. |
| 588 | + For example, if a time series starts in Jan-2001, and seasons includes `"DJF"` |
| 589 | + then observations from Jan-2001, and Feb-2001 are ignored in the grouping |
| 590 | + since Dec-2000 isn't present. |
| 591 | +
|
| 592 | + Examples |
| 593 | + -------- |
| 594 | + >>> SeasonResampler(["JF", "MAM", "JJAS", "OND"]) |
| 595 | + >>> SeasonResampler(["DJFM", "AM", "JJA", "SON"]) |
| 596 | + """ |
| 597 | + |
| 598 | + seasons: Sequence[str] |
| 599 | + drop_incomplete: bool = field(default=True) |
| 600 | + season_inds: Sequence[Sequence[int]] = field(init=False, repr=False) |
| 601 | + season_tuples: Mapping[str, Sequence[int]] = field(init=False, repr=False) |
| 602 | + |
| 603 | + def __post_init__(self): |
| 604 | + self.season_inds = season_to_month_tuple(self.seasons) |
| 605 | + self.season_tuples = dict(zip(self.seasons, self.season_inds, strict=False)) |
| 606 | + |
| 607 | + def factorize(self, group): |
| 608 | + if group.ndim != 1: |
| 609 | + raise ValueError( |
| 610 | + "SeasonResampler can only be used to resample by 1D arrays." |
| 611 | + ) |
| 612 | + if not _contains_datetime_like_objects(group.variable): |
| 613 | + raise ValueError( |
| 614 | + "SeasonResampler can only be used to group by datetime-like arrays." |
| 615 | + ) |
| 616 | + |
| 617 | + seasons = self.seasons |
| 618 | + season_inds = self.season_inds |
| 619 | + season_tuples = self.season_tuples |
| 620 | + |
| 621 | + nstr = max(len(s) for s in seasons) |
| 622 | + year = group.dt.year.astype(int) |
| 623 | + month = group.dt.month.astype(int) |
| 624 | + season_label = np.full(group.shape, "", dtype=f"U{nstr}") |
| 625 | + |
| 626 | + # offset years for seasons with December and January |
| 627 | + for season_str, season_ind in zip(seasons, season_inds, strict=False): |
| 628 | + season_label[month.isin(season_ind)] = season_str |
| 629 | + if "DJ" in season_str: |
| 630 | + after_dec = season_ind[season_str.index("D") + 1 :] |
| 631 | + year[month.isin(after_dec)] -= 1 |
| 632 | + |
| 633 | + frame = pd.DataFrame( |
| 634 | + data={"index": np.arange(group.size), "month": month}, |
| 635 | + index=pd.MultiIndex.from_arrays( |
| 636 | + [year.data, season_label], names=["year", "season"] |
| 637 | + ), |
| 638 | + ) |
| 639 | + |
| 640 | + series = frame["index"] |
| 641 | + g = series.groupby(["year", "season"], sort=False) |
| 642 | + first_items = g.first() |
| 643 | + counts = g.count() |
| 644 | + |
| 645 | + # these are the seasons that are present |
| 646 | + unique_coord = pd.DatetimeIndex( |
| 647 | + [ |
| 648 | + pd.Timestamp(year=year, month=season_tuples[season][0], day=1) |
| 649 | + for year, season in first_items.index |
| 650 | + ] |
| 651 | + ) |
| 652 | + |
| 653 | + sbins = first_items.values.astype(int) |
| 654 | + group_indices = [ |
| 655 | + slice(i, j) for i, j in zip(sbins[:-1], sbins[1:], strict=False) |
| 656 | + ] |
| 657 | + group_indices += [slice(sbins[-1], None)] |
| 658 | + |
| 659 | + # Make sure the first and last timestamps |
| 660 | + # are for the correct months,if not we have incomplete seasons |
| 661 | + unique_codes = np.arange(len(unique_coord)) |
| 662 | + if self.drop_incomplete: |
| 663 | + for idx, slicer in zip([0, -1], (slice(1, None), slice(-1)), strict=False): |
| 664 | + stamp_year, stamp_season = frame.index[idx] |
| 665 | + code = seasons.index(stamp_season) |
| 666 | + stamp_month = season_inds[code][idx] |
| 667 | + if stamp_month != month[idx].item(): |
| 668 | + # we have an incomplete season! |
| 669 | + group_indices = group_indices[slicer] |
| 670 | + unique_coord = unique_coord[slicer] |
| 671 | + if idx == 0: |
| 672 | + unique_codes -= 1 |
| 673 | + unique_codes[idx] = -1 |
| 674 | + |
| 675 | + # all years and seasons |
| 676 | + complete_index = pd.DatetimeIndex( |
| 677 | + # This sorted call is a hack. It's hard to figure out how |
| 678 | + # to start the iteration |
| 679 | + sorted( |
| 680 | + [ |
| 681 | + pd.Timestamp(f"{y}-{m}-01") |
| 682 | + for y, m in itertools.product( |
| 683 | + range(year[0].item(), year[-1].item() + 1), |
| 684 | + [s[0] for s in season_inds], |
| 685 | + ) |
| 686 | + ] |
| 687 | + ) |
| 688 | + ) |
| 689 | + # only keep that included in data |
| 690 | + range_ = complete_index.get_indexer(unique_coord[[0, -1]]) |
| 691 | + full_index = complete_index[slice(range_[0], range_[-1] + 1)] |
| 692 | + # check that there are no "missing" seasons in the middle |
| 693 | + # print(full_index, unique_coord) |
| 694 | + if not full_index.equals(unique_coord): |
| 695 | + raise ValueError("Are there seasons missing in the middle of the dataset?") |
| 696 | + |
| 697 | + codes = group.copy(data=np.repeat(unique_codes, counts), deep=False) |
| 698 | + unique_coord_var = Variable(group.name, unique_coord, group.attrs) |
| 699 | + |
| 700 | + return EncodedGroups( |
| 701 | + codes=codes, |
| 702 | + group_indices=group_indices, |
| 703 | + unique_coord=unique_coord_var, |
| 704 | + full_index=full_index, |
| 705 | + ) |
0 commit comments