diff --git a/ax/analysis/plotly/__init__.py b/ax/analysis/plotly/__init__.py index 078ad5b594a..51fe32e6274 100644 --- a/ax/analysis/plotly/__init__.py +++ b/ax/analysis/plotly/__init__.py @@ -6,14 +6,20 @@ # pyre-strict from ax.analysis.plotly.cross_validation import CrossValidationPlot +from ax.analysis.plotly.interaction import InteractionPlot from ax.analysis.plotly.parallel_coordinates import ParallelCoordinatesPlot from ax.analysis.plotly.plotly_analysis import PlotlyAnalysis, PlotlyAnalysisCard from ax.analysis.plotly.scatter import ScatterPlot +from ax.analysis.plotly.surface.contour import ContourPlot +from ax.analysis.plotly.surface.slice import SlicePlot __all__ = [ + "ContourPlot", "CrossValidationPlot", + "InteractionPlot", "PlotlyAnalysis", "PlotlyAnalysisCard", "ParallelCoordinatesPlot", "ScatterPlot", + "SlicePlot", ] diff --git a/ax/analysis/plotly/interaction.py b/ax/analysis/plotly/interaction.py new file mode 100644 index 00000000000..96930499258 --- /dev/null +++ b/ax/analysis/plotly/interaction.py @@ -0,0 +1,340 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + + +from logging import Logger + +import pandas as pd +import torch +from ax.analysis.analysis import AnalysisCardLevel + +from ax.analysis.plotly.plotly_analysis import PlotlyAnalysis, PlotlyAnalysisCard + +from ax.analysis.plotly.surface.contour import ( + _prepare_data as _prepare_contour_data, + _prepare_plot as _prepare_contour_plot, +) +from ax.analysis.plotly.surface.slice import ( + _prepare_data as _prepare_slice_data, + _prepare_plot as _prepare_slice_plot, +) +from ax.analysis.plotly.surface.utils import is_axis_log_scale +from ax.analysis.plotly.utils import select_metric +from ax.core.experiment import Experiment +from ax.core.generation_strategy_interface import GenerationStrategyInterface +from ax.exceptions.core import UserInputError +from ax.modelbridge.registry import Models +from ax.modelbridge.torch import TorchModelBridge +from ax.models.torch.botorch_modular.surrogate import Surrogate +from ax.utils.common.logger import get_logger +from ax.utils.sensitivity.sobol_measures import ax_parameter_sens +from botorch.models.kernels.orthogonal_additive_kernel import OrthogonalAdditiveKernel + +from gpytorch.constraints import Positive +from gpytorch.kernels import RBFKernel +from gpytorch.priors import LogNormalPrior +from plotly import express as px, graph_objects as go +from plotly.subplots import make_subplots +from pyre_extensions import assert_is_instance + +logger: Logger = get_logger(__name__) + + +class InteractionPlot(PlotlyAnalysis): + """ + Analysis class which tries to explain the data of an experiment as one- or two- + dimensional additive components with a level of sparsity in the components. The + relative importance of each component is quantified by its Sobol index. Each + component may be visualized through slice or contour plots depending on if it is + a first order or second order component, respectively. + + The DataFrame computed will contain just the sensitivity analyisis with one row per + parameter and the following columns: + - feature: The name of the first or second order component + - sensitivity: The sensitivity of the component + """ + + def __init__( + self, + metric_name: str | None = None, + fit_interactions: bool = True, + most_important: bool = True, + seed: int = 0, + torch_device: torch.device | None = None, + ) -> None: + """ + Args: + metric_name: The metric to analyze. + fit_interactions: Whether to fit interaction effects in addition to main + effects. + most_important: Whether to sort by most or least important features in the + bar subplot. Also controls whether the six most or least important + features are plotted in the surface subplots. + seed: The seed with which to fit the model. Defaults to 0. Used + to ensure that the model fit is identical across the generation of + various plots. + torch_device: The torch device to use for the model. + """ + + self.metric_name = metric_name + self.fit_interactions = fit_interactions + self.most_important = most_important + self.seed = seed + self.torch_device = torch_device + + def compute( + self, + experiment: Experiment | None = None, + generation_strategy: GenerationStrategyInterface | None = None, + ) -> PlotlyAnalysisCard: + if experiment is None: + raise UserInputError("InteractionPlot requires an Experiment") + + metric_name = self.metric_name or select_metric(experiment=experiment) + + # Fix the seed to ensure that the model is fit identically across different + # analyses of the same experiment. + with torch.random.fork_rng(): + torch.torch.manual_seed(self.seed) + + # Fit the OAK model. + oak_model = self._get_oak_model( + experiment=experiment, metric_name=metric_name + ) + + # Calculate first- or second-order Sobol indices. + sens = ax_parameter_sens( + model_bridge=oak_model, + metrics=[metric_name], + order="second" if self.fit_interactions else "first", + signed=not self.fit_interactions, + )[metric_name] + + sensitivity_df = pd.DataFrame( + [*sens.items()], columns=["feature", "sensitivity"] + ).sort_values(by="sensitivity", key=abs, ascending=self.most_important) + + # Calculate feature importance bar plot. Only plot the top 15 features. + # Plot the absolute value of the sensitivity but color by the sign. + plotting_df = sensitivity_df.head(15).copy() + plotting_df["direction"] = plotting_df["sensitivity"].apply( + lambda x: "Increases Metric" if x >= 0 else "Decreases Metric" + ) + plotting_df["sensitivity"] = plotting_df["sensitivity"].abs() + + sensitivity_fig = px.bar( + plotting_df.sort_values( + by="sensitivity", key=abs, ascending=self.most_important + ), + x="sensitivity", + y="feature", + color="direction", + # Increase gets blue, decrease gets orange. + color_discrete_sequence=["orange", "blue"], + orientation="h", + ) + + # Calculate surface plots for six most or least important features + # Note: We use tail and reverse here because the bar plot is sorted from top to + # bottom. + top_features = [*reversed(sensitivity_df.tail(6)["feature"].to_list())] + surface_figs = [] + for feature_name in top_features: + try: + surface_figs.append( + _prepare_surface_plot( + experiment=experiment, + model=oak_model, + feature_name=feature_name, + metric_name=metric_name, + ) + ) + # Not all features will be able to be plotted, skip them gracefully. + except Exception as e: + logger.error(f"Failed to generate surface plot for {feature_name}: {e}") + + # Create a plotly figure to hold the bar plot in the top row and surface plots + # in a 3x2 grid below. + fig = make_subplots( + rows=4, + cols=3, + specs=[ + [{"colspan": 3}, None, None], + [{}, {}, {}], + [{}, {}, {}], + [{}, {}, {}], + ], + ) + + for trace in sensitivity_fig.data: + fig.add_trace(trace, row=1, col=1) + + for i in range(len(surface_figs)): + feature_name = top_features[i] + surface_fig = surface_figs[i] + + row = (i // 3) + 2 + col = (i % 3) + 1 + for trace in surface_fig.data: + fig.add_trace(trace, row=row, col=col) + + # Label and fix axes + if "&" in feature_name: # If the feature is a second-order component + x, y = feature_name.split(" & ") + + # Reapply log scale if necessary + fig.update_xaxes( + title_text=x, + type=( + "log" + if is_axis_log_scale( + parameter=experiment.search_space.parameters[x] + ) + else "linear" + ), + row=row, + col=col, + ) + fig.update_yaxes( + title_text=y, + type=( + "log" + if is_axis_log_scale( + parameter=experiment.search_space.parameters[y] + ) + else "linear" + ), + row=row, + col=col, + ) + else: # If the feature is a first-order component + fig.update_xaxes( + title_text=feature_name, + type=( + "log" + if is_axis_log_scale( + parameter=experiment.search_space.parameters[feature_name] + ) + else "linear" + ), + row=row, + col=col, + ) + + fig.update_layout( + height=1500, + width=1500, + ) + + subtitle_substring = ( + "one- or two-dimensional" if self.fit_interactions else "one-dimensional" + ) + + return self._create_plotly_analysis_card( + title=f"Interaction Analysis for {metric_name}", + subtitle=( + f"Understand an Experiment's data as {subtitle_substring} additive " + "components with sparsity. Important components are visualized through " + "slice or contour plots" + ), + level=AnalysisCardLevel.MID, + df=sensitivity_df, + fig=fig, + ) + + def _get_oak_model( + self, experiment: Experiment, metric_name: str + ) -> TorchModelBridge: + """ + Retrieves the modelbridge used for the analysis. The model uses an OAK + (Orthogonal Additive Kernel) with a sparsity-inducing prior, + which decomposes the objective into an additive sum of components. + + The kernel comes with a sparsity-inducing prior, which attempts explain the + data with as few components as possible. The smoothness of the components is + regularized by a lengthscale prior to guard against excessicely short + lengthscales being fit. + """ + data = experiment.lookup_data().filter(metric_names=[metric_name]) + model_bridge = Models.BOTORCH_MODULAR( + search_space=experiment.search_space, + experiment=experiment, + data=data, + surrogate=Surrogate( + covar_module_class=OrthogonalAdditiveKernel, + covar_module_options={ + # A fairly restrictive prior on the lengthscale avoids spurious + # fits, where a single component is fit to explain all + # variability. + # NOTE (hvarfner) Imposing a calibrated sparsity-inducing prior + # is probably a good add, but second-order components tend to + # break down even for weak priors. + "base_kernel": RBFKernel( + ard_num_dims=len(experiment.search_space.tunable_parameters), + lengthscale_prior=LogNormalPrior(2, 1), + ), + "dim": len(experiment.search_space.tunable_parameters), + "dtype": torch.float64, + "device": self.torch_device, + "second_order": self.fit_interactions, + "coeff_constraint": Positive( + transform=torch.exp, inv_transform=torch.log + ), + }, + allow_batched_models=False, + ), + ) + + return assert_is_instance(model_bridge, TorchModelBridge) + + +def _prepare_surface_plot( + experiment: Experiment, + model: TorchModelBridge, + feature_name: str, + metric_name: str, +) -> go.Figure: + if "&" in feature_name: + # Plot a contour plot for the second-order component. + x_parameter_name, y_parameter_name = feature_name.split(" & ") + df = _prepare_contour_data( + experiment=experiment, + model=model, + x_parameter_name=x_parameter_name, + y_parameter_name=y_parameter_name, + metric_name=metric_name, + ) + + return _prepare_contour_plot( + df=df, + x_parameter_name=x_parameter_name, + y_parameter_name=y_parameter_name, + metric_name=metric_name, + log_x=is_axis_log_scale( + parameter=experiment.search_space.parameters[x_parameter_name] + ), + log_y=is_axis_log_scale( + parameter=experiment.search_space.parameters[y_parameter_name] + ), + ) + + # If the feature is a first-order component, plot a slice plot. + df = _prepare_slice_data( + experiment=experiment, + model=model, + parameter_name=feature_name, + metric_name=metric_name, + ) + + return _prepare_slice_plot( + df=df, + parameter_name=feature_name, + metric_name=metric_name, + log_x=is_axis_log_scale( + parameter=experiment.search_space.parameters[feature_name] + ), + ) diff --git a/ax/analysis/plotly/surface/__init__.py b/ax/analysis/plotly/surface/__init__.py new file mode 100644 index 00000000000..f22e65e2769 --- /dev/null +++ b/ax/analysis/plotly/surface/__init__.py @@ -0,0 +1,9 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from ax.analysis.plotly.surface.contour import ContourPlot +from ax.analysis.plotly.surface.slice import SlicePlot + +__all__ = ["ContourPlot", "SlicePlot"] diff --git a/ax/analysis/plotly/surface/contour.py b/ax/analysis/plotly/surface/contour.py new file mode 100644 index 00000000000..ff96222382e --- /dev/null +++ b/ax/analysis/plotly/surface/contour.py @@ -0,0 +1,209 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +import math +from typing import Optional + +import pandas as pd +from ax.analysis.analysis import AnalysisCardLevel + +from ax.analysis.plotly.plotly_analysis import PlotlyAnalysis, PlotlyAnalysisCard +from ax.analysis.plotly.surface.utils import ( + get_parameter_values, + is_axis_log_scale, + select_fixed_value, +) +from ax.analysis.plotly.utils import select_metric +from ax.core.experiment import Experiment +from ax.core.generation_strategy_interface import GenerationStrategyInterface +from ax.core.observation import ObservationFeatures +from ax.exceptions.core import UserInputError +from ax.modelbridge.base import ModelBridge +from ax.modelbridge.generation_strategy import GenerationStrategy +from plotly import graph_objects as go +from pyre_extensions import none_throws + + +class ContourPlot(PlotlyAnalysis): + """ + Plot a 2D surface of the surrogate model's predicted outcomes for a given pair of + parameters, where all other parameters are held fixed at their status-quo value or + mean if no status quo is available. + + The DataFrame computed will contain the following columns: + - PARAMETER_NAME: The value of the x parameter specified + - PARAMETER_NAME: The value of the y parameter specified + - METRIC_NAME: The predected mean of the metric specified + """ + + def __init__( + self, + x_parameter_name: str, + y_parameter_name: str, + metric_name: str | None = None, + ) -> None: + """ + Args: + y_parameter_name: The name of the parameter to plot on the x-axis. + y_parameter_name: The name of the parameter to plot on the y-axis. + metric_name: The name of the metric to plot + """ + self.x_parameter_name = x_parameter_name + self.y_parameter_name = y_parameter_name + self.metric_name = metric_name + + def compute( + self, + experiment: Optional[Experiment] = None, + generation_strategy: Optional[GenerationStrategyInterface] = None, + ) -> PlotlyAnalysisCard: + if experiment is None: + raise UserInputError("ContourPlot requires an Experiment") + + if not isinstance(generation_strategy, GenerationStrategy): + raise UserInputError("ContourPlot requires a GenerationStrategy") + + if generation_strategy.model is None: + generation_strategy._fit_current_model(None) + + metric_name = self.metric_name or select_metric(experiment=experiment) + + df = _prepare_data( + experiment=experiment, + model=none_throws(generation_strategy.model), + x_parameter_name=self.x_parameter_name, + y_parameter_name=self.y_parameter_name, + metric_name=metric_name, + ) + + fig = _prepare_plot( + df=df, + x_parameter_name=self.x_parameter_name, + y_parameter_name=self.y_parameter_name, + metric_name=metric_name, + log_x=is_axis_log_scale( + parameter=experiment.search_space.parameters[self.x_parameter_name] + ), + log_y=is_axis_log_scale( + parameter=experiment.search_space.parameters[self.y_parameter_name] + ), + ) + + return self._create_plotly_analysis_card( + title=( + f"{self.x_parameter_name}, {self.y_parameter_name} vs. {metric_name}" + ), + subtitle=( + "2D contour of the surrogate model's predicted outcomes for " + f"{metric_name}" + ), + level=AnalysisCardLevel.LOW, + df=df, + fig=fig, + ) + + +def _prepare_data( + experiment: Experiment, + model: ModelBridge, + x_parameter_name: str, + y_parameter_name: str, + metric_name: str, +) -> pd.DataFrame: + # Choose which parameter values to predict points for. + xs = get_parameter_values( + parameter=experiment.search_space.parameters[x_parameter_name], density=10 + ) + ys = get_parameter_values( + parameter=experiment.search_space.parameters[y_parameter_name], density=10 + ) + + # Construct observation features for each parameter value previously chosen by + # fixing all other parameters to their status-quo value or mean. + features = [ + ObservationFeatures( + parameters={ + x_parameter_name: x, + y_parameter_name: y, + **{ + parameter.name: select_fixed_value(parameter=parameter) + for parameter in experiment.search_space.parameters.values() + if not ( + parameter.name == x_parameter_name + or parameter.name == y_parameter_name + ) + }, + } + ) + for x in xs + for y in ys + ] + + predictions = model.predict(observation_features=features) + + return pd.DataFrame.from_records( + [ + { + x_parameter_name: features[i].parameters[x_parameter_name], + y_parameter_name: features[i].parameters[y_parameter_name], + f"{metric_name}_mean": predictions[0][metric_name][i], + } + for i in range(len(features)) + ] + ) + + +def _prepare_plot( + df: pd.DataFrame, + x_parameter_name: str, + y_parameter_name: str, + metric_name: str, + log_x: bool, + log_y: bool, +) -> go.Figure: + z_grid = df.pivot( + index=y_parameter_name, columns=x_parameter_name, values=f"{metric_name}_mean" + ) + + fig = go.Figure( + data=go.Contour( + z=z_grid.values, + x=z_grid.columns.values, + y=z_grid.index.values, + contours_coloring="heatmap", + showscale=False, + ), + layout=go.Layout( + xaxis_title=x_parameter_name, + yaxis_title=y_parameter_name, + ), + ) + + # Set the x-axis scale to log if relevant + if log_x: + fig.update_xaxes( + type="log", + range=[ + math.log10(df[x_parameter_name].min()), + math.log10(df[x_parameter_name].max()), + ], + ) + else: + fig.update_xaxes(range=[df[x_parameter_name].min(), df[x_parameter_name].max()]) + + if log_y: + fig.update_yaxes( + type="log", + range=[ + math.log10(df[y_parameter_name].min()), + math.log10(df[y_parameter_name].max()), + ], + ) + else: + fig.update_yaxes(range=[df[y_parameter_name].min(), df[y_parameter_name].max()]) + + return fig diff --git a/ax/analysis/plotly/surface/slice.py b/ax/analysis/plotly/surface/slice.py new file mode 100644 index 00000000000..129f4275824 --- /dev/null +++ b/ax/analysis/plotly/surface/slice.py @@ -0,0 +1,197 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +import math +from typing import Optional + +import pandas as pd +from ax.analysis.analysis import AnalysisCardLevel + +from ax.analysis.plotly.plotly_analysis import PlotlyAnalysis, PlotlyAnalysisCard +from ax.analysis.plotly.surface.utils import ( + get_parameter_values, + is_axis_log_scale, + select_fixed_value, +) +from ax.analysis.plotly.utils import select_metric +from ax.core.experiment import Experiment +from ax.core.generation_strategy_interface import GenerationStrategyInterface +from ax.core.observation import ObservationFeatures +from ax.exceptions.core import UserInputError +from ax.modelbridge.base import ModelBridge +from ax.modelbridge.generation_strategy import GenerationStrategy +from plotly import express as px, graph_objects as go +from pyre_extensions import none_throws + + +class SlicePlot(PlotlyAnalysis): + """ + Plot a 1D "slice" of the surrogate model's predicted outcomes for a given + parameter, where all other parameters are held fixed at their status-quo value or + mean if no status quo is available. + + The DataFrame computed will contain the following columns: + - PARAMETER_NAME: The value of the parameter specified + - METRIC_NAME_mean: The predected mean of the metric specified + - METRIC_NAME_sem: The predected sem of the metric specified + """ + + def __init__( + self, + parameter_name: str, + metric_name: str | None = None, + ) -> None: + """ + Args: + parameter_name: The name of the parameter to plot on the x axis. + metric_name: The name of the metric to plot on the y axis. If not + specified the objective will be used. + """ + self.parameter_name = parameter_name + self.metric_name = metric_name + + def compute( + self, + experiment: Optional[Experiment] = None, + generation_strategy: Optional[GenerationStrategyInterface] = None, + ) -> PlotlyAnalysisCard: + if experiment is None: + raise UserInputError("SlicePlot requires an Experiment") + + if not isinstance(generation_strategy, GenerationStrategy): + raise UserInputError("SlicePlot requires a GenerationStrategy") + + if generation_strategy.model is None: + generation_strategy._fit_current_model(None) + + metric_name = self.metric_name or select_metric(experiment=experiment) + + df = _prepare_data( + experiment=experiment, + model=none_throws(generation_strategy.model), + parameter_name=self.parameter_name, + metric_name=metric_name, + ) + + fig = _prepare_plot( + df=df, + parameter_name=self.parameter_name, + metric_name=metric_name, + log_x=is_axis_log_scale( + parameter=experiment.search_space.parameters[self.parameter_name] + ), + ) + + return self._create_plotly_analysis_card( + title=f"{self.parameter_name} vs. {metric_name}", + subtitle=( + "1D slice of the surrogate model's predicted outcomes for " + f"{metric_name}" + ), + level=AnalysisCardLevel.LOW, + df=df, + fig=fig, + ) + + +def _prepare_data( + experiment: Experiment, + model: ModelBridge, + parameter_name: str, + metric_name: str, +) -> pd.DataFrame: + # Choose which parameter values to predict points for. + xs = get_parameter_values( + parameter=experiment.search_space.parameters[parameter_name] + ) + + # Construct observation features for each parameter value previously chosen by + # fixing all other parameters to their status-quo value or mean. + features = [ + ObservationFeatures( + parameters={ + parameter_name: x, + **{ + parameter.name: select_fixed_value(parameter=parameter) + for parameter in experiment.search_space.parameters.values() + if parameter.name != parameter_name + }, + } + ) + for x in xs + ] + + predictions = model.predict(observation_features=features) + + return pd.DataFrame.from_records( + [ + { + parameter_name: xs[i], + f"{metric_name}_mean": predictions[0][metric_name][i], + f"{metric_name}_sem": predictions[1][metric_name][metric_name][i], + } + for i in range(len(xs)) + ] + ).sort_values(by=parameter_name) + + +def _prepare_plot( + df: pd.DataFrame, + parameter_name: str, + metric_name: str, + log_x: bool = False, +) -> go.Figure: + x = df[parameter_name].tolist() + y = df[f"{metric_name}_mean"].tolist() + y_upper = (df[f"{metric_name}_mean"] + 1.96 * df[f"{metric_name}_sem"]).tolist() + y_lower = (df[f"{metric_name}_mean"] - 1.96 * df[f"{metric_name}_sem"]).tolist() + + plotly_blue = px.colors.qualitative.Plotly[0] + plotly_blue_translucent = "rgba(99, 110, 250, 0.2)" + + # Draw a line at the mean and a shaded region between the upper and lower bounds + line = go.Scatter( + x=x, + y=y, + line={"color": plotly_blue}, + mode="lines", + name=metric_name, + showlegend=False, + ) + error_band = go.Scatter( + # Concatenate x values in reverse order to create a closed polygon + x=x + x[::-1], + # Concatenate upper and lower bounds in reverse order + y=y_upper + y_lower[::-1], + fill="toself", + fillcolor=plotly_blue_translucent, + line={"color": "rgba(255,255,255,0)"}, # Make "line" transparent + hoverinfo="skip", + showlegend=False, + ) + + fig = go.Figure( + [line, error_band], + layout=go.Layout( + xaxis_title=parameter_name, + yaxis_title=metric_name, + ), + ) + + # Set the x-axis scale to log if relevant + if log_x: + fig.update_xaxes( + type="log", + range=[ + math.log10(df[parameter_name].min()), + math.log10(df[parameter_name].max()), + ], + ) + else: + fig.update_xaxes(range=[df[parameter_name].min(), df[parameter_name].max()]) + + return fig diff --git a/ax/analysis/plotly/surface/tests/test_contour.py b/ax/analysis/plotly/surface/tests/test_contour.py new file mode 100644 index 00000000000..6deec31ae4a --- /dev/null +++ b/ax/analysis/plotly/surface/tests/test_contour.py @@ -0,0 +1,83 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +from ax.analysis.analysis import AnalysisCardLevel +from ax.analysis.plotly.surface.contour import ContourPlot +from ax.exceptions.core import UserInputError +from ax.service.ax_client import AxClient, ObjectiveProperties +from ax.utils.common.testutils import TestCase +from ax.utils.testing.mock import mock_botorch_optimize + + +class TestContourPlot(TestCase): + @mock_botorch_optimize + def setUp(self) -> None: + super().setUp() + self.client = AxClient() + self.client.create_experiment( + is_test=True, + name="foo", + parameters=[ + { + "name": "x", + "type": "range", + "bounds": [-1.0, 1.0], + }, + { + "name": "y", + "type": "range", + "bounds": [-1.0, 1.0], + }, + ], + objectives={"bar": ObjectiveProperties(minimize=True)}, + ) + + for _ in range(10): + parameterization, trial_index = self.client.get_next_trial() + self.client.complete_trial( + trial_index=trial_index, + raw_data={ + "bar": parameterization["x"] ** 2 + parameterization["y"] ** 2 + }, + ) + + def test_compute(self) -> None: + analysis = ContourPlot( + x_parameter_name="x", y_parameter_name="y", metric_name="bar" + ) + + # Test that it fails if no Experiment is provided + with self.assertRaisesRegex(UserInputError, "requires an Experiment"): + analysis.compute() + # Test that it fails if no GenerationStrategy is provided + with self.assertRaisesRegex(UserInputError, "requires a GenerationStrategy"): + analysis.compute(experiment=self.client.experiment) + + card = analysis.compute( + experiment=self.client.experiment, + generation_strategy=self.client.generation_strategy, + ) + self.assertEqual( + card.name, + "ContourPlot", + ) + self.assertEqual(card.title, "x, y vs. bar") + self.assertEqual( + card.subtitle, + "2D contour of the surrogate model's predicted outcomes for bar", + ) + self.assertEqual(card.level, AnalysisCardLevel.LOW) + self.assertEqual( + {*card.df.columns}, + { + "x", + "y", + "bar_mean", + }, + ) + self.assertIsNotNone(card.blob) + self.assertEqual(card.blob_annotation, "plotly") diff --git a/ax/analysis/plotly/surface/tests/test_slice.py b/ax/analysis/plotly/surface/tests/test_slice.py new file mode 100644 index 00000000000..557a7665c37 --- /dev/null +++ b/ax/analysis/plotly/surface/tests/test_slice.py @@ -0,0 +1,73 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +from ax.analysis.analysis import AnalysisCardLevel +from ax.analysis.plotly.surface.slice import SlicePlot +from ax.exceptions.core import UserInputError +from ax.service.ax_client import AxClient, ObjectiveProperties +from ax.utils.common.testutils import TestCase +from ax.utils.testing.mock import mock_botorch_optimize + + +class TestSlicePlot(TestCase): + @mock_botorch_optimize + def setUp(self) -> None: + super().setUp() + self.client = AxClient() + self.client.create_experiment( + is_test=True, + name="foo", + parameters=[ + { + "name": "x", + "type": "range", + "bounds": [-1.0, 1.0], + } + ], + objectives={"bar": ObjectiveProperties(minimize=True)}, + ) + + for _ in range(10): + parameterization, trial_index = self.client.get_next_trial() + self.client.complete_trial( + trial_index=trial_index, raw_data={"bar": parameterization["x"] ** 2} + ) + + def test_compute(self) -> None: + analysis = SlicePlot(parameter_name="x", metric_name="bar") + + # Test that it fails if no Experiment is provided + with self.assertRaisesRegex(UserInputError, "requires an Experiment"): + analysis.compute() + # Test that it fails if no GenerationStrategy is provided + with self.assertRaisesRegex(UserInputError, "requires a GenerationStrategy"): + analysis.compute(experiment=self.client.experiment) + + card = analysis.compute( + experiment=self.client.experiment, + generation_strategy=self.client.generation_strategy, + ) + self.assertEqual( + card.name, + "SlicePlot", + ) + self.assertEqual(card.title, "x vs. bar") + self.assertEqual( + card.subtitle, + "1D slice of the surrogate model's predicted outcomes for bar", + ) + self.assertEqual(card.level, AnalysisCardLevel.LOW) + self.assertEqual( + {*card.df.columns}, + { + "x", + "bar_mean", + "bar_sem", + }, + ) + self.assertIsNotNone(card.blob) + self.assertEqual(card.blob_annotation, "plotly") diff --git a/ax/analysis/plotly/surface/utils.py b/ax/analysis/plotly/surface/utils.py new file mode 100644 index 00000000000..4b8acd632ee --- /dev/null +++ b/ax/analysis/plotly/surface/utils.py @@ -0,0 +1,62 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict +import math + +import numpy as np +from ax.core.parameter import ( + ChoiceParameter, + FixedParameter, + Parameter, + RangeParameter, + TParamValue, +) + + +def get_parameter_values(parameter: Parameter, density: int = 100) -> list[TParamValue]: + """ + Get a list of parameter values to predict over for a given parameter. + """ + + # For RangeParameter use linspace for the range of the parameter + if isinstance(parameter, RangeParameter): + if parameter.log_scale: + return np.logspace( + math.log10(parameter.lower), math.log10(parameter.upper), density + ).tolist() + + return np.linspace(parameter.lower, parameter.upper, density).tolist() + + # For ChoiceParameter use the values of the parameter directly + if isinstance(parameter, ChoiceParameter) and parameter.is_ordered: + return parameter.values + + raise ValueError( + f"Parameter {parameter.name} must be a RangeParameter or " + "ChoiceParameter with is_ordered=True to be used in surface plot." + ) + + +def select_fixed_value(parameter: Parameter) -> TParamValue: + """ + Select a fixed value for a parameter. Use mean for RangeParameter, "middle" value + for ChoiceParameter, and value for FixedParameter. + """ + if isinstance(parameter, RangeParameter): + return (parameter.lower * 1.0 + parameter.upper) / 2 + elif isinstance(parameter, ChoiceParameter): + return parameter.values[len(parameter.values) // 2] + elif isinstance(parameter, FixedParameter): + return parameter.value + else: + raise ValueError(f"Got unexpected parameter type {parameter}.") + + +def is_axis_log_scale(parameter: Parameter) -> bool: + """ + Check if the parameter is log scale. + """ + return isinstance(parameter, RangeParameter) and parameter.log_scale diff --git a/ax/analysis/plotly/tests/test_interaction.py b/ax/analysis/plotly/tests/test_interaction.py new file mode 100644 index 00000000000..196be8e10f3 --- /dev/null +++ b/ax/analysis/plotly/tests/test_interaction.py @@ -0,0 +1,80 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +from ax.analysis.analysis import AnalysisCardLevel +from ax.analysis.plotly.interaction import InteractionPlot +from ax.exceptions.core import UserInputError +from ax.service.ax_client import AxClient, ObjectiveProperties +from ax.utils.common.testutils import TestCase +from ax.utils.testing.mock import mock_botorch_optimize + + +class TestInteractionPlot(TestCase): + @mock_botorch_optimize + def setUp(self) -> None: + super().setUp() + self.client = AxClient() + self.client.create_experiment( + is_test=True, + name="foo", + parameters=[ + { + "name": "x", + "type": "range", + "bounds": [-1.0, 1.0], + }, + { + "name": "y", + "type": "range", + "bounds": [-1.0, 1.0], + }, + ], + objectives={"bar": ObjectiveProperties(minimize=True)}, + ) + + for _ in range(10): + parameterization, trial_index = self.client.get_next_trial() + self.client.complete_trial( + trial_index=trial_index, + raw_data={ + "bar": parameterization["x"] ** 2 + parameterization["y"] ** 2 + }, + ) + + def test_compute(self) -> None: + analysis = InteractionPlot(metric_name="bar") + + # Test that it fails if no Experiment is provided + with self.assertRaisesRegex(UserInputError, "requires an Experiment"): + analysis.compute() + + card = analysis.compute( + experiment=self.client.experiment, + generation_strategy=self.client.generation_strategy, + ) + self.assertEqual( + card.name, + "InteractionPlot", + ) + self.assertEqual(card.title, "Interaction Analysis for bar") + self.assertEqual( + card.subtitle, + "Understand an Experiment's data as one- or two-dimensional additive " + "components with sparsity. Important components are visualized through " + "slice or contour plots", + ) + self.assertEqual(card.level, AnalysisCardLevel.MID) + self.assertEqual( + {*card.df.columns}, + {"feature", "sensitivity"}, + ) + self.assertIsNotNone(card.blob) + self.assertEqual(card.blob_annotation, "plotly") + + fig = card.get_figure() + # Ensure all subplots are present + self.assertEqual(len(fig.data), 6)