33This module centralizes CMOR session setup and writing so that the rest of the
44pipeline can stay xarray-first. It supports either a dataset JSON file (preferred)
55or directly injected global attributes, and creates axes based on the coordinates
6- present in the provided dataset.
6+ present in the provided dataset. It also supports a packaged default
7+ `cmor_dataset.json` living under `cmip7_prep/data/`.
78"""
89
910from __future__ import annotations
1011
1112from contextlib import AbstractContextManager
13+ from importlib .resources import as_file , files
1214from pathlib import Path
1315from typing import Any , Dict , Optional
1416
15- import cmor # type: ignore
1617import numpy as np
1718import xarray as xr
19+ from xarray .coding .times import encode_cf_datetime
20+
21+ import cmor # type: ignore
22+
23+
24+ # ---------------------------------------------------------------------
25+ # Packaged resource helpers
26+ # ---------------------------------------------------------------------
27+ def packaged_dataset_json (filename : str = "cmor_dataset.json" ) -> Any :
28+ """Return a context manager yielding a real FS path to a packaged dataset JSON.
29+
30+ Looks under cmip7_prep/data/.
31+ Usage:
32+ with packaged_dataset_json() as p:
33+ cmor.dataset_json(str(p))
34+ """
35+ res = files ("cmip7_prep" ).joinpath (f"data/{ filename } " )
36+ return as_file (res )
1837
1938
39+ # ---------------------------------------------------------------------
40+ # Time encoding
41+ # ---------------------------------------------------------------------
42+ def _encode_time_to_num (time_da : xr .DataArray , units : str , calendar : str ) -> np .ndarray :
43+ """Return numeric CF time values (float64) acceptable to CMOR.
44+
45+ Tries xarray's encoder first; if that fails and cftime is available,
46+ falls back to cftime.date2num. Raises a ValueError with details otherwise.
47+ """
48+ # 1) xarray encoder (handles numpy datetime64 and cftime objects if cftime present)
49+ try :
50+ out = encode_cf_datetime (time_da .values , units = units , calendar = calendar )
51+ return np .asarray (out , dtype = "f8" )
52+ except (ValueError , TypeError ) as exc_xr :
53+ last_err = exc_xr
54+
55+ # 2) Optional cftime path (lazy import to keep lint/typecheck happy)
56+ try :
57+ import cftime # type: ignore # pylint: disable=import-outside-toplevel
58+
59+ seq = time_da .values .tolist () # handles object arrays / cftime arrays
60+ out = cftime .date2num (seq , units = units , calendar = calendar )
61+ return np .asarray (out , dtype = "f8" )
62+ except Exception as exc_cf : # noqa: BLE001 - we surface both causes together
63+ raise ValueError (
64+ f"Could not encode time to numeric CF values with units={ units !r} , "
65+ f"calendar={ calendar !r} . xarray error: { last_err } ; cftime error: { exc_cf } "
66+ ) from exc_cf
67+
68+
69+ # ---------------------------------------------------------------------
70+ # CMOR session
71+ # ---------------------------------------------------------------------
2072class CmorSession (AbstractContextManager ):
2173 """Context manager for CMOR sessions.
2274
@@ -43,74 +95,72 @@ def __init__(
4395 def __enter__ (self ) -> "CmorSession" :
4496 """Initialize CMOR and register dataset metadata."""
4597 cmor .setup (inpath = self .tables_path , netcdf_file_action = cmor .CMOR_REPLACE_3 )
98+
4699 if self .dataset_json_path :
47100 cmor .dataset_json (self .dataset_json_path )
48101 elif self .dataset_attrs :
49102 for key , value in self .dataset_attrs .items ():
50103 cmor .setGblAttr (key , value )
51104 else :
52- raise ValueError (
53- "CmorSession requires either dataset_json path or dataset_attrs."
54- )
105+ # Fallback to packaged cmor_dataset.json if available
106+ with packaged_dataset_json () as p :
107+ cmor .dataset_json (str (p ))
108+
55109 return self
56110
57- def __exit__ (self , exc_type , exc , tb ) -> None :
111+ def __exit__ (self , exc_type , exc , tb ) -> None : # noqa: D401
58112 """Finalize CMOR, closing any open handles."""
59113 cmor .close ()
60114
61115 # -------------------------
62116 # internal helpers
63117 # -------------------------
64-
65118 def _define_axes (self , ds : xr .Dataset , vdef : Any ) -> list [int ]:
66- """Create CMOR axis IDs based on dataset coordinates and optional vdef levels.
67-
68- Parameters
69- ----------
70- ds : xr.Dataset
71- Dataset containing coordinates like time, lat, lon, and optionally plev/lev.
72- vdef : Any
73- Variable definition object. If it has a ``levels`` mapping (e.g.,
74- ``{'axis_entry': 'plev19', 'name': 'plev', 'units': 'Pa'}``), that will
75- be used to define the vertical axis.
76-
77- Returns
78- -------
79- list[int]
80- List of CMOR axis IDs in the order they should be used for the variable.
81- """
119+ """Create CMOR axis IDs based on dataset coordinates and optional vdef levels."""
82120 axes : list [int ] = []
83121
84- # Time axis
122+ # ---- time axis ----
85123 if "time" in ds .coords or "time" in ds :
86124 time = ds ["time" ]
87125 t_units = time .attrs .get ("units" , "days since 1850-01-01" )
88- cal = time .attrs .get ("calendar" , "standard" )
89- if np .issubdtype (time .dtype , np .datetime64 ):
90- # convert to numeric time using CF units if needed
91- tvals = xr .conventions .times .encode_cf_datetime (
92- time .values , t_units , calendar = cal
93- )
94- else :
95- tvals = time .values
126+ cal = time .attrs .get ("calendar" , time .encoding .get ("calendar" , "noleap" ))
127+ tvals = _encode_time_to_num (time , t_units , cal )
128+
129+ # Optional bounds: try common names or CF 'bounds' attribute
130+ tb = ds .get ("time_bnds" ) or ds .get ("time_bounds" )
131+ if tb is None and isinstance (time .attrs .get ("bounds" ), str ):
132+ bname = time .attrs ["bounds" ]
133+ tb = ds .get (bname )
134+
135+ t_bnds = None
136+ if tb is not None :
137+ try :
138+ t_bnds = _encode_time_to_num (tb , t_units , cal )
139+ except ValueError :
140+ t_bnds = None
141+
96142 axes .append (
97- cmor .axis (table_entry = "time" , units = str (t_units ), coord_vals = tvals )
143+ cmor .axis (
144+ table_entry = "time" ,
145+ units = str (t_units ),
146+ coord_vals = tvals ,
147+ cell_bounds = t_bnds if t_bnds is not None else None ,
148+ )
98149 )
99150
100- # Vertical axis (pressure or model levels)
151+ # ---- vertical axis (plev or lev) ----
101152 levels_info = getattr (vdef , "levels" , None )
102- # Prefer an explicit plev axis in the data
103153 if "plev" in ds .coords :
104154 p = ds ["plev" ]
105155 p_units = p .attrs .get ("units" , "Pa" )
106- # If vdef specifies an axis entry (e.g., plev19), use that; else let CMOR infer
156+ # If vdef specifies an axis entry (e.g., plev19), use that; else infer
107157 table_entry = None
108158 if isinstance (levels_info , dict ):
109159 table_entry = levels_info .get ("axis_entry" ) or levels_info .get ("name" )
110160 if table_entry is None :
111- # Heuristic: choose plev19 if there are 19 levels, else generic plev
112161 table_entry = "plev19" if p .size == 19 else "plev"
113- p_bnds = ds .get ("plev_bnds" )
162+
163+ p_bnds = ds .get ("plev_bnds" ) or ds .get ("plev_bounds" )
114164 axes .append (
115165 cmor .axis (
116166 table_entry = table_entry ,
@@ -120,7 +170,6 @@ def _define_axes(self, ds: xr.Dataset, vdef: Any) -> list[int]:
120170 )
121171 )
122172 elif "lev" in ds .coords :
123- # Generic hybrid "lev" axis; rely on table entry provided via vdef or default to "alev"
124173 lev = ds ["lev" ]
125174 table_entry = "alev"
126175 if isinstance (levels_info , dict ):
@@ -129,12 +178,13 @@ def _define_axes(self, ds: xr.Dataset, vdef: Any) -> list[int]:
129178 cmor .axis (table_entry = table_entry , units = "1" , coord_vals = lev .values )
130179 )
131180
132- # Latitude / Longitude
181+ # ---- horizontal axes ----
133182 if "lat" in ds .coords and "lon" in ds .coords :
134183 lat = ds ["lat" ]
135184 lon = ds ["lon" ]
136- lat_b = ds .get ("lat_bnds" )
137- lon_b = ds .get ("lon_bnds" )
185+ lat_b = ds .get ("lat_bnds" ) or ds .get ("lat_bounds" )
186+ lon_b = ds .get ("lon_bnds" ) or ds .get ("lon_bounds" )
187+
138188 axes .append (
139189 cmor .axis (
140190 table_entry = "lat" ,
@@ -157,58 +207,48 @@ def _define_axes(self, ds: xr.Dataset, vdef: Any) -> list[int]:
157207 # -------------------------
158208 # public API
159209 # -------------------------
160-
161210 def write_variable (
162211 self , ds : xr .Dataset , varname : str , vdef : Any , outdir : Path
163212 ) -> None :
164- """Write one variable from `ds` to a CMOR-compliant NetCDF file.
165-
166- Parameters
167- ----------
168- ds : xr.Dataset
169- Dataset containing the variable and its coordinates.
170- varname : str
171- Name of the variable in `ds` to CMORize.
172- vdef : Any
173- An object with fields: ``name``, ``realm``,
174- optional ``units``, ``positive``, and optional
175- ``levels`` dict (see :meth:`_define_axes`).
176- This is typically a light-weight holder.
177- outdir : Path
178- Output directory for the CMORized NetCDF file.
179- """
180- # Load the appropriate CMOR table without relying on exceptions.
181- realm = getattr (vdef , "realm" , "Amon" )
182- candidate7 = Path (self .tables_path ) / f"CMIP7_{ realm } .json"
183- candidate6 = Path (self .tables_path ) / f"CMIP6_{ realm } .json"
213+ """Write one variable from `ds` to a CMOR-compliant NetCDF file."""
214+ # Pick CMOR table: prefer vdef.table, else vdef.realm (default Amon)
215+ table_key = getattr (vdef , "table" , None ) or getattr (vdef , "realm" , "Amon" )
216+ table_key = str (table_key )
217+ candidate7 = Path (self .tables_path ) / f"CMIP7_{ table_key } .json"
218+ candidate6 = Path (self .tables_path ) / f"CMIP6_{ table_key } .json"
219+
184220 if candidate7 .exists ():
185221 cmor .load_table (str (candidate7 ))
186222 elif candidate6 .exists ():
187223 cmor .load_table (str (candidate6 ))
188224 else :
189- # Fall back to passing the bare table name; CMOR will search its inpath.
190- # This branch avoids broad exception handling while still being flexible.
191- cmor .load_table (f"CMIP7_{ realm } .json" )
225+ # Let CMOR search inpath; will raise if not found
226+ cmor .load_table (f"CMIP7_{ table_key } .json" )
192227
193228 axes_ids = self ._define_axes (ds , vdef )
194- units = getattr (vdef , "units" , "" )
229+
230+ units = getattr (vdef , "units" , "" ) or ""
195231 var_id = cmor .variable (
196232 getattr (vdef , "name" , varname ),
197233 units ,
198234 axes_ids ,
199235 positive = getattr (vdef , "positive" , None ),
200236 )
201237
202- # Optional variable attributes (e.g., cell_methods)
238+ # Optional variable attributes (e.g., cell_methods, long_name, standard_name )
203239 if getattr (vdef , "cell_methods" , None ):
204240 cmor .set_variable_attribute (var_id , "cell_methods" , vdef .cell_methods )
241+ if getattr (vdef , "long_name" , None ):
242+ cmor .set_variable_attribute (var_id , "long_name" , vdef .long_name )
243+ if getattr (vdef , "standard_name" , None ):
244+ cmor .set_variable_attribute (var_id , "standard_name" , vdef .standard_name )
205245
206- # Ensure time is the leading dimension if present
207246 data = ds [varname ]
208247 if "time" in data .dims :
209248 data = data .transpose ("time" , ...)
210249
211- cmor .write (var_id , data .values , ntimes_passed = data .sizes .get ("time" , 1 ))
250+ # CMOR expects a NumPy array; this will materialize data as needed.
251+ cmor .write (var_id , np .asarray (data ), ntimes_passed = data .sizes .get ("time" , 1 ))
212252
213253 outdir = Path (outdir )
214254 outdir .mkdir (parents = True , exist_ok = True )
0 commit comments