Skip to content

Commit 23509c6

Browse files
mpolson64facebook-github-bot
authored andcommitted
Add black "x"s at sampled x coordinates in slice plot (#3405)
Summary: As titled. These are moderately useful here, but I really want to add them to the contour plot and I want the slice plot to match since theyre often displayed together. Differential Revision: D69988707
1 parent b709599 commit 23509c6

File tree

3 files changed

+51
-17
lines changed

3 files changed

+51
-17
lines changed

ax/analysis/plotly/interaction.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -350,4 +350,5 @@ def _prepare_surface_plot(
350350
log_x=is_axis_log_scale(
351351
parameter=experiment.search_space.parameters[feature_name]
352352
),
353+
display_sampled=True,
353354
)

ax/analysis/plotly/surface/slice.py

Lines changed: 47 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -36,21 +36,26 @@ class SlicePlot(PlotlyAnalysis):
3636
- PARAMETER_NAME: The value of the parameter specified
3737
- METRIC_NAME_mean: The predected mean of the metric specified
3838
- METRIC_NAME_sem: The predected sem of the metric specified
39+
- sampled: Whether the parameter value was sampled in at least one trial
3940
"""
4041

4142
def __init__(
4243
self,
4344
parameter_name: str,
4445
metric_name: str | None = None,
46+
display_sampled: bool = True,
4547
) -> None:
4648
"""
4749
Args:
4850
parameter_name: The name of the parameter to plot on the x axis.
4951
metric_name: The name of the metric to plot on the y axis. If not
5052
specified the objective will be used.
53+
display_sampled: If True, plot "x"s at x coordinates which have been
54+
sampled in at least one trial.
5155
"""
5256
self.parameter_name = parameter_name
5357
self.metric_name = metric_name
58+
self._display_sampled = display_sampled
5459

5560
def compute(
5661
self,
@@ -82,6 +87,7 @@ def compute(
8287
log_x=is_axis_log_scale(
8388
parameter=experiment.search_space.parameters[self.parameter_name]
8489
),
90+
display_sampled=self._display_sampled,
8591
)
8692

8793
return self._create_plotly_analysis_card(
@@ -102,10 +108,16 @@ def _prepare_data(
102108
parameter_name: str,
103109
metric_name: str,
104110
) -> pd.DataFrame:
111+
sampled_xs = [
112+
arm.parameters[parameter_name]
113+
for trial in experiment.trials.values()
114+
for arm in trial.arms
115+
]
105116
# Choose which parameter values to predict points for.
106-
xs = get_parameter_values(
117+
unsampled_xs = get_parameter_values(
107118
parameter=experiment.search_space.parameters[parameter_name]
108119
)
120+
xs = [*sampled_xs, *unsampled_xs]
109121

110122
# Construct observation features for each parameter value previously chosen by
111123
# fixing all other parameters to their status-quo value or mean.
@@ -125,27 +137,32 @@ def _prepare_data(
125137

126138
predictions = model.predict(observation_features=features)
127139

128-
return pd.DataFrame.from_records(
129-
[
130-
{
131-
parameter_name: xs[i],
132-
f"{metric_name}_mean": predictions[0][metric_name][i],
133-
f"{metric_name}_sem": predictions[1][metric_name][metric_name][i]
134-
** 0.5, # Convert the variance to the SEM
135-
}
136-
for i in range(len(xs))
137-
]
140+
return none_throws(
141+
pd.DataFrame.from_records(
142+
[
143+
{
144+
parameter_name: xs[i],
145+
f"{metric_name}_mean": predictions[0][metric_name][i],
146+
f"{metric_name}_sem": predictions[1][metric_name][metric_name][i]
147+
** 0.5, # Convert the variance to the SEM
148+
"sampled": xs[i] in sampled_xs,
149+
}
150+
for i in range(len(xs))
151+
]
152+
).drop_duplicates()
138153
).sort_values(by=parameter_name)
139154

140155

141156
def _prepare_plot(
142157
df: pd.DataFrame,
143158
parameter_name: str,
144159
metric_name: str,
145-
log_x: bool = False,
160+
log_x: bool,
161+
display_sampled: bool,
146162
) -> go.Figure:
147163
x = df[parameter_name].tolist()
148164
y = df[f"{metric_name}_mean"].tolist()
165+
149166
# Convert the SEMs to 95% confidence intervals
150167
y_upper = (df[f"{metric_name}_mean"] + 1.96 * df[f"{metric_name}_sem"]).tolist()
151168
y_lower = (df[f"{metric_name}_mean"] - 1.96 * df[f"{metric_name}_sem"]).tolist()
@@ -182,6 +199,24 @@ def _prepare_plot(
182199
),
183200
)
184201

202+
if display_sampled:
203+
x_sampled = df[df["sampled"]][parameter_name].tolist()
204+
y_sampled = df[df["sampled"]][f"{metric_name}_mean"].tolist()
205+
206+
samples = go.Scatter(
207+
x=x_sampled,
208+
y=y_sampled,
209+
mode="markers",
210+
marker={
211+
"symbol": "x",
212+
"color": "black",
213+
},
214+
name=f"Sampled {parameter_name}",
215+
showlegend=False,
216+
)
217+
218+
fig.add_trace(samples)
219+
185220
# Set the x-axis scale to log if relevant
186221
if log_x:
187222
fig.update_xaxes(

ax/analysis/plotly/surface/tests/test_slice.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -63,11 +63,9 @@ def test_compute(self) -> None:
6363
self.assertEqual(card.level, AnalysisCardLevel.LOW)
6464
self.assertEqual(
6565
{*card.df.columns},
66-
{
67-
"x",
68-
"bar_mean",
69-
"bar_sem",
70-
},
66+
{"x", "bar_mean", "bar_sem", "sampled"},
7167
)
7268
self.assertIsNotNone(card.blob)
7369
self.assertEqual(card.blob_annotation, "plotly")
70+
71+
self.assertEqual(card.df["sampled"].sum(), len(self.client.experiment.trials))

0 commit comments

Comments
 (0)