Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add black "x"s at sampled x coordinates in slice plot #3405

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions ax/analysis/plotly/interaction.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,4 +350,5 @@ def _prepare_surface_plot(
log_x=is_axis_log_scale(
parameter=experiment.search_space.parameters[feature_name]
),
display_sampled=True,
)
59 changes: 48 additions & 11 deletions ax/analysis/plotly/surface/slice.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,21 +36,26 @@ class SlicePlot(PlotlyAnalysis):
- PARAMETER_NAME: The value of the parameter specified
- METRIC_NAME_mean: The predected mean of the metric specified
- METRIC_NAME_sem: The predected sem of the metric specified
- sampled: Whether the parameter value was sampled in at least one trial
"""

def __init__(
self,
parameter_name: str,
metric_name: str | None = None,
display_sampled: bool = True,
) -> None:
"""
Args:
parameter_name: The name of the parameter to plot on the x axis.
metric_name: The name of the metric to plot on the y axis. If not
specified the objective will be used.
display_sampled: If True, plot "x"s at x coordinates which have been
sampled in at least one trial.
"""
self.parameter_name = parameter_name
self.metric_name = metric_name
self._display_sampled = display_sampled

def compute(
self,
Expand Down Expand Up @@ -82,6 +87,7 @@ def compute(
log_x=is_axis_log_scale(
parameter=experiment.search_space.parameters[self.parameter_name]
),
display_sampled=self._display_sampled,
)

return self._create_plotly_analysis_card(
Expand All @@ -102,10 +108,16 @@ def _prepare_data(
parameter_name: str,
metric_name: str,
) -> pd.DataFrame:
sampled_xs = [
arm.parameters[parameter_name]
for trial in experiment.trials.values()
for arm in trial.arms
]
# Choose which parameter values to predict points for.
xs = get_parameter_values(
unsampled_xs = get_parameter_values(
parameter=experiment.search_space.parameters[parameter_name]
)
xs = [*sampled_xs, *unsampled_xs]

# Construct observation features for each parameter value previously chosen by
# fixing all other parameters to their status-quo value or mean.
Expand All @@ -125,26 +137,33 @@ def _prepare_data(

predictions = model.predict(observation_features=features)

return pd.DataFrame.from_records(
[
{
parameter_name: xs[i],
f"{metric_name}_mean": predictions[0][metric_name][i],
f"{metric_name}_sem": predictions[1][metric_name][metric_name][i],
}
for i in range(len(xs))
]
return none_throws(
pd.DataFrame.from_records(
[
{
parameter_name: xs[i],
f"{metric_name}_mean": predictions[0][metric_name][i],
f"{metric_name}_sem": predictions[1][metric_name][metric_name][i]
** 0.5, # Convert the variance to the SEM
"sampled": xs[i] in sampled_xs,
}
for i in range(len(xs))
]
).drop_duplicates()
).sort_values(by=parameter_name)


def _prepare_plot(
df: pd.DataFrame,
parameter_name: str,
metric_name: str,
log_x: bool = False,
log_x: bool,
display_sampled: bool,
) -> go.Figure:
x = df[parameter_name].tolist()
y = df[f"{metric_name}_mean"].tolist()

# Convert the SEMs to 95% confidence intervals
y_upper = (df[f"{metric_name}_mean"] + 1.96 * df[f"{metric_name}_sem"]).tolist()
y_lower = (df[f"{metric_name}_mean"] - 1.96 * df[f"{metric_name}_sem"]).tolist()

Expand Down Expand Up @@ -180,6 +199,24 @@ def _prepare_plot(
),
)

if display_sampled:
x_sampled = df[df["sampled"]][parameter_name].tolist()
y_sampled = df[df["sampled"]][f"{metric_name}_mean"].tolist()

samples = go.Scatter(
x=x_sampled,
y=y_sampled,
mode="markers",
marker={
"symbol": "x",
"color": "black",
},
name=f"Sampled {parameter_name}",
showlegend=False,
)

fig.add_trace(samples)

# Set the x-axis scale to log if relevant
if log_x:
fig.update_xaxes(
Expand Down
8 changes: 3 additions & 5 deletions ax/analysis/plotly/surface/tests/test_slice.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,11 +63,9 @@ def test_compute(self) -> None:
self.assertEqual(card.level, AnalysisCardLevel.LOW)
self.assertEqual(
{*card.df.columns},
{
"x",
"bar_mean",
"bar_sem",
},
{"x", "bar_mean", "bar_sem", "sampled"},
)
self.assertIsNotNone(card.blob)
self.assertEqual(card.blob_annotation, "plotly")

self.assertEqual(card.df["sampled"].sum(), len(self.client.experiment.trials))