diff --git a/ax/analysis/plotly/__init__.py b/ax/analysis/plotly/__init__.py index a194854ca8f..51fe32e6274 100644 --- a/ax/analysis/plotly/__init__.py +++ b/ax/analysis/plotly/__init__.py @@ -6,6 +6,7 @@ # pyre-strict from ax.analysis.plotly.cross_validation import CrossValidationPlot +from ax.analysis.plotly.interaction import InteractionPlot from ax.analysis.plotly.parallel_coordinates import ParallelCoordinatesPlot from ax.analysis.plotly.plotly_analysis import PlotlyAnalysis, PlotlyAnalysisCard from ax.analysis.plotly.scatter import ScatterPlot @@ -15,6 +16,7 @@ __all__ = [ "ContourPlot", "CrossValidationPlot", + "InteractionPlot", "PlotlyAnalysis", "PlotlyAnalysisCard", "ParallelCoordinatesPlot", diff --git a/ax/analysis/plotly/interaction.py b/ax/analysis/plotly/interaction.py index ccc5fd98eef..91446720af6 100644 --- a/ax/analysis/plotly/interaction.py +++ b/ax/analysis/plotly/interaction.py @@ -6,126 +6,42 @@ # pyre-strict -import math -from typing import Any - -import numpy as np -import numpy.typing as npt +from logging import Logger import pandas as pd - import torch from ax.analysis.analysis import AnalysisCardLevel from ax.analysis.plotly.plotly_analysis import PlotlyAnalysis, PlotlyAnalysisCard -from ax.core.data import Data + +from ax.analysis.plotly.surface.contour import ( + _prepare_data as _prepare_contour_data, + _prepare_plot as _prepare_contour_plot, +) +from ax.analysis.plotly.surface.slice import ( + _prepare_data as _prepare_slice_data, + _prepare_plot as _prepare_slice_plot, +) +from ax.analysis.plotly.surface.utils import is_axis_log_scale +from ax.analysis.plotly.utils import select_metric from ax.core.experiment import Experiment from ax.core.generation_strategy_interface import GenerationStrategyInterface -from ax.core.observation import ObservationFeatures from ax.exceptions.core import UserInputError from ax.modelbridge.registry import Models from ax.modelbridge.torch import TorchModelBridge from ax.models.torch.botorch_modular.surrogate import Surrogate -from ax.plot.contour import _get_contour_predictions -from ax.plot.feature_importances import plot_feature_importance_by_feature_plotly -from ax.plot.helper import TNullableGeneratorRunsDict -from ax.plot.slice import _get_slice_predictions +from ax.utils.common.logger import get_logger from ax.utils.sensitivity.sobol_measures import ax_parameter_sens from botorch.models.kernels.orthogonal_additive_kernel import OrthogonalAdditiveKernel from gpytorch.constraints import Positive from gpytorch.kernels import RBFKernel from gpytorch.priors import LogNormalPrior -from plotly import graph_objects as go, io as pio +from plotly import express as px, graph_objects as go from plotly.subplots import make_subplots -from pyre_extensions import none_throws - - -TOP_K_TOO_LARGE_ERROR = ( - "Interaction Analysis only supports visualizing the slice/contour for" - " up to 6 component defined by the `top_k` argument, but received" - " {} as input." -) -MAX_NUM_PLOT_COMPONENTS: int = 6 -PLOT_SIZE: int = 380 - - -def get_model_kwargs( - use_interaction: bool, - num_parameters: int, - torch_device: torch.device | None = None, -) -> dict[str, Any]: - """Method to get the specific OAK kernel used to identify parameter interactions - in an Ax experiment. The kernel is an Orthogonal Additive Kernel (OAK), which - decomposes the objective into an additive sum of main parameter effects and - pairwise interaction effects. The kernel comes with a sparsity-inducing prior, - which attempts explain the data with as few components as possible. The - smoothness of the components is regularized by a lengthscale prior to guard - against excessicely short lengthscales being fit. - - Args: - use_interaction: Whether to use interaction effects. - num_parameters: Number of parameters in the experiment. - torch_device: The type of torch device to use for the model. - """ - # A fairly restrictive prior on the lengthscale avoids spurious - # fits, where a single component is fit to explain all variability. - # NOTE (hvarfner) Imposing a calibrated sparsity-inducing prior is - # probably a good add, but second-order components tend to break down - # even for weak priors. - return { - "covar_module_class": OrthogonalAdditiveKernel, - "covar_module_options": { - "base_kernel": RBFKernel( - ard_num_dims=num_parameters, - lengthscale_prior=LogNormalPrior(2, 1), - ), - "dim": num_parameters, - "dtype": torch.float64, - "device": torch_device, - "second_order": use_interaction, - "coeff_constraint": Positive(transform=torch.exp, inv_transform=torch.log), - }, - "allow_batched_models": False, - } - - -def sort_and_filter_top_k_components( - indices: dict[str, dict[str, npt.NDArray]], - k: int, - most_important: bool = True, -) -> dict[str, dict[str, npt.NDArray]]: - """Sorts and filter the top k components according to Sobol indices, per metric. - - Args: - indices: A dictionary of {metric: {component: sobol_index}} Sobol indices. - k: The number of components to keep. - most_important: Whether to keep the most or least important components. - - Returns: - A dictionary of the top k components. - """ - metrics = list(indices.keys()) - sorted_indices = { - metric: dict( - sorted( - metric_indices.items(), - key=lambda x: x[1], - reverse=most_important, - ) - ) - for metric, metric_indices in indices.items() - } +from pyre_extensions import assert_is_instance - # filter to top k components - sorted_indices = { - metric: { - key: value - for _, (key, value) in zip(range(k), sorted_indices[metric].items()) - } - for metric in metrics - } - return sorted_indices +logger: Logger = get_logger(__name__) class InteractionPlot(PlotlyAnalysis): @@ -135,100 +51,46 @@ class InteractionPlot(PlotlyAnalysis): relative importance of each component is quantified by its Sobol index. Each component may be visualized through slice or contour plots depending on if it is a first order or second order component, respectively. + + The DataFrame computed will contain just the sensitivity analyisis with one row per + parameter and the following columns: + - feature: The name of the first or second order component + - sensitivity: The sensitivity of the component """ def __init__( self, - metric_name: str, - top_k: int = 6, - data: Data | None = None, - most_important: bool = True, + metric_name: str | None = None, fit_interactions: bool = True, - display_components: bool = False, - decompose_components: bool = False, - plots_share_range: bool = True, - num_mc_samples: int = 10_000, - model_fit_seed: int = 0, + most_important: bool = True, + seed: int = 0, torch_device: torch.device | None = None, ) -> None: - """Constructor for InteractionAnalysis. - + """ Args: metric_name: The metric to analyze. - top_k: The 'k' most imortant interactions according to Sobol indices. - Supports up to 6 components visualized at once. - data: The data to analyze. Defaults to None, in which case the data is taken - from the experiment. - most_important: Whether to plot the most or least important interactions. fit_interactions: Whether to fit interaction effects in addition to main effects. - display_components: Display individual components instead of the summarized - plot of sobol index values. - decompose_components: Whether to visualize surfaces as the total effect of - x1 & x2 (False) or only the interaction term (True). Setting - decompose_components = True thus plots f(x1, x2) - f(x1) - f(x2). - plots_share_range: Whether to have all plots share the same output range in - the final visualization. - num_mc_samples: The number of Monte Carlo samples to use for the Sobol - index calculations. - model_fit_seed: The seed with which to fit the model. Defaults to 0. Used + most_important: Whether to sort by most or least important features in the + bar subplot. Also controls whether the six most or least important + features are plotted in the surface subplots. + seed: The seed with which to fit the model. Defaults to 0. Used to ensure that the model fit is identical across the generation of various plots. torch_device: The torch device to use for the model. """ - super().__init__() - if top_k > 6 and display_components: - raise UserInputError(TOP_K_TOO_LARGE_ERROR.format(str(top_k))) - self.metric_name: str = metric_name - self.top_k: int = top_k - self.data: Data | None = data - self.most_important: bool = most_important - self.fit_interactions: bool = fit_interactions - self.display_components: bool = display_components - self.decompose_components: bool = decompose_components - self.num_mc_samples: int = num_mc_samples - self.model_fit_seed: int = model_fit_seed - self.torch_device: torch.device | None = torch_device - self.plots_share_range: bool = plots_share_range - - def get_model( - self, experiment: Experiment, metric_names: list[str] | None = None - ) -> TorchModelBridge: - """ - Retrieves the modelbridge used for the analysis. The model uses an OAK - (Orthogonal Additive Kernel) with a sparsity-inducing prior, - which decomposes the objective into an additive sum of components. - """ - covar_module_kwargs = get_model_kwargs( - use_interaction=self.fit_interactions, - num_parameters=len(experiment.search_space.tunable_parameters), - torch_device=self.torch_device, - ) - data = experiment.lookup_data() if self.data is None else self.data - if metric_names: - data = data.filter(metric_names=metric_names) - with torch.random.fork_rng(): - # fixing the seed to ensure that the model is fit identically across - # different analyses of the same experiment - torch.torch.manual_seed(self.model_fit_seed) - model_bridge = Models.BOTORCH_MODULAR( - search_space=experiment.search_space, - experiment=experiment, - data=data, - surrogate=Surrogate(**covar_module_kwargs), - ) - return model_bridge # pyre-ignore[7] Return type is always a TorchModelBridge + self.metric_name = metric_name + self.fit_interactions = fit_interactions + self.most_important = most_important + self.seed = seed + self.torch_device = torch_device - # pyre-ignore[14] Must pass in an Experiment (not Experiment | None) def compute( self, experiment: Experiment | None = None, generation_strategy: GenerationStrategyInterface | None = None, ) -> PlotlyAnalysisCard: - model_bridge = self.get_model( - experiment=none_throws(experiment), metric_names=[self.metric_name] - ) """ Compute Sobol index sensitivity for one metric of an experiment. Sensitivity is comptuted by component, where a compoent may be either one variable @@ -237,378 +99,257 @@ def compute( to be a sum of components, and where marginal effects can be computed accurately. """ - experiment = none_throws(experiment) - model_bridge = self.get_model(experiment, [self.metric_name]) + + if experiment is None: + raise UserInputError("InteractionPlot requires an Experiment") + + metric_name = self.metric_name or select_metric(experiment=experiment) + + # Fix the seed to ensure that the model is fit identically across different + # analyses of the same experiment. with torch.random.fork_rng(): - # fixing the seed to ensure that the model is fit identically across - # different analyses of the same experiment - torch.torch.manual_seed(self.model_fit_seed) + torch.torch.manual_seed(self.seed) + + # Fit the OAK model. + oak_model = self._get_oak_model( + experiment=experiment, metric_name=metric_name + ) + + # Calculate first- or second-order Sobol indices. sens = ax_parameter_sens( - model_bridge=model_bridge, - metrics=[self.metric_name], + model_bridge=oak_model, + metrics=[metric_name], order="second" if self.fit_interactions else "first", signed=not self.fit_interactions, - num_mc_samples=self.num_mc_samples, - ) - sens = sort_and_filter_top_k_components( - indices=sens, k=self.top_k, most_important=self.most_important - ) - if not self.display_components: - return PlotlyAnalysisCard( - name="Interaction Analysis", - title=f"Feature Importance Analysis for {self.metric_name}", - subtitle=( - "Displays the most important features " - f"for {self.metric_name} by order of importance." - ), - level=AnalysisCardLevel.MID, - df=pd.DataFrame(sens), - blob=pio.to_json( - plot_feature_importance_by_feature_plotly( - sensitivity_values=sens, # pyre-ignore[6] - ) - ), - ) - else: - metric_sens = list(sens[self.metric_name].keys()) - return PlotlyAnalysisCard( - name="OAK Interaction Analysis", - title=( - "Additive Component Feature Importance Analysis " - f"for {self.metric_name}" - ), - subtitle=( - "Displays the most important features' effects " - f"on {self.metric_name} by order of importance." - ), - level=AnalysisCardLevel.MID, - df=pd.DataFrame(sens), - blob=pio.to_json( - plot_component_surfaces_plotly( - features=metric_sens, - model=model_bridge, - metric=self.metric_name, - plots_share_range=self.plots_share_range, - ) - ), - ) + )[metric_name] + sensitivity_df = pd.DataFrame( + [*sens.items()], columns=["feature", "sensitivity"] + ).sort_values(by="sensitivity", key=abs, ascending=self.most_important) -def update_plot_range(max_range: list[float], new_range: list[float]) -> list[float]: - """Updates the range to include the value. - Args: - max_range: Current max_range among all considered ranges. - new_range: New range to consider to be included. + # Calculate feature importance bar plot. Only plot the top 15 features. + # Plot the absolute value of the sensitivity but color by the sign. + plotting_df = sensitivity_df.head(15).copy() + plotting_df["direction"] = plotting_df["sensitivity"].apply( + lambda x: "Increases Metric" if x >= 0 else "Decreases Metric" + ) + plotting_df["sensitivity"] = plotting_df["sensitivity"].abs() + plotting_df.sort_values( + by="sensitivity", ascending=self.most_important, inplace=True + ) - Returns: - The updated max_range. - """ - if max_range[0] > new_range[0]: - max_range[0] = new_range[0] - if max_range[1] < new_range[1]: - max_range[1] = new_range[1] - return max_range + sensitivity_fig = px.bar( + plotting_df, + x="sensitivity", + y="feature", + color="direction", + # Increase gets blue, decrease gets orange. + color_discrete_sequence=["orange", "blue"], + orientation="h", + ) + # Calculate surface plots for six most or least important features + # Note: We use tail and reverse here because the bar plot is sorted from top to + # bottom. + top_features = [*reversed(plotting_df.tail(6)["feature"].to_list())] + surface_figs = [] + for feature_name in top_features: + try: + surface_figs.append( + _prepare_surface_plot( + experiment=experiment, + model=oak_model, + feature_name=feature_name, + metric_name=metric_name, + ) + ) + # Not all features will be able to be plotted, skip them gracefully. + except Exception as e: + logger.error(f"Failed to generate surface plot for {feature_name}: {e}") + + # Create a plotly figure to hold the bar plot in the top row and surface plots + # in a 2x3 grid below. + fig = make_subplots( + rows=4, + cols=2, + specs=[ + [{"colspan": 2}, None], + [{}, {}], + [{}, {}], + [{}, {}], + ], + ) -def plot_component_surfaces_plotly( - features: list[str], - model: TorchModelBridge, - metric: str, - plots_share_range: bool = True, - generator_runs_dict: TNullableGeneratorRunsDict = None, - density: int = 50, - slice_values: dict[str, Any] | None = None, - fixed_features: ObservationFeatures | None = None, - trial_index: int | None = None, - renormalize: bool = True, -) -> go.Figure: - """Plots the interaction surfaces for the given features. - - Args: - features: The features to plot. Can be either 1D or 2D, where 2D features are - specified as "x1 & x2". - model: The modelbridge used for prediction. - metric: The name of the metric to plot. - plots_share_range: Whether to have all plots should share the same output range. - generator_runs_dict: The generator runs dict to use. - density: The density of the grid, i.e. the number of points evaluated in each - dimension. - slice_values: The slice values to use for the parameters that are not plotted. - fixed_features: The fixed features to use. - trial_index: The trial index to include in the plot. - renormalize: Whether to renormalize the surface to have zero mean. - - Returns: - A plotly figure of all the interaction surfaces. - """ - traces = [] - titles = [] - param_names = [] - - # tracks the maximal value range so that all plots of the same type share the same - # signal range in the final visualization. We cannot just check the largest - # component's sobol index, as it may not have the largest co-domain. - surface_range = [float("inf"), -float("inf")] - slice_range = [float("inf"), -float("inf")] - first_surface = True - for feature in features: - if " & " in feature: - component_x, component_y = feature.split(" & ") - trace, minval, maxval = generate_interaction_component( - model=model, - component_x=component_x, - component_y=component_y, - metric=metric, - generator_runs_dict=generator_runs_dict, - density=density, - slice_values=slice_values, - fixed_features=fixed_features, - trial_index=trial_index, - first_surface=first_surface, - ) - first_surface = False - traces.append(trace) - param_names.append((component_x, component_y)) - titles.append(f"Total effect, {component_x} & {component_y}") - surface_range = update_plot_range(surface_range, [minval, maxval]) - else: - trace, minval, maxval = generate_main_effect_component( - model=model, - component=feature, - metric=metric, - generator_runs_dict=generator_runs_dict, - density=density, - slice_values=slice_values, - fixed_features=fixed_features, - trial_index=trial_index, - ) - traces.append(trace) - param_names.append(feature) - titles.append(f"Main Effect, {feature}") - slice_range = update_plot_range(slice_range, [minval, maxval]) - - # 1x3 plots if 3 total, 2x2 plots if 4 total, 3x2 plots if 6 total - num_rows = 1 if len(traces) <= (MAX_NUM_PLOT_COMPONENTS / 2) else 2 - num_cols = math.ceil(len(traces) / num_rows) - - fig = make_subplots(rows=num_rows, cols=num_cols, subplot_titles=titles) - for plot_idx, trace in enumerate(traces): - row = plot_idx // num_cols + 1 - col = plot_idx % num_cols + 1 - fig.add_trace(trace, row=row, col=col) - fig = set_axis_names( - figure=fig, trace=trace, row=row, col=col, param_names=param_names[plot_idx] + for trace in sensitivity_fig.data: + fig.add_trace(trace, row=1, col=1) + # Fix order of the features in the bar plot. + fig.update_yaxes( + categoryorder="array", categoryarray=plotting_df["feature"], row=1, col=1 ) - fig = scale_traces( - figure=fig, - traces=traces, - surface_range=surface_range, - slice_range=slice_range, - plots_share_range=plots_share_range, - ) - fig.update_layout({"width": PLOT_SIZE * num_cols, "height": PLOT_SIZE * num_rows}) - return fig + for i in range(len(surface_figs)): + feature_name = top_features[i] + surface_fig = surface_figs[i] + + row = (i // 2) + 2 + col = (i % 2) + 1 + for trace in surface_fig.data: + fig.add_trace(trace, row=row, col=col) + + # Label and fix axes + if "&" in feature_name: # If the feature is a second-order component + x, y = feature_name.split(" & ") + + # Reapply log scale if necessary + fig.update_xaxes( + title_text=x, + type=( + "log" + if is_axis_log_scale( + parameter=experiment.search_space.parameters[x] + ) + else "linear" + ), + row=row, + col=col, + ) + fig.update_yaxes( + title_text=y, + type=( + "log" + if is_axis_log_scale( + parameter=experiment.search_space.parameters[y] + ) + else "linear" + ), + row=row, + col=col, + ) + else: # If the feature is a first-order component + fig.update_xaxes( + title_text=feature_name, + type=( + "log" + if is_axis_log_scale( + parameter=experiment.search_space.parameters[feature_name] + ) + else "linear" + ), + row=row, + col=col, + ) + + # Expand layout since default rendering in most notebooks is too small. + fig.update_layout( + height=1500, + width=1000, + ) + subtitle_substring = ( + "one- or two-dimensional" if self.fit_interactions else "one-dimensional" + ) -def generate_main_effect_component( - model: TorchModelBridge, - component: str, - metric: str, - generator_runs_dict: TNullableGeneratorRunsDict = None, - density: int = 50, - slice_values: dict[str, Any] | None = None, - fixed_features: ObservationFeatures | None = None, - trial_index: int | None = None, -) -> tuple[go.Scatter, float, float]: - """Plots a slice "main effect" of the model for a given component. The values are - relative to the mean of all predictions, so that the magnitude of the component is - communicated. - - Args: - model: The modelbridge used for prediction. - component: The name of the component to plot. - metric: The name of the metric to plot. - generator_runs_dict: The generator runs dict to use. - density: The density of the grid, i.e. the number of points evaluated in each - dimension. - slice_values: The slice values to use for the parameters that are not plotted. - fixed_features: The fixed features to use. - trial_index: The trial index to include in the plot. - - Returns: - A contour plot of the component interaction, and the range of the plot. - """ - _, _, slice_mean, _, grid, _, _, _, _, slice_stdev, _ = _get_slice_predictions( - model=model, - param_name=component, - metric_name=metric, - generator_runs_dict=generator_runs_dict, - density=density, - slice_values=slice_values, - fixed_features=fixed_features, - trial_index=trial_index, - ) - # renormalize the slice to have zero mean (done for each component) - slice_mean = np.array(slice_mean) - np.array(slice_mean).mean() - - trace = go.Scatter( - x=grid, - y=slice_mean, - name=component, - line_shape="spline", - showlegend=False, - error_y={ - "type": "data", - "array": slice_stdev, - "visible": True, - "thickness": 0.8, - }, - ) + return self._create_plotly_analysis_card( + title=f"Interaction Analysis for {metric_name}", + subtitle=( + f"Understand an Experiment's data as {subtitle_substring} additive " + "components with sparsity. Important components are visualized through " + "slice or contour plots" + ), + level=AnalysisCardLevel.MID, + df=sensitivity_df, + fig=fig, + ) - return trace, np.min(slice_mean).astype(float), np.max(slice_mean).astype(float) + def _get_oak_model( + self, experiment: Experiment, metric_name: str + ) -> TorchModelBridge: + """ + Retrieves the modelbridge used for the analysis. The model uses an OAK + (Orthogonal Additive Kernel) with a sparsity-inducing prior, + which decomposes the objective into an additive sum of components. + The kernel comes with a sparsity-inducing prior, which attempts explain the + data with as few components as possible. The smoothness of the components is + regularized by a lengthscale prior to guard against excessicely short + lengthscales being fit. + """ + data = experiment.lookup_data().filter(metric_names=[metric_name]) + model_bridge = Models.BOTORCH_MODULAR( + search_space=experiment.search_space, + experiment=experiment, + data=data, + surrogate=Surrogate( + covar_module_class=OrthogonalAdditiveKernel, + covar_module_options={ + # A fairly restrictive prior on the lengthscale avoids spurious + # fits, where a single component is fit to explain all + # variability. + # NOTE (hvarfner) Imposing a calibrated sparsity-inducing prior + # is probably a good add, but second-order components tend to + # break down even for weak priors. + "base_kernel": RBFKernel( + ard_num_dims=len(experiment.search_space.tunable_parameters), + lengthscale_prior=LogNormalPrior(2, 1), + ), + "dim": len(experiment.search_space.tunable_parameters), + "dtype": torch.float64, + "device": self.torch_device, + "second_order": self.fit_interactions, + "coeff_constraint": Positive( + transform=torch.exp, inv_transform=torch.log + ), + }, + allow_batched_models=False, + ), + ) -def generate_interaction_component( - model: TorchModelBridge, - component_x: str, - component_y: str, - metric: str, - generator_runs_dict: TNullableGeneratorRunsDict = None, - density: int = 50, - slice_values: dict[str, Any] | None = None, - fixed_features: ObservationFeatures | None = None, - trial_index: int | None = None, - renormalize: bool = True, - first_surface: bool = True, -) -> tuple[go.Contour, float, float]: - """Plots a slice "main effect" of the model for a given component. The values are - relative to the mean of all predictions, so that the magnitude of the component is - communicated. - - Args: - model: The modelbridge used for prediction. - component_x: The name of the component to plot along the x-axis. - component_y: The name of the component to plot along the y-axis. - metric: The name of the metric to plot. - subtract_main_effects: Whether to subtract the main effects from the 2D surface. - If main effects are not subtracted, the 2D surface is the output of - plot_contour and models the effect of each parameter in isolation and their - interaction. If main effects are subtracted, the 2D surface visualizes only - the interaction effect of the two parameters. - generator_runs_dict: The generator runs dict to use. - density: The density of the grid, i.e. the number of points evaluated in each - dimension. - slice_values: The slice values to use for the parameters that are not plotted. - fixed_features: The fixed features to use. - trial_index: The trial index to include in the plot. - renormalize: Whether to renormalize the surface to have zero mean. - first_surface: Whether this is the first surface to be plotted. If so, we plot - its colorbar. - - Returns: - A contour plot of the component interaction, and the range of the plot. - """ - comp_name: str = f"{component_x} & {component_y}" - fixed_kwargs: dict[str, Any] = { - "model": model, - "generator_runs_dict": generator_runs_dict, - "density": density, - "slice_values": slice_values, - "fixed_features": fixed_features, - } - _, contour_mean, _, grid_x, grid_y, _ = _get_contour_predictions( - x_param_name=component_x, - y_param_name=component_y, - metric=metric, - **fixed_kwargs, - ) - contour_mean = np.reshape(contour_mean, (density, density)) - contour_mean = contour_mean - contour_mean.mean() - return ( - go.Contour( - z=contour_mean, - x=grid_x, - y=grid_y, - name=comp_name, - ncontours=50, - showscale=first_surface, - ), - np.min(contour_mean).astype(float), - np.max(contour_mean).astype(float), - ) + return assert_is_instance(model_bridge, TorchModelBridge) -def scale_traces( - figure: go.Figure, - traces: list[go.Scatter | go.Contour], - surface_range: list[float], - slice_range: list[float], - plots_share_range: bool = True, -) -> go.Figure: - """Scales the traces to have the same range. - - Args: - figure: The main plotly figure to update the traces on. - traces: The traces to scale. - surface_range: The range of the surface traces. - slice_range: The range of the slice traces. - plots_share_range: Whether to have all plots (and not just plots - of the same type) share the same output range. - - Returns: - A figure with the traces of the same type are scaled to have the same range. - """ - if plots_share_range: - total_range = update_plot_range(surface_range, slice_range) - slice_range = total_range - surface_range = total_range - - # plotly axis names in layout are of the form "xaxis{idx}" and "yaxis{idx}" except - # for the first one, which is "xaxis" and "yaxis". We need to keep track of the - # indices of the traces and then use the correct axis names when updating ranges. - axis_names = ["yaxis"] + [f"yaxis{idx}" for idx in range(2, len(traces) + 1)] - slice_axes = [ - axis_name - for trace, axis_name in zip(traces, axis_names) - if isinstance(trace, go.Scatter) - ] - - # scale the surface traces to have the same range - for trace_idx in range(len(figure["data"])): - trace = figure["data"][trace_idx] - if isinstance(trace, go.Contour): - trace["zmin"] = surface_range[0] - trace["zmax"] = surface_range[1] - - # and scale the slice traces to have the same range - figure.update_layout({ax: {"range": slice_range} for ax in slice_axes}) - return figure - - -def set_axis_names( - figure: go.Figure, - trace: go.Contour | go.Scatter, - row: int, - col: int, - param_names: str | tuple[str, str], +def _prepare_surface_plot( + experiment: Experiment, + model: TorchModelBridge, + feature_name: str, + metric_name: str, ) -> go.Figure: - """Sets the axis names for the given row and column. + if "&" in feature_name: + # Plot a contour plot for the second-order component. + x_parameter_name, y_parameter_name = feature_name.split(" & ") + df = _prepare_contour_data( + experiment=experiment, + model=model, + x_parameter_name=x_parameter_name, + y_parameter_name=y_parameter_name, + metric_name=metric_name, + ) - Args: - figure: The figure to update the axes on. - trace: The trace of the plot whose axis labels to update. - row: The row index of the trace in `figure`. - col: The column index of the trace in `figure`. - param_names: The parameter names to use for the axis names. + return _prepare_contour_plot( + df=df, + x_parameter_name=x_parameter_name, + y_parameter_name=y_parameter_name, + metric_name=metric_name, + log_x=is_axis_log_scale( + parameter=experiment.search_space.parameters[x_parameter_name] + ), + log_y=is_axis_log_scale( + parameter=experiment.search_space.parameters[y_parameter_name] + ), + ) - Returns: - A figure where the trace at (row, col) has its axis names set. - """ - if isinstance(trace, go.Contour): - X_name, Y_name = param_names - figure.update_xaxes(title_text=X_name, row=row, col=col) - figure.update_yaxes(title_text=Y_name, row=row, col=col) - else: - figure.update_xaxes(title_text=param_names, row=row, col=col) - return figure + # If the feature is a first-order component, plot a slice plot. + df = _prepare_slice_data( + experiment=experiment, + model=model, + parameter_name=feature_name, + metric_name=metric_name, + ) + + return _prepare_slice_plot( + df=df, + parameter_name=feature_name, + metric_name=metric_name, + log_x=is_axis_log_scale( + parameter=experiment.search_space.parameters[feature_name] + ), + ) diff --git a/ax/analysis/plotly/surface/contour.py b/ax/analysis/plotly/surface/contour.py index 743a114e640..a43d2b6e861 100644 --- a/ax/analysis/plotly/surface/contour.py +++ b/ax/analysis/plotly/surface/contour.py @@ -86,7 +86,6 @@ def compute( fig = _prepare_plot( df=df, - experiment=experiment, x_parameter_name=self.x_parameter_name, y_parameter_name=self.y_parameter_name, metric_name=metric_name, @@ -164,7 +163,6 @@ def _prepare_data( def _prepare_plot( df: pd.DataFrame, - experiment: Experiment, x_parameter_name: str, y_parameter_name: str, metric_name: str, diff --git a/ax/analysis/plotly/tests/test_interaction.py b/ax/analysis/plotly/tests/test_interaction.py index d60311a7223..196be8e10f3 100644 --- a/ax/analysis/plotly/tests/test_interaction.py +++ b/ax/analysis/plotly/tests/test_interaction.py @@ -1,4 +1,3 @@ -#!/usr/bin/env python3 # Copyright (c) Meta Platforms, Inc. and affiliates. # # This source code is licensed under the MIT license found in the @@ -6,192 +5,76 @@ # pyre-strict -import pandas as pd -import torch -from ax.analysis.analysis import AnalysisCard -from ax.analysis.plotly.interaction import ( - generate_interaction_component, - generate_main_effect_component, - get_model_kwargs, - InteractionPlot, - TOP_K_TOO_LARGE_ERROR, -) -from ax.exceptions.core import DataRequiredError, UserInputError - +from ax.analysis.analysis import AnalysisCardLevel +from ax.analysis.plotly.interaction import InteractionPlot +from ax.exceptions.core import UserInputError +from ax.service.ax_client import AxClient, ObjectiveProperties from ax.utils.common.testutils import TestCase -from ax.utils.testing.core_stubs import get_branin_experiment from ax.utils.testing.mock import mock_botorch_optimize -from botorch.models.kernels.orthogonal_additive_kernel import OrthogonalAdditiveKernel -from gpytorch.kernels import RBFKernel -from plotly import graph_objects as go -class InteractionTest(TestCase): - def test_interaction_get_model_kwargs(self) -> None: - kwargs = get_model_kwargs( - num_parameters=3, - use_interaction=False, - torch_device=torch.device("cpu"), +class TestInteractionPlot(TestCase): + @mock_botorch_optimize + def setUp(self) -> None: + super().setUp() + self.client = AxClient() + self.client.create_experiment( + is_test=True, + name="foo", + parameters=[ + { + "name": "x", + "type": "range", + "bounds": [-1.0, 1.0], + }, + { + "name": "y", + "type": "range", + "bounds": [-1.0, 1.0], + }, + ], + objectives={"bar": ObjectiveProperties(minimize=True)}, ) - self.assertEqual(kwargs["covar_module_class"], OrthogonalAdditiveKernel) - covar_module_options = kwargs["covar_module_options"] - self.assertIsInstance(covar_module_options["base_kernel"], RBFKernel) - self.assertEqual(covar_module_options["dim"], 3) - # Checks that we can retrieve the modelbridge that has interaction terms - kwargs = get_model_kwargs( - num_parameters=5, - use_interaction=True, - torch_device=torch.device("cpu"), - ) - self.assertEqual(kwargs["covar_module_class"], OrthogonalAdditiveKernel) - self.assertIsInstance(kwargs["covar_module_options"]["base_kernel"], RBFKernel) + for _ in range(10): + parameterization, trial_index = self.client.get_next_trial() + self.client.complete_trial( + trial_index=trial_index, + raw_data={ + "bar": parameterization["x"] ** 2 + parameterization["y"] ** 2 + }, + ) - @mock_botorch_optimize - def test_interaction_analysis_without_components(self) -> None: - exp = get_branin_experiment(with_completed_trial=True) - analysis = InteractionPlot( - metric_name="branin", - fit_interactions=False, - num_mc_samples=11, + def test_compute(self) -> None: + analysis = InteractionPlot(metric_name="bar") + + # Test that it fails if no Experiment is provided + with self.assertRaisesRegex(UserInputError, "requires an Experiment"): + analysis.compute() + + card = analysis.compute( + experiment=self.client.experiment, + generation_strategy=self.client.generation_strategy, ) - card = analysis.compute(experiment=exp) - self.assertIsInstance(card, AnalysisCard) - self.assertIsInstance(card.blob, str) - self.assertIsInstance(card.df, pd.DataFrame) self.assertEqual( card.name, - "Interaction Analysis", - ) - self.assertEqual( - card.title, - "Feature Importance Analysis for branin", + "InteractionPlot", ) + self.assertEqual(card.title, "Interaction Analysis for bar") self.assertEqual( card.subtitle, - "Displays the most important features for branin by order of importance.", - ) - - # with interaction terms - analysis = InteractionPlot( - metric_name="branin", - fit_interactions=True, - num_mc_samples=11, + "Understand an Experiment's data as one- or two-dimensional additive " + "components with sparsity. Important components are visualized through " + "slice or contour plots", ) - card = analysis.compute(experiment=exp) - self.assertIsInstance(card, AnalysisCard) - self.assertIsInstance(card.blob, str) - self.assertIsInstance(card.df, pd.DataFrame) - self.assertEqual(len(card.df), 3) + self.assertEqual(card.level, AnalysisCardLevel.MID) self.assertEqual( - card.subtitle, - "Displays the most important features for branin by order of importance.", + {*card.df.columns}, + {"feature", "sensitivity"}, ) + self.assertIsNotNone(card.blob) + self.assertEqual(card.blob_annotation, "plotly") - with self.assertRaisesRegex(UserInputError, TOP_K_TOO_LARGE_ERROR.format("7")): - InteractionPlot(metric_name="branin", top_k=7, display_components=True) - - analysis = InteractionPlot(metric_name="branout", fit_interactions=False) - with self.assertRaisesRegex( - DataRequiredError, "StandardizeY` transform requires non-empty data." - ): - analysis.compute(experiment=exp) - - @mock_botorch_optimize - def test_interaction_with_components(self) -> None: - exp = get_branin_experiment(with_completed_trial=True) - analysis = InteractionPlot( - metric_name="branin", - fit_interactions=True, - display_components=True, - num_mc_samples=11, - ) - card = analysis.compute(experiment=exp) - self.assertIsInstance(card, AnalysisCard) - self.assertIsInstance(card.blob, str) - self.assertIsInstance(card.df, pd.DataFrame) - self.assertEqual(len(card.df), 3) - - analysis = InteractionPlot( - metric_name="branin", - fit_interactions=True, - display_components=True, - top_k=2, - num_mc_samples=11, - ) - card = analysis.compute(experiment=exp) - self.assertIsInstance(card, AnalysisCard) - self.assertEqual(len(card.df), 2) - analysis = InteractionPlot( - metric_name="branin", - fit_interactions=True, - display_components=True, - model_fit_seed=999, - num_mc_samples=11, - ) - card = analysis.compute(experiment=exp) - self.assertIsInstance(card, AnalysisCard) - self.assertEqual(len(card.df), 3) - - @mock_botorch_optimize - def test_generate_main_effect_component(self) -> None: - exp = get_branin_experiment(with_completed_trial=True) - analysis = InteractionPlot( - metric_name="branin", - fit_interactions=True, - display_components=True, - num_mc_samples=11, - ) - density = 13 - model = analysis.get_model(experiment=exp) - comp, _, _ = generate_main_effect_component( - model=model, - component="x1", - metric="branin", - density=density, - ) - self.assertIsInstance(comp, go.Scatter) - self.assertEqual(comp["x"].shape, (density,)) - self.assertEqual(comp["y"].shape, (density,)) - self.assertEqual(comp["name"], "x1") - - with self.assertRaisesRegex(KeyError, "braninandout"): - generate_main_effect_component( - model=model, - component="x1", - metric="braninandout", - density=density, - ) - - @mock_botorch_optimize - def test_generate_interaction_component(self) -> None: - exp = get_branin_experiment(with_completed_trial=True) - analysis = InteractionPlot( - metric_name="branin", - fit_interactions=True, - display_components=True, - num_mc_samples=11, - ) - density = 3 - model = analysis.get_model(experiment=exp) - comp, _, _ = generate_interaction_component( - model=model, - component_x="x1", - component_y="x2", - metric="branin", - density=density, - ) - self.assertIsInstance(comp, go.Contour) - self.assertEqual(comp["x"].shape, (density,)) - self.assertEqual(comp["y"].shape, (density,)) - self.assertEqual(comp["z"].shape, (density, density)) - self.assertEqual(comp["name"], "x1 & x2") - - with self.assertRaisesRegex(KeyError, "braninandout"): - generate_interaction_component( - model=model, - component_x="x1", - component_y="x2", - metric="braninandout", - density=density, - ) + fig = card.get_figure() + # Ensure all subplots are present + self.assertEqual(len(fig.data), 6)