Skip to content

Commit 39e90bb

Browse files
committed
test: transform/expression with temporal integration added
1 parent 8fd7393 commit 39e90bb

File tree

1 file changed

+103
-2
lines changed

1 file changed

+103
-2
lines changed

ilamb3/tests/test_transform.py

Lines changed: 103 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,19 @@
1+
import re
2+
13
import cftime as cf
24
import numpy as np
35
import pytest
46
import xarray as xr
57

8+
import ilamb3.dataset as dset
69
from ilamb3.tests.test_run import generate_test_dset
710
from ilamb3.transform import ALL_TRANSFORMS
811

912
ALL_TRANSFORMS
13+
PYTHON_VARIABLE = r"\b[a-zA-Z_][a-zA-Z0-9_]*\b"
1014

1115

12-
def gen_msftmz(seed: int = 1):
16+
def gen_msftmz_dset(seed: int = 1):
1317
rs = np.random.RandomState(seed)
1418
coords = {}
1519
# time, basin, lev, lat
@@ -37,11 +41,44 @@ def gen_msftmz(seed: int = 1):
3741
)
3842

3943

44+
def gen_expression_dset(
45+
expr: str,
46+
var_meta: dict[str, dict[str, float | str]],
47+
nyear: int = 2,
48+
nlat: int = 2,
49+
nlon: int = 4,
50+
base_seed: int = 1,
51+
) -> xr.Dataset:
52+
lhs, rhs_vars = _parse_expr_variables(expr)
53+
54+
ds_list: list[xr.Dataset] = []
55+
for i, var in enumerate(rhs_vars):
56+
meta = var_meta.get(var, {})
57+
unit = meta.get("unit", "1")
58+
scale = meta.get("scale", 20.0)
59+
shift = meta.get("shift", 0.0)
60+
seed = base_seed + i # different seed per variable
61+
62+
ds_var = generate_test_dset(
63+
name=var,
64+
unit=unit,
65+
seed=seed,
66+
nyear=nyear,
67+
nlat=nlat,
68+
nlon=nlon,
69+
scale=scale,
70+
shift=shift,
71+
)
72+
ds_list.append(ds_var)
73+
74+
return xr.merge(ds_list)
75+
76+
4077
DATA = {
4178
"soil_moisture_to_vol_fraction": generate_test_dset(
4279
"mrsos", "kg m-2", nyear=1, nlat=2, nlon=4
4380
),
44-
"msftmz_to_rapid": gen_msftmz(),
81+
"msftmz_to_rapid": gen_msftmz_dset(),
4582
"ocean_heat_content": xr.merge(
4683
[
4784
generate_test_dset(
@@ -80,4 +117,68 @@ def gen_msftmz(seed: int = 1):
80117
def test_transform(name, kwargs, out, value):
81118
transform = ALL_TRANSFORMS[name](**kwargs)
82119
ds = transform(DATA[name])
120+
print(ds[out].mean().values) # this value should be the correct one
83121
assert np.allclose(value, ds[out].mean().values)
122+
123+
124+
def _parse_expr_variables(expr: str):
125+
lhs, rhs = expr.split("=")
126+
lhs_vars = re.findall(PYTHON_VARIABLE, lhs)
127+
rhs_vars = re.findall(PYTHON_VARIABLE, rhs)
128+
129+
assert len(lhs_vars) == 1
130+
lhs = lhs_vars[0]
131+
return lhs, rhs_vars
132+
133+
134+
@pytest.mark.parametrize(
135+
"expr_kwargs,var_meta,value",
136+
[
137+
(
138+
{"expr": "albedo = rsus / rsds", "integrate_time": False},
139+
{
140+
"rsus": {"unit": "W m-2", "scale": 150.0, "shift": 0.0},
141+
"rsds": {"unit": "W m-2", "scale": 300.0, "shift": 100.0},
142+
},
143+
0.34482726580080203,
144+
),
145+
(
146+
{"expr": "albedo = rsus / rsds", "integrate_time": True},
147+
{
148+
"rsus": {"unit": "W m-2", "scale": 150.0, "shift": 0.0},
149+
"rsds": {"unit": "W m-2", "scale": 300.0, "shift": 100.0},
150+
},
151+
0.30226715582441277,
152+
),
153+
],
154+
)
155+
def test_expression(expr_kwargs, var_meta, value):
156+
expr = expr_kwargs["expr"]
157+
158+
# build a test dataset for whatever variables appear in expr
159+
ds = gen_expression_dset(
160+
expr,
161+
var_meta=var_meta,
162+
nyear=2,
163+
nlat=2,
164+
nlon=4,
165+
base_seed=1,
166+
)
167+
168+
transform = ALL_TRANSFORMS["expression"](**expr_kwargs)
169+
ds_out = transform(ds)
170+
171+
lhs, rhs_vars = _parse_expr_variables(expr)
172+
assert lhs in ds_out
173+
174+
if not expr_kwargs.get("integrate_time", False):
175+
print(ds_out[lhs].mean().values)
176+
assert np.allclose(ds_out[lhs].mean().values, value)
177+
else:
178+
ds_expected = ds.copy()
179+
for v in rhs_vars:
180+
arr = ds_expected[v]
181+
if dset.is_temporal(arr):
182+
ds_expected[v] = dset.integrate_time(ds_expected, v, mean=True)
183+
print(ds_out[lhs].mean().values)
184+
assert np.allclose(ds_out[lhs].mean().values, value)

0 commit comments

Comments
 (0)