From c5afcdadfefad952e7214777c8e537a4bf360d1b Mon Sep 17 00:00:00 2001 From: lee1043 Date: Wed, 24 Apr 2024 16:57:16 -0700 Subject: [PATCH 1/3] first commit for lib files --- pcmdi_metrics/qbo/lib/__init__.py | 19 + .../qbo/lib/compute_qbo_mjo_metrics.py | 499 ++++++++++++++++++ pcmdi_metrics/qbo/lib/const.py | 16 + pcmdi_metrics/qbo/lib/kf_filter.py | 491 +++++++++++++++++ pcmdi_metrics/qbo/lib/utils.py | 360 +++++++++++++ pcmdi_metrics/qbo/lib/utils_parallel.py | 61 +++ 6 files changed, 1446 insertions(+) create mode 100644 pcmdi_metrics/qbo/lib/__init__.py create mode 100644 pcmdi_metrics/qbo/lib/compute_qbo_mjo_metrics.py create mode 100644 pcmdi_metrics/qbo/lib/const.py create mode 100644 pcmdi_metrics/qbo/lib/kf_filter.py create mode 100644 pcmdi_metrics/qbo/lib/utils.py create mode 100644 pcmdi_metrics/qbo/lib/utils_parallel.py diff --git a/pcmdi_metrics/qbo/lib/__init__.py b/pcmdi_metrics/qbo/lib/__init__.py new file mode 100644 index 000000000..02795f611 --- /dev/null +++ b/pcmdi_metrics/qbo/lib/__init__.py @@ -0,0 +1,19 @@ +from .compute_qbo_mjo_metrics import ( # noqa + process_qbo_mjo_metrics, +) +from .kf_filter.py import KFfilter # noqa +from .utils import ( # noqa + generate_target_grid, + select_time_range, + test_plot_time_series, + test_plot_maps, + standardize_lat_lon_name_in_dataset, + find_coord_key, + diag_plot, + mycolormap, +) +from .utils_parallel import ( # noqa + configure_logger, + LoggerWriter, + process, +) diff --git a/pcmdi_metrics/qbo/lib/compute_qbo_mjo_metrics.py b/pcmdi_metrics/qbo/lib/compute_qbo_mjo_metrics.py new file mode 100644 index 000000000..956d99b58 --- /dev/null +++ b/pcmdi_metrics/qbo/lib/compute_qbo_mjo_metrics.py @@ -0,0 +1,499 @@ +#!/usr/bin/env python + +# QBO-MJO Metric Prototyping + +import json +import os + +import xarray as xr +import xcdat as xc +from kf_filter import KFfilter +from utils import ( + diag_plot, + generate_target_grid, + select_time_range, + standardize_lat_lon_name_in_dataset, + test_plot_maps, + test_plot_time_series, +) + +from pcmdi_metrics.mean_climate.lib import load_and_regrid + + +def main(): + # model = "CESM2" + model = "ERA5" + + if model == "CESM2": + # User-defining parameters + params = { + "model": "CESM2", + "exp": "historical", + "member": "r1i1p1f1", + "input_file": "sample_data/ua_Amon_CESM2_historical_r1i1p1f1_gn_185001-201412.nc", + "input_file2": "sample_data/rlut_day_CESM2_historical_r1i1p1f1_gn_19800101-19891231.nc", + "varname": "ua", + "level": 50, # hPa (=mb) + "varname2": "rlut", + "start": "1981-01", + "end": "1988-12", + "regrid": False, + "regrid_tool": "xesmf", + "target_grid": "2x2", + "taper_to_mean": True, + "output_dir": "./output_data", + "debug": False, + } + elif model == "ERA5": + # User-defining parameters for ERA5 + params = { + "model": "ERA5", + "exp": None, + "member": None, + "input_file": "/work/lee1043/DATA/ERA5/ERA5_u50_monthly_1979-2021_rewrite.nc", + "input_file2": "/work/lee1043/DATA/ERA5/ERA5_olr_daily_40s40n_1979-2021_rewrite.nc", + "varname": "u50", + "level": None, # hPa (=mb) + "varname2": "olr", + "start": "1979-01", + # "end": "2014-12", + "end": "2010-12", + "regrid": True, + "regrid_tool": "xesmf", + "target_grid": "2x2", + "taper_to_mean": True, + "output_dir": "./output_data", + "debug": False, + } + + output_metrics = process_qbo_mjo_metrics(params) + print(output_metrics) + + +def process_qbo_mjo_metrics(params): + """ + Process QBO-MJO Metric based on the provided parameters. + + Args: + params (dict): A dictionary containing user-defined parameters. + + Returns: + dict: A dictionary containing the calculated metrics. + """ + output = {} + + model = params["model"] + exp = params["exp"] + member = params["member"] + input_file = params["input_file"] + input_file2 = params["input_file2"] + varname = params["varname"] + level = params["level"] + varname2 = params["varname2"] + start = params["start"] + end = params["end"] + regrid = params["regrid"] + regrid_tool = params["regrid_tool"] + target_grid = params["target_grid"] + taper_to_mean = params["taper_to_mean"] + output_dir = params["output_dir"] + debug = params["debug"] + + # ======================================================================================= + # ## Part 1: U50 + # + # 1. reads in monthly U data and handles the calendar/extracts the desired time range + # ( this phenomenon occurs in the recent record ~1979+) + # + # 2. calculates the DJF mean QBO index consistent with our paper. + # - reads in monthly zonal wind at 50mb level + # - calculates monthly anomalies for U50 + # - averages the data across latitudes for 10S-10N and all longitudes + # - calculates a 3-month running mean of these and then the standard deviation + # of that smoothed timeseries + # - then, we extract and average only the DJF season + + os.makedirs(output_dir, exist_ok=True) + + output_file_tag = "_" + model + if exp is not None: + output_file_tag += "_" + exp + if member is not None: + output_file_tag += "_" + member + output_file_tag += "_" + start + "_" + end + + if regrid: + # generate target grid + t_grid = generate_target_grid(target_grid) + + # reads in monthly zonal wind at 50mb level + ds = load_and_regrid( + data_path=input_file, + varname=varname, + level=level, + t_grid=t_grid, + decode_times=True, + regrid_tool=regrid_tool, + debug=debug, + ) + + output_file_tag += "_" + regrid_tool + "_" + target_grid + + else: + ds = xc.open_dataset(input_file) + if level is not None: + ds = ds.sel(plev=level * 100) # hPa to Pa + + # Subset time + ds = select_time_range(ds, start, end) + + # Standardize coordinate (lat, lon) + ds = standardize_lat_lon_name_in_dataset(ds) + + # Subset region (latitude 10S to 10N) + ds_region = ds.sel(lat=slice(-10, 10)) + + if debug: + print("range for u") + print(ds_region[varname].min().values) + print(ds_region[varname].max().values) + print(ds_region[varname].to_numpy().shape) + print("------------") + + # calculates monthly anomalies for U50 + # remove annaul cycle + ds_region_ano = ds_region.temporal.departures(varname, "month") + + if debug: + print("U anomalies") + print(ds_region_ano[varname].min().values) + print(ds_region_ano[varname].max().values) + print(ds_region_ano[varname].to_numpy().shape) + print("------------") + + # Averages the data across latitudes for 10S-10N and all longitudes + ds_region_ano_ave = ds_region_ano.spatial.average( + varname, axis=["X", "Y"] + ).compute() + + if debug: + print("range of u avgd in lat/lon") + print(ds_region_ano_ave[varname].min().values) + print(ds_region_ano_ave[varname].max().values) + print(ds_region_ano_ave[varname].to_numpy().shape) + + # calculates a 3-month running mean of these and then the standard deviation of that smoothed timeseries + ds_region_ano_ave_runningmean = ( + ds_region_ano_ave.rolling(time=3, center=True) + .construct("window") + .mean("window") + .isel(time=slice(1, -1)) + ) + ds_region_ano_ave_runningmean["time_bnds"] = ds_region_ano_ave["time_bnds"].isel( + time=slice(1, -1) + ) # re-add missed time bnds from the above line step + + if debug: + print("u range with seasonal smoothing") + print(ds_region_ano_ave_runningmean[varname].min().values) + print(ds_region_ano_ave_runningmean[varname].max().values) + print(ds_region_ano_ave_runningmean[varname].to_numpy().shape) + print("------") + + try: + ds_region_ano_ave_runningmean.to_netcdf( + os.path.join( + output_dir, + "ds_region_ano_ave_runningmean" + output_file_tag + ".nc", + ) + ) + except Exception as e: + print(e) + pass + + # standard deviation (std) + std = float(ds_region_ano_ave_runningmean[varname].std(dim="time")) + print( + "sd of entire qbo index - smoothed u50 anomalies averaged 10s-10n and all lons" + ) + print(std) + + # calculate seasonal mean of the smoothed time series (to get average DJF) + ds_region_ano_ave_runningmean_season = ( + ds_region_ano_ave_runningmean.temporal.group_average( + varname, + freq="season", + season_config={ + "custom_seasons": None, + "dec_mode": "DJF", + "drop_incomplete_djf": True, + }, + ) + ) + + # then, we extract only the average of DJF + ds_region_ano_ave_runningmean_season_djf = ( + ds_region_ano_ave_runningmean_season.groupby("time.season")["DJF"] + ) + + if debug: + print("range for djf-mean qbo index") + print(ds_region_ano_ave_runningmean_season_djf[varname].min().values) + print(ds_region_ano_ave_runningmean_season_djf[varname].max().values) + print(ds_region_ano_ave_runningmean_season_djf[varname].to_numpy().shape) + print("------") + + try: + ds_region_ano_ave_runningmean_season_djf.to_netcdf( + os.path.join( + output_dir, + "ds_region_ano_ave_runningmean_season_djf" + + output_file_tag + + ".nc", + ) + ) + except Exception as e: + print(e) + pass + + # Test plots (time series) + if debug: + test_plot_time_series( + ds_region_ano_ave[varname], + output_file=os.path.join( + output_dir, "ds_region_ano_ave" + output_file_tag + ".png" + ), + std=std, + title=f"{model} ({member})", + ) + + test_plot_time_series( + ds_region_ano_ave_runningmean[varname], + output_file=os.path.join( + output_dir, "ds_region_ano_ave_runningmean" + output_file_tag + ".png" + ), + std=std, + title=f"{model} ({member})", + ) + + test_plot_time_series( + ds_region_ano_ave_runningmean_season_djf[varname], + output_file=os.path.join( + output_dir, + "ds_region_ano_ave_runningmean_season_djf" + output_file_tag + ".png", + ), + std=std, + title=f"{model} ({member})", + ) + + # ## Part 2: OLR + # + # 3. we read in daily OLR for 30S-30N, again extracting the same desired time period, and calculate the mjo-filtered OLR following wheeler and kiladis 2009. + # - use the kf_filter function (https://github.com/tmiyachi/mcclimate/blob/master/kf_filter.py) which transforms the data to frequency space and filters out all waves except the mjo and back-transforms. + # - use the criteria in the kf_filter for periods 20-100 days, waves 1 to 5, and wavetype kelvin to get the mjo filtered olr. + # - then, we extract only the DJF season + # - and calculate the standard deviation of mjo-filtered olr for DJF + + # we read in daily OLR for 30S-30N, again extracting the same desired time period, and calculate the mjo-filtered OLR following wheeler and kiladis 2009. + + if regrid: + ds2 = load_and_regrid( + data_path=input_file2, + varname=varname2, + t_grid=t_grid, + decode_times=True, + regrid_tool=regrid_tool, + debug=debug, + ) + else: + ds2 = xc.open_mfdataset(input_file2) + + ds2 = select_time_range(ds2, start, end) + + # Subset region + ds2_region = ds2.sel(lat=slice(-30, 30)) + + if debug: + print("range of olr for desired time range and 30s to 30n") + print(ds2_region[varname2].min().values) + print(ds2_region[varname2].max().values) + print(ds2_region[varname2].to_numpy().shape) + + # Apply KF filter + kf = KFfilter( + datain=ds2_region[varname2].to_numpy(), + spd=1, + tim_taper=0.05, + taper_to_mean=taper_to_mean, + ) # 5% tapering each side, following Kim et al., 2020, GRL + kf_filtered = kf.kelvinfilter( + fmin=0.01, fmax=0.05, kmin=1, kmax=5, hmin=0.000001, hmax=8 + ) # periods 20-100 days, waves 1 to 5, equivalent depth 0 to 8 + ds2_region["mjo_olr"] = ( + ["time", "lat", "lon"], + kf_filtered, + ) # Convert the numpy as xarray DataArray and add it to the DataSet + + if debug: + print("range of mjolr") + print(ds2_region["mjo_olr"].min().values) + print(ds2_region["mjo_olr"].max().values) + print(ds2_region["mjo_olr"].to_numpy().shape) + + ds2_region["window"] = (["time"], kf.window) + ds2_region["mjo_olr_detrended"] = (["time", "lat", "lon"], kf.detrended) + ds2_region["mjo_olr_tapered"] = (["time", "lat", "lon"], kf.tapered) + + # Extract DJF only + ds2_region_djf = ds2_region.groupby("time.season")["DJF"] + + # Calculate the standard deviation of mjo-filtered olr for DJF + std2_map = ds2_region_djf["mjo_olr"].std(dim="time") # time steps collapse + + ds2_region["mjo_olr_stdmap"] = std2_map + + # ## Part 3: Diagnostics + # + # 4. then we partition the DJF seasons by QBO phase (easterly or westerly) + # - we use the threshold of +/- 0.5 standard deviation to define the QBO phase + # - and we partition all daily mjo-filtered olr for that DJF season into east, west, and neutral QBO phase + # - then take the standard deviation of QBOE mjo-filtered olr and QBOW mjo-filtered olr for plotting + # + # 5. currently in separate code, we plot the stddev of mjo-filtered olr for DJF along with the difference between QBOE and QBOW stddev of mjo-filtered olr. + + std2_map_phase = dict() + qbo_phase_years = dict() + + for qbo_phase in ["east", "west"]: + if qbo_phase == "west": + condition = ds_region_ano_ave_runningmean_season_djf[varname] > 0.5 * std + elif qbo_phase == "east": + condition = ds_region_ano_ave_runningmean_season_djf[varname] < -0.5 * std + + if any(condition.values.tolist()): + qbo_phase_ds = ds_region_ano_ave_runningmean_season_djf.where( + condition, drop=True + ) + qbo_phase_years[qbo_phase] = [ + t.year - 1 for t in qbo_phase_ds.indexes["time"].to_datetimeindex() + ] # define the year of the event by december of the djf season to simplify, thus t.year-1 + print( + "qbo_phase, qbo_phase_year_list:", qbo_phase, qbo_phase_years[qbo_phase] + ) + + datasets = [] + + for year in qbo_phase_years[qbo_phase]: + tmp_dec = ( + ds2_region_djf["mjo_olr"] + .groupby("time.year")[year] + .groupby("time.month")[12] + ) + tmp_jan = ( + ds2_region_djf["mjo_olr"] + .groupby("time.year")[year + 1] + .groupby("time.month")[1] + ) + tmp_feb = ( + ds2_region_djf["mjo_olr"] + .groupby("time.year")[year + 1] + .groupby("time.month")[2] + ) + datasets.extend([tmp_dec, tmp_jan, tmp_feb]) + + combined = xr.concat(datasets, dim="time") + std2_map_phase[qbo_phase] = combined.std(dim="time") # time steps collapse + + ds2_region["mjo_olr_stdmap_" + qbo_phase] = std2_map_phase[qbo_phase] + + std2_map_diff = std2_map_phase["east"] - std2_map_phase["west"] + ds2_region["mjo_olr_stdmap_east_minus_west"] = std2_map_diff + + if debug: + # Save all variables + ds2_region.to_netcdf( + os.path.join(output_dir, "mjo_olr" + output_file_tag + ".nc") + ) + else: + # Save only selected variables that is shown in the diagnostic plot + ds2_region[ + [ + "mjo_olr_stdmap", + "mjo_olr_stdmap_east", + "mjo_olr_stdmap_west", + "mjo_olr_stdmap_east_minus_west", + ] + ].to_netcdf( + os.path.join(output_dir, "mjo_olr_stddev_DJF" + output_file_tag + ".nc") + ) + + # Test plot (MJO filtered OLR MAP) + if debug: + test_plot_maps( + std2_map, + std2_map_phase, + fig_title=f"{model} ({member}): OLR DJF temporal STD map (" + + start + + " to " + + end + + ")", + output_file=os.path.join( + output_dir, "test_OLR_DJF_temporal_STD_map" + output_file_tag + ".png" + ), + ) + + # Diagnostic plot: the associated plot, with the gray contour lines for mjo activity and color filled contours for the qboe-qbow mjo activity difference + diag_plot( + std2_map, + std2_map_diff, + fig_title=f"{model} ({member}): Stddev of DJF MJO-Filtered OLR", + output_file=os.path.join( + output_dir, "mjo_olr_stddev_DJF" + output_file_tag + ".png" + ), + sub_region=(50, 170, -20, 5), + ) + + # ## Part 4: Metrics + + # mjo olr activity (standard deviation) over all years average over the maritime continent region in our kim et al 2020 paper (50E-170E, 20S-5N) + metric1 = float( + ds2_region.spatial.average( + data_var="mjo_olr_stdmap", lat_bounds=(-20, 5), lon_bounds=(50, 170) + )["mjo_olr_stdmap"].values + ) + + # mjo olr activity difference between qboe and qbow averaged over the maritime continent region + metric2 = float( + ds2_region.spatial.average( + data_var="mjo_olr_stdmap_east_minus_west", + lat_bounds=(-20, 5), + lon_bounds=(50, 170), + )["mjo_olr_stdmap_east_minus_west"].values + ) + + print("metric1 (mjo_activity):", metric1) + print("metric2 (mjo_activity_diff):", metric2) + + # Prepare output dictionary + output[model] = dict() + output[model][member] = dict() + output[model][member]["mjo_activity"] = metric1 + output[model][member]["mjo_activity_diff"] = metric2 + output[model][member]["qbo_east_years"] = qbo_phase_years["east"] + output[model][member]["qbo_west_years"] = qbo_phase_years["west"] + + # Write the dictionary to the JSON file + json_file_path = os.path.join( + output_dir, "mjo_olr_stddev_DJF" + output_file_tag + ".json" + ) + + with open(json_file_path, "w") as json_file: + json.dump(output, json_file, indent=4) + + print("json file saved:", json_file_path) + + return output + + +if __name__ == "__main__": + main() diff --git a/pcmdi_metrics/qbo/lib/const.py b/pcmdi_metrics/qbo/lib/const.py new file mode 100644 index 000000000..a0ec4397a --- /dev/null +++ b/pcmdi_metrics/qbo/lib/const.py @@ -0,0 +1,16 @@ +# constant module class + +# universal constants +gravitational_constant = 6.673e-11 # Nm^2/kg^2 +gasconst = 8.314e3 # JK^-1kmol^-1 + +# earth +gravity_earth = 9.81 # m/s^2 +radius_earth = 6.37e6 +omega_earth = 7.292e-5 + +# air +gasconst_dry = 287 +specific_heat_pressure = 1004 +specific_heat_volume = 717 +ratio_gamma = specific_heat_pressure / specific_heat_volume diff --git a/pcmdi_metrics/qbo/lib/kf_filter.py b/pcmdi_metrics/qbo/lib/kf_filter.py new file mode 100644 index 000000000..313d8d75c --- /dev/null +++ b/pcmdi_metrics/qbo/lib/kf_filter.py @@ -0,0 +1,491 @@ +# This script is from https://github.com/tmiyachi/mcclimate/blob/master/kf_filter.py +# 2023-10 taper_to_mean option added by Jiwoo Lee (LLNL) +# 2024-04 NaN value removal added by Jiwoo Lee (LLNL) + +import sys + +import const +import numpy +import scipy.fftpack as fftpack +import scipy.signal as signal + +NA = numpy.newaxis +pi = numpy.pi +g = const.gravity_earth +a = const.radius_earth +beta = 2.0 * const.omega_earth / const.radius_earth + + +class KFfilter: + """class for wavenumber-frequency filtering for WK99 and WKH00""" + + def __init__(self, datain, spd, tim_taper=0.1, taper_to_mean=False): + """Arguments: + + 'datain' -- numpy array, the data to be filtered. dimension must be (time, lat, lon) + + 'spd' -- int, samples per day + + 'tim_taper' -- float, tapering ratio by cos. applay tapering first and last tim_taper% + samples. default is cos20 tapering + + 'taper_to_mean' -- bool, taper to mean. default is False (taper to zero) + + """ + ntim, nlat, nlon = datain.shape + + # remove the lowest three harmonics of the seasonal cycle (WK99, WKW03) + ## if ntim > 365*spd/3: + ## rf = fftpack.rfft(datain,axis=0) + ## freq = fftpack.rfftfreq(ntim*spd, d=1./float(spd)) + ## rf[(freq <= 3./365) & (freq >=1./365),:,:] = 0.0 #freq<=3./365 only?? + ## datain = fftpack.irfft(rf,axis=0) + + # remove NaN value if exist in datain + nan_mask = numpy.isnan(datain) # Identify NaN values + datain[nan_mask] = 0 # Replace NaN values with zero + + # remove dominal trend + data = signal.detrend(datain, axis=0) + + self.detrended = data.copy() + + # tapering + if tim_taper == "hann": + window = signal.hann(ntim) + data = data * window[:, NA, NA] + elif tim_taper > 0: + # taper by cos tapering same dtype as input array + tp = int(ntim * tim_taper) + window = numpy.ones(ntim, dtype=datain.dtype) + x = numpy.arange(tp) + window[:tp] = 0.5 * (1.0 - numpy.cos(x * pi / tp)) + window[-tp:] = 0.5 * (1.0 - numpy.cos(x[::-1] * pi / tp)) + if taper_to_mean is False: + data = data * window[:, NA, NA] + else: + mean = data.mean(axis=0) + print("mean:", mean) + data = (data - mean) * window[:, NA, NA] + mean + + self.window = window + self.tapered = data.copy() + + # FFT + self.fftdata = fftpack.fft2(data, axes=(0, 2)) + + # Note + # fft is defined by exp(-ikx), so to adjust exp(ikx) multipried minus + wavenumber = -fftpack.fftfreq(nlon) * nlon + frequency = fftpack.fftfreq(ntim, d=1.0 / float(spd)) + knum, freq = numpy.meshgrid(wavenumber, frequency) + + # make f<0 domain same as f>0 domain + # CAUTION: wave definition is exp(i(k*x-omega*t)) but FFT definition exp(-ikx) + # so cahnge sign + knum[freq < 0] = -knum[freq < 0] + freq = numpy.abs(freq) + self.knum = knum + self.freq = freq + + self.wavenumber = wavenumber + self.frequency = frequency + + def decompose_antisymm(self): + """decompose attribute data to sym and antisym component""" + fftdata = self.fftdata + nf, nlat, nk = fftdata.shape + symm = 0.5 * ( + fftdata[:, : nlat / 2 + 1, :] + fftdata[:, nlat : nlat / 2 - 1 : -1, :] + ) + anti = 0.5 * (fftdata[:, : nlat / 2, :] - fftdata[:, nlat : nlat / 2 : -1, :]) + + self.fftdata = numpy.concatenate([anti, symm], axis=1) + + def kfmask(self, fmin=None, fmax=None, kmin=None, kmax=None): + """return wavenumber-frequency mask for wavefilter method + + Arguments: + + 'fmin/fmax' -- + + 'kmin/kmax' -- + """ + nf, nlat, nk = self.fftdata.shape + knum = self.knum + freq = self.freq + + # wavenumber cut-off + mask = numpy.zeros((nf, nk), dtype=numpy.bool) + if kmin is not None: + mask = mask | (knum < kmin) + if kmax is not None: + mask = mask | (kmax < knum) + + # frequency cutoff + if fmin is not None: + mask = mask | (freq < fmin) + if fmax is not None: + mask = mask | (fmax < freq) + + return mask + + def wavefilter(self, mask): + """apply wavenumber-frequency filtering by original mask. + + Arguments: + + 'mask' -- 2D boolean array (wavenumber, frequency).domain to be filterd + is False (True member to be zero) + """ + # wavenumber = self.wavenumber + # frequency = self.frequency + fftdata = self.fftdata.copy() + nf, nlat, nk = fftdata.shape + + if (nf, nk) != mask.shape: + print("mask array size is incorrect.") + sys.exit() + + mask = numpy.repeat(mask[:, NA, :], nlat, axis=1) + fftdata[mask] = 0.0 + + # inverse FFT + filterd = fftpack.ifft2(fftdata, axes=(0, 2)) + return filterd.real + + # filter + def kelvinfilter(self, fmin=0.05, fmax=0.4, kmin=None, kmax=14, hmin=8, hmax=90): + """kelvin wave filter + + Arguments: + + 'fmin/fmax' -- unit is cycle per day + + 'kmin/kmax' -- zonal wave number + + 'hmin/hmax' --equivalent depth + """ + + fftdata = self.fftdata.copy() + knum = self.knum + freq = self.freq + nf, nlat, nk = fftdata.shape + + # filtering ############################################################ + mask = numpy.zeros((nf, nk), dtype=numpy.bool) + # wavenumber cut-off + if kmin is not None: + mask = mask | (knum < kmin) + if kmax is not None: + mask = mask | (kmax < knum) + + # frequency cutoff + if fmin is not None: + mask = mask | (freq < fmin) + if fmax is not None: + mask = mask | (fmax < freq) + + # dispersion filter + if hmin is not None: + c = numpy.sqrt(g * hmin) + omega = ( + 2.0 * pi * freq / 24.0 / 3600.0 / numpy.sqrt(beta * c) + ) # adusting day^-1 to s^-1 + k = knum / a * numpy.sqrt(c / beta) # adusting ^2pia to ^m + mask = mask | (omega - k < 0) + if hmax is not None: + c = numpy.sqrt(g * hmax) + omega = ( + 2.0 * pi * freq / 24.0 / 3600.0 / numpy.sqrt(beta * c) + ) # adusting day^-1 to s^-1 + k = knum / a * numpy.sqrt(c / beta) # adusting ^2pia to ^m + mask = mask | (omega - k > 0) + + mask = numpy.repeat(mask[:, NA, :], nlat, axis=1) + fftdata[mask] = 0.0 + + filterd = fftpack.ifft2(fftdata, axes=(0, 2)) + return filterd.real + + def erfilter(self, fmin=None, fmax=None, kmin=-10, kmax=-1, hmin=8, hmax=90, n=1): + """equatorial wave filter + + Arguments: + + 'fmin/fmax' -- unit is cycle per day + + 'kmin/kmax' -- zonal wave number + + 'hmin/hmax' -- equivalent depth + + 'n' -- meridional mode number + """ + + if n <= 0 or n % 1 != 0: + print("n must be n>=1 integer") + sys.exit() + + fftdata = self.fftdata.copy() + knum = self.knum + freq = self.freq + nf, nlat, nk = fftdata.shape + + # filtering ############################################################ + mask = numpy.zeros((nf, nk), dtype=numpy.bool) + # wavenumber cut-off + if kmin is not None: + mask = mask | (knum < kmin) + if kmax is not None: + mask = mask | (kmax < knum) + + # frequency cutoff + if fmin is not None: + mask = mask | (freq < fmin) + if fmax is not None: + mask = mask | (fmax < freq) + + # dispersion filter + if hmin is not None: + c = numpy.sqrt(g * hmin) + omega = ( + 2.0 * pi * freq / 24.0 / 3600.0 / numpy.sqrt(beta * c) + ) # adusting day^-1 to s^-1 + k = knum / a * numpy.sqrt(c / beta) # adusting ^2pia to ^m + mask = mask | (omega * (k**2 + (2 * n + 1)) + k < 0) + if hmax is not None: + c = numpy.sqrt(g * hmax) + omega = ( + 2.0 * pi * freq / 24.0 / 3600.0 / numpy.sqrt(beta * c) + ) # adusting day^-1 to s^-1 + k = knum / a * numpy.sqrt(c / beta) # adusting ^2pia to ^m + mask = mask | (omega * (k**2 + (2 * n + 1)) + k > 0) + mask = numpy.repeat(mask[:, NA, :], nlat, axis=1) + + fftdata[mask] = 0.0 + + filterd = fftpack.ifft2(fftdata, axes=(0, 2)) + return filterd.real + + def igfilter(self, fmin=None, fmax=None, kmin=-15, kmax=-1, hmin=12, hmax=90, n=1): + """n>=1 inertio gravirt wave filter. default is n=1 WIG. + + Arguments: + + 'fmin/fmax' -- unit is cycle per day + + 'kmin/kmax' -- zonal wave number. negative is westward, positive is + eastward + + 'hmin/hmax' -- equivalent depth + + 'n' -- meridional mode number + """ + if n <= 0 or n % 1 != 0: + print("n must be n>=1 integer. for n=0 EIG you must use eig0filter method.") + sys.exit() + + fftdata = self.fftdata.copy() + knum = self.knum + freq = self.freq + nf, nlat, nk = fftdata.shape + + # filtering ############################################################ + mask = numpy.zeros((nf, nk), dtype=numpy.bool) + # wavenumber cut-off + if kmin is not None: + mask = mask | (knum < kmin) + if kmax is not None: + mask = mask | (kmax < knum) + + # frequency cutoff + if fmin is not None: + mask = mask | (freq < fmin) + if fmax is not None: + mask = mask | (fmax < freq) + + # dispersion filter + if hmin is not None: + c = numpy.sqrt(g * hmin) + omega = ( + 2.0 * pi * freq / 24.0 / 3600.0 / numpy.sqrt(beta * c) + ) # adusting day^-1 to s^-1 + k = knum / a * numpy.sqrt(c / beta) # adusting ^2pia to ^m + mask = mask | (omega**2 - k**2 - (2 * n + 1) < 0) + if hmax is not None: + c = numpy.sqrt(g * hmax) + omega = ( + 2.0 * pi * freq / 24.0 / 3600.0 / numpy.sqrt(beta * c) + ) # adusting day^-1 to s^-1 + k = knum / a * numpy.sqrt(c / beta) # adusting ^2pia to ^m + mask = mask | (omega**2 - k**2 - (2 * n + 1) > 0) + mask = numpy.repeat(mask[:, NA, :], nlat, axis=1) + fftdata[mask] = 0.0 + + filterd = fftpack.ifft2(fftdata, axes=(0, 2)) + return filterd.real + + def eig0filter(self, fmin=None, fmax=0.55, kmin=0, kmax=15, hmin=12, hmax=50): + """n>=0 eastward inertio gravirt wave filter. + + Arguments: + + 'fmin/fmax' -- unit is cycle per day + + 'kmin/kmax' -- zonal wave number. negative is westward, positive is + eastward + + 'hmin/hmax' -- equivalent depth + """ + if kmin < 0: + print("kmin must be positive. if k < 0, this mode is MRG") + sys.exit() + + fftdata = self.fftdata.copy() + knum = self.knum + freq = self.freq + nf, nlat, nk = fftdata.shape + + # filtering ############################################################ + mask = numpy.zeros((nf, nk), dtype=numpy.bool) + # wavenumber cut-off + if kmin is not None: + mask = mask | (knum < kmin) + if kmax is not None: + mask = mask | (kmax < knum) + + # frequency cutoff + if fmin is not None: + mask = mask | (freq < fmin) + if fmax is not None: + mask = mask | (fmax < freq) + + # dispersion filter + if hmin is not None: + c = numpy.sqrt(g * hmin) + omega = ( + 2.0 * pi * freq / 24.0 / 3600.0 / numpy.sqrt(beta * c) + ) # adusting day^-1 to s^-1 + k = knum / a * numpy.sqrt(c / beta) # adusting ^2pia to ^m + mask = mask | (omega**2 - k * omega - 1 < 0) + if hmax is not None: + c = numpy.sqrt(g * hmax) + omega = ( + 2.0 * pi * freq / 24.0 / 3600.0 / numpy.sqrt(beta * c) + ) # adusting day^-1 to s^-1 + k = knum / a * numpy.sqrt(c / beta) # adusting ^2pia to ^m + mask = mask | (omega**2 - k * omega - 1 > 0) + mask = numpy.repeat(mask[:, NA, :], nlat, axis=1) + fftdata[mask] = 0.0 + + filterd = fftpack.ifft2(fftdata, axes=(0, 2)) + return filterd.real + + def mrgfilter(self, fmin=None, fmax=None, kmin=-10, kmax=-1, hmin=8, hmax=90): + """mixed Rossby gravity wave + + Arguments: + + 'fmin/fmax' -- unit is cycle per day + + 'kmin/kmax' -- zonal wave number. negative is westward, positive is + eastward + + 'hmin/hmax' -- equivalent depth + """ + if kmax > 0: + print("kmax must be negative. if k > 0, this mode is the same as n=0 EIG") + sys.exit() + + fftdata = self.fftdata.copy() + knum = self.knum + freq = self.freq + nf, nlat, nk = fftdata.shape + + # filtering ############################################################ + mask = numpy.zeros((nf, nk), dtype=numpy.bool) + # wavenumber cut-off + if kmin is not None: + mask = mask | (knum < kmin) + if kmax is not None: + mask = mask | (kmax < knum) + + # frequency cutoff + if fmin is not None: + mask = mask | (freq < fmin) + if fmax is not None: + mask = mask | (fmax < freq) + + # dispersion filter + if hmin is not None: + c = numpy.sqrt(g * hmin) + omega = ( + 2.0 * pi * freq / 24.0 / 3600.0 / numpy.sqrt(beta * c) + ) # adusting day^-1 to s^-1 + k = knum / a * numpy.sqrt(c / beta) # adusting ^2pia to ^m + mask = mask | (omega**2 - k * omega - 1 < 0) + if hmax is not None: + c = numpy.sqrt(g * hmax) + omega = ( + 2.0 * pi * freq / 24.0 / 3600.0 / numpy.sqrt(beta * c) + ) # adusting day^-1 to s^-1 + k = knum / a * numpy.sqrt(c / beta) # adusting ^2pia to ^m + mask = mask | (omega**2 - k * omega - 1 > 0) + mask = numpy.repeat(mask[:, NA, :], nlat, axis=1) + fftdata[mask] = 0.0 + + filterd = fftpack.ifft2(fftdata, axes=(0, 2)) + return filterd.real + + def tdfilter(self, fmin=None, fmax=None, kmin=-20, kmax=-6): + """KTH05 TD-type filter. + + Arguments: + + 'fmin/fmax' -- unit is cycle per day + + 'kmin/kmax' -- zonal wave number. negative is westward, positive is + eastward + """ + fftdata = self.fftdata.copy() + knum = self.knum + freq = self.freq + nf, nlat, nk = fftdata.shape + mask = numpy.zeros((nf, nk), dtype=numpy.bool) + + # wavenumber cut-off + if kmin is not None: + mask = mask | (knum < kmin) + if kmax is not None: + mask = mask | (kmax < knum) + + # frequency cutoff + if fmin is not None: + mask = mask | (freq < fmin) + if fmax is not None: + mask = mask | (fmax < freq) + + # dispersion filter + mask = mask | (84 * freq + knum - 22 > 0) | (210 * freq + 2.5 * knum - 13 < 0) + mask = numpy.repeat(mask[:, NA, :], nlat, axis=1) + + fftdata[mask] = 0.0 + + filterd = fftpack.ifft2(fftdata, axes=(0, 2)) + return filterd.real + + +""" +# test ############################################# +import matplotlib.pyplot as plt +from scipy.fftpack import fftshift +x = fftshift(self.wavenumber) +y = fftshift(self.frequency) +power = numpy.abs(fftshift(fftdata[:,10,:], axes=(0,1)))**2 +z = power +CF=plt.contourf(x,y,z,[0,0.5,1],extend='max') +plt.axis([-17,17,-0.5,0.5]) +plt.colorbar(CF) +plt.show() +sys.exit() +""" diff --git a/pcmdi_metrics/qbo/lib/utils.py b/pcmdi_metrics/qbo/lib/utils.py new file mode 100644 index 000000000..a522a42cb --- /dev/null +++ b/pcmdi_metrics/qbo/lib/utils.py @@ -0,0 +1,360 @@ +import cartopy.crs as ccrs +import cartopy.feature as cfeature +import matplotlib.colors as mcolors +import matplotlib.pyplot as plt +import matplotlib.ticker as mticker +import numpy as np +import xcdat as xc + + +def generate_target_grid(target_grid): + """Generate common grid for interpolation + + Parameters + ---------- + target_grid : str + For example, "2.5x2.5" + + Returns + ------- + xcdat grid + _description_ + """ + + # generate target grid + res = target_grid.split("x") + lat_res = float(res[0]) + lon_res = float(res[1]) + start_lat = -90.0 + lat_res / 2 + start_lon = 0.0 + end_lat = 90.0 - lat_res / 2 + end_lon = 360.0 - lon_res + t_grid = xc.create_uniform_grid( + start_lat, end_lat, lat_res, start_lon, end_lon, lon_res + ) + + return t_grid + + +def select_time_range(ds, start, end): + """Subset time range + + Parameters + ---------- + ds : xarray dataset + dataset to subset + start : str + Starting year and month in format of "yyyy-mm" + end : str + Ending year and month in format of "yyyy-mm" + + Returns + ------- + xarray.Dataset + subsetted dataset + """ + + # USER DEFINED PERIOD + start_yr = int(start.split("-")[0]) + start_mo = int(start.split("-")[1]) + start_da = 1 + end_yr = int(end.split("-")[0]) + end_mo = int(end.split("-")[1]) + + print("end_yr:", end_yr) + print("end_mo:", end_mo) + + ds_tmp = ds.time.dt.days_in_month.sel(time=(ds.time.dt.year == end_yr)) + # end_da = int( + # ds_tmp.time.dt.days_in_month.sel(time=(ds_tmp.time.dt.month == end_mo))[-1] + # ) + end_da = int(ds_tmp.time[-1].dt.day) + + start_yr_str = str(start_yr).zfill(4) + start_mo_str = str(start_mo).zfill(2) + start_da_str = str(start_da).zfill(2) + end_yr_str = str(end_yr).zfill(4) + end_mo_str = str(end_mo).zfill(2) + end_da_str = str(end_da).zfill(2) + + # Subset given time period + ds = ds.sel( + time=slice( + start_yr_str + "-" + start_mo_str + "-" + start_da_str + " 00:00:00", + end_yr_str + "-" + end_mo_str + "-" + end_da_str + " 23:59:59", + ) + ) + + print("start_yr_str is ", start_yr_str) + print("start_mo_str is ", start_mo_str) + print("start_da is ", start_da) + print("end_yr_str is ", end_yr_str) + print("end_mo_str is ", end_mo_str) + print("end_da is", end_da) + + return ds + + +def test_plot_time_series(da, output_file, std=None, title=None): + """Plot time series to visualize interim output + + Parameters + ---------- + da : DataArray + DataArray to plot time series + output_file : str + file path and name for saving image + std : float, optional + standard deviation used for the threshold + title : str, optional + optional title, by default None + """ + + fig, ax = plt.subplots() + + da.plot(ax=ax) + + if std is not None: + ax.axhline(y=0.5 * std, c="k", ls="--") + ax.axhline(y=-0.5 * std, c="k", ls="--") + y = da.to_numpy() + x = da.time.to_numpy() + + ax.fill_between( + x, + y, + 0.5 * std, + where=y > 0.5 * std, + color="red", + interpolate=True, + alpha=0.5, + ) + + ax.fill_between( + x, + y, + -0.5 * std, + where=y < -0.5 * std, + color="blue", + interpolate=True, + alpha=0.5, + ) + + if title is not None: + ax.set_title(title) + + fig.savefig(output_file) + + +def test_plot_maps(std2_map, std2_map_phase, fig_title=None, output_file=None): + """_summary_ + + Parameters + ---------- + std2_map : xarray DataArray + _description_ + std2_map_phase : dict + _description_ + """ + + proj_setup = ccrs.PlateCarree(central_longitude=180) + proj = ccrs.PlateCarree() + + fig, axs = plt.subplots( + nrows=4, ncols=1, subplot_kw={"projection": proj_setup}, figsize=(13, 12) + ) + + axs = axs.flatten() + + for i, ax in enumerate(axs): + if i == 0: + data = std2_map + title = "All" + cmap = "viridis" + elif i == 1: + data = std2_map_phase["east"] + title = "QBO East" + cmap = "viridis" + elif i == 2: + data = std2_map_phase["west"] + title = "QBO West" + cmap = "viridis" + elif i == 3: + data = std2_map_phase["east"] - std2_map_phase["west"] + title = "Diff: East - West" + cmap = "coolwarm" + + data.plot( + ax=ax, + transform=proj, + cmap=cmap, + ) + + # Title each subplot + ax.set_title(title) + + # Draw the coastines for each subplot + ax.coastlines() + + # Create gridlines + gl = ax.gridlines( + crs=proj, linewidth=1, color="grey", alpha=0.2, linestyle="--" + ) + # Manipulate gridlines number and spaces + gl.ylocator = mticker.FixedLocator(np.arange(-90, 90, 20)) + gl.xlocator = mticker.FixedLocator(np.arange(-180, 180, 60)) + gl.top_labels = False + gl.bottom_labels = True + gl.left_labels = True + gl.right_labels = False + + if fig_title is not None: + fig.text(0.5, 0.95, fig_title, ha="center", va="bottom", fontsize=15) + fig.subplots_adjust(left=0.05, right=0.98, top=0.9, bottom=0.05, hspace=0.15) + + if output_file is not None: + fig.savefig(output_file) + + +def standardize_lat_lon_name_in_dataset(ds): + for coord in ("lat", "lon"): + coord_key_in_file = find_coord_key(ds, coord) + + if coord == coord_key_in_file: + pass + else: + ds = ds.rename({coord_key_in_file: coord}) + + # convert coord in descending order to ascending + if float(ds[coord][0].values) > float(ds[coord][-1].values): + if coord == "lat": + ds = ds.reindex(lat=list(ds.lat[::-1])) + elif coord == "lon": + ds = ds.reindex(lon=list(ds.lon[::-1])) + + return ds + + +def find_coord_key(ds, coord): + for coord_key in list(ds.coords.keys()): + if coord in coord_key.lower(): + return coord_key + + +def diag_plot( + std2_map, std2_map_diff, fig_title=None, output_file=None, sub_region=None +): + """_summary_ + + Parameters + ---------- + std2_map : xarray DataArray + _description_ + std2_map_diff : xarray DataArray + _description_ + """ + + proj_setup = ccrs.PlateCarree(central_longitude=180) + proj = ccrs.PlateCarree() + + fig, ax = plt.subplots(subplot_kw={"projection": proj_setup}, figsize=(8, 3)) + + lon1 = 0 + lon2 = 360 + lat1 = -30 + lat2 = 30 + ax.set_extent([lon1, lon2, lat1, lat2], crs=ccrs.PlateCarree()) + + data = std2_map.sel(lon=slice(lon1, lon2)).sel( + lat=slice(lat1, lat2) + ) # "All" --- contour + data2 = std2_map_diff.sel(lon=slice(lon1, lon2)).sel( + lat=slice(lat1, lat2) + ) # "Diff: East - West" --- color, + + # Adjust colormap + cmap = mycolormap() + + levels_shade = [-6, -4, -2, -1, 1, 2, 4, 6] + levels_contour = range(8, 30, 2) + + data.plot.contour( + ax=ax, + transform=proj, + levels=levels_contour, + colors="grey", + linewidths=1, + ) + + data2.plot( + ax=ax, + transform=proj, + cmap=cmap, + cbar_kwargs={"orientation": "horizontal", "ticks": levels_shade, "aspect": 40}, + levels=levels_shade, + extend="both", + ) + + # Title each subplot + if fig_title is not None: + ax.set_title(fig_title) + + # Add coastlines and other features if desired + ax.coastlines(resolution="50m", color="black", linewidth=1) + # ax.add_feature(cfeature.LAND, edgecolor='black') + + # plot land area in grey + land_50m = cfeature.NaturalEarthFeature( + "physical", "land", "50m", edgecolor=None, facecolor="lightgrey" + ) + ax.add_feature(land_50m) + + # Create gridlines + gl = ax.gridlines(crs=proj, linewidth=1, color="grey", alpha=0.2, linestyle="--") + # Manipulate gridlines number and spaces + gl.ylocator = mticker.FixedLocator(np.arange(-80, 80, 20)) + gl.xlocator = mticker.FixedLocator(np.arange(-180, 180, 60)) + gl.top_labels = False + gl.bottom_labels = True + gl.left_labels = True + gl.right_labels = False + + # Draw a rectangle to highlight the sub-region + if sub_region is not None: + lon_min, lon_max, lat_min, lat_max = sub_region + ax.plot( + [lon_min, lon_max, lon_max, lon_min, lon_min], + [lat_min, lat_min, lat_max, lat_max, lat_min], + color="lightgreen", + linestyle="--", + transform=ccrs.PlateCarree(), + ) + + if output_file is not None: + fig.tight_layout() + fig.savefig(output_file) + + +def mycolormap(): + """Combine two colormap to generate a new colormap for blue-white(middle)-yellow-red + + Returns + ------- + matplotlib colormap + """ + # Adjust colormap + + # sample the colormaps that you want to use. Use 128 from each so we get 256 + # colors in total + colors1 = plt.cm.Blues_r(np.linspace(0.0, 1, 127)) + colors2 = plt.cm.YlOrBr(np.linspace(0, 1, 127)) + + # add white in the middle + colors1 = np.append(colors1, [[0, 0, 0, 0]], axis=0) + colors2 = np.vstack((np.array([0, 0, 0, 0]), colors2)) + + # combine them and build a new colormap + colors = np.vstack((colors1, colors2)) + mymap = mcolors.LinearSegmentedColormap.from_list("my_colormap", colors) + + return mymap diff --git a/pcmdi_metrics/qbo/lib/utils_parallel.py b/pcmdi_metrics/qbo/lib/utils_parallel.py new file mode 100644 index 000000000..9ed5e1ac1 --- /dev/null +++ b/pcmdi_metrics/qbo/lib/utils_parallel.py @@ -0,0 +1,61 @@ +import logging +import sys +import time + +from compute_qbo_mjo_metrics import process_qbo_mjo_metrics + + +# Configure the logger +def configure_logger(filename): + logger = logging.getLogger() + logger.setLevel(logging.INFO) + + handler = logging.FileHandler(filename) + formatter = logging.Formatter( + "%(asctime)s - %(name)s - %(levelname)s - %(message)s" + ) + handler.setFormatter(formatter) + + logger.addHandler(handler) + return logger + + +# Redirect stdout to logger +class LoggerWriter: + def __init__(self, logger, level=logging.INFO): + self.logger = logger + self.level = level + + def write(self, message): + if message.strip() != "": + self.logger.log(self.level, message.strip()) + + def flush(self): + pass + + +def process(params): + exp = params["exp"] + model = params["model"] + member = params["member"] + log_file = params["log_file"] + + logger = configure_logger(log_file) + logger.info( + f'Starting process for {exp}, {model}, {member} at {time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())}' + ) + + # Redirect stdout to logger + sys.stdout = LoggerWriter(logger, logging.INFO) + + start_time = time.time() + + # Call detection function + process_qbo_mjo_metrics(params) + + end_time = time.time() + + print( + f'Process finished at {time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(end_time))}. Elapsed time: {end_time - start_time} seconds.' + ) + logger.info("Done") From ecda745c32f7a5542e0f226cbc5d3da282e3c05d Mon Sep 17 00:00:00 2001 From: lee1043 Date: Wed, 24 Apr 2024 16:57:41 -0700 Subject: [PATCH 2/3] first commit for driver --- pcmdi_metrics/qbo/__init__.py | 0 pcmdi_metrics/qbo/qbo_mjo_driver.py | 308 ++++++++++++++++++++++++++++ 2 files changed, 308 insertions(+) create mode 100644 pcmdi_metrics/qbo/__init__.py create mode 100644 pcmdi_metrics/qbo/qbo_mjo_driver.py diff --git a/pcmdi_metrics/qbo/__init__.py b/pcmdi_metrics/qbo/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/pcmdi_metrics/qbo/qbo_mjo_driver.py b/pcmdi_metrics/qbo/qbo_mjo_driver.py new file mode 100644 index 000000000..2b7641646 --- /dev/null +++ b/pcmdi_metrics/qbo/qbo_mjo_driver.py @@ -0,0 +1,308 @@ +import datetime +import glob +import multiprocessing +import os + +import xsearch as xs +from compute_qbo_mjo_metrics import process_qbo_mjo_metrics +from utils_parallel import process + +# User options (reference) -------------------------------------------------------------- + +reference_name = "ERA5" + +# ref_input_file = "/work/lee1043/DATA/ERA5/ERA5_u50_monthly_1979-2021.nc" +ref_input_file = "/work/lee1043/DATA/ERA5/ERA5_u50_monthly_1979-2021_rewrite.nc" +# ref_input_file2 = "/work/lee1043/DATA/ERA5/ERA5_olr_daily_40s40n_1979-2021.nc" +ref_input_file2 = "/work/lee1043/DATA/ERA5/ERA5_olr_daily_40s40n_1979-2021_rewrite.nc" + +ref_var1 = "u50" +ref_level1 = None # hPa (=mb) + +ref_var2 = "olr" + +ref_start = "1979-01" +ref_end = "2005-12" +# ref_end = "2021-12" + +include_reference = True +# include_reference = False + +# User options (model) ------------------------------------------------------------------ + +exps = ["historical"] +# exps = ["ssp126", "ssp245", "ssp375", "ssp585"] +# exps = ["ssp370"] + +mip_era = "CMIP6" +# mip_era = "CMIP5" + +# models = "all" +# models = ["ACCESS-CM2"] +models = [] + +mip = mip_era.lower() + +# Input 1: u50 monthly +var1 = "ua" +freq1 = "mon" +cmipTable1 = "Amon" +level1 = 50 # hPa + +# Input 2: olr daily +var2 = "rlut" +freq2 = "day" +cmipTable2 = "day" + +# first_member_only = True +first_member_only = False + +debug = False +# debug = True + +# +# parallel +# +parallel = False +# parallel = True # not complete yet ... working on it! + +num_processes = 10 +# num_processes = 3 + +# +# output +# +result_dir = "output" +if parallel: + result_dir = "output_parallel" +if debug: + result_dir = "output_debug" + +# Output: diagnostics --- netcdf file +output_filename_template = "qbo_mjo_%(model)_%(exp)_%(realization)_%(start)_%(end).nc" + +overwrite_output = False + +case_id = "{:v%Y%m%d}".format(datetime.datetime.now()) + +# +# Processing options +# +regrid = True +regrid_tool = "xesmf" +target_grid = "2x2" + +taper_to_mean = True + +# ------------------------------------------------------------------------------- + +params_collect = list() + +for exp in exps: + if exp in ["historical", "amip"]: + if debug: + model_start = "2000-01" + model_end = "2005-12" + else: + model_start = "1979-01" + """ + if mip_era == "CMIP6": + model_end = "2010-12" + elif mip_era == "CMIP5": + model_end = "2005-12" + else: + raise ValueError(f"{mip_era} is not defined for 'mip_era'") + """ + model_end = "2005-12" + else: + model_start = "2014-01" + model_end = "2100-12" + + outdir = os.path.join(result_dir, "diagnostics", mip, exp, case_id) + logdir = os.path.join(result_dir, "log", mip, exp, case_id) + + # Prepare a directory for output files + os.makedirs(outdir, exist_ok=True) + os.makedirs(logdir, exist_ok=True) + + # Search all available models + dpaths1 = xs.findPaths(exp, var1, freq1, cmipTable=cmipTable1, mip_era=mip_era) + dpaths2 = xs.findPaths(exp, var2, freq2, cmipTable=cmipTable2, mip_era=mip_era) + models1 = xs.natural_sort(xs.getGroupValues(dpaths1, "model")) + models2 = xs.natural_sort(xs.getGroupValues(dpaths2, "model")) + + common_models = [m for m in models1 if m in models2] + + print("exp:", exp) + print("models1:", models1) + print("number of models1:", len(models1)) + print("models2:", models2) + print("number of models2:", len(models2)) + + if models == "all": + models = common_models + + print("models:", models) + print("number of models:", len(models)) + + print("model_start:", model_start) + print("model_end:", model_end) + + if debug: + models = models[0:1] + print("exp:", exp) + print("models:", models) + print("number of models:", len(models)) + + if include_reference: + models.insert(0, "reference") + + # model loop + for model in models: + if model == "reference": + members = [reference_name] + + else: + dpaths_model1 = xs.retainDataByFacetValue(dpaths1, "model", model) + dpaths_model2 = xs.retainDataByFacetValue(dpaths2, "model", model) + members1 = xs.natural_sort(xs.getGroupValues(dpaths_model1, "member")) + members2 = xs.natural_sort(xs.getGroupValues(dpaths_model2, "member")) + members = [m for m in members1 if m in members2] + + if first_member_only or debug: + members = members[0:1] + if debug: + print("members1 (" + str(len(members1)) + "):", members1) + print("members2 (" + str(len(members2)) + "):", members2) + + if debug: + print("members (" + str(len(members)) + "):", members) + + # ensemble member loop + for member in members: + if model == "reference" and member == reference_name: + ncfiles1 = glob.glob(ref_input_file) + ncfiles2 = glob.glob(ref_input_file2) + + level_extract = ref_level1 + varname = ref_var1 + varname2 = ref_var2 + + start = ref_start + end = ref_end + + else: + dpaths_model_member_list1 = xs.getValuesForFacet( + dpaths_model1, "member", member + ) + dpaths_model_member_list2 = xs.getValuesForFacet( + dpaths_model2, "member", member + ) + + if debug: + print("dpaths_model_member_list1:", dpaths_model_member_list1) + print("dpaths_model_member_list2:", dpaths_model_member_list2) + + # Sanity check -- var1 + if len(dpaths_model_member_list1) > 1: + print( + "Error: multiple paths detected for ", + model, + member, + ": ", + dpaths_model_member_list1, + ) + else: + dpath1 = dpaths_model_member_list1[0] + ncfiles1 = xs.natural_sort(glob.glob(os.path.join(dpath1, "*.nc"))) + # Sanity check -- var2 + if len(dpaths_model_member_list2) > 1: + print( + "Error: multiple paths detected for ", + model, + member, + ": ", + dpaths_model_member_list2, + ) + else: + dpath2 = dpaths_model_member_list2[0] + ncfiles2 = xs.natural_sort(glob.glob(os.path.join(dpath2, "*.nc"))) + + level_extract = level1 + varname = var1 + varname2 = var2 + start = model_start + end = model_end + + if debug: + print("ncfiles1:", ncfiles1) + print("ncfiles2:", ncfiles2) + + # Set output file + output_filename = ( + output_filename_template.replace("%(exp)", exp) + .replace("%(model)", model) + .replace("%(realization)", member) + .replace("%(start)", start) + .replace("%(end)", end) + ) + output_file = os.path.join(outdir, output_filename) + + log_filename = output_filename.replace(".nc", ".log") + log_file = os.path.join(logdir, log_filename) + + # Set up parameters + params = { + "model": model, + "exp": exp, + "member": member, + "input_file": ncfiles1, + "input_file2": ncfiles2, + "varname": varname, + "level": level_extract, # hPa (=mb) + "varname2": varname2, + "start": start, + "end": end, + "regrid": regrid, + "regrid_tool": regrid_tool, + "target_grid": target_grid, + "taper_to_mean": taper_to_mean, + "output_dir": outdir, + "debug": debug, + "log_file": log_file, + } + + # Process ------------------------------------- + if overwrite_output: + pass + else: + if os.path.isfile(output_file): + continue # skip over the below part of the loop, and go on to the next to complete the rest of the loop. + + if parallel: + params_collect.append(params) + else: + # Call detection function + print("call process_qbo_mjo_metrics for", model, member) + # if 1: + try: + process_qbo_mjo_metrics(params) + print("done process_qbo_mjo_metrics for ", model, member) + except Exception as e: + print("process_qbo_mjo_metrics failed for ", model, member, e) + + +# The below is yet to work ... in progress! +if parallel: + num_task = len(params_collect) + print("number of total tasks: ", len(params_collect)) + if num_task < num_processes: + num_processes = num_task + print("number of processes for parallel: ", len(params_collect)) + + # pool object with number of element + pool = multiprocessing.Pool(processes=num_processes) + + # map the function to the list and pass + # function and input list as arguments + pool.starmap(process, params_collect) From 781b363bcd1c26bc60076c4b3dbd6ce2663da195 Mon Sep 17 00:00:00 2001 From: lee1043 Date: Wed, 24 Apr 2024 17:12:14 -0700 Subject: [PATCH 3/3] add qbo_mjo_driver.py --- setup.py | 1 + 1 file changed, 1 insertion(+) diff --git a/setup.py b/setup.py index 3a55367ec..abda73293 100644 --- a/setup.py +++ b/setup.py @@ -41,6 +41,7 @@ "pcmdi_metrics/cloud_feedback/cloud_feedback_driver.py", "pcmdi_metrics/extremes/extremes_driver.py", "pcmdi_metrics/sea_ice/sea_ice_driver.py", + "pcmdi_metrics/qbo/qbo_mjo_driver.py", ] entry_points = {