Skip to content

Commit 3af5471

Browse files
committed
make metrics work on regional subset
1 parent 5a71782 commit 3af5471

File tree

3 files changed

+167
-105
lines changed

3 files changed

+167
-105
lines changed

auxiliary_tools/debug/968-native-grid-vis/run_lat_lon_native.py

Lines changed: 31 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -15,37 +15,37 @@
1515
# Create parameter objects for 3 different runs
1616
params = []
1717

18-
# (1) First test configuration
19-
param1 = LatLonNativeParameter()
20-
param1.results_dir = f"/lcrc/group/e3sm/public_html/diagnostic_output/{username}/tests/lat_lon_native_test_1"
21-
param1.test_data_path = "/lcrc/group/e3sm/public_html/e3sm_diags_test_data/native_grid"
22-
param1.test_name = "v3.LR.amip_0101"
23-
param1.short_test_name = "v3.LR.amip_0101"
24-
param1.reference_data_path = "/lcrc/group/e3sm/public_html/e3sm_diags_test_data/native_grid"
25-
param1.ref_name = "v3.HR.test4"
26-
param1.short_ref_name = "v3.HR.test4"
27-
param1.seasons = ["DJF"]
28-
param1.test_grid_file = "/lcrc/group/e3sm/diagnostics/grids/ne30pg2.nc"
29-
param1.ref_grid_file = "/lcrc/group/e3sm/diagnostics/grids/ne120pg2.nc"
30-
param1.case_id = "model_vs_model"
31-
param1.run_type = "model_vs_model"
32-
params.append(param1)
33-
34-
# (2) Second test configuration
35-
param2 = LatLonNativeParameter()
36-
param2.results_dir = f"/lcrc/group/e3sm/public_html/diagnostic_output/{username}/tests/lat_lon_native_test_2"
37-
param2.test_data_path = "/lcrc/group/e3sm/public_html/e3sm_diags_test_data/native_grid"
38-
param2.test_file = "v3.LR.amip_0101_DJF_climo.nc"
39-
param2.short_test_name = "v3.LR.amip_0101"
40-
param2.reference_data_path = "/lcrc/group/e3sm/public_html/e3sm_diags_test_data/native_grid"
41-
param2.ref_file = "v3.HR.test4_DJF_climo.nc"
42-
param2.short_ref_name = "v3.HR.test4"
43-
param2.seasons = ["DJF"]
44-
param2.test_grid_file = "/lcrc/group/e3sm/diagnostics/grids/ne30pg2.nc"
45-
param2.ref_grid_file = "/lcrc/group/e3sm/diagnostics/grids/ne120pg2.nc"
46-
param2.case_id = "model_vs_model"
47-
param2.run_type = "model_vs_model"
48-
params.append(param2)
18+
## (1) First test configuration
19+
#param1 = LatLonNativeParameter()
20+
#param1.results_dir = f"/lcrc/group/e3sm/public_html/diagnostic_output/{username}/tests/lat_lon_native_test_1"
21+
#param1.test_data_path = "/lcrc/group/e3sm/public_html/e3sm_diags_test_data/native_grid"
22+
#param1.test_name = "v3.LR.amip_0101"
23+
#param1.short_test_name = "v3.LR.amip_0101"
24+
#param1.reference_data_path = "/lcrc/group/e3sm/public_html/e3sm_diags_test_data/native_grid"
25+
#param1.ref_name = "v3.HR.test4"
26+
#param1.short_ref_name = "v3.HR.test4"
27+
#param1.seasons = ["DJF"]
28+
#param1.test_grid_file = "/lcrc/group/e3sm/diagnostics/grids/ne30pg2.nc"
29+
#param1.ref_grid_file = "/lcrc/group/e3sm/diagnostics/grids/ne120pg2.nc"
30+
#param1.case_id = "model_vs_model"
31+
#param1.run_type = "model_vs_model"
32+
#params.append(param1)
33+
#
34+
## (2) Second test configuration
35+
#param2 = LatLonNativeParameter()
36+
#param2.results_dir = f"/lcrc/group/e3sm/public_html/diagnostic_output/{username}/tests/lat_lon_native_test_2"
37+
#param2.test_data_path = "/lcrc/group/e3sm/public_html/e3sm_diags_test_data/native_grid"
38+
#param2.test_file = "v3.LR.amip_0101_DJF_climo.nc"
39+
#param2.short_test_name = "v3.LR.amip_0101"
40+
#param2.reference_data_path = "/lcrc/group/e3sm/public_html/e3sm_diags_test_data/native_grid"
41+
#param2.ref_file = "v3.HR.test4_DJF_climo.nc"
42+
#param2.short_ref_name = "v3.HR.test4"
43+
#param2.seasons = ["DJF"]
44+
#param2.test_grid_file = "/lcrc/group/e3sm/diagnostics/grids/ne30pg2.nc"
45+
#param2.ref_grid_file = "/lcrc/group/e3sm/diagnostics/grids/ne120pg2.nc"
46+
#param2.case_id = "model_vs_model"
47+
#param2.run_type = "model_vs_model"
48+
#params.append(param2)
4949

