Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions e3sm_diags/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,12 @@
os.environ["OPENBLAS_NUM_THREADS"] = "1"
os.environ["OMP_NUM_THREADS"] = "1"
os.environ["MKL_NUM_THREADS"] = "1"

# Settings to preserve legacy Xarray behavior when merging datasets.
# See https://xarray.pydata.org/en/stable/user-guide/io.html#combining-multiple-files
# and https://xarray.pydata.org/en/stable/whats-new.html#id14
LEGACY_XARRAY_MERGE_KWARGS = {
# "override", "exact" are the new defaults as of Xarray v2025.08.0
"compat": "no_conflicts",
"join": "outer",
}
Comment on lines +15 to +22
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The constant variable for legacy xarray merge settings.

140 changes: 75 additions & 65 deletions e3sm_diags/driver/mp_partition_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import xarray as xr
from scipy.stats import binned_statistic

from e3sm_diags import INSTALL_PATH
from e3sm_diags import INSTALL_PATH, LEGACY_XARRAY_MERGE_KWARGS
from e3sm_diags.driver.utils.dataset_xr import Dataset
from e3sm_diags.logger import _setup_child_logger
from e3sm_diags.plot.mp_partition_plot import plot
Expand Down Expand Up @@ -60,19 +60,27 @@ def compute_lcf(cice, cliq, temp, landfrac):


def run_diag(parameter: MPpartitionParameter) -> MPpartitionParameter:
"""Runs the mixed-phase partition/T5050 diagnostic.

:param parameter: Parameters for the run
:type parameter: CoreParameter
:raises ValueError: Invalid run type
:return: Parameters for the run
:rtype: CoreParameter
"""
Runs the mixed-phase partition/T5050 diagnostic.

Parameters
----------
parameter : CoreParameter
Parameters for the run.

Raises
------
ValueError
If the run type is invalid.

Returns
-------
CoreParameter
Parameters for the run.
"""
run_type = parameter.run_type
season = "ANN"

# Read reference data first

benchmark_data_path = os.path.join(
INSTALL_PATH,
"control_runs",
Expand All @@ -82,35 +90,20 @@ def run_diag(parameter: MPpartitionParameter) -> MPpartitionParameter:
with open(benchmark_data_path, "r") as myfile:
lcf_file = myfile.read()

# parse file
metrics_dict = json.loads(lcf_file)

test_data = Dataset(parameter, data_type="test")
# test = test_data.get_timeseries_variable("LANDFRAC")
# print(dir(test))
# landfrac = test_data.get_timeseries_variable("LANDFRAC")(cdutil.region.domain(latitude=(-70.0, -30, "ccb")))
# temp = test_data.get_timeseries_variable("T")(cdutil.region.domain(latitude=(-70.0, -30, "ccb")))
# cice = test_data.get_timeseries_variable("CLDICE")(cdutil.region.domain(latitude=(-70.0, -30, "ccb")))
# cliq = test_data.get_timeseries_variable("CLDLIQ")(cdutil.region.domain(latitude=(-70.0, -30, "ccb")))

test_data_path = parameter.test_data_path
start_year = parameter.test_start_yr
end_year = parameter.test_end_yr
start_year = int(parameter.test_start_yr)

end_year = int(parameter.test_end_yr)

# TODO the time subsetting and variable derivation should be replaced during cdat revamp
try:
# xr.open_mfdataset() can accept an explicit list of files.
landfrac = xr.open_mfdataset(glob.glob(f"{test_data_path}/LANDFRAC_*")).sel(
lat=slice(-70, -30), time=slice(f"{start_year}-01-01", f"{end_year}-12-31")
)["LANDFRAC"]
temp = xr.open_mfdataset(glob.glob(f"{test_data_path}/T_*.nc")).sel(
lat=slice(-70, -30), time=slice(f"{start_year}-01-01", f"{end_year}-12-31")
)["T"]
cice = xr.open_mfdataset(glob.glob(f"{test_data_path}/CLDICE_*.nc")).sel(
lat=slice(-70, -30), time=slice(f"{start_year}-01-01", f"{end_year}-12-31")
)["CLDICE"]
cliq = xr.open_mfdataset(glob.glob(f"{test_data_path}/CLDLIQ_*.nc")).sel(
lat=slice(-70, -30), time=slice(f"{start_year}-01-01", f"{end_year}-12-31")
)["CLDLIQ"]
landfrac = _open_mfdataset(test_data_path, "LANDFRAC", start_year, end_year)
temp = _open_mfdataset(test_data_path, "T", start_year, end_year)
cice = _open_mfdataset(test_data_path, "CLDICE", start_year, end_year)
cliq = _open_mfdataset(test_data_path, "CLDLIQ", start_year, end_year)
except OSError:
logger.info(
f"No files to open for variables within {start_year} and {end_year} from {test_data_path}."
Expand All @@ -126,46 +119,19 @@ def run_diag(parameter: MPpartitionParameter) -> MPpartitionParameter:

if run_type == "model-vs-model":
ref_data = Dataset(parameter, data_type="ref")

ref_data_path = parameter.reference_data_path
start_year = parameter.ref_start_yr
end_year = parameter.ref_end_yr
# xr.open_mfdataset() can accept an explicit list of files.

try:
landfrac = xr.open_mfdataset(glob.glob(f"{ref_data_path}/LANDFRAC_*")).sel(
lat=slice(-70, -30),
time=slice(f"{start_year}-01-01", f"{end_year}-12-31"),
)["LANDFRAC"]
temp = xr.open_mfdataset(glob.glob(f"{ref_data_path}/T_*.nc")).sel(
lat=slice(-70, -30),
time=slice(f"{start_year}-01-01", f"{end_year}-12-31"),
)["T"]
cice = xr.open_mfdataset(glob.glob(f"{ref_data_path}/CLDICE_*.nc")).sel(
lat=slice(-70, -30),
time=slice(f"{start_year}-01-01", f"{end_year}-12-31"),
)["CLDICE"]
cliq = xr.open_mfdataset(glob.glob(f"{ref_data_path}/CLDLIQ_*.nc")).sel(
lat=slice(-70, -30),
time=slice(f"{start_year}-01-01", f"{end_year}-12-31"),
)["CLDLIQ"]
landfrac = _open_mfdataset(ref_data_path, "LANDFRAC", start_year, end_year)
temp = _open_mfdataset(ref_data_path, "T", start_year, end_year)
cice = _open_mfdataset(ref_data_path, "CLDICE", start_year, end_year)
cliq = _open_mfdataset(ref_data_path, "CLDLIQ", start_year, end_year)
except OSError:
logger.info(
f"No files to open for variables within {start_year} and {end_year} from {ref_data_path}."
)
raise

# landfrac = ref_data.get_timeseries_variable("LANDFRAC")(
# cdutil.region.domain(latitude=(-70.0, -30, "ccb"))
# )
# temp = ref_data.get_timeseries_variable("T")(
# cdutil.region.domain(latitude=(-70.0, -30, "ccb"))
# )
# cice = ref_data.get_timeseries_variable("CLDICE")(
# cdutil.region.domain(latitude=(-70.0, -30, "ccb"))
# )
# cliq = ref_data.get_timeseries_variable("CLDLIQ")(
# cdutil.region.domain(latitude=(-70.0, -30, "ccb"))
# )
parameter.ref_name_yrs = ref_data.get_name_yrs_attr(season)
metrics_dict["ref"] = {}
metrics_dict["ref"]["T"], metrics_dict["ref"]["LCF"] = compute_lcf(
Expand All @@ -177,3 +143,47 @@ def run_diag(parameter: MPpartitionParameter) -> MPpartitionParameter:
plot(metrics_dict, parameter)

return parameter


def _open_mfdataset(
data_path: str, var: str, start_year: int, end_year: int
) -> xr.DataArray:
"""
Open multiple NetCDF files as a single xarray Dataset and subset by time
and latitude.

