|
| 1 | +#!/usr/bin/env python |
| 2 | + |
| 3 | +""" |
| 4 | + Usage: |
| 5 | + $ ./netcdf_comparison.py --baseline BASELINE_FILE --new-file NEW_FILE |
| 6 | + --strict {exact,loose} [-r RTOL] [-a ATOL] [-t THRES] |
| 7 | +
|
| 8 | + Use xarray and numpy to compare two netcdf files. |
| 9 | + For each variable, flag |
| 10 | + 1. Variables that are present in one file but not the other |
| 11 | + 2. Variables where the data type doesn't match across files |
| 12 | + 3. Variables where the dimensions don't match across files |
| 13 | + 4. Variables where the missing values are not aligned |
| 14 | + 5. Variables that differ in one of two ways (user specifies which strictness level to use): |
| 15 | + a. Variables that are not exactly the same (--strict exact) |
| 16 | + b. Variables that are not "close" to each other (--strict loose) |
| 17 | + -- For values very close to 0 (<THRES), variables that differ by |
| 18 | + more than ATOL are flagged |
| 19 | + -- For values larger than THRES variables with a relative difference |
| 20 | + of more than RTOL are flagged |
| 21 | +
|
| 22 | + There is some knowledge of unit equivalence between mks and cgs. |
| 23 | +""" |
| 24 | + |
| 25 | +import logging |
| 26 | + |
| 27 | +################## |
| 28 | + |
| 29 | +# Store default values of rtol, atol, and thres in a global dictionary |
| 30 | +# to make it easy to update the default values if necessary |
| 31 | +# rtol = 1e-11 fails the cgs vs mks comparison |
| 32 | +# (POC_REMIN_DIC and PON_REMIN_NH4 have rel errors of ~1.4e-11) |
| 33 | +DEFAULT_TOLS = {'rtol' : 1e-9, 'atol' : 1e-16, 'thres' : 1e-16} |
| 34 | + |
| 35 | +################## |
| 36 | + |
| 37 | +def ds_comparison_exact(ds_base, ds_new): |
| 38 | + """ |
| 39 | + Compare baseline to new_file using xarray |
| 40 | + Will only pass if two files are identical (bfb) |
| 41 | + (run with --strict=exact) |
| 42 | + """ |
| 43 | + # Compare remaining variables |
| 44 | + return ds_comparison_loose(ds_base, ds_new, rtol=0, thres=0) |
| 45 | + |
| 46 | +################## |
| 47 | + |
| 48 | +def ds_comparison_loose(ds_base, ds_new, rtol=DEFAULT_TOLS['rtol'], atol=DEFAULT_TOLS['atol'], |
| 49 | + thres=DEFAULT_TOLS['thres']): |
| 50 | + """ |
| 51 | + Compare baseline to new_file using xarray |
| 52 | + Will pass if two files are within specified tolerance |
| 53 | + (run with --strict=loose) |
| 54 | + """ |
| 55 | + header_fail, ds_base, ds_new = _reduce_to_matching_variables(ds_base, ds_new) |
| 56 | + |
| 57 | + # Compare remaining variables |
| 58 | + return _variable_check_loose(ds_base, ds_new, rtol, atol, thres) or header_fail |
| 59 | + |
| 60 | +################## |
| 61 | + |
| 62 | +def _open_files(baseline, new_file): |
| 63 | + """ |
| 64 | + Reads two netCDF files, returns xarray datasets |
| 65 | + """ |
| 66 | + import xarray as xr |
| 67 | + |
| 68 | + logger = logging.getLogger(__name__) |
| 69 | + logger.info("Comparing %s to the baseline %s\n", new_file, baseline) |
| 70 | + |
| 71 | + return xr.open_dataset(baseline), xr.open_dataset(new_file) |
| 72 | + |
| 73 | +################## |
| 74 | + |
| 75 | +def _reduce_to_matching_variables(ds_base, ds_new): |
| 76 | + """ |
| 77 | + Are variables and dimensions the same in both datasets? Check for: |
| 78 | + 1. Variables in one dataset but not the other |
| 79 | + 2. Variables that are different types |
| 80 | + 3. Variables that are different dimensions |
| 81 | + * check both name and size of dimension |
| 82 | + Returns header_fail, ds_base, ds_new |
| 83 | + * header_fail is True if any of three checks above fail |
| 84 | + * ds_base and ds_new are same as inputs, except any variables that do not match the |
| 85 | + above criteria are dropped from the dataset |
| 86 | + """ |
| 87 | + logger = logging.getLogger(__name__) |
| 88 | + header_fail = False |
| 89 | + failed_vars = [] |
| 90 | + |
| 91 | + # 1. Any variables in one file but not the other? |
| 92 | + base_vars = ds_base.variables |
| 93 | + new_vars = ds_new.variables |
| 94 | + common_vars = set(base_vars) & set(new_vars) |
| 95 | + base_vars = list(set(base_vars) - common_vars) |
| 96 | + new_vars = list(set(new_vars) - common_vars) |
| 97 | + if base_vars: |
| 98 | + header_fail = True |
| 99 | + ds_base = ds_base.drop_vars(base_vars) |
| 100 | + logger.info("The following variables are in the baseline file but not the new file:") |
| 101 | + base_vars.sort() |
| 102 | + for var in base_vars: |
| 103 | + logger.info("* %s", var) |
| 104 | + logger.info("") |
| 105 | + if new_vars: |
| 106 | + header_fail = True |
| 107 | + ds_new = ds_new.drop_vars(new_vars) |
| 108 | + logger.info("The following variables are in the new file but not the baseline file:") |
| 109 | + new_vars.sort() |
| 110 | + for var in new_vars: |
| 111 | + logger.info("* %s", var) |
| 112 | + logger.info("") |
| 113 | + |
| 114 | + # 2. Can variables be compared? |
| 115 | + common_vars = list(common_vars) |
| 116 | + common_vars.sort() |
| 117 | + for var in common_vars: |
| 118 | + err_messages = [] |
| 119 | + # (a) Are they the same type? |
| 120 | + if ds_base[var].dtype != ds_new[var].dtype: |
| 121 | + err_messages.append( |
| 122 | + "Variable is {} in baseline and {} in new file".format( |
| 123 | + ds_base[var].dtype, ds_new[var].dtype |
| 124 | + ) |
| 125 | + ) |
| 126 | + # (b) Do the dimension names match? |
| 127 | + if ds_base[var].dims != ds_new[var].dims: |
| 128 | + err_messages.append( |
| 129 | + "Baseline dimensions are {} and new file dimensions are {}".format( |
| 130 | + ds_base[var].dims, ds_new[var].dims |
| 131 | + ) |
| 132 | + ) |
| 133 | + # (c) Do the dimension sizes match? |
| 134 | + if ds_base[var].data.shape != ds_new[var].data.shape: |
| 135 | + err_messages.append( |
| 136 | + "Baseline dimensions are {} and new file dimensions are {}".format( |
| 137 | + ds_base[var].data.shape, ds_new[var].data.shape |
| 138 | + ) |
| 139 | + ) |
| 140 | + |
| 141 | + # Report errors |
| 142 | + if _report_errs(var, err_messages): |
| 143 | + header_fail = True |
| 144 | + failed_vars.append(var) |
| 145 | + |
| 146 | + return header_fail, ds_base.drop_vars(failed_vars), ds_new.drop_vars(failed_vars) |
| 147 | + |
| 148 | +################## |
| 149 | + |
| 150 | +def _get_conversion_factor(ds_base, ds_new, var): |
| 151 | + unit_conversion = {key: {} for key in |
| 152 | + ['cm/s', 'cm','nmol/cm^3', 'nmol/cm^3/s', 'nmol/cm^2/s', |
| 153 | + '(nmol/cm^3)^-1 s^-1', 'g/cm^3/s', 'g/cm^2/s', 'meq/m^3', 'meq/m^3/s', |
| 154 | + 'neq/cm^2/s', 'meq/m^3 cm/s', 'mg/m^3 cm/s']} |
| 155 | + unit_conversion['cm/s']['m/s'] = 0.01 # cm/s -> m/s |
| 156 | + unit_conversion['cm']['m'] = 0.01 # cm -> m |
| 157 | + unit_conversion['nmol/cm^3']['mmol/m^3'] = 1 # nmol/cm^3 -> mmol/m^3 |
| 158 | + unit_conversion['nmol/cm^3/s']['mmol/m^3/s'] = 1 # nmol/cm^3/s -> mmol/m^3/s |
| 159 | + unit_conversion['nmol/cm^2/s']['mmol/m^2/s'] = 0.01 # nmol/cm^2/s -> mmol/m^2/s |
| 160 | + unit_conversion['(nmol/cm^3)^-1 s^-1']['(mmol/m^3)^-1 s^-1'] = 1 # same as nmol/cm^3 -> mmol/m^3 |
| 161 | + unit_conversion['g/cm^3/s']['kg/m^3/s'] = 1000 # g/cm^3/s -> kg/m^3/s |
| 162 | + unit_conversion['g/cm^2/s']['kg/m^2/s'] = 10 # g/cm^2/s -> kg/m^2/s |
| 163 | + unit_conversion['meq/m^3']['neq/cm^3'] = 1 # meq/m^3 -> neq/cm^3 |
| 164 | + unit_conversion['meq/m^3/s']['neq/cm^3/s'] = 1 # meq/m^3/s -> neq/cm^3/s |
| 165 | + unit_conversion['neq/cm^2/s']['meq/m^2/s'] = 0.01 # meq/m^3 cm/s -> meq/m^2/s |
| 166 | + unit_conversion['meq/m^3 cm/s']['neq/cm^2/s'] = 1 # meq/m^3 cm/s -> neq/cm^2/s |
| 167 | + unit_conversion['meq/m^3 cm/s']['meq/m^2/s'] = 0.01 # meq/m^3 cm/s -> meq/m^2/s |
| 168 | + unit_conversion['mg/m^3 cm/s']['mg/m^2/s'] = 0.01 # mg/m^3 cm/s -> mg/m^2/s |
| 169 | + |
| 170 | + conversion_factor = 1. |
| 171 | + if('units' in ds_base[var].attrs and ds_new[var].attrs): |
| 172 | + old_units = ds_base[var].attrs['units'] |
| 173 | + new_units = ds_new[var].attrs['units'] |
| 174 | + if old_units != new_units: |
| 175 | + found=False |
| 176 | + if new_units in unit_conversion and old_units in unit_conversion[new_units]: |
| 177 | + conversion_factor = unit_conversion[new_units][old_units] |
| 178 | + found = True |
| 179 | + if not found: |
| 180 | + if old_units in unit_conversion and new_units in unit_conversion[old_units]: |
| 181 | + conversion_factor = 1. / unit_conversion[old_units][new_units] |
| 182 | + found = True |
| 183 | + if not found: |
| 184 | + raise KeyError(f'Can not convert from {new_units} to {old_units} for {var}') |
| 185 | + return conversion_factor |
| 186 | + |
| 187 | +def _variable_check_loose(ds_base, ds_new, rtol, atol, thres): |
| 188 | + """ |
| 189 | + Assumes both datasets contain the same variables with the same dimensions |
| 190 | + Checks: |
| 191 | + 1. Are NaNs in the same place? |
| 192 | + 2. If baseline value = 0, then |new value| must be <= atol |
| 193 | + 3. Absolute vs relative error: |
| 194 | + i. If 0 < |baseline value| <= thres then want absolute difference < atol |
| 195 | + ii. If |baseline value| > thres, then want relative difference < rtol |
| 196 | + Note: if thres = 0 and rtol = 0, then this reduces to an exact test |
| 197 | + (Are non-NaN values identical?) |
| 198 | + """ |
| 199 | + import numpy as np |
| 200 | + |
| 201 | + error_checking = {'var_check_count': 0} |
| 202 | + common_vars = list(set(ds_base.variables) & set(ds_new.variables)) |
| 203 | + common_vars.sort() |
| 204 | + |
| 205 | + for var in common_vars: |
| 206 | + error_checking['messages'] = [] |
| 207 | + |
| 208 | + # (0) Update units of new file to match baseline |
| 209 | + conversion_factor = _get_conversion_factor(ds_base, ds_new, var) |
| 210 | + |
| 211 | + # (1) Are NaNs in the same place? |
| 212 | + if var.lower() == 'time': |
| 213 | + continue |
| 214 | + mask = np.isfinite(ds_base[var].data) |
| 215 | + if np.any(mask ^ np.isfinite(ds_new[var].data)): |
| 216 | + error_checking['messages'].append('NaNs are not in same place') |
| 217 | + |
| 218 | + # (2) compare everywhere that baseline is 0 |
| 219 | + if np.any(np.where(ds_base[var].data[mask] == 0, |
| 220 | + np.abs(ds_new[var].data[mask]) > atol, |
| 221 | + False)): |
| 222 | + error_checking['messages'].append( |
| 223 | + 'Baseline is 0 at some indices where abs value ' + |
| 224 | + 'of new data is larger than {}'.format(atol)) |
| 225 | + |
| 226 | + # (3i) Compare everywhere that 0 < |baseline| <= thres |
| 227 | + base_data = np.where((ds_base[var].data[mask] != 0) & |
| 228 | + (np.abs(ds_base[var].data[mask]) <= thres), |
| 229 | + ds_base[var].data[mask], 0) |
| 230 | + new_data = np.where((ds_base[var].data[mask] != 0) & |
| 231 | + (np.abs(ds_base[var].data[mask]) <= thres), |
| 232 | + conversion_factor*ds_new[var].data[mask], 0) |
| 233 | + abs_err = np.abs(new_data - base_data) |
| 234 | + if np.any(abs_err > atol): |
| 235 | + error_checking['messages'].append("Max absolute error ({}) exceeds {}".format( |
| 236 | + np.max(abs_err), atol)) |
| 237 | + |
| 238 | + # (3ii) Compare everywhere that |baseline| is > thres |
| 239 | + base_data = np.where(np.abs(ds_base[var].data[mask]) > thres, |
| 240 | + ds_base[var].data[mask], |
| 241 | + 0) |
| 242 | + new_data = np.where(np.abs(ds_base[var].data[mask]) > thres, |
| 243 | + conversion_factor*ds_new[var].data[mask], |
| 244 | + 0) |
| 245 | + rel_err = _compute_rel_err(ds_base[var], base_data, new_data, mask) |
| 246 | + if np.any(rel_err > rtol): |
| 247 | + if rtol == 0: |
| 248 | + abs_err = np.abs(new_data - base_data) |
| 249 | + error_checking['messages'].append("Values are not the same everywhere\n{}".format( |
| 250 | + " Max relative error: {}\n Max absolute error: {}".format( |
| 251 | + np.max(rel_err), np.max(abs_err)) |
| 252 | + )) |
| 253 | + else: |
| 254 | + error_checking['messages'].append("Max relative error ({}) exceeds {}".format( |
| 255 | + np.max(rel_err), rtol)) |
| 256 | + |
| 257 | + error_checking['var_check_count'] += _report_errs(var, error_checking['messages']) |
| 258 | + |
| 259 | + return error_checking['var_check_count']>0 |
| 260 | + |
| 261 | +################## |
| 262 | + |
| 263 | +def _compute_rel_err(da_base, base_data, new_data, mask): |
| 264 | + # denominator for relative error is local max value (3-point stencil) |
| 265 | + # note the assumption that column is first dimension |
| 266 | + import numpy as np |
| 267 | + |
| 268 | + if 'num_levels' in da_base.dims: |
| 269 | + # For variables with a depth dimension, we use a 3-point stencil in depth to find the |
| 270 | + # maximum value of the baseline value to use in the denominator of the relative error. |
| 271 | + # For the top (bottom) of the column, we use a 2-point stencil with the value |
| 272 | + # below (above) the given level |
| 273 | + rel_denom = da_base.rolling(num_levels=3, center=True, min_periods=2).max().data[mask] |
| 274 | + else: |
| 275 | + # For variables without a depth dimension, the denominator is the baseline value |
| 276 | + rel_denom = da_base.data[mask] |
| 277 | + rel_denom = np.where(np.isfinite(rel_denom), rel_denom, 0) |
| 278 | + return np.where(base_data != 0, np.abs(new_data - base_data), 0) / \ |
| 279 | + np.where(rel_denom != 0, rel_denom, 1) |
| 280 | + |
| 281 | +################## |
| 282 | + |
| 283 | +def _report_errs(var, err_messages): |
| 284 | + """ |
| 285 | + err_messages is a list of all accumulated errors |
| 286 | + """ |
| 287 | + logger = logging.getLogger(__name__) |
| 288 | + if err_messages: |
| 289 | + logger.info("Variable: %s", var) |
| 290 | + for err in err_messages: |
| 291 | + logger.info("... %s", err) |
| 292 | + logger.info("") |
| 293 | + return True |
| 294 | + return False |
| 295 | + |
| 296 | +################## |
| 297 | + |
| 298 | +def _parse_args(): |
| 299 | + """ Parse command line arguments |
| 300 | + """ |
| 301 | + |
| 302 | + import argparse |
| 303 | + |
| 304 | + parser = argparse.ArgumentParser(description="Compare two netCDF files using xarray and numpy", |
| 305 | + formatter_class=argparse.ArgumentDefaultsHelpFormatter) |
| 306 | + |
| 307 | + # Baseline for comparison |
| 308 | + parser.add_argument('-b', '--baseline', action='store', dest='baseline', required=True, |
| 309 | + help='Baseline for comparison') |
| 310 | + |
| 311 | + # File to compare to baseline |
| 312 | + parser.add_argument('-n', '--new-file', action='store', dest='new_file', required=True, |
| 313 | + help="File to compare to baseline") |
| 314 | + |
| 315 | + # Tolerances |
| 316 | + parser.add_argument('--strict', choices=['exact', 'loose'], required=True, |
| 317 | + help="Should files be bit-for-bit [exact] or within some tolerance [loose]") |
| 318 | + |
| 319 | + parser.add_argument('-r', '--rtol', action='store', dest='rtol', |
| 320 | + default=DEFAULT_TOLS['rtol'], type=float, |
| 321 | + help="Maximum allowable relative tolerance (only if strict=loose)") |
| 322 | + |
| 323 | + parser.add_argument('-a', '--atol', action='store', dest='atol', |
| 324 | + default=DEFAULT_TOLS['atol'], type=float, |
| 325 | + help="Maximum allowable absolute tolerance (only if strict=loose)") |
| 326 | + |
| 327 | + parser.add_argument('-t', '--thres', action='store', dest='thres', |
| 328 | + default=DEFAULT_TOLS['thres'], type=float, |
| 329 | + help="Threshold to switch from abs tolerance to rel (only if strict=loose)") |
| 330 | + |
| 331 | + return parser.parse_args() |
| 332 | + |
| 333 | +################## |
| 334 | + |
| 335 | +if __name__ == "__main__": |
| 336 | + import os |
| 337 | + |
| 338 | + # Set up logging |
| 339 | +# logging.basicConfig(format='%(levelname)s (%(funcName)s): %(message)s', level=logging.INFO) |
| 340 | + logging.basicConfig(format='%(message)s', level=logging.INFO) |
| 341 | + LOGGER = logging.getLogger(__name__) |
| 342 | + |
| 343 | + args = _parse_args() |
| 344 | + ds_base_in, ds_new_in = _open_files(args.baseline, args.new_file) |
| 345 | + if args.strict == 'loose': |
| 346 | + if ds_comparison_loose(ds_base_in, ds_new_in, args.rtol, args.atol, args.thres): |
| 347 | + LOGGER.error("Differences found between files!") |
| 348 | + os.sys.exit(1) |
| 349 | + LOGGER.info("PASS: All variables match and are within specified tolerance.") |
| 350 | + if args.strict == 'exact': |
| 351 | + if ds_comparison_exact(ds_base_in, ds_new_in): |
| 352 | + LOGGER.error("Differences found between files!") |
| 353 | + os.sys.exit(1) |
| 354 | + LOGGER.info("PASS: All variables match and have exactly the same values.") |
0 commit comments