5050
# (3) Third test configuration
5151
param3 = LatLonNativeParameter()

e3sm_diags/driver/lat_lon_native_driver.py

Lines changed: 104 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
import uxarray as ux
77

8+
from e3sm_diags.derivations.default_regions_xr import REGION_SPECS
89
from e3sm_diags.driver import METRICS_DEFAULT_VALUE
910
from e3sm_diags.driver.utils.dataset_native import NativeDataset
1011
from e3sm_diags.driver.utils.type_annotations import MetricsDict
@@ -223,14 +224,25 @@ def _run_diags_2d(
223224
_compute_diff_between_grids(uxds_test, uxds_ref, var_key)
224225
)
225226

226-
# Create metrics dictionary using remapped datasets (following lat_lon_driver pattern)
227+
# Apply regional subsetting to all datasets before metrics calculation
228+
uxds_test_subset = _apply_regional_subsetting(uxds_test, var_key, region)
229+
uxds_ref_subset = _apply_regional_subsetting(uxds_ref, var_key, region)
230+
uxds_test_remapped_subset = _apply_regional_subsetting(
231+
uxds_test_remapped, var_key, region
232+
)
233+
uxds_ref_remapped_subset = _apply_regional_subsetting(
234+
uxds_ref_remapped, var_key, region
235+
)
236+
uxds_diff_subset = _apply_regional_subsetting(uxds_diff, var_key, region)
237+
238+
# Create metrics dictionary using regionally subsetted datasets
227239
metrics_dict = _create_metrics_dict(
228240
var_key,
229-
uxds_test,
230-
uxds_ref,
231-
uxds_test_remapped,
232-
uxds_ref_remapped,
233-
uxds_diff,
241+
uxds_test_subset,
242+
uxds_ref_subset,
243+
uxds_test_remapped_subset,
244+
uxds_ref_remapped_subset,
245+
uxds_diff_subset,
234246
)
235247

236248
# Store metrics in parameter for plot function to access
@@ -240,7 +252,7 @@ def _run_diags_2d(
240252
var_key, season, region, ref_name, ilev=None
241253
)
242254

