Skip to content

Commit 316efcd

Browse files
blethamfacebook-github-bot
authored andcommitted
option to use a particular trial for CV test set (#4996)
Summary: In situations where we have a target trial that is significantly different from other trials on the experiment, for instance if there was a system rebase in between two trials, we really care especially about our ability to predict the target trial, as that is what we expect to see moving forward. This adds a kwarg to the DiagnosticAnalysis to specify a particular trial for the diagnostics. All trials will still be used as train arms in each CV fold, but the test arms will be limited to arms in that trial. Differential Revision: D95824210
1 parent 56bdb25 commit 316efcd

4 files changed

Lines changed: 54 additions & 3 deletions

File tree

ax/analysis/diagnostics.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,14 +36,23 @@ class DiagnosticAnalysis(Analysis):
3636
of leave-one-out cross validation.
3737
"""
3838

39-
def __init__(self, include_tracking_metrics: bool = False) -> None:
39+
def __init__(
40+
self,
41+
include_tracking_metrics: bool = False,
42+
test_trial_index: int | None = None,
43+
) -> None:
4044
"""Initialize the DiagnosticAnalysis.
4145
4246
Args:
4347
include_tracking_metrics: Whether to include tracking metrics or just use
4448
the optimization config metrics.
49+
test_trial_index: If provided, limits cross validation to only evaluate
50+
predictions for observations from this trial. Other trials'
51+
observations will still be used for training but will not
52+
appear as test points.
4553
"""
4654
self.include_tracking_metrics = include_tracking_metrics
55+
self.test_trial_index = test_trial_index
4756

4857
@override
4958
def validate_applicable_state(
@@ -80,7 +89,9 @@ def compute(
8089
cross_validation_plots = []
8190
metric_r2_card = []
8291
if not is_bandit:
83-
cv_analysis = CrossValidationPlot(metric_names=metric_names)
92+
cv_analysis = CrossValidationPlot(
93+
metric_names=metric_names, test_trial_index=self.test_trial_index
94+
)
8495
cv_card = cv_analysis.compute_or_error_card(
8596
experiment=experiment,
8697
generation_strategy=generation_strategy,

ax/analysis/plotly/cross_validation.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@ def __init__(
7676
untransform: bool = False,
7777
trial_index: int | None = None,
7878
labels: Mapping[str, str] | None = None,
79+
test_trial_index: int | None = None,
7980
) -> None:
8081
"""
8182
Args:
@@ -99,13 +100,18 @@ def __init__(
99100
trial.
100101
labels: Optional dictionary of labels for the plot. Useful for when metric
101102
names are too long or otherwise challenging to read.
103+
test_trial_index: If provided, limits cross validation to only evaluate
104+
predictions for observations from this trial. Other trials'
105+
observations will still be used for training but will not
106+
appear as test points.
102107
"""
103108

104109
self.metric_names = metric_names
105110
self.folds = folds
106111
self.untransform = untransform
107112
self.trial_index = trial_index
108113
self.labels: dict[str, str] = {**labels} if labels is not None else {}
114+
self.test_trial_index = test_trial_index
109115
self._r2s: dict[str, float] = {}
110116

111117
@override
@@ -138,8 +144,17 @@ def compute(
138144
)
139145

140146
cards = []
147+
148+
test_selector = (
149+
(lambda obs: obs.features.trial_index == self.test_trial_index)
150+
if self.test_trial_index is not None
151+
else None
152+
)
141153
cv_results = cross_validate(
142-
adapter=relevant_adapter, folds=self.folds, untransform=self.untransform
154+
adapter=relevant_adapter,
155+
folds=self.folds,
156+
untransform=self.untransform,
157+
test_selector=test_selector,
143158
)
144159
relevant_adapter_metric_names = [
145160
relevant_adapter._experiment.signature_to_metric[signature].name
@@ -217,6 +232,7 @@ def compute_cross_validation_adhoc(
217232
folds: int = -1,
218233
untransform: bool = True,
219234
labels: Mapping[str, str] | None = None,
235+
test_trial_index: int | None = None,
220236
experiment: Experiment | None = None,
221237
generation_strategy: GenerationStrategy | None = None,
222238
adapter: Adapter | None = None,
@@ -243,6 +259,10 @@ def compute_cross_validation_adhoc(
243259
is.
244260
labels: Optional dictionary of labels for the plot. Useful for when metric
245261
names are too long or otherwise challenging to read.
262+
test_trial_index: If provided, limits cross validation to only evaluate
263+
predictions for observations from this trial. Other trials'
264+
observations will still be used for training but will not
265+
appear as test points.
246266
experiment: Optional. The experiment to extract data from.
247267
generation_strategy: Optional. The generation strategy to extract the adapter
248268
from.
@@ -260,6 +280,7 @@ def compute_cross_validation_adhoc(
260280
folds=folds,
261281
untransform=untransform,
262282
labels=labels,
283+
test_trial_index=test_trial_index,
263284
)
264285

265286
return analysis.compute(

ax/analysis/plotly/tests/test_cross_validation.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,22 @@ def test_it_can_specify_trial_index_correctly(self) -> None:
139139
card.df["arm_name"].unique(),
140140
)
141141

142+
def test_test_trial_index_filters_to_single_trial(self) -> None:
143+
# test_trial_index filters CV to only evaluate predictions for observations
144+
# from that trial. Use trial 0 which is in the model's training data.
145+
analysis = CrossValidationPlot(metric_names=["bar"], test_trial_index=0)
146+
(card,) = analysis.compute(
147+
generation_strategy=self.client.generation_strategy
148+
).flatten()
149+
# Only the arm from trial 0 should appear as a test point
150+
trial_0_arm_name = none_throws(
151+
assert_is_instance(self.client.experiment.trials[0], Trial).arm
152+
).name
153+
self.assertEqual(
154+
list(card.df["arm_name"].unique()),
155+
[trial_0_arm_name],
156+
)
157+
142158
@mock.patch(
143159
"ax.analysis.plotly.cross_validation.cross_validate", wraps=cross_validate
144160
)

ax/analysis/plotly/tests/test_scatter.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -236,6 +236,9 @@ def test_compute_adhoc(self) -> None:
236236
**kwargs,
237237
)
238238

239+
# Normalize timestamps since cards are computed at different times
240+
for card, adhoc_card in zip(cards.flatten(), adhoc_cards.flatten()):
241+
adhoc_card._timestamp = card._timestamp
239242
self.assertEqual(cards, adhoc_cards)
240243

241244
@TestCase.ax_long_test(

0 commit comments

Comments
 (0)