Skip to content

Commit 2aa191f

Browse files
mpolson64facebook-github-bot
authored andcommitted
Add black "x"s at sampled coordinates in contour plot (#3406)
Summary: Pull Request resolved: #3406 As titled. Reviewed By: mgarrard Differential Revision: D69989849
1 parent c80324f commit 2aa191f

File tree

3 files changed

+80
-11
lines changed

3 files changed

+80
-11
lines changed

ax/analysis/plotly/interaction.py

+1
Original file line numberDiff line numberDiff line change
@@ -336,6 +336,7 @@ def _prepare_surface_plot(
336336
log_y=is_axis_log_scale(
337337
parameter=experiment.search_space.parameters[y_parameter_name]
338338
),
339+
display_sampled=True,
339340
)
340341

341342
# If the feature is a first-order component, plot a slice plot.

ax/analysis/plotly/surface/contour.py

+52-11
Original file line numberDiff line numberDiff line change
@@ -36,19 +36,23 @@ class ContourPlot(PlotlyAnalysis):
3636
- PARAMETER_NAME: The value of the x parameter specified
3737
- PARAMETER_NAME: The value of the y parameter specified
3838
- METRIC_NAME: The predected mean of the metric specified
39+
- sampled: Whether the parameter values were sampled in at least one trial
3940
"""
4041

4142
def __init__(
4243
self,
4344
x_parameter_name: str,
4445
y_parameter_name: str,
4546
metric_name: str | None = None,
47+
display_sampled: bool = True,
4648
) -> None:
4749
"""
4850
Args:
4951
y_parameter_name: The name of the parameter to plot on the x-axis.
5052
y_parameter_name: The name of the parameter to plot on the y-axis.
5153
metric_name: The name of the metric to plot
54+
display_sampled: If True, plot "x"s at x coordinates which have been
55+
sampled in at least one trial.
5256
"""
5357
# TODO: Add a flag to specify whether or not to plot markers at the (x, y)
5458
# coordinates of arms (with hover text). This is fine to exlude for now because
@@ -57,6 +61,7 @@ def __init__(
5761
self.x_parameter_name = x_parameter_name
5862
self.y_parameter_name = y_parameter_name
5963
self.metric_name = metric_name
64+
self._display_sampled = display_sampled
6065

6166
def compute(
6267
self,
@@ -94,6 +99,7 @@ def compute(
9499
log_y=is_axis_log_scale(
95100
parameter=experiment.search_space.parameters[self.y_parameter_name]
96101
),
102+
display_sampled=self._display_sampled,
97103
)
98104

99105
return self._create_plotly_analysis_card(
@@ -118,14 +124,23 @@ def _prepare_data(
118124
y_parameter_name: str,
119125
metric_name: str,
120126
) -> pd.DataFrame:
127+
sampled = [
128+
(arm.parameters[x_parameter_name], arm.parameters[y_parameter_name])
129+
for trial in experiment.trials.values()
130+
for arm in trial.arms
131+
]
132+
121133
# Choose which parameter values to predict points for.
122-
xs = get_parameter_values(
134+
unsampled_xs = get_parameter_values(
123135
parameter=experiment.search_space.parameters[x_parameter_name], density=10
124136
)
125-
ys = get_parameter_values(
137+
unsampled_ys = get_parameter_values(
126138
parameter=experiment.search_space.parameters[y_parameter_name], density=10
127139
)
128140

141+
xs = [*[sample[0] for sample in sampled], *unsampled_xs]
142+
ys = [*[sample[1] for sample in sampled], *unsampled_ys]
143+
129144
# Construct observation features for each parameter value previously chosen by
130145
# fixing all other parameters to their status-quo value or mean.
131146
features = [
@@ -149,15 +164,22 @@ def _prepare_data(
149164

150165
predictions = model.predict(observation_features=features)
151166

152-
return pd.DataFrame.from_records(
153-
[
154-
{
155-
x_parameter_name: features[i].parameters[x_parameter_name],
156-
y_parameter_name: features[i].parameters[y_parameter_name],
157-
f"{metric_name}_mean": predictions[0][metric_name][i],
158-
}
159-
for i in range(len(features))
160-
]
167+
return none_throws(
168+
pd.DataFrame.from_records(
169+
[
170+
{
171+
x_parameter_name: features[i].parameters[x_parameter_name],
172+
y_parameter_name: features[i].parameters[y_parameter_name],
173+
f"{metric_name}_mean": predictions[0][metric_name][i],
174+
"sampled": (
175+
features[i].parameters[x_parameter_name],
176+
features[i].parameters[y_parameter_name],
177+
)
178+
in sampled,
179+
}
180+
for i in range(len(features))
181+
]
182+
).drop_duplicates()
161183
)
162184

163185

@@ -168,6 +190,7 @@ def _prepare_plot(
168190
metric_name: str,
169191
log_x: bool,
170192
log_y: bool,
193+
display_sampled: bool,
171194
) -> go.Figure:
172195
z_grid = df.pivot(
173196
index=y_parameter_name, columns=x_parameter_name, values=f"{metric_name}_mean"
@@ -187,6 +210,24 @@ def _prepare_plot(
187210
),
188211
)
189212

213+
if display_sampled:
214+
x_sampled = df[df["sampled"]][x_parameter_name].tolist()
215+
y_sampled = df[df["sampled"]][y_parameter_name].tolist()
216+
217+
samples = go.Scatter(
218+
x=x_sampled,
219+
y=y_sampled,
220+
mode="markers",
221+
marker={
222+
"symbol": "x",
223+
"color": "black",
224+
},
225+
name="Sampled",
226+
showlegend=False,
227+
)
228+
229+
fig.add_trace(samples)
230+
190231
# Set the x-axis scale to log if relevant
191232
if log_x:
192233
fig.update_xaxes(

ax/analysis/plotly/surface/tests/test_contour.py

+27
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,14 @@
77

88
from ax.analysis.analysis import AnalysisCardCategory, AnalysisCardLevel
99
from ax.analysis.plotly.surface.contour import ContourPlot
10+
from ax.core.trial import Trial
1011
from ax.exceptions.core import UserInputError
1112
from ax.service.ax_client import AxClient, ObjectiveProperties
1213
from ax.utils.common.testutils import TestCase
1314
from ax.utils.testing.mock import mock_botorch_optimize
1415

16+
from pyre_extensions import assert_is_instance, none_throws
17+
1518

1619
class TestContourPlot(TestCase):
1720
@mock_botorch_optimize
@@ -78,7 +81,31 @@ def test_compute(self) -> None:
7881
"x",
7982
"y",
8083
"bar_mean",
84+
"sampled",
8185
},
8286
)
8387
self.assertIsNotNone(card.blob)
8488
self.assertEqual(card.blob_annotation, "plotly")
89+
90+
# Assert that any row where sampled is True has a value of x that is
91+
# sampled in at least one trial.
92+
x_values_sampled = {
93+
none_throws(assert_is_instance(trial, Trial).arm).parameters["x"]
94+
for trial in self.client.experiment.trials.values()
95+
}
96+
y_values_sampled = {
97+
none_throws(assert_is_instance(trial, Trial).arm).parameters["y"]
98+
for trial in self.client.experiment.trials.values()
99+
}
100+
self.assertTrue(
101+
card.df.apply(
102+
lambda row: row["x"] in x_values_sampled
103+
and row["y"] in y_values_sampled
104+
if row["sampled"]
105+
else True,
106+
axis=1,
107+
).all()
108+
)
109+
110+
# Less-than-or-equal to because we may have removed some duplicates
111+
self.assertTrue(card.df["sampled"].sum() <= len(self.client.experiment.trials))

0 commit comments

Comments
 (0)