Skip to content

Commit ed5c9f2

Browse files
committed
Refactors regrid2 main function to work on a DataArray, removes dataset
from land-sea mask generation
1 parent 993ae15 commit ed5c9f2

File tree

4 files changed

+118
-143
lines changed

4 files changed

+118
-143
lines changed

tests/test_mask.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -239,7 +239,7 @@ def test_generate_land_sea_mask_pcmdi(ds):
239239

240240
output = mask.generate_land_sea_mask(ds["ts"], method="pcmdi")
241241

242-
xr.testing.assert_equal(output.lsmask, expected)
242+
xr.testing.assert_equal(output, expected)
243243

244244

245245
@mock.patch("xcdat.mask._improve_mask")
@@ -248,7 +248,7 @@ def test_generate_land_sea_mask_pcmdi_multiple_iterations(_improve_mask, ds):
248248
[
249249
[1, 1, 1, 1],
250250
[0, 0, 0, 0],
251-
[1, 1, 0, 1],
251+
[1, 1, 1, 1],
252252
[0, 0, 0, 0],
253253
],
254254
dims=("lat", "lon"),
@@ -258,24 +258,24 @@ def test_generate_land_sea_mask_pcmdi_multiple_iterations(_improve_mask, ds):
258258
[
259259
[1, 1, 1, 1],
260260
[0, 0, 0, 0],
261-
[1, 1, 1, 1],
261+
[1, 1, 0, 1],
262262
[0, 0, 0, 0],
263263
],
264264
dims=("lat", "lon"),
265265
coords={"lat": ds.lat.copy(), "lon": ds.lon.copy()},
266266
)
267267

268268
_improve_mask.side_effect = [
269-
xr.Dataset({"sftlf": mask1.copy()}),
270-
xr.Dataset({"sftlf": mask2.copy()}),
271-
xr.Dataset({"sftlf": mask2.copy()}),
269+
mask1.copy(),
270+
mask2.copy(),
271+
mask2.copy(),
272272
]
273273

