12
12
from xarray .core .extensions import AccessorRegistrationWarning
13
13
from ..earthdatastore .cube_utils import GeometryManager
14
14
15
- warnings .filterwarnings ("ignore" , category = AccessorRegistrationWarning )
16
-
17
-
18
- class MisType (Warning ):
19
- pass
20
-
21
-
22
- def _typer (raise_mistype = False , custom_types = {}):
23
- """
24
-
25
- Parameters
26
- ----------
27
- raise_mistype : TYPE, optional
28
- DESCRIPTION. The default is False.
29
- custom_types : TYPE, optional
30
- DESCRIPTION. The default is {}.
31
- Example : {
32
- np.ndarray:{"func":np.asarray},
33
- xr.Dataset:{"func":xr.DataArray.to_dataset,"kwargs":{"dim":"band"}}
34
- }
35
- Raises
36
- ------
37
- MisType
38
- DESCRIPTION.
39
-
40
- Returns
41
- -------
42
- TYPE
43
- DESCRIPTION.
44
-
45
- """
46
-
47
- def decorator (func ):
48
- def force (* args , ** kwargs ):
49
- _args = list (args )
50
- for key , vals in func .__annotations__ .items ():
51
- if not isinstance (vals , (list , tuple )):
52
- vals = [vals ]
53
- val = vals [0 ]
54
- idx = func .__code__ .co_varnames .index (key )
55
- is_kwargs = key in kwargs .keys ()
56
- if not is_kwargs and idx >= len (_args ):
57
- break
58
- value = kwargs .get (key , None ) if is_kwargs else args [idx ]
59
- if type (value ) in vals :
60
- continue
61
- if (
62
- type (kwargs .get (key )) not in vals
63
- if is_kwargs
64
- else type (args [idx ]) not in vals
65
- ):
66
- if raise_mistype :
67
- if is_kwargs :
68
- expected = f"{ type (kwargs [key ]).__name__ } ({ kwargs [key ]} )"
69
- else :
70
- expected = f"{ type (args [idx ]).__name__ } ({ args [idx ]} )"
71
- raise MisType (f"{ key } expected { val .__name__ } , not { expected } ." )
72
- if any (val == k for k in custom_types .keys ()):
73
- exp = custom_types [val ]
74
- var = args [idx ]
75
- res = exp ["func" ](var , ** exp .get ("kwargs" , {}))
76
- if is_kwargs :
77
- kwargs [key ] = res
78
- else :
79
- _args [idx ] = res
80
- elif is_kwargs :
81
- kwargs [key ] = (
82
- val (kwargs [key ]) if val is not list else [kwargs [key ]]
83
- )
84
- else :
85
- _args [idx ] = val (args [idx ]) if val is not list else [args [idx ]]
86
- args = tuple (_args )
87
- return func (* args , ** kwargs )
88
-
89
- return force
90
-
91
- return decorator
92
-
93
-
94
- @_typer ()
15
+
95
16
def xr_loop_func (
96
17
dataset : xr .Dataset ,
97
18
func ,
@@ -122,7 +43,6 @@ def _xr_loop_func(dataset, metafunc, loop_dimension, **kwargs):
122
43
)
123
44
124
45
125
- @_typer ()
126
46
def _lee_filter (img , window_size : int ):
127
47
img_ = img .copy ()
128
48
if isinstance (img , np .ndarray ):
@@ -166,7 +86,6 @@ def clip(self, geom):
166
86
def _max_time_wrap (self , wish = 5 , col = "time" ):
167
87
return np .min ((wish , self ._obj [col ].size ))
168
88
169
- @_typer ()
170
89
def plot_band (self , cmap = "Greys" , col = "time" , col_wrap = 5 , ** kwargs ):
171
90
return self ._obj .plot .imshow (
172
91
cmap = cmap ,
@@ -175,6 +94,29 @@ def plot_band(self, cmap="Greys", col="time", col_wrap=5, **kwargs):
175
94
** kwargs ,
176
95
)
177
96
97
+ def whittaker (
98
+ self ,
99
+ lmbd : float ,
100
+ weights : np .ndarray = None ,
101
+ a : float = 0.5 ,
102
+ min_value : float = - np .inf ,
103
+ max_value : float = np .inf ,
104
+ max_iter : int = 10 ,
105
+ time = "time" ,
106
+ ):
107
+ from . import whittaker
108
+
109
+ return whittaker .xr_wt (
110
+ self ._obj .to_dataset (name = "index" ),
111
+ lmbd ,
112
+ time = time ,
113
+ weights = weights ,
114
+ a = a ,
115
+ min_value = min_value ,
116
+ max_value = max_value ,
117
+ max_iter = max_iter ,
118
+ )["index" ]
119
+
178
120
179
121
@xr .register_dataset_accessor ("ed" )
180
122
class EarthDailyAccessorDataset :
@@ -187,7 +129,6 @@ def clip(self, geom):
187
129
def _max_time_wrap (self , wish = 5 , col = "time" ):
188
130
return np .min ((wish , self ._obj [col ].size ))
189
131
190
- @_typer ()
191
132
def plot_rgb (
192
133
self ,
193
134
red : str = "red" ,
@@ -205,7 +146,6 @@ def plot_rgb(
205
146
)
206
147
)
207
148
208
- @_typer ()
209
149
def plot_band (self , band , cmap = "Greys" , col = "time" , col_wrap = 5 , ** kwargs ):
210
150
return self ._obj [band ].plot .imshow (
211
151
cmap = cmap ,
@@ -214,7 +154,6 @@ def plot_band(self, band, cmap="Greys", col="time", col_wrap=5, **kwargs):
214
154
** kwargs ,
215
155
)
216
156
217
- @_typer ()
218
157
def lee_filter (self , window_size : int ):
219
158
return xr .apply_ufunc (
220
159
_lee_filter ,
@@ -225,7 +164,6 @@ def lee_filter(self, window_size: int):
225
164
kwargs = dict (window_size = window_size ),
226
165
)
227
166
228
- @_typer ()
229
167
def centroid (self , to_wkt : str = False , to_4326 : bool = True ):
230
168
"""Return the geographic center point in 4326/WKT of this dataset."""
231
169
# we can use a cache on our accessor objects, because accessors
@@ -288,7 +226,6 @@ def available_indices(self, details=False):
288
226
available_indices .append (spyndex .indices [k ] if details else k )
289
227
return available_indices
290
228
291
- @_typer ()
292
229
def add_indices (self , indices : list , ** kwargs ):
293
230
"""
294
231
Uses spyndex to compute and add index.
@@ -313,12 +250,11 @@ def add_indices(self, indices: list, **kwargs):
313
250
idx = spyndex .computeIndex (index = indices , params = params , ** kwargs )
314
251
315
252
if len (indices ) == 1 :
316
- idx = idx .expand_dims (indices = indices )
317
- idx = idx .to_dataset (dim = "indices " )
253
+ idx = idx .expand_dims (index = indices )
254
+ idx = idx .to_dataset (dim = "index " )
318
255
319
256
return xr .merge ((self ._obj , idx ))
320
257
321
- @_typer ()
322
258
def sel_nearest_dates (
323
259
self ,
324
260
target : (xr .Dataset , xr .DataArray ),
@@ -334,14 +270,14 @@ def sel_nearest_dates(
334
270
for i , j in enumerate (pos )
335
271
if j .days <= max_delta
336
272
]
273
+ pos = np .unique (pos )
337
274
if return_target :
338
275
method_convert = {"bfill" : "ffill" , "ffill" : "bfill" , "nearest" : "nearest" }
339
276
return self ._obj .sel (time = pos ), target .sel (
340
277
time = pos , method = method_convert [method ]
341
278
)
342
279
return self ._obj .sel (time = pos )
343
280
344
- @_typer ()
345
281
def whittaker (
346
282
self ,
347
283
lmbd : float ,
@@ -350,15 +286,16 @@ def whittaker(
350
286
min_value : float = - np .inf ,
351
287
max_value : float = np .inf ,
352
288
max_iter : int = 10 ,
289
+ time = "time" ,
353
290
):
354
291
from . import whittaker
355
292
356
293
return whittaker .xr_wt (
357
294
self ._obj ,
358
295
lmbd ,
359
- time = " time" ,
360
- weights = None ,
361
- a = 0.5 ,
296
+ time = time ,
297
+ weights = weights ,
298
+ a = a ,
362
299
min_value = min_value ,
363
300
max_value = max_value ,
364
301
max_iter = max_iter ,
0 commit comments