diff --git a/zppy_interfaces/global_time_series/coupled_global.py b/zppy_interfaces/global_time_series/coupled_global.py index b07bee8..1d46f08 100644 --- a/zppy_interfaces/global_time_series/coupled_global.py +++ b/zppy_interfaces/global_time_series/coupled_global.py @@ -154,15 +154,28 @@ def get_exps(parameters: Parameters) -> List[Dict[str, Any]]: return exps -def set_var( +def load_all_var_data( exp: Dict[str, Any], exp_key: str, var_list: List[Variable], valid_vars: List[str], invalid_vars: List[str], - rgn: str, -) -> List[Variable]: +) -> Tuple[List[Variable], Dict[str, Tuple[xarray.core.dataarray.DataArray, str]]]: + """Load data for all variables for all regions. + + Args: + exp: Experiment configuration + exp_key: Key for data directory + var_list: List of variables to load + valid_vars: List to track successfully loaded variables + invalid_vars: List to track variables that couldn't be loaded + + Returns: + Tuple of (list of successfully loaded variables, dictionary of data arrays) + """ new_var_list: List[Variable] = [] + all_data_dict = {} + if exp[exp_key] != "": try: dataset_wrapper: DatasetWrapper = DatasetWrapper(exp[exp_key]) @@ -172,118 +185,292 @@ def set_var( f"DatasetWrapper object could not be created for {exp_key}={exp[exp_key]}" ) raise e + for var in var_list: var_str: str = var.variable_name try: + # Get data for all regions data_array: xarray.core.dataarray.DataArray units: str - data_array, units = dataset_wrapper.globalAnnual(var) - valid_vars.append(str(var_str)) # Append the name - new_var_list.append(var) # Append the variable itself + data_array, units = dataset_wrapper.globalAnnual(var, all_regions=True) + + # Store the result keyed by variable name + all_data_dict[var_str] = (data_array, units) + + # Track successful variables + valid_vars.append(str(var_str)) + new_var_list.append(var) + + # Store year info if not already present + if "year" not in exp["annual"]: + years: np.ndarray[cftime.DatetimeNoLeap] = data_array.coords["time"].values + exp["annual"]["year"] = [x.year for x in years] + except Exception as e: logger.error(e) logger.error(f"globalAnnual failed for {var_str}") invalid_vars.append(str(var_str)) - continue - if data_array.sizes["rgn"] > 1: - # number of years x 3 regions = data_array.shape - # 3 regions = global, northern hemisphere, southern hemisphere - # We get here if we used the updated `ts` task - # (using `rgn_avg` rather than `glb_avg`). - if rgn == "glb": - n = 0 - elif rgn == "n": - n = 1 - elif rgn == "s": - n = 2 - else: - raise RuntimeError(f"Invalid rgn={rgn}") - data_array = data_array.isel(rgn=n) # Just use nth region - elif rgn != "glb": - # data_array only has one dimension -- glb. - # Therefore it is not possible to get n or s plots. - raise RuntimeError( - f"var={var_str} only has global data. Cannot process rgn={rgn}" - ) - exp["annual"][var_str] = (data_array, units) - if "year" not in exp["annual"]: - years: np.ndarray[cftime.DatetimeNoLeap] = data_array.coords[ - "time" - ].values - exp["annual"]["year"] = [x.year for x in years] + del dataset_wrapper - return new_var_list + + return new_var_list, all_data_dict -def process_data( - parameters: Parameters, requested_variables: RequestedVariables, rgn: str -) -> List[Dict[str, Any]]: +def extract_region_data( + all_data_dict: Dict[str, Tuple[xarray.core.dataarray.DataArray, str]], + rgn: str, +) -> Dict[str, Tuple[xarray.core.dataarray.DataArray, str]]: + """Extract data for a specific region from the all-regions data dictionary. + + Args: + all_data_dict: Dictionary mapping variable names to (data_array, units) tuples + rgn: Region to extract ('glb', 'n', or 's') + + Returns: + Dictionary with data arrays extracted for the specified region + """ + region_data_dict = {} + + # Map region string to index + if rgn == "glb": + n = 0 + elif rgn == "n": + n = 1 + elif rgn == "s": + n = 2 + else: + raise RuntimeError(f"Invalid rgn={rgn}") + + # Process each variable + for var_str, (data_array, units) in all_data_dict.items(): + # Extract region if the data has multiple regions + if "rgn" in data_array.dims and data_array.sizes["rgn"] > 1: + # Extract the specific region + region_data = data_array.isel(rgn=n) + elif rgn != "glb": + # If no rgn dimension but trying to get n or s, that's an error + raise RuntimeError( + f"var={var_str} only has global data. Cannot process rgn={rgn}" + ) + else: + # No rgn dimension but wanting global data, or already extracted + region_data = data_array + + # Store in output dictionary + region_data_dict[var_str] = (region_data, units) + + return region_data_dict + + +def load_all_region_data( + parameters: Parameters, requested_variables: RequestedVariables +) -> Tuple[List[Dict[str, Any]], Dict[str, List[str]], Dict[str, List[str]]]: + """Load all data for all regions at once. + + Args: + parameters: Configuration parameters + requested_variables: Variables to load for each component + + Returns: + Tuple of (experiment data, valid variables, invalid variables) + """ + # Get experiment configurations exps: List[Dict[str, Any]] = get_exps(parameters) - valid_vars: List[str] = [] - invalid_vars: List[str] = [] - exp: Dict[str, Any] + + # Track valid and invalid variables by component + component_valid_vars = {"atmos": [], "ice": [], "land": [], "ocean": []} + component_invalid_vars = {"atmos": [], "ice": [], "land": [], "ocean": []} + + # Process each experiment for exp in exps: + # Initialize annual data and all-regions data storage exp["annual"] = {} - - requested_variables.vars_original = set_var( + exp["all_regions_data"] = {} + + # Initialize component data dictionaries + exp["all_regions_data"]["atmos"] = {} + exp["all_regions_data"]["ice"] = {} + exp["all_regions_data"]["land"] = {} + exp["all_regions_data"]["ocean"] = {} + + # Load data for each component - original vars (atmosphere variables) + requested_variables.vars_original, atmos_original_data = load_all_var_data( exp, "atmos", requested_variables.vars_original, - valid_vars, - invalid_vars, - rgn, + component_valid_vars["atmos"], + component_invalid_vars["atmos"], ) - requested_variables.vars_atm = set_var( - exp, "atmos", requested_variables.vars_atm, valid_vars, invalid_vars, rgn + exp["all_regions_data"]["atmos"].update(atmos_original_data) + + # Load data for each component - atmosphere variables + requested_variables.vars_atm, atmos_data = load_all_var_data( + exp, + "atmos", + requested_variables.vars_atm, + component_valid_vars["atmos"], + component_invalid_vars["atmos"], ) - requested_variables.vars_ice = set_var( - exp, "ice", requested_variables.vars_ice, valid_vars, invalid_vars, rgn + exp["all_regions_data"]["atmos"].update(atmos_data) + + # Load data for each component - ice variables + requested_variables.vars_ice, ice_data = load_all_var_data( + exp, + "ice", + requested_variables.vars_ice, + component_valid_vars["ice"], + component_invalid_vars["ice"], ) - requested_variables.vars_land = set_var( + exp["all_regions_data"]["ice"].update(ice_data) + + # Load data for each component - land variables + requested_variables.vars_land, land_data = load_all_var_data( exp, "land", requested_variables.vars_land, - valid_vars, - invalid_vars, - rgn, + component_valid_vars["land"], + component_invalid_vars["land"], ) - requested_variables.vars_ocn = set_var( - exp, "ocean", requested_variables.vars_ocn, valid_vars, invalid_vars, rgn + exp["all_regions_data"]["land"].update(land_data) + + # Load data for each component - ocean variables + requested_variables.vars_ocn, ocn_data = load_all_var_data( + exp, + "ocean", + requested_variables.vars_ocn, + component_valid_vars["ocean"], + component_invalid_vars["ocean"], ) - - # Optionally read ohc + exp["all_regions_data"]["ocean"].update(ocn_data) + + # Special handling for ocean heat content if exp["ocean"] != "": - dataset_wrapper = DatasetWrapper(exp["ocean"]) - exp["annual"]["ohc"], _ = dataset_wrapper.globalAnnual(Variable("ohc")) - # anomalies with respect to first year - exp["annual"]["ohc"][:] = exp["annual"]["ohc"][:] - exp["annual"]["ohc"][0] - + try: + dataset_wrapper = DatasetWrapper(exp["ocean"]) + data_array, units = dataset_wrapper.globalAnnual(Variable("ohc"), all_regions=True) + + # Store in all regions data + exp["all_regions_data"]["ocean"]["ohc"] = (data_array, units) + + # Track as valid variable + component_valid_vars["ocean"].append("ohc") + + del dataset_wrapper + except Exception as e: + logger.error(e) + logger.error("Failed to load ohc data") + component_invalid_vars["ocean"].append("ohc") + + # Special handling for ocean volume if exp["vol"] != "": - dataset_wrapper = DatasetWrapper(exp["vol"]) - exp["annual"]["volume"], _ = dataset_wrapper.globalAnnual( - Variable("volume") - ) - # annomalies with respect to first year - exp["annual"]["volume"][:] = ( - exp["annual"]["volume"][:] - exp["annual"]["volume"][0] - ) + try: + dataset_wrapper = DatasetWrapper(exp["vol"]) + data_array, units = dataset_wrapper.globalAnnual(Variable("volume"), all_regions=True) + + # Store in all regions data + exp["all_regions_data"]["ocean"]["volume"] = (data_array, units) + + # Track as valid variable + component_valid_vars["ocean"].append("volume") + + del dataset_wrapper + except Exception as e: + logger.error(e) + logger.error("Failed to load volume data") + component_invalid_vars["ocean"].append("volume") + + # Log success and failures for all components + for component, valid_vars in component_valid_vars.items(): + if valid_vars: + logger.info(f"{component} variables were computed successfully: {valid_vars}") + + for component, invalid_vars in component_invalid_vars.items(): + if invalid_vars: + logger.error(f"{component} variables could not be computed: {invalid_vars}") + + return exps, component_valid_vars, component_invalid_vars - logger.info( - f"{rgn} region globalAnnual was computed successfully for these variables: {valid_vars}" - ) - logger.error( - f"{rgn} region globalAnnual could not be computed for these variables: {invalid_vars}" - ) - return exps +def process_data( + all_region_exps: List[Dict[str, Any]], rgn: str +) -> List[Dict[str, Any]]: + """Process data for a specific region. + + Args: + all_region_exps: Experiments with all-regions data already loaded + rgn: Region to process ('glb', 'n', or 's') + + Returns: + List of experiment dictionaries with region-specific data + """ + # Create a deep copy to avoid modifying the original + import copy + exps = copy.deepcopy(all_region_exps) + + # Extract region-specific data for each experiment + for exp in exps: + # Extract atmosphere data + if "atmos" in exp["all_regions_data"]: + atmos_region_data = extract_region_data(exp["all_regions_data"]["atmos"], rgn) + exp["annual"].update(atmos_region_data) + + # Extract ice data + if "ice" in exp["all_regions_data"]: + ice_region_data = extract_region_data(exp["all_regions_data"]["ice"], rgn) + exp["annual"].update(ice_region_data) + + # Extract land data + if "land" in exp["all_regions_data"]: + land_region_data = extract_region_data(exp["all_regions_data"]["land"], rgn) + exp["annual"].update(land_region_data) + + # Extract ocean data + if "ocean" in exp["all_regions_data"]: + ocean_region_data = extract_region_data(exp["all_regions_data"]["ocean"], rgn) + exp["annual"].update(ocean_region_data) + + # Process OHC anomalies if available + if "ohc" in exp["annual"]: + # anomalies with respect to first year + ohc_data, ohc_units = exp["annual"]["ohc"] + ohc_anomaly = ohc_data - ohc_data[0] + exp["annual"]["ohc"] = (ohc_anomaly, ohc_units) + + # Process volume anomalies if available + if "volume" in exp["annual"]: + # anomalies with respect to first year + volume_data, volume_units = exp["annual"]["volume"] + volume_anomaly = volume_data - volume_data[0] + exp["annual"]["volume"] = (volume_anomaly, volume_units) + + # Clean up all_regions_data to save memory + del exp["all_regions_data"] + + return exps -# Run coupled_global ########################################################## -def run(parameters: Parameters, requested_variables: RequestedVariables, rgn: str): - # Experiments - exps: List[Dict[str, Any]] = process_data(parameters, requested_variables, rgn) +# Run coupled_global for a single region ########################################################## +def run_region( + parameters: Parameters, + requested_variables: RequestedVariables, + rgn: str, + all_region_exps: List[Dict[str, Any]] +): + """Process and plot data for a specific region. + + Args: + parameters: Configuration parameters + requested_variables: Variables to process + rgn: Region to process ('glb', 'n', or 's') + all_region_exps: Experiment data with all regions already loaded + """ + # Extract data for this specific region + exps: List[Dict[str, Any]] = process_data(all_region_exps, rgn) + + # Set up x-axis limits xlim: List[float] = [float(parameters.year1), float(parameters.year2)] + # Track successful and failed plots valid_plots: List[str] = [] invalid_plots: List[str] = [] @@ -298,6 +485,8 @@ def run(parameters: Parameters, requested_variables: RequestedVariables, rgn: st ("lnd", list(map(lambda v: v.variable_name, requested_variables.vars_land))), ("ocn", list(map(lambda v: v.variable_name, requested_variables.vars_ocn))), ] + + # Generate plots for each component for component, plot_list in mapping: make_plot_pdfs( parameters, @@ -309,13 +498,19 @@ def run(parameters: Parameters, requested_variables: RequestedVariables, rgn: st valid_plots, invalid_plots, ) - logger.info(f"These {rgn} region plots generated successfully: {valid_plots}") - logger.error( - f"These {rgn} region plots could not be generated successfully: {invalid_plots}" - ) + + # Log results + if valid_plots: + logger.info(f"These {rgn} region plots generated successfully: {valid_plots}") + + if invalid_plots: + logger.error( + f"These {rgn} region plots could not be generated successfully: {invalid_plots}" + ) def get_vars(requested_variables: RequestedVariables, component: str) -> List[Variable]: + """Get variable list for a specific component.""" vars: List[Variable] if component == "original": vars = requested_variables.vars_original @@ -333,24 +528,38 @@ def get_vars(requested_variables: RequestedVariables, component: str) -> List[Va def coupled_global(parameters: Parameters) -> None: + """Main entry point for the global time series plots. + + Changes from original version: + - Load all data for all regions once, then process each region using that data + - This reduces I/O operations significantly + """ + # Initialize variables for all components requested_variables = RequestedVariables(parameters) + + # OPTIMIZATION: Load all data for all regions once + logger.info("Loading data for all regions...") + all_region_exps, valid_vars, invalid_vars = load_all_region_data(parameters, requested_variables) + + # Process each region using the already-loaded data for rgn in parameters.regions: - run(parameters, requested_variables, rgn) + logger.info(f"Processing region: {rgn}") + run_region(parameters, requested_variables, rgn, all_region_exps) + + # Create viewer if requested if parameters.make_viewer: # In this case, we don't want the summary PDF. # Rather, we want to construct a viewer similar to E3SM Diags. title_and_url_list: List[Tuple[str, str]] = [] - for component in [ - "atm", - "ice", - "lnd", - "ocn", - ]: # Don't create viewer for original component + + # Create viewers for each component except original + for component in ["atm", "ice", "lnd", "ocn"]: # Don't create viewer for original component vars = get_vars(requested_variables, component) if vars: url = create_viewer(parameters, vars, component) logger.info(f"Viewer URL for {component}: {url}") title_and_url_list.append((component, url)) + # Special case for original plots: always use user-provided dimensions. vars = get_vars(requested_variables, "original") if vars: @@ -361,6 +570,7 @@ def coupled_global(parameters: Parameters) -> None: f"{parameters.figstr}_glb_original.pdf", ) ) - + + # Create index page for all viewers index_url: str = create_viewer_index(parameters.results_dir, title_and_url_list) logger.info(f"Viewer index URL: {index_url}") diff --git a/zppy_interfaces/global_time_series/coupled_global_dataset_wrapper.py b/zppy_interfaces/global_time_series/coupled_global_dataset_wrapper.py index cd5e3dc..c07845c 100644 --- a/zppy_interfaces/global_time_series/coupled_global_dataset_wrapper.py +++ b/zppy_interfaces/global_time_series/coupled_global_dataset_wrapper.py @@ -79,6 +79,7 @@ def globalAnnualHelper( scale_factor: float, original_units: str, final_units: str, + extract_region: bool = True, ) -> Tuple[xarray.core.dataarray.DataArray, str]: data_array: xarray.core.dataarray.DataArray @@ -181,12 +182,23 @@ def globalAnnualHelper( return data_array, units def globalAnnual( - self, var: Variable + self, var: Variable, all_regions: bool = False ) -> Tuple[xarray.core.dataarray.DataArray, str]: + """Get annual average for a variable. + + Args: + var: The variable to process + all_regions: If True, returns data for all regions without extraction + If False (default), extracts single region based on rgn param + + Returns: + Tuple of (data_array, units) + """ return self.globalAnnualHelper( var.variable_name, var.metric, var.scale_factor, var.original_units, var.final_units, + extract_region=not all_regions, ) diff --git a/zppy_interfaces/global_time_series/coupled_global_plotting.py b/zppy_interfaces/global_time_series/coupled_global_plotting.py index 42f1f6b..81a9afd 100644 --- a/zppy_interfaces/global_time_series/coupled_global_plotting.py +++ b/zppy_interfaces/global_time_series/coupled_global_plotting.py @@ -556,53 +556,45 @@ def make_plot_pdfs( # noqa: C901 if num_plots == 0: return - # If make_viewer, then we want to do 1 plot per page. - # However, the original plots are excluded from this restriction. - # Note: if the user provides nrows=ncols=1, there will still be a single plot per page - keep_user_dims = (not parameters.make_viewer) or (component == "original") - if keep_user_dims: + # Ensure output directory exists + os.makedirs(parameters.results_dir, exist_ok=True) + + # For "original" component: Create multi-figure PDF with multiple plots per page + if component == "original": + # Use user-specified dimensions for original component nrows = parameters.nrows ncols = parameters.ncols - else: - nrows = 1 - ncols = 1 - - plots_per_page = nrows * ncols - num_pages = math.ceil(num_plots / plots_per_page) - - counter = 0 - os.makedirs(parameters.results_dir, exist_ok=True) - # https://stackoverflow.com/questions/58738992/save-multiple-figures-with-subplots-into-a-pdf-with-multiple-pages - pdf = matplotlib.backends.backend_pdf.PdfPages( - f"{parameters.results_dir}/{parameters.figstr}_{rgn}_{component}.pdf" - ) - for page in range(num_pages): - if plots_per_page == 1: - logger.info("Using reduced figsize") - fig = plt.figure(1, figsize=[13.5 / 2, 16.5 / 4]) - else: - logger.info("Using standard figsize") + plots_per_page = nrows * ncols + num_pages = math.ceil(num_plots / plots_per_page) + + # Create PDF file + pdf = matplotlib.backends.backend_pdf.PdfPages( + f"{parameters.results_dir}/{parameters.figstr}_{rgn}_{component}.pdf" + ) + + # Process each page + counter = 0 + for page in range(num_pages): + # Create multi-plot figure fig = plt.figure(1, figsize=[13.5, 16.5]) - logger.info(f"Figure size={fig.get_size_inches() * fig.dpi}") - fig.suptitle(f"{parameters.figstr}_{rgn}_{component}") - for j in range(plots_per_page): - logger.info( - f"Plotting plot {j} on page {page}. This is plot {counter} in total." - ) - # The final page doesn't need to be filled out with plots. - if counter >= num_plots: - break - ax = plt.subplot( - nrows, - ncols, - j + 1, - ) - plot_name = plot_list[counter] - if component == "original": + fig.suptitle(f"{parameters.figstr}_{rgn}_{component}") + + # Process plots for this page + for j in range(plots_per_page): + # The final page doesn't need to be filled out with plots + if counter >= num_plots: + break + + # Create subplot + ax = plt.subplot(nrows, ncols, j + 1) + plot_name = plot_list[counter] + + # Generate plot try: plot_function = PLOT_DICT[plot_name] except KeyError: raise KeyError(f"Invalid plot name: {plot_name}") + try: plot_function(ax, xlim, exps, rgn) valid_plots.append(plot_name) @@ -623,36 +615,55 @@ def make_plot_pdfs( # noqa: C901 f"Failed plot_function for {plot_name}. Check that {required_vars} are available." ) invalid_plots.append(plot_name) + counter += 1 + + # Finalize and save the figure + fig.tight_layout() + pdf.savefig(fig) + + # Save PNG of the entire page + if num_pages > 1: + # Multi-page PDF - include page number in filename + png_path = f"{parameters.results_dir}/{parameters.figstr}_{rgn}_{component}_{page}.png" else: - try: - plot_generic(ax, xlim, exps, plot_name, rgn) - valid_plots.append(plot_name) - except Exception: - traceback.print_exc() - logger.error( - f"plot_generic failed. Invalid plot={plot_name}, rgn={rgn}" - ) - invalid_plots.append(plot_name) - counter += 1 - - fig.tight_layout() - pdf.savefig(1) - # Always save individual PNGs for viewer mode - if plots_per_page == 1: - fig.savefig( - f"{parameters.results_dir}/{parameters.figstr}_{rgn}_{component}_{plot_name}.png", - dpi=150, - ) - elif num_pages > 1: - fig.savefig( - f"{parameters.results_dir}/{parameters.figstr}_{rgn}_{component}_{page}.png", - dpi=150, - ) - else: - fig.savefig( - f"{parameters.results_dir}/{parameters.figstr}_{rgn}_{component}.png", - dpi=150, - ) - plt.close(fig) - pdf.close() + # Single page PDF + png_path = f"{parameters.results_dir}/{parameters.figstr}_{rgn}_{component}.png" + + fig.savefig(png_path, dpi=100) # Use lower DPI for better performance + plt.close(fig) + + # Close PDF file + pdf.close() + + # For non-original components: Create individual PNGs (one plot per file) + else: + # Process each plot individually + for i, plot_name in enumerate(plot_list): + # Create single-plot figure + fig = plt.figure(figsize=[13.5 / 2, 16.5 / 4]) + ax = fig.add_subplot(111) + + # Generate plot + try: + plot_generic(ax, xlim, exps, plot_name, rgn) + valid_plots.append(plot_name) + except Exception: + traceback.print_exc() + logger.error( + f"plot_generic failed. Invalid plot={plot_name}, rgn={rgn}" + ) + invalid_plots.append(plot_name) + plt.close(fig) + continue # Skip saving failed plots + + # Finalize and save PNG + fig.suptitle(f"{parameters.figstr}_{rgn}_{component}") + fig.tight_layout() + + # Save individual PNG + png_path = f"{parameters.results_dir}/{parameters.figstr}_{rgn}_{component}_{plot_name}.png" + fig.savefig(png_path, dpi=100) + + # Clean up + plt.close(fig)