Skip to content

Commit e9838fc

Browse files
[Bug]: Debug zppy diffs for v3.0.0 (#931)
* Fix RMSE and CORR text position in plots * Fix ref plot max values zonal/meridional mean * Fix rmse/corr values for zonal_mean_2d * Fix metrics and fonts in polar * Fix aerosol_aeronet X and Y axes labels - Add `aerosol_aeronet` to `SETS_USING_LAT_LON_FORMATTER` - Remove unused `plt.legend()` in `aerosol_aeronet_plot.py` to avoid UserWarning: No artists with labels found to put in legend. Note that artists whose label start with an underscore are ignored when legend() is called with no argument. * Update `_hybrid_to_plevs()` to keep more attrs - Keep `long_name`. `long_name` is used for downstream plotting, including zonal_mean_xy. - Update `zonal_mean_xy_plot.py` to use correct `long_name` attribute * Fix relative diff plots displaying mean nans - This was caused by performing arithmetic on the entire regridded and spatially averaged diffs dataset, which also affected other variables such as bounds * Add remote attachment config to workspace --------- Co-authored-by: chengzhuzhang <[email protected]>
1 parent 059655e commit e9838fc

File tree

26 files changed

+4389
-11
lines changed

26 files changed

+4389
-11
lines changed

.vscode/e3sm_diags.code-workspace

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,17 @@
5454
"launch": {
5555
"version": "0.2.0",
5656
"configurations": [
57+
{
58+
"name": "Python: Current File",
59+
"type": "debugpy",
60+
"request": "launch",
61+
"program": "${file}",
62+
"console": "integratedTerminal",
63+
"justMyCode": true,
64+
"env": {
65+
"PYTHONPATH": "${workspaceFolder}"
66+
}
67+
},
5768
{
5869
"name": "Python: Current File",
5970
"type": "debugpy",
Lines changed: 155 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,155 @@
1+
#%%
2+
import glob
3+
from typing import List
4+
5+
import numpy as np
6+
import xarray as xr
7+
8+
from e3sm_diags.derivations.derivations import DERIVED_VARIABLES
9+
import pandas as pd
10+
11+
DEV_DIR = "25-02-04-branch-930-zppy-diffs"
12+
DEV_PATH = f"/lcrc/group/e3sm/public_html/cdat-migration-fy24/{DEV_DIR}/"
13+
14+
DEV_GLOB_ALL = sorted(glob.glob(DEV_PATH + "**/**/*.nc"))
15+
DEV_NUM_FILES = len(DEV_GLOB_ALL)
16+
17+
MAIN_DIR = "25-02-04-main-zppy-diffs"
18+
MAIN_PATH = f"/lcrc/group/e3sm/public_html/cdat-migration-fy24/{MAIN_DIR}/"
19+
MAIN_GLOB_ALL = sorted(glob.glob(MAIN_PATH + "**/**/*.nc"))
20+
MAIN_NUM_FILES = len(MAIN_GLOB_ALL)
21+
22+
#%%
23+
KEEP_VARS = ["OMI-MLS-TCO-ANN-60S60N"]
24+
25+
DEV_GLOB = [fp for fp in DEV_GLOB_ALL if "diff.nc" not in fp and any(var in fp for var in KEEP_VARS)]
26+
MAIN_GLOB = [fp for fp in MAIN_GLOB_ALL if "diff.nc" not in fp and any(var in fp for var in KEEP_VARS)]
27+
28+
DEV_GLOB_DIFF = [fp for fp in DEV_GLOB_ALL if "diff.nc" in fp and any(var in fp for var in KEEP_VARS)]
29+
MAIN_GLOB_DIFF = [fp for fp in MAIN_GLOB_ALL if "diff.nc" in fp and any(var in fp for var in KEEP_VARS)]
30+
31+
def _get_var_data(ds: xr.Dataset, var_key: str) -> np.ndarray:
32+
"""Get the variable data using a list of matching keys.
33+
34+
The `main` branch saves the dataset using the original variable name,
35+
while the dev branch saves the variable with the derived variable name.
36+
The dev branch is performing the expected behavior here.
37+
38+
Parameters
39+
----------
40+
ds : xr.Dataset
41+
_description_
42+
var_key : str
43+
_description_
44+
45+
Returns
46+
-------
47+
np.ndarray
48+
_description_
49+
"""
50+
51+
data = None
52+
53+
try:
54+
data = ds[var_key]
55+
except KeyError:
56+
try:
57+
var_keys = DERIVED_VARIABLES[var_key].keys()
58+
except KeyError:
59+
var_keys = DERIVED_VARIABLES[var_key.upper()].keys()
60+
61+
var_keys = [var_key] + list(sum(var_keys, ()))
62+
63+
for key in var_keys:
64+
if key in ds.data_vars.keys():
65+
data = ds[key]
66+
break
67+
68+
return data
69+
#%%
70+
ATOL = 0
71+
RTOL = 1e-5
72+
73+
print(f"Relative tolerance: {RTOL}, Absolute tolerance: {ATOL}")
74+
75+
def compare_files(main_glob: List[str]):
76+
for fp_main in main_glob:
77+
var_key = fp_main.split("-")[-3]
78+
fp_type = fp_main.split("-")[-1].split("_")[-1]
79+
80+
fp_dev = fp_main.replace(MAIN_DIR, DEV_DIR)
81+
82+
print(f"{var_key} - {fp_type}")
83+
print("-" * 50)
84+
print(f"Main: {fp_main}\nDev: {fp_dev}")
85+
86+
ds_main = xr.open_dataset(fp_main)
87+
ds_dev = xr.open_dataset(fp_dev)
88+
89+
dv_main = _get_var_data(ds_main, var_key)
90+
dv_dev = _get_var_data(ds_dev, var_key)
91+
92+
if dv_main is None:
93+
dv_main = _get_var_data(ds_main, var_key + "_diff")
94+
95+
try:
96+
np.testing.assert_allclose(dv_main.values, dv_dev.values, rtol=RTOL, atol=ATOL)
97+
except AssertionError as e:
98+
print(f"{e}\n")
99+
else:
100+
print(f"Arrays are within relative tolerance.\n")
101+
102+
print("Comparing test.nc and ref.nc files")
103+
print("=" * 50)
104+
compare_files(MAIN_GLOB)
105+
106+
#%%
107+
print("Comparing diff.nc files")
108+
print("=" * 50)
109+
compare_files(MAIN_GLOB_DIFF)
110+
111+
# %%
112+
def compare_stats(main_glob_diff: List[str]):
113+
for fp_main in main_glob_diff:
114+
var_key = fp_main.split("-")[-3]
115+
fp_type = fp_main.split("-")[-1].split("_")[-1]
116+
117+
fp_dev = fp_main.replace(MAIN_DIR, DEV_DIR)
118+
119+
print(f"{var_key} - {fp_type}")
120+
print("-" * 50)
121+
print(f"Main: {fp_main}\nDev: {fp_dev}")
122+
123+
ds_main = xr.open_dataset(fp_main)
124+
ds_dev = xr.open_dataset(fp_dev)
125+
126+
dv_main = _get_var_data(ds_main, var_key)
127+
dv_dev = _get_var_data(ds_dev, var_key)
128+
129+
if dv_main is None:
130+
dv_main = _get_var_data(ds_main, var_key + "_diff")
131+
stats_main = {
132+
"min": dv_main.min().item(),
133+
"max": dv_main.max().item(),
134+
"mean": dv_main.mean().item(),
135+
"sum": dv_main.sum().item(),
136+
"nan_count": np.isnan(dv_main).sum().item()
137+
}
138+
139+
stats_dev = {
140+
"min": dv_dev.min().item(),
141+
"max": dv_dev.max().item(),
142+
"mean": dv_dev.mean().item(),
143+
"sum": dv_dev.sum().item(),
144+
"nan_count": np.isnan(dv_dev).sum().item()
145+
}
146+
147+
df_stats = pd.DataFrame([stats_main, stats_dev], index=["main", "dev"])
148+
print(df_stats)
149+
150+
print("Comparing stats of diff.nc files")
151+
print("=" * 50)
152+
compare_stats(MAIN_GLOB_DIFF)
153+
154+
155+
# %%

0 commit comments

Comments
 (0)