|
| 1 | +"""Shared test fixtures and a FakeCMOR stand-in for unit tests.""" |
| 2 | + |
| 3 | +from __future__ import annotations |
| 4 | + |
| 5 | +from pathlib import Path |
| 6 | +from typing import Any, Tuple |
| 7 | +import uuid |
| 8 | +import numpy as np |
| 9 | +import pytest |
| 10 | + |
| 11 | +# Import production module at top-level (avoids C0415) |
| 12 | +from cmip7_prep import cmor_writer as cw |
| 13 | + |
| 14 | + |
| 15 | +class FakeCMOR: # pylint: disable=too-many-instance-attributes |
| 16 | + """Minimal CMOR stand-in that mimics constants and API surface used by cmor_writer.""" |
| 17 | + |
| 18 | + # Common constants across CMOR builds |
| 19 | + CMOR_REPLACE = 0 |
| 20 | + CMOR_REPLACE_3 = 0 |
| 21 | + CMOR_APPEND = 1 |
| 22 | + CMOR_APPEND_3 = 1 |
| 23 | + CMOR_NORMAL = 1 |
| 24 | + CMOR_VERBOSE = 2 |
| 25 | + CMOR_QUIET = 0 |
| 26 | + |
| 27 | + def __init__(self) -> None: |
| 28 | + """Create a fake CMOR session state.""" |
| 29 | + self.attrs: dict[str, Any] = {} |
| 30 | + self.inpath: str | None = None |
| 31 | + self.dataset_json_path: str | None = None |
| 32 | + self.last_table: str | None = None |
| 33 | + self.axis_calls: list[tuple] = [] |
| 34 | + self.variable_calls: list[tuple] = [] |
| 35 | + self.write_calls: list[tuple] = [] |
| 36 | + self.closed = False |
| 37 | + self.closed_file = "" |
| 38 | + |
| 39 | + # --- core API used by CmorSession --- |
| 40 | + # pylint: disable=unused-argument |
| 41 | + def setup( |
| 42 | + self, |
| 43 | + inpath: str, |
| 44 | + netcdf_file_action: int | None = None, |
| 45 | + set_verbosity: int | None = None, |
| 46 | + ) -> None: |
| 47 | + """Record the tables path.""" |
| 48 | + self.inpath = str(inpath) |
| 49 | + |
| 50 | + def dataset_json(self, rcfile: str) -> None: |
| 51 | + """Record the dataset JSON path and seed a few CV attributes.""" |
| 52 | + self.dataset_json_path = str(rcfile) |
| 53 | + ip = Path(self.inpath or ".") |
| 54 | + self.attrs.setdefault("_controlled_vocabulary_file", str(ip / "CMIP6_CV.json")) |
| 55 | + self.attrs.setdefault("_AXIS_ENTRY_FILE", str(ip / "CMIP6_coordinate.json")) |
| 56 | + self.attrs.setdefault("_FORMULA_VAR_FILE", str(ip / "CMIP6_formula_terms.json")) |
| 57 | + |
| 58 | + # New API names |
| 59 | + def set_cur_dataset_attribute(self, key: str, value: Any) -> None: |
| 60 | + """Set a dataset/global attribute.""" |
| 61 | + self.attrs[str(key)] = value |
| 62 | + |
| 63 | + def get_cur_dataset_attribute(self, key: str) -> Any: |
| 64 | + """Get a dataset/global attribute.""" |
| 65 | + return self.attrs.get(str(key), "") |
| 66 | + |
| 67 | + # Legacy API names (keep camelCase to match CMOR) # pylint: disable=invalid-name |
| 68 | + def setGblAttr(self, key: str, value: Any) -> None: # noqa: N802 |
| 69 | + """Legacy alias for set_cur_dataset_attribute.""" |
| 70 | + self.set_cur_dataset_attribute(key, value) |
| 71 | + |
| 72 | + def getGblAttr(self, key: str) -> Any: # noqa: N802 |
| 73 | + """Legacy alias for get_cur_dataset_attribute.""" |
| 74 | + return self.get_cur_dataset_attribute(key) |
| 75 | + |
| 76 | + # Tables & variable definitions |
| 77 | + def load_table(self, name: str) -> int: |
| 78 | + """Record last table loaded and return a fake handle.""" |
| 79 | + self.last_table = str(name) |
| 80 | + return 0 |
| 81 | + |
| 82 | + def axis( |
| 83 | + self, table_entry: str, units: str, coord_vals, cell_bounds=None, **_ |
| 84 | + ) -> int: |
| 85 | + """Record axis definition and return a fake axis id.""" |
| 86 | + self.axis_calls.append( |
| 87 | + ( |
| 88 | + table_entry, |
| 89 | + units, |
| 90 | + np.asarray(coord_vals), |
| 91 | + None if cell_bounds is None else np.asarray(cell_bounds), |
| 92 | + ) |
| 93 | + ) |
| 94 | + return len(self.axis_calls) |
| 95 | + |
| 96 | + def variable(self, table_entry: str, units: str, axis_ids, **_) -> int: |
| 97 | + """Record variable definition and return a fake var id.""" |
| 98 | + self.variable_calls.append((table_entry, units, tuple(axis_ids))) |
| 99 | + return 10 |
| 100 | + |
| 101 | + def write(self, var_id: int, data, **_) -> None: # pylint: disable=unused-argument |
| 102 | + """Record a write call; auto-generate tracking_id from tracking_prefix if needed.""" |
| 103 | + tid = self.attrs.get("tracking_id", "") |
| 104 | + prefix = self.attrs.get("tracking_prefix", "") |
| 105 | + if ( |
| 106 | + (tid == "" or tid is None) |
| 107 | + and isinstance(prefix, str) |
| 108 | + and prefix.startswith("hdl:") |
| 109 | + ): |
| 110 | + self.attrs["tracking_id"] = f"{prefix}{uuid.uuid4()}" |
| 111 | + self.write_calls.append((var_id, np.asarray(data))) |
| 112 | + |
| 113 | + # pylint: disable=unused-argument |
| 114 | + def close(self, var_id: int | None = None, file_name: str | None = None) -> None: |
| 115 | + """Close the current object/file.""" |
| 116 | + self.closed = True |
| 117 | + self.closed_file = file_name or "" |
| 118 | + |
| 119 | + |
| 120 | +@pytest.fixture() |
| 121 | +def fake_cmor(monkeypatch, tmp_path) -> Tuple[FakeCMOR, Path]: |
| 122 | + """Patch cmor_writer.cmor with FakeCMOR and return (fake, tables_path).""" |
| 123 | + fake = FakeCMOR() |
| 124 | + monkeypatch.setattr(cw, "cmor", fake, raising=True) |
| 125 | + tables = tmp_path / "CMIP6_Tables" |
| 126 | + tables.mkdir() |
| 127 | + return fake, tables |
| 128 | + |
| 129 | + |
| 130 | +# Back-compat for any older tests that reference `_FakeCMOR` |
| 131 | +_FakeCMOR = FakeCMOR # noqa: N816 |
0 commit comments