From 82b5929cd3428e920ec78ed58bee6b0d4ad1e3ed Mon Sep 17 00:00:00 2001 From: Miles Olson Date: Fri, 21 Feb 2025 13:59:04 -0800 Subject: [PATCH 1/4] Fix CI calculation in SlicePlot (#3404) Summary: As titled. Added comments as well to show whats going on. Differential Revision: D69987525 --- ax/analysis/plotly/surface/slice.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/ax/analysis/plotly/surface/slice.py b/ax/analysis/plotly/surface/slice.py index d0cbc6a7a7c..ae6958a58a6 100644 --- a/ax/analysis/plotly/surface/slice.py +++ b/ax/analysis/plotly/surface/slice.py @@ -130,7 +130,8 @@ def _prepare_data( { parameter_name: xs[i], f"{metric_name}_mean": predictions[0][metric_name][i], - f"{metric_name}_sem": predictions[1][metric_name][metric_name][i], + f"{metric_name}_sem": predictions[1][metric_name][metric_name][i] + ** 0.5, # Convert the variance to the SEM } for i in range(len(xs)) ] @@ -145,6 +146,7 @@ def _prepare_plot( ) -> go.Figure: x = df[parameter_name].tolist() y = df[f"{metric_name}_mean"].tolist() + # Convert the SEMs to 95% confidence intervals y_upper = (df[f"{metric_name}_mean"] + 1.96 * df[f"{metric_name}_sem"]).tolist() y_lower = (df[f"{metric_name}_mean"] - 1.96 * df[f"{metric_name}_sem"]).tolist() From d82564b5bb2b37e71365cec307a703467ee5ea13 Mon Sep 17 00:00:00 2001 From: Miles Olson Date: Fri, 21 Feb 2025 13:59:04 -0800 Subject: [PATCH 2/4] Add black "x"s at sampled x coordinates in slice plot (#3405) Summary: As titled. These are moderately useful here, but I really want to add them to the contour plot and I want the slice plot to match since theyre often displayed together. Differential Revision: D69988707 --- ax/analysis/plotly/interaction.py | 1 + ax/analysis/plotly/surface/slice.py | 59 +++++++++++++++---- .../plotly/surface/tests/test_slice.py | 8 +-- 3 files changed, 51 insertions(+), 17 deletions(-) diff --git a/ax/analysis/plotly/interaction.py b/ax/analysis/plotly/interaction.py index cd29d38ae71..a01ad63db07 100644 --- a/ax/analysis/plotly/interaction.py +++ b/ax/analysis/plotly/interaction.py @@ -350,4 +350,5 @@ def _prepare_surface_plot( log_x=is_axis_log_scale( parameter=experiment.search_space.parameters[feature_name] ), + display_sampled=True, ) diff --git a/ax/analysis/plotly/surface/slice.py b/ax/analysis/plotly/surface/slice.py index ae6958a58a6..58abec24d02 100644 --- a/ax/analysis/plotly/surface/slice.py +++ b/ax/analysis/plotly/surface/slice.py @@ -36,21 +36,26 @@ class SlicePlot(PlotlyAnalysis): - PARAMETER_NAME: The value of the parameter specified - METRIC_NAME_mean: The predected mean of the metric specified - METRIC_NAME_sem: The predected sem of the metric specified + - sampled: Whether the parameter value was sampled in at least one trial """ def __init__( self, parameter_name: str, metric_name: str | None = None, + display_sampled: bool = True, ) -> None: """ Args: parameter_name: The name of the parameter to plot on the x axis. metric_name: The name of the metric to plot on the y axis. If not specified the objective will be used. + display_sampled: If True, plot "x"s at x coordinates which have been + sampled in at least one trial. """ self.parameter_name = parameter_name self.metric_name = metric_name + self._display_sampled = display_sampled def compute( self, @@ -82,6 +87,7 @@ def compute( log_x=is_axis_log_scale( parameter=experiment.search_space.parameters[self.parameter_name] ), + display_sampled=self._display_sampled, ) return self._create_plotly_analysis_card( @@ -102,10 +108,16 @@ def _prepare_data( parameter_name: str, metric_name: str, ) -> pd.DataFrame: + sampled_xs = [ + arm.parameters[parameter_name] + for trial in experiment.trials.values() + for arm in trial.arms + ] # Choose which parameter values to predict points for. - xs = get_parameter_values( + unsampled_xs = get_parameter_values( parameter=experiment.search_space.parameters[parameter_name] ) + xs = [*sampled_xs, *unsampled_xs] # Construct observation features for each parameter value previously chosen by # fixing all other parameters to their status-quo value or mean. @@ -125,16 +137,19 @@ def _prepare_data( predictions = model.predict(observation_features=features) - return pd.DataFrame.from_records( - [ - { - parameter_name: xs[i], - f"{metric_name}_mean": predictions[0][metric_name][i], - f"{metric_name}_sem": predictions[1][metric_name][metric_name][i] - ** 0.5, # Convert the variance to the SEM - } - for i in range(len(xs)) - ] + return none_throws( + pd.DataFrame.from_records( + [ + { + parameter_name: xs[i], + f"{metric_name}_mean": predictions[0][metric_name][i], + f"{metric_name}_sem": predictions[1][metric_name][metric_name][i] + ** 0.5, # Convert the variance to the SEM + "sampled": xs[i] in sampled_xs, + } + for i in range(len(xs)) + ] + ).drop_duplicates() ).sort_values(by=parameter_name) @@ -142,10 +157,12 @@ def _prepare_plot( df: pd.DataFrame, parameter_name: str, metric_name: str, - log_x: bool = False, + log_x: bool, + display_sampled: bool, ) -> go.Figure: x = df[parameter_name].tolist() y = df[f"{metric_name}_mean"].tolist() + # Convert the SEMs to 95% confidence intervals y_upper = (df[f"{metric_name}_mean"] + 1.96 * df[f"{metric_name}_sem"]).tolist() y_lower = (df[f"{metric_name}_mean"] - 1.96 * df[f"{metric_name}_sem"]).tolist() @@ -182,6 +199,24 @@ def _prepare_plot( ), ) + if display_sampled: + x_sampled = df[df["sampled"]][parameter_name].tolist() + y_sampled = df[df["sampled"]][f"{metric_name}_mean"].tolist() + + samples = go.Scatter( + x=x_sampled, + y=y_sampled, + mode="markers", + marker={ + "symbol": "x", + "color": "black", + }, + name=f"Sampled {parameter_name}", + showlegend=False, + ) + + fig.add_trace(samples) + # Set the x-axis scale to log if relevant if log_x: fig.update_xaxes( diff --git a/ax/analysis/plotly/surface/tests/test_slice.py b/ax/analysis/plotly/surface/tests/test_slice.py index 557a7665c37..5b1b3ab9520 100644 --- a/ax/analysis/plotly/surface/tests/test_slice.py +++ b/ax/analysis/plotly/surface/tests/test_slice.py @@ -63,11 +63,9 @@ def test_compute(self) -> None: self.assertEqual(card.level, AnalysisCardLevel.LOW) self.assertEqual( {*card.df.columns}, - { - "x", - "bar_mean", - "bar_sem", - }, + {"x", "bar_mean", "bar_sem", "sampled"}, ) self.assertIsNotNone(card.blob) self.assertEqual(card.blob_annotation, "plotly") + + self.assertEqual(card.df["sampled"].sum(), len(self.client.experiment.trials)) From d9826bfe59f0356678f917cf3f53ed7ff8d03bad Mon Sep 17 00:00:00 2001 From: Miles Olson Date: Fri, 21 Feb 2025 13:59:04 -0800 Subject: [PATCH 3/4] Add black "x"s at sampled coordinates in contour plot (#3406) Summary: As titled. Differential Revision: D69989849 --- ax/analysis/plotly/interaction.py | 1 + ax/analysis/plotly/surface/contour.py | 63 +++++++++++++++---- .../plotly/surface/tests/test_contour.py | 3 + 3 files changed, 56 insertions(+), 11 deletions(-) diff --git a/ax/analysis/plotly/interaction.py b/ax/analysis/plotly/interaction.py index a01ad63db07..b38bac2cfd9 100644 --- a/ax/analysis/plotly/interaction.py +++ b/ax/analysis/plotly/interaction.py @@ -333,6 +333,7 @@ def _prepare_surface_plot( log_y=is_axis_log_scale( parameter=experiment.search_space.parameters[y_parameter_name] ), + display_sampled=True, ) # If the feature is a first-order component, plot a slice plot. diff --git a/ax/analysis/plotly/surface/contour.py b/ax/analysis/plotly/surface/contour.py index 60bb32a7ede..7bed9cc9b53 100644 --- a/ax/analysis/plotly/surface/contour.py +++ b/ax/analysis/plotly/surface/contour.py @@ -36,6 +36,7 @@ class ContourPlot(PlotlyAnalysis): - PARAMETER_NAME: The value of the x parameter specified - PARAMETER_NAME: The value of the y parameter specified - METRIC_NAME: The predected mean of the metric specified + - sampled: Whether the parameter values were sampled in at least one trial """ def __init__( @@ -43,12 +44,15 @@ def __init__( x_parameter_name: str, y_parameter_name: str, metric_name: str | None = None, + display_sampled: bool = True, ) -> None: """ Args: y_parameter_name: The name of the parameter to plot on the x-axis. y_parameter_name: The name of the parameter to plot on the y-axis. metric_name: The name of the metric to plot + display_sampled: If True, plot "x"s at x coordinates which have been + sampled in at least one trial. """ # TODO: Add a flag to specify whether or not to plot markers at the (x, y) # coordinates of arms (with hover text). This is fine to exlude for now because @@ -57,6 +61,7 @@ def __init__( self.x_parameter_name = x_parameter_name self.y_parameter_name = y_parameter_name self.metric_name = metric_name + self._display_sampled = display_sampled def compute( self, @@ -93,6 +98,7 @@ def compute( log_y=is_axis_log_scale( parameter=experiment.search_space.parameters[self.y_parameter_name] ), + display_sampled=self._display_sampled, ) return self._create_plotly_analysis_card( @@ -116,14 +122,23 @@ def _prepare_data( y_parameter_name: str, metric_name: str, ) -> pd.DataFrame: + sampled = [ + (arm.parameters[x_parameter_name], arm.parameters[y_parameter_name]) + for trial in experiment.trials.values() + for arm in trial.arms + ] + # Choose which parameter values to predict points for. - xs = get_parameter_values( + unsampled_xs = get_parameter_values( parameter=experiment.search_space.parameters[x_parameter_name], density=10 ) - ys = get_parameter_values( + unsampled_ys = get_parameter_values( parameter=experiment.search_space.parameters[y_parameter_name], density=10 ) + xs = [*[sample[0] for sample in sampled], *unsampled_xs] + ys = [*[sample[1] for sample in sampled], *unsampled_ys] + # Construct observation features for each parameter value previously chosen by # fixing all other parameters to their status-quo value or mean. features = [ @@ -147,15 +162,22 @@ def _prepare_data( predictions = model.predict(observation_features=features) - return pd.DataFrame.from_records( - [ - { - x_parameter_name: features[i].parameters[x_parameter_name], - y_parameter_name: features[i].parameters[y_parameter_name], - f"{metric_name}_mean": predictions[0][metric_name][i], - } - for i in range(len(features)) - ] + return none_throws( + pd.DataFrame.from_records( + [ + { + x_parameter_name: features[i].parameters[x_parameter_name], + y_parameter_name: features[i].parameters[y_parameter_name], + f"{metric_name}_mean": predictions[0][metric_name][i], + "sampled": ( + features[i].parameters[x_parameter_name], + features[i].parameters[y_parameter_name], + ) + in sampled, + } + for i in range(len(features)) + ] + ).drop_duplicates() ) @@ -166,6 +188,7 @@ def _prepare_plot( metric_name: str, log_x: bool, log_y: bool, + display_sampled: bool, ) -> go.Figure: z_grid = df.pivot( index=y_parameter_name, columns=x_parameter_name, values=f"{metric_name}_mean" @@ -185,6 +208,24 @@ def _prepare_plot( ), ) + if display_sampled: + x_sampled = df[df["sampled"]][x_parameter_name].tolist() + y_sampled = df[df["sampled"]][y_parameter_name].tolist() + + samples = go.Scatter( + x=x_sampled, + y=y_sampled, + mode="markers", + marker={ + "symbol": "x", + "color": "black", + }, + name="Sampled", + showlegend=False, + ) + + fig.add_trace(samples) + # Set the x-axis scale to log if relevant if log_x: fig.update_xaxes( diff --git a/ax/analysis/plotly/surface/tests/test_contour.py b/ax/analysis/plotly/surface/tests/test_contour.py index 6deec31ae4a..f8a63c1fea4 100644 --- a/ax/analysis/plotly/surface/tests/test_contour.py +++ b/ax/analysis/plotly/surface/tests/test_contour.py @@ -77,7 +77,10 @@ def test_compute(self) -> None: "x", "y", "bar_mean", + "sampled", }, ) self.assertIsNotNone(card.blob) self.assertEqual(card.blob_annotation, "plotly") + + self.assertEqual(card.df["sampled"].sum(), len(self.client.experiment.trials)) From 2293eebfe2bf211148da534d2fb486b87aac84f3 Mon Sep 17 00:00:00 2001 From: Miles Olson Date: Fri, 21 Feb 2025 13:59:04 -0800 Subject: [PATCH 4/4] Improve robustness in InteractionAnalysis (#3407) Summary: A number of features to improve reliablity of the InteractionAnalysis. * Hide OAK kernel behind a flag which defaults to False. When false, use the current GenerationNode's adapter * If ax_parameter_sens fails log an exception and fallback to the surrogate's feature_importances * Do not plot samples on the slice and countour plots if there are more than 50 samples (it gets too cluttered) * Changed the orange and blue colors on the importance bar chart to be in the plotly color scheme * Make plot not error out on unordered choice params * Improved subtitle Differential Revision: D69993111 --- ax/analysis/plotly/interaction.py | 85 +++++++++++++++----- ax/analysis/plotly/tests/test_interaction.py | 2 +- 2 files changed, 64 insertions(+), 23 deletions(-) diff --git a/ax/analysis/plotly/interaction.py b/ax/analysis/plotly/interaction.py index b38bac2cfd9..d4748fe2307 100644 --- a/ax/analysis/plotly/interaction.py +++ b/ax/analysis/plotly/interaction.py @@ -29,6 +29,7 @@ from ax.generation_strategy.generation_strategy import GenerationStrategy from ax.modelbridge.registry import Generators from ax.modelbridge.torch import TorchAdapter +from ax.modelbridge.transforms.one_hot import OH_PARAM_INFIX from ax.models.torch.botorch_modular.surrogate import Surrogate from ax.utils.common.logger import get_logger from ax.utils.sensitivity.sobol_measures import ax_parameter_sens @@ -39,10 +40,12 @@ from gpytorch.priors import LogNormalPrior from plotly import express as px, graph_objects as go from plotly.subplots import make_subplots -from pyre_extensions import assert_is_instance +from pyre_extensions import assert_is_instance, none_throws logger: Logger = get_logger(__name__) +DISPLAY_SAMPLED_THRESHOLD: int = 50 + class InteractionPlot(PlotlyAnalysis): """ @@ -63,6 +66,7 @@ def __init__( metric_name: str | None = None, fit_interactions: bool = True, most_important: bool = True, + use_oak_model: bool = False, seed: int = 0, torch_device: torch.device | None = None, ) -> None: @@ -74,6 +78,8 @@ def __init__( most_important: Whether to sort by most or least important features in the bar subplot. Also controls whether the six most or least important features are plotted in the surface subplots. + use_oak_model: Whether to use an OAK model for the analysis. If False, use + Adapter from the current GenerationNode. seed: The seed with which to fit the model. Defaults to 0. Used to ensure that the model fit is identical across the generation of various plots. @@ -83,6 +89,7 @@ def __init__( self.metric_name = metric_name self.fit_interactions = fit_interactions self.most_important = most_important + self.use_oak_model = use_oak_model self.seed = seed self.torch_device = torch_device @@ -103,6 +110,12 @@ def compute( if experiment is None: raise UserInputError("InteractionPlot requires an Experiment") + if generation_strategy is None and not self.use_oak_model: + raise UserInputError( + "InteractionPlot requires a GenerationStrategy when use_oak_model is " + "False" + ) + metric_name = self.metric_name or select_metric(experiment=experiment) # Fix the seed to ensure that the model is fit identically across different @@ -110,19 +123,42 @@ def compute( with torch.random.fork_rng(): torch.torch.manual_seed(self.seed) - # Fit the OAK model. - oak_model = self._get_oak_model( - experiment=experiment, metric_name=metric_name - ) + if self.use_oak_model: + adapter = self._get_oak_model( + experiment=experiment, metric_name=metric_name + ) + else: + gs = none_throws(generation_strategy) + if gs.model is None: + gs._fit_current_model(None) - # Calculate first- or second-order Sobol indices. - sens = ax_parameter_sens( - model_bridge=oak_model, - metrics=[metric_name], - order="second" if self.fit_interactions else "first", - signed=not self.fit_interactions, - )[metric_name] + adapter = assert_is_instance(gs.model, TorchAdapter) + try: + # Calculate first- or second-order Sobol indices. + sens = ax_parameter_sens( + model_bridge=adapter, + metrics=[metric_name], + order="second" if self.fit_interactions else "first", + signed=not self.fit_interactions, + )[metric_name] + except Exception as e: + logger.exception( + f"Failed to compute sensitivity analysis with {e}. Falling back " + "on the surrogate model's feature importances." + ) + + sens = { + metric_name: adapter.feature_importances(metric_name) + for metric_name in adapter.metric_names + } + # Filter out an parameters that have been added to the search space via one-hot + # encoding -- these make the sensitivity analysis less interpretable and break + # the surface plots. + # TODO: Do something more principled here. + sens = {k: v for k, v in sens.items() if OH_PARAM_INFIX not in k} + + # Create a DataFrame with the sensitivity analysis. sensitivity_df = pd.DataFrame( [*sens.items()], columns=["feature", "sensitivity"] ).sort_values(by="sensitivity", key=abs, ascending=self.most_important) @@ -138,13 +174,16 @@ def compute( by="sensitivity", ascending=self.most_important, inplace=True ) + plotly_blue = px.colors.qualitative.Plotly[0] + plotly_orange = px.colors.qualitative.Plotly[4] + sensitivity_fig = px.bar( plotting_df, x="sensitivity", y="feature", color="direction", # Increase gets blue, decrease gets orange. - color_discrete_sequence=["orange", "blue"], + color_discrete_sequence=[plotly_blue, plotly_orange], orientation="h", ) @@ -158,7 +197,7 @@ def compute( surface_figs.append( _prepare_surface_plot( experiment=experiment, - model=oak_model, + model=adapter, feature_name=feature_name, metric_name=metric_name, ) @@ -245,16 +284,18 @@ def compute( width=1000, ) - subtitle_substring = ( - "one- or two-dimensional" if self.fit_interactions else "one-dimensional" - ) + subtitle_substring = ", or pairs of parameters" if self.fit_interactions else "" return self._create_plotly_analysis_card( title=f"Interaction Analysis for {metric_name}", subtitle=( - f"Understand an Experiment's data as {subtitle_substring} additive " - "components with sparsity. Important components are visualized through " - "slice or contour plots" + f"Understand how changes to your parameters affect {metric_name}. " + f"Parameters{subtitle_substring} which rank higher here explain more " + f"of the observed variation in {metric_name}. The direction of the " + "effect is indicated by the color of the bar plot. Additionally, the " + "six most important parameters are visualized through surface plots " + f"which show the predicted outcomes for {metric_name} as a function " + "of the plotted parameters with the other parameters held fixed." ), level=AnalysisCardLevel.MID, df=sensitivity_df, @@ -333,7 +374,7 @@ def _prepare_surface_plot( log_y=is_axis_log_scale( parameter=experiment.search_space.parameters[y_parameter_name] ), - display_sampled=True, + display_sampled=df["sampled"].sum() <= DISPLAY_SAMPLED_THRESHOLD, ) # If the feature is a first-order component, plot a slice plot. @@ -351,5 +392,5 @@ def _prepare_surface_plot( log_x=is_axis_log_scale( parameter=experiment.search_space.parameters[feature_name] ), - display_sampled=True, + display_sampled=df["sampled"].sum() <= DISPLAY_SAMPLED_THRESHOLD, ) diff --git a/ax/analysis/plotly/tests/test_interaction.py b/ax/analysis/plotly/tests/test_interaction.py index 8314ae362b6..6876d70d379 100644 --- a/ax/analysis/plotly/tests/test_interaction.py +++ b/ax/analysis/plotly/tests/test_interaction.py @@ -50,7 +50,7 @@ def setUp(self) -> None: ) @mock_botorch_optimize def test_compute(self) -> None: - analysis = InteractionPlot(metric_name="bar") + analysis = InteractionPlot(metric_name="bar", use_oak_model=True) # Test that it fails if no Experiment is provided with self.assertRaisesRegex(UserInputError, "requires an Experiment"):