243-
# Call plot function directly (pass region parameter)
255+
# Call plot function with original datasets for visualization
244256
plot_func(
245257
parameter,
246258
var_key,
@@ -252,10 +264,13 @@ def _run_diags_2d(
252264
else:
253265
logger.info(f"Processing {var_key} for region {region} (model-only)")
254266

255-
# Create metrics dictionary for model-only run
267+
# Apply regional subsetting to test dataset before metrics calculation
268+
uxds_test_subset = _apply_regional_subsetting(uxds_test, var_key, region)
269+
270+
# Create metrics dictionary for model-only run using regionally subsetted dataset
256271
metrics_dict = _create_metrics_dict(
257272
var_key,
258-
uxds_test,
273+
uxds_test_subset,
259274
None, # No reference dataset
260275
None, # No remapped test dataset (not needed for model-only)
261276
None, # No remapped reference dataset
@@ -269,7 +284,7 @@ def _run_diags_2d(
269284
var_key, season, region, ref_name, ilev=None
270285
)
271286

272-
# Call plot function directly (pass region parameter)
287+
# Call plot function with original dataset for visualization
273288
plot_func(
274289
parameter,
275290
var_key,
@@ -588,8 +603,8 @@ def _create_metrics_dict(
588603
var_test = uxds_test[var_key]
589604
metrics_dict: MetricsDict = {
590605
"test": {
591-
"min": var_test.min().item(),
592-
"max": var_test.max().item(),
606+
"min": [var_test.min().item()],
607+
"max": [var_test.max().item()],
593608
"mean": [var_test.weighted_mean().item()],
594609
"std": METRICS_DEFAULT_VALUE, # Not implemented yet for native grids
595610
},
@@ -603,8 +618,8 @@ def _create_metrics_dict(
603618
if uxds_ref is not None and var_key in uxds_ref:
604619
var_ref = uxds_ref[var_key]
605620
metrics_dict["ref"] = {
606-
"min": var_ref.min().item(),
607-
"max": var_ref.max().item(),
621+
"min": [var_ref.min().item()],
622+
"max": [var_ref.max().item()],
608623
"mean": [var_ref.weighted_mean().item()],
609624
"std": METRICS_DEFAULT_VALUE, # Not implemented yet for native grids
610625
}
@@ -613,8 +628,8 @@ def _create_metrics_dict(
613628
if uxds_test_remapped is not None and var_key in uxds_test_remapped:
614629
var_test_remapped = uxds_test_remapped[var_key]
615630
metrics_dict["test_regrid"] = {
616-
"min": var_test_remapped.min().item(),
617-
"max": var_test_remapped.max().item(),
631+
"min": [var_test_remapped.min().item()],
632+
"max": [var_test_remapped.max().item()],
618633
"mean": [var_test_remapped.weighted_mean().item()],
619634
"std": METRICS_DEFAULT_VALUE, # Not implemented yet for native grids
620635
}
@@ -623,8 +638,8 @@ def _create_metrics_dict(
623638
if uxds_ref_remapped is not None and var_key in uxds_ref_remapped:
624639
var_ref_remapped = uxds_ref_remapped[var_key]
625640
metrics_dict["ref_regrid"] = {
626-
"min": var_ref_remapped.min().item(),
627-
"max": var_ref_remapped.max().item(),
641+
"min": [var_ref_remapped.min().item()],
642+
"max": [var_ref_remapped.max().item()],
628643
"mean": [var_ref_remapped.weighted_mean().item()],
629644
"std": METRICS_DEFAULT_VALUE, # Not implemented yet for native grids
630645
}
@@ -653,8 +668,8 @@ def _create_metrics_dict(
653668
if uxds_diff is not None and var_key in uxds_diff:
654669
var_diff = uxds_diff[var_key]
655670
metrics_dict["diff"] = {
656-
"min": var_diff.min().item(),
657-
"max": var_diff.max().item(),
671+
"min": [var_diff.min().item()],
672+
"max": [var_diff.max().item()],
658673
"mean": [var_diff.weighted_mean().item()],
659674
"std": METRICS_DEFAULT_VALUE, # Not implemented yet for native grids
660675
}
@@ -683,3 +698,72 @@ def _set_default_metric_values(metrics_dict: MetricsDict) -> MetricsDict:
683698
}
684699

685700
return metrics_dict
701+
702+
703+
def _apply_regional_subsetting(
704+
uxds: ux.UxDataset | None, var_key: str, region: str
705+
) -> ux.UxDataset | None:
706+
"""Apply regional subsetting to a uxarray dataset based on region specification.
707+
708+
This function follows the same pattern as the regional subsetting in
709+
lat_lon_native_plot.py but moves it to the driver for consistency.
710+
711+
Parameters
712+
----------
713+
uxds : ux.UxDataset or None
714+
The uxarray dataset to subset.
715+
var_key : str
716+
The variable key to subset.
717+
region : str
718+
The region specification (e.g., "global", "CONUS", etc.).
719+
720+
Returns
721+
-------
722+
ux.UxDataset or None
723+
The regionally subsetted dataset, or None if input was None.
724+
"""
725+
if uxds is None:
726+
return uxds
727+
728+
# Get region specs (same logic as in plot function)
729+
region_specs = REGION_SPECS.get(region, None)
730+
731+
if region_specs is None:
732+
# Unknown region, return original dataset
733+
logger.warning(
734+
f"Region '{region}' not found in REGION_SPECS. Using global dataset."
735+
)
736+
return uxds
737+
738+
# Get bounds (same logic as in plot function)
739+
lat_bounds = region_specs.get("lat", (-90, 90)) # type: ignore
740+
lon_bounds = region_specs.get("lon", (0, 360)) # type: ignore
741+
is_global_domain = lat_bounds == (-90, 90) and lon_bounds == (0, 360)
742+
743+
if is_global_domain:
744+
# Global domain, no subsetting needed
745+
return uxds
746+
747+
try:
748+
# Check if target variable exists
749+
if var_key not in uxds.data_vars:
750+
logger.warning(
751+
f"Variable '{var_key}' not found in dataset. Available vars: {list(uxds.data_vars)}"
752+
)
753+
return uxds
754+
755+
# Apply subsetting to the specific variable
756+
var_subset = uxds[var_key].subset.bounding_box(lon_bounds, lat_bounds)
757+
758+
# Create new dataset from subsetted variable
759+
uxds_subset = var_subset.to_dataset()
760+
uxds_subset.attrs.update(uxds.attrs)
761+
uxds_subset[var_key].attrs.update(uxds[var_key].attrs)
762+
return uxds_subset
763+
764+
except Exception as e:
765+
logger.warning(
766+
f"Failed to apply regional subsetting for region '{region}': {e}"
767+
)
768+
logger.warning("Using global dataset instead.")
769+
return uxds

e3sm_diags/plot/lat_lon_native_plot.py

Lines changed: 32 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -109,53 +109,28 @@ def plot( # noqa: C901
109109

110110
logger.info(f"Region: {region}, lat_bounds: {lat_bounds}, lon_bounds: {lon_bounds}")
111111

112-
# Extract metrics directly from the uxarray dataset
112+
# Extract metrics from parameter.metrics_dict (calculated in driver with regional subsetting)
113113
if uxds_test is not None and var_key in uxds_test:
114-
# ------------------------------------------------------------
115-
# FIXME: Metrics extraction for test, ref, and diff datasets are duplicated,
116-
# extract to a helper function.
117-
# ------------------------------------------------------------
118-
if is_global_domain:
119-
test_min = uxds_test[var_key].min().item()
120-
test_max = uxds_test[var_key].max().item()
121-
# For native grid, use weighted mean
122-
test_mean = uxds_test[var_key].weighted_mean().item()
123-
else:
124-
test_subset = uxds_test[var_key].subset.bounding_box(
125-
lon_bounds,
126-
lat_bounds,
127-
)
128-
test_min = test_subset.min().item()
129-
test_max = test_subset.max().item()
130-
# For native grid, use weighted mean
131-
test_mean = test_subset.weighted_mean().item()
132-
133114
units = uxds_test[var_key].attrs.get("units", "")
134-
# ------------------------------------------------------------
115+
116+
# Get test metrics from parameter.metrics_dict
117+
try:
118+
test_min = parameter.metrics_dict["test_regrid"]["min"][0] # type: ignore
119+
test_max = parameter.metrics_dict["test_regrid"]["max"][0] # type: ignore
120+
test_mean = parameter.metrics_dict["test_regrid"]["mean"][0] # type: ignore
121+
except (KeyError, IndexError, TypeError) as e:
122+
logger.warning(
123+
f"Failed to get test metrics from metrics_dict: {e}, using NaN"
124+
)
125+
test_min = test_max = test_mean = float("nan")
135126
else:
136127
# This should not happen since we check earlier, but just in case
137128
logger.error(f"Missing test data for variable {var_key} in native grid dataset")
138-
139129
return
140130

141131
# Extract metrics for reference data if available
142132
ref_min = ref_max = ref_mean = diff_min = diff_max = diff_mean = None
143133
if has_reference and uxds_ref is not None:
144-
# ------------------------------------------------------------
145-
# FIXME: Metrics extraction for test, ref, and diff datasets are duplicated,
146-
# extract to a helper function.
147-
# ------------------------------------------------------------
148-
if is_global_domain:
149-
ref_min = uxds_ref[var_key].min().item()
150-
ref_max = uxds_ref[var_key].max().item()
151-
ref_mean = uxds_ref[var_key].weighted_mean().item()
152-
else:
153-
ref_subset = uxds_ref[var_key].subset.bounding_box(lon_bounds, lat_bounds)
154-
ref_min = ref_subset.min().item()
155-
ref_max = ref_subset.max().item()
156-
ref_mean = ref_subset.weighted_mean().item()
157-
# ------------------------------------------------------------
158-
159134
ref_units = uxds_ref[var_key].attrs.get("units", "")
160135

161136
# Check if units match between test and reference
@@ -164,25 +139,28 @@ def plot( # noqa: C901
164139
f"Units mismatch between test ({units}) and reference ({ref_units})"
165140
)
166141

167-
# Calculate approximate metrics for difference if not already available
142+
# Get reference metrics from parameter.metrics_dict
143+
try:
144+
ref_min = parameter.metrics_dict["ref"]["min"][0] # type: ignore
145+
ref_max = parameter.metrics_dict["ref"]["max"][0] # type: ignore
146+
ref_mean = parameter.metrics_dict["ref"]["mean"][0] # type: ignore
147+
except (KeyError, IndexError, TypeError):
148+
logger.warning(
149+
"Failed to get reference metrics from metrics_dict, using NaN"
150+
)
151+
ref_min = ref_max = ref_mean = float("nan")
152+
153+
# Get difference metrics from parameter.metrics_dict
168154
if uxds_diff is not None and var_key in uxds_diff:
169-
# ------------------------------------------------------------
170-
# FIXME: Metrics extraction for test, ref, and diff datasets are duplicated,
171-
# extract to a helper function.
172-
# ------------------------------------------------------------
173-
if is_global_domain:
174-
diff_min = uxds_diff[var_key].min().item()
175-
diff_max = uxds_diff[var_key].max().item()
176-
diff_mean = uxds_diff[var_key].weighted_mean().item()
177-
else:
178-
diff_subset = uxds_diff[var_key].subset.bounding_box(
179-
lon_bounds,
180-
lat_bounds,
155+
try:
156+
diff_min = parameter.metrics_dict["diff"]["min"][0] # type: ignore
157+
diff_max = parameter.metrics_dict["diff"]["max"][0] # type: ignore
158+
diff_mean = parameter.metrics_dict["diff"]["mean"][0] # type: ignore
159+
except (KeyError, IndexError, TypeError):
160+
logger.warning(
161+
"Failed to get diff metrics from metrics_dict, using NaN"
181162
)
182-
diff_min = diff_subset.min().item()
183-
diff_max = diff_subset.max().item()
184-
diff_mean = diff_subset.weighted_mean().item()
185-
# ------------------------------------------------------------
163+
diff_min = diff_max = diff_mean = float("nan")
186164

187165
# Create panels following the lat_lon_plot layout
188166
# Panel 1: Test data (always created)

0 commit comments

Comments
 (0)