diff --git a/ax/benchmark/benchmark.py b/ax/benchmark/benchmark.py index b698bfdbe24..b8b081f8f41 100644 --- a/ax/benchmark/benchmark.py +++ b/ax/benchmark/benchmark.py @@ -264,6 +264,7 @@ def get_benchmark_orchestrator_options( early_stopping_strategy: BaseEarlyStoppingStrategy | None, include_status_quo: bool = False, logging_level: int = DEFAULT_LOG_LEVEL, + tolerated_trial_failure_rate: float = 0.5, ) -> OrchestratorOptions: """ Get the ``OrchestratorOptions`` for the given ``BenchmarkMethod``. @@ -278,6 +279,8 @@ def get_benchmark_orchestrator_options( early_stopping_strategy: The early stopping strategy to use (if any). include_status_quo: Whether to include the status quo in each trial. logging_level: The logging level to use for the Orchestrator. + tolerated_trial_failure_rate: Fraction of trials allowed to fail without + aborting the optimization. Expects value between 0 and 1. Default is 0.5. Returns: ``OrchestratorOptions`` @@ -299,6 +302,7 @@ def get_benchmark_orchestrator_options( early_stopping_strategy=early_stopping_strategy, status_quo_weight=1.0 if include_status_quo else 0.0, logging_level=logging_level, + tolerated_trial_failure_rate=tolerated_trial_failure_rate, ) @@ -590,6 +594,7 @@ def run_optimization_with_orchestrator( run_trials_in_batches: bool = False, timeout_hours: float | None = None, orchestrator_logging_level: int = DEFAULT_LOG_LEVEL, + tolerated_trial_failure_rate: float = 0.5, ) -> Experiment: """ Optimize the ``problem`` using the ``method`` and ``Orchestrator``, seeding @@ -626,6 +631,7 @@ def run_optimization_with_orchestrator( early_stopping_strategy=method.early_stopping_strategy, include_status_quo=sq_arm is not None, logging_level=orchestrator_logging_level, + tolerated_trial_failure_rate=tolerated_trial_failure_rate, ) runner = get_benchmark_runner( problem=problem, @@ -677,6 +683,7 @@ def benchmark_replication( timeout_hours: float = 4.0, orchestrator_logging_level: int = DEFAULT_LOG_LEVEL, strip_runner_before_saving: bool = True, + tolerated_trial_failure_rate: float = 0.5, ) -> BenchmarkResult: """ Run one benchmarking replication (equivalent to one optimization loop). @@ -714,6 +721,7 @@ def benchmark_replication( run_trials_in_batches=run_trials_in_batches, timeout_hours=timeout_hours, orchestrator_logging_level=orchestrator_logging_level, + tolerated_trial_failure_rate=tolerated_trial_failure_rate, ) benchmark_result = get_benchmark_result_from_experiment_and_gs( @@ -797,6 +805,7 @@ def benchmark_one_method_problem( run_trials_in_batches: bool = False, timeout_hours: float = 4.0, orchestrator_logging_level: int = DEFAULT_LOG_LEVEL, + tolerated_trial_failure_rate: float = 0.5, ) -> AggregatedBenchmarkResult: return AggregatedBenchmarkResult.from_benchmark_results( results=[ @@ -807,6 +816,7 @@ def benchmark_one_method_problem( run_trials_in_batches=run_trials_in_batches, timeout_hours=timeout_hours, orchestrator_logging_level=orchestrator_logging_level, + tolerated_trial_failure_rate=tolerated_trial_failure_rate, ) for seed in seeds ] @@ -820,6 +830,7 @@ def benchmark_multiple_problems_methods( run_trials_in_batches: bool = False, timeout_hours: float = 4.0, orchestrator_logging_level: int = DEFAULT_LOG_LEVEL, + tolerated_trial_failure_rate: float = 0.5, ) -> list[AggregatedBenchmarkResult]: """ For each `problem` and `method` in the Cartesian product of `problems` and @@ -835,6 +846,7 @@ def benchmark_multiple_problems_methods( run_trials_in_batches=run_trials_in_batches, timeout_hours=timeout_hours, orchestrator_logging_level=orchestrator_logging_level, + tolerated_trial_failure_rate=tolerated_trial_failure_rate, ) for p, m in product(problems, methods) ] diff --git a/ax/benchmark/tests/test_benchmark.py b/ax/benchmark/tests/test_benchmark.py index 37e01fdf186..69cbea4bafe 100644 --- a/ax/benchmark/tests/test_benchmark.py +++ b/ax/benchmark/tests/test_benchmark.py @@ -1131,6 +1131,18 @@ def test_get_benchmark_orchestrator_options(self) -> None: self.assertEqual( orchestrator_options.status_quo_weight, 1.0 if include_sq else 0.0 ) + # Default tolerated_trial_failure_rate should be 0.5 + self.assertEqual(orchestrator_options.tolerated_trial_failure_rate, 0.5) + + with self.subTest("custom tolerated_trial_failure_rate"): + orchestrator_options = get_benchmark_orchestrator_options( + batch_size=1, + run_trials_in_batches=False, + max_pending_trials=2, + early_stopping_strategy=None, + tolerated_trial_failure_rate=0.9, + ) + self.assertEqual(orchestrator_options.tolerated_trial_failure_rate, 0.9) def test_replication_with_status_quo(self) -> None: method = BenchmarkMethod(