@@ -36,21 +36,26 @@ class SlicePlot(PlotlyAnalysis):
36
36
- PARAMETER_NAME: The value of the parameter specified
37
37
- METRIC_NAME_mean: The predected mean of the metric specified
38
38
- METRIC_NAME_sem: The predected sem of the metric specified
39
+ - sampled: Whether the parameter value was sampled in at least one trial
39
40
"""
40
41
41
42
def __init__ (
42
43
self ,
43
44
parameter_name : str ,
44
45
metric_name : str | None = None ,
46
+ display_sampled : bool = True ,
45
47
) -> None :
46
48
"""
47
49
Args:
48
50
parameter_name: The name of the parameter to plot on the x axis.
49
51
metric_name: The name of the metric to plot on the y axis. If not
50
52
specified the objective will be used.
53
+ display_sampled: If True, plot "x"s at x coordinates which have been
54
+ sampled in at least one trial.
51
55
"""
52
56
self .parameter_name = parameter_name
53
57
self .metric_name = metric_name
58
+ self ._display_sampled = display_sampled
54
59
55
60
def compute (
56
61
self ,
@@ -82,6 +87,7 @@ def compute(
82
87
log_x = is_axis_log_scale (
83
88
parameter = experiment .search_space .parameters [self .parameter_name ]
84
89
),
90
+ display_sampled = self ._display_sampled ,
85
91
)
86
92
87
93
return self ._create_plotly_analysis_card (
@@ -102,10 +108,16 @@ def _prepare_data(
102
108
parameter_name : str ,
103
109
metric_name : str ,
104
110
) -> pd .DataFrame :
111
+ sampled_xs = [
112
+ arm .parameters [parameter_name ]
113
+ for trial in experiment .trials .values ()
114
+ for arm in trial .arms
115
+ ]
105
116
# Choose which parameter values to predict points for.
106
- xs = get_parameter_values (
117
+ unsampled_xs = get_parameter_values (
107
118
parameter = experiment .search_space .parameters [parameter_name ]
108
119
)
120
+ xs = [* sampled_xs , * unsampled_xs ]
109
121
110
122
# Construct observation features for each parameter value previously chosen by
111
123
# fixing all other parameters to their status-quo value or mean.
@@ -125,27 +137,32 @@ def _prepare_data(
125
137
126
138
predictions = model .predict (observation_features = features )
127
139
128
- return pd .DataFrame .from_records (
129
- [
130
- {
131
- parameter_name : xs [i ],
132
- f"{ metric_name } _mean" : predictions [0 ][metric_name ][i ],
133
- f"{ metric_name } _sem" : predictions [1 ][metric_name ][metric_name ][i ]
134
- ** 0.5 , # Convert the variance to the SEM
135
- }
136
- for i in range (len (xs ))
137
- ]
140
+ return none_throws (
141
+ pd .DataFrame .from_records (
142
+ [
143
+ {
144
+ parameter_name : xs [i ],
145
+ f"{ metric_name } _mean" : predictions [0 ][metric_name ][i ],
146
+ f"{ metric_name } _sem" : predictions [1 ][metric_name ][metric_name ][i ]
147
+ ** 0.5 , # Convert the variance to the SEM
148
+ "sampled" : xs [i ] in sampled_xs ,
149
+ }
150
+ for i in range (len (xs ))
151
+ ]
152
+ ).drop_duplicates ()
138
153
).sort_values (by = parameter_name )
139
154
140
155
141
156
def _prepare_plot (
142
157
df : pd .DataFrame ,
143
158
parameter_name : str ,
144
159
metric_name : str ,
145
- log_x : bool = False ,
160
+ log_x : bool ,
161
+ display_sampled : bool ,
146
162
) -> go .Figure :
147
163
x = df [parameter_name ].tolist ()
148
164
y = df [f"{ metric_name } _mean" ].tolist ()
165
+
149
166
# Convert the SEMs to 95% confidence intervals
150
167
y_upper = (df [f"{ metric_name } _mean" ] + 1.96 * df [f"{ metric_name } _sem" ]).tolist ()
151
168
y_lower = (df [f"{ metric_name } _mean" ] - 1.96 * df [f"{ metric_name } _sem" ]).tolist ()
@@ -182,6 +199,24 @@ def _prepare_plot(
182
199
),
183
200
)
184
201
202
+ if display_sampled :
203
+ x_sampled = df [df ["sampled" ]][parameter_name ].tolist ()
204
+ y_sampled = df [df ["sampled" ]][f"{ metric_name } _mean" ].tolist ()
205
+
206
+ samples = go .Scatter (
207
+ x = x_sampled ,
208
+ y = y_sampled ,
209
+ mode = "markers" ,
210
+ marker = {
211
+ "symbol" : "x" ,
212
+ "color" : "black" ,
213
+ },
214
+ name = f"Sampled { parameter_name } " ,
215
+ showlegend = False ,
216
+ )
217
+
218
+ fig .add_trace (samples )
219
+
185
220
# Set the x-axis scale to log if relevant
186
221
if log_x :
187
222
fig .update_xaxes (
0 commit comments