Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions ax/benchmark/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,7 @@ def get_benchmark_orchestrator_options(
early_stopping_strategy: BaseEarlyStoppingStrategy | None,
include_status_quo: bool = False,
logging_level: int = DEFAULT_LOG_LEVEL,
tolerated_trial_failure_rate: float = 0.5,
) -> OrchestratorOptions:
"""
Get the ``OrchestratorOptions`` for the given ``BenchmarkMethod``.
Expand All @@ -278,6 +279,8 @@ def get_benchmark_orchestrator_options(
early_stopping_strategy: The early stopping strategy to use (if any).
include_status_quo: Whether to include the status quo in each trial.
logging_level: The logging level to use for the Orchestrator.
tolerated_trial_failure_rate: Fraction of trials allowed to fail without
aborting the optimization. Expects value between 0 and 1. Default is 0.5.

Returns:
``OrchestratorOptions``
Expand All @@ -299,6 +302,7 @@ def get_benchmark_orchestrator_options(
early_stopping_strategy=early_stopping_strategy,
status_quo_weight=1.0 if include_status_quo else 0.0,
logging_level=logging_level,
tolerated_trial_failure_rate=tolerated_trial_failure_rate,
)


Expand Down Expand Up @@ -590,6 +594,7 @@ def run_optimization_with_orchestrator(
run_trials_in_batches: bool = False,
timeout_hours: float | None = None,
orchestrator_logging_level: int = DEFAULT_LOG_LEVEL,
tolerated_trial_failure_rate: float = 0.5,
) -> Experiment:
"""
Optimize the ``problem`` using the ``method`` and ``Orchestrator``, seeding
Expand Down Expand Up @@ -626,6 +631,7 @@ def run_optimization_with_orchestrator(
early_stopping_strategy=method.early_stopping_strategy,
include_status_quo=sq_arm is not None,
logging_level=orchestrator_logging_level,
tolerated_trial_failure_rate=tolerated_trial_failure_rate,
)
runner = get_benchmark_runner(
problem=problem,
Expand Down Expand Up @@ -677,6 +683,7 @@ def benchmark_replication(
timeout_hours: float = 4.0,
orchestrator_logging_level: int = DEFAULT_LOG_LEVEL,
strip_runner_before_saving: bool = True,
tolerated_trial_failure_rate: float = 0.5,
) -> BenchmarkResult:
"""
Run one benchmarking replication (equivalent to one optimization loop).
Expand Down Expand Up @@ -714,6 +721,7 @@ def benchmark_replication(
run_trials_in_batches=run_trials_in_batches,
timeout_hours=timeout_hours,
orchestrator_logging_level=orchestrator_logging_level,
tolerated_trial_failure_rate=tolerated_trial_failure_rate,
)

benchmark_result = get_benchmark_result_from_experiment_and_gs(
Expand Down Expand Up @@ -797,6 +805,7 @@ def benchmark_one_method_problem(
run_trials_in_batches: bool = False,
timeout_hours: float = 4.0,
orchestrator_logging_level: int = DEFAULT_LOG_LEVEL,
tolerated_trial_failure_rate: float = 0.5,
) -> AggregatedBenchmarkResult:
return AggregatedBenchmarkResult.from_benchmark_results(
results=[
Expand All @@ -807,6 +816,7 @@ def benchmark_one_method_problem(
run_trials_in_batches=run_trials_in_batches,
timeout_hours=timeout_hours,
orchestrator_logging_level=orchestrator_logging_level,
tolerated_trial_failure_rate=tolerated_trial_failure_rate,
)
for seed in seeds
]
Expand All @@ -820,6 +830,7 @@ def benchmark_multiple_problems_methods(
run_trials_in_batches: bool = False,
timeout_hours: float = 4.0,
orchestrator_logging_level: int = DEFAULT_LOG_LEVEL,
tolerated_trial_failure_rate: float = 0.5,
) -> list[AggregatedBenchmarkResult]:
"""
For each `problem` and `method` in the Cartesian product of `problems` and
Expand All @@ -835,6 +846,7 @@ def benchmark_multiple_problems_methods(
run_trials_in_batches=run_trials_in_batches,
timeout_hours=timeout_hours,
orchestrator_logging_level=orchestrator_logging_level,
tolerated_trial_failure_rate=tolerated_trial_failure_rate,
)
for p, m in product(problems, methods)
]
Expand Down
12 changes: 12 additions & 0 deletions ax/benchmark/tests/test_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -1131,6 +1131,18 @@ def test_get_benchmark_orchestrator_options(self) -> None:
self.assertEqual(
orchestrator_options.status_quo_weight, 1.0 if include_sq else 0.0
)
# Default tolerated_trial_failure_rate should be 0.5
self.assertEqual(orchestrator_options.tolerated_trial_failure_rate, 0.5)

with self.subTest("custom tolerated_trial_failure_rate"):
orchestrator_options = get_benchmark_orchestrator_options(
batch_size=1,
run_trials_in_batches=False,
max_pending_trials=2,
early_stopping_strategy=None,
tolerated_trial_failure_rate=0.9,
)
self.assertEqual(orchestrator_options.tolerated_trial_failure_rate, 0.9)

def test_replication_with_status_quo(self) -> None:
method = BenchmarkMethod(
Expand Down
Loading