Skip to content

Commit 54dc100

Browse files
mpolson64meta-codesync[bot]
authored andcommitted
Move add_trial_type, update_runner, supports_trial_type to base Experiment
Summary: Phase 3 of moving MultiTypeExperiment features into base Experiment. Moves `add_trial_type` and `update_runner` from MultiTypeExperiment to the base Experiment class, making them available to all experiments. Updates `supports_trial_type` to unify the logic: for multi-type experiments (where `default_trial_type` is set), only trial types registered in `_trial_type_to_runner` are supported. For single-type experiments, `None` is supported along with `SHORT_RUN` and `LONG_RUN` for backward compatibility with generation strategies that use those trial types. Removes the corresponding overrides from MultiTypeExperiment — all three methods are now inherited from the base class. Differential Revision: D94988577
1 parent fcb63d7 commit 54dc100

2 files changed

Lines changed: 32 additions & 36 deletions

File tree

ax/core/experiment.py

Lines changed: 32 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2047,6 +2047,31 @@ def default_trials(self) -> set[int]:
20472047
if trial.trial_type == self.default_trial_type
20482048
}
20492049

2050+
def add_trial_type(self, trial_type: str, runner: Runner) -> Self:
2051+
"""Add a new trial type to be supported by this experiment.
2052+
2053+
Args:
2054+
trial_type: The new trial type to be added.
2055+
runner: The default runner for trials of this type.
2056+
"""
2057+
if self.supports_trial_type(trial_type):
2058+
raise ValueError(f"Experiment already contains trial_type `{trial_type}`")
2059+
self._trial_type_to_runner[trial_type] = runner
2060+
return self
2061+
2062+
def update_runner(self, trial_type: str, runner: Runner) -> Self:
2063+
"""Update the default runner for an existing trial type.
2064+
2065+
Args:
2066+
trial_type: The trial type whose runner should be updated.
2067+
runner: The new runner for trials of this type.
2068+
"""
2069+
if not self.supports_trial_type(trial_type):
2070+
raise ValueError(f"Experiment does not contain trial_type `{trial_type}`")
2071+
self._trial_type_to_runner[trial_type] = runner
2072+
self._runner = runner
2073+
return self
2074+
20502075
def runner_for_trial_type(self, trial_type: str | None) -> Runner | None:
20512076
"""The default runner to use for a given trial type.
20522077
@@ -2061,9 +2086,14 @@ def runner_for_trial_type(self, trial_type: str | None) -> Runner | None:
20612086
def supports_trial_type(self, trial_type: str | None) -> bool:
20622087
"""Whether this experiment allows trials of the given type.
20632088
2064-
The base experiment class only supports None. For experiments
2065-
with multiple trial types, use the MultiTypeExperiment class.
2089+
For experiments with a ``default_trial_type`` (multi-type experiments),
2090+
only trial types registered in ``_trial_type_to_runner`` are supported.
2091+
For single-type experiments, ``None`` is always supported, along with
2092+
``SHORT_RUN`` and ``LONG_RUN`` for backward compatibility with
2093+
generation strategies that use those trial types.
20662094
"""
2095+
if self._default_trial_type is not None:
2096+
return trial_type in self._trial_type_to_runner
20672097
return (
20682098
trial_type is None
20692099
or trial_type == Keys.SHORT_RUN

ax/core/multi_type_experiment.py

Lines changed: 0 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -104,19 +104,6 @@ def __init__(
104104
self._default_trial_type
105105
)
106106

107-
def add_trial_type(self, trial_type: str, runner: Runner) -> Self:
108-
"""Add a new trial_type to be supported by this experiment.
109-
110-
Args:
111-
trial_type: The new trial_type to be added.
112-
runner: The default runner for trials of this type.
113-
"""
114-
if self.supports_trial_type(trial_type):
115-
raise ValueError(f"Experiment already contains trial_type `{trial_type}`")
116-
117-
self._trial_type_to_runner[trial_type] = runner
118-
return self
119-
120107
# pyre does not support inferring the type of property setter decorators
121108
# or the `.fset` attribute on properties.
122109
# pyre-fixme[56]: Pyre was not able to infer the type of the decorator.
@@ -130,20 +117,6 @@ def optimization_config(self, optimization_config: OptimizationConfig) -> None:
130117
self.default_trial_type
131118
)
132119

133-
def update_runner(self, trial_type: str, runner: Runner) -> Self:
134-
"""Update the default runner for an existing trial_type.
135-
136-
Args:
137-
trial_type: The new trial_type to be added.
138-
runner: The new runner for trials of this type.
139-
"""
140-
if not self.supports_trial_type(trial_type):
141-
raise ValueError(f"Experiment does not contain trial_type `{trial_type}`")
142-
143-
self._trial_type_to_runner[trial_type] = runner
144-
self._runner = runner
145-
return self
146-
147120
def add_tracking_metric(
148121
self,
149122
metric: Metric,
@@ -232,13 +205,6 @@ def _fetch_trial_data(
232205
# Invoke parent's fetch method using only metrics for this trial_type
233206
return super()._fetch_trial_data(trial.index, metrics=metrics, **kwargs)
234207

235-
def supports_trial_type(self, trial_type: str | None) -> bool:
236-
"""Whether this experiment allows trials of the given type.
237-
238-
Only trial types defined in the trial_type_to_runner are allowed.
239-
"""
240-
return trial_type in self._trial_type_to_runner.keys()
241-
242208

243209
def filter_trials_by_type(
244210
trials: Sequence[BaseTrial], trial_type: str | None

0 commit comments

Comments
 (0)