Skip to content

Commit 9c72dc8

Browse files
Lena Kashtelyanmeta-codesync[bot]
authored andcommitted
Avoid rewriting timestamps when status is not changed (#5074)
Summary: Pull Request resolved: #5074 Avoid rewriting timestamps when status is not changed Reviewed By: andycylmeta Differential Revision: D97329836 fbshipit-source-id: 8415932cef2e5b4cb47dd0286230364b4cafde81
1 parent c8124ff commit 9c72dc8

2 files changed

Lines changed: 25 additions & 20 deletions

File tree

ax/core/base_trial.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -693,6 +693,9 @@ def mark_stale(self, unsafe: bool = False) -> Self:
693693
def mark_as(self, status: TrialStatus, unsafe: bool = False, **kwargs: Any) -> Self:
694694
"""Mark trial with a new TrialStatus.
695695
696+
If the trial is already in the given status, this is a no-op -- the
697+
trial is returned unchanged and timestamps are not overwritten.
698+
696699
Args:
697700
status: The new status of the trial.
698701
unsafe: Ignore sanity checks on state transitions.
@@ -702,6 +705,8 @@ def mark_as(self, status: TrialStatus, unsafe: bool = False, **kwargs: Any) -> S
702705
Returns:
703706
The trial instance.
704707
"""
708+
if self._status == status:
709+
return self
705710
if status == TrialStatus.STAGED:
706711
self.mark_staged(unsafe=unsafe)
707712
elif status == TrialStatus.RUNNING:

ax/orchestration/orchestrator.py

Lines changed: 20 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1311,7 +1311,7 @@ def poll_and_process_results(self, poll_all_trial_statuses: bool = False) -> boo
13111311
self._sleep_if_too_early_to_poll()
13121312

13131313
# POLL TRIAL STATUSES
1314-
new_status_to_trial_idcs = self.poll_trial_status(
1314+
polled_status_to_trial_idcs = self.poll_trial_status(
13151315
poll_all_trial_statuses=poll_all_trial_statuses
13161316
)
13171317

@@ -1321,12 +1321,12 @@ def poll_and_process_results(self, poll_all_trial_statuses: bool = False) -> boo
13211321
# This must be done before updating the trial statuses, so we can differentiate
13221322
# newly and previously completed trials.
13231323
trial_indices_to_fetch = self._get_trial_indices_to_fetch(
1324-
new_status_to_trial_idcs=new_status_to_trial_idcs
1324+
polled_status_to_trial_idcs=polled_status_to_trial_idcs
13251325
)
13261326

13271327
# UPDATE TRIAL STATUSES
1328-
trial_indices_with_updated_statuses = self._apply_new_trial_statuses(
1329-
new_status_to_trial_idcs=new_status_to_trial_idcs,
1328+
trial_indices_with_updated_statuses = self._apply_trial_statuses(
1329+
polled_status_to_trial_idcs=polled_status_to_trial_idcs,
13301330
)
13311331
updated_any_trial_status = len(trial_indices_with_updated_statuses) > 0
13321332
trial_indices_with_updated_data_or_status.update(
@@ -1423,20 +1423,20 @@ def _num_ran_in_orchestrator(self) -> int:
14231423
"""Returns the number of trials that have been run by the orchestrator."""
14241424
return len(self.experiment.trials) - self._num_preexisting_trials
14251425

1426-
def _apply_new_trial_statuses(
1427-
self, new_status_to_trial_idcs: dict[TrialStatus, set[int]]
1426+
def _apply_trial_statuses(
1427+
self, polled_status_to_trial_idcs: dict[TrialStatus, set[int]]
14281428
) -> set[int]:
1429-
"""Apply new trial statuses to the experiment according to poll results.
1429+
"""Apply polled trial statuses to the experiment.
14301430
14311431
Args:
1432-
new_status_to_trial_idcs: Changes to be applied to trial statuses from
1433-
poll_trial_status.
1432+
polled_status_to_trial_idcs: Statuses as reported by poll_trial_status.
1433+
May include trials whose status has not changed.
14341434
14351435
Returns:
1436-
Set of trial indices that were updated with new statuses.
1436+
Set of trial indices that were processed.
14371437
"""
14381438
updated_trial_indices = set()
1439-
for status, trial_idcs in new_status_to_trial_idcs.items():
1439+
for status, trial_idcs in polled_status_to_trial_idcs.items():
14401440
if status.is_candidate or status.is_deployed:
14411441
# No need to consider candidate, staged or running trials here (none of
14421442
# these trials should actually be candidates, but we can filter on that)
@@ -1465,15 +1465,15 @@ def _apply_new_trial_statuses(
14651465
def _identify_trial_indices_to_fetch(
14661466
self,
14671467
old_status_to_trial_idcs: Mapping[TrialStatus, set[int]],
1468-
new_status_to_trial_idcs: Mapping[TrialStatus, set[int]],
1468+
polled_status_to_trial_idcs: Mapping[TrialStatus, set[int]],
14691469
) -> set[int]:
14701470
"""
14711471
Identify trial indices to fetch data for based on changes in trial statuses.
14721472
14731473
Args:
14741474
old_status_to_trial_idcs: Mapping of old trial statuses
14751475
to their corresponding trial indices.
1476-
new_status_to_trial_idcs: Mapping of new trial statuses
1476+
polled_status_to_trial_idcs: Mapping of new trial statuses
14771477
to their corresponding trial indices.
14781478
Returns:
14791479
Set of trial indices to fetch data for.
@@ -1484,7 +1484,7 @@ def _identify_trial_indices_to_fetch(
14841484
) | old_status_to_trial_idcs.get(TrialStatus.EARLY_STOPPED, set())
14851485

14861486
newly_completed = (
1487-
new_status_to_trial_idcs.get(TrialStatus.COMPLETED, set())
1487+
polled_status_to_trial_idcs.get(TrialStatus.COMPLETED, set())
14881488
- prev_completed_trial_idcs
14891489
)
14901490

@@ -1499,11 +1499,11 @@ def _identify_trial_indices_to_fetch(
14991499
if any(
15001500
m.is_available_while_running() for m in self.experiment.metrics.values()
15011501
):
1502-
running_trial_indices_with_metrics = new_status_to_trial_idcs.get(
1502+
running_trial_indices_with_metrics = polled_status_to_trial_idcs.get(
15031503
TrialStatus.RUNNING, set()
15041504
) | old_status_to_trial_idcs.get(TrialStatus.RUNNING, set())
15051505

1506-
for status, indices in new_status_to_trial_idcs.items():
1506+
for status, indices in polled_status_to_trial_idcs.items():
15071507
if status.is_terminal and indices:
15081508
running_trial_indices_with_metrics -= indices
15091509

@@ -1533,17 +1533,17 @@ def _identify_trial_indices_to_fetch(
15331533
return trial_indices_to_fetch
15341534

15351535
def _get_trial_indices_to_fetch(
1536-
self, new_status_to_trial_idcs: Mapping[TrialStatus, set[int]]
1536+
self, polled_status_to_trial_idcs: Mapping[TrialStatus, set[int]]
15371537
) -> set[int]:
15381538
"""Get trial indices to fetch data for the experiment given
1539-
`new_status_to_trial_idcs` and metric properties. This should include:
1539+
`polled_status_to_trial_idcs` and metric properties. This should include:
15401540
- newly completed trials
15411541
- running trials if the experiment has metrics available while running
15421542
- previously completed (or early stopped) trials if the experiment
15431543
has metrics with new data after completion which finished recently
15441544
15451545
Args:
1546-
new_status_to_trial_idcs: Changes about to be applied to trial statuses.
1546+
polled_status_to_trial_idcs: Changes about to be applied to trial statuses.
15471547
15481548
Returns:
15491549
Set of trial indices to fetch data for.
@@ -1555,7 +1555,7 @@ def _get_trial_indices_to_fetch(
15551555

15561556
return self._identify_trial_indices_to_fetch(
15571557
old_status_to_trial_idcs=old_status_to_trial_idcs,
1558-
new_status_to_trial_idcs=new_status_to_trial_idcs,
1558+
polled_status_to_trial_idcs=polled_status_to_trial_idcs,
15591559
)
15601560

15611561
def _get_recently_completed_trial_indices(self) -> set[int]:

0 commit comments

Comments
 (0)