Skip to content

Commit d7e7620

Browse files
Sunny Shenfacebook-github-bot
authored andcommitted
Update BenchmarkRunner to mark trials with NaN data as ABANDONED
Summary: Update `BenchmarkRunner.poll_trial_status` to mark trials with NaN data as `ABANDONED`. Differential Revision: D87094541
1 parent 71c9895 commit d7e7620

File tree

2 files changed

+118
-2
lines changed

2 files changed

+118
-2
lines changed

ax/benchmark/benchmark_runner.py

Lines changed: 53 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
# pyre-strict
77

8+
import logging
89
import warnings
910
from collections.abc import Iterable, Mapping, Sequence
1011
from dataclasses import dataclass, field
@@ -26,6 +27,8 @@
2627
from ax.utils.common.serialization import TClassDecoderRegistry, TDecoderRegistry
2728
from ax.utils.testing.backend_simulator import BackendSimulator, BackendSimulatorOptions
2829

30+
logger: logging.Logger = logging.getLogger(__name__)
31+
2932

3033
def _dict_of_arrays_to_df(
3134
Y_true_by_arm: Mapping[str, npt.NDArray],
@@ -264,12 +267,60 @@ def run(self, trial: BaseTrial) -> dict[str, BenchmarkTrialMetadata]:
264267
)
265268
return {"benchmark_metadata": metadata}
266269

270+
def _trial_has_nan(self, trial: BaseTrial) -> bool:
271+
"""Check if a trial's benchmark metadata contains NaN in the mean column.
272+
273+
Args:
274+
trial: The trial to check.
275+
276+
Returns:
277+
True if any metric's DataFrame has NaN in the "mean" column.
278+
"""
279+
run_metadata = trial.run_metadata
280+
if not isinstance(run_metadata, dict):
281+
return False
282+
metadata = run_metadata.get("benchmark_metadata")
283+
if metadata is None:
284+
return False
285+
for df in metadata.dfs.values():
286+
if not df.empty and df["mean"].isna().any():
287+
return True
288+
return False
289+
267290
def poll_trial_status(
268291
self, trials: Iterable[BaseTrial]
269292
) -> dict[TrialStatus, set[int]]:
270293
if self.simulated_backend_runner is None:
271-
return {TrialStatus.COMPLETED: {t.index for t in trials}}
272-
return self.simulated_backend_runner.poll_trial_status(trials=trials)
294+
statuses: dict[TrialStatus, set[int]] = {
295+
TrialStatus.COMPLETED: {t.index for t in trials}
296+
}
297+
else:
298+
statuses = self.simulated_backend_runner.poll_trial_status(trials=trials)
299+
300+
# Move completed trials with NaN data to ABANDONED.
301+
completed_indices = statuses.get(TrialStatus.COMPLETED, set())
302+
if completed_indices:
303+
trials_by_index = {t.index: t for t in trials}
304+
nan_indices = {
305+
idx
306+
for idx in completed_indices
307+
if (trial := trials_by_index.get(idx)) is not None
308+
and self._trial_has_nan(trial)
309+
}
310+
if nan_indices:
311+
for idx in nan_indices:
312+
logger.info(
313+
f"Trial {idx} has NaN in metrics and will be "
314+
"marked as ABANDONED."
315+
)
316+
statuses.pop(TrialStatus.COMPLETED)
317+
remaining = completed_indices - nan_indices
318+
if remaining:
319+
statuses[TrialStatus.COMPLETED] = remaining
320+
statuses[TrialStatus.ABANDONED] = (
321+
statuses.get(TrialStatus.ABANDONED, set()) | nan_indices
322+
)
323+
return statuses
273324

274325
@classmethod
275326
def serialize_init_args(cls, obj: Any) -> dict[str, Any]:

ax/benchmark/tests/test_benchmark_runner.py

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -445,3 +445,68 @@ def test_heterogeneous_step_runtime(self) -> None:
445445
ValueError, "Step duration must be non-negative"
446446
):
447447
runner.run(trial=trial)
448+
449+
def test_nan_trials_abandoned(self) -> None:
450+
"""Trials with NaN data are marked ABANDONED, valid trials COMPLETED."""
451+
test_function = IdentityTestFunction(outcome_names=["objective"])
452+
453+
with self.subTest("without simulated backend"):
454+
runner = BenchmarkRunner(
455+
test_function=test_function,
456+
noise=GaussianNoise(noise_std=0.0),
457+
)
458+
experiment = Experiment(
459+
name="test_nan_sync",
460+
is_test=True,
461+
runner=runner,
462+
search_space=Mock(spec=SearchSpace),
463+
)
464+
trial_valid = Trial(experiment=experiment)
465+
trial_valid.add_arm(Arm(name="0_0", parameters={"x0": 1.5}))
466+
trial_valid.run()
467+
468+
trial_nan = Trial(experiment=experiment)
469+
trial_nan.add_arm(Arm(name="1_0", parameters={"x0": float("nan")}))
470+
trial_nan.run()
471+
472+
trial_valid2 = Trial(experiment=experiment)
473+
trial_valid2.add_arm(Arm(name="2_0", parameters={"x0": 2.0}))
474+
trial_valid2.run()
475+
476+
statuses = runner.poll_trial_status([trial_valid, trial_nan, trial_valid2])
477+
self.assertEqual(
478+
statuses[TrialStatus.COMPLETED],
479+
{trial_valid.index, trial_valid2.index},
480+
)
481+
self.assertEqual(statuses[TrialStatus.ABANDONED], {trial_nan.index})
482+
483+
with self.subTest("with simulated backend"):
484+
runner = BenchmarkRunner(
485+
test_function=test_function,
486+
noise=GaussianNoise(noise_std=0.0),
487+
force_use_simulated_backend=True,
488+
)
489+
experiment = Experiment(
490+
name="test_nan_async",
491+
is_test=True,
492+
runner=runner,
493+
search_space=Mock(spec=SearchSpace),
494+
)
495+
trial_nan = Trial(experiment=experiment)
496+
trial_nan.add_arm(Arm(name="0_0", parameters={"x0": float("nan")}))
497+
trial_nan.run()
498+
499+
trial_valid = Trial(experiment=experiment)
500+
trial_valid.add_arm(Arm(name="1_0", parameters={"x0": 3.0}))
501+
trial_valid.run()
502+
503+
# Advance simulated time so trials complete
504+
simulator = none_throws(
505+
none_throws(runner.simulated_backend_runner).simulator
506+
)
507+
while simulator.num_completed < 2:
508+
simulator.update()
509+
510+
statuses = runner.poll_trial_status([trial_nan, trial_valid])
511+
self.assertEqual(statuses[TrialStatus.COMPLETED], {trial_valid.index})
512+
self.assertEqual(statuses[TrialStatus.ABANDONED], {trial_nan.index})

0 commit comments

Comments
 (0)