Skip to content

Commit 97118ff

Browse files
authored
Add encode_cf, decode_cf (#69)
* Add encode_cf, decode_cf * cleanup * Update gitignore * Add test * Add cf_xarray as dependency * Update tests * Handle multiple CRS * Updates * Use crs_wkt directly in decode * fix tests * Check indexes for equality * Add comment * Don't set crs attribute * Revert "Don't set crs attribute" This reverts commit 2a7cf38. * fix * Add cf-xarray to conda env * Update docs * Add docstring * Typing fixes: Disallow dataarrays * Add to api.rst * Another fix.
1 parent 6167014 commit 97118ff

File tree

8 files changed

+281
-66
lines changed

8 files changed

+281
-66
lines changed

.gitignore

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -139,10 +139,11 @@ dmypy.json
139139

140140
# sphinx
141141
doc/source/generated
142+
doc/source/geo-encoded*
142143

143144
# ruff
144145
.ruff_cache
145146
doc/source/cube.joblib.compressed
146147
doc/source/cube.pickle
147148

148-
cache/
149+
cache/

doc/source/api.rst

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,8 @@ Methods
5656
Dataset.xvec.to_geopandas
5757
Dataset.xvec.extract_points
5858
Dataset.xvec.zonal_stats
59+
Dataset.xvec.encode_cf
60+
Dataset.xvec.decode_cf
5961

6062

6163
DataArray.xvec
@@ -91,4 +93,4 @@ Methods
9193
DataArray.xvec.to_geodataframe
9294
DataArray.xvec.to_geopandas
9395
DataArray.xvec.extract_points
94-
DataArray.xvec.zonal_stats
96+
DataArray.xvec.zonal_stats

doc/source/io.ipynb

Lines changed: 81 additions & 61 deletions
Large diffs are not rendered by default.

environment.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ dependencies:
66
# required
77
- shapely=2
88
- xarray
9+
- cf_xarray
910
# testing
1011
- pytest
1112
- pytest-cov

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ dependencies = [
3030
"xarray >= 2022.12.0",
3131
"pyproj >= 3.0.0",
3232
"shapely >= 2.0b1",
33+
"cf_xarray >= 0.9.2",
3334
]
3435

3536
[project.urls]

xvec/accessor.py

Lines changed: 121 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -164,7 +164,7 @@ def geom_coords(self) -> Mapping[Hashable, xr.DataArray]:
164164
).coords
165165

