Skip to content

Commit f0e5426

Browse files
committed
Refactor implementation of I/O for retrieving time-sliced datasets
- Refactor drivers for polar, meriodional_mean_2d, zonal_mean_2d, zonal_mean_xy to use `driver.utils.io._get_xarray_datasets()` - Add `driver.io._get_xarray_datasets()` utility to simplify fetching of xarray datasets based on time selection - Update references to `season` to `time_selection` with `TimeSelection` annotation - Move `_set_time_slice_name_yrs_attrs()` back to `LatLonNativeParameter` because it is only used there -- consider refactoring later - Remove `time_slice.py` as this functions were converted to `CoreParameter` methods
1 parent 1d9c55e commit f0e5426

File tree

13 files changed

+493
-523
lines changed

13 files changed

+493
-523
lines changed

e3sm_diags/driver/lat_lon_driver.py

Lines changed: 39 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
import xarray as xr
66

77
from e3sm_diags.driver import METRICS_DEFAULT_VALUE
8-
from e3sm_diags.driver.utils.climo_xr import ClimoFreq
98
from e3sm_diags.driver.utils.dataset_xr import Dataset
109
from e3sm_diags.driver.utils.io import _save_data_metrics_and_plots
1110
from e3sm_diags.driver.utils.regrid import (
@@ -16,15 +15,14 @@
1615
regrid_z_axis_to_plevs,
1716
subset_and_align_datasets,
1817
)
19-
from e3sm_diags.driver.utils.time_slice import check_time_selection
20-
from e3sm_diags.driver.utils.type_annotations import MetricsDict
2118
from e3sm_diags.logger import _setup_child_logger
2219
from e3sm_diags.metrics.metrics import correlation, rmse, spatial_avg, std
2320
from e3sm_diags.plot.lat_lon_plot import plot as plot_func
2421

2522
logger = _setup_child_logger(__name__)
2623

2724
if TYPE_CHECKING:
25+
from e3sm_diags.driver.utils.type_annotations import MetricsDict, TimeSelection
2826
from e3sm_diags.parameter.core_parameter import CoreParameter
2927

3028

