Skip to content

Commit c80324f

Browse files
mpolson64facebook-github-bot
authored andcommitted
Add black "x"s at sampled x coordinates in slice plot
Summary: As titled. These are moderately useful here, but I really want to add them to the contour plot and I want the slice plot to match since theyre often displayed together. Differential Revision: D69988707
1 parent f1ee58f commit c80324f

File tree

3 files changed

+66
-17
lines changed

3 files changed

+66
-17
lines changed

ax/analysis/plotly/interaction.py

+1
Original file line numberDiff line numberDiff line change
@@ -353,4 +353,5 @@ def _prepare_surface_plot(
353353
log_x=is_axis_log_scale(
354354
parameter=experiment.search_space.parameters[feature_name]
355355
),
356+
display_sampled=True,
356357
)

ax/analysis/plotly/surface/slice.py

+48-12
Original file line numberDiff line numberDiff line change
@@ -36,21 +36,26 @@ class SlicePlot(PlotlyAnalysis):
3636
- PARAMETER_NAME: The value of the parameter specified
3737
- METRIC_NAME_mean: The predected mean of the metric specified
3838
- METRIC_NAME_sem: The predected sem of the metric specified
39+
- sampled: Whether the parameter value was sampled in at least one trial
3940
"""
4041

4142
def __init__(
4243
self,
4344
parameter_name: str,
4445
metric_name: str | None = None,
46+
display_sampled: bool = True,
4547
) -> None:
4648
"""
4749
Args:
4850
parameter_name: The name of the parameter to plot on the x axis.
4951
metric_name: The name of the metric to plot on the y axis. If not
5052
specified the objective will be used.
53+
display_sampled: If True, plot "x"s at x coordinates which have been
54+
sampled in at least one trial.
5155
"""
5256
self.parameter_name = parameter_name
5357
self.metric_name = metric_name
58+
self._display_sampled = display_sampled
5459

5560
def compute(
5661
self,
@@ -83,6 +88,7 @@ def compute(
8388
log_x=is_axis_log_scale(
8489
parameter=experiment.search_space.parameters[self.parameter_name]
8590
),
91+
display_sampled=self._display_sampled,
8692
)
8793

8894
return self._create_plotly_analysis_card(
@@ -104,10 +110,16 @@ def _prepare_data(
104110
parameter_name: str,
105111
metric_name: str,
106112
) -> pd.DataFrame:
113+
sampled_xs = [
114+
arm.parameters[parameter_name]
115+
for trial in experiment.trials.values()
116+
for arm in trial.arms
117+
]
107118
# Choose which parameter values to predict points for.
108-
xs = get_parameter_values(
119+
unsampled_xs = get_parameter_values(
109120
parameter=experiment.search_space.parameters[parameter_name]
110121
)
122+
xs = [*sampled_xs, *unsampled_xs]
111123

112124
# Construct observation features for each parameter value previously chosen by
113125
# fixing all other parameters to their status-quo value or mean.
@@ -127,27 +139,32 @@ def _prepare_data(
127139

128140
predictions = model.predict(observation_features=features)
129141

130-
return pd.DataFrame.from_records(
131-
[
132-
{
133-
parameter_name: xs[i],
134-
f"{metric_name}_mean": predictions[0][metric_name][i],
135-
f"{metric_name}_sem": predictions[1][metric_name][metric_name][i]
136-
** 0.5, # Convert the variance to the SEM
137-
}
138-
for i in range(len(xs))
139-
]
142+
return none_throws(
143+
pd.DataFrame.from_records(
144+
[
145+
{
146+
parameter_name: xs[i],
147+
f"{metric_name}_mean": predictions[0][metric_name][i],
148+
f"{metric_name}_sem": predictions[1][metric_name][metric_name][i]
149+
** 0.5, # Convert the variance to the SEM
150+
"sampled": xs[i] in sampled_xs,
151+
}
152+
for i in range(len(xs))
153+
]
154+
).drop_duplicates()
140155
).sort_values(by=parameter_name)
141156

142157

143158
def _prepare_plot(
144159
df: pd.DataFrame,
145160
parameter_name: str,
146161
metric_name: str,
147-
log_x: bool = False,
162+
log_x: bool,
163+
display_sampled: bool,
148164
) -> go.Figure:
149165
x = df[parameter_name].tolist()
150166
y = df[f"{metric_name}_mean"].tolist()
167+
151168
# Convert the SEMs to 95% confidence intervals
152169
y_upper = (df[f"{metric_name}_mean"] + 1.96 * df[f"{metric_name}_sem"]).tolist()
153170
y_lower = (df[f"{metric_name}_mean"] - 1.96 * df[f"{metric_name}_sem"]).tolist()
@@ -184,6 +201,25 @@ def _prepare_plot(
184201
),
185202
)
186203

204+
if display_sampled:
205+
sampled = df[df["sampled"]]
206+
x_sampled = sampled[parameter_name].tolist()
207+
y_sampled = sampled[f"{metric_name}_mean"].tolist()
208+
209+
samples = go.Scatter(
210+
x=x_sampled,
211+
y=y_sampled,
212+
mode="markers",
213+
marker={
214+
"symbol": "x",
215+
"color": "black",
216+
},
217+
name=f"Sampled {parameter_name}",
218+
showlegend=False,
219+
)
220+
221+
fig.add_trace(samples)
222+
187223
# Set the x-axis scale to log if relevant
188224
if log_x:
189225
fig.update_xaxes(

ax/analysis/plotly/surface/tests/test_slice.py

+17-5
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,14 @@
77

88
from ax.analysis.analysis import AnalysisCardCategory, AnalysisCardLevel
99
from ax.analysis.plotly.surface.slice import SlicePlot
10+
from ax.core.trial import Trial
1011
from ax.exceptions.core import UserInputError
1112
from ax.service.ax_client import AxClient, ObjectiveProperties
1213
from ax.utils.common.testutils import TestCase
1314
from ax.utils.testing.mock import mock_botorch_optimize
1415

16+
from pyre_extensions import assert_is_instance, none_throws
17+
1518

1619
class TestSlicePlot(TestCase):
1720
@mock_botorch_optimize
@@ -64,11 +67,20 @@ def test_compute(self) -> None:
6467
self.assertEqual(card.category, AnalysisCardCategory.INSIGHT)
6568
self.assertEqual(
6669
{*card.df.columns},
67-
{
68-
"x",
69-
"bar_mean",
70-
"bar_sem",
71-
},
70+
{"x", "bar_mean", "bar_sem", "sampled"},
7271
)
7372
self.assertIsNotNone(card.blob)
7473
self.assertEqual(card.blob_annotation, "plotly")
74+
75+
# Assert that any row where sampled is True has a value of x that is
76+
# sampled in at least one trial.
77+
x_values_sampled = {
78+
none_throws(assert_is_instance(trial, Trial).arm).parameters["x"]
79+
for trial in self.client.experiment.trials.values()
80+
}
81+
self.assertTrue(
82+
card.df.apply(
83+
lambda row: row["x"] in x_values_sampled if row["sampled"] else True,
84+
axis=1,
85+
).all()
86+
)

0 commit comments

Comments
 (0)