Skip to content

Commit f5976b8

Browse files
saitcakmakmeta-codesync[bot]
authored andcommitted
Use total-order sensitivity for high-dimensional experiments (#5115)
Summary: Pull Request resolved: #5115 Switch InsightsAnalysis to use total-order Sobol sensitivity analysis instead of second-order when the experiment has more than 25 parameters. Second-order sensitivity computes pairwise interaction effects at O(p^2) cost, which becomes expensive for high-dimensional search spaces. Total-order captures each parameter's overall importance including all interactions at O(p) cost. For a 93-parameter GAIN experiment (GAIN_35259), total-order sensitivity is 2.7x faster than second-order (16s vs 43s per metric, with a pre-fitted adapter). The threshold of 25 parameters preserves second-order analysis (and contour plots showing pairwise interactions) for lower-dimensional experiments where the O(p^2) cost is manageable. Reviewed By: mpolson64 Differential Revision: D98506213 fbshipit-source-id: 3bb7d4d006d25fc62c0823a4e3ee108f8263745d
1 parent 60600e0 commit f5976b8

2 files changed

Lines changed: 84 additions & 7 deletions

File tree

ax/analysis/insights.py

Lines changed: 31 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,13 @@
2424
from pyre_extensions import none_throws, override
2525

2626

27+
# When the number of parameters exceeds this threshold, use total-order
28+
# sensitivity analysis instead of second-order. Second-order computes pairwise
29+
# interaction effects (O(p^2)) which becomes expensive for high-dimensional
30+
# search spaces, while total-order captures each parameter's overall importance
31+
# including all interactions at O(p) cost.
32+
_MAX_NUM_PARAMS_FOR_SECOND_ORDER: int = 25
33+
2734
INSIGHTS_CARDGROUP_TITLE = "Insights Analysis"
2835

2936
INSIGHTS_CARDGROUP_SUBTITLE = (
@@ -33,6 +40,26 @@
3340
)
3441

3542

43+
def _choose_sensitivity_order(
44+
num_params: int,
45+
) -> Literal["first", "second", "total"]:
46+
"""Choose the sensitivity analysis order based on parameter count.
47+
48+
- 1 parameter: first-order (second-order requires >= 2 for interaction
49+
effects).
50+
- Many parameters (> threshold): total-order to avoid the O(p^2) cost of
51+
second-order pairwise interactions.
52+
- Otherwise: second-order to surface pairwise interactions for contour
53+
plots.
54+
"""
55+
if num_params == 1:
56+
return "first"
57+
elif num_params > _MAX_NUM_PARAMS_FOR_SECOND_ORDER:
58+
return "total"
59+
else:
60+
return "second"
61+
62+
3663
@final
3764
class InsightsAnalysis(Analysis):
3865
"""
@@ -116,18 +143,16 @@ def compute(
116143
# For non-bandit experiments, for each objective and constraint, compute a
117144
# sensitivity analysis and plot the top 3 surfaces.
118145
else:
119-
# Default to second-order sensitivity analysis, but fall back to first-order
120-
# if there is only one parameter (second-order requires at least 2
121-
# parameters for interaction effects).
122-
order: Literal["first", "second"] = (
123-
"first" if len(experiment.search_space.parameters) == 1 else "second"
146+
num_params = len(experiment.search_space.parameters)
147+
sensitivity_order: Literal["first", "second", "total"] = (
148+
_choose_sensitivity_order(num_params=num_params)
124149
)
125150
top_surfaces_groups = [
126151
TopSurfacesAnalysis(
127152
metric_name=metric_name,
128153
top_k=3,
129154
relativize=relativize,
130-
order=order,
155+
order=sensitivity_order,
131156
).compute_or_error_card(
132157
experiment=experiment,
133158
generation_strategy=generation_strategy,

ax/analysis/tests/test_overview.py

Lines changed: 53 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,16 @@
77

88

99
from datetime import datetime
10+
from unittest.mock import patch
1011

1112
import pandas as pd
1213
from ax.adapter.base import Adapter
1314
from ax.adapter.registry import Generators
14-
from ax.analysis.insights import InsightsAnalysis
15+
from ax.analysis.insights import _MAX_NUM_PARAMS_FOR_SECOND_ORDER, InsightsAnalysis
1516
from ax.analysis.overview import OverviewAnalysis
1617
from ax.analysis.plotly.arm_effects import ArmEffectsPlot
1718
from ax.analysis.plotly.scatter import ScatterPlot
19+
from ax.analysis.plotly.top_surfaces import TopSurfacesAnalysis
1820
from ax.analysis.results import ResultsAnalysis
1921
from ax.api.client import Client
2022
from ax.api.configs import RangeParameterConfig
@@ -448,3 +450,53 @@ def test_insights_analysis_single_parameter(self) -> None:
448450
# Check that none of the cards are error cards
449451
for card in all_cards:
450452
self.assertNotIsInstance(card, ErrorAnalysisCard)
453+
454+
@mock_botorch_optimize
455+
def test_insights_analysis_many_parameters_uses_total_order(self) -> None:
456+
"""Test that InsightsAnalysis uses total-order sensitivity for
457+
high-dimensional experiments (> _MAX_NUM_PARAMS_FOR_SECOND_ORDER
458+
parameters) to avoid the O(p^2) cost of second-order analysis.
459+
"""
460+
num_params = _MAX_NUM_PARAMS_FOR_SECOND_ORDER + 1
461+
client = Client()
462+
client.configure_experiment(
463+
name="many_params",
464+
parameters=[
465+
RangeParameterConfig(
466+
name=f"x{i}",
467+
bounds=(0.0, 1.0),
468+
parameter_type="float",
469+
)
470+
for i in range(num_params)
471+
],
472+
)
473+
client.configure_optimization(objective="objective_metric")
474+
475+
for _ in range(num_params + 2):
476+
for trial_index, parameters in client.get_next_trials(max_trials=1).items():
477+
client.complete_trial(
478+
trial_index=trial_index,
479+
raw_data={
480+
"objective_metric": sum(float(v) for v in parameters.values()),
481+
},
482+
)
483+
484+
# Patch TopSurfacesAnalysis.__init__ to capture the order argument
485+
original_init: object = TopSurfacesAnalysis.__init__
486+
captured_orders: list[object] = []
487+
488+
def patched_init(self_inner: TopSurfacesAnalysis, **kwargs: object) -> None:
489+
captured_orders.append(kwargs.get("order", "second"))
490+
# pyre-ignore[29]: `object` is not callable
491+
original_init(self_inner, **kwargs)
492+
493+
with patch.object(TopSurfacesAnalysis, "__init__", patched_init):
494+
InsightsAnalysis().compute(
495+
experiment=client._experiment,
496+
generation_strategy=client._generation_strategy,
497+
)
498+
499+
# Should use total-order for many parameters
500+
self.assertGreater(len(captured_orders), 0)
501+
for order in captured_orders:
502+
self.assertEqual(order, "total")

0 commit comments

Comments
 (0)