14
14
from collections import defaultdict
15
15
from collections .abc import Mapping , Sequence
16
16
from dataclasses import dataclass , field
17
- from itertools import pairwise
18
- from itertools import chain
17
+ from itertools import chain , pairwise
19
18
from typing import TYPE_CHECKING , Any , Literal , cast
20
19
21
20
import numpy as np
25
24
from xarray .coding .cftime_offsets import BaseCFTimeOffset , _new_to_legacy_freq
26
25
from xarray .coding .cftimeindex import CFTimeIndex
27
26
from xarray .core import duck_array_ops
28
- from xarray .core .computation import apply_ufunc
29
- from xarray .core .coordinates import Coordinates , _coordinates_from_variable
30
- from xarray .core .coordinates import Coordinates
31
- from xarray .core .common import _contains_datetime_like_objects
32
- from xarray .core .common import _contains_datetime_like_objects
33
27
from xarray .core .common import (
34
28
_contains_cftime_datetimes ,
35
29
_contains_datetime_like_objects ,
36
30
)
37
- from xarray .core .coordinates import Coordinates
31
+ from xarray .core .computation import apply_ufunc
32
+ from xarray .core .coordinates import Coordinates , _coordinates_from_variable
38
33
from xarray .core .dataarray import DataArray
39
34
from xarray .core .duck_array_ops import isnull
40
35
from xarray .core .formatting import first_n_items
@@ -751,14 +746,16 @@ def __post_init__(self):
751
746
)
752
747
self .season_tuples = dict (zip (self .seasons , self .season_inds , strict = True ))
753
748
754
- def factorize (self , group ) :
749
+ def factorize (self , group : T_Group ) -> EncodedGroups :
755
750
if group .ndim != 1 :
756
751
raise ValueError (
757
752
"SeasonResampler can only be used to resample by 1D arrays."
758
753
)
759
- if not _contains_datetime_like_objects (group .variable ):
754
+ if not isinstance (group , DataArray ) or not _contains_datetime_like_objects (
755
+ group .variable
756
+ ):
760
757
raise ValueError (
761
- "SeasonResampler can only be used to group by datetime-like arrays ."
758
+ "SeasonResampler can only be used to group by datetime-like DataArrays ."
762
759
)
763
760
764
761
seasons = self .seasons
@@ -775,13 +772,14 @@ def factorize(self, group):
775
772
season_label [month .isin (season_ind )] = season_str
776
773
if "DJ" in season_str :
777
774
after_dec = season_ind [season_str .index ("D" ) + 1 :]
778
- # important this is assuming non-overlapping seasons
775
+ # important: this is assuming non-overlapping seasons
779
776
year [month .isin (after_dec )] -= 1
780
777
781
778
# Allow users to skip one or more months?
782
- # present_seasons is a mask that is True for months that are requestsed in the output
779
+ # present_seasons is a mask that is True for months that are requested in the output
783
780
present_seasons = season_label != ""
784
781
if present_seasons .all ():
782
+ # avoid copies if we can.
785
783
present_seasons = slice (None )
786
784
frame = pd .DataFrame (
787
785
data = {
@@ -794,10 +792,13 @@ def factorize(self, group):
794
792
),
795
793
)
796
794
797
- series = frame ["index" ]
798
- g = series .groupby (["year" , "season" ], sort = False )
799
- first_items = g .first ()
800
- counts = g .count ()
795
+ agged = (
796
+ frame ["index" ]
797
+ .groupby (["year" , "season" ], sort = False )
798
+ .agg (["first" , "count" ])
799
+ )
800
+ first_items = agged ["first" ]
801
+ counts = agged ["count" ]
801
802
802
803
if _contains_cftime_datetimes (group .data ):
803
804
index_class = CFTimeIndex
@@ -814,32 +815,18 @@ def factorize(self, group):
814
815
]
815
816
)
816
817
817
- sbins = first_items .values .astype (int )
818
- group_indices = [
819
- slice (i , j ) for i , j in zip (sbins [:- 1 ], sbins [1 :], strict = True )
820
- ]
821
- group_indices += [slice (sbins [- 1 ], None )]
822
-
823
- # Make sure the first and last timestamps
824
- # are for the correct months,if not we have incomplete seasons
825
- unique_codes = np .arange (len (unique_coord ))
826
- if self .drop_incomplete :
827
- for idx , slicer in zip ([0 , - 1 ], (slice (1 , None ), slice (- 1 )), strict = True ):
828
- stamp_year , stamp_season = frame .index [idx ]
829
- code = seasons .index (stamp_season )
830
- stamp_month = season_inds [code ][idx ]
831
- if stamp_month != month [present_seasons ][idx ].item ():
832
- # we have an incomplete season!
833
- group_indices = group_indices [slicer ]
834
- unique_coord = unique_coord [slicer ]
835
- if idx == 0 :
836
- unique_codes -= 1
837
- unique_codes [idx ] = - 1
838
-
839
- # all years and seasons
818
+ # sbins = first_items.values.astype(int)
819
+ # group_indices = [
820
+ # slice(i, j) for i, j in zip(sbins[:-1], sbins[1:], strict=True)
821
+ # ]
822
+ # group_indices += [slice(sbins[-1], None)]
823
+
824
+ # This sorted call is a hack. It's hard to figure out how
825
+ # to start the iteration for arbitrary season ordering
826
+ # for example "DJF" as first entry or last entry
827
+ # So we construct the largest possible index and slice it to the
828
+ # range present in the data.
840
829
complete_index = index_class (
841
- # This sorted call is a hack. It's hard to figure out how
842
- # to start the iteration
843
830
sorted (
844
831
[
845
832
datetime_class (year = y , month = m , day = 1 )
@@ -850,22 +837,56 @@ def factorize(self, group):
850
837
]
851
838
)
852
839
)
853
- # only keep that included in data
854
- range_ = complete_index .get_indexer (unique_coord [[0 , - 1 ]])
855
- full_index = complete_index [slice (range_ [0 ], range_ [- 1 ] + 1 )]
840
+
841
+ # all years and seasons
842
+ def get_label (year , season ):
843
+ month = season_tuples [season ][0 ]
844
+ return f"{ year } -{ month } -01"
845
+
846
+ unique_codes = np .arange (len (unique_coord ))
847
+ first_valid_season = season_label [0 ]
848
+ last_valid_season = season_label [- 1 ]
849
+ first_year , last_year = year .data [[0 , - 1 ]]
850
+ if self .drop_incomplete :
851
+ if month .data [0 ] != season_tuples [first_valid_season ][0 ]:
852
+ if "DJ" in first_valid_season :
853
+ first_year += 1
854
+ first_valid_season = seasons [
855
+ (seasons .index (first_valid_season ) + 1 ) % len (seasons )
856
+ ]
857
+ # group_indices = group_indices[slice(1, None)]
858
+ unique_codes -= 1
859
+
860
+ if month .data [- 1 ] != season_tuples [last_valid_season ][- 1 ]:
861
+ last_valid_season = seasons [seasons .index (last_valid_season ) - 1 ]
862
+ if "DJ" in last_valid_season :
863
+ last_year -= 1
864
+ # group_indices = group_indices[slice(-1)]
865
+ unique_codes [- 1 ] = - 1
866
+
867
+ first_label = get_label (first_year , first_valid_season )
868
+ last_label = get_label (last_year , last_valid_season )
869
+
870
+ slicer = complete_index .slice_indexer (first_label , last_label )
871
+ full_index = complete_index [slicer ]
872
+ # TODO: group must be sorted
873
+ # codes = np.searchsorted(edges, group.data, side="left")
874
+ # codes -= 1
875
+ # codes[~present_seasons | group.data >= edges[-1]] = -1
876
+ # codes[isnull(group.data)] = -1
877
+ # import ipdb; ipdb.set_trace()
856
878
# check that there are no "missing" seasons in the middle
857
- # print(full_index, unique_coord)
858
- if not full_index .equals (unique_coord ):
859
- raise ValueError ("Are there seasons missing in the middle of the dataset?" )
879
+ # if not full_index.equals(unique_coord):
880
+ # raise ValueError("Are there seasons missing in the middle of the dataset?")
860
881
861
882
final_codes = np .full (group .data .size , - 1 )
862
883
final_codes [present_seasons ] = np .repeat (unique_codes , counts )
863
884
codes = group .copy (data = final_codes , deep = False )
864
- unique_coord_var = Variable (group .name , unique_coord , group .attrs )
885
+ # unique_coord_var = Variable(group.name, unique_coord, group.attrs)
865
886
866
887
return EncodedGroups (
867
888
codes = codes ,
868
- group_indices = group_indices ,
869
- unique_coord = unique_coord_var ,
889
+ # group_indices=group_indices,
890
+ # unique_coord=unique_coord_var,
870
891
full_index = full_index ,
871
892
)
0 commit comments