Skip to content

Commit aa8f874

Browse files
committed
new regrid
1 parent 67a3fad commit aa8f874

File tree

1 file changed

+122
-123
lines changed

1 file changed

+122
-123
lines changed

cmip7_prep/regrid.py

Lines changed: 122 additions & 123 deletions
Original file line numberDiff line numberDiff line change
@@ -1,92 +1,162 @@
1+
# cmip7_prep/regrid.py
12
"""Regridding utilities for CESM -> 1° lat/lon using precomputed ESMF weights."""
2-
33
from __future__ import annotations
44
from dataclasses import dataclass
55
from pathlib import Path
66
from typing import Optional, Dict, Tuple
77

8+
import numpy as np
89
import xarray as xr
910
import xesmf as xe
10-
import numpy as np
1111

1212
# Default weight maps; override via function args.
13-
DEFAULT_CONS_MAP = Path(
14-
"/glade/campaign/cesm/cesmdata/inputdata/cpl/gridmaps/ne30pg3/map_ne30pg3_to_1x1d_aave.nc"
15-
)
16-
DEFAULT_BILIN_MAP = Path(
17-
"/glade/campaign/cesm/cesmdata/inputdata/cpl/gridmaps/ne30pg3/map_ne30pg3_to_1x1d_bilin.nc"
18-
) # optional bilinear map
13+
DEFAULT_CONS_MAP = Path("map_ne30pg3_to_1x1d_aave.nc")
14+
DEFAULT_BILIN_MAP = Path("") # optional bilinear map
1915

2016
# Variables treated as "intensive" → prefer bilinear when available.
2117
INTENSIVE_VARS = {
22-
"tas",
23-
"tasmin",
24-
"tasmax",
25-
"psl",
26-
"ps",
27-
"huss",
28-
"uas",
29-
"vas",
30-
"sfcWind",
31-
"ts",
32-
"prsn",
33-
"clt",
34-
"ta",
35-
"ua",
36-
"va",
37-
"zg",
38-
"hus",
39-
"thetao",
40-
"uo",
41-
"vo",
42-
"so",
18+
"tas", "tasmin", "tasmax", "psl", "ps", "huss", "uas", "vas", "sfcWind",
19+
"ts", "prsn", "clt", "ta", "ua", "va", "zg", "hus", "thetao", "uo", "vo", "so",
4320
}
4421

45-
4622
@dataclass(frozen=True)
4723
class MapSpec:
4824
"""Specification of which weight map to use for a variable."""
49-
5025
method_label: str # "conservative" or "bilinear"
5126
path: Path
5227

