Skip to content

Commit 71e3bd4

Browse files
committed
update regrid again
1 parent aa8f874 commit 71e3bd4

File tree

1 file changed

+151
-106
lines changed

1 file changed

+151
-106
lines changed

cmip7_prep/regrid.py

Lines changed: 151 additions & 106 deletions
Original file line numberDiff line numberDiff line change
@@ -26,140 +26,146 @@ class MapSpec:
2626
path: Path
2727

2828
# -------------------------
29-
# helpers to build minimal grids from weight file
29+
# NetCDF opener (backends)
3030
# -------------------------
3131

32-
def _first_var(m: xr.Dataset, *names: str) -> Optional[xr.DataArray]:
32+
def _open_nc(path: Path) -> xr.Dataset:
33+
"""Open NetCDF with explicit engine(s) to avoid backend autodetect failures."""
34+
path = Path(path)
35+
for engine in ("netcdf4", "scipy"):
36+
try:
37+
return xr.open_dataset(str(path), engine=engine)
38+
except Exception:
39+
pass
40+
raise ValueError(
41+
f"Could not open {path} with xarray engines ['netcdf4','scipy']. "
42+
"Check that the file is a NetCDF and the backends are installed."
43+
)
44+
45+
# -------------------------
46+
# Minimal dummy grids from the weight file (based on your approach)
47+
# -------------------------
48+
49+
def _read_array(m: xr.Dataset, *names: str) -> Optional[xr.DataArray]:
3350
for n in names:
3451
if n in m:
3552
return m[n]
3653
return None
3754

