11from __future__ import annotations
22
3- from typing import TYPE_CHECKING , List
3+ from typing import TYPE_CHECKING , List , Sequence
44
55import uxarray as ux
66import xarray as xr
77
8- from e3sm_diags .driver .utils .climo_xr import ClimoFreq
98from e3sm_diags .driver .utils .dataset_xr import Dataset
109from e3sm_diags .driver .utils .regrid import (
1110 _apply_land_sea_mask ,
1918logger = _setup_child_logger (__name__ )
2019
2120if TYPE_CHECKING :
22- from e3sm_diags .parameter .lat_lon_native_parameter import LatLonNativeParameter
21+ from e3sm_diags .parameter .lat_lon_native_parameter import (
22+ LatLonNativeParameter ,
23+ TimeSelection ,
24+ )
2325
2426# The default value for metrics if it is not calculated. This value was
2527# preserved from the legacy CDAT codebase because the viewer expects this
3133def run_diag (parameter : LatLonNativeParameter ) -> LatLonNativeParameter : # noqa: C901
3234 """Get metrics for the lat_lon_native diagnostic set.
3335
34- This function loops over each variable, season, pressure level (if 3-D),
36+ This function loops over each variable, season/time_slice , pressure level (if 3-D),
3537 and region.
3638
3739 Parameters
@@ -51,10 +53,19 @@ def run_diag(parameter: LatLonNativeParameter) -> LatLonNativeParameter: # noqa
5153 (e.g., one is 2-D and the other is 3-D).
5254 """
5355 variables = parameter .variables
54- seasons = parameter .seasons
5556 ref_name = getattr (parameter , "ref_name" , "" )
5657 regions = parameter .regions
5758
59+ # Determine whether to use seasons or time_slices
60+ if len (parameter .time_slices ) > 0 :
61+ time_periods : Sequence ["TimeSelection" ] = parameter .time_slices
62+ using_time_slices = True
63+ logger .info (f"Using time_slices: { time_periods } " )
64+ else :
65+ time_periods = parameter .seasons
66+ using_time_slices = False
67+ logger .info (f"Using seasons: { time_periods } " )
68+
5869 # Variables storing xarray `Dataset` objects start with `ds_` and
5970 # variables storing e3sm_diags `Dataset` objects end with `_ds`. This
6071 # is to help distinguish both objects from each other.
@@ -65,16 +76,28 @@ def run_diag(parameter: LatLonNativeParameter) -> LatLonNativeParameter: # noqa
6576 logger .info ("Variable: {}" .format (var_key ))
6677 parameter .var_id = var_key
6778
68- for season in seasons :
69- parameter ._set_name_yrs_attrs (test_ds , ref_ds , season )
79+ for time_period in time_periods :
80+ if using_time_slices :
81+ # For time_slices, we need to pass the slice info differently
82+ logger .info (f"Processing time slice: { time_period } " )
83+ # Set up period-specific attributes for file naming
84+ parameter ._set_time_slice_attrs (test_ds , ref_ds , time_period )
85+ else :
86+ # For seasons, use existing logic
87+ logger .info (f"Processing season: { time_period } " )
88+ parameter ._set_name_yrs_attrs (test_ds , ref_ds , time_period )
7089
71- ds_test = _get_native_dataset (test_ds , var_key , season )
72- ds_ref = _get_native_dataset (ref_ds , var_key , season , allow_missing = True )
90+ # Use the same function for both cases - it will handle the logic internally
91+ ds_test = _get_native_dataset (
92+ test_ds , var_key , time_period , using_time_slices
93+ )
94+ ds_ref = _get_native_dataset (
95+ ref_ds , var_key , time_period , using_time_slices , allow_missing = True
96+ )
7397
7498 # Log basic dataset info
75- logger .debug (f"Test dataset variables: { list (ds_test .variables )} " )
76-
77-
99+ if ds_test is not None :
100+ logger .debug (f"Test dataset variables: { list (ds_test .variables )} " )
78101
79102 # Load the native grid information for test data
80103 uxds_test = None
@@ -145,7 +168,7 @@ def run_diag(parameter: LatLonNativeParameter) -> LatLonNativeParameter: # noqa
145168 if uxds_test is not None :
146169 _run_diags_2d_model_only (
147170 parameter ,
148- season ,
171+ time_period ,
149172 regions ,
150173 var_key ,
151174 ref_name ,
@@ -159,7 +182,7 @@ def run_diag(parameter: LatLonNativeParameter) -> LatLonNativeParameter: # noqa
159182 # Only handle 2D variables for now
160183 _run_diags_2d (
161184 parameter ,
162- season ,
185+ time_period ,
163186 regions ,
164187 var_key ,
165188 ref_name ,
@@ -171,7 +194,11 @@ def run_diag(parameter: LatLonNativeParameter) -> LatLonNativeParameter: # noqa
171194
172195
173196def _get_native_dataset (
174- dataset : Dataset , var_key : str , season : ClimoFreq , allow_missing : bool = False
197+ dataset : Dataset ,
198+ var_key : str ,
199+ season : "TimeSelection" ,
200+ is_time_slice : bool = False ,
201+ allow_missing : bool = False ,
175202) -> xr .Dataset | None :
176203 """Get the climatology dataset for the variable and season for native grid processing.
177204
@@ -188,8 +215,11 @@ def _get_native_dataset(
188215 The dataset object (test or reference).
189216 var_key : str
190217 The key of the variable.
191- season : CLIMO_FREQ
192- The climatology frequency.
218+ season : TimeSelection
219+ The climatology frequency or time slice string.
220+ is_time_slice : bool, optional
221+ If True, treat season as a time slice string rather than climatology frequency.
222+ Default is False.
193223 allow_missing : bool, optional
194224 If True, return None when dataset cannot be loaded instead of raising
195225 an exception. This enables model-only runs when reference data is missing.
@@ -207,30 +237,38 @@ def _get_native_dataset(
207237 If the dataset cannot be loaded and allow_missing=False.
208238 """
209239 try :
210- # Get the climatology dataset
211- ds = dataset .get_climo_dataset (var_key , season )
212-
213- # Try to get file_path from different possible sources and store it in parameter
214- file_path = None
215- if hasattr (ds , "file_path" ):
216- file_path = ds .file_path
217- elif hasattr (ds , "filepath" ):
218- file_path = ds .filepath
219- elif hasattr (ds , "_file_obj" ) and hasattr (ds ._file_obj , "name" ):
220- file_path = ds ._file_obj .name
221-
222- # Store path in parameter based on dataset type
223- if file_path :
224- if dataset .data_type == "test" and not hasattr (dataset .parameter , "test_data_file_path" ):
225- dataset .parameter .test_data_file_path = file_path
226- elif dataset .data_type == "ref" and not hasattr (dataset .parameter , "ref_data_file_path" ):
227- dataset .parameter .ref_data_file_path = file_path
240+ if is_time_slice :
241+ # For time slices, get the full dataset without averaging
242+ ds = _get_full_native_dataset (dataset , var_key )
243+ # Apply the time slice
244+ ds = _apply_time_slice (ds , season )
245+ else :
246+ # Standard climatology processing
247+ from e3sm_diags .driver .utils .climo_xr import CLIMO_FREQS
248+
249+ if season in CLIMO_FREQS :
250+ ds = dataset .get_climo_dataset (var_key , season ) # type: ignore
251+ else :
252+ raise ValueError (f"Invalid season for climatology: { season } " )
253+
254+ # Store file path in parameter for native grid processing
255+ if is_time_slice :
256+ # For time slices, we know the exact file path we used
257+ filepath = dataset ._get_climo_filepath_with_params ()
258+ if filepath :
259+ if dataset .data_type == "test" :
260+ dataset .parameter .test_data_file_path = filepath
261+ elif dataset .data_type == "ref" :
262+ dataset .parameter .ref_data_file_path = filepath
263+ # Note: For climatology case, get_climo_dataset() already handles file path storage
228264
229265 return ds
230266
231267 except (RuntimeError , IOError ) as e :
232268 if allow_missing :
233- logger .info (f"Cannot process { dataset .data_type } data: { e } . Using model-only mode." )
269+ logger .info (
270+ f"Cannot process { dataset .data_type } data: { e } . Using model-only mode."
271+ )
234272 return None
235273 else :
236274 raise
@@ -521,7 +559,115 @@ def _compare_grids(uxds_test, uxds_ref):
521559
522560 except Exception as e :
523561 logger .warning (f"Error comparing grids: { e } " )
524- return False , 0 , 0
562+ return False
563+
564+
565+ def _get_full_native_dataset (dataset : Dataset , var_key : str ) -> xr .Dataset :
566+ """Get the full native dataset without any time averaging.
567+
568+ This function uses the dataset's file path parameters to directly open
569+ the raw data file for time slicing operations.
570+
571+ Parameters
572+ ----------
573+ dataset : Dataset
574+ The dataset object (test or reference).
575+ var_key : str
576+ The key of the variable.
577+
578+ Returns
579+ -------
580+ xr.Dataset
581+ The full dataset with all time steps.
582+
583+ Raises
584+ ------
585+ RuntimeError
586+ If unable to get the full dataset.
587+ """
588+ import os
589+
590+ import xarray as xr
591+
592+ # Get the file path using the parameter-based method
593+ filepath = dataset ._get_climo_filepath_with_params ()
594+
595+ if filepath is None :
596+ raise RuntimeError (
597+ f"Unable to get file path for { dataset .data_type } dataset. "
598+ f"For time slicing, please ensure that "
599+ f"{ 'ref_file' if dataset .data_type == 'ref' else 'test_file' } parameter is set."
600+ )
601+
602+ if not os .path .exists (filepath ):
603+ raise RuntimeError (f"File not found: { filepath } " )
604+
605+ logger .info (f"Opening full native dataset from: { filepath } " )
606+
607+ try :
608+ # Open the dataset directly without any averaging
609+ ds = xr .open_dataset (filepath , decode_times = True )
610+ logger .info (
611+ f"Successfully opened dataset with time dimension size: { ds .sizes .get ('time' , 'N/A' )} "
612+ )
613+ return ds
614+
615+ except Exception as e :
616+ raise RuntimeError (f"Failed to open dataset { filepath } : { e } " ) from e
617+
618+
619+ def _apply_time_slice (ds : xr .Dataset , time_slice : str ) -> xr .Dataset :
620+ """Apply time slice selection to a dataset.
621+
622+ Parameters
623+ ----------
624+ ds : xr.Dataset
625+ The input dataset with time dimension.
626+ time_slice : str
627+ The time slice specification (e.g., "0:10:2", "5:15", "7").
628+
629+ Returns
630+ -------
631+ xr.Dataset
632+ The dataset with time slice applied.
633+ """
634+
635+ # Parse the time slice string
636+ time_dim = None
637+ for dim in ds .dims :
638+ if str (dim ).lower () in ["time" , "t" ]:
639+ time_dim = dim
640+ break
641+
642+ if time_dim is None :
643+ logger .warning (
644+ "No time dimension found in dataset. Returning original dataset."
645+ )
646+ return ds
647+
648+ # Parse slice notation
649+ if ":" in time_slice :
650+ # Handle slice notation like "0:10:2" or "5:15" or ":10" or "5:" or "::2"
651+ parts = time_slice .split (":" )
652+
653+ start = int (parts [0 ]) if parts [0 ] else None
654+ end = int (parts [1 ]) if len (parts ) > 1 and parts [1 ] else None
655+ step = int (parts [2 ]) if len (parts ) > 2 and parts [2 ] else None
656+
657+ # Apply the slice
658+ ds_sliced = ds .isel ({time_dim : slice (start , end , step )})
659+ else :
660+ # Single index
661+ index = int (time_slice )
662+ ds_sliced = ds .isel ({time_dim : index })
663+
664+ logger .info (
665+ f"Applied time slice '{ time_slice } ' to dataset. "
666+ f"Original time length: { ds .sizes [time_dim ]} , "
667+ f"Sliced time length: { ds_sliced .sizes .get (time_dim , 1 )} "
668+ )
669+
670+ return ds_sliced
525671
526672
527673def _compute_direct_difference (uxds_test , uxds_ref , var_key ):
@@ -748,7 +894,7 @@ def _get_matching_src_vars(dataset, target_var_map):
748894def _apply_derivation_func (dataset , func , src_var_keys , target_var_key ):
749895 """Apply derivation function following dataset_xr pattern."""
750896 from e3sm_diags .derivations .derivations import FUNC_NEEDS_TARGET_VAR
751-
897+
752898 func_args = [dataset [var ] for var in src_var_keys ]
753899
754900 if func in FUNC_NEEDS_TARGET_VAR :
@@ -762,16 +908,14 @@ def _apply_derivation_func(dataset, func, src_var_keys, target_var_key):
762908def _process_variable_derivations (dataset , var_key , dataset_name = "" ):
763909 """Process variable derivations following dataset_xr approach."""
764910 from e3sm_diags .derivations .derivations import DERIVED_VARIABLES
765-
911+
766912 name_suffix = f" in { dataset_name } dataset" if dataset_name else ""
767913
768914 # Follow dataset_xr._get_climo_dataset logic:
769915 # 1. If var is in derived_vars_map, try to derive it
770916 if var_key in DERIVED_VARIABLES :
771917 target_var_map = DERIVED_VARIABLES [var_key ]
772- matching_target_var_map = _get_matching_src_vars (
773- dataset , target_var_map
774- )
918+ matching_target_var_map = _get_matching_src_vars (dataset , target_var_map )
775919
776920 if matching_target_var_map is not None :
777921 # Get derivation function and source variables
@@ -783,14 +927,10 @@ def _process_variable_derivations(dataset, var_key, dataset_name=""):
783927 )
784928
785929 try :
786- _apply_derivation_func (
787- dataset , derivation_func , src_var_keys , var_key
788- )
930+ _apply_derivation_func (dataset , derivation_func , src_var_keys , var_key )
789931 return True
790932 except Exception as e :
791- logger .warning (
792- f"Failed to derive { var_key } { name_suffix } : { e } "
793- )
933+ logger .warning (f"Failed to derive { var_key } { name_suffix } : { e } " )
794934 # Fall through to check if variable exists directly
795935
796936 # 2. Check if variable exists directly in dataset
0 commit comments