Skip to content

Commit 9136ea7

Browse files
committed
Add initial complete_run_script.py
1 parent ab8f23d commit 9136ea7

File tree

2 files changed

+225
-1
lines changed

2 files changed

+225
-1
lines changed

tests/complete_run_script.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -199,4 +199,6 @@
199199
"aerosol_budget",
200200
"tropical_subseasonal",
201201
]
202-
runner.run_diags(params)
202+
203+
if __name__ == "__main__":
204+
runner.run_diags(params)

tests/test_regression.py

Lines changed: 222 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,222 @@
1+
import glob
2+
3+
import numpy as np
4+
import pytest
5+
import xarray as xr
6+
7+
from e3sm_diags.derivations.derivations import DERIVED_VARIABLES
8+
from e3sm_diags.logger import custom_logger
9+
from tests.complete_run_script import params, runner
10+
11+
12+
logger = custom_logger(__name__)
13+
14+
15+
DEV_DIR = "843-migration-phase3-model-vs-obs"
16+
DEV_PATH = f"/global/cfs/cdirs/e3sm/www/cdat-migration-fy24/{DEV_DIR}/"
17+
18+
DEV_GLOB = sorted(glob.glob(DEV_PATH + "**/**/*.nc"))
19+
DEV_NUM_FILES = len(DEV_GLOB)
20+
21+
MAIN_DIR = "main"
22+
MAIN_PATH = f"/global/cfs/cdirs/e3sm/www/cdat-migration-fy24/{MAIN_DIR}/"
23+
MAIN_GLOB = sorted(glob.glob(MAIN_PATH + "**/**/*.nc"))
24+
MAIN_NUM_FILES = len(MAIN_GLOB)
25+
26+
# Absolute and relative tolerance levels for comparison of the data.
27+
# Absolute is in floating point terms, relative is in percentage terms.
28+
ATOL = 0
29+
RTOL = 1e-5
30+
31+
32+
@pytest.fixture(scope="module")
33+
def run_diags_and_get_results_dir() -> str:
34+
"""Run the diagnostics and get the results directory containing the images.
35+
36+
The scope of this fixture is at the module level so that it only runs
37+
once, then each individual test can reference the result directory.
38+
39+
Returns
40+
-------
41+
str
42+
The path to the results directory.
43+
"""
44+
results = runner.run_diags(params)
45+
46+
if results is not None:
47+
results_dir = results[0].results_dir
48+
else:
49+
results_dir = params[0].results_dir
50+
51+
return results_dir
52+
53+
54+
class TestRegression:
55+
@pytest.fixture(autouse=True)
56+
def setup(self, run_diags_and_get_results_dir):
57+
# TODO: We need to store `main` results on a data container
58+
self.results_dir = run_diags_and_get_results_dir
59+
60+
def test_check_if_files_found(self):
61+
if DEV_NUM_FILES == 0 or MAIN_NUM_FILES == 0:
62+
raise IOError(
63+
"No files found at DEV_PATH and/or MAIN_PATH. "
64+
f"Please check {DEV_PATH} and {MAIN_PATH}."
65+
)
66+
67+
def test_check_if_matching_filecount(self):
68+
if DEV_NUM_FILES != MAIN_NUM_FILES:
69+
raise IOError(
70+
"Number of files do not match at DEV_PATH and MAIN_PATH "
71+
f"({DEV_NUM_FILES} vs. {MAIN_NUM_FILES})."
72+
)
73+
74+
logger.info(f"Matching file count ({DEV_NUM_FILES} and {MAIN_NUM_FILES}).")
75+
76+
def test_check_if_missing_files(self):
77+
missing_dev_files, missing_main_files = _check_if_missing_files()
78+
79+
assert len(missing_dev_files) == 0
80+
assert len(missing_main_files) == 0
81+
82+
def test_get_relative_diffs(self):
83+
results = _get_relative_diffs()
84+
85+
assert len(results["missing_files"]) == 0
86+
assert len(results["missing_vars"]) == 0
87+
assert len(results["matching_files"]) > 0
88+
assert len(results["mismatch_errors"]) == 0
89+
assert len(results["not_equal_errors"]) == 0
90+
assert len(results["key_errors"]) == 0
91+
92+
93+
def _get_relative_diffs():
94+
results = {
95+
"missing_files": [],
96+
"missing_vars": [],
97+
"matching_files": [],
98+
"mismatch_errors": [],
99+
"not_equal_errors": [],
100+
"key_errors": [],
101+
}
102+
103+
for fp_main in MAIN_GLOB:
104+
fp_dev = fp_main.replace(MAIN_DIR, DEV_DIR)
105+
106+
logger.info("Comparing:")
107+
logger.info(f" * {fp_dev}")
108+
logger.info(f" * {fp_main}")
109+
110+
try:
111+
ds1 = xr.open_dataset(fp_dev)
112+
ds2 = xr.open_dataset(fp_main)
113+
except FileNotFoundError as e:
114+
logger.info(f" {e}")
115+
116+
if isinstance(e, FileNotFoundError) or isinstance(e, OSError):
117+
results["missing_files"].append(fp_dev)
118+
119+
continue
120+
121+
var_key = fp_main.split("-")[-3]
122+
123+
# for 3d vars such as T-200
124+
var_key.isdigit()
125+
if var_key.isdigit():
126+
var_key = fp_main.split("-")[-4]
127+
128+
dev_data = _get_var_data(ds1, var_key)
129+
main_data = _get_var_data(ds2, var_key)
130+
131+
logger.info(f" * var_key: {var_key}")
132+
133+
if dev_data is None or main_data is None:
134+
if dev_data is None:
135+
results["missing_vars"].append(fp_dev)
136+
elif main_data is None:
137+
results["missing_vars"].append(fp_main)
138+
139+
logger.error(" * Could not find variable key in the dataset(s)")
140+
141+
continue
142+
143+
try:
144+
np.testing.assert_allclose(
145+
dev_data,
146+
main_data,
147+
atol=ATOL,
148+
rtol=RTOL,
149+
)
150+
results["matching_files"].append(fp_main)
151+
except (KeyError, AssertionError) as e:
152+
msg = str(e)
153+
154+
logger.info(f" {msg}")
155+
156+
if "mismatch" in msg:
157+
results["mismatch_errors"].append(fp_dev)
158+
elif "Not equal to tolerance" in msg:
159+
results["not_equal_errors"].append(fp_dev)
160+
else:
161+
logger.info(f" * All close and within relative tolerance ({RTOL})")
162+
163+
return results
164+
165+
166+
def _get_var_data(ds: xr.Dataset, var_key: str) -> np.ndarray | None:
167+
"""Retrieve variable data from an xarray Dataset.
168+
169+
Parameters
170+
----------
171+
ds : xr.Dataset
172+
The xarray Dataset from which to retrieve the variable data.
173+
var_key : str
174+
The key of the variable to retrieve.
175+
176+
Returns
177+
-------
178+
np.ndarray
179+
The data of the specified variable as a NumPy array. If the variable is
180+
not found, returns None.
181+
182+
Raises
183+
------
184+
KeyError
185+
If the variable key is not found in the Dataset and is not a derived
186+
variable.
187+
"""
188+
data = None
189+
190+
try:
191+
var_keys = DERIVED_VARIABLES[var_key].keys()
192+
except KeyError:
193+
var_keys = DERIVED_VARIABLES[var_key.upper()].keys()
194+
195+
var_keys = [var_key] + list(sum(var_keys, ()))
196+
197+
for key in var_keys:
198+
if key in ds.data_vars.keys():
199+
data = ds[key].values
200+
201+
break
202+
203+
return data
204+
205+
206+
def _check_if_missing_files():
207+
missing_dev_files = []
208+
missing_main_files = []
209+
210+
for fp_main in MAIN_GLOB:
211+
fp_dev = fp_main.replace(MAIN_DIR, DEV_DIR)
212+
213+
if fp_dev not in DEV_GLOB:
214+
missing_dev_files.append(fp_dev)
215+
216+
for fp_dev in DEV_GLOB:
217+
fp_main = fp_dev.replace(DEV_DIR, MAIN_DIR)
218+
219+
if fp_main not in MAIN_GLOB:
220+
missing_main_files.append(fp_main)
221+
222+
return missing_dev_files, missing_main_files

0 commit comments

Comments
 (0)