@@ -339,6 +339,8 @@ def pcmdi_land_sea_mask(
339339 da : xr .DataArray ,
340340 threshold1 : float = 0.2 ,
341341 threshold2 : float = 0.3 ,
342+ source : xr .Dataset | None = None ,
343+ source_data_var : str | None = None ,
342344) -> xr .DataArray :
343345 """Generate a land-sea mask using the PCMDI method.
344346
@@ -353,6 +355,10 @@ def pcmdi_land_sea_mask(
353355 The first threshold for improving the mask, by default 0.2.
354356 threshold2 : float, optional
355357 The second threshold for improving the mask, by default 0.3.
358+ source : xr.Dataset | None, optional
359+ The Dataset containing the variable to use as the high resolution source.
360+ source_data_var : str | None, optional
361+ Name of the variable in `source` to use as the high resolution source.
356362
357363 Returns
358364 -------
@@ -362,24 +368,47 @@ def pcmdi_land_sea_mask(
362368 Examples
363369 --------
364370
371+ Generate a land-sea mask using the PCMDI method:
365372
373+ >>> import xcdat
374+ >>> ds = xcdat.open_dataset("/path/to/file")
375+ >>> land_sea_mask = xcdat.mask.pcmdi_land_sea_mask(ds["tas"])
376+
377+ Generate a land-sea mask using the PCMDI method with custom thresholds:
378+
379+ >>> land_sea_mask = xcdat.mask.pcmdi_land_sea_mask(ds["tas"], threshold1=0.3, threshold2=0.4)
380+
381+ Generate a land-sea mask using the PCMDI method with a custom high resolution source:
382+
383+ >>> highres_ds = xcdata.open_dataset("/path/to/file")
384+ >>> land_sea_mask = xcdat.mask.pcmdi_land_sea_mask(ds["tas"], source=highres_ds, source_data_var="highres")
366385 """
367- resource_path = str (_get_resource_path ("navy_land.nc" , Path .cwd ()))
386+ if source is not None and source_data_var is None :
387+ raise ValueError (
388+ "The 'source_data_var' value cannot be None when using the 'source' option."
389+ )
390+
391+ if source is None :
392+ source_data_var = "sftlf"
368393
369- highres = open_dataset ( resource_path )
394+ resource_path = str ( _get_resource_path ( "navy_land.nc" , Path . cwd ()) )
370395
371- highres_regrid = highres .regridder .horizontal (
372- "sftlf" , obj_to_grid_ds (da ), tool = "regrid2"
396+ source = open_dataset (resource_path )
397+
398+ source_regrid = source .regridder .horizontal (
399+ source_data_var , obj_to_grid_ds (da ), tool = "regrid2"
373400 )
374401
375- mask = highres_regrid .copy ()
376- mask ["sftlf" ] = xr .where (highres_regrid .sftlf > 0.5 , 1 , 0 ).astype ("i" )
402+ mask = source_regrid .copy ()
403+ mask [source_data_var ] = xr .where (source_regrid [source_data_var ] > 0.5 , 1 , 0 ).astype (
404+ "i"
405+ )
377406
378- lon = mask . sftlf .cf ["X" ]
407+ lon = mask [ source_data_var ] .cf ["X" ]
379408 lon_bnds = mask .bounds .get_bounds ("X" )
380409 is_circular = _is_circular (lon , lon_bnds )
381410
382- surrounds = _generate_surrounds (mask . sftlf , is_circular )
411+ surrounds = _generate_surrounds (mask [ source_data_var ] , is_circular )
383412
384413 i = 0
385414
@@ -388,8 +417,8 @@ def pcmdi_land_sea_mask(
388417
389418 improved_mask = _improve_mask (
390419 mask .copy (deep = True ),
391- highres_regrid ,
392- "sftlf" ,
420+ source_regrid ,
421+ source_data_var , # type: ignore[arg-type]
393422 surrounds ,
394423 is_circular ,
395424 threshold1 ,
@@ -403,7 +432,7 @@ def pcmdi_land_sea_mask(
403432
404433 i += 1
405434
406- return mask ["sftlf" ]
435+ return mask [source_data_var ]
407436
408437
409438def _get_resource_path (filename : str , default_path : Path | None = None ) -> Path :
@@ -610,9 +639,9 @@ def _convert_points(
610639 diff : xr.DataArray
611640 The difference between the source and an approximated mask.
612641 threshold1 : float
613- The first threshold for conversion .
642+ Threshold for points in the `diff` DataArray .
614643 threshold2 : float
615- The second threshold for conversion .
644+ Threshold for points in the `source` DataArray .
616645 is_circular : bool
617646 Whether the longitude axis is circular.
618647 surrounds : list[np.ndarray]
@@ -627,16 +656,16 @@ def _convert_points(
627656 """
628657 UL , UC , UR , ML , MR , LL , LC , LR = surrounds
629658
630- flip_value = 0.0
631659 mask_value = 1.0
632660 compare_func : Callable
633661 if convert_land :
634662 compare_func = np .greater
635663 else :
636664 compare_func = np .less
637- flip_value = 1.0
638665 mask_value = 0.0
639666
667+ flip_value = abs (mask_value - 1.0 )
668+
640669 c1 = compare_func (diff , threshold1 )
641670 c2 = compare_func (source , threshold2 )
642671 c = np .logical_and (c1 , c2 )
0 commit comments