Skip to content

Commit aa6648d

Browse files
authored
v0.0.12
v0.0.12
2 parents 2178a63 + 963fc9e commit aa6648d

File tree

3 files changed

+90
-26
lines changed

3 files changed

+90
-26
lines changed

earthdaily/accessor/__init__.py

+29-21
Original file line numberDiff line numberDiff line change
@@ -22,38 +22,36 @@ class MisType(Warning):
2222
_SUPPORTED_DTYPE = [int, float, list, bool, str]
2323

2424

25-
def _typer(raise_mistype=False):
25+
def _typer(raise_mistype=False, custom_types={}):
2626
def decorator(func):
2727
def force(*args, **kwargs):
2828
_args = list(args)
29-
idx = 1
29+
func_arg = func.__code__.co_varnames
3030
for key, val in func.__annotations__.items():
31+
if not isinstance(val, (list, tuple)):
32+
val = [val]
33+
idx = [i for i in range(len(func_arg)) if func_arg[i] == key][0]
3134
is_kwargs = key in kwargs.keys()
3235
if not is_kwargs and idx >= len(args):
3336
continue
3437
input_value = kwargs.get(key, None) if is_kwargs else args[idx]
35-
if type(input_value) == val:
38+
if type(input_value) in val:
3639
continue
37-
if raise_mistype and (
38-
val != type(kwargs.get(key))
40+
if (
41+
type(kwargs.get(key)) not in val
3942
if is_kwargs
40-
else val != type(args[idx])
43+
else type(args[idx]) not in val
4144
):
45+
if raise_mistype:
46+
if is_kwargs:
47+
expected = f"{type(kwargs[key]).__name__} ({kwargs[key]})"
48+
else:
49+
expected = f"{type(args[idx]).__name__} ({args[idx]})"
50+
raise MisType(f"{key} expected {val.__name__}, not {expected}.")
4251
if is_kwargs:
43-
expected = f"{type(kwargs[key]).__name__} ({kwargs[key]})"
52+
kwargs[key] = val[0](kwargs[key])
4453
else:
45-
expected = f"{type(args[idx]).__name__} ({args[idx]})"
46-
47-
raise MisType(
48-
f"{key} expected a {val.__name__}, not a {expected}."
49-
)
50-
if is_kwargs:
51-
kwargs[key] = val(kwargs[key]) if val != list else [kwargs[key]]
52-
elif len(args) >= idx:
53-
if isinstance(val, (list, tuple)) and len(val) > 1:
54-
val = val[0]
55-
_args[idx] = val(args[idx]) if val != list else [args[idx]]
56-
idx += 1
54+
_args[idx] = val[0](args[idx])
5755
args = tuple(_args)
5856
return func(*args, **kwargs)
5957

@@ -335,8 +333,18 @@ def whittaker(
335333
max_iter=max_iter,
336334
)
337335

338-
def zonal_stats(self, geometry, operations: list = ["mean"]):
336+
def zonal_stats(
337+
self,
338+
geometry,
339+
operations: list = ["mean"],
340+
raise_missing_geometry: bool = False,
341+
):
339342
from ..earthdatastore.cube_utils import zonal_stats, GeometryManager
340343

341344
geometry = GeometryManager(geometry).to_geopandas()
342-
return zonal_stats(self._obj, geometry, operations=operations)
345+
return zonal_stats(
346+
self._obj,
347+
geometry,
348+
operations=operations,
349+
raise_missing_geometry=raise_missing_geometry,
350+
)

earthdaily/earthdatastore/cube_utils/_zonal.py

+52-2
Original file line numberDiff line numberDiff line change
@@ -105,27 +105,77 @@ def zonal_stats_numpy(
105105
def zonal_stats(
106106
dataset,
107107
gdf,
108-
operations=["mean"],
108+
operations: list = ["mean"],
109109
all_touched=False,
110110
method="geocube",
111111
verbose=False,
112+
raise_missing_geometry=False,
112113
):
114+
"""
115+
116+
117+
Parameters
118+
----------
119+
dataset : xr.Dataset
120+
DESCRIPTION.
121+
gdf : gpd.GeoDataFrame
122+
DESCRIPTION.
123+
operations : TYPE, list.
124+
DESCRIPTION. The default is ["mean"].
125+
all_touched : TYPE, optional
126+
DESCRIPTION. The default is False.
127+
method : TYPE, optional
128+
DESCRIPTION. The default is "geocube".
129+
verbose : TYPE, optional
130+
DESCRIPTION. The default is False.
131+
raise_missing_geometry : TYPE, optional
132+
DESCRIPTION. The default is False.
133+
134+
Raises
135+
------
136+
ValueError
137+
DESCRIPTION.
138+
NotImplementedError
139+
DESCRIPTION.
140+
141+
Returns
142+
-------
143+
TYPE
144+
DESCRIPTION.
145+
146+
"""
113147
if method == "geocube":
114148
from geocube.api.core import make_geocube
149+
from geocube.rasterize import rasterize_image
150+
151+
def custom_rasterize_image(all_touched=all_touched, **kwargs):
152+
return rasterize_image(all_touched=all_touched, **kwargs)
115153

116154
gdf["tmp_index"] = np.arange(gdf.shape[0])
117155
out_grid = make_geocube(
118156
gdf,
119157
measurements=["tmp_index"],
120158
like=dataset, # ensure the data are on the same grid
159+
rasterize_function=custom_rasterize_image,
121160
)
122161
cube = dataset.groupby(out_grid.tmp_index)
123162
zonal_stats = xr.concat(
124163
[getattr(cube, operation)() for operation in operations], dim="stats"
125164
)
126165
zonal_stats["stats"] = operations
127-
zonal_stats["tmp_index"] = list(gdf.index)
128166

167+
if zonal_stats["tmp_index"].size != gdf.shape[0]:
168+
index_list = [
169+
gdf.index[i] for i in zonal_stats["tmp_index"].values.astype(np.int16)
170+
]
171+
if raise_missing_geometry:
172+
diff = gdf.shape[0] - len(index_list)
173+
raise ValueError(
174+
f'{diff} geometr{"y is" if diff==1 else "ies are"} missing in the zonal stats. This can be due to too small geometries, duplicated...'
175+
)
176+
else:
177+
index_list = list(gdf.index)
178+
zonal_stats["tmp_index"] = index_list
129179
return zonal_stats.rename(dict(tmp_index="feature"))
130180

131181
tqdm_bar = tqdm.tqdm(total=gdf.shape[0])

tests/test_zonalstats.py

+9-3
Original file line numberDiff line numberDiff line change
@@ -27,21 +27,23 @@ def setUp(self, constant=np.random.randint(1, 12)):
2727
"time": times,
2828
},
2929
).rio.write_crs("EPSG:4326")
30-
ds = ds.transpose('time','x','y')
3130
# first pixel
31+
3232
geometry = [
33+
Polygon([(0, 0), (0, 0.5), (0.5, 0.5), (0.5, 0)]),
3334
Polygon([(0, 0), (0, 1.2), (1.2, 1.2), (1.2, 0)]),
3435
Polygon([(1, 1), (9, 1), (9, 2.1), (1, 1)])
3536
]
3637
# out of bound geom # Polygon([(10,10), (10,11), (11,11), (11,10)])]
3738
gdf = gpd.GeoDataFrame({"geometry": geometry}, crs="EPSG:4326")
39+
gdf.index = ['tosmall','ok','ok']
3840
self.gdf = gdf
3941
self.datacube = ds
4042

4143

4244
def test_basic(self):
4345
zonalstats = earthdaily.earthdatastore.cube_utils.zonal_stats(
44-
self.datacube, self.gdf, all_touched=True, operations=["min", "max"]
46+
self.datacube, self.gdf, operations=["min", "max"], raise_missing_geometry=False
4547
)
4648
for operation in ["min", "max"]:
4749
self._check_results(
@@ -55,6 +57,10 @@ def _check_results(self, stats_values, operation="min"):
5557
}
5658
self.assertTrue(np.all(stats_values == results[operation]))
5759

58-
60+
def test_error(self):
61+
with self.assertRaises(ValueError):
62+
earthdaily.earthdatastore.cube_utils.zonal_stats(
63+
self.datacube, self.gdf, operations=["min", "max"], raise_missing_geometry=True)
64+
5965
if __name__ == "__main__":
6066
unittest.main()

0 commit comments

Comments
 (0)