28+
# -------------------------
29+
# helpers to build minimal grids from weight file
30+
# -------------------------
31+
32+
def _first_var(m: xr.Dataset, *names: str) -> Optional[xr.DataArray]:
33+
for n in names:
34+
if n in m:
35+
return m[n]
36+
return None
37+
38+
def _src_coords_from_map(mapfile: Path) -> Tuple[xr.DataArray, xr.DataArray]:
39+
"""Return (lat_in, lon_in) as 1-D arrays for the source (unstructured) grid."""
40+
with xr.open_dataset(mapfile) as m:
41+
lat = _first_var(m, "lat_a", "yc_a", "src_grid_center_lat", "y_a", "y_A")
42+
lon = _first_var(m, "lon_a", "xc_a", "src_grid_center_lon", "x_a", "x_A")
43+
if lat is not None and lon is not None:
44+
lat = lat.rename("lat").astype("f8").reset_coords(drop=True)
45+
lon = lon.rename("lon").astype("f8").reset_coords(drop=True)
46+
# ensure 1-D
47+
lat = lat if lat.ndim == 1 else lat.stack(points=lat.dims).rename("lat")
48+
lon = lon if lon.ndim == 1 else lon.stack(points=lon.dims).rename("lon")
49+
return lat, lon
50+
51+
# fallback: size only
52+
size = None
53+
for c in ("src_grid_size", "n_a"):
54+
if c in m:
55+
try:
56+
size = int(np.asarray(m[c]).ravel()[0])
57+
break
58+
except Exception:
59+
pass
60+
if size is None:
61+
# try dims on sparse index variables
62+
for v in ("row", "col"):
63+
if v in m and m[v].dims:
64+
# ESMF uses 1-based indices; not directly size, but better than nothing
65+
size = int(np.asarray(m[v]).max())
66+
break
67+
if size is None:
68+
raise ValueError("Cannot infer source grid size from weight file.")
69+
70+
lat = xr.DataArray(np.zeros(size, dtype="f8"), dims=("points",), name="lat")
71+
lon = xr.DataArray(np.zeros(size, dtype="f8"), dims=("points",), name="lon")
72+
return lat, lon
73+
74+
def _dst_coords_from_map(mapfile: Path) -> Dict[str, xr.DataArray]:
75+
"""Extract dest lat/lon (+bounds if present) from an ESMF map file."""
76+
with xr.open_dataset(mapfile) as m:
77+
# centers
78+
lat = _first_var(m, "lat_b", "yc_b", "lat", "yc", "dst_grid_center_lat")
79+
lon = _first_var(m, "lon_b", "xc_b", "lon", "xc", "dst_grid_center_lon")
80+
81+
if lat is not None and lon is not None:
82+
lat = lat.rename("lat").astype("f8")
83+
lon = lon.rename("lon").astype("f8")
84+
# If 2-D curvilinear, keep dims; if 1-D, leave as-is
85+
else:
86+
# fallback from dims
87+
dims = None
88+
for name in ("dst_grid_dims", "dst_grid_size"):
89+
if name in m:
90+
dims = np.asarray(m[name]).ravel()
91+
break
92+
if dims is None or dims.size < 2:
93+
# assume 1° regular
94+
ny, nx = 180, 360
95+
else:
96+
if dims.size == 1:
97+
ny, nx = int(dims[0]), 1
98+
else:
99+
ny, nx = int(dims[-2]), int(dims[-1])
100+
lat = xr.DataArray(np.linspace(-89.5, 89.5, ny), dims=("lat",), name="lat")
101+
lon = xr.DataArray((np.arange(nx) + 0.5) * (360.0 / nx), dims=("lon",), name="lon")
102+
103+
# bounds (optional)
104+
lat_b = _first_var(m, "lat_bnds", "lat_b", "bounds_lat", "lat_bounds", "y_bnds", "yb")
105+
lon_b = _first_var(m, "lon_bnds", "lon_b", "bounds_lon", "lon_bounds", "x_bnds", "xb")
106+
107+
coords = {"lat": lat, "lon": lon}
108+
if lat_b is not None:
109+
coords["lat_bnds"] = lat_b.astype("f8")
110+
if lon_b is not None:
111+
coords["lon_bnds"] = lon_b.astype("f8")
112+
return coords
113+
114+
def _make_ds_in_out_from_map(mapfile: Path) -> Tuple[xr.Dataset, xr.Dataset]:
115+
"""Construct minimal CF-like ds_in (unstructured) and ds_out (structured/curvilinear) from weight file."""
116+
lat_in, lon_in = _src_coords_from_map(mapfile)
117+
dst = _dst_coords_from_map(mapfile)
118+
119+
# ds_in: unstructured → 1-D lat/lon on 'points'
120+
if lat_in.dims != ("points",):
121+
lat_in = lat_in.rename({lat_in.dims[0]: "points"})
122+
if lon_in.dims != ("points",):
123+
lon_in = lon_in.rename({lon_in.dims[0]: "points"})
124+
ds_in = xr.Dataset({"lat": lat_in, "lon": lon_in})
125+
ds_in["lat"].attrs.update({"units": "degrees_north", "standard_name": "latitude"})
126+
ds_in["lon"].attrs.update({"units": "degrees_east", "standard_name": "longitude"})
127+
128+
# ds_out: accept 1-D lat/lon (regular) or 2-D (curvilinear) from weights
129+
ds_out = xr.Dataset({k: v for k, v in dst.items() if k in {"lat", "lon"}})
130+
for k in ("lat", "lon"):
131+
if k in ds_out:
132+
ds_out[k].attrs.update(
133+
{"units": f"degrees_{'north' if k == 'lat' else 'east'}",
134+
"standard_name": "latitude" if k == "lat" else "longitude"}
135+
)
136+
return ds_in, ds_out
53137

