Skip to content

Commit 214a2d5

Browse files
Sunny Shenfacebook-github-bot
authored andcommitted
Allow BenchmarkProblem to take in custom BenchmarkRunner (facebook#4948)
Summary: `BenchmarkRunner` by default assumes all trials are completed in `poll_trial_status`. We need a custom runner (`FailureAwareBenchmarkRunner` from previous diff) to override `poll_trial_status` which stores trial status depending on if the trial data is NaN (should be abandoned) or is available (completed) Differential Revision: D87091317
1 parent c2f9aec commit 214a2d5

File tree

2 files changed

+13
-5
lines changed

2 files changed

+13
-5
lines changed

ax/benchmark/benchmark.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -576,11 +576,17 @@ def run_optimization_with_orchestrator(
576576
include_status_quo=sq_arm is not None,
577577
logging_level=orchestrator_logging_level,
578578
)
579-
runner = get_benchmark_runner(
580-
problem=problem,
581-
max_concurrency=orchestrator_options.max_pending_trials,
582-
force_use_simulated_backend=method.early_stopping_strategy is not None,
583-
)
579+
580+
# Use custom runner if provided on the problem, otherwise create standard runner
581+
if problem.runner is not None:
582+
runner = problem.runner
583+
else:
584+
runner = get_benchmark_runner(
585+
problem=problem,
586+
max_concurrency=orchestrator_options.max_pending_trials,
587+
force_use_simulated_backend=method.early_stopping_strategy is not None,
588+
)
589+
584590
experiment = Experiment(
585591
name=f"{problem.name}|{method.name}_{int(time())}",
586592
search_space=problem.search_space,

ax/benchmark/benchmark_problem.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from dataclasses import dataclass, field
1111

1212
from ax.benchmark.benchmark_metric import BenchmarkMapMetric, BenchmarkMetric
13+
from ax.benchmark.benchmark_runner import BenchmarkRunner
1314
from ax.benchmark.benchmark_step_runtime_function import TBenchmarkStepRuntimeFunction
1415
from ax.benchmark.benchmark_test_function import BenchmarkTestFunction
1516
from ax.benchmark.noise import GaussianNoise, Noise
@@ -107,6 +108,7 @@ class BenchmarkProblem(Base):
107108
dict[AuxiliaryExperimentPurpose, list[AuxiliaryExperiment]] | None
108109
) = None
109110
tracking_metrics: list[Metric] | None = None
111+
runner: BenchmarkRunner | None = None
110112

111113
def __post_init__(self) -> None:
112114
# Handle backward compatibility for noise_std parameter

0 commit comments

Comments
 (0)