166166
@property
167-
def geom_coords_indexed(self) -> Mapping[Hashable, xr.DataArray]:
167+
def geom_coords_indexed(self) -> xr.Coordinates:
168168
"""Returns a dictionary of xarray.DataArray objects corresponding to
169169
coordinate variables using :class:`~xvec.GeometryIndex`.
170170
@@ -1290,6 +1290,126 @@ def extract_points(
12901290
)
12911291
return result
12921292

1293+
def encode_cf(self) -> xr.Dataset:
1294+
"""
1295+
Encode all geometry variables and associated CRS with CF conventions.
1296+
1297+
Use this method prior to writing an Xarray dataset to any array format
1298+
(e.g. netCDF or Zarr).
1299+
1300+
The following invariant is satisfied:
1301+
``assert ds.xvec.encode_cf().xvec.decode_cf().identical(ds) is True``
1302+
1303+
CRS information on the ``GeometryIndex`` is encoded using CF's ``grid_mapping`` convention.
1304+
1305+
This function uses ``cf_xarray.geometry.encode_geometries`` under the hood and will only
1306+
work on Datasets.
1307+
1308+
Returns
1309+
-------
1310+
Dataset
1311+
"""
1312+
import cf_xarray as cfxr
1313+
1314+
if not isinstance(self._obj, xr.Dataset):
1315+
raise ValueError(
1316+
"CF encoding is only valid on Datasets. Convert to a dataset using `.to_dataset()` first."
1317+
)
1318+
1319+
ds = self._obj.copy()
1320+
coords = self.geom_coords_indexed
1321+
1322+
# TODO: this could use geoxarray, but is quite simple in any case
1323+
# Adapted from rioxarray
1324+
# 1. First find all unique CRS objects
1325+
# preserve ordering for roundtripping
1326+
unique_crs = []
1327+
for _, xi in sorted(coords.xindexes.items()):
1328+
if xi.crs not in unique_crs:
1329+
unique_crs.append(xi.crs)
1330+
if len(unique_crs) == 1:
1331+
grid_mappings = {unique_crs.pop(): "spatial_ref"}
1332+
else:
1333+
grid_mappings = {
1334+
crs_: f"spatial_ref_{i}" for i, crs_ in enumerate(unique_crs)
1335+
}
1336+
1337+
# 2. Convert CRS to grid_mapping variables and assign them
1338+
for crs, grid_mapping in grid_mappings.items():
1339+
grid_mapping_attrs = crs.to_cf()
1340+
# TODO: not all CRS can be represented by CF grid_mappings
1341+
# For now, we allow this.
1342+
# if "grid_mapping_name" not in grid_mapping_attrs:
1343+
# raise ValueError
1344+
wkt_str = crs.to_wkt()
1345+
grid_mapping_attrs["spatial_ref"] = wkt_str
1346+
grid_mapping_attrs["crs_wkt"] = wkt_str
1347+
ds.coords[grid_mapping] = xr.Variable(
1348+
dims=(), data=0, attrs=grid_mapping_attrs
1349+
)
1350+
1351+
# 3. Associate other variables with appropriate grid_mapping variable
1352+
# We asumme that this relation follows from dimension names being shared between
1353+
# the GeometryIndex and the variable being checked.
1354+
for name, coord in coords.items():
1355+
dims = set(coord.dims)
1356+
index = coords.xindexes[name]
1357+
varnames = (k for k, v in ds._variables.items() if dims & set(v.dims))
1358+
for name in varnames:
1359+
if TYPE_CHECKING:
1360+
assert isinstance(index, GeometryIndex)
1361+
ds._variables[name].attrs["grid_mapping"] = grid_mappings[index.crs]
1362+
1363+
encoded = cfxr.geometry.encode_geometries(ds)
1364+
return encoded
1365+
1366+
def decode_cf(self) -> xr.Dataset:
1367+
"""
1368+
Decode geometries stored as CF-compliant arrays to shapely geometries.
1369+
1370+
The following invariant is satisfied:
1371+
``assert ds.xvec.encode_cf().xvec.decode_cf().identical(ds) is True``
1372+
1373+
1374+
A ``GeometryIndex`` is created automatically and CRS information, if available
1375+
following CF's ``grid_mapping`` convention, will be associated with the ``GeometryIndex``.
1376+
1377+
This function uses ``cf_xarray.geometry.decode_geometries`` under the hood, and will only
1378+
work on Datasets.
1379+
1380+
Returns
1381+
-------
1382+
Dataset
1383+
"""
1384+
import cf_xarray as cfxr
1385+
1386+
if not isinstance(self._obj, xr.Dataset):
1387+
raise ValueError(
1388+
"CF decoding is only supported on Datasets. Convert to a Dataset using `.to_dataset()` first."
1389+
)
1390+
1391+
decoded = cfxr.geometry.decode_geometries(self._obj.copy())
1392+
crs = {
1393+
name: CRS.from_user_input(var.attrs["crs_wkt"])
1394+
for name, var in decoded._variables.items()
1395+
if "crs_wkt" in var.attrs or "grid_mapping_name" in var.attrs
1396+
}
1397+
dims = decoded.xvec.geom_coords.dims
1398+
for dim in dims:
1399+
decoded = (
1400+
decoded.set_xindex(dim) if dim not in decoded._indexes else decoded
1401+
)
1402+
decoded = decoded.xvec.set_geom_indexes(
1403+
dim, crs=crs.get(decoded[dim].attrs.get("grid_mapping", None))
1404+
)
1405+
for name in crs:
1406+
# remove spatial_ref so the coordinate system is only stored on the index
1407+
del decoded[name]
1408+
for var in decoded._variables.values():
1409+
if set(dims) & set(var.dims):
1410+
var.attrs.pop("grid_mapping", None)
1411+
return decoded
1412+
12931413

12941414
def _resolve_input(
12951415
positional: Mapping[Any, Any] | None,

xvec/tests/conftest.py

Lines changed: 41 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ def multi_dataset(geom_array, geom_array_z):
6969

7070
@pytest.fixture(scope="session")
7171
def multi_geom_dataset(geom_array, geom_array_z):
72-
return (
72+
ds = (
7373
xr.Dataset(
7474
coords={
7575
"geom": geom_array,
@@ -80,11 +80,32 @@ def multi_geom_dataset(geom_array, geom_array_z):
8080
.set_xindex("geom", GeometryIndex, crs=26915)
8181
.set_xindex("geom_z", GeometryIndex, crs=26915)
8282
)
83+
ds["geom"].attrs["crs"] = ds.xindexes["geom"].crs
84+
ds["geom_z"].attrs["crs"] = ds.xindexes["geom_z"].crs
85+
return ds
86+
87+
88+
@pytest.fixture(scope="session")
89+
def multi_geom_multi_crs_dataset(geom_array, geom_array_z):
90+
ds = (
91+
xr.Dataset(
92+
coords={
93+
"geom": geom_array,
94+
"geom_z": geom_array_z,
95+
}
96+
)
97+
.drop_indexes(["geom", "geom_z"])
98+
.set_xindex("geom", GeometryIndex, crs=26915)
99+
.set_xindex("geom_z", GeometryIndex, crs="EPSG:4362")
100+
)
101+
ds["geom"].attrs["crs"] = ds.xindexes["geom"].crs
102+
ds["geom_z"].attrs["crs"] = ds.xindexes["geom_z"].crs
103+
return ds
83104

84105

85106
@pytest.fixture(scope="session")
86107
def multi_geom_no_index_dataset(geom_array, geom_array_z):
87-
return (
108+
ds = (
88109
xr.Dataset(
89110
coords={
90111
"geom": geom_array,
@@ -96,6 +117,9 @@ def multi_geom_no_index_dataset(geom_array, geom_array_z):
96117
.set_xindex("geom", GeometryIndex, crs=26915)
97118
.set_xindex("geom_z", GeometryIndex, crs=26915)
98119
)
120+
ds["geom"].attrs["crs"] = ds.xindexes["geom"].crs
121+
ds["geom_z"].attrs["crs"] = ds.xindexes["geom_z"].crs
122+
return ds
99123

100124

101125
@pytest.fixture(scope="session")
@@ -157,3 +181,18 @@ def traffic_dataset(geom_array):
157181
"day": pd.date_range("2023-01-01", periods=10),
158182
},
159183
).xvec.set_geom_indexes(["origin", "destination"], crs=26915)
184+
185+
186+
@pytest.fixture(
187+
params=[
188+
"first_geom_dataset",
189+
"multi_dataset",
190+
"multi_geom_dataset",
191+
"multi_geom_no_index_dataset",
192+
"multi_geom_multi_crs_dataset",
193+
"traffic_dataset",
194+
],
195+
scope="session",
196+
)
197+
def all_datasets(request):
198+
return request.getfixturevalue(request.param)

xvec/tests/test_accessor.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -674,3 +674,34 @@ def test_extract_points_array():
674674
geometry=4326
675675
),
676676
)
677+
678+
679+
def test_cf_roundtrip(all_datasets):
680+
ds = all_datasets
681+
copy = ds.copy(deep=True)
682+
encoded = ds.xvec.encode_cf()
683+
684+
if unique_crs := {
685+
idx.crs for idx in ds.xvec.geom_coords_indexed.xindexes.values() if idx.crs
686+
}:
687+
nwkts = sum(1 for var in encoded._variables.values() if "crs_wkt" in var.attrs)
688+
assert len(unique_crs) == nwkts
689+
roundtripped = encoded.xvec.decode_cf()
690+
691+
xr.testing.assert_identical(ds, roundtripped)
692+
assert_indexes_equals(ds, roundtripped)
693+
# make sure we didn't modify the original dataset.
694+
xr.testing.assert_identical(ds, copy)
695+
696+
697+
def assert_indexes_equals(left, right):
698+
# Till https://github.com/pydata/xarray/issues/5812 is resolved
699+
# Also, we don't record whether an unindexed coordinate was serialized
700+
# So just asssert that the left ("expected") dataset has fewer indexes
701+
# than the right.
702+
# This isn't great...
703+
assert sorted(left.xindexes.keys()) <= sorted(right.xindexes.keys())
704+
for k in left.xindexes:
705+
if not isinstance(left.xindexes[k], GeometryIndex):
706+
continue
707+
assert left.xindexes[k].equals(right.xindexes[k])

0 commit comments

Comments
 (0)