Skip to content

Commit dc92668

Browse files
committed
Update complete_run_script.py to save netCDF
1 parent 9878bd0 commit dc92668

File tree

5 files changed

+60
-13
lines changed

5 files changed

+60
-13
lines changed

e3sm_diags/parameter/meridional_mean_2d_parameter.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ def __init__(self):
1010
super(MeridionalMean2dParameter, self).__init__()
1111
# Override existing attributes
1212
# =============================
13-
self.plevs = numpy.logspace(2.0, 3.0, num=17).tolist()
13+
self.plevs = numpy.logspace(2.0, 3.0, num=17).tolist() # type: ignore
1414
self.plot_log_plevs = False
1515
self.plot_plevs = False
1616
# Granulating plevs causes duplicate plots in this case.

e3sm_diags/parameter/zonal_mean_2d_parameter.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ def __init__(self):
1414
super(ZonalMean2dParameter, self).__init__()
1515
# Override existing attributes
1616
# =============================
17-
self.plevs = copy.deepcopy(DEFAULT_PLEVS)
17+
self.plevs = copy.deepcopy(DEFAULT_PLEVS) # type: ignore
1818
self.plot_log_plevs = False
1919
self.plot_plevs = False
2020
# Granulating plevs causes duplicate plots in this case.

e3sm_diags/parameter/zonal_mean_2d_stratosphere_parameter.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,5 +12,5 @@ def __init__(self):
1212
super(ZonalMean2dStratosphereParameter, self).__init__()
1313
# Override existing attributes
1414
# =============================
15-
self.plevs = copy.deepcopy(DEFAULT_PLEVS)
15+
self.plevs = copy.deepcopy(DEFAULT_PLEVS) # type: ignore
1616
self.plot_log_plevs = True

tests/complete_run_script.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,10 @@
4646

4747
case = "extendedOutput.v3.LR.historical_0101"
4848
short_name = "v3.LR.historical_0101"
49-
results_dir = "/global/cfs/cdirs/e3sm/www/chengzhu/tutorial2024/e3sm_diags_extended_int"
49+
50+
# TODO: Update `MAIN_DIR` as needed.
51+
MAIN_DIR = "24-12-09-main"
52+
results_dir = f"/global/cfs/cdirs/e3sm/www/e3sm_diags/{MAIN_DIR}/"
5053

5154
test_climo = "/global/cfs/cdirs/e3sm/chengzhu/tutorial2024/v3.LR.historical_0101/post/atm/180x360_aave/clim/15yr"
5255
test_ts = "/global/cfs/cdirs/e3sm/chengzhu/tutorial2024/v3.LR.historical_0101/post/atm/180x360_aave/ts/monthly/15yr"
@@ -78,6 +81,7 @@
7881
param.output_format_subplot = []
7982
param.multiprocessing = True
8083
param.num_workers = 24
84+
param.save_netcdf = True
8185
param.seasons = ["ANN"]
8286
params = [param]
8387

@@ -93,6 +97,7 @@
9397
enso_param.ref_start_yr = start_yr
9498
enso_param.ref_end_yr = end_yr
9599

100+
enso_param.save_netcdf = True
96101
params.append(enso_param)
97102

98103
trop_param = TropicalSubseasonalParameter()
@@ -106,7 +111,9 @@
106111
trop_param.ref_start_yr = "2001"
107112
trop_param.ref_end_yr = "2010"
108113

114+
trop_param.save_netcdf = True
109115
params.append(trop_param)
116+
110117
qbo_param = QboParameter()
111118
qbo_param.test_data_path = test_ts
112119
# qbo_param.test_name = short_name
@@ -118,7 +125,9 @@
118125
# Obs
119126
qbo_param.reference_data_path = ref_ts
120127

128+
qbo_param.save_netcdf = True
121129
params.append(qbo_param)
130+
122131
dc_param = DiurnalCycleParameter()
123132
dc_param.test_data_path = "/global/cfs/cdirs/e3sm/chengzhu/tutorial2024/v3.LR.historical_0101/post/atm/180x360_aave/clim_diurnal_8xdaily/"
124133
# dc_param.short_test_name = short_name
@@ -128,7 +137,9 @@
128137
# Obs
129138
dc_param.reference_data_path = ref_climo
130139

140+
dc_param.save_netcdf = True
131141
params.append(dc_param)
142+
132143
streamflow_param = StreamflowParameter()
133144
streamflow_param.reference_data_path = ref_ts
134145
streamflow_param.test_data_path = "/global/cfs/cdirs/e3sm/chengzhu/tutorial2024/v3.LR.historical_0101/post/rof/native/ts/monthly/15yr/"
@@ -143,7 +154,9 @@
143154
)
144155
streamflow_param.ref_end_yr = "1995"
145156

