1111from xcdat .axis import get_dim_coords
1212from xcdat .regridder .accessor import obj_to_grid_ds
1313from xcdat .regridder .grid import create_grid
14+ from xcdat .regridder .regrid2 import _horizontal
1415
1516logger = _setup_custom_logger (__name__ )
1617
@@ -202,7 +203,6 @@ def _pcmdi_land_sea_mask(
202203 da : xr .DataArray ,
203204 threshold1 : float = 0.2 ,
204205 threshold2 : float = 0.3 ,
205- mask_name : str = "lsmask" ,
206206) -> xr .DataArray :
207207 """Generate a land-sea mask using the PCMDI method.
208208
@@ -233,14 +233,13 @@ def _pcmdi_land_sea_mask(
233233 "sftlf" , obj_to_grid_ds (da ), tool = "regrid2"
234234 )
235235
236- mask = highres_regrid .copy ()
237- mask ["sftlf" ] = xr .where (highres_regrid .sftlf > 0.5 , 1 , 0 ).astype ("i" )
236+ mask = xr .where (highres_regrid .sftlf > 0.5 , 1 , 0 ).astype ("i" )
238237
239- lon = mask .sftlf . cf ["X" ]
240- lon_bnds = mask .bounds .get_bounds ("X" )
238+ lon = mask .cf ["X" ]
239+ lon_bnds = highres_regrid .bounds .get_bounds ("X" )
241240 is_circular = _is_circular (lon , lon_bnds )
242241
243- surrounds = _generate_surrounds (mask . sftlf , is_circular )
242+ surrounds = _generate_surrounds (mask , is_circular )
244243
245244 i = 0
246245
@@ -264,8 +263,6 @@ def _pcmdi_land_sea_mask(
264263
265264 i += 1
266265
267- mask = mask .rename ({"sftlf" : mask_name })
268-
269266 return mask
270267
271268
@@ -332,14 +329,14 @@ def _is_circular(lon: xr.DataArray, lon_bnds: xr.DataArray) -> bool:
332329
333330
334331def _improve_mask (
335- mask : xr .Dataset ,
332+ mask_da : xr .DataArray ,
336333 source : xr .Dataset ,
337334 data_var : str ,
338335 surrounds : list [np .ndarray ],
339336 is_circular : bool ,
340337 threshold1 = 0.2 ,
341338 threshold2 = 0.3 ,
342- ) -> xr .Dataset :
339+ ) -> xr .DataArray :
343340 """Improve a land-sea mask.
344341
345342 This function improves a land-sea mask by converting points based on
@@ -367,15 +364,12 @@ def _improve_mask(
367364 xr.Dataset
368365 The improved mask.
369366 """
370- mask_approx = _map2four (
371- mask ,
372- data_var ,
373- )
367+ mask_approx = _map2four (mask_da )
374368
375- diff = source [data_var ] - mask_approx [ data_var ]
369+ diff = source [data_var ] - mask_approx
376370
377371 mask_convert_land = _convert_points (
378- mask [ data_var ] * 1.0 ,
372+ mask_da * 1.0 ,
379373 source [data_var ],
380374 diff ,
381375 threshold1 ,
@@ -395,12 +389,10 @@ def _improve_mask(
395389 convert_land = False ,
396390 )
397391
398- mask [data_var ] = mask_convert_sea .astype ("i" )
399-
400- return mask
392+ return mask_convert_sea .astype ("i" )
401393
402394
403- def _map2four (mask : xr .Dataset , data_var : str ) -> xr .Dataset :
395+ def _map2four (mask_da : xr .DataArray ) -> xr .DataArray :
404396 """Map a mask to four subgrids and back.
405397
406398 This function regrids a mask to four subgrids (odd-odd, odd-even,
@@ -419,9 +411,7 @@ def _map2four(mask: xr.Dataset, data_var: str) -> xr.Dataset:
419411 xr.Dataset
420412 The processed mask.
421413 """
422- mask_temp = mask .copy ()
423-
424- lat , lon = mask_temp [data_var ].cf ["Y" ], mask_temp [data_var ].cf ["X" ]
414+ lat , lon = mask_da .cf ["Y" ], mask_da .cf ["X" ]
425415 lat_odd , lat_even = lat [::2 ], lat [1 ::2 ]
426416 lon_odd , lon_even = lon [::2 ], lon [1 ::2 ]
427417
@@ -430,23 +420,21 @@ def _map2four(mask: xr.Dataset, data_var: str) -> xr.Dataset:
430420 even_odd = create_grid (y = lat_even , x = lon_odd , add_bounds = True )
431421 even_even = create_grid (y = lat_even , x = lon_even , add_bounds = True )
432422
433- regrid_odd_odd = mask_temp . regridder . horizontal ( data_var , odd_odd , tool = "regrid2" )
434- regrid_odd_even = mask_temp . regridder . horizontal ( data_var , odd_even , tool = "regrid2" )
435- regrid_even_odd = mask_temp . regridder . horizontal ( data_var , even_odd , tool = "regrid2" )
436- regrid_even_even = mask_temp . regridder . horizontal (
437- data_var , even_even , tool = "regrid2"
423+ regrid_odd_odd , _ = _horizontal ( mask_da . copy (), obj_to_grid_ds ( mask_da ), odd_odd )
424+ regrid_odd_even , _ = _horizontal ( mask_da . copy (), obj_to_grid_ds ( mask_da ), odd_even )
425+ regrid_even_odd , _ = _horizontal ( mask_da . copy (), obj_to_grid_ds ( mask_da ), even_odd )
426+ regrid_even_even , _ = _horizontal (
427+ mask_da . copy (), obj_to_grid_ds ( mask_da ), even_even
438428 )
439429
440- output = np .zeros (mask_temp [data_var ].shape , dtype = "f" )
441-
442- output [::2 , ::2 ] = regrid_odd_odd [data_var ].data
443- output [::2 , 1 ::2 ] = regrid_odd_even [data_var ].data
444- output [1 ::2 , ::2 ] = regrid_even_odd [data_var ].data
445- output [1 ::2 , 1 ::2 ] = regrid_even_even [data_var ].data
430+ output = xr .zeros_like (mask_da , dtype = "f" )
446431
447- mask_temp [data_var ] = (mask_temp [data_var ].dims , output )
432+ output [::2 , ::2 ] = regrid_odd_odd .data
433+ output [::2 , 1 ::2 ] = regrid_odd_even .data
434+ output [1 ::2 , ::2 ] = regrid_even_odd .data
435+ output [1 ::2 , 1 ::2 ] = regrid_even_even .data
448436
449- return mask_temp
437+ return output
450438
451439
452440def _convert_points (
0 commit comments