38-
def _src_coords_from_map(mapfile: Path) -> Tuple[xr.DataArray, xr.DataArray]:
39-
"""Return (lat_in, lon_in) as 1-D arrays for the source (unstructured) grid."""
40-
with xr.open_dataset(mapfile) as m:
41-
lat = _first_var(m, "lat_a", "yc_a", "src_grid_center_lat", "y_a", "y_A")
42-
lon = _first_var(m, "lon_a", "xc_a", "src_grid_center_lon", "x_a", "x_A")
43-
if lat is not None and lon is not None:
44-
lat = lat.rename("lat").astype("f8").reset_coords(drop=True)
45-
lon = lon.rename("lon").astype("f8").reset_coords(drop=True)
46-
# ensure 1-D
47-
lat = lat if lat.ndim == 1 else lat.stack(points=lat.dims).rename("lat")
48-
lon = lon if lon.ndim == 1 else lon.stack(points=lon.dims).rename("lon")
49-
return lat, lon
50-
51-
# fallback: size only
52-
size = None
53-
for c in ("src_grid_size", "n_a"):
54-
if c in m:
55-
try:
56-
size = int(np.asarray(m[c]).ravel()[0])
57-
break
58-
except Exception:
59-
pass
60-
if size is None:
61-
# try dims on sparse index variables
62-
for v in ("row", "col"):
63-
if v in m and m[v].dims:
64-
# ESMF uses 1-based indices; not directly size, but better than nothing
65-
size = int(np.asarray(m[v]).max())
66-
break
67-
if size is None:
68-
raise ValueError("Cannot infer source grid size from weight file.")
69-
70-
lat = xr.DataArray(np.zeros(size, dtype="f8"), dims=("points",), name="lat")
71-
lon = xr.DataArray(np.zeros(size, dtype="f8"), dims=("points",), name="lon")
72-
return lat, lon
55+
def _get_src_shape(m: xr.Dataset) -> Tuple[int, int]:
56+
"""Infer the source grid 'shape' expected by xESMF's ds_to_ESMFgrid.
7357
74-
def _dst_coords_from_map(mapfile: Path) -> Dict[str, xr.DataArray]:
75-
"""Extract dest lat/lon (+bounds if present) from an ESMF map file."""
76-
with xr.open_dataset(mapfile) as m:
77-
# centers
78-
lat = _first_var(m, "lat_b", "yc_b", "lat", "yc", "dst_grid_center_lat")
79-
lon = _first_var(m, "lon_b", "xc_b", "lon", "xc", "dst_grid_center_lon")
80-
81-
if lat is not None and lon is not None:
82-
lat = lat.rename("lat").astype("f8")
83-
lon = lon.rename("lon").astype("f8")
84-
# If 2-D curvilinear, keep dims; if 1-D, leave as-is
58+
We provide a dummy 2D shape even when the true source is unstructured.
59+
"""
60+
a = _read_array(m, "src_grid_dims")
61+
if a is not None:
62+
vals = np.asarray(a).ravel().astype(int)
63+
if vals.size == 1:
64+
return (1, int(vals[0]))
65+
if vals.size >= 2:
66+
return (int(vals[-2]), int(vals[-1]))
67+
# fallbacks for unstructured
68+
for n in ("src_grid_size", "n_a"):
69+
if n in m:
70+
size = int(np.asarray(m[n]).ravel()[0])
71+
return (1, size)
72+
# very last resort: infer from max index of sparse matrix rows
73+
if "row" in m:
74+
size = int(np.asarray(m["row"]).max())
75+
return (1, size)
76+
raise ValueError("Cannot infer source grid size from weight file.")
77+
78+
def _get_dst_latlon_1d(m: xr.Dataset) -> Tuple[np.ndarray, np.ndarray]:
79+
"""Return 1D dest lat, lon arrays from weight file.
80+
81+
Prefers 2D center lat/lon (yc_b/xc_b or lat_b/lon_b), reshaped to (ny, nx),
82+
then converts to 1D centers by taking first column/row, which is valid for
83+
regular 1° lat/lon weights.
84+
"""
85+
lat2d = _read_array(m, "yc_b", "lat_b", "dst_grid_center_lat", "yc", "lat")
86+
lon2d = _read_array(m, "xc_b", "lon_b", "dst_grid_center_lon", "xc", "lon")
87+
if lat2d is not None and lon2d is not None:
88+
# figure out (ny, nx)
89+
if "dst_grid_dims" in m:
90+
ny, nx = [int(x) for x in np.asarray(m["dst_grid_dims"]).ravel()][-2:]
8591
else:
86-
# fallback from dims
87-
dims = None
88-
for name in ("dst_grid_dims", "dst_grid_size"):
89-
if name in m:
90-
dims = np.asarray(m[name]).ravel()
91-
break
92-
if dims is None or dims.size < 2:
93-
# assume 1° regular
92+
# try to infer directly from array size
93+
size = int(np.asarray(lat2d).size)
94+
# default 1x1 grid size is 180*360
95+
if size == 180 * 360:
9496
ny, nx = 180, 360
9597
else:
96-
if dims.size == 1:
97-
ny, nx = int(dims[0]), 1
98-
else:
99-
ny, nx = int(dims[-2]), int(dims[-1])
100-
lat = xr.DataArray(np.linspace(-89.5, 89.5, ny), dims=("lat",), name="lat")
101-
lon = xr.DataArray((np.arange(nx) + 0.5) * (360.0 / nx), dims=("lon",), name="lon")
102-
103-
# bounds (optional)
104-
lat_b = _first_var(m, "lat_bnds", "lat_b", "bounds_lat", "lat_bounds", "y_bnds", "yb")
105-
lon_b = _first_var(m, "lon_bnds", "lon_b", "bounds_lon", "lon_bounds", "x_bnds", "xb")
106-
107-
coords = {"lat": lat, "lon": lon}
108-
if lat_b is not None:
109-
coords["lat_bnds"] = lat_b.astype("f8")
110-
if lon_b is not None:
111-
coords["lon_bnds"] = lon_b.astype("f8")
112-
return coords
113-
114-
def _make_ds_in_out_from_map(mapfile: Path) -> Tuple[xr.Dataset, xr.Dataset]:
115-
"""Construct minimal CF-like ds_in (unstructured) and ds_out (structured/curvilinear) from weight file."""
116-
lat_in, lon_in = _src_coords_from_map(mapfile)
117-
dst = _dst_coords_from_map(mapfile)
118-
119-
# ds_in: unstructured → 1-D lat/lon on 'points'
120-
if lat_in.dims != ("points",):
121-
lat_in = lat_in.rename({lat_in.dims[0]: "points"})
122-
if lon_in.dims != ("points",):
123-
lon_in = lon_in.rename({lon_in.dims[0]: "points"})
124-
ds_in = xr.Dataset({"lat": lat_in, "lon": lon_in})
98+
# fallback: assume square-ish
99+
ny = int(round(np.sqrt(size)))
100+
nx = size // ny
101+
lat2d = np.asarray(lat2d).reshape(ny, nx)
102+
lon2d = np.asarray(lon2d).reshape(ny, nx)
103+
return lat2d[:, 0].astype("f8"), lon2d[0, :].astype("f8")
104+
105+
# If 1D lat/lon already present
106+
lat1d = _read_array(m, "lat", "yc")
107+
lon1d = _read_array(m, "lon", "xc")
108+
if lat1d is not None and lon1d is not None and lat1d.ndim == 1 and lon1d.ndim == 1:
109+
return np.asarray(lat1d, dtype="f8"), np.asarray(lon1d, dtype="f8")
110+
111+
# Final fallback: fabricate a 1° grid
112+
ny, nx = 180, 360
113+
lat = np.linspace(-89.5, 89.5, ny, dtype="f8")
114+
lon = (np.arange(nx, dtype="f8") + 0.5) * (360.0 / nx)
115+
return lat, lon
116+
117+
def _make_dummy_grids(mapfile: Path) -> Tuple[xr.Dataset, xr.Dataset]:
118+
"""Construct minimal ds_in/ds_out satisfying xESMF when reusing weights."""
119+
with _open_nc(mapfile) as m:
120+
nlat_in, nlon_in = _get_src_shape(m)
121+
lat_out_1d, lon_out_1d = _get_dst_latlon_1d(m)
122+
123+
# Dummy input: arbitrary 2D indices, only shapes matter when weights are provided.
124+
ds_in = xr.Dataset(
125+
{
126+
"lat": ("lat", np.arange(nlat_in, dtype="f8")),
127+
"lon": ("lon", np.arange(nlon_in, dtype="f8")),
128+
}
129+
)
125130
ds_in["lat"].attrs.update({"units": "degrees_north", "standard_name": "latitude"})
126131
ds_in["lon"].attrs.update({"units": "degrees_east", "standard_name": "longitude"})
127132

