Skip to content

Commit 48a9728

Browse files
committed
Subset time series variables and load into memory
- Required for faster downstream operations, which require in-memory NumPy arrays
1 parent fa9c176 commit 48a9728

File tree

1 file changed

+69
-56
lines changed

1 file changed

+69
-56
lines changed

e3sm_diags/driver/utils/dataset_xr.py

Lines changed: 69 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,11 @@
4444
TS_EXT_FILEPATTERN = r"_.{13}.nc"
4545

4646

47+
# Additional variables to keep when subsetting.
48+
HYBRID_VAR_KEYS = set(list(sum(HYBRID_SIGMA_KEYS.values(), ())))
49+
MISC_VARS = ["area"]
50+
51+
4752
def squeeze_time_dim(ds: xr.Dataset) -> xr.Dataset:
4853
"""Squeeze single coordinate climatology time dimensions.
4954
@@ -404,7 +409,7 @@ def _get_climo_dataset(self, season: str) -> xr.Dataset:
404409
)
405410

406411
ds = squeeze_time_dim(ds)
407-
ds = self._subset_vars_and_load(ds)
412+
ds = self._subset_vars_and_load(ds, self.var)
408413

409414
return ds
410415

@@ -740,50 +745,6 @@ def _get_matching_climo_src_vars(
740745

741746
return None
742747

743-
def _subset_vars_and_load(self, ds: xr.Dataset) -> xr.Dataset:
744-
"""Subset for variables needed for processing and load into memory.
745-
746-
Subsetting the dataset reduces its memory footprint. Loading is
747-
necessary because there seems to be an issue with `open_mfdataset()`
748-
and using the multiprocessing scheduler defined in e3sm_diags,
749-
resulting in timeouts and resource locking. To avoid this, we load the
750-
multi-file dataset into memory before performing downstream operations.
751-
752-
Source: https://github.com/pydata/xarray/issues/3781
753-
754-
Parameters
755-
----------
756-
ds : xr.Dataset
757-
The dataset.
758-
759-
Returns
760-
-------
761-
xr.Dataset
762-
The dataset subsetted and loaded into memory.
763-
"""
764-
# slat and slon are lat lon pair for staggered FV grid included in
765-
# remapped files.
766-
if "slat" in ds.dims:
767-
ds = ds.drop_dims(["slat", "slon"])
768-
769-
all_vars_keys = list(ds.data_vars.keys())
770-
771-
hybrid_var_keys = set(list(sum(HYBRID_SIGMA_KEYS.values(), ())))
772-
misc_vars = ["area"]
773-
keep_vars = [
774-
var
775-
for var in all_vars_keys
776-
if "bnd" in var
777-
or "bounds" in var
778-
or var in hybrid_var_keys
779-
or var in misc_vars
780-
]
781-
ds = ds[[self.var] + keep_vars]
782-
783-
ds.load(scheduler="sync")
784-
785-
return ds
786-
787748
# --------------------------------------------------------------------------
788749
# Time series related methods
789750
# --------------------------------------------------------------------------
@@ -973,7 +934,7 @@ def _get_matching_time_series_src_vars(
973934
# time series filepath.
974935
for tuple_of_vars in possible_vars:
975936
all_vars_found = all(
976-
self._get_timeseries_filepaths(path, var) is not None
937+
self._get_time_series_filepaths(path, var) is not None
977938
for var in tuple_of_vars
978939
)
979940

@@ -983,7 +944,7 @@ def _get_matching_time_series_src_vars(
983944
# None of the entries in the derived variables dictionary are valid,
984945
# so try to get the dataset for the variable directly.
985946
# Example file name: {var}_{start_yr}01_{end_yr}12.nc.
986-
if self._get_timeseries_filepaths(path, self.var) is not None:
947+
if self._get_time_series_filepaths(path, self.var) is not None:
987948
return {(self.var,): lambda x: x}
988949

989950
raise IOError(
@@ -1028,25 +989,25 @@ def _get_time_series_dataset_obj(self, var) -> xr.Dataset:
1028989
xr.Dataset
1029990
The dataset for the variable.
1030991
"""
1031-
filepaths = self._get_timeseries_filepaths(self.root_path, var)
992+
filepaths = self._get_time_series_filepaths(self.root_path, var)
1032993

1033994
if filepaths is None:
1034995
raise IOError(
1035996
f"No time series `.nc` file was found for '{var}' in '{self.root_path}'"
1036997
)
998+
1037999
ds = xc.open_mfdataset(
10381000
filepaths,
10391001
add_bounds=["X", "Y", "T"],
10401002
decode_times=True,
1041-
use_cftime=True,
10421003
coords="minimal",
10431004
compat="override",
10441005
)
1045-
ds_subset = self._subset_time_series_dataset(ds, filepaths)
1006+
ds_subset = self._subset_time_series_dataset(ds, filepaths, var)
10461007

10471008
return ds_subset
10481009

