Skip to content

Commit 7a81465

Browse files
authored
Merge pull request #46 from earthdaily/dev
v0.0.7
2 parents c5c32f0 + afcbcba commit 7a81465

12 files changed

+335
-45
lines changed

CHANGELOG.md

+6
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,12 @@ All notable changes to this project will be documented in this file.
44
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
55
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
66

7+
## [0.0.7] - 2024-02-28
8+
9+
#### Added
10+
11+
- `ed` accessor.
12+
713
## [0.0.6] - 2024-02-23
814

915
### Fixed

earthdaily/__init__.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
from . import earthdatastore, datasets
2+
from .accessor import EarthDailyAccessorDataArray, EarthDailyAccessorDataset
23

3-
__version__ = "0.0.6"
4+
__version__ = "0.0.7"

earthdaily/accessor/__init__.py

+284
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,284 @@
1+
import warnings
2+
import xarray as xr
3+
import rioxarray as rxr
4+
import numpy as np
5+
import pandas as pd
6+
import geopandas as gpd
7+
from shapely.geometry import Point
8+
from dask import array as da
9+
import spyndex
10+
from dask_image import ndfilters as ndimage
11+
12+
from xarray.core.extensions import AccessorRegistrationWarning
13+
14+
warnings.filterwarnings("ignore", category=AccessorRegistrationWarning)
15+
16+
17+
class MisType(Warning):
18+
pass
19+
20+
21+
_SUPPORTED_DTYPE = [int, float, list, bool, str]
22+
23+
24+
def _typer(raise_mistype=False):
25+
def decorator(func):
26+
def force(*args, **kwargs):
27+
for key, val in func.__annotations__.items():
28+
if val not in _SUPPORTED_DTYPE or kwargs.get(key, None) is None:
29+
continue
30+
if raise_mistype and val != type(kwargs.get(key)):
31+
raise MisType(
32+
f"{key} expected a {val.__name__}, not a {type(kwargs[key]).__name__} ({kwargs[key]})"
33+
)
34+
kwargs[key] = val(kwargs[key]) if val != list else [kwargs[key]]
35+
return func(*args, **kwargs)
36+
37+
return force
38+
39+
return decorator
40+
41+
42+
@_typer()
43+
def xr_loop_func(
44+
dataset: xr.Dataset,
45+
func,
46+
to_numpy: bool = False,
47+
loop_dimension: str = "time",
48+
**kwargs,
49+
):
50+
def _xr_loop_func(dataset, metafunc, loop_dimension, **kwargs):
51+
if to_numpy is True:
52+
dataset_func = dataset.copy()
53+
looped = [
54+
metafunc(dataset.isel({loop_dimension: i}).load().data, **kwargs)
55+
for i in range(dataset[loop_dimension].size)
56+
]
57+
dataset_func.data = np.asarray(looped)
58+
return dataset_func
59+
else:
60+
return xr.concat(
61+
[
62+
metafunc(dataset.isel({loop_dimension: i}), **kwargs)
63+
for i in range(dataset[loop_dimension].size)
64+
],
65+
dim=loop_dimension,
66+
)
67+
68+
return dataset.map(
69+
func=_xr_loop_func, metafunc=func, loop_dimension=loop_dimension, **kwargs
70+
)
71+
72+
73+
@_typer()
74+
def _lee_filter(img, window_size: int):
75+
try:
76+
from dask_image import ndfilters
77+
except ImportError:
78+
raise ImportError("Please install dask-image to run lee_filter")
79+
80+
img_ = img.copy()
81+
ndimage_type = ndfilters
82+
if hasattr(img, "data"):
83+
if isinstance(img.data, (memoryview, np.ndarray)):
84+
ndimage_type = ndimage
85+
img = img.data
86+
# print(ndimage_type)
87+
binary_nan = ndimage_type.minimum_filter(
88+
xr.where(np.isnan(img), 0, 1), size=window_size
89+
)
90+
binary_nan = np.where(binary_nan == 0, np.nan, 1)
91+
img = xr.where(np.isnan(img), 0, img)
92+
window_size = da.from_array([window_size, window_size, 1])
93+
94+
img_mean = ndimage_type.uniform_filter(img, window_size)
95+
img_sqr_mean = ndimage_type.uniform_filter(img**2, window_size)
96+
img_variance = img_sqr_mean - img_mean**2
97+
98+
overall_variance = np.var(img, axis=(0, 1))
99+
100+
img_weights = img_variance / (np.add(img_variance, overall_variance))
101+
102+
img_output = img_mean + img_weights * (np.subtract(img, img_mean))
103+
img_output = xr.where(np.isnan(binary_nan), img_, img_output)
104+
return img_output
105+
106+
107+
@xr.register_dataarray_accessor("ed")
108+
class EarthDailyAccessorDataArray:
109+
def __init__(self, xarray_obj):
110+
self._obj = xarray_obj
111+
112+
@_typer()
113+
def plot_band(self, cmap="Greys", col="time", col_wrap=5, **kwargs):
114+
return self._obj.plot.imshow(cmap=cmap, col=col, col_wrap=col_wrap, **kwargs)
115+
116+
@_typer()
117+
def plot_index(
118+
self, cmap="RdYlGn", vmin=-1, vmax=1, col="time", col_wrap=5, **kwargs
119+
):
120+
return self._obj.plot.imshow(
121+
vmin=vmin, vmax=vmax, cmap=cmap, col=col, col_wrap=col_wrap, **kwargs
122+
)
123+
124+
125+
@xr.register_dataset_accessor("ed")
126+
class EarthDailyAccessorDataset:
127+
def __init__(self, xarray_obj):
128+
self._obj = xarray_obj
129+
130+
@_typer()
131+
def plot_rgb(
132+
self,
133+
red: str = "red",
134+
green: str = "green",
135+
blue: str = "blue",
136+
col="time",
137+
col_wrap=5,
138+
**kwargs,
139+
):
140+
return (
141+
self._obj[[red, green, blue]]
142+
.to_array(dim="bands")
143+
.plot.imshow(col=col, col_wrap=col_wrap, **kwargs)
144+
)
145+
146+
@_typer()
147+
def plot_band(self, band, cmap="Greys", col="time", col_wrap=5, **kwargs):
148+
return self._obj[band].plot.imshow(
149+
cmap=cmap, col=col, col_wrap=col_wrap, **kwargs
150+
)
151+
152+
@_typer()
153+
def plot_index(
154+
self, index, cmap="RdYlGn", vmin=-1, vmax=1, col="time", col_wrap=5, **kwargs
155+
):
156+
return self._obj[index].plot.imshow(
157+
vmin=vmin, vmax=vmax, cmap=cmap, col=col, col_wrap=col_wrap, **kwargs
158+
)
159+
160+
@_typer()
161+
def lee_filter(self, window_size: int = 7):
162+
return xr.apply_ufunc(
163+
_lee_filter,
164+
self._obj,
165+
input_core_dims=[["time"]],
166+
dask="allowed",
167+
output_core_dims=[["time"]],
168+
kwargs=dict(window_size=window_size),
169+
)
170+
171+
@_typer()
172+
def centroid(self, to_wkt: str = False, to_4326: bool = True):
173+
"""Return the geographic center point in 4326/WKT of this dataset."""
174+
# we can use a cache on our accessor objects, because accessors
175+
# themselves are cached on instances that access them.
176+
lon = float(self._obj.x[int(self._obj.x.size / 2)])
177+
lat = float(self._obj.y[int(self._obj.y.size / 2)])
178+
point = gpd.GeoSeries([Point(lon, lat)], crs=self._obj.rio.crs)
179+
if to_4326:
180+
point = point.to_crs(epsg="4326")
181+
if to_wkt:
182+
point = point.map(lambda x: x.wkt).iloc[0]
183+
return point
184+
185+
def _auto_mapper(self):
186+
_BAND_MAPPING = {
187+
"coastal": "A",
188+
"blue": "B",
189+
"green": "G",
190+
"yellow": "Y",
191+
"red": "R",
192+
"rededge1": "RE1",
193+
"rededge2": "RE2",
194+
"rededge3": "RE3",
195+
"nir": "N",
196+
"nir08": "N2",
197+
"watervapor": "WV",
198+
"swir16": "S1",
199+
"swir22": "S2",
200+
"lwir": "T1",
201+
"lwir11": "T2",
202+
"vv": "VV",
203+
"vh": "VH",
204+
"hh": "HH",
205+
"hv": "HV",
206+
}
207+
208+
params = {}
209+
data_vars = list(
210+
self._obj.rename(
211+
{var: var.lower() for var in self._obj.data_vars}
212+
).data_vars
213+
)
214+
for v in data_vars:
215+
if v in _BAND_MAPPING.keys():
216+
params[_BAND_MAPPING[v]] = self._obj[v]
217+
return params
218+
219+
def list_available_index(self, details=False):
220+
mapper = list(self._auto_mapper().keys())
221+
indices = spyndex.indices
222+
available_indices = []
223+
for k, v in indices.items():
224+
needed_bands = v.bands
225+
for needed_band in needed_bands:
226+
if needed_band not in mapper:
227+
break
228+
available_indices.append(spyndex.indices[k] if details else k)
229+
return available_indices
230+
231+
@_typer()
232+
def add_index(self, index: list, **kwargs):
233+
"""
234+
Uses spyndex to compute and add index.
235+
236+
For list of indices, see https://github.com/awesome-spectral-indices/awesome-spectral-indices.
237+
238+
239+
Parameters
240+
----------
241+
index : list
242+
['NDVI'].
243+
Returns
244+
-------
245+
xr.Dataset
246+
The input xr.Dataset with new data_vars of indices.
247+
248+
"""
249+
250+
params = {}
251+
bands_mapping = self._auto_mapper()
252+
for k, v in bands_mapping.items():
253+
params[k] = self._obj[v]
254+
params.update(**kwargs)
255+
idx = spyndex.computeIndex(index=index, params=params, **kwargs)
256+
257+
if len(index) == 1:
258+
idx = idx.expand_dims(index=index)
259+
idx = idx.to_dataset(dim="index")
260+
261+
return xr.merge((self._obj, idx))
262+
263+
@_typer()
264+
def sel_nearest_dates(
265+
self,
266+
target,
267+
max_delta: int = 0,
268+
method: str = "nearest",
269+
return_target: bool = False,
270+
):
271+
src_time = self._obj.sel(time=target.time.dt.date, method=method).time.dt.date
272+
target_time = target.time.dt.date
273+
pos = np.abs(src_time.data - target_time.data)
274+
pos = [
275+
src_time.isel(time=i).time.values
276+
for i, j in enumerate(pos)
277+
if j.days <= max_delta
278+
]
279+
if return_target:
280+
method_convert = {"bfill": "ffill", "ffill": "bfill", "nearest": "nearest"}
281+
return self._obj.sel(time=pos), target.sel(
282+
time=pos, method=method_convert[method]
283+
)
284+
return self._obj.sel(time=pos)

