@@ -26,140 +26,146 @@ class MapSpec:
2626 path : Path
2727
2828# -------------------------
29- # helpers to build minimal grids from weight file
29+ # NetCDF opener (backends)
3030# -------------------------
3131
32- def _first_var (m : xr .Dataset , * names : str ) -> Optional [xr .DataArray ]:
32+ def _open_nc (path : Path ) -> xr .Dataset :
33+ """Open NetCDF with explicit engine(s) to avoid backend autodetect failures."""
34+ path = Path (path )
35+ for engine in ("netcdf4" , "scipy" ):
36+ try :
37+ return xr .open_dataset (str (path ), engine = engine )
38+ except Exception :
39+ pass
40+ raise ValueError (
41+ f"Could not open { path } with xarray engines ['netcdf4','scipy']. "
42+ "Check that the file is a NetCDF and the backends are installed."
43+ )
44+
45+ # -------------------------
46+ # Minimal dummy grids from the weight file (based on your approach)
47+ # -------------------------
48+
49+ def _read_array (m : xr .Dataset , * names : str ) -> Optional [xr .DataArray ]:
3350 for n in names :
3451 if n in m :
3552 return m [n ]
3653 return None
3754
38- def _src_coords_from_map (mapfile : Path ) -> Tuple [xr .DataArray , xr .DataArray ]:
39- """Return (lat_in, lon_in) as 1-D arrays for the source (unstructured) grid."""
40- with xr .open_dataset (mapfile ) as m :
41- lat = _first_var (m , "lat_a" , "yc_a" , "src_grid_center_lat" , "y_a" , "y_A" )
42- lon = _first_var (m , "lon_a" , "xc_a" , "src_grid_center_lon" , "x_a" , "x_A" )
43- if lat is not None and lon is not None :
44- lat = lat .rename ("lat" ).astype ("f8" ).reset_coords (drop = True )
45- lon = lon .rename ("lon" ).astype ("f8" ).reset_coords (drop = True )
46- # ensure 1-D
47- lat = lat if lat .ndim == 1 else lat .stack (points = lat .dims ).rename ("lat" )
48- lon = lon if lon .ndim == 1 else lon .stack (points = lon .dims ).rename ("lon" )
49- return lat , lon
50-
51- # fallback: size only
52- size = None
53- for c in ("src_grid_size" , "n_a" ):
54- if c in m :
55- try :
56- size = int (np .asarray (m [c ]).ravel ()[0 ])
57- break
58- except Exception :
59- pass
60- if size is None :
61- # try dims on sparse index variables
62- for v in ("row" , "col" ):
63- if v in m and m [v ].dims :
64- # ESMF uses 1-based indices; not directly size, but better than nothing
65- size = int (np .asarray (m [v ]).max ())
66- break
67- if size is None :
68- raise ValueError ("Cannot infer source grid size from weight file." )
69-
70- lat = xr .DataArray (np .zeros (size , dtype = "f8" ), dims = ("points" ,), name = "lat" )
71- lon = xr .DataArray (np .zeros (size , dtype = "f8" ), dims = ("points" ,), name = "lon" )
72- return lat , lon
55+ def _get_src_shape (m : xr .Dataset ) -> Tuple [int , int ]:
56+ """Infer the source grid 'shape' expected by xESMF's ds_to_ESMFgrid.
7357
74- def _dst_coords_from_map (mapfile : Path ) -> Dict [str , xr .DataArray ]:
75- """Extract dest lat/lon (+bounds if present) from an ESMF map file."""
76- with xr .open_dataset (mapfile ) as m :
77- # centers
78- lat = _first_var (m , "lat_b" , "yc_b" , "lat" , "yc" , "dst_grid_center_lat" )
79- lon = _first_var (m , "lon_b" , "xc_b" , "lon" , "xc" , "dst_grid_center_lon" )
80-
81- if lat is not None and lon is not None :
82- lat = lat .rename ("lat" ).astype ("f8" )
83- lon = lon .rename ("lon" ).astype ("f8" )
84- # If 2-D curvilinear, keep dims; if 1-D, leave as-is
58+ We provide a dummy 2D shape even when the true source is unstructured.
59+ """
60+ a = _read_array (m , "src_grid_dims" )
61+ if a is not None :
62+ vals = np .asarray (a ).ravel ().astype (int )
63+ if vals .size == 1 :
64+ return (1 , int (vals [0 ]))
65+ if vals .size >= 2 :
66+ return (int (vals [- 2 ]), int (vals [- 1 ]))
67+ # fallbacks for unstructured
68+ for n in ("src_grid_size" , "n_a" ):
69+ if n in m :
70+ size = int (np .asarray (m [n ]).ravel ()[0 ])
71+ return (1 , size )
72+ # very last resort: infer from max index of sparse matrix rows
73+ if "row" in m :
74+ size = int (np .asarray (m ["row" ]).max ())
75+ return (1 , size )
76+ raise ValueError ("Cannot infer source grid size from weight file." )
77+
78+ def _get_dst_latlon_1d (m : xr .Dataset ) -> Tuple [np .ndarray , np .ndarray ]:
79+ """Return 1D dest lat, lon arrays from weight file.
80+
81+ Prefers 2D center lat/lon (yc_b/xc_b or lat_b/lon_b), reshaped to (ny, nx),
82+ then converts to 1D centers by taking first column/row, which is valid for
83+ regular 1° lat/lon weights.
84+ """
85+ lat2d = _read_array (m , "yc_b" , "lat_b" , "dst_grid_center_lat" , "yc" , "lat" )
86+ lon2d = _read_array (m , "xc_b" , "lon_b" , "dst_grid_center_lon" , "xc" , "lon" )
87+ if lat2d is not None and lon2d is not None :
88+ # figure out (ny, nx)
89+ if "dst_grid_dims" in m :
90+ ny , nx = [int (x ) for x in np .asarray (m ["dst_grid_dims" ]).ravel ()][- 2 :]
8591 else :
86- # fallback from dims
87- dims = None
88- for name in ("dst_grid_dims" , "dst_grid_size" ):
89- if name in m :
90- dims = np .asarray (m [name ]).ravel ()
91- break
92- if dims is None or dims .size < 2 :
93- # assume 1° regular
92+ # try to infer directly from array size
93+ size = int (np .asarray (lat2d ).size )
94+ # default 1x1 grid size is 180*360
95+ if size == 180 * 360 :
9496 ny , nx = 180 , 360
9597 else :
96- if dims .size == 1 :
97- ny , nx = int (dims [0 ]), 1
98- else :
99- ny , nx = int (dims [- 2 ]), int (dims [- 1 ])
100- lat = xr .DataArray (np .linspace (- 89.5 , 89.5 , ny ), dims = ("lat" ,), name = "lat" )
101- lon = xr .DataArray ((np .arange (nx ) + 0.5 ) * (360.0 / nx ), dims = ("lon" ,), name = "lon" )
102-
103- # bounds (optional)
104- lat_b = _first_var (m , "lat_bnds" , "lat_b" , "bounds_lat" , "lat_bounds" , "y_bnds" , "yb" )
105- lon_b = _first_var (m , "lon_bnds" , "lon_b" , "bounds_lon" , "lon_bounds" , "x_bnds" , "xb" )
106-
107- coords = {"lat" : lat , "lon" : lon }
108- if lat_b is not None :
109- coords ["lat_bnds" ] = lat_b .astype ("f8" )
110- if lon_b is not None :
111- coords ["lon_bnds" ] = lon_b .astype ("f8" )
112- return coords
113-
114- def _make_ds_in_out_from_map (mapfile : Path ) -> Tuple [xr .Dataset , xr .Dataset ]:
115- """Construct minimal CF-like ds_in (unstructured) and ds_out (structured/curvilinear) from weight file."""
116- lat_in , lon_in = _src_coords_from_map (mapfile )
117- dst = _dst_coords_from_map (mapfile )
118-
119- # ds_in: unstructured → 1-D lat/lon on 'points'
120- if lat_in .dims != ("points" ,):
121- lat_in = lat_in .rename ({lat_in .dims [0 ]: "points" })
122- if lon_in .dims != ("points" ,):
123- lon_in = lon_in .rename ({lon_in .dims [0 ]: "points" })
124- ds_in = xr .Dataset ({"lat" : lat_in , "lon" : lon_in })
98+ # fallback: assume square-ish
99+ ny = int (round (np .sqrt (size )))
100+ nx = size // ny
101+ lat2d = np .asarray (lat2d ).reshape (ny , nx )
102+ lon2d = np .asarray (lon2d ).reshape (ny , nx )
103+ return lat2d [:, 0 ].astype ("f8" ), lon2d [0 , :].astype ("f8" )
104+
105+ # If 1D lat/lon already present
106+ lat1d = _read_array (m , "lat" , "yc" )
107+ lon1d = _read_array (m , "lon" , "xc" )
108+ if lat1d is not None and lon1d is not None and lat1d .ndim == 1 and lon1d .ndim == 1 :
109+ return np .asarray (lat1d , dtype = "f8" ), np .asarray (lon1d , dtype = "f8" )
110+
111+ # Final fallback: fabricate a 1° grid
112+ ny , nx = 180 , 360
113+ lat = np .linspace (- 89.5 , 89.5 , ny , dtype = "f8" )
114+ lon = (np .arange (nx , dtype = "f8" ) + 0.5 ) * (360.0 / nx )
115+ return lat , lon
116+
117+ def _make_dummy_grids (mapfile : Path ) -> Tuple [xr .Dataset , xr .Dataset ]:
118+ """Construct minimal ds_in/ds_out satisfying xESMF when reusing weights."""
119+ with _open_nc (mapfile ) as m :
120+ nlat_in , nlon_in = _get_src_shape (m )
121+ lat_out_1d , lon_out_1d = _get_dst_latlon_1d (m )
122+
123+ # Dummy input: arbitrary 2D indices, only shapes matter when weights are provided.
124+ ds_in = xr .Dataset (
125+ {
126+ "lat" : ("lat" , np .arange (nlat_in , dtype = "f8" )),
127+ "lon" : ("lon" , np .arange (nlon_in , dtype = "f8" )),
128+ }
129+ )
125130 ds_in ["lat" ].attrs .update ({"units" : "degrees_north" , "standard_name" : "latitude" })
126131 ds_in ["lon" ].attrs .update ({"units" : "degrees_east" , "standard_name" : "longitude" })
127132
128- # ds_out: accept 1-D lat/lon (regular) or 2-D (curvilinear) from weights
129- ds_out = xr .Dataset ({k : v for k , v in dst .items () if k in {"lat" , "lon" }})
130- for k in ("lat" , "lon" ):
131- if k in ds_out :
132- ds_out [k ].attrs .update (
133- {"units" : f"degrees_{ 'north' if k == 'lat' else 'east' } " ,
134- "standard_name" : "latitude" if k == "lat" else "longitude" }
135- )
133+ # Output: 1D regular lat/lon extracted from weights
134+ ds_out = xr .Dataset (
135+ {"lat" : ("lat" , lat_out_1d ), "lon" : ("lon" , lon_out_1d )}
136+ )
137+ ds_out ["lat" ].attrs .update ({"units" : "degrees_north" , "standard_name" : "latitude" })
138+ ds_out ["lon" ].attrs .update ({"units" : "degrees_east" , "standard_name" : "longitude" })
139+
136140 return ds_in , ds_out
137141
142+ # -------------------------
143+ # Cache
144+ # -------------------------
145+
138146class _RegridderCache :
139147 """Cache of xESMF Regridders constructed from weight files.
140148
141- This avoids reconstructing regridders for the same weight file multiple times
142- and provides a small API to fetch or clear cached instances .
149+ We build minimal `ds_in`/`ds_out` from the weight file to satisfy CF checks,
150+ then reuse the weight file for the actual mapping .
143151 """
144152 _cache : Dict [Path , xe .Regridder ] = {}
145153
146154 @classmethod
147155 def get (cls , mapfile : Path , method_label : str ) -> xe .Regridder :
148- """Return a cached regridder for the given weight file and method.
149-
150- We build minimal `ds_in`/`ds_out` by reading lon/lat from the weight file.
151- This satisfies xESMF's CF checks even when we reuse weights.
152- """
156+ """Return a cached regridder for the given weight file and method."""
153157 mapfile = mapfile .expanduser ().resolve ()
154158 if mapfile not in cls ._cache :
155159 if not mapfile .exists ():
156160 raise FileNotFoundError (f"Regrid weights not found: { mapfile } " )
157- ds_in , ds_out = _make_ds_in_out_from_map (mapfile )
161+ ds_in , ds_out = _make_dummy_grids (mapfile )
158162 cls ._cache [mapfile ] = xe .Regridder (
159- ds_in , ds_out ,
163+ ds_in ,
164+ ds_out ,
160165 method = method_label ,
161- filename = str (mapfile ),
166+ filename = str (mapfile ), # reuse the ESMF weight file on disk
162167 reuse_weights = True ,
168+ periodic = True , # 0..360 longitudes
163169 )
164170 return cls ._cache [mapfile ]
165171
@@ -168,6 +174,10 @@ def clear(cls) -> None:
168174 """Clear all cached regridders (useful for tests or releasing resources)."""
169175 cls ._cache .clear ()
170176
177+ # -------------------------
178+ # Selection & utilities
179+ # -------------------------
180+
171181def _pick_maps (
172182 varname : str ,
173183 conservative_map : Optional [Path ] = None ,
@@ -207,6 +217,10 @@ def _rename_xy_to_latlon(da: xr.DataArray) -> xr.DataArray:
207217 dim_map ["x" ] = "lon"
208218 return da .rename (dim_map ) if dim_map else da
209219
220+ # -------------------------
221+ # Public API
222+ # -------------------------
223+
210224def regrid_to_1deg (
211225 ds_in : xr .Dataset ,
212226 varname : str ,
@@ -215,14 +229,41 @@ def regrid_to_1deg(
215229 conservative_map : Optional [Path ] = None ,
216230 bilinear_map : Optional [Path ] = None ,
217231 keep_attrs : bool = True ,
232+ dtype : str | None = "float32" ,
233+ output_time_chunk : int | None = 12 ,
218234) -> xr .DataArray :
219- """Regrid a field on (time, ncol[, ...]) to (time, [lev,] lat, lon)."""
235+ """Regrid a field on (time, ncol[, ...]) to (time, [lev,] lat, lon).
236+
237+ Parameters
238+ ----------
239+ ds_in : Dataset with var on (..., ncol)
240+ varname : str
241+ Variable name to regrid.
242+ method : Optional[str]
243+ Force "conservative" or "bilinear". If None, choose based on var type.
244+ conservative_map, bilinear_map : Path
245+ ESMF weight files. If missing, defaults are used.
246+ keep_attrs : bool
247+ Copy attrs from input variable to output.
248+ dtype : str or None
249+ Cast input to this dtype before regridding (default float32). Set None to disable.
250+ output_time_chunk : int or None
251+ If set and 'time' present, make xESMF return chunked output along 'time'.
252+ """
220253 if varname not in ds_in :
221254 raise KeyError (f"{ varname !r} not in dataset." )
222255
223256 da = ds_in [varname ]
224257 da2 , non_spatial = _ensure_ncol_last (da )
225258
259+ # cast to save memory
260+ if dtype is not None and str (da2 .dtype ) != dtype :
261+ da2 = da2 .astype (dtype )
262+
263+ # keep dask lazy and chunk along time if present
264+ if "time" in da2 .dims and output_time_chunk :
265+ da2 = da2 .chunk ({"time" : output_time_chunk })
266+
226267 spec = _pick_maps (
227268 varname ,
228269 conservative_map = conservative_map ,
@@ -231,10 +272,14 @@ def regrid_to_1deg(
231272 )
232273 regridder = _RegridderCache .get (spec .path , spec .method_label )
233274
234- out = regridder (da2 ) # -> (*non_spatial, y/x or lat/lon)
275+ # tell xESMF to produce chunked output
276+ kwargs = {}
277+ if "time" in da2 .dims and output_time_chunk :
278+ kwargs ["output_chunks" ] = {"time" : output_time_chunk }
279+
280+ out = regridder (da2 , ** kwargs ) # -> (*non_spatial, y/x or lat/lon)
235281 out = _rename_xy_to_latlon (out )
236282
237- # Try to attach standard attrs
238283 if keep_attrs :
239284 out .attrs .update (da .attrs )
240285
0 commit comments