55
66import uxarray as ux
77
8+ from e3sm_diags .driver import METRICS_DEFAULT_VALUE
89from e3sm_diags .driver .utils .dataset_native import NativeDataset
10+ from e3sm_diags .driver .utils .type_annotations import MetricsDict
911from e3sm_diags .logger import _setup_child_logger
12+ from e3sm_diags .metrics .metrics import native_correlation , native_rmse
1013from e3sm_diags .plot .lat_lon_native_plot import plot as plot_func
1114
1215logger = _setup_child_logger (__name__ )
@@ -216,9 +219,28 @@ def _run_diags_2d(
216219 if has_valid_ref :
217220 logger .info (f"Processing { var_key } for region { region } (model vs model)" )
218221
219- uxds_diff = _compute_diff_between_grids (uxds_test , uxds_ref , var_key )
222+ uxds_diff , uxds_test_remapped , uxds_ref_remapped = (
223+ _compute_diff_between_grids (uxds_test , uxds_ref , var_key )
224+ )
225+
226+ # Create metrics dictionary using remapped datasets (following lat_lon_driver pattern)
227+ metrics_dict = _create_metrics_dict (
228+ var_key ,
229+ uxds_test ,
230+ uxds_ref ,
231+ uxds_test_remapped ,
232+ uxds_ref_remapped ,
233+ uxds_diff ,
234+ )
235+
236+ # Store metrics in parameter for plot function to access
237+ parameter .metrics_dict = metrics_dict
220238
221- # Plot with comparison mode
239+ parameter ._set_param_output_attrs (
240+ var_key , season , region , ref_name , ilev = None
241+ )
242+
243+ # Call plot function directly (pass region parameter)
222244 plot_func (
223245 parameter ,
224246 var_key ,
@@ -230,7 +252,24 @@ def _run_diags_2d(
230252 else :
231253 logger .info (f"Processing { var_key } for region { region } (model-only)" )
232254
233- # Plot with model-only mode.
255+ # Create metrics dictionary for model-only run
256+ metrics_dict = _create_metrics_dict (
257+ var_key ,
258+ uxds_test ,
259+ None , # No reference dataset
260+ None , # No remapped test dataset (not needed for model-only)
261+ None , # No remapped reference dataset
262+ None , # No difference dataset
263+ )
264+
265+ # Store metrics in parameter for plot function to access
266+ parameter .metrics_dict = metrics_dict
267+
268+ parameter ._set_param_output_attrs (
269+ var_key , season , region , ref_name , ilev = None
270+ )
271+
272+ # Call plot function directly (pass region parameter)
234273 plot_func (
235274 parameter ,
236275 var_key ,
@@ -243,7 +282,7 @@ def _run_diags_2d(
243282
244283def _compute_diff_between_grids (
245284 uxds_test : ux .UxDataset , uxds_ref : ux .UxDataset , var_key : str
246- ) -> ux .UxDataset :
285+ ) -> tuple [ ux .UxDataset | None , ux . UxDataset , ux . UxDataset ] :
247286 """Compute the difference between two native grid datasets.
248287
249288 This function handles the remapping between different grids if needed,
@@ -263,8 +302,9 @@ def _compute_diff_between_grids(
263302
264303 Returns
265304 -------
266- ux.UxDataset or None
267- A dataset containing the difference data, or None if computation fails
305+ tuple[ux.UxDataset | None, ux.UxDataset, ux.UxDataset]
306+ A tuple containing (difference_dataset, remapped_test, remapped_ref).
307+ The difference dataset can be None if computation fails.
268308 """
269309 try :
270310 # Check if variables exist in both datasets
@@ -274,24 +314,27 @@ def _compute_diff_between_grids(
274314 if var_key not in uxds_ref :
275315 logger .error (f"Variable { var_key } not found in reference dataset" )
276316
277- return None
317+ return None , uxds_test , uxds_ref
278318
279319 # Determine if both grids are identical by comparing properties and
280320 # create a difference dataset accordingly. Otherwise return None.
281321 same_grid , test_face_count , ref_face_count = _compare_grids (uxds_test , uxds_ref )
282322
283323 if same_grid :
284324 uxds_diff = _compute_direct_difference (uxds_test , uxds_ref , var_key )
325+ # For same grid, no remapping needed
326+ remapped_test = uxds_test
327+ remapped_ref = uxds_ref
285328 else :
286329 # Determine which grid to use as target (prefer lower resolution grid)
287330 target_is_test = ref_face_count >= test_face_count
288331
289- uxds_diff = _compute_remapped_difference (
332+ uxds_diff , remapped_test , remapped_ref = _compute_remapped_difference (
290333 uxds_test , uxds_ref , var_key , target_is_test
291334 )
292335
293336 if uxds_diff is None :
294- return None
337+ return None , uxds_test , uxds_ref
295338
296339 # Copy attributes and add diff metadata
297340 if var_key in uxds_diff and var_key in uxds_test :
@@ -303,12 +346,12 @@ def _compute_diff_between_grids(
303346 f"Difference in { uxds_diff [var_key ].attrs .get ('long_name' , var_key )} "
304347 )
305348
306- return uxds_diff
349+ return uxds_diff , remapped_test , remapped_ref
307350
308351 except Exception as e :
309352 logger .error (f"Error in compute_diff_between_grids: { e } " )
310353
311- return None
354+ return None , uxds_test , uxds_ref
312355
313356
314357def _compare_grids (
@@ -417,7 +460,7 @@ def _compute_direct_difference(
417460
418461def _compute_remapped_difference (
419462 uxds_test : ux .UxDataset , uxds_ref : ux .UxDataset , var_key : str , target_is_test : bool
420- ) -> ux .UxDataset | None :
463+ ) -> tuple [ ux .UxDataset | None , ux . UxDataset , ux . UxDataset ] :
421464 """Compute difference with remapping for different grids.
422465
423466 FIXME: This function has too many nested blocks and should be refactored.
@@ -467,23 +510,33 @@ def _compute_remapped_difference(
467510 # Remap reference to test grid
468511 logger .info ("Remapping reference data to test grid" )
469512 uxds_diff = uxds_test .copy ()
513+ remapped_test = uxds_test
470514
471515 ref_remapped = ref_var .remap .nearest_neighbor (
472516 uxds_test .uxgrid , remap_to = "face centers"
473517 )
474518 uxds_diff [var_key ] = test_var - ref_remapped
475519
520+ # Create remapped reference dataset
521+ remapped_ref = uxds_test .copy ()
522+ remapped_ref [var_key ] = ref_remapped
523+
476524 else :
477525 # Remap test to reference grid
478526 logger .info ("Remapping test data to reference grid" )
479527 uxds_diff = uxds_ref .copy ()
528+ remapped_ref = uxds_ref
480529
481530 test_remapped = test_var .remap .nearest_neighbor (
482531 uxds_ref .uxgrid , remap_to = "face centers"
483532 )
484533 uxds_diff [var_key ] = test_remapped - ref_var
485534
486- return uxds_diff
535+ # Create remapped test dataset
536+ remapped_test = uxds_ref .copy ()
537+ remapped_test [var_key ] = test_remapped
538+
539+ return uxds_diff , remapped_test , remapped_ref
487540
488541 except Exception as e :
489542 logger .error (f"Error during remapping and difference computation: { e } " )
@@ -495,4 +548,138 @@ def _compute_remapped_difference(
495548 )
496549 logger .debug (traceback .format_exc ())
497550
498- return None
551+ return None , uxds_test , uxds_ref
552+
553+
554+ def _create_metrics_dict (
555+ var_key : str ,
556+ uxds_test : ux .UxDataset ,
557+ uxds_ref : ux .UxDataset | None ,
558+ uxds_test_remapped : ux .UxDataset | None ,
559+ uxds_ref_remapped : ux .UxDataset | None ,
560+ uxds_diff : ux .UxDataset | None ,
561+ ) -> MetricsDict :
562+ """Create a metrics dictionary for native grid datasets.
563+
564+ This function follows the same pattern as lat_lon_driver._create_metrics_dict
565+ but uses uxarray datasets and native grid operations.
566+
567+ Parameters
568+ ----------
569+ var_key : str
570+ The variable key.
571+ uxds_test : ux.UxDataset
572+ The original test uxarray dataset.
573+ uxds_ref : ux.UxDataset | None
574+ The original reference uxarray dataset.
575+ uxds_test_remapped : ux.UxDataset | None
576+ The remapped test uxarray dataset.
577+ uxds_ref_remapped : ux.UxDataset | None
578+ The remapped reference uxarray dataset.
579+ uxds_diff : ux.UxDataset | None
580+ The difference uxarray dataset.
581+
582+ Returns
583+ -------
584+ MetricsDict
585+ The metrics dictionary.
586+ """
587+ # Basic test metrics using original dataset
588+ var_test = uxds_test [var_key ]
589+ metrics_dict : MetricsDict = {
590+ "test" : {
591+ "min" : var_test .min ().item (),
592+ "max" : var_test .max ().item (),
593+ "mean" : [var_test .weighted_mean ().item ()],
594+ "std" : METRICS_DEFAULT_VALUE , # Not implemented yet for native grids
595+ },
596+ "unit" : uxds_test [var_key ].attrs .get ("units" , "" ),
597+ }
598+
599+ # Set default values for all optional metrics
600+ metrics_dict = _set_default_metric_values (metrics_dict )
601+
602+ # Add reference metrics if available (using original dataset)
603+ if uxds_ref is not None and var_key in uxds_ref :
604+ var_ref = uxds_ref [var_key ]
605+ metrics_dict ["ref" ] = {
606+ "min" : var_ref .min ().item (),
607+ "max" : var_ref .max ().item (),
608+ "mean" : [var_ref .weighted_mean ().item ()],
609+ "std" : METRICS_DEFAULT_VALUE , # Not implemented yet for native grids
610+ }
611+
612+ # Add remapped test metrics if available
613+ if uxds_test_remapped is not None and var_key in uxds_test_remapped :
614+ var_test_remapped = uxds_test_remapped [var_key ]
615+ metrics_dict ["test_regrid" ] = {
616+ "min" : var_test_remapped .min ().item (),
617+ "max" : var_test_remapped .max ().item (),
618+ "mean" : [var_test_remapped .weighted_mean ().item ()],
619+ "std" : METRICS_DEFAULT_VALUE , # Not implemented yet for native grids
620+ }
621+
622+ # Add remapped reference metrics if available
623+ if uxds_ref_remapped is not None and var_key in uxds_ref_remapped :
624+ var_ref_remapped = uxds_ref_remapped [var_key ]
625+ metrics_dict ["ref_regrid" ] = {
626+ "min" : var_ref_remapped .min ().item (),
627+ "max" : var_ref_remapped .max ().item (),
628+ "mean" : [var_ref_remapped .weighted_mean ().item ()],
629+ "std" : METRICS_DEFAULT_VALUE , # Not implemented yet for native grids
630+ }
631+
632+ # Calculate RMSE and correlation on remapped datasets (following lat_lon pattern)
633+ if uxds_test_remapped is not None and uxds_ref_remapped is not None :
634+ try :
635+ rmse_val = native_rmse (uxds_test_remapped , uxds_ref_remapped , var_key )
636+ corr_val = native_correlation (
637+ uxds_test_remapped , uxds_ref_remapped , var_key
638+ )
639+
640+ metrics_dict ["misc" ] = {
641+ "rmse" : [rmse_val ],
642+ "corr" : [corr_val ],
643+ }
644+ except Exception as e :
645+ logger .warning (f"Failed to calculate RMSE/correlation: { e } " )
646+ # Keep default NaN values for misc metrics
647+
648+ # For model-only run, copy test metrics to test_regrid
649+ if uxds_test is not None and uxds_ref_remapped is None :
650+ metrics_dict ["test_regrid" ] = metrics_dict ["test" ]
651+
652+ # Add difference metrics if available
653+ if uxds_diff is not None and var_key in uxds_diff :
654+ var_diff = uxds_diff [var_key ]
655+ metrics_dict ["diff" ] = {
656+ "min" : var_diff .min ().item (),
657+ "max" : var_diff .max ().item (),
658+ "mean" : [var_diff .weighted_mean ().item ()],
659+ "std" : METRICS_DEFAULT_VALUE , # Not implemented yet for native grids
660+ }
661+
662+ return metrics_dict
663+
664+
665+ def _set_default_metric_values (metrics_dict : MetricsDict ) -> MetricsDict :
666+ """Set default values for optional metrics in the dictionary.
667+
668+ This function follows the same pattern as lat_lon_driver._set_default_metric_values.
669+ """
670+ var_keys = ["test_regrid" , "ref" , "ref_regrid" , "diff" ]
671+ metric_keys = ["min" , "max" , "mean" , "std" ]
672+
673+ for var_key in var_keys :
674+ if var_key not in metrics_dict :
675+ metrics_dict [var_key ] = {
676+ metric_key : METRICS_DEFAULT_VALUE for metric_key in metric_keys
677+ }
678+
679+ if "misc" not in metrics_dict :
680+ metrics_dict ["misc" ] = {
681+ "rmse" : METRICS_DEFAULT_VALUE ,
682+ "corr" : METRICS_DEFAULT_VALUE ,
683+ }
684+
685+ return metrics_dict
0 commit comments