Skip to content

Commit

Permalink
Add black "x"s at sampled x coordinates in slice plot (#3405)
Browse files Browse the repository at this point in the history
Summary:

As titled. These are moderately useful here, but I really want to add them to the contour plot and I want the slice plot to match since theyre often displayed together.

Differential Revision: D69988707
  • Loading branch information
mpolson64 authored and facebook-github-bot committed Feb 21, 2025
1 parent b709599 commit 23509c6
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 17 deletions.
1 change: 1 addition & 0 deletions ax/analysis/plotly/interaction.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,4 +350,5 @@ def _prepare_surface_plot(
log_x=is_axis_log_scale(
parameter=experiment.search_space.parameters[feature_name]
),
display_sampled=True,
)
59 changes: 47 additions & 12 deletions ax/analysis/plotly/surface/slice.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,21 +36,26 @@ class SlicePlot(PlotlyAnalysis):
- PARAMETER_NAME: The value of the parameter specified
- METRIC_NAME_mean: The predected mean of the metric specified
- METRIC_NAME_sem: The predected sem of the metric specified
- sampled: Whether the parameter value was sampled in at least one trial
"""

def __init__(
self,
parameter_name: str,
metric_name: str | None = None,
display_sampled: bool = True,
) -> None:
"""
Args:
parameter_name: The name of the parameter to plot on the x axis.
metric_name: The name of the metric to plot on the y axis. If not
specified the objective will be used.
display_sampled: If True, plot "x"s at x coordinates which have been
sampled in at least one trial.
"""
self.parameter_name = parameter_name
self.metric_name = metric_name
self._display_sampled = display_sampled

def compute(
self,
Expand Down Expand Up @@ -82,6 +87,7 @@ def compute(
log_x=is_axis_log_scale(
parameter=experiment.search_space.parameters[self.parameter_name]
),
display_sampled=self._display_sampled,
)

return self._create_plotly_analysis_card(
Expand All @@ -102,10 +108,16 @@ def _prepare_data(
parameter_name: str,
metric_name: str,
) -> pd.DataFrame:
sampled_xs = [
arm.parameters[parameter_name]
for trial in experiment.trials.values()
for arm in trial.arms
]
# Choose which parameter values to predict points for.
xs = get_parameter_values(
unsampled_xs = get_parameter_values(
parameter=experiment.search_space.parameters[parameter_name]
)
xs = [*sampled_xs, *unsampled_xs]

# Construct observation features for each parameter value previously chosen by
# fixing all other parameters to their status-quo value or mean.
Expand All @@ -125,27 +137,32 @@ def _prepare_data(

predictions = model.predict(observation_features=features)

return pd.DataFrame.from_records(
[
{
parameter_name: xs[i],
f"{metric_name}_mean": predictions[0][metric_name][i],
f"{metric_name}_sem": predictions[1][metric_name][metric_name][i]
** 0.5, # Convert the variance to the SEM
}
for i in range(len(xs))
]
return none_throws(
pd.DataFrame.from_records(
[
{
parameter_name: xs[i],
f"{metric_name}_mean": predictions[0][metric_name][i],
f"{metric_name}_sem": predictions[1][metric_name][metric_name][i]
** 0.5, # Convert the variance to the SEM
"sampled": xs[i] in sampled_xs,
}
for i in range(len(xs))
]
).drop_duplicates()
).sort_values(by=parameter_name)


def _prepare_plot(
df: pd.DataFrame,
parameter_name: str,
metric_name: str,
log_x: bool = False,
log_x: bool,
display_sampled: bool,
) -> go.Figure:
x = df[parameter_name].tolist()
y = df[f"{metric_name}_mean"].tolist()

# Convert the SEMs to 95% confidence intervals
y_upper = (df[f"{metric_name}_mean"] + 1.96 * df[f"{metric_name}_sem"]).tolist()
y_lower = (df[f"{metric_name}_mean"] - 1.96 * df[f"{metric_name}_sem"]).tolist()
Expand Down Expand Up @@ -182,6 +199,24 @@ def _prepare_plot(
),
)

if display_sampled:
x_sampled = df[df["sampled"]][parameter_name].tolist()
y_sampled = df[df["sampled"]][f"{metric_name}_mean"].tolist()

samples = go.Scatter(
x=x_sampled,
y=y_sampled,
mode="markers",
marker={
"symbol": "x",
"color": "black",
},
name=f"Sampled {parameter_name}",
showlegend=False,
)

fig.add_trace(samples)

# Set the x-axis scale to log if relevant
if log_x:
fig.update_xaxes(
Expand Down
8 changes: 3 additions & 5 deletions ax/analysis/plotly/surface/tests/test_slice.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,11 +63,9 @@ def test_compute(self) -> None:
self.assertEqual(card.level, AnalysisCardLevel.LOW)
self.assertEqual(
{*card.df.columns},
{
"x",
"bar_mean",
"bar_sem",
},
{"x", "bar_mean", "bar_sem", "sampled"},
)
self.assertIsNotNone(card.blob)
self.assertEqual(card.blob_annotation, "plotly")

self.assertEqual(card.df["sampled"].sum(), len(self.client.experiment.trials))

0 comments on commit 23509c6

Please sign in to comment.