@@ -36,21 +36,26 @@ class SlicePlot(PlotlyAnalysis):
36
36
- PARAMETER_NAME: The value of the parameter specified
37
37
- METRIC_NAME_mean: The predected mean of the metric specified
38
38
- METRIC_NAME_sem: The predected sem of the metric specified
39
+ - sampled: Whether the parameter value was sampled in at least one trial
39
40
"""
40
41
41
42
def __init__ (
42
43
self ,
43
44
parameter_name : str ,
44
45
metric_name : str | None = None ,
46
+ display_sampled : bool = True ,
45
47
) -> None :
46
48
"""
47
49
Args:
48
50
parameter_name: The name of the parameter to plot on the x axis.
49
51
metric_name: The name of the metric to plot on the y axis. If not
50
52
specified the objective will be used.
53
+ display_sampled: If True, plot "x"s at x coordinates which have been
54
+ sampled in at least one trial.
51
55
"""
52
56
self .parameter_name = parameter_name
53
57
self .metric_name = metric_name
58
+ self ._display_sampled = display_sampled
54
59
55
60
def compute (
56
61
self ,
@@ -83,6 +88,7 @@ def compute(
83
88
log_x = is_axis_log_scale (
84
89
parameter = experiment .search_space .parameters [self .parameter_name ]
85
90
),
91
+ display_sampled = self ._display_sampled ,
86
92
)
87
93
88
94
return self ._create_plotly_analysis_card (
@@ -104,10 +110,16 @@ def _prepare_data(
104
110
parameter_name : str ,
105
111
metric_name : str ,
106
112
) -> pd .DataFrame :
113
+ sampled_xs = [
114
+ arm .parameters [parameter_name ]
115
+ for trial in experiment .trials .values ()
116
+ for arm in trial .arms
117
+ ]
107
118
# Choose which parameter values to predict points for.
108
- xs = get_parameter_values (
119
+ unsampled_xs = get_parameter_values (
109
120
parameter = experiment .search_space .parameters [parameter_name ]
110
121
)
122
+ xs = [* sampled_xs , * unsampled_xs ]
111
123
112
124
# Construct observation features for each parameter value previously chosen by
113
125
# fixing all other parameters to their status-quo value or mean.
@@ -127,27 +139,32 @@ def _prepare_data(
127
139
128
140
predictions = model .predict (observation_features = features )
129
141
130
- return pd .DataFrame .from_records (
131
- [
132
- {
133
- parameter_name : xs [i ],
134
- f"{ metric_name } _mean" : predictions [0 ][metric_name ][i ],
135
- f"{ metric_name } _sem" : predictions [1 ][metric_name ][metric_name ][i ]
136
- ** 0.5 , # Convert the variance to the SEM
137
- }
138
- for i in range (len (xs ))
139
- ]
142
+ return none_throws (
143
+ pd .DataFrame .from_records (
144
+ [
145
+ {
146
+ parameter_name : xs [i ],
147
+ f"{ metric_name } _mean" : predictions [0 ][metric_name ][i ],
148
+ f"{ metric_name } _sem" : predictions [1 ][metric_name ][metric_name ][i ]
149
+ ** 0.5 , # Convert the variance to the SEM
150
+ "sampled" : xs [i ] in sampled_xs ,
151
+ }
152
+ for i in range (len (xs ))
153
+ ]
154
+ ).drop_duplicates ()
140
155
).sort_values (by = parameter_name )
141
156
142
157
143
158
def _prepare_plot (
144
159
df : pd .DataFrame ,
145
160
parameter_name : str ,
146
161
metric_name : str ,
147
- log_x : bool = False ,
162
+ log_x : bool ,
163
+ display_sampled : bool ,
148
164
) -> go .Figure :
149
165
x = df [parameter_name ].tolist ()
150
166
y = df [f"{ metric_name } _mean" ].tolist ()
167
+
151
168
# Convert the SEMs to 95% confidence intervals
152
169
y_upper = (df [f"{ metric_name } _mean" ] + 1.96 * df [f"{ metric_name } _sem" ]).tolist ()
153
170
y_lower = (df [f"{ metric_name } _mean" ] - 1.96 * df [f"{ metric_name } _sem" ]).tolist ()
@@ -184,6 +201,25 @@ def _prepare_plot(
184
201
),
185
202
)
186
203
204
+ if display_sampled :
205
+ sampled = df [df ["sampled" ]]
206
+ x_sampled = sampled [parameter_name ].tolist ()
207
+ y_sampled = sampled [f"{ metric_name } _mean" ].tolist ()
208
+
209
+ samples = go .Scatter (
210
+ x = x_sampled ,
211
+ y = y_sampled ,
212
+ mode = "markers" ,
213
+ marker = {
214
+ "symbol" : "x" ,
215
+ "color" : "black" ,
216
+ },
217
+ name = f"Sampled { parameter_name } " ,
218
+ showlegend = False ,
219
+ )
220
+
221
+ fig .add_trace (samples )
222
+
187
223
# Set the x-axis scale to log if relevant
188
224
if log_x :
189
225
fig .update_xaxes (
0 commit comments