From b7095998f2bd9c19309b60e53367a68d4c0cc587 Mon Sep 17 00:00:00 2001 From: Miles Olson Date: Fri, 21 Feb 2025 13:58:56 -0800 Subject: [PATCH 1/2] Fix CI calculation in SlicePlot (#3404) Summary: As titled. Added comments as well to show whats going on. Differential Revision: D69987525 --- ax/analysis/plotly/surface/slice.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/ax/analysis/plotly/surface/slice.py b/ax/analysis/plotly/surface/slice.py index d0cbc6a7a7c..ae6958a58a6 100644 --- a/ax/analysis/plotly/surface/slice.py +++ b/ax/analysis/plotly/surface/slice.py @@ -130,7 +130,8 @@ def _prepare_data( { parameter_name: xs[i], f"{metric_name}_mean": predictions[0][metric_name][i], - f"{metric_name}_sem": predictions[1][metric_name][metric_name][i], + f"{metric_name}_sem": predictions[1][metric_name][metric_name][i] + ** 0.5, # Convert the variance to the SEM } for i in range(len(xs)) ] @@ -145,6 +146,7 @@ def _prepare_plot( ) -> go.Figure: x = df[parameter_name].tolist() y = df[f"{metric_name}_mean"].tolist() + # Convert the SEMs to 95% confidence intervals y_upper = (df[f"{metric_name}_mean"] + 1.96 * df[f"{metric_name}_sem"]).tolist() y_lower = (df[f"{metric_name}_mean"] - 1.96 * df[f"{metric_name}_sem"]).tolist() From 23509c6da18c16c9216b30f1d3af2d3cc1d6253a Mon Sep 17 00:00:00 2001 From: Miles Olson Date: Fri, 21 Feb 2025 13:58:56 -0800 Subject: [PATCH 2/2] Add black "x"s at sampled x coordinates in slice plot (#3405) Summary: As titled. These are moderately useful here, but I really want to add them to the contour plot and I want the slice plot to match since theyre often displayed together. Differential Revision: D69988707 --- ax/analysis/plotly/interaction.py | 1 + ax/analysis/plotly/surface/slice.py | 59 +++++++++++++++---- .../plotly/surface/tests/test_slice.py | 8 +-- 3 files changed, 51 insertions(+), 17 deletions(-) diff --git a/ax/analysis/plotly/interaction.py b/ax/analysis/plotly/interaction.py index cd29d38ae71..a01ad63db07 100644 --- a/ax/analysis/plotly/interaction.py +++ b/ax/analysis/plotly/interaction.py @@ -350,4 +350,5 @@ def _prepare_surface_plot( log_x=is_axis_log_scale( parameter=experiment.search_space.parameters[feature_name] ), + display_sampled=True, ) diff --git a/ax/analysis/plotly/surface/slice.py b/ax/analysis/plotly/surface/slice.py index ae6958a58a6..58abec24d02 100644 --- a/ax/analysis/plotly/surface/slice.py +++ b/ax/analysis/plotly/surface/slice.py @@ -36,21 +36,26 @@ class SlicePlot(PlotlyAnalysis): - PARAMETER_NAME: The value of the parameter specified - METRIC_NAME_mean: The predected mean of the metric specified - METRIC_NAME_sem: The predected sem of the metric specified + - sampled: Whether the parameter value was sampled in at least one trial """ def __init__( self, parameter_name: str, metric_name: str | None = None, + display_sampled: bool = True, ) -> None: """ Args: parameter_name: The name of the parameter to plot on the x axis. metric_name: The name of the metric to plot on the y axis. If not specified the objective will be used. + display_sampled: If True, plot "x"s at x coordinates which have been + sampled in at least one trial. """ self.parameter_name = parameter_name self.metric_name = metric_name + self._display_sampled = display_sampled def compute( self, @@ -82,6 +87,7 @@ def compute( log_x=is_axis_log_scale( parameter=experiment.search_space.parameters[self.parameter_name] ), + display_sampled=self._display_sampled, ) return self._create_plotly_analysis_card( @@ -102,10 +108,16 @@ def _prepare_data( parameter_name: str, metric_name: str, ) -> pd.DataFrame: + sampled_xs = [ + arm.parameters[parameter_name] + for trial in experiment.trials.values() + for arm in trial.arms + ] # Choose which parameter values to predict points for. - xs = get_parameter_values( + unsampled_xs = get_parameter_values( parameter=experiment.search_space.parameters[parameter_name] ) + xs = [*sampled_xs, *unsampled_xs] # Construct observation features for each parameter value previously chosen by # fixing all other parameters to their status-quo value or mean. @@ -125,16 +137,19 @@ def _prepare_data( predictions = model.predict(observation_features=features) - return pd.DataFrame.from_records( - [ - { - parameter_name: xs[i], - f"{metric_name}_mean": predictions[0][metric_name][i], - f"{metric_name}_sem": predictions[1][metric_name][metric_name][i] - ** 0.5, # Convert the variance to the SEM - } - for i in range(len(xs)) - ] + return none_throws( + pd.DataFrame.from_records( + [ + { + parameter_name: xs[i], + f"{metric_name}_mean": predictions[0][metric_name][i], + f"{metric_name}_sem": predictions[1][metric_name][metric_name][i] + ** 0.5, # Convert the variance to the SEM + "sampled": xs[i] in sampled_xs, + } + for i in range(len(xs)) + ] + ).drop_duplicates() ).sort_values(by=parameter_name) @@ -142,10 +157,12 @@ def _prepare_plot( df: pd.DataFrame, parameter_name: str, metric_name: str, - log_x: bool = False, + log_x: bool, + display_sampled: bool, ) -> go.Figure: x = df[parameter_name].tolist() y = df[f"{metric_name}_mean"].tolist() + # Convert the SEMs to 95% confidence intervals y_upper = (df[f"{metric_name}_mean"] + 1.96 * df[f"{metric_name}_sem"]).tolist() y_lower = (df[f"{metric_name}_mean"] - 1.96 * df[f"{metric_name}_sem"]).tolist() @@ -182,6 +199,24 @@ def _prepare_plot( ), ) + if display_sampled: + x_sampled = df[df["sampled"]][parameter_name].tolist() + y_sampled = df[df["sampled"]][f"{metric_name}_mean"].tolist() + + samples = go.Scatter( + x=x_sampled, + y=y_sampled, + mode="markers", + marker={ + "symbol": "x", + "color": "black", + }, + name=f"Sampled {parameter_name}", + showlegend=False, + ) + + fig.add_trace(samples) + # Set the x-axis scale to log if relevant if log_x: fig.update_xaxes( diff --git a/ax/analysis/plotly/surface/tests/test_slice.py b/ax/analysis/plotly/surface/tests/test_slice.py index 557a7665c37..5b1b3ab9520 100644 --- a/ax/analysis/plotly/surface/tests/test_slice.py +++ b/ax/analysis/plotly/surface/tests/test_slice.py @@ -63,11 +63,9 @@ def test_compute(self) -> None: self.assertEqual(card.level, AnalysisCardLevel.LOW) self.assertEqual( {*card.df.columns}, - { - "x", - "bar_mean", - "bar_sem", - }, + {"x", "bar_mean", "bar_sem", "sampled"}, ) self.assertIsNotNone(card.blob) self.assertEqual(card.blob_annotation, "plotly") + + self.assertEqual(card.df["sampled"].sum(), len(self.client.experiment.trials))