Skip to content

Commit aa5d29b

Browse files
committed
EAMxx: rm mvk, hi est
1 parent f974b60 commit aa5d29b

File tree

17 files changed

+1628
-1239
lines changed

17 files changed

+1628
-1239
lines changed
Lines changed: 198 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,198 @@
1+
"""
2+
Ensemble Statistical Test using multi-instance capability.
3+
4+
This test runs multiple EAMxx instances with different perturbation seeds
5+
and uses statistical tests to verify that the climate state is identical
6+
between different runs.
7+
8+
EST inherits from SystemTestsCommon and only overrides:
9+
- setup_phase (to setup multi-instance with per-instance perturbed seeds)
10+
- _generate_baseline (move needed hist files to baseline directory)
11+
- _compare_baseline methods (to run the statistical tests)
12+
13+
EST relies on two util files:
14+
- est_perts.py: functions to duplicate and modify yaml files
15+
- est_stats.py: functions to conduct statistical testing
16+
"""
17+
18+
import os
19+
import glob
20+
import logging
21+
import sys
22+
23+
import CIME.test_status
24+
import CIME.utils
25+
from CIME.status import append_testlog
26+
from CIME.SystemTests.system_tests_common import SystemTestsCommon
27+
from CIME.case.case_setup import case_setup
28+
29+
logger = logging.getLogger(__name__)
30+
31+
32+
# pylint: disable=too-few-public-methods
33+
class EST(SystemTestsCommon):
34+
"""Ensemble Statistical Test using multi-instance capability"""
35+
36+
# pylint: disable=too-many-arguments, too-many-positional-arguments
37+
def setup_phase(
38+
self,
39+
clean=False,
40+
test_mode=False,
41+
reset=False,
42+
keep=False,
43+
disable_git=False,
44+
):
45+
"""setup phase implementation"""
46+
# first call the parent method and flush
47+
self.setup_indv(
48+
clean=clean,
49+
test_mode=test_mode,
50+
reset=reset,
51+
keep=keep,
52+
disable_git=disable_git,
53+
)
54+
self._case.flush()
55+
# and again...?
56+
case_setup(self._case, test_mode=False, reset=True)
57+
58+
# get run directory
59+
run_dir = self._case.get_value("RUNDIR")
60+
# get n_inst
61+
n_inst = int(self._case.get_value("NINST_ATM"))
62+
# return early if n_inst <= 1
63+
# we really don't want people to run this test with n_inst=1
64+
if n_inst <= 1:
65+
msg = (
66+
f"NINST_ATM = {n_inst}. This test requires NINST_ATM > 1. "
67+
"Consider setting NINST_ATM > 1 in your env_run.xml "
68+
"or use _C# specifier in test name for a multi-driver "
69+
"multi-instance setup (producing # pelayout copies), "
70+
"or _N# for a single-driver multi-instance setup "
71+
"(dividing specified pelayout among # instances)."
72+
)
73+
raise ValueError(msg)
74+
75+
# get est_perts functions
76+
# but first add the directory to sys.path if not already there
77+
est_perts_path = os.path.join(
78+
os.path.dirname(__file__), 'est_perts.py'
79+
)
80+
if not os.path.exists(est_perts_path):
81+
raise ImportError(
82+
f"Cannot find est_perts.py at {est_perts_path}"
83+
)
84+
if os.path.dirname(__file__) not in sys.path:
85+
sys.path.insert(0, os.path.dirname(__file__))
86+
# pylint: disable=import-outside-toplevel
87+
from est_perts import duplicate_yaml_file, update_yaml_file
88+
89+
# duplicate the yaml files n_inst times
90+
duplicate_yaml_file(f"{run_dir}/data/scream_input.yaml", n_inst)
91+
duplicate_yaml_file(f"{run_dir}/data/monthly_average.yaml", n_inst)
92+
# Let's update the perturbation properties inside the yaml files
93+
# this handles unique seeds and unique output files manually
94+
for i in range(1, n_inst + 1):
95+
yaml_file = f"{run_dir}/data/scream_input.yaml_{i:04d}"
96+
out_file = f"{run_dir}/data/monthly_average.yaml_{i:04d}"
97+
if not os.path.isfile(yaml_file):
98+
raise FileNotFoundError(
99+
f"File {yaml_file} does not exist.")
100+
if not os.path.isfile(out_file):
101+
raise FileNotFoundError(f"File {out_file} does not exist.")
102+
update_yaml_file(yaml_file, i, "pert")
103+
update_yaml_file(out_file, i, "out")
104+
105+
def _generate_baseline(self):
106+
"""generate a new baseline case based on the current test"""
107+
# might as well call the parent method first
108+
super()._generate_baseline()
109+
110+
with CIME.utils.SharedArea():
111+
# get the baseline and run directories
112+
base_gen_dir = os.path.join(
113+
self._case.get_value("BASELINE_ROOT"),
114+
self._case.get_value("BASEGEN_CASE"),
115+
)
116+
run_dir = self._case.get_value("RUNDIR")
117+
118+
# for eamxx, we need to get all files that have
119+
# *scream_????.h.*.nc added to this list
120+
hists = glob.glob(
121+
os.path.join(run_dir, "*scream_????.h.AVERAGE.*.nc")
122+
)
123+
hist_files = [os.path.basename(h) for h in hists]
124+
125+
for hist in hist_files:
126+
src = os.path.join(run_dir, hist)
127+
tgt = os.path.join(base_gen_dir, hist)
128+
# remove baselines if they exist
129+
# this is safe because cime forces users to use -o
130+
if os.path.exists(tgt):
131+
os.remove(tgt)
132+
133+
# log and copy
134+
logger.info(
135+
"Copying ... \n \t %s \n ... to ... \n \t %s \n\n",
136+
src, tgt
137+
)
138+
CIME.utils.safe_copy(src, tgt, preserve_meta=False)
139+
140+
def _compare_baseline(self):
141+
"""compare phase implementation"""
142+
with self._test_status as ts:
143+
# if we are resubmitting, then we don't do the comparison
144+
if int(self._case.get_value("RESUBMIT")) > 0:
145+
ts.set_status(
146+
CIME.test_status.BASELINE_PHASE,
147+
CIME.test_status.TEST_PASS_STATUS
148+
)
149+
return
150+
151+
# set to FAIL to start with, will update later
152+
ts.set_status(
153+
CIME.test_status.BASELINE_PHASE,
154+
CIME.test_status.TEST_FAIL_STATUS
155+
)
156+
157+
# get the run and baseline directories
158+
run_dir = self._case.get_value("RUNDIR")
159+
base_dir = os.path.join(
160+
self._case.get_value("BASELINE_ROOT"),
161+
self._case.get_value("BASECMP_CASE"),
162+
)
163+
164+
# launch the statistics tests
165+
# first, import est_stats funcs from the other file
166+
est_stats_path = os.path.join(
167+
os.path.dirname(__file__), 'est_stats.py'
168+
)
169+
if not os.path.exists(est_stats_path):
170+
raise ImportError(
171+
f"Cannot find est_stats.py at {est_stats_path}"
172+
)
173+
# Add the directory to sys.path if not already there
174+
if os.path.dirname(__file__) not in sys.path:
175+
sys.path.insert(0, os.path.dirname(__file__))
176+
# note be extra safe and import whole file
177+
# because we want to avoid import errors of needed pkgs
178+
# pylint: disable=import-outside-toplevel
179+
import est_stats as est
180+
# now, launch
181+
comments, new_ts = est.run_stats_comparison(
182+
run_dir,
183+
base_dir,
184+
analysis_type="spatiotemporal",
185+
test_type="ks",
186+
alpha=0.01,
187+
)
188+
189+
if new_ts == "PASS":
190+
out_ts = CIME.test_status.TEST_PASS_STATUS
191+
else:
192+
out_ts = CIME.test_status.TEST_FAIL_STATUS
193+
194+
# log the results and set the test status
195+
append_testlog(comments, self._orig_caseroot)
196+
ts.set_status(
197+
CIME.test_status.BASELINE_PHASE, out_ts
198+
)
Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
"""
2+
Perturbation functions for EST system test.
3+
"""
4+
5+
import os
6+
import shutil
7+
8+
9+
def duplicate_yaml_file(yaml_file, num_copies):
10+
"""Duplicate a YAML file into multiple copies with four-digit suffixes."""
11+
12+
if not os.path.isfile(yaml_file):
13+
raise FileNotFoundError(f"The file {yaml_file} does not exist.")
14+
15+
for i in range(1, num_copies + 1):
16+
new_file = f"{yaml_file}_{i:04d}"
17+
shutil.copyfile(yaml_file, new_file)
18+
19+
20+
def update_yaml_file(yaml_file, seed, pert_out):
21+
"""Update YAML input and output files with perturbation details."""
22+
23+
# Read the file content
24+
with open(yaml_file, "r", encoding="utf-8") as file:
25+
lines = file.readlines()
26+
27+
if pert_out == "pert":
28+
found_seed = False
29+
found_output = False
30+
new_lines = []
31+
32+
# Process each line
33+
for line in lines:
34+
if line.strip().startswith("perturbation_random_seed:"):
35+
# replace perturbation_random_seed: 0
36+
# with perturbation_random_seed: {seed}
37+
new_lines.append(
38+
line.replace(
39+
"perturbation_random_seed: 0",
40+
f"perturbation_random_seed: {seed}",
41+
)
42+
)
43+
found_seed = True
44+
elif "monthly_average.yaml" in line.strip():
45+
# replace "monthly_average.yaml"
46+
# with "monthly_average.yaml_{seed:04d}"
47+
new_lines.append(
48+
line.replace(
49+
"monthly_average.yaml",
50+
f"monthly_average.yaml_{seed:04d}"
51+
)
52+
)
53+
found_output = True
54+
else:
55+
new_lines.append(line)
56+
57+
if not found_seed:
58+
raise ValueError(f"'perturbation_random_seed' NOT in {yaml_file}")
59+
if not found_output:
60+
raise ValueError(f"'monthly_average.yaml' NOT in {yaml_file}")
61+
62+
# Write back to file
63+
with open(yaml_file, "w", encoding="utf-8") as file:
64+
file.writelines(new_lines)
65+
66+
elif pert_out == "out":
67+
# Track if we found and updated required fields
68+
found_prefix = False
69+
70+
new_lines = []
71+
# Process each line
72+
for line in lines:
73+
if line.strip().startswith("filename_prefix:"):
74+
# replace ".scream" with ".scream_{seed:04d}"
75+
new_lines.append(line.replace(
76+
".scream", f".scream_{seed:04d}"))
77+
found_prefix = True
78+
else:
79+
new_lines.append(line)
80+
81+
if not found_prefix:
82+
raise ValueError(f"Couldn't find 'filename_prefix' in {yaml_file}")
83+
84+
# Write the new lines back to file
85+
with open(yaml_file, "w", encoding="utf-8") as file:
86+
file.writelines(new_lines)

0 commit comments

Comments
 (0)