@@ -324,7 +324,7 @@ def update_stop_metadata(self, metadata: dict[str, Any]) -> dict[str, Any]:
324324 self ._stop_metadata .update (metadata )
325325 return self ._stop_metadata
326326
327- def run (self ) -> Self :
327+ def run (self ) -> BaseTrial :
328328 """Deploys the trial according to the behavior on the runner.
329329
330330 The runner returns a `run_metadata` dict containining metadata
@@ -349,7 +349,7 @@ def run(self) -> Self:
349349 self .mark_running ()
350350 return self
351351
352- def stop (self , new_status : TrialStatus , reason : str | None = None ) -> Self :
352+ def stop (self , new_status : TrialStatus , reason : str | None = None ) -> BaseTrial :
353353 """Stops the trial according to the behavior on the runner.
354354
355355 The runner returns a `stop_metadata` dict containining metadata
@@ -384,7 +384,7 @@ def stop(self, new_status: TrialStatus, reason: str | None = None) -> Self:
384384 self .mark_as (new_status )
385385 return self
386386
387- def complete (self , reason : str | None = None ) -> Self :
387+ def complete (self , reason : str | None = None ) -> BaseTrial :
388388 """Stops the trial if functionality is defined on runner
389389 and marks trial completed.
390390
@@ -524,7 +524,7 @@ def status_reason(self) -> str | None:
524524 """Reason string for the trial status (failed, abandoned, or early stopped)."""
525525 return self ._status_reason
526526
527- def mark_staged (self , unsafe : bool = False ) -> Self :
527+ def mark_staged (self , unsafe : bool = False ) -> BaseTrial :
528528 """Mark the trial as being staged for running.
529529
530530 Args:
@@ -542,7 +542,7 @@ def mark_staged(self, unsafe: bool = False) -> Self:
542542
543543 def mark_running (
544544 self , no_runner_required : bool = False , unsafe : bool = False
545- ) -> Self :
545+ ) -> BaseTrial :
546546 """Mark trial has started running.
547547
548548 Args:
@@ -572,7 +572,7 @@ def mark_running(
572572
573573 def mark_completed (
574574 self , unsafe : bool = False , time_completed : str | None = None
575- ) -> Self :
575+ ) -> BaseTrial :
576576 """Mark trial as completed.
577577
578578 Args:
@@ -596,7 +596,9 @@ def mark_completed(
596596 )
597597 return self
598598
599- def mark_abandoned (self , reason : str | None = None , unsafe : bool = False ) -> Self :
599+ def mark_abandoned (
600+ self , reason : str | None = None , unsafe : bool = False
601+ ) -> BaseTrial :
600602 """Mark trial as abandoned.
601603
602604 NOTE: Arms in abandoned trials are considered to be 'pending points'
@@ -622,7 +624,7 @@ def mark_abandoned(self, reason: str | None = None, unsafe: bool = False) -> Sel
622624 self ._time_completed = datetime .now ()
623625 return self
624626
625- def mark_failed (self , reason : str | None = None , unsafe : bool = False ) -> Self :
627+ def mark_failed (self , reason : str | None = None , unsafe : bool = False ) -> BaseTrial :
626628 """Mark trial as failed.
627629
628630 Args:
@@ -642,7 +644,7 @@ def mark_failed(self, reason: str | None = None, unsafe: bool = False) -> Self:
642644
643645 def mark_early_stopped (
644646 self , reason : str | None = None , unsafe : bool = False
645- ) -> Self :
647+ ) -> BaseTrial :
646648 """Mark trial as early stopped.
647649
648650 Args:
@@ -668,7 +670,7 @@ def mark_early_stopped(
668670 self ._time_completed = datetime .now ()
669671 return self
670672
671- def mark_stale (self , unsafe : bool = False ) -> Self :
673+ def mark_stale (self , unsafe : bool = False ) -> BaseTrial :
672674 """Mark trial as stale.
673675
674676 Args:
@@ -689,7 +691,9 @@ def mark_stale(self, unsafe: bool = False) -> Self:
689691 self ._time_completed = datetime .now ()
690692 return self
691693
692- def mark_as (self , status : TrialStatus , unsafe : bool = False , ** kwargs : Any ) -> Self :
694+ def mark_as (
695+ self , status : TrialStatus , unsafe : bool = False , ** kwargs : Any
696+ ) -> BaseTrial :
693697 """Mark trial with a new TrialStatus.
694698
695699 Args:
@@ -720,7 +724,7 @@ def mark_as(self, status: TrialStatus, unsafe: bool = False, **kwargs: Any) -> S
720724 raise TrialMutationError (f"Cannot mark trial as { status } ." )
721725 return self
722726
723- def mark_arm_abandoned (self , arm_name : str , reason : str | None = None ) -> Self :
727+ def mark_arm_abandoned (self , arm_name : str , reason : str | None = None ) -> BaseTrial :
724728 raise NotImplementedError (
725729 "Abandoning arms is only supported for `BatchTrial`. "
726730 "Use `trial.mark_abandoned` if applicable."
0 commit comments