Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions e3sm_diags/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,12 @@
os.environ["OPENBLAS_NUM_THREADS"] = "1"
os.environ["OMP_NUM_THREADS"] = "1"
os.environ["MKL_NUM_THREADS"] = "1"

# Settings to preserve legacy Xarray behavior when merging datasets.
# See https://xarray.pydata.org/en/stable/user-guide/io.html#combining-multiple-files
# and https://xarray.pydata.org/en/stable/whats-new.html#id14
LEGACY_XARRAY_MERGE_KWARGS = {
# "override", "exact" are the new defaults as of Xarray v2025.08.0
"compat": "no_conflicts",
"join": "outer",
}
Comment on lines +15 to +22
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The constant variable for legacy xarray merge settings.

143 changes: 78 additions & 65 deletions e3sm_diags/driver/mp_partition_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import xarray as xr
from scipy.stats import binned_statistic

from e3sm_diags import INSTALL_PATH
from e3sm_diags import INSTALL_PATH, LEGACY_XARRAY_MERGE_KWARGS
from e3sm_diags.driver.utils.dataset_xr import Dataset
from e3sm_diags.logger import _setup_child_logger
from e3sm_diags.plot.mp_partition_plot import plot
Expand Down Expand Up @@ -60,19 +60,27 @@ def compute_lcf(cice, cliq, temp, landfrac):


def run_diag(parameter: MPpartitionParameter) -> MPpartitionParameter:
"""Runs the mixed-phase partition/T5050 diagnostic.

:param parameter: Parameters for the run
:type parameter: CoreParameter
:raises ValueError: Invalid run type
:return: Parameters for the run
:rtype: CoreParameter
"""
Runs the mixed-phase partition/T5050 diagnostic.

Parameters
----------
parameter : CoreParameter
Parameters for the run.

Raises
------
ValueError
If the run type is invalid.

Returns
-------
CoreParameter
Parameters for the run.
"""
run_type = parameter.run_type
season = "ANN"

# Read reference data first

benchmark_data_path = os.path.join(
INSTALL_PATH,
"control_runs",
Expand All @@ -82,35 +90,20 @@ def run_diag(parameter: MPpartitionParameter) -> MPpartitionParameter:
with open(benchmark_data_path, "r") as myfile:
lcf_file = myfile.read()

# parse file
metrics_dict = json.loads(lcf_file)

test_data = Dataset(parameter, data_type="test")
# test = test_data.get_timeseries_variable("LANDFRAC")
# print(dir(test))
# landfrac = test_data.get_timeseries_variable("LANDFRAC")(cdutil.region.domain(latitude=(-70.0, -30, "ccb")))
# temp = test_data.get_timeseries_variable("T")(cdutil.region.domain(latitude=(-70.0, -30, "ccb")))
# cice = test_data.get_timeseries_variable("CLDICE")(cdutil.region.domain(latitude=(-70.0, -30, "ccb")))
# cliq = test_data.get_timeseries_variable("CLDLIQ")(cdutil.region.domain(latitude=(-70.0, -30, "ccb")))

test_data_path = parameter.test_data_path
start_year = parameter.test_start_yr
end_year = parameter.test_end_yr

start_year = int(parameter.test_start_yr)
end_year = int(parameter.test_end_yr)

# TODO the time subsetting and variable derivation should be replaced during cdat revamp
try:
# xr.open_mfdataset() can accept an explicit list of files.
landfrac = xr.open_mfdataset(glob.glob(f"{test_data_path}/LANDFRAC_*")).sel(
lat=slice(-70, -30), time=slice(f"{start_year}-01-01", f"{end_year}-12-31")
)["LANDFRAC"]
temp = xr.open_mfdataset(glob.glob(f"{test_data_path}/T_*.nc")).sel(
lat=slice(-70, -30), time=slice(f"{start_year}-01-01", f"{end_year}-12-31")
)["T"]
cice = xr.open_mfdataset(glob.glob(f"{test_data_path}/CLDICE_*.nc")).sel(
lat=slice(-70, -30), time=slice(f"{start_year}-01-01", f"{end_year}-12-31")
)["CLDICE"]
cliq = xr.open_mfdataset(glob.glob(f"{test_data_path}/CLDLIQ_*.nc")).sel(
lat=slice(-70, -30), time=slice(f"{start_year}-01-01", f"{end_year}-12-31")
)["CLDLIQ"]
landfrac = _open_mfdataset(test_data_path, "LANDFRAC", start_year, end_year)
temp = _open_mfdataset(test_data_path, "T", start_year, end_year)
cice = _open_mfdataset(test_data_path, "CLDICE", start_year, end_year)
cliq = _open_mfdataset(test_data_path, "CLDLIQ", start_year, end_year)
except OSError:
logger.info(
f"No files to open for variables within {start_year} and {end_year} from {test_data_path}."
Expand All @@ -126,46 +119,22 @@ def run_diag(parameter: MPpartitionParameter) -> MPpartitionParameter:

if run_type == "model-vs-model":
ref_data = Dataset(parameter, data_type="ref")

ref_data_path = parameter.reference_data_path
start_year = parameter.ref_start_yr
end_year = parameter.ref_end_yr
# xr.open_mfdataset() can accept an explicit list of files.

start_year = int(parameter.ref_start_yr)
end_year = int(parameter.ref_end_yr)

try:
landfrac = xr.open_mfdataset(glob.glob(f"{ref_data_path}/LANDFRAC_*")).sel(
lat=slice(-70, -30),
time=slice(f"{start_year}-01-01", f"{end_year}-12-31"),
)["LANDFRAC"]
temp = xr.open_mfdataset(glob.glob(f"{ref_data_path}/T_*.nc")).sel(
lat=slice(-70, -30),
time=slice(f"{start_year}-01-01", f"{end_year}-12-31"),
)["T"]
cice = xr.open_mfdataset(glob.glob(f"{ref_data_path}/CLDICE_*.nc")).sel(
lat=slice(-70, -30),
time=slice(f"{start_year}-01-01", f"{end_year}-12-31"),
)["CLDICE"]
cliq = xr.open_mfdataset(glob.glob(f"{ref_data_path}/CLDLIQ_*.nc")).sel(
lat=slice(-70, -30),
time=slice(f"{start_year}-01-01", f"{end_year}-12-31"),
)["CLDLIQ"]
landfrac = _open_mfdataset(ref_data_path, "LANDFRAC", start_year, end_year)
temp = _open_mfdataset(ref_data_path, "T", start_year, end_year)
cice = _open_mfdataset(ref_data_path, "CLDICE", start_year, end_year)
cliq = _open_mfdataset(ref_data_path, "CLDLIQ", start_year, end_year)
except OSError:
logger.info(
f"No files to open for variables within {start_year} and {end_year} from {ref_data_path}."
)
raise

# landfrac = ref_data.get_timeseries_variable("LANDFRAC")(
# cdutil.region.domain(latitude=(-70.0, -30, "ccb"))
# )
# temp = ref_data.get_timeseries_variable("T")(
# cdutil.region.domain(latitude=(-70.0, -30, "ccb"))
# )
# cice = ref_data.get_timeseries_variable("CLDICE")(
# cdutil.region.domain(latitude=(-70.0, -30, "ccb"))
# )
# cliq = ref_data.get_timeseries_variable("CLDLIQ")(
# cdutil.region.domain(latitude=(-70.0, -30, "ccb"))
# )
parameter.ref_name_yrs = ref_data.get_name_yrs_attr(season)
metrics_dict["ref"] = {}
metrics_dict["ref"]["T"], metrics_dict["ref"]["LCF"] = compute_lcf(
Expand All @@ -177,3 +146,47 @@ def run_diag(parameter: MPpartitionParameter) -> MPpartitionParameter:
plot(metrics_dict, parameter)

return parameter


def _open_mfdataset(
data_path: str, var: str, start_year: int, end_year: int
) -> xr.DataArray:
"""
Open multiple NetCDF files as a single xarray Dataset and subset by time
and latitude.

