Skip to content

Commit f46499d

Browse files
authored
Merge pull request #2 from ESMCI/jpe/merge_branch
Jpe/merge branch
2 parents 6ad59db + 6dbc5fd commit f46499d

File tree

12 files changed

+814
-67
lines changed

12 files changed

+814
-67
lines changed

.github/workflows/pylint.yml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,3 +36,8 @@ jobs:
3636
run: poetry install --no-interaction --no-ansi
3737
- name: Run pre-commit on all files
3838
run: poetry run pre-commit run --all-files
39+
# the following can be used by developers to login to the github server in case of errors
40+
# see https://github.com/marketplace/actions/debugging-with-tmate for further details
41+
# - name: Setup tmate session
42+
# if: ${{ failure() }}
43+
# uses: mxschmitt/action-tmate@v3

cmip7_prep/cmor_writer.py

Lines changed: 170 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
# pylint: disable=too-many-lines
12
"""Thin CMOR wrapper used by cmip7_prep.
23
34
This module centralizes CMOR session setup and writing so that the rest of the
@@ -122,6 +123,47 @@ def _encode_time_to_num(obj, units: str, calendar: str) -> np.ndarray:
122123
return nums.reshape(arr.shape)
123124

124125

126+
def _bounds_from_centers_1d(vals: np.ndarray, kind: str) -> np.ndarray:
127+
"""Compute [n,2] cell bounds from 1-D centers for 'lat' or 'lon'.
128+
129+
- For 'lat': clamps to [-90, 90]
130+
- For 'lon': treats as periodic [0, 360)
131+
- Works with non-uniform spacing (uses midpoints between neighbors)
132+
"""
133+
v = np.asarray(vals, dtype="f8").reshape(-1)
134+
n = v.size
135+
if n < 2:
136+
raise ValueError("Need at least 2 points to compute bounds")
137+
138+
# neighbor midpoints
139+
mid = 0.5 * (v[1:] + v[:-1]) # length n-1
140+
bounds = np.empty((n, 2), dtype="f8")
141+
bounds[1:, 0] = mid
142+
bounds[:-1, 1] = mid
143+
144+
# end caps: extrapolate by half-step at ends
145+
first_step = v[1] - v[0]
146+
last_step = v[-1] - v[-2]
147+
bounds[0, 0] = v[0] - 0.5 * first_step
148+
bounds[-1, 1] = v[-1] + 0.5 * last_step
149+
150+
if kind == "lat":
151+
# clamp to physical limits
152+
bounds[:, 0] = np.maximum(bounds[:, 0], -90.0)
153+
bounds[:, 1] = np.minimum(bounds[:, 1], 90.0)
154+
elif kind == "lon":
155+
# wrap to [0, 360)
156+
bounds = bounds % 360.0
157+
# ensure each row is increasing in modulo arithmetic
158+
wrap = bounds[:, 1] < bounds[:, 0]
159+
if np.any(wrap):
160+
bounds[wrap, 1] += 360.0
161+
else:
162+
raise ValueError("kind must be 'lat' or 'lon'")
163+
164+
return bounds
165+
166+
125167
def _encode_time_bounds_to_num(tb, units: str, calendar: str) -> np.ndarray:
126168
"""
127169
Encode bounds array of shape (..., 2) to numeric CF time.
@@ -382,6 +424,31 @@ def _resolve_table_filename(tables_path: Path, key: str) -> str:
382424
DatasetJsonLike = Union[str, Path, AbstractContextManager]
383425

384426