earthdaily/earthdatastore/cube_utils/_zonal.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,12 @@
66
"""
77

88
from rasterio import features
9-
from scipy.sparse import csr_matrix
109
import numpy as np
1110
import xarray as xr
1211
import tqdm
1312
from . import custom_operations
1413
from .preprocessing import rasterize
14+
from scipy.sparse import csr_matrix
1515

1616

1717
def _compute_M(data):

examples/compare_scale_s2.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -46,26 +46,26 @@ def get_cube(rescale=True):
4646

4747
pivot_cube = get_cube(rescale=False) * 0.0001
4848

49-
#####################################################################da#########
49+
##############################################################################
5050
# Plots cube with SCL with at least 50% of clear data
5151
# ----------------------------------------------------
5252

53-
pivot_cube.to_array(dim="band").plot.imshow(vmin=0, vmax=0.33, col="time", col_wrap=3)
53+
pivot_cube.ed.plot_rgb(vmin=0, vmax=0.33, col="time", col_wrap=3)
5454
plt.show()
5555

56-
#####################################################################da#########
56+
##############################################################################
5757
# Get cube with automatic rescale (default option)
5858
# ----------------------------------------------------
5959

6060
pivot_cube = get_cube()
6161
pivot_cube.clear_percent.plot.scatter(x="time")
6262
plt.show()
6363

64-
#####################################################################da#########
64+
##############################################################################
6565
# Plots cube with SCL with at least 50% of clear data
6666
# ----------------------------------------------------
6767

6868

69-
pivot_cube.to_array(dim="band").plot.imshow(vmin=0, vmax=0.33, col="time", col_wrap=3)
69+
pivot_cube.ed.plot_rgb(vmin=0, vmax=0.33, col="time", col_wrap=3)
7070

7171
plt.show()

examples/earthdaily_simulated_dataset.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@
6060
# Plot RGB image time series
6161
# -------------------------------------------
6262

63-
datacube[["red", "green", "blue"]].to_array(dim="band").plot.imshow(
63+
datacube[["red", "green", "blue"]].ed.plot_rgb(
6464
col="time", col_wrap=4, vmax=0.2
6565
)
6666

examples/field_evolution.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,8 @@
5858
# ----------------------------------------------------
5959

6060
zonal_stats = earthdatastore.cube_utils.zonal_stats(
61-
pivot_cube, pivot, operations=["mean", "max", "min"]
61+
pivot_cube, pivot, operations=["mean", "max", "min"],
62+
method="standard"
6263
)
6364
zonal_stats = zonal_stats.load()
6465

examples/first_steps_create_datacube.py

+2-4
Original file line numberDiff line numberDiff line change
@@ -44,9 +44,7 @@
4444
plt.title("Percentage of clear pixels on the study site")
4545
plt.show()
4646

47-
s2_datacube[["red", "green", "blue"]].to_array(dim="band").plot.imshow(
48-
vmin=0, vmax=0.2, col="time", col_wrap=4
49-
)
47+
s2_datacube.ed.plot_rgb(vmin=0, vmax=0.2, col="time", col_wrap=4)
5048

5149
###########################################################
5250
# Create datacube in three steps
@@ -82,6 +80,6 @@
8280
s2_datacube, 50
8381
) # at least 50% of clear pixels
8482
#
85-
s2_datacube[["red", "green", "blue"]].to_array(dim="band").plot.imshow(
83+
s2_datacube.ed.plot_rgb(
8684
vmin=0, vmax=0.2, col="time", col_wrap=4
8785
)

examples/venus_cube_mask.py

+1-3
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,4 @@
7272
)
7373
print(venus_datacube)
7474

75-
venus_datacube.isel(time=slice(29, 31), x=slice(4000, 4500), y=slice(4000, 4500))[
76-
["red", "green", "blue"]
77-
].to_array(dim="band").plot.imshow(col="time", vmin=0, vmax=0.30)
75+
venus_datacube.isel(time=slice(29, 31), x=slice(4000, 4500), y=slice(4000, 4500)).plot_rgb()

0 commit comments

Comments
 (0)