@@ -100,6 +100,7 @@ def __init__(
100100 default_data_type : DataType | None = None ,
101101 auxiliary_experiments_by_purpose : None
102102 | (dict [AuxiliaryExperimentPurpose , list [AuxiliaryExperiment ]]) = None ,
103+ default_trial_type : str | None = None ,
103104 ) -> None :
104105 """Inits Experiment.
105106
@@ -130,7 +131,7 @@ def __init__(
130131
131132 self ._name = name
132133 self .description = description
133- self .runner = runner
134+ self ._runner = runner
134135 self .is_test : bool = is_test
135136
136137 self ._data_by_trial : dict [int , OrderedDict [int , Data ]] = {}
@@ -141,6 +142,12 @@ def __init__(
141142 self ._trials : dict [int , BaseTrial ] = {}
142143 self ._properties : dict [str , Any ] = properties or {}
143144 self ._default_data_type : DataType = default_data_type or DataType .DATA
145+
146+ # Initialize trial type to runner mapping
147+ self ._default_trial_type = default_trial_type
148+ self ._trial_type_to_runner : dict [str | None , Runner | None ] = {
149+ default_trial_type : runner
150+ }
144151 # Used to keep track of whether any trials on the experiment
145152 # specify a TTL. Since trials need to be checked for their TTL's
146153 # expiration often, having this attribute helps avoid unnecessary
@@ -391,6 +398,26 @@ def status_quo(self, status_quo: Arm | None) -> None:
391398
392399 self ._status_quo = status_quo
393400
401+ @property
402+ def runner (self ) -> Runner | None :
403+ """Default runner used for trials on this experiment."""
404+ return self ._runner
405+
406+ @runner .setter
407+ def runner (self , runner : Runner | None ) -> None :
408+ """Set the default runner and update trial type mapping."""
409+ self ._runner = runner
410+ if runner is not None :
411+ self ._trial_type_to_runner [self ._default_trial_type ] = runner
412+ else :
413+ self ._trial_type_to_runner = {None : None }
414+
415+ @runner .deleter
416+ def runner (self ) -> None :
417+ """Delete the runner."""
418+ self ._runner = None
419+ self ._trial_type_to_runner = {None : None }
420+
394421 @property
395422 def parameters (self ) -> dict [str , Parameter ]:
396423 """The parameters in the experiment's search space."""
@@ -1327,7 +1354,7 @@ def stop_trial_runs(
13271354 reasons = [None ] * len (trials )
13281355
13291356 for trial , reason in zip (trials , reasons ):
1330- runner = self .runner_for_trial ( trial = trial )
1357+ runner = self .runner_for_trial_type ( trial_type = trial . trial_type )
13311358 if runner is None :
13321359 raise RunnerNotFoundError (
13331360 "Unable to stop trial runs: Runner not configured "
@@ -1336,17 +1363,6 @@ def stop_trial_runs(
13361363 runner .stop (trial = trial , reason = reason )
13371364 trial .mark_early_stopped ()
13381365
1339- def reset_runners (self , runner : Runner ) -> None :
1340- """Replace all candidate trials runners.
1341-
1342- Args:
1343- runner: New runner to replace with.
1344- """
1345- for trial in self ._trials .values ():
1346- if trial .status == TrialStatus .CANDIDATE :
1347- trial .runner = runner
1348- self .runner = runner
1349-
13501366 def _attach_trial (self , trial : BaseTrial , index : int | None = None ) -> int :
13511367 """Attach a trial to this experiment.
13521368
@@ -1648,15 +1664,18 @@ def default_trial_type(self) -> str | None:
16481664 In the base experiment class this is always None. For experiments
16491665 with multiple trial types, use the MultiTypeExperiment class.
16501666 """
1651- return None
1667+ return self . _default_trial_type
16521668
1653- def runner_for_trial (self , trial : BaseTrial ) -> Runner | None :
1654- """The default runner to use for a given trial.
1669+ def runner_for_trial_type (self , trial_type : str | None ) -> Runner | None :
1670+ """The default runner to use for a given trial type .
16551671
1656- In the base experiment class, this is always the default experiment runner.
1657- For experiments with multiple trial types, use the MultiTypeExperiment class.
1672+ Looks up the appropriate runner for this trial type in the trial_type_to_runner.
16581673 """
1659- return trial ._runner if trial ._runner else self .runner
1674+ if not self .supports_trial_type (trial_type ):
1675+ raise ValueError (f"Trial type `{ trial_type } ` is not supported." )
1676+ if (runner := self ._trial_type_to_runner .get (trial_type )) is None :
1677+ return self .runner # return the default runner
1678+ return runner
16601679
16611680 def supports_trial_type (self , trial_type : str | None ) -> bool :
16621681 """Whether this experiment allows trials of the given type.
0 commit comments