Skip to content

Commit a4a6fb0

Browse files
authored
Add regression tests (#107)
* Add actual regression testing to Github Actions I ran tidal-Simmons and kpp tests on derecho, and copied the resulting .nc files into new baselines/ directory. I also copied MARBL's netcdf comparison tool into CVMix_tools/ and then updated run_test_suite.sh to (a) run the KPP test with the -nc flag, and compare the contents of each netcdf file to the baseline. Lastly, I updated the github action to make sure the python environment contains xarray, numpy, and netcdf4. * Remove large tidal baseline files Add Bryan-Lewis baselines, and ensure data_memcopy.nc == data_pointer.nc * Add forgotten -nc flag Bryan Lewis test was not writing netcdf files * Add baselines for shear and double diffusion Also update the test suite to compare against these new baselines
1 parent d20b989 commit a4a6fb0

File tree

18 files changed

+425
-2
lines changed

18 files changed

+425
-2
lines changed

.github/workflows/run_test_suite.yml

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,18 @@ jobs:
77
steps:
88
- uses: actions/checkout@v3
99

10+
- name: Setup Python
11+
uses: actions/setup-python@v4
12+
13+
- name: Setup python environment
14+
run: |
15+
pip install numpy xarray netcdf4
16+
1017
- name: Load Environment
1118
run: |
1219
sudo apt-get update
1320
sudo apt install make gfortran netcdf-bin libnetcdf-dev libnetcdff-dev openmpi-bin libopenmpi-dev
21+
1422
- name: Run Test Suite
1523
run: |
1624
./reg_tests/common/setup_inputdata.sh

CVMix_tools/netcdf_comparison.py

Lines changed: 354 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,354 @@
1+
#!/usr/bin/env python
2+
3+
"""
4+
Usage:
5+
$ ./netcdf_comparison.py --baseline BASELINE_FILE --new-file NEW_FILE
6+
--strict {exact,loose} [-r RTOL] [-a ATOL] [-t THRES]
7+
8+
Use xarray and numpy to compare two netcdf files.
9+
For each variable, flag
10+
1. Variables that are present in one file but not the other
11+
2. Variables where the data type doesn't match across files
12+
3. Variables where the dimensions don't match across files
13+
4. Variables where the missing values are not aligned
14+
5. Variables that differ in one of two ways (user specifies which strictness level to use):
15+
a. Variables that are not exactly the same (--strict exact)
16+
b. Variables that are not "close" to each other (--strict loose)
17+
-- For values very close to 0 (<THRES), variables that differ by
18+
more than ATOL are flagged
19+
-- For values larger than THRES variables with a relative difference
20+
of more than RTOL are flagged
21+
22+
There is some knowledge of unit equivalence between mks and cgs.
23+
"""
24+
25+
import logging
26+
27+
##################
28+
29+
# Store default values of rtol, atol, and thres in a global dictionary
30+
# to make it easy to update the default values if necessary
31+
# rtol = 1e-11 fails the cgs vs mks comparison
32+
# (POC_REMIN_DIC and PON_REMIN_NH4 have rel errors of ~1.4e-11)
33+
DEFAULT_TOLS = {'rtol' : 1e-9, 'atol' : 1e-16, 'thres' : 1e-16}
34+
35+
##################
36+
37+
def ds_comparison_exact(ds_base, ds_new):
38+
"""
39+
Compare baseline to new_file using xarray
40+
Will only pass if two files are identical (bfb)
41+
(run with --strict=exact)
42+
"""
43+
# Compare remaining variables
44+
return ds_comparison_loose(ds_base, ds_new, rtol=0, thres=0)
45+
46+
##################
47+
48+
def ds_comparison_loose(ds_base, ds_new, rtol=DEFAULT_TOLS['rtol'], atol=DEFAULT_TOLS['atol'],
49+
thres=DEFAULT_TOLS['thres']):
50+
"""
51+
Compare baseline to new_file using xarray
52+
Will pass if two files are within specified tolerance
53+
(run with --strict=loose)
54+
"""
55+
header_fail, ds_base, ds_new = _reduce_to_matching_variables(ds_base, ds_new)
56+
57+
# Compare remaining variables
58+
return _variable_check_loose(ds_base, ds_new, rtol, atol, thres) or header_fail
59+
60+
##################
61+
62+
def _open_files(baseline, new_file):
63+
"""
64+
Reads two netCDF files, returns xarray datasets
65+
"""
66+
import xarray as xr
67+
68+
logger = logging.getLogger(__name__)
69+
logger.info("Comparing %s to the baseline %s\n", new_file, baseline)
70+
71+
return xr.open_dataset(baseline), xr.open_dataset(new_file)
72+
73+
##################
74+
75+
def _reduce_to_matching_variables(ds_base, ds_new):
76+
"""
77+
Are variables and dimensions the same in both datasets? Check for:
78+
1. Variables in one dataset but not the other
79+
2. Variables that are different types
80+
3. Variables that are different dimensions
81+
* check both name and size of dimension
82+
Returns header_fail, ds_base, ds_new
83+
* header_fail is True if any of three checks above fail
84+
* ds_base and ds_new are same as inputs, except any variables that do not match the
85+
above criteria are dropped from the dataset
86+
"""
87+
logger = logging.getLogger(__name__)
88+
header_fail = False
89+
failed_vars = []
90+
91+
# 1. Any variables in one file but not the other?
92+
base_vars = ds_base.variables
93+
new_vars = ds_new.variables
94+
common_vars = set(base_vars) & set(new_vars)
95+
base_vars = list(set(base_vars) - common_vars)
96+
new_vars = list(set(new_vars) - common_vars)
97+
if base_vars:
98+
header_fail = True
99+
ds_base = ds_base.drop_vars(base_vars)
100+
logger.info("The following variables are in the baseline file but not the new file:")
101+
base_vars.sort()
102+
for var in base_vars:
103+
logger.info("* %s", var)
104+
logger.info("")
105+
if new_vars:
106+
header_fail = True
107+
ds_new = ds_new.drop_vars(new_vars)
108+
logger.info("The following variables are in the new file but not the baseline file:")
109+
new_vars.sort()
110+
for var in new_vars:
111+
logger.info("* %s", var)
112+
logger.info("")
113+
114+
# 2. Can variables be compared?
115+
common_vars = list(common_vars)
116+
common_vars.sort()
117+
for var in common_vars:
118+
err_messages = []
119+
# (a) Are they the same type?
120+
if ds_base[var].dtype != ds_new[var].dtype:
121+
err_messages.append(
122+
"Variable is {} in baseline and {} in new file".format(
123+
ds_base[var].dtype, ds_new[var].dtype
124+
)
125+
)
126+
# (b) Do the dimension names match?
127+
if ds_base[var].dims != ds_new[var].dims:
128+
err_messages.append(
129+
"Baseline dimensions are {} and new file dimensions are {}".format(
130+
ds_base[var].dims, ds_new[var].dims
131+
)
132+
)
133+
# (c) Do the dimension sizes match?
134+
if ds_base[var].data.shape != ds_new[var].data.shape:
135+
err_messages.append(
136+
"Baseline dimensions are {} and new file dimensions are {}".format(
137+
ds_base[var].data.shape, ds_new[var].data.shape
138+
)
139+
)
140+
141+
# Report errors
142+
if _report_errs(var, err_messages):
143+
header_fail = True
144+
failed_vars.append(var)
145+
146+
return header_fail, ds_base.drop_vars(failed_vars), ds_new.drop_vars(failed_vars)
147+
148+
##################
149+
150+
def _get_conversion_factor(ds_base, ds_new, var):
151+
unit_conversion = {key: {} for key in
152+
['cm/s', 'cm','nmol/cm^3', 'nmol/cm^3/s', 'nmol/cm^2/s',
153+
'(nmol/cm^3)^-1 s^-1', 'g/cm^3/s', 'g/cm^2/s', 'meq/m^3', 'meq/m^3/s',
154+
'neq/cm^2/s', 'meq/m^3 cm/s', 'mg/m^3 cm/s']}
155+
unit_conversion['cm/s']['m/s'] = 0.01 # cm/s -> m/s
156+
unit_conversion['cm']['m'] = 0.01 # cm -> m
157+
unit_conversion['nmol/cm^3']['mmol/m^3'] = 1 # nmol/cm^3 -> mmol/m^3
158+
unit_conversion['nmol/cm^3/s']['mmol/m^3/s'] = 1 # nmol/cm^3/s -> mmol/m^3/s
159+
unit_conversion['nmol/cm^2/s']['mmol/m^2/s'] = 0.01 # nmol/cm^2/s -> mmol/m^2/s
160+
unit_conversion['(nmol/cm^3)^-1 s^-1']['(mmol/m^3)^-1 s^-1'] = 1 # same as nmol/cm^3 -> mmol/m^3
161+
unit_conversion['g/cm^3/s']['kg/m^3/s'] = 1000 # g/cm^3/s -> kg/m^3/s
162+
unit_conversion['g/cm^2/s']['kg/m^2/s'] = 10 # g/cm^2/s -> kg/m^2/s
163+
unit_conversion['meq/m^3']['neq/cm^3'] = 1 # meq/m^3 -> neq/cm^3
164+
unit_conversion['meq/m^3/s']['neq/cm^3/s'] = 1 # meq/m^3/s -> neq/cm^3/s
165+
unit_conversion['neq/cm^2/s']['meq/m^2/s'] = 0.01 # meq/m^3 cm/s -> meq/m^2/s
166+
unit_conversion['meq/m^3 cm/s']['neq/cm^2/s'] = 1 # meq/m^3 cm/s -> neq/cm^2/s
167+
unit_conversion['meq/m^3 cm/s']['meq/m^2/s'] = 0.01 # meq/m^3 cm/s -> meq/m^2/s
168+
unit_conversion['mg/m^3 cm/s']['mg/m^2/s'] = 0.01 # mg/m^3 cm/s -> mg/m^2/s
169+
170+
conversion_factor = 1.
171+
if('units' in ds_base[var].attrs and ds_new[var].attrs):
172+
old_units = ds_base[var].attrs['units']
173+
new_units = ds_new[var].attrs['units']
174+
if old_units != new_units:
175+
found=False
176+
if new_units in unit_conversion and old_units in unit_conversion[new_units]:
177+
conversion_factor = unit_conversion[new_units][old_units]
178+
found = True
179+
if not found:
180+
if old_units in unit_conversion and new_units in unit_conversion[old_units]:
181+
conversion_factor = 1. / unit_conversion[old_units][new_units]
182+
found = True
183+
if not found:
184+
raise KeyError(f'Can not convert from {new_units} to {old_units} for {var}')
185+
return conversion_factor
186+
187+
def _variable_check_loose(ds_base, ds_new, rtol, atol, thres):
188+
"""
189+
Assumes both datasets contain the same variables with the same dimensions
190+
Checks:
191+
1. Are NaNs in the same place?
192+
2. If baseline value = 0, then |new value| must be <= atol
193+
3. Absolute vs relative error:
194+
i. If 0 < |baseline value| <= thres then want absolute difference < atol
195+
ii. If |baseline value| > thres, then want relative difference < rtol
196+
Note: if thres = 0 and rtol = 0, then this reduces to an exact test
197+
(Are non-NaN values identical?)
198+
"""
199+
import numpy as np
200+
201+
error_checking = {'var_check_count': 0}
202+
common_vars = list(set(ds_base.variables) & set(ds_new.variables))
203+
common_vars.sort()
204+
205+
for var in common_vars:
206+
error_checking['messages'] = []
207+
208+
# (0) Update units of new file to match baseline
209+
conversion_factor = _get_conversion_factor(ds_base, ds_new, var)
210+
211+
# (1) Are NaNs in the same place?
212+
if var.lower() == 'time':
213+
continue
214+
mask = np.isfinite(ds_base[var].data)
215+
if np.any(mask ^ np.isfinite(ds_new[var].data)):
216+
error_checking['messages'].append('NaNs are not in same place')
217+
218+
# (2) compare everywhere that baseline is 0
219+
if np.any(np.where(ds_base[var].data[mask] == 0,
220+
np.abs(ds_new[var].data[mask]) > atol,
221+
False)):
222+
error_checking['messages'].append(
223+
'Baseline is 0 at some indices where abs value ' +
224+
'of new data is larger than {}'.format(atol))
225+
226+
# (3i) Compare everywhere that 0 < |baseline| <= thres
227+
base_data = np.where((ds_base[var].data[mask] != 0) &
228+
(np.abs(ds_base[var].data[mask]) <= thres),
229+
ds_base[var].data[mask], 0)
230+
new_data = np.where((ds_base[var].data[mask] != 0) &
231+
(np.abs(ds_base[var].data[mask]) <= thres),
232+
conversion_factor*ds_new[var].data[mask], 0)
233+
abs_err = np.abs(new_data - base_data)
234+
if np.any(abs_err > atol):
235+
error_checking['messages'].append("Max absolute error ({}) exceeds {}".format(
236+
np.max(abs_err), atol))
237+
238+
# (3ii) Compare everywhere that |baseline| is > thres
239+
base_data = np.where(np.abs(ds_base[var].data[mask]) > thres,
240+
ds_base[var].data[mask],
241+
0)
242+
new_data = np.where(np.abs(ds_base[var].data[mask]) > thres,
243+
conversion_factor*ds_new[var].data[mask],
244+
0)
245+
rel_err = _compute_rel_err(ds_base[var], base_data, new_data, mask)
246+
if np.any(rel_err > rtol):
247+
if rtol == 0:
248+
abs_err = np.abs(new_data - base_data)
249+
error_checking['messages'].append("Values are not the same everywhere\n{}".format(
250+
" Max relative error: {}\n Max absolute error: {}".format(
251+
np.max(rel_err), np.max(abs_err))
252+
))
253+
else:
254+
error_checking['messages'].append("Max relative error ({}) exceeds {}".format(
255+
np.max(rel_err), rtol))
256+
257+
error_checking['var_check_count'] += _report_errs(var, error_checking['messages'])
258+
259+
return error_checking['var_check_count']>0
260+
261+
##################
262+
263+
def _compute_rel_err(da_base, base_data, new_data, mask):
264+
# denominator for relative error is local max value (3-point stencil)
265+
# note the assumption that column is first dimension
266+
import numpy as np
267+
268+
if 'num_levels' in da_base.dims:
269+
# For variables with a depth dimension, we use a 3-point stencil in depth to find the
270+
# maximum value of the baseline value to use in the denominator of the relative error.
271+
# For the top (bottom) of the column, we use a 2-point stencil with the value
272+
# below (above) the given level
273+
rel_denom = da_base.rolling(num_levels=3, center=True, min_periods=2).max().data[mask]
274+
else:
275+
# For variables without a depth dimension, the denominator is the baseline value
276+
rel_denom = da_base.data[mask]
277+
rel_denom = np.where(np.isfinite(rel_denom), rel_denom, 0)
278+
return np.where(base_data != 0, np.abs(new_data - base_data), 0) / \
279+
np.where(rel_denom != 0, rel_denom, 1)
280+
281+
##################
282+
283+
def _report_errs(var, err_messages):
284+
"""
285+
err_messages is a list of all accumulated errors
286+
"""
287+
logger = logging.getLogger(__name__)
288+
if err_messages:
289+
logger.info("Variable: %s", var)
290+
for err in err_messages:
291+
logger.info("... %s", err)
292+
logger.info("")
293+
return True
294+
return False
295+
296+
##################
297+
298+
def _parse_args():
299+
""" Parse command line arguments
300+
"""
301+
302+
import argparse
303+
304+
parser = argparse.ArgumentParser(description="Compare two netCDF files using xarray and numpy",
305+
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
306+
307+
# Baseline for comparison
308+
parser.add_argument('-b', '--baseline', action='store', dest='baseline', required=True,
309+
help='Baseline for comparison')
310+
311+
# File to compare to baseline
312+
parser.add_argument('-n', '--new-file', action='store', dest='new_file', required=True,
313+
help="File to compare to baseline")
314+
315+
# Tolerances
316+
parser.add_argument('--strict', choices=['exact', 'loose'], required=True,
317+
help="Should files be bit-for-bit [exact] or within some tolerance [loose]")
318+
319+
parser.add_argument('-r', '--rtol', action='store', dest='rtol',
320+
default=DEFAULT_TOLS['rtol'], type=float,
321+
help="Maximum allowable relative tolerance (only if strict=loose)")
322+
323+
parser.add_argument('-a', '--atol', action='store', dest='atol',
324+
default=DEFAULT_TOLS['atol'], type=float,
325+
help="Maximum allowable absolute tolerance (only if strict=loose)")
326+
327+
parser.add_argument('-t', '--thres', action='store', dest='thres',
328+
default=DEFAULT_TOLS['thres'], type=float,
329+
help="Threshold to switch from abs tolerance to rel (only if strict=loose)")
330+
331+
return parser.parse_args()
332+
333+
##################
334+
335+
if __name__ == "__main__":
336+
import os
337+
338+
# Set up logging
339+
# logging.basicConfig(format='%(levelname)s (%(funcName)s): %(message)s', level=logging.INFO)
340+
logging.basicConfig(format='%(message)s', level=logging.INFO)
341+
LOGGER = logging.getLogger(__name__)
342+
343+
args = _parse_args()
344+
ds_base_in, ds_new_in = _open_files(args.baseline, args.new_file)
345+
if args.strict == 'loose':
346+
if ds_comparison_loose(ds_base_in, ds_new_in, args.rtol, args.atol, args.thres):
347+
LOGGER.error("Differences found between files!")
348+
os.sys.exit(1)
349+
LOGGER.info("PASS: All variables match and are within specified tolerance.")
350+
if args.strict == 'exact':
351+
if ds_comparison_exact(ds_base_in, ds_new_in):
352+
LOGGER.error("Differences found between files!")
353+
os.sys.exit(1)
354+
LOGGER.info("PASS: All variables match and have exactly the same values.")

0 commit comments

Comments
 (0)