|
| 1 | + |
1 | 2 | # cmip7_prep/cmor_writer.py |
2 | | -import cmor, numpy as np, xarray as xr |
| 3 | +from __future__ import annotations |
| 4 | +import cmor |
| 5 | +import numpy as np |
| 6 | +import xarray as xr |
3 | 7 | from contextlib import AbstractContextManager |
4 | 8 | from pathlib import Path |
5 | 9 |
|
6 | 10 | class CmorSession(AbstractContextManager): |
7 | | - def __init__(self, tables_path, dataset_attrs): |
| 11 | + """Thin wrapper around CMOR that supports a dataset_json path.""" |
| 12 | + def __init__(self, tables_path: str | Path, dataset_json: str | Path | None = None, dataset_attrs: dict | None = None): |
8 | 13 | self.tables_path = str(tables_path) |
9 | | - self.dataset_attrs = dataset_attrs |
| 14 | + self.dataset_json_path = str(dataset_json) if dataset_json else None |
| 15 | + self.dataset_attrs = dataset_attrs or {} |
10 | 16 |
|
11 | 17 | def __enter__(self): |
12 | 18 | cmor.setup(inpath=self.tables_path, netcdf_file_action=cmor.CMOR_REPLACE_3) |
13 | | - cmor.dataset_json("cmor_dataset.json") # or pass dict via cmor.setGblAttr |
14 | | - cmor.load_table("CMIP7_Amon.json") # swap per realm |
| 19 | + if self.dataset_json_path: |
| 20 | + cmor.dataset_json(self.dataset_json_path) |
| 21 | + elif self.dataset_attrs: |
| 22 | + # Fallback: allow direct attribute injection if a JSON is not provided |
| 23 | + for k, v in self.dataset_attrs.items(): |
| 24 | + cmor.setGblAttr(k, v) |
| 25 | + else: |
| 26 | + # Be explicit so users notice misconfiguration |
| 27 | + raise ValueError("CmorSession requires either dataset_json path or dataset_attrs.") |
15 | 28 | return self |
16 | 29 |
|
17 | 30 | def __exit__(self, exc_type, exc, tb): |
18 | 31 | cmor.close() |
19 | 32 |
|
20 | 33 | def _define_axes(self, ds: xr.Dataset, vdef): |
21 | 34 | axes = [] |
22 | | - # time |
23 | | - axes.append(cmor.axis(table_entry="time", units=str(ds["time"].attrs.get("units","days since 1850-01-01")), |
24 | | - coord_vals=ds["time"].values.astype("float64"))) |
25 | | - # lat/lon (1° target) OR curvilinear (for OCN) |
26 | | - if "lat" in ds.dims and "lon" in ds.dims: |
27 | | - axes.append(cmor.axis(table_entry="lat", units="degrees_north", |
28 | | - coord_vals=ds["lat"].values, |
29 | | - cell_bounds=ds["lat_bnds"].values if "lat_bnds" in ds else None)) |
30 | | - axes.append(cmor.axis(table_entry="lon", units="degrees_east", |
31 | | - coord_vals=ds["lon"].values, |
32 | | - cell_bounds=ds["lon_bnds"].values if "lon_bnds" in ds else None)) |
33 | | - # vertical axes as needed (plev, lev) |
34 | | - # ... |
| 35 | + # time axis (expects proper CF units in ds['time'].attrs['units'], else supply default) |
| 36 | + t_units = ds["time"].attrs.get("units", "days since 1850-01-01") |
| 37 | + tvals = xr.conventions.times.encode_cf_datetime(ds["time"].values, t_units, calendar=ds["time"].attrs.get("calendar", "standard")) if np.issubdtype(ds["time"].dtype, np.datetime64) else ds["time"].values |
| 38 | + axes.append(cmor.axis(table_entry="time", units=str(t_units), coord_vals=tvals)) |
| 39 | + |
| 40 | + # latitude / longitude (regular grid) |
| 41 | + if "lat" in ds.coords and "lon" in ds.coords: |
| 42 | + lat = ds["lat"].values |
| 43 | + lon = ds["lon"].values |
| 44 | + lat_b = ds["lat_bnds"].values if "lat_bnds" in ds else None |
| 45 | + lon_b = ds["lon_bnds"].values if "lon_bnds" in ds else None |
| 46 | + axes.append(cmor.axis(table_entry="lat", units="degrees_north", coord_vals=lat, cell_bounds=lat_b)) |
| 47 | + axes.append(cmor.axis(table_entry="lon", units="degrees_east", coord_vals=lon, cell_bounds=lon_b)) |
| 48 | + |
| 49 | + # pressure or model level axes would be added here if needed based on vdef |
35 | 50 | return axes |
36 | 51 |
|
37 | 52 | def write_variable(self, ds: xr.Dataset, varname: str, vdef, outdir: Path): |
38 | | - # choose correct table per realm before defining axes |
39 | | - cmor.load_table(f"CMIP7_{vdef.realm}.json") # or CMIP6 fallback |
| 53 | + # load proper table per realm (Amon, Lmon, etc.) |
| 54 | + tbl = f"CMIP7_{vdef.realm}.json" |
| 55 | + try: |
| 56 | + cmor.load_table(tbl) |
| 57 | + except Exception: |
| 58 | + # Fallback to CMIP6-style table names if CMIP7 tables aren't available yet |
| 59 | + cmor.load_table(f"CMIP6_{vdef.realm}.json") |
40 | 60 |
|
41 | 61 | axes_ids = self._define_axes(ds, vdef) |
42 | | - var_id = cmor.variable(vdef.name, vdef.units, axes_ids, |
43 | | - positive=vdef.positive if vdef.positive else None) |
44 | | - # cell_methods if provided |
45 | | - if vdef.cell_methods: |
46 | | - cmor.set_variable_attribute(var_id, "cell_methods", vdef.cell_methods) |
| 62 | + var_id = cmor.variable(vdef.name, getattr(vdef, "units", ""), axes_ids, |
| 63 | + positive=getattr(vdef, "positive", None)) |
47 | 64 |
|
48 | | - data = ds[varname].transpose(...).values # time-major |
49 | | - cmor.write(var_id, data, ntimes_passed=data.shape[0]) |
50 | | - |
51 | | - # write fx on target grid if applicable (once per run is enough) |
52 | | - if vdef.realm in ("Amon","Lmon"): |
53 | | - self._write_fx(ds) |
54 | | - |
55 | | - cmor.close(var_id, file_name=str(outdir / f"{vdef.name}.nc")) |
| 65 | + # cell_methods as variable attribute |
| 66 | + if getattr(vdef, "cell_methods", None): |
| 67 | + cmor.set_variable_attribute(var_id, "cell_methods", vdef.cell_methods) |
56 | 68 |
|
57 | | - def _write_fx(self, ds): |
58 | | - # Example: areacella, sftlf on the 1° grid |
59 | | - if {"lat","lon"}.issubset(ds.dims): |
60 | | - cmor.load_table("CMIP7_fx.json") |
61 | | - lat = ds["lat"].values; lon = ds["lon"].values |
62 | | - ax_t = cmor.axis(table_entry="time", units="days since 1850-01-01", coord_vals=np.array([0.0])) |
63 | | - ax_la = cmor.axis(table_entry="lat", units="degrees_north", coord_vals=lat) |
64 | | - ax_lo = cmor.axis(table_entry="lon", units="degrees_east", coord_vals=lon) |
65 | | - # area |
66 | | - if "cell_area" in ds: |
67 | | - v = cmor.variable("areacella", "m2", [ax_la, ax_lo]) |
68 | | - cmor.write(v, ds["cell_area"].values[...]) |
69 | | - cmor.close(v) |
70 | | - # land fraction |
71 | | - if "sftlf" in ds: |
72 | | - v = cmor.variable("sftlf", "%", [ax_la, ax_lo]) |
73 | | - cmor.write(v, ds["sftlf"].values[...]) |
74 | | - cmor.close(v) |
| 69 | + data = ds[varname].values |
| 70 | + # make sure time is leading dimension if present |
| 71 | + if "time" in ds[varname].dims: |
| 72 | + data = ds[varname].transpose("time", ...).values |
| 73 | + cmor.write(var_id, data, ntimes_passed=data.shape[0] if data.ndim >= 1 else 1) |
75 | 74 |
|
| 75 | + # Close and write file |
| 76 | + outfile = Path(outdir) / f"{vdef.name}.nc" |
| 77 | + cmor.close(var_id, file_name=str(outfile)) |
0 commit comments