@@ -51,57 +49,44 @@ def run_diag(parameter: CoreParameter) -> CoreParameter:
5149
(e.g., one is 2-D and the other is 3-D).
5250
"""
5351
variables = parameter.variables
54-
seasons = parameter.seasons
55-
time_slices = getattr(parameter, "time_slices", [])
5652
ref_name = getattr(parameter, "ref_name", "")
5753
regions = parameter.regions
5854

59-
# Check that either seasons or time_slices is specified, but not both
60-
has_seasons, has_time_slices = check_time_selection(
61-
seasons, time_slices, require_one=True
62-
)
63-
64-
# Determine which time selection to use
65-
time_selections = time_slices if has_time_slices else seasons
66-
is_time_slice_mode = has_time_slices
67-
6855
# Variables storing xarray `Dataset` objects start with `ds_` and
6956
# variables storing e3sm_diags `Dataset` objects end with `_ds`. This
7057
# is to help distinguish both objects from each other.
7158
test_ds = Dataset(parameter, data_type="test")
7259
ref_ds = Dataset(parameter, data_type="ref")
7360

61+
time_selection_type, time_selections = parameter._get_time_selection_to_use()
62+
7463
for var_key in variables:
7564
logger.info("Variable: {}".format(var_key))
7665
parameter.var_id = var_key
7766

7867
for time_selection in time_selections:
79-
# Set name/yrs attributes based on time selection mode
80-
if is_time_slice_mode:
81-
# For time slices, we set attributes after loading data
82-
# since we need the actual time coordinates
83-
pass
84-
else:
85-
parameter._set_name_yrs_attrs(test_ds, ref_ds, time_selection)
68+
is_time_slice = time_selection_type == "time_slices"
8669

87-
# Get datasets - pass is_time_slice flag if it's a time slice
88-
if is_time_slice_mode:
89-
ds_test = test_ds.get_climo_dataset(
90-
var_key, time_selection, is_time_slice=True
91-
)
92-
ds_ref = _get_ref_dataset(
93-
ref_ds, var_key, time_selection, is_time_slice=True
94-
)
95-
# For time slices, set name_yrs after data is loaded
96-
parameter._set_name_yrs_attrs(test_ds, ref_ds, time_selection)
97-
# Use the climatology season (ANN) for land sea mask
70+
# Get test and reference datasets.
71+
# NOTE: lat_lon diagnostics get reference datasets differently than
72+
# other sets using its own helper function `_get_ref_dataset`.
73+
if is_time_slice:
74+
ds_test = test_ds.get_time_sliced_dataset(var_key, time_selection)
75+
76+
# For time slices, always use the annual land-sea mask.
9877
ds_land_sea_mask: xr.Dataset = test_ds._get_land_sea_mask("ANN")
9978
else:
100-
ds_test = test_ds.get_climo_dataset(var_key, time_selection)
101-
ds_ref = _get_ref_dataset(
102-
ref_ds, var_key, time_selection, is_time_slice=False
79+
# time_selection will be ClimoFreq, so ignore type checking here.
80+
ds_test = test_ds.get_climo_dataset(var_key, time_selection) # type: ignore[arg-type]
81+
ds_land_sea_mask: xr.Dataset = test_ds._get_land_sea_mask( # type: ignore[no-redef]
82+
time_selection # type: ignore[arg-type]
10383
)
104-
ds_land_sea_mask = test_ds._get_land_sea_mask(time_selection)
84+
85+
ds_ref = _get_ref_dataset(ref_ds, var_key, time_selection, is_time_slice)
86+
87+
# Set name_yrs after loading data because time sliced datasets
88+
# have the required attributes only after loading the data.
89+
parameter._set_name_yrs_attrs(test_ds, ref_ds, time_selection)
10590

10691
if ds_ref is None:
10792
is_vars_3d = has_z_axis(ds_test[var_key])
@@ -317,7 +302,7 @@ def _run_diags_2d(
317302
ds_test: xr.Dataset,
318303
ds_ref: xr.Dataset,
319304
ds_land_sea_mask: xr.Dataset,
320-
season: str,
305+
time_selection: TimeSelection,
321306
regions: list[str],
322307
var_key: str,
323308
ref_name: str,
@@ -338,8 +323,8 @@ def _run_diags_2d(
338323
ds_land_sea_mask : xr.Dataset
339324
The land sea mask dataset, which is only used for masking if the region
340325
is "land" or "ocean".
341-
season : str
342-
The season.
326+
time_selection : TimeSelection
327+
The time slice or season.
343328
regions : list[str]
344329
The list of regions.
345330
var_key : str
@@ -372,7 +357,9 @@ def _run_diags_2d(
372357
ds_diff_region,
373358
)
374359

375-
parameter._set_param_output_attrs(var_key, season, region, ref_name, ilev=None)
360+
parameter._set_param_output_attrs(
361+
var_key, time_selection, region, ref_name, ilev=None
362+
)
376363
_save_data_metrics_and_plots(
377364
parameter,
378365
plot_func,
@@ -391,7 +378,7 @@ def _run_diags_3d(
391378
ds_test: xr.Dataset,
392379
ds_ref: xr.Dataset,
393380
ds_land_sea_mask: xr.Dataset,
394-
season: str,
381+
time_selection: str,
395382
regions: list[str],
396383
var_key: str,
397384
ref_name: str,
@@ -412,8 +399,8 @@ def _run_diags_3d(
412399
ds_land_sea_mask : xr.Dataset
413400
The land sea mask dataset, which is only used for masking if the region
414401
is "land" or "ocean".
415-
season : str
416-
The season.
402+
time_selection : TimeSelection
403+
The time slice or season.
417404
regions : list[str]
418405
The list of regions.
419406
var_key : str
@@ -457,7 +444,9 @@ def _run_diags_3d(
457444
ds_diff_region,
458445
)
459446

460-
parameter._set_param_output_attrs(var_key, season, region, ref_name, ilev)
447+
parameter._set_param_output_attrs(
448+
var_key, time_selection, region, ref_name, ilev
449+
)
461450
_save_data_metrics_and_plots(
462451
parameter,
463452
plot_func,
@@ -474,7 +463,7 @@ def _run_diags_3d(
474463
def _get_ref_dataset(
475464
dataset: Dataset,
476465
var_key: str,
477-
time_selection: ClimoFreq | str,
466+
time_selection: TimeSelection,
478467
is_time_slice: bool = False,
479468
) -> xr.Dataset | None:
480469
"""Get the reference dataset for the variable and time selection.
@@ -506,9 +495,11 @@ def _get_ref_dataset(
506495
"""
507496
if dataset.data_type == "ref":
508497
try:
509-
ds_ref = dataset.get_climo_dataset(
510-
var_key, time_selection, is_time_slice=is_time_slice
511-
)
498+
if is_time_slice:
499+
ds_ref = dataset.get_time_sliced_dataset(var_key, time_selection)
500+
else:
501+
# time_selection will be ClimoFreq, so ignore type checking here.
502+
ds_ref = dataset.get_climo_dataset(var_key, time_selection) # type: ignore[arg-type]
512503
except (RuntimeError, IOError):
513504
ds_ref = None
514505

e3sm_diags/driver/meridional_mean_2d_driver.py

Lines changed: 16 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,15 @@
77
import xcdat as xc
88

99
from e3sm_diags.driver.utils.dataset_xr import Dataset
10-
from e3sm_diags.driver.utils.io import _save_data_metrics_and_plots
10+
from e3sm_diags.driver.utils.io import (
11+
_get_xarray_datasets,
12+
_save_data_metrics_and_plots,
13+
)
1114
from e3sm_diags.driver.utils.regrid import (
1215
align_grids_to_lower_res,
1316
has_z_axis,
1417
regrid_z_axis_to_plevs,
1518
)
16-
from e3sm_diags.driver.utils.time_slice import check_time_selection
17-
from e3sm_diags.driver.utils.type_annotations import MetricsDict
1819
from e3sm_diags.logger import _setup_child_logger
1920
from e3sm_diags.metrics.metrics import correlation, rmse, spatial_avg
2021
from e3sm_diags.parameter.zonal_mean_2d_parameter import DEFAULT_PLEVS
@@ -23,6 +24,7 @@
2324
logger = _setup_child_logger(__name__)
2425

2526
if TYPE_CHECKING:
27+
from e3sm_diags.driver.utils.type_annotations import MetricsDict, TimeSelection
2628
from e3sm_diags.parameter.meridional_mean_2d_parameter import (
2729
MeridionalMean2dParameter,
2830
)
@@ -51,47 +53,25 @@ def run_diag(parameter: MeridionalMean2dParameter) -> MeridionalMean2dParameter:
5153
If the test or ref variables do are not 3-D (no Z-axis).
5254
"""
5355
variables = parameter.variables
54-
seasons = parameter.seasons
55-
time_slices = getattr(parameter, "time_slices", [])
5656
ref_name = getattr(parameter, "ref_name", "")
5757

58-
# Check that either seasons or time_slices is specified, but not both
59-
has_seasons, has_time_slices = check_time_selection(
60-
seasons, time_slices, require_one=True
61-
)
62-
63-
# Determine which time selection to use
64-
time_selections = time_slices if has_time_slices else seasons
65-
is_time_slice_mode = has_time_slices
66-
6758
test_ds = Dataset(parameter, data_type="test")
6859
ref_ds = Dataset(parameter, data_type="ref")
6960

61+
time_selection_type, time_selections = parameter._get_time_selection_to_use()
62+
7063
for var_key in variables:
7164
logger.info("Variable: {}".format(var_key))
7265
parameter.var_id = var_key
7366

7467
for time_selection in time_selections:
75-
# Set name/yrs attributes based on time selection mode
76-
if is_time_slice_mode:
77-
# For time slices, we set attributes after loading data
78-
pass
79-
else:
80-
parameter._set_name_yrs_attrs(test_ds, ref_ds, time_selection)
81-
82-
# Get datasets - pass is_time_slice flag if it's a time slice
83-
if is_time_slice_mode:
84-
ds_test = test_ds.get_climo_dataset(
85-
var_key, time_selection, is_time_slice=True
86-
)
87-
ds_ref = ref_ds.get_climo_dataset(
88-
var_key, time_selection, is_time_slice=True
89-
)
90-
# For time slices, set name_yrs after data is loaded
91-
parameter._set_name_yrs_attrs(test_ds, ref_ds, time_selection)
92-
else:
93-
ds_test = test_ds.get_climo_dataset(var_key, time_selection)
94-
ds_ref = ref_ds.get_climo_dataset(var_key, time_selection)
68+
ds_test, ds_ref, _ = _get_xarray_datasets(
69+
test_ds, ref_ds, var_key, time_selection_type, time_selection
70+
)
71+
72+
# Set name_yrs after loading data because time sliced datasets
73+
# have the required attributes only after loading the data.
74+
parameter._set_name_yrs_attrs(test_ds, ref_ds, time_selection)
9575

9676
dv_test = ds_test[var_key]
9777
dv_ref = ds_ref[var_key]
@@ -122,7 +102,7 @@ def _run_diags_3d(
122102
parameter: MeridionalMean2dParameter,
123103
ds_test: xr.Dataset,
124104
ds_ref: xr.Dataset,
125-
season: str,
105+
time_selection: TimeSelection,
126106
var_key: str,
127107
ref_name: str,
128108
):
@@ -171,7 +151,7 @@ def _run_diags_3d(
171151
)
172152

173153
parameter._set_param_output_attrs(
174-
var_key, season, parameter.regions[0], ref_name, ilev=None
154+
var_key, time_selection, parameter.regions[0], ref_name, ilev=None
175155
)
176156
_save_data_metrics_and_plots(
177157
parameter,

0 commit comments

Comments
 (0)