Skip to content

Commit 8b72480

Browse files
committed
refactor metrics generation and using weighted mean,rsme,corr
1 parent 8944b82 commit 8b72480

File tree

5 files changed

+334
-45
lines changed

5 files changed

+334
-45
lines changed

auxiliary_tools/debug/968-native-grid-vis/TGCLDLWP.cfg

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,8 @@ sets = ["lat_lon_native"]
33
case_id = "model_vs_model"
44
variables = ["TGCLDLWP"]
55
seasons = ["ANN", "01", "02", "03", "04", "05", "06", "07", "08", "09", "10", "11", "12", "DJF", "MAM", "JJA", "SON"]
6-
regions = ["global", "30S30N-150E90W"]
6+
regions = ["global"]
7+
#regions = ["global", "30S30N-150E90W"]
78
#test_colormap = "Blues"
89
#reference_colormap = "Blues"
910
diff_colormap = "RdBu"

auxiliary_tools/debug/968-native-grid-vis/run_lat_lon_native.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@
4747
#param.test_grid_file = "/lcrc/group/e3sm/diagnostics/grids/ne30pg2.nc"
4848
#param.ref_grid_file = "/lcrc/group/e3sm/diagnostics/grids/ne120pg2.nc"
4949

50-
##(3)
50+
#(3)
5151
param.test_data_path = "/lcrc/group/e3sm2/ac.wlin/E3SMv3/v3.LR.historical_0051/archive/atm/hist"
5252
param.test_file = "v3.LR.historical_0051.eam.h0.1989-12.nc"
5353
#param.short_test_name = "v3.LR.amip_0101"

e3sm_diags/driver/lat_lon_native_driver.py

Lines changed: 201 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,11 @@
55

66
import uxarray as ux
77

8+
from e3sm_diags.driver import METRICS_DEFAULT_VALUE
89
from e3sm_diags.driver.utils.dataset_native import NativeDataset
10+
from e3sm_diags.driver.utils.type_annotations import MetricsDict
911
from e3sm_diags.logger import _setup_child_logger
12+
from e3sm_diags.metrics.metrics import native_correlation, native_rmse
1013
from e3sm_diags.plot.lat_lon_native_plot import plot as plot_func
1114

