|
| 1 | +# pylint: disable=too-many-lines |
1 | 2 | """Thin CMOR wrapper used by cmip7_prep. |
2 | 3 |
|
3 | 4 | This module centralizes CMOR session setup and writing so that the rest of the |
@@ -122,6 +123,47 @@ def _encode_time_to_num(obj, units: str, calendar: str) -> np.ndarray: |
122 | 123 | return nums.reshape(arr.shape) |
123 | 124 |
|
124 | 125 |
|
| 126 | +def _bounds_from_centers_1d(vals: np.ndarray, kind: str) -> np.ndarray: |
| 127 | + """Compute [n,2] cell bounds from 1-D centers for 'lat' or 'lon'. |
| 128 | +
|
| 129 | + - For 'lat': clamps to [-90, 90] |
| 130 | + - For 'lon': treats as periodic [0, 360) |
| 131 | + - Works with non-uniform spacing (uses midpoints between neighbors) |
| 132 | + """ |
| 133 | + v = np.asarray(vals, dtype="f8").reshape(-1) |
| 134 | + n = v.size |
| 135 | + if n < 2: |
| 136 | + raise ValueError("Need at least 2 points to compute bounds") |
| 137 | + |
| 138 | + # neighbor midpoints |
| 139 | + mid = 0.5 * (v[1:] + v[:-1]) # length n-1 |
| 140 | + bounds = np.empty((n, 2), dtype="f8") |
| 141 | + bounds[1:, 0] = mid |
| 142 | + bounds[:-1, 1] = mid |
| 143 | + |
| 144 | + # end caps: extrapolate by half-step at ends |
| 145 | + first_step = v[1] - v[0] |
| 146 | + last_step = v[-1] - v[-2] |
| 147 | + bounds[0, 0] = v[0] - 0.5 * first_step |
| 148 | + bounds[-1, 1] = v[-1] + 0.5 * last_step |
| 149 | + |
| 150 | + if kind == "lat": |
| 151 | + # clamp to physical limits |
| 152 | + bounds[:, 0] = np.maximum(bounds[:, 0], -90.0) |
| 153 | + bounds[:, 1] = np.minimum(bounds[:, 1], 90.0) |
| 154 | + elif kind == "lon": |
| 155 | + # wrap to [0, 360) |
| 156 | + bounds = bounds % 360.0 |
| 157 | + # ensure each row is increasing in modulo arithmetic |
| 158 | + wrap = bounds[:, 1] < bounds[:, 0] |
| 159 | + if np.any(wrap): |
| 160 | + bounds[wrap, 1] += 360.0 |
| 161 | + else: |
| 162 | + raise ValueError("kind must be 'lat' or 'lon'") |
| 163 | + |
| 164 | + return bounds |
| 165 | + |
| 166 | + |
125 | 167 | def _encode_time_bounds_to_num(tb, units: str, calendar: str) -> np.ndarray: |
126 | 168 | """ |
127 | 169 | Encode bounds array of shape (..., 2) to numeric CF time. |
@@ -382,6 +424,31 @@ def _resolve_table_filename(tables_path: Path, key: str) -> str: |
382 | 424 | DatasetJsonLike = Union[str, Path, AbstractContextManager] |
383 | 425 |
|
384 | 426 |
|
| 427 | +def _fx_glob_pattern(name: str) -> str: |
| 428 | + # CMOR filenames vary; this finds most fx files for this var |
| 429 | + # e.g., *_sftlf_fx_*.nc or sftlf_fx_*.nc |
| 430 | + return f"**/*_{name}_fx_*.nc" |
| 431 | + |
| 432 | + |
| 433 | +def _open_existing_fx(outdir: Path, name: str) -> xr.DataArray | None: |
| 434 | + # Search recursively for an existing fx file for this var |
| 435 | + for p in outdir.rglob(_fx_glob_pattern(name)): |
| 436 | + try: |
| 437 | + ds = xr.open_dataset(p, engine="netcdf4") |
| 438 | + if name in ds: |
| 439 | + return ds[name] |
| 440 | + except FileNotFoundError: |
| 441 | + return None |
| 442 | + except (OSError, ValueError) as e: |
| 443 | + # OSError: unreadable/corrupt file, low-level I/O; ValueError: engine/decoding issues |
| 444 | + warnings.warn(f"[fx] failed to open {p} with netcdf4: {e}", RuntimeWarning) |
| 445 | + except (ImportError, ModuleNotFoundError) as e: |
| 446 | + # netCDF4 backend not installed |
| 447 | + warnings.warn(f"[fx] netcdf4 backend unavailable: {e}", RuntimeWarning) |
| 448 | + |
| 449 | + return None |
| 450 | + |
| 451 | + |
385 | 452 | # --------------------------------------------------------------------- |
386 | 453 | # CMOR session |
387 | 454 | # --------------------------------------------------------------------- |
@@ -422,8 +489,14 @@ def __init__( |
422 | 489 | self._log_name = log_name |
423 | 490 | self._log_path: Path | None = None |
424 | 491 | self._pending_ps = None |
425 | | - self._outdir = Path(outdir) if outdir is not None else Path.cwd() / "CMIP7" |
| 492 | + self._outdir = Path(outdir or "./CMIP7").resolve() |
426 | 493 | self._outdir.mkdir(parents=True, exist_ok=True) |
| 494 | + self._fx_written: set[str] = ( |
| 495 | + set() |
| 496 | + ) # remembers which fx vars were written this run |
| 497 | + self._fx_cache: dict[str, xr.DataArray] = ( |
| 498 | + {} |
| 499 | + ) # regridded fx fields cached in-memory |
427 | 500 |
|
428 | 501 | def __enter__(self) -> "CmorSession": |
429 | 502 | # Resolve logfile path if requested |
@@ -765,6 +838,100 @@ def _get_1d_with_bounds(dsi: xr.Dataset, name: str, units_default: str): |
765 | 838 | axes_ids.extend([lat_id, lon_id]) |
766 | 839 | return axes_ids |
767 | 840 |
|
| 841 | + def _write_fx_2d(self, ds: xr.Dataset, name: str, units: str) -> None: |
| 842 | + if name not in ds: |
| 843 | + return |
| 844 | + table_filename = _resolve_table_filename(self.tables_path, "fx") |
| 845 | + cmor.load_table(table_filename) |
| 846 | + |
| 847 | + lat = ds["lat"].values |
| 848 | + lon = ds["lon"].values |
| 849 | + lat_b = ds.get("lat_bnds") |
| 850 | + lon_b = ds.get("lon_bnds") |
| 851 | + lat_b = ( |
| 852 | + lat_b.values |
| 853 | + if isinstance(lat_b, xr.DataArray) |
| 854 | + else _bounds_from_centers_1d(lat, "lat") |
| 855 | + ) |
| 856 | + lon_b = ( |
| 857 | + lon_b.values |
| 858 | + if isinstance(lon_b, xr.DataArray) |
| 859 | + else _bounds_from_centers_1d(lon, "lon") |
| 860 | + ) |
| 861 | + |
| 862 | + lat_id = cmor.axis( |
| 863 | + "latitude", "degrees_north", coord_vals=lat, cell_bounds=lat_b |
| 864 | + ) |
| 865 | + lon_id = cmor.axis( |
| 866 | + "longitude", "degrees_east", coord_vals=lon, cell_bounds=lon_b |
| 867 | + ) |
| 868 | + data_filled, fillv = _filled_for_cmor(ds[name]) |
| 869 | + |
| 870 | + var_id = cmor.variable(name, units, [lat_id, lon_id], missing_value=fillv) |
| 871 | + print(f"write fx variable {name}") |
| 872 | + cmor.write( |
| 873 | + var_id, |
| 874 | + np.asarray(data_filled), |
| 875 | + ) |
| 876 | + cmor.close(var_id) |
| 877 | + |
| 878 | + def ensure_fx_written_and_cached(self, ds_regr: xr.Dataset) -> xr.Dataset: |
| 879 | + """Ensure sftlf and areacella exist in ds_regr and are written once as fx. |
| 880 | + If not present in ds_regr, try to read from existing CMOR fx files in outdir. |
| 881 | + If present in ds_regr but not yet written this run, write and cache them. |
| 882 | + Returns ds_regr augmented with any missing fx fields. |
| 883 | + """ |
| 884 | + need = [("sftlf", "%"), ("areacella", "m2")] |
| 885 | + out = ds_regr |
| 886 | + |
| 887 | + for name, units in need: |
| 888 | + # 1) Already cached this run? |
| 889 | + if name in self._fx_cache: |
| 890 | + if name not in out: |
| 891 | + out = out.assign({name: self._fx_cache[name]}) |
| 892 | + continue |
| 893 | + |
| 894 | + # 2) Present in regridded dataset? (best case) |
| 895 | + if name in out: |
| 896 | + self._fx_cache[name] = out[name] |
| 897 | + if name not in self._fx_written: |
| 898 | + # Convert landfrac to % if needed |
| 899 | + if name == "sftlf": |
| 900 | + v = out[name] |
| 901 | + if (np.nanmax(v.values) <= 1.0) and v.attrs.get( |
| 902 | + "units", "" |
| 903 | + ) not in ("%", "percent"): |
| 904 | + out = out.assign( |
| 905 | + { |
| 906 | + name: (v * 100.0).assign_attrs( |
| 907 | + v.attrs | {"units": "%"} |
| 908 | + ) |
| 909 | + } |
| 910 | + ) |
| 911 | + self._fx_cache[name] = out[name] |
| 912 | + self._write_fx_2d(out, name, units) |
| 913 | + self._fx_written.add(name) |
| 914 | + continue |
| 915 | + |
| 916 | + # 3) Not present in ds_regr → try reading existing CMOR fx output |
| 917 | + if self._outdir: |
| 918 | + fx_da = _open_existing_fx(self._outdir, name) |
| 919 | + if fx_da is not None: |
| 920 | + # Verify grid match (simple equality on lat/lon values) |
| 921 | + if ( |
| 922 | + "lat" in out |
| 923 | + and "lon" in out |
| 924 | + and np.array_equal(out["lat"].values, fx_da["lat"].values) |
| 925 | + and np.array_equal(out["lon"].values, fx_da["lon"].values) |
| 926 | + ): |
| 927 | + out = out.assign({name: fx_da}) |
| 928 | + self._fx_cache[name] = out[name] |
| 929 | + self._fx_written.add(name) # already exists on disk |
| 930 | + continue |
| 931 | + # If grid mismatch, you could regrid fx_da here; for now, skip. |
| 932 | + # 4) Last resort: leave missing; caller may compute it later |
| 933 | + return out |
| 934 | + |
768 | 935 | # public API |
769 | 936 | # ------------------------- |
770 | 937 | def write_variable( |
@@ -801,6 +968,8 @@ def write_variable( |
801 | 968 | time_da = ds.get("time") |
802 | 969 | nt = 0 |
803 | 970 |
|
| 971 | + self.ensure_fx_written_and_cached(ds) |
| 972 | + |
804 | 973 | # ---- Main variable write ---- |
805 | 974 |
|
806 | 975 | cmor.write( |
|
0 commit comments