157+
streamflow_param.save_netcdf = True
146158
params.append(streamflow_param)
159+
147160
tc_param = TCAnalysisParameter()
148161
tc_param.test_data_path = "/global/cfs/cdirs/e3sm/chengzhu/tutorial2024/v3.LR.historical_0101/post/atm/tc-analysis_2000_2014"
149162
# tc_param.short_test_name = short_name
@@ -159,6 +172,7 @@
159172
tc_param.ref_start_yr = "1979"
160173
tc_param.ref_end_yr = "2018"
161174

175+
tc_param.save_netcdf = True
162176
params.append(tc_param)
163177

164178
arm_param = ARMDiagsParameter()
@@ -177,6 +191,7 @@
177191
arm_param.ref_start_yr = "0001"
178192
arm_param.ref_end_yr = "0001"
179193

194+
arm_param.save_netcdf = True
180195
params.append(arm_param)
181196

182197
# Run

tests/test_regression.py

Lines changed: 41 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,7 @@
11
import glob
2+
import subprocess
3+
from datetime import datetime
4+
from typing import List, TypedDict
25

36
import numpy as np
47
import pytest
@@ -8,18 +11,37 @@
811
from e3sm_diags.logger import custom_logger
912
from tests.complete_run_script import params, runner
1013

11-
1214
logger = custom_logger(__name__)
1315

1416

15-
DEV_DIR = "843-migration-phase3-model-vs-obs"
16-
DEV_PATH = f"/global/cfs/cdirs/e3sm/www/cdat-migration-fy24/{DEV_DIR}/"
17+
def _get_git_branch_name() -> str:
18+
"""Get the current git branch name."""
19+
try:
20+
branch_name = (
21+
subprocess.check_output(
22+
["git", "rev-parse", "--abbrev-ref", "HEAD"],
23+
stderr=subprocess.DEVNULL,
24+
)
25+
.strip()
26+
.decode("utf-8")
27+
)
28+
except subprocess.CalledProcessError:
29+
branch_name = "unknown"
30+
31+
return branch_name
32+
33+
34+
BRANCH_NAME = _get_git_branch_name()
35+
DEV_TIMESTAMP = datetime.now().strftime("%y-%m-%d")
36+
DEV_DIR = f"{DEV_TIMESTAMP}-{BRANCH_NAME}"
37+
DEV_PATH = f"/global/cfs/cdirs/e3sm/www/e3sm_diags/complete_run/{DEV_DIR}"
1738

1839
DEV_GLOB = sorted(glob.glob(DEV_PATH + "**/**/*.nc"))
1940
DEV_NUM_FILES = len(DEV_GLOB)
2041

21-
MAIN_DIR = "main"
22-
MAIN_PATH = f"/global/cfs/cdirs/e3sm/www/cdat-migration-fy24/{MAIN_DIR}/"
42+
# TODO: Update `MAIN_DIR` as needed.
43+
MAIN_DIR = "24-12-09-main"
44+
MAIN_PATH = f"/global/cfs/cdirs/e3sm/www/e3sm_diags/{MAIN_DIR}/"
2345
MAIN_GLOB = sorted(glob.glob(MAIN_PATH + "**/**/*.nc"))
2446
MAIN_NUM_FILES = len(MAIN_GLOB)
2547

@@ -54,7 +76,6 @@ def run_diags_and_get_results_dir() -> str:
5476
class TestRegression:
5577
@pytest.fixture(autouse=True)
5678
def setup(self, run_diags_and_get_results_dir):
57-
# TODO: We need to store `main` results on a data container
5879
self.results_dir = run_diags_and_get_results_dir
5980

6081
def test_check_if_files_found(self):
@@ -90,8 +111,19 @@ def test_get_relative_diffs(self):
90111
assert len(results["key_errors"]) == 0
91112

92113

93-
def _get_relative_diffs():
94-
results = {
114+
class DiffResults(TypedDict):
115+
"""Type annotation for the results of the relative differences comparison."""
116+
117+
missing_files: List[str]
118+
missing_vars: List[str]
119+
matching_files: List[str]
120+
mismatch_errors: List[str]
121+
not_equal_errors: List[str]
122+
key_errors: List[str]
123+
124+
125+
def _get_relative_diffs() -> DiffResults:
126+
results: DiffResults = {
95127
"missing_files": [],
96128
"missing_vars": [],
97129
"matching_files": [],
@@ -192,7 +224,7 @@ def _get_var_data(ds: xr.Dataset, var_key: str) -> np.ndarray | None:
192224
except KeyError:
193225
var_keys = DERIVED_VARIABLES[var_key.upper()].keys()
194226

195-
var_keys = [var_key] + list(sum(var_keys, ()))
227+
var_keys = [var_key] + list(sum(var_keys, ())) # type: ignore
196228

197229
for key in var_keys:
198230
if key in ds.data_vars.keys():

0 commit comments

Comments
 (0)