This function reads multiple NetCDF files matching the specified variable
name and combines them into a single xarray Dataset. The data is then
subsetted based on the specified time range and latitude bounds.

Parameters
----------
data_path : str
The path to the directory containing the NetCDF files.
var : str
The variable name to match in the file pattern.
start_year : int
The starting year for the time subsetting.
end_year : int
The ending year for the time subsetting.

Returns
-------
xr.DataArray
The subsetted DataArray for the specified variable, filtered by time
and latitude.
"""
file_pattern = f"{data_path}/{var}_*.nc"
ds = xr.open_mfdataset(
glob.glob(file_pattern),
data_vars="minimal",
**LEGACY_XARRAY_MERGE_KWARGS, # type: ignore[ arg-type ]
)

ds_sub = ds.sel(
lat=slice(-70, -30), time=slice(f"{start_year}-01-01", f"{end_year}-12-31")
)

da_var = ds_sub[var]

return da_var
Comment on lines +151 to +192
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The DRY function.

1 change: 1 addition & 0 deletions e3sm_diags/driver/tropical_subseasonal_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,7 @@ def wf_analysis(x, **kwargs):
background.rename("spec_background"),
],
compat="override",
join="outer",
)
spec_all = spec.drop("component")
spec_all["spec_raw_sym"].attrs = {"component": "symmetric", "type": "raw"}
Expand Down
5 changes: 3 additions & 2 deletions e3sm_diags/driver/utils/dataset_xr.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import xarray as xr
import xcdat as xc

from e3sm_diags import LEGACY_XARRAY_MERGE_KWARGS
from e3sm_diags.derivations.derivations import (
DERIVED_VARIABLES,
FUNC_NEEDS_TARGET_VAR,
Expand Down Expand Up @@ -1159,7 +1160,7 @@ def _get_dataset_with_source_vars(self, vars_to_get: tuple[str, ...]) -> xr.Data
ds = self._get_time_series_dataset_obj(var)
datasets.append(ds)

ds = xr.merge(datasets)
ds = xr.merge(datasets, **LEGACY_XARRAY_MERGE_KWARGS) # type: ignore[ arg-type ]
ds = squeeze_time_dim(ds)

return ds
Expand Down Expand Up @@ -1640,7 +1641,7 @@ def _get_land_sea_mask_dataset(self, season: ClimoFreq) -> xr.Dataset | None:
datasets.append(ds_ocn)

if len(datasets) == 2:
return xr.merge(datasets)
return xr.merge(datasets, **LEGACY_XARRAY_MERGE_KWARGS) # type: ignore[ arg-type ]

return None

Expand Down
Loading