Skip to content

Commit 6460a1a

Browse files
Sunny Shenmeta-codesync[bot]
authored andcommitted
Allow custom tolerated_trial_failure_rate (facebook#4954)
Summary: Pull Request resolved: facebook#4954 Allow custom tolerated_trial_failure_rate for failure aware benchmark problems where Standard BO without failure awareness is expected to have a high trial failure rate Differential Revision: D87092423
1 parent b304ab3 commit 6460a1a

File tree

2 files changed

+24
-0
lines changed

2 files changed

+24
-0
lines changed

ax/benchmark/benchmark.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -264,6 +264,7 @@ def get_benchmark_orchestrator_options(
264264
early_stopping_strategy: BaseEarlyStoppingStrategy | None,
265265
include_status_quo: bool = False,
266266
logging_level: int = DEFAULT_LOG_LEVEL,
267+
tolerated_trial_failure_rate: float = 0.5,
267268
) -> OrchestratorOptions:
268269
"""
269270
Get the ``OrchestratorOptions`` for the given ``BenchmarkMethod``.
@@ -278,6 +279,8 @@ def get_benchmark_orchestrator_options(
278279
early_stopping_strategy: The early stopping strategy to use (if any).
279280
include_status_quo: Whether to include the status quo in each trial.
280281
logging_level: The logging level to use for the Orchestrator.
282+
tolerated_trial_failure_rate: Fraction of trials allowed to fail without
283+
aborting the optimization. Expects value between 0 and 1. Default is 0.5.
281284
282285
Returns:
283286
``OrchestratorOptions``
@@ -299,6 +302,7 @@ def get_benchmark_orchestrator_options(
299302
early_stopping_strategy=early_stopping_strategy,
300303
status_quo_weight=1.0 if include_status_quo else 0.0,
301304
logging_level=logging_level,
305+
tolerated_trial_failure_rate=tolerated_trial_failure_rate,
302306
)
303307

304308

@@ -590,6 +594,7 @@ def run_optimization_with_orchestrator(
590594
run_trials_in_batches: bool = False,
591595
timeout_hours: float | None = None,
592596
orchestrator_logging_level: int = DEFAULT_LOG_LEVEL,
597+
tolerated_trial_failure_rate: float = 0.5,
593598
) -> Experiment:
594599
"""
595600
Optimize the ``problem`` using the ``method`` and ``Orchestrator``, seeding
@@ -626,6 +631,7 @@ def run_optimization_with_orchestrator(
626631
early_stopping_strategy=method.early_stopping_strategy,
627632
include_status_quo=sq_arm is not None,
628633
logging_level=orchestrator_logging_level,
634+
tolerated_trial_failure_rate=tolerated_trial_failure_rate,
629635
)
630636
runner = get_benchmark_runner(
631637
problem=problem,
@@ -677,6 +683,7 @@ def benchmark_replication(
677683
timeout_hours: float = 4.0,
678684
orchestrator_logging_level: int = DEFAULT_LOG_LEVEL,
679685
strip_runner_before_saving: bool = True,
686+
tolerated_trial_failure_rate: float = 0.5,
680687
) -> BenchmarkResult:
681688
"""
682689
Run one benchmarking replication (equivalent to one optimization loop).
@@ -714,6 +721,7 @@ def benchmark_replication(
714721
run_trials_in_batches=run_trials_in_batches,
715722
timeout_hours=timeout_hours,
716723
orchestrator_logging_level=orchestrator_logging_level,
724+
tolerated_trial_failure_rate=tolerated_trial_failure_rate,
717725
)
718726

719727
benchmark_result = get_benchmark_result_from_experiment_and_gs(
@@ -797,6 +805,7 @@ def benchmark_one_method_problem(
797805
run_trials_in_batches: bool = False,
798806
timeout_hours: float = 4.0,
799807
orchestrator_logging_level: int = DEFAULT_LOG_LEVEL,
808+
tolerated_trial_failure_rate: float = 0.5,
800809
) -> AggregatedBenchmarkResult:
801810
return AggregatedBenchmarkResult.from_benchmark_results(
802811
results=[
@@ -807,6 +816,7 @@ def benchmark_one_method_problem(
807816
run_trials_in_batches=run_trials_in_batches,
808817
timeout_hours=timeout_hours,
809818
orchestrator_logging_level=orchestrator_logging_level,
819+
tolerated_trial_failure_rate=tolerated_trial_failure_rate,
810820
)
811821
for seed in seeds
812822
]
@@ -820,6 +830,7 @@ def benchmark_multiple_problems_methods(
820830
run_trials_in_batches: bool = False,
821831
timeout_hours: float = 4.0,
822832
orchestrator_logging_level: int = DEFAULT_LOG_LEVEL,
833+
tolerated_trial_failure_rate: float = 0.5,
823834
) -> list[AggregatedBenchmarkResult]:
824835
"""
825836
For each `problem` and `method` in the Cartesian product of `problems` and
@@ -835,6 +846,7 @@ def benchmark_multiple_problems_methods(
835846
run_trials_in_batches=run_trials_in_batches,
836847
timeout_hours=timeout_hours,
837848
orchestrator_logging_level=orchestrator_logging_level,
849+
tolerated_trial_failure_rate=tolerated_trial_failure_rate,
838850
)
839851
for p, m in product(problems, methods)
840852
]

ax/benchmark/tests/test_benchmark.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1131,6 +1131,18 @@ def test_get_benchmark_orchestrator_options(self) -> None:
11311131
self.assertEqual(
11321132
orchestrator_options.status_quo_weight, 1.0 if include_sq else 0.0
11331133
)
1134+
# Default tolerated_trial_failure_rate should be 0.5
1135+
self.assertEqual(orchestrator_options.tolerated_trial_failure_rate, 0.5)
1136+
1137+
with self.subTest("custom tolerated_trial_failure_rate"):
1138+
orchestrator_options = get_benchmark_orchestrator_options(
1139+
batch_size=1,
1140+
run_trials_in_batches=False,
1141+
max_pending_trials=2,
1142+
early_stopping_strategy=None,
1143+
tolerated_trial_failure_rate=0.9,
1144+
)
1145+
self.assertEqual(orchestrator_options.tolerated_trial_failure_rate, 0.9)
11341146

11351147
def test_replication_with_status_quo(self) -> None:
11361148
method = BenchmarkMethod(

0 commit comments

Comments
 (0)