1049-
def _get_timeseries_filepaths(
1010+
def _get_time_series_filepaths(
10501011
self, root_path: str, var_key: str
10511012
) -> List[str] | None:
10521013
"""Get the matching variable time series filepaths.
@@ -1083,7 +1044,7 @@ def _get_timeseries_filepaths(
10831044

10841045
# Attempt 1 - try to find the file directly in `data_path`
10851046
# Example: {path}/ts_200001_200112.nc"
1086-
match = self._get_matching_time_series_filepath(
1047+
match = self._get_matching_time_series_filepaths(
10871048
root_path, var_key, filename_pattern
10881049
)
10891050

@@ -1092,13 +1053,13 @@ def _get_timeseries_filepaths(
10921053
# Example: {path}/*/{ref_name}/*/ts_200001_200112.nc"
10931054
ref_name = getattr(self.parameter, "ref_name", None)
10941055
if match is None and ref_name is not None:
1095-
match = self._get_matching_time_series_filepath(
1056+
match = self._get_matching_time_series_filepaths(
10961057
root_path, var_key, filename_pattern, ref_name
10971058
)
10981059

10991060
return match
11001061

1101-
def _get_matching_time_series_filepath(
1062+
def _get_matching_time_series_filepaths(
11021063
self,
11031064
root_path: str,
11041065
var_key: str,
@@ -1145,22 +1106,30 @@ def _get_matching_time_series_filepath(
11451106
return matches
11461107

11471108
def _subset_time_series_dataset(
1148-
self, ds: xr.Dataset, filepaths: List[str]
1109+
self, ds: xr.Dataset, filepaths: List[str], var: str
11491110
) -> xr.Dataset:
1150-
"""Subset the time series dataset based on the filepath.
1111+
"""Subset the time series dataset.
1112+
1113+
This method subsets the variables in the dataset and loads the data
1114+
into memory, then subsets on the time slice based on the specified
1115+
files.
11511116
11521117
Parameters
11531118
----------
11541119
ds : xr.Dataset
11551120
The time series dataset.
11561121
filepaths : List[str]
11571122
The list of filepaths.
1123+
var : str
1124+
The main variable to keep.
11581125
11591126
Returns
11601127
-------
11611128
xr.Dataset
11621129
The subsetted time series dataset.
11631130
"""
1131+
ds_subset = self._subset_vars_and_load(ds, var)
1132+
11641133
time_slice = self._get_time_slice(ds, filepaths)
11651134
ds_subset = ds.sel(time=time_slice).squeeze()
11661135

@@ -1335,6 +1304,7 @@ def _get_time_bounds_delta(self, time_bnds: xr.DataArray) -> timedelta:
13351304
The time delta.
13361305
"""
13371306
time_delta = time_bnds[0][-1] - time_bnds[0][0]
1307+
# FIXME: This line is a slow with open_mfdataset and dask arrays
13381308
time_delta_py = pd.to_timedelta(time_delta.values).to_pytimedelta()
13391309

13401310
return time_delta_py
@@ -1516,3 +1486,46 @@ def _get_land_sea_mask(self, season: str) -> xr.Dataset:
15161486
ds_mask = xr.merge([ds_land_frac, ds_ocean_frac])
15171487

15181488
return ds_mask
1489+
1490+
def _subset_vars_and_load(self, ds: xr.Dataset, var: str) -> xr.Dataset:
1491+
"""Subset for variables needed for processing and load into memory.
1492+
1493+
Subsetting the dataset reduces its memory footprint. Loading is
1494+
necessary because there seems to be an issue with `open_mfdataset()`
1495+
and using the multiprocessing scheduler defined in e3sm_diags,
1496+
resulting in timeouts and resource locking. To avoid this, we load the
1497+
multi-file dataset into memory before performing downstream operations.
1498+
1499+
Source: https://github.com/pydata/xarray/issues/3781
1500+
1501+
Parameters
1502+
----------
1503+
ds : xr.Dataset
1504+
The dataset.
1505+
var : str
1506+
The main variable to keep.
1507+
1508+
Returns
1509+
-------
1510+
xr.Dataset
1511+
The dataset subsetted and loaded into memory.
1512+
"""
1513+
# slat and slon are lat lon pair for staggered FV grid included in
1514+
# remapped files.
1515+
if "slat" in ds.dims:
1516+
ds = ds.drop_dims(["slat", "slon"])
1517+
1518+
all_vars_keys = list(ds.data_vars.keys())
1519+
keep_vars = [
1520+
var
1521+
for var in all_vars_keys
1522+
if "bnd" in var
1523+
or "bounds" in var
1524+
or var in HYBRID_VAR_KEYS
1525+
or var in MISC_VARS
1526+
]
1527+
ds = ds[[var] + keep_vars]
1528+
1529+
ds.load(scheduler="sync")
1530+
1531+
return ds

0 commit comments

Comments
 (0)