Skip to content

Commit c08c4a9

Browse files
ItsMrLinmeta-codesync[bot]
authored andcommitted
Exclude LILO labeling trials from trials_expecting_data (#5144)
Summary: Pull Request resolved: #5144 LILO labeling trials have their pairwise preference data fetched inline during the labeling loop (`_run_lilo_labeling_loop`), not via the normal data refetch paths (orchestrator `poll_and_process_results`, PTSClient `refetch_data`). Including them in `trials_expecting_data` causes unnecessary data fetch attempts for metrics (e.g., Deltoid) that don't exist on these trials, producing noisy errors and wasting time. This filters LILO labeling trials (`trial_type == Keys.LILO_LABELING`) from the `trials_expecting_data` property on `Experiment`, which is the centralized source used by all downstream data refetch consumers. Reviewed By: saitcakmak Differential Revision: D99571562 fbshipit-source-id: 7064918189335ac2630fcd7ea57046ae296e437a
1 parent e942453 commit c08c4a9

6 files changed

Lines changed: 52 additions & 18 deletions

File tree

ax/core/base_trial.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -198,6 +198,16 @@ def status(self) -> TrialStatus:
198198
self._mark_stale_if_past_TTL()
199199
return none_throws(self._status)
200200

201+
@property
202+
def expecting_data(self) -> bool:
203+
"""Whether this trial expects data via the standard data-fetch pipeline.
204+
205+
Returns ``False`` for LILO labeling trials because their pairwise
206+
preference data is fetched inline during the labeling loop and is
207+
never refetched through the normal orchestration path.
208+
"""
209+
return self.status.expecting_data and self.trial_type != Keys.LILO_LABELING
210+
201211
@status.setter
202212
def status(self, status: TrialStatus) -> None:
203213
raise NotImplementedError("Use `trial.mark_*` methods to set trial status.")

ax/core/experiment.py

Lines changed: 7 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -45,11 +45,7 @@
4545
from ax.core.runner import Runner
4646
from ax.core.search_space import SearchSpace
4747
from ax.core.trial import Trial
48-
from ax.core.trial_status import (
49-
DEFAULT_STATUSES_TO_WARM_START,
50-
STATUSES_EXPECTING_DATA,
51-
TrialStatus,
52-
)
48+
from ax.core.trial_status import DEFAULT_STATUSES_TO_WARM_START, TrialStatus
5349
from ax.core.types import ComparisonOp, TParameterization
5450
from ax.exceptions.core import (
5551
AxError,
@@ -1406,10 +1402,10 @@ def trials_by_status(self) -> dict[TrialStatus, list[BaseTrial]]:
14061402

14071403
@property
14081404
def trials_expecting_data(self) -> list[BaseTrial]:
1409-
"""list[BaseTrial]: the list of all trials for which data has arrived
1410-
or is expected to arrive.
1405+
"""list[BaseTrial]: the list of all trials that expect data via the
1406+
standard data-fetch pipeline.
14111407
"""
1412-
return [trial for trial in self.trials.values() if trial.status.expecting_data]
1408+
return [trial for trial in self.trials.values() if trial.expecting_data]
14131409

14141410
@property
14151411
def completed_trials(self) -> list[BaseTrial]:
@@ -1433,15 +1429,10 @@ def running_trial_indices(self) -> set[int]:
14331429

14341430
@property
14351431
def trial_indices_expecting_data(self) -> set[int]:
1436-
"""Set of indices of trials, statuses of which indicate that we expect
1437-
these trials to have data, either already or in the future.
1432+
"""Set of indices of trials that expect data via the standard
1433+
data-fetch pipeline.
14381434
"""
1439-
return set.union(
1440-
*(
1441-
self.trial_indices_by_status[status]
1442-
for status in STATUSES_EXPECTING_DATA
1443-
)
1444-
)
1435+
return {trial.index for trial in self.trials.values() if trial.expecting_data}
14451436

14461437
def trial_indices_with_data(
14471438
self, critical_metrics_only: bool | None = True

ax/core/multi_type_experiment.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -275,7 +275,7 @@ def fetch_data(
275275
[
276276
(
277277
trial.fetch_data(**kwargs, metrics=metrics)
278-
if trial.status.expecting_data
278+
if trial.expecting_data
279279
else Data()
280280
)
281281
for trial in self.trials.values()

ax/core/tests/test_experiment.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -881,14 +881,26 @@ def test_experiment_runner(self) -> None:
881881
candidate_batch.run()
882882
candidate_batch._status = TrialStatus.CANDIDATE
883883
self.assertEqual(self.experiment.trials_expecting_data, [batch])
884+
885+
# LILO labeling trials are excluded from trials_expecting_data
886+
# (their data is fetched inline during the labeling loop).
887+
lilo_batch = self.experiment.new_batch_trial(
888+
trial_type=Keys.LILO_LABELING,
889+
)
890+
lilo_batch.run()
891+
lilo_batch.mark_completed()
892+
self.assertEqual(self.experiment.trials_expecting_data, [batch])
893+
884894
tbs = self.experiment.trials_by_status # All statuses should be present
885895
self.assertEqual(len(tbs), len(TrialStatus))
886896
self.assertEqual(tbs[TrialStatus.RUNNING], [batch])
887897
self.assertEqual(tbs[TrialStatus.CANDIDATE], [candidate_batch])
898+
self.assertEqual(tbs[TrialStatus.COMPLETED], [lilo_batch])
888899
tibs = self.experiment.trial_indices_by_status
889900
self.assertEqual(len(tibs), len(TrialStatus))
890901
self.assertEqual(tibs[TrialStatus.RUNNING], {0})
891902
self.assertEqual(tibs[TrialStatus.CANDIDATE], {1})
903+
self.assertEqual(tibs[TrialStatus.COMPLETED], {2})
892904

893905
identifier = {"new_runner": True}
894906
# pyre-fixme[6]: For 1st param expected `Optional[str]` but got `Dict[str,
@@ -1727,6 +1739,12 @@ def test_trial_indices(self) -> None:
17271739
)
17281740
self.assertEqual(experiment.trial_indices_expecting_data, {2, 5})
17291741

1742+
# LILO labeling trials are excluded from trial_indices_expecting_data.
1743+
lilo_trial = experiment.new_batch_trial(trial_type=Keys.LILO_LABELING)
1744+
lilo_trial.mark_running(no_runner_required=True)
1745+
lilo_trial.mark_completed()
1746+
self.assertEqual(experiment.trial_indices_expecting_data, {2, 5})
1747+
17301748
def test_trial_indices_with_data(self) -> None:
17311749
exp = get_branin_experiment_with_multi_objective(
17321750
with_status_quo=True,

ax/core/tests/test_trial.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from ax.exceptions.core import TrialMutationError, UnsupportedError, UserInputError
2727
from ax.metrics.branin import BraninMetric
2828
from ax.runners.synthetic import SyntheticRunner
29+
from ax.utils.common.constants import Keys
2930
from ax.utils.common.result import Ok
3031
from ax.utils.common.testutils import TestCase
3132
from ax.utils.testing.core_stubs import (
@@ -271,8 +272,22 @@ def test_mark_as(self) -> None:
271272
TrialStatus.COMPLETED,
272273
]:
273274
self.assertTrue(self.trial.status.expecting_data)
275+
# trial.expecting_data follows status for normal trials.
276+
self.assertTrue(self.trial.expecting_data)
274277
else:
275278
self.assertFalse(self.trial.status.expecting_data)
279+
self.assertFalse(self.trial.expecting_data)
280+
281+
def test_expecting_data_excludes_lilo(self) -> None:
282+
"""LILO labeling trials never expect data via the standard pipeline."""
283+
self.trial._trial_type = Keys.LILO_LABELING
284+
self.trial.mark_running(no_runner_required=True)
285+
self.assertTrue(self.trial.status.expecting_data)
286+
self.assertFalse(self.trial.expecting_data)
287+
288+
self.trial.mark_completed()
289+
self.assertTrue(self.trial.status.expecting_data)
290+
self.assertFalse(self.trial.expecting_data)
276291

277292
def test_stop(self) -> None:
278293
# test bad old status

ax/orchestration/orchestrator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -477,7 +477,7 @@ def trials_expecting_data(self) -> list[BaseTrial]:
477477
"""
478478
trials = []
479479
for trial in self.experiment.trials.values():
480-
if trial.status.expecting_data:
480+
if trial.expecting_data:
481481
if self.trial_type is None or trial.trial_type == self.trial_type:
482482
trials.append(trial)
483483
return trials

0 commit comments

Comments
 (0)