55
66import uxarray as ux
77
8+ from e3sm_diags .derivations .default_regions_xr import REGION_SPECS
89from e3sm_diags .driver import METRICS_DEFAULT_VALUE
910from e3sm_diags .driver .utils .dataset_native import NativeDataset
1011from e3sm_diags .driver .utils .type_annotations import MetricsDict
@@ -223,14 +224,25 @@ def _run_diags_2d(
223224 _compute_diff_between_grids (uxds_test , uxds_ref , var_key )
224225 )
225226
226- # Create metrics dictionary using remapped datasets (following lat_lon_driver pattern)
227+ # Apply regional subsetting to all datasets before metrics calculation
228+ uxds_test_subset = _apply_regional_subsetting (uxds_test , var_key , region )
229+ uxds_ref_subset = _apply_regional_subsetting (uxds_ref , var_key , region )
230+ uxds_test_remapped_subset = _apply_regional_subsetting (
231+ uxds_test_remapped , var_key , region
232+ )
233+ uxds_ref_remapped_subset = _apply_regional_subsetting (
234+ uxds_ref_remapped , var_key , region
235+ )
236+ uxds_diff_subset = _apply_regional_subsetting (uxds_diff , var_key , region )
237+
238+ # Create metrics dictionary using regionally subsetted datasets
227239 metrics_dict = _create_metrics_dict (
228240 var_key ,
229- uxds_test ,
230- uxds_ref ,
231- uxds_test_remapped ,
232- uxds_ref_remapped ,
233- uxds_diff ,
241+ uxds_test_subset ,
242+ uxds_ref_subset ,
243+ uxds_test_remapped_subset ,
244+ uxds_ref_remapped_subset ,
245+ uxds_diff_subset ,
234246 )
235247
236248 # Store metrics in parameter for plot function to access
@@ -240,7 +252,7 @@ def _run_diags_2d(
240252 var_key , season , region , ref_name , ilev = None
241253 )
242254
243- # Call plot function directly (pass region parameter)
255+ # Call plot function with original datasets for visualization
244256 plot_func (
245257 parameter ,
246258 var_key ,
@@ -252,10 +264,13 @@ def _run_diags_2d(
252264 else :
253265 logger .info (f"Processing { var_key } for region { region } (model-only)" )
254266
255- # Create metrics dictionary for model-only run
267+ # Apply regional subsetting to test dataset before metrics calculation
268+ uxds_test_subset = _apply_regional_subsetting (uxds_test , var_key , region )
269+
270+ # Create metrics dictionary for model-only run using regionally subsetted dataset
256271 metrics_dict = _create_metrics_dict (
257272 var_key ,
258- uxds_test ,
273+ uxds_test_subset ,
259274 None , # No reference dataset
260275 None , # No remapped test dataset (not needed for model-only)
261276 None , # No remapped reference dataset
@@ -269,7 +284,7 @@ def _run_diags_2d(
269284 var_key , season , region , ref_name , ilev = None
270285 )
271286
272- # Call plot function directly (pass region parameter)
287+ # Call plot function with original dataset for visualization
273288 plot_func (
274289 parameter ,
275290 var_key ,
@@ -588,8 +603,8 @@ def _create_metrics_dict(
588603 var_test = uxds_test [var_key ]
589604 metrics_dict : MetricsDict = {
590605 "test" : {
591- "min" : var_test .min ().item (),
592- "max" : var_test .max ().item (),
606+ "min" : [ var_test .min ().item ()] ,
607+ "max" : [ var_test .max ().item ()] ,
593608 "mean" : [var_test .weighted_mean ().item ()],
594609 "std" : METRICS_DEFAULT_VALUE , # Not implemented yet for native grids
595610 },
@@ -603,8 +618,8 @@ def _create_metrics_dict(
603618 if uxds_ref is not None and var_key in uxds_ref :
604619 var_ref = uxds_ref [var_key ]
605620 metrics_dict ["ref" ] = {
606- "min" : var_ref .min ().item (),
607- "max" : var_ref .max ().item (),
621+ "min" : [ var_ref .min ().item ()] ,
622+ "max" : [ var_ref .max ().item ()] ,
608623 "mean" : [var_ref .weighted_mean ().item ()],
609624 "std" : METRICS_DEFAULT_VALUE , # Not implemented yet for native grids
610625 }
@@ -613,8 +628,8 @@ def _create_metrics_dict(
613628 if uxds_test_remapped is not None and var_key in uxds_test_remapped :
614629 var_test_remapped = uxds_test_remapped [var_key ]
615630 metrics_dict ["test_regrid" ] = {
616- "min" : var_test_remapped .min ().item (),
617- "max" : var_test_remapped .max ().item (),
631+ "min" : [ var_test_remapped .min ().item ()] ,
632+ "max" : [ var_test_remapped .max ().item ()] ,
618633 "mean" : [var_test_remapped .weighted_mean ().item ()],
619634 "std" : METRICS_DEFAULT_VALUE , # Not implemented yet for native grids
620635 }
@@ -623,8 +638,8 @@ def _create_metrics_dict(
623638 if uxds_ref_remapped is not None and var_key in uxds_ref_remapped :
624639 var_ref_remapped = uxds_ref_remapped [var_key ]
625640 metrics_dict ["ref_regrid" ] = {
626- "min" : var_ref_remapped .min ().item (),
627- "max" : var_ref_remapped .max ().item (),
641+ "min" : [ var_ref_remapped .min ().item ()] ,
642+ "max" : [ var_ref_remapped .max ().item ()] ,
628643 "mean" : [var_ref_remapped .weighted_mean ().item ()],
629644 "std" : METRICS_DEFAULT_VALUE , # Not implemented yet for native grids
630645 }
@@ -653,8 +668,8 @@ def _create_metrics_dict(
653668 if uxds_diff is not None and var_key in uxds_diff :
654669 var_diff = uxds_diff [var_key ]
655670 metrics_dict ["diff" ] = {
656- "min" : var_diff .min ().item (),
657- "max" : var_diff .max ().item (),
671+ "min" : [ var_diff .min ().item ()] ,
672+ "max" : [ var_diff .max ().item ()] ,
658673 "mean" : [var_diff .weighted_mean ().item ()],
659674 "std" : METRICS_DEFAULT_VALUE , # Not implemented yet for native grids
660675 }
@@ -683,3 +698,72 @@ def _set_default_metric_values(metrics_dict: MetricsDict) -> MetricsDict:
683698 }
684699
685700 return metrics_dict
701+
702+
703+ def _apply_regional_subsetting (
704+ uxds : ux .UxDataset | None , var_key : str , region : str
705+ ) -> ux .UxDataset | None :
706+ """Apply regional subsetting to a uxarray dataset based on region specification.
707+
708+ This function follows the same pattern as the regional subsetting in
709+ lat_lon_native_plot.py but moves it to the driver for consistency.
710+
711+ Parameters
712+ ----------
713+ uxds : ux.UxDataset or None
714+ The uxarray dataset to subset.
715+ var_key : str
716+ The variable key to subset.
717+ region : str
718+ The region specification (e.g., "global", "CONUS", etc.).
719+
720+ Returns
721+ -------
722+ ux.UxDataset or None
723+ The regionally subsetted dataset, or None if input was None.
724+ """
725+ if uxds is None :
726+ return uxds
727+
728+ # Get region specs (same logic as in plot function)
729+ region_specs = REGION_SPECS .get (region , None )
730+
731+ if region_specs is None :
732+ # Unknown region, return original dataset
733+ logger .warning (
734+ f"Region '{ region } ' not found in REGION_SPECS. Using global dataset."
735+ )
736+ return uxds
737+
738+ # Get bounds (same logic as in plot function)
739+ lat_bounds = region_specs .get ("lat" , (- 90 , 90 )) # type: ignore
740+ lon_bounds = region_specs .get ("lon" , (0 , 360 )) # type: ignore
741+ is_global_domain = lat_bounds == (- 90 , 90 ) and lon_bounds == (0 , 360 )
742+
743+ if is_global_domain :
744+ # Global domain, no subsetting needed
745+ return uxds
746+
747+ try :
748+ # Check if target variable exists
749+ if var_key not in uxds .data_vars :
750+ logger .warning (
751+ f"Variable '{ var_key } ' not found in dataset. Available vars: { list (uxds .data_vars )} "
752+ )
753+ return uxds
754+
755+ # Apply subsetting to the specific variable
756+ var_subset = uxds [var_key ].subset .bounding_box (lon_bounds , lat_bounds )
757+
758+ # Create new dataset from subsetted variable
759+ uxds_subset = var_subset .to_dataset ()
760+ uxds_subset .attrs .update (uxds .attrs )
761+ uxds_subset [var_key ].attrs .update (uxds [var_key ].attrs )
762+ return uxds_subset
763+
764+ except Exception as e :
765+ logger .warning (
766+ f"Failed to apply regional subsetting for region '{ region } ': { e } "
767+ )
768+ logger .warning ("Using global dataset instead." )
769+ return uxds
0 commit comments