This function reads multiple NetCDF files matching the specified variable
name and combines them into a single xarray Dataset. The data is then
subsetted based on the specified time range and latitude bounds.

Parameters
----------
data_path : str
The path to the directory containing the NetCDF files.
var : str
The variable name to match in the file pattern.
start_year : int
The starting year for the time subsetting.
end_year : int
The ending year for the time subsetting.

Returns
-------
xr.DataArray
The subsetted DataArray for the specified variable, filtered by time
and latitude.
"""
file_pattern = f"{data_path}/{var}_*.nc"
ds = xr.open_mfdataset(
glob.glob(file_pattern),
data_vars="minimal",
**LEGACY_XARRAY_MERGE_KWARGS, # type: ignore[ arg-type ]
)

ds_sub = ds.sel(
lat=slice(-70, -30), time=slice(f"{start_year}-01-01", f"{end_year}-12-31")
)

da_var = ds_sub[var]

return da_var
Comment on lines +151 to +192
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The DRY function.

1 change: 1 addition & 0 deletions e3sm_diags/driver/tropical_subseasonal_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,7 @@ def wf_analysis(x, **kwargs):
background.rename("spec_background"),
],
compat="override",
join="outer",
)
spec_all = spec.drop("component")
spec_all["spec_raw_sym"].attrs = {"component": "symmetric", "type": "raw"}
Expand Down
5 changes: 3 additions & 2 deletions e3sm_diags/driver/utils/dataset_xr.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import xarray as xr
import xcdat as xc

from e3sm_diags import LEGACY_XARRAY_MERGE_KWARGS
from e3sm_diags.derivations.derivations import (
DERIVED_VARIABLES,
FUNC_NEEDS_TARGET_VAR,
Expand Down Expand Up @@ -1159,7 +1160,7 @@ def _get_dataset_with_source_vars(self, vars_to_get: tuple[str, ...]) -> xr.Data
ds = self._get_time_series_dataset_obj(var)
datasets.append(ds)

ds = xr.merge(datasets)
ds = xr.merge(datasets, **LEGACY_XARRAY_MERGE_KWARGS) # type: ignore[ arg-type ]
ds = squeeze_time_dim(ds)

return ds
Expand Down Expand Up @@ -1640,7 +1641,7 @@ def _get_land_sea_mask_dataset(self, season: ClimoFreq) -> xr.Dataset | None:
datasets.append(ds_ocn)

if len(datasets) == 2:
return xr.merge(datasets)
return xr.merge(datasets, **LEGACY_XARRAY_MERGE_KWARGS) # type: ignore[ arg-type ]

return None

Expand Down
Loading