forked from facebook/Ax
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathcross_validation.py
More file actions
441 lines (402 loc) · 17.5 KB
/
cross_validation.py
File metadata and controls
441 lines (402 loc) · 17.5 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
# pyre-strict
from collections.abc import Mapping, Sequence
from typing import final
import pandas as pd
from ax.adapter.base import Adapter
from ax.adapter.cross_validation import cross_validate, CVResult
from ax.analysis.analysis import Analysis
from ax.analysis.healthcheck.predictable_metrics import DEFAULT_MODEL_FIT_THRESHOLD
from ax.analysis.plotly.color_constants import AX_BLUE
from ax.analysis.plotly.plotly_analysis import create_plotly_analysis_card
from ax.analysis.plotly.utils import get_scatter_point_color, Z_SCORE_95_CI
from ax.analysis.utils import extract_relevant_adapter, validate_adapter_can_predict
from ax.core.analysis_card import AnalysisCardBase
from ax.core.experiment import Experiment
from ax.generation_strategy.generation_strategy import GenerationStrategy
from ax.utils.stats.model_fit_stats import coefficient_of_determination
from plotly import graph_objects as go
from pyre_extensions import override
CV_CARDGROUP_TITLE = "Cross Validation: Assessing model fit"
CV_CARDGROUP_SUBTITLE = (
"Cross-validation plots display the model fit for each metric in the "
"experiment. The model is trained on a subset of the data and then predicts the "
"outcome for the remaining subset. The plots show the predicted outcome for the "
"validation set on the y-axis against its actual value on the x-axis. Points "
"that align closely with the dotted diagonal line indicate a strong model fit, "
"signifying accurate predictions. Additionally, the plots include "
"confidence intervals that provide insight into the noise in observations and "
"the uncertainty in model predictions. <br><br>"
"NOTE: A horizontal, flat line of predictions "
"indicates that the model has not picked up on sufficient signal in the data, "
"and instead is just predicting the mean."
)
@final
class CrossValidationPlot(Analysis):
"""
Plotly Scatter plot for cross validation for model predictions using the current
model on the GenerationStrategy. This plot is useful for understanding how well
the model is able to predict out-of-sample which in turn is indicative of its
ability to suggest valuable candidates.
Splits the model's training data into train/test folds and makes
out-of-sample predictions on the test folds.
A well fit model will have points clustered around the y=x line, and a model with
poor fit may have points in a horizontal band in the center of the plot
indicating a tendency to predict the observed mean of the specificed metric for
all arms.
The DataFrame computed will contain one row per arm and the following columns:
- arm_name: The name of the arm
- observed: The observed mean of the metric specified
- observed_sem: The SEM of the observed mean of the metric specified
- predicted: The predicted mean of the metric specified
- predicted_sem: The SEM of the predicted mean of the metric specified
The card title includes the R² (coefficient of determination) score for
the metric.
"""
def __init__(
self,
metric_names: Sequence[str] | None = None,
folds: int = -1,
untransform: bool = False,
trial_index: int | None = None,
labels: Mapping[str, str] | None = None,
) -> None:
"""
Args:
metric_names: The names of the metrics to plot. If not specified all metrics
available on the underlying model will be used.
folds: Number of subsamples to partition observations into. Use -1 for
leave-one-out cross validation.
untransform: Whether to untransform the model predictions before cross
validating. Generators are trained on transformed data, and candidate
generation is performed in the transformed space. Computing the model
quality metric based on the cross-validation results in the
untransformed space may not be representative of the model that
is actually used for candidate generation in case of non-invertible
transforms, e.g., Winsorize or LogY. While the model in the
transformed space may not be representative of the original data in
regions where outliers have been removed, we have found it to better
reflect the how good the model used for candidate generation actually
is.
trial_index: Optional trial index that the model from generation_strategy
was used to generate. Useful card attribute to filter to only specific
trial.
labels: Optional dictionary of labels for the plot. Useful for when metric
names are too long or otherwise challenging to read.
"""
self.metric_names = metric_names
self.folds = folds
self.untransform = untransform
self.trial_index = trial_index
self.labels: dict[str, str] = {**labels} if labels is not None else {}
self._r2s: dict[str, float] = {}
@override
def validate_applicable_state(
self,
experiment: Experiment | None = None,
generation_strategy: GenerationStrategy | None = None,
adapter: Adapter | None = None,
) -> str | None:
"""
CrossValidationPlot requires only an Adapter which can predict.
"""
return validate_adapter_can_predict(
experiment=experiment,
generation_strategy=generation_strategy,
adapter=adapter,
required_metric_names=None,
)
def compute(
self,
experiment: Experiment | None = None,
generation_strategy: GenerationStrategy | None = None,
adapter: Adapter | None = None,
) -> AnalysisCardBase:
relevant_adapter = extract_relevant_adapter(
experiment=experiment,
generation_strategy=generation_strategy,
adapter=adapter,
)
cards = []
cv_results = cross_validate(
adapter=relevant_adapter, folds=self.folds, untransform=self.untransform
)
relevant_adapter_metric_names = [
relevant_adapter._experiment.signature_to_metric[signature].name
for signature in relevant_adapter._metric_signatures
]
self._r2s = {}
for metric_name in self.metric_names or relevant_adapter_metric_names:
df = _prepare_data(
metric_name=metric_name, cv_results=cv_results, adapter=relevant_adapter
)
fig = _prepare_plot(df=df)
k_folds_substring = (
f"{self.folds}-fold" if self.folds > 0 else "leave-one-out"
)
# If a human readable metric name is provided, use it in the title
metric_title = self.labels.get(metric_name, metric_name)
r_squared = coefficient_of_determination(
y_obs=df["observed"].to_numpy(),
y_pred=df["predicted"].to_numpy(),
)
self._r2s[metric_title] = r_squared
# Define the cross-validation description based on the number of folds
cv_description = (
(
f"the data is split into {self.folds} subsets and the model is "
f"trained on {self.folds - 1} subsets while the remaining subset "
"is used for validation"
)
if self.folds > 0
else (
"the model is trained on all data except one sample, which is "
"used for validation"
)
)
card = create_plotly_analysis_card(
name=self.__class__.__name__,
title=(
f"Cross Validation for {metric_title} (R\u00b2 = {r_squared:.2f})"
),
subtitle=(
"The cross-validation plot displays the model fit for each "
f"metric in the experiment. It employs a {k_folds_substring} "
f"approach, where {cv_description}. The plot shows the "
"predicted outcome for the validation set on the y-axis against "
"its actual value on the x-axis. Points that align closely with "
"the dotted diagonal line indicate a strong model fit, signifying "
"accurate predictions. Additionally, the plot includes 95% "
"confidence intervals that provide insight into the noise in "
"observations and the uncertainty in model predictions. A "
"horizontal, flat line of predictions indicates that the model "
"has not picked up on sufficient signal in the data, and instead "
"is just predicting the mean."
),
df=df,
fig=fig,
)
cards.append(card)
# Create a summary table of R2 values for all metrics
if self._r2s:
threshold = DEFAULT_MODEL_FIT_THRESHOLD
metric_names_list = list(self._r2s.keys())
r2_values = [f"{v:.2f}" for v in self._r2s.values()]
fill_colors = [
"rgba(0, 200, 0, 0.15)" if r2 >= threshold else "white"
for r2 in self._r2s.values()
]
r2_fig = go.Figure(
data=[
go.Table(
columnwidth=[4, 1],
header={
"values": ["Metric", "R\u00b2"],
"align": "left",
},
cells={
"values": [metric_names_list, r2_values],
"align": "left",
"fill_color": [fill_colors, fill_colors],
},
)
]
)
r2_card = create_plotly_analysis_card(
name=self.__class__.__name__,
title="Summary of model fits",
subtitle=(
"R\u00b2 (coefficient of determination) measures how well"
" the model predicts each metric. Higher values indicate"
" better model fit. Metrics with R\u00b2 >="
f" {threshold} are highlighted in green."
),
df=pd.DataFrame(
{
"Metric": metric_names_list,
"R\u00b2": list(self._r2s.values()),
}
),
fig=r2_fig,
)
cards.append(r2_card)
return self._create_analysis_card_group(
title=CV_CARDGROUP_TITLE,
subtitle=CV_CARDGROUP_SUBTITLE,
children=cards,
)
def compute_cross_validation_adhoc(
metric_names: Sequence[str] | None = None,
folds: int = -1,
untransform: bool = True,
labels: Mapping[str, str] | None = None,
experiment: Experiment | None = None,
generation_strategy: GenerationStrategy | None = None,
adapter: Adapter | None = None,
) -> AnalysisCardBase:
"""
Helper method to expose adhoc cross validation plotting. Only for advanced users in
a notebook setting.
Args:
metric_names: The names of the metrics to plot. If not specified all metrics
available on the underlying model will be used.
folds: Number of subsamples to partition observations into. Use -1 for
leave-one-out cross validation.
untransform: Whether to untransform the model predictions before cross
validating. Generators are trained on transformed data, and candidate
generation is performed in the transformed space. Computing the model
quality metric based on the cross-validation results in the
untransformed space may not be representative of the model that
is actually used for candidate generation in case of non-invertible
transforms, e.g., Winsorize or LogY. While the model in the
transformed space may not be representative of the original data in
regions where outliers have been removed, we have found it to better
reflect the how good the model used for candidate generation actually
is.
labels: Optional dictionary of labels for the plot. Useful for when metric
names are too long or otherwise challenging to read.
experiment: Optional. The experiment to extract data from.
generation_strategy: Optional. The generation strategy to extract the adapter
from.
adapter: Optional. The adapter to cross validate. If provided, this adapter
will be used instead of the current adapter on the ``GenerationStrategy``
"""
relevant_adapter = extract_relevant_adapter(
experiment=experiment,
generation_strategy=generation_strategy,
adapter=adapter,
)
analysis = CrossValidationPlot(
metric_names=metric_names,
folds=folds,
untransform=untransform,
labels=labels,
)
return analysis.compute(
experiment=experiment,
adapter=relevant_adapter,
)
def _prepare_data(
metric_name: str, cv_results: list[CVResult], adapter: Adapter
) -> pd.DataFrame:
records = []
for observed, predicted in cv_results:
observed_metric_names = []
predicted_metric_names = []
for signature in observed.data.metric_signatures:
observed_metric_names.append(
adapter._experiment.signature_to_metric[signature].name
)
for signature in predicted.metric_signatures:
predicted_metric_names.append(
adapter._experiment.signature_to_metric[signature].name
)
# Find the index of the metric in observed and predicted
observed_i = next(
(i for i, name in enumerate(observed_metric_names) if name == metric_name),
None,
)
predicted_i = next(
(i for i, name in enumerate(predicted_metric_names) if name == metric_name),
None,
)
# Check if both indices are found
if observed_i is not None and predicted_i is not None:
record = {
"arm_name": observed.arm_name,
"observed": observed.data.means[observed_i],
"predicted": predicted.means[predicted_i],
# Compute the 95% confidence intervals for plotting purposes
"observed_95_ci": observed.data.covariance[observed_i][observed_i]
** 0.5
* Z_SCORE_95_CI,
"predicted_95_ci": predicted.covariance[predicted_i][predicted_i] ** 0.5
* Z_SCORE_95_CI,
}
records.append(record)
return pd.DataFrame.from_records(records)
def _prepare_plot(
df: pd.DataFrame,
) -> go.Figure:
# Create a scatter plot using Plotly Graph Objects for more control
fig = go.Figure()
TRANSPARENT_AX_BLUE: str = get_scatter_point_color(
hex_color=AX_BLUE,
ci_transparency=True,
)
FILLED_AX_BLUE: str = get_scatter_point_color(
hex_color=AX_BLUE,
ci_transparency=False,
)
fig.add_trace(
go.Scatter(
x=df["observed"],
y=df["predicted"],
mode="markers",
marker={
"color": FILLED_AX_BLUE,
},
error_x={
"type": "data",
"array": df["observed_95_ci"],
"visible": True,
"color": TRANSPARENT_AX_BLUE,
},
error_y={
"type": "data",
"array": df["predicted_95_ci"],
"visible": True,
"color": TRANSPARENT_AX_BLUE,
},
text=df["arm_name"],
hovertemplate=(
"<b>Arm Name: %{text}</b><br>"
+ "Predicted: %{y}<br>"
+ "Observed: %{x}<br>"
+ "<extra></extra>" # Removes the trace name from the hover
),
hoverlabel={
"bgcolor": TRANSPARENT_AX_BLUE,
"font": {"color": "black"},
},
)
)
# Add a gray dashed line at y=x starting and ending just outside of the region of
# interest for reference. A well fit model should have points clustered around
# this line.
lower_bound = (
min(
(df["observed"] - df["observed_95_ci"].fillna(0)).min(),
(df["predicted"] - df["predicted_95_ci"].fillna(0)).min(),
)
* 0.999 # tight autozoom
)
upper_bound = (
max(
(df["observed"] + df["observed_95_ci"].fillna(0)).max(),
(df["predicted"] + df["predicted_95_ci"].fillna(0)).max(),
)
* 1.001 # tight autozoom
)
fig.add_shape(
type="line",
x0=lower_bound,
y0=lower_bound,
x1=upper_bound,
y1=upper_bound,
line={"color": "gray", "dash": "dot"},
)
# Update axes with tight autozoom that remains square
fig.update_xaxes(
range=[lower_bound, upper_bound], constrain="domain", title="Actual Outcome"
)
fig.update_yaxes(
range=[lower_bound, upper_bound],
scaleanchor="x",
scaleratio=1,
title="Predicted Outcome",
)
return fig