From d92f0931ef5febd284a42ad3a19ee1c371c35d1a Mon Sep 17 00:00:00 2001 From: chengzhuzhang Date: Tue, 1 Apr 2025 20:43:54 -0700 Subject: [PATCH 1/4] Simplify make_plot_pdfs function to improve performance - Only create multi-figure PDFs for 'original' component - For non-original components, create individual PNGs for each plot - This removes the bottleneck of creating large multi-page PDFs for components that are primarily used for individual image display in the viewer - Reduce DPI from 150 to 100 for faster rendering and smaller files - Improve code organization and readability --- .../coupled_global_plotting.py | 153 ++++++++++-------- 1 file changed, 82 insertions(+), 71 deletions(-) diff --git a/zppy_interfaces/global_time_series/coupled_global_plotting.py b/zppy_interfaces/global_time_series/coupled_global_plotting.py index 42f1f6b..81a9afd 100644 --- a/zppy_interfaces/global_time_series/coupled_global_plotting.py +++ b/zppy_interfaces/global_time_series/coupled_global_plotting.py @@ -556,53 +556,45 @@ def make_plot_pdfs( # noqa: C901 if num_plots == 0: return - # If make_viewer, then we want to do 1 plot per page. - # However, the original plots are excluded from this restriction. - # Note: if the user provides nrows=ncols=1, there will still be a single plot per page - keep_user_dims = (not parameters.make_viewer) or (component == "original") - if keep_user_dims: + # Ensure output directory exists + os.makedirs(parameters.results_dir, exist_ok=True) + + # For "original" component: Create multi-figure PDF with multiple plots per page + if component == "original": + # Use user-specified dimensions for original component nrows = parameters.nrows ncols = parameters.ncols - else: - nrows = 1 - ncols = 1 - - plots_per_page = nrows * ncols - num_pages = math.ceil(num_plots / plots_per_page) - - counter = 0 - os.makedirs(parameters.results_dir, exist_ok=True) - # https://stackoverflow.com/questions/58738992/save-multiple-figures-with-subplots-into-a-pdf-with-multiple-pages - pdf = matplotlib.backends.backend_pdf.PdfPages( - f"{parameters.results_dir}/{parameters.figstr}_{rgn}_{component}.pdf" - ) - for page in range(num_pages): - if plots_per_page == 1: - logger.info("Using reduced figsize") - fig = plt.figure(1, figsize=[13.5 / 2, 16.5 / 4]) - else: - logger.info("Using standard figsize") + plots_per_page = nrows * ncols + num_pages = math.ceil(num_plots / plots_per_page) + + # Create PDF file + pdf = matplotlib.backends.backend_pdf.PdfPages( + f"{parameters.results_dir}/{parameters.figstr}_{rgn}_{component}.pdf" + ) + + # Process each page + counter = 0 + for page in range(num_pages): + # Create multi-plot figure fig = plt.figure(1, figsize=[13.5, 16.5]) - logger.info(f"Figure size={fig.get_size_inches() * fig.dpi}") - fig.suptitle(f"{parameters.figstr}_{rgn}_{component}") - for j in range(plots_per_page): - logger.info( - f"Plotting plot {j} on page {page}. This is plot {counter} in total." - ) - # The final page doesn't need to be filled out with plots. - if counter >= num_plots: - break - ax = plt.subplot( - nrows, - ncols, - j + 1, - ) - plot_name = plot_list[counter] - if component == "original": + fig.suptitle(f"{parameters.figstr}_{rgn}_{component}") + + # Process plots for this page + for j in range(plots_per_page): + # The final page doesn't need to be filled out with plots + if counter >= num_plots: + break + + # Create subplot + ax = plt.subplot(nrows, ncols, j + 1) + plot_name = plot_list[counter] + + # Generate plot try: plot_function = PLOT_DICT[plot_name] except KeyError: raise KeyError(f"Invalid plot name: {plot_name}") + try: plot_function(ax, xlim, exps, rgn) valid_plots.append(plot_name) @@ -623,36 +615,55 @@ def make_plot_pdfs( # noqa: C901 f"Failed plot_function for {plot_name}. Check that {required_vars} are available." ) invalid_plots.append(plot_name) + counter += 1 + + # Finalize and save the figure + fig.tight_layout() + pdf.savefig(fig) + + # Save PNG of the entire page + if num_pages > 1: + # Multi-page PDF - include page number in filename + png_path = f"{parameters.results_dir}/{parameters.figstr}_{rgn}_{component}_{page}.png" else: - try: - plot_generic(ax, xlim, exps, plot_name, rgn) - valid_plots.append(plot_name) - except Exception: - traceback.print_exc() - logger.error( - f"plot_generic failed. Invalid plot={plot_name}, rgn={rgn}" - ) - invalid_plots.append(plot_name) - counter += 1 - - fig.tight_layout() - pdf.savefig(1) - # Always save individual PNGs for viewer mode - if plots_per_page == 1: - fig.savefig( - f"{parameters.results_dir}/{parameters.figstr}_{rgn}_{component}_{plot_name}.png", - dpi=150, - ) - elif num_pages > 1: - fig.savefig( - f"{parameters.results_dir}/{parameters.figstr}_{rgn}_{component}_{page}.png", - dpi=150, - ) - else: - fig.savefig( - f"{parameters.results_dir}/{parameters.figstr}_{rgn}_{component}.png", - dpi=150, - ) - plt.close(fig) - pdf.close() + # Single page PDF + png_path = f"{parameters.results_dir}/{parameters.figstr}_{rgn}_{component}.png" + + fig.savefig(png_path, dpi=100) # Use lower DPI for better performance + plt.close(fig) + + # Close PDF file + pdf.close() + + # For non-original components: Create individual PNGs (one plot per file) + else: + # Process each plot individually + for i, plot_name in enumerate(plot_list): + # Create single-plot figure + fig = plt.figure(figsize=[13.5 / 2, 16.5 / 4]) + ax = fig.add_subplot(111) + + # Generate plot + try: + plot_generic(ax, xlim, exps, plot_name, rgn) + valid_plots.append(plot_name) + except Exception: + traceback.print_exc() + logger.error( + f"plot_generic failed. Invalid plot={plot_name}, rgn={rgn}" + ) + invalid_plots.append(plot_name) + plt.close(fig) + continue # Skip saving failed plots + + # Finalize and save PNG + fig.suptitle(f"{parameters.figstr}_{rgn}_{component}") + fig.tight_layout() + + # Save individual PNG + png_path = f"{parameters.results_dir}/{parameters.figstr}_{rgn}_{component}_{plot_name}.png" + fig.savefig(png_path, dpi=100) + + # Clean up + plt.close(fig) From 7a3fcccee344edd6a035acff793a81b6db8a5eb6 Mon Sep 17 00:00:00 2001 From: chengzhuzhang Date: Tue, 1 Apr 2025 21:42:25 -0700 Subject: [PATCH 2/4] Optimize data loading for multiple regions This significant optimization restructures the data loading pipeline to load NetCDF files only once for all regions, rather than multiple times (once per region). Since each NetCDF file already contains data for all regions (global, northern hemisphere, southern hemisphere), this optimization: 1. Reduces I/O operations dramatically - files are read just once 2. Improves memory efficiency - data is loaded once and reused 3. Decreases processing time - particularly for multi-region analyses Implementation: - Added new load_all_region_data function to load data for all regions at once - Created extract_region_data function to extract specific region data - Restructured the data flow in coupled_global to process regions after data load - Maintains backward compatibility with existing API - Better error handling and reporting for problematic variables This is a significant performance enhancement for use cases that process multiple regions from the same dataset. --- .../global_time_series/coupled_global.py | 396 ++++++++++++++---- .../coupled_global_dataset_wrapper.py | 14 +- 2 files changed, 316 insertions(+), 94 deletions(-) diff --git a/zppy_interfaces/global_time_series/coupled_global.py b/zppy_interfaces/global_time_series/coupled_global.py index b07bee8..1d46f08 100644 --- a/zppy_interfaces/global_time_series/coupled_global.py +++ b/zppy_interfaces/global_time_series/coupled_global.py @@ -154,15 +154,28 @@ def get_exps(parameters: Parameters) -> List[Dict[str, Any]]: return exps -def set_var( +def load_all_var_data( exp: Dict[str, Any], exp_key: str, var_list: List[Variable], valid_vars: List[str], invalid_vars: List[str], - rgn: str, -) -> List[Variable]: +) -> Tuple[List[Variable], Dict[str, Tuple[xarray.core.dataarray.DataArray, str]]]: + """Load data for all variables for all regions. + + Args: + exp: Experiment configuration + exp_key: Key for data directory + var_list: List of variables to load + valid_vars: List to track successfully loaded variables + invalid_vars: List to track variables that couldn't be loaded + + Returns: + Tuple of (list of successfully loaded variables, dictionary of data arrays) + """ new_var_list: List[Variable] = [] + all_data_dict = {} + if exp[exp_key] != "": try: dataset_wrapper: DatasetWrapper = DatasetWrapper(exp[exp_key]) @@ -172,118 +185,292 @@ def set_var( f"DatasetWrapper object could not be created for {exp_key}={exp[exp_key]}" ) raise e + for var in var_list: var_str: str = var.variable_name try: + # Get data for all regions data_array: xarray.core.dataarray.DataArray units: str - data_array, units = dataset_wrapper.globalAnnual(var) - valid_vars.append(str(var_str)) # Append the name - new_var_list.append(var) # Append the variable itself + data_array, units = dataset_wrapper.globalAnnual(var, all_regions=True) + + # Store the result keyed by variable name + all_data_dict[var_str] = (data_array, units) + + # Track successful variables + valid_vars.append(str(var_str)) + new_var_list.append(var) + + # Store year info if not already present + if "year" not in exp["annual"]: + years: np.ndarray[cftime.DatetimeNoLeap] = data_array.coords["time"].values + exp["annual"]["year"] = [x.year for x in years] + except Exception as e: logger.error(e) logger.error(f"globalAnnual failed for {var_str}") invalid_vars.append(str(var_str)) - continue - if data_array.sizes["rgn"] > 1: - # number of years x 3 regions = data_array.shape - # 3 regions = global, northern hemisphere, southern hemisphere - # We get here if we used the updated `ts` task - # (using `rgn_avg` rather than `glb_avg`). - if rgn == "glb": - n = 0 - elif rgn == "n": - n = 1 - elif rgn == "s": - n = 2 - else: - raise RuntimeError(f"Invalid rgn={rgn}") - data_array = data_array.isel(rgn=n) # Just use nth region - elif rgn != "glb": - # data_array only has one dimension -- glb. - # Therefore it is not possible to get n or s plots. - raise RuntimeError( - f"var={var_str} only has global data. Cannot process rgn={rgn}" - ) - exp["annual"][var_str] = (data_array, units) - if "year" not in exp["annual"]: - years: np.ndarray[cftime.DatetimeNoLeap] = data_array.coords[ - "time" - ].values - exp["annual"]["year"] = [x.year for x in years] + del dataset_wrapper - return new_var_list + + return new_var_list, all_data_dict -def process_data( - parameters: Parameters, requested_variables: RequestedVariables, rgn: str -) -> List[Dict[str, Any]]: +def extract_region_data( + all_data_dict: Dict[str, Tuple[xarray.core.dataarray.DataArray, str]], + rgn: str, +) -> Dict[str, Tuple[xarray.core.dataarray.DataArray, str]]: + """Extract data for a specific region from the all-regions data dictionary. + + Args: + all_data_dict: Dictionary mapping variable names to (data_array, units) tuples + rgn: Region to extract ('glb', 'n', or 's') + + Returns: + Dictionary with data arrays extracted for the specified region + """ + region_data_dict = {} + + # Map region string to index + if rgn == "glb": + n = 0 + elif rgn == "n": + n = 1 + elif rgn == "s": + n = 2 + else: + raise RuntimeError(f"Invalid rgn={rgn}") + + # Process each variable + for var_str, (data_array, units) in all_data_dict.items(): + # Extract region if the data has multiple regions + if "rgn" in data_array.dims and data_array.sizes["rgn"] > 1: + # Extract the specific region + region_data = data_array.isel(rgn=n) + elif rgn != "glb": + # If no rgn dimension but trying to get n or s, that's an error + raise RuntimeError( + f"var={var_str} only has global data. Cannot process rgn={rgn}" + ) + else: + # No rgn dimension but wanting global data, or already extracted + region_data = data_array + + # Store in output dictionary + region_data_dict[var_str] = (region_data, units) + + return region_data_dict + + +def load_all_region_data( + parameters: Parameters, requested_variables: RequestedVariables +) -> Tuple[List[Dict[str, Any]], Dict[str, List[str]], Dict[str, List[str]]]: + """Load all data for all regions at once. + + Args: + parameters: Configuration parameters + requested_variables: Variables to load for each component + + Returns: + Tuple of (experiment data, valid variables, invalid variables) + """ + # Get experiment configurations exps: List[Dict[str, Any]] = get_exps(parameters) - valid_vars: List[str] = [] - invalid_vars: List[str] = [] - exp: Dict[str, Any] + + # Track valid and invalid variables by component + component_valid_vars = {"atmos": [], "ice": [], "land": [], "ocean": []} + component_invalid_vars = {"atmos": [], "ice": [], "land": [], "ocean": []} + + # Process each experiment for exp in exps: + # Initialize annual data and all-regions data storage exp["annual"] = {} - - requested_variables.vars_original = set_var( + exp["all_regions_data"] = {} + + # Initialize component data dictionaries + exp["all_regions_data"]["atmos"] = {} + exp["all_regions_data"]["ice"] = {} + exp["all_regions_data"]["land"] = {} + exp["all_regions_data"]["ocean"] = {} + + # Load data for each component - original vars (atmosphere variables) + requested_variables.vars_original, atmos_original_data = load_all_var_data( exp, "atmos", requested_variables.vars_original, - valid_vars, - invalid_vars, - rgn, + component_valid_vars["atmos"], + component_invalid_vars["atmos"], ) - requested_variables.vars_atm = set_var( - exp, "atmos", requested_variables.vars_atm, valid_vars, invalid_vars, rgn + exp["all_regions_data"]["atmos"].update(atmos_original_data) + + # Load data for each component - atmosphere variables + requested_variables.vars_atm, atmos_data = load_all_var_data( + exp, + "atmos", + requested_variables.vars_atm, + component_valid_vars["atmos"], + component_invalid_vars["atmos"], ) - requested_variables.vars_ice = set_var( - exp, "ice", requested_variables.vars_ice, valid_vars, invalid_vars, rgn + exp["all_regions_data"]["atmos"].update(atmos_data) + + # Load data for each component - ice variables + requested_variables.vars_ice, ice_data = load_all_var_data( + exp, + "ice", + requested_variables.vars_ice, + component_valid_vars["ice"], + component_invalid_vars["ice"], ) - requested_variables.vars_land = set_var( + exp["all_regions_data"]["ice"].update(ice_data) + + # Load data for each component - land variables + requested_variables.vars_land, land_data = load_all_var_data( exp, "land", requested_variables.vars_land, - valid_vars, - invalid_vars, - rgn, + component_valid_vars["land"], + component_invalid_vars["land"], ) - requested_variables.vars_ocn = set_var( - exp, "ocean", requested_variables.vars_ocn, valid_vars, invalid_vars, rgn + exp["all_regions_data"]["land"].update(land_data) + + # Load data for each component - ocean variables + requested_variables.vars_ocn, ocn_data = load_all_var_data( + exp, + "ocean", + requested_variables.vars_ocn, + component_valid_vars["ocean"], + component_invalid_vars["ocean"], ) - - # Optionally read ohc + exp["all_regions_data"]["ocean"].update(ocn_data) + + # Special handling for ocean heat content if exp["ocean"] != "": - dataset_wrapper = DatasetWrapper(exp["ocean"]) - exp["annual"]["ohc"], _ = dataset_wrapper.globalAnnual(Variable("ohc")) - # anomalies with respect to first year - exp["annual"]["ohc"][:] = exp["annual"]["ohc"][:] - exp["annual"]["ohc"][0] - + try: + dataset_wrapper = DatasetWrapper(exp["ocean"]) + data_array, units = dataset_wrapper.globalAnnual(Variable("ohc"), all_regions=True) + + # Store in all regions data + exp["all_regions_data"]["ocean"]["ohc"] = (data_array, units) + + # Track as valid variable + component_valid_vars["ocean"].append("ohc") + + del dataset_wrapper + except Exception as e: + logger.error(e) + logger.error("Failed to load ohc data") + component_invalid_vars["ocean"].append("ohc") + + # Special handling for ocean volume if exp["vol"] != "": - dataset_wrapper = DatasetWrapper(exp["vol"]) - exp["annual"]["volume"], _ = dataset_wrapper.globalAnnual( - Variable("volume") - ) - # annomalies with respect to first year - exp["annual"]["volume"][:] = ( - exp["annual"]["volume"][:] - exp["annual"]["volume"][0] - ) + try: + dataset_wrapper = DatasetWrapper(exp["vol"]) + data_array, units = dataset_wrapper.globalAnnual(Variable("volume"), all_regions=True) + + # Store in all regions data + exp["all_regions_data"]["ocean"]["volume"] = (data_array, units) + + # Track as valid variable + component_valid_vars["ocean"].append("volume") + + del dataset_wrapper + except Exception as e: + logger.error(e) + logger.error("Failed to load volume data") + component_invalid_vars["ocean"].append("volume") + + # Log success and failures for all components + for component, valid_vars in component_valid_vars.items(): + if valid_vars: + logger.info(f"{component} variables were computed successfully: {valid_vars}") + + for component, invalid_vars in component_invalid_vars.items(): + if invalid_vars: + logger.error(f"{component} variables could not be computed: {invalid_vars}") + + return exps, component_valid_vars, component_invalid_vars - logger.info( - f"{rgn} region globalAnnual was computed successfully for these variables: {valid_vars}" - ) - logger.error( - f"{rgn} region globalAnnual could not be computed for these variables: {invalid_vars}" - ) - return exps +def process_data( + all_region_exps: List[Dict[str, Any]], rgn: str +) -> List[Dict[str, Any]]: + """Process data for a specific region. + + Args: + all_region_exps: Experiments with all-regions data already loaded + rgn: Region to process ('glb', 'n', or 's') + + Returns: + List of experiment dictionaries with region-specific data + """ + # Create a deep copy to avoid modifying the original + import copy + exps = copy.deepcopy(all_region_exps) + + # Extract region-specific data for each experiment + for exp in exps: + # Extract atmosphere data + if "atmos" in exp["all_regions_data"]: + atmos_region_data = extract_region_data(exp["all_regions_data"]["atmos"], rgn) + exp["annual"].update(atmos_region_data) + + # Extract ice data + if "ice" in exp["all_regions_data"]: + ice_region_data = extract_region_data(exp["all_regions_data"]["ice"], rgn) + exp["annual"].update(ice_region_data) + + # Extract land data + if "land" in exp["all_regions_data"]: + land_region_data = extract_region_data(exp["all_regions_data"]["land"], rgn) + exp["annual"].update(land_region_data) + + # Extract ocean data + if "ocean" in exp["all_regions_data"]: + ocean_region_data = extract_region_data(exp["all_regions_data"]["ocean"], rgn) + exp["annual"].update(ocean_region_data) + + # Process OHC anomalies if available + if "ohc" in exp["annual"]: + # anomalies with respect to first year + ohc_data, ohc_units = exp["annual"]["ohc"] + ohc_anomaly = ohc_data - ohc_data[0] + exp["annual"]["ohc"] = (ohc_anomaly, ohc_units) + + # Process volume anomalies if available + if "volume" in exp["annual"]: + # anomalies with respect to first year + volume_data, volume_units = exp["annual"]["volume"] + volume_anomaly = volume_data - volume_data[0] + exp["annual"]["volume"] = (volume_anomaly, volume_units) + + # Clean up all_regions_data to save memory + del exp["all_regions_data"] + + return exps -# Run coupled_global ########################################################## -def run(parameters: Parameters, requested_variables: RequestedVariables, rgn: str): - # Experiments - exps: List[Dict[str, Any]] = process_data(parameters, requested_variables, rgn) +# Run coupled_global for a single region ########################################################## +def run_region( + parameters: Parameters, + requested_variables: RequestedVariables, + rgn: str, + all_region_exps: List[Dict[str, Any]] +): + """Process and plot data for a specific region. + + Args: + parameters: Configuration parameters + requested_variables: Variables to process + rgn: Region to process ('glb', 'n', or 's') + all_region_exps: Experiment data with all regions already loaded + """ + # Extract data for this specific region + exps: List[Dict[str, Any]] = process_data(all_region_exps, rgn) + + # Set up x-axis limits xlim: List[float] = [float(parameters.year1), float(parameters.year2)] + # Track successful and failed plots valid_plots: List[str] = [] invalid_plots: List[str] = [] @@ -298,6 +485,8 @@ def run(parameters: Parameters, requested_variables: RequestedVariables, rgn: st ("lnd", list(map(lambda v: v.variable_name, requested_variables.vars_land))), ("ocn", list(map(lambda v: v.variable_name, requested_variables.vars_ocn))), ] + + # Generate plots for each component for component, plot_list in mapping: make_plot_pdfs( parameters, @@ -309,13 +498,19 @@ def run(parameters: Parameters, requested_variables: RequestedVariables, rgn: st valid_plots, invalid_plots, ) - logger.info(f"These {rgn} region plots generated successfully: {valid_plots}") - logger.error( - f"These {rgn} region plots could not be generated successfully: {invalid_plots}" - ) + + # Log results + if valid_plots: + logger.info(f"These {rgn} region plots generated successfully: {valid_plots}") + + if invalid_plots: + logger.error( + f"These {rgn} region plots could not be generated successfully: {invalid_plots}" + ) def get_vars(requested_variables: RequestedVariables, component: str) -> List[Variable]: + """Get variable list for a specific component.""" vars: List[Variable] if component == "original": vars = requested_variables.vars_original @@ -333,24 +528,38 @@ def get_vars(requested_variables: RequestedVariables, component: str) -> List[Va def coupled_global(parameters: Parameters) -> None: + """Main entry point for the global time series plots. + + Changes from original version: + - Load all data for all regions once, then process each region using that data + - This reduces I/O operations significantly + """ + # Initialize variables for all components requested_variables = RequestedVariables(parameters) + + # OPTIMIZATION: Load all data for all regions once + logger.info("Loading data for all regions...") + all_region_exps, valid_vars, invalid_vars = load_all_region_data(parameters, requested_variables) + + # Process each region using the already-loaded data for rgn in parameters.regions: - run(parameters, requested_variables, rgn) + logger.info(f"Processing region: {rgn}") + run_region(parameters, requested_variables, rgn, all_region_exps) + + # Create viewer if requested if parameters.make_viewer: # In this case, we don't want the summary PDF. # Rather, we want to construct a viewer similar to E3SM Diags. title_and_url_list: List[Tuple[str, str]] = [] - for component in [ - "atm", - "ice", - "lnd", - "ocn", - ]: # Don't create viewer for original component + + # Create viewers for each component except original + for component in ["atm", "ice", "lnd", "ocn"]: # Don't create viewer for original component vars = get_vars(requested_variables, component) if vars: url = create_viewer(parameters, vars, component) logger.info(f"Viewer URL for {component}: {url}") title_and_url_list.append((component, url)) + # Special case for original plots: always use user-provided dimensions. vars = get_vars(requested_variables, "original") if vars: @@ -361,6 +570,7 @@ def coupled_global(parameters: Parameters) -> None: f"{parameters.figstr}_glb_original.pdf", ) ) - + + # Create index page for all viewers index_url: str = create_viewer_index(parameters.results_dir, title_and_url_list) logger.info(f"Viewer index URL: {index_url}") diff --git a/zppy_interfaces/global_time_series/coupled_global_dataset_wrapper.py b/zppy_interfaces/global_time_series/coupled_global_dataset_wrapper.py index cd5e3dc..c07845c 100644 --- a/zppy_interfaces/global_time_series/coupled_global_dataset_wrapper.py +++ b/zppy_interfaces/global_time_series/coupled_global_dataset_wrapper.py @@ -79,6 +79,7 @@ def globalAnnualHelper( scale_factor: float, original_units: str, final_units: str, + extract_region: bool = True, ) -> Tuple[xarray.core.dataarray.DataArray, str]: data_array: xarray.core.dataarray.DataArray @@ -181,12 +182,23 @@ def globalAnnualHelper( return data_array, units def globalAnnual( - self, var: Variable + self, var: Variable, all_regions: bool = False ) -> Tuple[xarray.core.dataarray.DataArray, str]: + """Get annual average for a variable. + + Args: + var: The variable to process + all_regions: If True, returns data for all regions without extraction + If False (default), extracts single region based on rgn param + + Returns: + Tuple of (data_array, units) + """ return self.globalAnnualHelper( var.variable_name, var.metric, var.scale_factor, var.original_units, var.final_units, + extract_region=not all_regions, ) From 619b2fa82e0e7e89dcac493a498e3df65a2c60e3 Mon Sep 17 00:00:00 2001 From: chengzhuzhang Date: Wed, 2 Apr 2025 20:28:40 -0700 Subject: [PATCH 3/4] Fix two issues with plotting and viewer creation MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 1. Fix issue with missing plots in original component for non-global regions - Handle global-only plots in non-global regions by creating empty plots - Ensures all plots appear in multi-plot pages 2. Fix viewer creation when make_viewer=True - Create viewers for all components with requested plots - Properly link to generated PNGs in the viewer 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude --- .../global_time_series/coupled_global.py | 10 ++-- .../coupled_global_plotting.py | 51 +++++++++++-------- 2 files changed, 37 insertions(+), 24 deletions(-) diff --git a/zppy_interfaces/global_time_series/coupled_global.py b/zppy_interfaces/global_time_series/coupled_global.py index 1d46f08..344fcdb 100644 --- a/zppy_interfaces/global_time_series/coupled_global.py +++ b/zppy_interfaces/global_time_series/coupled_global.py @@ -554,15 +554,17 @@ def coupled_global(parameters: Parameters) -> None: # Create viewers for each component except original for component in ["atm", "ice", "lnd", "ocn"]: # Don't create viewer for original component - vars = get_vars(requested_variables, component) - if vars: + component_plot_list = getattr(parameters, f"plots_{component}") + # Create viewer if this component was requested in parameters + if component_plot_list: + vars = get_vars(requested_variables, component) url = create_viewer(parameters, vars, component) logger.info(f"Viewer URL for {component}: {url}") title_and_url_list.append((component, url)) # Special case for original plots: always use user-provided dimensions. - vars = get_vars(requested_variables, "original") - if vars: + if parameters.plots_original: + vars = get_vars(requested_variables, "original") logger.info("Using user provided dimensions for original plots PDF") title_and_url_list.append( ( diff --git a/zppy_interfaces/global_time_series/coupled_global_plotting.py b/zppy_interfaces/global_time_series/coupled_global_plotting.py index 81a9afd..4768c46 100644 --- a/zppy_interfaces/global_time_series/coupled_global_plotting.py +++ b/zppy_interfaces/global_time_series/coupled_global_plotting.py @@ -595,27 +595,38 @@ def make_plot_pdfs( # noqa: C901 except KeyError: raise KeyError(f"Invalid plot name: {plot_name}") - try: - plot_function(ax, xlim, exps, rgn) + # Check if this is a global-only plot and we're not in global region + is_global_only = (plot_name in ["change_ohc", "max_moc", "change_sea_level"]) and (rgn != "glb") + + if is_global_only: + # For global-only plots in non-global regions, just create an empty plot + ax.set_title(f"{plot_name}") + ax.set_xticks([]) + ax.set_yticks([]) valid_plots.append(plot_name) - except Exception: - traceback.print_exc() - required_vars = [] - if plot_name == "net_toa_flux_restom": - required_vars = ["RESTOM"] - elif plot_name == "net_atm_energy_imbalance": - required_vars = ["RESTOM", "RESSURF"] - elif plot_name == "global_surface_air_temperature": - required_vars = ["TREFHT"] - elif plot_name == "toa_radiation": - required_vars = ["FSNTOA", "FLUT"] - elif plot_name == "net_atm_water_imbalance": - required_vars = ["PRECC", "PRECL", "QFLX"] - logger.error( - f"Failed plot_function for {plot_name}. Check that {required_vars} are available." - ) - invalid_plots.append(plot_name) - + else: + # For normal plots + try: + plot_function(ax, xlim, exps, rgn) + valid_plots.append(plot_name) + except Exception: + traceback.print_exc() + required_vars = [] + if plot_name == "net_toa_flux_restom": + required_vars = ["RESTOM"] + elif plot_name == "net_atm_energy_imbalance": + required_vars = ["RESTOM", "RESSURF"] + elif plot_name == "global_surface_air_temperature": + required_vars = ["TREFHT"] + elif plot_name == "toa_radiation": + required_vars = ["FSNTOA", "FLUT"] + elif plot_name == "net_atm_water_imbalance": + required_vars = ["PRECC", "PRECL", "QFLX"] + logger.error( + f"Failed plot_function for {plot_name}. Check that {required_vars} are available." + ) + invalid_plots.append(plot_name) + counter += 1 # Finalize and save the figure From b4f8f103cde0acfc1adaa0a11935d3e70b829264 Mon Sep 17 00:00:00 2001 From: chengzhuzhang Date: Wed, 2 Apr 2025 20:44:46 -0700 Subject: [PATCH 4/4] Revert "Fix two issues with plotting and viewer creation" This reverts commit 619b2fa82e0e7e89dcac493a498e3df65a2c60e3. --- .../global_time_series/coupled_global.py | 10 ++-- .../coupled_global_plotting.py | 51 ++++++++----------- 2 files changed, 24 insertions(+), 37 deletions(-) diff --git a/zppy_interfaces/global_time_series/coupled_global.py b/zppy_interfaces/global_time_series/coupled_global.py index 344fcdb..1d46f08 100644 --- a/zppy_interfaces/global_time_series/coupled_global.py +++ b/zppy_interfaces/global_time_series/coupled_global.py @@ -554,17 +554,15 @@ def coupled_global(parameters: Parameters) -> None: # Create viewers for each component except original for component in ["atm", "ice", "lnd", "ocn"]: # Don't create viewer for original component - component_plot_list = getattr(parameters, f"plots_{component}") - # Create viewer if this component was requested in parameters - if component_plot_list: - vars = get_vars(requested_variables, component) + vars = get_vars(requested_variables, component) + if vars: url = create_viewer(parameters, vars, component) logger.info(f"Viewer URL for {component}: {url}") title_and_url_list.append((component, url)) # Special case for original plots: always use user-provided dimensions. - if parameters.plots_original: - vars = get_vars(requested_variables, "original") + vars = get_vars(requested_variables, "original") + if vars: logger.info("Using user provided dimensions for original plots PDF") title_and_url_list.append( ( diff --git a/zppy_interfaces/global_time_series/coupled_global_plotting.py b/zppy_interfaces/global_time_series/coupled_global_plotting.py index 4768c46..81a9afd 100644 --- a/zppy_interfaces/global_time_series/coupled_global_plotting.py +++ b/zppy_interfaces/global_time_series/coupled_global_plotting.py @@ -595,38 +595,27 @@ def make_plot_pdfs( # noqa: C901 except KeyError: raise KeyError(f"Invalid plot name: {plot_name}") - # Check if this is a global-only plot and we're not in global region - is_global_only = (plot_name in ["change_ohc", "max_moc", "change_sea_level"]) and (rgn != "glb") - - if is_global_only: - # For global-only plots in non-global regions, just create an empty plot - ax.set_title(f"{plot_name}") - ax.set_xticks([]) - ax.set_yticks([]) + try: + plot_function(ax, xlim, exps, rgn) valid_plots.append(plot_name) - else: - # For normal plots - try: - plot_function(ax, xlim, exps, rgn) - valid_plots.append(plot_name) - except Exception: - traceback.print_exc() - required_vars = [] - if plot_name == "net_toa_flux_restom": - required_vars = ["RESTOM"] - elif plot_name == "net_atm_energy_imbalance": - required_vars = ["RESTOM", "RESSURF"] - elif plot_name == "global_surface_air_temperature": - required_vars = ["TREFHT"] - elif plot_name == "toa_radiation": - required_vars = ["FSNTOA", "FLUT"] - elif plot_name == "net_atm_water_imbalance": - required_vars = ["PRECC", "PRECL", "QFLX"] - logger.error( - f"Failed plot_function for {plot_name}. Check that {required_vars} are available." - ) - invalid_plots.append(plot_name) - + except Exception: + traceback.print_exc() + required_vars = [] + if plot_name == "net_toa_flux_restom": + required_vars = ["RESTOM"] + elif plot_name == "net_atm_energy_imbalance": + required_vars = ["RESTOM", "RESSURF"] + elif plot_name == "global_surface_air_temperature": + required_vars = ["TREFHT"] + elif plot_name == "toa_radiation": + required_vars = ["FSNTOA", "FLUT"] + elif plot_name == "net_atm_water_imbalance": + required_vars = ["PRECC", "PRECL", "QFLX"] + logger.error( + f"Failed plot_function for {plot_name}. Check that {required_vars} are available." + ) + invalid_plots.append(plot_name) + counter += 1 # Finalize and save the figure