@@ -119,6 +119,45 @@ def _get_src_shape(m: xr.Dataset) -> Tuple[int, int]:
119119 raise ValueError ("Cannot infer source grid size from weight file." )
120120
121121
122+ def _dst_latlon_1d_from_map (mapfile : Path ) -> tuple [np .ndarray , np .ndarray ]:
123+ """Return 1D (lat_out, lon_out) arrays for the destination grid from the map file."""
124+ with _open_nc (mapfile ) as m :
125+ # prefer 2-D centers, then reshape → 1-D
126+ lat2d = _read_array (m , "yc_b" , "lat_b" , "dst_grid_center_lat" , "yc" , "lat" )
127+ lon2d = _read_array (m , "xc_b" , "lon_b" , "dst_grid_center_lon" , "xc" , "lon" )
128+ if lat2d is not None and lon2d is not None :
129+ size = int (np .asarray (lat2d ).size )
130+ # try dims if present, else infer 180x360 for 1° grid
131+ if "dst_grid_dims" in m :
132+ dims = np .asarray (m ["dst_grid_dims" ]).ravel ().astype (int )
133+ if dims .size >= 2 :
134+ ny , nx = int (dims [- 2 ]), int (dims [- 1 ])
135+ else :
136+ ny , nx = 180 , 360
137+ else :
138+ if size == 180 * 360 :
139+ ny , nx = 180 , 360
140+ else :
141+ ny = int (round (np .sqrt (size )))
142+ nx = size // ny
143+ lat2d = np .asarray (lat2d ).reshape (ny , nx )
144+ lon2d = np .asarray (lon2d ).reshape (ny , nx )
145+ lat1d = lat2d [:, 0 ].astype ("f8" )
146+ lon1d = lon2d [0 , :].astype ("f8" )
147+ return lat1d , lon1d
148+
149+ # fallback: 1-D already present
150+ lat1d = _read_array (m , "lat" , "yc" )
151+ lon1d = _read_array (m , "lon" , "xc" )
152+ if lat1d is not None and lon1d is not None :
153+ return np .asarray (lat1d , dtype = "f8" ), np .asarray (lon1d , dtype = "f8" )
154+
155+ # last resort: fabricate standard 1°
156+ lat = np .linspace (- 89.5 , 89.5 , 180 , dtype = "f8" )
157+ lon = np .arange (360 , dtype = "f8" ) + 0.5
158+ return lat , lon
159+
160+
122161def _get_dst_latlon_1d (m : xr .Dataset ) -> Tuple [np .ndarray , np .ndarray ]:
123162 """Return 1D dest lat, lon arrays from weight file.
124163
@@ -258,14 +297,14 @@ def _ensure_ncol_last(da: xr.DataArray) -> Tuple[xr.DataArray, Tuple[str, ...]]:
258297 return da .transpose (* non_spatial , "ncol" ), non_spatial
259298
260299
261- def _rename_xy_to_latlon (da : xr .DataArray ) -> xr .DataArray :
262- """Normalize 2-D dims to ('lat','lon') if they came out as ('y','x')."""
263- dim_map = {}
264- if "y" in da .dims :
265- dim_map ["y" ] = "lat"
266- if "x" in da .dims :
267- dim_map ["x" ] = "lon"
268- return da .rename (dim_map ) if dim_map else da
300+ # def _rename_xy_to_latlon(da: xr.DataArray) -> xr.DataArray:
301+ # """Normalize 2-D dims to ('lat','lon') if they came out as ('y','x')."""
302+ # dim_map = {}
303+ # if "y" in da.dims:
304+ # dim_map["y"] = "lat"
305+ # if "x" in da.dims:
306+ # dim_map["x"] = "lon"
307+ # return da.rename(dim_map) if dim_map else da
269308
270309
271310# -------------------------
@@ -305,8 +344,8 @@ def regrid_to_1deg(
305344 if varname not in ds_in :
306345 raise KeyError (f"{ varname !r} not in dataset." )
307346
308- da = ds_in [varname ]
309- da2 , non_spatial = _ensure_ncol_last (da )
347+ var_da = ds_in [varname ] # always a DataArray
348+ da2 , non_spatial = _ensure_ncol_last (var_da )
310349
311350 # cast to save memory
312351 if dtype is not None and str (da2 .dtype ) != dtype :
@@ -337,42 +376,135 @@ def regrid_to_1deg(
337376 if "time" in da2_2d .dims and output_time_chunk :
338377 kwargs ["output_chunks" ] = {"time" : output_time_chunk }
339378
340- out = regridder (da2_2d , ** kwargs ) # -> (*non_spatial, y/x or lat/lon)
341- out = _rename_xy_to_latlon (out )
342-
343- if keep_attrs :
344- out .attrs .update (da .attrs )
345-
346- for c in non_spatial :
347- if c in ds_in .coords and c in out .dims :
348- out = out .assign_coords ({c : ds_in [c ]})
349-
350- if "lat" in out .coords :
351- out ["lat" ].attrs .setdefault ("units" , "degrees_north" )
352- out ["lat" ].attrs .setdefault ("standard_name" , "latitude" )
353- if "lon" in out .coords :
354- out ["lon" ].attrs .setdefault ("units" , "degrees_east" )
355- out ["lon" ].attrs .setdefault ("standard_name" , "longitude" )
379+ out = regridder (da2_2d , ** kwargs ) # current call that returns (*non_spatial, ?, ?)
380+
381+ # --- NEW: robust lat/lon assignment based on destination grid lengths ---
382+ lat1d , lon1d = _dst_latlon_1d_from_map (spec .path )
383+ ny , nx = len (lat1d ), len (lon1d )
384+
385+ # find the last two dims that came from xESMF
386+ spatial_dims = [d for d in out .dims if d not in non_spatial ]
387+ if len (spatial_dims ) < 2 :
388+ raise ValueError (f"Unexpected output dims { out .dims } ; need two spatial dims." )
389+
390+ da , db = spatial_dims [- 2 ], spatial_dims [- 1 ]
391+ na , nb = out .sizes [da ], out .sizes [db ]
392+
393+ # Decide mapping by comparing lengths to (ny, nx)
394+ if na == ny and nb == nx :
395+ out = out .rename ({da : "lat" , db : "lon" })
396+ elif na == nx and nb == ny :
397+ out = out .rename ({da : "lon" , db : "lat" })
398+ else :
399+ # Heuristic fallback: pick the dim whose size matches 180 as lat
400+ if {na , nb } == {ny , nx }:
401+ # covered above; should not reach here
402+ pass
403+ else :
404+ # choose the one closer to 180 as lat
405+ choose_lat = da if abs (na - 180 ) <= abs (nb - 180 ) else db
406+ choose_lon = db if choose_lat == da else da
407+ out = out .rename ({choose_lat : "lat" , choose_lon : "lon" })
408+
409+ # assign canonical 1-D coords
410+ out = out .assign_coords (lat = ("lat" , lat1d ), lon = ("lon" , lon1d ))
411+
412+ try :
413+ out = out .transpose (* non_spatial , "lat" , "lon" )
414+ except ValueError :
415+ # fallback if non_spatial is empty
416+ out = out .transpose ("lat" , "lon" )
417+ if keep_attrs and hasattr (var_da , "attrs" ):
418+ out .attrs .update (var_da .attrs )
356419
357420 return out
358421
359422
360- def regrid_mask_or_area (
361- da_in : xr .DataArray ,
423+ def regrid_to_1deg_ds (
424+ ds_in : xr .Dataset ,
425+ varname : str ,
362426 * ,
427+ time_from : xr .Dataset | None = None ,
428+ method : Optional [str ] = None ,
363429 conservative_map : Optional [Path ] = None ,
364- ) -> xr .DataArray :
365- """Regrid a mask or cell-area field using conservative weights."""
366- if "ncol" not in da_in .dims :
367- raise ValueError ("Expected 'ncol' in dims for mask/area regridding." )
368- if "time" in da_in .dims :
369- da_in = da_in .transpose ("time" , "ncol" , ...)
370-
371- spec = MapSpec (
372- "conservative" , Path (conservative_map ) if conservative_map else DEFAULT_CONS_MAP
373- )
374- regridder = _RegridderCache .get (spec .path , spec .method_label )
430+ bilinear_map : Optional [Path ] = None ,
431+ keep_attrs : bool = True ,
432+ dtype : str | None = "float32" ,
433+ output_time_chunk : int | None = 12 ,
434+ ) -> xr .Dataset :
435+ """Regrid `varname` and return a Dataset containing the regridded variable.
375436
376- out = regridder (da_in )
377- out = _rename_xy_to_latlon (out )
378- return out
437+ Parameters mirror regrid_to_1deg, but this function:
438+ - returns an xr.Dataset({varname: DataArray})
439+ - if `time_from` is provided, copies 'time' and its bounds into the dataset
440+ """
441+ da = regrid_to_1deg (
442+ ds_in ,
443+ varname ,
444+ method = method ,
445+ conservative_map = conservative_map ,
446+ bilinear_map = bilinear_map ,
447+ keep_attrs = keep_attrs ,
448+ dtype = dtype ,
449+ output_time_chunk = output_time_chunk ,
450+ )
451+ ds_out = xr .Dataset ({varname : da })
452+ if time_from is not None :
453+ ds_out = _attach_time_and_bounds (ds_out , time_from )
454+ return ds_out
455+
456+
457+ # def regrid_mask_or_area(
458+ # da_in: xr.DataArray,
459+ # *,
460+ # conservative_map: Optional[Path] = None,
461+ # ) -> xr.DataArray:
462+ # """Regrid a mask or cell-area field using conservative weights."""
463+ # if "ncol" not in da_in.dims:
464+ # raise ValueError("Expected 'ncol' in dims for mask/area regridding.")
465+ # if "time" in da_in.dims:
466+ # da_in = da_in.transpose("time", "ncol", ...)#
467+ # spec = MapSpec(
468+ # "conservative", Path(conservative_map) if conservative_map else DEFAULT_CONS_MAP
469+ # )
470+ # regridder = _RegridderCache.get(spec.path, spec.method_label)#
471+ #
472+ # out = regridder(da_in)
473+ # out = _rename_xy_to_latlon(out)
474+ # return out
475+
476+ # --- convenience: carry time + bounds from a source dataset into an output dataset ---
477+
478+
479+ def _attach_time_and_bounds (ds_out : xr .Dataset , time_from : xr .Dataset ) -> xr .Dataset :
480+ """Return ds_out with 'time' coord and existing time bounds copied from time_from.
481+
482+ - Looks for the bounds variable via time.attrs['bounds'], or 'time_bounds' / 'time_bnds'.
483+ - Aligns bounds along time and ensures dims are (time, nbnd) before attaching.
484+ - Adds bounds as a DATA VARIABLE (not a coord) so 'nbnd' can be created.
485+ """
486+ if "time" in time_from :
487+ ds_out = ds_out .assign_coords (time = time_from ["time" ])
488+
489+ # locate bounds
490+ bname = time_from ["time" ].attrs .get ("bounds" )
491+ tb = None
492+ if isinstance (bname , str ) and bname in time_from :
493+ tb = time_from [bname ]
494+ elif "time_bounds" in time_from :
495+ bname , tb = "time_bounds" , time_from ["time_bounds" ]
496+ elif "time_bnds" in time_from :
497+ bname , tb = "time_bnds" , time_from ["time_bnds" ]
498+
499+ if tb is not None :
500+ # align length to ds_out
501+ if tb .sizes .get ("time" ) != ds_out .sizes .get ("time" ):
502+ tb = tb .reindex (time = ds_out ["time" ])
503+ # ensure (time, nbnd) ordering
504+ if tb .dims [0 ] != "time" :
505+ other = next (d for d in tb .dims if d != "time" )
506+ tb = tb .transpose ("time" , other )
507+ ds_out [bname ] = tb # data variable (nbnd is created if needed)
508+ ds_out ["time" ].attrs ["bounds" ] = bname
509+
510+ return ds_out
0 commit comments