1+ # cmip7_prep/regrid.py
12"""Regridding utilities for CESM -> 1° lat/lon using precomputed ESMF weights."""
2-
33from __future__ import annotations
44from dataclasses import dataclass
55from pathlib import Path
66from typing import Optional , Dict , Tuple
77
8+ import numpy as np
89import xarray as xr
910import xesmf as xe
10- import numpy as np
1111
1212# Default weight maps; override via function args.
13- DEFAULT_CONS_MAP = Path (
14- "/glade/campaign/cesm/cesmdata/inputdata/cpl/gridmaps/ne30pg3/map_ne30pg3_to_1x1d_aave.nc"
15- )
16- DEFAULT_BILIN_MAP = Path (
17- "/glade/campaign/cesm/cesmdata/inputdata/cpl/gridmaps/ne30pg3/map_ne30pg3_to_1x1d_bilin.nc"
18- ) # optional bilinear map
13+ DEFAULT_CONS_MAP = Path ("map_ne30pg3_to_1x1d_aave.nc" )
14+ DEFAULT_BILIN_MAP = Path ("" ) # optional bilinear map
1915
2016# Variables treated as "intensive" → prefer bilinear when available.
2117INTENSIVE_VARS = {
22- "tas" ,
23- "tasmin" ,
24- "tasmax" ,
25- "psl" ,
26- "ps" ,
27- "huss" ,
28- "uas" ,
29- "vas" ,
30- "sfcWind" ,
31- "ts" ,
32- "prsn" ,
33- "clt" ,
34- "ta" ,
35- "ua" ,
36- "va" ,
37- "zg" ,
38- "hus" ,
39- "thetao" ,
40- "uo" ,
41- "vo" ,
42- "so" ,
18+ "tas" , "tasmin" , "tasmax" , "psl" , "ps" , "huss" , "uas" , "vas" , "sfcWind" ,
19+ "ts" , "prsn" , "clt" , "ta" , "ua" , "va" , "zg" , "hus" , "thetao" , "uo" , "vo" , "so" ,
4320}
4421
45-
4622@dataclass (frozen = True )
4723class MapSpec :
4824 """Specification of which weight map to use for a variable."""
49-
5025 method_label : str # "conservative" or "bilinear"
5126 path : Path
5227
28+ # -------------------------
29+ # helpers to build minimal grids from weight file
30+ # -------------------------
31+
32+ def _first_var (m : xr .Dataset , * names : str ) -> Optional [xr .DataArray ]:
33+ for n in names :
34+ if n in m :
35+ return m [n ]
36+ return None
37+
38+ def _src_coords_from_map (mapfile : Path ) -> Tuple [xr .DataArray , xr .DataArray ]:
39+ """Return (lat_in, lon_in) as 1-D arrays for the source (unstructured) grid."""
40+ with xr .open_dataset (mapfile ) as m :
41+ lat = _first_var (m , "lat_a" , "yc_a" , "src_grid_center_lat" , "y_a" , "y_A" )
42+ lon = _first_var (m , "lon_a" , "xc_a" , "src_grid_center_lon" , "x_a" , "x_A" )
43+ if lat is not None and lon is not None :
44+ lat = lat .rename ("lat" ).astype ("f8" ).reset_coords (drop = True )
45+ lon = lon .rename ("lon" ).astype ("f8" ).reset_coords (drop = True )
46+ # ensure 1-D
47+ lat = lat if lat .ndim == 1 else lat .stack (points = lat .dims ).rename ("lat" )
48+ lon = lon if lon .ndim == 1 else lon .stack (points = lon .dims ).rename ("lon" )
49+ return lat , lon
50+
51+ # fallback: size only
52+ size = None
53+ for c in ("src_grid_size" , "n_a" ):
54+ if c in m :
55+ try :
56+ size = int (np .asarray (m [c ]).ravel ()[0 ])
57+ break
58+ except Exception :
59+ pass
60+ if size is None :
61+ # try dims on sparse index variables
62+ for v in ("row" , "col" ):
63+ if v in m and m [v ].dims :
64+ # ESMF uses 1-based indices; not directly size, but better than nothing
65+ size = int (np .asarray (m [v ]).max ())
66+ break
67+ if size is None :
68+ raise ValueError ("Cannot infer source grid size from weight file." )
69+
70+ lat = xr .DataArray (np .zeros (size , dtype = "f8" ), dims = ("points" ,), name = "lat" )
71+ lon = xr .DataArray (np .zeros (size , dtype = "f8" ), dims = ("points" ,), name = "lon" )
72+ return lat , lon
73+
74+ def _dst_coords_from_map (mapfile : Path ) -> Dict [str , xr .DataArray ]:
75+ """Extract dest lat/lon (+bounds if present) from an ESMF map file."""
76+ with xr .open_dataset (mapfile ) as m :
77+ # centers
78+ lat = _first_var (m , "lat_b" , "yc_b" , "lat" , "yc" , "dst_grid_center_lat" )
79+ lon = _first_var (m , "lon_b" , "xc_b" , "lon" , "xc" , "dst_grid_center_lon" )
80+
81+ if lat is not None and lon is not None :
82+ lat = lat .rename ("lat" ).astype ("f8" )
83+ lon = lon .rename ("lon" ).astype ("f8" )
84+ # If 2-D curvilinear, keep dims; if 1-D, leave as-is
85+ else :
86+ # fallback from dims
87+ dims = None
88+ for name in ("dst_grid_dims" , "dst_grid_size" ):
89+ if name in m :
90+ dims = np .asarray (m [name ]).ravel ()
91+ break
92+ if dims is None or dims .size < 2 :
93+ # assume 1° regular
94+ ny , nx = 180 , 360
95+ else :
96+ if dims .size == 1 :
97+ ny , nx = int (dims [0 ]), 1
98+ else :
99+ ny , nx = int (dims [- 2 ]), int (dims [- 1 ])
100+ lat = xr .DataArray (np .linspace (- 89.5 , 89.5 , ny ), dims = ("lat" ,), name = "lat" )
101+ lon = xr .DataArray ((np .arange (nx ) + 0.5 ) * (360.0 / nx ), dims = ("lon" ,), name = "lon" )
102+
103+ # bounds (optional)
104+ lat_b = _first_var (m , "lat_bnds" , "lat_b" , "bounds_lat" , "lat_bounds" , "y_bnds" , "yb" )
105+ lon_b = _first_var (m , "lon_bnds" , "lon_b" , "bounds_lon" , "lon_bounds" , "x_bnds" , "xb" )
106+
107+ coords = {"lat" : lat , "lon" : lon }
108+ if lat_b is not None :
109+ coords ["lat_bnds" ] = lat_b .astype ("f8" )
110+ if lon_b is not None :
111+ coords ["lon_bnds" ] = lon_b .astype ("f8" )
112+ return coords
113+
114+ def _make_ds_in_out_from_map (mapfile : Path ) -> Tuple [xr .Dataset , xr .Dataset ]:
115+ """Construct minimal CF-like ds_in (unstructured) and ds_out (structured/curvilinear) from weight file."""
116+ lat_in , lon_in = _src_coords_from_map (mapfile )
117+ dst = _dst_coords_from_map (mapfile )
118+
119+ # ds_in: unstructured → 1-D lat/lon on 'points'
120+ if lat_in .dims != ("points" ,):
121+ lat_in = lat_in .rename ({lat_in .dims [0 ]: "points" })
122+ if lon_in .dims != ("points" ,):
123+ lon_in = lon_in .rename ({lon_in .dims [0 ]: "points" })
124+ ds_in = xr .Dataset ({"lat" : lat_in , "lon" : lon_in })
125+ ds_in ["lat" ].attrs .update ({"units" : "degrees_north" , "standard_name" : "latitude" })
126+ ds_in ["lon" ].attrs .update ({"units" : "degrees_east" , "standard_name" : "longitude" })
127+
128+ # ds_out: accept 1-D lat/lon (regular) or 2-D (curvilinear) from weights
129+ ds_out = xr .Dataset ({k : v for k , v in dst .items () if k in {"lat" , "lon" }})
130+ for k in ("lat" , "lon" ):
131+ if k in ds_out :
132+ ds_out [k ].attrs .update (
133+ {"units" : f"degrees_{ 'north' if k == 'lat' else 'east' } " ,
134+ "standard_name" : "latitude" if k == "lat" else "longitude" }
135+ )
136+ return ds_in , ds_out
53137
54138class _RegridderCache :
55139 """Cache of xESMF Regridders constructed from weight files.
56140
57141 This avoids reconstructing regridders for the same weight file multiple times
58142 and provides a small API to fetch or clear cached instances.
59143 """
60-
61144 _cache : Dict [Path , xe .Regridder ] = {}
62145
63146 @classmethod
64147 def get (cls , mapfile : Path , method_label : str ) -> xe .Regridder :
65148 """Return a cached regridder for the given weight file and method.
66149
67- If no regridder exists yet for `mapfile`, it is created using xESMF with
68- `filename=mapfile` (so source/destination grids are read from the weight
69- file) and stored in the cache. Subsequent calls reuse the same instance.
70-
71- Parameters
72- ----------
73- mapfile : Path
74- Path to an ESMF weight file.
75- method_label : str
76- xESMF method label; used only for constructor parity.
77-
78- Returns
79- -------
80- xe.Regridder
81- Cached or newly created regridder.
150+ We build minimal `ds_in`/`ds_out` by reading lon/lat from the weight file.
151+ This satisfies xESMF's CF checks even when we reuse weights.
82152 """
83153 mapfile = mapfile .expanduser ().resolve ()
84154 if mapfile not in cls ._cache :
85155 if not mapfile .exists ():
86156 raise FileNotFoundError (f"Regrid weights not found: { mapfile } " )
157+ ds_in , ds_out = _make_ds_in_out_from_map (mapfile )
87158 cls ._cache [mapfile ] = xe .Regridder (
88- xr .Dataset (),
89- xr .Dataset (),
159+ ds_in , ds_out ,
90160 method = method_label ,
91161 filename = str (mapfile ),
92162 reuse_weights = True ,
@@ -98,7 +168,6 @@ def clear(cls) -> None:
98168 """Clear all cached regridders (useful for tests or releasing resources)."""
99169 cls ._cache .clear ()
100170
101-
102171def _pick_maps (
103172 varname : str ,
104173 conservative_map : Optional [Path ] = None ,
@@ -122,54 +191,13 @@ def _pick_maps(
122191 return MapSpec ("bilinear" , bilin )
123192 return MapSpec ("conservative" , cons )
124193
125-
126194def _ensure_ncol_last (da : xr .DataArray ) -> Tuple [xr .DataArray , Tuple [str , ...]]:
127195 """Move 'ncol' to the last position; return (da, non_spatial_dims)."""
128196 if "ncol" not in da .dims :
129197 raise ValueError (f"Expected 'ncol' in dims; got { da .dims } " )
130198 non_spatial = tuple (d for d in da .dims if d != "ncol" )
131199 return da .transpose (* non_spatial , "ncol" ), non_spatial
132200
133-
134- def _dst_coords_from_map (mapfile : Path ) -> Dict [str , xr .DataArray ]:
135- """Extract dest lat/lon (+bounds if present) from an ESMF map file."""
136- with xr .open_dataset (mapfile ) as m :
137- # centers
138- if "lat" in m and "lon" in m :
139- lat = m ["lat" ]
140- lon = m ["lon" ]
141- elif "yc" in m and "xc" in m :
142- lat = m ["yc" ].rename ("lat" )
143- lon = m ["xc" ].rename ("lon" )
144- else :
145- dims = np .asarray (m .get ("dst_grid_dims" , [180 , 360 ])).ravel ()
146- ny = int (dims [- 2 ]) if dims .size >= 2 else 180
147- nx = int (dims [- 1 ]) if dims .size >= 2 else 360
148- lat = xr .DataArray (np .linspace (- 89.5 , 89.5 , ny ), dims = ("lat" ,), name = "lat" )
149- lon = xr .DataArray (
150- (np .arange (nx ) + 0.5 ) * (360.0 / nx ), dims = ("lon" ,), name = "lon"
151- )
152-
153- # bounds
154- lat_b = None
155- lon_b = None
156- for cand in ("lat_bnds" , "lat_b" , "bounds_lat" , "lat_bounds" , "y_bnds" , "yb" ):
157- if cand in m :
158- lat_b = m [cand ]
159- break
160- for cand in ("lon_bnds" , "lon_b" , "bounds_lon" , "lon_bounds" , "x_bnds" , "xb" ):
161- if cand in m :
162- lon_b = m [cand ]
163- break
164-
165- coords = {"lat" : lat , "lon" : lon }
166- if lat_b is not None :
167- coords ["lat_bnds" ] = lat_b
168- if lon_b is not None :
169- coords ["lon_bnds" ] = lon_b
170- return coords
171-
172-
173201def _rename_xy_to_latlon (da : xr .DataArray ) -> xr .DataArray :
174202 """Normalize 2-D dims to ('lat','lon') if they came out as ('y','x')."""
175203 dim_map = {}
@@ -179,7 +207,6 @@ def _rename_xy_to_latlon(da: xr.DataArray) -> xr.DataArray:
179207 dim_map ["x" ] = "lon"
180208 return da .rename (dim_map ) if dim_map else da
181209
182-
183210def regrid_to_1deg (
184211 ds_in : xr .Dataset ,
185212 varname : str ,
@@ -207,20 +234,7 @@ def regrid_to_1deg(
207234 out = regridder (da2 ) # -> (*non_spatial, y/x or lat/lon)
208235 out = _rename_xy_to_latlon (out )
209236
210- # Attach lat/lon (+bounds) from map if available
211- try :
212- dst_coords = _dst_coords_from_map (spec .path )
213- if {"lat" , "lon" }.issubset (out .dims ):
214- out = out .assign_coords (
215- {k : v for k , v in dst_coords .items () if k in {"lat" , "lon" }}
216- )
217- for bname in ("lat_bnds" , "lon_bnds" ):
218- if bname in dst_coords :
219- out = out .assign_coords ({bname : dst_coords [bname ]})
220- except (OSError , ValueError , KeyError ):
221- # Non-fatal: keep whatever coords xESMF provided
222- pass
223-
237+ # Try to attach standard attrs
224238 if keep_attrs :
225239 out .attrs .update (da .attrs )
226240
@@ -237,7 +251,6 @@ def regrid_to_1deg(
237251
238252 return out
239253
240-
241254def regrid_mask_or_area (
242255 da_in : xr .DataArray ,
243256 * ,
@@ -249,24 +262,10 @@ def regrid_mask_or_area(
249262 if "time" in da_in .dims :
250263 da_in = da_in .transpose ("time" , "ncol" , ...)
251264
252- spec = MapSpec (
253- "conservative" , Path (conservative_map ) if conservative_map else DEFAULT_CONS_MAP
254- )
265+ spec = MapSpec ("conservative" , Path (conservative_map ) if conservative_map else DEFAULT_CONS_MAP )
255266 regridder = _RegridderCache .get (spec .path , spec .method_label )
256267
257268 out = regridder (da_in )
258269 out = _rename_xy_to_latlon (out )
259-
260- try :
261- dst_coords = _dst_coords_from_map (spec .path )
262- if {"lat" , "lon" }.issubset (out .dims ):
263- out = out .assign_coords (
264- {k : v for k , v in dst_coords .items () if k in {"lat" , "lon" }}
265- )
266- for bname in ("lat_bnds" , "lon_bnds" ):
267- if bname in dst_coords :
268- out = out .assign_coords ({bname : dst_coords [bname ]})
269- except (OSError , ValueError , KeyError ):
270- pass
271-
272270 return out
271+
0 commit comments