128-
# ds_out: accept 1-D lat/lon (regular) or 2-D (curvilinear) from weights
129-
ds_out = xr.Dataset({k: v for k, v in dst.items() if k in {"lat", "lon"}})
130-
for k in ("lat", "lon"):
131-
if k in ds_out:
132-
ds_out[k].attrs.update(
133-
{"units": f"degrees_{'north' if k == 'lat' else 'east'}",
134-
"standard_name": "latitude" if k == "lat" else "longitude"}
135-
)
133+
# Output: 1D regular lat/lon extracted from weights
134+
ds_out = xr.Dataset(
135+
{"lat": ("lat", lat_out_1d), "lon": ("lon", lon_out_1d)}
136+
)
137+
ds_out["lat"].attrs.update({"units": "degrees_north", "standard_name": "latitude"})
138+
ds_out["lon"].attrs.update({"units": "degrees_east", "standard_name": "longitude"})
139+
136140
return ds_in, ds_out
137141

142+
# -------------------------
143+
# Cache
144+
# -------------------------
145+
138146
class _RegridderCache:
139147
"""Cache of xESMF Regridders constructed from weight files.
140148
141-
This avoids reconstructing regridders for the same weight file multiple times
142-
and provides a small API to fetch or clear cached instances.
149+
We build minimal `ds_in`/`ds_out` from the weight file to satisfy CF checks,
150+
then reuse the weight file for the actual mapping.
143151
"""
144152
_cache: Dict[Path, xe.Regridder] = {}
145153

