Skip to content

Commit 5966881

Browse files
Sunny Shenfacebook-github-bot
authored andcommitted
Allow custom tolerated_trial_failure_rate
Summary: Allow custom tolerated_trial_failure_rate for failure aware benchmark problems where Standard BO without failure awareness is expected to have a high trial failure rate Differential Revision: D87092423
1 parent 0d105cf commit 5966881

2 files changed

Lines changed: 24 additions & 0 deletions

File tree

ax/benchmark/benchmark.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -260,6 +260,7 @@ def get_benchmark_orchestrator_options(
260260
early_stopping_strategy: BaseEarlyStoppingStrategy | None,
261261
include_status_quo: bool = False,
262262
logging_level: int = DEFAULT_LOG_LEVEL,
263+
tolerated_trial_failure_rate: float = 0.5,
263264
) -> OrchestratorOptions:
264265
"""
265266
Get the ``OrchestratorOptions`` for the given ``BenchmarkMethod``.
@@ -274,6 +275,8 @@ def get_benchmark_orchestrator_options(
274275
early_stopping_strategy: The early stopping strategy to use (if any).
275276
include_status_quo: Whether to include the status quo in each trial.
276277
logging_level: The logging level to use for the Orchestrator.
278+
tolerated_trial_failure_rate: Fraction of trials allowed to fail without
279+
aborting the optimization. Expects value between 0 and 1. Default is 0.5.
277280
278281
Returns:
279282
``OrchestratorOptions``
@@ -295,6 +298,7 @@ def get_benchmark_orchestrator_options(
295298
early_stopping_strategy=early_stopping_strategy,
296299
status_quo_weight=1.0 if include_status_quo else 0.0,
297300
logging_level=logging_level,
301+
tolerated_trial_failure_rate=tolerated_trial_failure_rate,
298302
)
299303

300304

@@ -578,6 +582,7 @@ def run_optimization_with_orchestrator(
578582
run_trials_in_batches: bool = False,
579583
timeout_hours: float | None = None,
580584
orchestrator_logging_level: int = DEFAULT_LOG_LEVEL,
585+
tolerated_trial_failure_rate: float = 0.5,
581586
) -> Experiment:
582587
"""
583588
Optimize the ``problem`` using the ``method`` and ``Orchestrator``, seeding
@@ -614,6 +619,7 @@ def run_optimization_with_orchestrator(
614619
early_stopping_strategy=method.early_stopping_strategy,
615620
include_status_quo=sq_arm is not None,
616621
logging_level=orchestrator_logging_level,
622+
tolerated_trial_failure_rate=tolerated_trial_failure_rate,
617623
)
618624

619625
# Use custom runner if provided on the problem, otherwise create standard runner
@@ -671,6 +677,7 @@ def benchmark_replication(
671677
timeout_hours: float = 4.0,
672678
orchestrator_logging_level: int = DEFAULT_LOG_LEVEL,
673679
strip_runner_before_saving: bool = True,
680+
tolerated_trial_failure_rate: float = 0.5,
674681
) -> BenchmarkResult:
675682
"""
676683
Run one benchmarking replication (equivalent to one optimization loop).
@@ -708,6 +715,7 @@ def benchmark_replication(
708715
run_trials_in_batches=run_trials_in_batches,
709716
timeout_hours=timeout_hours,
710717
orchestrator_logging_level=orchestrator_logging_level,
718+
tolerated_trial_failure_rate=tolerated_trial_failure_rate,
711719
)
712720

713721
benchmark_result = get_benchmark_result_from_experiment_and_gs(
@@ -789,6 +797,7 @@ def benchmark_one_method_problem(
789797
run_trials_in_batches: bool = False,
790798
timeout_hours: float = 4.0,
791799
orchestrator_logging_level: int = DEFAULT_LOG_LEVEL,
800+
tolerated_trial_failure_rate: float = 0.5,
792801
) -> AggregatedBenchmarkResult:
793802
return AggregatedBenchmarkResult.from_benchmark_results(
794803
results=[
@@ -799,6 +808,7 @@ def benchmark_one_method_problem(
799808
run_trials_in_batches=run_trials_in_batches,
800809
timeout_hours=timeout_hours,
801810
orchestrator_logging_level=orchestrator_logging_level,
811+
tolerated_trial_failure_rate=tolerated_trial_failure_rate,
802812
)
803813
for seed in seeds
804814
]
@@ -812,6 +822,7 @@ def benchmark_multiple_problems_methods(
812822
run_trials_in_batches: bool = False,
813823
timeout_hours: float = 4.0,
814824
orchestrator_logging_level: int = DEFAULT_LOG_LEVEL,
825+
tolerated_trial_failure_rate: float = 0.5,
815826
) -> list[AggregatedBenchmarkResult]:
816827
"""
817828
For each `problem` and `method` in the Cartesian product of `problems` and
@@ -827,6 +838,7 @@ def benchmark_multiple_problems_methods(
827838
run_trials_in_batches=run_trials_in_batches,
828839
timeout_hours=timeout_hours,
829840
orchestrator_logging_level=orchestrator_logging_level,
841+
tolerated_trial_failure_rate=tolerated_trial_failure_rate,
830842
)
831843
for p, m in product(problems, methods)
832844
]

ax/benchmark/tests/test_benchmark.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1131,6 +1131,18 @@ def test_get_benchmark_orchestrator_options(self) -> None:
11311131
self.assertEqual(
11321132
orchestrator_options.status_quo_weight, 1.0 if include_sq else 0.0
11331133
)
1134+
# Default tolerated_trial_failure_rate should be 0.5
1135+
self.assertEqual(orchestrator_options.tolerated_trial_failure_rate, 0.5)
1136+
1137+
with self.subTest("custom tolerated_trial_failure_rate"):
1138+
orchestrator_options = get_benchmark_orchestrator_options(
1139+
batch_size=1,
1140+
run_trials_in_batches=False,
1141+
max_pending_trials=2,
1142+
early_stopping_strategy=None,
1143+
tolerated_trial_failure_rate=0.9,
1144+
)
1145+
self.assertEqual(orchestrator_options.tolerated_trial_failure_rate, 0.9)
11341146

11351147
def test_replication_with_status_quo(self) -> None:
11361148
method = BenchmarkMethod(

0 commit comments

Comments
 (0)