55import xarray as xr
66
77from e3sm_diags .driver import METRICS_DEFAULT_VALUE
8- from e3sm_diags .driver .utils .climo_xr import ClimoFreq
98from e3sm_diags .driver .utils .dataset_xr import Dataset
109from e3sm_diags .driver .utils .io import _save_data_metrics_and_plots
1110from e3sm_diags .driver .utils .regrid import (
1615 regrid_z_axis_to_plevs ,
1716 subset_and_align_datasets ,
1817)
19- from e3sm_diags .driver .utils .time_slice import check_time_selection
20- from e3sm_diags .driver .utils .type_annotations import MetricsDict
2118from e3sm_diags .logger import _setup_child_logger
2219from e3sm_diags .metrics .metrics import correlation , rmse , spatial_avg , std
2320from e3sm_diags .plot .lat_lon_plot import plot as plot_func
2421
2522logger = _setup_child_logger (__name__ )
2623
2724if TYPE_CHECKING :
25+ from e3sm_diags .driver .utils .type_annotations import MetricsDict , TimeSelection
2826 from e3sm_diags .parameter .core_parameter import CoreParameter
2927
3028
@@ -51,57 +49,44 @@ def run_diag(parameter: CoreParameter) -> CoreParameter:
5149 (e.g., one is 2-D and the other is 3-D).
5250 """
5351 variables = parameter .variables
54- seasons = parameter .seasons
55- time_slices = getattr (parameter , "time_slices" , [])
5652 ref_name = getattr (parameter , "ref_name" , "" )
5753 regions = parameter .regions
5854
59- # Check that either seasons or time_slices is specified, but not both
60- has_seasons , has_time_slices = check_time_selection (
61- seasons , time_slices , require_one = True
62- )
63-
64- # Determine which time selection to use
65- time_selections = time_slices if has_time_slices else seasons
66- is_time_slice_mode = has_time_slices
67-
6855 # Variables storing xarray `Dataset` objects start with `ds_` and
6956 # variables storing e3sm_diags `Dataset` objects end with `_ds`. This
7057 # is to help distinguish both objects from each other.
7158 test_ds = Dataset (parameter , data_type = "test" )
7259 ref_ds = Dataset (parameter , data_type = "ref" )
7360
61+ time_selection_type , time_selections = parameter ._get_time_selection_to_use ()
62+
7463 for var_key in variables :
7564 logger .info ("Variable: {}" .format (var_key ))
7665 parameter .var_id = var_key
7766
7867 for time_selection in time_selections :
79- # Set name/yrs attributes based on time selection mode
80- if is_time_slice_mode :
81- # For time slices, we set attributes after loading data
82- # since we need the actual time coordinates
83- pass
84- else :
85- parameter ._set_name_yrs_attrs (test_ds , ref_ds , time_selection )
68+ is_time_slice = time_selection_type == "time_slices"
8669
87- # Get datasets - pass is_time_slice flag if it's a time slice
88- if is_time_slice_mode :
89- ds_test = test_ds .get_climo_dataset (
90- var_key , time_selection , is_time_slice = True
91- )
92- ds_ref = _get_ref_dataset (
93- ref_ds , var_key , time_selection , is_time_slice = True
94- )
95- # For time slices, set name_yrs after data is loaded
96- parameter ._set_name_yrs_attrs (test_ds , ref_ds , time_selection )
97- # Use the climatology season (ANN) for land sea mask
70+ # Get test and reference datasets.
71+ # NOTE: lat_lon diagnostics get reference datasets differently than
72+ # other sets using its own helper function `_get_ref_dataset`.
73+ if is_time_slice :
74+ ds_test = test_ds .get_time_sliced_dataset (var_key , time_selection )
75+
76+ # For time slices, always use the annual land-sea mask.
9877 ds_land_sea_mask : xr .Dataset = test_ds ._get_land_sea_mask ("ANN" )
9978 else :
100- ds_test = test_ds .get_climo_dataset (var_key , time_selection )
101- ds_ref = _get_ref_dataset (
102- ref_ds , var_key , time_selection , is_time_slice = False
79+ # time_selection will be ClimoFreq, so ignore type checking here.
80+ ds_test = test_ds .get_climo_dataset (var_key , time_selection ) # type: ignore[arg-type]
81+ ds_land_sea_mask : xr .Dataset = test_ds ._get_land_sea_mask ( # type: ignore[no-redef]
82+ time_selection # type: ignore[arg-type]
10383 )
104- ds_land_sea_mask = test_ds ._get_land_sea_mask (time_selection )
84+
85+ ds_ref = _get_ref_dataset (ref_ds , var_key , time_selection , is_time_slice )
86+
87+ # Set name_yrs after loading data because time sliced datasets
88+ # have the required attributes only after loading the data.
89+ parameter ._set_name_yrs_attrs (test_ds , ref_ds , time_selection )
10590
10691 if ds_ref is None :
10792 is_vars_3d = has_z_axis (ds_test [var_key ])
@@ -317,7 +302,7 @@ def _run_diags_2d(
317302 ds_test : xr .Dataset ,
318303 ds_ref : xr .Dataset ,
319304 ds_land_sea_mask : xr .Dataset ,
320- season : str ,
305+ time_selection : TimeSelection ,
321306 regions : list [str ],
322307 var_key : str ,
323308 ref_name : str ,
@@ -338,8 +323,8 @@ def _run_diags_2d(
338323 ds_land_sea_mask : xr.Dataset
339324 The land sea mask dataset, which is only used for masking if the region
340325 is "land" or "ocean".
341- season : str
342- The season.
326+ time_selection : TimeSelection
327+ The time slice or season.
343328 regions : list[str]
344329 The list of regions.
345330 var_key : str
@@ -372,7 +357,9 @@ def _run_diags_2d(
372357 ds_diff_region ,
373358 )
374359
375- parameter ._set_param_output_attrs (var_key , season , region , ref_name , ilev = None )
360+ parameter ._set_param_output_attrs (
361+ var_key , time_selection , region , ref_name , ilev = None
362+ )
376363 _save_data_metrics_and_plots (
377364 parameter ,
378365 plot_func ,
@@ -391,7 +378,7 @@ def _run_diags_3d(
391378 ds_test : xr .Dataset ,
392379 ds_ref : xr .Dataset ,
393380 ds_land_sea_mask : xr .Dataset ,
394- season : str ,
381+ time_selection : str ,
395382 regions : list [str ],
396383 var_key : str ,
397384 ref_name : str ,
@@ -412,8 +399,8 @@ def _run_diags_3d(
412399 ds_land_sea_mask : xr.Dataset
413400 The land sea mask dataset, which is only used for masking if the region
414401 is "land" or "ocean".
415- season : str
416- The season.
402+ time_selection : TimeSelection
403+ The time slice or season.
417404 regions : list[str]
418405 The list of regions.
419406 var_key : str
@@ -457,7 +444,9 @@ def _run_diags_3d(
457444 ds_diff_region ,
458445 )
459446
460- parameter ._set_param_output_attrs (var_key , season , region , ref_name , ilev )
447+ parameter ._set_param_output_attrs (
448+ var_key , time_selection , region , ref_name , ilev
449+ )
461450 _save_data_metrics_and_plots (
462451 parameter ,
463452 plot_func ,
@@ -474,7 +463,7 @@ def _run_diags_3d(
474463def _get_ref_dataset (
475464 dataset : Dataset ,
476465 var_key : str ,
477- time_selection : ClimoFreq | str ,
466+ time_selection : TimeSelection ,
478467 is_time_slice : bool = False ,
479468) -> xr .Dataset | None :
480469 """Get the reference dataset for the variable and time selection.
@@ -506,9 +495,11 @@ def _get_ref_dataset(
506495 """
507496 if dataset .data_type == "ref" :
508497 try :
509- ds_ref = dataset .get_climo_dataset (
510- var_key , time_selection , is_time_slice = is_time_slice
511- )
498+ if is_time_slice :
499+ ds_ref = dataset .get_time_sliced_dataset (var_key , time_selection )
500+ else :
501+ # time_selection will be ClimoFreq, so ignore type checking here.
502+ ds_ref = dataset .get_climo_dataset (var_key , time_selection ) # type: ignore[arg-type]
512503 except (RuntimeError , IOError ):
513504 ds_ref = None
514505
0 commit comments