7
7
from __future__ import annotations
8
8
9
9
import datetime
10
+ import functools
10
11
import itertools
12
+ import operator
11
13
from abc import ABC , abstractmethod
12
14
from collections import defaultdict
13
15
from collections .abc import Mapping , Sequence
@@ -670,7 +672,12 @@ class SeasonResampler(Resampler):
670
672
671
673
def __post_init__ (self ):
672
674
self .season_inds = season_to_month_tuple (self .seasons )
673
- self .season_tuples = dict (zip (self .seasons , self .season_inds , strict = False ))
675
+ all_inds = functools .reduce (operator .add , self .season_inds )
676
+ if len (all_inds ) > len (set (all_inds )):
677
+ raise ValueError (
678
+ f"Overlapping seasons are not allowed. Received { self .seasons !r} "
679
+ )
680
+ self .season_tuples = dict (zip (self .seasons , self .season_inds , strict = True ))
674
681
675
682
def factorize (self , group ):
676
683
if group .ndim != 1 :
@@ -696,12 +703,22 @@ def factorize(self, group):
696
703
season_label [month .isin (season_ind )] = season_str
697
704
if "DJ" in season_str :
698
705
after_dec = season_ind [season_str .index ("D" ) + 1 :]
706
+ # important this is assuming non-overlapping seasons
699
707
year [month .isin (after_dec )] -= 1
700
708
709
+ # Allow users to skip one or more months?
710
+ # present_seasons is a mask that is True for months that are requestsed in the output
711
+ present_seasons = season_label != ""
712
+ if present_seasons .all ():
713
+ present_seasons = slice (None )
701
714
frame = pd .DataFrame (
702
- data = {"index" : np .arange (group .size ), "month" : month },
715
+ data = {
716
+ "index" : np .arange (group [present_seasons ].size ),
717
+ "month" : month [present_seasons ],
718
+ },
703
719
index = pd .MultiIndex .from_arrays (
704
- [year .data , season_label ], names = ["year" , "season" ]
720
+ [year .data [present_seasons ], season_label [present_seasons ]],
721
+ names = ["year" , "season" ],
705
722
),
706
723
)
707
724
@@ -727,19 +744,19 @@ def factorize(self, group):
727
744
728
745
sbins = first_items .values .astype (int )
729
746
group_indices = [
730
- slice (i , j ) for i , j in zip (sbins [:- 1 ], sbins [1 :], strict = False )
747
+ slice (i , j ) for i , j in zip (sbins [:- 1 ], sbins [1 :], strict = True )
731
748
]
732
749
group_indices += [slice (sbins [- 1 ], None )]
733
750
734
751
# Make sure the first and last timestamps
735
752
# are for the correct months,if not we have incomplete seasons
736
753
unique_codes = np .arange (len (unique_coord ))
737
754
if self .drop_incomplete :
738
- for idx , slicer in zip ([0 , - 1 ], (slice (1 , None ), slice (- 1 )), strict = False ):
755
+ for idx , slicer in zip ([0 , - 1 ], (slice (1 , None ), slice (- 1 )), strict = True ):
739
756
stamp_year , stamp_season = frame .index [idx ]
740
757
code = seasons .index (stamp_season )
741
758
stamp_month = season_inds [code ][idx ]
742
- if stamp_month != month [idx ].item ():
759
+ if stamp_month != month [present_seasons ][ idx ].item ():
743
760
# we have an incomplete season!
744
761
group_indices = group_indices [slicer ]
745
762
unique_coord = unique_coord [slicer ]
@@ -769,7 +786,9 @@ def factorize(self, group):
769
786
if not full_index .equals (unique_coord ):
770
787
raise ValueError ("Are there seasons missing in the middle of the dataset?" )
771
788
772
- codes = group .copy (data = np .repeat (unique_codes , counts ), deep = False )
789
+ final_codes = np .full (group .data .size , - 1 )
790
+ final_codes [present_seasons ] = np .repeat (unique_codes , counts )
791
+ codes = group .copy (data = final_codes , deep = False )
773
792
unique_coord_var = Variable (group .name , unique_coord , group .attrs )
774
793
775
794
return EncodedGroups (
0 commit comments