Skip to content

Commit f8106ac

Browse files
ItsMrLinmeta-codesync[bot]
authored andcommitted
Preference-aware cross-validation (#5219)
Summary: Pull Request resolved: #5219 ## Summary When `CrossValidationPlot` encounters preference metrics (e.g., `pairwise_pref_query` backed by PairwiseGP), it switches to a preference-appropriate model quality evaluation: **Regression metrics** (unchanged): predicted-vs-observed scatter + R² — "are predictions accurate?" **Preference metrics** (new): pairwise classification accuracy — "does the model correctly predict which arm is preferred?" Both answer the same question ("is the model trustworthy?") with the appropriate methodology. ### Adapter layer: pair-aware fold splitting `_pairwise_kfold_train_test_split` in `ax/adapter/cross_validation.py` splits by trial_index instead of arm_name. Each pairwise trial contains exactly two arms forming a comparison pair — splitting by trial keeps pairs intact. Only trials with pairwise data are used as fold boundaries. Non-pairwise trials (e.g., the BO trial with tracking metrics) remain in every training fold so the full ModelList can be refitted — all metrics need data in every fold. `compute_pairwise_accuracy` computes the fraction of held-out comparison pairs where the model correctly predicts which arm is preferred (random baseline = 50%). ### Analysis layer: visualization switching `CrossValidationPlot` accepts `preference_metrics: set[str] | None`. When the current metric is in this set, it: 1. Uses `_pairwise_kfold_train_test_split` as the `fold_generator` 2. Computes classification accuracy via `compute_pairwise_accuracy` 3. Renders an accuracy bar chart (scatter + R² is meaningless when observed = binary 0/1 and predicted = latent utility on an incommensurate scale) `DiagnosticAnalysis` no longer filters preference metrics from CV. Instead, it passes `preference_metrics` to `CrossValidationPlot` so it can switch visualization mode per metric. ### Additional fixes in this diff - **Pairwise-aware fold splitting in `best_point.py`**: `get_best_parameters_from_model_predictions_with_trial_index` calls `cross_validate()` internally (for model fit assessment). Previously it used default arm-based fold splitting, which broke pairwise comparison pairs apart — a fold with an odd number of pairwise observations crashes `prep_pairwise_data(reshape to [-1, 2])`. Now passes `_pairwise_kfold_train_test_split` as fold generator when preference metrics are present. ### Compatibility with SAAS + PairwiseGP ModelLists A fully-Bayesian SAAS outcome model combined with a PairwiseGP in a `ModelList` yields posteriors with an extra leading MCMC-sample batch dimension. This is already handled in `predict_from_model` (landed in D99037272), which the CV path reaches via `BoTorchGenerator.cross_validate -> Surrogate.predict`, so no change in `torch.py` is required. This diff adds an end-to-end regression test (`test_pairwise_cv_with_saas_pairwise_modellist`) that fits a real SAAS + PairwiseGP `ModelList` and runs preference-aware cross-validation — a path that previously had no coverage. Differential Revision: D99151833 fbshipit-source-id: 9e7a0df0521ab556ddb7adbb9f66cecf47aed4b4
1 parent c40dfae commit f8106ac

6 files changed

Lines changed: 627 additions & 72 deletions

File tree

ax/adapter/cross_validation.py

Lines changed: 139 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -656,6 +656,145 @@ def _kfold_train_test_split(
656656
)
657657

658658

659+
def _pairwise_kfold_train_test_split(
660+
folds: int,
661+
training_data: ExperimentData,
662+
preference_metric_name: str | None = None,
663+
) -> Iterable[CVData]:
664+
"""Return train/test CV splits based on trial indices.
665+
666+
For pairwise/preference data, each comparison trial contains exactly two
667+
arms forming a pair. Splitting by trial_index (instead of arm_name) keeps
668+
comparison pairs intact -- both arms of a pair are always together in
669+
the same fold.
670+
671+
When ``preference_metric_name`` is provided, only trials with observations
672+
for that metric are used as fold boundaries. This prevents holding out
673+
trials that have no pairwise data (e.g., a BO trial with only tracking
674+
metrics), which would remove all data for those metrics and cause
675+
downstream model-fitting failures.
676+
677+
Args:
678+
folds: Number of folds. Use -1 for leave-one-out CV (one trial per fold).
679+
training_data: Training data to split.
680+
preference_metric_name: If provided, only split on trials that have
681+
non-NaN observations for this metric.
682+
683+
Returns:
684+
Yields CVData with train/test splits.
685+
"""
686+
if preference_metric_name is not None:
687+
obs_mean = training_data.observation_data["mean"]
688+
if preference_metric_name in obs_mean.columns:
689+
non_null = obs_mean[preference_metric_name].dropna()
690+
trial_indices = sorted(set(non_null.index.get_level_values("trial_index")))
691+
else:
692+
trial_indices = sorted(
693+
set(training_data.arm_data.index.get_level_values("trial_index"))
694+
)
695+
else:
696+
trial_indices = sorted(
697+
set(training_data.arm_data.index.get_level_values("trial_index"))
698+
)
699+
n = len(trial_indices)
700+
if n < 2:
701+
raise UnsupportedError(
702+
"Pairwise cross validation requires at least two trials in the "
703+
f"training data. Only {n} trials were found."
704+
)
705+
elif folds > n:
706+
raise ValueError(
707+
f"Training data only has {n} trials, which is less than {folds} folds."
708+
)
709+
elif folds < 2 and folds != -1:
710+
raise ValueError("Folds must be -1 for LOO, or > 1.")
711+
elif folds == -1:
712+
folds = n
713+
714+
trial_arr = np.array(trial_indices)
715+
if folds != n:
716+
np.random.shuffle(trial_arr)
717+
test_size = n // folds
718+
final_size = test_size + (n - folds * test_size)
719+
for fold in range(folds):
720+
trial_arr = np.roll(trial_arr, test_size)
721+
n_test = test_size if fold < folds - 1 else final_size
722+
train_trials = set(trial_arr[:-n_test].tolist())
723+
test_trials_set = set(trial_arr[-n_test:].tolist())
724+
# Non-split trials (e.g., BO trials with only tracking metrics)
725+
# must stay in every training fold so the full ModelList can be
726+
# refitted (all metrics need data).
727+
# TODO(D94970662): Once Experiment._trial_type_to_metric_names
728+
# lands, use experiment.trials_for_type() instead of data-driven
729+
# inference to determine split-eligible vs always-train trials.
730+
all_trials = set(training_data.arm_data.index.get_level_values("trial_index"))
731+
train_trials = train_trials | (all_trials - set(trial_indices))
732+
yield CVData(
733+
training_data=training_data.filter_by_trial_index(
734+
trial_indices=train_trials
735+
),
736+
test_data=training_data.filter_by_trial_index(
737+
trial_indices=test_trials_set
738+
),
739+
)
740+
741+
742+
def compute_pairwise_accuracy(
743+
cv_results: list[CVResult],
744+
metric_name: str,
745+
) -> float:
746+
"""Compute classification accuracy for pairwise preference CV.
747+
748+
For each held-out comparison pair (two arms from the same trial), checks
749+
whether the model correctly predicts which arm is preferred (i.e., the arm
750+
with higher predicted utility matches the arm with observed label = 1).
751+
752+
Args:
753+
cv_results: Cross-validation results from cross_validate().
754+
metric_name: The preference metric name to evaluate.
755+
756+
Returns:
757+
Fraction of correctly predicted comparisons (random baseline = 0.5).
758+
"""
759+
# Group CV results by trial_index to find comparison pairs.
760+
trial_groups: dict[int | None, list[CVResult]] = defaultdict(list)
761+
for result in cv_results:
762+
trial_groups[result.observed.features.trial_index].append(result)
763+
764+
correct = 0
765+
total = 0
766+
for _trial_idx, results in trial_groups.items():
767+
if len(results) != 2:
768+
continue
769+
r0, r1 = results
770+
771+
# Find the metric index in the observation data.
772+
try:
773+
idx0 = list(r0.observed.data.metric_signatures).index(metric_name)
774+
idx1 = list(r1.observed.data.metric_signatures).index(metric_name)
775+
pred_idx0 = list(r0.predicted.metric_signatures).index(metric_name)
776+
pred_idx1 = list(r1.predicted.metric_signatures).index(metric_name)
777+
except ValueError:
778+
continue
779+
780+
obs0 = r0.observed.data.means[idx0]
781+
obs1 = r1.observed.data.means[idx1]
782+
pred0 = r0.predicted.means[pred_idx0]
783+
pred1 = r1.predicted.means[pred_idx1]
784+
785+
# The arm with observed label = 1 is preferred. Check if the model
786+
# predicts higher utility for the preferred arm.
787+
if obs0 == obs1:
788+
continue # Skip ties in observed labels
789+
preferred_is_0 = obs0 > obs1
790+
model_predicts_0 = pred0 > pred1
791+
if preferred_is_0 == model_predicts_0:
792+
correct += 1
793+
total += 1
794+
795+
return correct / total if total > 0 else 0.0
796+
797+
659798
def gen_trial_split(
660799
training_data: ExperimentData,
661800
test_trials: list[int],

0 commit comments

Comments
 (0)