|
| 1 | +from typing import List, Optional, Tuple |
| 2 | + |
| 3 | +import matplotlib |
| 4 | +import numpy as np |
| 5 | +import xarray as xr |
| 6 | +import xcdat as xc |
| 7 | + |
| 8 | +from e3sm_diags.driver.utils.type_annotations import MetricsDict |
| 9 | +from e3sm_diags.logger import custom_logger |
| 10 | +from e3sm_diags.parameter.core_parameter import CoreParameter |
| 11 | +from e3sm_diags.parameter.zonal_mean_2d_parameter import DEFAULT_PLEVS |
| 12 | +from e3sm_diags.plot.utils import ( |
| 13 | + DEFAULT_PANEL_CFG, |
| 14 | + _add_colorbar, |
| 15 | + _add_contour_plot, |
| 16 | + _add_min_mean_max_text, |
| 17 | + _add_rmse_corr_text, |
| 18 | + _configure_titles, |
| 19 | + _configure_x_and_y_axes, |
| 20 | + _get_c_levels_and_norm, |
| 21 | + _save_plot, |
| 22 | +) |
| 23 | + |
| 24 | +matplotlib.use("Agg") |
| 25 | +import matplotlib.pyplot as plt # isort:skip # noqa: E402 |
| 26 | + |
| 27 | +logger = custom_logger(__name__) |
| 28 | + |
| 29 | + |
| 30 | +# Configs for x axis ticks and x axis limits. |
| 31 | +X_TICKS = np.array([-90, -60, -30, 0, 30, 60, 90]) |
| 32 | +X_LIM = -90, 90 |
| 33 | + |
| 34 | + |
| 35 | +def plot( |
| 36 | + parameter: CoreParameter, |
| 37 | + da_test: xr.DataArray, |
| 38 | + da_ref: xr.DataArray, |
| 39 | + da_diff: xr.DataArray, |
| 40 | + metrics_dict: MetricsDict, |
| 41 | +): |
| 42 | + """Plot the variable's metrics generated by the zonal_mean_2d set. |
| 43 | +
|
| 44 | + Parameters |
| 45 | + ---------- |
| 46 | + parameter : CoreParameter |
| 47 | + The CoreParameter object containing plot configurations. |
| 48 | + da_test : xr.DataArray |
| 49 | + The test data. |
| 50 | + da_ref : xr.DataArray |
| 51 | + The reference data. |
| 52 | + da_diff : xr.DataArray |
| 53 | + The difference between `da_test` and `da_ref` (both are regridded to |
| 54 | + the lower resolution of the two beforehand). |
| 55 | + metrics_dict : Metrics |
| 56 | + The metrics. |
| 57 | + """ |
| 58 | + fig = plt.figure(figsize=parameter.figsize, dpi=parameter.dpi) |
| 59 | + fig.suptitle(parameter.main_title, x=0.5, y=0.96, fontsize=18) |
| 60 | + |
| 61 | + # The variable units. |
| 62 | + units = metrics_dict["units"] |
| 63 | + |
| 64 | + # Add the first subplot for test data. |
| 65 | + min1 = metrics_dict["test"]["min"] # type: ignore |
| 66 | + mean1 = metrics_dict["test"]["mean"] # type: ignore |
| 67 | + max1 = metrics_dict["test"]["max"] # type: ignore |
| 68 | + |
| 69 | + _add_colormap( |
| 70 | + 0, |
| 71 | + da_test, |
| 72 | + fig, |
| 73 | + parameter, |
| 74 | + parameter.test_colormap, |
| 75 | + parameter.contour_levels, |
| 76 | + title=(parameter.test_name_yrs, parameter.test_title, units), # type: ignore |
| 77 | + metrics=(max1, mean1, min1), # type: ignore |
| 78 | + ) |
| 79 | + |
| 80 | + # Add the second and third subplots for ref data and the differences, |
| 81 | + # respectively. |
| 82 | + min2 = metrics_dict["ref"]["min"] # type: ignore |
| 83 | + mean2 = metrics_dict["ref"]["mean"] # type: ignore |
| 84 | + max2 = metrics_dict["ref"]["max"] # type: ignore |
| 85 | + |
| 86 | + _add_colormap( |
| 87 | + 1, |
| 88 | + da_ref, |
| 89 | + fig, |
| 90 | + parameter, |
| 91 | + parameter.reference_colormap, |
| 92 | + parameter.contour_levels, |
| 93 | + title=(parameter.ref_name_yrs, parameter.reference_title, units), # type: ignore |
| 94 | + metrics=(max2, mean2, min2), # type: ignore |
| 95 | + ) |
| 96 | + |
| 97 | + min3 = metrics_dict["diff"]["min"] # type: ignore |
| 98 | + mean3 = metrics_dict["diff"]["mean"] # type: ignore |
| 99 | + max3 = metrics_dict["diff"]["max"] # type: ignore |
| 100 | + r = metrics_dict["misc"]["rmse"] # type: ignore |
| 101 | + c = metrics_dict["misc"]["corr"] # type: ignore |
| 102 | + |
| 103 | + _add_colormap( |
| 104 | + 2, |
| 105 | + da_diff, |
| 106 | + fig, |
| 107 | + parameter, |
| 108 | + parameter.diff_colormap, |
| 109 | + parameter.diff_levels, |
| 110 | + title=(None, parameter.diff_title, da_diff.attrs["units"]), # |
| 111 | + metrics=(max3, mean3, min3, r, c), # type: ignore |
| 112 | + ) |
| 113 | + |
| 114 | + _save_plot(fig, parameter) |
| 115 | + |
| 116 | + plt.close() |
| 117 | + |
| 118 | + |
| 119 | +def _add_colormap( |
| 120 | + subplot_num: int, |
| 121 | + var: xr.DataArray, |
| 122 | + fig: plt.Figure, |
| 123 | + parameter: CoreParameter, |
| 124 | + color_map: str, |
| 125 | + contour_levels: List[float], |
| 126 | + title: Tuple[Optional[str], str, str], |
| 127 | + metrics: Tuple[float, ...], |
| 128 | +): |
| 129 | + lat = xc.get_dim_coords(var, axis="Y") |
| 130 | + plev = xc.get_dim_coords(var, axis="Z") |
| 131 | + var = var.squeeze() |
| 132 | + |
| 133 | + # Configure contour levels |
| 134 | + # -------------------------------------------------------------------------- |
| 135 | + c_levels, norm = _get_c_levels_and_norm(contour_levels) |
| 136 | + |
| 137 | + # Add the contour plot |
| 138 | + # -------------------------------------------------------------------------- |
| 139 | + ax = fig.add_axes(DEFAULT_PANEL_CFG[subplot_num], projection=None) |
| 140 | + |
| 141 | + contour_plot = _add_contour_plot( |
| 142 | + ax, parameter, var, lat, plev, color_map, None, norm, c_levels |
| 143 | + ) |
| 144 | + |
| 145 | + # Configure the aspect ratio and plot titles. |
| 146 | + # -------------------------------------------------------------------------- |
| 147 | + ax.set_aspect("auto") |
| 148 | + _configure_titles(ax, title) |
| 149 | + |
| 150 | + # Configure x and y axis. |
| 151 | + # -------------------------------------------------------------------------- |
| 152 | + _configure_x_and_y_axes(ax, X_TICKS, None, None, parameter.current_set) |
| 153 | + ax.set_xlim(X_LIM) |
| 154 | + |
| 155 | + if parameter.plot_log_plevs: |
| 156 | + ax.set_yscale("log") |
| 157 | + |
| 158 | + if parameter.plot_plevs: |
| 159 | + plev_ticks = parameter.plevs |
| 160 | + plt.yticks(plev_ticks, plev_ticks) |
| 161 | + |
| 162 | + # For default plevs, specify the pressure axis and show the 50 mb tick |
| 163 | + # at the top. |
| 164 | + if ( |
| 165 | + not parameter.plot_log_plevs |
| 166 | + and not parameter.plot_plevs |
| 167 | + and parameter.plevs == DEFAULT_PLEVS |
| 168 | + ): |
| 169 | + plev_ticks = parameter.plevs |
| 170 | + new_ticks = [plev_ticks[0]] + plev_ticks[1::2] |
| 171 | + new_ticks = [int(x) for x in new_ticks] |
| 172 | + plt.yticks(new_ticks, new_ticks) |
| 173 | + |
| 174 | + plt.ylabel("pressure (mb)") |
| 175 | + ax.invert_yaxis() |
| 176 | + |
| 177 | + # Add and configure the color bar. |
| 178 | + # -------------------------------------------------------------------------- |
| 179 | + _add_colorbar(fig, subplot_num, DEFAULT_PANEL_CFG, contour_plot, c_levels) |
| 180 | + |
| 181 | + # Add metrics text. |
| 182 | + # -------------------------------------------------------------------------- |
| 183 | + # Min, Mean, Max |
| 184 | + _add_min_mean_max_text(fig, subplot_num, DEFAULT_PANEL_CFG, metrics) |
| 185 | + |
| 186 | + if len(metrics) == 5: |
| 187 | + _add_rmse_corr_text(fig, subplot_num, DEFAULT_PANEL_CFG, metrics) |
0 commit comments