@@ -19,39 +19,70 @@ class MisType(Warning):
19
19
pass
20
20
21
21
22
- _SUPPORTED_DTYPE = [int , float , list , bool , str ]
23
-
24
-
25
22
def _typer (raise_mistype = False , custom_types = {}):
23
+ """
24
+
25
+ Parameters
26
+ ----------
27
+ raise_mistype : TYPE, optional
28
+ DESCRIPTION. The default is False.
29
+ custom_types : TYPE, optional
30
+ DESCRIPTION. The default is {}.
31
+ Example : {
32
+ np.ndarray:{"func":np.asarray},
33
+ xr.Dataset:{"func":xr.DataArray.to_dataset,"kwargs":{"dim":"band"}}
34
+ }
35
+ Raises
36
+ ------
37
+ MisType
38
+ DESCRIPTION.
39
+
40
+ Returns
41
+ -------
42
+ TYPE
43
+ DESCRIPTION.
44
+
45
+ """
46
+
26
47
def decorator (func ):
27
48
def force (* args , ** kwargs ):
28
49
_args = list (args )
29
- func_arg = func .__code__ . co_varnames
30
- for key , val in func . __annotations__ . items ( ):
31
- if not isinstance ( val , ( list , tuple )):
32
- val = [ val ]
33
- idx = [ i for i in range ( len ( func_arg )) if func_arg [ i ] == key ][ 0 ]
50
+ for key , vals in func .__annotations__ . items ():
51
+ if not isinstance ( vals , ( list , tuple ) ):
52
+ vals = [ vals ]
53
+ val = vals [ 0 ]
54
+ idx = func . __code__ . co_varnames . index ( key )
34
55
is_kwargs = key in kwargs .keys ()
35
- if not is_kwargs and idx >= len (args ):
36
- continue
37
- input_value = kwargs .get (key , None ) if is_kwargs else args [idx ]
38
- if type (input_value ) in val :
56
+ if not is_kwargs and idx > len (_args ):
57
+ break
58
+ value = kwargs .get (key , None ) if is_kwargs else args [idx ]
59
+ if type (value ) in vals :
39
60
continue
40
61
if (
41
- type (kwargs .get (key )) not in val
62
+ type (kwargs .get (key )) not in vals
42
63
if is_kwargs
43
- else type (args [idx ]) not in val
64
+ else type (args [idx ]) not in vals
44
65
):
45
66
if raise_mistype :
46
67
if is_kwargs :
47
68
expected = f"{ type (kwargs [key ]).__name__ } ({ kwargs [key ]} )"
48
69
else :
49
70
expected = f"{ type (args [idx ]).__name__ } ({ args [idx ]} )"
50
71
raise MisType (f"{ key } expected { val .__name__ } , not { expected } ." )
51
- if is_kwargs :
52
- kwargs [key ] = val [0 ](kwargs [key ])
72
+ if any (val == k for k in custom_types .keys ()):
73
+ exp = custom_types [val ]
74
+ var = args [idx ]
75
+ res = exp ["func" ](var , ** exp .get ("kwargs" , {}))
76
+ if is_kwargs :
77
+ kwargs [key ] = res
78
+ else :
79
+ _args [idx ] = res
80
+ elif is_kwargs :
81
+ kwargs [key ] = (
82
+ var (kwargs [key ]) if var is not list else [kwargs [key ]]
83
+ )
53
84
else :
54
- _args [idx ] = val [ 0 ] (args [idx ])
85
+ _args [idx ] = var (args [idx ]) if var is not list else [ args [ idx ]]
55
86
args = tuple (_args )
56
87
return func (* args , ** kwargs )
57
88
@@ -258,7 +289,7 @@ def available_indices(self, details=False):
258
289
return available_indices
259
290
260
291
@_typer ()
261
- def add_indices (self , index : list , ** kwargs ):
292
+ def add_indices (self , indices : list , ** kwargs ):
262
293
"""
263
294
Uses spyndex to compute and add index.
264
295
@@ -267,7 +298,7 @@ def add_indices(self, index: list, **kwargs):
267
298
268
299
Parameters
269
300
----------
270
- index : list
301
+ indices : list
271
302
['NDVI'].
272
303
Returns
273
304
-------
@@ -279,11 +310,11 @@ def add_indices(self, index: list, **kwargs):
279
310
params = {}
280
311
params = self ._auto_mapper ()
281
312
params .update (** kwargs )
282
- idx = spyndex .computeIndex (index = index , params = params , ** kwargs )
313
+ idx = spyndex .computeIndex (index = indices , params = params , ** kwargs )
283
314
284
- if len (index ) == 1 :
285
- idx = idx .expand_dims (index = index )
286
- idx = idx .to_dataset (dim = "index " )
315
+ if len (indices ) == 1 :
316
+ idx = idx .expand_dims (indices = indices )
317
+ idx = idx .to_dataset (dim = "indices " )
287
318
288
319
return xr .merge ((self ._obj , idx ))
289
320
0 commit comments