Skip to content

Commit e65bf2a

Browse files
Preserve legacy Xarray merge settings (#1016)
1 parent aff2eae commit e65bf2a

File tree

4 files changed

+91
-67
lines changed

4 files changed

+91
-67
lines changed

e3sm_diags/__init__.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,3 +11,12 @@
1111
os.environ["OPENBLAS_NUM_THREADS"] = "1"
1212
os.environ["OMP_NUM_THREADS"] = "1"
1313
os.environ["MKL_NUM_THREADS"] = "1"
14+
15+
# Settings to preserve legacy Xarray behavior when merging datasets.
16+
# See https://xarray.pydata.org/en/stable/user-guide/io.html#combining-multiple-files
17+
# and https://xarray.pydata.org/en/stable/whats-new.html#id14
18+
LEGACY_XARRAY_MERGE_KWARGS = {
19+
# "override", "exact" are the new defaults as of Xarray v2025.08.0
20+
"compat": "no_conflicts",
21+
"join": "outer",
22+
}

e3sm_diags/driver/mp_partition_driver.py

Lines changed: 78 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
import xarray as xr
1717
from scipy.stats import binned_statistic
1818

19-
from e3sm_diags import INSTALL_PATH
19+
from e3sm_diags import INSTALL_PATH, LEGACY_XARRAY_MERGE_KWARGS
2020
from e3sm_diags.driver.utils.dataset_xr import Dataset
2121
from e3sm_diags.logger import _setup_child_logger
2222
from e3sm_diags.plot.mp_partition_plot import plot
@@ -60,19 +60,27 @@ def compute_lcf(cice, cliq, temp, landfrac):
6060

6161

6262
def run_diag(parameter: MPpartitionParameter) -> MPpartitionParameter:
63-
"""Runs the mixed-phase partition/T5050 diagnostic.
64-
65-
:param parameter: Parameters for the run
66-
:type parameter: CoreParameter
67-
:raises ValueError: Invalid run type
68-
:return: Parameters for the run
69-
:rtype: CoreParameter
63+
"""
64+
Runs the mixed-phase partition/T5050 diagnostic.
65+
66+
Parameters
67+
----------
68+
parameter : CoreParameter
69+
Parameters for the run.
70+
71+
Raises
72+
------
73+
ValueError
74+
If the run type is invalid.
75+
76+
Returns
77+
-------
78+
CoreParameter
79+
Parameters for the run.
7080
"""
7181
run_type = parameter.run_type
7282
season = "ANN"
7383

74-
# Read reference data first
75-
7684
benchmark_data_path = os.path.join(
7785
INSTALL_PATH,
7886
"control_runs",
@@ -82,35 +90,20 @@ def run_diag(parameter: MPpartitionParameter) -> MPpartitionParameter:
8290
with open(benchmark_data_path, "r") as myfile:
8391
lcf_file = myfile.read()
8492

85-
# parse file
8693
metrics_dict = json.loads(lcf_file)
8794

8895
test_data = Dataset(parameter, data_type="test")
89-
# test = test_data.get_timeseries_variable("LANDFRAC")
90-
# print(dir(test))
91-
# landfrac = test_data.get_timeseries_variable("LANDFRAC")(cdutil.region.domain(latitude=(-70.0, -30, "ccb")))
92-
# temp = test_data.get_timeseries_variable("T")(cdutil.region.domain(latitude=(-70.0, -30, "ccb")))
93-
# cice = test_data.get_timeseries_variable("CLDICE")(cdutil.region.domain(latitude=(-70.0, -30, "ccb")))
94-
# cliq = test_data.get_timeseries_variable("CLDLIQ")(cdutil.region.domain(latitude=(-70.0, -30, "ccb")))
95-
9696
test_data_path = parameter.test_data_path
97-
start_year = parameter.test_start_yr
98-
end_year = parameter.test_end_yr
97+
98+
start_year = int(parameter.test_start_yr)
99+
end_year = int(parameter.test_end_yr)
100+
99101
# TODO the time subsetting and variable derivation should be replaced during cdat revamp
100102
try:
101-
# xr.open_mfdataset() can accept an explicit list of files.
102-
landfrac = xr.open_mfdataset(glob.glob(f"{test_data_path}/LANDFRAC_*")).sel(
103-
lat=slice(-70, -30), time=slice(f"{start_year}-01-01", f"{end_year}-12-31")
104-
)["LANDFRAC"]
105-
temp = xr.open_mfdataset(glob.glob(f"{test_data_path}/T_*.nc")).sel(
106-
lat=slice(-70, -30), time=slice(f"{start_year}-01-01", f"{end_year}-12-31")
107-
)["T"]
108-
cice = xr.open_mfdataset(glob.glob(f"{test_data_path}/CLDICE_*.nc")).sel(
109-
lat=slice(-70, -30), time=slice(f"{start_year}-01-01", f"{end_year}-12-31")
110-
)["CLDICE"]
111-
cliq = xr.open_mfdataset(glob.glob(f"{test_data_path}/CLDLIQ_*.nc")).sel(
112-
lat=slice(-70, -30), time=slice(f"{start_year}-01-01", f"{end_year}-12-31")
113-
)["CLDLIQ"]
103+
landfrac = _open_mfdataset(test_data_path, "LANDFRAC", start_year, end_year)
104+
temp = _open_mfdataset(test_data_path, "T", start_year, end_year)
105+
cice = _open_mfdataset(test_data_path, "CLDICE", start_year, end_year)
106+
cliq = _open_mfdataset(test_data_path, "CLDLIQ", start_year, end_year)
114107
except OSError:
115108
logger.info(
116109
f"No files to open for variables within {start_year} and {end_year} from {test_data_path}."
@@ -126,46 +119,22 @@ def run_diag(parameter: MPpartitionParameter) -> MPpartitionParameter:
126119

127120
if run_type == "model-vs-model":
128121
ref_data = Dataset(parameter, data_type="ref")
129-
130122
ref_data_path = parameter.reference_data_path
131-
start_year = parameter.ref_start_yr
132-
end_year = parameter.ref_end_yr
133-
# xr.open_mfdataset() can accept an explicit list of files.
123+
124+
start_year = int(parameter.ref_start_yr)
125+
end_year = int(parameter.ref_end_yr)
126+
134127
try:
135-
landfrac = xr.open_mfdataset(glob.glob(f"{ref_data_path}/LANDFRAC_*")).sel(
136-
lat=slice(-70, -30),
137-
time=slice(f"{start_year}-01-01", f"{end_year}-12-31"),
138-
)["LANDFRAC"]
139-
temp = xr.open_mfdataset(glob.glob(f"{ref_data_path}/T_*.nc")).sel(
140-
lat=slice(-70, -30),
141-
time=slice(f"{start_year}-01-01", f"{end_year}-12-31"),
142-
)["T"]
143-
cice = xr.open_mfdataset(glob.glob(f"{ref_data_path}/CLDICE_*.nc")).sel(
144-
lat=slice(-70, -30),
145-
time=slice(f"{start_year}-01-01", f"{end_year}-12-31"),
146-
)["CLDICE"]
147-
cliq = xr.open_mfdataset(glob.glob(f"{ref_data_path}/CLDLIQ_*.nc")).sel(
148-
lat=slice(-70, -30),
149-
time=slice(f"{start_year}-01-01", f"{end_year}-12-31"),
150-
)["CLDLIQ"]
128+
landfrac = _open_mfdataset(ref_data_path, "LANDFRAC", start_year, end_year)
129+
temp = _open_mfdataset(ref_data_path, "T", start_year, end_year)
130+
cice = _open_mfdataset(ref_data_path, "CLDICE", start_year, end_year)
131+
cliq = _open_mfdataset(ref_data_path, "CLDLIQ", start_year, end_year)
151132
except OSError:
152133
logger.info(
153134
f"No files to open for variables within {start_year} and {end_year} from {ref_data_path}."
154135
)
155136
raise
156137

157-
# landfrac = ref_data.get_timeseries_variable("LANDFRAC")(
158-
# cdutil.region.domain(latitude=(-70.0, -30, "ccb"))
159-
# )
160-
# temp = ref_data.get_timeseries_variable("T")(
161-
# cdutil.region.domain(latitude=(-70.0, -30, "ccb"))
162-
# )
163-
# cice = ref_data.get_timeseries_variable("CLDICE")(
164-
# cdutil.region.domain(latitude=(-70.0, -30, "ccb"))
165-
# )
166-
# cliq = ref_data.get_timeseries_variable("CLDLIQ")(
167-
# cdutil.region.domain(latitude=(-70.0, -30, "ccb"))
168-
# )
169138
parameter.ref_name_yrs = ref_data.get_name_yrs_attr(season)
170139
metrics_dict["ref"] = {}
171140
metrics_dict["ref"]["T"], metrics_dict["ref"]["LCF"] = compute_lcf(
@@ -177,3 +146,47 @@ def run_diag(parameter: MPpartitionParameter) -> MPpartitionParameter:
177146
plot(metrics_dict, parameter)
178147

179148
return parameter
149+
150+
151+
def _open_mfdataset(
152+
data_path: str, var: str, start_year: int, end_year: int
153+
) -> xr.DataArray:
154+
"""
155+
Open multiple NetCDF files as a single xarray Dataset and subset by time
156+
and latitude.
157+
158+
This function reads multiple NetCDF files matching the specified variable
159+
name and combines them into a single xarray Dataset. The data is then
160+
subsetted based on the specified time range and latitude bounds.
161+
162+
Parameters
163+
----------
164+
data_path : str
165+
The path to the directory containing the NetCDF files.
166+
var : str
167+
The variable name to match in the file pattern.
168+
start_year : int
169+
The starting year for the time subsetting.
170+
end_year : int
171+
The ending year for the time subsetting.
172+
173+
Returns
174+
-------
175+
xr.DataArray
176+
The subsetted DataArray for the specified variable, filtered by time
177+
and latitude.
178+
"""
179+
file_pattern = f"{data_path}/{var}_*.nc"
180+
ds = xr.open_mfdataset(
181+
glob.glob(file_pattern),
182+
data_vars="minimal",
183+
**LEGACY_XARRAY_MERGE_KWARGS, # type: ignore[ arg-type ]
184+
)
185+
186+
ds_sub = ds.sel(
187+
lat=slice(-70, -30), time=slice(f"{start_year}-01-01", f"{end_year}-12-31")
188+
)
189+
190+
da_var = ds_sub[var]
191+
192+
return da_var

e3sm_diags/driver/tropical_subseasonal_driver.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -244,6 +244,7 @@ def wf_analysis(x, **kwargs):
244244
background.rename("spec_background"),
245245
],
246246
compat="override",
247+
join="outer",
247248
)
248249
spec_all = spec.drop("component")
249250
spec_all["spec_raw_sym"].attrs = {"component": "symmetric", "type": "raw"}

e3sm_diags/driver/utils/dataset_xr.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
import xarray as xr
2525
import xcdat as xc
2626

27+
from e3sm_diags import LEGACY_XARRAY_MERGE_KWARGS
2728
from e3sm_diags.derivations.derivations import (
2829
DERIVED_VARIABLES,
2930
FUNC_NEEDS_TARGET_VAR,
@@ -1159,7 +1160,7 @@ def _get_dataset_with_source_vars(self, vars_to_get: tuple[str, ...]) -> xr.Data
11591160
ds = self._get_time_series_dataset_obj(var)
11601161
datasets.append(ds)
11611162

1162-
ds = xr.merge(datasets)
1163+
ds = xr.merge(datasets, **LEGACY_XARRAY_MERGE_KWARGS) # type: ignore[ arg-type ]
11631164
ds = squeeze_time_dim(ds)
11641165

11651166
return ds
@@ -1640,7 +1641,7 @@ def _get_land_sea_mask_dataset(self, season: ClimoFreq) -> xr.Dataset | None:
16401641
datasets.append(ds_ocn)
16411642

16421643
if len(datasets) == 2:
1643-
return xr.merge(datasets)
1644+
return xr.merge(datasets, **LEGACY_XARRAY_MERGE_KWARGS) # type: ignore[ arg-type ]
16441645

16451646
return None
16461647

0 commit comments

Comments
 (0)