@@ -36,19 +36,23 @@ class ContourPlot(PlotlyAnalysis):
36
36
- PARAMETER_NAME: The value of the x parameter specified
37
37
- PARAMETER_NAME: The value of the y parameter specified
38
38
- METRIC_NAME: The predected mean of the metric specified
39
+ - sampled: Whether the parameter values were sampled in at least one trial
39
40
"""
40
41
41
42
def __init__ (
42
43
self ,
43
44
x_parameter_name : str ,
44
45
y_parameter_name : str ,
45
46
metric_name : str | None = None ,
47
+ display_sampled : bool = True ,
46
48
) -> None :
47
49
"""
48
50
Args:
49
51
y_parameter_name: The name of the parameter to plot on the x-axis.
50
52
y_parameter_name: The name of the parameter to plot on the y-axis.
51
53
metric_name: The name of the metric to plot
54
+ display_sampled: If True, plot "x"s at x coordinates which have been
55
+ sampled in at least one trial.
52
56
"""
53
57
# TODO: Add a flag to specify whether or not to plot markers at the (x, y)
54
58
# coordinates of arms (with hover text). This is fine to exlude for now because
@@ -57,6 +61,7 @@ def __init__(
57
61
self .x_parameter_name = x_parameter_name
58
62
self .y_parameter_name = y_parameter_name
59
63
self .metric_name = metric_name
64
+ self ._display_sampled = display_sampled
60
65
61
66
def compute (
62
67
self ,
@@ -94,6 +99,7 @@ def compute(
94
99
log_y = is_axis_log_scale (
95
100
parameter = experiment .search_space .parameters [self .y_parameter_name ]
96
101
),
102
+ display_sampled = self ._display_sampled ,
97
103
)
98
104
99
105
return self ._create_plotly_analysis_card (
@@ -118,14 +124,23 @@ def _prepare_data(
118
124
y_parameter_name : str ,
119
125
metric_name : str ,
120
126
) -> pd .DataFrame :
127
+ sampled = [
128
+ (arm .parameters [x_parameter_name ], arm .parameters [y_parameter_name ])
129
+ for trial in experiment .trials .values ()
130
+ for arm in trial .arms
131
+ ]
132
+
121
133
# Choose which parameter values to predict points for.
122
- xs = get_parameter_values (
134
+ unsampled_xs = get_parameter_values (
123
135
parameter = experiment .search_space .parameters [x_parameter_name ], density = 10
124
136
)
125
- ys = get_parameter_values (
137
+ unsampled_ys = get_parameter_values (
126
138
parameter = experiment .search_space .parameters [y_parameter_name ], density = 10
127
139
)
128
140
141
+ xs = [* [sample [0 ] for sample in sampled ], * unsampled_xs ]
142
+ ys = [* [sample [1 ] for sample in sampled ], * unsampled_ys ]
143
+
129
144
# Construct observation features for each parameter value previously chosen by
130
145
# fixing all other parameters to their status-quo value or mean.
131
146
features = [
@@ -149,15 +164,22 @@ def _prepare_data(
149
164
150
165
predictions = model .predict (observation_features = features )
151
166
152
- return pd .DataFrame .from_records (
153
- [
154
- {
155
- x_parameter_name : features [i ].parameters [x_parameter_name ],
156
- y_parameter_name : features [i ].parameters [y_parameter_name ],
157
- f"{ metric_name } _mean" : predictions [0 ][metric_name ][i ],
158
- }
159
- for i in range (len (features ))
160
- ]
167
+ return none_throws (
168
+ pd .DataFrame .from_records (
169
+ [
170
+ {
171
+ x_parameter_name : features [i ].parameters [x_parameter_name ],
172
+ y_parameter_name : features [i ].parameters [y_parameter_name ],
173
+ f"{ metric_name } _mean" : predictions [0 ][metric_name ][i ],
174
+ "sampled" : (
175
+ features [i ].parameters [x_parameter_name ],
176
+ features [i ].parameters [y_parameter_name ],
177
+ )
178
+ in sampled ,
179
+ }
180
+ for i in range (len (features ))
181
+ ]
182
+ ).drop_duplicates ()
161
183
)
162
184
163
185
@@ -168,6 +190,7 @@ def _prepare_plot(
168
190
metric_name : str ,
169
191
log_x : bool ,
170
192
log_y : bool ,
193
+ display_sampled : bool ,
171
194
) -> go .Figure :
172
195
z_grid = df .pivot (
173
196
index = y_parameter_name , columns = x_parameter_name , values = f"{ metric_name } _mean"
@@ -187,6 +210,24 @@ def _prepare_plot(
187
210
),
188
211
)
189
212
213
+ if display_sampled :
214
+ x_sampled = df [df ["sampled" ]][x_parameter_name ].tolist ()
215
+ y_sampled = df [df ["sampled" ]][y_parameter_name ].tolist ()
216
+
217
+ samples = go .Scatter (
218
+ x = x_sampled ,
219
+ y = y_sampled ,
220
+ mode = "markers" ,
221
+ marker = {
222
+ "symbol" : "x" ,
223
+ "color" : "black" ,
224
+ },
225
+ name = "Sampled" ,
226
+ showlegend = False ,
227
+ )
228
+
229
+ fig .add_trace (samples )
230
+
190
231
# Set the x-axis scale to log if relevant
191
232
if log_x :
192
233
fig .update_xaxes (
0 commit comments