|
| 1 | +import re |
| 2 | + |
1 | 3 | import cftime as cf |
2 | 4 | import numpy as np |
3 | 5 | import pytest |
4 | 6 | import xarray as xr |
5 | 7 |
|
| 8 | +import ilamb3.dataset as dset |
6 | 9 | from ilamb3.tests.test_run import generate_test_dset |
7 | 10 | from ilamb3.transform import ALL_TRANSFORMS |
8 | 11 |
|
9 | 12 | ALL_TRANSFORMS |
| 13 | +PYTHON_VARIABLE = r"\b[a-zA-Z_][a-zA-Z0-9_]*\b" |
10 | 14 |
|
11 | 15 |
|
12 | | -def gen_msftmz(seed: int = 1): |
| 16 | +def gen_msftmz_dset(seed: int = 1): |
13 | 17 | rs = np.random.RandomState(seed) |
14 | 18 | coords = {} |
15 | 19 | # time, basin, lev, lat |
@@ -37,11 +41,44 @@ def gen_msftmz(seed: int = 1): |
37 | 41 | ) |
38 | 42 |
|
39 | 43 |
|
| 44 | +def gen_expression_dset( |
| 45 | + expr: str, |
| 46 | + var_meta: dict[str, dict[str, float | str]], |
| 47 | + nyear: int = 2, |
| 48 | + nlat: int = 2, |
| 49 | + nlon: int = 4, |
| 50 | + base_seed: int = 1, |
| 51 | +) -> xr.Dataset: |
| 52 | + lhs, rhs_vars = _parse_expr_variables(expr) |
| 53 | + |
| 54 | + ds_list: list[xr.Dataset] = [] |
| 55 | + for i, var in enumerate(rhs_vars): |
| 56 | + meta = var_meta.get(var, {}) |
| 57 | + unit = meta.get("unit", "1") |
| 58 | + scale = meta.get("scale", 20.0) |
| 59 | + shift = meta.get("shift", 0.0) |
| 60 | + seed = base_seed + i # different seed per variable |
| 61 | + |
| 62 | + ds_var = generate_test_dset( |
| 63 | + name=var, |
| 64 | + unit=unit, |
| 65 | + seed=seed, |
| 66 | + nyear=nyear, |
| 67 | + nlat=nlat, |
| 68 | + nlon=nlon, |
| 69 | + scale=scale, |
| 70 | + shift=shift, |
| 71 | + ) |
| 72 | + ds_list.append(ds_var) |
| 73 | + |
| 74 | + return xr.merge(ds_list) |
| 75 | + |
| 76 | + |
40 | 77 | DATA = { |
41 | 78 | "soil_moisture_to_vol_fraction": generate_test_dset( |
42 | 79 | "mrsos", "kg m-2", nyear=1, nlat=2, nlon=4 |
43 | 80 | ), |
44 | | - "msftmz_to_rapid": gen_msftmz(), |
| 81 | + "msftmz_to_rapid": gen_msftmz_dset(), |
45 | 82 | "ocean_heat_content": xr.merge( |
46 | 83 | [ |
47 | 84 | generate_test_dset( |
@@ -80,4 +117,68 @@ def gen_msftmz(seed: int = 1): |
80 | 117 | def test_transform(name, kwargs, out, value): |
81 | 118 | transform = ALL_TRANSFORMS[name](**kwargs) |
82 | 119 | ds = transform(DATA[name]) |
| 120 | + print(ds[out].mean().values) # this value should be the correct one |
83 | 121 | assert np.allclose(value, ds[out].mean().values) |
| 122 | + |
| 123 | + |
| 124 | +def _parse_expr_variables(expr: str): |
| 125 | + lhs, rhs = expr.split("=") |
| 126 | + lhs_vars = re.findall(PYTHON_VARIABLE, lhs) |
| 127 | + rhs_vars = re.findall(PYTHON_VARIABLE, rhs) |
| 128 | + |
| 129 | + assert len(lhs_vars) == 1 |
| 130 | + lhs = lhs_vars[0] |
| 131 | + return lhs, rhs_vars |
| 132 | + |
| 133 | + |
| 134 | +@pytest.mark.parametrize( |
| 135 | + "expr_kwargs,var_meta,value", |
| 136 | + [ |
| 137 | + ( |
| 138 | + {"expr": "albedo = rsus / rsds", "integrate_time": False}, |
| 139 | + { |
| 140 | + "rsus": {"unit": "W m-2", "scale": 150.0, "shift": 0.0}, |
| 141 | + "rsds": {"unit": "W m-2", "scale": 300.0, "shift": 100.0}, |
| 142 | + }, |
| 143 | + 0.34482726580080203, |
| 144 | + ), |
| 145 | + ( |
| 146 | + {"expr": "albedo = rsus / rsds", "integrate_time": True}, |
| 147 | + { |
| 148 | + "rsus": {"unit": "W m-2", "scale": 150.0, "shift": 0.0}, |
| 149 | + "rsds": {"unit": "W m-2", "scale": 300.0, "shift": 100.0}, |
| 150 | + }, |
| 151 | + 0.30226715582441277, |
| 152 | + ), |
| 153 | + ], |
| 154 | +) |
| 155 | +def test_expression(expr_kwargs, var_meta, value): |
| 156 | + expr = expr_kwargs["expr"] |
| 157 | + |
| 158 | + # build a test dataset for whatever variables appear in expr |
| 159 | + ds = gen_expression_dset( |
| 160 | + expr, |
| 161 | + var_meta=var_meta, |
| 162 | + nyear=2, |
| 163 | + nlat=2, |
| 164 | + nlon=4, |
| 165 | + base_seed=1, |
| 166 | + ) |
| 167 | + |
| 168 | + transform = ALL_TRANSFORMS["expression"](**expr_kwargs) |
| 169 | + ds_out = transform(ds) |
| 170 | + |
| 171 | + lhs, rhs_vars = _parse_expr_variables(expr) |
| 172 | + assert lhs in ds_out |
| 173 | + |
| 174 | + if not expr_kwargs.get("integrate_time", False): |
| 175 | + print(ds_out[lhs].mean().values) |
| 176 | + assert np.allclose(ds_out[lhs].mean().values, value) |
| 177 | + else: |
| 178 | + ds_expected = ds.copy() |
| 179 | + for v in rhs_vars: |
| 180 | + arr = ds_expected[v] |
| 181 | + if dset.is_temporal(arr): |
| 182 | + ds_expected[v] = dset.integrate_time(ds_expected, v, mean=True) |
| 183 | + print(ds_out[lhs].mean().values) |
| 184 | + assert np.allclose(ds_out[lhs].mean().values, value) |
0 commit comments