1215
logger = _setup_child_logger(__name__)
@@ -216,9 +219,28 @@ def _run_diags_2d(
216219
if has_valid_ref:
217220
logger.info(f"Processing {var_key} for region {region} (model vs model)")
218221

219-
uxds_diff = _compute_diff_between_grids(uxds_test, uxds_ref, var_key)
222+
uxds_diff, uxds_test_remapped, uxds_ref_remapped = (
223+
_compute_diff_between_grids(uxds_test, uxds_ref, var_key)
224+
)
225+
226+
# Create metrics dictionary using remapped datasets (following lat_lon_driver pattern)
227+
metrics_dict = _create_metrics_dict(
228+
var_key,
229+
uxds_test,
230+
uxds_ref,
231+
uxds_test_remapped,
232+
uxds_ref_remapped,
233+
uxds_diff,
234+
)
235+
236+
# Store metrics in parameter for plot function to access
237+
parameter.metrics_dict = metrics_dict
220238

221-
# Plot with comparison mode
239+
parameter._set_param_output_attrs(
240+
var_key, season, region, ref_name, ilev=None
241+
)
242+
243+
# Call plot function directly (pass region parameter)
222244
plot_func(
223245
parameter,
224246
var_key,
@@ -230,7 +252,24 @@ def _run_diags_2d(
230252
else:
231253
logger.info(f"Processing {var_key} for region {region} (model-only)")
232254

233-
# Plot with model-only mode.
255+
# Create metrics dictionary for model-only run
256+
metrics_dict = _create_metrics_dict(
257+
var_key,
258+
uxds_test,
259+
None, # No reference dataset
260+
None, # No remapped test dataset (not needed for model-only)
261+
None, # No remapped reference dataset
262+
None, # No difference dataset
263+
)
264+
265+
# Store metrics in parameter for plot function to access
266+
parameter.metrics_dict = metrics_dict
267+
268+
parameter._set_param_output_attrs(
269+
var_key, season, region, ref_name, ilev=None
270+
)
271+
272+
# Call plot function directly (pass region parameter)
234273
plot_func(
235274
parameter,
236275
var_key,
@@ -243,7 +282,7 @@ def _run_diags_2d(
243282

244283
def _compute_diff_between_grids(
245284
uxds_test: ux.UxDataset, uxds_ref: ux.UxDataset, var_key: str
246-
) -> ux.UxDataset:
285+
) -> tuple[ux.UxDataset | None, ux.UxDataset, ux.UxDataset]:
247286
"""Compute the difference between two native grid datasets.
248287
249288
This function handles the remapping between different grids if needed,
@@ -263,8 +302,9 @@ def _compute_diff_between_grids(
263302
264303
Returns
265304
-------
266-
ux.UxDataset or None
267-
A dataset containing the difference data, or None if computation fails
305+
tuple[ux.UxDataset | None, ux.UxDataset, ux.UxDataset]
306+
A tuple containing (difference_dataset, remapped_test, remapped_ref).
307+
The difference dataset can be None if computation fails.
268308
"""
269309
try:
270310
# Check if variables exist in both datasets
@@ -274,24 +314,27 @@ def _compute_diff_between_grids(
274314
if var_key not in uxds_ref:
275315
logger.error(f"Variable {var_key} not found in reference dataset")
276316

277-
return None
317+
return None, uxds_test, uxds_ref
278318

279319
# Determine if both grids are identical by comparing properties and
280320
# create a difference dataset accordingly. Otherwise return None.
281321
same_grid, test_face_count, ref_face_count = _compare_grids(uxds_test, uxds_ref)
282322

283323
if same_grid:
284324
uxds_diff = _compute_direct_difference(uxds_test, uxds_ref, var_key)
325+
# For same grid, no remapping needed
326+
remapped_test = uxds_test
327+
remapped_ref = uxds_ref
285328
else:
286329
# Determine which grid to use as target (prefer lower resolution grid)
287330
target_is_test = ref_face_count >= test_face_count
288331

289-
uxds_diff = _compute_remapped_difference(
332+
uxds_diff, remapped_test, remapped_ref = _compute_remapped_difference(
290333
uxds_test, uxds_ref, var_key, target_is_test
291334
)
292335

293336
if uxds_diff is None:
294-
return None
337+
return None, uxds_test, uxds_ref
295338

296339
# Copy attributes and add diff metadata
297340
if var_key in uxds_diff and var_key in uxds_test:
@@ -303,12 +346,12 @@ def _compute_diff_between_grids(
303346
f"Difference in {uxds_diff[var_key].attrs.get('long_name', var_key)}"
304347
)
305348

306-
return uxds_diff
349+
return uxds_diff, remapped_test, remapped_ref
307350

308351
except Exception as e:
309352
logger.error(f"Error in compute_diff_between_grids: {e}")
310353

311-
return None
354+
return None, uxds_test, uxds_ref
312355

313356

314357
def _compare_grids(
@@ -417,7 +460,7 @@ def _compute_direct_difference(
417460

418461
def _compute_remapped_difference(
419462
uxds_test: ux.UxDataset, uxds_ref: ux.UxDataset, var_key: str, target_is_test: bool
420-
) -> ux.UxDataset | None:
463+
) -> tuple[ux.UxDataset | None, ux.UxDataset, ux.UxDataset]:
421464
"""Compute difference with remapping for different grids.
422465
423466
FIXME: This function has too many nested blocks and should be refactored.
@@ -467,23 +510,33 @@ def _compute_remapped_difference(
467510
# Remap reference to test grid
468511
logger.info("Remapping reference data to test grid")
469512
uxds_diff = uxds_test.copy()
513+
remapped_test = uxds_test
470514

471515
ref_remapped = ref_var.remap.nearest_neighbor(
472516
uxds_test.uxgrid, remap_to="face centers"
473517
)
474518
uxds_diff[var_key] = test_var - ref_remapped
475519

520+
# Create remapped reference dataset
521+
remapped_ref = uxds_test.copy()
522+
remapped_ref[var_key] = ref_remapped
523+
476524
else:
477525
# Remap test to reference grid
478526
logger.info("Remapping test data to reference grid")
479527
uxds_diff = uxds_ref.copy()
528+
remapped_ref = uxds_ref
480529

481530
test_remapped = test_var.remap.nearest_neighbor(
482531
uxds_ref.uxgrid, remap_to="face centers"
483532
)
484533
uxds_diff[var_key] = test_remapped - ref_var
485534

486-
return uxds_diff
535+
# Create remapped test dataset
536+
remapped_test = uxds_ref.copy()
537+
remapped_test[var_key] = test_remapped
538+
539+
return uxds_diff, remapped_test, remapped_ref
487540

488541
except Exception as e:
489542
logger.error(f"Error during remapping and difference computation: {e}")
@@ -495,4 +548,138 @@ def _compute_remapped_difference(
495548
)
496549
logger.debug(traceback.format_exc())
497550

498-
return None
551+
return None, uxds_test, uxds_ref
552+
553+
554+
def _create_metrics_dict(
555+
var_key: str,
556+
uxds_test: ux.UxDataset,
557+
uxds_ref: ux.UxDataset | None,
558+
uxds_test_remapped: ux.UxDataset | None,
559+
uxds_ref_remapped: ux.UxDataset | None,
560+
uxds_diff: ux.UxDataset | None,
561+
) -> MetricsDict:
562+
"""Create a metrics dictionary for native grid datasets.
563+
564+
This function follows the same pattern as lat_lon_driver._create_metrics_dict
565+
but uses uxarray datasets and native grid operations.
566+
567+
Parameters
568+
----------
569+
var_key : str
570+
The variable key.
571+
uxds_test : ux.UxDataset
572+
The original test uxarray dataset.
573+
uxds_ref : ux.UxDataset | None
574+
The original reference uxarray dataset.
575+
uxds_test_remapped : ux.UxDataset | None
576+
The remapped test uxarray dataset.
577+
uxds_ref_remapped : ux.UxDataset | None
578+
The remapped reference uxarray dataset.
579+
uxds_diff : ux.UxDataset | None
580+
The difference uxarray dataset.
581+
582+
Returns
583+
-------
584+
MetricsDict
585+
The metrics dictionary.
586+
"""
587+
# Basic test metrics using original dataset
588+
var_test = uxds_test[var_key]
589+
metrics_dict: MetricsDict = {
590+
"test": {
591+
"min": var_test.min().item(),
592+
"max": var_test.max().item(),
593+
"mean": [var_test.weighted_mean().item()],
594+
"std": METRICS_DEFAULT_VALUE, # Not implemented yet for native grids
595+
},
596+
"unit": uxds_test[var_key].attrs.get("units", ""),
597+
}
598+
599+
# Set default values for all optional metrics
600+
metrics_dict = _set_default_metric_values(metrics_dict)
601+
602+
# Add reference metrics if available (using original dataset)
603+
if uxds_ref is not None and var_key in uxds_ref:
604+
var_ref = uxds_ref[var_key]
605+
metrics_dict["ref"] = {
606+
"min": var_ref.min().item(),
607+
"max": var_ref.max().item(),
608+
"mean": [var_ref.weighted_mean().item()],
609+
"std": METRICS_DEFAULT_VALUE, # Not implemented yet for native grids
610+
}
611+
612+
# Add remapped test metrics if available
613+
if uxds_test_remapped is not None and var_key in uxds_test_remapped:
614+
var_test_remapped = uxds_test_remapped[var_key]
615+
metrics_dict["test_regrid"] = {
616+
"min": var_test_remapped.min().item(),
617+
"max": var_test_remapped.max().item(),
618+
"mean": [var_test_remapped.weighted_mean().item()],
619+
"std": METRICS_DEFAULT_VALUE, # Not implemented yet for native grids
620+
}
621+
622+
# Add remapped reference metrics if available
623+
if uxds_ref_remapped is not None and var_key in uxds_ref_remapped:
624+
var_ref_remapped = uxds_ref_remapped[var_key]
625+
metrics_dict["ref_regrid"] = {
626+
"min": var_ref_remapped.min().item(),
627+
"max": var_ref_remapped.max().item(),
628+
"mean": [var_ref_remapped.weighted_mean().item()],
629+
"std": METRICS_DEFAULT_VALUE, # Not implemented yet for native grids
630+
}
631+
632+
# Calculate RMSE and correlation on remapped datasets (following lat_lon pattern)
633+
if uxds_test_remapped is not None and uxds_ref_remapped is not None:
634+
try:
635+
rmse_val = native_rmse(uxds_test_remapped, uxds_ref_remapped, var_key)
636+
corr_val = native_correlation(
637+
uxds_test_remapped, uxds_ref_remapped, var_key
638+
)
639+
640+
metrics_dict["misc"] = {
641+
"rmse": [rmse_val],
642+
"corr": [corr_val],
643+
}
644+
except Exception as e:
645+
logger.warning(f"Failed to calculate RMSE/correlation: {e}")
646+
# Keep default NaN values for misc metrics
647+
648+
# For model-only run, copy test metrics to test_regrid
649+
if uxds_test is not None and uxds_ref_remapped is None:
650+
metrics_dict["test_regrid"] = metrics_dict["test"]
651+
652+
# Add difference metrics if available
653+
if uxds_diff is not None and var_key in uxds_diff:
654+
var_diff = uxds_diff[var_key]
655+
metrics_dict["diff"] = {
656+
"min": var_diff.min().item(),
657+
"max": var_diff.max().item(),
658+
"mean": [var_diff.weighted_mean().item()],
659+
"std": METRICS_DEFAULT_VALUE, # Not implemented yet for native grids
660+
}
661+
662+
return metrics_dict
663+
664+
665+
def _set_default_metric_values(metrics_dict: MetricsDict) -> MetricsDict:
666+
"""Set default values for optional metrics in the dictionary.
667+
668+
This function follows the same pattern as lat_lon_driver._set_default_metric_values.
669+
"""
670+
var_keys = ["test_regrid", "ref", "ref_regrid", "diff"]
671+
metric_keys = ["min", "max", "mean", "std"]
672+
673+
for var_key in var_keys:
674+
if var_key not in metrics_dict:
675+
metrics_dict[var_key] = {
676+
metric_key: METRICS_DEFAULT_VALUE for metric_key in metric_keys
677+
}
678+
679+
if "misc" not in metrics_dict:
680+
metrics_dict["misc"] = {
681+
"rmse": METRICS_DEFAULT_VALUE,
682+
"corr": METRICS_DEFAULT_VALUE,
683+
}
684+
685+
return metrics_dict

0 commit comments

Comments
 (0)