Skip to content

Commit f3e8831

Browse files
sdaultonmeta-codesync[bot]
authored andcommitted
Reduce botorch dependency creep in Ax (facebook#5161)
Summary: Pull Request resolved: facebook#5161 Reviewed By: saitcakmak Differential Revision: D100111247 fbshipit-source-id: b74e2a913669e5a70f34f0bf24e484b3adcf36a9
1 parent 6244798 commit f3e8831

3 files changed

Lines changed: 26 additions & 26 deletions

File tree

ax/analysis/healthcheck/complexity_rating.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -121,8 +121,14 @@ def compute(
121121
options = self.options if self.options is not None else OrchestratorOptions()
122122
optimization_summary = summarize_ax_optimization_complexity(
123123
experiment=experiment,
124-
options=options,
125124
tier_metadata=self.tier_metadata,
125+
uses_early_stopping=options.early_stopping_strategy is not None,
126+
uses_global_stopping=options.global_stopping_strategy is not None,
127+
tolerated_trial_failure_rate=options.tolerated_trial_failure_rate,
128+
max_pending_trials=options.max_pending_trials,
129+
min_failed_trials_for_failure_rate_check=(
130+
options.min_failed_trials_for_failure_rate_check
131+
),
126132
)
127133

128134
# Determine tier

ax/utils/common/complexity_utils.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
from ax.adapter.parameter_utils import can_map_to_binary, is_unordered_choice
1111
from ax.core.experiment import Experiment
1212
from ax.exceptions.core import OptimizationNotConfiguredError
13-
from ax.orchestration.orchestrator import OrchestratorOptions
1413
from ax.utils.common.tier_utils import ( # noqa: F401
1514
check_if_in_standard,
1615
DEFAULT_TIER_MESSAGES,
@@ -22,8 +21,12 @@
2221

2322
def summarize_ax_optimization_complexity(
2423
experiment: Experiment,
25-
options: OrchestratorOptions,
2624
tier_metadata: dict[str, Any],
25+
uses_early_stopping: bool = False,
26+
uses_global_stopping: bool = False,
27+
tolerated_trial_failure_rate: float | None = 0.5,
28+
max_pending_trials: int | None = 10,
29+
min_failed_trials_for_failure_rate_check: int | None = 5,
2730
) -> OptimizationSummary:
2831
"""Summarize the experiment's optimization complexity.
2932
@@ -32,8 +35,13 @@ def summarize_ax_optimization_complexity(
3235
3336
Args:
3437
experiment: The Ax Experiment.
35-
options: The orchestrator options.
3638
tier_metadata: tier-related meta-data from the orchestrator.
39+
uses_early_stopping: Whether early stopping is enabled.
40+
uses_global_stopping: Whether global stopping is enabled.
41+
tolerated_trial_failure_rate: The tolerated trial failure rate.
42+
max_pending_trials: The maximum number of pending trials.
43+
min_failed_trials_for_failure_rate_check: The minimum number of failed
44+
trials before checking the failure rate.
3745
3846
Returns:
3947
A dictionary summarizing the experiment.
@@ -60,8 +68,6 @@ def summarize_ax_optimization_complexity(
6068
else 1
6169
)
6270
num_outcome_constraints = len(optimization_config.outcome_constraints)
63-
uses_early_stopping = options.early_stopping_strategy is not None
64-
uses_global_stopping = options.global_stopping_strategy is not None
6571

6672
# Check if any metrics use merge_multiple_curves
6773
uses_merge_multiple_curves = False
@@ -93,9 +99,9 @@ def summarize_ax_optimization_complexity(
9399
uses_global_stopping=uses_global_stopping,
94100
uses_merge_multiple_curves=uses_merge_multiple_curves,
95101
uses_standard_api=uses_standard_api,
96-
tolerated_trial_failure_rate=options.tolerated_trial_failure_rate,
97-
max_pending_trials=options.max_pending_trials,
102+
tolerated_trial_failure_rate=tolerated_trial_failure_rate,
103+
max_pending_trials=max_pending_trials,
98104
min_failed_trials_for_failure_rate_check=(
99-
options.min_failed_trials_for_failure_rate_check
105+
min_failed_trials_for_failure_rate_check
100106
),
101107
)

ax/utils/common/tests/test_complexity_utils.py

Lines changed: 5 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88

99
from ax.core.metric import Metric
1010
from ax.exceptions.core import OptimizationNotConfiguredError, UserInputError
11-
from ax.orchestration.orchestrator import OrchestratorOptions
1211
from ax.utils.common.complexity_utils import (
1312
check_if_in_standard,
1413
DEFAULT_TIER_MESSAGES,
@@ -29,7 +28,6 @@ class TestSummarizeAxOptimizationComplexity(TestCase):
2928
def setUp(self) -> None:
3029
super().setUp()
3130
self.experiment = get_experiment()
32-
self.options = OrchestratorOptions()
3331
self.tier_metadata: dict[str, object] = {}
3432

3533
def test_basic_experiment_summary(self) -> None:
@@ -38,7 +36,6 @@ def test_basic_experiment_summary(self) -> None:
3836
# WHEN we summarize the experiment
3937
summary = summarize_ax_optimization_complexity(
4038
experiment=self.experiment,
41-
options=self.options,
4239
tier_metadata=self.tier_metadata,
4340
)
4441

@@ -57,7 +54,6 @@ def test_multi_objective_experiment(self) -> None:
5754
# WHEN we summarize the experiment
5855
summary = summarize_ax_optimization_complexity(
5956
experiment=experiment,
60-
options=self.options,
6157
tier_metadata=self.tier_metadata,
6258
)
6359

@@ -75,7 +71,6 @@ def test_experiment_without_optimization_config_raises(self) -> None:
7571
):
7672
summarize_ax_optimization_complexity(
7773
experiment=self.experiment,
78-
options=self.options,
7974
tier_metadata=self.tier_metadata,
8075
)
8176

@@ -106,7 +101,6 @@ def test_tier_metadata_extraction(self) -> None:
106101
# WHEN we summarize the experiment
107102
summary = summarize_ax_optimization_complexity(
108103
experiment=self.experiment,
109-
options=self.options,
110104
tier_metadata=tier_metadata,
111105
)
112106

@@ -115,21 +109,17 @@ def test_tier_metadata_extraction(self) -> None:
115109
self.assertEqual(summary.uses_standard_api, expected_all_configs)
116110

117111
def test_orchestrator_options_extraction(self) -> None:
118-
# GIVEN custom orchestrator options
119-
options = OrchestratorOptions(
120-
tolerated_trial_failure_rate=0.25,
121-
max_pending_trials=5,
122-
min_failed_trials_for_failure_rate_check=10,
123-
)
124-
112+
# GIVEN custom options
125113
# WHEN we summarize the experiment
126114
summary = summarize_ax_optimization_complexity(
127115
experiment=self.experiment,
128-
options=options,
129116
tier_metadata=self.tier_metadata,
117+
tolerated_trial_failure_rate=0.25,
118+
max_pending_trials=5,
119+
min_failed_trials_for_failure_rate_check=10,
130120
)
131121

132-
# THEN the summary should reflect orchestrator options
122+
# THEN the summary should reflect the options
133123
self.assertEqual(summary.tolerated_trial_failure_rate, 0.25)
134124
self.assertEqual(summary.max_pending_trials, 5)
135125
self.assertEqual(summary.min_failed_trials_for_failure_rate_check, 10)
@@ -141,7 +131,6 @@ def test_parameter_constraints_counted(self) -> None:
141131
# WHEN we summarize the experiment
142132
summary = summarize_ax_optimization_complexity(
143133
experiment=experiment,
144-
options=self.options,
145134
tier_metadata=self.tier_metadata,
146135
)
147136

@@ -158,7 +147,6 @@ def test_merge_multiple_curves_detection(self) -> None:
158147
# WHEN we summarize the experiment
159148
summary = summarize_ax_optimization_complexity(
160149
experiment=self.experiment,
161-
options=self.options,
162150
tier_metadata=self.tier_metadata,
163151
)
164152

0 commit comments

Comments
 (0)