Skip to content

Commit 40fe62e

Browse files
blethamfacebook-github-bot
authored andcommitted
option to use a particular trial for CV test set
Summary: In situations where we have a target trial that is significantly different from other trials on the experiment, for instance if there was a system rebase in between two trials, we really care especially about our ability to predict the target trial, as that is what we expect to see moving forward. This adds a kwarg to the DiagnosticAnalysis to specify a particular trial for the diagnostics. All trials will still be used as train arms in each CV fold, but the test arms will be limited to arms in that trial. Differential Revision: D95824210
1 parent 4635a18 commit 40fe62e

5 files changed

Lines changed: 59 additions & 4 deletions

File tree

ax/analysis/diagnostics.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,14 +36,23 @@ class DiagnosticAnalysis(Analysis):
3636
of leave-one-out cross validation.
3737
"""
3838

39-
def __init__(self, include_tracking_metrics: bool = False) -> None:
39+
def __init__(
40+
self,
41+
include_tracking_metrics: bool = False,
42+
trial_index: int | None = None,
43+
) -> None:
4044
"""Initialize the DiagnosticAnalysis.
4145
4246
Args:
4347
include_tracking_metrics: Whether to include tracking metrics or just use
4448
the optimization config metrics.
49+
trial_index: If provided, limits cross validation to only evaluate
50+
predictions for observations from this trial. Other trials'
51+
observations will still be used for training but will not
52+
appear as test points.
4553
"""
4654
self.include_tracking_metrics = include_tracking_metrics
55+
self.trial_index = trial_index
4756

4857
@override
4958
def validate_applicable_state(
@@ -80,7 +89,9 @@ def compute(
8089
cross_validation_plots = []
8190
metric_r2_card = []
8291
if not is_bandit:
83-
cv_analysis = CrossValidationPlot(metric_names=metric_names)
92+
cv_analysis = CrossValidationPlot(
93+
metric_names=metric_names, cv_trial_index=self.trial_index
94+
)
8495
cv_card = cv_analysis.compute_or_error_card(
8596
experiment=experiment,
8697
generation_strategy=generation_strategy,

ax/analysis/plotly/cross_validation.py

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
# pyre-strict
77

88

9-
from collections.abc import Mapping, Sequence
9+
from collections.abc import Callable, Mapping, Sequence
1010
from typing import final
1111

1212
import pandas as pd
@@ -19,6 +19,7 @@
1919
from ax.analysis.utils import extract_relevant_adapter, validate_adapter_can_predict
2020
from ax.core.analysis_card import AnalysisCardBase
2121
from ax.core.experiment import Experiment
22+
from ax.core.observation import Observation
2223
from ax.generation_strategy.generation_strategy import GenerationStrategy
2324
from ax.utils.stats.model_fit_stats import coefficient_of_determination
2425
from plotly import graph_objects as go
@@ -76,6 +77,7 @@ def __init__(
7677
untransform: bool = False,
7778
trial_index: int | None = None,
7879
labels: Mapping[str, str] | None = None,
80+
cv_trial_index: int | None = None,
7981
) -> None:
8082
"""
8183
Args:
@@ -99,13 +101,18 @@ def __init__(
99101
trial.
100102
labels: Optional dictionary of labels for the plot. Useful for when metric
101103
names are too long or otherwise challenging to read.
104+
cv_trial_index: If provided, limits cross validation to only evaluate
105+
predictions for observations from this trial. Other trials'
106+
observations will still be used for training but will not
107+
appear as test points.
102108
"""
103109

104110
self.metric_names = metric_names
105111
self.folds = folds
106112
self.untransform = untransform
107113
self.trial_index = trial_index
108114
self.labels: dict[str, str] = {**labels} if labels is not None else {}
115+
self.cv_trial_index = cv_trial_index
109116
self._r2s: dict[str, float] = {}
110117

111118
@override
@@ -138,8 +145,25 @@ def compute(
138145
)
139146

140147
cards = []
148+
149+
def _make_test_selector(
150+
trial_index: int,
151+
) -> Callable[[Observation], bool]:
152+
def test_selector(obs: Observation) -> bool:
153+
return obs.features.trial_index == trial_index
154+
155+
return test_selector
156+
157+
test_selector = (
158+
_make_test_selector(self.cv_trial_index)
159+
if self.cv_trial_index is not None
160+
else None
161+
)
141162
cv_results = cross_validate(
142-
adapter=relevant_adapter, folds=self.folds, untransform=self.untransform
163+
adapter=relevant_adapter,
164+
folds=self.folds,
165+
untransform=self.untransform,
166+
test_selector=test_selector,
143167
)
144168
relevant_adapter_metric_names = [
145169
relevant_adapter._experiment.signature_to_metric[signature].name

ax/analysis/plotly/tests/test_constraint_feasibility.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -218,6 +218,7 @@ def test_offline(self) -> None:
218218
generation_strategy=generation_strategy,
219219
)
220220

221+
@mock_botorch_optimize
221222
def test_online(self) -> None:
222223
for experiment in get_online_experiments():
223224
# Skip if no outcome constraints

ax/analysis/plotly/tests/test_cross_validation.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,22 @@ def test_it_can_specify_trial_index_correctly(self) -> None:
139139
card.df["arm_name"].unique(),
140140
)
141141

142+
def test_cv_trial_index_filters_to_single_trial(self) -> None:
143+
# cv_trial_index filters CV to only evaluate predictions for observations
144+
# from that trial. Use trial 0 which is in the model's training data.
145+
analysis = CrossValidationPlot(metric_names=["bar"], cv_trial_index=0)
146+
(card,) = analysis.compute(
147+
generation_strategy=self.client.generation_strategy
148+
).flatten()
149+
# Only the arm from trial 0 should appear as a test point
150+
trial_0_arm_name = none_throws(
151+
assert_is_instance(self.client.experiment.trials[0], Trial).arm
152+
).name
153+
self.assertEqual(
154+
list(card.df["arm_name"].unique()),
155+
[trial_0_arm_name],
156+
)
157+
142158
@mock.patch(
143159
"ax.analysis.plotly.cross_validation.cross_validate", wraps=cross_validate
144160
)

ax/analysis/plotly/tests/test_scatter.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -236,6 +236,9 @@ def test_compute_adhoc(self) -> None:
236236
**kwargs,
237237
)
238238

239+
# Normalize timestamps since cards are computed at different times
240+
for card, adhoc_card in zip(cards.flatten(), adhoc_cards.flatten()):
241+
adhoc_card._timestamp = card._timestamp
239242
self.assertEqual(cards, adhoc_cards)
240243

241244
@TestCase.ax_long_test(

0 commit comments

Comments
 (0)