Skip to content

Commit

Permalink
Add black "x"s at sampled coordinates in contour plot (#3406)
Browse files Browse the repository at this point in the history
Summary:

As titled.

Differential Revision: D69989849
  • Loading branch information
mpolson64 authored and facebook-github-bot committed Feb 21, 2025
1 parent d82564b commit d9826bf
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 11 deletions.
1 change: 1 addition & 0 deletions ax/analysis/plotly/interaction.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,6 +333,7 @@ def _prepare_surface_plot(
log_y=is_axis_log_scale(
parameter=experiment.search_space.parameters[y_parameter_name]
),
display_sampled=True,
)

# If the feature is a first-order component, plot a slice plot.
Expand Down
63 changes: 52 additions & 11 deletions ax/analysis/plotly/surface/contour.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,19 +36,23 @@ class ContourPlot(PlotlyAnalysis):
- PARAMETER_NAME: The value of the x parameter specified
- PARAMETER_NAME: The value of the y parameter specified
- METRIC_NAME: The predected mean of the metric specified
- sampled: Whether the parameter values were sampled in at least one trial
"""

def __init__(
self,
x_parameter_name: str,
y_parameter_name: str,
metric_name: str | None = None,
display_sampled: bool = True,
) -> None:
"""
Args:
y_parameter_name: The name of the parameter to plot on the x-axis.
y_parameter_name: The name of the parameter to plot on the y-axis.
metric_name: The name of the metric to plot
display_sampled: If True, plot "x"s at x coordinates which have been
sampled in at least one trial.
"""
# TODO: Add a flag to specify whether or not to plot markers at the (x, y)
# coordinates of arms (with hover text). This is fine to exlude for now because
Expand All @@ -57,6 +61,7 @@ def __init__(
self.x_parameter_name = x_parameter_name
self.y_parameter_name = y_parameter_name
self.metric_name = metric_name
self._display_sampled = display_sampled

def compute(
self,
Expand Down Expand Up @@ -93,6 +98,7 @@ def compute(
log_y=is_axis_log_scale(
parameter=experiment.search_space.parameters[self.y_parameter_name]
),
display_sampled=self._display_sampled,
)

return self._create_plotly_analysis_card(
Expand All @@ -116,14 +122,23 @@ def _prepare_data(
y_parameter_name: str,
metric_name: str,
) -> pd.DataFrame:
sampled = [
(arm.parameters[x_parameter_name], arm.parameters[y_parameter_name])
for trial in experiment.trials.values()
for arm in trial.arms
]

# Choose which parameter values to predict points for.
xs = get_parameter_values(
unsampled_xs = get_parameter_values(
parameter=experiment.search_space.parameters[x_parameter_name], density=10
)
ys = get_parameter_values(
unsampled_ys = get_parameter_values(
parameter=experiment.search_space.parameters[y_parameter_name], density=10
)

xs = [*[sample[0] for sample in sampled], *unsampled_xs]
ys = [*[sample[1] for sample in sampled], *unsampled_ys]

# Construct observation features for each parameter value previously chosen by
# fixing all other parameters to their status-quo value or mean.
features = [
Expand All @@ -147,15 +162,22 @@ def _prepare_data(

predictions = model.predict(observation_features=features)

return pd.DataFrame.from_records(
[
{
x_parameter_name: features[i].parameters[x_parameter_name],
y_parameter_name: features[i].parameters[y_parameter_name],
f"{metric_name}_mean": predictions[0][metric_name][i],
}
for i in range(len(features))
]
return none_throws(
pd.DataFrame.from_records(
[
{
x_parameter_name: features[i].parameters[x_parameter_name],
y_parameter_name: features[i].parameters[y_parameter_name],
f"{metric_name}_mean": predictions[0][metric_name][i],
"sampled": (
features[i].parameters[x_parameter_name],
features[i].parameters[y_parameter_name],
)
in sampled,
}
for i in range(len(features))
]
).drop_duplicates()
)


Expand All @@ -166,6 +188,7 @@ def _prepare_plot(
metric_name: str,
log_x: bool,
log_y: bool,
display_sampled: bool,
) -> go.Figure:
z_grid = df.pivot(
index=y_parameter_name, columns=x_parameter_name, values=f"{metric_name}_mean"
Expand All @@ -185,6 +208,24 @@ def _prepare_plot(
),
)

if display_sampled:
x_sampled = df[df["sampled"]][x_parameter_name].tolist()
y_sampled = df[df["sampled"]][y_parameter_name].tolist()

samples = go.Scatter(
x=x_sampled,
y=y_sampled,
mode="markers",
marker={
"symbol": "x",
"color": "black",
},
name="Sampled",
showlegend=False,
)

fig.add_trace(samples)

# Set the x-axis scale to log if relevant
if log_x:
fig.update_xaxes(
Expand Down
3 changes: 3 additions & 0 deletions ax/analysis/plotly/surface/tests/test_contour.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,10 @@ def test_compute(self) -> None:
"x",
"y",
"bar_mean",
"sampled",
},
)
self.assertIsNotNone(card.blob)
self.assertEqual(card.blob_annotation, "plotly")

self.assertEqual(card.df["sampled"].sum(), len(self.client.experiment.trials))

0 comments on commit d9826bf

Please sign in to comment.