From a371d26aa38aa5dfabe6f4df95b644353d102a01 Mon Sep 17 00:00:00 2001 From: Sunny Shen Date: Wed, 11 Mar 2026 23:46:47 -0700 Subject: [PATCH] Allow custom tolerated_trial_failure_rate (#4954) Summary: Pull Request resolved: https://github.com/facebook/Ax/pull/4954 Allow custom tolerated_trial_failure_rate for failure aware benchmark problems where Standard BO without failure awareness is expected to have a high trial failure rate Reviewed By: Balandat Differential Revision: D87092423 --- ax/benchmark/benchmark.py | 12 ++++++++++++ ax/benchmark/tests/test_benchmark.py | 12 ++++++++++++ 2 files changed, 24 insertions(+) diff --git a/ax/benchmark/benchmark.py b/ax/benchmark/benchmark.py index b698bfdbe24..b8b081f8f41 100644 --- a/ax/benchmark/benchmark.py +++ b/ax/benchmark/benchmark.py @@ -264,6 +264,7 @@ def get_benchmark_orchestrator_options( early_stopping_strategy: BaseEarlyStoppingStrategy | None, include_status_quo: bool = False, logging_level: int = DEFAULT_LOG_LEVEL, + tolerated_trial_failure_rate: float = 0.5, ) -> OrchestratorOptions: """ Get the ``OrchestratorOptions`` for the given ``BenchmarkMethod``. @@ -278,6 +279,8 @@ def get_benchmark_orchestrator_options( early_stopping_strategy: The early stopping strategy to use (if any). include_status_quo: Whether to include the status quo in each trial. logging_level: The logging level to use for the Orchestrator. + tolerated_trial_failure_rate: Fraction of trials allowed to fail without + aborting the optimization. Expects value between 0 and 1. Default is 0.5. Returns: ``OrchestratorOptions`` @@ -299,6 +302,7 @@ def get_benchmark_orchestrator_options( early_stopping_strategy=early_stopping_strategy, status_quo_weight=1.0 if include_status_quo else 0.0, logging_level=logging_level, + tolerated_trial_failure_rate=tolerated_trial_failure_rate, ) @@ -590,6 +594,7 @@ def run_optimization_with_orchestrator( run_trials_in_batches: bool = False, timeout_hours: float | None = None, orchestrator_logging_level: int = DEFAULT_LOG_LEVEL, + tolerated_trial_failure_rate: float = 0.5, ) -> Experiment: """ Optimize the ``problem`` using the ``method`` and ``Orchestrator``, seeding @@ -626,6 +631,7 @@ def run_optimization_with_orchestrator( early_stopping_strategy=method.early_stopping_strategy, include_status_quo=sq_arm is not None, logging_level=orchestrator_logging_level, + tolerated_trial_failure_rate=tolerated_trial_failure_rate, ) runner = get_benchmark_runner( problem=problem, @@ -677,6 +683,7 @@ def benchmark_replication( timeout_hours: float = 4.0, orchestrator_logging_level: int = DEFAULT_LOG_LEVEL, strip_runner_before_saving: bool = True, + tolerated_trial_failure_rate: float = 0.5, ) -> BenchmarkResult: """ Run one benchmarking replication (equivalent to one optimization loop). @@ -714,6 +721,7 @@ def benchmark_replication( run_trials_in_batches=run_trials_in_batches, timeout_hours=timeout_hours, orchestrator_logging_level=orchestrator_logging_level, + tolerated_trial_failure_rate=tolerated_trial_failure_rate, ) benchmark_result = get_benchmark_result_from_experiment_and_gs( @@ -797,6 +805,7 @@ def benchmark_one_method_problem( run_trials_in_batches: bool = False, timeout_hours: float = 4.0, orchestrator_logging_level: int = DEFAULT_LOG_LEVEL, + tolerated_trial_failure_rate: float = 0.5, ) -> AggregatedBenchmarkResult: return AggregatedBenchmarkResult.from_benchmark_results( results=[ @@ -807,6 +816,7 @@ def benchmark_one_method_problem( run_trials_in_batches=run_trials_in_batches, timeout_hours=timeout_hours, orchestrator_logging_level=orchestrator_logging_level, + tolerated_trial_failure_rate=tolerated_trial_failure_rate, ) for seed in seeds ] @@ -820,6 +830,7 @@ def benchmark_multiple_problems_methods( run_trials_in_batches: bool = False, timeout_hours: float = 4.0, orchestrator_logging_level: int = DEFAULT_LOG_LEVEL, + tolerated_trial_failure_rate: float = 0.5, ) -> list[AggregatedBenchmarkResult]: """ For each `problem` and `method` in the Cartesian product of `problems` and @@ -835,6 +846,7 @@ def benchmark_multiple_problems_methods( run_trials_in_batches=run_trials_in_batches, timeout_hours=timeout_hours, orchestrator_logging_level=orchestrator_logging_level, + tolerated_trial_failure_rate=tolerated_trial_failure_rate, ) for p, m in product(problems, methods) ] diff --git a/ax/benchmark/tests/test_benchmark.py b/ax/benchmark/tests/test_benchmark.py index 37e01fdf186..69cbea4bafe 100644 --- a/ax/benchmark/tests/test_benchmark.py +++ b/ax/benchmark/tests/test_benchmark.py @@ -1131,6 +1131,18 @@ def test_get_benchmark_orchestrator_options(self) -> None: self.assertEqual( orchestrator_options.status_quo_weight, 1.0 if include_sq else 0.0 ) + # Default tolerated_trial_failure_rate should be 0.5 + self.assertEqual(orchestrator_options.tolerated_trial_failure_rate, 0.5) + + with self.subTest("custom tolerated_trial_failure_rate"): + orchestrator_options = get_benchmark_orchestrator_options( + batch_size=1, + run_trials_in_batches=False, + max_pending_trials=2, + early_stopping_strategy=None, + tolerated_trial_failure_rate=0.9, + ) + self.assertEqual(orchestrator_options.tolerated_trial_failure_rate, 0.9) def test_replication_with_status_quo(self) -> None: method = BenchmarkMethod(