146154
@classmethod
147155
def get(cls, mapfile: Path, method_label: str) -> xe.Regridder:
148-
"""Return a cached regridder for the given weight file and method.
149-
150-
We build minimal `ds_in`/`ds_out` by reading lon/lat from the weight file.
151-
This satisfies xESMF's CF checks even when we reuse weights.
152-
"""
156+
"""Return a cached regridder for the given weight file and method."""
153157
mapfile = mapfile.expanduser().resolve()
154158
if mapfile not in cls._cache:
155159
if not mapfile.exists():
156160
raise FileNotFoundError(f"Regrid weights not found: {mapfile}")
157-
ds_in, ds_out = _make_ds_in_out_from_map(mapfile)
161+
ds_in, ds_out = _make_dummy_grids(mapfile)
158162
cls._cache[mapfile] = xe.Regridder(
159-
ds_in, ds_out,
163+
ds_in,
164+
ds_out,
160165
method=method_label,
161-
filename=str(mapfile),
166+
filename=str(mapfile), # reuse the ESMF weight file on disk
162167
reuse_weights=True,
168+
periodic=True, # 0..360 longitudes
163169
)
164170
return cls._cache[mapfile]
165171

@@ -168,6 +174,10 @@ def clear(cls) -> None:
168174
"""Clear all cached regridders (useful for tests or releasing resources)."""
169175
cls._cache.clear()
170176

177+
# -------------------------
178+
# Selection & utilities
179+
# -------------------------
180+
171181
def _pick_maps(
172182
varname: str,
173183
conservative_map: Optional[Path] = None,
@@ -207,6 +217,10 @@ def _rename_xy_to_latlon(da: xr.DataArray) -> xr.DataArray:
207217
dim_map["x"] = "lon"
208218
return da.rename(dim_map) if dim_map else da
209219

220+
# -------------------------
221+
# Public API
222+
# -------------------------
223+
210224
def regrid_to_1deg(
211225
ds_in: xr.Dataset,
212226
varname: str,
@@ -215,14 +229,41 @@ def regrid_to_1deg(
215229
conservative_map: Optional[Path] = None,
216230
bilinear_map: Optional[Path] = None,
217231
keep_attrs: bool = True,
232+
dtype: str | None = "float32",
233+
output_time_chunk: int | None = 12,
218234
) -> xr.DataArray:
219-
"""Regrid a field on (time, ncol[, ...]) to (time, [lev,] lat, lon)."""
235+
"""Regrid a field on (time, ncol[, ...]) to (time, [lev,] lat, lon).
236+
237+
Parameters
238+
----------
239+
ds_in : Dataset with var on (..., ncol)
240+
varname : str
241+
Variable name to regrid.
242+
method : Optional[str]
243+
Force "conservative" or "bilinear". If None, choose based on var type.
244+
conservative_map, bilinear_map : Path
245+
ESMF weight files. If missing, defaults are used.
246+
keep_attrs : bool
247+
Copy attrs from input variable to output.
248+
dtype : str or None
249+
Cast input to this dtype before regridding (default float32). Set None to disable.
250+
output_time_chunk : int or None
251+
If set and 'time' present, make xESMF return chunked output along 'time'.
252+
"""
220253
if varname not in ds_in:
221254
raise KeyError(f"{varname!r} not in dataset.")
222255

223256
da = ds_in[varname]
224257
da2, non_spatial = _ensure_ncol_last(da)
225258

259+
# cast to save memory
260+
if dtype is not None and str(da2.dtype) != dtype:
261+
da2 = da2.astype(dtype)
262+
263+
# keep dask lazy and chunk along time if present
264+
if "time" in da2.dims and output_time_chunk:
265+
da2 = da2.chunk({"time": output_time_chunk})
266+
226267
spec = _pick_maps(
227268
varname,
228269
conservative_map=conservative_map,
@@ -231,10 +272,14 @@ def regrid_to_1deg(
231272
)
232273
regridder = _RegridderCache.get(spec.path, spec.method_label)
233274

234-
out = regridder(da2) # -> (*non_spatial, y/x or lat/lon)
275+
# tell xESMF to produce chunked output
276+
kwargs = {}
277+
if "time" in da2.dims and output_time_chunk:
278+
kwargs["output_chunks"] = {"time": output_time_chunk}
279+
280+
out = regridder(da2, **kwargs) # -> (*non_spatial, y/x or lat/lon)
235281
out = _rename_xy_to_latlon(out)
236282

237-
# Try to attach standard attrs
238283
if keep_attrs:
239284
out.attrs.update(da.attrs)
240285

0 commit comments

Comments
 (0)