Skip to content

Commit 8fa1800

Browse files
committed
add time snapshot selection function for visualize native history file
1 parent 8f9521e commit 8fa1800

File tree

2 files changed

+329
-49
lines changed

2 files changed

+329
-49
lines changed

e3sm_diags/driver/lat_lon_native_driver.py

Lines changed: 188 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,10 @@
11
from __future__ import annotations
22

3-
from typing import TYPE_CHECKING, List
3+
from typing import TYPE_CHECKING, List, Sequence
44

55
import uxarray as ux
66
import xarray as xr
77

8-
from e3sm_diags.driver.utils.climo_xr import ClimoFreq
98
from e3sm_diags.driver.utils.dataset_xr import Dataset
109
from e3sm_diags.driver.utils.regrid import (
1110
_apply_land_sea_mask,
@@ -19,7 +18,10 @@
1918
logger = _setup_child_logger(__name__)
2019

2120
if TYPE_CHECKING:
22-
from e3sm_diags.parameter.lat_lon_native_parameter import LatLonNativeParameter
21+
from e3sm_diags.parameter.lat_lon_native_parameter import (
22+
LatLonNativeParameter,
23+
TimeSelection,
24+
)
2325

2426
# The default value for metrics if it is not calculated. This value was
2527
# preserved from the legacy CDAT codebase because the viewer expects this
@@ -31,7 +33,7 @@
3133
def run_diag(parameter: LatLonNativeParameter) -> LatLonNativeParameter: # noqa: C901
3234
"""Get metrics for the lat_lon_native diagnostic set.
3335
34-
This function loops over each variable, season, pressure level (if 3-D),
36+
This function loops over each variable, season/time_slice, pressure level (if 3-D),
3537
and region.
3638
3739
Parameters
@@ -51,10 +53,19 @@ def run_diag(parameter: LatLonNativeParameter) -> LatLonNativeParameter: # noqa
5153
(e.g., one is 2-D and the other is 3-D).
5254
"""
5355
variables = parameter.variables
54-
seasons = parameter.seasons
5556
ref_name = getattr(parameter, "ref_name", "")
5657
regions = parameter.regions
5758

59+
# Determine whether to use seasons or time_slices
60+
if len(parameter.time_slices) > 0:
61+
time_periods: Sequence["TimeSelection"] = parameter.time_slices
62+
using_time_slices = True
63+
logger.info(f"Using time_slices: {time_periods}")
64+
else:
65+
time_periods = parameter.seasons
66+
using_time_slices = False
67+
logger.info(f"Using seasons: {time_periods}")
68+
5869
# Variables storing xarray `Dataset` objects start with `ds_` and
5970
# variables storing e3sm_diags `Dataset` objects end with `_ds`. This
6071
# is to help distinguish both objects from each other.
@@ -65,16 +76,28 @@ def run_diag(parameter: LatLonNativeParameter) -> LatLonNativeParameter: # noqa
6576
logger.info("Variable: {}".format(var_key))
6677
parameter.var_id = var_key
6778

68-
for season in seasons:
69-
parameter._set_name_yrs_attrs(test_ds, ref_ds, season)
79+
for time_period in time_periods:
80+
if using_time_slices:
81+
# For time_slices, we need to pass the slice info differently
82+
logger.info(f"Processing time slice: {time_period}")
83+
# Set up period-specific attributes for file naming
84+
parameter._set_time_slice_attrs(test_ds, ref_ds, time_period)
85+
else:
86+
# For seasons, use existing logic
87+
logger.info(f"Processing season: {time_period}")
88+
parameter._set_name_yrs_attrs(test_ds, ref_ds, time_period)
7089

71-
ds_test = _get_native_dataset(test_ds, var_key, season)
72-
ds_ref = _get_native_dataset(ref_ds, var_key, season, allow_missing=True)
90+
# Use the same function for both cases - it will handle the logic internally
91+
ds_test = _get_native_dataset(
92+
test_ds, var_key, time_period, using_time_slices
93+
)
94+
ds_ref = _get_native_dataset(
95+
ref_ds, var_key, time_period, using_time_slices, allow_missing=True
96+
)
7397

7498
# Log basic dataset info
75-
logger.debug(f"Test dataset variables: {list(ds_test.variables)}")
76-
77-
99+
if ds_test is not None:
100+
logger.debug(f"Test dataset variables: {list(ds_test.variables)}")
78101

79102
# Load the native grid information for test data
80103
uxds_test = None
@@ -145,7 +168,7 @@ def run_diag(parameter: LatLonNativeParameter) -> LatLonNativeParameter: # noqa
145168
if uxds_test is not None:
146169
_run_diags_2d_model_only(
147170
parameter,
148-
season,
171+
time_period,
149172
regions,
150173
var_key,
151174
ref_name,
@@ -159,7 +182,7 @@ def run_diag(parameter: LatLonNativeParameter) -> LatLonNativeParameter: # noqa
159182
# Only handle 2D variables for now
160183
_run_diags_2d(
161184
parameter,
162-
season,
185+
time_period,
163186
regions,
164187
var_key,
165188
ref_name,
@@ -171,7 +194,11 @@ def run_diag(parameter: LatLonNativeParameter) -> LatLonNativeParameter: # noqa
171194

172195

173196
def _get_native_dataset(
174-
dataset: Dataset, var_key: str, season: ClimoFreq, allow_missing: bool = False
197+
dataset: Dataset,
198+
var_key: str,
199+
season: "TimeSelection",
200+
is_time_slice: bool = False,
201+
allow_missing: bool = False,
175202
) -> xr.Dataset | None:
176203
"""Get the climatology dataset for the variable and season for native grid processing.
177204
@@ -188,8 +215,11 @@ def _get_native_dataset(
188215
The dataset object (test or reference).
189216
var_key : str
190217
The key of the variable.
191-
season : CLIMO_FREQ
192-
The climatology frequency.
218+
season : TimeSelection
219+
The climatology frequency or time slice string.
220+
is_time_slice : bool, optional
221+
If True, treat season as a time slice string rather than climatology frequency.
222+
Default is False.
193223
allow_missing : bool, optional
194224
If True, return None when dataset cannot be loaded instead of raising
195225
an exception. This enables model-only runs when reference data is missing.
@@ -207,30 +237,38 @@ def _get_native_dataset(
207237
If the dataset cannot be loaded and allow_missing=False.
208238
"""
209239
try:
210-
# Get the climatology dataset
211-
ds = dataset.get_climo_dataset(var_key, season)
212-
213-
# Try to get file_path from different possible sources and store it in parameter
214-
file_path = None
215-
if hasattr(ds, "file_path"):
216-
file_path = ds.file_path
217-
elif hasattr(ds, "filepath"):
218-
file_path = ds.filepath
219-
elif hasattr(ds, "_file_obj") and hasattr(ds._file_obj, "name"):
220-
file_path = ds._file_obj.name
221-
222-
# Store path in parameter based on dataset type
223-
if file_path:
224-
if dataset.data_type == "test" and not hasattr(dataset.parameter, "test_data_file_path"):
225-
dataset.parameter.test_data_file_path = file_path
226-
elif dataset.data_type == "ref" and not hasattr(dataset.parameter, "ref_data_file_path"):
227-
dataset.parameter.ref_data_file_path = file_path
240+
if is_time_slice:
241+
# For time slices, get the full dataset without averaging
242+
ds = _get_full_native_dataset(dataset, var_key)
243+
# Apply the time slice
244+
ds = _apply_time_slice(ds, season)
245+
else:
246+
# Standard climatology processing
247+
from e3sm_diags.driver.utils.climo_xr import CLIMO_FREQS
248+
249+
if season in CLIMO_FREQS:
250+
ds = dataset.get_climo_dataset(var_key, season) # type: ignore
251+
else:
252+
raise ValueError(f"Invalid season for climatology: {season}")
253+
254+
# Store file path in parameter for native grid processing
255+
if is_time_slice:
256+
# For time slices, we know the exact file path we used
257+
filepath = dataset._get_climo_filepath_with_params()
258+
if filepath:
259+
if dataset.data_type == "test":
260+
dataset.parameter.test_data_file_path = filepath
261+
elif dataset.data_type == "ref":
262+
dataset.parameter.ref_data_file_path = filepath
263+
# Note: For climatology case, get_climo_dataset() already handles file path storage
228264

229265
return ds
230266

231267
except (RuntimeError, IOError) as e:
232268
if allow_missing:
233-
logger.info(f"Cannot process {dataset.data_type} data: {e}. Using model-only mode.")
269+
logger.info(
270+
f"Cannot process {dataset.data_type} data: {e}. Using model-only mode."
271+
)
234272
return None
235273
else:
236274
raise
@@ -521,7 +559,115 @@ def _compare_grids(uxds_test, uxds_ref):
521559

522560
except Exception as e:
523561
logger.warning(f"Error comparing grids: {e}")
524-
return False, 0, 0
562+
return False
563+
564+
565+
def _get_full_native_dataset(dataset: Dataset, var_key: str) -> xr.Dataset:
566+
"""Get the full native dataset without any time averaging.
567+
568+
This function uses the dataset's file path parameters to directly open
569+
the raw data file for time slicing operations.
570+
571+
Parameters
572+
----------
573+
dataset : Dataset
574+
The dataset object (test or reference).
575+
var_key : str
576+
The key of the variable.
577+
578+
Returns
579+
-------
580+
xr.Dataset
581+
The full dataset with all time steps.
582+
583+
Raises
584+
------
585+
RuntimeError
586+
If unable to get the full dataset.
587+
"""
588+
import os
589+
590+
import xarray as xr
591+
592+
# Get the file path using the parameter-based method
593+
filepath = dataset._get_climo_filepath_with_params()
594+
595+
if filepath is None:
596+
raise RuntimeError(
597+
f"Unable to get file path for {dataset.data_type} dataset. "
598+
f"For time slicing, please ensure that "
599+
f"{'ref_file' if dataset.data_type == 'ref' else 'test_file'} parameter is set."
600+
)
601+
602+
if not os.path.exists(filepath):
603+
raise RuntimeError(f"File not found: {filepath}")
604+
605+
logger.info(f"Opening full native dataset from: {filepath}")
606+
607+
try:
608+
# Open the dataset directly without any averaging
609+
ds = xr.open_dataset(filepath, decode_times=True)
610+
logger.info(
611+
f"Successfully opened dataset with time dimension size: {ds.sizes.get('time', 'N/A')}"
612+
)
613+
return ds
614+
615+
except Exception as e:
616+
raise RuntimeError(f"Failed to open dataset {filepath}: {e}") from e
617+
618+
619+
def _apply_time_slice(ds: xr.Dataset, time_slice: str) -> xr.Dataset:
620+
"""Apply time slice selection to a dataset.
621+
622+
Parameters
623+
----------
624+
ds : xr.Dataset
625+
The input dataset with time dimension.
626+
time_slice : str
627+
The time slice specification (e.g., "0:10:2", "5:15", "7").
628+
629+
Returns
630+
-------
631+
xr.Dataset
632+
The dataset with time slice applied.
633+
"""
634+
635+
# Parse the time slice string
636+
time_dim = None
637+
for dim in ds.dims:
638+
if str(dim).lower() in ["time", "t"]:
639+
time_dim = dim
640+
break
641+
642+
if time_dim is None:
643+
logger.warning(
644+
"No time dimension found in dataset. Returning original dataset."
645+
)
646+
return ds
647+
648+
# Parse slice notation
649+
if ":" in time_slice:
650+
# Handle slice notation like "0:10:2" or "5:15" or ":10" or "5:" or "::2"
651+
parts = time_slice.split(":")
652+
653+
start = int(parts[0]) if parts[0] else None
654+
end = int(parts[1]) if len(parts) > 1 and parts[1] else None
655+
step = int(parts[2]) if len(parts) > 2 and parts[2] else None
656+
657+
# Apply the slice
658+
ds_sliced = ds.isel({time_dim: slice(start, end, step)})
659+
else:
660+
# Single index
661+
index = int(time_slice)
662+
ds_sliced = ds.isel({time_dim: index})
663+
664+
logger.info(
665+
f"Applied time slice '{time_slice}' to dataset. "
666+
f"Original time length: {ds.sizes[time_dim]}, "
667+
f"Sliced time length: {ds_sliced.sizes.get(time_dim, 1)}"
668+
)
669+
670+
return ds_sliced
525671

526672

527673
def _compute_direct_difference(uxds_test, uxds_ref, var_key):
@@ -748,7 +894,7 @@ def _get_matching_src_vars(dataset, target_var_map):
748894
def _apply_derivation_func(dataset, func, src_var_keys, target_var_key):
749895
"""Apply derivation function following dataset_xr pattern."""
750896
from e3sm_diags.derivations.derivations import FUNC_NEEDS_TARGET_VAR
751-
897+
752898
func_args = [dataset[var] for var in src_var_keys]
753899

754900
if func in FUNC_NEEDS_TARGET_VAR:
@@ -762,16 +908,14 @@ def _apply_derivation_func(dataset, func, src_var_keys, target_var_key):
762908
def _process_variable_derivations(dataset, var_key, dataset_name=""):
763909
"""Process variable derivations following dataset_xr approach."""
764910
from e3sm_diags.derivations.derivations import DERIVED_VARIABLES
765-
911+
766912
name_suffix = f" in {dataset_name} dataset" if dataset_name else ""
767913

768914
# Follow dataset_xr._get_climo_dataset logic:
769915
# 1. If var is in derived_vars_map, try to derive it
770916
if var_key in DERIVED_VARIABLES:
771917
target_var_map = DERIVED_VARIABLES[var_key]
772-
matching_target_var_map = _get_matching_src_vars(
773-
dataset, target_var_map
774-
)
918+
matching_target_var_map = _get_matching_src_vars(dataset, target_var_map)
775919

776920
if matching_target_var_map is not None:
777921
# Get derivation function and source variables
@@ -783,14 +927,10 @@ def _process_variable_derivations(dataset, var_key, dataset_name=""):
783927
)
784928

785929
try:
786-
_apply_derivation_func(
787-
dataset, derivation_func, src_var_keys, var_key
788-
)
930+
_apply_derivation_func(dataset, derivation_func, src_var_keys, var_key)
789931
return True
790932
except Exception as e:
791-
logger.warning(
792-
f"Failed to derive {var_key}{name_suffix}: {e}"
793-
)
933+
logger.warning(f"Failed to derive {var_key}{name_suffix}: {e}")
794934
# Fall through to check if variable exists directly
795935

796936
# 2. Check if variable exists directly in dataset

0 commit comments

Comments
 (0)