Skip to content

Commit 8ec8faa

Browse files
sdaultonmeta-codesync[bot]
authored andcommitted
Use compute_metric_availability in AxClient.fit_model (#5111)
Summary: Pull Request resolved: #5111 The previous data check in `fit_model` only verified that the DataFrame was non-empty (`lookup_data().df.empty`), which would pass even when data existed for only a subset of required metrics. This could allow model fitting to proceed with incomplete data, leading to downstream errors. Replace the manual check with `compute_metric_availability()` from `ax.core.utils`, which inspects per-trial metric coverage against the optimization config's required metrics. `fit_model` now raises `DataRequiredError` unless at least one completed trial has data for **all** required metrics (`MetricAvailability.COMPLETE`). Reviewed By: saitcakmak Differential Revision: D98208718 fbshipit-source-id: 9c5b898a92e8656516e19af5a12174528e6ce9e4
1 parent 7c67537 commit 8ec8faa

2 files changed

Lines changed: 30 additions & 2 deletions

File tree

ax/service/ax_client.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
TParameterization,
4141
TParamValue,
4242
)
43+
from ax.core.utils import compute_metric_availability, MetricAvailability
4344
from ax.early_stopping.strategies import BaseEarlyStoppingStrategy
4445
from ax.early_stopping.utils import estimate_early_stopping_savings
4546
from ax.exceptions.constants import CHOLESKY_ERROR_ANNOTATION
@@ -1407,9 +1408,14 @@ def fit_model(self) -> None:
14071408
raise DataRequiredError(
14081409
"At least one trial must be completed with data to fit a model."
14091410
)
1410-
if self.experiment.lookup_data(trial_indices=completed_trial_indices).df.empty:
1411+
availability = compute_metric_availability(
1412+
experiment=self.experiment,
1413+
trial_indices=completed_trial_indices,
1414+
)
1415+
if not any(v == MetricAvailability.COMPLETE for v in availability.values()):
14111416
raise DataRequiredError(
1412-
"At least one completed trial must have data attached to fit a model."
1417+
"At least one completed trial must have data for all required "
1418+
"metrics to fit a model."
14131419
)
14141420
self.generation_strategy.fit(experiment=self.experiment)
14151421

ax/service/tests/test_ax_client.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2581,6 +2581,28 @@ def test_get_model_predictions_no_next_trial_no_completed_trial(self) -> None:
25812581
):
25822582
ax_client.get_model_predictions()
25832583

2584+
def test_fit_model_partial_metric_data(self) -> None:
2585+
"""Test that fit_model raises when completed trials only have data for
2586+
a subset of required metrics."""
2587+
ax_client = _set_up_client_for_get_model_predictions_no_next_trial()
2588+
# Attach a trial and complete it with data for only one of the two
2589+
# required metrics (test_metric1 is the objective, test_metric2 is the
2590+
# constraint). We bypass complete_trial() because it marks the trial as
2591+
# failed when required metrics are missing. Instead, we attach data and
2592+
# mark completed directly, simulating the case where the data check at
2593+
# completion time is skipped (e.g., data is attached asynchronously).
2594+
trial: TParameterization = {"x1": 0.1, "x2": 0.1}
2595+
_parameters, trial_index = ax_client.attach_trial(trial)
2596+
ax_trial = ax_client.get_trial(trial_index)
2597+
ax_trial.update_trial_data(raw_data={"test_metric1": (1.0, 0.0)})
2598+
ax_trial.mark_completed()
2599+
2600+
with self.assertRaisesRegex(
2601+
DataRequiredError,
2602+
"At least one completed trial must have data for all required metrics",
2603+
):
2604+
ax_client.fit_model()
2605+
25842606
def test_get_model_predictions_no_next_trial_filtered(self) -> None:
25852607
ax_client = _set_up_client_for_get_model_predictions_no_next_trial()
25862608
_attach_completed_trials(ax_client)

0 commit comments

Comments
 (0)