427+
def _fx_glob_pattern(name: str) -> str:
428+
# CMOR filenames vary; this finds most fx files for this var
429+
# e.g., *_sftlf_fx_*.nc or sftlf_fx_*.nc
430+
return f"**/*_{name}_fx_*.nc"
431+
432+
433+
def _open_existing_fx(outdir: Path, name: str) -> xr.DataArray | None:
434+
# Search recursively for an existing fx file for this var
435+
for p in outdir.rglob(_fx_glob_pattern(name)):
436+
try:
437+
ds = xr.open_dataset(p, engine="netcdf4")
438+
if name in ds:
439+
return ds[name]
440+
except FileNotFoundError:
441+
return None
442+
except (OSError, ValueError) as e:
443+
# OSError: unreadable/corrupt file, low-level I/O; ValueError: engine/decoding issues
444+
warnings.warn(f"[fx] failed to open {p} with netcdf4: {e}", RuntimeWarning)
445+
except (ImportError, ModuleNotFoundError) as e:
446+
# netCDF4 backend not installed
447+
warnings.warn(f"[fx] netcdf4 backend unavailable: {e}", RuntimeWarning)
448+
449+
return None
450+
451+
385452
# ---------------------------------------------------------------------
386453
# CMOR session
387454
# ---------------------------------------------------------------------
@@ -422,8 +489,14 @@ def __init__(
422489
self._log_name = log_name
423490
self._log_path: Path | None = None
424491
self._pending_ps = None
425-
self._outdir = Path(outdir) if outdir is not None else Path.cwd() / "CMIP7"
492+
self._outdir = Path(outdir or "./CMIP7").resolve()
426493
self._outdir.mkdir(parents=True, exist_ok=True)
494+
self._fx_written: set[str] = (
495+
set()
496+
) # remembers which fx vars were written this run
497+
self._fx_cache: dict[str, xr.DataArray] = (
498+
{}
499+
) # regridded fx fields cached in-memory
427500

428501
def __enter__(self) -> "CmorSession":
429502
# Resolve logfile path if requested
@@ -765,6 +838,100 @@ def _get_1d_with_bounds(dsi: xr.Dataset, name: str, units_default: str):
765838
axes_ids.extend([lat_id, lon_id])
766839
return axes_ids
767840

841+
def _write_fx_2d(self, ds: xr.Dataset, name: str, units: str) -> None:
842+
if name not in ds:
843+
return
844+
table_filename = _resolve_table_filename(self.tables_path, "fx")
845+
cmor.load_table(table_filename)
846+
847+
lat = ds["lat"].values
848+
lon = ds["lon"].values
849+
lat_b = ds.get("lat_bnds")
850+
lon_b = ds.get("lon_bnds")
851+
lat_b = (
852+
lat_b.values
853+
if isinstance(lat_b, xr.DataArray)
854+
else _bounds_from_centers_1d(lat, "lat")
855+
)
856+
lon_b = (
857+
lon_b.values
858+
if isinstance(lon_b, xr.DataArray)
859+
else _bounds_from_centers_1d(lon, "lon")
860+
)
861+
862+
lat_id = cmor.axis(
863+
"latitude", "degrees_north", coord_vals=lat, cell_bounds=lat_b
864+
)
865+
lon_id = cmor.axis(
866+
"longitude", "degrees_east", coord_vals=lon, cell_bounds=lon_b
867+
)
868+
data_filled, fillv = _filled_for_cmor(ds[name])
869+
870+
var_id = cmor.variable(name, units, [lat_id, lon_id], missing_value=fillv)
871+
print(f"write fx variable {name}")
872+
cmor.write(
873+
var_id,
874+
np.asarray(data_filled),
875+
)
876+
cmor.close(var_id)
877+
878+
def ensure_fx_written_and_cached(self, ds_regr: xr.Dataset) -> xr.Dataset:
879+
"""Ensure sftlf and areacella exist in ds_regr and are written once as fx.
880+
If not present in ds_regr, try to read from existing CMOR fx files in outdir.
881+
If present in ds_regr but not yet written this run, write and cache them.
882+
Returns ds_regr augmented with any missing fx fields.
883+
"""
884+
need = [("sftlf", "%"), ("areacella", "m2")]
885+
out = ds_regr
886+
887+
for name, units in need:
888+
# 1) Already cached this run?
889+
if name in self._fx_cache:
890+
if name not in out:
891+
out = out.assign({name: self._fx_cache[name]})
892+
continue
893+
894+
# 2) Present in regridded dataset? (best case)
895+
if name in out:
896+
self._fx_cache[name] = out[name]
897+
if name not in self._fx_written:
898+
# Convert landfrac to % if needed
899+
if name == "sftlf":
900+
v = out[name]
901+
if (np.nanmax(v.values) <= 1.0) and v.attrs.get(
902+
"units", ""
903+
) not in ("%", "percent"):
904+
out = out.assign(
905+
{
906+
name: (v * 100.0).assign_attrs(
907+
v.attrs | {"units": "%"}
908+
)
909+
}
910+
)
911+
self._fx_cache[name] = out[name]
912+
self._write_fx_2d(out, name, units)
913+
self._fx_written.add(name)
914+
continue
915+
916+
# 3) Not present in ds_regr → try reading existing CMOR fx output
917+
if self._outdir:
918+
fx_da = _open_existing_fx(self._outdir, name)
919+
if fx_da is not None:
920+
# Verify grid match (simple equality on lat/lon values)
921+
if (
922+
"lat" in out
923+
and "lon" in out
924+
and np.array_equal(out["lat"].values, fx_da["lat"].values)
925+
and np.array_equal(out["lon"].values, fx_da["lon"].values)
926+
):
927+
out = out.assign({name: fx_da})
928+
self._fx_cache[name] = out[name]
929+
self._fx_written.add(name) # already exists on disk
930+
continue
931+
# If grid mismatch, you could regrid fx_da here; for now, skip.
932+
# 4) Last resort: leave missing; caller may compute it later
933+
return out
934+
768935
# public API
769936
# -------------------------
770937
def write_variable(
@@ -801,6 +968,8 @@ def write_variable(
801968
time_da = ds.get("time")
802969
nt = 0
803970

971+
self.ensure_fx_written_and_cached(ds)
972+
804973
# ---- Main variable write ----
805974

806975
cmor.write(

cmip7_prep/data/cesm_to_cmip7.yaml

Lines changed: 80 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,24 @@ variables:
8787
- cesm_var: QFLX
8888
scale: -1.0 # flip sign if CESM positive upward
8989

90+
evspsblsoi:
91+
table: Lmon
92+
units: kg m-2 s-1
93+
long_name: Water Evaporation from Soil
94+
dims: [time, lat, lon]
95+
positive: up
96+
sources:
97+
- cesm_var: QSOIL
98+
99+
evspsblveg:
100+
table: Lmon
101+
units: kg m-2 s-1
102+
long_name: Water Evaporation from Soil
103+
dims: [time, lat, lon]
104+
positive: up
105+
sources:
106+
- cesm_var: QVEGE
107+
90108
hfls:
91109
table: Amon
92110
units: "W m-2"
@@ -139,6 +157,51 @@ variables:
139157
sources:
140158
- cesm_var: QREFHT
141159

160+
lai:
161+
table: Lmon
162+
units: "1"
163+
dims: [time, lat, lon]
164+
sources:
165+
- cesm_var: TLAI
166+
167+
mrfso:
168+
table: Lmon
169+
units: kg m-2
170+
dims: [time, lat, lon]
171+
formula: verticalsum(SOILICE, capped_at=5000)
172+
sources:
173+
- cesm_var: SOILICE
174+
175+
mrro:
176+
table: Lmon
177+
units: kg m-2 s-1
178+
dims: [time, lat, lon]
179+
sources:
180+
- cesm_var: QRUNOFF
181+
182+
mrros:
183+
table: Lmon
184+
units: kg m-2 s-1
185+
dims: [time, lat, lon]
186+
sources:
187+
- cesm_var: QOVER
188+
189+
mrso:
190+
table: Lmon
191+
units: kg m-2
192+
dims: [time, lat, lon]
193+
formula: verticalsum(SOILICE + SOILLIQ, capped_at=5000)
194+
sources:
195+
- cesm_var: SOILICE
196+
- cesm_var: SOILLIQ
197+
198+
mrsos:
199+
table: Lmon
200+
units: kg m-2
201+
dims: [time, lat, lon]
202+
sources:
203+
- cesm_var: SOILWATER_10CM
204+
142205
pr:
143206
table: Amon
144207
units: kg m-2 s-1
@@ -432,10 +495,20 @@ variables:
432495
sources:
433496
- cesm_var: Z3
434497

435-
fx:
436-
areacella:
437-
units: m2
438-
source: cell_area
439-
sftlf:
440-
units: '%'
441-
source: sftlf
498+
# areacella:
499+
# table: fx
500+
# units: m2
501+
# formula: area * 1.e6
502+
# standard_name: cell_area
503+
# dims: [lat, lon]
504+
# sources:
505+
# - cesm_var: ZBOT
506+
#
507+
#
508+
# sftlf:
509+
# table: fx
510+
# units: '%'
511+
# formula: landfrac * 100
512+
# dims: [lat, lon]
513+
# sources:
514+
# - cesm_var: ZBOT

cmip7_prep/mapping_compat.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,23 @@ def _safe_eval(expr: str, local_names: Dict[str, Any]) -> Any:
102102
2.0
103103
"""
104104
safe_globals = {"__builtins__": {}}
105-
locals_safe = {"np": np, "xr": xr}
105+
106+
# Add custom formula functions here
107+
def verticalsum(arr, capped_at=None, dim="levsoi"):
108+
# arr can be a DataArray or an expression
109+
if isinstance(arr, xr.DataArray):
110+
summed = arr.sum(dim=dim, skipna=True)
111+
else:
112+
summed = arr # fallback, should be DataArray
113+
if capped_at is not None:
114+
summed = xr.where(summed > capped_at, capped_at, summed)
115+
return summed
116+
117+
locals_safe = {
118+
"np": np,
119+
"xr": xr,
120+
"verticalsum": verticalsum,
121+
}
106122
locals_safe.update(local_names)
107123
# pylint: disable=eval-used
108124
return eval(expr, safe_globals, locals_safe)
@@ -274,6 +290,11 @@ def _realize_core(ds: xr.Dataset, vc: VarConfig) -> xr.DataArray:
274290
raise KeyError(f"source variable {vc.source!r} not found in dataset")
275291
return ds[vc.source]
276292

293+
if vc.name == "sftlf":
294+
vc.raw_variables = "landfrac"
295+
elif vc.name == "areacella":
296+
vc.raw_variables = "area"
297+
277298
# 2) identity mapping from a single raw variable
278299
if (
279300
vc.raw_variables

0 commit comments

Comments
 (0)