Skip to content

Commit 597d415

Browse files
committed
add more unit tests
1 parent ad63fd7 commit 597d415

File tree

7 files changed

+307
-35
lines changed

7 files changed

+307
-35
lines changed

.github/workflows/pylint.yml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,12 @@ jobs:
6767
with:
6868
python-version: ${{ matrix.python-version }}
6969
cache: 'pip'
70+
- name: Install NetCDF
71+
run: sudo apt-get update && sudo apt-get install -y libnetcdf-dev
72+
- name: Install ESMF
73+
uses: esmf-org/install-esmf-action@v1
74+
with:
75+
version: latest
7076

7177
- name: Install Poetry
7278
run: |

cmip7_prep/cmor_writer.py

Lines changed: 44 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,9 @@ def packaged_dataset_json(filename: str = "cmor_dataset.json") -> Any:
3535
cmor.dataset_json(str(p))
3636
"""
3737
res = ir_files("cmip7_prep.data").joinpath(filename)
38-
return as_file(res)
38+
p = as_file(res)
39+
print(f" p is {p}")
40+
return p
3941

4042

4143
# ---------------------------------------------------------------------
@@ -150,6 +152,19 @@ def _get_attr(name: str):
150152
return cmor.getGblAttr(name) # type: ignore[attr-defined]
151153

152154

155+
def _resolve_table_filename(tables_path: Path, table_key: str) -> str:
156+
"""Return basename like 'CMIP6_Amon.json' or 'CMIP7_Amon.json' based on tables_path."""
157+
# If caller passed an explicit filename, keep it.
158+
if table_key.endswith(".json"):
159+
return table_key
160+
pstr = str(tables_path)
161+
if "CMIP6" in pstr:
162+
return f"CMIP6_{table_key}.json"
163+
if "CMIP7" in pstr:
164+
return f"CMIP7_{table_key}.json"
165+
return f"{table_key}.json"
166+
167+
153168
# ---------------------------------------------------------------------
154169
# CMOR session
155170
# ---------------------------------------------------------------------
@@ -178,26 +193,29 @@ def __init__(
178193

179194
def __enter__(self) -> "CmorSession":
180195
"""Initialize CMOR and register dataset metadata."""
181-
cmor.setup(inpath=self.tables_path, netcdf_file_action=cmor.CMOR_REPLACE_3)
182-
183-
if self.dataset_json_path:
184-
cmor.dataset_json(self.dataset_json_path)
185-
elif self.dataset_attrs:
186-
for key, value in self.dataset_attrs.items():
187-
cmor.set_cur_dataset_attribute(key, value)
188-
else:
189-
# Fallback to packaged cmor_dataset.json if available
190-
with packaged_dataset_json() as p:
191-
cmor.dataset_json(str(p))
192-
193-
# product must be exactly "model-output" for CMIP6 tables
194-
if "cmip6" in str(self.tables_path):
195-
try:
196-
prod = cmor.get_cur_dataset_attribute("product") # type: ignore[attr-defined]
197-
except Exception: # pylint: disable=broad-except
198-
prod = None
199-
if prod != "model-output":
200-
cmor.set_cur_dataset_attribute("product", "model-output")
196+
# Setup CMOR with the Tables directory
197+
replace_flag = getattr(cmor, "CMOR_REPLACE_3", getattr(cmor, "CMOR_REPLACE", 0))
198+
verbosity = getattr(cmor, "CMOR_NORMAL", getattr(cmor, "CMOR_VERBOSE", 0))
199+
cmor.setup(
200+
inpath=str(self.tables_path),
201+
netcdf_file_action=replace_flag,
202+
set_verbosity=verbosity,
203+
)
204+
205+
# ALWAYS seed CMOR’s internal dataset state from a JSON file,
206+
# then apply dataset_attrs as overrides.
207+
p = (
208+
self.dataset_json_path
209+
if self.dataset_json_path is not None
210+
else packaged_dataset_json()
211+
)
212+
cmor.dataset_json(str(p))
213+
try:
214+
prod = cmor.get_cur_dataset_attribute("product") # type: ignore[attr-defined]
215+
except Exception: # pylint: disable=broad-except
216+
prod = None
217+
if prod != "model-output":
218+
cmor.set_cur_dataset_attribute("product", "model-output")
201219

202220
# long paragraph; split to keep lines < 100
203221
inst = _get_attr("institution_id") or "NCAR"
@@ -354,19 +372,11 @@ def write_variable(
354372
) -> None:
355373
"""Write one variable from `ds` to a CMOR-compliant NetCDF file."""
356374
# Pick CMOR table: prefer vdef.table, else vdef.realm (default Amon)
357-
table_key = getattr(vdef, "table", None) or getattr(vdef, "realm", "Amon")
358-
table_key = str(table_key)
359-
candidate7 = Path(self.tables_path) / f"CMIP7_{table_key}.json"
360-
candidate6 = Path(self.tables_path) / f"CMIP6_{table_key}.json"
361-
print(f"table_key is {table_key} candidate6 is {candidate6}")
362-
363-
if candidate7.exists():
364-
cmor.load_table(str(candidate7))
365-
elif candidate6.exists():
366-
cmor.load_table(str(candidate6))
367-
else:
368-
# Let CMOR search inpath; will raise if not found
369-
cmor.load_table(f"CMIP7_{table_key}.json")
375+
table_key = (
376+
getattr(vdef, "table", None) or getattr(vdef, "realm", None) or "Amon"
377+
)
378+
table_filename = _resolve_table_filename(self.tables_path, table_key)
379+
cmor.load_table(table_filename)
370380

371381
axes_ids = self._define_axes(ds, vdef)
372382

pyproject.toml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
21
[tool.poetry]
32
name = "cmip7-prep"
43
version = "0.1.0"

tests/conftest.py

Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,131 @@
1+
"""Shared test fixtures and a FakeCMOR stand-in for unit tests."""
2+
3+
from __future__ import annotations
4+
5+
from pathlib import Path
6+
from typing import Any, Tuple
7+
import uuid
8+
import numpy as np
9+
import pytest
10+
11+
# Import production module at top-level (avoids C0415)
12+
from cmip7_prep import cmor_writer as cw
13+
14+
15+
class FakeCMOR: # pylint: disable=too-many-instance-attributes
16+
"""Minimal CMOR stand-in that mimics constants and API surface used by cmor_writer."""
17+
18+
# Common constants across CMOR builds
19+
CMOR_REPLACE = 0
20+
CMOR_REPLACE_3 = 0
21+
CMOR_APPEND = 1
22+
CMOR_APPEND_3 = 1
23+
CMOR_NORMAL = 1
24+
CMOR_VERBOSE = 2
25+
CMOR_QUIET = 0
26+
27+
def __init__(self) -> None:
28+
"""Create a fake CMOR session state."""
29+
self.attrs: dict[str, Any] = {}
30+
self.inpath: str | None = None
31+
self.dataset_json_path: str | None = None
32+
self.last_table: str | None = None
33+
self.axis_calls: list[tuple] = []
34+
self.variable_calls: list[tuple] = []
35+
self.write_calls: list[tuple] = []
36+
self.closed = False
37+
self.closed_file = ""
38+
39+
# --- core API used by CmorSession ---
40+
# pylint: disable=unused-argument
41+
def setup(
42+
self,
43+
inpath: str,
44+
netcdf_file_action: int | None = None,
45+
set_verbosity: int | None = None,
46+
) -> None:
47+
"""Record the tables path."""
48+
self.inpath = str(inpath)
49+
50+
def dataset_json(self, rcfile: str) -> None:
51+
"""Record the dataset JSON path and seed a few CV attributes."""
52+
self.dataset_json_path = str(rcfile)
53+
ip = Path(self.inpath or ".")
54+
self.attrs.setdefault("_controlled_vocabulary_file", str(ip / "CMIP6_CV.json"))
55+
self.attrs.setdefault("_AXIS_ENTRY_FILE", str(ip / "CMIP6_coordinate.json"))
56+
self.attrs.setdefault("_FORMULA_VAR_FILE", str(ip / "CMIP6_formula_terms.json"))
57+
58+
# New API names
59+
def set_cur_dataset_attribute(self, key: str, value: Any) -> None:
60+
"""Set a dataset/global attribute."""
61+
self.attrs[str(key)] = value
62+
63+
def get_cur_dataset_attribute(self, key: str) -> Any:
64+
"""Get a dataset/global attribute."""
65+
return self.attrs.get(str(key), "")
66+
67+
# Legacy API names (keep camelCase to match CMOR) # pylint: disable=invalid-name
68+
def setGblAttr(self, key: str, value: Any) -> None: # noqa: N802
69+
"""Legacy alias for set_cur_dataset_attribute."""
70+
self.set_cur_dataset_attribute(key, value)
71+
72+
def getGblAttr(self, key: str) -> Any: # noqa: N802
73+
"""Legacy alias for get_cur_dataset_attribute."""
74+
return self.get_cur_dataset_attribute(key)
75+
76+
# Tables & variable definitions
77+
def load_table(self, name: str) -> int:
78+
"""Record last table loaded and return a fake handle."""
79+
self.last_table = str(name)
80+
return 0
81+
82+
def axis(
83+
self, table_entry: str, units: str, coord_vals, cell_bounds=None, **_
84+
) -> int:
85+
"""Record axis definition and return a fake axis id."""
86+
self.axis_calls.append(
87+
(
88+
table_entry,
89+
units,
90+
np.asarray(coord_vals),
91+
None if cell_bounds is None else np.asarray(cell_bounds),
92+
)
93+
)
94+
return len(self.axis_calls)
95+
96+
def variable(self, table_entry: str, units: str, axis_ids, **_) -> int:
97+
"""Record variable definition and return a fake var id."""
98+
self.variable_calls.append((table_entry, units, tuple(axis_ids)))
99+
return 10
100+
101+
def write(self, var_id: int, data, **_) -> None: # pylint: disable=unused-argument
102+
"""Record a write call; auto-generate tracking_id from tracking_prefix if needed."""
103+
tid = self.attrs.get("tracking_id", "")
104+
prefix = self.attrs.get("tracking_prefix", "")
105+
if (
106+
(tid == "" or tid is None)
107+
and isinstance(prefix, str)
108+
and prefix.startswith("hdl:")
109+
):
110+
self.attrs["tracking_id"] = f"{prefix}{uuid.uuid4()}"
111+
self.write_calls.append((var_id, np.asarray(data)))
112+
113+
# pylint: disable=unused-argument
114+
def close(self, var_id: int | None = None, file_name: str | None = None) -> None:
115+
"""Close the current object/file."""
116+
self.closed = True
117+
self.closed_file = file_name or ""
118+
119+
120+
@pytest.fixture()
121+
def fake_cmor(monkeypatch, tmp_path) -> Tuple[FakeCMOR, Path]:
122+
"""Patch cmor_writer.cmor with FakeCMOR and return (fake, tables_path)."""
123+
fake = FakeCMOR()
124+
monkeypatch.setattr(cw, "cmor", fake, raising=True)
125+
tables = tmp_path / "CMIP6_Tables"
126+
tables.mkdir()
127+
return fake, tables
128+
129+
130+
# Back-compat for any older tests that reference `_FakeCMOR`
131+
_FakeCMOR = FakeCMOR # noqa: N816

tests/test_cmor_sessions_attrs.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
# tests/test_cmor_session_attrs.py
2+
"""test cmor session attributes"""
3+
from cmip7_prep import cmor_writer as cw
4+
5+
6+
def test_session_sets_tracking_prefix_and_normalizes_license_product(fake_cmor):
7+
"""test tracking_prefix and license strings"""
8+
fake, tables_path = fake_cmor
9+
10+
# No tracking_id provided; product wrong to test normalization
11+
sess = cw.CmorSession(
12+
tables_path=tables_path,
13+
dataset_attrs={
14+
"institution_id": "NCAR",
15+
"product": "output", # will be normalized to model-output
16+
},
17+
)
18+
with sess:
19+
pass
20+
21+
# setup/dataset_json called
22+
assert fake.inpath == str(tables_path)
23+
assert isinstance(fake.dataset_json_path, str)
24+
25+
# tracking_prefix set; tracking_id cleared (so CMOR can generate)
26+
assert fake.attrs.get("tracking_prefix") == "hdl:21.14100/"
27+
assert fake.attrs.get("tracking_id") == ""
28+
29+
# product normalized
30+
assert fake.attrs.get("product") == "model-output"
31+
32+
# license is a long paragraph; just sanity check start and URL
33+
lic = fake.attrs.get("license", "")
34+
assert lic.startswith("CMIP6 model data produced by NCAR")
35+
assert "https://creativecommons.org/licenses/by/4.0/" in lic

tests/test_cmor_writer_helpers.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
"""tests/test_cmor_writer_helpers.py"""
2+
3+
import numpy as np
4+
import pytest
5+
6+
from cmip7_prep.cmor_writer import _encode_time_to_num, _encode_time_bounds_to_num
7+
8+
cftime = pytest.importorskip("cftime")
9+
10+
11+
def test_encode_time_to_num_with_cftime_no_leap():
12+
"""test 0001-01-16 12:00, 0001-02-15 12:00 in noleap"""
13+
t0 = cftime.DatetimeNoLeap(1, 1, 16, 12, 0, 0)
14+
t1 = cftime.DatetimeNoLeap(1, 2, 15, 12, 0, 0)
15+
arr = np.array([t0, t1], dtype=object)
16+
17+
out = _encode_time_to_num(arr, units="days since 0001-01-01", calendar="noleap")
18+
19+
assert out.shape == (2,)
20+
assert out.dtype == np.float64
21+
assert np.all(np.isfinite(out))
22+
assert np.all(np.diff(out) > 0) # strictly increasing
23+
24+
25+
def test_encode_time_bounds_to_num_shape_and_order():
26+
"""test encode time bounds"""
27+
tb = np.array(
28+
[
29+
[cftime.DatetimeNoLeap(1, 1, 1), cftime.DatetimeNoLeap(1, 1, 31)],
30+
[cftime.DatetimeNoLeap(1, 2, 1), cftime.DatetimeNoLeap(1, 2, 28)],
31+
],
32+
dtype=object,
33+
)
34+
out = _encode_time_bounds_to_num(tb, "days since 0001-01-01", "noleap")
35+
assert out.shape == (2, 2)
36+
assert np.all(out[:, 1] >= out[:, 0])
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
"""tests/test_cmor_writer_load_table.py"""
2+
3+
import types
4+
import numpy as np
5+
import xarray as xr
6+
7+
from cmip7_prep import cmor_writer as cw
8+
9+
10+
def test_write_variable_loads_basename_table_and_defines_axes(fake_cmor, tmp_path):
11+
"""test load basename table and define axes"""
12+
fake, tables_path = fake_cmor
13+
14+
# Build minimal numeric-time dataset to avoid cftime dependence
15+
lat = xr.DataArray(np.array([-89.5, 89.5]), dims=("lat",))
16+
lon = xr.DataArray(np.array([0.5, 1.5, 2.5]), dims=("lon",))
17+
time = xr.DataArray(np.array([0.5], dtype="f8"), dims=("time",))
18+
tb = xr.DataArray(np.array([[0.0, 1.0]], dtype="f8"), dims=("time", "nbnd"))
19+
20+
tas = xr.DataArray(
21+
np.ones((1, 2, 3), dtype="f4"),
22+
dims=("time", "lat", "lon"),
23+
coords={"time": time, "lat": lat, "lon": lon},
24+
name="tas",
25+
)
26+
ds = xr.Dataset({"tas": tas, "time_bounds": tb})
27+
ds["time"].attrs["bounds"] = "time_bounds"
28+
29+
outdir = tmp_path / "out"
30+
outdir.mkdir()
31+
32+
# minimal vdef with table name
33+
vdef = types.SimpleNamespace(name="tas", table="Amon", units="K")
34+
35+
# pylint: disable=using-constant-test
36+
with (
37+
cw.CmurSession
38+
if False
39+
else cw.CmorSession(
40+
tables_path=tables_path, dataset_attrs={"institution_id": "NCAR"}
41+
)
42+
) as cm: # noqa: E701
43+
cm.write_variable(ds, "tas", vdef, outdir=outdir)
44+
45+
# Table basename should be used (resolved by inpath)
46+
assert fake.last_table == "CMIP6_Amon.json"
47+
48+
# Axis calls: expect time, latitude, longitude in some order; verify entries exist
49+
entries = [a[0] for a in fake.axis_calls]
50+
assert "time" in entries
51+
assert "latitude" in entries
52+
assert "longitude" in entries
53+
54+
# Variable and write were called
55+
assert fake.variable_calls and fake.write_calls

0 commit comments

Comments
 (0)