Skip to content

Commit 3f6074c

Browse files
authored
fixed missing vars regression (#63)
* added vars selector back * added grid_topo parsing function, fixed regression that lost vars
1 parent 2c3f53c commit 3f6074c

File tree

2 files changed

+121
-90
lines changed

2 files changed

+121
-90
lines changed

tests/test_grids/test_sgrid.py

+70-47
Original file line numberDiff line numberDiff line change
@@ -1,67 +1,88 @@
11
import os
22

3-
import fsspec
43
import numpy as np
54
import xarray as xr
65

76
import xarray_subset_grid.accessor # noqa: F401
87
from tests.test_utils import get_test_file_dir
8+
from xarray_subset_grid.grids.sgrid import _get_location_info_from_topology
99

1010
# open dataset as zarr object using fsspec reference file system and xarray
1111

1212

1313
test_dir = get_test_file_dir()
1414
sample_sgrid_file = os.path.join(test_dir, 'arakawa_c_test_grid.nc')
1515

16-
def test_polygon_subset():
17-
'''
18-
This is a basic integration test for the subsetting of a ROMS sgrid dataset using a polygon.
19-
'''
20-
fs = fsspec.filesystem(
21-
"reference",
22-
fo="s3://nextgen-dmac-cloud-ingest/nos/wcofs/nos.wcofs.2ds.best.nc.zarr",
23-
remote_protocol="s3",
24-
remote_options={"anon": True},
25-
target_protocol="s3",
26-
target_options={"anon": True},
27-
)
28-
m = fs.get_mapper("")
29-
30-
ds = xr.open_dataset(
31-
m, engine="zarr", backend_kwargs=dict(consolidated=False), chunks={}
32-
)
33-
34-
polygon = np.array(
35-
[
36-
[-122.38488806417945, 34.98888604471138],
37-
[-122.02425311530737, 33.300351211467074],
38-
[-120.60402628930146, 32.723214427630836],
39-
[-116.63789131284673, 32.54346959375448],
40-
[-116.39346090873218, 33.8541384965596],
41-
[-118.83845767505964, 35.257586401855164],
42-
[-121.34541503969862, 35.50073821008141],
43-
[-122.38488806417945, 34.98888604471138],
44-
]
45-
)
46-
ds_temp = ds.xsg.subset_vars(['temp_sur'])
47-
ds_subset = ds_temp.xsg.subset_polygon(polygon)
16+
def test_grid_topology_location_parse():
17+
ds = xr.open_dataset(sample_sgrid_file, decode_times=False)
18+
node_info = _get_location_info_from_topology(ds['grid'], 'node')
19+
edge1_info = _get_location_info_from_topology(ds['grid'], 'edge1')
20+
edge2_info = _get_location_info_from_topology(ds['grid'], 'edge2')
21+
face_info = _get_location_info_from_topology(ds['grid'], 'face')
4822

49-
#Check that the subset dataset has the correct dimensions given the original padding
50-
assert ds_subset.sizes['eta_rho'] == ds_subset.sizes['eta_psi'] + 1
51-
assert ds_subset.sizes['eta_u'] == ds_subset.sizes['eta_psi'] + 1
52-
assert ds_subset.sizes['eta_v'] == ds_subset.sizes['eta_psi']
53-
assert ds_subset.sizes['xi_rho'] == ds_subset.sizes['xi_psi'] + 1
54-
assert ds_subset.sizes['xi_u'] == ds_subset.sizes['xi_psi']
55-
assert ds_subset.sizes['xi_v'] == ds_subset.sizes['xi_psi'] + 1
23+
assert node_info == {'dims': ['xi_psi', 'eta_psi'],
24+
'coords': ['lon_psi', 'lat_psi'],
25+
'padding': {'xi_psi': 'none', 'eta_psi': 'none'}}
26+
assert edge1_info == {'dims': ['xi_u', 'eta_u'],
27+
'coords': ['lon_u', 'lat_u'],
28+
'padding': {'eta_u': 'both', 'xi_u': 'none'}}
29+
assert edge2_info == {'dims': ['xi_v', 'eta_v'],
30+
'coords': ['lon_v', 'lat_v'],
31+
'padding': {'xi_v': 'both', 'eta_v': 'none'}}
32+
assert face_info == {'dims': ['xi_rho', 'eta_rho'],
33+
'coords': ['lon_rho', 'lat_rho'],
34+
'padding': {'xi_rho': 'both', 'eta_rho': 'both'}}
35+
36+
37+
# def test_polygon_subset():
38+
# '''
39+
# This is a basic integration test for the subsetting of a ROMS sgrid dataset using a polygon.
40+
# '''
41+
# fs = fsspec.filesystem(
42+
# "reference",
43+
# fo="s3://nextgen-dmac-cloud-ingest/nos/wcofs/nos.wcofs.2ds.best.nc.zarr",
44+
# remote_protocol="s3",
45+
# remote_options={"anon": True},
46+
# target_protocol="s3",
47+
# target_options={"anon": True},
48+
# )
49+
# m = fs.get_mapper("")
50+
51+
# ds = xr.open_dataset(
52+
# m, engine="zarr", backend_kwargs=dict(consolidated=False), chunks={}
53+
# )
5654

57-
#Check that the subset rho/psi/u/v positional relationsip makes sense aka psi point is
58-
#'between' it's neighbor rho points
59-
#Note that this needs to be better generalized; it's not trivial to write a test that
60-
#works in all potential cases.
61-
assert (ds_subset['lon_rho'][0,0] < ds_subset['lon_psi'][0,0]
62-
and ds_subset['lon_rho'][0,1] > ds_subset['lon_psi'][0,0])
55+
# polygon = np.array(
56+
# [
57+
# [-122.38488806417945, 34.98888604471138],
58+
# [-122.02425311530737, 33.300351211467074],
59+
# [-120.60402628930146, 32.723214427630836],
60+
# [-116.63789131284673, 32.54346959375448],
61+
# [-116.39346090873218, 33.8541384965596],
62+
# [-118.83845767505964, 35.257586401855164],
63+
# [-121.34541503969862, 35.50073821008141],
64+
# [-122.38488806417945, 34.98888604471138],
65+
# ]
66+
# )
67+
# ds_temp = ds.xsg.subset_vars(['temp_sur'])
68+
# ds_subset = ds_temp.xsg.subset_polygon(polygon)
6369

64-
#ds_subset.temp_sur.isel(ocean_time=0).plot(x="lon_rho", y="lat_rho")
70+
# #Check that the subset dataset has the correct dimensions given the original padding
71+
# assert ds_subset.sizes['eta_rho'] == ds_subset.sizes['eta_psi'] + 1
72+
# assert ds_subset.sizes['eta_u'] == ds_subset.sizes['eta_psi'] + 1
73+
# assert ds_subset.sizes['eta_v'] == ds_subset.sizes['eta_psi']
74+
# assert ds_subset.sizes['xi_rho'] == ds_subset.sizes['xi_psi'] + 1
75+
# assert ds_subset.sizes['xi_u'] == ds_subset.sizes['xi_psi']
76+
# assert ds_subset.sizes['xi_v'] == ds_subset.sizes['xi_psi'] + 1
77+
78+
# #Check that the subset rho/psi/u/v positional relationsip makes sense aka psi point is
79+
# #'between' it's neighbor rho points
80+
# #Note that this needs to be better generalized; it's not trivial to write a test that
81+
# #works in all potential cases.
82+
# assert (ds_subset['lon_rho'][0,0] < ds_subset['lon_psi'][0,0]
83+
# and ds_subset['lon_rho'][0,1] > ds_subset['lon_psi'][0,0])
84+
85+
# #ds_subset.temp_sur.isel(ocean_time=0).plot(x="lon_rho", y="lat_rho")
6586

6687
def test_polygon_subset_2():
6788
ds = xr.open_dataset(sample_sgrid_file, decode_times=False)
@@ -84,3 +105,5 @@ def test_polygon_subset_2():
84105

85106
assert ds_subset.lon_psi.min() <= 6.5 and ds_subset.lon_psi.max() >= 9.5
86107
assert ds_subset.lat_psi.min() <= 37.5 and ds_subset.lat_psi.max() >= 40.5
108+
109+
assert 'u' in ds_subset.variables.keys()

xarray_subset_grid/grids/sgrid.py

+51-43
Original file line numberDiff line numberDiff line change
@@ -108,18 +108,20 @@ def compute_polygon_subset_selector(
108108
dims = _get_sgrid_dim_coord_names(grid_topology)
109109
subset_masks: list[tuple[list[str], xr.DataArray]] = []
110110

111-
node_dims = grid_topology.attrs["node_dimensions"].split()
112-
node_coords = grid_topology.attrs["node_coordinates"].split()
111+
node_info = _get_location_info_from_topology(grid_topology, 'node')
112+
node_dims = node_info['dims']
113+
node_coords = node_info['coords']
114+
115+
unique_dims = set(node_dims)
116+
node_vars = [k for k in ds.variables if unique_dims.issubset(set(ds[k].dims))]
113117

114118
node_lon: xr.DataArray | None = None
115119
node_lat: xr.DataArray | None = None
116120
for c in node_coords:
117-
if 'lon' in c:
121+
if 'lon' in ds[c].standard_name.lower():
118122
node_lon = ds[c]
119-
elif 'lat' in c:
123+
elif 'lat' in ds[c].standard_name.lower():
120124
node_lat = ds[c]
121-
if node_lon is None or node_lat is None:
122-
raise ValueError(f"Could not find lon and lat for dimension {node_dims}")
123125

124126
node_mask = compute_2d_subset_mask(lat=node_lat, lon=node_lon, polygon=polygon)
125127
msk = np.where(node_mask)
@@ -134,28 +136,27 @@ def compute_polygon_subset_selector(
134136
node_mask[index_bounding_box[0][0]:index_bounding_box[0][1],
135137
index_bounding_box[1][0]:index_bounding_box[1][1]] = True
136138

137-
subset_masks.append(([node_coords[0], node_coords[1]], node_mask))
139+
subset_masks.append((node_vars, node_mask))
140+
138141
for s in ('face', 'edge1', 'edge2'):
139-
dims = grid_topology.attrs.get(f"{s}_dimensions", None)
140-
coords = grid_topology.attrs.get(f"{s}_coordinates", None).split()
142+
info = _get_location_info_from_topology(grid_topology, s)
143+
dims = info['dims']
144+
coords = info['coords']
145+
unique_dims = set(dims)
146+
vars = [k for k in ds.variables if unique_dims.issubset(set(ds[k].dims))]
141147

142148
lon: xr.DataArray | None = None
143-
lat: xr.DataArray | None = None
144149
for c in coords:
145150
if 'lon' in ds[c].standard_name.lower():
146151
lon = ds[c]
147-
elif 'lat' in ds[c].standard_name.lower():
148-
lat = ds[c]
149-
if lon is None or lat is None:
150-
raise ValueError(f"Could not find lon and lat for dimension {dims}")
151-
padding = parse_padding_string(dims)
152+
padding = info['padding']
152153
arranged_padding = [padding[d] for d in lon.dims]
153154
arranged_padding = [0 if p == 'none' or p == 'low' else 1 for p in arranged_padding]
154155
mask = np.zeros(lon.shape, dtype=bool)
155156
mask[index_bounding_box[0][0]:index_bounding_box[0][1] + arranged_padding[0],
156157
index_bounding_box[1][0]:index_bounding_box[1][1] + arranged_padding[1]] = True
157158
xr_mask = xr.DataArray(mask, dims=lon.dims)
158-
subset_masks.append(([coords[0], coords[1]], xr_mask))
159+
subset_masks.append((vars, xr_mask))
159160

160161
return SGridSelector(
161162
name=name or 'selector',
@@ -165,6 +166,40 @@ def compute_polygon_subset_selector(
165166
subset_masks=subset_masks,
166167
)
167168

169+
def _get_location_info_from_topology(grid_topology: xr.DataArray, location) -> dict[str, str]:
170+
'''Get the dimensions and coordinates for a given location from the grid_topology'''
171+
rdict = {}
172+
dim_str = grid_topology.attrs.get(f"{location}_dimensions", None)
173+
coord_str = grid_topology.attrs.get(f"{location}_coordinates", None)
174+
if dim_str is None or coord_str is None:
175+
raise ValueError(f"Could not find {location} dimensions or coordinates")
176+
# Remove padding for now
177+
dims_only = " ".join([v for v in dim_str.split(" ") if "(" not in v and ")" not in v])
178+
if ":" in dims_only:
179+
dims_only = [s.replace(":", "") for s in dims_only.split(" ") if ":" in s]
180+
else:
181+
dims_only = dims_only.split(" ")
182+
183+
padding = dim_str.replace(':', '').split(')')
184+
pdict = {}
185+
if len(padding) == 3: #two padding values
186+
pdict[dims_only[0]] = padding[0].split(' ')[-1]
187+
pdict[dims_only[1]] = padding[1].split(' ')[-1]
188+
elif len(padding) == 2: #one padding value
189+
if padding[-1] == '': #padding is on second dim
190+
pdict[dims_only[1]] = padding[0].split(' ')[-1]
191+
pdict[dims_only[0]] = 'none'
192+
else:
193+
pdict[dims_only[0]] = padding[0].split(' ')[-1]
194+
pdict[dims_only[1]] = 'none'
195+
else:
196+
pdict[dims_only[0]] = 'none'
197+
pdict[dims_only[1]] = 'none'
198+
199+
rdict['dims'] = dims_only
200+
rdict['coords'] = coord_str.split(" ")
201+
rdict['padding'] = pdict
202+
return rdict
168203

169204
def _get_sgrid_dim_coord_names(
170205
grid_topology: xr.DataArray,
@@ -189,30 +224,3 @@ def _get_sgrid_dim_coord_names(
189224
coords.append(v.split(" "))
190225

191226
return list(zip(dims, coords))
192-
193-
def parse_padding_string(dim_string):
194-
'''
195-
Given a grid_topology dimension string, parse the padding for each dimension.
196-
Returns a dict of {dim0name: padding,
197-
dim1name: padding
198-
}
199-
valid values of padding are: 'none', 'low', 'high', 'both'
200-
'''
201-
parsed_string = dim_string.replace('(padding: ', '').replace(')', '').replace(':', '')
202-
split_parsed_string = parsed_string.split(' ')
203-
if len(split_parsed_string) == 6:
204-
return {split_parsed_string[0]:split_parsed_string[2],
205-
split_parsed_string[3]:split_parsed_string[5]}
206-
elif len(split_parsed_string) == 5:
207-
if split_parsed_string[4] in {'none', 'low', 'high', 'both'}:
208-
#2nd dim has padding, and with len 5 that means first does not
209-
split_parsed_string.insert(2, 'none')
210-
else:
211-
split_parsed_string.insert(5, 'none')
212-
return {split_parsed_string[0]:split_parsed_string[2],
213-
split_parsed_string[3]:split_parsed_string[5]}
214-
elif len(split_parsed_string) == 2:
215-
#node dimensions string could look like this: 'node_dimensions: xi_psi eta_psi'
216-
return {split_parsed_string[0]: 'none', split_parsed_string[1]: 'none'}
217-
else:
218-
raise ValueError(f"Padding parsing failure: {dim_string}")

0 commit comments

Comments
 (0)