Skip to content

Commit 9fdf797

Browse files
author
nicolasK
committed
fix(_typer)
1 parent c6470c3 commit 9fdf797

File tree

4 files changed

+40
-24
lines changed

4 files changed

+40
-24
lines changed

earthdaily/accessor/__init__.py

+32-15
Original file line numberDiff line numberDiff line change
@@ -24,14 +24,27 @@ class MisType(Warning):
2424
def _typer(raise_mistype=False):
2525
def decorator(func):
2626
def force(*args, **kwargs):
27+
_args = list(args)
28+
idx = 1
2729
for key, val in func.__annotations__.items():
28-
if val not in _SUPPORTED_DTYPE or kwargs.get(key, None) is None:
30+
is_kwargs = key in kwargs.keys()
31+
if val not in _SUPPORTED_DTYPE or kwargs.get(key, None) is None and is_kwargs or len(args)==1:
2932
continue
30-
if raise_mistype and val != type(kwargs.get(key)):
31-
raise MisType(
32-
f"{key} expected a {val.__name__}, not a {type(kwargs[key]).__name__} ({kwargs[key]})"
33+
if raise_mistype and (val != type(kwargs.get(key)) if is_kwargs else val != type(args[idx])):
34+
if is_kwargs:
35+
expected = f"{type(kwargs[key]).__name__} ({kwargs[key]})"
36+
else:
37+
expected = f"{type(args[idx]).__name__} ({args[idx]})"
38+
39+
raise MisType(
40+
f"{key} expected a {val.__name__}, not a {expected}."
3341
)
34-
kwargs[key] = val(kwargs[key]) if val != list else [kwargs[key]]
42+
if is_kwargs:
43+
kwargs[key] = val(kwargs[key]) if val != list else [kwargs[key]]
44+
else:
45+
_args[idx] = val(args[idx]) if val != list else [args[idx]]
46+
idx+=1
47+
args = tuple(_args)
3548
return func(*args, **kwargs)
3649

3750
return force
@@ -103,22 +116,24 @@ def _lee_filter(img, window_size: int):
103116
img_output = xr.where(np.isnan(binary_nan), img_, img_output)
104117
return img_output
105118

106-
107119
@xr.register_dataarray_accessor("ed")
108120
class EarthDailyAccessorDataArray:
109121
def __init__(self, xarray_obj):
110122
self._obj = xarray_obj
123+
124+
def _max_time_wrap(self, wish=5):
125+
return np.min((wish,self._obj['time'].size))
111126

112127
@_typer()
113128
def plot_band(self, cmap="Greys", col="time", col_wrap=5, **kwargs):
114-
return self._obj.plot.imshow(cmap=cmap, col=col, col_wrap=col_wrap, **kwargs)
129+
return self._obj.plot.imshow(cmap=cmap, col=col, col_wrap=self._max_time_wrap(col_wrap), **kwargs)
115130

116131
@_typer()
117132
def plot_index(
118133
self, cmap="RdYlGn", vmin=-1, vmax=1, col="time", col_wrap=5, **kwargs
119134
):
120135
return self._obj.plot.imshow(
121-
vmin=vmin, vmax=vmax, cmap=cmap, col=col, col_wrap=col_wrap, **kwargs
136+
vmin=vmin, vmax=vmax, cmap=cmap, col=col, col_wrap=self._max_time_wrap(col_wrap), **kwargs
122137
)
123138

124139

@@ -127,6 +142,10 @@ class EarthDailyAccessorDataset:
127142
def __init__(self, xarray_obj):
128143
self._obj = xarray_obj
129144

145+
def _max_time_wrap(self, wish=5):
146+
return np.min((wish,self._obj['time'].size))
147+
148+
130149
@_typer()
131150
def plot_rgb(
132151
self,
@@ -140,21 +159,21 @@ def plot_rgb(
140159
return (
141160
self._obj[[red, green, blue]]
142161
.to_array(dim="bands")
143-
.plot.imshow(col=col, col_wrap=col_wrap, **kwargs)
162+
.plot.imshow(col=col, col_wrap=self._max_time_wrap(col_wrap), **kwargs)
144163
)
145164

146165
@_typer()
147166
def plot_band(self, band, cmap="Greys", col="time", col_wrap=5, **kwargs):
148167
return self._obj[band].plot.imshow(
149-
cmap=cmap, col=col, col_wrap=col_wrap, **kwargs
168+
cmap=cmap, col=col, col_wrap=self._max_time_wrap(col_wrap), **kwargs
150169
)
151170

152171
@_typer()
153172
def plot_index(
154173
self, index, cmap="RdYlGn", vmin=-1, vmax=1, col="time", col_wrap=5, **kwargs
155174
):
156175
return self._obj[index].plot.imshow(
157-
vmin=vmin, vmax=vmax, cmap=cmap, col=col, col_wrap=col_wrap, **kwargs
176+
vmin=vmin, vmax=vmax, cmap=cmap, col=col, col_wrap=self._max_time_wrap(col_wrap), **kwargs
158177
)
159178

160179
@_typer()
@@ -216,7 +235,7 @@ def _auto_mapper(self):
216235
params[_BAND_MAPPING[v]] = self._obj[v]
217236
return params
218237

219-
def list_available_index(self, details=False):
238+
def available_index(self, details=False):
220239
mapper = list(self._auto_mapper().keys())
221240
indices = spyndex.indices
222241
available_indices = []
@@ -248,9 +267,7 @@ def add_index(self, index: list, **kwargs):
248267
"""
249268

250269
params = {}
251-
bands_mapping = self._auto_mapper()
252-
for k, v in bands_mapping.items():
253-
params[k] = self._obj[v]
270+
params = self._auto_mapper()
254271
params.update(**kwargs)
255272
idx = spyndex.computeIndex(index=index, params=params, **kwargs)
256273

examples/compare_scale_s2.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ def get_cube(rescale=True):
5050
# Plots cube with SCL with at least 50% of clear data
5151
# ----------------------------------------------------
5252

53-
pivot_cube.ed.plot_rgb(vmin=0, vmax=0.33, col="time", col_wrap=3)
53+
pivot_cube.ed.plot_rgb(col_wrap=3)
5454
plt.show()
5555

5656
##############################################################################
@@ -66,6 +66,6 @@ def get_cube(rescale=True):
6666
# ----------------------------------------------------
6767

6868

69-
pivot_cube.ed.plot_rgb(vmin=0, vmax=0.33, col="time", col_wrap=3)
69+
pivot_cube.ed.plot_rgb(col_wrap=3)
7070

7171
plt.show()

examples/multisensors_cube.py

+3-5
Original file line numberDiff line numberDiff line change
@@ -43,9 +43,7 @@
4343
)
4444

4545
# Add the NDVI
46-
datacube["ndvi"] = (datacube["nir"] - datacube["red"]) / (
47-
datacube["nir"] + datacube["red"]
48-
)
46+
datacube = datacube.ed.add_index('NDVI')
4947

5048
# Load in memory
5149
datacube = datacube.load()
@@ -63,7 +61,7 @@
6361
# See the NDVI evolution
6462
# -------------------------------------------
6563

66-
datacube["ndvi"].plot.imshow(
64+
datacube["NDVI"].plot.imshow(
6765
col="time", col_wrap=3, vmin=0, vmax=0.8, cmap="RdYlGn"
6866
)
6967
plt.show()
@@ -72,6 +70,6 @@
7270
# See the NDVI mean evolution
7371
# -------------------------------------------
7472

75-
datacube["ndvi"].groupby("time").mean(...).plot.line(x="time")
73+
datacube["NDVI"].groupby("time").mean(...).plot.line(x="time")
7674
plt.title("NDVI evolution")
7775
plt.show()

examples/venus_cube_mask.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@
3737
# Search for items
3838
# -------------------------------------------
3939

40-
items = eds.search(collection, query=query, prefer_alternate="download")
40+
items = eds.search(collection, query=query, prefer_alternate="download", limit=5)
4141

4242
##############################################################################
4343
# .. note::
@@ -72,4 +72,5 @@
7272
)
7373
print(venus_datacube)
7474

75-
venus_datacube.isel(time=slice(29, 31), x=slice(4000, 4500), y=slice(4000, 4500)).ed.plot_rgb()
75+
venus_datacube.isel(time=slice(29, 31), x=slice(4000, 4500), y=slice(4000, 4500)).ed.plot_rgb(vmax=0.2)
76+

0 commit comments

Comments
 (0)