diff --git a/.github/workflows/pytest.yaml b/.github/workflows/pytest.yaml index da14b61..e29fc85 100644 --- a/.github/workflows/pytest.yaml +++ b/.github/workflows/pytest.yaml @@ -35,14 +35,30 @@ jobs: run: | python -m pip install --upgrade pip pip install -r requirements.txt + - name: Cache CMOR build + uses: actions/cache@v4 + with: + path: | + /home/runner/cmor + /home/runner/cmor3-source + /home/runner/.cache/pip + key: cmor-${{ runner.os }}-${{ hashFiles('requirements.txt') }}-${{ hashFiles('**/cmor_version.txt') }} + restore-keys: | + cmor-${{ runner.os }}- + - name: build and install cmor run: | - git clone https://github.com/PCMDI/cmor.git $HOME/cmor3-source - cd $HOME/cmor3-source - ./configure --prefix=$HOME/cmor --with-netcdf=/usr/local --with-udunits2=/usr/local - make - make install - pip install . + if [ ! -d "$HOME/cmor" ]; then + git clone https://github.com/PCMDI/cmor.git $HOME/cmor3-source + cd $HOME/cmor3-source + ./configure --prefix=$HOME/cmor --with-netcdf=/usr/local --with-udunits2=/usr/local + make + make install + pip install . + else + cd /home/runner/cmor3-source + pip install . + fi - name: Run pytest env: PYTHONPATH: ${{ github.workspace }} @@ -51,4 +67,9 @@ jobs: run: | #pip uninstall -y cmip7-prep || true #pip install . - pytest -q + pytest -q --doctest-modules --cov=cmip7_prep --cov-report=xml + - name: Upload coverage report + uses: actions/upload-artifact@v4 + with: + name: coverage-xml + path: coverage.xml diff --git a/.pylintrc b/.pylintrc index 83cbb91..e2e8180 100644 --- a/.pylintrc +++ b/.pylintrc @@ -8,7 +8,8 @@ disable= too-many-arguments, too-many-locals, too-many-branches, - too-many-statements + too-many-statements, + invalid-name [FORMAT] max-line-length=100 diff --git a/cmip7_prep/cache_tools.py b/cmip7_prep/cache_tools.py new file mode 100644 index 0000000..35ad0aa --- /dev/null +++ b/cmip7_prep/cache_tools.py @@ -0,0 +1,235 @@ +# cmip7_prep/cache_tools.py +"""Tools for caching and reuse in regridding.""" +from pathlib import Path +from typing import Dict, Optional, Tuple +import logging +import xesmf as xe +import xarray as xr +import numpy as np + +from cmip7_prep.cmor_utils import bounds_from_centers_1d + +logger = logging.getLogger(__name__) + + +# ------------------------- +# NetCDF opener (backends) +# ------------------------- +def open_nc(path: Path) -> xr.Dataset: + """Open NetCDF with explicit engines and narrow exception handling. + + Tries 'netcdf4' then 'scipy'. Collects the failure reasons and raises a + single RuntimeError if neither works. + """ + path = Path(path) + if not path.exists(): + raise FileNotFoundError(f"Weight file not found: {path}") + + errors: dict[str, Exception] = {} + for engine in ("netcdf4", "scipy"): + try: + return xr.open_dataset(str(path), engine=engine) + except (ValueError, OSError, ImportError, ModuleNotFoundError) as exc: + # ValueError: invalid/unavailable engine or decode issue + # OSError: low-level file I/O/HDF5 issues + # ImportError/ModuleNotFoundError: backend not installed + errors[engine] = exc + + details = "; ".join( + f"{eng}: {type(err).__name__}: {err}" for eng, err in errors.items() + ) + raise RuntimeError( + f"Could not open {path} with xarray engines ['netcdf4', 'scipy']. " + f"Tried both; reasons: {details}" + ) + + +def _make_dummy_grids(mapfile: Path) -> tuple[xr.Dataset, xr.Dataset]: + """Construct minimal ds_in/ds_out satisfying xESMF when reusing weights. + Adds CF-style bounds for both lat and lon so conservative methods don’t + trigger cf-xarray’s bounds inference on size-1 dimensions.""" + with open_nc(mapfile) as m: + nlon_in, nlat_in = _get_src_shape(m) + lat_out_1d, lon_out_1d = _get_dst_latlon_1d() + + # --- Dummy INPUT grid (unstructured → represent as 2D with length-1 lat) --- + lat_in = np.arange( + -180.0, 180.0, 360.0 / nlat_in, dtype="f8" + ) # e.g., [0], length can be 1 + lon_in = np.arange(0.5, 360.5, 360.0 / nlon_in, dtype="f8") + ds_in = xr.Dataset( + data_vars={ + "lat_bnds": ( + ("lat", "nbnds"), + bounds_from_centers_1d(lat_in, "lat"), + ), + "lon_bnds": ( + ("lon", "nbnds"), + bounds_from_centers_1d(lon_in, "lon"), + ), + }, + coords={ + "lat": ( + "lat", + lat_in, + { + "units": "degrees_north", + "standard_name": "latitude", + "bounds": "lat_bnds", + }, + ), + "lon": ( + "lon", + lon_in, + { + "units": "degrees_east", + "standard_name": "longitude", + "bounds": "lon_bnds", + }, + ), + "nbnds": ("nbnds", np.array([0, 1], dtype="i4")), + }, + ) + + # --- OUTPUT grid from weights (canonical 1° lat/lon) --- + lat_out_bnds = bounds_from_centers_1d(lat_out_1d, "lat") + lon_out_bnds = bounds_from_centers_1d(lon_out_1d, "lon") + + ds_out = xr.Dataset( + data_vars={ + "lat_bnds": (("lat", "nbnds"), lat_out_bnds), + "lon_bnds": (("lon", "nbnds"), lon_out_bnds), + }, + coords={ + "lat": ( + "lat", + lat_out_1d, + { + "units": "degrees_north", + "standard_name": "latitude", + "bounds": "lat_bnds", + }, + ), + "lon": ( + "lon", + lon_out_1d, + { + "units": "degrees_east", + "standard_name": "longitude", + "bounds": "lon_bnds", + }, + ), + "nbnds": ("nbnds", np.array([0, 1], dtype="i4")), + }, + ) + return ds_in, ds_out + + +# ------------------------- +# Minimal dummy grids from the weight file (based on your approach) +# ------------------------- + + +def read_array(m: xr.Dataset, *names: str) -> Optional[xr.DataArray]: + """helper tool to read array from xarray datasets""" + for n in names: + if n in m: + return m[n] + return None + + +def _get_src_shape(m: xr.Dataset) -> Tuple[int, int]: + """Infer the source grid 'shape' expected by xESMF's ds_to_ESMFgrid. + + We provide a dummy 2D shape even when the true source is unstructured. + """ + a = read_array(m, "src_grid_dims") + if a is not None: + vals = np.asarray(a).ravel().astype(int) + if vals.size == 1: + return (int(vals[0]), 1) + if vals.size >= 2: + return (int(vals[-2]), int(vals[-1])) + # fallbacks for unstructured + for n in ("src_grid_size", "n_a"): + if n in m: + size = int(np.asarray(m[n]).ravel()[0]) + return (size, 1) + # very last resort: infer from max index of sparse matrix rows + if "row" in m: + size = int(np.asarray(m["row"]).max()) + return (size, 1) + raise ValueError("Cannot infer source grid size from weight file.") + + +def _get_dst_latlon_1d() -> Tuple[np.ndarray, np.ndarray]: + """Return 1D dest lat, lon arrays from weight file. + + Prefers 2D center lat/lon (yc_b/xc_b or lat_b/lon_b), reshaped to (ny, nx), + then converts to 1D centers by taking first column/row, which is valid for + regular 1° lat/lon weights. + """ + # Final fallback: fabricate a 1° grid + ny, nx = 180, 360 + lat = np.linspace(-89.5, 89.5, ny, dtype="f8") + lon = (np.arange(nx, dtype="f8") + 0.5) * (360.0 / nx) + + return lat, lon + + +# ------------------------- +# Cache +# ------------------------- +class FXCache: + """Cache of regridded FX fields (sftlf, areacella) keyed by mapfile.""" + + _cache: Dict[Path, xr.Dataset] = {} + + @classmethod + def get(cls, key: Path) -> xr.Dataset | None: + """get cached variable""" + return cls._cache.get(key) + + @classmethod + def put(cls, key: Path, ds_fx: xr.Dataset) -> None: + """put variable into cache""" + cls._cache[key] = ds_fx + + @classmethod + def clear(cls) -> None: + """clear cache""" + cls._cache.clear() + + +class RegridderCache: + """Cache of xESMF Regridders constructed from weight files. + + We build minimal `ds_in`/`ds_out` from the weight file to satisfy CF checks, + then reuse the weight file for the actual mapping. + """ + + _cache: Dict[Path, xe.Regridder] = {} + + @classmethod + def get(cls, mapfile: Path, method_label: str) -> xe.Regridder: + """Return a cached regridder for the given weight file and method.""" + mapfile = mapfile.expanduser().resolve() + if mapfile not in cls._cache: + if not mapfile.exists(): + raise FileNotFoundError(f"Regrid weights not found: {mapfile}") + ds_in, ds_out = _make_dummy_grids(mapfile) + logger.info("Creating xESMF Regridder from weights: %s", mapfile) + cls._cache[mapfile] = xe.Regridder( + ds_in, + ds_out, + method=method_label, + filename=str(mapfile), # reuse the ESMF weight file on disk + reuse_weights=True, + periodic=True, # 0..360 longitudes + ) + return cls._cache[mapfile] + + @classmethod + def clear(cls) -> None: + """Clear all cached regridders (useful for tests or releasing resources).""" + cls._cache.clear() diff --git a/cmip7_prep/cmor_utils.py b/cmip7_prep/cmor_utils.py index 4452342..1ef190b 100644 --- a/cmip7_prep/cmor_utils.py +++ b/cmip7_prep/cmor_utils.py @@ -57,6 +57,8 @@ def filled_for_cmor( # keep attrs helpful for downstream da2.attrs["_FillValue"] = f da2.attrs["missing_value"] = f + if da2.dtype != np.float32: + da2 = da2.astype(np.float32) return da2, f @@ -116,20 +118,23 @@ def bounds_from_centers_1d(vals: np.ndarray, kind: str) -> np.ndarray: """ v = np.asarray(vals, dtype="f8").reshape(-1) n = v.size - if n < 2: - raise ValueError("Need at least 2 points to compute bounds") - - # neighbor midpoints - mid = 0.5 * (v[1:] + v[:-1]) # length n-1 - bounds = np.empty((n, 2), dtype="f8") - bounds[1:, 0] = mid - bounds[:-1, 1] = mid - - # end caps: extrapolate by half-step at ends - first_step = v[1] - v[0] - last_step = v[-1] - v[-2] - bounds[0, 0] = v[0] - 0.5 * first_step - bounds[-1, 1] = v[-1] + 0.5 * last_step + if n == 1: + # Special case: single value, make a cell of width 1 centered on v[0] + bounds = np.array([[v[0] - 0.5, v[0] + 0.5]], dtype="f8") + elif n < 2: + raise ValueError("Need at least 1 point to compute bounds") + else: + # neighbor midpoints + mid = 0.5 * (v[1:] + v[:-1]) # length n-1 + bounds = np.empty((n, 2), dtype="f8") + bounds[1:, 0] = mid + bounds[:-1, 1] = mid + + # end caps: extrapolate by half-step at ends + first_step = v[1] - v[0] + last_step = v[-1] - v[-2] + bounds[0, 0] = v[0] - 0.5 * first_step + bounds[-1, 1] = v[-1] + 0.5 * last_step if kind == "lat": # clamp to physical limits @@ -142,8 +147,6 @@ def bounds_from_centers_1d(vals: np.ndarray, kind: str) -> np.ndarray: wrap = bounds[:, 1] < bounds[:, 0] if np.any(wrap): bounds[wrap, 1] += 360.0 - else: - raise ValueError("kind must be 'lat' or 'lon'") return bounds diff --git a/cmip7_prep/cmor_writer.py b/cmip7_prep/cmor_writer.py index b89ecb8..89125b9 100644 --- a/cmip7_prep/cmor_writer.py +++ b/cmip7_prep/cmor_writer.py @@ -290,14 +290,17 @@ def _get_1d_with_bounds(dsi: xr.Dataset, name: str, units_default: str): plev_id = None lat_id = None lon_id = None + sdepth_id = None - print(f"[CMOR axis debug] var_dims: {var_dims}") + logger.debug("[CMOR axis debug] var_dims: %s", var_dims) if "xh" in var_dims and "yh" in var_dims: # MOM6/curvilinear grid: register xh/yh as generic axes (i/j), not as lat/lon # Define the native grid using the coordinate arrays - print( - f"[CMOR axis debug] Defining unstructured grid for variable {var_name}." + logger.debug( + "[CMOR axis debug] Defining unstructured grid for variable %s.", + var_name, ) + i_id = cmor.axis( table_entry="i", units="1", @@ -308,7 +311,7 @@ def _get_1d_with_bounds(dsi: xr.Dataset, name: str, units_default: str): units="1", length=ds["yh"].size, ) - print("[CMOR axis debug] Defining unstructured grid_id.") + logger.debug("[CMOR axis debug] Defining unstructured grid_id.") grid_id = cmor.grid( axis_ids=[j_id, i_id], # note CMOR wants fastest varying last longitude=ds["geolon"].values, @@ -330,9 +333,9 @@ def _get_1d_with_bounds(dsi: xr.Dataset, name: str, units_default: str): f" in variable '{var_name}' (curvilinear)" ) axes_ids.append(axis_id) - print(f"[CMOR axis debug] Appending grid_id: {grid_id}") + logger.debug("[CMOR axis debug] Appending grid_id: %s", grid_id) axes_ids.append(grid_id) - print(f"[CMOR axis debug] axes_ids: {axes_ids}") + logger.debug("[CMOR axis debug] axes_ids: %s", axes_ids) return axes_ids # --- horizontal axes (use CMOR names) ---- @@ -477,12 +480,24 @@ def _get_1d_with_bounds(dsi: xr.Dataset, name: str, units_default: str): coord_vals=pvals, cell_bounds=pb if pb is not None else None, ) - + elif "sdepth" in var_dims: + values = ds["sdepth"].values + logger.info("write sdepth axis") + bnds = bounds_from_centers_1d(values, "sdepth") + if bnds[0, 0] < 0: + bnds[0, 0] = 0.0 # no negative soil depth bounds + sdepth_id = cmor.axis( + table_entry="sdepth", + units="m", + coord_vals=np.asarray(values), + cell_bounds=bnds, + ) # Map dimension names to axis IDs dim_to_axis = { "time": time_id, "alev": alev_id, # hybrid sigma "lev": alev_id, # sometimes used for hybrid + "sdepth": sdepth_id, "plev": plev_id, "lat": lat_id, "latitude": lat_id, @@ -544,7 +559,24 @@ def ensure_fx_written_and_cached(self, ds_regr: xr.Dataset) -> xr.Dataset: If present in ds_regr but not yet written this run, write and cache them. Returns ds_regr augmented with any missing fx fields. """ - need = [("sftlf", "%"), ("areacella", "m2")] + need = [ + ("sftlf", "%"), + ("areacella", "m2"), + ("sftof", "%"), + ("areacello", "m3"), + ("mrsofc", "m3 s-1"), + ("orog", "m"), + ("thkcello", "m"), + ("slthick", "m"), + ("basin", "m2"), + ("deptho", "m"), + ("hfgeou", "m"), + ("masscello", "m3"), + ("thkcello", "m"), + ("rootd", "m"), + ("sftgif", "%"), + ("sftif", "%"), + ] # land fraction, ocean cell area, soil moisture fraction out = ds_regr for name, units in need: @@ -613,23 +645,22 @@ def write_variable( cmor.load_table(table_filename) data = ds[vdef.name] + logger.info("Prepare data for CMOR %s", data.dtype) # debug data_filled, fillv = filled_for_cmor(data) - logger.info("Define axes") + logger.info("Define axes data_filled dtype: %s", data_filled.dtype) # debug axes_ids = self._define_axes(ds, vdef) units = getattr(vdef, "units", "") or "" # Debug logging for axis mapping # Try to get axis table entries for each axis_id - # pylint: disable=broad-exception-caught try: for i, aid in enumerate(axes_ids): entry = cmor.axis_entry(aid) if hasattr(cmor, "axis_entry") else None - logger.info( + logger.debug( "[CMOR DEBUG] axis %d: id=%s, table_entry=%s", i, aid, entry ) - except ( - Exception - ) as e: # Broad except needed: cmor.axis_entry may raise various errors + # pylint: disable=broad-exception-caught + except Exception as e: logger.warning("[CMOR DEBUG] Could not retrieve axis table entries: %s", e) var_id = cmor.variable( getattr(vdef, "name", varname), @@ -638,7 +669,6 @@ def write_variable( positive=getattr(vdef, "positive", None), missing_value=fillv, ) - data = ds[varname] # ---- Prepare time info for this write (local, not cached) ---- time_da = ds.coords.get("time") diff --git a/cmip7_prep/data/cesm_to_cmip7.yaml b/cmip7_prep/data/cesm_to_cmip7.yaml index ba0484d..86dd8ec 100644 --- a/cmip7_prep/data/cesm_to_cmip7.yaml +++ b/cmip7_prep/data/cesm_to_cmip7.yaml @@ -1,4 +1,3 @@ - # cesm_to_cmip7.yaml — CMIP-keyed schema with CMIP6 compatibility # # Top-level keys are CMIP variable names. @@ -42,6 +41,15 @@ variables: sources: - cesm_var: CLDICE + cLitter: + table: Lmon + units: "kg m-2" + dims: [time, lat, lon] + formula: (TOTLITC + CWD_C)/1000.0 + sources: + - cesm_var: TOTLITC + - cesm_var: CWD_C + clivi: table: Amon units: "kg m-2" @@ -76,6 +84,47 @@ variables: sources: - cesm_var: TGCLDLWP + cProduct: + table: Lmon + units: "kg m-2" + dims: [time, lat, lon] + formula: (TOTPRODC + CWD_P)/1000.0 + sources: + - cesm_var: TOTPRODC + - cesm_var: CWD_P + + cSoilFast: + table: Lmon + units: "kg m-2" + dims: [time, lat, lon] + formula: (SOIL1C)/1000.0 + sources: + - cesm_var: SOIL1C + + cSoilMedium: + table: Lmon + units: "kg m-2" + dims: [time, lat, lon] + formula: (SOIL2C)/1000.0 + sources: + - cesm_var: SOIL2C + + cSoilSlow: + table: Lmon + units: "kg m-2" + dims: [time, lat, lon] + formula: (SOIL3C)/1000.0 + sources: + - cesm_var: SOIL3C + + cVeg: + table: Lmon + units: "kg m-2" + dims: [time, lat, lon] + formula: (TOTVEGC)/1000.0 + sources: + - cesm_var: TOTVEGC + evspsbl: table: Amon units: kg m-2 s-1 @@ -105,6 +154,33 @@ variables: sources: - cesm_var: QVEGE + fFire: + table: Lmon + units: "kg m-2 s-1" + dims: [time, lat, lon] + positive: up + formula: COL_FIRE_CLOSS/1000.0 + sources: + - cesm_var: COL_FIRE_CLOSS + + fHarvest: + table: Lmon + units: "kg m-2 s-1" + dims: [time, lat, lon] + positive: up + formula: WOOD_HARVESTC/1000.0 + sources: + - cesm_var: WOOD_HARVESTC + + gpp: + table: Lmon + units: "kg m-2 s-1" + dims: [time, lat, lon] + formula: GPP/1000.0 + positive: down + sources: + - cesm_var: GPP + hfls: table: Amon units: "W m-2" @@ -202,6 +278,24 @@ variables: sources: - cesm_var: SOILWATER_10CM + nbp: + table: Lmon + units: "kg m-2 s-1" + dims: [time, lat, lon] + formula: NBP/1000.0 + positive: down + sources: + - cesm_var: NBP + + npp: + table: Lmon + units: "kg m-2 s-1" + dims: [time, lat, lon] + formula: NPP/1000.0 + positive: down + sources: + - cesm_var: NPP + pr: table: Amon units: kg m-2 s-1 @@ -232,6 +326,13 @@ variables: - cesm_var: PRECSC - cesm_var: PRECSL + prveg: + table: Lmon + units: "kg m-2 s-1" + dims: [time, lat, lon] + sources: + - cesm_var: QINTR + prw: table: Amon units: "kg m-2" @@ -248,6 +349,7 @@ variables: regrid_method: bilinear sources: - cesm_var: PSL + ps: table: Amon units: Pa @@ -257,13 +359,24 @@ variables: regrid_method: bilinear sources: - cesm_var: PS - rlds: - table: Amon - units: "W m-2" + + ra: + table: Lmon + units: "kg m-2 s-1" dims: [time, lat, lon] positive: up + formula: AR/1000.0 sources: - - cesm_var: FLDS + - cesm_var: AR + + rh: + table: Lmon + units: "kg m-2 s-1" + dims: [time, lat, lon] + positive: up + formula: HR/1000.0 + sources: + - cesm_var: HR rldscs: table: Amon @@ -285,6 +398,7 @@ variables: sources: - cesm_var: FLDS - cesm_var: FLNS + rluscs: table: Amon units: "W m-2" @@ -371,10 +485,19 @@ variables: - cesm_var: SOLIN - cesm_var: FSNTOAC + snld: + table: Lmon + units: "m" + dims: [time, lat, lon] + sources: + - cesm_var: SNOW_DEPTH + sos: table: Omon units: "1e-3" - dims: [time, j, i] + dims: + - [time, lat, lon] + regrid_method: conservative sources: - cesm_var: sos @@ -438,6 +561,16 @@ variables: sources: - cesm_var: TAUY + tran: + table: Lmon + units: "kg m-2 s-1" + dims: [time, lat, lon] + positive: up + formula: QSOIL + QVEGT + sources: + - cesm_var: QVEGT + - cesm_var: QSOIL + ts: table: Amon units: "K" @@ -445,6 +578,13 @@ variables: sources: - cesm_var: TS + tsl: + table: Lmon + units: "K" + dims: [time, sdepth, lat, lon] + sources: + - cesm_var: TSOI + ua: table: Amon units: m s-1 diff --git a/cmip7_prep/mapping_compat.py b/cmip7_prep/mapping_compat.py index 68eb2a6..06535be 100644 --- a/cmip7_prep/mapping_compat.py +++ b/cmip7_prep/mapping_compat.py @@ -47,6 +47,7 @@ class VarConfig: regrid_method: Optional[str] = None long_name: Optional[str] = None standard_name: Optional[str] = None + dims: Optional[List[str]] = None def as_cfg(self) -> Dict[str, Any]: """Return a plain dict view for convenience in other modules. @@ -68,6 +69,7 @@ def as_cfg(self) -> Dict[str, Any]: "regrid_method": self.regrid_method, "long_name": self.long_name, "standard_name": self.standard_name, + "dims": self.dims, } return {k: v for k, v in d.items() if v is not None} @@ -215,6 +217,7 @@ def realize(self, ds: xr.Dataset, cmip_name: str) -> xr.DataArray: vc = self._vars[cmip_name] da = _realize_core(ds, vc) + if da is not None: if vc.unit_conversion is not None: da = _apply_unit_conversion(da, vc.unit_conversion) @@ -270,6 +273,7 @@ def _to_varconfig(name: str, cfg: TMapping[str, Any]) -> VarConfig: regrid_method=cfg.get("regrid_method"), long_name=cfg.get("long_name"), standard_name=cfg.get("standard_name"), + dims=cfg.get("dims"), ) return vc diff --git a/cmip7_prep/mom6_static.py b/cmip7_prep/mom6_static.py index 85c3214..a2c3519 100644 --- a/cmip7_prep/mom6_static.py +++ b/cmip7_prep/mom6_static.py @@ -4,6 +4,30 @@ import xarray as xr import numpy as np +from cmip7_prep.regrid import _sftof_from_native + + +def ocean_fx_fields(static_path, out_path=None): + """ + Read MOM6 static grid and write ocean-related fx fields (sftof, areacello, etc). + Returns a dict of DataArrays for use in regridding normalization/denormalization. + If out_path is given, writes a NetCDF with these fields. + """ + + ds = xr.open_dataset(static_path) + fx = {} + # Extract sftof (sea fraction) using the helper + sftof = _sftof_from_native(ds) + if sftof is not None: + fx["sftof"] = sftof + # Extract areacello if present + if "areacello" in ds: + fx["areacello"] = ds["areacello"] + # Optionally add other ocean mask/area fields as needed + # Save to NetCDF if requested + if out_path is not None: + xr.Dataset(fx).to_netcdf(out_path) + return fx def compute_cell_bounds_from_corners(corner_array): @@ -32,7 +56,7 @@ def compute_cell_bounds_from_corners(corner_array): return bounds -def load_mom6_static(static_path): +def load_mom6_grid(static_path): """Load MOM6 static file and return geolat, geolon, geolat_c, geolon_c arrays.""" ds = xr.open_dataset(static_path) # For a supergrid, centers are every other point, bounds are full array diff --git a/cmip7_prep/pipeline.py b/cmip7_prep/pipeline.py index 8017c2b..8aec309 100644 --- a/cmip7_prep/pipeline.py +++ b/cmip7_prep/pipeline.py @@ -6,6 +6,7 @@ from pathlib import Path from typing import Optional, Sequence, Union, Dict, List import re +import logging import warnings import glob import sys @@ -15,7 +16,7 @@ from .regrid import regrid_to_1deg_ds from .vertical import to_plev19 - +logger = logging.getLogger(__name__) # --------------------------- file discovery --------------------------- _VAR_TOKEN = re.compile(r"(? List[str]: """Gather all native CESM vars needed to realize the requested CMIP vars.""" needed: set[str] = set() - for v in cmip_vars: + for var in cmip_vars: try: - cfg = mapping.get_cfg(v) or {} + cfg = mapping.get_cfg(var) or {} except KeyError: print( - f"WARNING: skipping '{v}': no mapping found in {mapping.path}", + f"WARNING: skipping '{var}': no mapping found in {mapping.path}", file=sys.stderr, ) continue @@ -58,12 +59,12 @@ def _collect_required_cesm_vars( raws = cfg.get("raw_variables") or cfg.get("sources") or [] if src: needed.add(src) - for r in raws: + for raw in raws: # 'sources' items may be dicts with 'cesm_var' - if isinstance(r, dict) and "cesm_var" in r: - needed.add(r["cesm_var"]) - elif isinstance(r, str): - needed.add(r) + if isinstance(raw, dict) and "cesm_var" in raw: + needed.add(raw["cesm_var"]) + elif isinstance(raw, str): + needed.add(raw) # vertical dependencies if plev19 levels = cfg.get("levels") or {} if (levels.get("name") or "").lower() == "plev19": @@ -201,6 +202,7 @@ def realize_regrid_prepare( *, tables_path: Optional[Union[str, Path]] = None, time_chunk: Optional[int] = 12, + mom6_grid: Optional[Dict[str, xr.DataArray]] = None, regrid_kwargs: Optional[dict] = None, open_kwargs: Optional[dict] = None, ) -> xr.Dataset: @@ -211,7 +213,7 @@ def realize_regrid_prepare( """ regrid_kwargs = dict(regrid_kwargs or {}) open_kwargs = dict(open_kwargs or {}) - + aux = [] # 1) Get native dataset if isinstance(ds_or_glob, xr.Dataset): ds_native = ds_or_glob @@ -220,7 +222,21 @@ def realize_regrid_prepare( ds_native = open_native_for_cmip_vars( [cmip_var], ds_or_glob, mapping, **open_kwargs ) + if "landfrac" not in ds_native and "ncol" not in ds_native: + logger.info("Variable has no 'landfrac' or 'ncol' dim; assuming ocn variable.") + # Add MOM6 grid info if provided + if mom6_grid: + aux = ["geolat", "geolon", "geolat_c", "geolon_c"] + ds_native["geolat"] = xr.DataArray(mom6_grid[0], dims=("yh", "xh")) + ds_native["geolon"] = xr.DataArray(mom6_grid[1], dims=("yh", "xh")) + ds_native["geolat_c"] = xr.DataArray(mom6_grid[2], dims=("yhp", "xhp")) + ds_native["geolon_c"] = xr.DataArray(mom6_grid[3], dims=("yhp", "xhp")) + else: + logger.error("No MOM6 grid info provided; geolat/geolon not added.") + raise ValueError( + f"MOM6 grid information is required for variable {cmip_var} but was not provided." + ) # 2) Realize the target variable ds_v = mapping.realize(ds_native, cmip_var) da = ds_v if isinstance(ds_v, xr.DataArray) else ds_v[cmip_var] @@ -228,11 +244,9 @@ def realize_regrid_prepare( da = da.chunk({"time": int(time_chunk)}) ds_vars = xr.Dataset({cmip_var: da}) - - if "landfrac" in ds_native and "landfrac" not in ds_vars: - ds_vars = ds_vars.assign(landfrac=ds_native["landfrac"]) - if "area" in ds_native and "area" not in ds_vars: - ds_vars = ds_vars.assign(area=ds_native["area"]) + for var in ("landfrac", "area", "sftof"): + if var in ds_native and var not in ds_vars: + ds_vars = ds_vars.assign(**{var: ds_native[var]}) # 3) Check whether hybrid-σ is required cfg = mapping.get_cfg(cmip_var) or {} @@ -247,7 +261,11 @@ def realize_regrid_prepare( # can produce PS(time,lat,lon) if "PS" in ds_native and "PS" not in ds_vars: ds_vars = ds_vars.assign(PS=ds_native["PS"]) - + aux = [ + nm + for nm in ("hyai", "hybi", "hyam", "hybm", "P0", "ilev", "lev") + if nm in ds_native + ] # 5) Apply vertical transform if needed (plev19, etc.). # Single-var helper already takes cfg + tables_path ds_vert = _apply_vertical_if_needed( @@ -259,19 +277,29 @@ def realize_regrid_prepare( if is_hybrid and "PS" in ds_vert: names_to_regrid.append("PS") + # 7) Rename levgrnd if present to sdepth + + # Check if 'levgrnd' is a dimension of any variable in names_to_regrid + needs_levgrnd_rename = any( + (v in ds_vert and "levgrnd" in getattr(ds_vert[v], "dims", [])) + for v in names_to_regrid + ) + if ( + needs_levgrnd_rename + and "levgrnd" in ds_native.dims + and "levgrnd" in ds_native.coords + ): + logger.info("Renaming 'levgrnd' dimension to 'sdepth'") + ds_vert = ds_vert.rename_dims({"levgrnd": "sdepth"}) + # Ensure the coordinate variable is also copied + ds_vert = ds_vert.assign_coords(sdepth=ds_native["levgrnd"].values) + ds_regr = regrid_to_1deg_ds( ds_vert, names_to_regrid, time_from=ds_native, **regrid_kwargs ) - # 7) If hybrid: merge in 1-D hybrid coefficients directly from native (no regridding needed) - if is_hybrid: - aux = [ - nm - for nm in ("hyai", "hybi", "hyam", "hybm", "P0", "ilev", "lev") - if nm in ds_native - ] - if aux: - ds_regr = ds_regr.merge(ds_native[aux], compat="override") + if aux: + ds_regr = ds_regr.merge(ds_native[aux], compat="override") return ds_regr diff --git a/cmip7_prep/regrid.py b/cmip7_prep/regrid.py index ab7c8f1..3688ac6 100644 --- a/cmip7_prep/regrid.py +++ b/cmip7_prep/regrid.py @@ -3,15 +3,14 @@ from __future__ import annotations from dataclasses import dataclass from pathlib import Path -from typing import Optional, Dict, Tuple +from typing import Optional, Tuple import logging import xarray as xr -import xesmf as xe # import warnings import numpy as np from cmip7_prep.cmor_utils import bounds_from_centers_1d - +from cmip7_prep.cache_tools import FXCache, RegridderCache, read_array, open_nc logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) @@ -28,11 +27,19 @@ _HAS_DASK = False # Default weight maps; override via function args. -DEFAULT_CONS_MAP = Path( - "/glade/campaign/cesm/cesmdata/inputdata/cpl/gridmaps/ne30pg3/map_ne30pg3_to_1x1d_aave.nc" +INPUTDATA_DIR = Path("/glade/campaign/cesm/cesmdata/inputdata/") +DEFAULT_CONS_MAP_NE30 = Path( + INPUTDATA_DIR / "cpl/gridmaps/ne30pg3/map_ne30pg3_to_1x1d_aave.nc" +) +DEFAULT_BILIN_MAP_NE30 = Path( + INPUTDATA_DIR / "cpl/gridmaps/ne30pg3/map_ne30pg3_to_1x1d_bilin.nc" +) # optional bilinear map + +DEFAULT_CONS_MAP_T232 = Path( + INPUTDATA_DIR / "cpl/gridmaps/tx2_3v2/map_t232_TO_1x1d_aave.251023.nc" ) -DEFAULT_BILIN_MAP = Path( - "/glade/campaign/cesm/cesmdata/inputdata/cpl/gridmaps/ne30pg3/map_ne30pg3_to_1x1d_bilin.nc" +DEFAULT_BILIN_MAP_T232 = Path( + INPUTDATA_DIR / "cpl/gridmaps/tx2_3v2/map_t232_TO_1x1d_blin.251023.nc" ) # optional bilinear map # Variables treated as "intensive" → prefer bilinear when available. @@ -69,6 +76,34 @@ class MapSpec: path: Path +# --- OCEAN FRACTION (sftof) extraction --- +def _sftof_from_native(ds: xr.Dataset) -> xr.DataArray | None: + """Return sea fraction (sftof) from ocean static grid, using 'wet' mask if present.""" + # Try common names for ocean fraction + for name in ["sftof", "ocnfrac", "wet"]: + if name in ds: + v = ds[name] + # If 'wet', convert 0/1 to percent + if name == "wet": + vmax = float(np.nanmax(np.asarray(v))) + # If 0/1, convert to percent + if vmax <= 1.0 + 1e-6: + out = v * 100.0 + else: + out = v + else: + out = v + out = out.clip(min=0.0, max=100.0) + out = out.astype("f8") + attrs = dict(out.attrs) + attrs["units"] = "%" + attrs.setdefault("standard_name", "sea_area_fraction") + attrs.setdefault("long_name", "Percentage of sea area") + out.attrs = attrs + return out + return None + + def _pick_from_candidates(ds: xr.Dataset, *names: str) -> xr.DataArray | None: """Return the first present variable among candidate names (case-insensitive).""" for nm in names: @@ -91,38 +126,6 @@ def _normalize_land_aux(da: xr.DataArray, hdim: str) -> xr.DataArray: return da -# ------------------------- -# NetCDF opener (backends) -# ------------------------- -def _open_nc(path: Path) -> xr.Dataset: - """Open NetCDF with explicit engines and narrow exception handling. - - Tries 'netcdf4' then 'scipy'. Collects the failure reasons and raises a - single RuntimeError if neither works. - """ - path = Path(path) - if not path.exists(): - raise FileNotFoundError(f"Weight file not found: {path}") - - errors: dict[str, Exception] = {} - for engine in ("netcdf4", "scipy"): - try: - return xr.open_dataset(str(path), engine=engine) - except (ValueError, OSError, ImportError, ModuleNotFoundError) as exc: - # ValueError: invalid/unavailable engine or decode issue - # OSError: low-level file I/O/HDF5 issues - # ImportError/ModuleNotFoundError: backend not installed - errors[engine] = exc - - details = "; ".join( - f"{eng}: {type(err).__name__}: {err}" for eng, err in errors.items() - ) - raise RuntimeError( - f"Could not open {path} with xarray engines ['netcdf4', 'scipy']. " - f"Tried both; reasons: {details}" - ) - - def _attach_vertical_metadata(ds_out: xr.Dataset, ds_src: xr.Dataset) -> xr.Dataset: """ Pass-through vertical metadata needed for hybrid-sigma: @@ -152,239 +155,6 @@ def _attach_vertical_metadata(ds_out: xr.Dataset, ds_src: xr.Dataset) -> xr.Data return ds_out -# ------------------------- -# Minimal dummy grids from the weight file (based on your approach) -# ------------------------- - - -def _read_array(m: xr.Dataset, *names: str) -> Optional[xr.DataArray]: - for n in names: - if n in m: - return m[n] - return None - - -def _get_src_shape(m: xr.Dataset) -> Tuple[int, int]: - """Infer the source grid 'shape' expected by xESMF's ds_to_ESMFgrid. - - We provide a dummy 2D shape even when the true source is unstructured. - """ - a = _read_array(m, "src_grid_dims") - if a is not None: - vals = np.asarray(a).ravel().astype(int) - if vals.size == 1: - return (1, int(vals[0])) - if vals.size >= 2: - return (int(vals[-2]), int(vals[-1])) - # fallbacks for unstructured - for n in ("src_grid_size", "n_a"): - if n in m: - size = int(np.asarray(m[n]).ravel()[0]) - return (1, size) - # very last resort: infer from max index of sparse matrix rows - if "row" in m: - size = int(np.asarray(m["row"]).max()) - return (1, size) - raise ValueError("Cannot infer source grid size from weight file.") - - -def _dst_latlon_1d_from_map(mapfile: Path) -> tuple[np.ndarray, np.ndarray]: - """Return canonical 1-D (lat, lon) for the destination grid. - - Robust to map files whose stored dst_grid_dims or reshape order is swapped. - We derive 1-D axes by taking UNIQUE values from the 2-D center fields. - """ - with _open_nc(mapfile) as m: - lat2d = _read_array(m, "yc_b", "lat_b", "dst_grid_center_lat", "yc", "lat") - lon2d = _read_array(m, "xc_b", "lon_b", "dst_grid_center_lon", "xc", "lon") - - if lat2d is not None and lon2d is not None: - lat2 = np.asarray(lat2d).reshape(-1) # flatten safely - lon2 = np.asarray(lon2d).reshape(-1) - - # Take unique values (rounded to avoid tiny FP noise) - lat_unique = np.unique(lat2.round(6)) - lon_unique = np.unique(lon2.round(6)) - - # If either came out descending, sort ascending - lat1d = np.sort(lat_unique).astype("f8") - lon1d = np.sort(lon_unique).astype("f8") - - # If longitudes are in [-180,180], convert to [0,360) - if lon1d.min() < 0.0 or lon1d.max() <= 180.0: - lon1d = (lon1d % 360.0).astype("f8") - lon1d.sort() - - # Prefer the classic 180/360 lengths if present - if lat1d.size == 360 and lon1d.size == 180: - # swapped: pick every other for lat (=180) and expand lon to 360 if needed - # However, for standard 1° grids stored swapped, lat values repeat; use stride 2. - lat1d = lat1d[::2] - # For lon 180 centers, mapfile likely only stored 0.5..179.5. - # We'll fabricate 0.5..359.5 to be safe. - if lon1d.size == 180: - lon1d = np.arange(360, dtype="f8") + 0.5 - - # sanity bounds - if not ( - -91.0 <= float(lat1d.min()) <= -89.0 - and 89.0 <= float(lat1d.max()) <= 91.0 - ): - # If still odd, try the alternative: - # extract along the other axis by reshaping via dst_grid_dims - # Fallback to canonical 1° - lat1d = np.linspace(-89.5, 89.5, 180, dtype="f8") - if lon1d.size != 360: - lon1d = np.arange(360, dtype="f8") + 0.5 - - return lat1d, lon1d - - # 1-D fields present already - lat1d = _read_array(m, "lat", "yc") - lon1d = _read_array(m, "lon", "xc") - if lat1d is not None and lon1d is not None: - lat1 = np.sort(np.asarray(lat1d, dtype="f8")) - lon1 = np.sort(np.asarray(lon1d, dtype="f8")) - if lon1.min() < 0.0 or lon1.max() <= 180.0: - lon1 = lon1 % 360.0 - lon1.sort() - if lat1.size == 360: - lat1 = lat1[::2] - if lon1.size != 360: - lon1 = np.arange(360, dtype="f8") + 0.5 - return lat1, lon1 - - # Last resort: fabricate standard 1° - lat = np.linspace(-89.5, 89.5, 180, dtype="f8") - lon = np.arange(360, dtype="f8") + 0.5 - return lat, lon - - -def _get_dst_latlon_1d() -> Tuple[np.ndarray, np.ndarray]: - """Return 1D dest lat, lon arrays from weight file. - - Prefers 2D center lat/lon (yc_b/xc_b or lat_b/lon_b), reshaped to (ny, nx), - then converts to 1D centers by taking first column/row, which is valid for - regular 1° lat/lon weights. - """ - # Final fallback: fabricate a 1° grid - ny, nx = 180, 360 - lat = np.linspace(-89.5, 89.5, ny, dtype="f8") - lon = (np.arange(nx, dtype="f8") + 0.5) * (360.0 / nx) - - return lat, lon - - -def _bounds_from_centers_1d( - centers: np.ndarray, *, periodic: bool = False, period: float = 360.0 -) -> np.ndarray: - """Compute simple 1D bounds from 1D cell centers. - For regular spacing this sets bounds[i] = [c[i]-dx/2, c[i]+dx/2]. - For periodic=True, wrap the last bound to period (e.g., lon 0..360).""" - - centers = np.asarray(centers, dtype="f8").ravel() - n = centers.size - b = np.empty((n, 2), dtype="f8") - if n == 1: - # Any small cell is fine for dummy grid; choose +/- 0.5 around center - dx = 1.0 - b[0, 0] = centers[0] - dx / 2.0 - b[0, 1] = centers[0] + dx / 2.0 - else: - # Estimate dx from adjacent centers (assume uniform spacing) - dx = np.diff(centers).mean() - b[:, 0] = centers - dx / 2.0 - b[:, 1] = centers + dx / 2.0 - if periodic: - # Keep within [0, period] - b = np.mod(b, period) - # Ensure right bound of last cell connects to period exactly - # (helps some ESMF periodic checks for 0..360) - b[-1, 1] = ( - period if abs(b[-1, 1] - period) < 1e-6 or b[-1, 1] < 1e-6 else b[-1, 1] - ) - return b - - -def _make_dummy_grids(mapfile: Path) -> tuple[xr.Dataset, xr.Dataset]: - """Construct minimal ds_in/ds_out satisfying xESMF when reusing weights. - Adds CF-style bounds for both lat and lon so conservative methods don’t - trigger cf-xarray’s bounds inference on size-1 dimensions.""" - with _open_nc(mapfile) as m: - nlat_in, nlon_in = _get_src_shape(m) - lat_out_1d, lon_out_1d = _get_dst_latlon_1d() - - # --- Dummy INPUT grid (unstructured → represent as 2D with length-1 lat) --- - lat_in = np.arange(nlat_in, dtype="f8") # e.g., [0], length can be 1 - lon_in = np.arange(nlon_in, dtype="f8") - ds_in = xr.Dataset( - data_vars={ - "lat_bnds": ( - ("lat", "nbnds"), - _bounds_from_centers_1d(lat_in, periodic=False), - ), - "lon_bnds": ( - ("lon", "nbnds"), - _bounds_from_centers_1d(lon_in, periodic=False), - ), - }, - coords={ - "lat": ( - "lat", - lat_in, - { - "units": "degrees_north", - "standard_name": "latitude", - "bounds": "lat_bnds", - }, - ), - "lon": ( - "lon", - lon_in, - { - "units": "degrees_east", - "standard_name": "longitude", - "bounds": "lon_bnds", - }, - ), - "nbnds": ("nbnds", np.array([0, 1], dtype="i4")), - }, - ) - - # --- OUTPUT grid from weights (canonical 1° lat/lon) --- - lat_out_bnds = _bounds_from_centers_1d(lat_out_1d, periodic=False) - lon_out_bnds = _bounds_from_centers_1d(lon_out_1d, periodic=True, period=360.0) - - ds_out = xr.Dataset( - data_vars={ - "lat_bnds": (("lat", "nbnds"), lat_out_bnds), - "lon_bnds": (("lon", "nbnds"), lon_out_bnds), - }, - coords={ - "lat": ( - "lat", - lat_out_1d, - { - "units": "degrees_north", - "standard_name": "latitude", - "bounds": "lat_bnds", - }, - ), - "lon": ( - "lon", - lon_out_1d, - { - "units": "degrees_east", - "standard_name": "longitude", - "bounds": "lon_bnds", - }, - ), - "nbnds": ("nbnds", np.array([0, 1], dtype="i4")), - }, - ) - return ds_in, ds_out - - # ------------------------- # Selection & utilities # ------------------------- @@ -413,10 +183,15 @@ def _pick_maps( conservative_map: Optional[Path] = None, bilinear_map: Optional[Path] = None, force_method: Optional[str] = None, + realm: Optional[str] = None, ) -> MapSpec: """Choose which precomputed map file to use for a variable.""" - cons = Path(conservative_map) if conservative_map else DEFAULT_CONS_MAP - bilin = Path(bilinear_map) if bilinear_map else DEFAULT_BILIN_MAP + if realm == "ocn": + cons = Path(conservative_map) if conservative_map else DEFAULT_CONS_MAP_T232 + bilin = Path(bilinear_map) if bilinear_map else DEFAULT_BILIN_MAP_T232 + else: + cons = Path(conservative_map) if conservative_map else DEFAULT_CONS_MAP_NE30 + bilin = Path(bilinear_map) if bilinear_map else DEFAULT_BILIN_MAP_NE30 if force_method: if force_method not in {"conservative", "bilinear"}: @@ -461,6 +236,7 @@ def regrid_to_1deg_ds( keep_attrs: bool = True, dtype: str | None = "float32", output_time_chunk: int | None = 12, + sftlf_path: Optional[Path] = None, ) -> xr.Dataset: """Regrid var(s) and return a Dataset.""" @@ -484,6 +260,14 @@ def regrid_to_1deg_ds( # Attach time (and bounds) from the original dataset if requested if time_from is not None: ds_out = _attach_time_and_bounds(ds_out, time_from) + if "ncol" in ds_in.dims: + realm = "atm" + elif "lndgrid" in ds_in.dims: + realm = "lnd" + else: + realm = "ocn" + if "sftof" in ds_in: + logger.info("Regridding sftof (sea fraction) from source dataset") # Pick the mapfile you used for conservative/bilinear selection spec = _pick_maps( @@ -491,9 +275,11 @@ def regrid_to_1deg_ds( conservative_map=conservative_map, bilinear_map=bilinear_map, force_method="conservative", + realm=realm, ) # fx always conservative - print(f"using fx map: {spec.path}") - ds_fx = _regrid_fx_once(spec.path, ds_in) # ← uses cache + logger.info("using fx map: %s", spec.path) + ds_fx = _regrid_fx_once(spec.path, ds_in, sftlf_path) # ← uses cache + if ds_fx: # Don’t overwrite if user already computed and passed them in for name in ( @@ -538,6 +324,90 @@ def _denormalize_land_field( return out +def _denormalize_ocn_field(out_norm: xr.DataArray, ds_in: xr.Dataset) -> xr.DataArray: + """Denormalize field by destination sftof (sea fraction).""" + logger.info("Denormalizing ocean field by destination sftof (sea fraction)") + sftof_dst = _sftof_from_native(ds_in) # fallback: use source if no destination + if sftof_dst is not None: + frac_dst = sftof_dst / 100.0 + out = out_norm / frac_dst.where(frac_dst > 0) + else: + out = out_norm + return out + + +def _dst_latlon_1d_from_map(mapfile: Path) -> tuple[np.ndarray, np.ndarray]: + """Return canonical 1-D (lat, lon) for the destination grid. + + Robust to map files whose stored dst_grid_dims or reshape order is swapped. + We derive 1-D axes by taking UNIQUE values from the 2-D center fields. + """ + with open_nc(mapfile) as m: + lat2d = read_array(m, "yc_b", "lat_b", "dst_grid_center_lat", "yc", "lat") + lon2d = read_array(m, "xc_b", "lon_b", "dst_grid_center_lon", "xc", "lon") + + if lat2d is not None and lon2d is not None: + lat2 = np.asarray(lat2d).reshape(-1) # flatten safely + lon2 = np.asarray(lon2d).reshape(-1) + + # Take unique values (rounded to avoid tiny FP noise) + lat_unique = np.unique(lat2.round(6)) + lon_unique = np.unique(lon2.round(6)) + + # If either came out descending, sort ascending + lat1d = np.sort(lat_unique).astype("f8") + lon1d = np.sort(lon_unique).astype("f8") + + # If longitudes are in [-180,180], convert to [0,360) + if lon1d.min() < 0.0 or lon1d.max() <= 180.0: + lon1d = (lon1d % 360.0).astype("f8") + lon1d.sort() + + # Prefer the classic 180/360 lengths if present + if lat1d.size == 360 and lon1d.size == 180: + # swapped: pick every other for lat (=180) and expand lon to 360 if needed + # However, for standard 1° grids stored swapped, lat values repeat; use stride 2. + lat1d = lat1d[::2] + # For lon 180 centers, mapfile likely only stored 0.5..179.5. + # We'll fabricate 0.5..359.5 to be safe. + if lon1d.size == 180: + lon1d = np.arange(360, dtype="f8") + 0.5 + + # sanity bounds + if not ( + -91.0 <= float(lat1d.min()) <= -89.0 + and 89.0 <= float(lat1d.max()) <= 91.0 + ): + # If still odd, try the alternative: + # extract along the other axis by reshaping via dst_grid_dims + # Fallback to canonical 1° + lat1d = np.linspace(-89.5, 89.5, 180, dtype="f8") + if lon1d.size != 360: + lon1d = np.arange(360, dtype="f8") + 0.5 + + return lat1d, lon1d + + # 1-D fields present already + lat1d = read_array(m, "lat", "yc") + lon1d = read_array(m, "lon", "xc") + if lat1d is not None and lon1d is not None: + lat1 = np.sort(np.asarray(lat1d, dtype="f8")) + lon1 = np.sort(np.asarray(lon1d, dtype="f8")) + if lon1.min() < 0.0 or lon1.max() <= 180.0: + lon1 = lon1 % 360.0 + lon1.sort() + if lat1.size == 360: + lat1 = lat1[::2] + if lon1.size != 360: + lon1 = np.arange(360, dtype="f8") + 0.5 + return lat1, lon1 + + # Last resort: fabricate standard 1° + lat = np.linspace(-89.5, 89.5, 180, dtype="f8") + lon = np.arange(360, dtype="f8") + 0.5 + return lat, lon + + def regrid_to_1deg( ds_in: xr.Dataset, varname: str, @@ -569,12 +439,28 @@ def regrid_to_1deg( """ if varname not in ds_in: raise KeyError(f"{varname!r} not in dataset.") - + realm = None var_da = ds_in[varname] # always a DataArray - - da2, non_spatial, hdim = _ensure_ncol_last(var_da) - if hdim == "lndgrid": - da2 = _normalize_land_field(da2, ds_in) + if "ncol" not in var_da.dims and "lndgrid" not in var_da.dims: + logger.info("Variable has no 'ncol' or 'lndgrid' dim; assuming ocn variable.") + hdim = "tripolar" + da2 = var_da # Use the DataArray, not the whole Dataset + non_spatial = [d for d in da2.dims if d not in ("yh", "xh")] + realm = "ocn" + method = method or "conservative" # force conservative for ocn + # --- OCEAN: Normalize by sftof (sea fraction) if present --- + sftof = _sftof_from_native(ds_in) + if sftof is not None: + logger.info("Normalizing ocean field by source sftof (sea fraction)") + frac = sftof / 100.0 + da2 = da2.fillna(0) * frac + else: + da2, non_spatial, hdim = _ensure_ncol_last(var_da) + if hdim == "lndgrid": + da2 = _normalize_land_field(da2, ds_in) + realm = "lnd" + elif hdim == "ncol": + realm = "atm" # cast to save memory if dtype is not None and str(da2.dtype) != dtype: @@ -589,26 +475,51 @@ def regrid_to_1deg( conservative_map=conservative_map, bilinear_map=bilinear_map, force_method=method, + realm=realm, ) - regridder = _RegridderCache.get(spec.path, spec.method_label) + logger.info( + "Regridding %s using %s map: %s for realm %s", + varname, + spec.method_label, + spec.path, + realm, + ) + regridder = RegridderCache.get(spec.path, spec.method_label) # tell xESMF to produce chunked output kwargs = {} if "time" in da2.dims and output_time_chunk: kwargs["output_chunks"] = {"time": output_time_chunk} + if realm in ("atm", "lnd"): + da2_2d = ( + da2.rename({hdim: "lon"}) + .expand_dims({"lat": 1}) # add a dummy 'lat' of length 1 + .transpose( + *non_spatial, "lat", "lon" + ) # ensure last two dims are ('lat','lon') + ) + else: + da2_2d = da2.rename({"xh": "lon", "yh": "lat"}).transpose( + *non_spatial, "lat", "lon" + ) - da2_2d = ( - da2.rename({hdim: "lon"}) - .expand_dims({"lat": 1}) # add a dummy 'lat' of length 1 - .transpose(*non_spatial, "lat", "lon") # ensure last two dims are ('lat','lon') + da2_2d = da2_2d.assign_coords(lon=((da2_2d.lon % 360))) + logger.debug( + "da2_2d range: %f to %f lat, %f to %f lon", + da2_2d["lat"].min().item(), + da2_2d["lat"].max().item(), + da2_2d["lon"].min().item(), + da2_2d["lon"].max().item(), ) out_norm = regridder(da2_2d, **kwargs) - if hdim == "lndgrid": + + if realm == "lnd": out = _denormalize_land_field(out_norm, ds_in, spec.path) + elif realm == "ocn": + out = _denormalize_ocn_field(out_norm, ds_in) else: - # default path (atm or no landfrac/sftlf available) - out = regridder(da2_2d, **kwargs) + out = out_norm # --- NEW: robust lat/lon assignment based on destination grid lengths --- lat1d, lon1d = _dst_latlon_1d_from_map(spec.path) @@ -619,9 +530,23 @@ def regrid_to_1deg( if len(spatial_dims) < 2: raise ValueError(f"Unexpected output dims {out.dims}; need two spatial dims.") + if len(spatial_dims) > 2: + logger.warning( + "More than two spatial dims found in output: %s; using last two.", + spatial_dims, + ) + spatial_dims = spatial_dims[-2:] da, db = spatial_dims[-2], spatial_dims[-1] na, nb = out.sizes[da], out.sizes[db] - + logger.info( + "Output spatial dims: %s (%d), %s (%d); target (lat %d, lon %d)", + da, + na, + db, + nb, + ny, + nx, + ) # Decide mapping by comparing lengths to (ny, nx) if na == ny and nb == nx: out = out.rename({da: "lat", db: "lon"}) @@ -637,7 +562,7 @@ def regrid_to_1deg( choose_lat = da if abs(na - 180) <= abs(nb - 180) else db choose_lon = db if choose_lat == da else da out = out.rename({choose_lat: "lat", choose_lon: "lon"}) - + logger.info("Final output dims: %s", out.dims) # assign canonical 1-D coords out = out.assign_coords(lat=("lat", lat1d), lon=("lon", lon1d)) @@ -738,6 +663,31 @@ def _build_fx_native(ds_native: xr.Dataset) -> xr.Dataset: sftlf = _sftlf_from_native(ds_native) if sftlf is not None: pieces["sftlf"] = sftlf + # Also extract sftof (sea fraction) if present + sftof = None + for name in ["sftof", "ocnfrac", "wet"]: + if name in ds_native: + logger.info("Extracting sftof from native variable %s", name) + v = ds_native[name] + # If 'wet', convert 0/1 to percent + if name == "wet": + vmax = float(np.nanmax(np.asarray(v))) + if vmax <= 1.0 + 1e-6: + sftof = v * 100.0 + else: + sftof = v + else: + sftof = v + sftof = sftof.clip(min=0.0, max=100.0) + sftof = sftof.astype("f8") + attrs = dict(sftof.attrs) + attrs["units"] = "%" + attrs.setdefault("standard_name", "sea_area_fraction") + attrs.setdefault("long_name", "Percentage of sea area") + sftof.attrs = attrs + break + if sftof is not None: + pieces["sftof"] = sftof if not pieces: return xr.Dataset() ds_fx = xr.Dataset(pieces) @@ -812,21 +762,27 @@ def compute_areacella_from_bounds( return da -def _regrid_fx_once(mapfile: Path, ds_native: xr.Dataset) -> xr.Dataset: +def _regrid_fx_once( + mapfile: Path, ds_native: xr.Dataset, sftlf_path: Path | None = None +) -> xr.Dataset: """Compute & regrid sftlf/areacella once for a given mapfile; cache result. sftlf is regridded from source; areacella is computed on the destination grid. """ - cached = _FXCache.get(mapfile) + cached = FXCache.get(mapfile) if cached is not None: + logger.info("Getting cached fx variables") return cached + out_vars = {} + if sftlf_path: + logger.info("Getting sftlf from output path %s", sftlf_path) + out_vars["sftlf"] = xr.open_mfdataset(sftlf_path)["sftlf"] ds_fx_native = _build_fx_native(ds_native) - out_vars = {} # Regrid sftlf from source if present - if "sftlf" in ds_fx_native: - regridder = _RegridderCache.get(mapfile, "conservative") + if "sftlf" not in out_vars and "sftlf" in ds_fx_native: + regridder = RegridderCache.get(mapfile, "conservative") da = ds_fx_native["sftlf"] da2 = ( da.rename({"lndgrid": "lon"}) @@ -840,6 +796,19 @@ def _regrid_fx_once(mapfile: Path, ds_native: xr.Dataset) -> xr.Dataset: out.attrs.update(da.attrs) out_vars["sftlf"] = out + # Regrid sftof (sea fraction) from source if present + if "sftof" in ds_fx_native: + logger.info("Regridding sftof (sea fraction) from native") + regridder = RegridderCache.get(mapfile, "conservative") + da = ds_fx_native["sftof"] + da2 = da.rename({"xh": "lon", "yh": "lat"}).transpose(..., "lat", "lon") + out = regridder(da2) + spatial = [d for d in out.dims if d in ("lat", "lon")] + out = out.transpose(*spatial) + out.name = "sftof" + out.attrs.update(da.attrs) + out_vars["sftof"] = out + # Always compute areacella on the destination grid, not by regridding # Use the destination grid from the mapfile lat1d, lon1d = _dst_latlon_1d_from_map(mapfile) @@ -853,62 +822,5 @@ def _regrid_fx_once(mapfile: Path, ds_native: xr.Dataset) -> xr.Dataset: out_vars["areacella"] = areacella ds_fx = xr.Dataset(out_vars) - _FXCache.put(mapfile, ds_fx) + FXCache.put(mapfile, ds_fx) return ds_fx - - -# ------------------------- -# Cache -# ------------------------- -class _FXCache: - """Cache of regridded FX fields (sftlf, areacella) keyed by mapfile.""" - - _cache: Dict[Path, xr.Dataset] = {} - - @classmethod - def get(cls, key: Path) -> xr.Dataset | None: - """get cached variable""" - return cls._cache.get(key) - - @classmethod - def put(cls, key: Path, ds_fx: xr.Dataset) -> None: - """put variable into cache""" - cls._cache[key] = ds_fx - - @classmethod - def clear(cls) -> None: - """clear cache""" - cls._cache.clear() - - -class _RegridderCache: - """Cache of xESMF Regridders constructed from weight files. - - We build minimal `ds_in`/`ds_out` from the weight file to satisfy CF checks, - then reuse the weight file for the actual mapping. - """ - - _cache: Dict[Path, xe.Regridder] = {} - - @classmethod - def get(cls, mapfile: Path, method_label: str) -> xe.Regridder: - """Return a cached regridder for the given weight file and method.""" - mapfile = mapfile.expanduser().resolve() - if mapfile not in cls._cache: - if not mapfile.exists(): - raise FileNotFoundError(f"Regrid weights not found: {mapfile}") - ds_in, ds_out = _make_dummy_grids(mapfile) - cls._cache[mapfile] = xe.Regridder( - ds_in, - ds_out, - method=method_label, - filename=str(mapfile), # reuse the ESMF weight file on disk - reuse_weights=True, - periodic=True, # 0..360 longitudes - ) - return cls._cache[mapfile] - - @classmethod - def clear(cls) -> None: - """Clear all cached regridders (useful for tests or releasing resources).""" - cls._cache.clear() diff --git a/requirements.txt b/requirements.txt index 5b66a33..423fe25 100644 --- a/requirements.txt +++ b/requirements.txt @@ -6,5 +6,6 @@ pyyaml click esmpy pytest +pytest-cov xesmf geocat-comp diff --git a/scripts/fullLmon.sh b/scripts/fullLmon.sh index 89b73c6..270205b 100644 --- a/scripts/fullLmon.sh +++ b/scripts/fullLmon.sh @@ -8,6 +8,6 @@ module load conda conda activate CMORDEV -poetry run python ./scripts/monthly_cmor.py --realm lnd --test --workers 32 --skip-timeseries - +NCPUS=$(cat $PBS_NODEFILE | wc -l) +poetry run python ./scripts/monthly_cmor.py --realm lnd --test --workers $NCPUS --skip-timeseries --cmip-vars cProduct #/glade/u/home/cmip7/cases/b.e30_beta06.B1850C_LTso.ne30_t232_wgx3.192.wrkflw.1 /glade/work/hannay/cesm_tags/cesm3_0_beta06/cime diff --git a/scripts/fullamon.sh b/scripts/fullamon.sh index eee1054..2f00ebc 100644 --- a/scripts/fullamon.sh +++ b/scripts/fullamon.sh @@ -9,7 +9,7 @@ module load conda conda activate CMORDEV - -poetry run python ./scripts/monthly_cmor.py --realm atm --test --workers 32 --skip-timeseries +NCPUS=$(cat $PBS_NODEFILE | wc -l) +poetry run python ./scripts/monthly_cmor.py --realm atm --test --workers $NCPUS --skip-timeseries #/glade/u/home/cmip7/cases/b.e30_beta06.B1850C_LTso.ne30_t232_wgx3.192.wrkflw.1 /glade/work/hannay/cesm_tags/cesm3_0_beta06/cime diff --git a/scripts/fullomon.sh b/scripts/fullomon.sh index 3c92a87..c2846d4 100644 --- a/scripts/fullomon.sh +++ b/scripts/fullomon.sh @@ -10,4 +10,5 @@ module load conda conda activate CMORDEV -poetry run python ./scripts/monthly_cmor.py --realm ocn --test --workers 1 --cmip-vars sos --skip-timeseries +poetry run python ./scripts/monthly_cmor.py --realm ocn --test --workers 1 --cmip-vars sos --skip-timeseries \ + --ocn-static-file /glade/derecho/scratch/cmip7/archive/b.e30_beta06.B1850C_LTso.ne30_t232_wgx3.192.wrkflw.1/ocn/hist/b.e30_beta06.B1850C_LTso.ne30_t232_wgx3.192.wrkflw.1.mom6.h.static.nc diff --git a/scripts/monthly_cmor.py b/scripts/monthly_cmor.py index 406e73c..b152753 100644 --- a/scripts/monthly_cmor.py +++ b/scripts/monthly_cmor.py @@ -29,7 +29,8 @@ from cmip7_prep.pipeline import realize_regrid_prepare, open_native_for_cmip_vars from cmip7_prep.cmor_writer import CmorSession from cmip7_prep.dreq_search import find_variables_by_prefix -from cmip7_prep.mom6_static import load_mom6_static +from cmip7_prep.mom6_static import load_mom6_grid +from cmip7_prep.mom6_static import ocean_fx_fields from gents.hfcollection import HFCollection from gents.timeseries import TSCollection from dask.distributed import LocalCluster @@ -60,10 +61,16 @@ def parse_args(): ) parser.add_argument( - "--static-grid", + "--ocn-grid-file", + type=str, + default="/glade/campaign/cesm/cesmdata/inputdata/ocn/mom/tx2_3v2/ocean_hgrid_221123.nc", + help="Path to ocean grid description file for MOM (optional)", + ) + parser.add_argument( + "--ocn-static-file", type=str, default=None, - help="Path to static grid file for MOM variables (optional)", + help="Path to static file for MOM variables (optional)", ) parser.add_argument( "--cmip-vars", @@ -119,17 +126,17 @@ def parse_args(): run_years = int(run_freq[:-1]) run_months = run_years * 12 except Exception: - print(f"Invalid --run-freq value: {run_freq}") + logger.error(f"Invalid --run-freq value: {run_freq}") sys.exit(1) elif run_freq.endswith("m"): try: run_months = int(run_freq[:-1]) run_years = run_months // 12 except Exception: - print(f"Invalid --run-freq value: {run_freq}") + logger.error(f"Invalid --run-freq value: {run_freq}") sys.exit(1) else: - print(f"Invalid --run-freq value: {run_freq}") + logger.error(f"Invalid --run-freq value: {run_freq}") sys.exit(1) args.run_years = run_years args.run_months = run_months @@ -137,97 +144,143 @@ def parse_args(): def process_one_var( - varname: str, mapping, inputfile, tables_path, outdir, mom6_grid=None -) -> tuple[str, str]: - """Compute+write one CMIP variable. Returns (varname, 'ok' or error message).""" + varname: str, + mapping, + inputfile, + tables_path, + outdir, + realm="atm", + mom6_grid=None, + ocn_fx_fields=None, +) -> list[tuple[str, str]]: + """Compute+write one CMIP variable. Returns list of (varname, 'ok' or error message).""" logger.info(f"Starting processing for variable: {varname}") + results = [(varname, "started")] try: - - logger.info(f"Loading native data for {varname} from {inputfile}") - ds_native, var = open_native_for_cmip_vars( - varname, inputfile, mapping, use_cftime=True, parallel=True - ) - if var is None: - logger.warning(f"Source variable(s) not found for {varname}") - return (varname, "ERROR: Source variable(s) not found.") - # Example usage: attach MOM6 grid info for ocn realm - if mom6_grid is not None: - logger.info(f"MOM6 grid info available for {varname}") - # You can use mom6_grid['geolat'], mom6_grid['geolon'], etc. here - # For example, attach to ds_native as needed: - # ds_native['geolat'] = (('yh','xh'), mom6_grid['geolat']) - # ds_native['geolon'] = (('yh','xh'), mom6_grid['geolon']) + cfg = mapping.get_cfg(varname) except Exception as e: - logger.error(f"Exception while reading {varname}: {e!r}") - return (varname, f"ERROR: {e!r}") - try: - print(f"Regrid/prepare for {varname}") - if "ncol" in ds_native.dims or "lndgrid" in ds_native.dims: - logger.info(f"Regridding for {varname}") - ds_cmor = realize_regrid_prepare( + logger.error(f"Error retrieving config for {varname}: {e}") + results.append((varname, f"ERROR: {e}")) + return results + + dims_list = cfg.get("dims") + # If dims is a single list (atm/lnd), wrap in a list for uniformity + if dims_list and isinstance(dims_list[0], str): + dims_list = [dims_list] + for dims in dims_list: + logger.info(f"Processing {varname} with dims {dims}") + try: + open_kwargs = None + if realm == "ocn": + open_kwargs = {"decode_timedelta": False} + + ds_native, var = open_native_for_cmip_vars( + varname, + inputfile, mapping, - ds_native, + use_cftime=True, + parallel=True, + open_kwargs=open_kwargs, + ) + + # Append ocn_fx_fields to ds_native if available + if realm == "ocn" and ocn_fx_fields is not None: + ds_native = ds_native.merge(ocn_fx_fields) + logger.debug( + "ds_native keys: %s for var %s with dims %s", + list(ds_native.keys()), varname, + dims, + ) + if var is None: + logger.warning(f"Source variable(s) not found for {varname}") + results.append((varname, "ERROR: Source variable(s) not found.")) + continue + + # --- OCN: distinguish native vs regridded by dims --- + if "xh" in dims and "yh" in dims: + logger.info(f"Preparing native grid output for mom6 variable {varname}") + ds_cmor = ds_native + results.append((varname, "analyzed native mom6 grid")) + if mom6_grid is not None: + logger.info(f"Add geolat to ds_cmor") + + def _extract_array(val): + # If val is a tuple, return the first element, else return as is + return val[0] if isinstance(val, tuple) else val + + ds_cmor["geolat"] = xr.DataArray(mom6_grid[0], dims=("yh", "xh")) + ds_cmor["geolon"] = xr.DataArray(mom6_grid[1], dims=("yh", "xh")) + ds_cmor["geolat_c"] = xr.DataArray( + mom6_grid[2], dims=("yhp", "xhp") + ) + ds_cmor["geolon_c"] = xr.DataArray( + mom6_grid[3], dims=("yhp", "xhp") + ) + else: + # For lnd/atm or any other dims, use existing logic + logger.debug( + "Processing %s for dims %s (atm/lnd or other)", varname, dims + ) + ds_cmor = realize_regrid_prepare( + mapping, + ds_native, + varname, + tables_path=tables_path, + time_chunk=12, + mom6_grid=mom6_grid, + regrid_kwargs={ + "output_time_chunk": 12, + "dtype": "float32", + }, + open_kwargs={"decode_timedelta": True}, + ) + except Exception as e: + logger.error( + "Exception during regridding of %s with dims %s: %r", + varname, + dims, + e, + ) + continue + try: + # CMORize + log_dir = outdir + "/logs" + with CmorSession( tables_path=tables_path, - time_chunk=12, - regrid_kwargs={ - "output_time_chunk": 12, - "dtype": "float32", - "bilinear_map": Path( - "/glade/campaign/cesm/cesmdata/inputdata/cpl/gridmaps/ne30pg3/map_ne30pg3_to_1x1d_bilin.nc" - ), - }, + log_dir=log_dir, + log_name=f"cmor_{datetime.now(UTC).strftime('%Y%m%dT%H%M%SZ')}_{varname}.log", + dataset_attrs={"institution_id": "NCAR"}, + outdir=outdir, + ) as cm: + + vdef = type( + "VDef", + (), + { + "name": varname, + "table": cfg.get("table", "Amon"), + "units": cfg.get("units", ""), + "dims": dims, + "positive": cfg.get("positive", None), + "cell_methods": cfg.get("cell_methods", None), + "long_name": cfg.get("long_name", None), + "standard_name": cfg.get("standard_name", None), + "levels": cfg.get("levels", None), + }, + )() + logger.info( + f"Writing variable {varname} with dims {dims} and type {ds_cmor[varname].dtype}" + ) + cm.write_variable(ds_cmor, varname, vdef) + logger.info(f"Finished processing for {varname} with dims {dims}") + results.append((varname, "ok")) + except Exception as e: + logger.error( + f"Exception while processing {varname} with dims {dims}: {e!r}" ) - else: - logger.info(f"Skipping regrid/prepare for ocean variable {varname}") - ds_cmor = ds_native - # Attach only cell center arrays and bounds for CMOR - if mom6_grid is not None: - geolat = mom6_grid[0] - geolon = mom6_grid[1] - geolat_c = mom6_grid[2] - geolon_c = mom6_grid[3] - ncell = geolat.size - # 2D for reference, 1D for CMOR axes - ds_cmor["geolat"] = (("yh", "xh"), geolat) - ds_cmor["geolon"] = (("yh", "xh"), geolon) - ds_cmor["geolat_c"] = (("yhp", "xhp"), geolat_c) - ds_cmor["geolon_c"] = (("yhp", "xhp"), geolon_c) - except Exception as e: - logger.error(f"Exception while regridding/preparing {varname}: {e!r}") - try: - logger.info(f"CMOR writing for {varname}") - log_dir = outdir + "/logs" # or set as needed - with CmorSession( - tables_path=tables_path, - log_dir=log_dir, - log_name=f"cmor_{datetime.now(UTC).strftime('%Y%m%dT%H%M%SZ')}_{varname}.log", - dataset_attrs={"institution_id": "NCAR"}, - outdir=outdir, - ) as cm: - cfg = mapping.get_cfg(varname) - vdef = type( - "VDef", - (), - { - "name": varname, - "table": cfg.get("table", "Amon"), - "units": cfg.get("units", ""), - "dims": cfg.get("dims", []), - "positive": cfg.get("positive", None), - "cell_methods": cfg.get("cell_methods", None), - "long_name": cfg.get("long_name", None), - "standard_name": cfg.get("standard_name", None), - "levels": cfg.get("levels", None), - }, - )() - print(f"Writing variable {varname}") - cm.write_variable(ds_cmor, varname, vdef) - logger.info(f"Finished processing for {varname}") - return (varname, "ok") - except Exception as e: - logger.error(f"Exception while processing {varname}: {e!r}") - return (varname, f"ERROR: {e!r}") + results.append((varname, f"ERROR: {e!r}")) + return results process_one_var_delayed = delayed(process_one_var) @@ -245,7 +298,7 @@ def latest_monthly_file( raise NotADirectoryError(directory) found = [] seps = set() - print(f"Looking for files in {str(directory)}") + logger.debug(f"Looking for files in {str(directory)}") for p in directory.iterdir(): if not p.is_file(): continue @@ -261,7 +314,7 @@ def latest_monthly_file( return None if require_consistent_style and len(seps) > 1: raise ValueError("Mixed date styles detected (YYYYMM.nc and YYYY-MM.nc).") - print(f"Found {len(found)} files in {str(directory)}") + logger.debug(f"Found {len(found)} files in {str(directory)}") found.sort(key=lambda t: (t[0], t[1], t[2].name)) year, month, path = found[-1] return path, year, month @@ -273,6 +326,8 @@ def main(): OUTDIR = args.outdir mom6_grid = None + ocn_grid = None + ocn_fx_fields = None if args.realm == "atm": include_patterns = ["*cam.h0a*"] var_prefix = "Amon." @@ -286,8 +341,11 @@ def main(): include_patterns = ["*mom6.h.rho2.*", "*mom6.h.native.*", "*mom6.h.sfc.*"] var_prefix = "Omon." subdir = "ocn" - if not args.static_grid: - ocn_static_grid = "/glade/campaign/cesm/cesmdata/inputdata/ocn/mom/tx2_3v2/ocean_hgrid_221123.nc" + if args.ocn_grid_file: + ocn_grid = args.ocn_grid_file + if args.ocn_static_file: + ocn_static_file = args.ocn_static_file + ocn_fx_fields = ocean_fx_fields(ocn_static_file) # Setup input/output directories if args.caseroot and args.cimeroot: caseroot = args.caseroot @@ -298,7 +356,7 @@ def main(): try: from CIME.case import Case except ImportError as e: - print(f"Error importing CIME modules: {e}") + logger.warning(f"Error importing CIME modules: {e}") sys.exit(1) with Case(caseroot, read_only=True) as case: inputroot = case.get_value("DOUT_S_ROOT") @@ -317,7 +375,7 @@ def main(): # Calculate span in months span_months = (nyr - tsyr) * 12 if span_months < args.run_months: - print( + logger.info( f"Less than required run frequency ready ({span_months} months, need {args.run_months}), not processing {nyr}, {tsyr}" ) sys.exit(0) @@ -347,9 +405,9 @@ def main(): if not os.path.exists(str(TSDIR)): os.makedirs(str(TSDIR)) # Load MOM6 static grid if needed (ocn realm) - if args.realm == "ocn" and ocn_static_grid: - mom6_grid = load_mom6_static(ocn_static_grid) - print(f"Using MOM static grid file: {ocn_static_grid}") + if args.realm == "ocn" and ocn_grid: + mom6_grid = load_mom6_grid(ocn_grid) + logger.info(f"Using MOM grid file: {ocn_grid}") # Dask cluster setup if args.workers == 1: client = None @@ -367,7 +425,7 @@ def main(): input_head_dir = INPUTDIR output_head_dir = TSDIR if args.skip_timeseries: - print("Skipping timeseries processing as per --skip-timeseries flag.") + logger.info("Skipping timeseries processing as per --skip-timeseries flag.") else: cnt = 0 for include_pattern in include_patterns: @@ -379,7 +437,7 @@ def main(): sys.exit(0) hf_collection = HFCollection(input_head_dir, dask_client=client) for include_pattern in include_patterns: - logger.info(f"Processing files with pattern: {include_pattern}") + logger.info("Processing files with pattern: %s", include_pattern) hfp_collection = hf_collection.include_patterns([include_pattern]) hfp_collection.pull_metadata() ts_collection = TSCollection( @@ -388,16 +446,16 @@ def main(): if args.overwrite: ts_collection = ts_collection.apply_overwrite("*") ts_collection.execute() - print("Timeseries processing complete, starting CMORization...") + logger.info("Timeseries processing complete, starting CMORization...") mapping = Mapping.from_packaged_default() - print(f"Finding variables with prefix {var_prefix}") + logger.info(f"Finding variables with prefix {var_prefix}") if args.cmip_vars and len(args.cmip_vars) > 0: cmip_vars = args.cmip_vars else: cmip_vars = find_variables_by_prefix( None, var_prefix, where={"List of Experiments": "piControl"} ) - print(f"CMORIZING {len(cmip_vars)} variables") + logger.info(f"CMORIZING {len(cmip_vars)} variables") # Load requested variables if len(cmip_vars) > 0: if len(include_patterns) == 1: @@ -408,14 +466,28 @@ def main(): if args.workers == 1: results = [ process_one_var( - v, mapping, input_path, TABLES, OUTDIR, mom6_grid=mom6_grid + v, + mapping, + input_path, + TABLES, + OUTDIR, + realm=args.realm, + mom6_grid=mom6_grid, + ocn_fx_fields=ocn_fx_fields, ) for v in cmip_vars ] else: futs = [ process_one_var_delayed( - var, mapping, input_path, TABLES, OUTDIR, mom6_grid=mom6_grid + var, + mapping, + input_path, + TABLES, + OUTDIR, + realm=args.realm, + mom6_grid=mom6_grid, + ocn_fx_fields=ocn_fx_fields, ) for var in cmip_vars ] @@ -428,15 +500,28 @@ def main(): results = [] for _, result in as_completed(futures, with_results=True): try: - results.append(result) # (v, status) + # Handle result types: list of tuples, tuple, or other + if isinstance(result, list): + # If it's a list, check if it's a list of tuples + if all(isinstance(x, tuple) and len(x) == 2 for x in result): + results.extend(result) + else: + # Not a list of tuples, wrap as unknown + results.append((str(result), "unknown")) + elif isinstance(result, tuple) and len(result) == 2: + results.append(result) + else: + # Not a tuple/list, wrap as unknown + results.append((str(result), "unknown")) except Exception as e: - print("Task error:", e) + logger.error("Task error:", e) raise for v, status in results: - print(v, "→", status) + logger.info(f"Variable {v} processed with status: {status}") + else: - print("No results to process.") + logger.info("No results to process.") if client: client.close() if cluster: diff --git a/tests/test_regrid_latlon.py b/tests/test_regrid_latlon.py index 4d3f3ec..647e5e8 100644 --- a/tests/test_regrid_latlon.py +++ b/tests/test_regrid_latlon.py @@ -34,7 +34,7 @@ def test_lat_lon_named_and_sized_correctly(monkeypatch, order): # Fake cache → our regridder # pylint: disable=protected-access monkeypatch.setattr( - regrid._RegridderCache, + regrid.RegridderCache, "get", staticmethod(lambda path, method: _FakeRegridder(order)), ) @@ -75,7 +75,7 @@ def test_regrid_to_1deg_ds_carries_time_bounds(monkeypatch): """regrid_to_1deg_ds propagates time and existing bounds from the source dataset.""" # pylint: disable=protected-access monkeypatch.setattr( - regrid._RegridderCache, + regrid.RegridderCache, "get", staticmethod(lambda path, method: _FakeRegridder("yx")), ) @@ -123,7 +123,7 @@ def test_attrs_propagated(monkeypatch): # fake regridder + lat/lon # pylint: disable=protected-access monkeypatch.setattr( - regrid._RegridderCache, "get", staticmethod(lambda p, m: _FakeRegridder("yx")) + regrid.RegridderCache, "get", staticmethod(lambda p, m: _FakeRegridder("yx")) ) lat1d = np.linspace(-89.5, 89.5, 180) lon1d = np.arange(360, dtype="f8") + 0.5