Skip to content

Commit bb61b84

Browse files
committed
add vertical.py
1 parent db162d0 commit bb61b84

File tree

1 file changed

+177
-6
lines changed

1 file changed

+177
-6
lines changed

cmip7_prep/vertical.py

Lines changed: 177 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,178 @@
1-
# cmip7_prep/vertical.py
2-
def to_plev19(ds, var, vdef):
3-
# Use model hybrid coefficients + surface pressure to compute pressure
4-
# then interpolate log-pressure to requested plev19 (Pa).
5-
# Return dataset with dim replaced by 'plev'.
6-
return ds
71

2+
"""Vertical coordinate handling for CESM → CMIP7.
3+
4+
This module provides utilities to convert hybrid-sigma model levels to requested
5+
pressure levels (e.g., CMIP plev19) prior to CMORization.
6+
7+
Primary entry point:
8+
to_plev19(ds, var, tables_path, ...)
9+
10+
Dependencies:
11+
- geocat-comp (preferred): uses `interp_hybrid_to_pressure`
12+
If not available, this function raises ImportError (fallback can be added later).
13+
"""
14+
from __future__ import annotations
15+
16+
import json
17+
import os
18+
from typing import Optional
19+
20+
import numpy as np
21+
import xarray as xr
22+
23+
# Try to resolve geocat.comp's interpolation function in a compatible way
24+
try:
25+
# geocat-comp >= 2023.x commonly exposes this symbol at top-level
26+
from geocat.comp import interp_hybrid_to_pressure as _interp_h2p # type: ignore
27+
except Exception: # pragma: no cover - environment-dependent
28+
try:
29+
# older layouts may nest it under .interpolation
30+
import geocat.comp as _gc # type: ignore
31+
_interp_h2p = _gc.interpolation.interp_hybrid_to_pressure
32+
except Exception: # pragma: no cover
33+
_interp_h2p = None
34+
35+
36+
def _read_requested_levels(tables_path: str | os.PathLike, axis_name: str = "plev19") -> np.ndarray:
37+
"""Read requested target pressure levels from CMIP coordinate JSON.
38+
39+
Parameters
40+
----------
41+
tables_path : str or Path
42+
Directory containing CMIPx coordinate/table JSON files.
43+
axis_name : str
44+
Axis entry name to read (e.g., 'plev19').
45+
46+
Returns
47+
-------
48+
np.ndarray
49+
1-D array of requested pressure levels in Pa.
50+
"""
51+
coord_json_candidates = [
52+
"CMIP7_coordinate.json",
53+
"CMIP6_coordinate.json",
54+
"CMIP_coordinate.json",
55+
]
56+
coord_json = None
57+
for name in coord_json_candidates:
58+
candidate = os.path.join(str(tables_path), name)
59+
if os.path.exists(candidate):
60+
coord_json = candidate
61+
break
62+
if coord_json is None:
63+
raise FileNotFoundError(
64+
f"Could not find a coordinate table JSON under {tables_path!s}; "
65+
f"tried {coord_json_candidates}"
66+
)
67+
68+
with open(coord_json, "r", encoding="utf-8") as f:
69+
data = json.load(f)
70+
71+
try:
72+
req = data["axis_entry"][axis_name]["requested"]
73+
except Exception as exc: # pragma: no cover
74+
raise KeyError(f"Axis entry '{axis_name}' not found in {coord_json}") from exc
75+
76+
levels = np.asarray(req, dtype="f8")
77+
return levels
78+
79+
80+
def _resolve_p0(ds: xr.Dataset, p0_name: str = "P0") -> float:
81+
"""Return reference pressure P0 (Pa) from dataset or default to 100000 Pa."""
82+
if p0_name in ds:
83+
# could be a scalar DataArray
84+
val = ds[p0_name].values
85+
try:
86+
return float(val)
87+
except Exception:
88+
pass
89+
if p0_name in ds.attrs:
90+
try:
91+
return float(ds.attrs[p0_name])
92+
except Exception:
93+
pass
94+
return 100000.0 # Pa
95+
96+
97+
def to_plev19(
98+
ds: xr.Dataset,
99+
var: str,
100+
tables_path: str | os.PathLike,
101+
*,
102+
lev_dim: str = "lev",
103+
ps_name: str = "PS",
104+
hyam_name: str = "hyam",
105+
hybm_name: str = "hybm",
106+
p0_name: str = "P0",
107+
) -> xr.Dataset:
108+
"""Interpolate a hybrid-level variable to CMIP plev19 pressure levels (Pa).
109+
110+
Parameters
111+
----------
112+
ds : xr.Dataset
113+
Input dataset containing the variable and required hybrid inputs.
114+
Expected fields: `var`, `PS`, `hyam`, `hybm` (names configurable).
115+
var : str
116+
Name of the variable to be interpolated along the vertical.
117+
tables_path : str or Path
118+
Path to CMIPx Tables directory (for coordinate JSON).
119+
lev_dim : str, default "lev"
120+
Name of the hybrid-level dimension in `var`.
121+
ps_name, hyam_name, hybm_name, p0_name : str
122+
Names for surface pressure, hybrid A/B (midpoint) coefficients, and reference pressure.
123+
124+
Returns
125+
-------
126+
xr.Dataset
127+
A new dataset with `var` replaced by a pressure-level version with dimension `plev`
128+
and coordinate `plev` set to the requested values (Pa). Drops `hyam`, `hybm`, and `P0`
129+
if they were present.
130+
131+
Notes
132+
-----
133+
This function prefers geocat-comp's `interp_hybrid_to_pressure`. If geocat-comp
134+
is not importable in the environment, it will raise ImportError with guidance.
135+
"""
136+
if _interp_h2p is None: # pragma: no cover - depends on environment
137+
raise ImportError(
138+
"geocat.comp is not available. Please install 'geocat-comp' "
139+
"to enable hybrid->pressure vertical interpolation."
140+
)
141+
142+
required = [var, ps_name, hyam_name, hybm_name]
143+
missing = [name for name in required if name not in ds]
144+
if missing:
145+
raise KeyError(f"Missing required variables in dataset: {missing}")
146+
147+
p0 = _resolve_p0(ds, p0_name=p0_name)
148+
new_levels = _read_requested_levels(tables_path, axis_name="plev19")
149+
150+
# geocat-comp performs log-pressure interpolation internally
151+
out_da = _interp_h2p(
152+
ds[var],
153+
ds[ps_name],
154+
ds[hyam_name],
155+
ds[hybm_name],
156+
p0=p0,
157+
new_levels=new_levels,
158+
lev_dim=lev_dim,
159+
)
160+
161+
# Ensure dimension is named 'plev' and coordinate is present
162+
if "plev" not in out_da.dims:
163+
out_da = out_da.rename({lev_dim: "plev"})
164+
out_da = out_da.assign_coords(plev=("plev", new_levels))
165+
out_da["plev"].attrs.update(
166+
{"units": "Pa", "standard_name": "air_pressure", "positive": "down"}
167+
)
168+
169+
# Assemble return dataset
170+
ds_out = ds.copy()
171+
ds_out[var] = out_da
172+
173+
# Optionally drop hybrid coefficients and P0 if present
174+
drop = [n for n in (hyam_name, hybm_name, p0_name) if n in ds_out.variables]
175+
if drop:
176+
ds_out = ds_out.drop_vars(drop)
177+
178+
return ds_out

0 commit comments

Comments
 (0)