Skip to content

Commit 4237ae4

Browse files
jasonb5tomvothecoder
authored andcommitted
Adds ability to define the source mask to use in the PCMDI land-sea mask generation
1 parent 92c14ba commit 4237ae4

File tree

2 files changed

+93
-17
lines changed

2 files changed

+93
-17
lines changed

tests/test_mask.py

Lines changed: 49 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -252,8 +252,55 @@ def test_generate_land_sea_mask_pcmdi(ds):
252252
xr.testing.assert_equal(output, expected)
253253

254254

255+
def test_pcmdi_land_sea_mask_custom_source(ds):
256+
source = xr.DataArray(
257+
[
258+
[0.1, 0.1, 0.9, 0.2],
259+
[0.1, 0.9, 0.9, 0.1],
260+
[0.0, 0.1, 0.9, 0.9],
261+
[0.1, 0.1, 0.9, 0.1],
262+
],
263+
dims=("lat", "lon"),
264+
coords={"lat": ds.lat.copy(), "lon": ds.lon.copy()},
265+
attrs={"Conventions": "CF-1.0"},
266+
).to_dataset(name="highres_mask")
267+
268+
output = mask.pcmdi_land_sea_mask(
269+
ds["ts"], source=source, source_data_var="highres_mask"
270+
)
271+
272+
expected = xr.DataArray(
273+
[[0, 0, 1, 0], [0, 1, 1, 0], [0, 0, 1, 1], [0, 0, 1, 0]],
274+
dims=("lat", "lon"),
275+
coords={"lat": ds.lat.copy(), "lon": ds.lon.copy()},
276+
attrs={"Conventions": "CF-1.0"},
277+
)
278+
279+
xr.testing.assert_allclose(output, expected)
280+
281+
282+
def test_pcmdi_land_sea_mask_custom_source_error(ds):
283+
source = xr.DataArray(
284+
[
285+
[0.1, 0.1, 0.9, 0.2],
286+
[0.1, 0.9, 0.9, 0.1],
287+
[0.0, 0.1, 0.9, 0.9],
288+
[0.1, 0.1, 0.9, 0.1],
289+
],
290+
dims=("lat", "lon"),
291+
coords={"lat": ds.lat.copy(), "lon": ds.lon.copy()},
292+
attrs={"Conventions": "CF-1.0"},
293+
).to_dataset(name="highres_mask")
294+
295+
with pytest.raises(
296+
ValueError,
297+
match="The 'source_data_var' value cannot be None when using the 'source' option.",
298+
):
299+
mask.pcmdi_land_sea_mask(ds["ts"], source=source)
300+
301+
255302
@mock.patch("xcdat.mask._improve_mask")
256-
def test_generate_land_sea_mask_pcmdi_multiple_iterations(_improve_mask, ds):
303+
def test_pcmdi_land_sea_mask_multiple_iterations(_improve_mask, ds):
257304
mask1 = xr.DataArray(
258305
[
259306
[1, 1, 1, 1],
@@ -292,7 +339,7 @@ def test_generate_land_sea_mask_pcmdi_multiple_iterations(_improve_mask, ds):
292339
coords={"lat": ds.lat.copy(), "lon": ds.lon.copy()},
293340
)
294341

295-
output = mask.generate_land_sea_mask(ds["ts"], method="pcmdi")
342+
output = mask.pcmdi_land_sea_mask(ds["ts"])
296343

297344
xr.testing.assert_equal(output, expected)
298345

xcdat/mask.py

Lines changed: 44 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -339,6 +339,8 @@ def pcmdi_land_sea_mask(
339339
da: xr.DataArray,
340340
threshold1: float = 0.2,
341341
threshold2: float = 0.3,
342+
source: xr.Dataset | None = None,
343+
source_data_var: str | None = None,
342344
) -> xr.DataArray:
343345
"""Generate a land-sea mask using the PCMDI method.
344346
@@ -353,6 +355,10 @@ def pcmdi_land_sea_mask(
353355
The first threshold for improving the mask, by default 0.2.
354356
threshold2 : float, optional
355357
The second threshold for improving the mask, by default 0.3.
358+
source : xr.Dataset | None, optional
359+
The Dataset containing the variable to use as the high resolution source.
360+
source_data_var : str | None, optional
361+
Name of the variable in `source` to use as the high resolution source.
356362
357363
Returns
358364
-------
@@ -362,24 +368,47 @@ def pcmdi_land_sea_mask(
362368
Examples
363369
--------
364370
371+
Generate a land-sea mask using the PCMDI method:
365372
373+
>>> import xcdat
374+
>>> ds = xcdat.open_dataset("/path/to/file")
375+
>>> land_sea_mask = xcdat.mask.pcmdi_land_sea_mask(ds["tas"])
376+
377+
Generate a land-sea mask using the PCMDI method with custom thresholds:
378+
379+
>>> land_sea_mask = xcdat.mask.pcmdi_land_sea_mask(ds["tas"], threshold1=0.3, threshold2=0.4)
380+
381+
Generate a land-sea mask using the PCMDI method with a custom high resolution source:
382+
383+
>>> highres_ds = xcdata.open_dataset("/path/to/file")
384+
>>> land_sea_mask = xcdat.mask.pcmdi_land_sea_mask(ds["tas"], source=highres_ds, source_data_var="highres")
366385
"""
367-
resource_path = str(_get_resource_path("navy_land.nc", Path.cwd()))
386+
if source is not None and source_data_var is None:
387+
raise ValueError(
388+
"The 'source_data_var' value cannot be None when using the 'source' option."
389+
)
390+
391+
if source is None:
392+
source_data_var = "sftlf"
368393

369-
highres = open_dataset(resource_path)
394+
resource_path = str(_get_resource_path("navy_land.nc", Path.cwd()))
370395

371-
highres_regrid = highres.regridder.horizontal(
372-
"sftlf", obj_to_grid_ds(da), tool="regrid2"
396+
source = open_dataset(resource_path)
397+
398+
source_regrid = source.regridder.horizontal(
399+
source_data_var, obj_to_grid_ds(da), tool="regrid2"
373400
)
374401

375-
mask = highres_regrid.copy()
376-
mask["sftlf"] = xr.where(highres_regrid.sftlf > 0.5, 1, 0).astype("i")
402+
mask = source_regrid.copy()
403+
mask[source_data_var] = xr.where(source_regrid[source_data_var] > 0.5, 1, 0).astype(
404+
"i"
405+
)
377406

378-
lon = mask.sftlf.cf["X"]
407+
lon = mask[source_data_var].cf["X"]
379408
lon_bnds = mask.bounds.get_bounds("X")
380409
is_circular = _is_circular(lon, lon_bnds)
381410

382-
surrounds = _generate_surrounds(mask.sftlf, is_circular)
411+
surrounds = _generate_surrounds(mask[source_data_var], is_circular)
383412

384413
i = 0
385414

@@ -388,8 +417,8 @@ def pcmdi_land_sea_mask(
388417

389418
improved_mask = _improve_mask(
390419
mask.copy(deep=True),
391-
highres_regrid,
392-
"sftlf",
420+
source_regrid,
421+
source_data_var, # type: ignore[arg-type]
393422
surrounds,
394423
is_circular,
395424
threshold1,
@@ -403,7 +432,7 @@ def pcmdi_land_sea_mask(
403432

404433
i += 1
405434

406-
return mask["sftlf"]
435+
return mask[source_data_var]
407436

408437

409438
def _get_resource_path(filename: str, default_path: Path | None = None) -> Path:
@@ -610,9 +639,9 @@ def _convert_points(
610639
diff : xr.DataArray
611640
The difference between the source and an approximated mask.
612641
threshold1 : float
613-
The first threshold for conversion.
642+
Threshold for points in the `diff` DataArray.
614643
threshold2 : float
615-
The second threshold for conversion.
644+
Threshold for points in the `source` DataArray.
616645
is_circular : bool
617646
Whether the longitude axis is circular.
618647
surrounds : list[np.ndarray]
@@ -627,16 +656,16 @@ def _convert_points(
627656
"""
628657
UL, UC, UR, ML, MR, LL, LC, LR = surrounds
629658

630-
flip_value = 0.0
631659
mask_value = 1.0
632660
compare_func: Callable
633661
if convert_land:
634662
compare_func = np.greater
635663
else:
636664
compare_func = np.less
637-
flip_value = 1.0
638665
mask_value = 0.0
639666

667+
flip_value = abs(mask_value - 1.0)
668+
640669
c1 = compare_func(diff, threshold1)
641670
c2 = compare_func(source, threshold2)
642671
c = np.logical_and(c1, c2)

0 commit comments

Comments
 (0)