Skip to content

Commit e167d7c

Browse files
Lena Kashtelyanmeta-codesync[bot]
authored andcommitted
Make all mark_* methods on BaseTrial no-op when status is unchanged (#5097)
Summary: Pull Request resolved: #5097 NOTE: Please see D97329836 for context; curious what folks think of this change. IMO applying the change in that diff only to `mark_as` will cause confusion in the future. ## Claude below Extend the no-op-on-same-status pattern from `mark_as` to every individual `mark_*` method on `BaseTrial`. This ensures that calling any status-marking method on a trial that is already in that status is a safe no-op that does not overwrite timestamps, regardless of whether the caller goes through `mark_as` or calls the specific method directly. Reviewed By: Cesar-Cardoso Differential Revision: D97785459 fbshipit-source-id: e622bb1bff3365783c43bc0068cf7b3a2b126e79
1 parent 6f88d2d commit e167d7c

2 files changed

Lines changed: 32 additions & 2 deletions

File tree

ax/core/base_trial.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -528,11 +528,15 @@ def status_reason(self) -> str | None:
528528
def mark_staged(self, unsafe: bool = False) -> Self:
529529
"""Mark the trial as being staged for running.
530530
531+
No-op if the trial is already staged.
532+
531533
Args:
532534
unsafe: Ignore sanity checks on state transitions.
533535
Returns:
534536
The trial instance.
535537
"""
538+
if self._status == TrialStatus.STAGED:
539+
return self
536540
if not unsafe and self._status != TrialStatus.CANDIDATE:
537541
raise TrialMutationError(
538542
f"Can only stage a candidate trial. This trial is {self._status}"
@@ -546,6 +550,8 @@ def mark_running(
546550
) -> Self:
547551
"""Mark trial has started running.
548552
553+
No-op if the trial is already running.
554+
549555
Args:
550556
no_runner_required: Whether to skip the check for presence of a
551557
``Runner`` on the experiment.
@@ -554,6 +560,8 @@ def mark_running(
554560
Returns:
555561
The trial instance.
556562
"""
563+
if self._status == TrialStatus.RUNNING:
564+
return self
557565
if self.runner is None and not no_runner_required:
558566
raise ValueError("Cannot mark trial running without setting runner.")
559567

@@ -576,6 +584,8 @@ def mark_completed(
576584
) -> Self:
577585
"""Mark trial as completed.
578586
587+
No-op if the trial is already completed.
588+
579589
Args:
580590
unsafe: Ignore sanity checks on state transitions.
581591
time_completed: The time the trial was completed. If None, defaults to
@@ -585,6 +595,8 @@ def mark_completed(
585595
Returns:
586596
The trial instance.
587597
"""
598+
if self._status == TrialStatus.COMPLETED:
599+
return self
588600
if not unsafe and self._status != TrialStatus.RUNNING:
589601
raise TrialMutationError(
590602
"Can only complete trial that is currently running."
@@ -600,6 +612,8 @@ def mark_completed(
600612
def mark_abandoned(self, reason: str | None = None, unsafe: bool = False) -> Self:
601613
"""Mark trial as abandoned.
602614
615+
No-op if the trial is already abandoned.
616+
603617
NOTE: Arms in abandoned trials are considered to be 'pending points'
604618
in experiment after their abandonment to avoid Ax models suggesting
605619
the same arm again as a new candidate. Arms in abandoned trials are
@@ -613,6 +627,8 @@ def mark_abandoned(self, reason: str | None = None, unsafe: bool = False) -> Sel
613627
Returns:
614628
The trial instance.
615629
"""
630+
if self._status == TrialStatus.ABANDONED:
631+
return self
616632
if not unsafe and none_throws(self._status).is_terminal:
617633
raise ValueError("Cannot abandon a trial in a terminal state.")
618634

@@ -626,11 +642,15 @@ def mark_abandoned(self, reason: str | None = None, unsafe: bool = False) -> Sel
626642
def mark_failed(self, reason: str | None = None, unsafe: bool = False) -> Self:
627643
"""Mark trial as failed.
628644
645+
No-op if the trial is already failed.
646+
629647
Args:
630648
unsafe: Ignore sanity checks on state transitions.
631649
Returns:
632650
The trial instance.
633651
"""
652+
if self._status == TrialStatus.FAILED:
653+
return self
634654
if not unsafe and self._status != TrialStatus.RUNNING:
635655
raise TrialMutationError(
636656
"Can only mark failed a trial that is currently running."
@@ -646,12 +666,16 @@ def mark_early_stopped(
646666
) -> Self:
647667
"""Mark trial as early stopped.
648668
669+
No-op if the trial is already early stopped.
670+
649671
Args:
650672
reason: The reason the trial was early stopped.
651673
unsafe: Ignore sanity checks on state transitions.
652674
Returns:
653675
The trial instance.
654676
"""
677+
if self._status == TrialStatus.EARLY_STOPPED:
678+
return self
655679
if not unsafe:
656680
if self._status != TrialStatus.RUNNING:
657681
raise TrialMutationError(
@@ -672,11 +696,15 @@ def mark_early_stopped(
672696
def mark_stale(self, unsafe: bool = False) -> Self:
673697
"""Mark trial as stale.
674698
699+
No-op if the trial is already stale.
700+
675701
Args:
676702
unsafe: Ignore sanity checks on state transitions.
677703
Returns:
678704
The trial instance.
679705
"""
706+
if self._status == TrialStatus.STALE:
707+
return self
680708
if not unsafe and self._status != TrialStatus.CANDIDATE:
681709
raise TrialMutationError(
682710
message=(

ax/core/tests/test_batch_trial.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -274,8 +274,10 @@ def test_BatchLifecycle(self) -> None:
274274
with self.assertRaises(TrialMutationError):
275275
self.batch.mark_staged()
276276

277-
with self.assertRaises(TrialMutationError):
278-
self.batch.mark_completed()
277+
# Re-marking as completed is a no-op (no error, no timestamp change)
278+
time_completed_before = self.batch.time_completed
279+
self.batch.mark_completed()
280+
self.assertEqual(self.batch.time_completed, time_completed_before)
279281

280282
with self.assertRaises(TrialMutationError):
281283
self.batch.mark_running()

0 commit comments

Comments
 (0)