diff --git a/e3sm_diags/__init__.py b/e3sm_diags/__init__.py index 3e09ce500..f04c5b1ff 100644 --- a/e3sm_diags/__init__.py +++ b/e3sm_diags/__init__.py @@ -11,3 +11,12 @@ os.environ["OPENBLAS_NUM_THREADS"] = "1" os.environ["OMP_NUM_THREADS"] = "1" os.environ["MKL_NUM_THREADS"] = "1" + +# Settings to preserve legacy Xarray behavior when merging datasets. +# See https://xarray.pydata.org/en/stable/user-guide/io.html#combining-multiple-files +# and https://xarray.pydata.org/en/stable/whats-new.html#id14 +LEGACY_XARRAY_MERGE_KWARGS = { + # "override", "exact" are the new defaults as of Xarray v2025.08.0 + "compat": "no_conflicts", + "join": "outer", +} diff --git a/e3sm_diags/driver/mp_partition_driver.py b/e3sm_diags/driver/mp_partition_driver.py index b4b27b522..45e020875 100644 --- a/e3sm_diags/driver/mp_partition_driver.py +++ b/e3sm_diags/driver/mp_partition_driver.py @@ -16,7 +16,7 @@ import xarray as xr from scipy.stats import binned_statistic -from e3sm_diags import INSTALL_PATH +from e3sm_diags import INSTALL_PATH, LEGACY_XARRAY_MERGE_KWARGS from e3sm_diags.driver.utils.dataset_xr import Dataset from e3sm_diags.logger import _setup_child_logger from e3sm_diags.plot.mp_partition_plot import plot @@ -60,19 +60,27 @@ def compute_lcf(cice, cliq, temp, landfrac): def run_diag(parameter: MPpartitionParameter) -> MPpartitionParameter: - """Runs the mixed-phase partition/T5050 diagnostic. - - :param parameter: Parameters for the run - :type parameter: CoreParameter - :raises ValueError: Invalid run type - :return: Parameters for the run - :rtype: CoreParameter + """ + Runs the mixed-phase partition/T5050 diagnostic. + + Parameters + ---------- + parameter : CoreParameter + Parameters for the run. + + Raises + ------ + ValueError + If the run type is invalid. + + Returns + ------- + CoreParameter + Parameters for the run. """ run_type = parameter.run_type season = "ANN" - # Read reference data first - benchmark_data_path = os.path.join( INSTALL_PATH, "control_runs", @@ -82,35 +90,20 @@ def run_diag(parameter: MPpartitionParameter) -> MPpartitionParameter: with open(benchmark_data_path, "r") as myfile: lcf_file = myfile.read() - # parse file metrics_dict = json.loads(lcf_file) test_data = Dataset(parameter, data_type="test") - # test = test_data.get_timeseries_variable("LANDFRAC") - # print(dir(test)) - # landfrac = test_data.get_timeseries_variable("LANDFRAC")(cdutil.region.domain(latitude=(-70.0, -30, "ccb"))) - # temp = test_data.get_timeseries_variable("T")(cdutil.region.domain(latitude=(-70.0, -30, "ccb"))) - # cice = test_data.get_timeseries_variable("CLDICE")(cdutil.region.domain(latitude=(-70.0, -30, "ccb"))) - # cliq = test_data.get_timeseries_variable("CLDLIQ")(cdutil.region.domain(latitude=(-70.0, -30, "ccb"))) - test_data_path = parameter.test_data_path - start_year = parameter.test_start_yr - end_year = parameter.test_end_yr + + start_year = int(parameter.test_start_yr) + end_year = int(parameter.test_end_yr) + # TODO the time subsetting and variable derivation should be replaced during cdat revamp try: - # xr.open_mfdataset() can accept an explicit list of files. - landfrac = xr.open_mfdataset(glob.glob(f"{test_data_path}/LANDFRAC_*")).sel( - lat=slice(-70, -30), time=slice(f"{start_year}-01-01", f"{end_year}-12-31") - )["LANDFRAC"] - temp = xr.open_mfdataset(glob.glob(f"{test_data_path}/T_*.nc")).sel( - lat=slice(-70, -30), time=slice(f"{start_year}-01-01", f"{end_year}-12-31") - )["T"] - cice = xr.open_mfdataset(glob.glob(f"{test_data_path}/CLDICE_*.nc")).sel( - lat=slice(-70, -30), time=slice(f"{start_year}-01-01", f"{end_year}-12-31") - )["CLDICE"] - cliq = xr.open_mfdataset(glob.glob(f"{test_data_path}/CLDLIQ_*.nc")).sel( - lat=slice(-70, -30), time=slice(f"{start_year}-01-01", f"{end_year}-12-31") - )["CLDLIQ"] + landfrac = _open_mfdataset(test_data_path, "LANDFRAC", start_year, end_year) + temp = _open_mfdataset(test_data_path, "T", start_year, end_year) + cice = _open_mfdataset(test_data_path, "CLDICE", start_year, end_year) + cliq = _open_mfdataset(test_data_path, "CLDLIQ", start_year, end_year) except OSError: logger.info( f"No files to open for variables within {start_year} and {end_year} from {test_data_path}." @@ -126,46 +119,22 @@ def run_diag(parameter: MPpartitionParameter) -> MPpartitionParameter: if run_type == "model-vs-model": ref_data = Dataset(parameter, data_type="ref") - ref_data_path = parameter.reference_data_path - start_year = parameter.ref_start_yr - end_year = parameter.ref_end_yr - # xr.open_mfdataset() can accept an explicit list of files. + + start_year = int(parameter.ref_start_yr) + end_year = int(parameter.ref_end_yr) + try: - landfrac = xr.open_mfdataset(glob.glob(f"{ref_data_path}/LANDFRAC_*")).sel( - lat=slice(-70, -30), - time=slice(f"{start_year}-01-01", f"{end_year}-12-31"), - )["LANDFRAC"] - temp = xr.open_mfdataset(glob.glob(f"{ref_data_path}/T_*.nc")).sel( - lat=slice(-70, -30), - time=slice(f"{start_year}-01-01", f"{end_year}-12-31"), - )["T"] - cice = xr.open_mfdataset(glob.glob(f"{ref_data_path}/CLDICE_*.nc")).sel( - lat=slice(-70, -30), - time=slice(f"{start_year}-01-01", f"{end_year}-12-31"), - )["CLDICE"] - cliq = xr.open_mfdataset(glob.glob(f"{ref_data_path}/CLDLIQ_*.nc")).sel( - lat=slice(-70, -30), - time=slice(f"{start_year}-01-01", f"{end_year}-12-31"), - )["CLDLIQ"] + landfrac = _open_mfdataset(ref_data_path, "LANDFRAC", start_year, end_year) + temp = _open_mfdataset(ref_data_path, "T", start_year, end_year) + cice = _open_mfdataset(ref_data_path, "CLDICE", start_year, end_year) + cliq = _open_mfdataset(ref_data_path, "CLDLIQ", start_year, end_year) except OSError: logger.info( f"No files to open for variables within {start_year} and {end_year} from {ref_data_path}." ) raise - # landfrac = ref_data.get_timeseries_variable("LANDFRAC")( - # cdutil.region.domain(latitude=(-70.0, -30, "ccb")) - # ) - # temp = ref_data.get_timeseries_variable("T")( - # cdutil.region.domain(latitude=(-70.0, -30, "ccb")) - # ) - # cice = ref_data.get_timeseries_variable("CLDICE")( - # cdutil.region.domain(latitude=(-70.0, -30, "ccb")) - # ) - # cliq = ref_data.get_timeseries_variable("CLDLIQ")( - # cdutil.region.domain(latitude=(-70.0, -30, "ccb")) - # ) parameter.ref_name_yrs = ref_data.get_name_yrs_attr(season) metrics_dict["ref"] = {} metrics_dict["ref"]["T"], metrics_dict["ref"]["LCF"] = compute_lcf( @@ -177,3 +146,47 @@ def run_diag(parameter: MPpartitionParameter) -> MPpartitionParameter: plot(metrics_dict, parameter) return parameter + + +def _open_mfdataset( + data_path: str, var: str, start_year: int, end_year: int +) -> xr.DataArray: + """ + Open multiple NetCDF files as a single xarray Dataset and subset by time + and latitude. + + This function reads multiple NetCDF files matching the specified variable + name and combines them into a single xarray Dataset. The data is then + subsetted based on the specified time range and latitude bounds. + + Parameters + ---------- + data_path : str + The path to the directory containing the NetCDF files. + var : str + The variable name to match in the file pattern. + start_year : int + The starting year for the time subsetting. + end_year : int + The ending year for the time subsetting. + + Returns + ------- + xr.DataArray + The subsetted DataArray for the specified variable, filtered by time + and latitude. + """ + file_pattern = f"{data_path}/{var}_*.nc" + ds = xr.open_mfdataset( + glob.glob(file_pattern), + data_vars="minimal", + **LEGACY_XARRAY_MERGE_KWARGS, # type: ignore[ arg-type ] + ) + + ds_sub = ds.sel( + lat=slice(-70, -30), time=slice(f"{start_year}-01-01", f"{end_year}-12-31") + ) + + da_var = ds_sub[var] + + return da_var diff --git a/e3sm_diags/driver/tropical_subseasonal_driver.py b/e3sm_diags/driver/tropical_subseasonal_driver.py index b48538f89..9e5a3aa9e 100644 --- a/e3sm_diags/driver/tropical_subseasonal_driver.py +++ b/e3sm_diags/driver/tropical_subseasonal_driver.py @@ -244,6 +244,7 @@ def wf_analysis(x, **kwargs): background.rename("spec_background"), ], compat="override", + join="outer", ) spec_all = spec.drop("component") spec_all["spec_raw_sym"].attrs = {"component": "symmetric", "type": "raw"} diff --git a/e3sm_diags/driver/utils/dataset_xr.py b/e3sm_diags/driver/utils/dataset_xr.py index ca1fc5121..5dc8932a1 100644 --- a/e3sm_diags/driver/utils/dataset_xr.py +++ b/e3sm_diags/driver/utils/dataset_xr.py @@ -24,6 +24,7 @@ import xarray as xr import xcdat as xc +from e3sm_diags import LEGACY_XARRAY_MERGE_KWARGS from e3sm_diags.derivations.derivations import ( DERIVED_VARIABLES, FUNC_NEEDS_TARGET_VAR, @@ -1159,7 +1160,7 @@ def _get_dataset_with_source_vars(self, vars_to_get: tuple[str, ...]) -> xr.Data ds = self._get_time_series_dataset_obj(var) datasets.append(ds) - ds = xr.merge(datasets) + ds = xr.merge(datasets, **LEGACY_XARRAY_MERGE_KWARGS) # type: ignore[ arg-type ] ds = squeeze_time_dim(ds) return ds @@ -1640,7 +1641,7 @@ def _get_land_sea_mask_dataset(self, season: ClimoFreq) -> xr.Dataset | None: datasets.append(ds_ocn) if len(datasets) == 2: - return xr.merge(datasets) + return xr.merge(datasets, **LEGACY_XARRAY_MERGE_KWARGS) # type: ignore[ arg-type ] return None