54138
class _RegridderCache:
55139
"""Cache of xESMF Regridders constructed from weight files.
56140
57141
This avoids reconstructing regridders for the same weight file multiple times
58142
and provides a small API to fetch or clear cached instances.
59143
"""
60-
61144
_cache: Dict[Path, xe.Regridder] = {}
62145

63146
@classmethod
64147
def get(cls, mapfile: Path, method_label: str) -> xe.Regridder:
65148
"""Return a cached regridder for the given weight file and method.
66149
67-
If no regridder exists yet for `mapfile`, it is created using xESMF with
68-
`filename=mapfile` (so source/destination grids are read from the weight
69-
file) and stored in the cache. Subsequent calls reuse the same instance.
70-
71-
Parameters
72-
----------
73-
mapfile : Path
74-
Path to an ESMF weight file.
75-
method_label : str
76-
xESMF method label; used only for constructor parity.
77-
78-
Returns
79-
-------
80-
xe.Regridder
81-
Cached or newly created regridder.
150+
We build minimal `ds_in`/`ds_out` by reading lon/lat from the weight file.
151+
This satisfies xESMF's CF checks even when we reuse weights.
82152
"""
83153
mapfile = mapfile.expanduser().resolve()
84154
if mapfile not in cls._cache:
85155
if not mapfile.exists():
86156
raise FileNotFoundError(f"Regrid weights not found: {mapfile}")
157+
ds_in, ds_out = _make_ds_in_out_from_map(mapfile)
87158
cls._cache[mapfile] = xe.Regridder(
88-
xr.Dataset(),
89-
xr.Dataset(),
159+
ds_in, ds_out,
90160
method=method_label,
91161
filename=str(mapfile),
92162
reuse_weights=True,
@@ -98,7 +168,6 @@ def clear(cls) -> None:
98168
"""Clear all cached regridders (useful for tests or releasing resources)."""
99169
cls._cache.clear()
100170

101-
102171
def _pick_maps(
103172
varname: str,
104173
conservative_map: Optional[Path] = None,
@@ -122,54 +191,13 @@ def _pick_maps(
122191
return MapSpec("bilinear", bilin)
123192
return MapSpec("conservative", cons)
124193

125-
126194
def _ensure_ncol_last(da: xr.DataArray) -> Tuple[xr.DataArray, Tuple[str, ...]]:
127195
"""Move 'ncol' to the last position; return (da, non_spatial_dims)."""
128196
if "ncol" not in da.dims:
129197
raise ValueError(f"Expected 'ncol' in dims; got {da.dims}")
130198
non_spatial = tuple(d for d in da.dims if d != "ncol")
131199
return da.transpose(*non_spatial, "ncol"), non_spatial
132200

133-
134-
def _dst_coords_from_map(mapfile: Path) -> Dict[str, xr.DataArray]:
135-
"""Extract dest lat/lon (+bounds if present) from an ESMF map file."""
136-
with xr.open_dataset(mapfile) as m:
137-
# centers
138-
if "lat" in m and "lon" in m:
139-
lat = m["lat"]
140-
lon = m["lon"]
141-
elif "yc" in m and "xc" in m:
142-
lat = m["yc"].rename("lat")
143-
lon = m["xc"].rename("lon")
144-
else:
145-
dims = np.asarray(m.get("dst_grid_dims", [180, 360])).ravel()
146-
ny = int(dims[-2]) if dims.size >= 2 else 180
147-
nx = int(dims[-1]) if dims.size >= 2 else 360
148-
lat = xr.DataArray(np.linspace(-89.5, 89.5, ny), dims=("lat",), name="lat")
149-
lon = xr.DataArray(
150-
(np.arange(nx) + 0.5) * (360.0 / nx), dims=("lon",), name="lon"
151-
)
152-
153-
# bounds
154-
lat_b = None
155-
lon_b = None
156-
for cand in ("lat_bnds", "lat_b", "bounds_lat", "lat_bounds", "y_bnds", "yb"):
157-
if cand in m:
158-
lat_b = m[cand]
159-
break
160-
for cand in ("lon_bnds", "lon_b", "bounds_lon", "lon_bounds", "x_bnds", "xb"):
161-
if cand in m:
162-
lon_b = m[cand]
163-
break
164-
165-
coords = {"lat": lat, "lon": lon}
166-
if lat_b is not None:
167-
coords["lat_bnds"] = lat_b
168-
if lon_b is not None:
169-
coords["lon_bnds"] = lon_b
170-
return coords
171-
172-
173201
def _rename_xy_to_latlon(da: xr.DataArray) -> xr.DataArray:
174202
"""Normalize 2-D dims to ('lat','lon') if they came out as ('y','x')."""
175203
dim_map = {}
@@ -179,7 +207,6 @@ def _rename_xy_to_latlon(da: xr.DataArray) -> xr.DataArray:
179207
dim_map["x"] = "lon"
180208
return da.rename(dim_map) if dim_map else da
181209

182-
183210
def regrid_to_1deg(
184211
ds_in: xr.Dataset,
185212
varname: str,
@@ -207,20 +234,7 @@ def regrid_to_1deg(
207234
out = regridder(da2) # -> (*non_spatial, y/x or lat/lon)
208235
out = _rename_xy_to_latlon(out)
209236

210-
# Attach lat/lon (+bounds) from map if available
211-
try:
212-
dst_coords = _dst_coords_from_map(spec.path)
213-
if {"lat", "lon"}.issubset(out.dims):
214-
out = out.assign_coords(
215-
{k: v for k, v in dst_coords.items() if k in {"lat", "lon"}}
216-
)
217-
for bname in ("lat_bnds", "lon_bnds"):
218-
if bname in dst_coords:
219-
out = out.assign_coords({bname: dst_coords[bname]})
220-
except (OSError, ValueError, KeyError):
221-
# Non-fatal: keep whatever coords xESMF provided
222-
pass
223-
237+
# Try to attach standard attrs
224238
if keep_attrs:
225239
out.attrs.update(da.attrs)
226240

@@ -237,7 +251,6 @@ def regrid_to_1deg(
237251

238252
return out
239253

240-
241254
def regrid_mask_or_area(
242255
da_in: xr.DataArray,
243256
*,
@@ -249,24 +262,10 @@ def regrid_mask_or_area(
249262
if "time" in da_in.dims:
250263
da_in = da_in.transpose("time", "ncol", ...)
251264

252-
spec = MapSpec(
253-
"conservative", Path(conservative_map) if conservative_map else DEFAULT_CONS_MAP
254-
)
265+
spec = MapSpec("conservative", Path(conservative_map) if conservative_map else DEFAULT_CONS_MAP)
255266
regridder = _RegridderCache.get(spec.path, spec.method_label)
256267

257268
out = regridder(da_in)
258269
out = _rename_xy_to_latlon(out)
259-
260-
try:
261-
dst_coords = _dst_coords_from_map(spec.path)
262-
if {"lat", "lon"}.issubset(out.dims):
263-
out = out.assign_coords(
264-
{k: v for k, v in dst_coords.items() if k in {"lat", "lon"}}
265-
)
266-
for bname in ("lat_bnds", "lon_bnds"):
267-
if bname in dst_coords:
268-
out = out.assign_coords({bname: dst_coords[bname]})
269-
except (OSError, ValueError, KeyError):
270-
pass
271-
272270
return out
271+

0 commit comments

Comments
 (0)