Commit f8106ac
Preference-aware cross-validation (#5219)
Summary:
Pull Request resolved: #5219
## Summary
When `CrossValidationPlot` encounters preference metrics (e.g., `pairwise_pref_query` backed by PairwiseGP), it switches to a preference-appropriate model quality evaluation:
**Regression metrics** (unchanged): predicted-vs-observed scatter + R² — "are predictions accurate?"
**Preference metrics** (new): pairwise classification accuracy — "does the model correctly predict which arm is preferred?"
Both answer the same question ("is the model trustworthy?") with the appropriate methodology.
### Adapter layer: pair-aware fold splitting
`_pairwise_kfold_train_test_split` in `ax/adapter/cross_validation.py` splits by trial_index instead of arm_name. Each pairwise trial contains exactly two arms forming a comparison pair — splitting by trial keeps pairs intact.
Only trials with pairwise data are used as fold boundaries. Non-pairwise trials (e.g., the BO trial with tracking metrics) remain in every training fold so the full ModelList can be refitted — all metrics need data in every fold.
`compute_pairwise_accuracy` computes the fraction of held-out comparison pairs where the model correctly predicts which arm is preferred (random baseline = 50%).
### Analysis layer: visualization switching
`CrossValidationPlot` accepts `preference_metrics: set[str] | None`. When the current metric is in this set, it:
1. Uses `_pairwise_kfold_train_test_split` as the `fold_generator`
2. Computes classification accuracy via `compute_pairwise_accuracy`
3. Renders an accuracy bar chart (scatter + R² is meaningless when observed = binary 0/1 and predicted = latent utility on an incommensurate scale)
`DiagnosticAnalysis` no longer filters preference metrics from CV. Instead, it passes `preference_metrics` to `CrossValidationPlot` so it can switch visualization mode per metric.
### Additional fixes in this diff
- **Pairwise-aware fold splitting in `best_point.py`**: `get_best_parameters_from_model_predictions_with_trial_index` calls `cross_validate()` internally (for model fit assessment). Previously it used default arm-based fold splitting, which broke pairwise comparison pairs apart — a fold with an odd number of pairwise observations crashes `prep_pairwise_data(reshape to [-1, 2])`. Now passes `_pairwise_kfold_train_test_split` as fold generator when preference metrics are present.
### Compatibility with SAAS + PairwiseGP ModelLists
A fully-Bayesian SAAS outcome model combined with a PairwiseGP in a `ModelList` yields posteriors with an extra leading MCMC-sample batch dimension. This is already handled in `predict_from_model` (landed in D99037272), which the CV path reaches via `BoTorchGenerator.cross_validate -> Surrogate.predict`, so no change in `torch.py` is required. This diff adds an end-to-end regression test (`test_pairwise_cv_with_saas_pairwise_modellist`) that fits a real SAAS + PairwiseGP `ModelList` and runs preference-aware cross-validation — a path that previously had no coverage.
Differential Revision: D99151833
fbshipit-source-id: 9e7a0df0521ab556ddb7adbb9f66cecf47aed4b41 parent c40dfae commit f8106ac
6 files changed
Lines changed: 627 additions & 72 deletions
File tree
- ax
- adapter
- tests
- analysis
- plotly
- tests
- service/utils
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
656 | 656 | | |
657 | 657 | | |
658 | 658 | | |
| 659 | + | |
| 660 | + | |
| 661 | + | |
| 662 | + | |
| 663 | + | |
| 664 | + | |
| 665 | + | |
| 666 | + | |
| 667 | + | |
| 668 | + | |
| 669 | + | |
| 670 | + | |
| 671 | + | |
| 672 | + | |
| 673 | + | |
| 674 | + | |
| 675 | + | |
| 676 | + | |
| 677 | + | |
| 678 | + | |
| 679 | + | |
| 680 | + | |
| 681 | + | |
| 682 | + | |
| 683 | + | |
| 684 | + | |
| 685 | + | |
| 686 | + | |
| 687 | + | |
| 688 | + | |
| 689 | + | |
| 690 | + | |
| 691 | + | |
| 692 | + | |
| 693 | + | |
| 694 | + | |
| 695 | + | |
| 696 | + | |
| 697 | + | |
| 698 | + | |
| 699 | + | |
| 700 | + | |
| 701 | + | |
| 702 | + | |
| 703 | + | |
| 704 | + | |
| 705 | + | |
| 706 | + | |
| 707 | + | |
| 708 | + | |
| 709 | + | |
| 710 | + | |
| 711 | + | |
| 712 | + | |
| 713 | + | |
| 714 | + | |
| 715 | + | |
| 716 | + | |
| 717 | + | |
| 718 | + | |
| 719 | + | |
| 720 | + | |
| 721 | + | |
| 722 | + | |
| 723 | + | |
| 724 | + | |
| 725 | + | |
| 726 | + | |
| 727 | + | |
| 728 | + | |
| 729 | + | |
| 730 | + | |
| 731 | + | |
| 732 | + | |
| 733 | + | |
| 734 | + | |
| 735 | + | |
| 736 | + | |
| 737 | + | |
| 738 | + | |
| 739 | + | |
| 740 | + | |
| 741 | + | |
| 742 | + | |
| 743 | + | |
| 744 | + | |
| 745 | + | |
| 746 | + | |
| 747 | + | |
| 748 | + | |
| 749 | + | |
| 750 | + | |
| 751 | + | |
| 752 | + | |
| 753 | + | |
| 754 | + | |
| 755 | + | |
| 756 | + | |
| 757 | + | |
| 758 | + | |
| 759 | + | |
| 760 | + | |
| 761 | + | |
| 762 | + | |
| 763 | + | |
| 764 | + | |
| 765 | + | |
| 766 | + | |
| 767 | + | |
| 768 | + | |
| 769 | + | |
| 770 | + | |
| 771 | + | |
| 772 | + | |
| 773 | + | |
| 774 | + | |
| 775 | + | |
| 776 | + | |
| 777 | + | |
| 778 | + | |
| 779 | + | |
| 780 | + | |
| 781 | + | |
| 782 | + | |
| 783 | + | |
| 784 | + | |
| 785 | + | |
| 786 | + | |
| 787 | + | |
| 788 | + | |
| 789 | + | |
| 790 | + | |
| 791 | + | |
| 792 | + | |
| 793 | + | |
| 794 | + | |
| 795 | + | |
| 796 | + | |
| 797 | + | |
659 | 798 | | |
660 | 799 | | |
661 | 800 | | |
| |||
0 commit comments