274274
expected = xr.DataArray(
275275
[
276276
[1, 1, 1, 1],
277277
[0, 0, 0, 0],
278-
[1, 1, 1, 1],
278+
[1, 1, 0, 1],
279279
[0, 0, 0, 0],
280280
],
281281
dims=("lat", "lon"),
@@ -284,7 +284,7 @@ def test_generate_land_sea_mask_pcmdi_multiple_iterations(_improve_mask, ds):
284284

285285
output = mask.generate_land_sea_mask(ds["ts"], method="pcmdi")
286286

287-
xr.testing.assert_equal(output.lsmask, expected)
287+
xr.testing.assert_equal(output, expected)
288288

289289

290290
def test_get_resource_path(monkeypatch, tmp_path):

tests/test_regrid.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -655,8 +655,10 @@ def test_regrid_create_nan_mask(self):
655655

656656
output_data = regridder.horizontal("ts", self.coarse_2d_ds)
657657

658+
assert output_data.ts.dtype == np.float32
659+
658660
# np.nan != np.nan, replace with 1e20
659-
output_data = output_data.fillna(1e20)
661+
output_data = output_data.fillna(1e20).astype(np.float32)
660662

661663
expected_output = np.array(
662664
[
@@ -678,8 +680,10 @@ def test_regrid_input_mask(self):
678680

679681
output_data = regridder.horizontal("ts", self.coarse_2d_ds)
680682

683+
assert output_data.ts.dtype == np.float32
684+
681685
# np.nan != np.nan, replace with 1e20
682-
output_data = output_data.fillna(1e20)
686+
output_data = output_data.fillna(1e20).astype(np.float32)
683687

684688
expected_output = np.array(
685689
[

xcdat/mask.py

Lines changed: 24 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from xcdat.axis import get_dim_coords
1212
from xcdat.regridder.accessor import obj_to_grid_ds
1313
from xcdat.regridder.grid import create_grid
14+
from xcdat.regridder.regrid2 import _horizontal
1415

1516
logger = _setup_custom_logger(__name__)
1617

@@ -202,7 +203,6 @@ def _pcmdi_land_sea_mask(
202203
da: xr.DataArray,
203204
threshold1: float = 0.2,
204205
threshold2: float = 0.3,
205-
mask_name: str = "lsmask",
206206
) -> xr.DataArray:
207207
"""Generate a land-sea mask using the PCMDI method.
208208
@@ -233,14 +233,13 @@ def _pcmdi_land_sea_mask(
233233
"sftlf", obj_to_grid_ds(da), tool="regrid2"
234234
)
235235

236-
mask = highres_regrid.copy()
237-
mask["sftlf"] = xr.where(highres_regrid.sftlf > 0.5, 1, 0).astype("i")
236+
mask = xr.where(highres_regrid.sftlf > 0.5, 1, 0).astype("i")
238237

239-
lon = mask.sftlf.cf["X"]
240-
lon_bnds = mask.bounds.get_bounds("X")
238+
lon = mask.cf["X"]
239+
lon_bnds = highres_regrid.bounds.get_bounds("X")
241240
is_circular = _is_circular(lon, lon_bnds)
242241

243-
surrounds = _generate_surrounds(mask.sftlf, is_circular)
242+
surrounds = _generate_surrounds(mask, is_circular)
244243

245244
i = 0
246245

@@ -264,8 +263,6 @@ def _pcmdi_land_sea_mask(
264263

265264
i += 1
266265

267-
mask = mask.rename({"sftlf": mask_name})
268-
269266
return mask
270267

271268

@@ -332,14 +329,14 @@ def _is_circular(lon: xr.DataArray, lon_bnds: xr.DataArray) -> bool:
332329

333330

334331
def _improve_mask(
335-
mask: xr.Dataset,
332+
mask_da: xr.DataArray,
336333
source: xr.Dataset,
337334
data_var: str,
338335
surrounds: list[np.ndarray],
339336
is_circular: bool,
340337
threshold1=0.2,
341338
threshold2=0.3,
342-
) -> xr.Dataset:
339+
) -> xr.DataArray:
343340
"""Improve a land-sea mask.
344341
345342
This function improves a land-sea mask by converting points based on
@@ -367,15 +364,12 @@ def _improve_mask(
367364
xr.Dataset
368365
The improved mask.
369366
"""
370-
mask_approx = _map2four(
371-
mask,
372-
data_var,
373-
)
367+
mask_approx = _map2four(mask_da)
374368

375-
diff = source[data_var] - mask_approx[data_var]
369+
diff = source[data_var] - mask_approx
376370

377371
mask_convert_land = _convert_points(
378-
mask[data_var] * 1.0,
372+
mask_da * 1.0,
379373
source[data_var],
380374
diff,
381375
threshold1,
@@ -395,12 +389,10 @@ def _improve_mask(
395389
convert_land=False,
396390
)
397391

398-
mask[data_var] = mask_convert_sea.astype("i")
399-
400-
return mask
392+
return mask_convert_sea.astype("i")
401393

402394

403-
def _map2four(mask: xr.Dataset, data_var: str) -> xr.Dataset:
395+
def _map2four(mask_da: xr.DataArray) -> xr.DataArray:
404396
"""Map a mask to four subgrids and back.
405397
406398
This function regrids a mask to four subgrids (odd-odd, odd-even,
@@ -419,9 +411,7 @@ def _map2four(mask: xr.Dataset, data_var: str) -> xr.Dataset:
419411
xr.Dataset
420412
The processed mask.
421413
"""
422-
mask_temp = mask.copy()
423-
424-
lat, lon = mask_temp[data_var].cf["Y"], mask_temp[data_var].cf["X"]
414+
lat, lon = mask_da.cf["Y"], mask_da.cf["X"]
425415
lat_odd, lat_even = lat[::2], lat[1::2]
426416
lon_odd, lon_even = lon[::2], lon[1::2]
427417

@@ -430,23 +420,21 @@ def _map2four(mask: xr.Dataset, data_var: str) -> xr.Dataset:
430420
even_odd = create_grid(y=lat_even, x=lon_odd, add_bounds=True)
431421
even_even = create_grid(y=lat_even, x=lon_even, add_bounds=True)
432422

433-
regrid_odd_odd = mask_temp.regridder.horizontal(data_var, odd_odd, tool="regrid2")
434-
regrid_odd_even = mask_temp.regridder.horizontal(data_var, odd_even, tool="regrid2")
435-
regrid_even_odd = mask_temp.regridder.horizontal(data_var, even_odd, tool="regrid2")
436-
regrid_even_even = mask_temp.regridder.horizontal(
437-
data_var, even_even, tool="regrid2"
423+
regrid_odd_odd, _ = _horizontal(mask_da.copy(), obj_to_grid_ds(mask_da), odd_odd)
424+
regrid_odd_even, _ = _horizontal(mask_da.copy(), obj_to_grid_ds(mask_da), odd_even)
425+
regrid_even_odd, _ = _horizontal(mask_da.copy(), obj_to_grid_ds(mask_da), even_odd)
426+
regrid_even_even, _ = _horizontal(
427+
mask_da.copy(), obj_to_grid_ds(mask_da), even_even
438428
)
439429

440-
output = np.zeros(mask_temp[data_var].shape, dtype="f")
441-
442-
output[::2, ::2] = regrid_odd_odd[data_var].data
443-
output[::2, 1::2] = regrid_odd_even[data_var].data
444-
output[1::2, ::2] = regrid_even_odd[data_var].data
445-
output[1::2, 1::2] = regrid_even_even[data_var].data
430+
output = xr.zeros_like(mask_da, dtype="f")
446431

447-
mask_temp[data_var] = (mask_temp[data_var].dims, output)
432+
output[::2, ::2] = regrid_odd_odd.data
433+
output[::2, 1::2] = regrid_odd_even.data
434+
output[1::2, ::2] = regrid_even_odd.data
435+
output[1::2, 1::2] = regrid_even_even.data
448436

449-
return mask_temp
437+
return output
450438

451439

452440
def _convert_points(

0 commit comments

Comments
 (0)